diff --git a/modules/async_worker.py b/modules/async_worker.py index a9d45086..944a1d49 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -125,6 +125,7 @@ class AsyncTask: enhance_prompt = args.pop() enhance_negative_prompt = args.pop() enhance_mask_model = args.pop() + enhance_mask_cloth_category = args.pop() enhance_mask_sam_model = args.pop() enhance_mask_text_threshold = args.pop() enhance_mask_box_threshold = args.pop() @@ -141,6 +142,7 @@ class AsyncTask: enhance_prompt, enhance_negative_prompt, enhance_mask_model, + enhance_mask_cloth_category, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, @@ -1334,17 +1336,20 @@ def worker(): break # inpaint for all other tabs - for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_detections, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field, enhance_inpaint_erode_or_dilate, enhance_mask_invert in async_task.enhance_ctrls: + for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_cloth_category, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_detections, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field, enhance_inpaint_erode_or_dilate, enhance_mask_invert in async_task.enhance_ctrls: current_task_id += 1 current_progress = int(base_progress + (100 - preparation_steps) / float(all_steps) * (done_steps_upscaling + done_steps_inpainting)) progressbar(async_task, current_progress, f'Preparing enhancement {current_task_id + 1}/{total_count} ...') enhancement_task_start_time = time.perf_counter() + extras = {} if enhance_mask_model == 'sam': print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"') + elif enhance_mask_model == 'u2net_cloth_seg': + extras['cloth_category'] = enhance_mask_cloth_category mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image( - img, mask_model=enhance_mask_model, sam_options=SAMOptions( + img, mask_model=enhance_mask_model, extras=extras, sam_options=SAMOptions( dino_prompt=enhance_mask_dino_prompt_text, dino_box_threshold=enhance_mask_box_threshold, dino_text_threshold=enhance_mask_text_threshold, diff --git a/webui.py b/webui.py index f9cb41ec..b5229ffe 100644 --- a/webui.py +++ b/webui.py @@ -472,6 +472,7 @@ with shared.gradio_root: enhance_prompt, enhance_negative_prompt, enhance_mask_model, + enhance_mask_cloth_category, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold,