Unlock to allow changing model.
This commit is contained in:
lllyasviel 2023-08-12 17:43:39 -07:00 committed by GitHub
parent 1ff382c8ef
commit 158afe088d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 217 additions and 41 deletions

View File

@ -1,2 +1 @@
version = '1.0.17'
version = '1.0.19'

View File

@ -29,6 +29,14 @@ class StableDiffusionModel:
self.clip = clip
self.clip_vision = clip_vision
def to_meta(self):
if self.unet is not None:
self.unet.model.to('meta')
if self.clip is not None:
self.clip.cond_stage_model.to('meta')
if self.vae is not None:
self.vae.first_stage_model.to('meta')
@torch.no_grad()
def load_model(ckpt_filename):
@ -42,8 +50,8 @@ def load_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
return model
lora = comfy.utils.load_torch_file(lora_filename, safe_load=True)
model.unet, model.clip = comfy.sd.load_lora_for_models(model.unet, model.clip, lora, strength_model, strength_clip)
return model
unet, clip = comfy.sd.load_lora_for_models(model.unet, model.clip, lora, strength_model, strength_clip)
return StableDiffusionModel(unet=unet, clip=clip, vae=model.vae, clip_vision=model.clip_vision)
@torch.no_grad()
@ -92,7 +100,7 @@ def get_previewer(device, latent_format):
@torch.no_grad()
def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu',
scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
force_full_denoise=False):
force_full_denoise=False, callback_function=None):
# SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
# SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
# "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
@ -118,6 +126,8 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
if callback_function is not None:
callback_function(step, x0, x, total_steps)
if previewer and step % 3 == 0:
previewer.preview(x0, step, total_steps)
pbar.update_absolute(step + 1, total_steps, None)

View File

@ -1,46 +1,146 @@
import modules.core as core
import os
import torch
import modules.path
from modules.path import modelfile_path, lorafile_path
from comfy.model_base import SDXL, SDXLRefiner
xl_base_filename = os.path.join(modelfile_path, 'sd_xl_base_1.0_0.9vae.safetensors')
xl_refiner_filename = os.path.join(modelfile_path, 'sd_xl_refiner_1.0_0.9vae.safetensors')
xl_base_offset_lora_filename = os.path.join(lorafile_path, 'sd_xl_offset_example-lora_1.0.safetensors')
xl_base: core.StableDiffusionModel = None
xl_base_hash = ''
xl_base = core.load_model(xl_base_filename)
xl_base = core.load_lora(xl_base, xl_base_offset_lora_filename, strength_model=0.5, strength_clip=0.0)
del xl_base.vae
xl_refiner: core.StableDiffusionModel = None
xl_refiner_hash = ''
xl_refiner = core.load_model(xl_refiner_filename)
xl_base_patched: core.StableDiffusionModel = None
xl_base_patched_hash = ''
def refresh_base_model(name):
global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash
if xl_base_hash == str(name):
return
filename = os.path.join(modules.path.modelfile_path, name)
if xl_base is not None:
xl_base.to_meta()
xl_base = None
xl_base = core.load_model(filename)
if not isinstance(xl_base.unet.model, SDXL):
print('Model not supported. Fooocus only support SDXL model as the base model.')
xl_base = None
xl_base_hash = ''
refresh_base_model(modules.path.default_base_model_name)
xl_base_hash = name
xl_base_patched = xl_base
xl_base_patched_hash = ''
return
xl_base_hash = name
xl_base_patched = xl_base
xl_base_patched_hash = ''
print(f'Base model loaded: {xl_base_hash}')
return
def refresh_refiner_model(name):
global xl_refiner, xl_refiner_hash
if xl_refiner_hash == str(name):
return
if name == 'None':
xl_refiner = None
xl_refiner_hash = ''
print(f'Refiner unloaded.')
return
filename = os.path.join(modules.path.modelfile_path, name)
if xl_refiner is not None:
xl_refiner.to_meta()
xl_refiner = None
xl_refiner = core.load_model(filename)
if not isinstance(xl_refiner.unet.model, SDXLRefiner):
print('Model not supported. Fooocus only support SDXL refiner as the refiner.')
xl_refiner = None
xl_refiner_hash = ''
print(f'Refiner unloaded.')
return
xl_refiner_hash = name
print(f'Refiner model loaded: {xl_refiner_hash}')
xl_refiner.vae.first_stage_model.to('meta')
xl_refiner.vae = None
return
def refresh_loras(loras):
global xl_base, xl_base_patched, xl_base_patched_hash
if xl_base_patched_hash == str(loras):
return
model = xl_base
for name, weight in loras:
if name == 'None':
continue
filename = os.path.join(modules.path.lorafile_path, name)
model = core.load_lora(model, filename, strength_model=weight, strength_clip=weight)
xl_base_patched = model
xl_base_patched_hash = str(loras)
print(f'LoRAs loaded: {xl_base_patched_hash}')
return
refresh_base_model(modules.path.default_base_model_name)
refresh_refiner_model(modules.path.default_refiner_model_name)
refresh_loras([(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)])
@torch.no_grad()
def process(positive_prompt, negative_prompt, steps, switch, width, height, image_seed, callback):
positive_conditions = core.encode_prompt_condition(clip=xl_base.clip, prompt=positive_prompt)
negative_conditions = core.encode_prompt_condition(clip=xl_base.clip, prompt=negative_prompt)
positive_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=positive_prompt)
negative_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=negative_prompt)
positive_conditions = core.encode_prompt_condition(clip=xl_base_patched.clip, prompt=positive_prompt)
negative_conditions = core.encode_prompt_condition(clip=xl_base_patched.clip, prompt=negative_prompt)
empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1)
sampled_latent = core.ksampler_with_refiner(
model=xl_base.unet,
positive=positive_conditions,
negative=negative_conditions,
refiner=xl_refiner.unet,
refiner_positive=positive_conditions_refiner,
refiner_negative=negative_conditions_refiner,
refiner_switch_step=switch,
latent=empty_latent,
steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True,
seed=image_seed,
callback_function=callback
)
if xl_refiner is not None:
decoded_latent = core.decode_vae(vae=xl_refiner.vae, latent_image=sampled_latent)
positive_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=positive_prompt)
negative_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=negative_prompt)
sampled_latent = core.ksampler_with_refiner(
model=xl_base_patched.unet,
positive=positive_conditions,
negative=negative_conditions,
refiner=xl_refiner.unet,
refiner_positive=positive_conditions_refiner,
refiner_negative=negative_conditions_refiner,
refiner_switch_step=switch,
latent=empty_latent,
steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True,
seed=image_seed,
callback_function=callback
)
else:
sampled_latent = core.ksampler(
model=xl_base_patched.unet,
positive=positive_conditions,
negative=negative_conditions,
latent=empty_latent,
steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True,
seed=image_seed,
callback_function=callback
)
decoded_latent = core.decode_vae(vae=xl_base_patched.vae, latent_image=sampled_latent)
images = core.image_to_numpy(decoded_latent)

