package io.modelcontextprotocol.client.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.vladsch.flexmark.util.sequence.SequenceUtils;
import io.modelcontextprotocol.spec.ClientMcpTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.util.Assert;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.function.Consumer;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;

/* loaded from: input_file:mcp/mcp-0.8.0-SNAPSHOT.jar:io/modelcontextprotocol/client/transport/StdioClientTransport.class */
public class StdioClientTransport implements ClientMcpTransport {
    private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class);
    private final Sinks.Many<McpSchema.JSONRPCMessage> inboundSink;
    private final Sinks.Many<McpSchema.JSONRPCMessage> outboundSink;
    private Process process;
    private ObjectMapper objectMapper;
    private Scheduler inboundScheduler;
    private Scheduler outboundScheduler;
    private Scheduler errorScheduler;
    private final ServerParameters params;
    private final Sinks.Many<String> errorSink;
    private volatile boolean isClosing;
    private Consumer<String> stdErrorHandler;

    public StdioClientTransport(ServerParameters serverParameters) {
        this(serverParameters, new ObjectMapper());
    }

    public StdioClientTransport(ServerParameters serverParameters, ObjectMapper objectMapper) {
        this.isClosing = false;
        this.stdErrorHandler = str -> {
            logger.info("STDERR Message received: {}", str);
        };
        Assert.notNull(serverParameters, "The params can not be null");
        Assert.notNull(objectMapper, "The ObjectMapper can not be null");
        this.inboundSink = Sinks.many().unicast().onBackpressureBuffer();
        this.outboundSink = Sinks.many().unicast().onBackpressureBuffer();
        this.params = serverParameters;
        this.objectMapper = objectMapper;
        this.errorSink = Sinks.many().unicast().onBackpressureBuffer();
        this.inboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "inbound");
        this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound");
        this.errorScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "error");
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> function) {
        return Mono.fromRunnable(() -> {
            handleIncomingMessages(function);
            handleIncomingErrors();
            ArrayList arrayList = new ArrayList();
            arrayList.add(this.params.getCommand());
            arrayList.addAll(this.params.getArgs());
            ProcessBuilder processBuilder = getProcessBuilder();
            processBuilder.command(arrayList);
            processBuilder.environment().putAll(this.params.getEnv());
            try {
                this.process = processBuilder.start();
                if (this.process.getInputStream() == null || this.process.getOutputStream() == null) {
                    this.process.destroy();
                    throw new RuntimeException("Process input or output stream is null");
                }
                startInboundProcessing();
                startOutboundProcessing();
                startErrorProcessing();
            } catch (IOException e) {
                throw new RuntimeException("Failed to start process with command: " + arrayList, e);
            }
        }).subscribeOn(Schedulers.boundedElastic());
    }

    protected ProcessBuilder getProcessBuilder() {
        return new ProcessBuilder(new String[0]);
    }

    public void setStdErrorHandler(Consumer<String> consumer) {
        this.stdErrorHandler = consumer;
    }

    public void awaitForExit() {
        try {
            this.process.waitFor();
        } catch (InterruptedException e) {
            throw new RuntimeException("Process interrupted", e);
        }
    }

    private void startErrorProcessing() {
        this.errorScheduler.schedule(() -> {
            String readLine;
            try {
                try {
                    BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(this.process.getErrorStream()));
                    while (true) {
                        try {
                            if (this.isClosing || (readLine = bufferedReader.readLine()) == null) {
                                break;
                            }
                            try {
                                if (!this.errorSink.tryEmitNext(readLine).isSuccess()) {
                                    break;
                                }
                            } catch (Exception e) {
                                if (!this.isClosing) {
                                    logger.error("Error processing error message", e);
                                }
                            }
                        } catch (Throwable th) {
                            try {
                                bufferedReader.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                            throw th;
                        }
                    }
                    bufferedReader.close();
                    this.isClosing = true;
                    this.errorSink.tryEmitComplete();
                } catch (Throwable th3) {
                    this.isClosing = true;
                    this.errorSink.tryEmitComplete();
                    throw th3;
                }
            } catch (IOException e2) {
                if (!this.isClosing) {
                    logger.error("Error reading from error stream", e2);
                }
                this.isClosing = true;
                this.errorSink.tryEmitComplete();
            }
        });
    }

    private void handleIncomingMessages(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> function) {
        this.inboundSink.asFlux().flatMap(jSONRPCMessage -> {
            return Mono.just(jSONRPCMessage).transform(function).contextWrite(context -> {
                return context.put("observation", "myObservation");
            });
        }).subscribe();
    }

    private void handleIncomingErrors() {
        this.errorSink.asFlux().subscribe(str -> {
            this.stdErrorHandler.accept(str);
        });
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage jSONRPCMessage) {
        return this.outboundSink.tryEmitNext(jSONRPCMessage).isSuccess() ? Mono.empty() : Mono.error(new RuntimeException("Failed to enqueue message"));
    }

    private void startInboundProcessing() {
        this.inboundScheduler.schedule(() -> {
            String readLine;
            try {
                try {
                    BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(this.process.getInputStream()));
                    while (true) {
                        try {
                            if (this.isClosing || (readLine = bufferedReader.readLine()) == null) {
                                break;
                            }
                            try {
                                McpSchema.JSONRPCMessage deserializeJsonRpcMessage = McpSchema.deserializeJsonRpcMessage(this.objectMapper, readLine);
                                if (!this.inboundSink.tryEmitNext(deserializeJsonRpcMessage).isSuccess()) {
                                    break;
                                }
                            } catch (Exception e) {
                                if (!this.isClosing) {
                                    logger.error("Error processing inbound message for line: " + readLine, e);
                                }
                            }
                        } catch (Throwable th) {
                            try {
                                bufferedReader.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                            throw th;
                        }
                    }
                    bufferedReader.close();
                    this.isClosing = true;
                    this.inboundSink.tryEmitComplete();
                } catch (Throwable th3) {
                    this.isClosing = true;
                    this.inboundSink.tryEmitComplete();
                    throw th3;
                }
            } catch (IOException e2) {
                if (!this.isClosing) {
                    logger.error("Error reading from input stream", e2);
                }
                this.isClosing = true;
                this.inboundSink.tryEmitComplete();
            }
        });
    }

    private void startOutboundProcessing() {
        handleOutbound(flux -> {
            return flux.publishOn(this.outboundScheduler).handle((jSONRPCMessage, synchronousSink) -> {
                if (jSONRPCMessage == null || this.isClosing) {
                    return;
                }
                try {
                    String replace = this.objectMapper.writeValueAsString(jSONRPCMessage).replace("\r\n", "\\n").replace(SequenceUtils.EOL, "\\n").replace("\r", "\\n");
                    OutputStream outputStream = this.process.getOutputStream();
                    synchronized (outputStream) {
                        outputStream.write(replace.getBytes(StandardCharsets.UTF_8));
                        outputStream.write(SequenceUtils.EOL.getBytes(StandardCharsets.UTF_8));
                        outputStream.flush();
                    }
                    synchronousSink.next(jSONRPCMessage);
                } catch (IOException e) {
                    synchronousSink.error(new RuntimeException(e));
                }
            });
        });
    }

    protected void handleOutbound(Function<Flux<McpSchema.JSONRPCMessage>, Flux<McpSchema.JSONRPCMessage>> function) {
        function.apply(this.outboundSink.asFlux()).doOnComplete(() -> {
            this.isClosing = true;
            this.outboundSink.tryEmitComplete();
        }).doOnError(th -> {
            if (this.isClosing) {
                return;
            }
            logger.error("Error in outbound processing", th);
            this.isClosing = true;
            this.outboundSink.tryEmitComplete();
        }).subscribe();
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(() -> {
            this.isClosing = true;
            logger.debug("Initiating graceful shutdown");
        }).then(Mono.defer(() -> {
            this.inboundSink.tryEmitComplete();
            this.outboundSink.tryEmitComplete();
            this.errorSink.tryEmitComplete();
            return Mono.delay(Duration.ofMillis(100L));
        })).then(Mono.fromFuture(() -> {
            logger.debug("Sending TERM to process");
            if (this.process == null) {
                return CompletableFuture.failedFuture(new RuntimeException("Process not started"));
            }
            this.process.destroy();
            return this.process.onExit();
        })).doOnNext(process -> {
            if (process.exitValue() != 0) {
                logger.warn("Process terminated with code " + process.exitValue());
            }
        }).then(Mono.fromRunnable(() -> {
            try {
                this.inboundScheduler.dispose();
                this.errorScheduler.dispose();
                this.outboundScheduler.dispose();
                logger.debug("Graceful shutdown completed");
            } catch (Exception e) {
                logger.error("Error during graceful shutdown", e);
            }
        })).then().subscribeOn(Schedulers.boundedElastic());
    }

    public Sinks.Many<String> getErrorSink() {
        return this.errorSink;
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public <T> T unmarshalFrom(Object obj, TypeReference<T> typeReference) {
        return (T) this.objectMapper.convertValue(obj, typeReference);
    }
}
