feat: add disclaimer + skipping, stage2 won't properly work when used with inpaint or outpaint
This commit is contained in:
parent
dbc844804b
commit
f8f36828c7
|
|
@ -18,7 +18,7 @@ sam_options = SAMOptions(
|
|||
model_type='vit_b'
|
||||
)
|
||||
|
||||
mask_image = generate_mask_from_image(image, sam_options=sam_options)
|
||||
mask_image, _, _, _ = generate_mask_from_image(image, sam_options=sam_options)
|
||||
|
||||
merged_masks_img = Image.fromarray(mask_image)
|
||||
merged_masks_img.show()
|
||||
|
|
|
|||
|
|
@ -42,9 +42,13 @@ def optimize_masks(masks: torch.Tensor) -> torch.Tensor:
|
|||
|
||||
|
||||
def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None,
|
||||
sam_options: SAMOptions | None = SAMOptions) -> np.ndarray | None:
|
||||
sam_options: SAMOptions | None = SAMOptions) -> tuple[np.ndarray | None, int | None, int | None, int | None]:
|
||||
dino_detection_count = 0
|
||||
sam_detection_count = 0
|
||||
sam_detection_on_mask_count = 0
|
||||
|
||||
if image is None:
|
||||
return
|
||||
return None, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
|
||||
|
||||
if extras is None:
|
||||
extras = {}
|
||||
|
|
@ -53,13 +57,15 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
|
|||
image = image['image']
|
||||
|
||||
if mask_model != 'sam' and sam_options is None:
|
||||
return remove(
|
||||
result = remove(
|
||||
image,
|
||||
session=new_session(mask_model, **extras),
|
||||
only_mask=True,
|
||||
**extras
|
||||
)
|
||||
|
||||
return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
|
||||
|
||||
assert sam_options is not None
|
||||
|
||||
detections, boxes, logits, phrases = default_groundingdino(
|
||||
|
|
@ -80,7 +86,11 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
|
|||
sam_predictor = SamPredictor(sam)
|
||||
final_mask_tensor = torch.zeros((image.shape[0], image.shape[1]))
|
||||
|
||||
if boxes.size(0) > 0:
|
||||
dino_detection_count = boxes.size(0)
|
||||
sam_detection_count = 0
|
||||
sam_detection_on_mask_count = 0
|
||||
|
||||
if dino_detection_count > 0:
|
||||
sam_predictor.set_image(image)
|
||||
|
||||
if sam_options.dino_erode_or_dilate != 0:
|
||||
|
|
@ -97,7 +107,7 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
|
|||
draw = ImageDraw.Draw(debug_dino_image)
|
||||
for box in boxes.numpy():
|
||||
draw.rectangle(box.tolist(), fill="white")
|
||||
return np.array(debug_dino_image)
|
||||
return np.array(debug_dino_image), dino_detection_count, sam_detection_count, sam_detection_on_mask_count
|
||||
|
||||
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
|
||||
masks, _, _ = sam_predictor.predict_torch(
|
||||
|
|
@ -109,12 +119,12 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
|
|||
|
||||
masks = optimize_masks(masks)
|
||||
|
||||
num_obj = min(len(logits), sam_options.max_num_boxes)
|
||||
for obj_ind in range(num_obj):
|
||||
sam_objects = min(len(logits), sam_options.max_num_boxes)
|
||||
for obj_ind in range(sam_objects):
|
||||
mask_tensor = masks[obj_ind][0]
|
||||
final_mask_tensor += mask_tensor
|
||||
|
||||
final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy()
|
||||
mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255
|
||||
mask_image = np.array(mask_image, dtype=np.uint8)
|
||||
return mask_image
|
||||
return mask_image, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
|
||||
|
|
|
|||
|
|
@ -1040,13 +1040,15 @@ def worker():
|
|||
|
||||
# stage2
|
||||
progressbar(async_task, current_progress, 'Processing stage2 ...')
|
||||
final_unet = pipeline.final_unet.clone()
|
||||
if len(async_task.stage2_ctrls) == 0:
|
||||
final_unet = pipeline.final_unet
|
||||
if len(async_task.stage2_ctrls) == 0 or 'inpaint' in goals:
|
||||
print(f'[Stage2] Skipping, preconditions aren\'t met')
|
||||
continue
|
||||
|
||||
for img in imgs:
|
||||
for stage2_mask_dino_prompt_text, stage2_mask_box_threshold, stage2_mask_text_threshold, stage2_mask_sam_max_num_boxes, stage2_mask_sam_model in async_task.stage2_ctrls:
|
||||
mask = generate_mask_from_image(img, sam_options=SAMOptions(
|
||||
print(f'[Stage2] Searching for "{stage2_mask_dino_prompt_text}"')
|
||||
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, sam_options=SAMOptions(
|
||||
dino_prompt=stage2_mask_dino_prompt_text,
|
||||
dino_box_threshold=stage2_mask_box_threshold,
|
||||
dino_text_threshold=stage2_mask_text_threshold,
|
||||
|
|
@ -1060,12 +1062,43 @@ def worker():
|
|||
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
|
||||
# TODO also show do_not_show_finished_images=len(tasks) == 1
|
||||
yield_result(async_task, mask, async_task.black_out_nsfw, False,
|
||||
do_not_show_finished_images=len(tasks) == 1 or async_task.disable_intermediate_results)
|
||||
do_not_show_finished_images=len(
|
||||
tasks) == 1 or async_task.disable_intermediate_results)
|
||||
|
||||
print(f'[Stage2] {dino_detection_count} boxes detected')
|
||||
print(f'[Stage2] {sam_detection_count} segments detected in boxes')
|
||||
print(f'[Stage2] {sam_detection_on_mask_count} segments applied to mask')
|
||||
|
||||
if dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0:
|
||||
print(f'[Stage2] Skipping')
|
||||
continue
|
||||
|
||||
# TODO make configurable
|
||||
|
||||
# # do not apply loras / controlnets / etc. twice (samplers are needed though)
|
||||
# pipeline.final_unet = pipeline.model_base.unet.clone()
|
||||
|
||||
# pipeline.refresh_everything(refiner_model_name=async_task.refiner_model_name,
|
||||
# base_model_name=async_task.base_model_name,
|
||||
# loras=[],
|
||||
# base_model_additional_loras=[],
|
||||
# use_synthetic_refiner=use_synthetic_refiner,
|
||||
# vae_name=async_task.vae_name)
|
||||
# pipeline.set_clip_skip(async_task.clip_skip)
|
||||
#
|
||||
# # patch everything again except original inpainting
|
||||
# if 'cn' in goals:
|
||||
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
|
||||
# if async_task.freeu_enabled:
|
||||
# apply_freeu(async_task)
|
||||
# patch_samplers(async_task)
|
||||
|
||||
# defaults from inpaint mode improve details
|
||||
denoising_strength_stage2 = 0.5
|
||||
inpaint_respective_field_stage2 = 0.0
|
||||
inpaint_head_model_path_stage2 = None
|
||||
inpaint_parameterized_stage2 = False # inpaint_engine = None, improve detail
|
||||
|
||||
goals_stage2 = ['inpaint']
|
||||
denoising_strength_stage2, initial_latent_stage2, width_stage2, height_stage2 = apply_inpaint(
|
||||
async_task, None, inpaint_head_model_path_stage2, img, mask,
|
||||
|
|
@ -1080,7 +1113,6 @@ def worker():
|
|||
# reset and prepare next iteration
|
||||
img = imgs2[0]
|
||||
pipeline.final_unet = final_unet
|
||||
inpaint_worker.current_task = None
|
||||
|
||||
except ldm_patched.modules.model_management.InterruptProcessingException:
|
||||
if async_task.last_stop == 'skip':
|
||||
|
|
|
|||
65
webui.py
65
webui.py
|
|
@ -250,7 +250,9 @@ with shared.gradio_root:
|
|||
model_type=sam_model
|
||||
)
|
||||
|
||||
return generate_mask_from_image(image, mask_model, extras, sam_options)
|
||||
mask, _, _, _ = generate_mask_from_image(image, mask_model, extras, sam_options)
|
||||
|
||||
return mask
|
||||
|
||||
inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')],
|
||||
inputs=inpaint_mask_model,
|
||||
|
|
@ -299,38 +301,41 @@ with shared.gradio_root:
|
|||
outputs=metadata_json, queue=False, show_progress=True)
|
||||
|
||||
with gr.Row(visible=False) as stage2_input_panel:
|
||||
with gr.Tabs():
|
||||
stage2_ctrls = []
|
||||
for index in range(modules.config.default_stage2_tabs):
|
||||
with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item:
|
||||
stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False)
|
||||
with gr.Accordion('Options', visible=True, open=False) as stage2_accordion:
|
||||
# stage2_mode = gr.Dropdown(choices=modules.flags.inpaint_options, value=modules.flags.inpaint_option_detail, label='Method', interactive=True)
|
||||
stage2_mask_dino_prompt_text = gr.Textbox(label='Segmentation prompt', info='Use singular whenever possible', interactive=True)
|
||||
example_stage2_mask_dino_prompt_text = gr.Dataset(samples=modules.config.example_stage2_prompts,
|
||||
label='Additional Prompt Quick List',
|
||||
components=[stage2_mask_dino_prompt_text],
|
||||
visible=True)
|
||||
example_stage2_mask_dino_prompt_text.click(lambda x: x[0], inputs=example_stage2_mask_dino_prompt_text, outputs=stage2_mask_dino_prompt_text, show_progress=False, queue=False)
|
||||
with gr.Column():
|
||||
gr.HTML('DISCLAIMER: Stage2 will be skipped when used in combination with Inpaint or Outpaint!')
|
||||
with gr.Row():
|
||||
with gr.Tabs():
|
||||
stage2_ctrls = []
|
||||
for index in range(modules.config.default_stage2_tabs):
|
||||
with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item:
|
||||
stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False)
|
||||
with gr.Accordion('Options', visible=True, open=False) as stage2_accordion:
|
||||
# stage2_mode = gr.Dropdown(choices=modules.flags.inpaint_options, value=modules.flags.inpaint_option_detail, label='Method', interactive=True)
|
||||
stage2_mask_dino_prompt_text = gr.Textbox(label='Segmentation prompt', info='Use singular whenever possible', interactive=True)
|
||||
example_stage2_mask_dino_prompt_text = gr.Dataset(samples=modules.config.example_stage2_prompts,
|
||||
label='Additional Prompt Quick List',
|
||||
components=[stage2_mask_dino_prompt_text],
|
||||
visible=True)
|
||||
example_stage2_mask_dino_prompt_text.click(lambda x: x[0], inputs=example_stage2_mask_dino_prompt_text, outputs=stage2_mask_dino_prompt_text, show_progress=False, queue=False)
|
||||
|
||||
with gr.Accordion("Advanced options", visible=True, open=False) as inpaint_mask_advanced_options:
|
||||
stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True)
|
||||
stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True)
|
||||
stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True)
|
||||
stage2_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", minimum=1, maximum=5, value=modules.config.default_sam_max_num_boxes, step=1, interactive=True)
|
||||
with gr.Accordion("Advanced options", visible=True, open=False) as inpaint_mask_advanced_options:
|
||||
stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True)
|
||||
stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True)
|
||||
stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True)
|
||||
stage2_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", minimum=1, maximum=5, value=modules.config.default_sam_max_num_boxes, step=1, interactive=True)
|
||||
|
||||
stage2_ctrls += [
|
||||
stage2_enabled,
|
||||
# stage2_mode,
|
||||
stage2_mask_dino_prompt_text,
|
||||
stage2_mask_box_threshold,
|
||||
stage2_mask_text_threshold,
|
||||
stage2_mask_sam_max_num_boxes,
|
||||
stage2_mask_sam_model,
|
||||
]
|
||||
stage2_ctrls += [
|
||||
stage2_enabled,
|
||||
# stage2_mode,
|
||||
stage2_mask_dino_prompt_text,
|
||||
stage2_mask_box_threshold,
|
||||
stage2_mask_text_threshold,
|
||||
stage2_mask_sam_max_num_boxes,
|
||||
stage2_mask_sam_model,
|
||||
]
|
||||
|
||||
stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled,
|
||||
outputs=stage2_accordion, queue=False, show_progress=False)
|
||||
stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled,
|
||||
outputs=stage2_accordion, queue=False, show_progress=False)
|
||||
switch_js = "(x) => {if(x){viewer_to_bottom(100);viewer_to_bottom(500);}else{viewer_to_top();} return x;}"
|
||||
down_js = "() => {viewer_to_bottom();}"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue