From 2e1e82941d0434306f3f4c240a69b12ac38bd2d1 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 18 Nov 2023 12:58:05 +0100 Subject: [PATCH] rebase changes of main for easier handling --- modules/async_worker.py | 16 +++++++++--- modules/censor.py | 54 ++++++++++++++++++++++++++++++++++++++ modules/config.py | 17 +++++++++++- requirements_versions.txt | 3 ++- webui.py | 55 +++++++++++++++++++++++++-------------- 5 files changed, 120 insertions(+), 25 deletions(-) create mode 100644 modules/censor.py diff --git a/modules/async_worker.py b/modules/async_worker.py index b4207591..65618038 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -13,7 +13,6 @@ async_tasks = [] def worker(): global async_tasks - import traceback import math import numpy as np @@ -35,6 +34,8 @@ def worker(): import fooocus_extras.ip_adapter as ip_adapter import fooocus_extras.face_crop + from modules.censor import censor_batch + from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion from modules.private_logger import log from modules.expansion import safe_str @@ -55,10 +56,14 @@ def worker(): print(f'[Fooocus] {text}') async_task.yields.append(['preview', (number, text, None)]) - def yield_result(async_task, imgs, do_not_show_finished_images=False): + def yield_result(async_task, imgs, do_not_show_finished_images=False, progressbar_index=13): if not isinstance(imgs, list): imgs = [imgs] + if modules.config.default_black_out_nsfw: + progressbar(async_task, progressbar_index, 'Checking for NSFW content ...') + imgs = censor_batch(imgs) + async_task.results = async_task.results + imgs if do_not_show_finished_images: @@ -652,7 +657,7 @@ def worker(): done_steps = current_task_id * steps + step async_task.yields.append(['preview', ( int(15.0 + 85.0 * float(done_steps) / float(all_steps)), - f'Step {step}/{total_steps} in the {current_task_id + 1}-th Sampling', + f'Sampling Image {current_task_id + 1}/{image_number}, Step {step + 1}/{total_steps} ...', y)]) for current_task_id, task in enumerate(tasks): @@ -720,11 +725,14 @@ def worker(): d.append((f'LoRA [{n}] weight', w)) log(x, d, single_line_number=3) - yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1) + yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))) except fcbh.model_management.InterruptProcessingException as e: if shared.last_stop == 'skip': print('User skipped') continue + elif shared.last_stop == 'stop_previous': + print('Previous task stopped') + break else: print('User stopped') break diff --git a/modules/censor.py b/modules/censor.py new file mode 100644 index 00000000..fac6db09 --- /dev/null +++ b/modules/censor.py @@ -0,0 +1,54 @@ +# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py + +import numpy as np +import torch +import modules.core as core + +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor +from PIL import Image + +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = None +safety_checker = None + + +def numpy_to_pil(image): + image = (image * 255).round().astype("uint8") + + #pil_image = Image.fromarray(image, 'RGB') + pil_image = Image.fromarray(image) + + return pil_image + + +# check and replace nsfw content +def check_safety(x_image): + global safety_feature_extractor, safety_checker + + if safety_feature_extractor is None: + safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + + return x_checked_image, has_nsfw_concept + + +def censor_single(x): + x_checked_image, has_nsfw_concept = check_safety(x) + + # replace image with black pixels, keep dimensions + # workaround due to different numpy / pytorch image matrix format + if has_nsfw_concept[0]: + imageshape = x_checked_image.shape + x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8) + + return x_checked_image + + +def censor_batch(images): + images = [censor_single(image) for image in images] + + return images diff --git a/modules/config.py b/modules/config.py index 43d79188..24f8bd25 100644 --- a/modules/config.py +++ b/modules/config.py @@ -243,10 +243,15 @@ default_advanced_checkbox = get_config_item_or_set_default( default_value=False, validator=lambda x: isinstance(x, bool) ) +default_max_image_number = get_config_item_or_set_default( + key='default_max_image_number', + default_value=4, + validator=lambda x: isinstance(x, int) and x >= 1 and x <= 32 +) default_image_number = get_config_item_or_set_default( key='default_image_number', default_value=2, - validator=lambda x: isinstance(x, int) and 1 <= x <= 32 + validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number ) checkpoint_downloads = get_config_item_or_set_default( key='checkpoint_downloads', @@ -303,6 +308,16 @@ default_overwrite_switch = get_config_item_or_set_default( default_value=-1, validator=lambda x: isinstance(x, int) ) +default_black_out_nsfw = get_config_item_or_set_default( + key='default_black_out_nsfw', + default_value=False, + validator=lambda x: isinstance(x, bool) +) +default_hide_preview_if_black_out_nsfw = get_config_item_or_set_default( + key='default_hide_preview_if_black_out_nsfw', + default_value=True, + validator=lambda x: isinstance(x, bool) +) config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))] diff --git a/requirements_versions.txt b/requirements_versions.txt index 5d5af5d6..4091e6c0 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -14,4 +14,5 @@ omegaconf==2.2.3 gradio==3.41.2 pygit2==1.12.2 opencv-contrib-python==4.8.0.74 -httpx==0.24.1 +diffusers==0.21.4 +httpx==0.24.1 \ No newline at end of file diff --git a/webui.py b/webui.py index d1381d5e..6134b213 100644 --- a/webui.py +++ b/webui.py @@ -13,6 +13,7 @@ import modules.gradio_hijack as grh import modules.advanced_parameters as advanced_parameters import modules.style_sorter as style_sorter import args_manager +import fcbh.model_management as model_management import copy from modules.sdxl_styles import legal_style_names @@ -20,10 +21,8 @@ from modules.private_logger import get_current_html_path from modules.ui_gradio_extensions import reload_javascript from modules.auth import auth_enabled, check_auth - def generate_clicked(*args): - # outputs=[progress_html, progress_window, progress_gallery, gallery] - + # worker_outputs=[progress_html, progress_window, progress_gallery, gallery] execution_start_time = time.perf_counter() task = worker.AsyncTask(args=list(args)) finished = False @@ -31,7 +30,9 @@ def generate_clicked(*args): yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Waiting for task to start ...')), \ gr.update(visible=True, value=None), \ gr.update(visible=False, value=None), \ - gr.update(visible=False) + gr.update(visible=False), \ + gr.update(visible=True, interactive=True), \ + gr.update(visible=True, interactive=True) worker.async_tasks.append(task) @@ -49,19 +50,25 @@ def generate_clicked(*args): percentage, title, image = product yield gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \ - gr.update(visible=True, value=image) if image is not None else gr.update(), \ + gr.update(visible=True, value=image) if image is not None and not (modules.config.default_black_out_nsfw and modules.config.default_hide_preview_if_black_out_nsfw) else gr.update(), \ gr.update(), \ - gr.update(visible=False) + gr.update(visible=False), \ + gr.update(visible=True), \ + gr.update(visible=True) if flag == 'results': yield gr.update(visible=True), \ gr.update(visible=True), \ gr.update(visible=True, value=product), \ - gr.update(visible=False) + gr.update(visible=False), \ + gr.update(visible=True), \ + gr.update(visible=True) if flag == 'finish': yield gr.update(visible=False), \ gr.update(visible=False), \ gr.update(visible=False), \ - gr.update(visible=True, value=product) + gr.update(visible=True, value=product), \ + gr.update(visible=False), \ + gr.update(visible=False) finished = True execution_time = time.perf_counter() - execution_start_time @@ -108,13 +115,11 @@ with shared.gradio_root: stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False) def stop_clicked(): - import fcbh.model_management as model_management shared.last_stop = 'stop' model_management.interrupt_current_processing() return [gr.update(interactive=False)] * 2 def skip_clicked(): - import fcbh.model_management as model_management shared.last_stop = 'skip' model_management.interrupt_current_processing() return @@ -204,7 +209,7 @@ with shared.gradio_root: aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.config.available_aspect_ratios, value=modules.config.default_aspect_ratio, info='width × height', elem_classes='aspect_ratios') - image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=modules.config.default_image_number) + image_number = gr.Slider(label='Image Number', minimum=1, maximum=modules.config.default_max_image_number, step=1, value=modules.config.default_image_number) negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.", info='Describing what you do not want to see.', lines=2, elem_id='negative_prompt', @@ -414,6 +419,24 @@ with shared.gradio_root: model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls, queue=False, show_progress=False) + with gr.Tab(label='Audio'): + play_notification = gr.Checkbox(label='Play notification after rendering', value=False) + notification_file = 'notification.mp3' + if os.path.exists(notification_file): + notification = gr.State(value=notification_file) + notification_input = gr.Audio(label='Notification', interactive=True, elem_id='audio_notification', visible=False, show_edit_button=False) + + def play_notification_checked(r, notification): + return gr.update(visible=r, value=notification if r else None) + + def notification_input_changed(notification_input, notification): + if notification_input: + notification = notification_input + return notification + + play_notification.change(fn=play_notification_checked, inputs=[play_notification, notification], outputs=[notification_input], queue=False) + notification_input.change(fn=notification_input_changed, inputs=[notification_input, notification], outputs=[notification], queue=False) + performance_selection.change(lambda x: [gr.update(interactive=x != 'Extreme Speed')] * 11, inputs=performance_selection, outputs=[ @@ -437,19 +460,13 @@ with shared.gradio_root: ctrls += [outpaint_selections, inpaint_input_image] ctrls += ip_ctrls - generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \ + generate_button.click(lambda: (gr.update(visible=True, interactive=False), gr.update(visible=True, interactive=False), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \ .then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \ .then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \ - .then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery]) \ + .then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery, stop_button, skip_button]) \ .then(lambda: (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)), outputs=[generate_button, stop_button, skip_button]) \ .then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed') - for notification_file in ['notification.ogg', 'notification.mp3']: - if os.path.exists(notification_file): - gr.Audio(interactive=False, value=notification_file, elem_id='audio_notification', visible=False) - break - - def dump_default_english_config(): from modules.localization import dump_english_config dump_english_config(grh.all_components)