# Copyright (c) ModelScope Contributors. All rights reserved.
import os
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Union

import json
from packaging import version

import swift
from swift.hub import get_hub
from swift.model import get_ckpt_dir, get_model_processor, load_by_unsloth
from swift.plugins import extra_tuners
from swift.ray import RayArguments
from swift.template import Template, get_template
from swift.utils import (Processor, check_json_format, get_dist_setting, get_logger, import_external_file, is_dist,
                         is_master, json_parse_to_dict, safe_snapshot_download, set_device, use_hf_hub)
from .data_args import DataArguments
from .generation_args import GenerationArguments
from .model_args import ModelArguments
from .quant_args import QuantizeArguments
from .template_args import TemplateArguments

logger = get_logger()


def get_supported_tuners():
    return {'lora', 'full', 'longlora', 'adalora', 'llamapro', 'adapter', 'vera', 'boft', 'fourierft', 'reft', 'bone'
            } | set(extra_tuners.keys())


@dataclass
class CompatArguments:
    ckpt_dir: Optional[str] = None
    lora_modules: List[str] = field(default_factory=list)

    def _handle_ckpt_dir(self: 'BaseArguments'):
        assert os.path.isdir(self.ckpt_dir), f'self.ckpt_dir: {self.ckpt_dir}'
        if (os.path.exists(os.path.join(self.ckpt_dir, 'adapter_config.json'))
                or os.path.exists(os.path.join(self.ckpt_dir, 'default', 'adapter_config.json'))
                or os.path.exists(os.path.join(self.ckpt_dir, 'reft'))):
            if self.ckpt_dir in self.adapters:
                return
            self.adapters.insert(0, self.ckpt_dir)
        else:
            self.model = self.ckpt_dir
        self.ckpt_dir = None

    def __post_init__(self: 'BaseArguments'):
        if self.ckpt_dir is not None:
            self._handle_ckpt_dir()

        if len(self.lora_modules) > 0:
            self.adapters += self.lora_modules


