feat: add lora handling to A1111 scheme

This commit is contained in:
Manuel Schmid 2024-01-29 21:56:10 +01:00
parent 89c8e3a812
commit 78d1ad3962
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
1 changed files with 30 additions and 9 deletions

View File

@ -1,5 +1,6 @@
import json
import re
from pathlib import Path
from abc import ABC, abstractmethod
from PIL import Image
@ -40,6 +41,7 @@ class A1111MetadataParser(MetadataParser):
'refiner_model': 'Refiner',
'refiner_model_hash': 'Refiner hash',
'lora_hashes': 'Lora hashes',
'lora_weights': 'Lora weights',
'version': 'Version'
}
@ -92,6 +94,25 @@ class A1111MetadataParser(MetadataParser):
except Exception:
pass
if 'base_model' in data:
for filename in modules.config.model_filenames:
path = Path(filename)
if data['base_model'] == path.stem:
data['base_model'] = path.name
break
if 'lora_hashes' in data:
# TODO optimize by using hash for matching. Problem is speed of creating the hash per model, even on startup
lora_filenames = modules.config.lora_filenames.copy()
lora_filenames.remove(modules.config.downloading_sdxl_lcm_lora())
for li, lora in enumerate(data['lora_hashes'].split(', ')):
name, _, weight = lora.split(': ')
for filename in lora_filenames:
path = Path(filename)
if name == path.stem:
data[f'lora_combined_{li + 1}'] = f'{path.name} : {weight}'
break
return data
def parse_string(self, metadata: dict) -> str:
@ -101,30 +122,30 @@ class A1111MetadataParser(MetadataParser):
width, heigth = eval(data['resolution'])
lora_hashes = []
lora_weights = []
for index in range(lora_count_with_lcm):
key = f'lora_name_{index + 1}'
if key in data:
name = data[f'lora_name_{index + 1}']
# TODO handle LoRA weight
# weight = data[f'lora_weight_{index + 1}']
hash = data[f'lora_hash_{index + 1}']
lora_hashes.append(f'{name.split(".")[0]}: {hash}')
lora_hashes_string = ", ".join(lora_hashes)
lora_name = Path(data[f'lora_name_{index + 1}']).stem
lora_weight = data[f'lora_weight_{index + 1}']
lora_hash = data[f'lora_hash_{index + 1}']
# workaround for Fooocus not knowing LoRA name in LoRA metadata
lora_hashes.append(f'{lora_name}: {lora_hash}: {lora_weight}')
lora_hashes_string = ', '.join(lora_hashes)
generation_params = {
self.fooocus_to_a1111['steps']: data['steps'],
self.fooocus_to_a1111['sampler']: data['sampler'],
self.fooocus_to_a1111['seed']: data['seed'],
# TODO check resolution value, should be string
self.fooocus_to_a1111['resolution']: f'{width}x{heigth}',
# TODO load model by name / hash
self.fooocus_to_a1111['base_model']: data['base_model'].split('.')[0],
self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem,
self.fooocus_to_a1111['base_model_hash']: data['base_model_hash']
}
if 'refiner_model' in data and data['refiner_model'] != 'None' and 'refiner_model_hash' in data:
generation_params |= {
self.fooocus_to_a1111['refiner_model']: data['refiner_model'].split('.')[0],
self.fooocus_to_a1111['refiner_model']: Path(data['refiner_model']).stem,
self.fooocus_to_a1111['refiner_model_hash']: data['refiner_model_hash'],
}