adds initial model loading flow
This commit is contained in:
parent
67c080db68
commit
c07a2b6241
9
BUILD
9
BUILD
|
|
@ -1,4 +1,13 @@
|
||||||
load("@rules_python//python:defs.bzl", "py_binary")
|
load("@rules_python//python:defs.bzl", "py_binary")
|
||||||
|
load("@subpar//:subpar.bzl", "par_binary")
|
||||||
|
|
||||||
package(default_visibility = ["//visibility:public"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
par_binary(
|
||||||
|
name = 'main',
|
||||||
|
srcs = ["main.py"],
|
||||||
|
deps = [
|
||||||
|
"//utilities:logger",
|
||||||
|
"//utilities:memory",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
accelerate
|
||||||
colorlog
|
colorlog
|
||||||
diffusers
|
diffusers
|
||||||
torch
|
torch
|
||||||
|
transformers
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,11 @@
|
||||||
|
load("@rules_python//python:defs.bzl", "py_library", "py_test")
|
||||||
|
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "memory",
|
||||||
|
srcs = ["memory.py"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "logger",
|
name = "logger",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
import gc
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def empty_memory():
|
||||||
|
'''
|
||||||
|
Performs garbage collection and empty cache in cuda device
|
||||||
|
'''
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
Loading…
Reference in New Issue