﻿import math
import re
import torch
import torch.nn.functional as F

def _safe_audio(audio):
    if audio is None:
        return None
    if not isinstance(audio, dict):
        return audio
    wf = audio.get("waveform", None)
    if isinstance(wf, torch.Tensor):
        if wf.dim() == 1:
            wf = wf.unsqueeze(0)
        elif wf.dim() == 2:
            if wf.shape[0] > wf.shape[1] and wf.shape[1] <= 8:
                wf = wf.transpose(0, 1).contiguous()
        audio = dict(audio)
        audio["waveform"] = wf
    return audio

def _ema_1d(x: torch.Tensor, alpha: float):
    alpha = float(max(0.0, min(0.99, alpha)))
    if alpha <= 0:
        return x
    y = x.clone()
    for i in range(1, x.numel()):
        y[i] = alpha * y[i - 1] + (1.0 - alpha) * x[i]
    return y

def _smooth_params(dx, dy, zoom, rot, level: str):
    lv = (level or "HARD").upper()
    if lv == "OFF":
        return dx, dy, zoom, rot
    alpha = 0.25 if lv == "HARD" else 0.12
    dx = _ema_1d(dx, alpha)
    dy = _ema_1d(dy, alpha)
    zoom = _ema_1d(zoom, alpha)
    rot = _ema_1d(rot, alpha)
    return dx, dy, zoom, rot

def _blur_bg(x_bchw, strength=0.30):
    strength = float(max(0.0, min(1.0, strength)))
    if strength <= 0:
        return x_bchw
    k = int(7 + strength * 31)
    if k % 2 == 0:
        k += 1
    pad = k // 2
    x_bchw = F.avg_pool2d(x_bchw, kernel_size=k, stride=1, padding=pad)
    x_bchw = F.avg_pool2d(x_bchw, kernel_size=k, stride=1, padding=pad)
    return x_bchw

