feat: add sampler mapping
This commit is contained in:
parent
ed4a958da8
commit
63403d614e
|
|
@ -12,12 +12,43 @@ uov_list = [
|
|||
disabled, subtle_variation, strong_variation, upscale_15, upscale_2, upscale_fast
|
||||
]
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||
CIVITAI_NO_KARRAS = ["euler", "euler_ancestral", "heun", "dpm_fast", "dpm_adaptive", "ddim", "uni_pc"]
|
||||
|
||||
# fooocus: a1111 (Civitai)
|
||||
KSAMPLER = {
|
||||
"euler": "Euler",
|
||||
"euler_ancestral": "Euler a",
|
||||
"heun": "Heun",
|
||||
"heunpp2": "",
|
||||
"dpm_2": "DPM2",
|
||||
"dpm_2_ancestral": "DPM2 a",
|
||||
"lms": "LMS",
|
||||
"dpm_fast": "DPM fast",
|
||||
"dpm_adaptive": "DPM adaptive",
|
||||
"dpmpp_2s_ancestral": "DPM++ 2S a",
|
||||
"dpmpp_sde": "DPM++ SDE",
|
||||
"dpmpp_sde_gpu": "",
|
||||
"dpmpp_2m": "DPM++ 2M",
|
||||
"dpmpp_2m_sde": "DPM++ 2M SDE",
|
||||
"dpmpp_2m_sde_gpu": "",
|
||||
"dpmpp_3m_sde": "",
|
||||
"dpmpp_3m_sde_gpu": "",
|
||||
"ddpm": "",
|
||||
"lcm": "LCM"
|
||||
}
|
||||
|
||||
SAMPLER_EXTRA = {
|
||||
"ddim": "DDIM",
|
||||
"uni_pc": "UniPC",
|
||||
"uni_pc_bh2": ""
|
||||
}
|
||||
|
||||
SAMPLERS = KSAMPLER | SAMPLER_EXTRA
|
||||
|
||||
KSAMPLER_NAMES = list(KSAMPLER.keys())
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())
|
||||
|
||||
sampler_list = SAMPLER_NAMES
|
||||
scheduler_list = SCHEDULER_NAMES
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from PIL import Image
|
|||
import modules.config
|
||||
import modules.sdxl_styles
|
||||
from modules.flags import MetadataScheme, Performance, Steps
|
||||
from modules.flags import lora_count
|
||||
from modules.flags import lora_count, SAMPLERS, CIVITAI_NO_KARRAS
|
||||
from modules.util import quote, unquote, extract_styles_from_prompt, is_json, calculate_sha256
|
||||
|
||||
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
||||
|
|
@ -246,6 +246,7 @@ class A1111MetadataParser(MetadataParser):
|
|||
'performance': 'Performance',
|
||||
'steps': 'Steps',
|
||||
'sampler': 'Sampler',
|
||||
'scheduler': 'Scheduler',
|
||||
'guidance_scale': 'CFG scale',
|
||||
'seed': 'Seed',
|
||||
'resolution': 'Size',
|
||||
|
|
@ -325,6 +326,12 @@ class A1111MetadataParser(MetadataParser):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
if 'sampler' in data:
|
||||
sampler = data['sampler'].replace(' Karras', '')
|
||||
# get key
|
||||
data['sampler'] = [k for k, v in SAMPLERS.items() if v == sampler][0]
|
||||
|
||||
|
||||
for key in ['base_model', 'refiner_model']:
|
||||
if key in data:
|
||||
for filename in modules.config.model_filenames:
|
||||
|
|
@ -351,9 +358,16 @@ class A1111MetadataParser(MetadataParser):
|
|||
|
||||
width, height = eval(data['resolution'])
|
||||
|
||||
sampler = data['sampler']
|
||||
scheduler = data['scheduler']
|
||||
if sampler in SAMPLERS and SAMPLERS[sampler] != '':
|
||||
sampler = SAMPLERS[sampler]
|
||||
if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras':
|
||||
sampler += f' Karras'
|
||||
|
||||
generation_params = {
|
||||
self.fooocus_to_a1111['steps']: self.steps,
|
||||
self.fooocus_to_a1111['sampler']: data['sampler'],
|
||||
self.fooocus_to_a1111['sampler']: sampler,
|
||||
self.fooocus_to_a1111['seed']: data['seed'],
|
||||
self.fooocus_to_a1111['resolution']: f'{width}x{height}',
|
||||
self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'],
|
||||
|
|
@ -363,6 +377,7 @@ class A1111MetadataParser(MetadataParser):
|
|||
self.fooocus_to_a1111['base_model_hash']: self.base_model_hash,
|
||||
|
||||
self.fooocus_to_a1111['performance']: data['performance'],
|
||||
self.fooocus_to_a1111['scheduler']: scheduler,
|
||||
# workaround for multiline prompts
|
||||
self.fooocus_to_a1111['raw_prompt']: self.raw_prompt,
|
||||
self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt,
|
||||
|
|
|
|||
Loading…
Reference in New Issue