From 158afe088d4d7c0580b0e7e7e0c32412e219e118 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 12 Aug 2023 17:43:39 -0700 Subject: [PATCH] 1.0.19 (#33) Unlock to allow changing model. --- fooocus_version.py | 3 +- modules/core.py | 16 +++- modules/default_pipeline.py | 154 +++++++++++++++++++++++++++++------- modules/path.py | 32 ++++++++ modules/sdxl_styles.py | 4 +- update_log.md | 4 + webui.py | 45 +++++++++-- 7 files changed, 217 insertions(+), 41 deletions(-) diff --git a/fooocus_version.py b/fooocus_version.py index f7018a52..4d0c90e0 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1,2 +1 @@ -version = '1.0.17' - +version = '1.0.19' diff --git a/modules/core.py b/modules/core.py index 799314c0..6abd3216 100644 --- a/modules/core.py +++ b/modules/core.py @@ -29,6 +29,14 @@ class StableDiffusionModel: self.clip = clip self.clip_vision = clip_vision + def to_meta(self): + if self.unet is not None: + self.unet.model.to('meta') + if self.clip is not None: + self.clip.cond_stage_model.to('meta') + if self.vae is not None: + self.vae.first_stage_model.to('meta') + @torch.no_grad() def load_model(ckpt_filename): @@ -42,8 +50,8 @@ def load_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0): return model lora = comfy.utils.load_torch_file(lora_filename, safe_load=True) - model.unet, model.clip = comfy.sd.load_lora_for_models(model.unet, model.clip, lora, strength_model, strength_clip) - return model + unet, clip = comfy.sd.load_lora_for_models(model.unet, model.clip, lora, strength_model, strength_clip) + return StableDiffusionModel(unet=unet, clip=clip, vae=model.vae, clip_vision=model.clip_vision) @torch.no_grad() @@ -92,7 +100,7 @@ def get_previewer(device, latent_format): @torch.no_grad() def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu', scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, - force_full_denoise=False): + force_full_denoise=False, callback_function=None): # SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] # SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", # "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", @@ -118,6 +126,8 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): + if callback_function is not None: + callback_function(step, x0, x, total_steps) if previewer and step % 3 == 0: previewer.preview(x0, step, total_steps) pbar.update_absolute(step + 1, total_steps, None) diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index da401e9f..9fb98535 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -1,46 +1,146 @@ import modules.core as core import os import torch +import modules.path -from modules.path import modelfile_path, lorafile_path +from comfy.model_base import SDXL, SDXLRefiner -xl_base_filename = os.path.join(modelfile_path, 'sd_xl_base_1.0_0.9vae.safetensors') -xl_refiner_filename = os.path.join(modelfile_path, 'sd_xl_refiner_1.0_0.9vae.safetensors') -xl_base_offset_lora_filename = os.path.join(lorafile_path, 'sd_xl_offset_example-lora_1.0.safetensors') +xl_base: core.StableDiffusionModel = None +xl_base_hash = '' -xl_base = core.load_model(xl_base_filename) -xl_base = core.load_lora(xl_base, xl_base_offset_lora_filename, strength_model=0.5, strength_clip=0.0) -del xl_base.vae +xl_refiner: core.StableDiffusionModel = None +xl_refiner_hash = '' -xl_refiner = core.load_model(xl_refiner_filename) +xl_base_patched: core.StableDiffusionModel = None +xl_base_patched_hash = '' + + +def refresh_base_model(name): + global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash + if xl_base_hash == str(name): + return + + filename = os.path.join(modules.path.modelfile_path, name) + + if xl_base is not None: + xl_base.to_meta() + xl_base = None + + xl_base = core.load_model(filename) + if not isinstance(xl_base.unet.model, SDXL): + print('Model not supported. Fooocus only support SDXL model as the base model.') + xl_base = None + xl_base_hash = '' + refresh_base_model(modules.path.default_base_model_name) + xl_base_hash = name + xl_base_patched = xl_base + xl_base_patched_hash = '' + return + + xl_base_hash = name + xl_base_patched = xl_base + xl_base_patched_hash = '' + print(f'Base model loaded: {xl_base_hash}') + + return + + +def refresh_refiner_model(name): + global xl_refiner, xl_refiner_hash + if xl_refiner_hash == str(name): + return + + if name == 'None': + xl_refiner = None + xl_refiner_hash = '' + print(f'Refiner unloaded.') + return + + filename = os.path.join(modules.path.modelfile_path, name) + + if xl_refiner is not None: + xl_refiner.to_meta() + xl_refiner = None + + xl_refiner = core.load_model(filename) + if not isinstance(xl_refiner.unet.model, SDXLRefiner): + print('Model not supported. Fooocus only support SDXL refiner as the refiner.') + xl_refiner = None + xl_refiner_hash = '' + print(f'Refiner unloaded.') + return + + xl_refiner_hash = name + print(f'Refiner model loaded: {xl_refiner_hash}') + + xl_refiner.vae.first_stage_model.to('meta') + xl_refiner.vae = None + return + + +def refresh_loras(loras): + global xl_base, xl_base_patched, xl_base_patched_hash + if xl_base_patched_hash == str(loras): + return + + model = xl_base + for name, weight in loras: + if name == 'None': + continue + + filename = os.path.join(modules.path.lorafile_path, name) + model = core.load_lora(model, filename, strength_model=weight, strength_clip=weight) + xl_base_patched = model + xl_base_patched_hash = str(loras) + print(f'LoRAs loaded: {xl_base_patched_hash}') + + return + + +refresh_base_model(modules.path.default_base_model_name) +refresh_refiner_model(modules.path.default_refiner_model_name) +refresh_loras([(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)]) @torch.no_grad() def process(positive_prompt, negative_prompt, steps, switch, width, height, image_seed, callback): - positive_conditions = core.encode_prompt_condition(clip=xl_base.clip, prompt=positive_prompt) - negative_conditions = core.encode_prompt_condition(clip=xl_base.clip, prompt=negative_prompt) - - positive_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=positive_prompt) - negative_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=negative_prompt) + positive_conditions = core.encode_prompt_condition(clip=xl_base_patched.clip, prompt=positive_prompt) + negative_conditions = core.encode_prompt_condition(clip=xl_base_patched.clip, prompt=negative_prompt) empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1) - sampled_latent = core.ksampler_with_refiner( - model=xl_base.unet, - positive=positive_conditions, - negative=negative_conditions, - refiner=xl_refiner.unet, - refiner_positive=positive_conditions_refiner, - refiner_negative=negative_conditions_refiner, - refiner_switch_step=switch, - latent=empty_latent, - steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, - seed=image_seed, - callback_function=callback - ) + if xl_refiner is not None: - decoded_latent = core.decode_vae(vae=xl_refiner.vae, latent_image=sampled_latent) + positive_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=positive_prompt) + negative_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=negative_prompt) + + sampled_latent = core.ksampler_with_refiner( + model=xl_base_patched.unet, + positive=positive_conditions, + negative=negative_conditions, + refiner=xl_refiner.unet, + refiner_positive=positive_conditions_refiner, + refiner_negative=negative_conditions_refiner, + refiner_switch_step=switch, + latent=empty_latent, + steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, + seed=image_seed, + callback_function=callback + ) + + else: + sampled_latent = core.ksampler( + model=xl_base_patched.unet, + positive=positive_conditions, + negative=negative_conditions, + latent=empty_latent, + steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, + seed=image_seed, + callback_function=callback + ) + + decoded_latent = core.decode_vae(vae=xl_base_patched.vae, latent_image=sampled_latent) images = core.image_to_numpy(decoded_latent) diff --git a/modules/path.py b/modules/path.py index d192062f..189f98fc 100644 --- a/modules/path.py +++ b/modules/path.py @@ -5,3 +5,35 @@ lorafile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../mode temp_outputs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../outputs/')) os.makedirs(temp_outputs_path, exist_ok=True) + +default_base_model_name = 'sd_xl_base_1.0_0.9vae.safetensors' +default_refiner_model_name = 'sd_xl_refiner_1.0_0.9vae.safetensors' +default_lora_name = 'sd_xl_offset_example-lora_1.0.safetensors' +default_lora_weight = 0.5 + +model_filenames = [] +lora_filenames = [] + + +def get_model_filenames(folder_path): + if not os.path.isdir(folder_path): + raise ValueError("Folder path is not a valid directory.") + + filenames = [] + for filename in os.listdir(folder_path): + if os.path.isfile(os.path.join(folder_path, filename)): + _, file_extension = os.path.splitext(filename) + if file_extension.lower() in ['.pth', '.ckpt', '.bin', '.safetensors']: + filenames.append(filename) + + return filenames + + +def update_all_model_names(): + global model_filenames, lora_filenames + model_filenames = get_model_filenames(modelfile_path) + lora_filenames = get_model_filenames(lorafile_path) + return + + +update_all_model_names() diff --git a/modules/sdxl_styles.py b/modules/sdxl_styles.py index 27da3f10..a1dfdebc 100644 --- a/modules/sdxl_styles.py +++ b/modules/sdxl_styles.py @@ -2,7 +2,7 @@ styles = [ { - "name": "sai-base", + "name": "None", "prompt": "{prompt}", "negative_prompt": "" }, @@ -529,7 +529,7 @@ styles = [ ] styles = {k['name']: (k['prompt'], k['negative_prompt']) for k in styles} -default_style = styles['sai-base'] +default_style = styles['None'] style_keys = list(styles.keys()) diff --git a/update_log.md b/update_log.md index ab5825c5..0f150375 100644 --- a/update_log.md +++ b/update_log.md @@ -1,3 +1,7 @@ +### 1.0.19 + +* Unlock to allow changing model. + ### 1.0.17 * Change default model to SDXL-1.0-vae-0.9. (This means the models will be downloaded again, but we should do it as early as possible so that all new users only need to download once. Really sorry for day-0 users. But frankly this is not too late considering that the project is just publicly available in less than 24 hours - if it has been a week then we will prefer more lightweight tricks to update.) diff --git a/webui.py b/webui.py index 4e6336e5..f9bc4a8c 100644 --- a/webui.py +++ b/webui.py @@ -1,16 +1,23 @@ import gradio as gr +import modules.path import random import fooocus_version +import modules.default_pipeline as pipeline from modules.sdxl_styles import apply_style, style_keys, aspect_ratios -from modules.default_pipeline import process from modules.cv2win32 import close_all_preview, save_image from modules.util import generate_temp_filename -from modules.path import temp_outputs_path def generate_clicked(prompt, negative_prompt, style_selction, performance_selction, - aspect_ratios_selction, image_number, image_seed, progress=gr.Progress()): + aspect_ratios_selction, image_number, image_seed, base_model_name, refiner_model_name, + l1, w1, l2, w2, l3, w3, l4, w4, l5, w5, progress=gr.Progress()): + + loras = [(l1, w1), (l2, w2), (l3, w3), (l4, w4), (l5, w5)] + + pipeline.refresh_base_model(base_model_name) + pipeline.refresh_refiner_model(refiner_model_name) + pipeline.refresh_loras(loras) p_txt, n_txt = apply_style(style_selction, prompt, negative_prompt) @@ -35,10 +42,10 @@ def generate_clicked(prompt, negative_prompt, style_selction, performance_selcti progress(float(done_steps) / float(all_steps), f'Step {step}/{total_steps} in the {i}-th Sampling') for i in range(image_number): - imgs = process(p_txt, n_txt, steps, switch, width, height, seed, callback=callback) + imgs = pipeline.process(p_txt, n_txt, steps, switch, width, height, seed, callback=callback) for x in imgs: - local_temp_filename = generate_temp_filename(folder=temp_outputs_path, extension='png') + local_temp_filename = generate_temp_filename(folder=modules.path.temp_outputs_path, extension='png') save_image(local_temp_filename, x) seed += 1 @@ -61,21 +68,45 @@ with block: with gr.Row(): advanced_checkbox = gr.Checkbox(label='Advanced', value=False, container=False) with gr.Column(scale=0.5, visible=False) as right_col: - with gr.Tab(label='Generator Setting'): + with gr.Tab(label='Setting'): performance_selction = gr.Radio(label='Performance', choices=['Speed', 'Quality'], value='Speed') aspect_ratios_selction = gr.Radio(label='Aspect Ratios (width × height)', choices=list(aspect_ratios.keys()), value='1152×896') image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=2) image_seed = gr.Number(label='Random Seed', value=-1, precision=0) negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.") - with gr.Tab(label='Image Style'): + with gr.Tab(label='Style'): style_selction = gr.Radio(show_label=False, container=True, choices=style_keys, value='cinematic-default') + with gr.Tab(label='Advanced'): + with gr.Row(): + base_model = gr.Dropdown(label='SDXL Base Model', choices=modules.path.model_filenames, value=modules.path.default_base_model_name, show_label=True) + refiner_model = gr.Dropdown(label='SDXL Refiner', choices=['None'] + modules.path.model_filenames, value=modules.path.default_refiner_model_name, show_label=True) + with gr.Accordion(label='LoRAs', open=True): + lora_ctrls = [] + for i in range(5): + with gr.Row(): + lora_model = gr.Dropdown(label=f'SDXL LoRA {i+1}', choices=['None'] + modules.path.lora_filenames, value=modules.path.default_lora_name if i == 0 else 'None') + lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=modules.path.default_lora_weight) + lora_ctrls += [lora_model, lora_weight] + model_refresh = gr.Button(label='Refresh', value='Refresh All Files', variant='secondary') + + def model_refresh_clicked(): + modules.path.update_all_model_names() + results = [] + results += [gr.update(choices=modules.path.model_filenames), gr.update(choices=['None'] + modules.path.model_filenames)] + for i in range(5): + results += [gr.update(choices=['None'] + modules.path.lora_filenames), gr.update()] + return results + + model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls) + advanced_checkbox.change(lambda x: gr.update(visible=x), advanced_checkbox, right_col) ctrls = [ prompt, negative_prompt, style_selction, performance_selction, aspect_ratios_selction, image_number, image_seed ] + ctrls += [base_model, refiner_model] + lora_ctrls run_button.click(fn=generate_clicked, inputs=ctrls, outputs=[gallery]) block.launch(inbrowser=True)