/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.get.GetAction;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.get.MultiGetAction;
import org.opensearch.action.get.MultiGetResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.InferenceProcessor;
import org.opensearch.neuralsearch.processor.TextInferenceRequest;
import org.opensearch.neuralsearch.processor.optimization.InferenceFilter;
import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter;
import org.opensearch.neuralsearch.stats.events.EventStatName;
import org.opensearch.neuralsearch.stats.events.EventStatsManager;
import org.opensearch.transport.client.OpenSearchClient;

public final class TextEmbeddingProcessor
extends InferenceProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(TextEmbeddingProcessor.class);
    public static final String TYPE = "text_embedding";
    public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
    private final OpenSearchClient openSearchClient;
    private final boolean skipExisting;
    private final TextEmbeddingInferenceFilter textEmbeddingInferenceFilter;

    public TextEmbeddingProcessor(String tag, String description, int batchSize, String modelId, Map<String, Object> fieldMap, boolean skipExisting, TextEmbeddingInferenceFilter textEmbeddingInferenceFilter, OpenSearchClient openSearchClient, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService) {
        super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
        this.skipExisting = skipExisting;
        this.textEmbeddingInferenceFilter = textEmbeddingInferenceFilter;
        this.openSearchClient = openSearchClient;
    }

    @Override
    public void doExecute(IngestDocument ingestDocument, Map<String, Object> processMap, List<String> inferenceList, BiConsumer<IngestDocument, Exception> handler) {
        EventStatsManager.increment(EventStatName.TEXT_EMBEDDING_PROCESSOR_EXECUTIONS);
        if (!this.skipExisting) {
            this.generateAndSetInference(ingestDocument, processMap, inferenceList, handler);
            return;
        }
        EventStatsManager.increment(EventStatName.SKIP_EXISTING_EXECUTIONS);
        Object index = ingestDocument.getSourceAndMetadata().get("_index");
        Object id = ingestDocument.getSourceAndMetadata().get("_id");
        if (Objects.isNull(index) || Objects.isNull(id)) {
            this.generateAndSetInference(ingestDocument, processMap, inferenceList, handler);
            return;
        }
        this.openSearchClient.execute((ActionType)GetAction.INSTANCE, (ActionRequest)new GetRequest(index.toString(), id.toString()), ActionListener.wrap(response -> this.reuseOrGenerateEmbedding((GetResponse)response, ingestDocument, processMap, inferenceList, handler, (InferenceFilter)this.textEmbeddingInferenceFilter), e -> handler.accept((IngestDocument)null, (Exception)e)));
    }

    @Override
    public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
        this.mlCommonsClientAccessor.inferenceSentences((TextInferenceRequest)((TextInferenceRequest.TextInferenceRequestBuilder)((TextInferenceRequest.TextInferenceRequestBuilder)TextInferenceRequest.builder().modelId(this.modelId)).inputTexts(inferenceList)).build(), (ActionListener<List<List<Number>>>)ActionListener.wrap(handler::accept, onException));
    }

    @Override
    public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
        try {
            EventStatsManager.increment(EventStatName.TEXT_EMBEDDING_PROCESSOR_EXECUTIONS);
            if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
                handler.accept(ingestDocumentWrappers);
                return;
            }
            List<InferenceProcessor.DataForInference> dataForInferences = this.getDataForInference(ingestDocumentWrappers);
            List<String> inferenceList = this.constructInferenceTexts(dataForInferences);
            if (inferenceList.isEmpty()) {
                handler.accept(ingestDocumentWrappers);
                return;
            }
            if (!this.skipExisting) {
                this.doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
                return;
            }
            EventStatsManager.increment(EventStatName.SKIP_EXISTING_EXECUTIONS);
            this.openSearchClient.execute((ActionType)MultiGetAction.INSTANCE, (ActionRequest)this.buildMultiGetRequest(dataForInferences), ActionListener.wrap(response -> this.reuseOrGenerateEmbedding((MultiGetResponse)response, ingestDocumentWrappers, inferenceList, dataForInferences, handler, (InferenceFilter)this.textEmbeddingInferenceFilter), e -> this.updateWithExceptions(this.getIngestDocumentWrappers(dataForInferences), handler, (Exception)e)));
        }
        catch (Exception e2) {
            this.updateWithExceptions(ingestDocumentWrappers, handler, e2);
        }
    }
}

