/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.agents;

import java.time.Instant;
import java.util.HashMap;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.ml.action.agent.MLAgentRegistrationValidator;
import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.MLMemoryType;
import org.opensearch.ml.common.agent.AgentModelService;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLAgentModelSpec;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.engine.function_calling.FunctionCallingFactory;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.PutDataObjectRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportRegisterAgentAction
extends HandledTransportAction<ActionRequest, MLRegisterAgentResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportRegisterAgentAction.class);
    MLIndicesHandler mlIndicesHandler;
    Client client;
    SdkClient sdkClient;
    ClusterService clusterService;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private final MLAgentRegistrationValidator agentRegistrationValidator;

    @Inject
    public TransportRegisterAgentAction(TransportService transportService, ActionFilters actionFilters, Client client, SdkClient sdkClient, MLIndicesHandler mlIndicesHandler, ClusterService clusterService, MLFeatureEnabledSetting mlFeatureEnabledSetting, ContextManagementTemplateService contextManagementTemplateService) {
        super("cluster:admin/opensearch/ml/agents/register", transportService, actionFilters, MLRegisterAgentRequest::new);
        this.client = client;
        this.sdkClient = sdkClient;
        this.mlIndicesHandler = mlIndicesHandler;
        this.clusterService = clusterService;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.agentRegistrationValidator = new MLAgentRegistrationValidator(contextManagementTemplateService);
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterAgentResponse> listener) {
        User user = RestActionUtils.getUserContext(this.client);
        MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.fromActionRequest((ActionRequest)request);
        MLAgent mlAgent = registerAgentRequest.getMlAgent();
        if (mlAgent.getMemory() != null && MLMemoryType.REMOTE_AGENTIC_MEMORY.name().equalsIgnoreCase(mlAgent.getMemory().getType()) && !this.mlFeatureEnabledSetting.isRemoteAgenticMemoryEnabled()) {
            listener.onFailure((Exception)new OpenSearchStatusException(MLCommonsSettings.ML_COMMONS_REMOTE_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN, new Object[0]));
            return;
        }
        if (mlAgent.getModel() != null) {
            this.createModelAndRegisterAgent(mlAgent, listener);
            return;
        }
        this.registerAgent(mlAgent, listener);
    }

    private void createModelAndRegisterAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener) {
        try {
            MLRegisterModelInput modelInput = AgentModelService.createModelFromSpec((MLAgentModelSpec)mlAgent.getModel());
            MLRegisterModelRequest modelRequest = new MLRegisterModelRequest(modelInput);
            this.client.execute((ActionType)MLRegisterModelAction.INSTANCE, (ActionRequest)modelRequest, ActionListener.wrap(modelResponse -> {
                String llmInterface;
                String modelId = modelResponse.getModelId();
                HashMap<String, String> parameters = new HashMap<String, String>();
                if (mlAgent.getParameters() != null) {
                    parameters.putAll(mlAgent.getParameters());
                }
                if ((llmInterface = AgentModelService.inferLLMInterface((String)mlAgent.getModel().getModelProvider())) != null) {
                    parameters.put("_llm_interface", llmInterface);
                }
                LLMSpec llmSpec = LLMSpec.builder().modelId(modelId).parameters(mlAgent.getModel().getModelParameters()).build();
                MLAgentModelSpec sanitizedModelSpec = mlAgent.getModel().toBuilder().modelParameters(null).credential(null).build();
                MLAgent agent = mlAgent.toBuilder().llm(llmSpec).model(sanitizedModelSpec).parameters(parameters).build();
                this.registerAgent(agent, listener);
            }, arg_0 -> listener.onFailure(arg_0)));
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    private void registerAgent(MLAgent agent, ActionListener<MLRegisterAgentResponse> listener) {
        this.validateAgent(agent, (ActionListener<MLAgent>)ActionListener.wrap(validatedAgent -> this.proceedWithAgentRegistration((MLAgent)validatedAgent, listener), arg_0 -> listener.onFailure(arg_0)));
    }

    private void validateAgent(MLAgent agent, ActionListener<MLAgent> listener) {
        if (agent.hasContextManagementTemplate()) {
            String templateName = agent.getContextManagementTemplateName();
            this.agentRegistrationValidator.validateContextManagementTemplateAccess(templateName, (ActionListener<Boolean>)ActionListener.wrap(hasAccess -> {
                if (Boolean.TRUE.equals(hasAccess)) {
                    listener.onResponse((Object)agent);
                } else {
                    listener.onFailure((Exception)new IllegalArgumentException("You don't have permission to use the context management template provided, template name: " + templateName));
                }
            }, e -> {
                log.error("You don't have permission to use the context management template provided, template name: {}", (Object)templateName, e);
                listener.onFailure(e);
            }));
        } else if (agent.getInlineContextManagement() != null) {
            try {
                this.validateInlineContextManagement(agent);
                listener.onResponse((Object)agent);
            }
            catch (Exception e2) {
                listener.onFailure(e2);
            }
        } else {
            listener.onResponse((Object)agent);
        }
    }

    private void validateInlineContextManagement(MLAgent agent) {
        if (agent.getInlineContextManagement() == null) {
            log.error("You must provide context management content when creating an agent without providing context management template name!");
            throw new IllegalArgumentException("You must provide context management content when creating an agent without context management template name!");
        }
        if (!agent.getInlineContextManagement().isValid()) {
            log.error("Invalid context management configuration: configuration must have a name and at least one hook with valid context manager configurations");
            throw new IllegalArgumentException("Invalid context management configuration: configuration must have a name and at least one hook with valid context manager configurations");
        }
    }

    private void proceedWithAgentRegistration(MLAgent agent, ActionListener<MLRegisterAgentResponse> listener) {
        String llmInterface;
        String mcpConnectorConfigJSON;
        String string = mcpConnectorConfigJSON = agent.getParameters() != null ? (String)agent.getParameters().get("mcp_connectors") : null;
        if (mcpConnectorConfigJSON != null && !this.mlFeatureEnabledSetting.isMcpConnectorEnabled()) {
            listener.onFailure((Exception)new OpenSearchException(MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE, new Object[0]));
            return;
        }
        String string2 = llmInterface = agent.getParameters() != null ? (String)agent.getParameters().get("_llm_interface") : null;
        if (llmInterface != null) {
            if (llmInterface.trim().isEmpty()) {
                listener.onFailure((Exception)new IllegalArgumentException("_llm_interface cannot be blank or empty"));
                return;
            }
            try {
                FunctionCallingFactory.create((String)llmInterface);
            }
            catch (Exception e) {
                listener.onFailure((Exception)new IllegalArgumentException("Invalid _llm_interface"));
                return;
            }
        }
        Instant now = Instant.now();
        boolean isHiddenAgent = RestActionUtils.isSuperAdminUser(this.clusterService, this.client);
        MLAgent mlAgent = agent.toBuilder().createdTime(now).lastUpdateTime(now).isHidden(Boolean.valueOf(isHiddenAgent)).build();
        String tenantId = agent.getTenantId();
        if (!TenantAwareHelper.validateTenantId(this.mlFeatureEnabledSetting, tenantId, listener)) {
            return;
        }
        if (MLAgentType.from((String)mlAgent.getType()) == MLAgentType.PLAN_EXECUTE_AND_REFLECT && !mlAgent.getParameters().containsKey("executor_agent_id")) {
            this.createConversationAgent(mlAgent, tenantId, (ActionListener<String>)ActionListener.wrap(conversationAgentId -> {
                HashMap<String, String> parameters = new HashMap<String, String>(mlAgent.getParameters());
                parameters.put("executor_agent_id", (String)conversationAgentId);
                MLAgent updatedAgent = mlAgent.toBuilder().parameters(parameters).build();
                this.registerAgentToIndex(updatedAgent, tenantId, listener);
            }, arg_0 -> listener.onFailure(arg_0)));
        } else {
            this.registerAgentToIndex(mlAgent, tenantId, listener);
        }
    }

    private void createConversationAgent(MLAgent planExecuteReflectAgent, String tenantId, ActionListener<String> listener) {
        Instant now = Instant.now();
        boolean isHiddenAgent = RestActionUtils.isSuperAdminUser(this.clusterService, this.client);
        MLAgent conversationAgent = planExecuteReflectAgent.toBuilder().name(planExecuteReflectAgent.getName() + " (ReAct)").type(MLAgentType.CONVERSATIONAL.name()).description("Execution Agent for Plan Execute Reflect - " + planExecuteReflectAgent.getName()).createdTime(now).lastUpdateTime(now).isHidden(Boolean.valueOf(isHiddenAgent)).build();
        this.registerAgentToIndex(conversationAgent, tenantId, (ActionListener<MLRegisterAgentResponse>)ActionListener.wrap(response -> listener.onResponse((Object)response.getAgentId()), arg_0 -> listener.onFailure(arg_0)));
    }

    private void registerAgentToIndex(MLAgent mlAgent, String tenantId, ActionListener<MLRegisterAgentResponse> listener) {
        this.mlIndicesHandler.initMLAgentIndex(ActionListener.wrap(result -> {
            if (result.booleanValue()) {
                try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                    this.sdkClient.putDataObjectAsync(((PutDataObjectRequest.Builder)((PutDataObjectRequest.Builder)PutDataObjectRequest.builder().index(".plugins-ml-agent")).tenantId(tenantId)).dataObject((ToXContentObject)mlAgent).build()).whenComplete((r, throwable) -> {
                        context.restore();
                        if (throwable != null) {
                            Exception cause = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
                            log.error("Failed to index ML agent", (Throwable)cause);
                            listener.onFailure(cause);
                        } else {
                            try {
                                IndexResponse indexResponse = r.indexResponse();
                                log.info("Agent creation result: {}, Agent id: {}", (Object)indexResponse.getResult(), (Object)indexResponse.getId());
                                MLRegisterAgentResponse response = new MLRegisterAgentResponse(r.id());
                                listener.onResponse((Object)response);
                            }
                            catch (Exception e) {
                                listener.onFailure(e);
                            }
                        }
                    });
                }
                catch (Exception e) {
                    log.error("Failed to index ML agent", (Throwable)e);
                    listener.onFailure(e);
                }
            } else {
                log.error("Failed to create ML agent index");
                listener.onFailure((Exception)new OpenSearchException("Failed to create ML agent index", new Object[0]));
            }
        }, e -> {
            log.error("Failed to create ML agent index", (Throwable)e);
            listener.onFailure(e);
        }));
    }
}

