from collections.abc import Sequence
from typing import Any, cast

from sqlalchemy import select, update
from sqlalchemy.orm import Session

from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.file.models import File
from core.memory import NodeTokenBufferMemory, TokenBufferMemory
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
    AssistantPromptMessage,
    MultiModalPromptMessageContent,
    PromptMessage,
    PromptMessageContentUnionTypes,
    PromptMessageRole,
    ToolPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import LLMGenerationData, ModelConfig
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID

from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError


def fetch_model_config(
    tenant_id: str, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
    if not node_data_model.mode:
        raise LLMModeRequiredError("LLM mode is required.")

    model = ModelManager().get_model_instance(
        tenant_id=tenant_id,
        model_type=ModelType.LLM,
        provider=node_data_model.provider,
        model=node_data_model.name,
    )

    model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)

    # check model
    provider_model = model.provider_model_bundle.configuration.get_provider_model(
        model=node_data_model.name, model_type=ModelType.LLM
    )

    if provider_model is None:
        raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
    provider_model.raise_for_status()

    # model config
    stop: list[str] = []
    if "stop" in node_data_model.completion_params:
        stop = node_data_model.completion_params.pop("stop")

    model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
    if not model_schema:
        raise ModelNotExistError(f"Model {node_data_model.name} not exist.")

    return model, ModelConfigWithCredentialsEntity(
        provider=node_data_model.provider,
        model=node_data_model.name,
        model_schema=model_schema,
        mode=node_data_model.mode,
        provider_model_bundle=model.provider_model_bundle,
        credentials=model.credentials,
        parameters=node_data_model.completion_params,
        stop=stop,
    )


def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
    variable = variable_pool.get(selector)
    if variable is None:
        return []
    elif isinstance(variable, FileSegment):
        return [variable.value]
    elif isinstance(variable, ArrayFileSegment):
        return variable.value
    elif isinstance(variable, NoneSegment | ArrayAnySegment):
        return []
    raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")


def fetch_memory(
    variable_pool: VariablePool,
    app_id: str,
    tenant_id: str,
    node_data_memory: MemoryConfig | None,
    model_instance: ModelInstance,
    node_id: str = "",
) -> BaseMemory | None:
    """
    Fetch memory based on configuration mode.

    Returns TokenBufferMemory for conversation mode (default),
    or NodeTokenBufferMemory for node mode (Chatflow only).

    :param variable_pool: Variable pool containing system variables
    :param app_id: Application ID
    :param tenant_id: Tenant ID
    :param node_data_memory: Memory configuration
    :param model_instance: Model instance for token counting
    :param node_id: Node ID in the workflow (required for node mode)
    :return: Memory instance or None if not applicable
    """
    if not node_data_memory:
        return None

    # Get conversation_id from variable pool (required for both modes in Chatflow)
    conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
    if not isinstance(conversation_id_variable, StringSegment):
        return None
    conversation_id = conversation_id_variable.value

    # Return appropriate memory type based on mode
    if node_data_memory.mode == MemoryMode.NODE:
        # Node-level memory (Chatflow only)
        if not node_id:
            return None
        return NodeTokenBufferMemory(
            app_id=app_id,
            conversation_id=conversation_id,
            node_id=node_id,
            tenant_id=tenant_id,
            model_instance=model_instance,
        )
    else:
        # Conversation-level memory (default)
        with Session(db.engine, expire_on_commit=False) as session:
            stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
            conversation = session.scalar(stmt)
            if not conversation:
                return None
        return TokenBufferMemory(conversation=conversation, model_instance=model_instance)


