fix: process other models than sam when using enhance

This commit is contained in:
Manuel Schmid 2024-06-16 21:05:10 +02:00
parent ff3418876d
commit 9c93c18d0b
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 12 additions and 8 deletions

View File

@ -56,7 +56,7 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
if 'image' in image:
image = image['image']
if mask_model != 'sam' and sam_options is None:
if mask_model != 'sam' or sam_options is None:
result = remove(
image,
session=new_session(mask_model, **extras),
@ -66,8 +66,6 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
assert sam_options is not None
detections, boxes, logits, phrases = default_groundingdino(
image=image,
caption=sam_options.dino_prompt,

View File

@ -118,6 +118,7 @@ class AsyncTask:
enhance_mask_dino_prompt_text = args.pop()
enhance_prompt = args.pop()
enhance_negative_prompt = args.pop()
enhance_mask_model = args.pop()
enhance_mask_sam_model = args.pop()
enhance_mask_text_threshold = args.pop()
enhance_mask_box_threshold = args.pop()
@ -131,6 +132,7 @@ class AsyncTask:
enhance_mask_dino_prompt_text,
enhance_prompt,
enhance_negative_prompt,
enhance_mask_model,
enhance_mask_sam_model,
enhance_mask_text_threshold,
enhance_mask_box_threshold,
@ -1080,9 +1082,11 @@ def worker():
progressbar(async_task, current_progress, 'Processing enhance ...')
for img in imgs:
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls:
print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"')
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, sam_options=SAMOptions(
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls:
if enhance_mask_model == 'sam':
print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"')
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, mask_model=enhance_mask_model, sam_options=SAMOptions(
dino_prompt=enhance_mask_dino_prompt_text,
dino_box_threshold=enhance_mask_box_threshold,
dino_text_threshold=enhance_mask_text_threshold,
@ -1091,7 +1095,8 @@ def worker():
max_num_boxes=enhance_mask_sam_max_num_boxes,
model_type=enhance_mask_sam_model
))
mask = mask[:, :, 0]
if len(mask.shape) == 3:
mask = mask[:, :, 0]
if int(async_task.inpaint_erode_or_dilate) != 0:
mask = erode_or_dilate(mask, async_task.inpaint_erode_or_dilate)
@ -1106,7 +1111,7 @@ def worker():
print(f'[Enhance] {sam_detection_count} segments detected in boxes')
print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask')
if dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0:
if enhance_mask_model == 'sam' and (dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0):
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
continue

View File

@ -416,6 +416,7 @@ with shared.gradio_root:
enhance_mask_dino_prompt_text,
enhance_prompt,
enhance_negative_prompt,
enhance_mask_model,
enhance_mask_sam_model,
enhance_mask_text_threshold,
enhance_mask_box_threshold,