View File

@ -5,3 +5,35 @@ lorafile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../mode
temp_outputs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../outputs/'))
os.makedirs(temp_outputs_path, exist_ok=True)
default_base_model_name = 'sd_xl_base_1.0_0.9vae.safetensors'
default_refiner_model_name = 'sd_xl_refiner_1.0_0.9vae.safetensors'
default_lora_name = 'sd_xl_offset_example-lora_1.0.safetensors'
default_lora_weight = 0.5
model_filenames = []
lora_filenames = []
def get_model_filenames(folder_path):
if not os.path.isdir(folder_path):
raise ValueError("Folder path is not a valid directory.")
filenames = []
for filename in os.listdir(folder_path):
if os.path.isfile(os.path.join(folder_path, filename)):
_, file_extension = os.path.splitext(filename)
if file_extension.lower() in ['.pth', '.ckpt', '.bin', '.safetensors']:
filenames.append(filename)
return filenames
def update_all_model_names():
global model_filenames, lora_filenames
model_filenames = get_model_filenames(modelfile_path)
lora_filenames = get_model_filenames(lorafile_path)
return
update_all_model_names()

View File

@ -2,7 +2,7 @@
styles = [
{
"name": "sai-base",
"name": "None",
"prompt": "{prompt}",
"negative_prompt": ""
},
@ -529,7 +529,7 @@ styles = [
]
styles = {k['name']: (k['prompt'], k['negative_prompt']) for k in styles}
default_style = styles['sai-base']
default_style = styles['None']
style_keys = list(styles.keys())

View File

@ -1,3 +1,7 @@
### 1.0.19
* Unlock to allow changing model.
### 1.0.17
* Change default model to SDXL-1.0-vae-0.9. (This means the models will be downloaded again, but we should do it as early as possible so that all new users only need to download once. Really sorry for day-0 users. But frankly this is not too late considering that the project is just publicly available in less than 24 hours - if it has been a week then we will prefer more lightweight tricks to update.)

View File

@ -1,16 +1,23 @@
import gradio as gr
import modules.path
import random
import fooocus_version
import modules.default_pipeline as pipeline
from modules.sdxl_styles import apply_style, style_keys, aspect_ratios
from modules.default_pipeline import process
from modules.cv2win32 import close_all_preview, save_image
from modules.util import generate_temp_filename
from modules.path import temp_outputs_path
def generate_clicked(prompt, negative_prompt, style_selction, performance_selction,
aspect_ratios_selction, image_number, image_seed, progress=gr.Progress()):
aspect_ratios_selction, image_number, image_seed, base_model_name, refiner_model_name,
l1, w1, l2, w2, l3, w3, l4, w4, l5, w5, progress=gr.Progress()):
loras = [(l1, w1), (l2, w2), (l3, w3), (l4, w4), (l5, w5)]
pipeline.refresh_base_model(base_model_name)
pipeline.refresh_refiner_model(refiner_model_name)
pipeline.refresh_loras(loras)
p_txt, n_txt = apply_style(style_selction, prompt, negative_prompt)
@ -35,10 +42,10 @@ def generate_clicked(prompt, negative_prompt, style_selction, performance_selcti
progress(float(done_steps) / float(all_steps), f'Step {step}/{total_steps} in the {i}-th Sampling')
for i in range(image_number):
imgs = process(p_txt, n_txt, steps, switch, width, height, seed, callback=callback)
imgs = pipeline.process(p_txt, n_txt, steps, switch, width, height, seed, callback=callback)
for x in imgs:
local_temp_filename = generate_temp_filename(folder=temp_outputs_path, extension='png')
local_temp_filename = generate_temp_filename(folder=modules.path.temp_outputs_path, extension='png')
save_image(local_temp_filename, x)
seed += 1
@ -61,21 +68,45 @@ with block:
with gr.Row():
advanced_checkbox = gr.Checkbox(label='Advanced', value=False, container=False)
with gr.Column(scale=0.5, visible=False) as right_col:
with gr.Tab(label='Generator Setting'):
with gr.Tab(label='Setting'):
performance_selction = gr.Radio(label='Performance', choices=['Speed', 'Quality'], value='Speed')
aspect_ratios_selction = gr.Radio(label='Aspect Ratios (width × height)', choices=list(aspect_ratios.keys()),
value='1152×896')
image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=2)
image_seed = gr.Number(label='Random Seed', value=-1, precision=0)
negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.")
with gr.Tab(label='Image Style'):
with gr.Tab(label='Style'):
style_selction = gr.Radio(show_label=False, container=True,
choices=style_keys, value='cinematic-default')
with gr.Tab(label='Advanced'):
with gr.Row():
base_model = gr.Dropdown(label='SDXL Base Model', choices=modules.path.model_filenames, value=modules.path.default_base_model_name, show_label=True)
refiner_model = gr.Dropdown(label='SDXL Refiner', choices=['None'] + modules.path.model_filenames, value=modules.path.default_refiner_model_name, show_label=True)
with gr.Accordion(label='LoRAs', open=True):
lora_ctrls = []
for i in range(5):
with gr.Row():
lora_model = gr.Dropdown(label=f'SDXL LoRA {i+1}', choices=['None'] + modules.path.lora_filenames, value=modules.path.default_lora_name if i == 0 else 'None')
lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=modules.path.default_lora_weight)
lora_ctrls += [lora_model, lora_weight]
model_refresh = gr.Button(label='Refresh', value='Refresh All Files', variant='secondary')
def model_refresh_clicked():
modules.path.update_all_model_names()
results = []
results += [gr.update(choices=modules.path.model_filenames), gr.update(choices=['None'] + modules.path.model_filenames)]
for i in range(5):
results += [gr.update(choices=['None'] + modules.path.lora_filenames), gr.update()]
return results
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls)
advanced_checkbox.change(lambda x: gr.update(visible=x), advanced_checkbox, right_col)
ctrls = [
prompt, negative_prompt, style_selction,
performance_selction, aspect_ratios_selction, image_number, image_seed
]
ctrls += [base_model, refiner_model] + lora_ctrls
run_button.click(fn=generate_clicked, inputs=ctrls, outputs=[gallery])
block.launch(inbrowser=True)