Merge pull request #3000 from lllyasviel/develop

Release 2.4.0
This commit is contained in:
Manuel Schmid 2024-05-26 18:18:53 +02:00 committed by GitHub
commit 12dc2396f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
52 changed files with 1484 additions and 271 deletions

View File

@ -1 +1,54 @@
.idea __pycache__
*.ckpt
*.safetensors
*.pth
*.pt
*.bin
*.patch
*.backup
*.corrupted
*.partial
*.onnx
sorted_styles.json
/input
/cache
/language/default.json
/test_imgs
config.txt
config_modification_tutorial.txt
user_path_config.txt
user_path_config-deprecated.txt
/modules/*.png
/repositories
/fooocus_env
/venv
/tmp
/ui-config.json
/outputs
/config.json
/log
/webui.settings.bat
/embeddings
/styles.csv
/params.txt
/styles.csv.bak
/webui-user.bat
/webui-user.sh
/interrogate
/user.css
/.idea
/notification.ogg
/notification.mp3
/SwinIR
/textual_inversion
.vscode
/extensions
/test/stdout.txt
/test/stderr.txt
/cache.json*
/config_states/
/node_modules
/package-lock.json
/.coverage*
/auth.json
.DS_Store

3
.gitattributes vendored Normal file
View File

@ -0,0 +1,3 @@
# Ensure that shell scripts always use lf line endings, e.g. entrypoint.sh for docker
* text=auto
*.sh text eol=lf

2
.github/CODEOWNERS vendored
View File

@ -1 +1 @@
* @lllyasviel * @mashb1t

6
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,6 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "monthly"

44
.github/workflows/build_container.yml vendored Normal file
View File

@ -0,0 +1,44 @@
name: Create and publish a container image
on:
push:
tags:
- 'v*'
jobs:
build-and-push-image:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v5
with:
images: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@ -1,4 +1,4 @@
FROM nvidia/cuda:12.3.1-base-ubuntu22.04 FROM nvidia/cuda:12.4.1-base-ubuntu22.04
ENV DEBIAN_FRONTEND noninteractive ENV DEBIAN_FRONTEND noninteractive
ENV CMDARGS --listen ENV CMDARGS --listen
@ -23,7 +23,7 @@ RUN chown -R user:user /content
WORKDIR /content WORKDIR /content
USER user USER user
RUN git clone https://github.com/lllyasviel/Fooocus /content/app COPY . /content/app
RUN mv /content/app/models /content/app/models.org RUN mv /content/app/models /content/app/models.org
CMD [ "sh", "-c", "/content/entrypoint.sh ${CMDARGS}" ] CMD [ "sh", "-c", "/content/entrypoint.sh ${CMDARGS}" ]

View File

@ -31,6 +31,9 @@ args_parser.parser.add_argument("--disable-metadata", action='store_true',
args_parser.parser.add_argument("--disable-preset-download", action='store_true', args_parser.parser.add_argument("--disable-preset-download", action='store_true',
help="Disables downloading models for presets", default=False) help="Disables downloading models for presets", default=False)
args_parser.parser.add_argument("--enable-describe-uov-image", action='store_true',
help="Disables automatic description of uov images when prompt is empty", default=False)
args_parser.parser.add_argument("--always-download-new-model", action='store_true', args_parser.parser.add_argument("--always-download-new-model", action='store_true',
help="Always download newer models ", default=False) help="Always download newer models ", default=False)

View File

@ -27,6 +27,7 @@ progress {
border-radius: 5px; /* Round the corners of the progress bar */ border-radius: 5px; /* Round the corners of the progress bar */
background-color: #f3f3f3; /* Light grey background */ background-color: #f3f3f3; /* Light grey background */
width: 100%; width: 100%;
vertical-align: middle !important;
} }
/* Style the progress bar container */ /* Style the progress bar container */
@ -69,6 +70,11 @@ progress::after {
height: 30px !important; height: 30px !important;
} }
.progress-bar span {
text-align: right;
width: 215px;
}
.type_row{ .type_row{
height: 80px !important; height: 80px !important;
} }
@ -101,10 +107,14 @@ progress::after {
overflow: auto !important; overflow: auto !important;
} }
.aspect_ratios label { .performance_selection label {
width: 140px !important; width: 140px !important;
} }
.aspect_ratios label {
flex: calc(50% - 5px) !important;
}
.aspect_ratios label span { .aspect_ratios label span {
white-space: nowrap !important; white-space: nowrap !important;
} }
@ -391,6 +401,14 @@ progress::after {
background-color: #fff8; background-color: #fff8;
font-family: monospace; font-family: monospace;
text-align: center; text-align: center;
border-radius-top: 5px; border-radius: 5px 5px 0px 0px;
display: none; /* remove this to enable tooltip in preview image */ display: none; /* remove this to enable tooltip in preview image */
}
#inpaint_canvas .canvas-tooltip-info {
top: 2px;
}
#inpaint_brush_color input[type=color]{
background: none;
} }

11
development.md Normal file
View File

@ -0,0 +1,11 @@
## Running unit tests
Native python:
```
python -m unittest tests/
```
Embedded python (Windows zip file installation method):
```
..\python_embeded\python.exe -m unittest
```

View File

@ -1,12 +1,10 @@
version: '3.9'
volumes: volumes:
fooocus-data: fooocus-data:
services: services:
app: app:
build: . build: .
image: fooocus image: ghcr.io/lllyasviel/fooocus
ports: ports:
- "7865:7865" - "7865:7865"
environment: environment:

View File

@ -1,35 +1,99 @@
# Fooocus on Docker # Fooocus on Docker
The docker image is based on NVIDIA CUDA 12.3 and PyTorch 2.0, see [Dockerfile](Dockerfile) and [requirements_docker.txt](requirements_docker.txt) for details. The docker image is based on NVIDIA CUDA 12.4 and PyTorch 2.1, see [Dockerfile](Dockerfile) and [requirements_docker.txt](requirements_docker.txt) for details.
## Requirements
- A computer with specs good enough to run Fooocus, and proprietary Nvidia drivers
- Docker, Docker Compose, or Podman
## Quick start ## Quick start
**This is just an easy way for testing. Please find more information in the [notes](#notes).** **More information in the [notes](#notes).**
### Running with Docker Compose
1. Clone this repository 1. Clone this repository
2. Build the image with `docker compose build` 2. Run the docker container with `docker compose up`.
3. Run the docker container with `docker compose up`. Building the image takes some time.
### Running with Docker
```sh
docker run -p 7865:7865 -v fooocus-data:/content/data -it \
--gpus all \
-e CMDARGS=--listen \
-e DATADIR=/content/data \
-e config_path=/content/data/config.txt \
-e config_example_path=/content/data/config_modification_tutorial.txt \
-e path_checkpoints=/content/data/models/checkpoints/ \
-e path_loras=/content/data/models/loras/ \
-e path_embeddings=/content/data/models/embeddings/ \
-e path_vae_approx=/content/data/models/vae_approx/ \
-e path_upscale_models=/content/data/models/upscale_models/ \
-e path_inpaint=/content/data/models/inpaint/ \
-e path_controlnet=/content/data/models/controlnet/ \
-e path_clip_vision=/content/data/models/clip_vision/ \
-e path_fooocus_expansion=/content/data/models/prompt_expansion/fooocus_expansion/ \
-e path_outputs=/content/app/outputs/ \
ghcr.io/lllyasviel/fooocus
```
### Running with Podman
```sh
podman run -p 7865:7865 -v fooocus-data:/content/data -it \
--security-opt=no-new-privileges --cap-drop=ALL --security-opt label=type:nvidia_container_t --device=nvidia.com/gpu=all \
-e CMDARGS=--listen \
-e DATADIR=/content/data \
-e config_path=/content/data/config.txt \
-e config_example_path=/content/data/config_modification_tutorial.txt \
-e path_checkpoints=/content/data/models/checkpoints/ \
-e path_loras=/content/data/models/loras/ \
-e path_embeddings=/content/data/models/embeddings/ \
-e path_vae_approx=/content/data/models/vae_approx/ \
-e path_upscale_models=/content/data/models/upscale_models/ \
-e path_inpaint=/content/data/models/inpaint/ \
-e path_controlnet=/content/data/models/controlnet/ \
-e path_clip_vision=/content/data/models/clip_vision/ \
-e path_fooocus_expansion=/content/data/models/prompt_expansion/fooocus_expansion/ \
-e path_outputs=/content/app/outputs/ \
ghcr.io/lllyasviel/fooocus
```
When you see the message `Use the app with http://0.0.0.0:7865/` in the console, you can access the URL in your browser. When you see the message `Use the app with http://0.0.0.0:7865/` in the console, you can access the URL in your browser.
Your models and outputs are stored in the `fooocus-data` volume, which, depending on OS, is stored in `/var/lib/docker/volumes`. Your models and outputs are stored in the `fooocus-data` volume, which, depending on OS, is stored in `/var/lib/docker/volumes/` (or `~/.local/share/containers/storage/volumes/` when using `podman`).
## Building the container locally
Clone the repository first, and open a terminal in the folder.
Build with `docker`:
```sh
docker build . -t fooocus
```
Build with `podman`:
```sh
podman build . -t fooocus
```
## Details ## Details
### Update the container manually ### Update the container manually (`docker compose`)
When you are using `docker compose up` continuously, the container is not updated to the latest version of Fooocus automatically. When you are using `docker compose up` continuously, the container is not updated to the latest version of Fooocus automatically.
Run `git pull` before executing `docker compose build --no-cache` to build an image with the latest Fooocus version. Run `git pull` before executing `docker compose build --no-cache` to build an image with the latest Fooocus version.
You can then start it with `docker compose up` You can then start it with `docker compose up`
### Import models, outputs ### Import models, outputs
If you want to import files from models or the outputs folder, you can uncomment the following settings in the [docker-compose.yml](docker-compose.yml):
If you want to import files from models or the outputs folder, you can add the following bind mounts in the [docker-compose.yml](docker-compose.yml) or your preferred method of running the container:
``` ```
#- ./models:/import/models # Once you import files, you don't need to mount again. #- ./models:/import/models # Once you import files, you don't need to mount again.
#- ./outputs:/import/outputs # Once you import files, you don't need to mount again. #- ./outputs:/import/outputs # Once you import files, you don't need to mount again.
``` ```
After running `docker compose up`, your files will be copied into `/content/data/models` and `/content/data/outputs` After running the container, your files will be copied into `/content/data/models` and `/content/data/outputs`
Since `/content/data` is a persistent volume folder, your files will be persisted even when you re-run `docker compose up --build` without above volume settings. Since `/content/data` is a persistent volume folder, your files will be persisted even when you re-run the container without the above mounts.
### Paths inside the container ### Paths inside the container
@ -54,6 +118,7 @@ Docker specified environments are there. They are used by 'entrypoint.sh'
|CMDARGS|Arguments for [entry_with_update.py](entry_with_update.py) which is called by [entrypoint.sh](entrypoint.sh)| |CMDARGS|Arguments for [entry_with_update.py](entry_with_update.py) which is called by [entrypoint.sh](entrypoint.sh)|
|config_path|'config.txt' location| |config_path|'config.txt' location|
|config_example_path|'config_modification_tutorial.txt' location| |config_example_path|'config_modification_tutorial.txt' location|
|HF_MIRROR| huggingface mirror site domain|
You can also use the same json key names and values explained in the 'config_modification_tutorial.txt' as the environments. You can also use the same json key names and values explained in the 'config_modification_tutorial.txt' as the environments.
See examples in the [docker-compose.yml](docker-compose.yml) See examples in the [docker-compose.yml](docker-compose.yml)

60
extras/censor.py Normal file
View File

@ -0,0 +1,60 @@
import os
import numpy as np
import torch
from transformers import CLIPConfig, CLIPImageProcessor
import ldm_patched.modules.model_management as model_management
import modules.config
from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker
from ldm_patched.modules.model_patcher import ModelPatcher
safety_checker_repo_root = os.path.join(os.path.dirname(__file__), 'safety_checker')
config_path = os.path.join(safety_checker_repo_root, "configs", "config.json")
preprocessor_config_path = os.path.join(safety_checker_repo_root, "configs", "preprocessor_config.json")
class Censor:
def __init__(self):
self.safety_checker_model: ModelPatcher | None = None
self.clip_image_processor: CLIPImageProcessor | None = None
self.load_device = torch.device('cpu')
self.offload_device = torch.device('cpu')
def init(self):
if self.safety_checker_model is None and self.clip_image_processor is None:
safety_checker_model = modules.config.downloading_safety_checker_model()
self.clip_image_processor = CLIPImageProcessor.from_json_file(preprocessor_config_path)
clip_config = CLIPConfig.from_json_file(config_path)
model = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config)
model.eval()
self.load_device = model_management.text_encoder_device()
self.offload_device = model_management.text_encoder_offload_device()
model.to(self.offload_device)
self.safety_checker_model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device)
def censor(self, images: list | np.ndarray) -> list | np.ndarray:
self.init()
model_management.load_model_gpu(self.safety_checker_model)
single = False
if not isinstance(images, list) or isinstance(images, np.ndarray):
images = [images]
single = True
safety_checker_input = self.clip_image_processor(images, return_tensors="pt")
safety_checker_input.to(device=self.load_device)
checked_images, has_nsfw_concept = self.safety_checker_model.model(images=images,
clip_input=safety_checker_input.pixel_values)
checked_images = [image.astype(np.uint8) for image in checked_images]
if single:
checked_images = checked_images[0]
return checked_images
default_censor = Censor().censor

View File

@ -0,0 +1,171 @@
{
"_name_or_path": "clip-vit-large-patch14/",
"architectures": [
"SafetyChecker"
],
"initializer_factor": 1.0,
"logit_scale_init_value": 2.6592,
"model_type": "clip",
"projection_dim": 768,
"text_config": {
"_name_or_path": "",
"add_cross_attention": false,
"architectures": null,
"attention_dropout": 0.0,
"bad_words_ids": null,
"bos_token_id": 0,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": null,
"decoder_start_token_id": null,
"diversity_penalty": 0.0,
"do_sample": false,
"dropout": 0.0,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-05,
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 77,
"min_length": 0,
"model_type": "clip_text_model",
"no_repeat_ngram_size": 0,
"num_attention_heads": 12,
"num_beam_groups": 1,
"num_beams": 1,
"num_hidden_layers": 12,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": 1,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"task_specific_params": null,
"temperature": 1.0,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"transformers_version": "4.21.0.dev0",
"typical_p": 1.0,
"use_bfloat16": false,
"vocab_size": 49408
},
"text_config_dict": {
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12
},
"torch_dtype": "float32",
"transformers_version": null,
"vision_config": {
"_name_or_path": "",
"add_cross_attention": false,
"architectures": null,
"attention_dropout": 0.0,
"bad_words_ids": null,
"bos_token_id": null,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": null,
"decoder_start_token_id": null,
"diversity_penalty": 0.0,
"do_sample": false,
"dropout": 0.0,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": null,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"image_size": 224,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-05,
"length_penalty": 1.0,
"max_length": 20,
"min_length": 0,
"model_type": "clip_vision_model",
"no_repeat_ngram_size": 0,
"num_attention_heads": 16,
"num_beam_groups": 1,
"num_beams": 1,
"num_hidden_layers": 24,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": null,
"patch_size": 14,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"task_specific_params": null,
"temperature": 1.0,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"transformers_version": "4.21.0.dev0",
"typical_p": 1.0,
"use_bfloat16": false
},
"vision_config_dict": {
"hidden_size": 1024,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"patch_size": 14
}
}

