fix: process other models than sam when using enhance
This commit is contained in:
parent
ff3418876d
commit
9c93c18d0b
|
|
@ -56,7 +56,7 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
|
||||||
if 'image' in image:
|
if 'image' in image:
|
||||||
image = image['image']
|
image = image['image']
|
||||||
|
|
||||||
if mask_model != 'sam' and sam_options is None:
|
if mask_model != 'sam' or sam_options is None:
|
||||||
result = remove(
|
result = remove(
|
||||||
image,
|
image,
|
||||||
session=new_session(mask_model, **extras),
|
session=new_session(mask_model, **extras),
|
||||||
|
|
@ -66,8 +66,6 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
|
||||||
|
|
||||||
return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
|
return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
|
||||||
|
|
||||||
assert sam_options is not None
|
|
||||||
|
|
||||||
detections, boxes, logits, phrases = default_groundingdino(
|
detections, boxes, logits, phrases = default_groundingdino(
|
||||||
image=image,
|
image=image,
|
||||||
caption=sam_options.dino_prompt,
|
caption=sam_options.dino_prompt,
|
||||||
|
|
|
||||||
|
|
@ -118,6 +118,7 @@ class AsyncTask:
|
||||||
enhance_mask_dino_prompt_text = args.pop()
|
enhance_mask_dino_prompt_text = args.pop()
|
||||||
enhance_prompt = args.pop()
|
enhance_prompt = args.pop()
|
||||||
enhance_negative_prompt = args.pop()
|
enhance_negative_prompt = args.pop()
|
||||||
|
enhance_mask_model = args.pop()
|
||||||
enhance_mask_sam_model = args.pop()
|
enhance_mask_sam_model = args.pop()
|
||||||
enhance_mask_text_threshold = args.pop()
|
enhance_mask_text_threshold = args.pop()
|
||||||
enhance_mask_box_threshold = args.pop()
|
enhance_mask_box_threshold = args.pop()
|
||||||
|
|
@ -131,6 +132,7 @@ class AsyncTask:
|
||||||
enhance_mask_dino_prompt_text,
|
enhance_mask_dino_prompt_text,
|
||||||
enhance_prompt,
|
enhance_prompt,
|
||||||
enhance_negative_prompt,
|
enhance_negative_prompt,
|
||||||
|
enhance_mask_model,
|
||||||
enhance_mask_sam_model,
|
enhance_mask_sam_model,
|
||||||
enhance_mask_text_threshold,
|
enhance_mask_text_threshold,
|
||||||
enhance_mask_box_threshold,
|
enhance_mask_box_threshold,
|
||||||
|
|
@ -1080,9 +1082,11 @@ def worker():
|
||||||
progressbar(async_task, current_progress, 'Processing enhance ...')
|
progressbar(async_task, current_progress, 'Processing enhance ...')
|
||||||
|
|
||||||
for img in imgs:
|
for img in imgs:
|
||||||
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls:
|
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls:
|
||||||
print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"')
|
if enhance_mask_model == 'sam':
|
||||||
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, sam_options=SAMOptions(
|
print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"')
|
||||||
|
|
||||||
|
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, mask_model=enhance_mask_model, sam_options=SAMOptions(
|
||||||
dino_prompt=enhance_mask_dino_prompt_text,
|
dino_prompt=enhance_mask_dino_prompt_text,
|
||||||
dino_box_threshold=enhance_mask_box_threshold,
|
dino_box_threshold=enhance_mask_box_threshold,
|
||||||
dino_text_threshold=enhance_mask_text_threshold,
|
dino_text_threshold=enhance_mask_text_threshold,
|
||||||
|
|
@ -1091,7 +1095,8 @@ def worker():
|
||||||
max_num_boxes=enhance_mask_sam_max_num_boxes,
|
max_num_boxes=enhance_mask_sam_max_num_boxes,
|
||||||
model_type=enhance_mask_sam_model
|
model_type=enhance_mask_sam_model
|
||||||
))
|
))
|
||||||
mask = mask[:, :, 0]
|
if len(mask.shape) == 3:
|
||||||
|
mask = mask[:, :, 0]
|
||||||
|
|
||||||
if int(async_task.inpaint_erode_or_dilate) != 0:
|
if int(async_task.inpaint_erode_or_dilate) != 0:
|
||||||
mask = erode_or_dilate(mask, async_task.inpaint_erode_or_dilate)
|
mask = erode_or_dilate(mask, async_task.inpaint_erode_or_dilate)
|
||||||
|
|
@ -1106,7 +1111,7 @@ def worker():
|
||||||
print(f'[Enhance] {sam_detection_count} segments detected in boxes')
|
print(f'[Enhance] {sam_detection_count} segments detected in boxes')
|
||||||
print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask')
|
print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask')
|
||||||
|
|
||||||
if dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0:
|
if enhance_mask_model == 'sam' and (dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0):
|
||||||
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
|
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
1
webui.py
1
webui.py
|
|
@ -416,6 +416,7 @@ with shared.gradio_root:
|
||||||
enhance_mask_dino_prompt_text,
|
enhance_mask_dino_prompt_text,
|
||||||
enhance_prompt,
|
enhance_prompt,
|
||||||
enhance_negative_prompt,
|
enhance_negative_prompt,
|
||||||
|
enhance_mask_model,
|
||||||
enhance_mask_sam_model,
|
enhance_mask_sam_model,
|
||||||
enhance_mask_text_threshold,
|
enhance_mask_text_threshold,
|
||||||
enhance_mask_box_threshold,
|
enhance_mask_box_threshold,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue