/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchrelevance.ml;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.searchrelevance.ml.ChunkProcessingContext;
import org.opensearch.searchrelevance.ml.ChunkResult;
import org.opensearch.searchrelevance.ml.MLInputOutputTransformer;

public class MLAccessor {
    @Generated
    private static final Logger log = LogManager.getLogger(MLAccessor.class);
    private final MachineLearningNodeClient mlClient;
    private final MLInputOutputTransformer transformer;
    private static final int MAX_RETRY_NUMBER = 3;
    private static final long RETRY_DELAY_MS = 1000L;

    public MLAccessor(MachineLearningNodeClient mlClient) {
        this.mlClient = mlClient;
        this.transformer = new MLInputOutputTransformer();
    }

    public void predict(String modelId, int tokenLimit, String searchText, String reference, Map<String, String> hits, ActionListener<ChunkResult> progressListener) {
        List<MLInput> mlInputs = this.transformer.createMLInputs(tokenLimit, searchText, reference, hits);
        log.info("Number of chunks: {}", (Object)mlInputs.size());
        ChunkProcessingContext context = new ChunkProcessingContext(mlInputs.size(), progressListener);
        for (int i = 0; i < mlInputs.size(); ++i) {
            this.processChunk(modelId, mlInputs.get(i), i, context);
        }
    }

    private void processChunk(String modelId, MLInput mlInput, int chunkIndex, ChunkProcessingContext context) {
        this.predictSingleChunkWithRetry(modelId, mlInput, chunkIndex, 0, (ActionListener<String>)ActionListener.wrap(response -> {
            log.info("Chunk {} processed successfully", (Object)chunkIndex);
            String processedResponse = this.cleanResponse((String)response);
            context.handleSuccess(chunkIndex, processedResponse);
        }, e -> {
            log.error("Chunk {} failed after all retries", (Object)chunkIndex, e);
            context.handleFailure(chunkIndex, (Exception)e);
        }));
    }

    private String cleanResponse(String response) {
        return response.substring(1, response.length() - 1);
    }

    private void predictSingleChunkWithRetry(final String modelId, final MLInput mlInput, final int chunkIndex, final int retryCount, final ActionListener<String> chunkListener) {
        this.predictSingleChunk(modelId, mlInput, new ActionListener<String>(){
            final /* synthetic */ MLAccessor this$0;
            {
                this.this$0 = this$0;
            }

            public void onResponse(String response) {
                chunkListener.onResponse((Object)response);
            }

            public void onFailure(Exception e) {
                if (retryCount < 3) {
                    log.warn("Chunk {} failed, attempt {}/{}. Retrying...", (Object)chunkIndex, (Object)(retryCount + 1), (Object)3);
                    long delay = 1000L * (long)Math.pow(2.0, retryCount);
                    this.this$0.scheduleRetry(() -> this.this$0.predictSingleChunkWithRetry(modelId, mlInput, chunkIndex, retryCount + 1, (ActionListener<String>)chunkListener), delay);
                } else {
                    chunkListener.onFailure(e);
                }
            }
        });
    }

    private void scheduleRetry(Runnable runnable, long delayMs) {
        CompletableFuture.delayedExecutor(delayMs, TimeUnit.MILLISECONDS).execute(runnable);
    }

    public void predictSingleChunk(String modelId, MLInput mlInput, ActionListener<String> listener) {
        this.mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> listener.onResponse((Object)this.transformer.extractResponseContent((MLOutput)mlOutput)), arg_0 -> listener.onFailure(arg_0)));
    }
}