View File

@ -0,0 +1,20 @@
{
"crop_size": 224,
"do_center_crop": true,
"do_convert_rgb": true,
"do_normalize": true,
"do_resize": true,
"feature_extractor_type": "CLIPFeatureExtractor",
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"resample": 3,
"size": 224
}

View File

@ -0,0 +1,126 @@
# from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import torch.nn as nn
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
class StableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
main_input_name = "clip_input"
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPConfig):
super().__init__(config)
self.vision_model = CLIPVisionModel(config.vision_config)
self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
@torch.no_grad()
def forward(self, clip_input, images):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
result = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)
result.append(result_img)
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
if torch.is_tensor(images) or torch.is_tensor(images[0]):
images[idx] = torch.zeros_like(images[idx]) # black image
else:
images[idx] = np.zeros(images[idx].shape) # black image
if any(has_nsfw_concepts):
logger.warning(
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
" Try again with a different prompt and/or seed."
)
return images, has_nsfw_concepts
@torch.no_grad()
def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
# increase this value to create a stronger `nsfw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
# special_scores = special_scores.round(decimals=3)
special_care = torch.any(special_scores > 0, dim=1)
special_adjustment = special_care * 0.01
special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
# concept_scores = concept_scores.round(decimals=3)
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
images[has_nsfw_concepts] = 0.0 # black image
return images, has_nsfw_concepts

View File

@ -1,69 +1,85 @@
# https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py # https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py
import os import os
import torch
import safetensors.torch as sf
import torch.nn as nn
import ldm_patched.modules.model_management
import safetensors.torch as sf
import torch
import torch.nn as nn
import ldm_patched.modules.model_management
from ldm_patched.modules.model_patcher import ModelPatcher from ldm_patched.modules.model_patcher import ModelPatcher
from modules.config import path_vae_approx from modules.config import path_vae_approx
class Block(nn.Module): class ResBlock(nn.Module):
def __init__(self, size): """Block with residuals"""
def __init__(self, ch):
super().__init__() super().__init__()
self.join = nn.ReLU() self.join = nn.ReLU()
self.norm = nn.BatchNorm2d(ch)
self.long = nn.Sequential( self.long = nn.Sequential(
nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1), nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1), nn.SiLU(),
nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1), nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1), nn.SiLU(),
nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1), nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.Dropout(0.1)
) )
def forward(self, x): def forward(self, x):
y = self.long(x) x = self.norm(x)
z = self.join(y + x) return self.join(self.long(x) + x)
return z
class Interposer(nn.Module): class ExtractBlock(nn.Module):
def __init__(self): """Increase no. of channels by [out/in]"""
def __init__(self, ch_in, ch_out):
super().__init__() super().__init__()
self.chan = 4 self.join = nn.ReLU()
self.hid = 128 self.short = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.long = nn.Sequential(
self.head_join = nn.ReLU() nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),
self.head_short = nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1) nn.SiLU(),
self.head_long = nn.Sequential( nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1), nn.SiLU(),
nn.LeakyReLU(0.1), nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1), nn.Dropout(0.1)
nn.LeakyReLU(0.1),
nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1),
)
self.core = nn.Sequential(
Block(self.hid),
Block(self.hid),
Block(self.hid),
)
self.tail = nn.Sequential(
nn.ReLU(),
nn.Conv2d(self.hid, self.chan, kernel_size=3, stride=1, padding=1)
) )
def forward(self, x): def forward(self, x):
y = self.head_join( return self.join(self.long(x) + self.short(x))
self.head_long(x) +
self.head_short(x)
class InterposerModel(nn.Module):
"""Main neural network"""
def __init__(self, ch_in=4, ch_out=4, ch_mid=64, scale=1.0, blocks=12):
super().__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.ch_mid = ch_mid
self.blocks = blocks
self.scale = scale
self.head = ExtractBlock(self.ch_in, self.ch_mid)
self.core = nn.Sequential(
nn.Upsample(scale_factor=self.scale, mode="nearest"),
*[ResBlock(self.ch_mid) for _ in range(blocks)],
nn.BatchNorm2d(self.ch_mid),
nn.SiLU(),
) )
self.tail = nn.Conv2d(self.ch_mid, self.ch_out, kernel_size=3, stride=1, padding=1)
def forward(self, x):
y = self.head(x)
z = self.core(y) z = self.core(y)
return self.tail(z) return self.tail(z)
vae_approx_model = None vae_approx_model = None
vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v3.1.safetensors') vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v4.0.safetensors')
def parse(x): def parse(x):
@ -72,7 +88,7 @@ def parse(x):
x_origin = x.clone() x_origin = x.clone()
if vae_approx_model is None: if vae_approx_model is None:
model = Interposer() model = InterposerModel()
model.eval() model.eval()
sd = sf.load_file(vae_approx_filename) sd = sf.load_file(vae_approx_filename)
model.load_state_dict(sd) model.load_state_dict(sd)

View File

@ -1 +1 @@
version = '2.3.1' version = '2.4.0'

View File

@ -122,6 +122,43 @@ document.addEventListener("DOMContentLoaded", function() {
initStylePreviewOverlay(); initStylePreviewOverlay();
}); });
var onAppend = function(elem, f) {
var observer = new MutationObserver(function(mutations) {
mutations.forEach(function(m) {
if (m.addedNodes.length) {
f(m.addedNodes);
}
});
});
observer.observe(elem, {childList: true});
}
function addObserverIfDesiredNodeAvailable(querySelector, callback) {
var elem = document.querySelector(querySelector);
if (!elem) {
window.setTimeout(() => addObserverIfDesiredNodeAvailable(querySelector, callback), 1000);
return;
}
onAppend(elem, callback);
}
/**
* Show reset button on toast "Connection errored out."
*/
addObserverIfDesiredNodeAvailable(".toast-wrap", function(added) {
added.forEach(function(element) {
if (element.innerText.includes("Connection errored out.")) {
window.setTimeout(function() {
document.getElementById("reset_button").classList.remove("hidden");
document.getElementById("generate_button").classList.add("hidden");
document.getElementById("skip_button").classList.add("hidden");
document.getElementById("stop_button").classList.add("hidden");
});
}
});
});
/** /**
* Add a ctrl+enter as a shortcut to start a generation * Add a ctrl+enter as a shortcut to start a generation
*/ */

View File

@ -4,12 +4,20 @@
"Generate": "Generate", "Generate": "Generate",
"Skip": "Skip", "Skip": "Skip",
"Stop": "Stop", "Stop": "Stop",
"Reconnect": "Reconnect",
"Input Image": "Input Image", "Input Image": "Input Image",
"Advanced": "Advanced", "Advanced": "Advanced",
"Upscale or Variation": "Upscale or Variation", "Upscale or Variation": "Upscale or Variation",
"Image Prompt": "Image Prompt", "Image Prompt": "Image Prompt",
"Inpaint or Outpaint (beta)": "Inpaint or Outpaint (beta)", "Inpaint or Outpaint": "Inpaint or Outpaint",
"Drag above image to here": "Drag above image to here", "Outpaint Direction": "Outpaint Direction",
"Method": "Method",
"Describe": "Describe",
"Content Type": "Content Type",
"Photograph": "Photograph",
"Art/Anime": "Art/Anime",
"Describe this Image into Prompt": "Describe this Image into Prompt",
"Image Size and Recommended Size": "Image Size and Recommended Size",
"Upscale or Variation:": "Upscale or Variation:", "Upscale or Variation:": "Upscale or Variation:",
"Disabled": "Disabled", "Disabled": "Disabled",
"Vary (Subtle)": "Vary (Subtle)", "Vary (Subtle)": "Vary (Subtle)",
@ -54,9 +62,12 @@
"Disable seed increment": "Disable seed increment", "Disable seed increment": "Disable seed increment",
"Disable automatic seed increment when image number is > 1.": "Disable automatic seed increment when image number is > 1.", "Disable automatic seed increment when image number is > 1.": "Disable automatic seed increment when image number is > 1.",
"Read wildcards in order": "Read wildcards in order", "Read wildcards in order": "Read wildcards in order",
"Black Out NSFW": "Black Out NSFW",
"Use black image if NSFW is detected.": "Use black image if NSFW is detected.",
"\ud83d\udcda History Log": "\uD83D\uDCDA History Log", "\ud83d\udcda History Log": "\uD83D\uDCDA History Log",
"Image Style": "Image Style", "Image Style": "Image Style",
"Fooocus V2": "Fooocus V2", "Fooocus V2": "Fooocus V2",
"Random Style": "Random Style",
"Default (Slightly Cinematic)": "Default (Slightly Cinematic)", "Default (Slightly Cinematic)": "Default (Slightly Cinematic)",
"Fooocus Masterpiece": "Fooocus Masterpiece", "Fooocus Masterpiece": "Fooocus Masterpiece",
"Fooocus Photograph": "Fooocus Photograph", "Fooocus Photograph": "Fooocus Photograph",
@ -309,6 +320,7 @@
"vae": "vae", "vae": "vae",
"CFG Mimicking from TSNR": "CFG Mimicking from TSNR", "CFG Mimicking from TSNR": "CFG Mimicking from TSNR",
"Enabling Fooocus's implementation of CFG mimicking for TSNR (effective when real CFG > mimicked CFG).": "Enabling Fooocus's implementation of CFG mimicking for TSNR (effective when real CFG > mimicked CFG).", "Enabling Fooocus's implementation of CFG mimicking for TSNR (effective when real CFG > mimicked CFG).": "Enabling Fooocus's implementation of CFG mimicking for TSNR (effective when real CFG > mimicked CFG).",
"CLIP Skip": "CLIP Skip",
"Sampler": "Sampler", "Sampler": "Sampler",
"dpmpp_2m_sde_gpu": "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_gpu": "dpmpp_2m_sde_gpu",
"Only effective in non-inpaint mode.": "Only effective in non-inpaint mode.", "Only effective in non-inpaint mode.": "Only effective in non-inpaint mode.",
@ -339,6 +351,8 @@
"sgm_uniform": "sgm_uniform", "sgm_uniform": "sgm_uniform",
"simple": "simple", "simple": "simple",
"ddim_uniform": "ddim_uniform", "ddim_uniform": "ddim_uniform",
"VAE": "VAE",
"Default (model)": "Default (model)",
"Forced Overwrite of Sampling Step": "Forced Overwrite of Sampling Step", "Forced Overwrite of Sampling Step": "Forced Overwrite of Sampling Step",
"Set as -1 to disable. For developer debugging.": "Set as -1 to disable. For developer debugging.", "Set as -1 to disable. For developer debugging.": "Set as -1 to disable. For developer debugging.",
"Forced Overwrite of Refiner Switch Step": "Forced Overwrite of Refiner Switch Step", "Forced Overwrite of Refiner Switch Step": "Forced Overwrite of Refiner Switch Step",
@ -378,7 +392,7 @@
"Fooocus Enhance": "Fooocus Enhance", "Fooocus Enhance": "Fooocus Enhance",
"Fooocus Cinematic": "Fooocus Cinematic", "Fooocus Cinematic": "Fooocus Cinematic",
"Fooocus Sharp": "Fooocus Sharp", "Fooocus Sharp": "Fooocus Sharp",
"Drag any image generated by Fooocus here": "Drag any image generated by Fooocus here", "For images created by Fooocus": "For images created by Fooocus",
"Metadata": "Metadata", "Metadata": "Metadata",
"Apply Metadata": "Apply Metadata", "Apply Metadata": "Apply Metadata",
"Metadata Scheme": "Metadata Scheme", "Metadata Scheme": "Metadata Scheme",

View File

@ -62,8 +62,8 @@ def prepare_environment():
vae_approx_filenames = [ vae_approx_filenames = [
('xlvaeapp.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'), ('xlvaeapp.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'),
('vaeapp_sd15.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/vaeapp_sd15.pt'), ('vaeapp_sd15.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/vaeapp_sd15.pt'),
('xl-to-v1_interposer-v3.1.safetensors', ('xl-to-v1_interposer-v4.0.safetensors',
'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors') 'https://huggingface.co/mashb1t/misc/resolve/main/xl-to-v1_interposer-v4.0.safetensors')
] ]
@ -80,6 +80,10 @@ if args.gpu_device_id is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id)
print("Set device to:", args.gpu_device_id) print("Set device to:", args.gpu_device_id)
if args.hf_mirror is not None :
os.environ['HF_MIRROR'] = str(args.hf_mirror)
print("Set hf_mirror to:", args.hf_mirror)
from modules import config from modules import config
os.environ['GRADIO_TEMP_DIR'] = config.temp_path os.environ['GRADIO_TEMP_DIR'] = config.temp_path

View File

@ -0,0 +1,55 @@
# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
import numpy as np
import torch
def loglinear_interp(t_steps, num_steps):
"""
Performs log-linear interpolation of a given array of decreasing numbers.
"""
xs = np.linspace(0, 1, len(t_steps))
ys = np.log(t_steps[::-1])
new_xs = np.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys)
interped_ys = np.exp(new_ys)[::-1].copy()
return interped_ys
NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582],
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
"SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
class AlignYourStepsScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model_type": (["SD1", "SDXL", "SVD"], ),
"steps": ("INT", {"default": 10, "min": 10, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model_type, steps, denoise):
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
total_steps = round(steps * denoise)
sigmas = NOISE_LEVELS[model_type][:]
if (steps + 1) != len(sigmas):
sigmas = loglinear_interp(sigmas, steps + 1)
sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0
return (torch.FloatTensor(sigmas), )
NODE_CLASS_MAPPINGS = {
"AlignYourStepsScheduler": AlignYourStepsScheduler,
}

View File

@ -230,6 +230,25 @@ class SamplerDPMPP_SDE:
sampler = ldm_patched.modules.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) sampler = ldm_patched.modules.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, ) return (sampler, )
class SamplerTCD:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"eta": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, eta=0.3):
sampler = ldm_patched.modules.samplers.ksampler("tcd", {"eta": eta})
return (sampler, )
class SamplerCustom: class SamplerCustom:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -292,6 +311,7 @@ NODE_CLASS_MAPPINGS = {
"KSamplerSelect": KSamplerSelect, "KSamplerSelect": KSamplerSelect,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerTCD": SamplerTCD,
"SplitSigmas": SplitSigmas, "SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas, "FlipSigmas": FlipSigmas,
} }

View File

