"""
Overview:
This module provides utilities for image tagging using IdolSankaku taggers.
It includes functions for loading models, processing images, and extracting tags.
The module is inspired by the `SmilingWolf/wd-tagger <https://huggingface.co/spaces/SmilingWolf/wd-tagger>`_
project on Hugging Face.
.. collapse:: Overview of IdolSankaku (NSFW Warning!!!)
.. image:: idolsankaku_demo.plot.py.svg
:align: center
This is an overall benchmark of all the idolsankaku models:
.. image:: idolsankaku_benchmark.plot.py.svg
:align: center
"""
from typing import List, Tuple, Any
import numpy as np
import onnxruntime
import pandas as pd
from PIL import Image
from hbutils.testing.requires.version import VersionInfo
from huggingface_hub import hf_hub_download
from imgutils.data import load_image, ImageTyping
from imgutils.tagging.format import remove_underline
from imgutils.tagging.overlap import drop_overlap_tags
from imgutils.utils import open_onnx_model, vreplace, ts_lru_cache, sigmoid
EXP_REPO = 'deepghs/idolsankaku_tagger_with_embeddings'
EVA02_LARGE_MODEL_DSV3_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
SWIN_MODEL_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
_IS_SUPPORT = VersionInfo(onnxruntime.__version__) >= '1.17'
MODEL_NAMES = {
"EVA02_Large": EVA02_LARGE_MODEL_DSV3_REPO,
"SwinV2": SWIN_MODEL_REPO,
}
_DEFAULT_MODEL_NAME = 'SwinV2'
def _version_support_check(model_name):
"""
Check if the current onnxruntime version supports the given model.
:param model_name: The name of the model to check.
:type model_name: str
:raises EnvironmentError: If the model is not supported by the current onnxruntime version.
"""
_ = model_name
if not _IS_SUPPORT:
raise EnvironmentError(f'Idolsankaku taggers not supported on onnxruntime {onnxruntime.__version__}, '
f'please upgrade it to 1.17+ version.\n'
f'If you are running on CPU, use "pip install -U onnxruntime" .\n'
f'If you are running on GPU, use "pip install -U onnxruntime-gpu" .') # pragma: no cover
@ts_lru_cache()
def _get_idolsankaku_model(model_name):
"""
Load an ONNX model from the Hugging Face Hub.
:param model_name: The name of the model to load.
:type model_name: str
:return: The loaded ONNX model.
:rtype: ONNXModel
"""
_version_support_check(model_name)
return open_onnx_model(hf_hub_download(
repo_id=EXP_REPO,
filename=f'{MODEL_NAMES[model_name]}/model.onnx',
))
@ts_lru_cache()
def _get_idolsankaku_labels(model_name, no_underline: bool = False) -> Tuple[
List[str], List[int], List[int], List[int]]:
"""
Get labels for the IdolSankaku model.
:param model_name: The name of the model.
:type model_name: str
:param no_underline: If True, replaces underscores in tag names with spaces.
:type no_underline: bool
:return: A tuple containing the list of tag names, and lists of indexes for rating, general, and character categories.
:rtype: Tuple[List[str], List[int], List[int], List[int]]
"""
df = pd.read_csv(hf_hub_download(
repo_id=EXP_REPO,
filename=f'{MODEL_NAMES[model_name]}/selected_tags.csv',
))
name_series = df["name"]
if no_underline:
name_series = name_series.map(remove_underline)
tag_names = name_series.tolist()
rating_indexes = list(np.where(df["category"] == 9)[0])
general_indexes = list(np.where(df["category"] == 0)[0])
character_indexes = list(np.where(df["category"] == 4)[0])
return tag_names, rating_indexes, general_indexes, character_indexes
@ts_lru_cache()
def _get_idolsankaku_weights(model_name):
"""
Load the weights for a idolsankaku model.
:param model_name: The name of the model.
:type model_name: str
:return: The loaded weights.
:rtype: numpy.ndarray
"""
_version_support_check(model_name)
return np.load(hf_hub_download(
repo_id=EXP_REPO,
filename=f'{MODEL_NAMES[model_name]}/matrix.npz',
))
def _mcut_threshold(probs) -> float:
"""
Compute the Maximum Cut Thresholding (MCut) for multi-label classification.
This method is based on the paper:
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
for Multi-label Classification. In 11th International Symposium, IDA 2012
(pp. 172-183).
:param probs: Array of probabilities.
:type probs: numpy.ndarray
:return: The computed threshold.
:rtype: float
"""
sorted_probs = probs[probs.argsort()[::-1]]
difs = sorted_probs[:-1] - sorted_probs[1:]
t = difs.argmax()
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
return thresh
def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
"""
Prepare an image for tagging by resizing and padding it.
:param image: The input image.
:type image: ImageTyping
:param target_size: The target size for the image.
:type target_size: int
:return: The prepared image as a numpy array.
:rtype: numpy.ndarray
"""
image = load_image(image, force_background=None, mode=None)
image_shape = image.size
max_dim = max(image_shape)
pad_left = (max_dim - image_shape[0]) // 2
pad_top = (max_dim - image_shape[1]) // 2
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
try:
padded_image.paste(image, (pad_left, pad_top), mask=image)
except ValueError:
padded_image.paste(image, (pad_left, pad_top))
if max_dim != target_size:
padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
image_array = np.asarray(padded_image, dtype=np.float32)
image_array = image_array[:, :, ::-1].transpose((2, 0, 1))
image_array = image_array / 127.5 - 1.0
return np.expand_dims(image_array, axis=0)
def _postprocess_embedding(
pred, embedding, logit,
model_name: str = _DEFAULT_MODEL_NAME,
general_threshold: float = 0.35,
general_mcut_enabled: bool = False,
character_threshold: float = 0.85,
character_mcut_enabled: bool = False,
no_underline: bool = False,
drop_overlap: bool = False,
fmt: Any = ('rating', 'general', 'character'),
):
"""
Post-process the embedding and prediction results.
:param pred: The prediction array.
:type pred: numpy.ndarray
:param embedding: The embedding array.
:type embedding: numpy.ndarray
:param logit: The logit array.
:type logit: numpy.ndarray
:param model_name: The name of the model used.
:type model_name: str
:param general_threshold: Threshold for general tags.
:type general_threshold: float
:param general_mcut_enabled: Whether to use MCut for general tags.
:type general_mcut_enabled: bool
:param character_threshold: Threshold for character tags.
:type character_threshold: float
:param character_mcut_enabled: Whether to use MCut for character tags.
:type character_mcut_enabled: bool
:param no_underline: Whether to remove underscores from tag names.
:type no_underline: bool
:param drop_overlap: Whether to drop overlapping tags.
:type drop_overlap: bool
:param fmt: The format of the output.
:type fmt: Any
:return: The post-processed results.
"""
assert len(pred.shape) == len(embedding.shape) == 1, \
f'Both pred and embeddings shapes should be 1-dim, ' \
f'but pred: {pred.shape!r}, embedding: {embedding.shape!r} actually found.'
tag_names, rating_indexes, general_indexes, character_indexes = _get_idolsankaku_labels(model_name, no_underline)
labels = list(zip(tag_names, pred.astype(float)))
rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes}
general_names = [labels[i] for i in general_indexes]
if general_mcut_enabled:
general_probs = np.array([x[1] for x in general_names])
general_threshold = _mcut_threshold(general_probs)
general_res = {x: v.item() for x, v in general_names if v > general_threshold}
if drop_overlap:
general_res = drop_overlap_tags(general_res)
character_names = [labels[i] for i in character_indexes]
if character_mcut_enabled:
character_probs = np.array([x[1] for x in character_names])
character_threshold = _mcut_threshold(character_probs)
character_threshold = max(0.15, character_threshold)
character_res = {x: v.item() for x, v in character_names if v > character_threshold}
return vreplace(
fmt,
{
'rating': rating,
'general': general_res,
'character': character_res,
'tag': {**general_res, **character_res},
'embedding': embedding.astype(np.float32),
'prediction': pred.astype(np.float32),
'logit': logit.astype(np.float32),
}
)
[docs]def convert_idolsankaku_emb_to_prediction(
emb: np.ndarray,
model_name: str = _DEFAULT_MODEL_NAME,
general_threshold: float = 0.35,
general_mcut_enabled: bool = False,
character_threshold: float = 0.85,
character_mcut_enabled: bool = False,
no_underline: bool = False,
drop_overlap: bool = False,
fmt: Any = ('rating', 'general', 'character'),
):
"""
Convert idolsankaku embedding to understandable prediction result. This function can process both
single embeddings (1-dimensional array) and batches of embeddings (2-dimensional array).
:param emb: The extracted embedding(s). Can be either a 1-dim array for single image or
2-dim array for batch processing
:type emb: numpy.ndarray
:param model_name: Name of the idolsankaku model to use for prediction
:type model_name: str
:param general_threshold: Confidence threshold for general tags (0.0 to 1.0)
:type general_threshold: float
:param general_mcut_enabled: Enable MCut thresholding for general tags to improve prediction quality
:type general_mcut_enabled: bool
:param character_threshold: Confidence threshold for character tags (0.0 to 1.0)
:type character_threshold: float
:param character_mcut_enabled: Enable MCut thresholding for character tags to improve prediction quality
:type character_mcut_enabled: bool
:param no_underline: Replace underscores with spaces in tag names for better readability
:type no_underline: bool
:param drop_overlap: Remove overlapping tags to reduce redundancy
:type drop_overlap: bool
:param fmt: Specify return format structure for predictions, default is ``('rating', 'general', 'character')``.
:type fmt: Any
:return: For single embeddings: prediction result based on fmt. For batches: list of prediction results.
For batch processing (2-dim input), returns a list where each element corresponds
to one embedding's predictions in the same format as single embedding output.
Example:
>>> import os
>>> import numpy as np
>>> from realutils.tagging import get_idolsankaku_tags, convert_idolsankaku_emb_to_prediction
>>>
>>> # extract the feature embedding, shape: (W, )
>>> embedding = get_idolsankaku_tags('skadi.jpg', fmt='embedding')
>>>
>>> # convert to understandable result
>>> rating, general, character = convert_idolsankaku_emb_to_prediction(embedding)
>>> # these 3 dicts will be the same as that returned by `get_idolsankaku_tags('skadi.jpg')`
>>>
>>> # Batch processing, shape: (B, W)
>>> embeddings = np.stack([
... get_idolsankaku_tags('img1.jpg', fmt='embedding'),
... get_idolsankaku_tags('img2.jpg', fmt='embedding'),
... ])
>>> # results will be a list of (rating, general, character) tuples
>>> results = convert_idolsankaku_emb_to_prediction(embeddings)
"""
z_weights = _get_idolsankaku_weights(model_name)
weight, bias = z_weights['weight'], z_weights['bias']
logit = emb @ weight + bias
pred = sigmoid(logit)
if len(emb.shape) == 1:
return _postprocess_embedding(
pred=pred,
embedding=emb,
logit=logit,
model_name=model_name,
general_threshold=general_threshold,
general_mcut_enabled=general_mcut_enabled,
character_threshold=character_threshold,
character_mcut_enabled=character_mcut_enabled,
no_underline=no_underline,
drop_overlap=drop_overlap,
fmt=fmt,
)
else:
return [
_postprocess_embedding(
pred=pred_item,
embedding=emb_item,
logit=logit_item,
model_name=model_name,
general_threshold=general_threshold,
general_mcut_enabled=general_mcut_enabled,
character_threshold=character_threshold,
character_mcut_enabled=character_mcut_enabled,
no_underline=no_underline,
drop_overlap=drop_overlap,
fmt=fmt,
)
for pred_item, emb_item, logit_item in zip(pred, emb, logit)
]