wip: add adetailer
This commit is contained in:
parent
bb72938261
commit
df70294a3e
|
|
@ -0,0 +1,278 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import UserList
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, partial
|
||||
from typing import Any, Literal, NamedTuple, Optional
|
||||
|
||||
try:
|
||||
from pydantic.v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
PositiveInt,
|
||||
confloat,
|
||||
conint,
|
||||
validator,
|
||||
)
|
||||
except ImportError:
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
PositiveInt,
|
||||
confloat,
|
||||
conint,
|
||||
validator,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkipImg2ImgOrig:
|
||||
steps: int
|
||||
sampler_name: str
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class Arg(NamedTuple):
|
||||
attr: str
|
||||
name: str
|
||||
|
||||
|
||||
class ArgsList(UserList):
|
||||
@cached_property
|
||||
def attrs(self) -> tuple[str, ...]:
|
||||
return tuple(attr for attr, _ in self)
|
||||
|
||||
@cached_property
|
||||
def names(self) -> tuple[str, ...]:
|
||||
return tuple(name for _, name in self)
|
||||
|
||||
|
||||
class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
||||
ad_model: str = "None"
|
||||
ad_model_classes: str = ""
|
||||
ad_tap_enable: bool = True
|
||||
ad_prompt: str = ""
|
||||
ad_negative_prompt: str = ""
|
||||
ad_confidence: confloat(ge=0.0, le=1.0) = 0.3
|
||||
ad_mask_k_largest: NonNegativeInt = 0
|
||||
ad_mask_min_ratio: confloat(ge=0.0, le=1.0) = 0.0
|
||||
ad_mask_max_ratio: confloat(ge=0.0, le=1.0) = 1.0
|
||||
ad_dilate_erode: int = 4
|
||||
ad_x_offset: int = 0
|
||||
ad_y_offset: int = 0
|
||||
ad_mask_merge_invert: Literal["None", "Merge", "Merge and Invert"] = "None"
|
||||
ad_mask_blur: NonNegativeInt = 4
|
||||
ad_denoising_strength: confloat(ge=0.0, le=1.0) = 0.4
|
||||
ad_inpaint_only_masked: bool = True
|
||||
ad_inpaint_only_masked_padding: NonNegativeInt = 32
|
||||
ad_use_inpaint_width_height: bool = False
|
||||
ad_inpaint_width: PositiveInt = 512
|
||||
ad_inpaint_height: PositiveInt = 512
|
||||
ad_use_steps: bool = False
|
||||
ad_steps: PositiveInt = 28
|
||||
ad_use_cfg_scale: bool = False
|
||||
ad_cfg_scale: NonNegativeFloat = 7.0
|
||||
ad_use_checkpoint: bool = False
|
||||
ad_checkpoint: Optional[str] = None
|
||||
ad_use_vae: bool = False
|
||||
ad_vae: Optional[str] = None
|
||||
ad_use_sampler: bool = False
|
||||
ad_sampler: str = "DPM++ 2M Karras"
|
||||
ad_scheduler: str = "Use same scheduler"
|
||||
ad_use_noise_multiplier: bool = False
|
||||
ad_noise_multiplier: confloat(ge=0.5, le=1.5) = 1.0
|
||||
ad_use_clip_skip: bool = False
|
||||
ad_clip_skip: conint(ge=1, le=12) = 1
|
||||
ad_restore_face: bool = False
|
||||
ad_controlnet_model: str = "None"
|
||||
ad_controlnet_module: str = "None"
|
||||
ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0
|
||||
ad_controlnet_guidance_start: confloat(ge=0.0, le=1.0) = 0.0
|
||||
ad_controlnet_guidance_end: confloat(ge=0.0, le=1.0) = 1.0
|
||||
is_api: bool = True
|
||||
|
||||
@validator("is_api", pre=True)
|
||||
def is_api_validator(cls, v: Any): # noqa: N805
|
||||
"tuple is json serializable but cannot be made with json deserialize."
|
||||
return type(v) is not tuple
|
||||
|
||||
@staticmethod
|
||||
def ppop(
|
||||
p: dict[str, Any],
|
||||
key: str,
|
||||
pops: list[str] | None = None,
|
||||
cond: Any = None,
|
||||
) -> None:
|
||||
if pops is None:
|
||||
pops = [key]
|
||||
if key not in p:
|
||||
return
|
||||
value = p[key]
|
||||
cond = (not bool(value)) if cond is None else value == cond
|
||||
|
||||
if cond:
|
||||
for k in pops:
|
||||
p.pop(k, None)
|
||||
|
||||
def extra_params(self, suffix: str = "") -> dict[str, Any]:
|
||||
if self.need_skip():
|
||||
return {}
|
||||
|
||||
p = {name: getattr(self, attr) for attr, name in ALL_ARGS}
|
||||
ppop = partial(self.ppop, p)
|
||||
|
||||
ppop("ADetailer model classes")
|
||||
ppop("ADetailer prompt")
|
||||
ppop("ADetailer negative prompt")
|
||||
p.pop("ADetailer tap enable", None) # always pop
|
||||
ppop("ADetailer mask only top k largest", cond=0)
|
||||
ppop("ADetailer mask min ratio", cond=0.0)
|
||||
ppop("ADetailer mask max ratio", cond=1.0)
|
||||
ppop("ADetailer x offset", cond=0)
|
||||
ppop("ADetailer y offset", cond=0)
|
||||
ppop("ADetailer mask merge invert", cond="None")
|
||||
ppop("ADetailer inpaint only masked", ["ADetailer inpaint padding"])
|
||||
ppop(
|
||||
"ADetailer use inpaint width height",
|
||||
[
|
||||
"ADetailer use inpaint width height",
|
||||
"ADetailer inpaint width",
|
||||
"ADetailer inpaint height",
|
||||
],
|
||||
)
|
||||
ppop(
|
||||
"ADetailer use separate steps",
|
||||
["ADetailer use separate steps", "ADetailer steps"],
|
||||
)
|
||||
ppop(
|
||||
"ADetailer use separate CFG scale",
|
||||
["ADetailer use separate CFG scale", "ADetailer CFG scale"],
|
||||
)
|
||||
ppop(
|
||||
"ADetailer use separate checkpoint",
|
||||
["ADetailer use separate checkpoint", "ADetailer checkpoint"],
|
||||
)
|
||||
ppop(
|
||||
"ADetailer use separate VAE",
|
||||
["ADetailer use separate VAE", "ADetailer VAE"],
|
||||
)
|
||||
ppop(
|
||||
"ADetailer use separate sampler",
|
||||
[
|
||||
"ADetailer use separate sampler",
|
||||
"ADetailer sampler",
|
||||
"ADetailer scheduler",
|
||||
],
|
||||
)
|
||||
ppop("ADetailer scheduler", cond="Use same scheduler")
|
||||
ppop(
|
||||
"ADetailer use separate noise multiplier",
|
||||
["ADetailer use separate noise multiplier", "ADetailer noise multiplier"],
|
||||
)
|
||||
|
||||
ppop(
|
||||
"ADetailer use separate CLIP skip",
|
||||
["ADetailer use separate CLIP skip", "ADetailer CLIP skip"],
|
||||
)
|
||||
|
||||
ppop("ADetailer restore face")
|
||||
ppop(
|
||||
"ADetailer ControlNet model",
|
||||
[
|
||||
"ADetailer ControlNet model",
|
||||
"ADetailer ControlNet module",
|
||||
"ADetailer ControlNet weight",
|
||||
"ADetailer ControlNet guidance start",
|
||||
"ADetailer ControlNet guidance end",
|
||||
],
|
||||
cond="None",
|
||||
)
|
||||
ppop("ADetailer ControlNet module", cond="None")
|
||||
ppop("ADetailer ControlNet weight", cond=1.0)
|
||||
ppop("ADetailer ControlNet guidance start", cond=0.0)
|
||||
ppop("ADetailer ControlNet guidance end", cond=1.0)
|
||||
|
||||
if suffix:
|
||||
p = {k + suffix: v for k, v in p.items()}
|
||||
|
||||
return p
|
||||
|
||||
def is_mediapipe(self) -> bool:
|
||||
return self.ad_model.lower().startswith("mediapipe")
|
||||
|
||||
def need_skip(self) -> bool:
|
||||
return self.ad_model == "None" or self.ad_tap_enable is False
|
||||
|
||||
|
||||
_all_args = [
|
||||
("ad_model", "ADetailer model"),
|
||||
("ad_model_classes", "ADetailer model classes"),
|
||||
("ad_tap_enable", "ADetailer tap enable"),
|
||||
("ad_prompt", "ADetailer prompt"),
|
||||
("ad_negative_prompt", "ADetailer negative prompt"),
|
||||
("ad_confidence", "ADetailer confidence"),
|
||||
("ad_mask_k_largest", "ADetailer mask only top k largest"),
|
||||
("ad_mask_min_ratio", "ADetailer mask min ratio"),
|
||||
("ad_mask_max_ratio", "ADetailer mask max ratio"),
|
||||
("ad_x_offset", "ADetailer x offset"),
|
||||
("ad_y_offset", "ADetailer y offset"),
|
||||
("ad_dilate_erode", "ADetailer dilate erode"),
|
||||
("ad_mask_merge_invert", "ADetailer mask merge invert"),
|
||||
("ad_mask_blur", "ADetailer mask blur"),
|
||||
("ad_denoising_strength", "ADetailer denoising strength"),
|
||||
("ad_inpaint_only_masked", "ADetailer inpaint only masked"),
|
||||
("ad_inpaint_only_masked_padding", "ADetailer inpaint padding"),
|
||||
("ad_use_inpaint_width_height", "ADetailer use inpaint width height"),
|
||||
("ad_inpaint_width", "ADetailer inpaint width"),
|
||||
("ad_inpaint_height", "ADetailer inpaint height"),
|
||||
("ad_use_steps", "ADetailer use separate steps"),
|
||||
("ad_steps", "ADetailer steps"),
|
||||
("ad_use_cfg_scale", "ADetailer use separate CFG scale"),
|
||||
("ad_cfg_scale", "ADetailer CFG scale"),
|
||||
("ad_use_checkpoint", "ADetailer use separate checkpoint"),
|
||||
("ad_checkpoint", "ADetailer checkpoint"),
|
||||
("ad_use_vae", "ADetailer use separate VAE"),
|
||||
("ad_vae", "ADetailer VAE"),
|
||||
("ad_use_sampler", "ADetailer use separate sampler"),
|
||||
("ad_sampler", "ADetailer sampler"),
|
||||
("ad_scheduler", "ADetailer scheduler"),
|
||||
("ad_use_noise_multiplier", "ADetailer use separate noise multiplier"),
|
||||
("ad_noise_multiplier", "ADetailer noise multiplier"),
|
||||
("ad_use_clip_skip", "ADetailer use separate CLIP skip"),
|
||||
("ad_clip_skip", "ADetailer CLIP skip"),
|
||||
("ad_restore_face", "ADetailer restore face"),
|
||||
("ad_controlnet_model", "ADetailer ControlNet model"),
|
||||
("ad_controlnet_module", "ADetailer ControlNet module"),
|
||||
("ad_controlnet_weight", "ADetailer ControlNet weight"),
|
||||
("ad_controlnet_guidance_start", "ADetailer ControlNet guidance start"),
|
||||
("ad_controlnet_guidance_end", "ADetailer ControlNet guidance end"),
|
||||
]
|
||||
|
||||
_args = [Arg(*args) for args in _all_args]
|
||||
ALL_ARGS = ArgsList(_args)
|
||||
|
||||
BBOX_SORTBY = [
|
||||
"None",
|
||||
"Position (left to right)",
|
||||
"Position (center to edge)",
|
||||
"Area (large to small)",
|
||||
]
|
||||
MASK_MERGE_INVERT = ["None", "Merge", "Merge and Invert"]
|
||||
|
||||
_script_default = (
|
||||
"dynamic_prompting",
|
||||
"dynamic_thresholding",
|
||||
"wildcard_recursive",
|
||||
"wildcards",
|
||||
"lora_block_weight",
|
||||
"negpip",
|
||||
)
|
||||
SCRIPT_DEFAULT = ",".join(sorted(_script_default))
|
||||
|
||||
_builtin_script = ("soft_inpainting", "hypertile_script")
|
||||
BUILTIN_SCRIPT = ",".join(sorted(_builtin_script))
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image, ImageDraw
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
REPO_ID = "Bingsu/adetailer"
|
||||
|
||||
T = TypeVar("T", int, float)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictOutput(Generic[T]):
|
||||
bboxes: list[list[T]] = field(default_factory=list)
|
||||
masks: list[Image.Image] = field(default_factory=list)
|
||||
preview: Optional[Image.Image] = None
|
||||
|
||||
|
||||
def hf_download(file: str, repo_id: str = REPO_ID) -> str:
|
||||
try:
|
||||
path = hf_hub_download(repo_id, file)
|
||||
except Exception:
|
||||
print(f"[ADetailer] Failed to load model {file!r} from huggingface")
|
||||
path = "INVALID"
|
||||
return path
|
||||
|
||||
|
||||
def safe_mkdir(path: str | os.PathLike[str]) -> None:
|
||||
path = Path(path)
|
||||
if not path.exists() and path.parent.exists() and os.access(path.parent, os.W_OK):
|
||||
path.mkdir()
|
||||
|
||||
|
||||
def scan_model_dir(path: Path) -> list[Path]:
|
||||
if not path.is_dir():
|
||||
return []
|
||||
return [p for p in path.rglob("*") if p.is_file() and p.suffix == ".pt"]
|
||||
|
||||
|
||||
def download_models(*names: str) -> dict[str, str]:
|
||||
models = OrderedDict()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for name in names:
|
||||
if "-world" in name:
|
||||
models[name] = executor.submit(
|
||||
hf_download, name, repo_id="Bingsu/yolo-world-mirror"
|
||||
)
|
||||
else:
|
||||
models[name] = executor.submit(hf_download, name)
|
||||
return {name: future.result() for name, future in models.items()}
|
||||
|
||||
|
||||
def get_models(
|
||||
*dirs: str | os.PathLike[str], huggingface: bool = True
|
||||
) -> OrderedDict[str, str]:
|
||||
model_paths = []
|
||||
|
||||
for dir_ in dirs:
|
||||
if not dir_:
|
||||
continue
|
||||
model_paths.extend(scan_model_dir(Path(dir_)))
|
||||
|
||||
models = OrderedDict()
|
||||
if huggingface:
|
||||
to_download = [
|
||||
"face_yolov8n.pt",
|
||||
"face_yolov8s.pt",
|
||||
"hand_yolov8n.pt",
|
||||
"person_yolov8n-seg.pt",
|
||||
"person_yolov8s-seg.pt",
|
||||
"yolov8x-worldv2.pt",
|
||||
]
|
||||
models.update(download_models(*to_download))
|
||||
|
||||
models.update(
|
||||
{
|
||||
"mediapipe_face_full": "mediapipe_face_full",
|
||||
"mediapipe_face_short": "mediapipe_face_short",
|
||||
"mediapipe_face_mesh": "mediapipe_face_mesh",
|
||||
"mediapipe_face_mesh_eyes_only": "mediapipe_face_mesh_eyes_only",
|
||||
}
|
||||
)
|
||||
|
||||
invalid_keys = [k for k, v in models.items() if v == "INVALID"]
|
||||
for key in invalid_keys:
|
||||
models.pop(key)
|
||||
|
||||
for path in model_paths:
|
||||
if path.name in models:
|
||||
continue
|
||||
models[path.name] = str(path)
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def create_mask_from_bbox(
|
||||
bboxes: list[list[float]], shape: tuple[int, int]
|
||||
) -> list[Image.Image]:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
bboxes: list[list[float]]
|
||||
list of [x1, y1, x2, y2]
|
||||
bounding boxes
|
||||
shape: tuple[int, int]
|
||||
shape of the image (width, height)
|
||||
|
||||
Returns
|
||||
-------
|
||||
masks: list[Image.Image]
|
||||
A list of masks
|
||||
|
||||
"""
|
||||
masks = []
|
||||
for bbox in bboxes:
|
||||
mask = Image.new("L", shape, 0)
|
||||
mask_draw = ImageDraw.Draw(mask)
|
||||
mask_draw.rectangle(bbox, fill=255)
|
||||
masks.append(mask)
|
||||
return masks
|
||||
|
||||
|
||||
def create_bbox_from_mask(
|
||||
masks: list[Image.Image], shape: tuple[int, int]
|
||||
) -> list[list[int]]:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
masks: list[Image.Image]
|
||||
A list of masks
|
||||
shape: tuple[int, int]
|
||||
shape of the image (width, height)
|
||||
|
||||
Returns
|
||||
-------
|
||||
bboxes: list[list[float]]
|
||||
A list of bounding boxes
|
||||
|
||||
"""
|
||||
bboxes = []
|
||||
for mask in masks:
|
||||
mask = mask.resize(shape)
|
||||
bbox = mask.getbbox()
|
||||
if bbox is not None:
|
||||
bboxes.append(list(bbox))
|
||||
return bboxes
|
||||
|
||||
|
||||
def ensure_pil_image(image: Any, mode: str = "RGB") -> Image.Image:
|
||||
if not isinstance(image, Image.Image):
|
||||
image = to_pil_image(image)
|
||||
if image.mode != mode:
|
||||
image = image.convert(mode)
|
||||
return image
|
||||
|
|
@ -0,0 +1,269 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from functools import partial, reduce
|
||||
from math import dist
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from extras.adetailer.args import MASK_MERGE_INVERT
|
||||
from extras.adetailer.common import ensure_pil_image, PredictOutput
|
||||
|
||||
|
||||
class SortBy(IntEnum):
|
||||
NONE = 0
|
||||
LEFT_TO_RIGHT = 1
|
||||
CENTER_TO_EDGE = 2
|
||||
AREA = 3
|
||||
|
||||
|
||||
class MergeInvert(IntEnum):
|
||||
NONE = 0
|
||||
MERGE = 1
|
||||
MERGE_INVERT = 2
|
||||
|
||||
|
||||
T = TypeVar("T", int, float)
|
||||
|
||||
|
||||
def _dilate(arr: np.ndarray, value: int) -> np.ndarray:
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
|
||||
return cv2.dilate(arr, kernel, iterations=1)
|
||||
|
||||
|
||||
def _erode(arr: np.ndarray, value: int) -> np.ndarray:
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
|
||||
return cv2.erode(arr, kernel, iterations=1)
|
||||
|
||||
|
||||
def dilate_erode(img: Image.Image, value: int) -> Image.Image:
|
||||
"""
|
||||
The dilate_erode function takes an image and a value.
|
||||
If the value is positive, it dilates the image by that amount.
|
||||
If the value is negative, it erodes the image by that amount.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
img: PIL.Image.Image
|
||||
the image to be processed
|
||||
value: int
|
||||
kernel size of dilation or erosion
|
||||
|
||||
Returns
|
||||
-------
|
||||
PIL.Image.Image
|
||||
The image that has been dilated or eroded
|
||||
"""
|
||||
if value == 0:
|
||||
return img
|
||||
|
||||
arr = np.array(img)
|
||||
arr = _dilate(arr, value) if value > 0 else _erode(arr, -value)
|
||||
|
||||
return Image.fromarray(arr)
|
||||
|
||||
|
||||
def offset(img: Image.Image, x: int = 0, y: int = 0) -> Image.Image:
|
||||
"""
|
||||
The offset function takes an image and offsets it by a given x(→) and y(↑) value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mask: Image.Image
|
||||
Pass the mask image to the function
|
||||
x: int
|
||||
→
|
||||
y: int
|
||||
↑
|
||||
|
||||
Returns
|
||||
-------
|
||||
PIL.Image.Image
|
||||
A new image that is offset by x and y
|
||||
"""
|
||||
return ImageChops.offset(img, x, -y)
|
||||
|
||||
|
||||
def is_all_black(img: Image.Image | np.ndarray) -> bool:
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(ensure_pil_image(img, "L"))
|
||||
return cv2.countNonZero(img) == 0
|
||||
|
||||
|
||||
def has_intersection(im1: Any, im2: Any) -> bool:
|
||||
arr1 = np.array(ensure_pil_image(im1, "L"))
|
||||
arr2 = np.array(ensure_pil_image(im2, "L"))
|
||||
return not is_all_black(cv2.bitwise_and(arr1, arr2))
|
||||
|
||||
|
||||
def bbox_area(bbox: list[T]) -> T:
|
||||
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
||||
|
||||
|
||||
def mask_preprocess(
|
||||
masks: list[Image.Image],
|
||||
kernel: int = 0,
|
||||
x_offset: int = 0,
|
||||
y_offset: int = 0,
|
||||
merge_invert: int | MergeInvert | str = MergeInvert.NONE,
|
||||
) -> list[Image.Image]:
|
||||
"""
|
||||
The mask_preprocess function takes a list of masks and preprocesses them.
|
||||
It dilates and erodes the masks, and offsets them by x_offset and y_offset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
masks: list[Image.Image]
|
||||
A list of masks
|
||||
kernel: int
|
||||
kernel size of dilation or erosion
|
||||
x_offset: int
|
||||
→
|
||||
y_offset: int
|
||||
↑
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Image.Image]
|
||||
A list of processed masks
|
||||
"""
|
||||
if not masks:
|
||||
return []
|
||||
|
||||
if x_offset != 0 or y_offset != 0:
|
||||
masks = [offset(m, x_offset, y_offset) for m in masks]
|
||||
|
||||
if kernel != 0:
|
||||
masks = [dilate_erode(m, kernel) for m in masks]
|
||||
masks = [m for m in masks if not is_all_black(m)]
|
||||
|
||||
return mask_merge_invert(masks, mode=merge_invert)
|
||||
|
||||
|
||||
# Bbox sorting
|
||||
def _key_left_to_right(bbox: list[T]) -> T:
|
||||
"""
|
||||
Left to right
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox: list[int] | list[float]
|
||||
list of [x1, y1, x2, y2]
|
||||
"""
|
||||
return bbox[0]
|
||||
|
||||
|
||||
def _key_center_to_edge(bbox: list[T], *, center: tuple[float, float]) -> float:
|
||||
"""
|
||||
Center to edge
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox: list[int] | list[float]
|
||||
list of [x1, y1, x2, y2]
|
||||
image: Image.Image
|
||||
the image
|
||||
"""
|
||||
bbox_center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
|
||||
return dist(center, bbox_center)
|
||||
|
||||
|
||||
def _key_area(bbox: list[T]) -> T:
|
||||
"""
|
||||
Large to small
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox: list[int] | list[float]
|
||||
list of [x1, y1, x2, y2]
|
||||
"""
|
||||
return -bbox_area(bbox)
|
||||
|
||||
|
||||
def sort_bboxes(
|
||||
pred: PredictOutput[T], order: int | SortBy = SortBy.NONE
|
||||
) -> PredictOutput[T]:
|
||||
if order == SortBy.NONE or len(pred.bboxes) <= 1:
|
||||
return pred
|
||||
|
||||
if order == SortBy.LEFT_TO_RIGHT:
|
||||
key = _key_left_to_right
|
||||
elif order == SortBy.CENTER_TO_EDGE:
|
||||
width, height = pred.preview.size
|
||||
center = (width / 2, height / 2)
|
||||
key = partial(_key_center_to_edge, center=center)
|
||||
elif order == SortBy.AREA:
|
||||
key = _key_area
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
items = len(pred.bboxes)
|
||||
idx = sorted(range(items), key=lambda i: key(pred.bboxes[i]))
|
||||
pred.bboxes = [pred.bboxes[i] for i in idx]
|
||||
pred.masks = [pred.masks[i] for i in idx]
|
||||
return pred
|
||||
|
||||
|
||||
# Filter by ratio
|
||||
def is_in_ratio(bbox: list[T], low: float, high: float, orig_area: int) -> bool:
|
||||
area = bbox_area(bbox)
|
||||
return low <= area / orig_area <= high
|
||||
|
||||
|
||||
def filter_by_ratio(
|
||||
pred: PredictOutput[T], low: float, high: float
|
||||
) -> PredictOutput[T]:
|
||||
if not pred.bboxes:
|
||||
return pred
|
||||
|
||||
w, h = pred.preview.size
|
||||
orig_area = w * h
|
||||
items = len(pred.bboxes)
|
||||
idx = [i for i in range(items) if is_in_ratio(pred.bboxes[i], low, high, orig_area)]
|
||||
pred.bboxes = [pred.bboxes[i] for i in idx]
|
||||
pred.masks = [pred.masks[i] for i in idx]
|
||||
return pred
|
||||
|
||||
|
||||
def filter_k_largest(pred: PredictOutput[T], k: int = 0) -> PredictOutput[T]:
|
||||
if not pred.bboxes or k == 0:
|
||||
return pred
|
||||
areas = [bbox_area(bbox) for bbox in pred.bboxes]
|
||||
idx = np.argsort(areas)[-k:]
|
||||
idx = idx[::-1]
|
||||
pred.bboxes = [pred.bboxes[i] for i in idx]
|
||||
pred.masks = [pred.masks[i] for i in idx]
|
||||
return pred
|
||||
|
||||
|
||||
# Merge / Invert
|
||||
def mask_merge(masks: list[Image.Image]) -> list[Image.Image]:
|
||||
arrs = [np.array(m) for m in masks]
|
||||
arr = reduce(cv2.bitwise_or, arrs)
|
||||
return [Image.fromarray(arr)]
|
||||
|
||||
|
||||
def mask_invert(masks: list[Image.Image]) -> list[Image.Image]:
|
||||
return [ImageChops.invert(m) for m in masks]
|
||||
|
||||
|
||||
def mask_merge_invert(
|
||||
masks: list[Image.Image], mode: int | MergeInvert | str
|
||||
) -> list[Image.Image]:
|
||||
if isinstance(mode, str):
|
||||
mode = MASK_MERGE_INVERT.index(mode)
|
||||
|
||||
if mode == MergeInvert.NONE or not masks:
|
||||
return masks
|
||||
|
||||
if mode == MergeInvert.MERGE:
|
||||
return mask_merge(masks)
|
||||
|
||||
if mode == MergeInvert.MERGE_INVERT:
|
||||
merged = mask_merge(masks)
|
||||
return mask_invert(merged)
|
||||
|
||||
raise RuntimeError
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
from extras.adetailer.args import ADetailerArgs
|
||||
from extras.adetailer.common import get_models, PredictOutput
|
||||
from extras.adetailer.mask import filter_by_ratio, filter_k_largest, sort_bboxes, mask_preprocess
|
||||
from modules import config
|
||||
|
||||
model_mapping = get_models(
|
||||
config.path_adetailer,
|
||||
huggingface=True,
|
||||
)
|
||||
|
||||
|
||||
def get_ad_model(name: str):
|
||||
if name not in model_mapping:
|
||||
msg = f"[-] ADetailer: Model {name!r} not found. Available models: {list(model_mapping.keys())}"
|
||||
raise ValueError(msg)
|
||||
return model_mapping[name]
|
||||
|
||||
|
||||
def pred_preprocessing(p, pred: PredictOutput, args: ADetailerArgs, inpaint_only_masked=False):
|
||||
pred = filter_by_ratio(
|
||||
pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio
|
||||
)
|
||||
pred = filter_k_largest(pred, k=args.ad_mask_k_largest)
|
||||
pred = sort_bboxes(pred)
|
||||
masks = mask_preprocess(
|
||||
pred.masks,
|
||||
kernel=args.ad_dilate_erode,
|
||||
x_offset=args.ad_x_offset,
|
||||
y_offset=args.ad_y_offset,
|
||||
merge_invert=args.ad_mask_merge_invert,
|
||||
)
|
||||
|
||||
#if inpaint_only_masked:
|
||||
# image_mask = self.get_image_mask(p)
|
||||
# masks = self.inpaint_mask_filter(image_mask, masks)
|
||||
return masks
|
||||
|
||||
|
||||
# def get_image_mask(p) -> Image.Image:
|
||||
# mask = p.image_mask
|
||||
# if getattr(p, "inpainting_mask_invert", False):
|
||||
# mask = ImageChops.invert(mask)
|
||||
# mask = create_binary_mask(mask)
|
||||
#
|
||||
# if is_skip_img2img(p):
|
||||
# if hasattr(p, "init_images") and p.init_images:
|
||||
# width, height = p.init_images[0].size
|
||||
# else:
|
||||
# msg = "[-] ADetailer: no init_images."
|
||||
# raise RuntimeError(msg)
|
||||
# else:
|
||||
# width, height = p.width, p.height
|
||||
# return images.resize_image(p.resize_mode, mask, width, height)
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
from extras.adetailer.common import PredictOutput, create_mask_from_bbox
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from ultralytics import YOLO, YOLOWorld
|
||||
|
||||
|
||||
def ultralytics_predict(
|
||||
model_path: str | Path,
|
||||
image: Image.Image,
|
||||
confidence: float = 0.3,
|
||||
device: str = "",
|
||||
classes: str = "",
|
||||
) -> PredictOutput[float]:
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO(model_path)
|
||||
apply_classes(model, model_path, classes)
|
||||
pred = model(image, conf=confidence, device=device)
|
||||
|
||||
bboxes = pred[0].boxes.xyxy.cpu().numpy()
|
||||
if bboxes.size == 0:
|
||||
return PredictOutput()
|
||||
bboxes = bboxes.tolist()
|
||||
|
||||
if pred[0].masks is None:
|
||||
masks = create_mask_from_bbox(bboxes, image.size)
|
||||
else:
|
||||
masks = mask_to_pil(pred[0].masks.data, image.size)
|
||||
preview = pred[0].plot()
|
||||
preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB)
|
||||
preview = Image.fromarray(preview)
|
||||
|
||||
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
|
||||
|
||||
|
||||
def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str):
|
||||
if not classes or "-world" not in Path(model_path).stem:
|
||||
return
|
||||
parsed = [c.strip() for c in classes.split(",") if c.strip()]
|
||||
if parsed:
|
||||
model.set_classes(parsed)
|
||||
|
||||
|
||||
def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
|
||||
The device can be CUDA, but `to_pil_image` takes care of that.
|
||||
|
||||
shape: tuple[int, int]
|
||||
(W, H) of the original image
|
||||
"""
|
||||
n = masks.shape[0]
|
||||
return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]
|
||||
|
||||
|
||||
|
|
@ -229,7 +229,7 @@ def worker():
|
|||
|
||||
def process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path, current_task_id,
|
||||
denoising_strength, final_scheduler_name, goals, initial_latent, switch, task, tasks,
|
||||
tiled, use_expansion, width, height):
|
||||
tiled, use_expansion, width, height, cleanup_conds=True):
|
||||
if async_task.last_stop is not False:
|
||||
ldm_patched.modules.model_management.interrupt_current_processing()
|
||||
positive_cond, negative_cond = task['c'], task['uc']
|
||||
|
|
@ -260,7 +260,8 @@ def worker():
|
|||
refiner_swap_method=async_task.refiner_swap_method,
|
||||
disable_preview=async_task.disable_preview
|
||||
)
|
||||
del task['c'], task['uc'], positive_cond, negative_cond # Save memory
|
||||
if cleanup_conds:
|
||||
del task['c'], task['uc'], positive_cond, negative_cond # Save memory
|
||||
if inpaint_worker.current_task is not None:
|
||||
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
|
||||
current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(
|
||||
|
|
@ -1007,9 +1008,51 @@ def worker():
|
|||
execution_start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path,
|
||||
imgs, img_paths = process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path,
|
||||
current_task_id, denoising_strength, final_scheduler_name, goals, initial_latent,
|
||||
switch, task, tasks, tiled, use_expansion, width, height)
|
||||
switch, task, tasks, tiled, use_expansion, width, height, False)
|
||||
|
||||
# adetailer
|
||||
for img in imgs:
|
||||
from extras.adetailer.ultralytics_predict import ultralytics_predict
|
||||
predictor = ultralytics_predict
|
||||
from extras.adetailer.script import get_ad_model
|
||||
ad_model = get_ad_model('face_yolov8n.pt')
|
||||
|
||||
kwargs = {}
|
||||
kwargs["device"] = torch.device('cpu')
|
||||
kwargs["classes"] = ""
|
||||
from PIL import Image
|
||||
img2 = Image.fromarray(img)
|
||||
pred = predictor(ad_model, img2, **kwargs)
|
||||
|
||||
if pred.preview is None:
|
||||
print(
|
||||
f"[-] ADetailer: nothing detected on image"
|
||||
)
|
||||
return False
|
||||
|
||||
from extras.adetailer.args import ADetailerArgs
|
||||
args = ADetailerArgs()
|
||||
from extras.adetailer.script import pred_preprocessing
|
||||
masks = pred_preprocessing(img, pred, args)
|
||||
merged_masks = np.maximum(*[np.array(mask) for mask in masks])
|
||||
async_task.yields.append(['preview', (100, '...', merged_masks)])
|
||||
denoising_strength = 0.5
|
||||
inpaint_head_model_path = None
|
||||
inpaint_parameterized = False
|
||||
denoising_strength, initial_latent, width, height = apply_inpaint(async_task, None,
|
||||
inpaint_head_model_path, img,
|
||||
merged_masks,
|
||||
inpaint_parameterized,
|
||||
denoising_strength, switch)
|
||||
|
||||
imgs, img_paths = process_task(all_steps, async_task, callback, controlnet_canny_path,
|
||||
controlnet_cpds_path,
|
||||
current_task_id, denoising_strength, final_scheduler_name, goals,
|
||||
initial_latent,
|
||||
switch, task, tasks, tiled, use_expansion, width, height)
|
||||
|
||||
except ldm_patched.modules.model_management.InterruptProcessingException:
|
||||
if async_task.last_stop == 'skip':
|
||||
print('User skipped')
|
||||
|
|
|
|||
|
|
@ -191,6 +191,7 @@ def get_dir_or_set_default(key, default_value, as_array=False, make_directory=Fa
|
|||
paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/checkpoints/'], True)
|
||||
paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True)
|
||||
path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/')
|
||||
path_adetailer = get_dir_or_set_default('path_adetailer', '../models/adetailer/')
|
||||
path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/')
|
||||
path_vae = get_dir_or_set_default('path_vae', '../models/vae/')
|
||||
path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/')
|
||||
|
|
|
|||
|
|
@ -19,4 +19,5 @@ onnxruntime==1.16.3
|
|||
timm==0.9.2
|
||||
translators==5.8.9
|
||||
rembg==2.0.53
|
||||
groundingdino-py==0.4.0
|
||||
groundingdino-py==0.4.0
|
||||
ultralytics==8.2.28
|
||||
Loading…
Reference in New Issue