feat: add download for sam models to config

This commit is contained in:
Manuel Schmid 2024-06-10 20:33:40 +02:00
parent 980563de9d
commit ce1fb74270
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 44 additions and 5 deletions

View File

@ -5,6 +5,7 @@ from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.amg import remove_small_regions
from extras.GroundingDINO.util.inference import default_groundingdino
import modules.config
class SAMOptions:
@ -17,7 +18,6 @@ class SAMOptions:
# SAM
max_num_boxes=2,
sam_checkpoint="./models/sam/sam_vit_l_0b3195.pth",
model_type="vit_l"
):
self.dino_prompt = dino_prompt
@ -25,7 +25,6 @@ class SAMOptions:
self.dino_text_threshold = dino_text_threshold
self.box_erode_or_dilate = box_erode_or_dilate
self.max_num_boxes = max_num_boxes
self.sam_checkpoint = sam_checkpoint
self.model_type = model_type
@ -99,7 +98,8 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
# TODO add model patcher for model logic and device management
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[sam_options.model_type](checkpoint=sam_options.sam_checkpoint)
sam_checkpoint = modules.config.download_sam_model(sam_options.model_type)
sam = sam_model_registry[sam_options.model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
sam_predictor = SamPredictor(sam)

View File

@ -202,6 +202,7 @@ path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../mo
path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/')
path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/')
path_safety_checker = get_dir_or_set_default('path_safety_checker', '../models/safety_checker/')
path_sam = get_dir_or_set_default('path_sam', '../models/sam/')
path_outputs = get_path_output()
@ -789,4 +790,43 @@ def downloading_safety_checker_model():
return os.path.join(path_safety_checker, 'stable-diffusion-safety-checker.bin')
def download_sam_model(sam_model: str) -> str:
match sam_model:
case 'default', 'vit_b':
return downloading_sam_vit_b()
case 'vit_l':
return downloading_sam_vit_l()
case 'vit_h':
return downloading_sam_vit_h()
case _:
raise ValueError(f"sam model {sam_model} does not exist.")
def downloading_sam_vit_b():
load_file_from_url(
url='https://huggingface.co/mashb1t/misc/resolve/main/sam_vit_b_01ec64.pth',
model_dir=path_sam,
file_name='sam_vit_b_01ec64.pth'
)
return os.path.join(path_sam, 'sam_vit_b_01ec64.pth')
def downloading_sam_vit_l():
load_file_from_url(
url='https://huggingface.co/mashb1t/misc/resolve/main/sam_vit_l_0b3195.pth',
model_dir=path_sam,
file_name='sam_vit_l_0b3195.pth'
)
return os.path.join(path_sam, 'sam_vit_l_0b3195.pth')
def downloading_sam_vit_h():
load_file_from_url(
url='https://huggingface.co/mashb1t/misc/resolve/main/sam_vit_h_4b8939.pth',
model_dir=path_sam,
file_name='sam_vit_h_4b8939.pth'
)
return os.path.join(path_sam, 'sam_vit_h_4b8939.pth')
update_files()

View File

@ -245,8 +245,7 @@ with shared.gradio_root:
dino_text_threshold=text_threshold,
box_erode_or_dilate=dino_erode_or_dilate,
max_num_boxes=2, #TODO replace with actual value
sam_checkpoint="./models/sam/sam_vit_l_0b3195.pth", # TODO replace with actual value
model_type="vit_l"
model_type=sam_model
)
return generate_mask_from_image(image, mask_model, extras, sam_options)