[BE] adds variable to allow bypassing memory usage reduction for better GPUs

This commit is contained in:
HappyZ 2023-05-19 17:16:55 -07:00
parent a8eeaec829
commit ffe0a0b688
2 changed files with 6 additions and 4 deletions

View File

@ -36,7 +36,7 @@ logger = Logger(name=LOGGER_NAME_BACKEND)
database = Database(logger) database = Database(logger)
def load_model(logger: Logger, use_gpu: bool) -> Model: def load_model(logger: Logger, use_gpu: bool, reduce_memory_usage: bool) -> Model:
# model candidates: # model candidates:
# "runwayml/stable-diffusion-v1-5" # "runwayml/stable-diffusion-v1-5"
# "CompVis/stable-diffusion-v1-4" # "CompVis/stable-diffusion-v1-4"
@ -53,7 +53,7 @@ def load_model(logger: Logger, use_gpu: bool) -> Model:
inpainting_model_name = "runwayml/stable-diffusion-inpainting" inpainting_model_name = "runwayml/stable-diffusion-inpainting"
model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu) model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu)
if use_gpu: if use_gpu and reduce_memory_usage:
model.set_low_memory_mode() model.set_low_memory_mode()
model.load_all() model.load_all()
@ -151,7 +151,7 @@ def main(args):
database.set_image_output_folder(args.image_output_folder) database.set_image_output_folder(args.image_output_folder)
database.connect(args.db) database.connect(args.db)
model = load_model(logger, args.gpu) model = load_model(logger, args.gpu, args.reduce_memory_usage)
backend(model, args.debug) backend(model, args.debug)
database.safe_disconnect() database.safe_disconnect()
@ -171,6 +171,9 @@ if __name__ == "__main__":
# Add an argument to set the path of the database file # Add an argument to set the path of the database file
parser.add_argument("--gpu", action="store_true", help="Enable to use GPU device") parser.add_argument("--gpu", action="store_true", help="Enable to use GPU device")
# Add an argument to reduce memory usage
parser.add_argument("--reduce-memory-usage", action="store_true", help="Reduce memory usage when using GPU")
# Add an argument to set the path of the database file # Add an argument to set the path of the database file
parser.add_argument( parser.add_argument(
"--image-output-folder", "--image-output-folder",

View File

@ -8,4 +8,3 @@ scikit-image
torch torch
transformers transformers
sentencepiece sentencepiece
fcntl