feat: add disclaimer + skipping, stage2 won't properly work when used with inpaint or outpaint

This commit is contained in:
Manuel Schmid 2024-06-13 01:13:21 +02:00
parent dbc844804b
commit f8f36828c7
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
4 changed files with 91 additions and 44 deletions

View File

@ -18,7 +18,7 @@ sam_options = SAMOptions(
model_type='vit_b'
)
mask_image = generate_mask_from_image(image, sam_options=sam_options)
mask_image, _, _, _ = generate_mask_from_image(image, sam_options=sam_options)
merged_masks_img = Image.fromarray(mask_image)
merged_masks_img.show()

View File

@ -42,9 +42,13 @@ def optimize_masks(masks: torch.Tensor) -> torch.Tensor:
def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None,
sam_options: SAMOptions | None = SAMOptions) -> np.ndarray | None:
sam_options: SAMOptions | None = SAMOptions) -> tuple[np.ndarray | None, int | None, int | None, int | None]:
dino_detection_count = 0
sam_detection_count = 0
sam_detection_on_mask_count = 0
if image is None:
return
return None, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
if extras is None:
extras = {}
@ -53,13 +57,15 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
image = image['image']
if mask_model != 'sam' and sam_options is None:
return remove(
result = remove(
image,
session=new_session(mask_model, **extras),
only_mask=True,
**extras
)
return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
assert sam_options is not None
detections, boxes, logits, phrases = default_groundingdino(
@ -80,7 +86,11 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
sam_predictor = SamPredictor(sam)
final_mask_tensor = torch.zeros((image.shape[0], image.shape[1]))
if boxes.size(0) > 0:
dino_detection_count = boxes.size(0)
sam_detection_count = 0
sam_detection_on_mask_count = 0
if dino_detection_count > 0:
sam_predictor.set_image(image)
if sam_options.dino_erode_or_dilate != 0:
@ -97,7 +107,7 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
draw = ImageDraw.Draw(debug_dino_image)
for box in boxes.numpy():
draw.rectangle(box.tolist(), fill="white")
return np.array(debug_dino_image)
return np.array(debug_dino_image), dino_detection_count, sam_detection_count, sam_detection_on_mask_count
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
masks, _, _ = sam_predictor.predict_torch(
@ -109,12 +119,12 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
masks = optimize_masks(masks)
num_obj = min(len(logits), sam_options.max_num_boxes)
for obj_ind in range(num_obj):
sam_objects = min(len(logits), sam_options.max_num_boxes)
for obj_ind in range(sam_objects):
mask_tensor = masks[obj_ind][0]
final_mask_tensor += mask_tensor
final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy()
mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255
mask_image = np.array(mask_image, dtype=np.uint8)
return mask_image
return mask_image, dino_detection_count, sam_detection_count, sam_detection_on_mask_count

View File

@ -1040,13 +1040,15 @@ def worker():
# stage2
progressbar(async_task, current_progress, 'Processing stage2 ...')
final_unet = pipeline.final_unet.clone()
if len(async_task.stage2_ctrls) == 0:
final_unet = pipeline.final_unet
if len(async_task.stage2_ctrls) == 0 or 'inpaint' in goals:
print(f'[Stage2] Skipping, preconditions aren\'t met')
continue
for img in imgs:
for stage2_mask_dino_prompt_text, stage2_mask_box_threshold, stage2_mask_text_threshold, stage2_mask_sam_max_num_boxes, stage2_mask_sam_model in async_task.stage2_ctrls:
mask = generate_mask_from_image(img, sam_options=SAMOptions(
print(f'[Stage2] Searching for "{stage2_mask_dino_prompt_text}"')
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, sam_options=SAMOptions(
dino_prompt=stage2_mask_dino_prompt_text,
dino_box_threshold=stage2_mask_box_threshold,
dino_text_threshold=stage2_mask_text_threshold,
@ -1060,12 +1062,43 @@ def worker():
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
# TODO also show do_not_show_finished_images=len(tasks) == 1
yield_result(async_task, mask, async_task.black_out_nsfw, False,
do_not_show_finished_images=len(tasks) == 1 or async_task.disable_intermediate_results)
do_not_show_finished_images=len(
tasks) == 1 or async_task.disable_intermediate_results)
print(f'[Stage2] {dino_detection_count} boxes detected')
print(f'[Stage2] {sam_detection_count} segments detected in boxes')
print(f'[Stage2] {sam_detection_on_mask_count} segments applied to mask')
if dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0:
print(f'[Stage2] Skipping')
continue
# TODO make configurable
# # do not apply loras / controlnets / etc. twice (samplers are needed though)
# pipeline.final_unet = pipeline.model_base.unet.clone()
# pipeline.refresh_everything(refiner_model_name=async_task.refiner_model_name,
# base_model_name=async_task.base_model_name,
# loras=[],
# base_model_additional_loras=[],
# use_synthetic_refiner=use_synthetic_refiner,
# vae_name=async_task.vae_name)
# pipeline.set_clip_skip(async_task.clip_skip)
#
# # patch everything again except original inpainting
# if 'cn' in goals:
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
# if async_task.freeu_enabled:
# apply_freeu(async_task)
# patch_samplers(async_task)
# defaults from inpaint mode improve details
denoising_strength_stage2 = 0.5
inpaint_respective_field_stage2 = 0.0
inpaint_head_model_path_stage2 = None
inpaint_parameterized_stage2 = False # inpaint_engine = None, improve detail
goals_stage2 = ['inpaint']
denoising_strength_stage2, initial_latent_stage2, width_stage2, height_stage2 = apply_inpaint(
async_task, None, inpaint_head_model_path_stage2, img, mask,
@ -1080,7 +1113,6 @@ def worker():
# reset and prepare next iteration
img = imgs2[0]
pipeline.final_unet = final_unet
inpaint_worker.current_task = None
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':

View File

@ -250,7 +250,9 @@ with shared.gradio_root:
model_type=sam_model
)
return generate_mask_from_image(image, mask_model, extras, sam_options)
mask, _, _, _ = generate_mask_from_image(image, mask_model, extras, sam_options)
return mask
inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')],
inputs=inpaint_mask_model,
@ -299,6 +301,9 @@ with shared.gradio_root:
outputs=metadata_json, queue=False, show_progress=True)
with gr.Row(visible=False) as stage2_input_panel:
with gr.Column():
gr.HTML('DISCLAIMER: Stage2 will be skipped when used in combination with Inpaint or Outpaint!')
with gr.Row():
with gr.Tabs():
stage2_ctrls = []
for index in range(modules.config.default_stage2_tabs):