diff --git a/backend.py b/backend.py index f720265..6fc2dd5 100644 --- a/backend.py +++ b/backend.py @@ -36,7 +36,7 @@ logger = Logger(name=LOGGER_NAME_BACKEND) database = Database(logger) -def load_model(logger: Logger, use_gpu: bool) -> Model: +def load_model(logger: Logger, use_gpu: bool, reduce_memory_usage: bool) -> Model: # model candidates: # "runwayml/stable-diffusion-v1-5" # "CompVis/stable-diffusion-v1-4" @@ -53,7 +53,7 @@ def load_model(logger: Logger, use_gpu: bool) -> Model: inpainting_model_name = "runwayml/stable-diffusion-inpainting" model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu) - if use_gpu: + if use_gpu and reduce_memory_usage: model.set_low_memory_mode() model.load_all() @@ -151,7 +151,7 @@ def main(args): database.set_image_output_folder(args.image_output_folder) database.connect(args.db) - model = load_model(logger, args.gpu) + model = load_model(logger, args.gpu, args.reduce_memory_usage) backend(model, args.debug) database.safe_disconnect() @@ -171,6 +171,9 @@ if __name__ == "__main__": # Add an argument to set the path of the database file parser.add_argument("--gpu", action="store_true", help="Enable to use GPU device") + # Add an argument to reduce memory usage + parser.add_argument("--reduce-memory-usage", action="store_true", help="Reduce memory usage when using GPU") + # Add an argument to set the path of the database file parser.add_argument( "--image-output-folder", diff --git a/requirements.txt b/requirements.txt index b44b5f5..0321179 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,3 @@ scikit-image torch transformers sentencepiece -fcntl