@ -70,7 +70,7 @@ class ModelSamplingDiscrete:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm"],), "sampling": (["eps", "v_prediction", "lcm", "tcd"]),
"zsnr": ("BOOLEAN", {"default": False}), "zsnr": ("BOOLEAN", {"default": False}),
}} }}
@ -90,6 +90,9 @@ class ModelSamplingDiscrete:
elif sampling == "lcm": elif sampling == "lcm":
sampling_type = LCM sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled sampling_base = ModelSamplingDiscreteDistilled
elif sampling == "tcd":
sampling_type = ldm_patched.modules.model_sampling.EPS
sampling_base = ModelSamplingDiscreteDistilled
class ModelSamplingAdvanced(sampling_base, sampling_type): class ModelSamplingAdvanced(sampling_base, sampling_type):
pass pass

View File

@ -752,7 +752,6 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
return x return x
@torch.no_grad() @torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/ # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
@ -808,3 +807,30 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
d_prime = w1 * d + w2 * d_2 + w3 * d_3 d_prime = w1 * d + w2 * d_2 + w3 * d_3
x = x + d_prime * dt x = x + d_prime * dt
return x return x
@torch.no_grad()
def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, eta=0.3):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.inner_model.model_sampling
timesteps_s = torch.floor((1 - eta) * model_sampling.timestep(sigmas)).to(dtype=torch.long).detach().cpu()
timesteps_s[-1] = 0
alpha_prod_s = model_sampling.alphas_cumprod[timesteps_s]
beta_prod_s = 1 - alpha_prod_s
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) # predicted_original_sample
eps = (x - denoised) / sigmas[i]
denoised = alpha_prod_s[i + 1].sqrt() * denoised + beta_prod_s[i + 1].sqrt() * eps
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
x = denoised
if eta > 0 and sigmas[i + 1] > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1])
x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt()
return x

View File

@ -37,6 +37,7 @@ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nar
parser.add_argument("--port", type=int, default=8188) parser.add_argument("--port", type=int, default=8188)
parser.add_argument("--disable-header-check", type=str, default=None, metavar="ORIGIN", nargs="?", const="*") parser.add_argument("--disable-header-check", type=str, default=None, metavar="ORIGIN", nargs="?", const="*")
parser.add_argument("--web-upload-size", type=float, default=100) parser.add_argument("--web-upload-size", type=float, default=100)
parser.add_argument("--hf-mirror", type=str, default=None)
parser.add_argument("--external-working-path", type=str, default=None, metavar="PATH", nargs='+', action='append') parser.add_argument("--external-working-path", type=str, default=None, metavar="PATH", nargs='+', action='append')
parser.add_argument("--output-path", type=str, default=None) parser.add_argument("--output-path", type=str, default=None)

View File

@ -50,17 +50,17 @@ class ModelSamplingDiscrete(torch.nn.Module):
self.linear_start = linear_start self.linear_start = linear_start
self.linear_end = linear_end self.linear_end = linear_end
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas) self.set_sigmas(sigmas)
self.set_alphas_cumprod(alphas_cumprod.float())
def set_sigmas(self, sigmas): def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas) self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log()) self.register_buffer('log_sigmas', sigmas.log())
def set_alphas_cumprod(self, alphas_cumprod):
self.register_buffer("alphas_cumprod", alphas_cumprod.float())
@property @property
def sigma_min(self): def sigma_min(self):
return self.sigmas[0] return self.sigmas[0]

View File

@ -523,7 +523,7 @@ class UNIPCBH2(Sampler):
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd"]
class KSAMPLER(Sampler): class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

View File

@ -427,12 +427,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename_param=None):
sd = ldm_patched.modules.utils.load_torch_file(ckpt_path) sd = ldm_patched.modules.utils.load_torch_file(ckpt_path)
sd_keys = sd.keys() sd_keys = sd.keys()
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
vae_filename = None
model = None model = None
model_patcher = None model_patcher = None
clip_target = None clip_target = None
@ -462,8 +463,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) if vae_filename_param is None:
vae_sd = model_config.process_vae_state_dict(vae_sd) vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
else:
vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename_param)
vae_filename = vae_filename_param
vae = VAE(sd=vae_sd) vae = VAE(sd=vae_sd)
if output_clip: if output_clip:
@ -485,7 +490,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print("loaded straight to GPU") print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher) model_management.load_model_gpu(model_patcher)
return (model_patcher, clip, vae, clipvision) return model_patcher, clip, vae, vae_filename, clipvision
def load_unet_state_dict(sd): #load unet in diffusers format def load_unet_state_dict(sd): #load unet in diffusers format

0
modules/__init__.py Normal file
View File

View File

