diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index 85cd7fc5..71a926d1 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -5,6 +5,7 @@ from segment_anything import sam_model_registry, SamPredictor from segment_anything.utils.amg import remove_small_regions from extras.GroundingDINO.util.inference import default_groundingdino +import modules.config class SAMOptions: @@ -17,7 +18,6 @@ class SAMOptions: # SAM max_num_boxes=2, - sam_checkpoint="./models/sam/sam_vit_l_0b3195.pth", model_type="vit_l" ): self.dino_prompt = dino_prompt @@ -25,7 +25,6 @@ class SAMOptions: self.dino_text_threshold = dino_text_threshold self.box_erode_or_dilate = box_erode_or_dilate self.max_num_boxes = max_num_boxes - self.sam_checkpoint = sam_checkpoint self.model_type = model_type @@ -99,7 +98,8 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= # TODO add model patcher for model logic and device management device = "cuda" if torch.cuda.is_available() else "cpu" - sam = sam_model_registry[sam_options.model_type](checkpoint=sam_options.sam_checkpoint) + sam_checkpoint = modules.config.download_sam_model(sam_options.model_type) + sam = sam_model_registry[sam_options.model_type](checkpoint=sam_checkpoint) sam.to(device=device) sam_predictor = SamPredictor(sam) diff --git a/modules/config.py b/modules/config.py index d7bd2d31..2833413c 100644 --- a/modules/config.py +++ b/modules/config.py @@ -202,6 +202,7 @@ path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../mo path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/') path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/') path_safety_checker = get_dir_or_set_default('path_safety_checker', '../models/safety_checker/') +path_sam = get_dir_or_set_default('path_sam', '../models/sam/') path_outputs = get_path_output() @@ -789,4 +790,43 @@ def downloading_safety_checker_model(): return os.path.join(path_safety_checker, 'stable-diffusion-safety-checker.bin') +def download_sam_model(sam_model: str) -> str: + match sam_model: + case 'default', 'vit_b': + return downloading_sam_vit_b() + case 'vit_l': + return downloading_sam_vit_l() + case 'vit_h': + return downloading_sam_vit_h() + case _: + raise ValueError(f"sam model {sam_model} does not exist.") + + +def downloading_sam_vit_b(): + load_file_from_url( + url='https://huggingface.co/mashb1t/misc/resolve/main/sam_vit_b_01ec64.pth', + model_dir=path_sam, + file_name='sam_vit_b_01ec64.pth' + ) + return os.path.join(path_sam, 'sam_vit_b_01ec64.pth') + + +def downloading_sam_vit_l(): + load_file_from_url( + url='https://huggingface.co/mashb1t/misc/resolve/main/sam_vit_l_0b3195.pth', + model_dir=path_sam, + file_name='sam_vit_l_0b3195.pth' + ) + return os.path.join(path_sam, 'sam_vit_l_0b3195.pth') + + +def downloading_sam_vit_h(): + load_file_from_url( + url='https://huggingface.co/mashb1t/misc/resolve/main/sam_vit_h_4b8939.pth', + model_dir=path_sam, + file_name='sam_vit_h_4b8939.pth' + ) + return os.path.join(path_sam, 'sam_vit_h_4b8939.pth') + + update_files() diff --git a/webui.py b/webui.py index 0d8e2396..a194b4c4 100644 --- a/webui.py +++ b/webui.py @@ -245,8 +245,7 @@ with shared.gradio_root: dino_text_threshold=text_threshold, box_erode_or_dilate=dino_erode_or_dilate, max_num_boxes=2, #TODO replace with actual value - sam_checkpoint="./models/sam/sam_vit_l_0b3195.pth", # TODO replace with actual value - model_type="vit_l" + model_type=sam_model ) return generate_mask_from_image(image, mask_model, extras, sam_options)