This commit is contained in:
lvmin 2023-08-10 08:45:30 -07:00
parent 8a452d4c7c
commit 746f8ef1f4
3 changed files with 34 additions and 7 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@ __pycache__
*.ckpt *.ckpt
*.safetensors *.safetensors
*.pth *.pth
!taesdxl_decoder.pth
/repositories /repositories
/venv /venv
/tmp /tmp

Binary file not shown.

View File

@ -1,14 +1,16 @@
import os
import random import random
import cv2
import einops
import torch import torch
import numpy as np import numpy as np
import comfy.model_management import comfy.model_management
import comfy.sample import comfy.sample
import comfy.utils import comfy.utils
import latent_preview
from comfy.sd import load_checkpoint_guess_config from comfy.sd import load_checkpoint_guess_config
from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode, common_ksampler from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
opCLIPTextEncode = CLIPTextEncode() opCLIPTextEncode = CLIPTextEncode()
@ -45,6 +47,20 @@ def decode_vae(vae, latent_image):
return opVAEDecode.decode(samples=latent_image, vae=vae)[0] return opVAEDecode.decode(samples=latent_image, vae=vae)[0]
def get_previewer(device, latent_format):
from latent_preview import TAESD, TAESDPreviewerImpl
taesd_decoder_path = os.path.abspath(os.path.realpath(os.path.join("models", "vae_approx",
latent_format.taesd_decoder_name)))
if not os.path.exists(taesd_decoder_path):
print(f"Warning: TAESD previews enabled, but could not find {taesd_decoder_path}")
return None
taesd = TAESD(None, taesd_decoder_path).to(device)
return taesd
@torch.no_grad() @torch.no_grad()
def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=9.0, sampler_name='euler_ancestral', scheduler='normal', denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=9.0, sampler_name='euler_ancestral', scheduler='normal', denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
seed = seed if isinstance(seed, int) else random.randint(1, 2 ** 64) seed = seed if isinstance(seed, int) else random.randint(1, 2 ** 64)
@ -66,21 +82,31 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=9.0, sa
if preview_format not in ["JPEG", "PNG"]: if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG" preview_format = "JPEG"
previewer = latent_preview.get_previewer(device, model.model.latent_format) previewer = get_previewer(device, model.model.latent_format)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps): def callback(step, x0, x, total_steps):
preview_bytes = None if previewer and step % 3 == 0:
if previewer: with torch.no_grad():
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) x_sample = previewer.decoder(x0).detach() * 255.0
pbar.update_absolute(step + 1, total_steps, preview_bytes) x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')
x_sample = x_sample.cpu().numpy()[..., ::-1].copy().clip(0, 255).astype(np.uint8)
for i, s in enumerate(x_sample):
cv2.imshow(f'Preview {i}', s)
cv2.waitKey(1)
pbar.update_absolute(step + 1, total_steps, None)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed) force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed)
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
if previewer:
cv2.destroyAllWindows()
return out return out