@dataclass
class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, DataArguments, TemplateArguments,
                    ModelArguments, RayArguments):
    """BaseArguments class is a dataclass that inherits from multiple argument classes.

    This class consolidates arguments from CompatArguments, GenerationArguments, QuantizeArguments, DataArguments,
    TemplateArguments, ModelArguments, RayArguments.

    Args:
        tuner_backend (str): The tuner backend to use. Choices are 'peft' or 'unsloth'. Default is 'peft'.
        train_type (str): The training type. Choices include 'lora', 'full', 'longlora', 'adalora', 'llamapro',
            'adapter', 'vera', 'boft', 'fourierft', 'reft'. Default is 'lora'.
        adapters (List[str]): A list of adapter IDs or paths. This is typically used for inference or deployment.
            It can also resume training by only loading adapter weights, differing from `resume_from_checkpoint`
            which also loads optimizer states. Default is [].
        external_plugins (List[str]): A list of external 'plugin.py' files to be registered and imported into
            the plugin module. Default is [].
        seed (int): The global random seed for reproducibility. Note that this does not affect `data_seed`,
            which controls dataset randomization. Default is 42.
        model_kwargs (Optional[str]): Additional keyword arguments for specific models, passed as a JSON string
            (e.g., '{"key": "value"}'). It's recommended to use the same arguments for inference as for training.
            Default is None.
        load_args (bool): Whether to load `args.json` from a checkpoint when using `--resume_from_checkpoint`,
            `--model`, or `--adapters`. Defaults to True for inference/export and False for training. Usually,
            this does not need to be modified. Default is True.
        load_data_args (bool): If True, will also load data-related arguments from `args.json`. This is useful
            for running inference on the same validation split used during training. Default is False.
        packing (bool): Whether to enable packing of datasets. Default is False.
        packing_length (Optional[int]): Length of packing. Default is None.
        packing_num_proc (int): Number of processes used for packing, Default is 1.
        lazy_tokenize (Optional[bool]): Whether to enable lazy tokenization. Default is None.
        use_hf (bool): Whether to use Hugging Face for downloading/uploading models and datasets. If False,
            ModelScope is used. Default is False.
        hub_token (Optional[str]): The authentication token for ModelScope or Hugging Face Hub. Default is None.
        ddp_timeout (int): Timeout for DDP (Distributed Data Parallel) operations, in seconds. Default is 18000000.
        ddp_backend (Optional[str]): The backend for DDP. Choices include "nccl", "gloo", "mpi", "ccl", "hccl",
            "cncl", "mccl". If None, it will be automatically selected. Default is None.
        ignore_args_error (bool): Whether to ignore argument errors. This is useful for compatibility with Jupyter
            notebooks. Default is False.
        use_swift_lora (bool): Whether to use swift lora. This is a compatible argument. Default is False.
    """
    tuner_backend: Literal['peft', 'unsloth'] = 'peft'
    train_type: str = field(default='lora', metadata={'help': f'train_type choices: {list(get_supported_tuners())}'})
    adapters: List[str] = field(default_factory=list)
    external_plugins: List[str] = field(default_factory=list)
    # This parameter is kept for swift3.x compatibility. Please use `external_plugins` as a replacement.
    custom_register_path: List[str] = field(default_factory=list)

    seed: int = 42
    model_kwargs: Optional[Union[dict, str]] = None
    load_args: bool = True
    load_data_args: bool = False
    # dataset
    packing: bool = False
    packing_length: Optional[int] = None
    packing_num_proc: int = 1
    lazy_tokenize: Optional[bool] = None
    # hub
    use_hf: bool = False
    # None: use env var `MODELSCOPE_API_TOKEN`
    hub_token: Optional[str] = field(
        default=None, metadata={'help': 'SDK token can be found in https://modelscope.cn/my/myaccesstoken'})
    # dist
    ddp_timeout: int = 18000000
    ddp_backend: Optional[str] = None

    # extra
    ignore_args_error: bool = False  # True: notebook compatibility
    use_swift_lora: bool = False  # True for using tuner_backend == swift, don't specify this unless you know what you are doing # noqa

    def _prepare_training_args(self, training_args: Dict[str, Any]) -> None:
        pass

    def _init_lazy_tokenize(self):
        if self.lazy_tokenize is None:
            if self.cached_dataset or self.cached_val_dataset:
                self.lazy_tokenize = False
            elif (self.model_meta is not None and self.model_meta.is_multimodal and not self.streaming
                  and not self.packing and not getattr(self, 'group_by_length', False)):
                self.lazy_tokenize = True
            else:
                self.lazy_tokenize = False
            logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}')
        if self.lazy_tokenize:
            if self.packing:
                raise ValueError('Packing and lazy_tokenize are incompatible.')
            if self.streaming:
                raise ValueError('Streaming and lazy_tokenize are incompatible.')

    def _import_external_plugins(self):
        if isinstance(self.external_plugins, str):
            self.external_plugins = [self.external_plugins]
        # swift v3.x compatibility
        if isinstance(self.custom_register_path, str):
            self.custom_register_path = [self.custom_register_path]
        if self.custom_register_path:
            self.external_plugins += self.custom_register_path

        if not self.external_plugins:
            return
        for external_plugin in self.external_plugins:
            import_external_file(external_plugin)
        logger.info(f'Successfully imported external_plugins: {self.external_plugins}.')

    @staticmethod
    def _check_is_adapter(adapter_dir: str) -> bool:
        if (os.path.exists(os.path.join(adapter_dir, 'adapter_config.json'))
                or os.path.exists(os.path.join(adapter_dir, 'default', 'adapter_config.json'))
                or os.path.exists(os.path.join(adapter_dir, 'reft'))):
            return True
        return False

    def _init_adapters(self):
        if isinstance(self.adapters, str):
            self.adapters = [self.adapters]
        self.adapters = [
            safe_snapshot_download(adapter, use_hf=self.use_hf, hub_token=self.hub_token) for adapter in self.adapters
        ]

    def __post_init__(self):
        self.swift_version = swift.__version__
        if self.use_hf or use_hf_hub():
            self.use_hf = True
            os.environ['USE_HF'] = '1'
        CompatArguments.__post_init__(self)
        self._init_adapters()
        self._init_ckpt_dir()
        self._import_external_plugins()
        self._init_model_kwargs()
        # The Seq2SeqTrainingArguments has a property called world_size, which cannot be assigned a value.
        self.rank, self.local_rank, self.global_world_size, self.local_world_size = get_dist_setting()
        logger.info(f'rank: {self.rank}, local_rank: {self.local_rank}, '
                    f'world_size: {self.global_world_size}, local_world_size: {self.local_world_size}')
        if self.train_type not in extra_tuners:
            for adapter in self.adapters:
                assert self._check_is_adapter(adapter), (
                    f'`{adapter}` is not an adapter, please try using `--model` to pass it.')
        ModelArguments.__post_init__(self)
        QuantizeArguments.__post_init__(self)
        TemplateArguments.__post_init__(self)
        DataArguments.__post_init__(self)
        RayArguments.__post_init__(self)
        self._init_stream()
        if self.max_length is None and self.model_info is not None:
            self.max_length = self.model_info.max_model_len
        if self.packing and self.packing_length is None:
            self.packing_length = self.max_length
        self._init_lazy_tokenize()
        self.hub = get_hub(self.use_hf)
        if self.hub.try_login(self.hub_token):
            logger.info('hub login successful!')

    def _init_model_kwargs(self):
        """Prepare model kwargs and set them to the env"""
        self.model_kwargs: Dict[str, Any] = json_parse_to_dict(self.model_kwargs)
        for k, v in self.model_kwargs.items():
            k = k.upper()
            os.environ[k] = str(v)

    @property
    def is_adapter(self) -> bool:
        return self.train_type not in {'full'}

    @property
    def supported_tuners(self):
        return get_supported_tuners()

    @property
    def adapters_can_be_merged(self):
        return {'lora', 'longlora', 'llamapro', 'adalora'}

    @classmethod
    def from_pretrained(cls, checkpoint_dir: str):
        self = super().__new__(cls)
        self.load_data_args = True
        self.ckpt_dir = checkpoint_dir
        self.load_args_from_ckpt()
        all_keys = list(f.name for f in fields(BaseArguments))
        for key in all_keys:
            if not hasattr(self, key):
                setattr(self, key, None)
        return self

    def _init_ckpt_dir(self, adapters=None):
        # compat megatron
        model = self.model or getattr(self, 'mcore_model', None) or getattr(self, 'load', None)
        adapters = adapters or self.adapters or getattr(self, 'mcore_adapters', None) or getattr(
            self, 'adapter_load', None)
        if isinstance(adapters, str):
            adapters = [adapters]
        self.ckpt_dir = get_ckpt_dir(model, adapters)
        if self.ckpt_dir and self.load_args:
            self.load_args_from_ckpt()

    def load_args_from_ckpt(self) -> None:
        args_path = os.path.join(self.ckpt_dir, 'args.json')
        assert os.path.exists(args_path), f'args_path: {args_path}'
        with open(args_path, 'r', encoding='utf-8') as f:
            old_args = json.load(f)
        force_load_keys = [
            # base_args
            'train_type',
            # model_args
            'task_type',
            # quant_args
            'bnb_4bit_quant_type',
            'bnb_4bit_use_double_quant',
        ]
        # If the current value is None or an empty list and it is among the following keys
        load_keys = [
            'external_plugins',
            # model_args
            'model',
            'model_type',
            'model_revision',
            'torch_dtype',
            'attn_impl',
            'new_special_tokens',
            'num_labels',
            'problem_type',
            'rope_scaling',
            'max_model_len',
            # quant_args
            'quant_method',
            'quant_bits',
            'hqq_axis',
            'bnb_4bit_compute_dtype',
            # template_args
            'template',
            'system',
            'truncation_strategy',
            'agent_template',
            'norm_bbox',
            'use_chat_template',
            'response_prefix',
        ]
        data_keys = list(f.name for f in fields(DataArguments))
        swift_version = old_args.get('swift_version')
        if swift_version is None or version.parse(swift_version) < version.parse('4.0.0.dev'):
            load_keys.remove('model_type')
        for key, old_value in old_args.items():
            if old_value is None:
                continue
            if key in force_load_keys or self.load_data_args and key in data_keys:
                setattr(self, key, old_value)
            value = getattr(self, key, None)
            if key in load_keys and (value is None or isinstance(value, (list, tuple)) and len(value) == 0):
                setattr(self, key, old_value)
        logger.info(f'Successfully loaded {args_path}.')

    def save_args(self, output_dir=None) -> None:
        if is_master():
            output_dir = output_dir or self.output_dir
            os.makedirs(output_dir, exist_ok=True)
            fpath = os.path.join(output_dir, 'args.json')
            logger.info(f'The {self.__class__.__name__} will be saved in: {fpath}')
            with open(fpath, 'w', encoding='utf-8') as f:
                json.dump(check_json_format(self.__dict__), f, ensure_ascii=False, indent=2)

    def _init_device(self):
        if is_dist():
            set_device()

    def get_template(self, processor: Optional[Processor] = None, **kwargs) -> Template:
        if processor is None:
            processor = self.get_model_processor(load_model=False)[1]
        template_kwargs = self.get_template_kwargs()
        if 'template_type' in kwargs:
            template_type = kwargs.get('template_type')
        else:
            template_type = self.template
        template_kwargs['template_type'] = template_type
        template = get_template(processor, **template_kwargs)
        return template

    def get_model_processor(self,
                            *,
                            model=None,
                            model_type=None,
                            model_revision=None,
                            task_type=None,
                            num_labels=None,
                            **kwargs):
        if self.tuner_backend == 'unsloth':
            return load_by_unsloth(self)
        res = self.get_model_kwargs()
        res.update(kwargs)
        # compat rlhf
        res['model_id_or_path'] = model or self.model
        res['model_type'] = model_type or self.model_type
        res['model_revision'] = model_revision or self.model_revision
        res['task_type'] = task_type or self.task_type
        res['num_labels'] = num_labels or self.num_labels

        return get_model_processor(**res)
