diff --git a/backend/headless/fcbh/clip_vision.py b/backend/headless/fcbh/clip_vision.py index 266a9375..b93b0da7 100644 --- a/backend/headless/fcbh/clip_vision.py +++ b/backend/headless/fcbh/clip_vision.py @@ -92,8 +92,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json") elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") - else: + elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") + else: + return None + clip = ClipVisionModel(json_config) m, u = clip.load_sd(sd) if len(m) > 0: diff --git a/backend/headless/fcbh/model_base.py b/backend/headless/fcbh/model_base.py index c5e1e399..f3f708f7 100644 --- a/backend/headless/fcbh/model_base.py +++ b/backend/headless/fcbh/model_base.py @@ -26,6 +26,7 @@ class BaseModel(torch.nn.Module): self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 + self.inpaint_model = False print("model_type", model_type.name) print("adm", self.adm_channels) @@ -71,6 +72,38 @@ class BaseModel(torch.nn.Module): def encode_adm(self, **kwargs): return None + def cond_concat(self, **kwargs): + if self.inpaint_model: + concat_keys = ("mask", "masked_image") + cond_concat = [] + denoise_mask = kwargs.get("denoise_mask", None) + latent_image = kwargs.get("latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + def blank_inpaint_image_like(latent_image): + blank_image = torch.ones_like(latent_image) + # these are the values for "zero" in pixel space translated to latent space + blank_image[:,0] *= 0.8223 + blank_image[:,1] *= -0.6876 + blank_image[:,2] *= 0.6364 + blank_image[:,3] *= 0.1380 + return blank_image + + for ck in concat_keys: + if denoise_mask is not None: + if ck == "mask": + cond_concat.append(denoise_mask[:,:1].to(device)) + elif ck == "masked_image": + cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space + else: + if ck == "mask": + cond_concat.append(torch.ones_like(noise)[:,:1]) + elif ck == "masked_image": + cond_concat.append(blank_inpaint_image_like(noise)) + return cond_concat + return None + def load_model_weights(self, sd, unet_prefix=""): to_load = {} keys = list(sd.keys()) @@ -112,7 +145,7 @@ class BaseModel(torch.nn.Module): return {**unet_state_dict, **vae_state_dict, **clip_state_dict} def set_inpaint(self): - self.concat_keys = ("mask", "masked_image") + self.inpaint_model = True def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): adm_inputs = [] diff --git a/backend/headless/fcbh/samplers.py b/backend/headless/fcbh/samplers.py index 7fe947ea..c88a3cc1 100644 --- a/backend/headless/fcbh/samplers.py +++ b/backend/headless/fcbh/samplers.py @@ -14,8 +14,8 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) #The main sampling function shared by all the samplers #Returns predicted noise -def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, seed=None): - def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): +def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): + def get_area_and_mult(cond, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 if 'timestep_start' in cond[1]: @@ -68,12 +68,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con conditionning = {} conditionning['c_crossattn'] = cond[0] - if cond_concat_in is not None and len(cond_concat_in) > 0: - cropped = [] - for x in cond_concat_in: - cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - cropped.append(cr) - conditionning['c_concat'] = torch.cat(cropped, dim=1) + + if 'concat' in cond[1]: + cond_concat_in = cond[1]['concat'] + if cond_concat_in is not None and len(cond_concat_in) > 0: + cropped = [] + for x in cond_concat_in: + cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + cropped.append(cr) + conditionning['c_concat'] = torch.cat(cropped, dim=1) if adm_cond is not None: conditionning['c_adm'] = adm_cond @@ -173,7 +176,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con out['c_adm'] = torch.cat(c_adm) return out - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options): + def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 @@ -185,14 +188,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con to_run = [] for x in cond: - p = get_area_and_mult(x, x_in, cond_concat_in, timestep) + p = get_area_and_mult(x, x_in, timestep) if p is None: continue to_run += [(p, COND)] if uncond is not None: for x in uncond: - p = get_area_and_mult(x, x_in, cond_concat_in, timestep) + p = get_area_and_mult(x, x_in, timestep) if p is None: continue @@ -286,7 +289,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if math.isclose(cond_scale, 1.0): uncond = None - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) + cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options) if "sampler_cfg_function" in model_options: args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep} return model_options["sampler_cfg_function"](args) @@ -307,8 +310,8 @@ class CFGNoisePredictor(torch.nn.Module): super().__init__() self.inner_model = model self.alphas_cumprod = model.alphas_cumprod - def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None): - out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options, seed=seed) + def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): + out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) return out @@ -316,11 +319,11 @@ class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, seed=None): + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask - out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options, seed=seed) + out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) if denoise_mask is not None: out *= denoise_mask @@ -358,15 +361,6 @@ def sgm_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) -def blank_inpaint_image_like(latent_image): - blank_image = torch.ones_like(latent_image) - # these are the values for "zero" in pixel space translated to latent space - blank_image[:,0] *= 0.8223 - blank_image[:,1] *= -0.6876 - blank_image[:,2] *= 0.6364 - blank_image[:,3] *= 0.1380 - return blank_image - def get_mask_aabb(masks): if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device, dtype=torch.int) @@ -543,6 +537,20 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type): return conds +def encode_cond(model_function, key, conds, device, **kwargs): + for t in range(len(conds)): + x = conds[t] + params = x[1].copy() + params["device"] = device + for k in kwargs: + if k not in params: + params[k] = kwargs[k] + + out = model_function(**params) + if out is not None: + x[1] = x[1].copy() + x[1][key] = out + return conds class Sampler: def sample(self): @@ -662,31 +670,19 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) + if latent_image is not None: + latent_image = model.process_latent_in(latent_image) + if model.is_adm(): positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") - if latent_image is not None: - latent_image = model.process_latent_in(latent_image) + if hasattr(model, 'cond_concat'): + positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} - cond_concat = None - if hasattr(model, 'concat_keys'): #inpaint - cond_concat = [] - for ck in model.concat_keys: - if denoise_mask is not None: - if ck == "mask": - cond_concat.append(denoise_mask[:,:1]) - elif ck == "masked_image": - cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space - else: - if ck == "mask": - cond_concat.append(torch.ones_like(noise)[:,:1]) - elif ck == "masked_image": - cond_concat.append(blank_inpaint_image_like(noise)) - extra_args["cond_concat"] = cond_concat - samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return model.process_latent_out(samples.to(torch.float32)) diff --git a/backend/headless/fcbh/sd.py b/backend/headless/fcbh/sd.py index 9b646140..5f1f0c6b 100644 --- a/backend/headless/fcbh/sd.py +++ b/backend/headless/fcbh/sd.py @@ -434,10 +434,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_clip: w = WeightsLoader() clip_target = model_config.clip_target() - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - sd = model_config.process_clip_state_dict(sd) - load_model_weights(w, sd) + if clip_target is not None: + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + sd = model_config.process_clip_state_dict(sd) + load_model_weights(w, sd) left_over = sd.keys() if len(left_over) > 0: diff --git a/fooocus_version.py b/fooocus_version.py index f9d6ec5f..a4dcdf49 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.705' +version = '2.1.707' diff --git a/modules/sample_hijack.py b/modules/sample_hijack.py index d4359e99..ed184dd6 100644 --- a/modules/sample_hijack.py +++ b/modules/sample_hijack.py @@ -6,7 +6,7 @@ from fcbh.model_base import SDXLRefiner, SDXL from fcbh.sample import get_additional_models, get_models_from_cond, cleanup_additional_models from fcbh.samplers import resolve_areas_and_cond_masks, wrap_model, calculate_start_end_timesteps, \ create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_adm, \ - blank_inpaint_image_like + encode_cond current_refiner = None @@ -77,10 +77,19 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) + if latent_image is not None: + latent_image = model.process_latent_in(latent_image) + if model.is_adm(): positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") + if hasattr(model, 'cond_concat'): + positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} + if current_refiner is not None and current_refiner.model.is_adm(): positive_refiner = clip_separate(positive, target_model=current_refiner.model) negative_refiner = clip_separate(negative, target_model=current_refiner.model) @@ -91,27 +100,6 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas positive_refiner[0][1]['adm_encoded'].to(positive[0][1]['adm_encoded']) negative_refiner[0][1]['adm_encoded'].to(negative[0][1]['adm_encoded']) - if latent_image is not None: - latent_image = model.process_latent_in(latent_image) - - extra_args = {"cond": positive, "uncond": negative, "cond_scale": cfg, "model_options": model_options, "seed": seed} - - cond_concat = None - if hasattr(model, 'concat_keys'): # inpaint - cond_concat = [] - for ck in model.concat_keys: - if denoise_mask is not None: - if ck == "mask": - cond_concat.append(denoise_mask[:,:1]) - elif ck == "masked_image": - cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space - else: - if ck == "mask": - cond_concat.append(torch.ones_like(noise)[:, :1]) - elif ck == "masked_image": - cond_concat.append(blank_inpaint_image_like(noise)) - extra_args["cond_concat"] = cond_concat - def refiner_switch(): cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))