diff --git a/css/style.css b/css/style.css index 09ce6cf0..b87b20a7 100644 --- a/css/style.css +++ b/css/style.css @@ -94,6 +94,10 @@ overflow:inherit !important; } +.gradio-container{ + overflow: visible; +} + /* fullpage image viewer */ #lightboxModal{ diff --git a/extras/ip_adapter.py b/extras/ip_adapter.py index a145b68d..cb1d366f 100644 --- a/extras/ip_adapter.py +++ b/extras/ip_adapter.py @@ -167,21 +167,19 @@ def preprocess(img, ip_adapter_path): ldm_patched.modules.model_management.load_model_gpu(clip_vision.patcher) pixel_values = clip_preprocess(numpy_to_pytorch(img).to(clip_vision.load_device)) - - if clip_vision.dtype != torch.float32: - precision_scope = torch.autocast - else: - precision_scope = lambda a, b: contextlib.nullcontext(a) - - with precision_scope(ldm_patched.modules.model_management.get_autocast_device(clip_vision.load_device), torch.float32): - outputs = clip_vision.model(pixel_values=pixel_values, intermediate_output=-2) + outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True) ip_adapter = entry['ip_adapter'] ip_layers = entry['ip_layers'] image_proj_model = entry['image_proj_model'] ip_unconds = entry['ip_unconds'] - cond = outputs[1].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) + if ip_adapter.plus: + cond = outputs.hidden_states[-2] + else: + cond = outputs.image_embeds + + cond = cond.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) ldm_patched.modules.model_management.load_model_gpu(image_proj_model) cond = image_proj_model.model(cond).to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) diff --git a/fooocus_version.py b/fooocus_version.py index efcfe020..2511cfc7 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.839' +version = '2.1.855' diff --git a/javascript/zoom.js b/javascript/zoom.js index e3fdcfb7..450a0347 100644 --- a/javascript/zoom.js +++ b/javascript/zoom.js @@ -1,18 +1,5 @@ onUiLoaded(async() => { // Helper functions - // Get active tab - - /** - * Waits for an element to be present in the DOM. - */ - const waitForElement = (id) => new Promise(resolve => { - const checkForElement = () => { - const element = document.querySelector(id); - if (element) return resolve(element); - setTimeout(checkForElement, 100); - }; - checkForElement(); - }); // Detect whether the element has a horizontal scroll bar function hasHorizontalScrollbar(element) { @@ -33,140 +20,40 @@ onUiLoaded(async() => { } } - // Check if hotkey is valid - function isValidHotkey(value) { - const specialKeys = ["Ctrl", "Alt", "Shift", "Disable"]; - return ( - (typeof value === "string" && - value.length === 1 && - /[a-z]/i.test(value)) || - specialKeys.includes(value) - ); - } - - // Normalize hotkey - function normalizeHotkey(hotkey) { - return hotkey.length === 1 ? "Key" + hotkey.toUpperCase() : hotkey; - } - - // Format hotkey for display - function formatHotkeyForDisplay(hotkey) { - return hotkey.startsWith("Key") ? hotkey.slice(3) : hotkey; - } - // Create hotkey configuration with the provided options function createHotkeyConfig(defaultHotkeysConfig) { const result = {}; // Resulting hotkey configuration - for (const key in defaultHotkeysConfig) { result[key] = defaultHotkeysConfig[key]; } - return result; } - // Disables functions in the config object based on the provided list of function names - function disableFunctions(config, disabledFunctions) { - // Bind the hasOwnProperty method to the functionMap object to avoid errors - const hasOwnProperty = - Object.prototype.hasOwnProperty.bind(functionMap); - - // Loop through the disabledFunctions array and disable the corresponding functions in the config object - disabledFunctions.forEach(funcName => { - if (hasOwnProperty(funcName)) { - const key = functionMap[funcName]; - config[key] = "disable"; - } - }); - - // Return the updated config object - return config; - } - - /** - * The restoreImgRedMask function displays a red mask around an image to indicate the aspect ratio. - * If the image display property is set to 'none', the mask breaks. To fix this, the function - * temporarily sets the display property to 'block' and then hides the mask again after 300 milliseconds - * to avoid breaking the canvas. Additionally, the function adjusts the mask to work correctly on - * very long images. - */ - function restoreImgRedMask(elements) { - const mainTabId = getTabId(elements); - - if (!mainTabId) return; - - const mainTab = gradioApp().querySelector(mainTabId); - const img = mainTab.querySelector("img"); - const imageARPreview = gradioApp().querySelector("#imageARPreview"); - - if (!img || !imageARPreview) return; - - imageARPreview.style.transform = ""; - if (parseFloat(mainTab.style.width) > 865) { - const transformString = mainTab.style.transform; - const scaleMatch = transformString.match( - /scale\(([-+]?[0-9]*\.?[0-9]+)\)/ - ); - let zoom = 1; // default zoom - - if (scaleMatch && scaleMatch[1]) { - zoom = Number(scaleMatch[1]); - } - - imageARPreview.style.transformOrigin = "0 0"; - imageARPreview.style.transform = `scale(${zoom})`; - } - - if (img.style.display !== "none") return; - - img.style.display = "block"; - - setTimeout(() => { - img.style.display = "none"; - }, 400); - } - // Default config const defaultHotkeysConfig = { - canvas_hotkey_zoom: "Alt", + canvas_hotkey_zoom: "Shift", canvas_hotkey_adjust: "Ctrl", + canvas_zoom_undo_extra_key: "Ctrl", + canvas_zoom_hotkey_undo: "KeyZ", canvas_hotkey_reset: "KeyR", canvas_hotkey_fullscreen: "KeyS", canvas_hotkey_move: "KeyF", - canvas_hotkey_overlap: "KeyO", - canvas_disabled_functions: [], canvas_show_tooltip: true, canvas_auto_expand: true, - canvas_blur_prompt: false, - }; - - const functionMap = { - "Zoom": "canvas_hotkey_zoom", - "Adjust brush size": "canvas_hotkey_adjust", - "Moving canvas": "canvas_hotkey_move", - "Fullscreen": "canvas_hotkey_fullscreen", - "Reset Zoom": "canvas_hotkey_reset", - "Overlap": "canvas_hotkey_overlap" + canvas_blur_prompt: true, }; // Loading the configuration from opts - const preHotkeysConfig = createHotkeyConfig( + const hotkeysConfig = createHotkeyConfig( defaultHotkeysConfig ); - // Disable functions that are not needed by the user - const hotkeysConfig = disableFunctions( - preHotkeysConfig, - preHotkeysConfig.canvas_disabled_functions - ); - let isMoving = false; - let mouseX, mouseY; let activeElement; const elemData = {}; - function applyZoomAndPan(elemId, isExtension = true) { + function applyZoomAndPan(elemId) { const targetElement = gradioApp().querySelector(elemId); if (!targetElement) { @@ -181,6 +68,7 @@ onUiLoaded(async() => { panX: 0, panY: 0 }; + let fullScreenMode = false; // Create tooltip @@ -211,44 +99,46 @@ onUiLoaded(async() => { action: "Adjust brush size", keySuffix: " + wheel" }, + {configKey: "canvas_zoom_hotkey_undo", action: "Undo last action", keyPrefix: `${hotkeysConfig.canvas_zoom_undo_extra_key} + ` }, {configKey: "canvas_hotkey_reset", action: "Reset zoom"}, { configKey: "canvas_hotkey_fullscreen", action: "Fullscreen mode" }, - {configKey: "canvas_hotkey_move", action: "Move canvas"}, - {configKey: "canvas_hotkey_overlap", action: "Overlap"} + {configKey: "canvas_hotkey_move", action: "Move canvas"} ]; - // Create hotkeys array with disabled property based on the config values - const hotkeys = hotkeysInfo.map(info => { + // Create hotkeys array based on the config values + const hotkeys = hotkeysInfo.map((info) => { const configValue = hotkeysConfig[info.configKey]; - const key = info.keySuffix ? - `${configValue}${info.keySuffix}` : - configValue.charAt(configValue.length - 1); - return { - key, - action: info.action, - disabled: configValue === "disable" - }; - }); - - for (const hotkey of hotkeys) { - if (hotkey.disabled) { - continue; + + let key = configValue.slice(-1); + + if (info.keySuffix) { + key = `${configValue}${info.keySuffix}`; } + + if (info.keyPrefix && info.keyPrefix !== "None + ") { + key = `${info.keyPrefix}${configValue[3]}`; + } + + return { + key, + action: info.action, + }; + }); + + hotkeys + .forEach(hotkey => { + const p = document.createElement("p"); + p.innerHTML = `${hotkey.key} - ${hotkey.action}`; + tooltipContent.appendChild(p); + }); + + tooltip.append(info, tooltipContent); - const p = document.createElement("p"); - p.innerHTML = `${hotkey.key} - ${hotkey.action}`; - tooltipContent.appendChild(p); - } - - // Add information and content elements to the tooltip element - tooltip.appendChild(info); - tooltip.appendChild(tooltipContent); - - // Add a hint element to the target element - toolTipElemnt.appendChild(tooltip); + // Add a hint element to the target element + toolTipElemnt.appendChild(tooltip); } //Show tool tip if setting enable @@ -264,9 +154,7 @@ onUiLoaded(async() => { panY: 0 }; - if (isExtension) { - targetElement.style.overflow = "hidden"; - } + targetElement.style.overflow = "hidden"; targetElement.isZoomed = false; @@ -284,7 +172,7 @@ onUiLoaded(async() => { closeBtn.addEventListener("click", resetZoom); } - if (canvas && isExtension) { + if (canvas) { const parentElement = targetElement.closest('[id^="component-"]'); if ( canvas && @@ -297,16 +185,6 @@ onUiLoaded(async() => { } - if ( - canvas && - !isExtension && - parseFloat(canvas.style.width) > 865 && - parseFloat(targetElement.style.width) > 865 - ) { - fitToElement(); - return; - } - targetElement.style.width = ""; } @@ -372,12 +250,10 @@ onUiLoaded(async() => { targetElement.style.transformOrigin = "0 0"; targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`; + targetElement.style.overflow = "visible"; toggleOverlap("on"); - if (isExtension) { - targetElement.style.overflow = "visible"; - } - + return newZoomLevel; } @@ -388,6 +264,7 @@ onUiLoaded(async() => { let zoomPosX, zoomPosY; let delta = 0.2; + if (elemData[elemId].zoomLevel > 7) { delta = 0.9; } else if (elemData[elemId].zoomLevel > 2) { @@ -421,12 +298,7 @@ onUiLoaded(async() => { let parentElement; - if (isExtension) { - parentElement = targetElement.closest('[id^="component-"]'); - } else { - parentElement = targetElement.parentElement; - } - + parentElement = targetElement.closest('[id^="component-"]'); // Get element and screen dimensions const elementWidth = targetElement.offsetWidth; @@ -455,6 +327,26 @@ onUiLoaded(async() => { toggleOverlap("off"); } + // Undo last action + function undoLastAction(e) { + let isCtrlPressed = isModifierKey(e, hotkeysConfig.canvas_zoom_undo_extra_key) + const isAuxButton = e.button >= 3; + + if (isAuxButton) { + isCtrlPressed = true + } else { + if (!isModifierKey(e, hotkeysConfig.canvas_zoom_undo_extra_key)) return; + } + + // Move undoBtn query outside the if statement to avoid unnecessary queries + const undoBtn = document.querySelector(`${activeElement} button[aria-label="Undo"]`); + + if ((isCtrlPressed) && undoBtn ) { + e.preventDefault(); + undoBtn.click(); + } + } + /** * This function fits the target element to the screen by calculating * the required scale and offsets. It also updates the global variables @@ -469,13 +361,8 @@ onUiLoaded(async() => { if (!canvas) return; - if (canvas.offsetWidth > 862 || isExtension) { - targetElement.style.width = (canvas.offsetWidth + 2) + "px"; - } - - if (isExtension) { - targetElement.style.overflow = "visible"; - } + targetElement.style.width = (canvas.offsetWidth + 2) + "px"; + targetElement.style.overflow = "visible"; if (fullScreenMode) { resetZoom(); @@ -549,11 +436,11 @@ onUiLoaded(async() => { } } - const hotkeyActions = { [hotkeysConfig.canvas_hotkey_reset]: resetZoom, [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, - [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen + [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen, + [hotkeysConfig.canvas_zoom_hotkey_undo]: undoLastAction, }; const action = hotkeyActions[event.code]; @@ -597,26 +484,27 @@ onUiLoaded(async() => { } targetElement.addEventListener("mousemove", getMousePosition); + targetElement.addEventListener("auxclick", undoLastAction); //observers // Creating an observer with a callback function to handle DOM changes const observer = new MutationObserver((mutationsList, observer) => { for (let mutation of mutationsList) { - // If the style attribute of the canvas has changed, by observation it happens only when the picture changes - if (mutation.type === 'attributes' && mutation.attributeName === 'style' && - mutation.target.tagName.toLowerCase() === 'canvas') { - targetElement.isExpanded = false; - setTimeout(resetZoom, 10); - } + // If the style attribute of the canvas has changed, by observation it happens only when the picture changes + if (mutation.type === 'attributes' && mutation.attributeName === 'style' && + mutation.target.tagName.toLowerCase() === 'canvas') { + targetElement.isExpanded = false; + setTimeout(resetZoom, 10); + } } - }); - - // Apply auto expand if enabled - if (hotkeysConfig.canvas_auto_expand) { + }); + + // Apply auto expand if enabled + if (hotkeysConfig.canvas_auto_expand) { targetElement.addEventListener("mousemove", autoExpand); // Set up an observer to track attribute changes - observer.observe(targetElement, {attributes: true, childList: true, subtree: true}); - } + observer.observe(targetElement, { attributes: true, childList: true, subtree: true }); + } // Handle events only inside the targetElement let isKeyDownHandlerAttached = false; @@ -661,7 +549,7 @@ onUiLoaded(async() => { function handleMoveKeyDown(e) { // Disable key locks to make pasting from the buffer work correctly - if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && event.code === 'KeyC') || e.code === "F5") { + if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && e.code === 'KeyC') || e.code === "F5") { return; } @@ -713,11 +601,7 @@ onUiLoaded(async() => { if (isMoving && elemId === activeElement) { updatePanPosition(e.movementX, e.movementY); targetElement.style.pointerEvents = "none"; - - if (isExtension) { - targetElement.style.overflow = "visible"; - } - + targetElement.style.overflow = "visible"; } else { targetElement.style.pointerEvents = "auto"; } @@ -745,18 +629,13 @@ onUiLoaded(async() => { } } - if (isExtension) { - targetElement.addEventListener("mousemove", checkForOutBox); - } - + targetElement.addEventListener("mousemove", checkForOutBox); window.addEventListener('resize', (e) => { resetZoom(); - if (isExtension) { - targetElement.isExpanded = false; - targetElement.isZoomed = false; - } + targetElement.isExpanded = false; + targetElement.isZoomed = false; }); gradioApp().addEventListener("mousemove", handleMoveByKey); diff --git a/ldm_patched/contrib/external.py b/ldm_patched/contrib/external.py index e20b08c5..7f95f084 100644 --- a/ldm_patched/contrib/external.py +++ b/ldm_patched/contrib/external.py @@ -1870,6 +1870,7 @@ def init_custom_nodes(): "nodes_images.py", "nodes_video_model.py", "nodes_sag.py", + "nodes_perpneg.py", ] for node_file in extras_files: diff --git a/ldm_patched/contrib/external_images.py b/ldm_patched/contrib/external_images.py index 3dbb3e3b..17e9c497 100644 --- a/ldm_patched/contrib/external_images.py +++ b/ldm_patched/contrib/external_images.py @@ -76,7 +76,7 @@ class SaveAnimatedWEBP: OUTPUT_NODE = True - CATEGORY = "_for_testing" + CATEGORY = "image/animation" def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): method = self.methods.get(method) @@ -138,7 +138,7 @@ class SaveAnimatedPNG: OUTPUT_NODE = True - CATEGORY = "_for_testing" + CATEGORY = "image/animation" def save_images(self, images, fps, compress_level, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append diff --git a/ldm_patched/contrib/external_latent.py b/ldm_patched/contrib/external_latent.py index e2364b88..c6f874e1 100644 --- a/ldm_patched/contrib/external_latent.py +++ b/ldm_patched/contrib/external_latent.py @@ -5,9 +5,7 @@ import torch def reshape_latent_to(target_shape, latent): if latent.shape[1:] != target_shape[1:]: - latent.movedim(1, -1) latent = ldm_patched.modules.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") - latent.movedim(-1, 1) return ldm_patched.modules.utils.repeat_to_batch_size(latent, target_shape[0]) @@ -104,9 +102,32 @@ class LatentInterpolate: samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) return (samples_out,) +class LatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "batch" + + CATEGORY = "latent/batch" + + def batch(self, samples1, samples2): + samples_out = samples1.copy() + s1 = samples1["samples"] + s2 = samples2["samples"] + + if s1.shape[1:] != s2.shape[1:]: + s2 = ldm_patched.modules.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center") + s = torch.cat((s1, s2), dim=0) + samples_out["samples"] = s + samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, + "LatentBatch": LatentBatch, } diff --git a/ldm_patched/contrib/external_model_advanced.py b/ldm_patched/contrib/external_model_advanced.py index 4ebd9dbf..03a2f045 100644 --- a/ldm_patched/contrib/external_model_advanced.py +++ b/ldm_patched/contrib/external_model_advanced.py @@ -19,41 +19,19 @@ class LCM(ldm_patched.modules.model_sampling.EPS): return c_out * x0 + c_skip * model_input -class ModelSamplingDiscreteDistilled(torch.nn.Module): +class ModelSamplingDiscreteDistilled(ldm_patched.modules.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 - def __init__(self): - super().__init__() - self.sigma_data = 1.0 - timesteps = 1000 - beta_start = 0.00085 - beta_end = 0.012 + def __init__(self, model_config=None): + super().__init__(model_config) - betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) + self.skip_steps = self.num_timesteps // self.original_timesteps - self.skip_steps = timesteps // self.original_timesteps - - - alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) + sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) for x in range(self.original_timesteps): - alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] + sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps] - sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 - self.set_sigmas(sigmas) - - def set_sigmas(self, sigmas): - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) - - @property - def sigma_min(self): - return self.sigmas[0] - - @property - def sigma_max(self): - return self.sigmas[-1] + self.set_sigmas(sigmas_valid) def timestep(self, sigma): log_sigma = sigma.log() @@ -68,14 +46,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module): log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] return log_sigma.exp().to(timestep.device) - def percent_to_sigma(self, percent): - if percent <= 0.0: - return 999999999.9 - if percent >= 1.0: - return 0.0 - percent = 1.0 - percent - return self.sigma(torch.tensor(percent * 999.0)).item() - def rescale_zero_terminal_snr_sigmas(sigmas): alphas_cumprod = 1 / ((sigmas * sigmas) + 1) @@ -124,7 +94,7 @@ class ModelSamplingDiscrete: class ModelSamplingAdvanced(sampling_base, sampling_type): pass - model_sampling = ModelSamplingAdvanced() + model_sampling = ModelSamplingAdvanced(model.model.model_config) if zsnr: model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) @@ -156,7 +126,7 @@ class ModelSamplingContinuousEDM: class ModelSamplingAdvanced(ldm_patched.modules.model_sampling.ModelSamplingContinuousEDM, sampling_type): pass - model_sampling = ModelSamplingAdvanced() + model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling.set_sigma_range(sigma_min, sigma_max) m.add_object_patch("model_sampling", model_sampling) return (m, ) diff --git a/ldm_patched/contrib/external_perpneg.py b/ldm_patched/contrib/external_perpneg.py new file mode 100644 index 00000000..ec91681f --- /dev/null +++ b/ldm_patched/contrib/external_perpneg.py @@ -0,0 +1,57 @@ +# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py + +import torch +import ldm_patched.modules.model_management +import ldm_patched.modules.sample +import ldm_patched.modules.samplers +import ldm_patched.modules.utils + + +class PerpNeg: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "empty_conditioning": ("CONDITIONING", ), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, empty_conditioning, neg_scale): + m = model.clone() + nocond = ldm_patched.modules.sample.convert_cond(empty_conditioning) + + def cfg_function(args): + model = args["model"] + noise_pred_pos = args["cond_denoised"] + noise_pred_neg = args["uncond_denoised"] + cond_scale = args["cond_scale"] + x = args["input"] + sigma = args["sigma"] + model_options = args["model_options"] + nocond_processed = ldm_patched.modules.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") + + (noise_pred_nocond, _) = ldm_patched.modules.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) + + pos = noise_pred_pos - noise_pred_nocond + neg = noise_pred_neg - noise_pred_nocond + perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg + perp_neg = perp * neg_scale + cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) + cfg_result = x - cfg_result + return cfg_result + + m.set_model_sampler_cfg_function(cfg_function) + + return (m, ) + + +NODE_CLASS_MAPPINGS = { + "PerpNeg": PerpNeg, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PerpNeg": "Perp-Neg", +} diff --git a/ldm_patched/contrib/external_sag.py b/ldm_patched/contrib/external_sag.py index 59d1890c..06ca67fa 100644 --- a/ldm_patched/contrib/external_sag.py +++ b/ldm_patched/contrib/external_sag.py @@ -29,9 +29,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None): # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * scale + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale @@ -62,7 +60,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - ratio = round(math.sqrt(lh * lw / hw1)) + ratio = math.ceil(math.sqrt(lh * lw / hw1)) mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] # Reshape @@ -113,7 +111,6 @@ class SelfAttentionGuidance: m = model.clone() attn_scores = None - mid_block_shape = None # TODO: make this work properly with chunked batches # currently, we can only save the attn from one UNet call @@ -136,7 +133,6 @@ class SelfAttentionGuidance: def post_cfg_function(args): nonlocal attn_scores - nonlocal mid_block_shape uncond_attn = attn_scores sag_scale = scale diff --git a/ldm_patched/ldm/modules/attention.py b/ldm_patched/ldm/modules/attention.py index f4579bac..49e502ed 100644 --- a/ldm_patched/ldm/modules/attention.py +++ b/ldm_patched/ldm/modules/attention.py @@ -104,9 +104,7 @@ def attention_basic(q, k, v, heads, mask=None): # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * scale + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale diff --git a/ldm_patched/modules/args_parser.py b/ldm_patched/modules/args_parser.py index 3931997d..7957783e 100644 --- a/ldm_patched/modules/args_parser.py +++ b/ldm_patched/modules/args_parser.py @@ -102,7 +102,7 @@ vram_group.add_argument("--always-cpu", action="store_true") parser.add_argument("--always-offload-from-vram", action="store_true") - +parser.add_argument("--pytorch-deterministic", action="store_true") parser.add_argument("--disable-server-log", action="store_true") parser.add_argument("--debug-mode", action="store_true") diff --git a/ldm_patched/modules/clip_vision.py b/ldm_patched/modules/clip_vision.py index eda441af..9699210d 100644 --- a/ldm_patched/modules/clip_vision.py +++ b/ldm_patched/modules/clip_vision.py @@ -19,11 +19,13 @@ class Output: def clip_preprocess(image, size=224): mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) - scale = (size / min(image.shape[1], image.shape[2])) - image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True) - h = (image.shape[2] - size)//2 - w = (image.shape[3] - size)//2 - image = image[:,:,h:h+size,w:w+size] + image = image.movedim(-1, 1) + if not (image.shape[2] == size and image.shape[3] == size): + scale = (size / min(image.shape[2], image.shape[3])) + image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] image = torch.clip((255. * image), 0, 255).round() / 255.0 return (image - mean.view([3,1,1])) / std.view([3,1,1]) @@ -34,11 +36,9 @@ class ClipVisionModel(): self.load_device = ldm_patched.modules.model_management.text_encoder_device() offload_device = ldm_patched.modules.model_management.text_encoder_offload_device() - self.dtype = torch.float32 - if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False): - self.dtype = torch.float16 - - self.model = ldm_patched.modules.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ldm_patched.modules.ops.disable_weight_init) + self.dtype = ldm_patched.modules.model_management.text_encoder_dtype(self.load_device) + self.model = ldm_patched.modules.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ldm_patched.modules.ops.manual_cast) + self.model.eval() self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): @@ -46,15 +46,8 @@ class ClipVisionModel(): def encode_image(self, image): ldm_patched.modules.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device)) - - if self.dtype != torch.float32: - precision_scope = torch.autocast - else: - precision_scope = lambda a, b: contextlib.nullcontext(a) - - with precision_scope(ldm_patched.modules.model_management.get_autocast_device(self.load_device), torch.float32): - out = self.model(pixel_values=pixel_values, intermediate_output=-2) + pixel_values = clip_preprocess(image.to(self.load_device)).float() + out = self.model(pixel_values=pixel_values, intermediate_output=-2) outputs = Output() outputs["last_hidden_state"] = out[0].to(ldm_patched.modules.model_management.intermediate_device()) diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index 0e783b36..31cf95da 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -28,6 +28,10 @@ total_vram = 0 lowvram_available = True xpu_available = False +if args.pytorch_deterministic: + print("Using deterministic algorithms for pytorch") + torch.use_deterministic_algorithms(True, warn_only=True) + directml_enabled = False if args.directml is not None: import torch_directml diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index 9996e74c..bfcb3f56 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -2,6 +2,7 @@ from ldm_patched.k_diffusion import sampling as k_diffusion_sampling from ldm_patched.unipc import uni_pc import torch import enum +import collections from ldm_patched.modules import model_management import math from ldm_patched.modules import model_base @@ -61,9 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in): for c in model_conds: conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) - control = None - if 'control' in conds: - control = conds['control'] + control = conds.get('control', None) patches = None if 'gligen' in conds: @@ -78,7 +77,8 @@ def get_area_and_mult(conds, x_in, timestep_in): patches['middle_patch'] = [gligen_patch] - return (input_x, mult, conditioning, area, control, patches) + cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches']) + return cond_obj(input_x, mult, conditioning, area, control, patches) def cond_equal_size(c1, c2): if c1 is c2: @@ -91,24 +91,24 @@ def cond_equal_size(c1, c2): return True def can_concat_cond(c1, c2): - if c1[0].shape != c2[0].shape: + if c1.input_x.shape != c2.input_x.shape: return False - #control - if (c1[4] is None) != (c2[4] is None): - return False - if c1[4] is not None: - if c1[4] is not c2[4]: + def objects_concatable(obj1, obj2): + if (obj1 is None) != (obj2 is None): return False + if obj1 is not None: + if obj1 is not obj2: + return False + return True - #patches - if (c1[5] is None) != (c2[5] is None): + if not objects_concatable(c1.control, c2.control): return False - if (c1[5] is not None): - if c1[5] is not c2[5]: - return False - return cond_equal_size(c1[2], c2[2]) + if not objects_concatable(c1.patches, c2.patches): + return False + + return cond_equal_size(c1.conditioning, c2.conditioning) def cond_cat(c_list): c_crossattn = [] @@ -184,13 +184,13 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): for x in to_batch: o = to_run.pop(x) p = o[0] - input_x += [p[0]] - mult += [p[1]] - c += [p[2]] - area += [p[3]] - cond_or_uncond += [o[1]] - control = p[4] - patches = p[5] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + control = p.control + patches = p.patches batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x) @@ -251,7 +251,8 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) if "sampler_cfg_function" in model_options: - args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} cfg_result = x - model_options["sampler_cfg_function"](args) else: cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale diff --git a/modules/async_worker.py b/modules/async_worker.py index ada35ecc..840bb1d9 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -34,6 +34,7 @@ def worker(): import modules.advanced_parameters as advanced_parameters import extras.ip_adapter as ip_adapter import extras.face_crop + import fooocus_version from modules.censor import censor_batch from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion @@ -281,9 +282,10 @@ def worker(): inpaint_image = HWC3(inpaint_image) if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \ and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0): + progressbar(async_task, 1, 'Downloading upscale models ...') + modules.config.downloading_upscale_model() if inpaint_parameterized: progressbar(async_task, 1, 'Downloading inpainter ...') - modules.config.downloading_upscale_model() inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models( advanced_parameters.inpaint_engine) base_model_additional_loras += [(inpaint_patch_model_path, 1.0)] @@ -401,8 +403,8 @@ def worker(): uc=None, positive_top_k=len(positive_basic_workloads), negative_top_k=len(negative_basic_workloads), - log_positive_prompt='; '.join([task_prompt] + task_extra_positive_prompts), - log_negative_prompt='; '.join([task_negative_prompt] + task_extra_negative_prompts), + log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts), + log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts), )) if use_expansion: @@ -497,7 +499,7 @@ def worker(): if direct_return: d = [('Upscale (Fast)', '2x')] - log(uov_input_image, d, single_line_number=1) + log(uov_input_image, d) yield_result(async_task, uov_input_image, do_not_show_finished_images=True) return @@ -779,12 +781,13 @@ def worker(): ('Refiner Switch', refiner_switch), ('Sampler', sampler_name), ('Scheduler', scheduler_name), - ('Seed', task['task_seed']) + ('Seed', task['task_seed']), ] - for n, w in loras: + for li, (n, w) in enumerate(loras): if n != 'None': - d.append((f'LoRA [{n}] weight', w)) - log(x, d, single_line_number=3) + d.append((f'LoRA {li + 1}', f'{n} : {w}')) + d.append(('Version', 'v' + fooocus_version.version)) + log(x, d) yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))) except ldm_patched.modules.model_management.InterruptProcessingException as e: @@ -806,12 +809,12 @@ def worker(): task = async_tasks.pop(0) try: handler(task) - except: - traceback.print_exc() - finally: build_image_wall(task) task.yields.append(['finish', task.results]) pipeline.prepare_text_encoder(async_call=True) + except: + traceback.print_exc() + task.yields.append(['finish', task.results]) pass diff --git a/modules/core.py b/modules/core.py index 86c56b5c..989b8e32 100644 --- a/modules/core.py +++ b/modules/core.py @@ -191,7 +191,7 @@ def encode_vae_inpaint(vae, pixels, mask): latent_mask = mask[:, None, :, :] latent_mask = torch.nn.functional.interpolate(latent_mask, size=(H * 8, W * 8), mode="bilinear").round() - latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round() + latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round().to(latent) return latent, latent_mask diff --git a/modules/meta_parser.py b/modules/meta_parser.py new file mode 100644 index 00000000..78d73978 --- /dev/null +++ b/modules/meta_parser.py @@ -0,0 +1,144 @@ +import json +import gradio as gr +import modules.config + + +def load_parameter_button_click(raw_prompt_txt): + loaded_parameter_dict = json.loads(raw_prompt_txt) + assert isinstance(loaded_parameter_dict, dict) + + results = [True, 1] + + try: + h = loaded_parameter_dict.get('Prompt', None) + assert isinstance(h, str) + results.append(h) + except: + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('Negative Prompt', None) + assert isinstance(h, str) + results.append(h) + except: + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('Styles', None) + h = eval(h) + assert isinstance(h, list) + results.append(h) + except: + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('Performance', None) + assert isinstance(h, str) + results.append(h) + except: + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('Resolution', None) + width, height = eval(h) + formatted = modules.config.add_ratio(f'{width}*{height}') + if formatted in modules.config.available_aspect_ratios: + results.append(formatted) + results.append(-1) + results.append(-1) + else: + results.append(gr.update()) + results.append(width) + results.append(height) + except: + results.append(gr.update()) + results.append(gr.update()) + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('Sharpness', None) + assert h is not None + h = float(h) + results.append(h) + except: + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('Guidance Scale', None) + assert h is not None + h = float(h) + results.append(h) + except: + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('ADM Guidance', None) + p, n, e = eval(h) + results.append(float(p)) + results.append(float(n)) + results.append(float(e)) + except: + results.append(gr.update()) + results.append(gr.update()) + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('Base Model', None) + assert isinstance(h, str) + results.append(h) + except: + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('Refiner Model', None) + assert isinstance(h, str) + results.append(h) + except: + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('Refiner Switch', None) + assert h is not None + h = float(h) + results.append(h) + except: + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('Sampler', None) + assert isinstance(h, str) + results.append(h) + except: + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('Scheduler', None) + assert isinstance(h, str) + results.append(h) + except: + results.append(gr.update()) + + try: + h = loaded_parameter_dict.get('Seed', None) + assert h is not None + h = int(h) + results.append(False) + results.append(h) + except: + results.append(gr.update()) + results.append(gr.update()) + + results.append(gr.update(visible=True)) + results.append(gr.update(visible=False)) + + for i in range(1, 6): + try: + n, w = loaded_parameter_dict.get(f'LoRA {i}').split(' : ') + w = float(w) + results.append(n) + results.append(w) + except: + results.append(gr.update()) + results.append(gr.update()) + + return results diff --git a/modules/patch.py b/modules/patch.py index c6012dfd..66b243cb 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -25,6 +25,8 @@ import modules.constants as constants from ldm_patched.modules.samplers import calc_cond_uncond_batch from ldm_patched.k_diffusion.sampling import BatchedBrownianTree from ldm_patched.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control +from modules.patch_precision import patch_all_precision +from modules.patch_clip import patch_all_clip sharpness = 2.0 @@ -214,16 +216,20 @@ def compute_cfg(uncond, cond, cfg_scale, t): def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options=None, seed=None): - if math.isclose(cond_scale, 1.0): - return calc_cond_uncond_batch(model, cond, None, x, timestep, model_options)[0] - global eps_record + if math.isclose(cond_scale, 1.0): + final_x0 = calc_cond_uncond_batch(model, cond, None, x, timestep, model_options)[0] + + if eps_record is not None: + eps_record = ((x - final_x0) / timestep).cpu() + + return final_x0 + positive_x0, negative_x0 = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) positive_eps = x - positive_x0 negative_eps = x - negative_x0 - sigma = timestep alpha = 0.001 * sharpness * global_diffusion_progress @@ -234,7 +240,7 @@ def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, mode cfg_scale=cond_scale, t=global_diffusion_progress) if eps_record is not None: - eps_record = (final_eps / sigma).cpu() + eps_record = (final_eps / timestep).cpu() return x - final_eps @@ -265,11 +271,11 @@ def sdxl_encode_adm_patched(self, **kwargs): height = float(height) * positive_adm_scale def embedder(number_list): - h = [self.embedder(torch.Tensor([number])) for number in number_list] - y = torch.flatten(torch.cat(h)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) - return y + h = self.embedder(torch.tensor(number_list, dtype=torch.float32)) + h = torch.flatten(h).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) + return h - width, height = round_to_64(width), round_to_64(height) + width, height = int(width), int(height) target_width, target_height = round_to_64(target_width), round_to_64(target_height) adm_emphasized = embedder([height, width, 0, 0, target_height, target_width]) @@ -281,46 +287,6 @@ def sdxl_encode_adm_patched(self, **kwargs): return final_adm -def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs): - to_encode = list() - max_token_len = 0 - has_weights = False - for x in token_weight_pairs: - tokens = list(map(lambda a: a[0], x)) - max_token_len = max(len(tokens), max_token_len) - has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) - to_encode.append(tokens) - - sections = len(to_encode) - if has_weights or sections == 0: - to_encode.append(ldm_patched.modules.sd1_clip.gen_empty_tokens(self.special_tokens, max_token_len)) - - out, pooled = self.encode(to_encode) - if pooled is not None: - first_pooled = pooled[0:1].to(ldm_patched.modules.model_management.intermediate_device()) - else: - first_pooled = pooled - - output = [] - for k in range(0, sections): - z = out[k:k + 1] - if has_weights: - original_mean = z.mean() - z_empty = out[-1] - for i in range(len(z)): - for j in range(len(z[i])): - weight = token_weight_pairs[k][j][1] - if weight != 1.0: - z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] - new_mean = z.mean() - z = z * (original_mean / new_mean) - output.append(z) - - if len(output) == 0: - return out[-1:].to(ldm_patched.modules.model_management.intermediate_device()), first_pooled - return torch.cat(output, dim=-2).to(ldm_patched.modules.model_management.intermediate_device()), first_pooled - - def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): if inpaint_worker.current_task is not None: latent_processor = self.inner_model.inner_model.process_latent_in @@ -514,6 +480,9 @@ def build_loaded(module, loader_name): def patch_all(): + patch_all_precision() + patch_all_clip() + if not hasattr(ldm_patched.modules.model_management, 'load_models_gpu_origin'): ldm_patched.modules.model_management.load_models_gpu_origin = ldm_patched.modules.model_management.load_models_gpu @@ -522,7 +491,6 @@ def patch_all(): ldm_patched.controlnet.cldm.ControlNet.forward = patched_cldm_forward ldm_patched.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward ldm_patched.modules.model_base.SDXL.encode_adm = sdxl_encode_adm_patched - ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method ldm_patched.modules.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward ldm_patched.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched ldm_patched.modules.samplers.sampling_function = patched_sampling_function diff --git a/modules/patch_clip.py b/modules/patch_clip.py new file mode 100644 index 00000000..74ee436a --- /dev/null +++ b/modules/patch_clip.py @@ -0,0 +1,213 @@ +# Consistent with Kohya/A1111 to reduce differences between model training and inference. + +import os +import torch +import ldm_patched.controlnet.cldm +import ldm_patched.k_diffusion.sampling +import ldm_patched.ldm.modules.attention +import ldm_patched.ldm.modules.diffusionmodules.model +import ldm_patched.ldm.modules.diffusionmodules.openaimodel +import ldm_patched.ldm.modules.diffusionmodules.openaimodel +import ldm_patched.modules.args_parser +import ldm_patched.modules.model_base +import ldm_patched.modules.model_management +import ldm_patched.modules.model_patcher +import ldm_patched.modules.samplers +import ldm_patched.modules.sd +import ldm_patched.modules.sd1_clip +import ldm_patched.modules.clip_vision +import ldm_patched.modules.model_management as model_management +import ldm_patched.modules.ops as ops +import contextlib + +from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection + + +@contextlib.contextmanager +def use_patched_ops(operations): + op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm'] + backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} + + try: + for op_name in op_names: + setattr(torch.nn, op_name, getattr(operations, op_name)) + + yield + + finally: + for op_name in op_names: + setattr(torch.nn, op_name, backups[op_name]) + return + + +def patched_encode_token_weights(self, token_weight_pairs): + to_encode = list() + max_token_len = 0 + has_weights = False + for x in token_weight_pairs: + tokens = list(map(lambda a: a[0], x)) + max_token_len = max(len(tokens), max_token_len) + has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) + to_encode.append(tokens) + + sections = len(to_encode) + if has_weights or sections == 0: + to_encode.append(ldm_patched.modules.sd1_clip.gen_empty_tokens(self.special_tokens, max_token_len)) + + out, pooled = self.encode(to_encode) + if pooled is not None: + first_pooled = pooled[0:1].to(ldm_patched.modules.model_management.intermediate_device()) + else: + first_pooled = pooled + + output = [] + for k in range(0, sections): + z = out[k:k + 1] + if has_weights: + original_mean = z.mean() + z_empty = out[-1] + for i in range(len(z)): + for j in range(len(z[i])): + weight = token_weight_pairs[k][j][1] + if weight != 1.0: + z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] + new_mean = z.mean() + z = z * (original_mean / new_mean) + output.append(z) + + if len(output) == 0: + return out[-1:].to(ldm_patched.modules.model_management.intermediate_device()), first_pooled + return torch.cat(output, dim=-2).to(ldm_patched.modules.model_management.intermediate_device()), first_pooled + + +def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last", layer_idx=None, + textmodel_json_config=None, dtype=None, special_tokens=None, + layer_norm_hidden_state=True, **kwargs): + torch.nn.Module.__init__(self) + assert layer in self.LAYERS + + if special_tokens is None: + special_tokens = {"start": 49406, "end": 49407, "pad": 49407} + + if textmodel_json_config is None: + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(ldm_patched.modules.sd1_clip.__file__)), + "sd1_clip_config.json") + + config = CLIPTextConfig.from_json_file(textmodel_json_config) + self.num_layers = config.num_hidden_layers + + with use_patched_ops(ops.manual_cast): + with modeling_utils.no_init_weights(): + self.transformer = CLIPTextModel(config) + + if dtype is not None: + self.transformer.to(dtype) + + self.transformer.text_model.embeddings.to(torch.float32) + + if freeze: + self.freeze() + + self.max_length = max_length + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.enable_attention_masks = False + + self.layer_norm_hidden_state = layer_norm_hidden_state + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.clip_layer(layer_idx) + self.layer_default = (self.layer, self.layer_idx) + + +def patched_SDClipModel_forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + device = backup_embeds.weight.device + tokens = self.set_up_textual_embeddings(tokens, backup_embeds) + tokens = torch.LongTensor(tokens).to(device) + + attention_mask = None + if self.enable_attention_masks: + attention_mask = torch.zeros_like(tokens) + max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + for x in range(attention_mask.shape[0]): + for y in range(attention_mask.shape[1]): + attention_mask[x, y] = 1 + if tokens[x, y] == max_token: + break + + outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, + output_hidden_states=self.layer == "hidden") + self.transformer.set_input_embeddings(backup_embeds) + + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + if self.layer_norm_hidden_state: + z = self.transformer.text_model.final_layer_norm(z) + + if hasattr(outputs, "pooler_output"): + pooled_output = outputs.pooler_output.float() + else: + pooled_output = None + + if self.text_projection is not None and pooled_output is not None: + pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() + + return z.float(), pooled_output + + +def patched_ClipVisionModel__init__(self, json_config): + config = CLIPVisionConfig.from_json_file(json_config) + + self.load_device = ldm_patched.modules.model_management.text_encoder_device() + self.offload_device = ldm_patched.modules.model_management.text_encoder_offload_device() + + if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False): + self.dtype = torch.float16 + else: + self.dtype = torch.float32 + + with use_patched_ops(ops.manual_cast): + with modeling_utils.no_init_weights(): + self.model = CLIPVisionModelWithProjection(config) + + self.model.to(self.dtype) + self.patcher = ldm_patched.modules.model_patcher.ModelPatcher( + self.model, + load_device=self.load_device, + offload_device=self.offload_device + ) + + +def patched_ClipVisionModel_encode_image(self, image): + ldm_patched.modules.model_management.load_model_gpu(self.patcher) + pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(self.load_device)) + outputs = self.model(pixel_values=pixel_values, output_hidden_states=True) + + for k in outputs: + t = outputs[k] + if t is not None: + if k == 'hidden_states': + outputs["penultimate_hidden_states"] = t[-2].to(ldm_patched.modules.model_management.intermediate_device()) + outputs["hidden_states"] = None + else: + outputs[k] = t.to(ldm_patched.modules.model_management.intermediate_device()) + + return outputs + + +def patch_all_clip(): + ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = patched_encode_token_weights + ldm_patched.modules.sd1_clip.SDClipModel.__init__ = patched_SDClipModel__init__ + ldm_patched.modules.sd1_clip.SDClipModel.forward = patched_SDClipModel_forward + ldm_patched.modules.clip_vision.ClipVisionModel.__init__ = patched_ClipVisionModel__init__ + ldm_patched.modules.clip_vision.ClipVisionModel.encode_image = patched_ClipVisionModel_encode_image + return diff --git a/modules/patch_precision.py b/modules/patch_precision.py new file mode 100644 index 00000000..83569bdd --- /dev/null +++ b/modules/patch_precision.py @@ -0,0 +1,60 @@ +# Consistent with Kohya to reduce differences between model training and inference. + +import torch +import math +import einops +import numpy as np + +import ldm_patched.ldm.modules.diffusionmodules.openaimodel +import ldm_patched.modules.model_sampling +import ldm_patched.modules.sd1_clip + +from ldm_patched.ldm.modules.diffusionmodules.util import make_beta_schedule + + +def patched_timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + # Consistent with Kohya to reduce differences between model training and inference. + + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = einops.repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def patched_register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + # Consistent with Kohya to reduce differences between model training and inference. + + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) + self.set_sigmas(sigmas) + return + + +def patch_all_precision(): + ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding = patched_timestep_embedding + ldm_patched.modules.model_sampling.ModelSamplingDiscrete._register_schedule = patched_register_schedule + return diff --git a/modules/private_logger.py b/modules/private_logger.py index 3a992cf6..83ba9e36 100644 --- a/modules/private_logger.py +++ b/modules/private_logger.py @@ -1,6 +1,8 @@ import os import args_manager import modules.config +import json +import urllib.parse from PIL import Image from modules.util import generate_temp_filename @@ -16,7 +18,7 @@ def get_current_html_path(): return html_name -def log(img, dic, single_line_number=3): +def log(img, dic): if args_manager.args.disable_image_log: return @@ -25,36 +27,67 @@ def log(img, dic, single_line_number=3): Image.fromarray(img).save(local_temp_filename) html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html') - existing_log = log_cache.get(html_name, None) + css_styles = ( + "" + ) - if existing_log is None: + js = ( + "" + ) + + begin_part = f"
Fooocus Log {date_string} (private)
\nAll images are clean, without any hidden data/meta, and safe to share with others.
\n\n" + end_part = f'\n' + + middle_part = log_cache.get(html_name, "") + + if middle_part == "": if os.path.exists(html_name): - existing_log = open(html_name, encoding='utf-8').read() - else: - existing_log = f'Fooocus Log {date_string} (private)
\nAll images do not contain any hidden data.
' + existing_split = open(html_name, 'r', encoding='utf-8').read().split('') + if len(existing_split) == 3: + middle_part = existing_split[1] + else: + middle_part = existing_split[0] div_name = only_name.replace('.', '_') - item = f'{only_name} \n" - for i, (k, v) in enumerate(dic): - if i < single_line_number: - item += f"{k}: {v} \n" - else: - if (i - single_line_number) % 2 == 0: - item += f"{k}: {v}, " - else: - item += f"{k}: {v} \n" + item = f"\n" - existing_log = item + existing_log + item += " |