def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
    provider_model_bundle = model_instance.provider_model_bundle
    provider_configuration = provider_model_bundle.configuration

    if provider_configuration.using_provider_type != ProviderType.SYSTEM:
        return

    system_configuration = provider_configuration.system_configuration

    quota_unit = None
    for quota_configuration in system_configuration.quota_configurations:
        if quota_configuration.quota_type == system_configuration.current_quota_type:
            quota_unit = quota_configuration.quota_unit

            if quota_configuration.quota_limit == -1:
                return

            break

    used_quota = None
    if quota_unit:
        if quota_unit == QuotaUnit.TOKENS:
            used_quota = usage.total_tokens
        elif quota_unit == QuotaUnit.CREDITS:
            used_quota = dify_config.get_model_credits(model_instance.model)
        else:
            used_quota = 1

    if used_quota is not None and system_configuration.current_quota_type is not None:
        if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
            from services.credit_pool_service import CreditPoolService

            CreditPoolService.check_and_deduct_credits(
                tenant_id=tenant_id,
                credits_required=used_quota,
            )
        elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
            from services.credit_pool_service import CreditPoolService

            CreditPoolService.check_and_deduct_credits(
                tenant_id=tenant_id,
                credits_required=used_quota,
                pool_type="paid",
            )
        else:
            with Session(db.engine) as session:
                stmt = (
                    update(Provider)
                    .where(
                        Provider.tenant_id == tenant_id,
                        # TODO: Use provider name with prefix after the data migration.
                        Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
                        Provider.provider_type == ProviderType.SYSTEM.value,
                        Provider.quota_type == system_configuration.current_quota_type.value,
                        Provider.quota_limit > Provider.quota_used,
                    )
                    .values(
                        quota_used=Provider.quota_used + used_quota,
                        last_used=naive_utc_now(),
                    )
                )
                session.execute(stmt)
                session.commit()


def build_context(
    prompt_messages: Sequence[PromptMessage],
    assistant_response: str,
    generation_data: LLMGenerationData | None = None,
    files: Sequence[Any] | None = None,
) -> list[PromptMessage]:
    """
    Build context from prompt messages and assistant response.
    Excludes system messages and includes the current LLM response.
    Returns list[PromptMessage] for use with ArrayPromptMessageSegment.

    For tool-enabled runs, reconstructs the full conversation including tool calls and results.
    Note: Multi-modal content base64 data is truncated to avoid storing large data in context.

    Args:
        prompt_messages: Initial prompt messages (user query, etc.)
        assistant_response: Final assistant response text
        generation_data: Optional generation data containing trace for tool-enabled runs
        files: Optional list of File objects generated during execution
    """

    context_messages: list[PromptMessage] = [
        _truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
    ]

    # Build file description suffix if files were generated
    file_suffix = ""
    if files:
        file_descriptions = _build_file_descriptions(files)
        if file_descriptions:
            file_suffix = f"\n\n{file_descriptions}"

    # For tool-enabled runs, reconstruct messages from trace
    if generation_data and generation_data.trace:
        context_messages.extend(_build_messages_from_trace(generation_data, assistant_response, file_suffix))
    else:
        context_messages.append(AssistantPromptMessage(content=assistant_response + file_suffix))

    return context_messages


def _build_file_descriptions(files: Sequence[Any]) -> str:
    """
    Build a text description of generated files for inclusion in context.

    The description includes file_id which can be used by subsequent nodes
    to reference the files via structured output.
    """
    if not files:
        return ""

    descriptions: list[str] = ["[Generated Files]"]
    for file in files:
        # Get file attributes (File is a Pydantic model)
        file_id = getattr(file, "id", None) or getattr(file, "related_id", None)
        filename = getattr(file, "filename", "unknown")
        file_type = getattr(file, "type", "unknown")
        if hasattr(file_type, "value"):
            file_type = file_type.value

        if file_id:
            descriptions.append(f"- {filename} (id: {file_id}, type: {file_type})")

    return "\n".join(descriptions)


