/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.drill.exec.planner.logical;


import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.drill.exec.planner.sql.DrillSqlOperator;
import org.eigenbase.rel.AggregateCall;
import org.eigenbase.rel.AggregateRel;
import org.eigenbase.rel.AggregateRelBase;
import org.eigenbase.rel.CalcRel;
import org.eigenbase.rel.RelNode;
import org.eigenbase.relopt.RelOptRule;
import org.eigenbase.relopt.RelOptRuleCall;
import org.eigenbase.relopt.RelOptRuleOperand;
import org.eigenbase.reltype.RelDataType;
import org.eigenbase.reltype.RelDataTypeFactory;
import org.eigenbase.reltype.RelDataTypeField;
import org.eigenbase.rex.RexBuilder;
import org.eigenbase.rex.RexCall;
import org.eigenbase.rex.RexLiteral;
import org.eigenbase.rex.RexNode;
import org.eigenbase.sql.SqlAggFunction;
import org.eigenbase.sql.fun.SqlAvgAggFunction;
import org.eigenbase.sql.fun.SqlStdOperatorTable;
import org.eigenbase.sql.fun.SqlSumAggFunction;
import org.eigenbase.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.eigenbase.sql.type.SqlTypeUtil;
import org.eigenbase.util.CompositeList;
import org.eigenbase.util.ImmutableIntList;
import org.eigenbase.util.Util;

import com.google.common.collect.ImmutableList;

/**
 * Rule to reduce aggregates to simpler forms. Currently only AVG(x) to
 * SUM(x)/COUNT(x), but eventually will handle others such as STDDEV.
 */
public class DrillReduceAggregatesRule extends RelOptRule {
  //~ Static fields/initializers ---------------------------------------------

  /**
   * The singleton.
   */
  public static final DrillReduceAggregatesRule INSTANCE =
      new DrillReduceAggregatesRule(operand(AggregateRel.class, any()));