@ -4,6 +4,7 @@ from modules.patch import PatchSettings, patch_settings, patch_all
patch_all() patch_all()
class AsyncTask: class AsyncTask:
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
@ -43,11 +44,13 @@ def worker():
import fooocus_version import fooocus_version
import args_manager import args_manager
from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays from extras.censor import default_censor
from modules.sdxl_styles import apply_style, get_random_style, fooocus_expansion, apply_arrays, random_style_name
from modules.private_logger import log from modules.private_logger import log
from extras.expansion import safe_str from extras.expansion import safe_str
from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \ from modules.util import (remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil,
get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras get_shape_ceil, resample_image, erode_or_dilate, get_enabled_loras,
parse_lora_references_from_prompt, apply_wildcards)
from modules.upscaler import perform_upscale from modules.upscaler import perform_upscale
from modules.flags import Performance from modules.flags import Performance
from modules.meta_parser import get_metadata_parser, MetadataScheme from modules.meta_parser import get_metadata_parser, MetadataScheme
@ -68,10 +71,15 @@ def worker():
print(f'[Fooocus] {text}') print(f'[Fooocus] {text}')
async_task.yields.append(['preview', (number, text, None)]) async_task.yields.append(['preview', (number, text, None)])
def yield_result(async_task, imgs, do_not_show_finished_images=False): def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False,
progressbar_index=flags.preparation_step_count):
if not isinstance(imgs, list): if not isinstance(imgs, list):
imgs = [imgs] imgs = [imgs]
if censor and (modules.config.default_black_out_nsfw or black_out_nsfw):
progressbar(async_task, progressbar_index, 'Checking for NSFW content ...')
imgs = default_censor(imgs)
async_task.results = async_task.results + imgs async_task.results = async_task.results + imgs
if do_not_show_finished_images: if do_not_show_finished_images:
@ -147,7 +155,8 @@ def worker():
base_model_name = args.pop() base_model_name = args.pop()
refiner_model_name = args.pop() refiner_model_name = args.pop()
refiner_switch = args.pop() refiner_switch = args.pop()
loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)]) loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in
range(modules.config.default_max_lora_number)])
input_image_checkbox = args.pop() input_image_checkbox = args.pop()
current_tab = args.pop() current_tab = args.pop()
uov_method = args.pop() uov_method = args.pop()
@ -160,12 +169,15 @@ def worker():
disable_preview = args.pop() disable_preview = args.pop()
disable_intermediate_results = args.pop() disable_intermediate_results = args.pop()
disable_seed_increment = args.pop() disable_seed_increment = args.pop()
black_out_nsfw = args.pop()
adm_scaler_positive = args.pop() adm_scaler_positive = args.pop()
adm_scaler_negative = args.pop() adm_scaler_negative = args.pop()
adm_scaler_end = args.pop() adm_scaler_end = args.pop()
adaptive_cfg = args.pop() adaptive_cfg = args.pop()
clip_skip = args.pop()
sampler_name = args.pop() sampler_name = args.pop()
scheduler_name = args.pop() scheduler_name = args.pop()
vae_name = args.pop()
overwrite_step = args.pop() overwrite_step = args.pop()
overwrite_switch = args.pop() overwrite_switch = args.pop()
overwrite_width = args.pop() overwrite_width = args.pop()
@ -195,7 +207,8 @@ def worker():
inpaint_erode_or_dilate = args.pop() inpaint_erode_or_dilate = args.pop()
save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False
metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS metadata_scheme = MetadataScheme(
args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
cn_tasks = {x: [] for x in flags.ip_list} cn_tasks = {x: [] for x in flags.ip_list}
for _ in range(flags.controlnet_image_count): for _ in range(flags.controlnet_image_count):
@ -225,10 +238,12 @@ def worker():
steps = performance_selection.steps() steps = performance_selection.steps()
performance_loras = []
if performance_selection == Performance.EXTREME_SPEED: if performance_selection == Performance.EXTREME_SPEED:
print('Enter LCM mode.') print('Enter LCM mode.')
progressbar(async_task, 1, 'Downloading LCM components ...') progressbar(async_task, 1, 'Downloading LCM components ...')
loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)] performance_loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]
if refiner_model_name != 'None': if refiner_model_name != 'None':
print(f'Refiner disabled in LCM mode.') print(f'Refiner disabled in LCM mode.')
@ -247,7 +262,7 @@ def worker():
elif performance_selection == Performance.LIGHTNING: elif performance_selection == Performance.LIGHTNING:
print('Enter Lightning mode.') print('Enter Lightning mode.')
progressbar(async_task, 1, 'Downloading Lightning components ...') progressbar(async_task, 1, 'Downloading Lightning components ...')
loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)] performance_loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)]
if refiner_model_name != 'None': if refiner_model_name != 'None':
print(f'Refiner disabled in Lightning mode.') print(f'Refiner disabled in Lightning mode.')
@ -263,7 +278,27 @@ def worker():
adm_scaler_negative = 1.0 adm_scaler_negative = 1.0
adm_scaler_end = 0.0 adm_scaler_end = 0.0
elif performance_selection == Performance.HYPER_SD:
print('Enter Hyper-SD mode.')
progressbar(async_task, 1, 'Downloading Hyper-SD components ...')
performance_loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)]
if refiner_model_name != 'None':
print(f'Refiner disabled in Hyper-SD mode.')
refiner_model_name = 'None'
sampler_name = 'dpmpp_sde_gpu'
scheduler_name = 'karras'
sharpness = 0.0
guidance_scale = 1.0
adaptive_cfg = 1.0
refiner_switch = 1.0
adm_scaler_positive = 1.0
adm_scaler_negative = 1.0
adm_scaler_end = 0.0
print(f'[Parameters] Adaptive CFG = {adaptive_cfg}') print(f'[Parameters] Adaptive CFG = {adaptive_cfg}')
print(f'[Parameters] CLIP Skip = {clip_skip}')
print(f'[Parameters] Sharpness = {sharpness}') print(f'[Parameters] Sharpness = {sharpness}')
print(f'[Parameters] ControlNet Softness = {controlnet_softness}') print(f'[Parameters] ControlNet Softness = {controlnet_softness}')
print(f'[Parameters] ADM Scale = ' print(f'[Parameters] ADM Scale = '
@ -425,14 +460,19 @@ def worker():
extra_positive_prompts = prompts[1:] if len(prompts) > 1 else [] extra_positive_prompts = prompts[1:] if len(prompts) > 1 else []
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else [] extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
progressbar(async_task, 3, 'Loading models ...') progressbar(async_task, 2, 'Loading models ...')
loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
loras += performance_loras
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras, loras=loras, base_model_additional_loras=base_model_additional_loras,
use_synthetic_refiner=use_synthetic_refiner) use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
pipeline.set_clip_skip(clip_skip)
progressbar(async_task, 3, 'Processing prompts ...') progressbar(async_task, 3, 'Processing prompts ...')
tasks = [] tasks = []
for i in range(image_number): for i in range(image_number):
if disable_seed_increment: if disable_seed_increment:
task_seed = seed % (constants.MAX_SEED + 1) task_seed = seed % (constants.MAX_SEED + 1)
@ -443,14 +483,20 @@ def worker():
task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order) task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
task_prompt = apply_arrays(task_prompt, i) task_prompt = apply_arrays(task_prompt, i)
task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order) task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts] task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts] extra_positive_prompts]
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
extra_negative_prompts]
positive_basic_workloads = [] positive_basic_workloads = []
negative_basic_workloads = [] negative_basic_workloads = []
task_styles = style_selections.copy()
if use_style: if use_style:
for s in style_selections: for i, s in enumerate(task_styles):
if s == random_style_name:
s = get_random_style(task_rng)
task_styles[i] = s
p, n = apply_style(s, positive=task_prompt) p, n = apply_style(s, positive=task_prompt)
positive_basic_workloads = positive_basic_workloads + p positive_basic_workloads = positive_basic_workloads + p
negative_basic_workloads = negative_basic_workloads + n negative_basic_workloads = negative_basic_workloads + n
@ -478,29 +524,30 @@ def worker():
negative_top_k=len(negative_basic_workloads), negative_top_k=len(negative_basic_workloads),
log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts), log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts),
log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts), log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts),
styles=task_styles
)) ))
if use_expansion: if use_expansion:
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
progressbar(async_task, 5, f'Preparing Fooocus text #{i + 1} ...') progressbar(async_task, 4, f'Preparing Fooocus text #{i + 1} ...')
expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed']) expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed'])
print(f'[Prompt Expansion] {expansion}') print(f'[Prompt Expansion] {expansion}')
t['expansion'] = expansion t['expansion'] = expansion
t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy. t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy.
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
progressbar(async_task, 7, f'Encoding positive #{i + 1} ...') progressbar(async_task, 5, f'Encoding positive #{i + 1} ...')
t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=t['positive_top_k']) t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=t['positive_top_k'])
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
if abs(float(cfg_scale) - 1.0) < 1e-4: if abs(float(cfg_scale) - 1.0) < 1e-4:
t['uc'] = pipeline.clone_cond(t['c']) t['uc'] = pipeline.clone_cond(t['c'])
else: else:
progressbar(async_task, 10, f'Encoding negative #{i + 1} ...') progressbar(async_task, 6, f'Encoding negative #{i + 1} ...')
t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k']) t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k'])
if len(goals) > 0: if len(goals) > 0:
progressbar(async_task, 13, 'Image processing ...') progressbar(async_task, 7, 'Image processing ...')
if 'vary' in goals: if 'vary' in goals:
if 'subtle' in uov_method: if 'subtle' in uov_method:
@ -521,7 +568,7 @@ def worker():
uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil) uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil)
initial_pixels = core.numpy_to_pytorch(uov_input_image) initial_pixels = core.numpy_to_pytorch(uov_input_image)
progressbar(async_task, 13, 'VAE encoding ...') progressbar(async_task, 8, 'VAE encoding ...')
candidate_vae, _ = pipeline.get_candidate_vae( candidate_vae, _ = pipeline.get_candidate_vae(
steps=steps, steps=steps,
@ -538,7 +585,7 @@ def worker():
if 'upscale' in goals: if 'upscale' in goals:
H, W, C = uov_input_image.shape H, W, C = uov_input_image.shape
progressbar(async_task, 13, f'Upscaling image from {str((H, W))} ...') progressbar(async_task, 9, f'Upscaling image from {str((H, W))} ...')
uov_input_image = perform_upscale(uov_input_image) uov_input_image = perform_upscale(uov_input_image)
print(f'Image upscaled.') print(f'Image upscaled.')
@ -572,8 +619,12 @@ def worker():
if direct_return: if direct_return:
d = [('Upscale (Fast)', 'upscale_fast', '2x')] d = [('Upscale (Fast)', 'upscale_fast', '2x')]
if modules.config.default_black_out_nsfw or black_out_nsfw:
progressbar(async_task, 100, 'Checking for NSFW content ...')
uov_input_image = default_censor(uov_input_image)
progressbar(async_task, 100, 'Saving image to system ...')
uov_input_image_path = log(uov_input_image, d, output_format=output_format) uov_input_image_path = log(uov_input_image, d, output_format=output_format)
yield_result(async_task, uov_input_image_path, do_not_show_finished_images=True) yield_result(async_task, uov_input_image_path, black_out_nsfw, False, do_not_show_finished_images=True)
return return
tiled = True tiled = True
@ -583,7 +634,7 @@ def worker():
denoising_strength = overwrite_upscale_strength denoising_strength = overwrite_upscale_strength
initial_pixels = core.numpy_to_pytorch(uov_input_image) initial_pixels = core.numpy_to_pytorch(uov_input_image)
progressbar(async_task, 13, 'VAE encoding ...') progressbar(async_task, 10, 'VAE encoding ...')
candidate_vae, _ = pipeline.get_candidate_vae( candidate_vae, _ = pipeline.get_candidate_vae(
steps=steps, steps=steps,
@ -637,11 +688,11 @@ def worker():
) )
if debugging_inpaint_preprocessor: if debugging_inpaint_preprocessor:
yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw,
do_not_show_finished_images=True) do_not_show_finished_images=True)
return return
progressbar(async_task, 13, 'VAE Inpaint encoding ...') progressbar(async_task, 11, 'VAE Inpaint encoding ...')
inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill) inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill)
inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image) inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image)
@ -661,7 +712,7 @@ def worker():
latent_swap = None latent_swap = None
if candidate_vae_swap is not None: if candidate_vae_swap is not None:
progressbar(async_task, 13, 'VAE SD15 encoding ...') progressbar(async_task, 12, 'VAE SD15 encoding ...')
latent_swap = core.encode_vae( latent_swap = core.encode_vae(
vae=candidate_vae_swap, vae=candidate_vae_swap,
pixels=inpaint_pixel_fill)['samples'] pixels=inpaint_pixel_fill)['samples']
@ -701,7 +752,7 @@ def worker():
cn_img = HWC3(cn_img) cn_img = HWC3(cn_img)
task[0] = core.numpy_to_pytorch(cn_img) task[0] = core.numpy_to_pytorch(cn_img)
if debugging_cn_preprocessor: if debugging_cn_preprocessor:
yield_result(async_task, cn_img, do_not_show_finished_images=True) yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True)
return return
for task in cn_tasks[flags.cn_cpds]: for task in cn_tasks[flags.cn_cpds]:
cn_img, cn_stop, cn_weight = task cn_img, cn_stop, cn_weight = task
@ -713,7 +764,7 @@ def worker():
cn_img = HWC3(cn_img) cn_img = HWC3(cn_img)
task[0] = core.numpy_to_pytorch(cn_img) task[0] = core.numpy_to_pytorch(cn_img)
if debugging_cn_preprocessor: if debugging_cn_preprocessor:
yield_result(async_task, cn_img, do_not_show_finished_images=True) yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True)
return return
for task in cn_tasks[flags.cn_ip]: for task in cn_tasks[flags.cn_ip]:
cn_img, cn_stop, cn_weight = task cn_img, cn_stop, cn_weight = task
@ -724,7 +775,7 @@ def worker():
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path) task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path)
if debugging_cn_preprocessor: if debugging_cn_preprocessor:
yield_result(async_task, cn_img, do_not_show_finished_images=True) yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True)
return return
for task in cn_tasks[flags.cn_ip_face]: for task in cn_tasks[flags.cn_ip_face]:
cn_img, cn_stop, cn_weight = task cn_img, cn_stop, cn_weight = task
@ -738,7 +789,7 @@ def worker():
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path) task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path)
if debugging_cn_preprocessor: if debugging_cn_preprocessor:
yield_result(async_task, cn_img, do_not_show_finished_images=True) yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True)
return return
all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face] all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face]
@ -773,29 +824,31 @@ def worker():
final_sampler_name = sampler_name final_sampler_name = sampler_name
final_scheduler_name = scheduler_name final_scheduler_name = scheduler_name
if scheduler_name == 'lcm': if scheduler_name in ['lcm', 'tcd']:
final_scheduler_name = 'sgm_uniform' final_scheduler_name = 'sgm_uniform'
if pipeline.final_unet is not None: if pipeline.final_unet is not None:
pipeline.final_unet = core.opModelSamplingDiscrete.patch( pipeline.final_unet = core.opModelSamplingDiscrete.patch(
pipeline.final_unet, pipeline.final_unet,
sampling='lcm', sampling=scheduler_name,
zsnr=False)[0] zsnr=False)[0]
if pipeline.final_refiner_unet is not None: if pipeline.final_refiner_unet is not None:
pipeline.final_refiner_unet = core.opModelSamplingDiscrete.patch( pipeline.final_refiner_unet = core.opModelSamplingDiscrete.patch(
pipeline.final_refiner_unet, pipeline.final_refiner_unet,
sampling='lcm', sampling=scheduler_name,
zsnr=False)[0] zsnr=False)[0]
print('Using lcm scheduler.') print(f'Using {scheduler_name} scheduler.')
async_task.yields.append(['preview', (13, 'Moving model to GPU ...', None)]) async_task.yields.append(['preview', (flags.preparation_step_count, 'Moving model to GPU ...', None)])
def callback(step, x0, x, total_steps, y): def callback(step, x0, x, total_steps, y):
done_steps = current_task_id * steps + step done_steps = current_task_id * steps + step
async_task.yields.append(['preview', ( async_task.yields.append(['preview', (
int(15.0 + 85.0 * float(done_steps) / float(all_steps)), int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(done_steps) / float(all_steps)),
f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', y)]) f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{image_number} ...', y)])
for current_task_id, task in enumerate(tasks): for current_task_id, task in enumerate(tasks):
current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(current_task_id * steps) / float(all_steps))
progressbar(async_task, current_progress, f'Preparing task {current_task_id + 1}/{image_number} ...')
execution_start_time = time.perf_counter() execution_start_time = time.perf_counter()
try: try:
@ -838,11 +891,18 @@ def worker():
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs] imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
img_paths = [] img_paths = []
current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float((current_task_id + 1) * steps) / float(all_steps))
if modules.config.default_black_out_nsfw or black_out_nsfw:
progressbar(async_task, current_progress, 'Checking for NSFW content ...')
imgs = default_censor(imgs)
progressbar(async_task, current_progress, f'Saving image {current_task_id + 1}/{image_number} to system ...')
for x in imgs: for x in imgs:
d = [('Prompt', 'prompt', task['log_positive_prompt']), d = [('Prompt', 'prompt', task['log_positive_prompt']),
('Negative Prompt', 'negative_prompt', task['log_negative_prompt']), ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']), ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']),
('Styles', 'styles', str(raw_style_selections)), ('Styles', 'styles',
str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])),
('Performance', 'performance', performance_selection.value)] ('Performance', 'performance', performance_selection.value)]
if performance_selection.steps() != steps: if performance_selection.steps() != steps:
@ -865,10 +925,14 @@ def worker():
if refiner_swap_method != flags.refiner_swap_method: if refiner_swap_method != flags.refiner_swap_method:
d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method)) d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method))
if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr: if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr:
d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg)) d.append(
('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg))
if clip_skip > 1:
d.append(('CLIP Skip', 'clip_skip', clip_skip))
d.append(('Sampler', 'sampler', sampler_name)) d.append(('Sampler', 'sampler', sampler_name))
d.append(('Scheduler', 'scheduler', scheduler_name)) d.append(('Scheduler', 'scheduler', scheduler_name))
d.append(('VAE', 'vae', vae_name))
d.append(('Seed', 'seed', str(task['task_seed']))) d.append(('Seed', 'seed', str(task['task_seed'])))
if freeu_enabled: if freeu_enabled:
@ -883,12 +947,14 @@ def worker():
metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme) metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
metadata_parser.set_data(task['log_positive_prompt'], task['positive'], metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
task['log_negative_prompt'], task['negative'], task['log_negative_prompt'], task['negative'],
steps, base_model_name, refiner_model_name, loras) steps, base_model_name, refiner_model_name, loras, vae_name)
d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) d.append(('Metadata Scheme', 'metadata_scheme',
metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
img_paths.append(log(x, d, metadata_parser, output_format)) img_paths.append(log(x, d, metadata_parser, output_format, task))
yield_result(async_task, img_paths, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) yield_result(async_task, img_paths, black_out_nsfw, False,
do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results)
except ldm_patched.modules.model_management.InterruptProcessingException as e: except ldm_patched.modules.model_management.InterruptProcessingException as e:
if async_task.last_stop == 'skip': if async_task.last_stop == 'skip':
print('User skipped') print('User skipped')

View File

@ -8,7 +8,7 @@ import modules.flags
import modules.sdxl_styles import modules.sdxl_styles
from modules.model_loader import load_file_from_url from modules.model_loader import load_file_from_url
from modules.util import get_files_from_folder, makedirs_with_log from modules.extra_utils import makedirs_with_log, get_files_from_folder
from modules.flags import OutputFormat, Performance, MetadataScheme from modules.flags import OutputFormat, Performance, MetadataScheme
@ -20,7 +20,7 @@ def get_config_path(key, default_value):
else: else:
return os.path.abspath(default_value) return os.path.abspath(default_value)
wildcards_max_bfs_depth = 64
config_path = get_config_path('config_path', "./config.txt") config_path = get_config_path('config_path', "./config.txt")
config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt") config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt")
config_dict = {} config_dict = {}
@ -189,12 +189,14 @@ paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/check
paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True) paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True)
path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/') path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/')
path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/') path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/')
path_vae = get_dir_or_set_default('path_vae', '../models/vae/')
path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/') path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/')
path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/') path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/')
path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/') path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/')
path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/') path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/')
path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion') path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/') path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/')
path_safety_checker = get_dir_or_set_default('path_safety_checker', '../models/safety_checker/')
path_outputs = get_path_output() path_outputs = get_path_output()
@ -346,6 +348,11 @@ default_scheduler = get_config_item_or_set_default(
default_value='karras', default_value='karras',
validator=lambda x: x in modules.flags.scheduler_list validator=lambda x: x in modules.flags.scheduler_list
) )
default_vae = get_config_item_or_set_default(
key='default_vae',
default_value=modules.flags.default_vae,
validator=lambda x: isinstance(x, str)
)
default_styles = get_config_item_or_set_default( default_styles = get_config_item_or_set_default(
key='default_styles', key='default_styles',
default_value=[ default_value=[
@ -409,13 +416,7 @@ embeddings_downloads = get_config_item_or_set_default(
) )
available_aspect_ratios = get_config_item_or_set_default( available_aspect_ratios = get_config_item_or_set_default(
key='available_aspect_ratios', key='available_aspect_ratios',
default_value=[ default_value=modules.flags.sdxl_aspect_ratios,
'704*1408', '704*1344', '768*1344', '768*1280', '832*1216', '832*1152',
'896*1152', '896*1088', '960*1088', '960*1024', '1024*1024', '1024*960',
'1088*960', '1088*896', '1152*896', '1152*832', '1216*832', '1280*768',
'1344*768', '1344*704', '1408*704', '1472*704', '1536*640', '1600*640',
'1664*576', '1728*576'
],
validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1 validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1
) )
default_aspect_ratio = get_config_item_or_set_default( default_aspect_ratio = get_config_item_or_set_default(
@ -433,6 +434,11 @@ default_cfg_tsnr = get_config_item_or_set_default(
default_value=7.0, default_value=7.0,
validator=lambda x: isinstance(x, numbers.Number) validator=lambda x: isinstance(x, numbers.Number)
) )
default_clip_skip = get_config_item_or_set_default(
key='default_clip_skip',
default_value=1,
validator=lambda x: isinstance(x, numbers.Number)
)
default_overwrite_step = get_config_item_or_set_default( default_overwrite_step = get_config_item_or_set_default(
key='default_overwrite_step', key='default_overwrite_step',
default_value=-1, default_value=-1,
@ -450,6 +456,11 @@ example_inpaint_prompts = get_config_item_or_set_default(
], ],
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x) validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x)
) )
default_black_out_nsfw = get_config_item_or_set_default(
key='default_black_out_nsfw',
default_value=False,
validator=lambda x: isinstance(x, bool)
)
default_save_metadata_to_images = get_config_item_or_set_default( default_save_metadata_to_images = get_config_item_or_set_default(
key='default_save_metadata_to_images', key='default_save_metadata_to_images',
default_value=False, default_value=False,
@ -481,6 +492,8 @@ possible_preset_keys = {
"default_loras": "<processed>", "default_loras": "<processed>",
"default_cfg_scale": "guidance_scale", "default_cfg_scale": "guidance_scale",
"default_sample_sharpness": "sharpness", "default_sample_sharpness": "sharpness",
"default_cfg_tsnr": "adaptive_cfg",
"default_clip_skip": "clip_skip",
"default_sampler": "sampler", "default_sampler": "sampler",
"default_scheduler": "scheduler", "default_scheduler": "scheduler",
"default_overwrite_step": "steps", "default_overwrite_step": "steps",
@ -514,7 +527,7 @@ def add_ratio(x):
default_aspect_ratio = add_ratio(default_aspect_ratio) default_aspect_ratio = add_ratio(default_aspect_ratio)
available_aspect_ratios = [add_ratio(x) for x in available_aspect_ratios] available_aspect_ratios_labels = [add_ratio(x) for x in available_aspect_ratios]
# Only write config in the first launch. # Only write config in the first launch.
@ -535,26 +548,45 @@ with open(config_example_path, "w", encoding="utf-8") as json_file:
model_filenames = [] model_filenames = []
lora_filenames = [] lora_filenames = []
lora_filenames_no_special = []
vae_filenames = []
wildcard_filenames = [] wildcard_filenames = []
sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors' sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors' sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors'
loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora] sdxl_hyper_sd_lora = 'sdxl_hyper_sd_4step_lora.safetensors'
loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora, sdxl_hyper_sd_lora]
def remove_special_loras(lora_filenames):
global loras_metadata_remove
loras_no_special = lora_filenames.copy()
for lora_to_remove in loras_metadata_remove:
if lora_to_remove in loras_no_special:
loras_no_special.remove(lora_to_remove)
return loras_no_special
def get_model_filenames(folder_paths, extensions=None, name_filter=None): def get_model_filenames(folder_paths, extensions=None, name_filter=None):
if extensions is None: if extensions is None:
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'] extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
files = [] files = []
if not isinstance(folder_paths, list):
folder_paths = [folder_paths]
for folder in folder_paths: for folder in folder_paths:
files += get_files_from_folder(folder, extensions, name_filter) files += get_files_from_folder(folder, extensions, name_filter)
return files return files
def update_files(): def update_files():
global model_filenames, lora_filenames, wildcard_filenames, available_presets global model_filenames, lora_filenames, lora_filenames_no_special, vae_filenames, wildcard_filenames, available_presets
model_filenames = get_model_filenames(paths_checkpoints) model_filenames = get_model_filenames(paths_checkpoints)
lora_filenames = get_model_filenames(paths_loras) lora_filenames = get_model_filenames(paths_loras)
lora_filenames_no_special = remove_special_loras(lora_filenames)
vae_filenames = get_model_filenames(path_vae)
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt']) wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
available_presets = get_presets() available_presets = get_presets()
return return
@ -608,13 +640,22 @@ def downloading_sdxl_lcm_lora():
def downloading_sdxl_lightning_lora(): def downloading_sdxl_lightning_lora():
load_file_from_url( load_file_from_url(
url='https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_4step_lora.safetensors', url='https://huggingface.co/mashb1t/misc/resolve/main/sdxl_lightning_4step_lora.safetensors',
model_dir=paths_loras[0], model_dir=paths_loras[0],
file_name=sdxl_lightning_lora file_name=sdxl_lightning_lora
) )
return sdxl_lightning_lora return sdxl_lightning_lora
def downloading_sdxl_hyper_sd_lora():
load_file_from_url(
url='https://huggingface.co/mashb1t/misc/resolve/main/sdxl_hyper_sd_4step_lora.safetensors',
model_dir=paths_loras[0],
file_name=sdxl_hyper_sd_lora
)
return sdxl_hyper_sd_lora
def downloading_controlnet_canny(): def downloading_controlnet_canny():
load_file_from_url( load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors', url='https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors',
@ -679,5 +720,13 @@ def downloading_upscale_model():
) )
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin') return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
def downloading_safety_checker_model():
load_file_from_url(
url='https://huggingface.co/mashb1t/misc/resolve/main/stable-diffusion-safety-checker.bin',
model_dir=path_safety_checker,
file_name='stable-diffusion-safety-checker.bin'
)
return os.path.join(path_safety_checker, 'stable-diffusion-safety-checker.bin')
update_files() update_files()

View File

@ -35,12 +35,13 @@ opModelSamplingDiscrete = ModelSamplingDiscrete()
class StableDiffusionModel: class StableDiffusionModel:
def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None): def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None, vae_filename=None):
self.unet = unet self.unet = unet
self.vae = vae self.vae = vae
self.clip = clip self.clip = clip
self.clip_vision = clip_vision self.clip_vision = clip_vision
self.filename = filename self.filename = filename
self.vae_filename = vae_filename
self.unet_with_lora = unet self.unet_with_lora = unet
self.clip_with_lora = clip self.clip_with_lora = clip
self.visited_loras = '' self.visited_loras = ''
@ -142,9 +143,10 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def load_model(ckpt_filename): def load_model(ckpt_filename, vae_filename=None):
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings) unet, clip, vae, vae_filename, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings,
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename) vae_filename_param=vae_filename)
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename, vae_filename=vae_filename)
@torch.no_grad() @torch.no_grad()