def _expand_canvas(frames_bhwc, scale, mode, blur_strength):
    mode = str(mode or "镜像扩边")
    if mode == "关闭" or scale <= 1.0001:
        return frames_bhwc, (0, 0, frames_bhwc.shape[1], frames_bhwc.shape[2])

    b, h, w, c = frames_bhwc.shape
    nh = int(round(h * scale))
    nw = int(round(w * scale))

    pad_y = max(0, nh - h)
    pad_x = max(0, nw - w)
    pt = pad_y // 2
    pb = pad_y - pt
    pl = pad_x // 2
    pr = pad_x - pl

    x = frames_bhwc.permute(0, 3, 1, 2).contiguous()

    if mode == "黑边扩边":
        x = F.pad(x, (pl, pr, pt, pb), mode="constant", value=0.0)
    elif mode == "镜像扩边":
        x = F.pad(x, (pl, pr, pt, pb), mode="reflect")
    elif mode == "模糊背景扩边":
        bg = F.interpolate(x, size=(nh, nw), mode="bilinear", align_corners=False)
        bg = _blur_bg(bg, strength=blur_strength)

        fg = torch.zeros_like(bg)
        fg[:, :, pt:pt + h, pl:pl + w] = x

        feather = max(8, int(min(h, w) * (0.03 + 0.10 * float(blur_strength))))
        feather = min(feather, max(8, min(h, w) // 2 - 1))

        yy = torch.arange(nh, device=bg.device, dtype=bg.dtype)
        xx = torch.arange(nw, device=bg.device, dtype=bg.dtype)

        def _ramp(v, start, end, feather):
            t = torch.zeros_like(v)
            t = torch.where((v >= start) & (v < end), torch.ones_like(v), t)
            if feather > 0:
                pi = 3.141592653589793
                l0, l1 = start - feather, start
                lt = ((v - l0) / max(1.0, float(feather))).clamp(0, 1)
                lt = 0.5 - 0.5 * torch.cos(lt * pi)
                t = torch.where((v >= l0) & (v < l1), lt, t)
                r0, r1 = end, end + feather
                rt = (1 - (v - r0) / max(1.0, float(feather))).clamp(0, 1)
                rt = 0.5 - 0.5 * torch.cos(rt * pi)
                t = torch.where((v >= r0) & (v < r1), rt, t)
            return t

        my = _ramp(yy, pt, pt + h, feather).view(1, 1, nh, 1)
        mx = _ramp(xx, pl, pl + w, feather).view(1, 1, 1, nw)
        mask = (my * mx).clamp(0, 1)

        x = bg * (1 - mask) + fg * mask
    else:
        return frames_bhwc, (0, 0, h, w)

    out = x.permute(0, 2, 3, 1).contiguous()
    crop = (pt, pl, h, w)
    return out, crop

def _crop_back(frames_bhwc, crop):
    top, left, h, w = crop
    return frames_bhwc[:, top:top + h, left:left + w, :]

def _apply_affine(frames_bhwc, dx_px, dy_px, zoom, rot_rad):
    b, h, w, c = frames_bhwc.shape
    device = frames_bhwc.device
    dtype = frames_bhwc.dtype

    tx = (2.0 * dx_px / max(1.0, w)).to(device=device, dtype=dtype)
    ty = (2.0 * dy_px / max(1.0, h)).to(device=device, dtype=dtype)

    zoom = zoom.to(device=device, dtype=dtype)
    rot = rot_rad.to(device=device, dtype=dtype)

    cos = torch.cos(rot)
    sin = torch.sin(rot)

    a = cos / zoom
    b_ = -sin / zoom
    c_ = sin / zoom
    d = cos / zoom

    theta = torch.zeros((b, 2, 3), device=device, dtype=dtype)
    theta[:, 0, 0] = a
    theta[:, 0, 1] = b_
    theta[:, 0, 2] = tx
    theta[:, 1, 0] = c_
    theta[:, 1, 1] = d
    theta[:, 1, 2] = ty

    x = frames_bhwc.permute(0, 3, 1, 2).contiguous()
    grid = F.affine_grid(theta, size=x.size(), align_corners=False)
    y = F.grid_sample(x, grid, mode="bilinear", padding_mode="zeros", align_corners=False)
    return y.permute(0, 2, 3, 1).contiguous()

def _parse_segments(script: str):
    segs = []
    if not script:
        return segs
    lines = script.splitlines()
    for ln in lines:
        s = ln.strip()
        if not s:
            continue
        s = re.sub(r"#.*$", "", s).strip()
        if not s:
            continue

        m = re.match(r"^\s*([0-9]+(?:\.[0-9]+)?)\s*(?:s|秒)?\s*(.*)$", s, re.I)
        if not m:
            continue
        dur = float(m.group(1))
        rest = (m.group(2) or "").strip()
        if dur <= 0:
            continue

        parts = rest.split()
        if not parts:
            continue
        move = parts[0].strip()

        strength = None
        speed = None
        for p in parts[1:]:
            if "强度" in p or "strength" in p.lower():
                val = p.split("=")[-1]
                try:
                    strength = float(val)
                except:
                    pass
            if "速度" in p or "speed" in p.lower():
                val = p.split("=")[-1]
                try:
                    speed = float(val)
                except:
                    pass

        segs.append({"dur": dur, "move": move, "strength": strength, "speed": speed})
    return segs

def _smoothstep(t):
    return t * t * (3 - 2 * t)

def _norm_move_name(x: str):
    s = str(x or "").strip().lower()
    if not s:
        return ""
    s = s.replace("（", "(").replace("）", ")")
    s = re.sub(r"[\s\-\_\/\\\+\|\:\;\,\.\!\?\[\]\{\}\<\>\(\)→←=]+", "", s)
    s = s.replace("l", "l").replace("r", "r")
    return s

MOVE_PRESETS = [
    "横摇左到右",
    "横摇右到左",
    "纵摇上到下",
    "纵摇下到上",
    "推近",
    "拉远",
    "斜移左上",
    "斜移右上",
    "斜移左下",
    "斜移右下",
    "环绕一圈",
    "环绕推近",
    "横摇推近",
    "横摇拉远",
    "推近上移",
    "推近下移",
    "呼吸变焦",
    "多利变焦",
    "摇臂上升推近",
    "摇臂下降拉远",
    "推近拉远",
    "拉远推近",
    "推近拉远推近",
    "拉远推近拉远",
    "手持轻",
    "手持中",
    "手持强",
]

MOVE_ALIASES_NORM = {
    _norm_move_name("横向摇镜_LR"): "横摇左到右",
    _norm_move_name("横向摇镜_RL"): "横摇右到左",
    _norm_move_name("纵向摇镜_UD"): "纵摇上到下",
    _norm_move_name("纵向摇镜_DU"): "纵摇下到上",
    _norm_move_name("斜向推移_左上"): "斜移左上",
    _norm_move_name("斜向推移_右上"): "斜移右上",
    _norm_move_name("斜向推移_左下"): "斜移左下",
    _norm_move_name("斜向推移_右下"): "斜移右下",
    _norm_move_name("环绕_一圈回到原位"): "环绕一圈",
    _norm_move_name("环绕推近_展示"): "环绕推近",
    _norm_move_name("组合_横摇推近"): "横摇推近",
    _norm_move_name("组合_横摇拉远"): "横摇拉远",
    _norm_move_name("组合_推近上移"): "推近上移",
    _norm_move_name("组合_推近下移"): "推近下移",
    _norm_move_name("电影呼吸_轻微变焦"): "呼吸变焦",
    _norm_move_name("多利变焦_假效果"): "多利变焦",
    _norm_move_name("摇臂_上升推近"): "摇臂上升推近",
    _norm_move_name("摇臂_下降拉远"): "摇臂下降拉远",
    _norm_move_name("横向摇镜"): "横摇左到右",
    _norm_move_name("横向摇镜l→r"): "横摇左到右",
    _norm_move_name("横向摇镜r→l"): "横摇右到左",
    _norm_move_name("纵向摇镜"): "纵摇上到下",
    _norm_move_name("推近zoomin"): "推近",
    _norm_move_name("拉远zoomout"): "拉远",
    _norm_move_name("环绕"): "环绕一圈",
    _norm_move_name("环绕主体"): "环绕一圈",
    _norm_move_name("组合横摇推近"): "横摇推近",
    _norm_move_name("组合横摇拉远"): "横摇拉远",
    _norm_move_name("推近+拉远"): "推近拉远",
    _norm_move_name("拉远+推近"): "拉远推近",
    _norm_move_name("推近拉远再推近"): "推近拉远推近",
    _norm_move_name("拉远推近再拉远"): "拉远推近拉远",
    _norm_move_name("手持抖动轻"): "手持轻",
    _norm_move_name("手持抖动中"): "手持中",
    _norm_move_name("手持抖动强"): "手持强",
}

def _to_canonical_move(move_name: str):
    raw = str(move_name or "").strip()
    if not raw:
        return ""
    n = _norm_move_name(raw)
    if n in MOVE_ALIASES_NORM:
        return MOVE_ALIASES_NORM[n]
    for v in MOVE_PRESETS:
        if _norm_move_name(v) == n:
            return v
    return raw

def _segment_target(move_name: str, strength: float, speed: float):
    m = _to_canonical_move(move_name)

    strength = float(max(0.0, min(3.0, strength)))
    speed = float(max(0.1, min(5.0, speed)))

    pan = 0.35 * strength
    tilt = 0.28 * strength

    dx = 0.0
    dy = 0.0
    zoom = 1.0
    rot = 0.0

    if m == "横摇左到右":
        dx = +pan
    elif m == "横摇右到左":
        dx = -pan
    elif m == "纵摇上到下":
        dy = +tilt
    elif m == "纵摇下到上":
        dy = -tilt
    elif m == "推近":
        zoom = 1.0 + 0.42 * strength
    elif m == "拉远":
        zoom = 1.0 - 0.28 * strength
    elif m == "斜移左上":
        dx, dy = -pan, -tilt
    elif m == "斜移右上":
        dx, dy = +pan, -tilt
    elif m == "斜移左下":
        dx, dy = -pan, +tilt
    elif m == "斜移右下":
        dx, dy = +pan, +tilt
    elif m == "横摇推近":
        dx = +pan
        zoom = 1.0 + 0.32 * strength
    elif m == "横摇拉远":
        dx = +pan
        zoom = 1.0 - 0.20 * strength
    elif m == "推近上移":
        dy = -tilt * 0.8
        zoom = 1.0 + 0.36 * strength
    elif m == "推近下移":
        dy = +tilt * 0.8
        zoom = 1.0 + 0.36 * strength
    elif m == "呼吸变焦":
        zoom = 1.0
    elif m == "环绕一圈":
        zoom = 1.0
    elif m == "环绕推近":
        zoom = 1.0 + 0.25 * strength
    elif m == "多利变焦":
        zoom = 1.0 + 0.42 * strength
        dx = -pan * 0.35
    elif m == "摇臂上升推近":
        dy = -tilt
        zoom = 1.0 + 0.30 * strength
    elif m == "摇臂下降拉远":
        dy = +tilt
        zoom = 1.0 - 0.20 * strength
    elif m == "推近拉远":
        zoom = 1.0
    elif m == "拉远推近":
        zoom = 1.0
    elif m == "推近拉远推近":
        zoom = 1.0
    elif m == "拉远推近拉远":
        zoom = 1.0
    elif m == "手持轻":
        zoom = 1.0
    elif m == "手持中":
        zoom = 1.0
    elif m == "手持强":
        zoom = 1.0
    else:
        dx = +pan

    return m, dx, dy, zoom, rot

def _handheld_curve(L, strength, speed, device=None, dtype=None):
    L = int(max(1, L))
    dev = device or torch.device("cpu")
    dt = dtype or torch.float32

    base = float(max(0.0, min(3.0, strength)))
    sp = float(max(0.1, min(5.0, speed)))

    jxy = 0.010 + 0.020 * base
    jz = 0.004 + 0.010 * base
    jr = 0.003 + 0.010 * base

    n1 = torch.randn(L, device=dev, dtype=dt) * (jxy * 0.9)
    n2 = torch.randn(L, device=dev, dtype=dt) * (jxy * 0.7)
    n3 = torch.randn(L, device=dev, dtype=dt) * (jz * 0.8)
    n4 = torch.randn(L, device=dev, dtype=dt) * (jr * 0.9)

    a = float(max(0.05, min(0.6, 0.18 + 0.08 * sp)))
    n1 = _ema_1d(n1, a)
    n2 = _ema_1d(n2, a)
    n3 = _ema_1d(n3, a)
    n4 = _ema_1d(n4, a)

    t = torch.linspace(0, 1, steps=L, device=dev, dtype=dt)
    w = (t * (1 - t)) * 4.0
    w = w.clamp(0, 1)

    dx = n1 * w
    dy = n2 * w
    z = n3 * w
    r = n4 * w

    dx = dx - torch.linspace(dx[0], dx[-1], steps=L, device=dev, dtype=dt)
    dy = dy - torch.linspace(dy[0], dy[-1], steps=L, device=dev, dtype=dt)
    z = z - torch.linspace(z[0], z[-1], steps=L, device=dev, dtype=dt)
    r = r - torch.linspace(r[0], r[-1], steps=L, device=dev, dtype=dt)

    return dx, dy, z, r

def _piecewise_zoom(L, cur_zoom, a0, a1, a2, amp):
    L = int(max(1, L))
    z = torch.ones(L) * float(cur_zoom)
    if L == 1:
        return z

    p0 = int(round(L * a0))
    p1 = int(round(L * a1))
    p2 = int(round(L * a2))
    p0 = max(1, min(L - 1, p0))
    p1 = max(p0 + 1, min(L - 1, p1))
    p2 = max(p1 + 1, min(L, p2))
    end0 = min(p0, L)
    end1 = min(p1, L)
    end2 = min(p2, L)

    z0 = float(cur_zoom)
    z1 = z0 * (1.0 + amp)
    z2 = z0 * (1.0 - amp * 0.85)
    z3 = z0 * (1.0 + amp)

    t0, t1, t2, t3 = 0.0, float(end0), float(end1), float(end2)
    m0 = 0.0
    m1 = (z2 - z0) / max(1e-6, (t2 - t0))
    m2 = (z3 - z1) / max(1e-6, (t3 - t1)) if t3 > t1 else 0.0
    m3 = 0.0

    def _hermite(p_start, p_end, m_start, m_end, length, out_slice):
        length = int(length)
        if length <= 0:
            return
        t = torch.arange(length, dtype=torch.float32) / float(length)
        t2 = t * t
        t3 = t2 * t
        h00 = (2 * t3) - (3 * t2) + 1
        h10 = (t3) - (2 * t2) + t
        h01 = (-2 * t3) + (3 * t2)
        h11 = (t3) - t2
        m_start = m_start * float(length)
        m_end = m_end * float(length)
        z_seg = (h00 * p_start) + (h10 * m_start) + (h01 * p_end) + (h11 * m_end)
        z[out_slice] = z_seg

    _hermite(z0, z1, m0, m1, end0 - 0, slice(0, end0))
    _hermite(z1, z2, m1, m2, end1 - end0, slice(end0, end1))
    _hermite(z2, z3, m2, m3, end2 - end1, slice(end1, end2))

    if end2 < L:
        z[end2:] = z3

    return z

def _build_trajectory(n_frames, fps, segments, single_move, strength, speed,
                     stabilize_level, overlap_frames, seed, enable_micro_jitter=False, jitter=0.0):
    torch.manual_seed(int(seed) if seed is not None else 0)

    n = int(n_frames)
    dx = torch.zeros(n)
    dy = torch.zeros(n)
    zoom = torch.ones(n)
    rot = torch.zeros(n)

    if n <= 1:
        return dx, dy, zoom, rot

    seg_list = list(segments) if segments else []
    if not seg_list:
        seg_list = [{"dur": (n / max(1, fps)), "move": single_move, "strength": None, "speed": None}]

    total_sec = sum(s["dur"] for s in seg_list)
    if total_sec <= 0:
        total_sec = (n / max(1, fps))
    scale = (n / max(1, fps)) / total_sec
    for s in seg_list:
        s["dur"] = max(1e-6, float(s["dur"]) * scale)

    seg_ranges = []
    cur = 0
    for i, s in enumerate(seg_list):
        frames_len = int(round(s["dur"] * fps))
        if i == len(seg_list) - 1:
            end = n
        else:
            end = min(n, cur + max(1, frames_len))
        if end <= cur:
            end = min(n, cur + 1)
        seg_ranges.append((cur, end, s))
        cur = end
        if cur >= n:
            break
    if seg_ranges and seg_ranges[-1][1] != n:
        a, _, s = seg_ranges[-1]
        seg_ranges[-1] = (a, n, s)

    cur_dx = 0.0
    cur_dy = 0.0
    cur_zoom = 1.0
    cur_rot = 0.0

    for (a, b, s) in seg_ranges:
        L = max(1, b - a)
        t = torch.linspace(0, 1, steps=L)
        sstep = _smoothstep(t)

        seg_strength = float(s["strength"]) if s.get("strength") is not None else float(strength)
        seg_speed = float(s["speed"]) if s.get("speed") is not None else float(speed)

        move_name, dx_end, dy_end, zoom_end, rot_end = _segment_target(s.get("move", single_move), seg_strength, seg_speed)
        move_name = _to_canonical_move(move_name)

        if str(move_name).startswith("环绕"):
            ang = 2 * math.pi * sstep * seg_speed
            r_x = 0.28 * seg_strength
            r_y = 0.18 * seg_strength
            dx_seg = cur_dx + r_x * torch.cos(ang)
            dy_seg = cur_dy + r_y * torch.sin(ang)
            if move_name == "环绕推近":
                z_to = cur_zoom * float(zoom_end)
                zoom_seg = cur_zoom + (z_to - cur_zoom) * sstep
            else:
                zoom_seg = torch.full((L,), float(cur_zoom))
            rot_seg = torch.full((L,), float(cur_rot))

            cur_dx = float(dx_seg[-1].item())
            cur_dy = float(dy_seg[-1].item())
            cur_zoom = float(zoom_seg[-1].item())
            cur_rot = float(rot_seg[-1].item())

            dx[a:b] = dx_seg
            dy[a:b] = dy_seg
            zoom[a:b] = zoom_seg
            rot[a:b] = rot_seg
            continue

        if move_name == "呼吸变焦":
            amp = 0.06 * seg_strength
            phase = 2 * math.pi * sstep * seg_speed
            zoom_seg = cur_zoom * (1.0 + amp * torch.sin(phase))
            dx_seg = torch.full((L,), float(cur_dx))
            dy_seg = torch.full((L,), float(cur_dy))
            rot_seg = torch.full((L,), float(cur_rot))

            cur_dx = float(dx_seg[-1].item())
            cur_dy = float(dy_seg[-1].item())
            cur_zoom = float(cur_zoom)
            cur_rot = float(cur_rot)

            dx[a:b] = dx_seg
            dy[a:b] = dy_seg
            zoom[a:b] = zoom_seg
            rot[a:b] = rot_seg
            continue

        if move_name in ("推近拉远", "拉远推近", "推近拉远推近", "拉远推近拉远"):
            amp = float(max(0.02, min(0.20, 0.10 * seg_strength)))
            if move_name == "推近拉远":
                z = torch.ones(L) * float(cur_zoom)
                p = max(1, int(round(L * 0.5)))
                s0 = sstep[:p]
                s1 = sstep[p:]
                z1 = float(cur_zoom) * (1.0 + amp)
                z[:p] = float(cur_zoom) + (z1 - float(cur_zoom)) * s0
                if L - p > 0:
                    z[p:] = z1 + (float(cur_zoom) - z1) * (s1 - s1[0]) / max(1e-6, float(s1[-1] - s1[0]))
                zoom_seg = z
            elif move_name == "拉远推近":
                z = torch.ones(L) * float(cur_zoom)
                p = max(1, int(round(L * 0.5)))
                s0 = sstep[:p]
                s1 = sstep[p:]
                z1 = float(cur_zoom) * (1.0 - amp * 0.85)
                z[:p] = float(cur_zoom) + (z1 - float(cur_zoom)) * s0
                if L - p > 0:
                    z[p:] = z1 + (float(cur_zoom) - z1) * (s1 - s1[0]) / max(1e-6, float(s1[-1] - s1[0]))
                zoom_seg = z
            elif move_name == "推近拉远推近":
                zoom_seg = _piecewise_zoom(L, cur_zoom, 1 / 3, 2 / 3, 1.0, amp)
            else:
                z = _piecewise_zoom(L, cur_zoom, 1 / 3, 2 / 3, 1.0, amp)
                z = float(cur_zoom) * 2.0 - z
                zoom_seg = z

            dx_seg = torch.full((L,), float(cur_dx))
            dy_seg = torch.full((L,), float(cur_dy))
            rot_seg = torch.full((L,), float(cur_rot))

            cur_dx = float(dx_seg[-1].item())
            cur_dy = float(dy_seg[-1].item())
            cur_zoom = float(zoom_seg[-1].item())
            cur_rot = float(rot_seg[-1].item())

            dx[a:b] = dx_seg
            dy[a:b] = dy_seg
            zoom[a:b] = zoom_seg
            rot[a:b] = rot_seg
            continue

        if move_name in ("手持轻", "手持中", "手持强"):
            level = {"手持轻": 0.7, "手持中": 1.2, "手持强": 1.9}[move_name]
            ddx, ddy, dz, dr = _handheld_curve(L, seg_strength * level, seg_speed, device=None, dtype=None)
            dx_seg = torch.full((L,), float(cur_dx)) + ddx
            dy_seg = torch.full((L,), float(cur_dy)) + ddy
            zoom_seg = torch.full((L,), float(cur_zoom)) * (1.0 + dz).clamp(0.2, 3.0)
            rot_seg = torch.full((L,), float(cur_rot)) + dr

            cur_dx = float(dx_seg[-1].item())
            cur_dy = float(dy_seg[-1].item())
            cur_zoom = float(zoom_seg[-1].item())
            cur_rot = float(rot_seg[-1].item())

            dx[a:b] = dx_seg
            dy[a:b] = dy_seg
            zoom[a:b] = zoom_seg
            rot[a:b] = rot_seg
            continue

        dx_to = cur_dx + float(dx_end)
        dy_to = cur_dy + float(dy_end)
        zoom_to = cur_zoom * float(zoom_end)
        rot_to = cur_rot + float(rot_end)

        dx_seg = cur_dx + (dx_to - cur_dx) * sstep
        dy_seg = cur_dy + (dy_to - cur_dy) * sstep
        zoom_seg = cur_zoom + (zoom_to - cur_zoom) * sstep
        rot_seg = cur_rot + (rot_to - cur_rot) * sstep

        cur_dx = float(dx_seg[-1].item())
        cur_dy = float(dy_seg[-1].item())
        cur_zoom = float(zoom_seg[-1].item())
        cur_rot = float(rot_seg[-1].item())

        dx[a:b] = dx_seg
        dy[a:b] = dy_seg
        zoom[a:b] = zoom_seg
        rot[a:b] = rot_seg

    ov = int(max(0, overlap_frames))
    if ov > 0 and len(seg_ranges) > 1:
        dx, dy, zoom, rot = _smooth_params(dx, dy, zoom, rot, "SOFT")

    dx, dy, zoom, rot = _smooth_params(dx, dy, zoom, rot, stabilize_level)

    if enable_micro_jitter and float(jitter) > 0:
        j = float(jitter)
        noise = (torch.randn_like(dx) * j)
        dx = dx + noise * 0.6
        dy = dy + noise * 0.4

    return dx, dy, zoom.clamp(min=0.2, max=3.0), rot

def _auto_pad_scale(dx_px, dy_px, zoom, h, w, safety_margin):
    zmin = float(zoom.min().item())
    zmin = max(0.2, zmin)
    need_zoom = 1.0 / zmin

    max_dx = float(dx_px.abs().max().item())
    max_dy = float(dy_px.abs().max().item())
    need_tx = 1.0 + (2.0 * max_dx / max(1.0, w))
    need_ty = 1.0 + (2.0 * max_dy / max(1.0, h))

    need = max(need_zoom, need_tx, need_ty)
    need = need + float(max(0.0, safety_margin))
    need = need + 0.02
    return float(max(1.0, min(4.0, need)))

class CameraMovePro:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "frames": ("IMAGE",),
                "audio": ("AUDIO",),

                "模式": (["一镜到底", "分段运镜"], {"default": "一镜到底"}),

                "运镜类型": (MOVE_PRESETS, {"default": "横摇左到右"}),

                "帧率FPS": ("INT", {"default": 25, "min": 1, "max": 240}),
                "时长秒": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 36000.0}),

                "强度": ("FLOAT", {"default": 1.20, "min": 0.0, "max": 3.0, "step": 0.01}),
                "速度": ("FLOAT", {"default": 0.85, "min": 0.1, "max": 5.0, "step": 0.01}),

                "稳定模式": (["HARD", "SOFT", "OFF"], {"default": "HARD"}),
                "段间平滑": ("INT", {"default": 8, "min": 0, "max": 60}),

                "不露边": (["开启", "关闭"], {"default": "开启"}),
                "扩边模式": (["镜像扩边", "模糊背景扩边", "黑边扩边", "关闭"], {"default": "镜像扩边"}),
                "扩边比例": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 2.0, "step": 0.01}),
                "扩边模糊强度": ("FLOAT", {"default": 0.30, "min": 0.0, "max": 1.0, "step": 0.01}),
                "安全边距": ("FLOAT", {"default": 0.08, "min": 0.0, "max": 0.35, "step": 0.01}),

                "随机种子": ("INT", {"default": 12345, "min": 0, "max": 2147483647}),
                "每次随机": (["关闭", "开启"], {"default": "关闭"}),

                "分块chunk": ("INT", {"default": 12, "min": 1, "max": 128}),
                "FP16": (["Auto", "On", "Off"], {"default": "Auto"}),

                "分段脚本": ("STRING", {
                    "multiline": True,
                    "default": """# 每行：秒数 镜头类型 [强度=] [速度=]
3 横摇左到右
3 推近拉远推近
4 拉远
"""
                }),

                "音频透传": (["开启", "关闭"], {"default": "开启"}),
            }
        }

    RETURN_TYPES = ("IMAGE", "AUDIO")
    RETURN_NAMES = ("frames", "audio")
    FUNCTION = "run"
    CATEGORY = "视频/相机运镜"

    def run(self, frames, audio,
            模式, 运镜类型, 帧率FPS, 时长秒, 强度, 速度,
            稳定模式, 段间平滑,
            不露边, 扩边模式, 扩边比例, 扩边模糊强度, 安全边距,
            随机种子, 每次随机,
            分块chunk, FP16,
            分段脚本,
            音频透传):

        audio = _safe_audio(audio) if (音频透传 == "开启") else None

        if not isinstance(frames, torch.Tensor):
            return (frames, audio)

        n, h, w, c = frames.shape
        fps = int(帧率FPS)

        seed = int(torch.randint(0, 2147483647, (1,)).item()) if 每次随机 == "开启" else int(随机种子)

        target_n = n
        if float(时长秒) > 0:
            target_n = max(1, int(round(float(时长秒) * fps)))

        if target_n != n:
            idx = torch.linspace(0, n - 1, steps=target_n, device=frames.device)
            idx0 = torch.floor(idx).long().clamp(0, n - 1)
            idx1 = (idx0 + 1).clamp(0, n - 1)
            w1 = (idx - idx0.float()).view(-1, 1, 1, 1)
            f0 = frames[idx0]
            f1 = frames[idx1]
            frames = f0 * (1 - w1) + f1 * w1
            n = target_n

        segments = []
        if 模式 == "分段运镜":
            segments = _parse_segments(分段脚本)

        dx_r, dy_r, zoom, rot = _build_trajectory(
            n_frames=n,
            fps=fps,
            segments=segments,
            single_move=运镜类型,
            strength=float(强度),
            speed=float(速度),
            stabilize_level=str(稳定模式),
            overlap_frames=int(段间平滑),
            seed=seed,
            enable_micro_jitter=False,
            jitter=0.0
        )

        base = float(min(h, w))
        dx_px = dx_r * base
        dy_px = dy_r * base

        leak_safe = (不露边 == "开启")
        pad_mode = str(扩边模式)
        blur_strength = float(扩边模糊强度)

        auto_need = _auto_pad_scale(dx_px, dy_px, zoom, h, w, float(安全边距)) if leak_safe else 1.0
        user_scale = float(扩边比例)
        pad_scale = auto_need if user_scale <= 0.0 else max(user_scale, auto_need)
        pad_scale = float(max(1.0, min(4.0, pad_scale)))

        device = frames.device
        if device.type == "cpu" and torch.cuda.is_available():
            device = torch.device("cuda")

        fp16mode = str(FP16)
        use_fp16 = (device.type == "cuda") if fp16mode == "Auto" else (fp16mode == "On")
        dtype = torch.float16 if use_fp16 else torch.float32

        chunk = int(max(1, int(分块chunk)))

        outs = []
        with torch.no_grad():
            dx_cpu = dx_px.cpu()
            dy_cpu = dy_px.cpu()
            zoom_cpu = zoom.cpu()
            rot_cpu = rot.cpu()

            for start in range(0, n, chunk):
                end = min(n, start + chunk)

                f = frames[start:end].to(device=device, dtype=dtype, non_blocking=True)
                f_exp, crop = _expand_canvas(f, scale=pad_scale, mode=pad_mode, blur_strength=blur_strength)

                dx_d = dx_cpu[start:end].to(device=device, dtype=dtype, non_blocking=True)
                dy_d = dy_cpu[start:end].to(device=device, dtype=dtype, non_blocking=True)
                zm_d = zoom_cpu[start:end].to(device=device, dtype=dtype, non_blocking=True)
                rt_d = rot_cpu[start:end].to(device=device, dtype=dtype, non_blocking=True)

                out = _apply_affine(f_exp, dx_d, dy_d, zm_d, rt_d)
                out = _crop_back(out, crop)

                outs.append(out.to("cpu", dtype=torch.float32).contiguous())

                del f, f_exp, out, dx_d, dy_d, zm_d, rt_d
                if device.type == "cuda":
                    torch.cuda.empty_cache()

        out_frames = torch.cat(outs, dim=0)
        return (out_frames, audio)
