import atexit
import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import List, Optional, Union
from urllib.parse import urlparse

import requests
import torch
from packaging import version
from pydantic import ValidationError
from requests import ConnectionError
from torch import nn

from swift.infer_engine import AdapterRequest, RequestConfig
from swift.infer_engine.protocol import ChatCompletionResponse, RolloutInferRequest, RolloutOutput
from swift.metrics import Metric
from swift.utils import is_trl_available, is_vllm_ascend_available, is_vllm_available
from .utils import format_host_for_url, is_valid_ipv6_address, peft_config_to_dict, resolve_hostname

if is_vllm_available():
    from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
    from vllm.distributed.utils import StatelessProcessGroup

    if is_vllm_ascend_available():
        from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator  # noqa

if is_trl_available():
    import trl
    trl_verison = version.parse(trl.__version__)

logger = logging.getLogger(__name__)


class VLLMClient:

    def __init__(self,
                 base_urls: Optional[List[str]] = None,
                 hosts: List[str] = ['0.0.0.0'],
                 server_ports: List[int] = [8000],
                 group_ports: Optional[Union[int, List[int]]] = None,
                 connection_timeout: float = 240.0):
        if not is_vllm_available():
            raise ImportError('vLLM is not installed. Please install it with `pip install vllm`.')

        if base_urls is not None:
            self.base_urls = []
            self.hosts = []
            for url in base_urls:
                parsed_url = urlparse(url)
                # Use resolve_hostname instead of gethostbyname for IPv6 support
                host = resolve_hostname(parsed_url.hostname)
                scheme = parsed_url.scheme or 'http'
                base_url_i = f'{scheme}://{parsed_url.netloc}{parsed_url.path}'
                self.base_urls.append(base_url_i)
                self.hosts.append(host)
        else:
            if len(hosts) != len(server_ports):
                raise ValueError('host and server_port must have same length when lists are provided')
            # Format IPv6 addresses correctly in URLs (wrap with brackets)
            self.base_urls = [f'http://{format_host_for_url(h)}:{p}' for h, p in zip(hosts, server_ports)]
            self.hosts = hosts

        self.num_servers = len(self.base_urls)

        if group_ports is None:
            group_ports = [51216 + i for i in range(self.num_servers)]

        self.sessions = [requests.Session() for _ in range(self.num_servers)]

        if isinstance(group_ports, int):
            self.group_ports = [group_ports + i for i in range(self.num_servers)]
        elif isinstance(group_ports, list) and len(group_ports) == self.num_servers:
            self.group_ports = group_ports
        else:
            raise ValueError('group_port must be int or list of length num_servers')

        self.pynccl_comms = []
        self.check_server(connection_timeout)

    def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
        server_status = [False] * self.num_servers

        def check_single_server(i):
            start_time = time.time()
            url = f'{self.base_urls[i]}/health/'
            while True:
                try:
                    response = requests.get(url, timeout=retry_interval)
                    if response.status_code == 200:
                        server_status[i] = True
                        return
                except Exception:
                    pass

                elapsed = time.time() - start_time
                if elapsed >= total_timeout:
                    return

                time.sleep(retry_interval)

        threads = []
        for i in range(self.num_servers):
            t = threading.Thread(target=check_single_server, args=(i, ))
            t.daemon = True
            t.start()
            threads.append(t)

        for t in threads:
            t.join(total_timeout)

        if not all(server_status):
            failed_servers = [self.base_urls[i] for i, status in enumerate(server_status) if not status]
            raise ConnectionError(f'Servers not reachable after {total_timeout}s: {failed_servers}')

    def infer(
        self,
        infer_requests: List[RolloutInferRequest],
        request_config: Optional[RequestConfig] = None,
        metrics: Optional[List[Metric]] = None,
        *,
        use_tqdm: Optional[bool] = None,
        adapter_request: Optional[AdapterRequest] = None,
    ):
        if not hasattr(self, 'use_async_engine') or not hasattr(self, 'use_gym_env'):
            self.get_engine_type()

        n = len(infer_requests)
        chunk_size = (n + self.num_servers - 1) // self.num_servers
        chunks = [infer_requests[i:i + chunk_size] for i in range(0, n, chunk_size)]
        chunks += [[]] * (self.num_servers - len(chunks))

        results = [None] * self.num_servers
        errors = [None] * self.num_servers
        if isinstance(request_config, RequestConfig):
            request_config = asdict(request_config)

        def process_chunk(i, chunk):
            try:
                if len(chunk) > 0 and isinstance(chunk[0], RolloutInferRequest):
                    chunk = [asdict(req) for req in chunk]

                response = self.sessions[i].post(
                    f'{self.base_urls[i]}/infer/',
                    json={
                        'infer_requests': chunk,
                        'request_config': request_config,
                        'metrics': metrics,
                        'use_tqdm': use_tqdm,
                        'adapter_request': adapter_request,
                    },
                )

                if response.status_code != 200:
                    errors[i] = Exception(f'Server {i} failed: {response.status_code}, {response.text}')
                    return

                resp_data = response.json()
                parsed: List[Union[RolloutOutput, ChatCompletionResponse]] = []

                for item in resp_data:
                    try:
                        parsed.append(RolloutOutput.model_validate(item))
                    except ValidationError:
                        parsed.append(ChatCompletionResponse(**item))
                results[i] = parsed
            except Exception as e:
                errors[i] = e

        with ThreadPoolExecutor(max_workers=self.num_servers) as executor:
            futures = [executor.submit(process_chunk, i, chunk) for i, chunk in enumerate(chunks)]
            for future in futures:
                future.result()

        all_errors = [e for e in errors if e is not None]
        if all_errors:
            raise RuntimeError(f'Multiple errors: {all_errors}')

        return [res for server_results in results for res in server_results]

    def init_communicator(self, device: Union[int, str] = 0):
        self.pynccl_comms = []
        for i in range(self.num_servers):
            response = self.sessions[i].get(f'{self.base_urls[i]}/get_world_size/')
            if response.status_code != 200:
                raise Exception(f'Server {i} failed: {response.text}')
            vllm_world_size = response.json()['world_size']

            world_size = vllm_world_size + 1
            rank = vllm_world_size
            kwargs = {}
            if trl_verison >= version.parse('0.20.0'):
                try:
                    client_device_uuid = str(torch.cuda.get_device_properties(device).uuid)
                except Exception:
                    client_device_uuid = '42'
                kwargs['client_device_uuid'] = client_device_uuid

            # Use '::' for IPv6 hosts, '0.0.0.0' for IPv4 hosts
            bind_host = '::' if is_valid_ipv6_address(self.hosts[i]) else '0.0.0.0'
            response = self.sessions[i].post(
                f'{self.base_urls[i]}/init_communicator/',
                json={
                    'host': bind_host,
                    'port': self.group_ports[i],
                    'world_size': world_size,
                    **kwargs
                })
            if response.status_code != 200:
                raise Exception(f'Server {i} init failed: {response.text}')

            time.sleep(0.1)

            pg = StatelessProcessGroup.create(
                host=self.hosts[i], port=self.group_ports[i], rank=rank, world_size=world_size)
            comm = PyNcclCommunicator(pg, device=device)
            self.pynccl_comms.append(comm)

        atexit.register(self.close_communicator)

    def update_named_param(self, name: str, weights: torch.Tensor):
        dtype = str(weights.dtype)
        shape = tuple(weights.shape)

        errors = [None] * self.num_servers

        def _update_single_server(i):
            try:
                response = self.sessions[i].post(
                    f'{self.base_urls[i]}/update_named_param/',
                    json={
                        'name': name,
                        'dtype': dtype,
                        'shape': shape
                    },
                )
                if response.status_code != 200:
                    raise Exception(f'Server {i} update failed: {response.text}')

                self.pynccl_comms[i].broadcast(weights, src=self.pynccl_comms[i].rank)
                self.pynccl_comms[i].group.barrier()
            except Exception as e:
                errors[i] = e

        with ThreadPoolExecutor(max_workers=self.num_servers) as executor:
            futures = [executor.submit(_update_single_server, i) for i in range(self.num_servers)]
            for future in futures:
                future.result()

        all_errors = [e for e in errors if e is not None]
        if all_errors:
            raise RuntimeError(f'Multiple errors: {all_errors}')

    def update_adapter_flattened_param(self, peft_config, metadatas, flattened_tensor):
        """
        Adds a LoRA adapter to the model on all servers using flattened tensor.

        Args:
            peft_config: PEFT configuration for LoRA adapter.
            metadatas: List of FlattenedTensorMetadata objects.
            flattened_tensor: The flattened tensor containing all adapter parameters.
        """
        errors = [None] * self.num_servers
        peft_config = peft_config_to_dict(peft_config)
        metadatas = [m.model_dump() if hasattr(m, 'model_dump') else m.dict() for m in metadatas]
        lora_int_id = int(time.time_ns() % 0x7FFFFFFF)

        def _update_single_server(i):
            try:
                data = {
                    'lora_int_id': lora_int_id,
                    'peft_config': {
                        **peft_config
                    },
                    'metadatas': metadatas,
                }

                response = self.sessions[i].post(
                    f'{self.base_urls[i]}/update_adapter_flattened_param/',
                    json=data,
                )
                if response.status_code != 200:
                    raise Exception(f'Server {i} update adapter failed: {response.text}')

                self.pynccl_comms[i].broadcast(flattened_tensor, src=self.pynccl_comms[i].rank)
                self.pynccl_comms[i].group.barrier()
            except Exception as e:
                errors[i] = e

        with ThreadPoolExecutor(max_workers=self.num_servers) as executor:
            futures = [executor.submit(_update_single_server, i) for i in range(self.num_servers)]
            for future in futures:
                future.result()

        all_errors = [e for e in errors if e is not None]
        if all_errors:
            raise RuntimeError(f'Multiple errors: {all_errors}')

    def update_adapter_param(self, peft_config, lora_params):
        """
        Adds a LoRA adapter to the model on all servers without flattening.
        Sends each tensor individually.

        Args:
            peft_config: PEFT configuration for LoRA adapter.
            lora_params: OrderedDict of (name, tensor) pairs for LoRA parameters.
        """
        errors = [None] * self.num_servers
        peft_config = peft_config_to_dict(peft_config)
        lora_int_id = int(time.time_ns() % 0x7FFFFFFF)

        # Build metadata for each tensor
        lora_tensors_metadata = []
        for name, param in lora_params.items():
            metadata = {
                'name': name,
                'dtype': str(param.dtype),
                'shape': tuple(param.shape),
                'start_idx': 0,  # Not used in non-flattened mode
                'end_idx': param.numel(),  # Not used in non-flattened mode
                'numel': param.numel(),
            }
            lora_tensors_metadata.append(metadata)

        def _update_single_server(i):
            try:
                data = {
                    'lora_int_id': lora_int_id,
                    'peft_config': {
                        **peft_config
                    },
                    'lora_tensors_metadata': lora_tensors_metadata,
                }

                response = self.sessions[i].post(
                    f'{self.base_urls[i]}/update_adapter_param/',
                    json=data,
                )
                if response.status_code != 200:
                    raise Exception(f'Server {i} update adapter failed: {response.text}')

                # Broadcast each tensor individually
                for name, param in lora_params.items():
                    self.pynccl_comms[i].broadcast(param, src=self.pynccl_comms[i].rank)
                self.pynccl_comms[i].group.barrier()
            except Exception as e:
                errors[i] = e

        with ThreadPoolExecutor(max_workers=self.num_servers) as executor:
            futures = [executor.submit(_update_single_server, i) for i in range(self.num_servers)]
            for future in futures:
                future.result()

        all_errors = [e for e in errors if e is not None]
        if all_errors:
            raise RuntimeError(f'Multiple errors: {all_errors}')

    def update_flattened_params(self, metadatas, flattened_tensor):
        """
        Updates model parameters using flattened tensor data.

        Args:
            metadatas: List of FlattenedTensorMetadata objects
            flattened_tensor: The flattened tensor containing all parameters
        """
        errors = [None] * self.num_servers
        metadatas = [m.model_dump() if hasattr(m, 'model_dump') else m.dict() for m in metadatas]

        def _update_single_server(i):
            try:
                data = {
                    'metadatas': metadatas,
                }

                response = self.sessions[i].post(
                    f'{self.base_urls[i]}/update_flattened_params/',
                    json=data,
                )
                if response.status_code != 200:
                    raise Exception(f'Server {i} update flattened params failed: {response.text}')

                self.pynccl_comms[i].broadcast(flattened_tensor, src=self.pynccl_comms[i].rank)
                self.pynccl_comms[i].group.barrier()
            except Exception as e:
                errors[i] = e

        with ThreadPoolExecutor(max_workers=self.num_servers) as executor:
            futures = [executor.submit(_update_single_server, i) for i in range(self.num_servers)]
            for future in futures:
                future.result()

        all_errors = [e for e in errors if e is not None]
        if all_errors:
            raise RuntimeError(f'Multiple errors: {all_errors}')

    def update_model_params(self, model: nn.Module):
        for name, param in model.named_parameters():
            self.update_named_param(name, param.data)

    def reset_prefix_cache(self):
        errors = [None] * self.num_servers

        def _reset_single_server(i):
            try:
                response = self.sessions[i].post(f'{self.base_urls[i]}/reset_prefix_cache/')
                if response.status_code != 200:
                    raise Exception(f'Server {i} reset failed: {response.text}')
            except Exception as e:
                errors[i] = e

        with ThreadPoolExecutor(max_workers=self.num_servers) as executor:
            futures = [executor.submit(_reset_single_server, i) for i in range(self.num_servers)]
            for future in futures:
                future.result()
        all_errors = [e for e in errors if e is not None]
        if all_errors:
            raise RuntimeError(f'Multiple errors on reset_prefix_cache: {all_errors}')

    def get_engine_type(self):
        # assume that all server has same engine type
        response = self.sessions[0].post(f'{self.base_urls[0]}/get_engine_type/')
        if response.status_code != 200:
            raise Exception(f'Engine type request failed: {response.text}')

        result = response.json()
        self.use_async_engine = result['engine_type'] == 'AsyncLLMEngine'
        self.enable_multi_turn = result.get('enable_multi_turn', False)
        self.use_gym_env = result.get('use_gym_env', False)
        self.enable_lora = result.get('enable_lora', False)
        return result

    def close_communicator(self):
        for i in range(self.num_servers):
            try:
                response = self.sessions[i].post(f'{self.base_urls[i]}/close_communicator/')
                if response.status_code != 200:
                    logger.warning(f'Server {i} close failed: {response.text}')
            except Exception as e:
                logger.warning(f'Error closing server {i} communicator: {str(e)}')
