"""
Overview:
    This module provides utilities for image tagging using IdolSankaku taggers.
    It includes functions for loading models, processing images, and extracting tags.
    The module is inspired by the `SmilingWolf/wd-tagger <https://huggingface.co/spaces/SmilingWolf/wd-tagger>`_
    project on Hugging Face.
    .. collapse:: Overview of IdolSankaku (NSFW Warning!!!)
        .. image:: idolsankaku_demo.plot.py.svg
            :align: center
    This is an overall benchmark of all the idolsankaku models:
    .. image:: idolsankaku_benchmark.plot.py.svg
        :align: center
"""
from typing import List, Tuple, Any
import numpy as np
import onnxruntime
import pandas as pd
from PIL import Image
from hbutils.testing.requires.version import VersionInfo
from huggingface_hub import hf_hub_download
from imgutils.data import load_image, ImageTyping
from imgutils.tagging.format import remove_underline
from imgutils.tagging.overlap import drop_overlap_tags
from imgutils.utils import open_onnx_model, vreplace, ts_lru_cache, sigmoid
EXP_REPO = 'deepghs/idolsankaku_tagger_with_embeddings'
EVA02_LARGE_MODEL_DSV3_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
SWIN_MODEL_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
_IS_SUPPORT = VersionInfo(onnxruntime.__version__) >= '1.17'
MODEL_NAMES = {
    "EVA02_Large": EVA02_LARGE_MODEL_DSV3_REPO,
    "SwinV2": SWIN_MODEL_REPO,
}
_DEFAULT_MODEL_NAME = 'SwinV2'
def _version_support_check(model_name):
    """
    Check if the current onnxruntime version supports the given model.
    :param model_name: The name of the model to check.
    :type model_name: str
    :raises EnvironmentError: If the model is not supported by the current onnxruntime version.
    """
    _ = model_name
    if not _IS_SUPPORT:
        raise EnvironmentError(f'Idolsankaku taggers not supported on onnxruntime {onnxruntime.__version__}, '
                               f'please upgrade it to 1.17+ version.\n'
                               f'If you are running on CPU, use "pip install -U onnxruntime" .\n'
                               f'If you are running on GPU, use "pip install -U onnxruntime-gpu" .')  # pragma: no cover
@ts_lru_cache()
def _get_idolsankaku_model(model_name):
    """
    Load an ONNX model from the Hugging Face Hub.
    :param model_name: The name of the model to load.
    :type model_name: str
    :return: The loaded ONNX model.
    :rtype: ONNXModel
    """
    _version_support_check(model_name)
    return open_onnx_model(hf_hub_download(
        repo_id=EXP_REPO,
        filename=f'{MODEL_NAMES[model_name]}/model.onnx',
    ))
@ts_lru_cache()
def _get_idolsankaku_labels(model_name, no_underline: bool = False) -> Tuple[
    List[str], List[int], List[int], List[int]]:
    """
    Get labels for the IdolSankaku model.
    :param model_name: The name of the model.
    :type model_name: str
    :param no_underline: If True, replaces underscores in tag names with spaces.
    :type no_underline: bool
    :return: A tuple containing the list of tag names, and lists of indexes for rating, general, and character categories.
    :rtype: Tuple[List[str], List[int], List[int], List[int]]
    """
    df = pd.read_csv(hf_hub_download(
        repo_id=EXP_REPO,
        filename=f'{MODEL_NAMES[model_name]}/selected_tags.csv',
    ))
    name_series = df["name"]
    if no_underline:
        name_series = name_series.map(remove_underline)
    tag_names = name_series.tolist()
    rating_indexes = list(np.where(df["category"] == 9)[0])
    general_indexes = list(np.where(df["category"] == 0)[0])
    character_indexes = list(np.where(df["category"] == 4)[0])
    return tag_names, rating_indexes, general_indexes, character_indexes
@ts_lru_cache()
def _get_idolsankaku_weights(model_name):
    """
    Load the weights for a idolsankaku model.
    :param model_name: The name of the model.
    :type model_name: str
    :return: The loaded weights.
    :rtype: numpy.ndarray
    """
    _version_support_check(model_name)
    return np.load(hf_hub_download(
        repo_id=EXP_REPO,
        filename=f'{MODEL_NAMES[model_name]}/matrix.npz',
    ))