View File

@ -3,6 +3,7 @@ import os
import torch import torch
import modules.patch import modules.patch
import modules.config import modules.config
import modules.flags
import ldm_patched.modules.model_management import ldm_patched.modules.model_management
import ldm_patched.modules.latent_formats import ldm_patched.modules.latent_formats
import modules.inpaint_worker import modules.inpaint_worker
@ -58,17 +59,21 @@ def assert_model_integrity():
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def refresh_base_model(name): def refresh_base_model(name, vae_name=None):
global model_base global model_base
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints) filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
if model_base.filename == filename: vae_filename = None
if vae_name is not None and vae_name != modules.flags.default_vae:
vae_filename = get_file_from_folder_list(vae_name, modules.config.path_vae)
if model_base.filename == filename and model_base.vae_filename == vae_filename:
return return
model_base = core.StableDiffusionModel() model_base = core.load_model(filename, vae_filename)
model_base = core.load_model(filename)
print(f'Base model loaded: {model_base.filename}') print(f'Base model loaded: {model_base.filename}')
print(f'VAE loaded: {model_base.vae_filename}')
return return
@ -196,6 +201,17 @@ def clip_encode(texts, pool_top_k=1):
return [[torch.cat(cond_list, dim=1), {"pooled_output": pooled_acc}]] return [[torch.cat(cond_list, dim=1), {"pooled_output": pooled_acc}]]
@torch.no_grad()
@torch.inference_mode()
def set_clip_skip(clip_skip: int):
global final_clip
if final_clip is None:
return
final_clip.clip_layer(-abs(clip_skip))
return
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def clear_all_caches(): def clear_all_caches():
@ -216,7 +232,7 @@ def prepare_text_encoder(async_call=True):
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def refresh_everything(refiner_model_name, base_model_name, loras, def refresh_everything(refiner_model_name, base_model_name, loras,
base_model_additional_loras=None, use_synthetic_refiner=False): base_model_additional_loras=None, use_synthetic_refiner=False, vae_name=None):
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
final_unet = None final_unet = None
@ -227,11 +243,11 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
if use_synthetic_refiner and refiner_model_name == 'None': if use_synthetic_refiner and refiner_model_name == 'None':
print('Synthetic Refiner Activated') print('Synthetic Refiner Activated')
refresh_base_model(base_model_name) refresh_base_model(base_model_name, vae_name)
synthesize_refiner_model() synthesize_refiner_model()
else: else:
refresh_refiner_model(refiner_model_name) refresh_refiner_model(refiner_model_name)
refresh_base_model(base_model_name) refresh_base_model(base_model_name, vae_name)
refresh_loras(loras, base_model_additional_loras=base_model_additional_loras) refresh_loras(loras, base_model_additional_loras=base_model_additional_loras)
assert_model_integrity() assert_model_integrity()
@ -254,7 +270,8 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
refresh_everything( refresh_everything(
refiner_model_name=modules.config.default_refiner_model_name, refiner_model_name=modules.config.default_refiner_model_name,
base_model_name=modules.config.default_base_model_name, base_model_name=modules.config.default_base_model_name,
loras=get_enabled_loras(modules.config.default_loras) loras=get_enabled_loras(modules.config.default_loras),
vae_name=modules.config.default_vae,
) )

26
modules/extra_utils.py Normal file
View File

@ -0,0 +1,26 @@
import os
def makedirs_with_log(path):
try:
os.makedirs(path, exist_ok=True)
except OSError as error:
print(f'Directory {path} could not be created, reason: {error}')
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
if not os.path.isdir(folder_path):
raise ValueError("Folder path is not a valid directory.")
filenames = []
for root, _, files in os.walk(folder_path, topdown=False):
relative_path = os.path.relpath(root, folder_path)
if relative_path == ".":
relative_path = ""
for filename in sorted(files, key=lambda s: s.casefold()):
_, file_extension = os.path.splitext(filename)
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
path = os.path.join(relative_path, filename)
filenames.append(path)
return filenames

View File

@ -34,7 +34,8 @@ KSAMPLER = {
"dpmpp_3m_sde": "", "dpmpp_3m_sde": "",
"dpmpp_3m_sde_gpu": "", "dpmpp_3m_sde_gpu": "",
"ddpm": "", "ddpm": "",
"lcm": "LCM" "lcm": "LCM",
"tcd": "TCD"
} }
SAMPLER_EXTRA = { SAMPLER_EXTRA = {
@ -47,12 +48,14 @@ SAMPLERS = KSAMPLER | SAMPLER_EXTRA
KSAMPLER_NAMES = list(KSAMPLER.keys()) KSAMPLER_NAMES = list(KSAMPLER.keys())
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo"] SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "align_your_steps", "tcd"]
SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys()) SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())
sampler_list = SAMPLER_NAMES sampler_list = SAMPLER_NAMES
scheduler_list = SCHEDULER_NAMES scheduler_list = SCHEDULER_NAMES
default_vae = 'Default (model)'
refiner_swap_method = 'joint' refiner_swap_method = 'joint'
cn_ip = "ImagePrompt" cn_ip = "ImagePrompt"
@ -78,6 +81,13 @@ inpaint_options = [inpaint_option_default, inpaint_option_detail, inpaint_option
desc_type_photo = 'Photograph' desc_type_photo = 'Photograph'
desc_type_anime = 'Art/Anime' desc_type_anime = 'Art/Anime'
sdxl_aspect_ratios = [
'704*1408', '704*1344', '768*1344', '768*1280', '832*1216', '832*1152',
'896*1152', '896*1088', '960*1088', '960*1024', '1024*1024', '1024*960',
'1088*960', '1088*896', '1152*896', '1152*832', '1216*832', '1280*768',
'1344*768', '1344*704', '1408*704', '1472*704', '1536*640', '1600*640',
'1664*576', '1728*576'
]
class MetadataScheme(Enum): class MetadataScheme(Enum):
FOOOCUS = 'fooocus' FOOOCUS = 'fooocus'
@ -90,6 +100,7 @@ metadata_scheme = [
] ]
controlnet_image_count = 4 controlnet_image_count = 4
preparation_step_count = 13
class OutputFormat(Enum): class OutputFormat(Enum):
@ -107,6 +118,7 @@ class Steps(IntEnum):
SPEED = 30 SPEED = 30
EXTREME_SPEED = 8 EXTREME_SPEED = 8
LIGHTNING = 4 LIGHTNING = 4
HYPER_SD = 4
class StepsUOV(IntEnum): class StepsUOV(IntEnum):
@ -114,6 +126,7 @@ class StepsUOV(IntEnum):
SPEED = 18 SPEED = 18
EXTREME_SPEED = 8 EXTREME_SPEED = 8
LIGHTNING = 4 LIGHTNING = 4
HYPER_SD = 4
class Performance(Enum): class Performance(Enum):
@ -121,6 +134,7 @@ class Performance(Enum):
SPEED = 'Speed' SPEED = 'Speed'
EXTREME_SPEED = 'Extreme Speed' EXTREME_SPEED = 'Extreme Speed'
LIGHTNING = 'Lightning' LIGHTNING = 'Lightning'
HYPER_SD = 'Hyper-SD'
@classmethod @classmethod
def list(cls) -> list: def list(cls) -> list:
@ -130,7 +144,7 @@ class Performance(Enum):
def has_restricted_features(cls, x) -> bool: def has_restricted_features(cls, x) -> bool:
if isinstance(x, Performance): if isinstance(x, Performance):
x = x.value x = x.value
return x in [cls.EXTREME_SPEED.value, cls.LIGHTNING.value] return x in [cls.EXTREME_SPEED.value, cls.LIGHTNING.value, cls.HYPER_SD.value]
def steps(self) -> int | None: def steps(self) -> int | None:
return Steps[self.name].value if Steps[self.name] else None return Steps[self.name].value if Steps[self.name] else None

View File

