adds initial model loading flow
This commit is contained in:
parent
67c080db68
commit
c07a2b6241
9
BUILD
9
BUILD
|
|
@ -1,4 +1,13 @@
|
|||
load("@rules_python//python:defs.bzl", "py_binary")
|
||||
load("@subpar//:subpar.bzl", "par_binary")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
par_binary(
|
||||
name = 'main',
|
||||
srcs = ["main.py"],
|
||||
deps = [
|
||||
"//utilities:logger",
|
||||
"//utilities:memory",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
accelerate
|
||||
colorlog
|
||||
diffusers
|
||||
torch
|
||||
transformers
|
||||
|
|
|
|||
|
|
@ -1,3 +1,11 @@
|
|||
load("@rules_python//python:defs.bzl", "py_library", "py_test")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
py_library(
|
||||
name = "memory",
|
||||
srcs = ["memory.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "logger",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
import gc
|
||||
import torch
|
||||
|
||||
|
||||
def empty_memory():
|
||||
'''
|
||||
Performs garbage collection and empty cache in cuda device
|
||||
'''
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
Loading…
Reference in New Issue