def _mcut_threshold(probs) -> float:
    """
    Compute the Maximum Cut Thresholding (MCut) for multi-label classification.
    This method is based on the paper:
    Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
    for Multi-label Classification. In 11th International Symposium, IDA 2012
    (pp. 172-183).
    :param probs: Array of probabilities.
    :type probs: numpy.ndarray
    :return: The computed threshold.
    :rtype: float
    """
    sorted_probs = probs[probs.argsort()[::-1]]
    difs = sorted_probs[:-1] - sorted_probs[1:]
    t = difs.argmax()
    thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
    return thresh
def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
    """
    Prepare an image for tagging by resizing and padding it.
    :param image: The input image.
    :type image: ImageTyping
    :param target_size: The target size for the image.
    :type target_size: int
    :return: The prepared image as a numpy array.
    :rtype: numpy.ndarray
    """
    image = load_image(image, force_background=None, mode=None)
    image_shape = image.size
    max_dim = max(image_shape)
    pad_left = (max_dim - image_shape[0]) // 2
    pad_top = (max_dim - image_shape[1]) // 2
    padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
    try:
        padded_image.paste(image, (pad_left, pad_top), mask=image)
    except ValueError:
        padded_image.paste(image, (pad_left, pad_top))
    if max_dim != target_size:
        padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
    image_array = np.asarray(padded_image, dtype=np.float32)
    image_array = image_array[:, :, ::-1].transpose((2, 0, 1))
    image_array = image_array / 127.5 - 1.0
    return np.expand_dims(image_array, axis=0)
def _postprocess_embedding(
        pred, embedding, logit,
        model_name: str = _DEFAULT_MODEL_NAME,
        general_threshold: float = 0.35,
        general_mcut_enabled: bool = False,
        character_threshold: float = 0.85,
        character_mcut_enabled: bool = False,
        no_underline: bool = False,
        drop_overlap: bool = False,
        fmt: Any = ('rating', 'general', 'character'),
):
    """
    Post-process the embedding and prediction results.
    :param pred: The prediction array.
    :type pred: numpy.ndarray
    :param embedding: The embedding array.
    :type embedding: numpy.ndarray
    :param logit: The logit array.
    :type logit: numpy.ndarray
    :param model_name: The name of the model used.
    :type model_name: str
    :param general_threshold: Threshold for general tags.
    :type general_threshold: float
    :param general_mcut_enabled: Whether to use MCut for general tags.
    :type general_mcut_enabled: bool
    :param character_threshold: Threshold for character tags.
    :type character_threshold: float
    :param character_mcut_enabled: Whether to use MCut for character tags.
    :type character_mcut_enabled: bool
    :param no_underline: Whether to remove underscores from tag names.
    :type no_underline: bool
    :param drop_overlap: Whether to drop overlapping tags.
    :type drop_overlap: bool
    :param fmt: The format of the output.
    :type fmt: Any
    :return: The post-processed results.
    """
    assert len(pred.shape) == len(embedding.shape) == 1, \
        f'Both pred and embeddings shapes should be 1-dim, ' \
        f'but pred: {pred.shape!r}, embedding: {embedding.shape!r} actually found.'
    tag_names, rating_indexes, general_indexes, character_indexes = _get_idolsankaku_labels(model_name, no_underline)
    labels = list(zip(tag_names, pred.astype(float)))
    rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes}
    general_names = [labels[i] for i in general_indexes]
    if general_mcut_enabled:
        general_probs = np.array([x[1] for x in general_names])
        general_threshold = _mcut_threshold(general_probs)
    general_res = {x: v.item() for x, v in general_names if v > general_threshold}
    if drop_overlap:
        general_res = drop_overlap_tags(general_res)
    character_names = [labels[i] for i in character_indexes]
    if character_mcut_enabled:
        character_probs = np.array([x[1] for x in character_names])
        character_threshold = _mcut_threshold(character_probs)
        character_threshold = max(0.15, character_threshold)
    character_res = {x: v.item() for x, v in character_names if v > character_threshold}
    return vreplace(
        fmt,
        {
            'rating': rating,
            'general': general_res,
            'character': character_res,
            'tag': {**general_res, **character_res},
            'embedding': embedding.astype(np.float32),
            'prediction': pred.astype(np.float32),
            'logit': logit.astype(np.float32),
        }
    )
