/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.calcite.plan.rule;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.rules.SubstitutionRule;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.commons.lang3.tuple.Pair;
import org.immutables.value.Value;
import org.opensearch.sql.calcite.plan.rule.ImmutablePPLAggGroupMergeRule;
import org.opensearch.sql.calcite.plan.rule.OpenSearchRuleConfig;
import org.opensearch.sql.calcite.utils.CalciteUtils;
import org.opensearch.sql.calcite.utils.PlanUtils;

@Value.Enclosing
public class PPLAggGroupMergeRule
extends RelRule<Config>
implements SubstitutionRule {
    protected PPLAggGroupMergeRule(Config config) {
        super((RelRule.Config)config);
    }

    public void onMatch(RelOptRuleCall call) {
        if (call.rels.length != 2) {
            throw new AssertionError((Object)String.format("The length of rels should be %s but got %s", this.operands.size(), call.rels.length));
        }
        LogicalAggregate aggregate = (LogicalAggregate)call.rel(0);
        LogicalProject project = (LogicalProject)call.rel(1);
        this.apply(call, aggregate, project);
    }

    public void apply(RelOptRuleCall call, LogicalAggregate aggregate, LogicalProject project) {
        List groupSet = aggregate.getGroupSet().asList();
        List<RexNode> groupNodes = groupSet.stream().map(group -> (RexNode)project.getProjects().get((int)group)).toList();
        Pair<List<Integer>, List<Integer>> baseFieldsAndOthers = CalciteUtils.partition(groupSet, i -> ((RexNode)project.getProjects().get((int)i)).getKind() == SqlKind.INPUT_REF);
        List baseGroupList = (List)baseFieldsAndOthers.getLeft();
        if (baseGroupList.size() != 1) {
            return;
        }
        Integer baseGroupField = (Integer)baseGroupList.get(0);
        RexInputRef baseGroupRef = (RexInputRef)project.getProjects().get(baseGroupField);
        List otherGroupList = (List)baseFieldsAndOthers.getRight();
        boolean allDependOnBaseField = otherGroupList.stream().map(i -> (RexNode)project.getProjects().get((int)i)).allMatch(node -> PPLAggGroupMergeRule.isDependentField(node, List.of(baseGroupRef)));
        if (!allDependOnBaseField) {
            return;
        }
        RelBuilder relBuilder = call.builder();
        relBuilder.push((RelNode)project);
        relBuilder.aggregate(relBuilder.groupKey(ImmutableBitSet.of((int)baseGroupField)), aggregate.getAggCallList());
        Mapping mapping = Mappings.target(List.of(Integer.valueOf(baseGroupRef.getIndex())), (int)(baseGroupRef.getIndex() + 1));
        ArrayList parentProjections = new ArrayList(RexUtil.apply((Mappings.TargetMapping)mapping, groupNodes));
        ImmutableList aggCallRefs = relBuilder.fields(IntStream.range(baseGroupList.size(), relBuilder.peek().getRowType().getFieldCount()).boxed().toList());
        parentProjections.addAll(aggCallRefs);
        relBuilder.project(parentProjections);
        call.transformTo(relBuilder.build());
        PlanUtils.tryPruneRelNodes(call);
    }

    public static boolean isDependentField(RexNode node, Collection<RexNode> baseFields) {
        if (node.getKind() == SqlKind.LITERAL) {
            return true;
        }
        if (node.getKind() == SqlKind.INPUT_REF && baseFields.contains(node)) {
            return true;
        }
        if (node instanceof RexCall && ((RexCall)node).getOperator().isDeterministic() && !((RexCall)node).getOperator().isAggregator()) {
            return ((RexCall)node).getOperands().stream().allMatch(op -> PPLAggGroupMergeRule.isDependentField(op, baseFields));
        }
        return false;
    }

    @Value.Immutable
    public static interface Config
    extends OpenSearchRuleConfig {
        public static final Config GROUP_MERGE = ImmutablePPLAggGroupMergeRule.Config.builder().build().withOperandSupplier(b0 -> b0.operand(LogicalAggregate.class).predicate(Config::containsMultipleGroupSets).oneInput(b1 -> b1.operand(LogicalProject.class).predicate(Config::containsDependentFields).anyInputs()));

        public static boolean containsMultipleGroupSets(LogicalAggregate aggregate) {
            return aggregate.getGroupSet().cardinality() > 1;
        }

        public static boolean containsDependentFields(LogicalProject project) {
            Set baseFields = project.getProjects().stream().filter(node -> node.getKind() == SqlKind.INPUT_REF).collect(Collectors.toUnmodifiableSet());
            return project.getProjects().stream().anyMatch(node -> PPLAggGroupMergeRule.isDependentField(node, baseFields));
        }

        default public PPLAggGroupMergeRule toRule() {
            return new PPLAggGroupMergeRule(this);
        }
    }
}

