[BE] supports to customize gpu device name when bootstrapping backend
This commit is contained in:
parent
54dc193f4e
commit
74c2acbe3c
21
backend.py
21
backend.py
|
|
@ -40,7 +40,9 @@ logger = Logger(name=LOGGER_NAME_BACKEND)
|
|||
database = Database(logger)
|
||||
|
||||
|
||||
def load_model(logger: Logger, use_gpu: bool, reduce_memory_usage: bool) -> Model:
|
||||
def load_model(
|
||||
logger: Logger, use_gpu: bool, gpu_device_name: str, reduce_memory_usage: bool
|
||||
) -> Model:
|
||||
# model candidates:
|
||||
# "runwayml/stable-diffusion-v1-5"
|
||||
# "CompVis/stable-diffusion-v1-4"
|
||||
|
|
@ -56,7 +58,13 @@ def load_model(logger: Logger, use_gpu: bool, reduce_memory_usage: bool) -> Mode
|
|||
# "runwayml/stable-diffusion-inpainting"
|
||||
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
|
||||
|
||||
model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu)
|
||||
model = Model(
|
||||
model_name,
|
||||
inpainting_model_name,
|
||||
logger,
|
||||
use_gpu=use_gpu,
|
||||
gpu_device_name=gpu_device_name,
|
||||
)
|
||||
if use_gpu and reduce_memory_usage:
|
||||
model.set_low_memory_mode()
|
||||
model.load_all()
|
||||
|
|
@ -172,7 +180,7 @@ def main(args):
|
|||
database.set_image_output_folder(args.image_output_folder)
|
||||
database.connect(args.db)
|
||||
|
||||
model = load_model(logger, args.gpu, args.reduce_memory_usage)
|
||||
model = load_model(logger, args.gpu, args.gpu_device, args.reduce_memory_usage)
|
||||
backend(model, args.gfpgan, args.debug)
|
||||
|
||||
database.safe_disconnect()
|
||||
|
|
@ -189,9 +197,14 @@ if __name__ == "__main__":
|
|||
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
|
||||
)
|
||||
|
||||
# Add an argument to set the path of the database file
|
||||
# Add an argument to set the 'gpu' flag
|
||||
parser.add_argument("--gpu", action="store_true", help="Enable to use GPU device")
|
||||
|
||||
# Add an argument to set the gpu device name
|
||||
parser.add_argument(
|
||||
"--gpu-device", type=str, default="cuda", help="GPU device name"
|
||||
)
|
||||
|
||||
# Add an argument to reduce memory usage
|
||||
parser.add_argument(
|
||||
"--reduce-memory-usage",
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class Img2Img:
|
|||
logger: DummyLogger = DummyLogger(),
|
||||
):
|
||||
self.model = model
|
||||
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
|
||||
self.__device = "cpu" if not self.model.use_gpu() else self.model.get_gpu_device_name()
|
||||
self.__output_folder = output_folder
|
||||
self.__logger = logger
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class Inpainting:
|
|||
logger: DummyLogger = DummyLogger(),
|
||||
):
|
||||
self.model = model
|
||||
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
|
||||
self.__device = "cpu" if not self.model.use_gpu() else self.model.get_gpu_device_name()
|
||||
self.__output_folder = output_folder
|
||||
self.__logger = logger
|
||||
|
||||
|
|
|
|||
|
|
@ -24,13 +24,15 @@ class Model:
|
|||
inpainting_model_name: str,
|
||||
logger: DummyLogger = DummyLogger(),
|
||||
use_gpu: bool = True,
|
||||
gpu_device_name: str = "cuda",
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.inpainting_model_name = inpainting_model_name
|
||||
self.__use_gpu = False
|
||||
self.__gpu_device = gpu_device_name
|
||||
if use_gpu and torch.cuda.is_available():
|
||||
self.__use_gpu = True
|
||||
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
|
||||
logger.info("running on {}".format(torch.cuda.get_device_name(self.__gpu_device)))
|
||||
else:
|
||||
logger.info("running on CPU (expect it to be verrry sloooow)")
|
||||
self.__logger = logger
|
||||
|
|
@ -44,6 +46,9 @@ class Model:
|
|||
def use_gpu(self):
|
||||
return self.__use_gpu
|
||||
|
||||
def get_gpu_device_name(self):
|
||||
return self.__gpu_device
|
||||
|
||||
def update_model_name(self, model_name:str):
|
||||
if not model_name or model_name == self.model_name:
|
||||
self.__logger.warn("model name empty or the same, not updated")
|
||||
|
|
@ -113,7 +118,7 @@ class Model:
|
|||
"failed to load model %s: %s" % (self.model_name, e)
|
||||
)
|
||||
if pipeline and self.use_gpu():
|
||||
pipeline.to("cuda")
|
||||
pipeline.to(self.get_gpu_device_name())
|
||||
|
||||
self.txt2img_pipeline = pipeline
|
||||
self.__default_txt2img_scheduler = pipeline.scheduler
|
||||
|
|
@ -154,7 +159,7 @@ class Model:
|
|||
% (self.inpainting_model_name, e)
|
||||
)
|
||||
if pipeline and self.use_gpu():
|
||||
pipeline.to("cuda")
|
||||
pipeline.to(self.get_gpu_device_name())
|
||||
self.inpaint_pipeline = pipeline
|
||||
self.__default_inpaint_scheduler = pipeline.scheduler
|
||||
empty_memory_cache()
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class Text2Img:
|
|||
logger: DummyLogger = DummyLogger(),
|
||||
):
|
||||
self.model = model
|
||||
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
|
||||
self.__device = "cpu" if not self.model.use_gpu() else self.model.get_gpu_device_name()
|
||||
self.__output_folder = output_folder
|
||||
self.__logger = logger
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue