diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py new file mode 100644 index 00000000..9773eb47 --- /dev/null +++ b/extras/inpaint_mask.py @@ -0,0 +1,15 @@ +from rembg import remove, new_session + + +def generate_mask_from_image(image, mask_model): + if image is None: + return + + if 'image' in image: + image = image['image'] + + return remove( + image, + session=new_session(mask_model), + only_mask=True + ) diff --git a/modules/config.py b/modules/config.py index c7af33db..4149f151 100644 --- a/modules/config.py +++ b/modules/config.py @@ -318,6 +318,12 @@ example_inpaint_prompts = get_config_item_or_set_default( example_inpaint_prompts = [[x] for x in example_inpaint_prompts] +default_inpaint_mask_model = get_config_item_or_set_default( + key='default_inpaint_mask_model', + default_value='isnet-general-use', + validator=lambda x: isinstance(x, str) +) + config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))] possible_preset_keys = [ diff --git a/modules/flags.py b/modules/flags.py index 27f2d716..ffdce78d 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -35,6 +35,10 @@ default_parameters = { inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6'] performance_selections = ['Speed', 'Quality', 'Extreme Speed'] +inpaint_mask_models = [ + 'u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime', 'sam' +] + inpaint_option_default = 'Inpaint or Outpaint (default)' inpaint_option_detail = 'Improve Detail (face, hand, eyes, etc.)' inpaint_option_modify = 'Modify Content (add objects, change background, etc.)' diff --git a/requirements_versions.txt b/requirements_versions.txt index b2111c1f..9cca8fca 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -16,3 +16,4 @@ opencv-contrib-python==4.8.0.74 httpx==0.24.1 onnxruntime==1.16.3 timm==0.9.2 +rembg==2.0.53 \ No newline at end of file diff --git a/webui.py b/webui.py index fadd852a..b4807a06 100644 --- a/webui.py +++ b/webui.py @@ -194,6 +194,19 @@ with shared.gradio_root: inpaint_additional_prompt = gr.Textbox(placeholder="Describe what you want to inpaint.", elem_id='inpaint_additional_prompt', label='Inpaint Additional Prompt', visible=False) outpaint_selections = gr.CheckboxGroup(choices=['Left', 'Right', 'Top', 'Bottom'], value=[], label='Outpaint Direction') inpaint_mode = gr.Dropdown(choices=modules.flags.inpaint_options, value=modules.flags.inpaint_option_default, label='Method') + with gr.Row(visible=False) as inpaint_mask_generation_row: + inpaint_mask_model = gr.Dropdown(label='Mask generation model', + choices=flags.inpaint_mask_models, + value=modules.config.default_inpaint_mask_model, visible=False) + generate_mask_button = gr.Button(value='Generate mask from image', visible=False) + + def generate_mask(image, mask_model): + from extras.inpaint_mask import generate_mask_from_image + return generate_mask_from_image(image, mask_model) + + generate_mask_button.click(fn=generate_mask, inputs=[inpaint_input_image, inpaint_mask_model], + outputs=inpaint_mask_image) + example_inpaint_prompts = gr.Dataset(samples=modules.config.example_inpaint_prompts, label='Additional Prompt Quick List', components=[inpaint_additional_prompt], visible=False) gr.HTML('* Powered by Fooocus Inpaint Engine \U0001F4D4 Document') example_inpaint_prompts.click(lambda x: x[0], inputs=example_inpaint_prompts, outputs=inpaint_additional_prompt, show_progress=False, queue=False) @@ -434,9 +447,11 @@ with shared.gradio_root: inpaint_strength, inpaint_respective_field, inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate] - inpaint_mask_upload_checkbox.change(lambda x: gr.update(visible=x), - inputs=inpaint_mask_upload_checkbox, - outputs=inpaint_mask_image, queue=False, show_progress=False) + inpaint_mask_upload_checkbox.change(lambda x: [gr.update(visible=x)] * 4, + inputs=inpaint_mask_upload_checkbox, + outputs=[inpaint_mask_image, generate_mask_button, + inpaint_mask_model, inpaint_mask_generation_row], + queue=False, show_progress=False) with gr.Tab(label='FreeU'): freeu_enabled = gr.Checkbox(label='Enabled', value=False)