From 78d1ad3962aedcb27629e9b78c8a06f1a6bc61bb Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Mon, 29 Jan 2024 21:56:10 +0100 Subject: [PATCH] feat: add lora handling to A1111 scheme --- modules/metadata.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/modules/metadata.py b/modules/metadata.py index 818494c2..f88f804f 100644 --- a/modules/metadata.py +++ b/modules/metadata.py @@ -1,5 +1,6 @@ import json import re +from pathlib import Path from abc import ABC, abstractmethod from PIL import Image @@ -40,6 +41,7 @@ class A1111MetadataParser(MetadataParser): 'refiner_model': 'Refiner', 'refiner_model_hash': 'Refiner hash', 'lora_hashes': 'Lora hashes', + 'lora_weights': 'Lora weights', 'version': 'Version' } @@ -92,6 +94,25 @@ class A1111MetadataParser(MetadataParser): except Exception: pass + if 'base_model' in data: + for filename in modules.config.model_filenames: + path = Path(filename) + if data['base_model'] == path.stem: + data['base_model'] = path.name + break + + if 'lora_hashes' in data: + # TODO optimize by using hash for matching. Problem is speed of creating the hash per model, even on startup + lora_filenames = modules.config.lora_filenames.copy() + lora_filenames.remove(modules.config.downloading_sdxl_lcm_lora()) + for li, lora in enumerate(data['lora_hashes'].split(', ')): + name, _, weight = lora.split(': ') + for filename in lora_filenames: + path = Path(filename) + if name == path.stem: + data[f'lora_combined_{li + 1}'] = f'{path.name} : {weight}' + break + return data def parse_string(self, metadata: dict) -> str: @@ -101,30 +122,30 @@ class A1111MetadataParser(MetadataParser): width, heigth = eval(data['resolution']) lora_hashes = [] + lora_weights = [] for index in range(lora_count_with_lcm): key = f'lora_name_{index + 1}' if key in data: - name = data[f'lora_name_{index + 1}'] - # TODO handle LoRA weight - # weight = data[f'lora_weight_{index + 1}'] - hash = data[f'lora_hash_{index + 1}'] - lora_hashes.append(f'{name.split(".")[0]}: {hash}') - lora_hashes_string = ", ".join(lora_hashes) + lora_name = Path(data[f'lora_name_{index + 1}']).stem + lora_weight = data[f'lora_weight_{index + 1}'] + lora_hash = data[f'lora_hash_{index + 1}'] + # workaround for Fooocus not knowing LoRA name in LoRA metadata + lora_hashes.append(f'{lora_name}: {lora_hash}: {lora_weight}') + lora_hashes_string = ', '.join(lora_hashes) generation_params = { self.fooocus_to_a1111['steps']: data['steps'], self.fooocus_to_a1111['sampler']: data['sampler'], self.fooocus_to_a1111['seed']: data['seed'], - # TODO check resolution value, should be string self.fooocus_to_a1111['resolution']: f'{width}x{heigth}', # TODO load model by name / hash - self.fooocus_to_a1111['base_model']: data['base_model'].split('.')[0], + self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem, self.fooocus_to_a1111['base_model_hash']: data['base_model_hash'] } if 'refiner_model' in data and data['refiner_model'] != 'None' and 'refiner_model_hash' in data: generation_params |= { - self.fooocus_to_a1111['refiner_model']: data['refiner_model'].split('.')[0], + self.fooocus_to_a1111['refiner_model']: Path(data['refiner_model']).stem, self.fooocus_to_a1111['refiner_model_hash']: data['refiner_model_hash'], }