From f8f36828c79229b69571a153a6cebcd606d67c35 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Thu, 13 Jun 2024 01:13:21 +0200 Subject: [PATCH] feat: add disclaimer + skipping, stage2 won't properly work when used with inpaint or outpaint --- experiments_mask_generation.py | 2 +- extras/inpaint_mask.py | 26 +++++++++----- modules/async_worker.py | 42 +++++++++++++++++++--- webui.py | 65 ++++++++++++++++++---------------- 4 files changed, 91 insertions(+), 44 deletions(-) diff --git a/experiments_mask_generation.py b/experiments_mask_generation.py index 272adfc5..0f6b960d 100644 --- a/experiments_mask_generation.py +++ b/experiments_mask_generation.py @@ -18,7 +18,7 @@ sam_options = SAMOptions( model_type='vit_b' ) -mask_image = generate_mask_from_image(image, sam_options=sam_options) +mask_image, _, _, _ = generate_mask_from_image(image, sam_options=sam_options) merged_masks_img = Image.fromarray(mask_image) merged_masks_img.show() diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index 7bb671f5..f9025ef2 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -42,9 +42,13 @@ def optimize_masks(masks: torch.Tensor) -> torch.Tensor: def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None, - sam_options: SAMOptions | None = SAMOptions) -> np.ndarray | None: + sam_options: SAMOptions | None = SAMOptions) -> tuple[np.ndarray | None, int | None, int | None, int | None]: + dino_detection_count = 0 + sam_detection_count = 0 + sam_detection_on_mask_count = 0 + if image is None: - return + return None, dino_detection_count, sam_detection_count, sam_detection_on_mask_count if extras is None: extras = {} @@ -53,13 +57,15 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= image = image['image'] if mask_model != 'sam' and sam_options is None: - return remove( + result = remove( image, session=new_session(mask_model, **extras), only_mask=True, **extras ) + return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count + assert sam_options is not None detections, boxes, logits, phrases = default_groundingdino( @@ -80,7 +86,11 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= sam_predictor = SamPredictor(sam) final_mask_tensor = torch.zeros((image.shape[0], image.shape[1])) - if boxes.size(0) > 0: + dino_detection_count = boxes.size(0) + sam_detection_count = 0 + sam_detection_on_mask_count = 0 + + if dino_detection_count > 0: sam_predictor.set_image(image) if sam_options.dino_erode_or_dilate != 0: @@ -97,7 +107,7 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= draw = ImageDraw.Draw(debug_dino_image) for box in boxes.numpy(): draw.rectangle(box.tolist(), fill="white") - return np.array(debug_dino_image) + return np.array(debug_dino_image), dino_detection_count, sam_detection_count, sam_detection_on_mask_count transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2]) masks, _, _ = sam_predictor.predict_torch( @@ -109,12 +119,12 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= masks = optimize_masks(masks) - num_obj = min(len(logits), sam_options.max_num_boxes) - for obj_ind in range(num_obj): + sam_objects = min(len(logits), sam_options.max_num_boxes) + for obj_ind in range(sam_objects): mask_tensor = masks[obj_ind][0] final_mask_tensor += mask_tensor final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy() mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255 mask_image = np.array(mask_image, dtype=np.uint8) - return mask_image + return mask_image, dino_detection_count, sam_detection_count, sam_detection_on_mask_count diff --git a/modules/async_worker.py b/modules/async_worker.py index 3191895f..1c288e16 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -1040,13 +1040,15 @@ def worker(): # stage2 progressbar(async_task, current_progress, 'Processing stage2 ...') - final_unet = pipeline.final_unet.clone() - if len(async_task.stage2_ctrls) == 0: + final_unet = pipeline.final_unet + if len(async_task.stage2_ctrls) == 0 or 'inpaint' in goals: + print(f'[Stage2] Skipping, preconditions aren\'t met') continue for img in imgs: for stage2_mask_dino_prompt_text, stage2_mask_box_threshold, stage2_mask_text_threshold, stage2_mask_sam_max_num_boxes, stage2_mask_sam_model in async_task.stage2_ctrls: - mask = generate_mask_from_image(img, sam_options=SAMOptions( + print(f'[Stage2] Searching for "{stage2_mask_dino_prompt_text}"') + mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, sam_options=SAMOptions( dino_prompt=stage2_mask_dino_prompt_text, dino_box_threshold=stage2_mask_box_threshold, dino_text_threshold=stage2_mask_text_threshold, @@ -1060,12 +1062,43 @@ def worker(): async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)]) # TODO also show do_not_show_finished_images=len(tasks) == 1 yield_result(async_task, mask, async_task.black_out_nsfw, False, - do_not_show_finished_images=len(tasks) == 1 or async_task.disable_intermediate_results) + do_not_show_finished_images=len( + tasks) == 1 or async_task.disable_intermediate_results) + + print(f'[Stage2] {dino_detection_count} boxes detected') + print(f'[Stage2] {sam_detection_count} segments detected in boxes') + print(f'[Stage2] {sam_detection_on_mask_count} segments applied to mask') + + if dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0: + print(f'[Stage2] Skipping') + continue + # TODO make configurable + + # # do not apply loras / controlnets / etc. twice (samplers are needed though) + # pipeline.final_unet = pipeline.model_base.unet.clone() + + # pipeline.refresh_everything(refiner_model_name=async_task.refiner_model_name, + # base_model_name=async_task.base_model_name, + # loras=[], + # base_model_additional_loras=[], + # use_synthetic_refiner=use_synthetic_refiner, + # vae_name=async_task.vae_name) + # pipeline.set_clip_skip(async_task.clip_skip) + # + # # patch everything again except original inpainting + # if 'cn' in goals: + # apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width) + # if async_task.freeu_enabled: + # apply_freeu(async_task) + # patch_samplers(async_task) + + # defaults from inpaint mode improve details denoising_strength_stage2 = 0.5 inpaint_respective_field_stage2 = 0.0 inpaint_head_model_path_stage2 = None inpaint_parameterized_stage2 = False # inpaint_engine = None, improve detail + goals_stage2 = ['inpaint'] denoising_strength_stage2, initial_latent_stage2, width_stage2, height_stage2 = apply_inpaint( async_task, None, inpaint_head_model_path_stage2, img, mask, @@ -1080,7 +1113,6 @@ def worker(): # reset and prepare next iteration img = imgs2[0] pipeline.final_unet = final_unet - inpaint_worker.current_task = None except ldm_patched.modules.model_management.InterruptProcessingException: if async_task.last_stop == 'skip': diff --git a/webui.py b/webui.py index 839cd1db..83a2296e 100644 --- a/webui.py +++ b/webui.py @@ -250,7 +250,9 @@ with shared.gradio_root: model_type=sam_model ) - return generate_mask_from_image(image, mask_model, extras, sam_options) + mask, _, _, _ = generate_mask_from_image(image, mask_model, extras, sam_options) + + return mask inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')], inputs=inpaint_mask_model, @@ -299,38 +301,41 @@ with shared.gradio_root: outputs=metadata_json, queue=False, show_progress=True) with gr.Row(visible=False) as stage2_input_panel: - with gr.Tabs(): - stage2_ctrls = [] - for index in range(modules.config.default_stage2_tabs): - with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item: - stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False) - with gr.Accordion('Options', visible=True, open=False) as stage2_accordion: - # stage2_mode = gr.Dropdown(choices=modules.flags.inpaint_options, value=modules.flags.inpaint_option_detail, label='Method', interactive=True) - stage2_mask_dino_prompt_text = gr.Textbox(label='Segmentation prompt', info='Use singular whenever possible', interactive=True) - example_stage2_mask_dino_prompt_text = gr.Dataset(samples=modules.config.example_stage2_prompts, - label='Additional Prompt Quick List', - components=[stage2_mask_dino_prompt_text], - visible=True) - example_stage2_mask_dino_prompt_text.click(lambda x: x[0], inputs=example_stage2_mask_dino_prompt_text, outputs=stage2_mask_dino_prompt_text, show_progress=False, queue=False) + with gr.Column(): + gr.HTML('DISCLAIMER: Stage2 will be skipped when used in combination with Inpaint or Outpaint!') + with gr.Row(): + with gr.Tabs(): + stage2_ctrls = [] + for index in range(modules.config.default_stage2_tabs): + with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item: + stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False) + with gr.Accordion('Options', visible=True, open=False) as stage2_accordion: + # stage2_mode = gr.Dropdown(choices=modules.flags.inpaint_options, value=modules.flags.inpaint_option_detail, label='Method', interactive=True) + stage2_mask_dino_prompt_text = gr.Textbox(label='Segmentation prompt', info='Use singular whenever possible', interactive=True) + example_stage2_mask_dino_prompt_text = gr.Dataset(samples=modules.config.example_stage2_prompts, + label='Additional Prompt Quick List', + components=[stage2_mask_dino_prompt_text], + visible=True) + example_stage2_mask_dino_prompt_text.click(lambda x: x[0], inputs=example_stage2_mask_dino_prompt_text, outputs=stage2_mask_dino_prompt_text, show_progress=False, queue=False) - with gr.Accordion("Advanced options", visible=True, open=False) as inpaint_mask_advanced_options: - stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True) - stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True) - stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True) - stage2_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", minimum=1, maximum=5, value=modules.config.default_sam_max_num_boxes, step=1, interactive=True) + with gr.Accordion("Advanced options", visible=True, open=False) as inpaint_mask_advanced_options: + stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True) + stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True) + stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True) + stage2_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", minimum=1, maximum=5, value=modules.config.default_sam_max_num_boxes, step=1, interactive=True) - stage2_ctrls += [ - stage2_enabled, - # stage2_mode, - stage2_mask_dino_prompt_text, - stage2_mask_box_threshold, - stage2_mask_text_threshold, - stage2_mask_sam_max_num_boxes, - stage2_mask_sam_model, - ] + stage2_ctrls += [ + stage2_enabled, + # stage2_mode, + stage2_mask_dino_prompt_text, + stage2_mask_box_threshold, + stage2_mask_text_threshold, + stage2_mask_sam_max_num_boxes, + stage2_mask_sam_model, + ] - stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled, - outputs=stage2_accordion, queue=False, show_progress=False) + stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled, + outputs=stage2_accordion, queue=False, show_progress=False) switch_js = "(x) => {if(x){viewer_to_bottom(100);viewer_to_bottom(500);}else{viewer_to_top();} return x;}" down_js = "() => {viewer_to_bottom();}"