package io.modelcontextprotocol.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.net.HttpHeaders;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.ServerMcpTransport;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.WebServlet;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

@WebServlet(asyncSupported = true)
/* loaded from: input_file:mcp/mcp-0.8.0-SNAPSHOT.jar:io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.class */
public class HttpServletSseServerTransport extends HttpServlet implements ServerMcpTransport {
    private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransport.class);
    public static final String UTF_8 = "UTF-8";
    public static final String APPLICATION_JSON = "application/json";
    public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}";
    public static final String DEFAULT_SSE_ENDPOINT = "/sse";
    public static final String MESSAGE_EVENT_TYPE = "message";
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    private final ObjectMapper objectMapper;
    private final String messageEndpoint;
    private final String sseEndpoint;
    private final Map<String, ClientSession> sessions;
    private final AtomicBoolean isClosing;
    private Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> connectHandler;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:mcp/mcp-0.8.0-SNAPSHOT.jar:io/modelcontextprotocol/server/transport/HttpServletSseServerTransport$ClientSession.class */
    public static class ClientSession {
        private final String id;
        private final AsyncContext asyncContext;
        private final PrintWriter writer;

        ClientSession(String str, AsyncContext asyncContext, PrintWriter printWriter) {
            this.id = str;
            this.asyncContext = asyncContext;
            this.writer = printWriter;
        }
    }

    public HttpServletSseServerTransport(ObjectMapper objectMapper, String str, String str2) {
        this.sessions = new ConcurrentHashMap();
        this.isClosing = new AtomicBoolean(false);
        this.objectMapper = objectMapper;
        this.messageEndpoint = str;
        this.sseEndpoint = str2;
    }

    public HttpServletSseServerTransport(ObjectMapper objectMapper, String str) {
        this(objectMapper, str, DEFAULT_SSE_ENDPOINT);
    }

    @Override // jakarta.servlet.http.HttpServlet
    protected void doGet(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws ServletException, IOException {
        if (!this.sseEndpoint.equals(httpServletRequest.getPathInfo())) {
            httpServletResponse.sendError(HttpServletResponse.SC_NOT_FOUND);
            return;
        }
        if (this.isClosing.get()) {
            httpServletResponse.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
            return;
        }
        httpServletResponse.setContentType("text/event-stream");
        httpServletResponse.setCharacterEncoding(UTF_8);
        httpServletResponse.setHeader(HttpHeaders.CACHE_CONTROL, "no-cache");
        httpServletResponse.setHeader(HttpHeaders.CONNECTION, "keep-alive");
        httpServletResponse.setHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN, "*");
        String uuid = UUID.randomUUID().toString();
        AsyncContext startAsync = httpServletRequest.startAsync();
        startAsync.setTimeout(0L);
        PrintWriter writer = httpServletResponse.getWriter();
        this.sessions.put(uuid, new ClientSession(uuid, startAsync, writer));
        sendEvent(writer, ENDPOINT_EVENT_TYPE, this.messageEndpoint);
    }

    @Override // jakarta.servlet.http.HttpServlet
    protected void doPost(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws ServletException, IOException {
        if (this.isClosing.get()) {
            httpServletResponse.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
            return;
        }
        if (!this.messageEndpoint.equals(httpServletRequest.getPathInfo())) {
            httpServletResponse.sendError(HttpServletResponse.SC_NOT_FOUND);
            return;
        }
        try {
            BufferedReader reader = httpServletRequest.getReader();
            StringBuilder sb = new StringBuilder();
            while (true) {
                String readLine = reader.readLine();
                if (readLine == null) {
                    break;
                } else {
                    sb.append(readLine);
                }
            }
            McpSchema.JSONRPCMessage deserializeJsonRpcMessage = McpSchema.deserializeJsonRpcMessage(this.objectMapper, sb.toString());
            if (this.connectHandler != null) {
                this.connectHandler.apply(Mono.just(deserializeJsonRpcMessage)).subscribe(jSONRPCMessage -> {
                    try {
                        httpServletResponse.setContentType(APPLICATION_JSON);
                        httpServletResponse.setCharacterEncoding(UTF_8);
                        String writeValueAsString = this.objectMapper.writeValueAsString(jSONRPCMessage);
                        PrintWriter writer = httpServletResponse.getWriter();
                        writer.write(writeValueAsString);
                        writer.flush();
                    } catch (Exception e) {
                        logger.error("Error sending response: {}", e.getMessage());
                        try {
                            httpServletResponse.sendError(500, "Error processing response: " + e.getMessage());
                        } catch (IOException e2) {
                            logger.error(FAILED_TO_SEND_ERROR_RESPONSE, e2.getMessage());
                        }
                    }
                }, th -> {
                    try {
                        logger.error("Error processing message: {}", th.getMessage());
                        McpError mcpError = new McpError(th.getMessage());
                        httpServletResponse.setContentType(APPLICATION_JSON);
                        httpServletResponse.setCharacterEncoding(UTF_8);
                        httpServletResponse.setStatus(500);
                        String writeValueAsString = this.objectMapper.writeValueAsString(mcpError);
                        PrintWriter writer = httpServletResponse.getWriter();
                        writer.write(writeValueAsString);
                        writer.flush();
                    } catch (IOException e) {
                        logger.error(FAILED_TO_SEND_ERROR_RESPONSE, e.getMessage());
                        try {
                            httpServletResponse.sendError(500, "Error sending error response: " + e.getMessage());
                        } catch (IOException e2) {
                            logger.error(FAILED_TO_SEND_ERROR_RESPONSE, e2.getMessage());
                        }
                    }
                });
            } else {
                httpServletResponse.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "No message handler configured");
            }
        } catch (Exception e) {
            logger.error("Invalid message format: {}", e.getMessage());
            try {
                McpError mcpError = new McpError("Invalid message format: " + e.getMessage());
                httpServletResponse.setContentType(APPLICATION_JSON);
                httpServletResponse.setCharacterEncoding(UTF_8);
                httpServletResponse.setStatus(400);
                String writeValueAsString = this.objectMapper.writeValueAsString(mcpError);
                PrintWriter writer = httpServletResponse.getWriter();
                writer.write(writeValueAsString);
                writer.flush();
            } catch (IOException e2) {
                logger.error(FAILED_TO_SEND_ERROR_RESPONSE, e2.getMessage());
                httpServletResponse.sendError(400, "Invalid message format");
            }
        }
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> function) {
        this.connectHandler = function;
        return Mono.empty();
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage jSONRPCMessage) {
        if (!this.sessions.isEmpty()) {
            return Mono.create(monoSink -> {
                try {
                    String writeValueAsString = this.objectMapper.writeValueAsString(jSONRPCMessage);
                    this.sessions.values().forEach(clientSession -> {
                        try {
                            sendEvent(clientSession.writer, MESSAGE_EVENT_TYPE, writeValueAsString);
                        } catch (IOException e) {
                            logger.error("Failed to send message to session {}: {}", clientSession.id, e.getMessage());
                            removeSession(clientSession);
                        }
                    });
                    monoSink.success();
                } catch (Exception e) {
                    logger.error("Failed to process message: {}", e.getMessage());
                    monoSink.error(new McpError("Failed to process message: " + e.getMessage()));
                }
            });
        }
        logger.debug("No active sessions to broadcast message to");
        return Mono.empty();
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public void close() {
        super.close();
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public <T> T unmarshalFrom(Object obj, TypeReference<T> typeReference) {
        return (T) this.objectMapper.convertValue(obj, typeReference);
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public Mono<Void> closeGracefully() {
        this.isClosing.set(true);
        logger.debug("Initiating graceful shutdown with {} active sessions", Integer.valueOf(this.sessions.size()));
        return Mono.create(monoSink -> {
            this.sessions.values().forEach(this::removeSession);
            monoSink.success();
        });
    }

    private void sendEvent(PrintWriter printWriter, String str, String str2) throws IOException {
        printWriter.write("event: " + str + "\n");
        printWriter.write("data: " + str2 + "\n\n");
        printWriter.flush();
        if (printWriter.checkError()) {
            throw new IOException("Client disconnected");
        }
    }

    private void removeSession(ClientSession clientSession) {
        this.sessions.remove(clientSession.id);
        clientSession.asyncContext.complete();
    }

    @Override // jakarta.servlet.GenericServlet, jakarta.servlet.Servlet
    public void destroy() {
        closeGracefully().block();
        super.destroy();
    }
}
