From 3873892b0ad61000e540eba60d0bc44f4667c889 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Mon, 10 Jun 2024 20:45:39 +0200 Subject: [PATCH] feat: change default_inpaint_mask_sam_model to match sam model registry --- extras/inpaint_mask.py | 2 +- modules/config.py | 6 +++--- modules/flags.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index 71a926d1..b67b81b3 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -18,7 +18,7 @@ class SAMOptions: # SAM max_num_boxes=2, - model_type="vit_l" + model_type='vit_b' ): self.dino_prompt = dino_prompt self.dino_box_threshold = dino_box_threshold diff --git a/modules/config.py b/modules/config.py index 2833413c..9bd354f5 100644 --- a/modules/config.py +++ b/modules/config.py @@ -551,8 +551,8 @@ default_inpaint_mask_cloth_category = get_config_item_or_set_default( default_inpaint_mask_sam_model = get_config_item_or_set_default( key='default_inpaint_mask_sam_model', - default_value='sam_vit_b_01ec64', - validator=lambda x: x in modules.flags.inpaint_mask_sam_model, + default_value='vit_b', + validator=lambda x: x in [y[1] for y in modules.flags.inpaint_mask_sam_model if y[1] == x], expected_type=str ) @@ -792,7 +792,7 @@ def downloading_safety_checker_model(): def download_sam_model(sam_model: str) -> str: match sam_model: - case 'default', 'vit_b': + case 'vit_b': return downloading_sam_vit_b() case 'vit_l': return downloading_sam_vit_l() diff --git a/modules/flags.py b/modules/flags.py index 6fec3663..ed9a5606 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -76,7 +76,7 @@ output_formats = ['png', 'jpeg', 'webp'] inpaint_mask_models = ['u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime', 'sam'] inpaint_mask_cloth_category = ['full', 'upper', 'lower'] -inpaint_mask_sam_model = ['sam_vit_b_01ec64', 'sam_vit_l_0b3195', 'sam_vit_h_4b8939'] +inpaint_mask_sam_model = [('base', 'vit_b'), ('large', 'vit_l'), ('huge', 'vit_h')] inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6'] inpaint_option_default = 'Inpaint or Outpaint (default)'