[docs]def convert_idolsankaku_emb_to_prediction(
        emb: np.ndarray,
        model_name: str = _DEFAULT_MODEL_NAME,
        general_threshold: float = 0.35,
        general_mcut_enabled: bool = False,
        character_threshold: float = 0.85,
        character_mcut_enabled: bool = False,
        no_underline: bool = False,
        drop_overlap: bool = False,
        fmt: Any = ('rating', 'general', 'character'),
):
    """
    Convert idolsankaku embedding to understandable prediction result. This function can process both
    single embeddings (1-dimensional array) and batches of embeddings (2-dimensional array).
    :param emb: The extracted embedding(s). Can be either a 1-dim array for single image or
                2-dim array for batch processing
    :type emb: numpy.ndarray
    :param model_name: Name of the idolsankaku model to use for prediction
    :type model_name: str
    :param general_threshold: Confidence threshold for general tags (0.0 to 1.0)
    :type general_threshold: float
    :param general_mcut_enabled: Enable MCut thresholding for general tags to improve prediction quality
    :type general_mcut_enabled: bool
    :param character_threshold: Confidence threshold for character tags (0.0 to 1.0)
    :type character_threshold: float
    :param character_mcut_enabled: Enable MCut thresholding for character tags to improve prediction quality
    :type character_mcut_enabled: bool
    :param no_underline: Replace underscores with spaces in tag names for better readability
    :type no_underline: bool
    :param drop_overlap: Remove overlapping tags to reduce redundancy
    :type drop_overlap: bool
    :param fmt: Specify return format structure for predictions, default is ``('rating', 'general', 'character')``.
    :type fmt: Any
    :return: For single embeddings: prediction result based on fmt. For batches: list of prediction results.
    For batch processing (2-dim input), returns a list where each element corresponds
    to one embedding's predictions in the same format as single embedding output.
    Example:
        >>> import os
        >>> import numpy as np
        >>> from realutils.tagging import get_idolsankaku_tags, convert_idolsankaku_emb_to_prediction
        >>>
        >>> # extract the feature embedding, shape: (W, )
        >>> embedding = get_idolsankaku_tags('skadi.jpg', fmt='embedding')
        >>>
        >>> # convert to understandable result
        >>> rating, general, character = convert_idolsankaku_emb_to_prediction(embedding)
        >>> # these 3 dicts will be the same as that returned by `get_idolsankaku_tags('skadi.jpg')`
        >>>
        >>> # Batch processing, shape: (B, W)
        >>> embeddings = np.stack([
        ...     get_idolsankaku_tags('img1.jpg', fmt='embedding'),
        ...     get_idolsankaku_tags('img2.jpg', fmt='embedding'),
        ... ])
        >>> # results will be a list of (rating, general, character) tuples
        >>> results = convert_idolsankaku_emb_to_prediction(embeddings)
    """
    z_weights = _get_idolsankaku_weights(model_name)
    weight, bias = z_weights['weight'], z_weights['bias']
    logit = emb @ weight + bias
    pred = sigmoid(logit)
    if len(emb.shape) == 1:
        return _postprocess_embedding(
            pred=pred,
            embedding=emb,
            logit=logit,
            model_name=model_name,
            general_threshold=general_threshold,
            general_mcut_enabled=general_mcut_enabled,
            character_threshold=character_threshold,
            character_mcut_enabled=character_mcut_enabled,
            no_underline=no_underline,
            drop_overlap=drop_overlap,
            fmt=fmt,
        )
    else:
        return [
            _postprocess_embedding(
                pred=pred_item,
                embedding=emb_item,
                logit=logit_item,
                model_name=model_name,
                general_threshold=general_threshold,
                general_mcut_enabled=general_mcut_enabled,
                character_threshold=character_threshold,
                character_mcut_enabled=character_mcut_enabled,
                no_underline=no_underline,
                drop_overlap=drop_overlap,
                fmt=fmt,
            )
            for pred_item, emb_item, logit_item in zip(pred, emb, logit)
        ]