This commit is contained in:
parent
8a452d4c7c
commit
746f8ef1f4
|
|
@ -2,6 +2,7 @@ __pycache__
|
|||
*.ckpt
|
||||
*.safetensors
|
||||
*.pth
|
||||
!taesdxl_decoder.pth
|
||||
/repositories
|
||||
/venv
|
||||
/tmp
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -1,14 +1,16 @@
|
|||
import os
|
||||
import random
|
||||
import cv2
|
||||
import einops
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.sample
|
||||
import comfy.utils
|
||||
import latent_preview
|
||||
|
||||
from comfy.sd import load_checkpoint_guess_config
|
||||
from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode, common_ksampler
|
||||
from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
|
||||
|
||||
|
||||
opCLIPTextEncode = CLIPTextEncode()
|
||||
|
|
@ -45,6 +47,20 @@ def decode_vae(vae, latent_image):
|
|||
return opVAEDecode.decode(samples=latent_image, vae=vae)[0]
|
||||
|
||||
|
||||
def get_previewer(device, latent_format):
|
||||
from latent_preview import TAESD, TAESDPreviewerImpl
|
||||
taesd_decoder_path = os.path.abspath(os.path.realpath(os.path.join("models", "vae_approx",
|
||||
latent_format.taesd_decoder_name)))
|
||||
|
||||
if not os.path.exists(taesd_decoder_path):
|
||||
print(f"Warning: TAESD previews enabled, but could not find {taesd_decoder_path}")
|
||||
return None
|
||||
|
||||
taesd = TAESD(None, taesd_decoder_path).to(device)
|
||||
|
||||
return taesd
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=9.0, sampler_name='euler_ancestral', scheduler='normal', denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||
seed = seed if isinstance(seed, int) else random.randint(1, 2 ** 64)
|
||||
|
|
@ -66,21 +82,31 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=9.0, sa
|
|||
if preview_format not in ["JPEG", "PNG"]:
|
||||
preview_format = "JPEG"
|
||||
|
||||
previewer = latent_preview.get_previewer(device, model.model.latent_format)
|
||||
previewer = get_previewer(device, model.model.latent_format)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
def callback(step, x0, x, total_steps):
|
||||
preview_bytes = None
|
||||
if previewer:
|
||||
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||
if previewer and step % 3 == 0:
|
||||
with torch.no_grad():
|
||||
x_sample = previewer.decoder(x0).detach() * 255.0
|
||||
x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')
|
||||
x_sample = x_sample.cpu().numpy()[..., ::-1].copy().clip(0, 255).astype(np.uint8)
|
||||
|
||||
for i, s in enumerate(x_sample):
|
||||
cv2.imshow(f'Preview {i}', s)
|
||||
cv2.waitKey(1)
|
||||
pbar.update_absolute(step + 1, total_steps, None)
|
||||
|
||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed)
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
|
||||
if previewer:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue