fix: add missing handling for cloth category for u2net_cloth_seg

This commit is contained in:
Manuel Schmid 2024-06-26 20:21:25 +02:00
parent 358b4bd10a
commit b3a4b4e532
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 8 additions and 2 deletions

View File

@ -125,6 +125,7 @@ class AsyncTask:
enhance_prompt = args.pop()
enhance_negative_prompt = args.pop()
enhance_mask_model = args.pop()
enhance_mask_cloth_category = args.pop()
enhance_mask_sam_model = args.pop()
enhance_mask_text_threshold = args.pop()
enhance_mask_box_threshold = args.pop()
@ -141,6 +142,7 @@ class AsyncTask:
enhance_prompt,
enhance_negative_prompt,
enhance_mask_model,
enhance_mask_cloth_category,
enhance_mask_sam_model,
enhance_mask_text_threshold,
enhance_mask_box_threshold,
@ -1334,17 +1336,20 @@ def worker():
break
# inpaint for all other tabs
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_detections, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field, enhance_inpaint_erode_or_dilate, enhance_mask_invert in async_task.enhance_ctrls:
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_cloth_category, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_detections, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field, enhance_inpaint_erode_or_dilate, enhance_mask_invert in async_task.enhance_ctrls:
current_task_id += 1
current_progress = int(base_progress + (100 - preparation_steps) / float(all_steps) * (done_steps_upscaling + done_steps_inpainting))
progressbar(async_task, current_progress, f'Preparing enhancement {current_task_id + 1}/{total_count} ...')
enhancement_task_start_time = time.perf_counter()
extras = {}
if enhance_mask_model == 'sam':
print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"')
elif enhance_mask_model == 'u2net_cloth_seg':
extras['cloth_category'] = enhance_mask_cloth_category
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(
img, mask_model=enhance_mask_model, sam_options=SAMOptions(
img, mask_model=enhance_mask_model, extras=extras, sam_options=SAMOptions(
dino_prompt=enhance_mask_dino_prompt_text,
dino_box_threshold=enhance_mask_box_threshold,
dino_text_threshold=enhance_mask_text_threshold,

View File

@ -472,6 +472,7 @@ with shared.gradio_root:
enhance_prompt,
enhance_negative_prompt,
enhance_mask_model,
enhance_mask_cloth_category,
enhance_mask_sam_model,
enhance_mask_text_threshold,
enhance_mask_box_threshold,