/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.openai.api;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Predicate;
import org.reactivestreams.Publisher;
import org.springframework.boot.context.properties.bind.ConstructorBinding;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class OpenAiApi {
    private static final String DEFAULT_BASE_URL = "https://api.openai.com";
    public static final String DEFAULT_CHAT_MODEL = "gpt-3.5-turbo";
    public static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002";
    private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;
    private final RestClient restClient;
    private final WebClient webClient;
    private final ObjectMapper objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);

    public OpenAiApi(String openAiToken) {
        this(DEFAULT_BASE_URL, openAiToken);
    }

    public OpenAiApi(String baseUrl, String openAiToken) {
        this(baseUrl, openAiToken, RestClient.builder());
    }

    public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
        Consumer<HttpHeaders> jsonContentHeaders = headers -> {
            headers.setBearerAuth(openAiToken);
            headers.setContentType(MediaType.APPLICATION_JSON);
        };
        ResponseErrorHandler responseErrorHandler = new ResponseErrorHandler(){

            public boolean hasError(ClientHttpResponse response) throws IOException {
                return response.getStatusCode().isError();
            }

            public void handleError(ClientHttpResponse response) throws IOException {
                if (response.getStatusCode().isError()) {
                    if (response.getStatusCode().is4xxClientError()) {
                        throw new OpenAiApiClientErrorException(String.format("%s - %s", response.getStatusCode().value(), OpenAiApi.this.objectMapper.readValue(response.getBody(), ResponseError.class)));
                    }
                    throw new OpenAiApiException(String.format("%s - %s", response.getStatusCode().value(), OpenAiApi.this.objectMapper.readValue(response.getBody(), ResponseError.class)));
                }
            }
        };
        this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).defaultStatusHandler(responseErrorHandler).build();
        this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build();
    }

    public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest chatRequest) {
        Assert.notNull((Object)chatRequest, (String)"The request body can not be null.");
        Assert.isTrue((chatRequest.stream() == false ? 1 : 0) != 0, (String)"Request must set the steam property to false.");
        return ((RestClient.RequestBodySpec)this.restClient.post().uri("/v1/chat/completions", new Object[0])).body((Object)chatRequest).retrieve().toEntity(ChatCompletion.class);
    }

    public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chatRequest) {
        Assert.notNull((Object)chatRequest, (String)"The request body can not be null.");
        Assert.isTrue((boolean)chatRequest.stream(), (String)"Request must set the steam property to true.");
        return ((WebClient.RequestBodySpec)this.webClient.post().uri("/v1/chat/completions", new Object[0])).body((Publisher)Mono.just((Object)chatRequest), ChatCompletionRequest.class).retrieve().bodyToFlux(String.class).takeUntil(SSE_DONE_PREDICATE).filter(SSE_DONE_PREDICATE.negate()).map(content -> this.parseJson((String)content, (Class)ChatCompletionChunk.class));
    }

    public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<T> embeddingRequest) {
        Assert.notNull(embeddingRequest, (String)"The request body can not be null.");
        Assert.notNull(embeddingRequest.input(), (String)"The input can not be null.");
        Assert.isTrue((embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List ? 1 : 0) != 0, (String)"The input must be either a String, or a List of Strings or List of List of integers.");
        T t = embeddingRequest.input();
        if (t instanceof List) {
            List list = (List)t;
            Assert.isTrue((!CollectionUtils.isEmpty((Collection)list) ? 1 : 0) != 0, (String)"The input list can not be empty.");
            Assert.isTrue((list.size() <= 2048 ? 1 : 0) != 0, (String)"The list must be 2048 dimensions or less");
            Assert.isTrue((list.get(0) instanceof String || list.get(0) instanceof Integer || list.get(0) instanceof List ? 1 : 0) != 0, (String)"The input must be either a String, or a List of Strings or list of list of integers.");
        }
        return ((RestClient.RequestBodySpec)this.restClient.post().uri("/v1/embeddings", new Object[0])).body(embeddingRequest).retrieve().toEntity((ParameterizedTypeReference)new ParameterizedTypeReference<EmbeddingList<Embedding>>(){});
    }

    public static Map<String, Object> parseJson(String jsonSchema) {
        try {
            return (Map)new ObjectMapper().readValue(jsonSchema, (TypeReference)new TypeReference<Map<String, Object>>(){});
        }
        catch (Exception e) {
            throw new OpenAiApiException("Failed to parse schema: " + jsonSchema, e);
        }
    }

    private <T> T parseJson(String json, Class<T> type) {
        try {
            return (T)this.objectMapper.readValue(json, type);
        }
        catch (Exception e) {
            throw new OpenAiApiException("Failed to parse schema: " + json, e);
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record ChatCompletionRequest(@JsonProperty(value="messages") List<ChatCompletionMessage> messages, @JsonProperty(value="model") String model, @JsonProperty(value="frequency_penalty") Float frequencyPenalty, @JsonProperty(value="logit_bias") Map<String, Integer> logitBias, @JsonProperty(value="max_tokens") Integer maxTokens, @JsonProperty(value="n") Integer n, @JsonProperty(value="presence_penalty") Float presencePenalty, @JsonProperty(value="response_format") ResponseFormat responseFormat, @JsonProperty(value="seed") Integer seed, @JsonProperty(value="stop") List<String> stop, @JsonProperty(value="stream") Boolean stream, @JsonProperty(value="temperature") Float temperature, @JsonProperty(value="top_p") Float topP, @JsonProperty(value="tools") List<FunctionTool> tools, @JsonProperty(value="tool_choice") ToolChoice toolChoice, @JsonProperty(value="user") String user) {
        public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Float temperature) {
            this(messages, model, Float.valueOf(0.0f), null, null, 1, Float.valueOf(0.0f), null, null, null, false, temperature, null, null, null, null);
        }

        public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Float temperature, boolean stream) {
            this(messages, model, Float.valueOf(0.0f), null, null, 1, Float.valueOf(0.0f), null, null, null, stream, temperature, null, null, null, null);
        }

        public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, List<FunctionTool> tools, ToolChoice toolChoice) {
            this(messages, model, Float.valueOf(0.0f), null, null, 1, Float.valueOf(0.0f), null, null, null, false, Float.valueOf(0.8f), null, tools, toolChoice, null);
        }

        public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
            this(messages, null, null, null, null, null, null, null, null, null, stream, null, null, null, null, null);
        }

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record ResponseFormat(@JsonProperty(value="type") String type) {
        }

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record ToolChoice(@JsonProperty(value="type") String type, @JsonProperty(value="function") Map<String, String> function) {
            @ConstructorBinding
            public ToolChoice(String functionName) {
                this("function", Map.of("name", functionName));
            }
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record ChatCompletion(@JsonProperty(value="id") String id, @JsonProperty(value="choices") List<Choice> choices, @JsonProperty(value="created") Long created, @JsonProperty(value="model") String model, @JsonProperty(value="system_fingerprint") String systemFingerprint, @JsonProperty(value="object") String object, @JsonProperty(value="usage") Usage usage) {

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record Choice(@JsonProperty(value="finish_reason") ChatCompletionFinishReason finishReason, @JsonProperty(value="index") Integer index, @JsonProperty(value="message") ChatCompletionMessage message, @JsonProperty(value="logprobs") LogProbs logprobs) {
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record EmbeddingRequest<T>(@JsonProperty(value="input") T input, @JsonProperty(value="model") String model, @JsonProperty(value="encoding_format") String encodingFormat, @JsonProperty(value="user") String user) {
        public EmbeddingRequest(T input, String model) {
            this(input, model, "float", null);
        }

        public EmbeddingRequest(T input) {
            this(input, OpenAiApi.DEFAULT_EMBEDDING_MODEL);
        }
    }

    public static class OpenAiApiException
    extends RuntimeException {
        public OpenAiApiException(String message) {
            super(message);
        }

        public OpenAiApiException(String message, Throwable cause) {
            super(message, cause);
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record ChatCompletionChunk(@JsonProperty(value="id") String id, @JsonProperty(value="choices") List<ChunkChoice> choices, @JsonProperty(value="created") Long created, @JsonProperty(value="model") String model, @JsonProperty(value="system_fingerprint") String systemFingerprint, @JsonProperty(value="object") String object) {

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record ChunkChoice(@JsonProperty(value="finish_reason") ChatCompletionFinishReason finishReason, @JsonProperty(value="index") Integer index, @JsonProperty(value="delta") ChatCompletionMessage delta, @JsonProperty(value="logprobs") LogProbs logprobs) {
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record EmbeddingList<T>(@JsonProperty(value="object") String object, @JsonProperty(value="data") List<T> data, @JsonProperty(value="model") String model, @JsonProperty(value="usage") Usage usage) {
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record Embedding(@JsonProperty(value="index") Integer index, @JsonProperty(value="embedding") List<Double> embedding, @JsonProperty(value="object") String object) {
        public Embedding(Integer index, List<Double> embedding) {
            this(index, embedding, "embedding");
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record Usage(@JsonProperty(value="completion_tokens") Integer completionTokens, @JsonProperty(value="prompt_tokens") Integer promptTokens, @JsonProperty(value="total_tokens") Integer totalTokens) {
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record LogProbs(@JsonProperty(value="content") List<Content> content) {

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record Content(@JsonProperty(value="token") String token, @JsonProperty(value="logprob") Float logprob, @JsonProperty(value="bytes") List<Integer> probBytes, @JsonProperty(value="top_logprobs") List<TopLogProbs> topLogprobs) {

            @JsonInclude(value=JsonInclude.Include.NON_NULL)
            public record TopLogProbs(@JsonProperty(value="token") String token, @JsonProperty(value="logprob") Float logprob, @JsonProperty(value="bytes") List<Integer> probBytes) {
            }
        }
    }

    public static enum ChatCompletionFinishReason {
        STOP,
        LENGTH,
        CONTENT_FILTER,
        TOOL_CALLS,
        FUNCTION_CALL;

    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record ChatCompletionMessage(@JsonProperty(value="content") String content, @JsonProperty(value="role") Role role, @JsonProperty(value="name") String name, @JsonProperty(value="tool_call_id") String toolCallId, @JsonProperty(value="tool_calls") List<ToolCall> toolCalls) {
        public ChatCompletionMessage(String content, Role role) {
            this(content, role, null, null, null);
        }

        public static enum Role {
            SYSTEM,
            USER,
            ASSISTANT,
            TOOL;

        }

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record ChatCompletionFunction(@JsonProperty(value="name") String name, @JsonProperty(value="arguments") String arguments) {
        }

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record ToolCall(@JsonProperty(value="id") String id, @JsonProperty(value="type") String type, @JsonProperty(value="function") ChatCompletionFunction function) {
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record FunctionTool(@JsonProperty(value="type") Type type, @JsonProperty(value="function") Function function) {
        @ConstructorBinding
        public FunctionTool(Function function) {
            this(Type.FUNCTION, function);
        }

        public static enum Type {
            FUNCTION;

        }

        public record Function(@JsonProperty(value="description") String description, @JsonProperty(value="name") String name, @JsonProperty(value="parameters") Map<String, Object> parameters) {
            @ConstructorBinding
            public Function(String description, String name, String jsonSchema) {
                this(description, name, OpenAiApi.parseJson(jsonSchema));
            }
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record ResponseError(@JsonProperty(value="error") Error error) {

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record Error(@JsonProperty(value="message") String message, @JsonProperty(value="type") String type, @JsonProperty(value="param") String param, @JsonProperty(value="code") String code) {
        }
    }

    public static class OpenAiApiClientErrorException
    extends RuntimeException {
        public OpenAiApiClientErrorException(String message) {
            super(message);
        }

        public OpenAiApiClientErrorException(String message, Throwable cause) {
            super(message, cause);
        }
    }
}

