From 63403d614e196b54d374058e992f92476d46d040 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Fri, 2 Feb 2024 23:44:47 +0100 Subject: [PATCH] feat: add sampler mapping --- modules/flags.py | 39 +++++++++++++++++++++++++++++++++++---- modules/meta_parser.py | 19 +++++++++++++++++-- 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/modules/flags.py b/modules/flags.py index db26394e..06ced601 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -12,12 +12,43 @@ uov_list = [ disabled, subtle_variation, strong_variation, upscale_15, upscale_2, upscale_fast ] -KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", - "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] +CIVITAI_NO_KARRAS = ["euler", "euler_ancestral", "heun", "dpm_fast", "dpm_adaptive", "ddim", "uni_pc"] + +# fooocus: a1111 (Civitai) +KSAMPLER = { + "euler": "Euler", + "euler_ancestral": "Euler a", + "heun": "Heun", + "heunpp2": "", + "dpm_2": "DPM2", + "dpm_2_ancestral": "DPM2 a", + "lms": "LMS", + "dpm_fast": "DPM fast", + "dpm_adaptive": "DPM adaptive", + "dpmpp_2s_ancestral": "DPM++ 2S a", + "dpmpp_sde": "DPM++ SDE", + "dpmpp_sde_gpu": "", + "dpmpp_2m": "DPM++ 2M", + "dpmpp_2m_sde": "DPM++ 2M SDE", + "dpmpp_2m_sde_gpu": "", + "dpmpp_3m_sde": "", + "dpmpp_3m_sde_gpu": "", + "ddpm": "", + "lcm": "LCM" +} + +SAMPLER_EXTRA = { + "ddim": "DDIM", + "uni_pc": "UniPC", + "uni_pc_bh2": "" +} + +SAMPLERS = KSAMPLER | SAMPLER_EXTRA + +KSAMPLER_NAMES = list(KSAMPLER.keys()) SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo"] -SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] +SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys()) sampler_list = SAMPLER_NAMES scheduler_list = SCHEDULER_NAMES diff --git a/modules/meta_parser.py b/modules/meta_parser.py index aa0cd10e..1bbb1abf 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -10,7 +10,7 @@ from PIL import Image import modules.config import modules.sdxl_styles from modules.flags import MetadataScheme, Performance, Steps -from modules.flags import lora_count +from modules.flags import lora_count, SAMPLERS, CIVITAI_NO_KARRAS from modules.util import quote, unquote, extract_styles_from_prompt, is_json, calculate_sha256 re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' @@ -246,6 +246,7 @@ class A1111MetadataParser(MetadataParser): 'performance': 'Performance', 'steps': 'Steps', 'sampler': 'Sampler', + 'scheduler': 'Scheduler', 'guidance_scale': 'CFG scale', 'seed': 'Seed', 'resolution': 'Size', @@ -325,6 +326,12 @@ class A1111MetadataParser(MetadataParser): except Exception: pass + if 'sampler' in data: + sampler = data['sampler'].replace(' Karras', '') + # get key + data['sampler'] = [k for k, v in SAMPLERS.items() if v == sampler][0] + + for key in ['base_model', 'refiner_model']: if key in data: for filename in modules.config.model_filenames: @@ -351,9 +358,16 @@ class A1111MetadataParser(MetadataParser): width, height = eval(data['resolution']) + sampler = data['sampler'] + scheduler = data['scheduler'] + if sampler in SAMPLERS and SAMPLERS[sampler] != '': + sampler = SAMPLERS[sampler] + if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras': + sampler += f' Karras' + generation_params = { self.fooocus_to_a1111['steps']: self.steps, - self.fooocus_to_a1111['sampler']: data['sampler'], + self.fooocus_to_a1111['sampler']: sampler, self.fooocus_to_a1111['seed']: data['seed'], self.fooocus_to_a1111['resolution']: f'{width}x{height}', self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'], @@ -363,6 +377,7 @@ class A1111MetadataParser(MetadataParser): self.fooocus_to_a1111['base_model_hash']: self.base_model_hash, self.fooocus_to_a1111['performance']: data['performance'], + self.fooocus_to_a1111['scheduler']: scheduler, # workaround for multiline prompts self.fooocus_to_a1111['raw_prompt']: self.raw_prompt, self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt,