adds model class to load pipelines structually

This commit is contained in:
happy 2023-04-26 00:21:17 -07:00
parent c07a2b6241
commit e40a7c02f6
4 changed files with 108 additions and 43 deletions

49
main.py
View File

@ -1,55 +1,20 @@
import torch
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionImg2ImgPipeline
from utilities.memory import empty_memory
from utilities.model import Model
from utilities.logger import Logger
def from_pretrained(model_name: str, logger: Logger):
rev = "diffusers-115k" if model_name == "naclbit/trinart_stable_diffusion_v2" else "fp16"
pipe = None
try:
pipe = StableDiffusionPipeline.from_pretrained(model_name, revision=rev, torch_dtype=torch.float16, safety_checker=None)
pipe.to("cuda")
except:
try:
pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16, safety_checker=None)
pipe.to("cuda")
except Exception as e:
logger.error("Failed to load model %s: %s" % (model_name, e))
return pipe
def prepare(logger: Logger):
empty_memory()
torch.set_default_dtype(torch.float16)
model_name = "darkstorm2150/Protogen_x3.4_Official_Release"
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
if not torch.cuda.is_available():
logger.error("no GPU found, will not proceed")
return False
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
logger.info("loading model: {}".format(model_name))
pipeline = from_pretrained(model_name, logger)
if pipeline is None:
return False
img2img = StableDiffusionImg2ImgPipeline(**pipeline.components)
default_pipe_scheduler = pipeline.scheduler
return True
model = Model(model_name, inpainting_model_name, logger)
model.reduce_memory()
model.load()
return model
def main():
logger = Logger(name="rl_trader")
if not prepare(logger):
return
model = prepare(logger)
input("confirm...")

View File

@ -7,6 +7,15 @@ py_library(
srcs = ["memory.py"],
)
py_library(
name = "model",
srcs = ["model.py"],
deps = [
":memory",
":logger",
],
)
py_library(
name = "logger",
srcs = ["logger.py"],

View File

@ -4,7 +4,14 @@ import torch
def empty_memory():
'''
Performs garbage collection and empty cache in cuda device
Performs garbage collection and empty cache in cuda device.
'''
gc.collect()
torch.cuda.empty_cache()
def tune_for_low_memory():
'''
Tunes PyTorch to use float16 to reduce memory footprint.
'''
torch.set_default_dtype(torch.float16)

84
utilities/model.py Normal file
View File

@ -0,0 +1,84 @@
import torch
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionInpaintPipeline
from utilities.logger import Logger
from utilities.memory import empty_memory
from utilities.memory import tune_for_low_memory
class Model():
'''Model class.'''
def __init__(self, model_name: str, inpainting_model_name: str, logger: Logger, use_gpu: bool=True):
self.model_name = model_name
self.inpainting_model_name = inpainting_model_name
self.__use_gpu = False
if use_gpu and torch.cuda.is_available():
self.__use_gpu = True
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
self.logger = logger
self.sd_pipeline = None
self.img2img_pipeline = None
self.inpaint_pipeline = None
self.__torch_dtype = "auto"
def use_gpu(self):
return self.__use_gpu
def reduce_memory(self):
self.logger.info("reduces memory usage by using float16 dtype")
tune_for_low_memory()
self.__torch_dtype = torch.float16
def load(self):
empty_memory()
if self.model_name:
revision = get_revision_from_model_name(self.model_name)
sd_pipeline = None
try:
sd_pipeline = StableDiffusionPipeline.from_pretrained(
model_name,
revision=revision,
torch_dtype=self.__torch_dtype,
safety_checker=None)
except:
try:
sd_pipeline = StableDiffusionPipeline.from_pretrained(
self.model_name,
torch_dtype=self.__torch_dtype,
safety_checker=None)
except Exception as e:
self.logger.error("failed to load model %s: %s" % (self.model_name, e))
if sd_pipeline and self.use_gpu():
sd_pipeline.to("cuda")
self.sd_pipeline = sd_pipeline
self.sd_pipeline_scheduler = sd_pipeline.scheduler
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(**sd_pipeline.components)
if self.inpainting_model_name:
revision = get_revision_from_model_name(self.inpainting_model_name)
inpaint_pipeline = None
try:
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
model_name,
revision=revision,
torch_dtype=self.__torch_dtype,
safety_checker=None)
except:
try:
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
self.inpainting_model_name,
torch_dtype=self.__torch_dtype,
safety_checker=None)
except Exception as e:
self.logger.error("failed to load inpaint model %s: %s" % (self.inpainting_model_name, e))
if inpaint_pipeline and self.use_gpu():
inpaint_pipeline.to("cuda")
self.inpaint_pipeline = inpaint_pipeline
self.inpaint_pipeline_scheduler = inpaint_pipeline.scheduler
def get_revision_from_model_name(model_name: str):
return "diffusers-115k" if model_name == "naclbit/trinart_stable_diffusion_v2" else "fp16"