rebase changes of main for easier handling
This commit is contained in:
parent
8f9f020e8f
commit
2e1e82941d
|
|
@ -13,7 +13,6 @@ async_tasks = []
|
|||
|
||||
def worker():
|
||||
global async_tasks
|
||||
|
||||
import traceback
|
||||
import math
|
||||
import numpy as np
|
||||
|
|
@ -35,6 +34,8 @@ def worker():
|
|||
import fooocus_extras.ip_adapter as ip_adapter
|
||||
import fooocus_extras.face_crop
|
||||
|
||||
from modules.censor import censor_batch
|
||||
|
||||
from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion
|
||||
from modules.private_logger import log
|
||||
from modules.expansion import safe_str
|
||||
|
|
@ -55,10 +56,14 @@ def worker():
|
|||
print(f'[Fooocus] {text}')
|
||||
async_task.yields.append(['preview', (number, text, None)])
|
||||
|
||||
def yield_result(async_task, imgs, do_not_show_finished_images=False):
|
||||
def yield_result(async_task, imgs, do_not_show_finished_images=False, progressbar_index=13):
|
||||
if not isinstance(imgs, list):
|
||||
imgs = [imgs]
|
||||
|
||||
if modules.config.default_black_out_nsfw:
|
||||
progressbar(async_task, progressbar_index, 'Checking for NSFW content ...')
|
||||
imgs = censor_batch(imgs)
|
||||
|
||||
async_task.results = async_task.results + imgs
|
||||
|
||||
if do_not_show_finished_images:
|
||||
|
|
@ -652,7 +657,7 @@ def worker():
|
|||
done_steps = current_task_id * steps + step
|
||||
async_task.yields.append(['preview', (
|
||||
int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
|
||||
f'Step {step}/{total_steps} in the {current_task_id + 1}-th Sampling',
|
||||
f'Sampling Image {current_task_id + 1}/{image_number}, Step {step + 1}/{total_steps} ...',
|
||||
y)])
|
||||
|
||||
for current_task_id, task in enumerate(tasks):
|
||||
|
|
@ -720,11 +725,14 @@ def worker():
|
|||
d.append((f'LoRA [{n}] weight', w))
|
||||
log(x, d, single_line_number=3)
|
||||
|
||||
yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1)
|
||||
yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)))
|
||||
except fcbh.model_management.InterruptProcessingException as e:
|
||||
if shared.last_stop == 'skip':
|
||||
print('User skipped')
|
||||
continue
|
||||
elif shared.last_stop == 'stop_previous':
|
||||
print('Previous task stopped')
|
||||
break
|
||||
else:
|
||||
print('User stopped')
|
||||
break
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import modules.core as core
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
from PIL import Image
|
||||
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = None
|
||||
safety_checker = None
|
||||
|
||||
|
||||
def numpy_to_pil(image):
|
||||
image = (image * 255).round().astype("uint8")
|
||||
|
||||
#pil_image = Image.fromarray(image, 'RGB')
|
||||
pil_image = Image.fromarray(image)
|
||||
|
||||
return pil_image
|
||||
|
||||
|
||||
# check and replace nsfw content
|
||||
def check_safety(x_image):
|
||||
global safety_feature_extractor, safety_checker
|
||||
|
||||
if safety_feature_extractor is None:
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
|
||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
return x_checked_image, has_nsfw_concept
|
||||
|
||||
|
||||
def censor_single(x):
|
||||
x_checked_image, has_nsfw_concept = check_safety(x)
|
||||
|
||||
# replace image with black pixels, keep dimensions
|
||||
# workaround due to different numpy / pytorch image matrix format
|
||||
if has_nsfw_concept[0]:
|
||||
imageshape = x_checked_image.shape
|
||||
x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8)
|
||||
|
||||
return x_checked_image
|
||||
|
||||
|
||||
def censor_batch(images):
|
||||
images = [censor_single(image) for image in images]
|
||||
|
||||
return images
|
||||
|
|
@ -243,10 +243,15 @@ default_advanced_checkbox = get_config_item_or_set_default(
|
|||
default_value=False,
|
||||
validator=lambda x: isinstance(x, bool)
|
||||
)
|
||||
default_max_image_number = get_config_item_or_set_default(
|
||||
key='default_max_image_number',
|
||||
default_value=4,
|
||||
validator=lambda x: isinstance(x, int) and x >= 1 and x <= 32
|
||||
)
|
||||
default_image_number = get_config_item_or_set_default(
|
||||
key='default_image_number',
|
||||
default_value=2,
|
||||
validator=lambda x: isinstance(x, int) and 1 <= x <= 32
|
||||
validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number
|
||||
)
|
||||
checkpoint_downloads = get_config_item_or_set_default(
|
||||
key='checkpoint_downloads',
|
||||
|
|
@ -303,6 +308,16 @@ default_overwrite_switch = get_config_item_or_set_default(
|
|||
default_value=-1,
|
||||
validator=lambda x: isinstance(x, int)
|
||||
)
|
||||
default_black_out_nsfw = get_config_item_or_set_default(
|
||||
key='default_black_out_nsfw',
|
||||
default_value=False,
|
||||
validator=lambda x: isinstance(x, bool)
|
||||
)
|
||||
default_hide_preview_if_black_out_nsfw = get_config_item_or_set_default(
|
||||
key='default_hide_preview_if_black_out_nsfw',
|
||||
default_value=True,
|
||||
validator=lambda x: isinstance(x, bool)
|
||||
)
|
||||
|
||||
config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))]
|
||||
|
||||
|
|
|
|||
|
|
@ -14,4 +14,5 @@ omegaconf==2.2.3
|
|||
gradio==3.41.2
|
||||
pygit2==1.12.2
|
||||
opencv-contrib-python==4.8.0.74
|
||||
httpx==0.24.1
|
||||
diffusers==0.21.4
|
||||
httpx==0.24.1
|
||||
55
webui.py
55
webui.py
|
|
@ -13,6 +13,7 @@ import modules.gradio_hijack as grh
|
|||
import modules.advanced_parameters as advanced_parameters
|
||||
import modules.style_sorter as style_sorter
|
||||
import args_manager
|
||||
import fcbh.model_management as model_management
|
||||
import copy
|
||||
|
||||
from modules.sdxl_styles import legal_style_names
|
||||
|
|
@ -20,10 +21,8 @@ from modules.private_logger import get_current_html_path
|
|||
from modules.ui_gradio_extensions import reload_javascript
|
||||
from modules.auth import auth_enabled, check_auth
|
||||
|
||||
|
||||
def generate_clicked(*args):
|
||||
# outputs=[progress_html, progress_window, progress_gallery, gallery]
|
||||
|
||||
# worker_outputs=[progress_html, progress_window, progress_gallery, gallery]
|
||||
execution_start_time = time.perf_counter()
|
||||
task = worker.AsyncTask(args=list(args))
|
||||
finished = False
|
||||
|
|
@ -31,7 +30,9 @@ def generate_clicked(*args):
|
|||
yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Waiting for task to start ...')), \
|
||||
gr.update(visible=True, value=None), \
|
||||
gr.update(visible=False, value=None), \
|
||||
gr.update(visible=False)
|
||||
gr.update(visible=False), \
|
||||
gr.update(visible=True, interactive=True), \
|
||||
gr.update(visible=True, interactive=True)
|
||||
|
||||
worker.async_tasks.append(task)
|
||||
|
||||
|
|
@ -49,19 +50,25 @@ def generate_clicked(*args):
|
|||
|
||||
percentage, title, image = product
|
||||
yield gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \
|
||||
gr.update(visible=True, value=image) if image is not None else gr.update(), \
|
||||
gr.update(visible=True, value=image) if image is not None and not (modules.config.default_black_out_nsfw and modules.config.default_hide_preview_if_black_out_nsfw) else gr.update(), \
|
||||
gr.update(), \
|
||||
gr.update(visible=False)
|
||||
gr.update(visible=False), \
|
||||
gr.update(visible=True), \
|
||||
gr.update(visible=True)
|
||||
if flag == 'results':
|
||||
yield gr.update(visible=True), \
|
||||
gr.update(visible=True), \
|
||||
gr.update(visible=True, value=product), \
|
||||
gr.update(visible=False)
|
||||
gr.update(visible=False), \
|
||||
gr.update(visible=True), \
|
||||
gr.update(visible=True)
|
||||
if flag == 'finish':
|
||||
yield gr.update(visible=False), \
|
||||
gr.update(visible=False), \
|
||||
gr.update(visible=False), \
|
||||
gr.update(visible=True, value=product)
|
||||
gr.update(visible=True, value=product), \
|
||||
gr.update(visible=False), \
|
||||
gr.update(visible=False)
|
||||
finished = True
|
||||
|
||||
execution_time = time.perf_counter() - execution_start_time
|
||||
|
|
@ -108,13 +115,11 @@ with shared.gradio_root:
|
|||
stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False)
|
||||
|
||||
def stop_clicked():
|
||||
import fcbh.model_management as model_management
|
||||
shared.last_stop = 'stop'
|
||||
model_management.interrupt_current_processing()
|
||||
return [gr.update(interactive=False)] * 2
|
||||
|
||||
def skip_clicked():
|
||||
import fcbh.model_management as model_management
|
||||
shared.last_stop = 'skip'
|
||||
model_management.interrupt_current_processing()
|
||||
return
|
||||
|
|
@ -204,7 +209,7 @@ with shared.gradio_root:
|
|||
aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.config.available_aspect_ratios,
|
||||
value=modules.config.default_aspect_ratio, info='width × height',
|
||||
elem_classes='aspect_ratios')
|
||||
image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=modules.config.default_image_number)
|
||||
image_number = gr.Slider(label='Image Number', minimum=1, maximum=modules.config.default_max_image_number, step=1, value=modules.config.default_image_number)
|
||||
negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.",
|
||||
info='Describing what you do not want to see.', lines=2,
|
||||
elem_id='negative_prompt',
|
||||
|
|
@ -414,6 +419,24 @@ with shared.gradio_root:
|
|||
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,
|
||||
queue=False, show_progress=False)
|
||||
|
||||
with gr.Tab(label='Audio'):
|
||||
play_notification = gr.Checkbox(label='Play notification after rendering', value=False)
|
||||
notification_file = 'notification.mp3'
|
||||
if os.path.exists(notification_file):
|
||||
notification = gr.State(value=notification_file)
|
||||
notification_input = gr.Audio(label='Notification', interactive=True, elem_id='audio_notification', visible=False, show_edit_button=False)
|
||||
|
||||
def play_notification_checked(r, notification):
|
||||
return gr.update(visible=r, value=notification if r else None)
|
||||
|
||||
def notification_input_changed(notification_input, notification):
|
||||
if notification_input:
|
||||
notification = notification_input
|
||||
return notification
|
||||
|
||||
play_notification.change(fn=play_notification_checked, inputs=[play_notification, notification], outputs=[notification_input], queue=False)
|
||||
notification_input.change(fn=notification_input_changed, inputs=[notification_input, notification], outputs=[notification], queue=False)
|
||||
|
||||
performance_selection.change(lambda x: [gr.update(interactive=x != 'Extreme Speed')] * 11,
|
||||
inputs=performance_selection,
|
||||
outputs=[
|
||||
|
|
@ -437,19 +460,13 @@ with shared.gradio_root:
|
|||
ctrls += [outpaint_selections, inpaint_input_image]
|
||||
ctrls += ip_ctrls
|
||||
|
||||
generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \
|
||||
generate_button.click(lambda: (gr.update(visible=True, interactive=False), gr.update(visible=True, interactive=False), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \
|
||||
.then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \
|
||||
.then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \
|
||||
.then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery]) \
|
||||
.then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery, stop_button, skip_button]) \
|
||||
.then(lambda: (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)), outputs=[generate_button, stop_button, skip_button]) \
|
||||
.then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed')
|
||||
|
||||
for notification_file in ['notification.ogg', 'notification.mp3']:
|
||||
if os.path.exists(notification_file):
|
||||
gr.Audio(interactive=False, value=notification_file, elem_id='audio_notification', visible=False)
|
||||
break
|
||||
|
||||
|
||||
def dump_default_english_config():
|
||||
from modules.localization import dump_english_config
|
||||
dump_english_config(grh.all_components)
|
||||
|
|
|
|||
Loading…
Reference in New Issue