rebase changes of main for easier handling

This commit is contained in:
Manuel Schmid 2023-11-18 12:58:05 +01:00
parent 8f9f020e8f
commit 2e1e82941d
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
5 changed files with 120 additions and 25 deletions

View File

@ -13,7 +13,6 @@ async_tasks = []
def worker():
global async_tasks
import traceback
import math
import numpy as np
@ -35,6 +34,8 @@ def worker():
import fooocus_extras.ip_adapter as ip_adapter
import fooocus_extras.face_crop
from modules.censor import censor_batch
from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion
from modules.private_logger import log
from modules.expansion import safe_str
@ -55,10 +56,14 @@ def worker():
print(f'[Fooocus] {text}')
async_task.yields.append(['preview', (number, text, None)])
def yield_result(async_task, imgs, do_not_show_finished_images=False):
def yield_result(async_task, imgs, do_not_show_finished_images=False, progressbar_index=13):
if not isinstance(imgs, list):
imgs = [imgs]
if modules.config.default_black_out_nsfw:
progressbar(async_task, progressbar_index, 'Checking for NSFW content ...')
imgs = censor_batch(imgs)
async_task.results = async_task.results + imgs
if do_not_show_finished_images:
@ -652,7 +657,7 @@ def worker():
done_steps = current_task_id * steps + step
async_task.yields.append(['preview', (
int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
f'Step {step}/{total_steps} in the {current_task_id + 1}-th Sampling',
f'Sampling Image {current_task_id + 1}/{image_number}, Step {step + 1}/{total_steps} ...',
y)])
for current_task_id, task in enumerate(tasks):
@ -720,11 +725,14 @@ def worker():
d.append((f'LoRA [{n}] weight', w))
log(x, d, single_line_number=3)
yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1)
yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)))
except fcbh.model_management.InterruptProcessingException as e:
if shared.last_stop == 'skip':
print('User skipped')
continue
elif shared.last_stop == 'stop_previous':
print('Previous task stopped')
break
else:
print('User stopped')
break

54
modules/censor.py Normal file
View File

@ -0,0 +1,54 @@
# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py
import numpy as np
import torch
import modules.core as core
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from PIL import Image
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = None
safety_checker = None
def numpy_to_pil(image):
image = (image * 255).round().astype("uint8")
#pil_image = Image.fromarray(image, 'RGB')
pil_image = Image.fromarray(image)
return pil_image
# check and replace nsfw content
def check_safety(x_image):
global safety_feature_extractor, safety_checker
if safety_feature_extractor is None:
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
return x_checked_image, has_nsfw_concept
def censor_single(x):
x_checked_image, has_nsfw_concept = check_safety(x)
# replace image with black pixels, keep dimensions
# workaround due to different numpy / pytorch image matrix format
if has_nsfw_concept[0]:
imageshape = x_checked_image.shape
x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8)
return x_checked_image
def censor_batch(images):
images = [censor_single(image) for image in images]
return images

View File

@ -243,10 +243,15 @@ default_advanced_checkbox = get_config_item_or_set_default(
default_value=False,
validator=lambda x: isinstance(x, bool)
)
default_max_image_number = get_config_item_or_set_default(
key='default_max_image_number',
default_value=4,
validator=lambda x: isinstance(x, int) and x >= 1 and x <= 32
)
default_image_number = get_config_item_or_set_default(
key='default_image_number',
default_value=2,
validator=lambda x: isinstance(x, int) and 1 <= x <= 32
validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number
)
checkpoint_downloads = get_config_item_or_set_default(
key='checkpoint_downloads',
@ -303,6 +308,16 @@ default_overwrite_switch = get_config_item_or_set_default(
default_value=-1,
validator=lambda x: isinstance(x, int)
)
default_black_out_nsfw = get_config_item_or_set_default(
key='default_black_out_nsfw',
default_value=False,
validator=lambda x: isinstance(x, bool)
)
default_hide_preview_if_black_out_nsfw = get_config_item_or_set_default(
key='default_hide_preview_if_black_out_nsfw',
default_value=True,
validator=lambda x: isinstance(x, bool)
)
config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))]

View File

@ -14,4 +14,5 @@ omegaconf==2.2.3
gradio==3.41.2
pygit2==1.12.2
opencv-contrib-python==4.8.0.74
httpx==0.24.1
diffusers==0.21.4
httpx==0.24.1

View File

