diff --git a/.gitignore b/.gitignore index fa3d9203..9c054065 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__ *.ckpt *.safetensors *.pth +!taesdxl_decoder.pth /repositories /venv /tmp diff --git a/models/vae_approx/taesdxl_decoder.pth b/models/vae_approx/taesdxl_decoder.pth new file mode 100644 index 00000000..f2b34452 Binary files /dev/null and b/models/vae_approx/taesdxl_decoder.pth differ diff --git a/modules/core.py b/modules/core.py index e382be86..b7b6f31d 100644 --- a/modules/core.py +++ b/modules/core.py @@ -1,14 +1,16 @@ +import os import random +import cv2 +import einops import torch import numpy as np import comfy.model_management import comfy.sample import comfy.utils -import latent_preview from comfy.sd import load_checkpoint_guess_config -from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode, common_ksampler +from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode opCLIPTextEncode = CLIPTextEncode() @@ -45,6 +47,20 @@ def decode_vae(vae, latent_image): return opVAEDecode.decode(samples=latent_image, vae=vae)[0] +def get_previewer(device, latent_format): + from latent_preview import TAESD, TAESDPreviewerImpl + taesd_decoder_path = os.path.abspath(os.path.realpath(os.path.join("models", "vae_approx", + latent_format.taesd_decoder_name))) + + if not os.path.exists(taesd_decoder_path): + print(f"Warning: TAESD previews enabled, but could not find {taesd_decoder_path}") + return None + + taesd = TAESD(None, taesd_decoder_path).to(device) + + return taesd + + @torch.no_grad() def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=9.0, sampler_name='euler_ancestral', scheduler='normal', denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): seed = seed if isinstance(seed, int) else random.randint(1, 2 ** 64) @@ -66,21 +82,31 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=9.0, sa if preview_format not in ["JPEG", "PNG"]: preview_format = "JPEG" - previewer = latent_preview.get_previewer(device, model.model.latent_format) + previewer = get_previewer(device, model.model.latent_format) pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): - preview_bytes = None - if previewer: - preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) - pbar.update_absolute(step + 1, total_steps, preview_bytes) + if previewer and step % 3 == 0: + with torch.no_grad(): + x_sample = previewer.decoder(x0).detach() * 255.0 + x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c') + x_sample = x_sample.cpu().numpy()[..., ::-1].copy().clip(0, 255).astype(np.uint8) + + for i, s in enumerate(x_sample): + cv2.imshow(f'Preview {i}', s) + cv2.waitKey(1) + pbar.update_absolute(step + 1, total_steps, None) samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed) out = latent.copy() out["samples"] = samples + + if previewer: + cv2.destroyAllWindows() + return out