# Copyright (c) ModelScope Contributors. All rights reserved.
import os
import sys

from transformers import PreTrainedModel

from swift.template import TemplateType
from swift.utils import get_device, git_clone_github
from ..constant import LLMModelType, MLLMModelType
from ..model_arch import ModelArch
from ..model_meta import Model, ModelGroup, ModelMeta
from ..register import ModelLoader, register_model


class LlamaLoader(ModelLoader):

    def get_config(self, model_dir):
        config = super().get_config(model_dir)
        if getattr(config, 'pretraining_tp', 1) > 1:
            config.pretraining_tp = 1
        return config


register_model(
    ModelMeta(
        LLMModelType.llama,
        [
            # llama2
            ModelGroup(
                [
                    # base
                    Model('modelscope/Llama-2-7b-ms', 'meta-llama/Llama-2-7b-hf'),
                    Model('modelscope/Llama-2-13b-ms', 'meta-llama/Llama-2-13b-hf'),
                    Model('modelscope/Llama-2-70b-ms', 'meta-llama/Llama-2-70b-hf'),
                    # chat
                    Model('modelscope/Llama-2-7b-chat-ms', 'meta-llama/Llama-2-7b-chat-hf'),
                    Model('modelscope/Llama-2-13b-chat-ms', 'meta-llama/Llama-2-13b-chat-hf'),
                    Model('modelscope/Llama-2-70b-chat-ms', 'meta-llama/Llama-2-70b-chat-hf'),
                ],
                TemplateType.llama,
                ignore_patterns=[r'.+\.bin$']),
            # chinese-llama2
            ModelGroup(
                [
                    # base
                    Model('AI-ModelScope/chinese-llama-2-1.3b', 'hfl/chinese-llama-2-1.3b'),
                    Model('AI-ModelScope/chinese-llama-2-7b', 'hfl/chinese-llama-2-7b'),
                    Model('AI-ModelScope/chinese-llama-2-7b-16k', 'hfl/chinese-llama-2-7b-16k'),
                    Model('AI-ModelScope/chinese-llama-2-7b-64k', 'hfl/chinese-llama-2-7b-64k'),
                    Model('AI-ModelScope/chinese-llama-2-13b', 'hfl/chinese-llama-2-13b'),
                    Model('AI-ModelScope/chinese-llama-2-13b-16k', 'hfl/chinese-llama-2-13b-16k'),
                    # chat
                    Model('AI-ModelScope/chinese-alpaca-2-1.3b', 'hfl/chinese-alpaca-2-1.3b'),
                    Model('AI-ModelScope/chinese-alpaca-2-7b', 'hfl/chinese-alpaca-2-7b'),
                    Model('AI-ModelScope/chinese-alpaca-2-7b-16k', 'hfl/chinese-alpaca-2-7b-16k'),
                    Model('AI-ModelScope/chinese-alpaca-2-7b-64k', 'hfl/chinese-alpaca-2-7b-64k'),
                    Model('AI-ModelScope/chinese-alpaca-2-13b', 'hfl/chinese-alpaca-2-13b'),
                    Model('AI-ModelScope/chinese-alpaca-2-13b-16k', 'hfl/chinese-alpaca-2-13b-16k'),
                ],
                TemplateType.llama),
            # base quant
            ModelGroup([
                Model('AI-ModelScope/Llama-2-7b-AQLM-2Bit-1x16-hf', 'ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf'),
            ],
                       TemplateType.llama,
                       requires=['transformers>=4.38', 'aqlm', 'torch>=2.2.0']),
            ModelGroup([
                Model('FlagAlpha/Atom-7B', 'FlagAlpha/Atom-7B'),
                Model('FlagAlpha/Atom-7B-Chat', 'FlagAlpha/Atom-7B-Chat'),
            ],
                       template=TemplateType.atom),
            ModelGroup([
                Model('langboat/Mengzi3-13B-Base', 'Langboat/Mengzi3-13B-Base'),
            ],
                       template=TemplateType.mengzi),
            ModelGroup([
                Model('AI-ModelScope/NuminaMath-7B-TIR', 'AI-MO/NuminaMath-7B-TIR'),
            ],
                       template=TemplateType.numina,
                       tags=['math']),
            ModelGroup([
                Model('Fengshenbang/Ziya2-13B-Base', 'IDEA-CCNL/Ziya2-13B-Base'),
                Model('Fengshenbang/Ziya2-13B-Chat', 'IDEA-CCNL/Ziya2-13B-Chat'),
            ],
                       template=TemplateType.ziya),
            ModelGroup([
                Model('InfiniAI/Megrez-3b-Instruct', 'Infinigence/Megrez-3B-Instruct'),
            ], TemplateType.megrez),
            # deepseek
            ModelGroup([
                Model('deepseek-ai/deepseek-llm-7b-base', 'deepseek-ai/deepseek-llm-7b-base'),
                Model('deepseek-ai/deepseek-llm-7b-chat', 'deepseek-ai/deepseek-llm-7b-chat'),
                Model('deepseek-ai/deepseek-llm-67b-base', 'deepseek-ai/deepseek-llm-67b-base'),
                Model('deepseek-ai/deepseek-llm-67b-chat', 'deepseek-ai/deepseek-llm-67b-chat'),
            ], TemplateType.deepseek),
            ModelGroup(
                [
                    Model('deepseek-ai/deepseek-math-7b-base', 'deepseek-ai/deepseek-math-7b-base'),
                    Model('deepseek-ai/deepseek-math-7b-instruct', 'deepseek-ai/deepseek-math-7b-instruct'),
                    Model('deepseek-ai/deepseek-math-7b-rl', 'deepseek-ai/deepseek-math-7b-rl'),
                ],
                TemplateType.deepseek,
                tags=['math'],
            ),
            ModelGroup(
                [
                    Model('deepseek-ai/deepseek-coder-1.3b-base', 'deepseek-ai/deepseek-coder-1.3b-base'),
                    Model('deepseek-ai/deepseek-coder-1.3b-instruct', 'deepseek-ai/deepseek-coder-1.3b-instruct'),
                    Model('deepseek-ai/deepseek-coder-6.7b-base', 'deepseek-ai/deepseek-coder-6.7b-base'),
                    Model('deepseek-ai/deepseek-coder-6.7b-instruct', 'deepseek-ai/deepseek-coder-6.7b-instruct'),
                    Model('deepseek-ai/deepseek-coder-33b-base', 'deepseek-ai/deepseek-coder-33b-base'),
                    Model('deepseek-ai/deepseek-coder-33b-instruct', 'deepseek-ai/deepseek-coder-33b-instruct'),
                ],
                TemplateType.deepseek,
                tags=['coding'],
            ),
            # MiniMind2
            ModelGroup(
                [
                    # MiniMind2
                    Model('gongjy/MiniMind2', 'jingyaogong/MiniMind2'),
                    # MiniMind2-Small
                    Model(None, 'jingyaogong/MiniMind2-Small'),
                ],
                TemplateType.minimind,
                requires=['transformers>=4.57.1']),
            # llama3
            ModelGroup(
                [
                    # chat
                    Model('LLM-Research/Meta-Llama-3-8B-Instruct', 'meta-llama/Meta-Llama-3-8B-Instruct'),
                    Model('LLM-Research/Meta-Llama-3-70B-Instruct', 'meta-llama/Meta-Llama-3-70B-Instruct'),
                    # base
                    Model('LLM-Research/Meta-Llama-3-8B', 'meta-llama/Meta-Llama-3-8B'),
                    Model('LLM-Research/Meta-Llama-3-70B', 'meta-llama/Meta-Llama-3-70B'),
                ],
                TemplateType.llama3),
            # llama3-quant
            ModelGroup([
                Model('swift/Meta-Llama-3-8B-Instruct-GPTQ-Int4', 'study-hjt/Meta-Llama-3-8B-Instruct-GPTQ-Int4'),
                Model('swift/Meta-Llama-3-8B-Instruct-GPTQ-Int8', 'study-hjt/Meta-Llama-3-8B-Instruct-GPTQ-Int8'),
                Model('swift/Meta-Llama-3-8B-Instruct-AWQ', 'study-hjt/Meta-Llama-3-8B-Instruct-AWQ'),
                Model('swift/Meta-Llama-3-70B-Instruct-GPTQ-Int4', 'study-hjt/Meta-Llama-3-70B-Instruct-GPTQ-Int4'),
                Model('swift/Meta-Llama-3-70B-Instruct-GPTQ-Int8', 'study-hjt/Meta-Llama-3-70B-Instruct-GPTQ-Int8'),
                Model('swift/Meta-Llama-3-70B-Instruct-AWQ', 'study-hjt/Meta-Llama-3-70B-Instruct-AWQ'),
            ], TemplateType.llama3),
            # chinese-llama3
            ModelGroup([
                Model('ChineseAlpacaGroup/llama-3-chinese-8b-instruct', 'hfl/llama-3-chinese-8b-instruct'),
                Model('ChineseAlpacaGroup/llama-3-chinese-8b', 'hfl/llama-3-chinese-8b'),
            ], TemplateType.llama3),
            # llama3.1
            ModelGroup(
                [
                    # chat
                    Model('LLM-Research/Meta-Llama-3.1-8B-Instruct', 'meta-llama/Meta-Llama-3.1-8B-Instruct'),
                    Model('LLM-Research/Meta-Llama-3.1-70B-Instruct', 'meta-llama/Meta-Llama-3.1-70B-Instruct'),
                    Model('LLM-Research/Meta-Llama-3.1-405B-Instruct', 'meta-llama/Meta-Llama-3.1-405B-Instruct'),
                    # base
                    Model('LLM-Research/Meta-Llama-3.1-8B', 'meta-llama/Meta-Llama-3.1-8B'),
                    Model('LLM-Research/Meta-Llama-3.1-70B', 'meta-llama/Meta-Llama-3.1-70B'),
                    Model('LLM-Research/Meta-Llama-3.1-405B', 'meta-llama/Meta-Llama-3.1-405B'),
                    # fp8
                    Model('LLM-Research/Meta-Llama-3.1-70B-Instruct-FP8', 'meta-llama/Meta-Llama-3.1-70B-Instruct-FP8'),
                    Model('LLM-Research/Meta-Llama-3.1-405B-Instruct-FP8',
                          'meta-llama/Meta-Llama-3.1-405B-Instruct-FP8'),
                ],
                TemplateType.llama3_2,
                requires=['transformers>=4.43']),
            # llama3.1-quant
            ModelGroup(
                [
                    # bnb-nf4
                    Model('LLM-Research/Meta-Llama-3.1-8B-Instruct-BNB-NF4',
                          'hugging-quants/Meta-Llama-3.1-8B-Instruct-BNB-NF4'),
                    Model('LLM-Research/Meta-Llama-3.1-70B-Instruct-bnb-4bit',
                          'unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit'),
                    Model('LLM-Research/Meta-Llama-3.1-405B-Instruct-BNB-NF4',
                          'hugging-quants/Meta-Llama-3.1-405B-Instruct-BNB-NF4'),
                    # gptq-int4
                    Model('LLM-Research/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4',
                          'hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4'),
                    Model('LLM-Research/Meta-Llama-3.1-70B-Instruct-GPTQ-INT4',
                          'hugging-quants/Meta-Llama-3.1-70B-Instruct-GPTQ-INT4'),
                    Model('LLM-Research/Meta-Llama-3.1-405B-Instruct-GPTQ-INT4',
                          'hugging-quants/Meta-Llama-3.1-405B-Instruct-GPTQ-INT4'),
                    # awq-int4
                    Model('LLM-Research/Meta-Llama-3.1-8B-Instruct-AWQ-INT4',
                          'hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4'),
                    Model('LLM-Research/Meta-Llama-3.1-70B-Instruct-AWQ-INT4',
                          'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4'),
                    Model('LLM-Research/Meta-Llama-3.1-405B-Instruct-AWQ-INT4',
                          'hugging-quants/Meta-Llama-3.1-405B-Instruct-AWQ-INT4'),
                ],
                TemplateType.llama3_2,
                requires=['transformers>=4.43']),
            # nvidia Nemotron
            ModelGroup([
                Model('AI-ModelScope/Llama-3.1-Nemotron-70B-Instruct-HF', 'nvidia/Llama-3.1-Nemotron-70B-Instruct-HF'),
            ],
                       TemplateType.llama3_2,
                       requires=['transformers>=4.43']),
            ModelGroup([
                Model('AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B', 'Skywork/Skywork-o1-Open-Llama-3.1-8B'),
            ],
                       TemplateType.skywork_o1,
                       requires=['transformers>=4.43']),
            ModelGroup([
                Model('LLM-Research/Llama-3.2-1B', 'meta-llama/Llama-3.2-1B'),
                Model('LLM-Research/Llama-3.2-3B', 'meta-llama/Llama-3.2-3B'),
                Model('LLM-Research/Llama-3.2-1B-Instruct', 'meta-llama/Llama-3.2-1B-Instruct'),
                Model('LLM-Research/Llama-3.2-3B-Instruct', 'meta-llama/Llama-3.2-3B-Instruct'),
            ],
                       template=TemplateType.llama3_2,
                       requires=['transformers>=4.43']),
            ModelGroup([
                Model('LLM-Research/Llama-3.3-70B-Instruct', 'meta-llama/Llama-3.3-70B-Instruct'),
                Model('unsloth/Llama-3.3-70B-Instruct-bnb-4bit', 'unsloth/Llama-3.3-70B-Instruct-bnb-4bit'),
            ],
                       template=TemplateType.llama3_2,
                       requires=['transformers>=4.43']),
            ModelGroup([
                Model('ZhipuAI/LongWriter-llama3.1-8b', 'zai-org/LongWriter-llama3.1-8b'),
            ],
                       TemplateType.longwriter_llama,
                       requires=['transformers>=4.43']),
            ModelGroup([
                Model('deepseek-ai/DeepSeek-R1-Distill-Llama-8B', 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'),
                Model('deepseek-ai/DeepSeek-R1-Distill-Llama-70B', 'deepseek-ai/DeepSeek-R1-Distill-Llama-70B'),
            ], TemplateType.deepseek_r1),
            ModelGroup([
                Model('LLM-Research/Reflection-Llama-3.1-70B', 'mattshumer/Reflection-Llama-3.1-70B'),
            ],
                       TemplateType.reflection,
                       requires=['transformers>=4.43']),
        ],
        LlamaLoader,
        model_arch=ModelArch.llama,
        architectures=['LlamaForCausalLM'],
    ))


class Llama3_2VisionLoader(ModelLoader):

    def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:
        from transformers import MllamaForConditionalGeneration
        self.auto_model_cls = self.auto_model_cls or MllamaForConditionalGeneration
        return super().get_model(model_dir, *args, **kwargs)


register_model(
    ModelMeta(
        MLLMModelType.llama3_2_vision,
        [
            ModelGroup([
                Model('LLM-Research/Llama-3.2-11B-Vision-Instruct', 'meta-llama/Llama-3.2-11B-Vision-Instruct'),
                Model('LLM-Research/Llama-3.2-90B-Vision-Instruct', 'meta-llama/Llama-3.2-90B-Vision-Instruct'),
                Model('LLM-Research/Llama-3.2-11B-Vision', 'meta-llama/Llama-3.2-11B-Vision'),
                Model('LLM-Research/Llama-3.2-90B-Vision', 'meta-llama/Llama-3.2-90B-Vision'),
            ])
        ],
        Llama3_2VisionLoader,
        template=TemplateType.llama3_2_vision,
        requires=['transformers>=4.45'],
        architectures=['MllamaForConditionalGeneration'],
        model_arch=ModelArch.llama3_2_vision,
        tags=['vision'],
    ))


class Llama4Loader(ModelLoader):

    def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:
        from transformers import Llama4ForConditionalGeneration
        self.auto_model_cls = self.auto_model_cls or Llama4ForConditionalGeneration
        return super().get_model(model_dir, *args, **kwargs)


register_model(
    ModelMeta(
        MLLMModelType.llama4,
        [
            ModelGroup([
                Model('LLM-Research/Llama-4-Scout-17B-16E', 'meta-llama/Llama-4-Scout-17B-16E'),
                Model('LLM-Research/Llama-4-Maverick-17B-128E', 'meta-llama/Llama-4-Maverick-17B-128E'),
                Model('LLM-Research/Llama-4-Scout-17B-16E-Instruct', 'meta-llama/Llama-4-Scout-17B-16E-Instruct'),
                Model('LLM-Research/Llama-4-Maverick-17B-128E-Instruct-FP8',
                      'meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8'),
                Model('LLM-Research/Llama-4-Maverick-17B-128E-Instruct',
                      'meta-llama/Llama-4-Maverick-17B-128E-Instruct'),
            ])
        ],
        Llama4Loader,
        template=TemplateType.llama4,
        requires=['transformers>=4.51'],
        model_arch=ModelArch.llama4,
        architectures=['Llama4ForConditionalGeneration'],
        tags=['vision'],
    ))


class Llama3OmniLoader(ModelLoader):

    def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel:
        local_repo_path = self.local_repo_path
        if not local_repo_path:
            local_repo_path = git_clone_github('https://github.com/ictnlp/LLaMA-Omni')
        sys.path.append(self.local_repo_path)
        from omni_speech.model import OmniSpeech2SLlamaForCausalLM, OmniSpeechLlamaForCausalLM
        import whisper
        config.speech_encoder = os.path.join(model_dir, 'large-v3.pt')
        if not os.path.exists(config.speech_encoder):
            whisper.load_model('large-v3', download_root=model_dir)
        self.auto_model_cls = self.auto_model_cls or OmniSpeech2SLlamaForCausalLM
        for key in ['forward', 'generate']:
            try:
                delattr(OmniSpeech2SLlamaForCausalLM, key)
                delattr(OmniSpeechLlamaForCausalLM, key)
            except AttributeError:
                pass
        # not support device_map='auto'
        device_map = model_kwargs['device_map']
        model_kwargs['device_map'] = None
        model = super().get_model(model_dir, config, processor, model_kwargs)
        model.to(get_device() if device_map == 'auto' else device_map)
        return model


register_model(
    ModelMeta(
        MLLMModelType.llama3_1_omni,
        [ModelGroup([
            Model('ICTNLP/Llama-3.1-8B-Omni', 'ICTNLP/Llama-3.1-8B-Omni'),
        ], )],
        Llama3OmniLoader,
        template=TemplateType.llama3_1_omni,
        architectures=['OmniSpeech2SLlamaForCausalLM'],
        model_arch=ModelArch.llama3_1_omni,
        requires=['openai-whisper'],
        tags=['audio'],
    ))
