feat: add download for sam models to config
This commit is contained in:
parent
980563de9d
commit
ce1fb74270
|
|
@ -5,6 +5,7 @@ from segment_anything import sam_model_registry, SamPredictor
|
|||
from segment_anything.utils.amg import remove_small_regions
|
||||
|
||||
from extras.GroundingDINO.util.inference import default_groundingdino
|
||||
import modules.config
|
||||
|
||||
|
||||
class SAMOptions:
|
||||
|
|
@ -17,7 +18,6 @@ class SAMOptions:
|
|||
|
||||
# SAM
|
||||
max_num_boxes=2,
|
||||
sam_checkpoint="./models/sam/sam_vit_l_0b3195.pth",
|
||||
model_type="vit_l"
|
||||
):
|
||||
self.dino_prompt = dino_prompt
|
||||
|
|
@ -25,7 +25,6 @@ class SAMOptions:
|
|||
self.dino_text_threshold = dino_text_threshold
|
||||
self.box_erode_or_dilate = box_erode_or_dilate
|
||||
self.max_num_boxes = max_num_boxes
|
||||
self.sam_checkpoint = sam_checkpoint
|
||||
self.model_type = model_type
|
||||
|
||||
|
||||
|
|
@ -99,7 +98,8 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
|
|||
# TODO add model patcher for model logic and device management
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
sam = sam_model_registry[sam_options.model_type](checkpoint=sam_options.sam_checkpoint)
|
||||
sam_checkpoint = modules.config.download_sam_model(sam_options.model_type)
|
||||
sam = sam_model_registry[sam_options.model_type](checkpoint=sam_checkpoint)
|
||||
sam.to(device=device)
|
||||
|
||||
sam_predictor = SamPredictor(sam)
|
||||
|
|
|
|||
|
|
@ -202,6 +202,7 @@ path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../mo
|
|||
path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/')
|
||||
path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/')
|
||||
path_safety_checker = get_dir_or_set_default('path_safety_checker', '../models/safety_checker/')
|
||||
path_sam = get_dir_or_set_default('path_sam', '../models/sam/')
|
||||
path_outputs = get_path_output()
|
||||
|
||||
|
||||
|
|
@ -789,4 +790,43 @@ def downloading_safety_checker_model():
|
|||
return os.path.join(path_safety_checker, 'stable-diffusion-safety-checker.bin')
|
||||
|
||||
|
||||
def download_sam_model(sam_model: str) -> str:
|
||||
match sam_model:
|
||||
case 'default', 'vit_b':
|
||||
return downloading_sam_vit_b()
|
||||
case 'vit_l':
|
||||
return downloading_sam_vit_l()
|
||||
case 'vit_h':
|
||||
return downloading_sam_vit_h()
|
||||
case _:
|
||||
raise ValueError(f"sam model {sam_model} does not exist.")
|
||||
|
||||
|
||||
def downloading_sam_vit_b():
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/mashb1t/misc/resolve/main/sam_vit_b_01ec64.pth',
|
||||
model_dir=path_sam,
|
||||
file_name='sam_vit_b_01ec64.pth'
|
||||
)
|
||||
return os.path.join(path_sam, 'sam_vit_b_01ec64.pth')
|
||||
|
||||
|
||||
def downloading_sam_vit_l():
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/mashb1t/misc/resolve/main/sam_vit_l_0b3195.pth',
|
||||
model_dir=path_sam,
|
||||
file_name='sam_vit_l_0b3195.pth'
|
||||
)
|
||||
return os.path.join(path_sam, 'sam_vit_l_0b3195.pth')
|
||||
|
||||
|
||||
def downloading_sam_vit_h():
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/mashb1t/misc/resolve/main/sam_vit_h_4b8939.pth',
|
||||
model_dir=path_sam,
|
||||
file_name='sam_vit_h_4b8939.pth'
|
||||
)
|
||||
return os.path.join(path_sam, 'sam_vit_h_4b8939.pth')
|
||||
|
||||
|
||||
update_files()
|
||||
|
|
|
|||
3
webui.py
3
webui.py
|
|
@ -245,8 +245,7 @@ with shared.gradio_root:
|
|||
dino_text_threshold=text_threshold,
|
||||
box_erode_or_dilate=dino_erode_or_dilate,
|
||||
max_num_boxes=2, #TODO replace with actual value
|
||||
sam_checkpoint="./models/sam/sam_vit_l_0b3195.pth", # TODO replace with actual value
|
||||
model_type="vit_l"
|
||||
model_type=sam_model
|
||||
)
|
||||
|
||||
return generate_mask_from_image(image, mask_model, extras, sam_options)
|
||||
|
|
|
|||
Loading…
Reference in New Issue