@ -34,18 +34,20 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
get_list('styles', 'Styles', loaded_parameter_dict, results) get_list('styles', 'Styles', loaded_parameter_dict, results)
get_str('performance', 'Performance', loaded_parameter_dict, results) get_str('performance', 'Performance', loaded_parameter_dict, results)
get_steps('steps', 'Steps', loaded_parameter_dict, results) get_steps('steps', 'Steps', loaded_parameter_dict, results)
get_float('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results) get_number('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results)
get_resolution('resolution', 'Resolution', loaded_parameter_dict, results) get_resolution('resolution', 'Resolution', loaded_parameter_dict, results)
get_float('guidance_scale', 'Guidance Scale', loaded_parameter_dict, results) get_number('guidance_scale', 'Guidance Scale', loaded_parameter_dict, results)
get_float('sharpness', 'Sharpness', loaded_parameter_dict, results) get_number('sharpness', 'Sharpness', loaded_parameter_dict, results)
get_adm_guidance('adm_guidance', 'ADM Guidance', loaded_parameter_dict, results) get_adm_guidance('adm_guidance', 'ADM Guidance', loaded_parameter_dict, results)
get_str('refiner_swap_method', 'Refiner Swap Method', loaded_parameter_dict, results) get_str('refiner_swap_method', 'Refiner Swap Method', loaded_parameter_dict, results)
get_float('adaptive_cfg', 'CFG Mimicking from TSNR', loaded_parameter_dict, results) get_number('adaptive_cfg', 'CFG Mimicking from TSNR', loaded_parameter_dict, results)
get_number('clip_skip', 'CLIP Skip', loaded_parameter_dict, results, cast_type=int)
get_str('base_model', 'Base Model', loaded_parameter_dict, results) get_str('base_model', 'Base Model', loaded_parameter_dict, results)
get_str('refiner_model', 'Refiner Model', loaded_parameter_dict, results) get_str('refiner_model', 'Refiner Model', loaded_parameter_dict, results)
get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results) get_number('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results)
get_str('sampler', 'Sampler', loaded_parameter_dict, results) get_str('sampler', 'Sampler', loaded_parameter_dict, results)
get_str('scheduler', 'Scheduler', loaded_parameter_dict, results) get_str('scheduler', 'Scheduler', loaded_parameter_dict, results)
get_str('vae', 'VAE', loaded_parameter_dict, results)
get_seed('seed', 'Seed', loaded_parameter_dict, results) get_seed('seed', 'Seed', loaded_parameter_dict, results)
if is_generating: if is_generating:
@ -82,11 +84,11 @@ def get_list(key: str, fallback: str | None, source_dict: dict, results: list, d
results.append(gr.update()) results.append(gr.update())
def get_float(key: str, fallback: str | None, source_dict: dict, results: list, default=None): def get_number(key: str, fallback: str | None, source_dict: dict, results: list, default=None, cast_type=float):
try: try:
h = source_dict.get(key, source_dict.get(fallback, default)) h = source_dict.get(key, source_dict.get(fallback, default))
assert h is not None assert h is not None
h = float(h) h = cast_type(h)
results.append(h) results.append(h)
except: except:
results.append(gr.update()) results.append(gr.update())
@ -123,7 +125,7 @@ def get_resolution(key: str, fallback: str | None, source_dict: dict, results: l
h = source_dict.get(key, source_dict.get(fallback, default)) h = source_dict.get(key, source_dict.get(fallback, default))
width, height = eval(h) width, height = eval(h)
formatted = modules.config.add_ratio(f'{width}*{height}') formatted = modules.config.add_ratio(f'{width}*{height}')
if formatted in modules.config.available_aspect_ratios: if formatted in modules.config.available_aspect_ratios_labels:
results.append(formatted) results.append(formatted)
results.append(-1) results.append(-1)
results.append(-1) results.append(-1)
@ -204,7 +206,6 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
def get_sha256(filepath): def get_sha256(filepath):
global hash_cache global hash_cache
if filepath not in hash_cache: if filepath not in hash_cache:
# is_safetensors = os.path.splitext(filepath)[1].lower() == '.safetensors'
hash_cache[filepath] = sha256(filepath) hash_cache[filepath] = sha256(filepath)
return hash_cache[filepath] return hash_cache[filepath]
@ -253,6 +254,7 @@ class MetadataParser(ABC):
self.refiner_model_name: str = '' self.refiner_model_name: str = ''
self.refiner_model_hash: str = '' self.refiner_model_hash: str = ''
self.loras: list = [] self.loras: list = []
self.vae_name: str = ''
@abstractmethod @abstractmethod
def get_scheme(self) -> MetadataScheme: def get_scheme(self) -> MetadataScheme:
@ -267,7 +269,7 @@ class MetadataParser(ABC):
raise NotImplementedError raise NotImplementedError
def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name, def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name,
refiner_model_name, loras): refiner_model_name, loras, vae_name):
self.raw_prompt = raw_prompt self.raw_prompt = raw_prompt
self.full_prompt = full_prompt self.full_prompt = full_prompt
self.raw_negative_prompt = raw_negative_prompt self.raw_negative_prompt = raw_negative_prompt
@ -289,12 +291,7 @@ class MetadataParser(ABC):
lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras) lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras)
lora_hash = get_sha256(lora_path) lora_hash = get_sha256(lora_path)
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash)) self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
self.vae_name = Path(vae_name).stem
@staticmethod
def remove_special_loras(lora_filenames):
for lora_to_remove in modules.config.loras_metadata_remove:
if lora_to_remove in lora_filenames:
lora_filenames.remove(lora_to_remove)
class A1111MetadataParser(MetadataParser): class A1111MetadataParser(MetadataParser):
@ -310,6 +307,7 @@ class A1111MetadataParser(MetadataParser):
'steps': 'Steps', 'steps': 'Steps',
'sampler': 'Sampler', 'sampler': 'Sampler',
'scheduler': 'Scheduler', 'scheduler': 'Scheduler',
'vae': 'VAE',
'guidance_scale': 'CFG scale', 'guidance_scale': 'CFG scale',
'seed': 'Seed', 'seed': 'Seed',
'resolution': 'Size', 'resolution': 'Size',
@ -317,6 +315,7 @@ class A1111MetadataParser(MetadataParser):
'adm_guidance': 'ADM Guidance', 'adm_guidance': 'ADM Guidance',
'refiner_swap_method': 'Refiner Swap Method', 'refiner_swap_method': 'Refiner Swap Method',
'adaptive_cfg': 'Adaptive CFG', 'adaptive_cfg': 'Adaptive CFG',
'clip_skip': 'Clip skip',
'overwrite_switch': 'Overwrite Switch', 'overwrite_switch': 'Overwrite Switch',
'freeu': 'FreeU', 'freeu': 'FreeU',
'base_model': 'Model', 'base_model': 'Model',
@ -397,13 +396,12 @@ class A1111MetadataParser(MetadataParser):
data['sampler'] = k data['sampler'] = k
break break
for key in ['base_model', 'refiner_model']: for key in ['base_model', 'refiner_model', 'vae']:
if key in data: if key in data:
for filename in modules.config.model_filenames: if key == 'vae':
path = Path(filename) self.add_extension_to_filename(data, modules.config.vae_filenames, 'vae')
if data[key] == path.stem: else:
data[key] = filename self.add_extension_to_filename(data, modules.config.model_filenames, key)
break
lora_data = '' lora_data = ''
if 'lora_weights' in data and data['lora_weights'] != '': if 'lora_weights' in data and data['lora_weights'] != '':
@ -412,13 +410,11 @@ class A1111MetadataParser(MetadataParser):
lora_data = data['lora_hashes'] lora_data = data['lora_hashes']
if lora_data != '': if lora_data != '':
lora_filenames = modules.config.lora_filenames.copy()
self.remove_special_loras(lora_filenames)
for li, lora in enumerate(lora_data.split(', ')): for li, lora in enumerate(lora_data.split(', ')):
lora_split = lora.split(': ') lora_split = lora.split(': ')
lora_name = lora_split[0] lora_name = lora_split[0]
lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1] lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1]
for filename in lora_filenames: for filename in modules.config.lora_filenames_no_special:
path = Path(filename) path = Path(filename)
if lora_name == path.stem: if lora_name == path.stem:
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}' data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
@ -433,6 +429,7 @@ class A1111MetadataParser(MetadataParser):
sampler = data['sampler'] sampler = data['sampler']
scheduler = data['scheduler'] scheduler = data['scheduler']
if sampler in SAMPLERS and SAMPLERS[sampler] != '': if sampler in SAMPLERS and SAMPLERS[sampler] != '':
sampler = SAMPLERS[sampler] sampler = SAMPLERS[sampler]
if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras': if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras':
@ -451,6 +448,7 @@ class A1111MetadataParser(MetadataParser):
self.fooocus_to_a1111['performance']: data['performance'], self.fooocus_to_a1111['performance']: data['performance'],
self.fooocus_to_a1111['scheduler']: scheduler, self.fooocus_to_a1111['scheduler']: scheduler,
self.fooocus_to_a1111['vae']: Path(data['vae']).stem,
# workaround for multiline prompts # workaround for multiline prompts
self.fooocus_to_a1111['raw_prompt']: self.raw_prompt, self.fooocus_to_a1111['raw_prompt']: self.raw_prompt,
self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt, self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt,
@ -462,7 +460,7 @@ class A1111MetadataParser(MetadataParser):
self.fooocus_to_a1111['refiner_model_hash']: self.refiner_model_hash self.fooocus_to_a1111['refiner_model_hash']: self.refiner_model_hash
} }
for key in ['adaptive_cfg', 'overwrite_switch', 'refiner_swap_method', 'freeu']: for key in ['adaptive_cfg', 'clip_skip', 'overwrite_switch', 'refiner_swap_method', 'freeu']:
if key in data: if key in data:
generation_params[self.fooocus_to_a1111[key]] = data[key] generation_params[self.fooocus_to_a1111[key]] = data[key]
@ -491,22 +489,29 @@ class A1111MetadataParser(MetadataParser):
negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else "" negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else ""
return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip() return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip()
@staticmethod
def add_extension_to_filename(data, filenames, key):
for filename in filenames:
path = Path(filename)
if data[key] == path.stem:
data[key] = filename
break
class FooocusMetadataParser(MetadataParser): class FooocusMetadataParser(MetadataParser):
def get_scheme(self) -> MetadataScheme: def get_scheme(self) -> MetadataScheme:
return MetadataScheme.FOOOCUS return MetadataScheme.FOOOCUS
def parse_json(self, metadata: dict) -> dict: def parse_json(self, metadata: dict) -> dict:
model_filenames = modules.config.model_filenames.copy()
lora_filenames = modules.config.lora_filenames.copy()
self.remove_special_loras(lora_filenames)
for key, value in metadata.items(): for key, value in metadata.items():
if value in ['', 'None']: if value in ['', 'None']:
continue continue
if key in ['base_model', 'refiner_model']: if key in ['base_model', 'refiner_model']:
metadata[key] = self.replace_value_with_filename(key, value, model_filenames) metadata[key] = self.replace_value_with_filename(key, value, modules.config.model_filenames)
elif key.startswith('lora_combined_'): elif key.startswith('lora_combined_'):
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames) metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames_no_special)
elif key == 'vae':
metadata[key] = self.replace_value_with_filename(key, value, modules.config.vae_filenames)
else: else:
continue continue
@ -533,6 +538,7 @@ class FooocusMetadataParser(MetadataParser):
res['refiner_model'] = self.refiner_model_name res['refiner_model'] = self.refiner_model_name
res['refiner_model_hash'] = self.refiner_model_hash res['refiner_model_hash'] = self.refiner_model_hash
res['vae'] = self.vae_name
res['loras'] = self.loras res['loras'] = self.loras
if modules.config.metadata_created_by != '': if modules.config.metadata_created_by != '':

View File

@ -14,6 +14,8 @@ def load_file_from_url(
Returns the path to the downloaded file. Returns the path to the downloaded file.
""" """
domain = os.environ.get("HF_MIRROR", "https://huggingface.co").rstrip('/')
url = str.replace(url, "https://huggingface.co", domain, 1)
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)
if not file_name: if not file_name:
parts = urlparse(url) parts = urlparse(url)

View File

