# -*- coding: utf-8 -*-
"""
Seed-VC 推理模块
封装Seed-VC的推理功能，提供简单易用的接口
"""

import os
import sys
import tempfile
from pathlib import Path
from typing import Optional, Union
import numpy as np

# 设置HuggingFace镜像
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")


class SeedVCInfer:
    """Seed-VC 推理类"""
    
    def __init__(
        self,
        device: str = "cuda",
        fp16: bool = True,
        diffusion_steps: int = 25,
    ):
        """
        初始化Seed-VC推理器
        
        Args:
            device: 计算设备 ("cuda" 或 "cpu")
            fp16: 是否使用半精度
            diffusion_steps: 扩散步数 (25快速, 50高质量)
        """
        self.device = device
        self.fp16 = fp16
        self.diffusion_steps = diffusion_steps
        
        self.model = None
        self.semantic_fn = None
        self.f0_fn = None
        self.vocoder_fn = None
        self.campplus_model = None
        self.mel_fn = None
        self.mel_fn_args = None
        
        self._loaded = False
        
        # Seed-VC源码路径（Downloads目录下）
        self.seed_vc_path = Path("C:/Users/ggpq/Downloads/seed-vc-test")
        
    def _ensure_seed_vc_path(self):
        """确保Seed-VC源码在Python路径中"""
        seed_vc_str = str(self.seed_vc_path)
        if seed_vc_str not in sys.path:
            sys.path.insert(0, seed_vc_str)
    
    def load_models(self) -> bool:
        """加载模型"""
        if self._loaded:
            return True
        
        try:
            self._ensure_seed_vc_path()
            
            print("[Seed-VC] 加载模型...")
            
            # 导入必要模块
            import torch
            from inference import load_models
            from types import SimpleNamespace
            
            # 构建参数
            args = SimpleNamespace(
                checkpoint=None,  # 自动下载
                config=None,
                f0_condition=False,
                fp16=self.fp16,
            )
            
            # 加载模型
            (
                self.model,
                self.semantic_fn,
                self.f0_fn,
                self.vocoder_fn,
                self.campplus_model,
                self.mel_fn,
                self.mel_fn_args
            ) = load_models(args)
            
            self._loaded = True
            print("[Seed-VC] ✅ 模型加载完成")
            return True
            
        except Exception as e:
            print(f"[Seed-VC] ❌ 模型加载失败: {e}")
            import traceback
            traceback.print_exc()
            return False
    
    def convert(
        self,
        source_audio: Union[str, np.ndarray],
        reference_audio: Union[str, np.ndarray],
        output_path: Optional[str] = None,
        sr: int = 22050,
        semi_tone_shift: int = 0,
        auto_f0_adjust: bool = True,
        length_adjust: float = 1.0,
        inference_cfg_rate: float = 0.7,
    ) -> Optional[str]:
        """
        语音转换
        
        Args:
            source_audio: 源音频路径或numpy数组
            reference_audio: 参考音频路径或numpy数组
            output_path: 输出路径，None则自动生成
            sr: 采样率
            semi_tone_shift: 音高偏移（半音）
            auto_f0_adjust: 自动调整音高
            length_adjust: 长度调整因子
            inference_cfg_rate: 推理配置率
        
        Returns:
            输出文件路径，失败返回None
        """
        if not self._loaded:
            if not self.load_models():
                return None
        
        try:
            self._ensure_seed_vc_path()
            
            import torch
            import torchaudio
            from inference import convert_voice
            from types import SimpleNamespace
            
            # 处理输入
            if isinstance(source_audio, np.ndarray):
                source_path = tempfile.mktemp(suffix=".wav")
                import soundfile as sf
                sf.write(source_path, source_audio, sr)
            else:
                source_path = source_audio
            
            if isinstance(reference_audio, np.ndarray):
                reference_path = tempfile.mktemp(suffix=".wav")
                import soundfile as sf
                sf.write(reference_path, reference_audio, sr)
            else:
                reference_path = reference_audio
            
            # 输出路径
            if output_path is None:
                output_dir = tempfile.mkdtemp()
            else:
                output_dir = str(Path(output_path).parent)
                Path(output_dir).mkdir(parents=True, exist_ok=True)
            
            # 构建参数
            args = SimpleNamespace(
                source=source_path,
                target=reference_path,
                output=output_dir,
                diffusion_steps=self.diffusion_steps,
                length_adjust=length_adjust,
                inference_cfg_rate=inference_cfg_rate,
                f0_condition=False,
                auto_f0_adjust=auto_f0_adjust,
                semi_tone_shift=semi_tone_shift,
                fp16=self.fp16,
            )
            
            print(f"[Seed-VC] 转换中... (steps={self.diffusion_steps})")
            
            # 执行转换
            result_path = convert_voice(
                args,
                self.model,
                self.semantic_fn,
                self.f0_fn,
                self.vocoder_fn,
                self.campplus_model,
                self.mel_fn,
                self.mel_fn_args
            )
            
            print(f"[Seed-VC] ✅ 转换完成: {result_path}")
            return result_path
            
        except Exception as e:
            print(f"[Seed-VC] ❌ 转换失败: {e}")
            import traceback
            traceback.print_exc()
            return None
    
    def convert_simple(
        self,
        source_path: str,
        reference_path: str,
        output_path: str,
        semi_tone_shift: int = 0,
    ) -> bool:
        """
        简单转换接口（命令行方式）
        
        更稳定，直接调用inference.py
        """
        try:
            import subprocess
            
            cmd = [
                sys.executable,
                str(self.seed_vc_path / "inference.py"),
                "--source", source_path,
                "--target", reference_path,
                "--output", str(Path(output_path).parent),
                "--diffusion-steps", str(self.diffusion_steps),
                "--semi-tone-shift", str(semi_tone_shift),
                "--auto-f0-adjust", "True",
                "--length-adjust", "1.0",
                "--inference-cfg-rate", "0.7",
                "--fp16", "True"
            ]
            
            env = os.environ.copy()
            env["HF_ENDPOINT"] = "https://hf-mirror.com"
            
            print(f"[Seed-VC] 执行转换...")
            result = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                env=env,
                cwd=str(self.seed_vc_path)
            )
            
            if result.returncode == 0 or "RTF:" in result.stderr:
                print(f"[Seed-VC] ✅ 转换完成")
                return True
            else:
                print(f"[Seed-VC] ❌ 转换失败: {result.stderr}")
                return False
                
        except Exception as e:
            print(f"[Seed-VC] ❌ 转换失败: {e}")
            return False


# 单例实例
_infer_instance: Optional[SeedVCInfer] = None


def get_infer(
    device: str = "cuda",
    fp16: bool = True,
    diffusion_steps: int = 25,
) -> SeedVCInfer:
    """获取推理实例（单例）"""
    global _infer_instance
    
    if _infer_instance is None:
        _infer_instance = SeedVCInfer(
            device=device,
            fp16=fp16,
            diffusion_steps=diffusion_steps
        )
    
    return _infer_instance