def _build_messages_from_trace(
    generation_data: LLMGenerationData,
    assistant_response: str,
    file_suffix: str = "",
) -> list[PromptMessage]:
    """
    Build assistant and tool messages from trace segments.

    Processes trace in order to reconstruct the conversation flow:
    - Model segments with tool_calls -> AssistantPromptMessage with tool_calls
    - Model segments without tool_calls -> AssistantPromptMessage with text only
    - Tool segments -> ToolPromptMessage with result

    assistant_response is the accumulated text from all model turns (see LLMGenerationData.text).
    Each model trace segment already contains its own text portion; to avoid duplication we track
    how much text has been covered by trace segments and only append the remaining portion (if any)
    along with file_suffix as the final assistant message.
    """
    from core.workflow.nodes.llm.entities import ModelTraceSegment, ToolTraceSegment

    messages: list[PromptMessage] = []
    # Track total text length already present in model trace segments
    covered_text_len = 0

    for segment in generation_data.trace:
        if segment.type == "model" and isinstance(segment.output, ModelTraceSegment):
            model_output = segment.output
            segment_content = model_output.text or ""
            covered_text_len += len(segment_content)

            if model_output.tool_calls:
                tool_calls = [
                    AssistantPromptMessage.ToolCall(
                        id=tc.id or "",
                        type="function",
                        function=AssistantPromptMessage.ToolCall.ToolCallFunction(
                            name=tc.name or "",
                            arguments=tc.arguments or "{}",
                        ),
                    )
                    for tc in model_output.tool_calls
                ]
                messages.append(AssistantPromptMessage(content=segment_content, tool_calls=tool_calls))
            elif segment_content:
                # Model response without tool calls (e.g., final text-only turn)
                messages.append(AssistantPromptMessage(content=segment_content))

        elif segment.type == "tool" and isinstance(segment.output, ToolTraceSegment):
            tool_output = segment.output
            messages.append(
                ToolPromptMessage(
                    content=tool_output.output or "",
                    tool_call_id=tool_output.id or "",
                    name=tool_output.name or "",
                )
            )

    # Append only the portion of assistant_response not already covered by trace segments
    remaining_text = assistant_response[covered_text_len:]
    final_content = remaining_text + file_suffix
    if final_content:
        messages.append(AssistantPromptMessage(content=final_content))

    return messages


def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
    """
    Truncate multi-modal content base64 data in a message to avoid storing large data.
    Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.

    If file_ref is present, clears base64_data and url (they can be restored later).
    Otherwise, truncates base64_data as fallback for legacy data.
    """
    content = message.content
    if content is None or isinstance(content, str):
        return message

    # Process list content, handling multi-modal data based on file_ref availability
    new_content: list[PromptMessageContentUnionTypes] = []
    for item in content:
        if isinstance(item, MultiModalPromptMessageContent):
            if item.file_ref:
                # Clear base64 and url, keep file_ref for later restoration
                new_content.append(item.model_copy(update={"base64_data": "", "url": ""}))
            else:
                # Fallback: truncate base64_data if no file_ref (legacy data)
                truncated_base64 = ""
                if item.base64_data:
                    truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
                new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
        else:
            new_content.append(item)

    return message.model_copy(update={"content": new_content})


def restore_multimodal_content_in_messages(messages: Sequence[PromptMessage]) -> list[PromptMessage]:
    """
    Restore multimodal content (base64 or url) in a list of PromptMessages.

    When context is saved, base64_data is cleared to save storage space.
    This function restores the content by parsing file_ref in each MultiModalPromptMessageContent.

    Args:
        messages: List of PromptMessages that may contain truncated multimodal content

    Returns:
        List of PromptMessages with restored multimodal content
    """
    from core.file import file_manager

    return [_restore_message_content(msg, file_manager) for msg in messages]


def _restore_message_content(message: PromptMessage, file_manager) -> PromptMessage:
    """Restore multimodal content in a single PromptMessage."""
    content = message.content
    if content is None or isinstance(content, str):
        return message

    restored_content: list[PromptMessageContentUnionTypes] = []
    for item in content:
        if isinstance(item, MultiModalPromptMessageContent):
            restored_item = file_manager.restore_multimodal_content(item)
            restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
        else:
            restored_content.append(item)

    return message.model_copy(update={"content": restored_content})