@ -51,6 +51,8 @@ def patched_register_schedule(self, given_betas=None, beta_schedule="linear", ti
self.linear_end = linear_end self.linear_end = linear_end
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
self.set_sigmas(sigmas) self.set_sigmas(sigmas)
alphas_cumprod = torch.tensor(alphas_cumprod, dtype=torch.float32)
self.set_alphas_cumprod(alphas_cumprod)
return return

View File

@ -21,7 +21,7 @@ def get_current_html_path(output_format=None):
return html_name return html_name
def log(img, metadata, metadata_parser: MetadataParser | None = None, output_format=None) -> str: def log(img, metadata, metadata_parser: MetadataParser | None = None, output_format=None, task=None) -> str:
path_outputs = modules.config.temp_path if args_manager.args.disable_image_log else modules.config.path_outputs path_outputs = modules.config.temp_path if args_manager.args.disable_image_log else modules.config.path_outputs
output_format = output_format if output_format else modules.config.default_output_format output_format = output_format if output_format else modules.config.default_output_format
date_string, local_temp_filename, only_name = generate_temp_filename(folder=path_outputs, extension=output_format) date_string, local_temp_filename, only_name = generate_temp_filename(folder=path_outputs, extension=output_format)
@ -111,9 +111,15 @@ def log(img, metadata, metadata_parser: MetadataParser | None = None, output_for
for label, key, value in metadata: for label, key, value in metadata:
value_txt = str(value).replace('\n', ' </br> ') value_txt = str(value).replace('\n', ' </br> ')
item += f"<tr><td class='label'>{label}</td><td class='value'>{value_txt}</td></tr>\n" item += f"<tr><td class='label'>{label}</td><td class='value'>{value_txt}</td></tr>\n"
if task is not None and 'positive' in task and 'negative' in task:
full_prompt_details = f"""<details><summary>Positive</summary>{', '.join(task['positive'])}</details>
<details><summary>Negative</summary>{', '.join(task['negative'])}</details>"""
item += f"<tr><td class='label'>Full raw prompt</td><td class='value'>{full_prompt_details}</td></tr>\n"
item += "</table>" item += "</table>"
js_txt = urllib.parse.quote(json.dumps({k: v for _, k, v in metadata}, indent=0), safe='') js_txt = urllib.parse.quote(json.dumps({k: v for _, k, v, in metadata}, indent=0), safe='')
item += f"</br><button onclick=\"to_clipboard('{js_txt}')\">Copy to Clipboard</button>" item += f"</br><button onclick=\"to_clipboard('{js_txt}')\">Copy to Clipboard</button>"
item += "</td>" item += "</td>"

View File

@ -3,6 +3,7 @@ import ldm_patched.modules.samplers
import ldm_patched.modules.model_management import ldm_patched.modules.model_management
from collections import namedtuple from collections import namedtuple
from ldm_patched.contrib.external_align_your_steps import AlignYourStepsScheduler
from ldm_patched.contrib.external_custom_sampler import SDTurboScheduler from ldm_patched.contrib.external_custom_sampler import SDTurboScheduler
from ldm_patched.k_diffusion import sampling as k_diffusion_sampling from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
from ldm_patched.modules.samplers import normal_scheduler, simple_scheduler, ddim_scheduler from ldm_patched.modules.samplers import normal_scheduler, simple_scheduler, ddim_scheduler
@ -175,6 +176,9 @@ def calculate_sigmas_scheduler_hacked(model, scheduler_name, steps):
sigmas = normal_scheduler(model, steps, sgm=True) sigmas = normal_scheduler(model, steps, sgm=True)
elif scheduler_name == "turbo": elif scheduler_name == "turbo":
sigmas = SDTurboScheduler().get_sigmas(namedtuple('Patcher', ['model'])(model=model), steps=steps, denoise=1.0)[0] sigmas = SDTurboScheduler().get_sigmas(namedtuple('Patcher', ['model'])(model=model), steps=steps, denoise=1.0)[0]
elif scheduler_name == "align_your_steps":
model_type = 'SDXL' if isinstance(model.latent_format, ldm_patched.modules.latent_formats.SDXL) else 'SD1'
sigmas = AlignYourStepsScheduler().get_sigmas(model_type=model_type, steps=steps, denoise=1.0)[0]
else: else:
raise TypeError("error invalid scheduler") raise TypeError("error invalid scheduler")
return sigmas return sigmas

View File

@ -2,13 +2,12 @@ import os
import re import re
import json import json
import math import math
import modules.config
from modules.util import get_files_from_folder from modules.extra_utils import get_files_from_folder
from random import Random
# cannot use modules.config - validators causing circular imports # cannot use modules.config - validators causing circular imports
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/')) styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
wildcards_max_bfs_depth = 64
def normalize_key(k): def normalize_key(k):
@ -24,7 +23,6 @@ def normalize_key(k):
styles = {} styles = {}
styles_files = get_files_from_folder(styles_path, ['.json']) styles_files = get_files_from_folder(styles_path, ['.json'])
for x in ['sdxl_styles_fooocus.json', for x in ['sdxl_styles_fooocus.json',
@ -50,8 +48,13 @@ for styles_file in styles_files:
print(f'Failed to load style file {styles_file}') print(f'Failed to load style file {styles_file}')
style_keys = list(styles.keys()) style_keys = list(styles.keys())
fooocus_expansion = "Fooocus V2" fooocus_expansion = 'Fooocus V2'
legal_style_names = [fooocus_expansion] + style_keys random_style_name = 'Random Style'
legal_style_names = [fooocus_expansion, random_style_name] + style_keys
def get_random_style(rng: Random) -> str:
return rng.choice(list(styles.items()))[0]
def apply_style(style, positive): def apply_style(style, positive):
@ -59,34 +62,7 @@ def apply_style(style, positive):
return p.replace('{prompt}', positive).splitlines(), n.splitlines() return p.replace('{prompt}', positive).splitlines(), n.splitlines()
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order): def get_words(arrays, total_mult, index):
for _ in range(wildcards_max_bfs_depth):
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
if len(placeholders) == 0:
return wildcard_text
print(f'[Wildcards] processing: {wildcard_text}')
for placeholder in placeholders:
try:
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
words = [x for x in words if x != '']
assert len(words) > 0
if read_wildcards_in_order:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
else:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
except:
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
f'Using "{placeholder}" as a normal word.')
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
print(f'[Wildcards] {wildcard_text}')
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
return wildcard_text
def get_words(arrays, totalMult, index):
if len(arrays) == 1: if len(arrays) == 1:
return [arrays[0].split(',')[index]] return [arrays[0].split(',')[index]]
else: else:
@ -95,7 +71,7 @@ def get_words(arrays, totalMult, index):
index -= index % len(words) index -= index % len(words)
index /= len(words) index /= len(words)
index = math.floor(index) index = math.floor(index)
return [word] + get_words(arrays[1:], math.floor(totalMult/len(words)), index) return [word] + get_words(arrays[1:], math.floor(total_mult / len(words)), index)
def apply_arrays(text, index): def apply_arrays(text, index):

View File

@ -39,7 +39,7 @@ def javascript_html():
head += f'<script type="text/javascript" src="{edit_attention_js_path}"></script>\n' head += f'<script type="text/javascript" src="{edit_attention_js_path}"></script>\n'
head += f'<script type="text/javascript" src="{viewer_js_path}"></script>\n' head += f'<script type="text/javascript" src="{viewer_js_path}"></script>\n'
head += f'<script type="text/javascript" src="{image_viewer_js_path}"></script>\n' head += f'<script type="text/javascript" src="{image_viewer_js_path}"></script>\n'
head += f'<meta name="samples-path" content="{samples_path}"></meta>\n' head += f'<meta name="samples-path" content="{samples_path}">\n'
if args_manager.args.theme: if args_manager.args.theme:
head += f'<script type="text/javascript">set_theme(\"{args_manager.args.theme}\");</script>\n' head += f'<script type="text/javascript">set_theme(\"{args_manager.args.theme}\");</script>\n'

View File

@ -1,4 +1,4 @@
import typing from pathlib import Path
import numpy as np import numpy as np
import datetime import datetime
@ -6,16 +6,27 @@ import random
import math import math
import os import os
import cv2 import cv2
import re
from typing import List, Tuple, AnyStr, NamedTuple
import json import json
import hashlib import hashlib
from PIL import Image from PIL import Image
import modules.config
import modules.sdxl_styles import modules.sdxl_styles
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
# Regexp compiled once. Matches entries with the following pattern:
# <lora:some_lora:1>
# <lora:aNotherLora:-1.6>
LORAS_PROMPT_PATTERN = re.compile(r"(<lora:([^:]+):([+-]?(?:\d+(?:\.\d*)?|\.\d+))>)", re.X)
HASH_SHA256_LENGTH = 10 HASH_SHA256_LENGTH = 10
def erode_or_dilate(x, k): def erode_or_dilate(x, k):
k = int(k) k = int(k)
if k > 0: if k > 0:
@ -163,25 +174,6 @@ def generate_temp_filename(folder='./outputs/', extension='png'):
return date_string, os.path.abspath(result), filename return date_string, os.path.abspath(result), filename
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
if not os.path.isdir(folder_path):
raise ValueError("Folder path is not a valid directory.")
filenames = []
for root, dirs, files in os.walk(folder_path, topdown=False):
relative_path = os.path.relpath(root, folder_path)
if relative_path == ".":
relative_path = ""
for filename in sorted(files, key=lambda s: s.casefold()):
_, file_extension = os.path.splitext(filename)
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
path = os.path.join(relative_path, filename)
filenames.append(path)
return filenames
def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH): def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH):
print(f"Calculating sha256 for {filename}: ", end='') print(f"Calculating sha256 for {filename}: ", end='')
if use_addnet_hash: if use_addnet_hash:
@ -355,7 +347,7 @@ def extract_styles_from_prompt(prompt, negative_prompt):
return list(reversed(extracted)), real_prompt, negative_prompt return list(reversed(extracted)), real_prompt, negative_prompt
class PromptStyle(typing.NamedTuple): class PromptStyle(NamedTuple):
name: str name: str
prompt: str prompt: str
negative_prompt: str negative_prompt: str
@ -370,7 +362,18 @@ def is_json(data: str) -> bool:
return True return True
def get_filname_by_stem(lora_name, filenames: List[str]) -> str | None:
for filename in filenames:
path = Path(filename)
if lora_name == path.stem:
return filename
return None
def get_file_from_folder_list(name, folders): def get_file_from_folder_list(name, folders):
if not isinstance(folders, list):
folders = [folders]
for folder in folders: for folder in folders:
filename = os.path.abspath(os.path.realpath(os.path.join(folder, name))) filename = os.path.abspath(os.path.realpath(os.path.join(folder, name)))
if os.path.isfile(filename): if os.path.isfile(filename):
@ -378,7 +381,6 @@ def get_file_from_folder_list(name, folders):
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name))) return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))
def ordinal_suffix(number: int) -> str: def ordinal_suffix(number: int) -> str:
return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th') return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th')
@ -390,5 +392,111 @@ def makedirs_with_log(path):
print(f'Directory {path} could not be created, reason: {error}') print(f'Directory {path} could not be created, reason: {error}')
def get_enabled_loras(loras: list) -> list: def get_enabled_loras(loras: list, remove_none=True) -> list:
return [[lora[1], lora[2]] for lora in loras if lora[0]] return [(lora[1], lora[2]) for lora in loras if lora[0] and (lora[1] != 'None' if remove_none else True)]
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
found_loras = []
prompt_without_loras = ''
cleaned_prompt = ''
for token in prompt.split(','):
matches = LORAS_PROMPT_PATTERN.findall(token)
if len(matches) == 0:
prompt_without_loras += token + ', '
continue
for match in matches:
lora_name = match[1] + '.safetensors'
if not skip_file_check:
lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special)
if lora_name is not None:
found_loras.append((lora_name, float(match[2])))
token = token.replace(match[0], '')
prompt_without_loras += token + ', '
if prompt_without_loras != '':
cleaned_prompt = prompt_without_loras[:-2]
if prompt_cleanup:
cleaned_prompt = cleanup_prompt(prompt_without_loras)
new_loras = []
lora_names = [lora[0] for lora in loras]
for found_lora in found_loras:
if deduplicate_loras and (found_lora[0] in lora_names or found_lora in new_loras):
continue
new_loras.append(found_lora)
if len(new_loras) == 0:
return loras, cleaned_prompt
updated_loras = []
for lora in loras + new_loras:
if lora[0] != "None":
updated_loras.append(lora)
return updated_loras[:loras_limit], cleaned_prompt
def cleanup_prompt(prompt):
prompt = re.sub(' +', ' ', prompt)
prompt = re.sub(',+', ',', prompt)
cleaned_prompt = ''
for token in prompt.split(','):
token = token.strip()
if token == '':
continue
cleaned_prompt += token + ', '
return cleaned_prompt[:-2]
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str:
for _ in range(modules.config.wildcards_max_bfs_depth):
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
if len(placeholders) == 0:
return wildcard_text
print(f'[Wildcards] processing: {wildcard_text}')
for placeholder in placeholders:
try:
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
words = [x for x in words if x != '']
assert len(words) > 0
if read_wildcards_in_order:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
else:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
except:
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
f'Using "{placeholder}" as a normal word.')
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
print(f'[Wildcards] {wildcard_text}')
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
return wildcard_text
def get_image_size_info(image: np.ndarray, aspect_ratios: list) -> str:
try:
image = Image.fromarray(np.uint8(image))
width, height = image.size
ratio = round(width / height, 2)
gcd = math.gcd(width, height)
lcm_ratio = f'{width // gcd}:{height // gcd}'
size_info = f'Image Size: {width} x {height}, Ratio: {ratio}, {lcm_ratio}'
closest_ratio = min(aspect_ratios, key=lambda x: abs(ratio - float(x.split('*')[0]) / float(x.split('*')[1])))
recommended_width, recommended_height = map(int, closest_ratio.split('*'))
recommended_ratio = round(recommended_width / recommended_height, 2)
recommended_gcd = math.gcd(recommended_width, recommended_height)
recommended_lcm_ratio = f'{recommended_width // recommended_gcd}:{recommended_height // recommended_gcd}'
size_info = f'{width} x {height}, {ratio}, {lcm_ratio}'
size_info += f'\n{recommended_width} x {recommended_height}, {recommended_ratio}, {recommended_lcm_ratio}'
return size_info
except Exception as e:
return f'Error reading image: {e}'

View File

@ -368,6 +368,7 @@ A safer way is just to try "run_anime.bat" or "run_realistic.bat" - they should
entry_with_update.py [-h] [--listen [IP]] [--port PORT] entry_with_update.py [-h] [--listen [IP]] [--port PORT]
[--disable-header-check [ORIGIN]] [--disable-header-check [ORIGIN]]
[--web-upload-size WEB_UPLOAD_SIZE] [--web-upload-size WEB_UPLOAD_SIZE]
[--hf-mirror HF_MIRROR]
[--external-working-path PATH [PATH ...]] [--external-working-path PATH [PATH ...]]
[--output-path OUTPUT_PATH] [--temp-path TEMP_PATH] [--output-path OUTPUT_PATH] [--temp-path TEMP_PATH]
[--cache-path CACHE_PATH] [--in-browser] [--cache-path CACHE_PATH] [--in-browser]

View File

@ -1,5 +1,2 @@
torch==2.0.1 torch==2.1.0
torchvision==0.15.2 torchvision==0.16.0
torchaudio==2.0.2
torchtext==0.15.2
torchdata==0.6.1

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

4
tests/__init__.py Normal file
View File

@ -0,0 +1,4 @@
import sys
import pathlib
sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve())

81
tests/test_utils.py Normal file
View File

@ -0,0 +1,81 @@
import unittest
from modules import util
class TestUtils(unittest.TestCase):
def test_can_parse_tokens_with_lora(self):
test_cases = [
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5, True),
"output": (
[('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'),
},
# Test can not exceed limit
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1, True),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt, very cool, cool'
),
},
# test Loras from UI take precedence over prompt
{
"input": (
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
[("hey-lora.safetensors", 0.4)],
5,
True
),
"output": (
[
('hey-lora.safetensors', 0.4),
('l1.safetensors', 0.4),
('l2.safetensors', -0.2),
('l3.safetensors', 0.3),
('l4.safetensors', 0.5)
],
'some prompt, very cool'
)
},
# test correct matching even if there is no space separating loras in the same token
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3, True),
"output": (
[
('hey-lora.safetensors', 0.4),
('you-lora.safetensors', 0.2)
],
'some prompt, very cool'
),
},
# test deduplication, also selected loras are never overridden with loras in prompt
{
"input": (
"some prompt, very cool, <lora:hey-lora:0.4><lora:hey-lora:0.4><lora:you-lora:0.2>",
[('you-lora.safetensors', 0.3)],
3,
True
),
"output": (
[
('you-lora.safetensors', 0.3),
('hey-lora.safetensors', 0.4)
],
'some prompt, very cool'
),
},
{
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6, True),
"output": (
[],
'<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>'
)
}
]
for test in test_cases:
prompt, loras, loras_limit, skip_file_check = test["input"]
expected = test["output"]
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, skip_file_check=skip_file_check)
self.assertEqual(expected, actual)

View File

