adds model class to load pipelines structually
This commit is contained in:
parent
c07a2b6241
commit
e40a7c02f6
49
main.py
49
main.py
|
|
@ -1,55 +1,20 @@
|
|||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
|
||||
from utilities.memory import empty_memory
|
||||
from utilities.model import Model
|
||||
from utilities.logger import Logger
|
||||
|
||||
def from_pretrained(model_name: str, logger: Logger):
|
||||
rev = "diffusers-115k" if model_name == "naclbit/trinart_stable_diffusion_v2" else "fp16"
|
||||
|
||||
pipe = None
|
||||
try:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_name, revision=rev, torch_dtype=torch.float16, safety_checker=None)
|
||||
pipe.to("cuda")
|
||||
except:
|
||||
try:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16, safety_checker=None)
|
||||
pipe.to("cuda")
|
||||
except Exception as e:
|
||||
logger.error("Failed to load model %s: %s" % (model_name, e))
|
||||
return pipe
|
||||
|
||||
def prepare(logger: Logger):
|
||||
empty_memory()
|
||||
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
model_name = "darkstorm2150/Protogen_x3.4_Official_Release"
|
||||
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
logger.error("no GPU found, will not proceed")
|
||||
return False
|
||||
|
||||
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
|
||||
|
||||
logger.info("loading model: {}".format(model_name))
|
||||
pipeline = from_pretrained(model_name, logger)
|
||||
|
||||
if pipeline is None:
|
||||
return False
|
||||
|
||||
img2img = StableDiffusionImg2ImgPipeline(**pipeline.components)
|
||||
default_pipe_scheduler = pipeline.scheduler
|
||||
|
||||
return True
|
||||
model = Model(model_name, inpainting_model_name, logger)
|
||||
model.reduce_memory()
|
||||
model.load()
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
logger = Logger(name="rl_trader")
|
||||
|
||||
if not prepare(logger):
|
||||
return
|
||||
model = prepare(logger)
|
||||
|
||||
input("confirm...")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,15 @@ py_library(
|
|||
srcs = ["memory.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model",
|
||||
srcs = ["model.py"],
|
||||
deps = [
|
||||
":memory",
|
||||
":logger",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "logger",
|
||||
srcs = ["logger.py"],
|
||||
|
|
|
|||
|
|
@ -4,7 +4,14 @@ import torch
|
|||
|
||||
def empty_memory():
|
||||
'''
|
||||
Performs garbage collection and empty cache in cuda device
|
||||
Performs garbage collection and empty cache in cuda device.
|
||||
'''
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def tune_for_low_memory():
|
||||
'''
|
||||
Tunes PyTorch to use float16 to reduce memory footprint.
|
||||
'''
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,84 @@
|
|||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers import StableDiffusionInpaintPipeline
|
||||
|
||||
from utilities.logger import Logger
|
||||
from utilities.memory import empty_memory
|
||||
from utilities.memory import tune_for_low_memory
|
||||
|
||||
class Model():
|
||||
'''Model class.'''
|
||||
def __init__(self, model_name: str, inpainting_model_name: str, logger: Logger, use_gpu: bool=True):
|
||||
self.model_name = model_name
|
||||
self.inpainting_model_name = inpainting_model_name
|
||||
self.__use_gpu = False
|
||||
if use_gpu and torch.cuda.is_available():
|
||||
self.__use_gpu = True
|
||||
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
|
||||
self.logger = logger
|
||||
self.sd_pipeline = None
|
||||
self.img2img_pipeline = None
|
||||
self.inpaint_pipeline = None
|
||||
self.__torch_dtype = "auto"
|
||||
|
||||
def use_gpu(self):
|
||||
return self.__use_gpu
|
||||
|
||||
def reduce_memory(self):
|
||||
self.logger.info("reduces memory usage by using float16 dtype")
|
||||
tune_for_low_memory()
|
||||
self.__torch_dtype = torch.float16
|
||||
|
||||
def load(self):
|
||||
empty_memory()
|
||||
|
||||
if self.model_name:
|
||||
revision = get_revision_from_model_name(self.model_name)
|
||||
sd_pipeline = None
|
||||
try:
|
||||
sd_pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
model_name,
|
||||
revision=revision,
|
||||
torch_dtype=self.__torch_dtype,
|
||||
safety_checker=None)
|
||||
except:
|
||||
try:
|
||||
sd_pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.__torch_dtype,
|
||||
safety_checker=None)
|
||||
except Exception as e:
|
||||
self.logger.error("failed to load model %s: %s" % (self.model_name, e))
|
||||
if sd_pipeline and self.use_gpu():
|
||||
sd_pipeline.to("cuda")
|
||||
self.sd_pipeline = sd_pipeline
|
||||
self.sd_pipeline_scheduler = sd_pipeline.scheduler
|
||||
|
||||
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(**sd_pipeline.components)
|
||||
|
||||
if self.inpainting_model_name:
|
||||
revision = get_revision_from_model_name(self.inpainting_model_name)
|
||||
inpaint_pipeline = None
|
||||
try:
|
||||
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_name,
|
||||
revision=revision,
|
||||
torch_dtype=self.__torch_dtype,
|
||||
safety_checker=None)
|
||||
except:
|
||||
try:
|
||||
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
self.inpainting_model_name,
|
||||
torch_dtype=self.__torch_dtype,
|
||||
safety_checker=None)
|
||||
except Exception as e:
|
||||
self.logger.error("failed to load inpaint model %s: %s" % (self.inpainting_model_name, e))
|
||||
if inpaint_pipeline and self.use_gpu():
|
||||
inpaint_pipeline.to("cuda")
|
||||
self.inpaint_pipeline = inpaint_pipeline
|
||||
self.inpaint_pipeline_scheduler = inpaint_pipeline.scheduler
|
||||
|
||||
|
||||
def get_revision_from_model_name(model_name: str):
|
||||
return "diffusers-115k" if model_name == "naclbit/trinart_stable_diffusion_v2" else "fp16"
|
||||
Loading…
Reference in New Issue