From 74c2acbe3cf8b6e9e8bf5cc6823cef5daead9fc8 Mon Sep 17 00:00:00 2001 From: HappyZ Date: Wed, 31 May 2023 18:30:21 -0700 Subject: [PATCH] [BE] supports to customize gpu device name when bootstrapping backend --- backend.py | 21 +++++++++++++++++---- utilities/img2img.py | 2 +- utilities/inpainting.py | 2 +- utilities/model.py | 11 ++++++++--- utilities/text2img.py | 2 +- 5 files changed, 28 insertions(+), 10 deletions(-) diff --git a/backend.py b/backend.py index 95b94a6..2131c42 100644 --- a/backend.py +++ b/backend.py @@ -40,7 +40,9 @@ logger = Logger(name=LOGGER_NAME_BACKEND) database = Database(logger) -def load_model(logger: Logger, use_gpu: bool, reduce_memory_usage: bool) -> Model: +def load_model( + logger: Logger, use_gpu: bool, gpu_device_name: str, reduce_memory_usage: bool +) -> Model: # model candidates: # "runwayml/stable-diffusion-v1-5" # "CompVis/stable-diffusion-v1-4" @@ -56,7 +58,13 @@ def load_model(logger: Logger, use_gpu: bool, reduce_memory_usage: bool) -> Mode # "runwayml/stable-diffusion-inpainting" inpainting_model_name = "runwayml/stable-diffusion-inpainting" - model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu) + model = Model( + model_name, + inpainting_model_name, + logger, + use_gpu=use_gpu, + gpu_device_name=gpu_device_name, + ) if use_gpu and reduce_memory_usage: model.set_low_memory_mode() model.load_all() @@ -172,7 +180,7 @@ def main(args): database.set_image_output_folder(args.image_output_folder) database.connect(args.db) - model = load_model(logger, args.gpu, args.reduce_memory_usage) + model = load_model(logger, args.gpu, args.gpu_device, args.reduce_memory_usage) backend(model, args.gfpgan, args.debug) database.safe_disconnect() @@ -189,9 +197,14 @@ if __name__ == "__main__": "--db", type=str, default="happysd.db", help="Path to SQLite database file" ) - # Add an argument to set the path of the database file + # Add an argument to set the 'gpu' flag parser.add_argument("--gpu", action="store_true", help="Enable to use GPU device") + # Add an argument to set the gpu device name + parser.add_argument( + "--gpu-device", type=str, default="cuda", help="GPU device name" + ) + # Add an argument to reduce memory usage parser.add_argument( "--reduce-memory-usage", diff --git a/utilities/img2img.py b/utilities/img2img.py index 99989dc..38e942e 100644 --- a/utilities/img2img.py +++ b/utilities/img2img.py @@ -31,7 +31,7 @@ class Img2Img: logger: DummyLogger = DummyLogger(), ): self.model = model - self.__device = "cpu" if not self.model.use_gpu() else "cuda" + self.__device = "cpu" if not self.model.use_gpu() else self.model.get_gpu_device_name() self.__output_folder = output_folder self.__logger = logger diff --git a/utilities/inpainting.py b/utilities/inpainting.py index 1df64fd..8f0acc5 100644 --- a/utilities/inpainting.py +++ b/utilities/inpainting.py @@ -31,7 +31,7 @@ class Inpainting: logger: DummyLogger = DummyLogger(), ): self.model = model - self.__device = "cpu" if not self.model.use_gpu() else "cuda" + self.__device = "cpu" if not self.model.use_gpu() else self.model.get_gpu_device_name() self.__output_folder = output_folder self.__logger = logger diff --git a/utilities/model.py b/utilities/model.py index 25bcd98..bd00035 100644 --- a/utilities/model.py +++ b/utilities/model.py @@ -24,13 +24,15 @@ class Model: inpainting_model_name: str, logger: DummyLogger = DummyLogger(), use_gpu: bool = True, + gpu_device_name: str = "cuda", ): self.model_name = model_name self.inpainting_model_name = inpainting_model_name self.__use_gpu = False + self.__gpu_device = gpu_device_name if use_gpu and torch.cuda.is_available(): self.__use_gpu = True - logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0"))) + logger.info("running on {}".format(torch.cuda.get_device_name(self.__gpu_device))) else: logger.info("running on CPU (expect it to be verrry sloooow)") self.__logger = logger @@ -44,6 +46,9 @@ class Model: def use_gpu(self): return self.__use_gpu + def get_gpu_device_name(self): + return self.__gpu_device + def update_model_name(self, model_name:str): if not model_name or model_name == self.model_name: self.__logger.warn("model name empty or the same, not updated") @@ -113,7 +118,7 @@ class Model: "failed to load model %s: %s" % (self.model_name, e) ) if pipeline and self.use_gpu(): - pipeline.to("cuda") + pipeline.to(self.get_gpu_device_name()) self.txt2img_pipeline = pipeline self.__default_txt2img_scheduler = pipeline.scheduler @@ -154,7 +159,7 @@ class Model: % (self.inpainting_model_name, e) ) if pipeline and self.use_gpu(): - pipeline.to("cuda") + pipeline.to(self.get_gpu_device_name()) self.inpaint_pipeline = pipeline self.__default_inpaint_scheduler = pipeline.scheduler empty_memory_cache() diff --git a/utilities/text2img.py b/utilities/text2img.py index 298d785..9be2585 100644 --- a/utilities/text2img.py +++ b/utilities/text2img.py @@ -28,7 +28,7 @@ class Text2Img: logger: DummyLogger = DummyLogger(), ): self.model = model - self.__device = "cpu" if not self.model.use_gpu() else "cuda" + self.__device = "cpu" if not self.model.use_gpu() else self.model.get_gpu_device_name() self.__output_folder = output_folder self.__logger = logger