@ -1,3 +1,23 @@
# [2.4.0](https://github.com/lllyasviel/Fooocus/releases/tag/v2.4.0)
* Change settings tab elements to be more compact
* Add clip skip slider
* Add select for custom VAE
* Add new style "Random Style"
* Update default anime model to animaPencilXL_v310
* Add button to reconnect the UI after Fooocus crashed without having to configure everything again (no page reload required)
* Add performance "hyper-sd" (based on [Hyper-SDXL 4 step LoRA](https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-4steps-lora.safetensors))
* Add [AlignYourSteps](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) scheduler by Nvidia, see
* Add [TCD](https://github.com/jabir-zheng/TCD) sampler and scheduler (based on sgm_uniform)
* Add NSFW image censoring (disables intermediate image preview while generating). Set config value `default_black_out_nsfw` to True to always enable.
* Add argument `--enable-describe-uov-image` to automatically describe uploaded images for upscaling
* Add inline lora prompt references with subfolder support, example prompt: `colorful bird <lora:toucan:1.2>`
* Add size and aspect ratio recommendation on image describe
* Add inpaint brush color picker, helpful when image and mask brush have the same color
* Add automated Docker image build using Github Actions on each release.
* Add full raw prompts to history logs
* Change code ownership from @lllyasviel to @mashb1t for automated issue / MR notification
# [2.3.1](https://github.com/lllyasviel/Fooocus/releases/tag/2.3.1) # [2.3.1](https://github.com/lllyasviel/Fooocus/releases/tag/2.3.1)
* Remove positive prompt from anime prefix to not reset prompt after switching presets * Remove positive prompt from anime prefix to not reset prompt after switching presets

119
webui.py
View File

@ -123,8 +123,9 @@ with shared.gradio_root:
with gr.Column(scale=3, min_width=0): with gr.Column(scale=3, min_width=0):
generate_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row', elem_id='generate_button', visible=True) generate_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row', elem_id='generate_button', visible=True)
reset_button = gr.Button(label="Reconnect", value="Reconnect", elem_classes='type_row', elem_id='reset_button', visible=False)
load_parameter_button = gr.Button(label="Load Parameters", value="Load Parameters", elem_classes='type_row', elem_id='load_parameter_button', visible=False) load_parameter_button = gr.Button(label="Load Parameters", value="Load Parameters", elem_classes='type_row', elem_id='load_parameter_button', visible=False)
skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', visible=False) skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', elem_id='skip_button', visible=False)
stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False) stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False)
def stop_clicked(currentTask): def stop_clicked(currentTask):
@ -151,7 +152,7 @@ with shared.gradio_root:
with gr.TabItem(label='Upscale or Variation') as uov_tab: with gr.TabItem(label='Upscale or Variation') as uov_tab:
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
uov_input_image = grh.Image(label='Drag above image to here', source='upload', type='numpy') uov_input_image = grh.Image(label='Image', source='upload', type='numpy', show_label=False)
with gr.Column(): with gr.Column():
uov_method = gr.Radio(label='Upscale or Variation:', choices=flags.uov_list, value=flags.disabled) uov_method = gr.Radio(label='Upscale or Variation:', choices=flags.uov_list, value=flags.disabled)
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/390" target="_blank">\U0001F4D4 Document</a>') gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/390" target="_blank">\U0001F4D4 Document</a>')
@ -200,7 +201,7 @@ with shared.gradio_root:
queue=False, show_progress=False) queue=False, show_progress=False)
with gr.TabItem(label='Inpaint or Outpaint') as inpaint_tab: with gr.TabItem(label='Inpaint or Outpaint') as inpaint_tab:
with gr.Row(): with gr.Row():
inpaint_input_image = grh.Image(label='Drag inpaint or outpaint image to here', source='upload', type='numpy', tool='sketch', height=500, brush_color="#FFFFFF", elem_id='inpaint_canvas') inpaint_input_image = grh.Image(label='Image', source='upload', type='numpy', tool='sketch', height=500, brush_color="#FFFFFF", elem_id='inpaint_canvas', show_label=False)
inpaint_mask_image = grh.Image(label='Mask Upload', source='upload', type='numpy', height=500, visible=False) inpaint_mask_image = grh.Image(label='Mask Upload', source='upload', type='numpy', height=500, visible=False)
with gr.Row(): with gr.Row():
@ -213,17 +214,26 @@ with shared.gradio_root:
with gr.TabItem(label='Describe') as desc_tab: with gr.TabItem(label='Describe') as desc_tab:
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
desc_input_image = grh.Image(label='Drag any image to here', source='upload', type='numpy') desc_input_image = grh.Image(label='Image', source='upload', type='numpy', show_label=False)
with gr.Column(): with gr.Column():
desc_method = gr.Radio( desc_method = gr.Radio(
label='Content Type', label='Content Type',
choices=[flags.desc_type_photo, flags.desc_type_anime], choices=[flags.desc_type_photo, flags.desc_type_anime],
value=flags.desc_type_photo) value=flags.desc_type_photo)
desc_btn = gr.Button(value='Describe this Image into Prompt') desc_btn = gr.Button(value='Describe this Image into Prompt')
desc_image_size = gr.Textbox(label='Image Size and Recommended Size', elem_id='desc_image_size', visible=False)
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/1363" target="_blank">\U0001F4D4 Document</a>') gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/1363" target="_blank">\U0001F4D4 Document</a>')
with gr.TabItem(label='Metadata') as load_tab:
def trigger_show_image_properties(image):
value = modules.util.get_image_size_info(image, modules.flags.sdxl_aspect_ratios)
return gr.update(value=value, visible=True)
desc_input_image.upload(trigger_show_image_properties, inputs=desc_input_image,
outputs=desc_image_size, show_progress=False, queue=False)
with gr.TabItem(label='Metadata') as metadata_tab:
with gr.Column(): with gr.Column():
metadata_input_image = grh.Image(label='Drag any image generated by Fooocus here', source='upload', type='filepath') metadata_input_image = grh.Image(label='For images created by Fooocus', source='upload', type='filepath')
metadata_json = gr.JSON(label='Metadata') metadata_json = gr.JSON(label='Metadata')
metadata_import_button = gr.Button(value='Apply Metadata') metadata_import_button = gr.Button(value='Apply Metadata')
@ -254,25 +264,40 @@ with shared.gradio_root:
inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False, _js=down_js, show_progress=False) inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
ip_tab.select(lambda: 'ip', outputs=current_tab, queue=False, _js=down_js, show_progress=False) ip_tab.select(lambda: 'ip', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
desc_tab.select(lambda: 'desc', outputs=current_tab, queue=False, _js=down_js, show_progress=False) desc_tab.select(lambda: 'desc', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
metadata_tab.select(lambda: 'metadata', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column: with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column:
with gr.Tab(label='Setting'): with gr.Tab(label='Setting'):
if not args_manager.args.disable_preset_selection: if not args_manager.args.disable_preset_selection:
preset_selection = gr.Radio(label='Preset', preset_selection = gr.Dropdown(label='Preset',
choices=modules.config.available_presets, choices=modules.config.available_presets,
value=args_manager.args.preset if args_manager.args.preset else "initial", value=args_manager.args.preset if args_manager.args.preset else "initial",
interactive=True) interactive=True)
performance_selection = gr.Radio(label='Performance', performance_selection = gr.Radio(label='Performance',
choices=flags.Performance.list(), choices=flags.Performance.list(),
value=modules.config.default_performance) value=modules.config.default_performance,
aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.config.available_aspect_ratios, elem_classes=['performance_selection'])
value=modules.config.default_aspect_ratio, info='width × height', with gr.Accordion(label='Aspect Ratios', open=False) as aspect_ratios_accordion:
elem_classes='aspect_ratios') aspect_ratios_selection = gr.Radio(label='Aspect Ratios', show_label=False,
choices=modules.config.available_aspect_ratios_labels,
value=modules.config.default_aspect_ratio,
info='width × height',
elem_classes='aspect_ratios')
def change_aspect_ratio(text):
import re
regex = re.compile('<.*?>')
cleaned_text = re.sub(regex, '', text)
return gr.update(label='Aspect Ratios ' + cleaned_text)
aspect_ratios_selection.change(change_aspect_ratio, inputs=aspect_ratios_selection, outputs=aspect_ratios_accordion, queue=False, show_progress=False)
shared.gradio_root.load(change_aspect_ratio, inputs=aspect_ratios_selection, outputs=aspect_ratios_accordion, queue=False, show_progress=False)
image_number = gr.Slider(label='Image Number', minimum=1, maximum=modules.config.default_max_image_number, step=1, value=modules.config.default_image_number) image_number = gr.Slider(label='Image Number', minimum=1, maximum=modules.config.default_max_image_number, step=1, value=modules.config.default_image_number)
output_format = gr.Radio(label='Output Format', output_format = gr.Radio(label='Output Format',
choices=flags.OutputFormat.list(), choices=flags.OutputFormat.list(),
value=modules.config.default_output_format) value=modules.config.default_output_format)
negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.", negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.",
info='Describing what you do not want to see.', lines=2, info='Describing what you do not want to see.', lines=2,
@ -402,10 +427,15 @@ with shared.gradio_root:
value=modules.config.default_cfg_tsnr, value=modules.config.default_cfg_tsnr,
info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR ' info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR '
'(effective when real CFG > mimicked CFG).') '(effective when real CFG > mimicked CFG).')
clip_skip = gr.Slider(label='CLIP Skip', minimum=1, maximum=10, step=1,
value=modules.config.default_clip_skip,
info='Bypass CLIP layers to avoid overfitting (use 1 to disable).')
sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list, sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list,
value=modules.config.default_sampler) value=modules.config.default_sampler)
scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list, scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list,
value=modules.config.default_scheduler) value=modules.config.default_scheduler)
vae_name = gr.Dropdown(label='VAE', choices=[modules.flags.default_vae] + modules.config.vae_filenames,
value=modules.config.default_vae, show_label=True)
generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch', generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch',
info='(Experimental) This may cause performance problems on some computers and certain internet conditions.', info='(Experimental) This may cause performance problems on some computers and certain internet conditions.',
@ -433,7 +463,8 @@ with shared.gradio_root:
overwrite_upscale_strength = gr.Slider(label='Forced Overwrite of Denoising Strength of "Upscale"', overwrite_upscale_strength = gr.Slider(label='Forced Overwrite of Denoising Strength of "Upscale"',
minimum=-1, maximum=1.0, step=0.001, value=-1, minimum=-1, maximum=1.0, step=0.001, value=-1,
info='Set as negative number to disable. For developer debugging.') info='Set as negative number to disable. For developer debugging.')
disable_preview = gr.Checkbox(label='Disable Preview', value=False, disable_preview = gr.Checkbox(label='Disable Preview', value=modules.config.default_black_out_nsfw,
interactive=not modules.config.default_black_out_nsfw,
info='Disable preview during generation.') info='Disable preview during generation.')
disable_intermediate_results = gr.Checkbox(label='Disable Intermediate Results', disable_intermediate_results = gr.Checkbox(label='Disable Intermediate Results',
value=modules.config.default_performance == flags.Performance.EXTREME_SPEED.value, value=modules.config.default_performance == flags.Performance.EXTREME_SPEED.value,
@ -444,6 +475,15 @@ with shared.gradio_root:
value=False) value=False)
read_wildcards_in_order = gr.Checkbox(label="Read wildcards in order", value=False) read_wildcards_in_order = gr.Checkbox(label="Read wildcards in order", value=False)
black_out_nsfw = gr.Checkbox(label='Black Out NSFW',
value=modules.config.default_black_out_nsfw,
interactive=not modules.config.default_black_out_nsfw,
info='Use black image if NSFW is detected.')
black_out_nsfw.change(lambda x: gr.update(value=x, interactive=not x),
inputs=black_out_nsfw, outputs=disable_preview, queue=False,
show_progress=False)
if not args_manager.args.disable_metadata: if not args_manager.args.disable_metadata:
save_metadata_to_images = gr.Checkbox(label='Save Metadata to Images', value=modules.config.default_save_metadata_to_images, save_metadata_to_images = gr.Checkbox(label='Save Metadata to Images', value=modules.config.default_save_metadata_to_images,
info='Adds parameters to generated images allowing manual regeneration.') info='Adds parameters to generated images allowing manual regeneration.')
@ -502,13 +542,20 @@ with shared.gradio_root:
inpaint_mask_upload_checkbox = gr.Checkbox(label='Enable Mask Upload', value=False) inpaint_mask_upload_checkbox = gr.Checkbox(label='Enable Mask Upload', value=False)
invert_mask_checkbox = gr.Checkbox(label='Invert Mask', value=False) invert_mask_checkbox = gr.Checkbox(label='Invert Mask', value=False)
inpaint_mask_color = gr.ColorPicker(label='Inpaint brush color', value='#FFFFFF', elem_id='inpaint_brush_color')
inpaint_ctrls = [debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_ctrls = [debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine,
inpaint_strength, inpaint_respective_field, inpaint_strength, inpaint_respective_field,
inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate] inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate]
inpaint_mask_upload_checkbox.change(lambda x: gr.update(visible=x), inpaint_mask_upload_checkbox.change(lambda x: gr.update(visible=x),
inputs=inpaint_mask_upload_checkbox, inputs=inpaint_mask_upload_checkbox,
outputs=inpaint_mask_image, queue=False, show_progress=False) outputs=inpaint_mask_image, queue=False,
show_progress=False)
inpaint_mask_color.change(lambda x: gr.update(brush_color=x), inputs=inpaint_mask_color,
outputs=inpaint_input_image,
queue=False, show_progress=False)
with gr.Tab(label='FreeU'): with gr.Tab(label='FreeU'):
freeu_enabled = gr.Checkbox(label='Enabled', value=False) freeu_enabled = gr.Checkbox(label='Enabled', value=False)
@ -528,6 +575,7 @@ with shared.gradio_root:
modules.config.update_files() modules.config.update_files()
results = [gr.update(choices=modules.config.model_filenames)] results = [gr.update(choices=modules.config.model_filenames)]
results += [gr.update(choices=['None'] + modules.config.model_filenames)] results += [gr.update(choices=['None'] + modules.config.model_filenames)]
results += [gr.update(choices=['None'] + modules.config.vae_filenames)]
if not args_manager.args.disable_preset_selection: if not args_manager.args.disable_preset_selection:
results += [gr.update(choices=modules.config.available_presets)] results += [gr.update(choices=modules.config.available_presets)]
for i in range(modules.config.default_max_lora_number): for i in range(modules.config.default_max_lora_number):
@ -535,7 +583,7 @@ with shared.gradio_root:
gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()] gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
return results return results
refresh_files_output = [base_model, refiner_model] refresh_files_output = [base_model, refiner_model, vae_name]
if not args_manager.args.disable_preset_selection: if not args_manager.args.disable_preset_selection:
refresh_files_output += [preset_selection] refresh_files_output += [preset_selection]
refresh_files.click(refresh_files_clicked, [], refresh_files_output + lora_ctrls, refresh_files.click(refresh_files_clicked, [], refresh_files_output + lora_ctrls,
@ -546,9 +594,9 @@ with shared.gradio_root:
load_data_outputs = [advanced_checkbox, image_number, prompt, negative_prompt, style_selections, load_data_outputs = [advanced_checkbox, image_number, prompt, negative_prompt, style_selections,
performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection, performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection,
overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive, overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive,
adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model, adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, clip_skip,
refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed, base_model, refiner_model, refiner_switch, sampler_name, scheduler_name, vae_name,
generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls seed_random, image_seed, generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls
if not args_manager.args.disable_preset_selection: if not args_manager.args.disable_preset_selection:
def preset_selection_change(preset, is_generating): def preset_selection_change(preset, is_generating):
@ -570,7 +618,7 @@ with shared.gradio_root:
return modules.meta_parser.load_parameter_button_click(json.dumps(preset_prepared), is_generating) return modules.meta_parser.load_parameter_button_click(json.dumps(preset_prepared), is_generating)
preset_selection.change(preset_selection_change, inputs=[preset_selection, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=True) \ preset_selection.change(preset_selection_change, inputs=[preset_selection, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=True) \
.then(fn=style_sorter.sort_styles, inputs=style_selections, outputs=style_selections, queue=False, show_progress=False) \ .then(fn=style_sorter.sort_styles, inputs=style_selections, outputs=style_selections, queue=False, show_progress=False)
performance_selection.change(lambda x: [gr.update(interactive=not flags.Performance.has_restricted_features(x))] * 11 + performance_selection.change(lambda x: [gr.update(interactive=not flags.Performance.has_restricted_features(x))] * 11 +
[gr.update(visible=not flags.Performance.has_restricted_features(x))] * 1 + [gr.update(visible=not flags.Performance.has_restricted_features(x))] * 1 +
@ -632,9 +680,9 @@ with shared.gradio_root:
ctrls += [input_image_checkbox, current_tab] ctrls += [input_image_checkbox, current_tab]
ctrls += [uov_method, uov_input_image] ctrls += [uov_method, uov_input_image]
ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image] ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image]
ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment] ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment, black_out_nsfw]
ctrls += [adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg] ctrls += [adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, clip_skip]
ctrls += [sampler_name, scheduler_name] ctrls += [sampler_name, scheduler_name, vae_name]
ctrls += [overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength] ctrls += [overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength]
ctrls += [overwrite_upscale_strength, mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint] ctrls += [overwrite_upscale_strength, mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint]
ctrls += [debugging_cn_preprocessor, skipping_cn_preprocessor, canny_low_threshold, canny_high_threshold] ctrls += [debugging_cn_preprocessor, skipping_cn_preprocessor, canny_low_threshold, canny_high_threshold]
@ -688,6 +736,14 @@ with shared.gradio_root:
.then(fn=update_history_link, outputs=history_link) \ .then(fn=update_history_link, outputs=history_link) \
.then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed') .then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed')
reset_button.click(lambda: [worker.AsyncTask(args=[]), False, gr.update(visible=True, interactive=True)] +
[gr.update(visible=False)] * 6 +
[gr.update(visible=True, value=[])],
outputs=[currentTask, state_is_generating, generate_button,
reset_button, stop_button, skip_button,
progress_html, progress_window, progress_gallery, gallery],
queue=False)
for notification_file in ['notification.ogg', 'notification.mp3']: for notification_file in ['notification.ogg', 'notification.mp3']:
if os.path.exists(notification_file): if os.path.exists(notification_file):
gr.Audio(interactive=False, value=notification_file, elem_id='audio_notification', visible=False) gr.Audio(interactive=False, value=notification_file, elem_id='audio_notification', visible=False)
@ -705,6 +761,15 @@ with shared.gradio_root:
desc_btn.click(trigger_describe, inputs=[desc_method, desc_input_image], desc_btn.click(trigger_describe, inputs=[desc_method, desc_input_image],
outputs=[prompt, style_selections], show_progress=True, queue=True) outputs=[prompt, style_selections], show_progress=True, queue=True)
if args_manager.args.enable_describe_uov_image:
def trigger_uov_describe(mode, img, prompt):
# keep prompt if not empty
if prompt == '':
return trigger_describe(mode, img)
return gr.update(), gr.update()
uov_input_image.upload(trigger_uov_describe, inputs=[desc_method, uov_input_image, prompt],
outputs=[prompt, style_selections], show_progress=True, queue=True)
def dump_default_english_config(): def dump_default_english_config():
from modules.localization import dump_english_config from modules.localization import dump_english_config

8
wildcards/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
*.txt
!animal.txt
!artist.txt
!color.txt
!color_flower.txt
!extended-color.txt
!flower.txt
!nationality.txt