From df70294a3e03a3111a6b483abb16545bcbac8fc7 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 8 Jun 2024 23:30:45 +0200 Subject: [PATCH] wip: add adetailer --- extras/adetailer/args.py | 278 ++++++++++++++++++++++++ extras/adetailer/common.py | 161 ++++++++++++++ extras/adetailer/mask.py | 269 +++++++++++++++++++++++ extras/adetailer/script.py | 53 +++++ extras/adetailer/ultralytics_predict.py | 67 ++++++ modules/async_worker.py | 51 ++++- modules/config.py | 1 + requirements_versions.txt | 3 +- 8 files changed, 878 insertions(+), 5 deletions(-) create mode 100644 extras/adetailer/args.py create mode 100644 extras/adetailer/common.py create mode 100644 extras/adetailer/mask.py create mode 100644 extras/adetailer/script.py create mode 100644 extras/adetailer/ultralytics_predict.py diff --git a/extras/adetailer/args.py b/extras/adetailer/args.py new file mode 100644 index 00000000..08ad4a3a --- /dev/null +++ b/extras/adetailer/args.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +from collections import UserList +from dataclasses import dataclass +from functools import cached_property, partial +from typing import Any, Literal, NamedTuple, Optional + +try: + from pydantic.v1 import ( + BaseModel, + Extra, + NonNegativeFloat, + NonNegativeInt, + PositiveInt, + confloat, + conint, + validator, + ) +except ImportError: + from pydantic import ( + BaseModel, + Extra, + NonNegativeFloat, + NonNegativeInt, + PositiveInt, + confloat, + conint, + validator, + ) + + +@dataclass +class SkipImg2ImgOrig: + steps: int + sampler_name: str + width: int + height: int + + +class Arg(NamedTuple): + attr: str + name: str + + +class ArgsList(UserList): + @cached_property + def attrs(self) -> tuple[str, ...]: + return tuple(attr for attr, _ in self) + + @cached_property + def names(self) -> tuple[str, ...]: + return tuple(name for _, name in self) + + +class ADetailerArgs(BaseModel, extra=Extra.forbid): + ad_model: str = "None" + ad_model_classes: str = "" + ad_tap_enable: bool = True + ad_prompt: str = "" + ad_negative_prompt: str = "" + ad_confidence: confloat(ge=0.0, le=1.0) = 0.3 + ad_mask_k_largest: NonNegativeInt = 0 + ad_mask_min_ratio: confloat(ge=0.0, le=1.0) = 0.0 + ad_mask_max_ratio: confloat(ge=0.0, le=1.0) = 1.0 + ad_dilate_erode: int = 4 + ad_x_offset: int = 0 + ad_y_offset: int = 0 + ad_mask_merge_invert: Literal["None", "Merge", "Merge and Invert"] = "None" + ad_mask_blur: NonNegativeInt = 4 + ad_denoising_strength: confloat(ge=0.0, le=1.0) = 0.4 + ad_inpaint_only_masked: bool = True + ad_inpaint_only_masked_padding: NonNegativeInt = 32 + ad_use_inpaint_width_height: bool = False + ad_inpaint_width: PositiveInt = 512 + ad_inpaint_height: PositiveInt = 512 + ad_use_steps: bool = False + ad_steps: PositiveInt = 28 + ad_use_cfg_scale: bool = False + ad_cfg_scale: NonNegativeFloat = 7.0 + ad_use_checkpoint: bool = False + ad_checkpoint: Optional[str] = None + ad_use_vae: bool = False + ad_vae: Optional[str] = None + ad_use_sampler: bool = False + ad_sampler: str = "DPM++ 2M Karras" + ad_scheduler: str = "Use same scheduler" + ad_use_noise_multiplier: bool = False + ad_noise_multiplier: confloat(ge=0.5, le=1.5) = 1.0 + ad_use_clip_skip: bool = False + ad_clip_skip: conint(ge=1, le=12) = 1 + ad_restore_face: bool = False + ad_controlnet_model: str = "None" + ad_controlnet_module: str = "None" + ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0 + ad_controlnet_guidance_start: confloat(ge=0.0, le=1.0) = 0.0 + ad_controlnet_guidance_end: confloat(ge=0.0, le=1.0) = 1.0 + is_api: bool = True + + @validator("is_api", pre=True) + def is_api_validator(cls, v: Any): # noqa: N805 + "tuple is json serializable but cannot be made with json deserialize." + return type(v) is not tuple + + @staticmethod + def ppop( + p: dict[str, Any], + key: str, + pops: list[str] | None = None, + cond: Any = None, + ) -> None: + if pops is None: + pops = [key] + if key not in p: + return + value = p[key] + cond = (not bool(value)) if cond is None else value == cond + + if cond: + for k in pops: + p.pop(k, None) + + def extra_params(self, suffix: str = "") -> dict[str, Any]: + if self.need_skip(): + return {} + + p = {name: getattr(self, attr) for attr, name in ALL_ARGS} + ppop = partial(self.ppop, p) + + ppop("ADetailer model classes") + ppop("ADetailer prompt") + ppop("ADetailer negative prompt") + p.pop("ADetailer tap enable", None) # always pop + ppop("ADetailer mask only top k largest", cond=0) + ppop("ADetailer mask min ratio", cond=0.0) + ppop("ADetailer mask max ratio", cond=1.0) + ppop("ADetailer x offset", cond=0) + ppop("ADetailer y offset", cond=0) + ppop("ADetailer mask merge invert", cond="None") + ppop("ADetailer inpaint only masked", ["ADetailer inpaint padding"]) + ppop( + "ADetailer use inpaint width height", + [ + "ADetailer use inpaint width height", + "ADetailer inpaint width", + "ADetailer inpaint height", + ], + ) + ppop( + "ADetailer use separate steps", + ["ADetailer use separate steps", "ADetailer steps"], + ) + ppop( + "ADetailer use separate CFG scale", + ["ADetailer use separate CFG scale", "ADetailer CFG scale"], + ) + ppop( + "ADetailer use separate checkpoint", + ["ADetailer use separate checkpoint", "ADetailer checkpoint"], + ) + ppop( + "ADetailer use separate VAE", + ["ADetailer use separate VAE", "ADetailer VAE"], + ) + ppop( + "ADetailer use separate sampler", + [ + "ADetailer use separate sampler", + "ADetailer sampler", + "ADetailer scheduler", + ], + ) + ppop("ADetailer scheduler", cond="Use same scheduler") + ppop( + "ADetailer use separate noise multiplier", + ["ADetailer use separate noise multiplier", "ADetailer noise multiplier"], + ) + + ppop( + "ADetailer use separate CLIP skip", + ["ADetailer use separate CLIP skip", "ADetailer CLIP skip"], + ) + + ppop("ADetailer restore face") + ppop( + "ADetailer ControlNet model", + [ + "ADetailer ControlNet model", + "ADetailer ControlNet module", + "ADetailer ControlNet weight", + "ADetailer ControlNet guidance start", + "ADetailer ControlNet guidance end", + ], + cond="None", + ) + ppop("ADetailer ControlNet module", cond="None") + ppop("ADetailer ControlNet weight", cond=1.0) + ppop("ADetailer ControlNet guidance start", cond=0.0) + ppop("ADetailer ControlNet guidance end", cond=1.0) + + if suffix: + p = {k + suffix: v for k, v in p.items()} + + return p + + def is_mediapipe(self) -> bool: + return self.ad_model.lower().startswith("mediapipe") + + def need_skip(self) -> bool: + return self.ad_model == "None" or self.ad_tap_enable is False + + +_all_args = [ + ("ad_model", "ADetailer model"), + ("ad_model_classes", "ADetailer model classes"), + ("ad_tap_enable", "ADetailer tap enable"), + ("ad_prompt", "ADetailer prompt"), + ("ad_negative_prompt", "ADetailer negative prompt"), + ("ad_confidence", "ADetailer confidence"), + ("ad_mask_k_largest", "ADetailer mask only top k largest"), + ("ad_mask_min_ratio", "ADetailer mask min ratio"), + ("ad_mask_max_ratio", "ADetailer mask max ratio"), + ("ad_x_offset", "ADetailer x offset"), + ("ad_y_offset", "ADetailer y offset"), + ("ad_dilate_erode", "ADetailer dilate erode"), + ("ad_mask_merge_invert", "ADetailer mask merge invert"), + ("ad_mask_blur", "ADetailer mask blur"), + ("ad_denoising_strength", "ADetailer denoising strength"), + ("ad_inpaint_only_masked", "ADetailer inpaint only masked"), + ("ad_inpaint_only_masked_padding", "ADetailer inpaint padding"), + ("ad_use_inpaint_width_height", "ADetailer use inpaint width height"), + ("ad_inpaint_width", "ADetailer inpaint width"), + ("ad_inpaint_height", "ADetailer inpaint height"), + ("ad_use_steps", "ADetailer use separate steps"), + ("ad_steps", "ADetailer steps"), + ("ad_use_cfg_scale", "ADetailer use separate CFG scale"), + ("ad_cfg_scale", "ADetailer CFG scale"), + ("ad_use_checkpoint", "ADetailer use separate checkpoint"), + ("ad_checkpoint", "ADetailer checkpoint"), + ("ad_use_vae", "ADetailer use separate VAE"), + ("ad_vae", "ADetailer VAE"), + ("ad_use_sampler", "ADetailer use separate sampler"), + ("ad_sampler", "ADetailer sampler"), + ("ad_scheduler", "ADetailer scheduler"), + ("ad_use_noise_multiplier", "ADetailer use separate noise multiplier"), + ("ad_noise_multiplier", "ADetailer noise multiplier"), + ("ad_use_clip_skip", "ADetailer use separate CLIP skip"), + ("ad_clip_skip", "ADetailer CLIP skip"), + ("ad_restore_face", "ADetailer restore face"), + ("ad_controlnet_model", "ADetailer ControlNet model"), + ("ad_controlnet_module", "ADetailer ControlNet module"), + ("ad_controlnet_weight", "ADetailer ControlNet weight"), + ("ad_controlnet_guidance_start", "ADetailer ControlNet guidance start"), + ("ad_controlnet_guidance_end", "ADetailer ControlNet guidance end"), +] + +_args = [Arg(*args) for args in _all_args] +ALL_ARGS = ArgsList(_args) + +BBOX_SORTBY = [ + "None", + "Position (left to right)", + "Position (center to edge)", + "Area (large to small)", +] +MASK_MERGE_INVERT = ["None", "Merge", "Merge and Invert"] + +_script_default = ( + "dynamic_prompting", + "dynamic_thresholding", + "wildcard_recursive", + "wildcards", + "lora_block_weight", + "negpip", +) +SCRIPT_DEFAULT = ",".join(sorted(_script_default)) + +_builtin_script = ("soft_inpainting", "hypertile_script") +BUILTIN_SCRIPT = ",".join(sorted(_builtin_script)) \ No newline at end of file diff --git a/extras/adetailer/common.py b/extras/adetailer/common.py new file mode 100644 index 00000000..f80103fc --- /dev/null +++ b/extras/adetailer/common.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import os +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Generic, Optional, TypeVar + +from huggingface_hub import hf_hub_download +from PIL import Image, ImageDraw +from torchvision.transforms.functional import to_pil_image + +REPO_ID = "Bingsu/adetailer" + +T = TypeVar("T", int, float) + + +@dataclass +class PredictOutput(Generic[T]): + bboxes: list[list[T]] = field(default_factory=list) + masks: list[Image.Image] = field(default_factory=list) + preview: Optional[Image.Image] = None + + +def hf_download(file: str, repo_id: str = REPO_ID) -> str: + try: + path = hf_hub_download(repo_id, file) + except Exception: + print(f"[ADetailer] Failed to load model {file!r} from huggingface") + path = "INVALID" + return path + + +def safe_mkdir(path: str | os.PathLike[str]) -> None: + path = Path(path) + if not path.exists() and path.parent.exists() and os.access(path.parent, os.W_OK): + path.mkdir() + + +def scan_model_dir(path: Path) -> list[Path]: + if not path.is_dir(): + return [] + return [p for p in path.rglob("*") if p.is_file() and p.suffix == ".pt"] + + +def download_models(*names: str) -> dict[str, str]: + models = OrderedDict() + with ThreadPoolExecutor() as executor: + for name in names: + if "-world" in name: + models[name] = executor.submit( + hf_download, name, repo_id="Bingsu/yolo-world-mirror" + ) + else: + models[name] = executor.submit(hf_download, name) + return {name: future.result() for name, future in models.items()} + + +def get_models( + *dirs: str | os.PathLike[str], huggingface: bool = True +) -> OrderedDict[str, str]: + model_paths = [] + + for dir_ in dirs: + if not dir_: + continue + model_paths.extend(scan_model_dir(Path(dir_))) + + models = OrderedDict() + if huggingface: + to_download = [ + "face_yolov8n.pt", + "face_yolov8s.pt", + "hand_yolov8n.pt", + "person_yolov8n-seg.pt", + "person_yolov8s-seg.pt", + "yolov8x-worldv2.pt", + ] + models.update(download_models(*to_download)) + + models.update( + { + "mediapipe_face_full": "mediapipe_face_full", + "mediapipe_face_short": "mediapipe_face_short", + "mediapipe_face_mesh": "mediapipe_face_mesh", + "mediapipe_face_mesh_eyes_only": "mediapipe_face_mesh_eyes_only", + } + ) + + invalid_keys = [k for k, v in models.items() if v == "INVALID"] + for key in invalid_keys: + models.pop(key) + + for path in model_paths: + if path.name in models: + continue + models[path.name] = str(path) + + return models + + +def create_mask_from_bbox( + bboxes: list[list[float]], shape: tuple[int, int] +) -> list[Image.Image]: + """ + Parameters + ---------- + bboxes: list[list[float]] + list of [x1, y1, x2, y2] + bounding boxes + shape: tuple[int, int] + shape of the image (width, height) + + Returns + ------- + masks: list[Image.Image] + A list of masks + + """ + masks = [] + for bbox in bboxes: + mask = Image.new("L", shape, 0) + mask_draw = ImageDraw.Draw(mask) + mask_draw.rectangle(bbox, fill=255) + masks.append(mask) + return masks + + +def create_bbox_from_mask( + masks: list[Image.Image], shape: tuple[int, int] +) -> list[list[int]]: + """ + Parameters + ---------- + masks: list[Image.Image] + A list of masks + shape: tuple[int, int] + shape of the image (width, height) + + Returns + ------- + bboxes: list[list[float]] + A list of bounding boxes + + """ + bboxes = [] + for mask in masks: + mask = mask.resize(shape) + bbox = mask.getbbox() + if bbox is not None: + bboxes.append(list(bbox)) + return bboxes + + +def ensure_pil_image(image: Any, mode: str = "RGB") -> Image.Image: + if not isinstance(image, Image.Image): + image = to_pil_image(image) + if image.mode != mode: + image = image.convert(mode) + return image \ No newline at end of file diff --git a/extras/adetailer/mask.py b/extras/adetailer/mask.py new file mode 100644 index 00000000..2faee71a --- /dev/null +++ b/extras/adetailer/mask.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +from enum import IntEnum +from functools import partial, reduce +from math import dist +from typing import Any, TypeVar + +import cv2 +import numpy as np +from PIL import Image, ImageChops + +from extras.adetailer.args import MASK_MERGE_INVERT +from extras.adetailer.common import ensure_pil_image, PredictOutput + + +class SortBy(IntEnum): + NONE = 0 + LEFT_TO_RIGHT = 1 + CENTER_TO_EDGE = 2 + AREA = 3 + + +class MergeInvert(IntEnum): + NONE = 0 + MERGE = 1 + MERGE_INVERT = 2 + + +T = TypeVar("T", int, float) + + +def _dilate(arr: np.ndarray, value: int) -> np.ndarray: + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value)) + return cv2.dilate(arr, kernel, iterations=1) + + +def _erode(arr: np.ndarray, value: int) -> np.ndarray: + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value)) + return cv2.erode(arr, kernel, iterations=1) + + +def dilate_erode(img: Image.Image, value: int) -> Image.Image: + """ + The dilate_erode function takes an image and a value. + If the value is positive, it dilates the image by that amount. + If the value is negative, it erodes the image by that amount. + + Parameters + ---------- + img: PIL.Image.Image + the image to be processed + value: int + kernel size of dilation or erosion + + Returns + ------- + PIL.Image.Image + The image that has been dilated or eroded + """ + if value == 0: + return img + + arr = np.array(img) + arr = _dilate(arr, value) if value > 0 else _erode(arr, -value) + + return Image.fromarray(arr) + + +def offset(img: Image.Image, x: int = 0, y: int = 0) -> Image.Image: + """ + The offset function takes an image and offsets it by a given x(→) and y(↑) value. + + Parameters + ---------- + mask: Image.Image + Pass the mask image to the function + x: int + → + y: int + ↑ + + Returns + ------- + PIL.Image.Image + A new image that is offset by x and y + """ + return ImageChops.offset(img, x, -y) + + +def is_all_black(img: Image.Image | np.ndarray) -> bool: + if isinstance(img, Image.Image): + img = np.array(ensure_pil_image(img, "L")) + return cv2.countNonZero(img) == 0 + + +def has_intersection(im1: Any, im2: Any) -> bool: + arr1 = np.array(ensure_pil_image(im1, "L")) + arr2 = np.array(ensure_pil_image(im2, "L")) + return not is_all_black(cv2.bitwise_and(arr1, arr2)) + + +def bbox_area(bbox: list[T]) -> T: + return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + + +def mask_preprocess( + masks: list[Image.Image], + kernel: int = 0, + x_offset: int = 0, + y_offset: int = 0, + merge_invert: int | MergeInvert | str = MergeInvert.NONE, +) -> list[Image.Image]: + """ + The mask_preprocess function takes a list of masks and preprocesses them. + It dilates and erodes the masks, and offsets them by x_offset and y_offset. + + Parameters + ---------- + masks: list[Image.Image] + A list of masks + kernel: int + kernel size of dilation or erosion + x_offset: int + → + y_offset: int + ↑ + + Returns + ------- + list[Image.Image] + A list of processed masks + """ + if not masks: + return [] + + if x_offset != 0 or y_offset != 0: + masks = [offset(m, x_offset, y_offset) for m in masks] + + if kernel != 0: + masks = [dilate_erode(m, kernel) for m in masks] + masks = [m for m in masks if not is_all_black(m)] + + return mask_merge_invert(masks, mode=merge_invert) + + +# Bbox sorting +def _key_left_to_right(bbox: list[T]) -> T: + """ + Left to right + + Parameters + ---------- + bbox: list[int] | list[float] + list of [x1, y1, x2, y2] + """ + return bbox[0] + + +def _key_center_to_edge(bbox: list[T], *, center: tuple[float, float]) -> float: + """ + Center to edge + + Parameters + ---------- + bbox: list[int] | list[float] + list of [x1, y1, x2, y2] + image: Image.Image + the image + """ + bbox_center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2) + return dist(center, bbox_center) + + +def _key_area(bbox: list[T]) -> T: + """ + Large to small + + Parameters + ---------- + bbox: list[int] | list[float] + list of [x1, y1, x2, y2] + """ + return -bbox_area(bbox) + + +def sort_bboxes( + pred: PredictOutput[T], order: int | SortBy = SortBy.NONE +) -> PredictOutput[T]: + if order == SortBy.NONE or len(pred.bboxes) <= 1: + return pred + + if order == SortBy.LEFT_TO_RIGHT: + key = _key_left_to_right + elif order == SortBy.CENTER_TO_EDGE: + width, height = pred.preview.size + center = (width / 2, height / 2) + key = partial(_key_center_to_edge, center=center) + elif order == SortBy.AREA: + key = _key_area + else: + raise RuntimeError + + items = len(pred.bboxes) + idx = sorted(range(items), key=lambda i: key(pred.bboxes[i])) + pred.bboxes = [pred.bboxes[i] for i in idx] + pred.masks = [pred.masks[i] for i in idx] + return pred + + +# Filter by ratio +def is_in_ratio(bbox: list[T], low: float, high: float, orig_area: int) -> bool: + area = bbox_area(bbox) + return low <= area / orig_area <= high + + +def filter_by_ratio( + pred: PredictOutput[T], low: float, high: float +) -> PredictOutput[T]: + if not pred.bboxes: + return pred + + w, h = pred.preview.size + orig_area = w * h + items = len(pred.bboxes) + idx = [i for i in range(items) if is_in_ratio(pred.bboxes[i], low, high, orig_area)] + pred.bboxes = [pred.bboxes[i] for i in idx] + pred.masks = [pred.masks[i] for i in idx] + return pred + + +def filter_k_largest(pred: PredictOutput[T], k: int = 0) -> PredictOutput[T]: + if not pred.bboxes or k == 0: + return pred + areas = [bbox_area(bbox) for bbox in pred.bboxes] + idx = np.argsort(areas)[-k:] + idx = idx[::-1] + pred.bboxes = [pred.bboxes[i] for i in idx] + pred.masks = [pred.masks[i] for i in idx] + return pred + + +# Merge / Invert +def mask_merge(masks: list[Image.Image]) -> list[Image.Image]: + arrs = [np.array(m) for m in masks] + arr = reduce(cv2.bitwise_or, arrs) + return [Image.fromarray(arr)] + + +def mask_invert(masks: list[Image.Image]) -> list[Image.Image]: + return [ImageChops.invert(m) for m in masks] + + +def mask_merge_invert( + masks: list[Image.Image], mode: int | MergeInvert | str +) -> list[Image.Image]: + if isinstance(mode, str): + mode = MASK_MERGE_INVERT.index(mode) + + if mode == MergeInvert.NONE or not masks: + return masks + + if mode == MergeInvert.MERGE: + return mask_merge(masks) + + if mode == MergeInvert.MERGE_INVERT: + merged = mask_merge(masks) + return mask_invert(merged) + + raise RuntimeError \ No newline at end of file diff --git a/extras/adetailer/script.py b/extras/adetailer/script.py new file mode 100644 index 00000000..05a4110e --- /dev/null +++ b/extras/adetailer/script.py @@ -0,0 +1,53 @@ +from extras.adetailer.args import ADetailerArgs +from extras.adetailer.common import get_models, PredictOutput +from extras.adetailer.mask import filter_by_ratio, filter_k_largest, sort_bboxes, mask_preprocess +from modules import config + +model_mapping = get_models( + config.path_adetailer, + huggingface=True, +) + + +def get_ad_model(name: str): + if name not in model_mapping: + msg = f"[-] ADetailer: Model {name!r} not found. Available models: {list(model_mapping.keys())}" + raise ValueError(msg) + return model_mapping[name] + + +def pred_preprocessing(p, pred: PredictOutput, args: ADetailerArgs, inpaint_only_masked=False): + pred = filter_by_ratio( + pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio + ) + pred = filter_k_largest(pred, k=args.ad_mask_k_largest) + pred = sort_bboxes(pred) + masks = mask_preprocess( + pred.masks, + kernel=args.ad_dilate_erode, + x_offset=args.ad_x_offset, + y_offset=args.ad_y_offset, + merge_invert=args.ad_mask_merge_invert, + ) + + #if inpaint_only_masked: + # image_mask = self.get_image_mask(p) + # masks = self.inpaint_mask_filter(image_mask, masks) + return masks + + + # def get_image_mask(p) -> Image.Image: + # mask = p.image_mask + # if getattr(p, "inpainting_mask_invert", False): + # mask = ImageChops.invert(mask) + # mask = create_binary_mask(mask) + # + # if is_skip_img2img(p): + # if hasattr(p, "init_images") and p.init_images: + # width, height = p.init_images[0].size + # else: + # msg = "[-] ADetailer: no init_images." + # raise RuntimeError(msg) + # else: + # width, height = p.width, p.height + # return images.resize_image(p.resize_mode, mask, width, height) \ No newline at end of file diff --git a/extras/adetailer/ultralytics_predict.py b/extras/adetailer/ultralytics_predict.py new file mode 100644 index 00000000..b028ea83 --- /dev/null +++ b/extras/adetailer/ultralytics_predict.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import cv2 +from PIL import Image +from torchvision.transforms.functional import to_pil_image + +from extras.adetailer.common import PredictOutput, create_mask_from_bbox + +if TYPE_CHECKING: + import torch + from ultralytics import YOLO, YOLOWorld + + +def ultralytics_predict( + model_path: str | Path, + image: Image.Image, + confidence: float = 0.3, + device: str = "", + classes: str = "", +) -> PredictOutput[float]: + from ultralytics import YOLO + + model = YOLO(model_path) + apply_classes(model, model_path, classes) + pred = model(image, conf=confidence, device=device) + + bboxes = pred[0].boxes.xyxy.cpu().numpy() + if bboxes.size == 0: + return PredictOutput() + bboxes = bboxes.tolist() + + if pred[0].masks is None: + masks = create_mask_from_bbox(bboxes, image.size) + else: + masks = mask_to_pil(pred[0].masks.data, image.size) + preview = pred[0].plot() + preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB) + preview = Image.fromarray(preview) + + return PredictOutput(bboxes=bboxes, masks=masks, preview=preview) + + +def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str): + if not classes or "-world" not in Path(model_path).stem: + return + parsed = [c.strip() for c in classes.split(",") if c.strip()] + if parsed: + model.set_classes(parsed) + + +def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]: + """ + Parameters + ---------- + masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W). + The device can be CUDA, but `to_pil_image` takes care of that. + + shape: tuple[int, int] + (W, H) of the original image + """ + n = masks.shape[0] + return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)] + + diff --git a/modules/async_worker.py b/modules/async_worker.py index 5e7c561f..b79598fe 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -229,7 +229,7 @@ def worker(): def process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path, current_task_id, denoising_strength, final_scheduler_name, goals, initial_latent, switch, task, tasks, - tiled, use_expansion, width, height): + tiled, use_expansion, width, height, cleanup_conds=True): if async_task.last_stop is not False: ldm_patched.modules.model_management.interrupt_current_processing() positive_cond, negative_cond = task['c'], task['uc'] @@ -260,7 +260,8 @@ def worker(): refiner_swap_method=async_task.refiner_swap_method, disable_preview=async_task.disable_preview ) - del task['c'], task['uc'], positive_cond, negative_cond # Save memory + if cleanup_conds: + del task['c'], task['uc'], positive_cond, negative_cond # Save memory if inpaint_worker.current_task is not None: imgs = [inpaint_worker.current_task.post_process(x) for x in imgs] current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float( @@ -1007,9 +1008,51 @@ def worker(): execution_start_time = time.perf_counter() try: - process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path, + imgs, img_paths = process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path, current_task_id, denoising_strength, final_scheduler_name, goals, initial_latent, - switch, task, tasks, tiled, use_expansion, width, height) + switch, task, tasks, tiled, use_expansion, width, height, False) + + # adetailer + for img in imgs: + from extras.adetailer.ultralytics_predict import ultralytics_predict + predictor = ultralytics_predict + from extras.adetailer.script import get_ad_model + ad_model = get_ad_model('face_yolov8n.pt') + + kwargs = {} + kwargs["device"] = torch.device('cpu') + kwargs["classes"] = "" + from PIL import Image + img2 = Image.fromarray(img) + pred = predictor(ad_model, img2, **kwargs) + + if pred.preview is None: + print( + f"[-] ADetailer: nothing detected on image" + ) + return False + + from extras.adetailer.args import ADetailerArgs + args = ADetailerArgs() + from extras.adetailer.script import pred_preprocessing + masks = pred_preprocessing(img, pred, args) + merged_masks = np.maximum(*[np.array(mask) for mask in masks]) + async_task.yields.append(['preview', (100, '...', merged_masks)]) + denoising_strength = 0.5 + inpaint_head_model_path = None + inpaint_parameterized = False + denoising_strength, initial_latent, width, height = apply_inpaint(async_task, None, + inpaint_head_model_path, img, + merged_masks, + inpaint_parameterized, + denoising_strength, switch) + + imgs, img_paths = process_task(all_steps, async_task, callback, controlnet_canny_path, + controlnet_cpds_path, + current_task_id, denoising_strength, final_scheduler_name, goals, + initial_latent, + switch, task, tasks, tiled, use_expansion, width, height) + except ldm_patched.modules.model_management.InterruptProcessingException: if async_task.last_stop == 'skip': print('User skipped') diff --git a/modules/config.py b/modules/config.py index ae00685d..d7bd2d31 100644 --- a/modules/config.py +++ b/modules/config.py @@ -191,6 +191,7 @@ def get_dir_or_set_default(key, default_value, as_array=False, make_directory=Fa paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/checkpoints/'], True) paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True) path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/') +path_adetailer = get_dir_or_set_default('path_adetailer', '../models/adetailer/') path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/') path_vae = get_dir_or_set_default('path_vae', '../models/vae/') path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/') diff --git a/requirements_versions.txt b/requirements_versions.txt index ebcd0297..d4e45e49 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -19,4 +19,5 @@ onnxruntime==1.16.3 timm==0.9.2 translators==5.8.9 rembg==2.0.53 -groundingdino-py==0.4.0 \ No newline at end of file +groundingdino-py==0.4.0 +ultralytics==8.2.28 \ No newline at end of file