Merge branch 'hotfix/skip-free-memory-for-interrogate'

This commit is contained in:
Manuel Schmid 2023-12-26 22:47:06 +01:00
commit 3503358dfd
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 6 additions and 6 deletions

View File

@ -47,7 +47,7 @@ class Interrogator:
self.blip_model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device)
model_management.load_model_gpu(self.blip_model)
model_management.load_model_gpu(self.blip_model, should_free_memory=False)
gpu_image = transforms.Compose([
transforms.ToTensor(),

View File

@ -353,7 +353,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
def load_models_gpu(models, memory_required=0):
def load_models_gpu(models, memory_required=0, should_free_memory=True):
global vram_state
inference_memory = minimum_inference_memory()
@ -376,7 +376,7 @@ def load_models_gpu(models, memory_required=0):
if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
if d != torch.device("cpu") and should_free_memory:
free_memory(extra_mem, d, models_already_loaded)
return
@ -388,7 +388,7 @@ def load_models_gpu(models, memory_required=0):
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
if device != torch.device("cpu") and should_free_memory:
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load:
@ -416,8 +416,8 @@ def load_models_gpu(models, memory_required=0):
return
def load_model_gpu(model):
return load_models_gpu([model])
def load_model_gpu(model, should_free_memory=True):
return load_models_gpu([model], should_free_memory=should_free_memory)
def cleanup_models():
to_delete = []