@ -13,6 +13,7 @@ import modules.gradio_hijack as grh
import modules.advanced_parameters as advanced_parameters
import modules.style_sorter as style_sorter
import args_manager
import fcbh.model_management as model_management
import copy
from modules.sdxl_styles import legal_style_names
@ -20,10 +21,8 @@ from modules.private_logger import get_current_html_path
from modules.ui_gradio_extensions import reload_javascript
from modules.auth import auth_enabled, check_auth
def generate_clicked(*args):
# outputs=[progress_html, progress_window, progress_gallery, gallery]
# worker_outputs=[progress_html, progress_window, progress_gallery, gallery]
execution_start_time = time.perf_counter()
task = worker.AsyncTask(args=list(args))
finished = False
@ -31,7 +30,9 @@ def generate_clicked(*args):
yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Waiting for task to start ...')), \
gr.update(visible=True, value=None), \
gr.update(visible=False, value=None), \
gr.update(visible=False)
gr.update(visible=False), \
gr.update(visible=True, interactive=True), \
gr.update(visible=True, interactive=True)
worker.async_tasks.append(task)
@ -49,19 +50,25 @@ def generate_clicked(*args):
percentage, title, image = product
yield gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \
gr.update(visible=True, value=image) if image is not None else gr.update(), \
gr.update(visible=True, value=image) if image is not None and not (modules.config.default_black_out_nsfw and modules.config.default_hide_preview_if_black_out_nsfw) else gr.update(), \
gr.update(), \
gr.update(visible=False)
gr.update(visible=False), \
gr.update(visible=True), \
gr.update(visible=True)
if flag == 'results':
yield gr.update(visible=True), \
gr.update(visible=True), \
gr.update(visible=True, value=product), \
gr.update(visible=False)
gr.update(visible=False), \
gr.update(visible=True), \
gr.update(visible=True)
if flag == 'finish':
yield gr.update(visible=False), \
gr.update(visible=False), \
gr.update(visible=False), \
gr.update(visible=True, value=product)
gr.update(visible=True, value=product), \
gr.update(visible=False), \
gr.update(visible=False)
finished = True
execution_time = time.perf_counter() - execution_start_time
@ -108,13 +115,11 @@ with shared.gradio_root:
stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False)
def stop_clicked():
import fcbh.model_management as model_management
shared.last_stop = 'stop'
model_management.interrupt_current_processing()
return [gr.update(interactive=False)] * 2
def skip_clicked():
import fcbh.model_management as model_management
shared.last_stop = 'skip'
model_management.interrupt_current_processing()
return
@ -204,7 +209,7 @@ with shared.gradio_root:
aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.config.available_aspect_ratios,
value=modules.config.default_aspect_ratio, info='width × height',
elem_classes='aspect_ratios')
image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=modules.config.default_image_number)
image_number = gr.Slider(label='Image Number', minimum=1, maximum=modules.config.default_max_image_number, step=1, value=modules.config.default_image_number)
negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.",
info='Describing what you do not want to see.', lines=2,
elem_id='negative_prompt',
@ -414,6 +419,24 @@ with shared.gradio_root:
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,
queue=False, show_progress=False)
with gr.Tab(label='Audio'):
play_notification = gr.Checkbox(label='Play notification after rendering', value=False)
notification_file = 'notification.mp3'
if os.path.exists(notification_file):
notification = gr.State(value=notification_file)
notification_input = gr.Audio(label='Notification', interactive=True, elem_id='audio_notification', visible=False, show_edit_button=False)
def play_notification_checked(r, notification):
return gr.update(visible=r, value=notification if r else None)
def notification_input_changed(notification_input, notification):
if notification_input:
notification = notification_input
return notification
play_notification.change(fn=play_notification_checked, inputs=[play_notification, notification], outputs=[notification_input], queue=False)
notification_input.change(fn=notification_input_changed, inputs=[notification_input, notification], outputs=[notification], queue=False)
performance_selection.change(lambda x: [gr.update(interactive=x != 'Extreme Speed')] * 11,
inputs=performance_selection,
outputs=[
@ -437,19 +460,13 @@ with shared.gradio_root:
ctrls += [outpaint_selections, inpaint_input_image]
ctrls += ip_ctrls
generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \
generate_button.click(lambda: (gr.update(visible=True, interactive=False), gr.update(visible=True, interactive=False), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \
.then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \
.then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \
.then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery]) \
.then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery, stop_button, skip_button]) \
.then(lambda: (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)), outputs=[generate_button, stop_button, skip_button]) \
.then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed')
for notification_file in ['notification.ogg', 'notification.mp3']:
if os.path.exists(notification_file):
gr.Audio(interactive=False, value=notification_file, elem_id='audio_notification', visible=False)
break
def dump_default_english_config():
from modules.localization import dump_english_config
dump_english_config(grh.all_components)