package io.rsocket.broker.acceptor;

import io.rsocket.ConnectionSetupPayload;
import io.rsocket.RSocket;
import io.rsocket.SocketAcceptor;
import io.rsocket.broker.RSocketIndex;
import io.rsocket.broker.RoutingTable;
import io.rsocket.broker.common.Id;
import io.rsocket.broker.common.WellKnownKey;
import io.rsocket.broker.frames.BrokerFrame;
import io.rsocket.broker.frames.BrokerInfo;
import io.rsocket.broker.frames.RouteJoin;
import io.rsocket.broker.frames.RouteSetup;
import io.rsocket.broker.rsocket.ErrorOnDisconnectRSocket;
import io.rsocket.broker.rsocket.RoutingRSocketFactory;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/* loaded from: input_file:io/rsocket/broker/acceptor/BrokerSocketAcceptor.class */
public class BrokerSocketAcceptor implements SocketAcceptor {
    protected static final Logger logger = LoggerFactory.getLogger(BrokerSocketAcceptor.class);
    protected final Id brokerId;
    protected final RoutingTable routingTable;
    protected final RSocketIndex rSocketIndex;
    protected final Function<ConnectionSetupPayload, BrokerFrame> payloadExtractor;
    protected final BiConsumer<BrokerInfo, RSocket> brokerInfoConsumer;
    protected final Consumer<BrokerInfo> brokerInfoCleaner;
    protected final RoutingRSocketFactory routingRSocketFactory;

    public BrokerSocketAcceptor(Id id, RoutingTable routingTable, RSocketIndex rSocketIndex, RoutingRSocketFactory routingRSocketFactory, Function<ConnectionSetupPayload, BrokerFrame> function, BiConsumer<BrokerInfo, RSocket> biConsumer, Consumer<BrokerInfo> consumer) {
        this.brokerId = id;
        this.routingTable = routingTable;
        this.rSocketIndex = rSocketIndex;
        this.routingRSocketFactory = routingRSocketFactory;
        this.payloadExtractor = function;
        this.brokerInfoConsumer = biConsumer;
        this.brokerInfoCleaner = consumer;
        logger.info("Starting Broker {}", id);
    }

    public Mono<RSocket> accept(ConnectionSetupPayload connectionSetupPayload, RSocket rSocket) {
        try {
            BrokerInfo brokerInfo = (BrokerFrame) this.payloadExtractor.apply(connectionSetupPayload);
            Runnable runnable = () -> {
                cleanup(brokerInfo);
            };
            logger.debug("accept {}", brokerInfo);
            RSocket wrapSendingSocket = wrapSendingSocket(rSocket, brokerInfo);
            if (brokerInfo instanceof BrokerInfo) {
                this.brokerInfoConsumer.accept(brokerInfo, wrapSendingSocket);
                return finalize(rSocket, runnable);
            }
            if (!(brokerInfo instanceof RouteSetup)) {
                throw new IllegalStateException("RouteSetup not found in metadata");
            }
            RouteSetup routeSetup = (RouteSetup) brokerInfo;
            return Mono.defer(() -> {
                RouteJoin routeJoin = toRouteJoin(routeSetup);
                this.rSocketIndex.put(routeJoin.getRouteId(), wrapSendingSocket, routeJoin.getTags());
                this.routingTable.add(routeJoin);
                return finalize(rSocket, runnable);
            });
        } catch (Exception e) {
            logger.error("Error accepting setup", e);
            return Mono.error(e);
        }
    }

    private Mono<RSocket> finalize(RSocket rSocket, Runnable runnable) {
        RSocket create = this.routingRSocketFactory.create();
        Flux.first(new Publisher[]{create.onClose(), rSocket.onClose()}).doFinally(signalType -> {
            runnable.run();
        }).subscribe();
        return Mono.just(create);
    }

    private void cleanup(BrokerFrame brokerFrame) {
        if (brokerFrame instanceof BrokerInfo) {
            this.brokerInfoCleaner.accept((BrokerInfo) brokerFrame);
        } else if (brokerFrame instanceof RouteSetup) {
            Id routeId = ((RouteSetup) brokerFrame).getRouteId();
            this.routingTable.remove(routeId);
            this.rSocketIndex.remove(routeId);
        }
    }

    private RSocket wrapSendingSocket(RSocket rSocket, BrokerFrame brokerFrame) {
        ErrorOnDisconnectRSocket errorOnDisconnectRSocket = new ErrorOnDisconnectRSocket(rSocket);
        errorOnDisconnectRSocket.onClose().doFinally(signalType -> {
            logger.info("Closing socket for {}", brokerFrame);
        });
        return errorOnDisconnectRSocket;
    }

    private RouteJoin toRouteJoin(RouteSetup routeSetup) {
        return RouteJoin.builder().brokerId(this.brokerId).routeId(routeSetup.getRouteId()).serviceName(routeSetup.getServiceName()).with(routeSetup.getTags()).with(WellKnownKey.ROUTE_ID, routeSetup.getRouteId().toString()).with(WellKnownKey.SERVICE_NAME, routeSetup.getServiceName()).build();
    }
}