  private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1);

  //~ Constructors -----------------------------------------------------------

  protected DrillReduceAggregatesRule(RelOptRuleOperand operand) {
    super(operand);
  }

  //~ Methods ----------------------------------------------------------------

  @Override
  public boolean matches(RelOptRuleCall call) {
    if (!super.matches(call)) {
      return false;
    }
    AggregateRelBase oldAggRel = (AggregateRelBase) call.rels[0];
    return containsAvgStddevVarCall(oldAggRel.getAggCallList());
  }

  public void onMatch(RelOptRuleCall ruleCall) {
    AggregateRelBase oldAggRel = (AggregateRelBase) ruleCall.rels[0];
    reduceAggs(ruleCall, oldAggRel);
  }

  /**
   * Returns whether any of the aggregates are calls to AVG, STDDEV_*, VAR_*.
   *
   * @param aggCallList List of aggregate calls
   */
  private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) {
    for (AggregateCall call : aggCallList) {
      if (call.getAggregation() instanceof SqlAvgAggFunction
          || call.getAggregation() instanceof SqlSumAggFunction) {
        return true;
      }
    }
    return false;
  }

  /*
  private boolean isMatch(AggregateCall call) {
    if (call.getAggregation() instanceof SqlAvgAggFunction) {
      final SqlAvgAggFunction.Subtype subtype =
          ((SqlAvgAggFunction) call.getAggregation()).getSubtype();
      return (subtype == SqlAvgAggFunction.Subtype.AVG);
    }
    return false;
  }
 */

  /**
   * Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in
   * the aggregates list to.
   *
   * <p>It handles newly generated common subexpressions since this was done
   * at the sql2rel stage.
   */
  private void reduceAggs(
      RelOptRuleCall ruleCall,
      AggregateRelBase oldAggRel) {
    RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();

    List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
    final int nGroups = oldAggRel.getGroupCount();

    List<AggregateCall> newCalls = new ArrayList<AggregateCall>();
    Map<AggregateCall, RexNode> aggCallMapping =
        new HashMap<AggregateCall, RexNode>();

    List<RexNode> projList = new ArrayList<RexNode>();

    // pass through group key
    for (int i = 0; i < nGroups; ++i) {
      projList.add(
          rexBuilder.makeInputRef(
              getFieldType(oldAggRel, i),
              i));
    }

    // List of input expressions. If a particular aggregate needs more, it
    // will add an expression to the end, and we will create an extra
    // project.
    RelNode input = oldAggRel.getChild();
    List<RexNode> inputExprs = new ArrayList<RexNode>();
    for (RelDataTypeField field : input.getRowType().getFieldList()) {
      inputExprs.add(
          rexBuilder.makeInputRef(
              field.getType(), inputExprs.size()));
    }

    // create new agg function calls and rest of project list together
    for (AggregateCall oldCall : oldCalls) {
      projList.add(
          reduceAgg(
              oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
    }

    final int extraArgCount =
        inputExprs.size() - input.getRowType().getFieldCount();
    if (extraArgCount > 0) {
      input =
          CalcRel.createProject(
              input,
              inputExprs,
              CompositeList.of(
                  input.getRowType().getFieldNames(),
                  Collections.<String>nCopies(
                      extraArgCount,
                      null)));
    }
    AggregateRelBase newAggRel =
        newAggregateRel(
            oldAggRel, input, newCalls);

    RelNode projectRel =
        CalcRel.createProject(
            newAggRel,
            projList,
            oldAggRel.getRowType().getFieldNames());

    ruleCall.transformTo(projectRel);
  }

  private RexNode reduceAgg(
      AggregateRelBase oldAggRel,
      AggregateCall oldCall,
      List<AggregateCall> newCalls,
      Map<AggregateCall, RexNode> aggCallMapping,
      List<RexNode> inputExprs) {
    if (oldCall.getAggregation() instanceof SqlSumAggFunction) {
      // replace original SUM(x) with
      // case COUNT(x) when 0 then null else SUM0(x) end
      return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
    }
    if (oldCall.getAggregation() instanceof SqlAvgAggFunction) {
      final SqlAvgAggFunction.Subtype subtype =
          ((SqlAvgAggFunction) oldCall.getAggregation()).getSubtype();

      switch (subtype) {
      case AVG:
        // replace original AVG(x) with SUM(x) / COUNT(x)
        return reduceAvg(
            oldAggRel, oldCall, newCalls, aggCallMapping);
      case STDDEV_POP:
        // replace original STDDEV_POP(x) with
        //   SQRT(
        //     (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
        //     / COUNT(x))
        return reduceStddev(
            oldAggRel, oldCall, true, true, newCalls, aggCallMapping,
            inputExprs);
      case STDDEV_SAMP:
        // replace original STDDEV_POP(x) with
        //   SQRT(
        //     (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
        //     / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
        return reduceStddev(
            oldAggRel, oldCall, false, true, newCalls, aggCallMapping,
            inputExprs);
      case VAR_POP:
        // replace original VAR_POP(x) with
        //     (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
        //     / COUNT(x)
        return reduceStddev(
            oldAggRel, oldCall, true, false, newCalls, aggCallMapping,
            inputExprs);
      case VAR_SAMP:
        // replace original VAR_POP(x) with
        //     (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
        //     / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
        return reduceStddev(
            oldAggRel, oldCall, false, false, newCalls, aggCallMapping,
            inputExprs);
      default:
        throw Util.unexpected(subtype);
      }
    } else {
      // anything else:  preserve original call
      RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
      final int nGroups = oldAggRel.getGroupCount();
      List<RelDataType> oldArgTypes = SqlTypeUtil
          .projectTypes(oldAggRel.getRowType(), oldCall.getArgList());
      return rexBuilder.addAggCall(
          oldCall,
          nGroups,
          newCalls,
          aggCallMapping,
          oldArgTypes);
    }
  }

  private RexNode reduceAvg(
      AggregateRelBase oldAggRel,
      AggregateCall oldCall,
      List<AggregateCall> newCalls,
      Map<AggregateCall, RexNode> aggCallMapping) {
    final int nGroups = oldAggRel.getGroupCount();
    RelDataTypeFactory typeFactory =
        oldAggRel.getCluster().getTypeFactory();
    RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
    int iAvgInput = oldCall.getArgList().get(0);
    RelDataType avgInputType =
        getFieldType(
            oldAggRel.getChild(),
            iAvgInput);
    RelDataType sumType =
        typeFactory.createTypeWithNullability(
            avgInputType,
            avgInputType.isNullable() || nGroups == 0);
    // SqlAggFunction sumAgg = new SqlSumAggFunction(sumType);
    SqlAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction(sumType);
    AggregateCall sumCall =
        new AggregateCall(
            sumAgg,
            oldCall.isDistinct(),
            oldCall.getArgList(),
            sumType,
            null);
    SqlAggFunction countAgg = SqlStdOperatorTable.COUNT;
    RelDataType countType = countAgg.getReturnType(typeFactory);
    AggregateCall countCall =
        new AggregateCall(
            countAgg,
            oldCall.isDistinct(),
            oldCall.getArgList(),
            countType,
            null);

    RexNode tmpsumRef =
        rexBuilder.addAggCall(
            sumCall,
            nGroups,
            newCalls,
            aggCallMapping,
            ImmutableList.of(avgInputType));

    RexNode tmpcountRef =
        rexBuilder.addAggCall(
            countCall,
            nGroups,
            newCalls,
            aggCallMapping,
            ImmutableList.of(avgInputType));

    RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE,
        rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
            tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
            rexBuilder.constantNull(),
            tmpsumRef);

    // NOTE:  these references are with respect to the output
    // of newAggRel
    /*
    RexNode numeratorRef =
        rexBuilder.makeCall(CastHighOp,
          rexBuilder.addAggCall(
              sumCall,
              nGroups,
              newCalls,
              aggCallMapping,
              ImmutableList.of(avgInputType))
        );
    */
    RexNode numeratorRef = rexBuilder.makeCall(CastHighOp,  n);

    RexNode denominatorRef =
        rexBuilder.addAggCall(
            countCall,
            nGroups,
            newCalls,
            aggCallMapping,
            ImmutableList.of(avgInputType));
    final RexNode divideRef =
        rexBuilder.makeCall(
            SqlStdOperatorTable.DIVIDE,
            numeratorRef,
            denominatorRef);
    return rexBuilder.makeCast(
        oldCall.getType(), divideRef);
  }

  private RexNode reduceSum(
      AggregateRelBase oldAggRel,
      AggregateCall oldCall,
      List<AggregateCall> newCalls,
      Map<AggregateCall, RexNode> aggCallMapping) {
    final int nGroups = oldAggRel.getGroupCount();
    RelDataTypeFactory typeFactory =
        oldAggRel.getCluster().getTypeFactory();
    RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
    int arg = oldCall.getArgList().get(0);
    RelDataType argType =
        getFieldType(
            oldAggRel.getChild(),
            arg);
    RelDataType sumType =
        typeFactory.createTypeWithNullability(
            argType, argType.isNullable());
    SqlAggFunction sumZeroAgg = new SqlSumEmptyIsZeroAggFunction(sumType);
    AggregateCall sumZeroCall =
        new AggregateCall(
            sumZeroAgg,
            oldCall.isDistinct(),
            oldCall.getArgList(),
            sumType,
            null);
    SqlAggFunction countAgg = SqlStdOperatorTable.COUNT;
    RelDataType countType = countAgg.getReturnType(typeFactory);
    AggregateCall countCall =
        new AggregateCall(
            countAgg,
            oldCall.isDistinct(),
            oldCall.getArgList(),
            countType,
            null);

    // NOTE:  these references are with respect to the output
    // of newAggRel
    RexNode sumZeroRef =
        rexBuilder.addAggCall(
            sumZeroCall,
            nGroups,
            newCalls,
            aggCallMapping,
            ImmutableList.of(argType));
    if (!oldCall.getType().isNullable()) {
      // If SUM(x) is not nullable, the validator must have determined that
      // nulls are impossible (because the group is never empty and x is never
      // null). Therefore we translate to SUM0(x).
      return sumZeroRef;
    }
    RexNode countRef =
        rexBuilder.addAggCall(
            countCall,
            nGroups,
            newCalls,
            aggCallMapping,
            ImmutableList.of(argType));
    return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
        rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
            countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
            rexBuilder.constantNull(),
            sumZeroRef);
  }

  private RexNode reduceStddev(
      AggregateRelBase oldAggRel,
      AggregateCall oldCall,
      boolean biased,
      boolean sqrt,
      List<AggregateCall> newCalls,
      Map<AggregateCall, RexNode> aggCallMapping,
      List<RexNode> inputExprs) {
    // stddev_pop(x) ==>
    //   power(
    //     (sum(x * x) - sum(x) * sum(x) / count(x))
    //     / count(x),
    //     .5)
    //
    // stddev_samp(x) ==>
    //   power(
    //     (sum(x * x) - sum(x) * sum(x) / count(x))
    //     / nullif(count(x) - 1, 0),
    //     .5)
    final int nGroups = oldAggRel.getGroupCount();
    RelDataTypeFactory typeFactory =
        oldAggRel.getCluster().getTypeFactory();
    final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();

    assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
    final int argOrdinal = oldCall.getArgList().get(0);
    final RelDataType argType =
        getFieldType(
            oldAggRel.getChild(),
            argOrdinal);

    // final RexNode argRef = inputExprs.get(argOrdinal);
    RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
    inputExprs.set(argOrdinal, argRef);

    final RexNode argSquared =
        rexBuilder.makeCall(
            SqlStdOperatorTable.MULTIPLY, argRef, argRef);
    final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);

    final RelDataType sumType =
        typeFactory.createTypeWithNullability(
            argType,
            true);
    final AggregateCall sumArgSquaredAggCall =
        new AggregateCall(
            new SqlSumAggFunction(sumType),
            oldCall.isDistinct(),
            ImmutableIntList.of(argSquaredOrdinal),
            sumType,
            null);
    final RexNode sumArgSquared =
        rexBuilder.addAggCall(
            sumArgSquaredAggCall,
            nGroups,
            newCalls,
            aggCallMapping,
            ImmutableList.of(argType));

    final AggregateCall sumArgAggCall =
        new AggregateCall(
            new SqlSumAggFunction(sumType),
            oldCall.isDistinct(),
            ImmutableIntList.of(argOrdinal),
            sumType,
            null);
    final RexNode sumArg =
          rexBuilder.addAggCall(
              sumArgAggCall,
              nGroups,
              newCalls,
              aggCallMapping,
              ImmutableList.of(argType));

    final RexNode sumSquaredArg =
          rexBuilder.makeCall(
              SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);

    final SqlAggFunction countAgg = SqlStdOperatorTable.COUNT;
    final RelDataType countType = countAgg.getReturnType(typeFactory);
    final AggregateCall countArgAggCall =
        new AggregateCall(
            countAgg,
            oldCall.isDistinct(),
            oldCall.getArgList(),
            countType,
            null);
    final RexNode countArg =
        rexBuilder.addAggCall(
            countArgAggCall,
            nGroups,
            newCalls,
            aggCallMapping,
            ImmutableList.of(argType));

    final RexNode avgSumSquaredArg =
        rexBuilder.makeCall(
            SqlStdOperatorTable.DIVIDE,
            sumSquaredArg, countArg);

    final RexNode diff =
        rexBuilder.makeCall(
            SqlStdOperatorTable.MINUS,
            sumArgSquared, avgSumSquaredArg);

    final RexNode denominator;
    if (biased) {
      denominator = countArg;
    } else {
      final RexLiteral one =
          rexBuilder.makeExactLiteral(BigDecimal.ONE);
      final RexNode nul =
          rexBuilder.makeNullLiteral(countArg.getType().getSqlTypeName());
      final RexNode countMinusOne =
          rexBuilder.makeCall(
              SqlStdOperatorTable.MINUS, countArg, one);
      final RexNode countEqOne =
          rexBuilder.makeCall(
              SqlStdOperatorTable.EQUALS, countArg, one);
      denominator =
          rexBuilder.makeCall(
              SqlStdOperatorTable.CASE,
              countEqOne, nul, countMinusOne);
    }

    final RexNode div =
        rexBuilder.makeCall(
            SqlStdOperatorTable.DIVIDE, diff, denominator);

    RexNode result = div;
    if (sqrt) {
      final RexNode half =
          rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
      result =
          rexBuilder.makeCall(
              SqlStdOperatorTable.POWER, div, half);
    }

    return rexBuilder.makeCast(
        oldCall.getType(), result);
  }

  /**
   * Finds the ordinal of an element in a list, or adds it.
   *
   * @param list    List
   * @param element Element to lookup or add
   * @param <T>     Element type
   * @return Ordinal of element in list
   */
  private static <T> int lookupOrAdd(List<T> list, T element) {
    int ordinal = list.indexOf(element);
    if (ordinal == -1) {
      ordinal = list.size();
      list.add(element);
    }
    return ordinal;
  }

  /**
   * Do a shallow clone of oldAggRel and update aggCalls. Could be refactored
   * into AggregateRelBase and subclasses - but it's only needed for some
   * subclasses.
   *
   * @param oldAggRel AggregateRel to clone.
   * @param inputRel  Input relational expression
   * @param newCalls  New list of AggregateCalls
   * @return shallow clone with new list of AggregateCalls.
   */
  protected AggregateRelBase newAggregateRel(
      AggregateRelBase oldAggRel,
      RelNode inputRel,
      List<AggregateCall> newCalls) {
    return new AggregateRel(
        oldAggRel.getCluster(),
        inputRel,
        oldAggRel.getGroupSet(),
        newCalls);
  }

  private RelDataType getFieldType(RelNode relNode, int i) {
    final RelDataTypeField inputField =
        relNode.getRowType().getFieldList().get(i);
    return inputField.getType();
  }

}

