diff --git a/modules/async_worker.py b/modules/async_worker.py index 40abb7fa..37ab09ae 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -39,8 +39,17 @@ def worker(): from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion from modules.private_logger import log from extras.expansion import safe_str - from modules.util import remove_empty_str, HWC3, resize_image, \ - get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix + from modules.util import ( + remove_empty_str, + HWC3, resize_image, + get_image_shape_ceil, + set_image_shape_ceil, + get_shape_ceil, + resample_image, + erode_or_dilate, + ordinal_suffix, + parse_lora_references_from_prompt + ) from modules.upscaler import perform_upscale try: @@ -370,6 +379,10 @@ def worker(): extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else [] progressbar(async_task, 3, 'Loading models ...') + + # Parse lora references from prompt + loras = parse_lora_references_from_prompt(prompt, loras) + pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, use_synthetic_refiner=use_synthetic_refiner) diff --git a/modules/util.py b/modules/util.py index c309480a..2355dae3 100644 --- a/modules/util.py +++ b/modules/util.py @@ -4,6 +4,8 @@ import random import math import os import cv2 +import re +from typing import List, Tuple, AnyStr from PIL import Image @@ -179,3 +181,23 @@ def get_files_from_folder(folder_path, exensions=None, name_filter=None): def ordinal_suffix(number: int) -> str: return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th') + + +def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, float]]): + pattern = re.compile(".*.*") + new_loras = [] + + for token in items.split(","): + print(token) + m = pattern.match(token) + + if m: + new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2)))) + + updated_loras = [] + for lora in loras + new_loras: + + if lora[0] != "None": + updated_loras.append(lora) + + return updated_loras