﻿import math
import wave
import json
import subprocess
import torch
from typing import Any, Tuple, Optional


def _to_int(v, default=0):
    try:
        return int(v)
    except Exception:
        return default


def _to_float(v, default=0.0):
    try:
        return float(v)
    except Exception:
        return default


def _format_mmss(sec: float) -> str:
    if sec < 0:
        sec = 0.0
    m = int(sec // 60)
    s = sec - m * 60
    return f"{m:02d}:{s:06.3f}"


def _pick_samples_dim(shape: Tuple[int, ...]) -> int:
    if not shape:
        return 0
    if len(shape) == 1:
        return shape[0]
    if len(shape) >= 3:
        last = shape[-1]
        if last > 8:
            return last
        return max(shape)
    a, b = shape[0], shape[1]
    if a <= 8 and b > 8:
        return b
    if b <= 8 and a > 8:
        return a
    return max(a, b)


def _ffprobe_duration(path: str) -> Optional[float]:
    if not isinstance(path, str):
        return None
    p = path.strip().strip('"').strip("'")
    if not p:
        return None
    try:
        cp = subprocess.run(
            ["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "json", p],
            capture_output=True,
            text=True,
            check=False,
        )
        if cp.returncode != 0:
            return None
        data = json.loads(cp.stdout or "{}")
        dur = data.get("format", {}).get("duration", None)
        if dur is None:
            return None
        dur_f = float(dur)
        if dur_f < 0:
            return None
        return dur_f
    except Exception:
        return None


def _wav_duration(path: str) -> Optional[float]:
    if not isinstance(path, str):
        return None
    p = path.strip().strip('"').strip("'")
    if not p:
        return None
    if not p.lower().endswith(".wav"):
        return None
    try:
        with wave.open(p, "rb") as w:
            n = w.getnframes()
            r = w.getframerate()
            if r <= 0:
                return None
            return float(n) / float(r)
    except Exception:
        return None


def _duration_from_audio(audio: Any) -> Optional[float]:
    if audio is None:
        return None

    if isinstance(audio, dict):
        sr = audio.get("sample_rate", None)
        wf = audio.get("waveform", None)

        if sr is None:
            sr = audio.get("sr", audio.get("rate", None))
        if wf is None:
            wf = audio.get("samples", audio.get("data", None))

        if sr is None or wf is None:
            return None

        sr_i = _to_int(sr, 0)
        if sr_i <= 0:
            return None

        if hasattr(wf, "shape"):
            shape = tuple(getattr(wf, "shape"))
            samples = _pick_samples_dim(shape)
        elif isinstance(wf, (list, tuple)):
            if len(wf) == 0:
                return 0.0
            if isinstance(wf[0], (list, tuple)):
                samples = max((len(x) for x in wf if isinstance(x, (list, tuple))), default=0)
            else:
                samples = len(wf)
        else:
            return None

        if samples <= 0:
            return 0.0

        return float(samples) / float(sr_i)

    if isinstance(audio, str):
        d = _ffprobe_duration(audio)
        if d is not None:
            return d
        return _wav_duration(audio)

    return None


def _align_mul(frames: int, align: int) -> int:
    if align < 1:
        align = 1
    if frames < 1:
        frames = 1
    return int(math.ceil(frames / align) * align)


def _align_plus1(frames: int, align: int) -> int:
    if align < 1:
        align = 1
    if frames < 1:
        frames = 1
    return int(math.ceil((frames - 1) / align) * align + 1)


class 冻结首帧并拼接:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "原图像批次": ("IMAGE",),
                "冻结帧数N": ("INT", {"default": 10, "min": 1, "max": 240, "step": 1}),
                "丢弃开头帧M": ("INT", {"default": 2, "min": 0, "max": 240, "step": 1}),
            }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("图像批次",)
    FUNCTION = "run"
    CATEGORY = "video_batch"

    def run(self, 原图像批次, 冻结帧数N=10, 丢弃开头帧M=2):
        if not isinstance(原图像批次, torch.Tensor):
            raise ValueError("原图像批次必须是 IMAGE 批次")
        if 原图像批次.dim() != 4:
            raise ValueError("原图像批次维度必须是4维")

        n = int(冻结帧数N)
        m = int(丢弃开头帧M)
        if n < 1:
            n = 1
        if m < 0:
            m = 0

        total = int(原图像批次.shape[0])
        if total < 1:
            raise ValueError("原图像批次为空")

        base = 原图像批次[m:] if m < total else 原图像批次[0:0]

        if int(base.shape[0]) < 1:
            first = 原图像批次[0:1]
            return (first.repeat(n, 1, 1, 1),)

        first = base[0:1]
        freeze = first.repeat(n, 1, 1, 1)
        return (torch.cat([freeze, base], dim=0),)


class 音频时长转帧数:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "音频": ("AUDIO",),
                "FPS": ("INT", {"default": 24, "min": 1, "max": 240, "step": 1}),
                "对齐模式": (["4+1", "8+1"], {"default": "4+1"}),
            },
            "optional": {
                "开头缓冲秒": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "结尾缓冲秒": ("FLOAT", {"default": 0.20, "min": 0.0, "max": 10.0, "step": 0.01}),
                "取整方式": (["ceil", "round", "floor"], {"default": "ceil"}),
            },
        }

    RETURN_TYPES = ("FLOAT", "INT", "INT", "INT", "STRING", "STRING")
    RETURN_NAMES = ("时长秒", "时长毫秒", "帧数_倍数对齐", "帧数_4+1或8+1", "时长显示", "调试信息")
    FUNCTION = "calc"
    CATEGORY = "audio_duration"

    def calc(self, 音频, FPS, 对齐模式, 开头缓冲秒=0.0, 结尾缓冲秒=0.20, 取整方式="ceil"):
        dur = _duration_from_audio(音频)
        if dur is None:
            raise ValueError("无法读取音频时长: 需要标准 AUDIO; 若传路径需要 ffprobe 或 wav")

        fps = _to_int(FPS, 24)
        if fps < 1:
            fps = 24

        lead = _to_float(开头缓冲秒, 0.0)
        tail = _to_float(结尾缓冲秒, 0.0)
        if lead < 0:
            lead = 0.0
        if tail < 0:
            tail = 0.0

        align = 4 if 对齐模式 == "4+1" else 8

        sec = float(dur) + lead + tail
        raw = sec * float(fps)

        if 取整方式 == "floor":
            base = int(math.floor(raw))
        elif 取整方式 == "round":
            base = int(round(raw))
        else:
            base = int(math.ceil(raw))

        if base < 1:
            base = 1

        frames_mul = _align_mul(base, align)
        frames_plus1 = _align_plus1(base, align)

        ms = int(round(sec * 1000.0))
        show = _format_mmss(sec)
        debug = f"fps={fps}, align={align}, base={base}, mul={frames_mul}, plus1={frames_plus1}"

        return (sec, ms, frames_mul, frames_plus1, show, debug)


NODE_CLASS_MAPPINGS = {
    "FreezeHead": 冻结首帧并拼接,
    "AudioDurationToFrames": 音频时长转帧数,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "FreezeHead": "Freeze Head Frames",
    "AudioDurationToFrames": "Audio Duration To Frames",
}
