feat: extract remaining attributes, do not use globals in patch

This commit is contained in:
Manuel Schmid 2024-01-22 21:13:44 +01:00
parent f3222b0f27
commit 177075ff7b
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 63 additions and 57 deletions

View File

@ -1,5 +1,6 @@
import threading
import os
from modules.patch import PatchSettings, patch_settings
class AsyncTask:
def __init__(self, args):
@ -42,6 +43,9 @@ def worker():
get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate
from modules.upscaler import perform_upscale
pid = os.getpid()
print(f'Started worker with PID {pid}')
try:
async_gradio_app = shared.gradio_root
flag = f'''App started successful. Use the app with {str(async_gradio_app.local_url)} or {str(async_gradio_app.server_name)}:{str(async_gradio_app.server_port)}'''
@ -227,22 +231,22 @@ def worker():
adm_scaler_end = 0.0
steps = 8
modules.patch.adaptive_cfg = adaptive_cfg
print(f'[Parameters] Adaptive CFG = {modules.patch.adaptive_cfg}')
modules.patch.sharpness = sharpness
print(f'[Parameters] Sharpness = {modules.patch.sharpness}')
modules.patch.controlnet_softness = controlnet_softness
print(f'[Parameters] ControlNet Softness = {modules.patch.controlnet_softness}')
modules.patch.positive_adm_scale = adm_scaler_positive
modules.patch.negative_adm_scale = adm_scaler_negative
modules.patch.adm_scaler_end = adm_scaler_end
print(f'[Parameters] Adaptive CFG = {adaptive_cfg}')
print(f'[Parameters] Sharpness = {sharpness}')
print(f'[Parameters] ControlNet Softness = {controlnet_softness}')
print(f'[Parameters] ADM Scale = '
f'{modules.patch.positive_adm_scale} : '
f'{modules.patch.negative_adm_scale} : '
f'{modules.patch.adm_scaler_end}')
f'{adm_scaler_positive} : '
f'{adm_scaler_negative} : '
f'{adm_scaler_end}')
patch_settings[pid] = PatchSettings(
sharpness,
adm_scaler_end,
adm_scaler_positive,
adm_scaler_negative,
controlnet_softness,
adaptive_cfg
)
cfg_scale = float(guidance_scale)
print(f'[Parameters] CFG = {cfg_scale}')
@ -815,9 +819,9 @@ def worker():
('Sharpness', sharpness),
('Guidance Scale', guidance_scale),
('ADM Guidance', str((
modules.patch.positive_adm_scale,
modules.patch.negative_adm_scale,
modules.patch.adm_scaler_end))),
modules.patch.patch_settings[pid].positive_adm_scale,
modules.patch.patch_settings[pid].negative_adm_scale,
modules.patch.patch_settings[pid].adm_scaler_end))),
('Base Model', base_model_name),
('Refiner Model', refiner_model_name),
('Refiner Switch', refiner_switch),
@ -860,6 +864,9 @@ def worker():
except:
traceback.print_exc()
task.yields.append(['finish', task.results])
finally:
if pid in modules.patch.patch_settings:
del modules.patch.patch_settings[os.getpid()]
pass

View File

@ -425,7 +425,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled)
if refiner_swap_method == 'vae':
modules.patch.eps_record = 'vae'
modules.patch.patch_settings[os.getpid()].eps_record = 'vae'
if modules.inpaint_worker.current_task is not None:
modules.inpaint_worker.current_task.unswap()
@ -463,7 +463,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
denoise=denoise)[switch:] * k_sigmas
len_sigmas = len(sigmas) - 1
noise_mean = torch.mean(modules.patch.eps_record, dim=1, keepdim=True)
noise_mean = torch.mean(modules.patch.patch_settings[os.getpid()].eps_record, dim=1, keepdim=True)
if modules.inpaint_worker.current_task is not None:
modules.inpaint_worker.current_task.swap()
@ -493,5 +493,5 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled)
images = core.pytorch_to_numpy(decoded_latent)
modules.patch.eps_record = None
modules.patch.patch_settings[os.getpid()].eps_record = None
return images

View File

