[BE] supports to customize gpu device name when bootstrapping backend

This commit is contained in:
HappyZ 2023-05-31 18:30:21 -07:00
parent 54dc193f4e
commit 74c2acbe3c
5 changed files with 28 additions and 10 deletions

View File

@ -40,7 +40,9 @@ logger = Logger(name=LOGGER_NAME_BACKEND)
database = Database(logger)
def load_model(logger: Logger, use_gpu: bool, reduce_memory_usage: bool) -> Model:
def load_model(
logger: Logger, use_gpu: bool, gpu_device_name: str, reduce_memory_usage: bool
) -> Model:
# model candidates:
# "runwayml/stable-diffusion-v1-5"
# "CompVis/stable-diffusion-v1-4"
@ -56,7 +58,13 @@ def load_model(logger: Logger, use_gpu: bool, reduce_memory_usage: bool) -> Mode
# "runwayml/stable-diffusion-inpainting"
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu)
model = Model(
model_name,
inpainting_model_name,
logger,
use_gpu=use_gpu,
gpu_device_name=gpu_device_name,
)
if use_gpu and reduce_memory_usage:
model.set_low_memory_mode()
model.load_all()
@ -172,7 +180,7 @@ def main(args):
database.set_image_output_folder(args.image_output_folder)
database.connect(args.db)
model = load_model(logger, args.gpu, args.reduce_memory_usage)
model = load_model(logger, args.gpu, args.gpu_device, args.reduce_memory_usage)
backend(model, args.gfpgan, args.debug)
database.safe_disconnect()
@ -189,9 +197,14 @@ if __name__ == "__main__":
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
)
# Add an argument to set the path of the database file
# Add an argument to set the 'gpu' flag
parser.add_argument("--gpu", action="store_true", help="Enable to use GPU device")
# Add an argument to set the gpu device name
parser.add_argument(
"--gpu-device", type=str, default="cuda", help="GPU device name"
)
# Add an argument to reduce memory usage
parser.add_argument(
"--reduce-memory-usage",

View File

@ -31,7 +31,7 @@ class Img2Img:
logger: DummyLogger = DummyLogger(),
):
self.model = model
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
self.__device = "cpu" if not self.model.use_gpu() else self.model.get_gpu_device_name()
self.__output_folder = output_folder
self.__logger = logger

View File

@ -31,7 +31,7 @@ class Inpainting:
logger: DummyLogger = DummyLogger(),
):
self.model = model
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
self.__device = "cpu" if not self.model.use_gpu() else self.model.get_gpu_device_name()
self.__output_folder = output_folder
self.__logger = logger

View File

@ -24,13 +24,15 @@ class Model:
inpainting_model_name: str,
logger: DummyLogger = DummyLogger(),
use_gpu: bool = True,
gpu_device_name: str = "cuda",
):
self.model_name = model_name
self.inpainting_model_name = inpainting_model_name
self.__use_gpu = False
self.__gpu_device = gpu_device_name
if use_gpu and torch.cuda.is_available():
self.__use_gpu = True
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
logger.info("running on {}".format(torch.cuda.get_device_name(self.__gpu_device)))
else:
logger.info("running on CPU (expect it to be verrry sloooow)")
self.__logger = logger
@ -44,6 +46,9 @@ class Model:
def use_gpu(self):
return self.__use_gpu
def get_gpu_device_name(self):
return self.__gpu_device
def update_model_name(self, model_name:str):
if not model_name or model_name == self.model_name:
self.__logger.warn("model name empty or the same, not updated")
@ -113,7 +118,7 @@ class Model:
"failed to load model %s: %s" % (self.model_name, e)
)
if pipeline and self.use_gpu():
pipeline.to("cuda")
pipeline.to(self.get_gpu_device_name())
self.txt2img_pipeline = pipeline
self.__default_txt2img_scheduler = pipeline.scheduler
@ -154,7 +159,7 @@ class Model:
% (self.inpainting_model_name, e)
)
if pipeline and self.use_gpu():
pipeline.to("cuda")
pipeline.to(self.get_gpu_device_name())
self.inpaint_pipeline = pipeline
self.__default_inpaint_scheduler = pipeline.scheduler
empty_memory_cache()

View File

@ -28,7 +28,7 @@ class Text2Img:
logger: DummyLogger = DummyLogger(),
):
self.model = model
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
self.__device = "cpu" if not self.model.use_gpu() else self.model.get_gpu_device_name()
self.__output_folder = output_folder
self.__logger = logger