@ -27,20 +27,27 @@ from ldm_patched.ldm.modules.diffusionmodules.openaimodel import forward_timeste
from modules.patch_precision import patch_all_precision
from modules.patch_clip import patch_all_clip
# TODO make these parameters dynamic:
# TODO sharpness, adm_scaler_end, positive_adm_scale, negative_adm_scale, adaptive_cfg + controlnet_softness
sharpness = 2.0
class PatchSettings:
def __init__(self,
sharpness=2.0,
adm_scaler_end=0.3,
positive_adm_scale=1.5,
negative_adm_scale=0.8,
controlnet_softness=0.25,
adaptive_cfg=7.0):
self.sharpness = sharpness
self.adm_scaler_end = adm_scaler_end
self.positive_adm_scale = positive_adm_scale
self.negative_adm_scale = negative_adm_scale
self.controlnet_softness = controlnet_softness
self.adaptive_cfg = adaptive_cfg
self.global_diffusion_progress = 0
self.eps_record = None
adm_scaler_end = 0.3
positive_adm_scale = 1.5
negative_adm_scale = 0.8
controlnet_softness = 0.25
patch_settings = {}
adaptive_cfg = 7.0
global_diffusion_progress = 0
eps_record = None
def calculate_weight_patched(self, patches, weight, key):
for p in patches:
@ -203,14 +210,12 @@ class BrownianTreeNoiseSamplerPatched:
def compute_cfg(uncond, cond, cfg_scale, t):
global adaptive_cfg
mimic_cfg = float(adaptive_cfg)
mimic_cfg = float(patch_settings[os.getpid()].adaptive_cfg)
real_cfg = float(cfg_scale)
real_eps = uncond + real_cfg * (cond - uncond)
if cfg_scale > adaptive_cfg:
if cfg_scale > patch_settings[os.getpid()].adaptive_cfg:
mimicked_eps = uncond + mimic_cfg * (cond - uncond)
return real_eps * t + mimicked_eps * (1 - t)
else:
@ -218,13 +223,11 @@ def compute_cfg(uncond, cond, cfg_scale, t):
def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options=None, seed=None):
global eps_record
if math.isclose(cond_scale, 1.0) and not model_options.get("disable_cfg1_optimization", False):
final_x0 = calc_cond_uncond_batch(model, cond, None, x, timestep, model_options)[0]
if eps_record is not None:
eps_record = ((x - final_x0) / timestep).cpu()
if patch_settings[os.getpid()].eps_record is not None:
patch_settings[os.getpid()].eps_record = ((x - final_x0) / timestep).cpu()
return final_x0
@ -233,16 +236,16 @@ def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, mode
positive_eps = x - positive_x0
negative_eps = x - negative_x0
alpha = 0.001 * sharpness * global_diffusion_progress
alpha = 0.001 * patch_settings[os.getpid()].sharpness * patch_settings[os.getpid()].global_diffusion_progress
positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0)
positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha)
final_eps = compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted,
cfg_scale=cond_scale, t=global_diffusion_progress)
cfg_scale=cond_scale, t=patch_settings[os.getpid()].global_diffusion_progress)
if eps_record is not None:
eps_record = (final_eps / timestep).cpu()
if patch_settings[os.getpid()].eps_record is not None:
patch_settings[os.getpid()].eps_record = (final_eps / timestep).cpu()
return x - final_eps
@ -257,8 +260,6 @@ def round_to_64(x):
def sdxl_encode_adm_patched(self, **kwargs):
global positive_adm_scale, negative_adm_scale
clip_pooled = ldm_patched.modules.model_base.sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 1024)
height = kwargs.get("height", 1024)
@ -266,11 +267,11 @@ def sdxl_encode_adm_patched(self, **kwargs):
target_height = height
if kwargs.get("prompt_type", "") == "negative":
width = float(width) * negative_adm_scale
height = float(height) * negative_adm_scale
width = float(width) * patch_settings[os.getpid()].negative_adm_scale
height = float(height) * patch_settings[os.getpid()].negative_adm_scale
elif kwargs.get("prompt_type", "") == "positive":
width = float(width) * positive_adm_scale
height = float(height) * positive_adm_scale
width = float(width) * patch_settings[os.getpid()].positive_adm_scale
height = float(height) * patch_settings[os.getpid()].positive_adm_scale
def embedder(number_list):
h = self.embedder(torch.tensor(number_list, dtype=torch.float32))
@ -324,7 +325,7 @@ def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale,
def timed_adm(y, timesteps):
if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632:
y_mask = (timesteps > 999.0 * (1.0 - float(adm_scaler_end))).to(y)[..., None]
y_mask = (timesteps > 999.0 * (1.0 - float(patch_settings[os.getpid()].adm_scaler_end))).to(y)[..., None]
y_with_adm = y[..., :2816].clone()
y_without_adm = y[..., 2816:].clone()
return y_with_adm * y_mask + y_without_adm * (1.0 - y_mask)
@ -359,19 +360,17 @@ def patched_cldm_forward(self, x, hint, timesteps, context, y=None, **kwargs):
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
if controlnet_softness > 0:
if patch_settings[os.getpid()].controlnet_softness > 0:
for i in range(10):
k = 1.0 - float(i) / 9.0
outs[i] = outs[i] * (1.0 - controlnet_softness * k)
outs[i] = outs[i] * (1.0 - patch_settings[os.getpid()].controlnet_softness * k)
return outs
def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
global global_diffusion_progress
self.current_step = 1.0 - timesteps.to(x) / 999.0
global_diffusion_progress = float(self.current_step.detach().cpu().numpy().tolist()[0])
patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().numpy().tolist()[0])
y = timed_adm(y, timesteps)
@ -485,7 +484,7 @@ def patch_all():
if ldm_patched.modules.model_management.directml_enabled:
ldm_patched.modules.model_management.lowvram_available = True
ldm_patched.modules.model_management.OOM_EXCEPTION = Exception
patch_all_precision()
patch_all_clip()