//===- TosaToRock.cpp - Lowering Tosa to Rock Dialect -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// These rewriters lower from the Tosa to the Rock dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/TosaToRock/TosaToRock.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Rock/IR/AmdArchDb.h"
#include "mlir/Dialect/Rock/IR/GetRockInfo.h"
#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/Dialect/Rock/IR/RockConvInterface.h"
#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
#include "mlir/Dialect/Rock/IR/TransformMapBuilder.h"
#include "mlir/Dialect/Rock/Tuning/ConvContext.h"
#include "mlir/Dialect/Rock/Tuning/RockTuning.h"
#include "mlir/Dialect/Rock/utility/builderUtils.h"
#include "mlir/Dialect/Rock/utility/loweringUtils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <utility>

#define DEBUG_TYPE "convert-tosa-to-rock"

using namespace mlir;

namespace {

static bool isZeroAttribute(Attribute value) {
  if (auto intValue = dyn_cast<IntegerAttr>(value))
    return intValue.getValue().isZero();
  if (auto fpValue = dyn_cast<FloatAttr>(value))
    return fpValue.getValue().isZero();
  if (auto splatValue = dyn_cast<SplatElementsAttr>(value))
    return isZeroAttribute(splatValue.getSplatValue<Attribute>());
  if (auto elementsValue = dyn_cast<ElementsAttr>(value))
    return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
  if (auto elementsValue = dyn_cast<DenseElementsAttr>(value))
    return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
  if (auto arrayValue = dyn_cast<ArrayAttr>(value))
    return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
  return false;
}

static bool isConstantZero(Value v) {
  if (auto cst = v.getDefiningOp<arith::ConstantOp>())
    return isZeroAttribute(cst.getValue());
  if (auto cst = v.getDefiningOp<tosa::ConstOp>())
    return isZeroAttribute(cst.getValuesAttr());
  return false;
}

static bool isNegInfAttribute(Attribute value) {
  if (auto fpValue = dyn_cast<FloatAttr>(value)) {
    auto value = fpValue.getValue();

    std::pair<APFloat, llvm::detail::opStatus> floatRes = rock::createAPFloat(
        fpValue.getType(), -std::numeric_limits<float>::infinity());
    auto expectedValue = floatRes.first;
    auto status = floatRes.second;
    assert(status == APFloat::opOK);

    return value.compare(expectedValue) == llvm::APFloat::cmpEqual;
  }
  if (auto splatValue = dyn_cast<SplatElementsAttr>(value))
    return isNegInfAttribute(splatValue.getSplatValue<Attribute>());
  if (auto elementsValue = dyn_cast<ElementsAttr>(value))
    return llvm::all_of(elementsValue.getValues<Attribute>(),
                        isNegInfAttribute);
  if (auto elementsValue = dyn_cast<DenseElementsAttr>(value))
    return llvm::all_of(elementsValue.getValues<Attribute>(),
                        isNegInfAttribute);
  if (auto arrayValue = dyn_cast<ArrayAttr>(value))
    return llvm::all_of(arrayValue.getValue(), isNegInfAttribute);

  return false;
}

static bool isConstIsNegInf(Value v) {
  if (auto cst = v.getDefiningOp<arith::ConstantOp>())
    return isNegInfAttribute(cst.getValue());
  if (auto cst = v.getDefiningOp<tosa::ConstOp>())
    return isNegInfAttribute(cst.getValuesAttr());
  return false;
}

static bool isIntAttrSame(Attribute value, int64_t expectedVal) {
  if (auto intValue = dyn_cast<IntegerAttr>(value)) {
    auto value = intValue.getValue();

    FailureOr<APInt> intRes =
        rock::createAPInt(intValue.getType(), expectedVal);
    if (failed(intRes))
      return false;

    return intRes.value() == value;
  }
  return false;
}

static bool isConstRangeAttribute(Attribute value) {
  if (auto splatValue = dyn_cast<SplatElementsAttr>(value))
    return false;
  if (auto elementsValue = dyn_cast<ElementsAttr>(value))
    return llvm::all_of(llvm::enumerate(elementsValue.getValues<Attribute>()),
                        [](const auto &indexedAttr) {
                          return isIntAttrSame(indexedAttr.value(),
                                               indexedAttr.index());
                        });
  if (auto elementsValue = dyn_cast<DenseElementsAttr>(value))
    return llvm::all_of(llvm::enumerate(elementsValue.getValues<Attribute>()),
                        [](const auto &indexedAttr) {
                          return isIntAttrSame(indexedAttr.value(),
                                               indexedAttr.index());
                        });
  if (auto arrayValue = dyn_cast<ArrayAttr>(value))
    return llvm::all_of(
        llvm::enumerate(arrayValue.getValue()), [](const auto &indexedAttr) {
          return isIntAttrSame(indexedAttr.value(), indexedAttr.index());
        });

  return false;
}

static bool isConstRange(Value v) {
  if (auto cst = v.getDefiningOp<arith::ConstantOp>())
    return isConstRangeAttribute(cst.getValue());
  if (auto cst = v.getDefiningOp<tosa::ConstOp>())
    return isConstRangeAttribute(cst.getValuesAttr());
  return false;
}

// Note:  we want something a bit more general than SmallString<8> for
// the layout string, but it has to allow for inserting a character into
// the string for the caller to see.
static Value expandTensor(PatternRewriter &rw, Operation *op, Value operand,
                          SmallString<8> &layout, StringRef lowerName,
                          int64_t g, uint32_t idx = 4) {
  auto loc = op->getLoc();
  auto oprType = cast<ShapedType>(operand.getType());
  if (!oprType.hasStaticShape()) {
    (void)rw.notifyMatchFailure(
        op, "tosa to rock conversion expects statically shaped tensors");
    return Value();
  }
  ArrayRef<int64_t> shape = oprType.getShape();

  SmallVector<uint32_t, 8> endDims;
  SmallVector<uint32_t, 8> startDims;
  SmallVector<StringRef, 8> startNames;

  // find the lower dimension that encodes the g dimension
  std::optional<uint32_t> groupFoldedDim = std::nullopt;

  for (uint32_t i = 0, e = shape.size(); i < e; ++i) {
    startNames.push_back(layout.substr(i, 1));
    if (layout[i] == lowerName[0]) {
      groupFoldedDim = i;
    } else {
      startDims.push_back(i);
      endDims.push_back(groupFoldedDim.has_value() ? i + 1 : i);
    }
  }

  if (!groupFoldedDim.has_value()) {
    (void)rw.notifyMatchFailure(op, "tosa conv has an invalid layout");
    return Value();
  }

  uint32_t lowerDim = groupFoldedDim.value();
  // insert 'g' dimension into layout
  rock::BottomUpTMBuilder transform(rw, ArrayRef<StringRef>(startNames), shape,
                                    loc);
  transform.passThrough(endDims, startDims);
  transform.unmerge({"g", lowerName}, {lowerDim, lowerDim + 1}, lowerName,
                    {{g, shape[lowerDim] / g}});
  layout = Twine(layout.substr(0, lowerDim) + "g" +
                 layout.substr(lowerDim, layout.size() - lowerDim))
               .str();

  return rock::TransformOp::create(rw, loc, operand, transform.get());
}

static rock::GemmFeatures getGemmFeaturesFromOp(Operation *op, Type inputType) {
  // Start by getting the arch from the Tosa op
  StringAttr arch = StringAttr::get(op->getContext(), "");
  FailureOr<StringAttr> maybeArch = rock::getArch(op);
  if (succeeded(maybeArch)) {
    arch = maybeArch.value();
  }

  // Now we can lookup the default features from the arch
  rock::AmdArchInfo archInfo = rock::lookupArchInfo(arch);
  rock::GemmFeatures features = archInfo.getDefaultFeatures(inputType);

  return features;
}

struct ConvFields {
  SmallString<8> filterLayout;
  SmallString<8> inputLayout;
  SmallString<8> outputLayout;
  Value inputExp;
  Value filterExp;
  Value outputExp;
  ArrayAttr pad;
  ArrayAttr stride;
  ArrayAttr dilation;
  rock::GemmFeaturesAttr features;
  StringAttr perfConfig;
};

static ConvFields commonConv(PatternRewriter &rw, Operation *op, Value input,
                             Value filter, Value output, DenseI64ArrayAttr pad,
                             DenseI64ArrayAttr stride,
                             DenseI64ArrayAttr dilation, int64_t group) {
  ConvFields res;

  res.filterLayout = "kyxc";
  if (auto attr = op->getAttrOfType<StringAttr>("filter_layout"))
    res.filterLayout = attr.getValue();
  else if (cast<ShapedType>(filter.getType()).getRank() > 4)
    res.filterLayout = "k012c";

  res.inputLayout = "nhwc";
  if (auto attr = op->getAttrOfType<StringAttr>("input_layout"))
    res.inputLayout = attr.getValue();
  else if (cast<ShapedType>(input.getType()).getRank() > 4)
    res.inputLayout = "n012c";
  if (output) {
    res.outputLayout = "nhwk";
    if (auto attr = op->getAttrOfType<StringAttr>("output_layout"))
      res.outputLayout = attr.getValue();
    else if (cast<ShapedType>(output.getType()).getRank() > 4)
      res.outputLayout = "n012k";
  }

  // expand tensors from rank 4 (NHWC) to rank 5 (NHWCG)
  // and add 'g into the layout
  res.inputExp = expandTensor(rw, op, input, res.inputLayout, "c", group);
  res.filterExp = expandTensor(rw, op, filter, res.filterLayout, "k", group);
  if (output)
    res.outputExp = expandTensor(rw, op, output, res.outputLayout, "k", group);

  res.pad = rw.getIndexArrayAttr(pad);
  res.stride = rw.getIndexArrayAttr(stride);
  res.dilation = rw.getIndexArrayAttr(dilation);
  res.perfConfig = op->getAttrOfType<StringAttr>("perf_config");

  return res;
}

static void addConvAttributes(PatternRewriter &rw, Operation *cop,
                              const ConvFields &convFields) {
  // specify layout attributes
  SmallVector<StringAttr, 5> filterLayoutSpec;
  SmallVector<StringAttr, 5> inputLayoutSpec;
  SmallVector<StringAttr, 5> outputLayoutSpec;
  for (size_t i = 0; i < convFields.filterLayout.size(); ++i) {
    filterLayoutSpec.push_back(
        rw.getStringAttr(convFields.filterLayout.substr(i, 1)));
    inputLayoutSpec.push_back(
        rw.getStringAttr(convFields.inputLayout.substr(i, 1) + "i"));
    if (convFields.outputExp)
      outputLayoutSpec.push_back(
          rw.getStringAttr(convFields.outputLayout.substr(i, 1) + "o"));
  }

  // arch-specific attributes
  // TODO: remove these
  if (auto attr = convFields.perfConfig)
    cop->setAttr("perf_config", attr);

  // convolution config attributes
  cop->setAttr("filter_layout",
               rw.getArrayAttr(ArrayRef<Attribute>(filterLayoutSpec.begin(),
                                                   filterLayoutSpec.end())));
  cop->setAttr("input_layout",
               rw.getArrayAttr(ArrayRef<Attribute>(inputLayoutSpec.begin(),
                                                   inputLayoutSpec.end())));
  if (convFields.outputExp)
    cop->setAttr("output_layout",
                 rw.getArrayAttr(ArrayRef<Attribute>(outputLayoutSpec.begin(),
                                                     outputLayoutSpec.end())));
}

static FailureOr<rock::RockConvInterface>
makeRockConv(ConversionPatternRewriter &rw, Operation *op, Value input,
             Value filter, Value output, DenseI64ArrayAttr pad,
             DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation,
             int64_t group, int64_t kernelID, std::string convKind) {
  Location loc = op->getLoc();
  ConvFields convFields =
      commonConv(rw, op, input, filter, output, pad, stride, dilation, group);

  Operation *cop = nullptr;
  if (convKind == "bwd_data") {
    cop = rock::ConvBwdDataOp::create(
        rw, loc, convFields.outputExp.getType(), convFields.filterExp,
        convFields.outputExp, convFields.inputExp,
        /*features=*/nullptr,
        /*blockSize=*/nullptr,
        /*gridSize=*/nullptr, rw.getIndexArrayAttr(pad),
        rw.getIndexArrayAttr(stride), rw.getIndexArrayAttr(dilation),
        /*params=*/nullptr, rw.getIndexAttr(kernelID),
        /*usesV4R1=*/rw.getBoolAttr(false));
  } else {
    // Handle forwards convolution
    assert(convKind != "bwd_weight" && "bwd_weight currently not implemented");
    cop = rock::ConvOp::create(
        rw, loc, convFields.outputExp.getType(), convFields.filterExp,
        convFields.inputExp, convFields.outputExp, /*features=*/nullptr,
        /*blockSize=*/nullptr, /*gridSize=*/nullptr, convFields.pad,
        convFields.stride, convFields.dilation, /*params=*/nullptr);
  }

  addConvAttributes(rw, cop, convFields);

  return cast<rock::RockConvInterface>(cop);
}

static Value traceToRes(Value tensor, DenseMap<Value, Value> &cache,
                        Value expectedTensor) {
  if (cache.contains(tensor))
    return cache.at(tensor);

  Value res = nullptr;
  if (tensor.getDefiningOp()) {
    if (expectedTensor == tensor) {
      res = tensor;
    } else if (auto view = tensor.getDefiningOp<ViewLikeOpInterface>()) {
      res = traceToRes(view.getViewSource(), cache, expectedTensor);
    } else if (auto expand = tensor.getDefiningOp<tensor::ExpandShapeOp>()) {
      res = traceToRes(expand.getSrc(), cache, expectedTensor);
    } else if (auto collapse =
                   tensor.getDefiningOp<tensor::CollapseShapeOp>()) {
      res = traceToRes(collapse.getSrc(), cache, expectedTensor);
    } else if (auto untransform =
                   tensor.getDefiningOp<rock::TensorUntransformCastOp>()) {
      res =
          traceToRes(untransform.getTransformedResult(), cache, expectedTensor);
    } else if (auto tosaOp = tensor.getDefiningOp<tosa::TosaOp>()) {
      for (auto operand : tosaOp->getOperands()) {
        if (llvm::isa<TensorType>(operand.getType())) {
          res = traceToRes(operand, cache, expectedTensor);
          if (res)
            break;
        }
      }
    }
  }

  cache.insert({tensor, res});
  return res;
}

static SetVector<int64_t> traceToRes(Value expectedTensor, func::FuncOp func) {
  llvm::DenseMap<Value, Value> cache;

  SmallVector<func::ReturnOp> returns;
  func.walk([&](func::ReturnOp returnOp) { returns.push_back(returnOp); });
  assert(returns.size() == 1 && "Number of returns is not one");
  func::ReturnOp returnOp = returns[0];

  SetVector<int64_t> resIndices;
  for (auto [i, res] : llvm::enumerate(returnOp->getOperands())) {
    Value out = traceToRes(res, cache, expectedTensor);
    if (out == expectedTensor)
      resIndices.insert(i);
  }
  return resIndices;
}

template <typename OpT>
static LogicalResult setSplitKAttrs(OpT op, rock::GemmFeatures features,
                                    PatternRewriter &rw) {
  auto perfConfig = op->template getAttrOfType<StringAttr>("perf_config");
  if (perfConfig && rock::isSplitKRequested(features, perfConfig)) {
    func::FuncOp func = op->template getParentOfType<func::FuncOp>();
    SetVector<int64_t> resIndices = traceToRes(op->getResult(0), func);
    if (resIndices.empty())
      return op.emitOpError(
          "can't trace the operation output to a kernel result");

    func::ReturnOp returnOp;
    func.walk([&](func::ReturnOp op) { returnOp = op; });
    for (int64_t resNumber : resIndices) {
      Type elementType =
          cast<ShapedType>(returnOp->getOperand(resNumber).getType())
              .getElementType();
      if (!isa<Float32Type, Float16Type, BFloat16Type>(elementType)) {
        return rw.notifyMatchFailure(
            op, "We only support F32, F16 and BF16 split-k, yet.");
      }
      Attribute outputInitVal = rw.getFloatAttr(elementType, 0.0);
      func.setResultAttr(resNumber, rock::PrefillAttr::getMnemonic(),
                         outputInitVal);
      func.setResultAttr(resNumber, "mhal.read_access", rw.getUnitAttr());
      // The original function also need the read access attr for the output.
      if (func->hasAttr("original_func")) {
        if (ModuleOp rootMod = func->getParentOfType<ModuleOp>()
                                   ->getParentOfType<ModuleOp>()) {
          SymbolTable symTable(rootMod);
          SymbolRefAttr originalFuncAttr =
              func->getAttrOfType<SymbolRefAttr>("original_func");
          if (func::FuncOp originalFunc = dyn_cast<func::FuncOp>(
                  symTable.lookupSymbolIn(rootMod, originalFuncAttr))) {
            originalFunc.setResultAttr(resNumber, "mhal.read_access",
                                       rw.getUnitAttr());
          }
        }
      }
    }
  }
  return success();
}

// Tosa ops can broadcast values along axes, which allows for
// element-wise operations without fully-matching dimensions.  The
// Elementwise trait is strict about matching dimensions, but
// broadcastable ops are also element-wise, and we know that an
// additional set of ops are also element-wise.
static bool isElementwiseOp(Operation *op) {
  return op->hasTrait<OpTrait::Elementwise>() ||
         op->hasTrait<OpTrait::ResultsBroadcastableShape>() ||
         // clang-format off
    isa<tosa::CastOp,
        tosa::ClampOp,
        tosa::ErfOp,
        tosa::SigmoidOp,
        tosa::TanhOp,
        tosa::AbsOp,
        tosa::CeilOp,
        tosa::ClzOp,
        tosa::ExpOp,
        tosa::FloorOp,
        tosa::LogOp,
        tosa::LogicalNotOp,
        tosa::NegateOp,
        tosa::ReciprocalOp,
        tosa::RsqrtOp,
        tosa::SelectOp,
        tosa::EqualOp,
        tosa::GreaterOp,
        tosa::GreaterEqualOp,
        tosa::MulOp
       >(op);
  // clang-format on
}

static Value addBlockArgument(OpBuilder &b, Value val, Block *block,
                              Location loc) {
  RankedTensorType valType = cast<RankedTensorType>(val.getType());
  val = block->addArgument(
      MemRefType::get(valType.getShape(), valType.getElementType()), loc);
  val = rock::getAsTensor(b, loc, val);
  return val;
}

static Operation *getConvOp(Operation *op) {
  if (isa<tensor::ExpandShapeOp>(op)) {
    op = op->getOperand(0).getDefiningOp();
  }
  if (!op)
    return nullptr;

  if (isa<tensor::CollapseShapeOp>(op)) {
    op = op->getOperand(0).getDefiningOp();
  }
  if (!op)
    return nullptr;

  while (isa<tosa::TransposeOp>(op)) {
    op = op->getOperand(0).getDefiningOp();
    if (!op)
      return nullptr;
  }
  return ((isa_and_nonnull<tosa::Conv2DOp>(op)) ||
          (isa_and_nonnull<tosa::TransposeConv2DOp>(op)))
             ? op
             : nullptr;
}

/*
GEMM+GEMM based ops can have elementwise region between first gemm and second
gemm. This helps with matching such GEMM+GEMM ops and also constructing the
elementwise region afterwards.
*/
template <typename OpT>
struct ElementwiseRegionFinder {
  /*
  This is simple DFS traversal to find out if it can hit gemm/conv op from the
  input. It keeps track of visited nodes to avoid cycles. It caches visited ops
  in topological order for rewrite. It also caches constant values and block
  argument candidates which will be used during rewrite.
  */
  void visit(Value input) {
    if (visitedSet.contains(input))
      return;
    visitedSet.insert(input);
    OpT fusionOp = input.getDefiningOp<OpT>();
    Operation *op = input.getDefiningOp();

    // We cannot handle bwd_data/weight conv ops + gemm yet, so bail early
    if (std::is_same_v<OpT, tosa::TransposeConv2DOp> && op)
      return;

    // we need to traverse tranposes if it's conv2d
    if (std::is_same_v<OpT, tosa::Conv2DOp> && op) {
      Operation *convOp = getConvOp(op);
      if (convOp)
        fusionOp = cast<OpT>(convOp);
    }
    if (fusionOp) {
      firstGemmBasedOp = fusionOp;
      firstGemmBasedVal = input;
      // cache blockArgCandidates for rewrite
      blockArgCandidates.push_back(input);
      return;
    }
    if (op && dyn_cast<tosa::ConstOp>(op)) {
      constantVals.push_back(input);
      return;
    }
    // Right now, this is a bit restricted that we only allow reshape-like
    // ops between in the elementwise tree that get fused to the fusion point.
    // TODO: however, the latest code gridwise-gemm-to-blockwise should tackle
    // more cases. The absolute restriction is gemm0Output to Linalg block
    // should contain invertible transforms, but that's future work.
    if (!op || (!isElementwiseOp(op) &&
                !isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(op))) {
      // cache blockArgCandidates for rewrite
      blockArgCandidates.push_back(input);
      return;
    }
    for (Value operand : op->getOperands()) {
      // do a DFS on each operand
      visit(operand);
    }
    // keep topological order for rewrite
    visitedOps.push_back(op);
  }

  FailureOr<OpT> getFirstGemmBasedOp() const {
    if (!firstGemmBasedOp)
      return failure();
    return firstGemmBasedOp;
  }

  SmallVector<Value> getElementwiseArgs() const {
    // ElementwiseArgs doesn't contain output from the first gemm explictly.
    // Therefore remove it.
    SmallVector<Value> elementwiseArgs = blockArgCandidates;
    uint64_t firstGemmBlockIndex = getFirstGemmBlockIndex();
    elementwiseArgs.erase(elementwiseArgs.begin() + firstGemmBlockIndex);
    return elementwiseArgs;
  }

  int64_t getFirstGemmBlockIndex() const {
    return std::find_if(blockArgCandidates.begin(), blockArgCandidates.end(),
                        [this](Value v) { return v == firstGemmBasedVal; }) -
           blockArgCandidates.begin();
  }

  void rewrite(Value input, OpBuilder &regionBuilder, Block *block,
               Location loc) const {
    PatternRewriter::InsertionGuard guard(regionBuilder);
    regionBuilder.setInsertionPointToEnd(block);
    IRMapping mapper;
    for (Value v : constantVals) {
      auto *newConstOp = regionBuilder.clone(*v.getDefiningOp());
      mapper.map(v, newConstOp->getResult(0));
    }
    for (Value v : blockArgCandidates) {
      auto newBlockArg = addBlockArgument(regionBuilder, v, block, loc);
      mapper.map(v, newBlockArg);
    }
    // make sure firstGemmBasedVal is passed as blockArgument for it is always
    // present
    Value lastRes = mapper.lookup(firstGemmBasedVal);
    for (Operation *op : visitedOps) {
      auto *newOp = regionBuilder.clone(*op, mapper);
      lastRes = newOp->getResult(0);
      mapper.map(lastRes, newOp->getResult(0));
    }
    RankedTensorType resTensorType = cast<RankedTensorType>(lastRes.getType());
    MemRefType resMemRefType = MemRefType::get(resTensorType.getShape(),
                                               resTensorType.getElementType());
    Value resMemref = bufferization::ToBufferOp::create(
        regionBuilder, loc,
        cast<mlir::bufferization::BufferLikeType>(resMemRefType), lastRes);
    Value outMemref = block->addArgument(resMemRefType, loc);
    memref::CopyOp::create(regionBuilder, loc, resMemref, outMemref);
    rock::YieldOp::create(regionBuilder, loc);
  }

private:
  OpT firstGemmBasedOp = nullptr;
  Value firstGemmBasedVal = nullptr;
  DenseSet<Value> visitedSet;
  SmallVector<Value> blockArgCandidates;
  SmallVector<Value> constantVals;
  SmallVector<Operation *> visitedOps;
};

static void addZeroInitPrefillAttribute(tosa::TransposeConv2DOp op,
                                        Operation *rockConv) {
  // First check if the TransposeConv2D op is going to require having it's
  // output zeroinitialized, i.e., not every element of the output buffer is
  // going to be written to
  rock::ConvolutionContext ctx = rock::populateConvContext(rockConv);
  auto strideDims = ctx.getStrideVal();
  auto dilationDims = ctx.getDilationVal();
  auto filterDims = ctx.getConvDims().fil;
  auto numKernels =
      rock::backwardDataKernelIds(strideDims, dilationDims, filterDims,
                                  /*usesV4R1=*/true);

  // If there is no zeroinit kernel needed, then there is nothing more we need
  // to do here.
  if (rock::isEveryElementWrittenBwdData(strideDims, dilationDims, filterDims))
    return;

  // Now we need to determine where to add the prefill attributes. Trace through
  // the output of the TransposeConv2D op to find where the result is used.
  Value output = op.getResult();
  func::FuncOp func = op->getParentOfType<func::FuncOp>();
  if (!func)
    return;

  SetVector<int64_t> resIndices = traceToRes(output, func);
  // If the output cannot be traced to a result index, then we have a case that
  // we cannot yet handle
  if (resIndices.empty())
    assert(false &&
           "Output of TransposeConv2D op cannot be traced to result index");

  OpBuilder builder(op.getContext());
  for (int64_t resNumber : resIndices) {
    Type funcResType = func.getFunctionType().getResult(resNumber);
    auto shapedResType = cast<ShapedType>(funcResType);
    Type elementType = shapedResType.getElementType();

    Attribute outputInitVal;
    if (isa<FloatType>(elementType)) {
      outputInitVal = builder.getFloatAttr(elementType, 0.0);
    } else if (isa<IntegerType>(elementType)) {
      outputInitVal = builder.getIntegerAttr(elementType, 0);
    } else {
      // We only expect integer and float types for now
      assert(false && "Unsupported element type for prefill attribute");
    }

    func.setResultAttr(resNumber, rock::PrefillAttr::getMnemonic(),
                       outputInitVal);
  }
}

template <typename OpT>
class ConvConverter final : public OpConversionPattern<OpT> {
public:
  using OpConversionPattern<OpT>::OpConversionPattern;

  LogicalResult matchAndRewrite(OpT op, typename OpT::Adaptor adaptor,
                                ConversionPatternRewriter &rw) const final {
    auto operands = adaptor.getOperands();
    auto loc = op->getLoc();
    auto *context = op->getContext();
    auto input = operands[0];
    auto filter = operands[1];
    auto bias = operands[2];
    auto outputType = cast<RankedTensorType>(op.getType());

    rock::GemmFeatures features = getGemmFeaturesFromOp(op, input.getType());

    if (failed(setSplitKAttrs(op, features, rw)))
      return failure();

    Value output =
        bufferization::AllocTensorOp::create(rw, loc, outputType, ValueRange{});

    auto groupAttr = op->template getAttrOfType<IntegerAttr>("group");
    auto padAttr = op->template getAttrOfType<DenseI64ArrayAttr>("pad");
    auto dilationAttr =
        op->template getAttrOfType<DenseI64ArrayAttr>("dilation");

    // Verify all required attributes are present
    int64_t group = 1;
    if (groupAttr)
      group = groupAttr.getInt();

    if (!padAttr)
      return op->emitError(
          "Expected 'pad' attribute to be present on the operation");

    if (!dilationAttr)
      return op->emitError(
          "Expected 'dilation' attribute to be present on the operation");

    std::string convKind = "";
    if (isa<tosa::TransposeConv2DOp>(op)) {
      // If we are trying to convert bwd_weight, fail as it's currently not
      // supported
      convKind = op->template getAttrOfType<StringAttr>("conv_kind").str();
      if (convKind == "bwd_weight") {
        op->emitError(
            "TosaToRock lowering support for bwd_weight not supported");
      }
      assert(convKind == "bwd_data" && "Expected bwd_data conv op");
    }

    FailureOr<rock::RockConvInterface> rockConv =
        makeRockConv(rw, op, input, filter, output, padAttr, op.getStrideAttr(),
                     dilationAttr, group, /*kernelID=*/0, convKind);

    if (convKind == "bwd_data")
      addZeroInitPrefillAttribute(cast<tosa::TransposeConv2DOp>(op),
                                  rockConv->getOperation());

    if (failed(rockConv))
      return failure();

    Value result;
    if (isa<tosa::TransposeConv2DOp>(op)) {
      result = output;
    } else {
      Operation *rockConvOp = rockConv->getOperation();
      result = rock::TensorUntransformCastOp::create(
          rw, loc, outputType, rockConvOp->getResult(0), rockConv->getOutput());
    }

    // test for zero bias, and ignore
    if (!isConstantZero(op.getOperand(2))) {
      // non-zero bias, replace with tosa.add w/ broadcast
      auto biasType = cast<ShapedType>(bias.getType());
      if (!biasType.hasStaticShape())
        return failure();

      int64_t nDims = cast<ShapedType>(input.getType()).getRank();
      SmallVector<int64_t> biasShape;
      for (int i = 0; i < nDims - 1; i++)
        biasShape.push_back(1);
      biasShape.push_back(biasType.getShape()[0]);
      auto newType =
          RankedTensorType::get(biasShape, biasType.getElementType());

      // [[0, 1, 2, 3]]
      ReassociationExprs exprs;
      for (int i = 0; i < nDims; i++)
        exprs.push_back(getAffineDimExpr(i, context));
      SmallVector<ReassociationExprs, 1> reassociations;
      reassociations.push_back(exprs);

      auto biasExpand =
          tensor::ExpandShapeOp::create(rw, loc, newType, bias, reassociations);

      result = tosa::AddOp::create(rw, loc, op.getType(),
                                   ValueRange{result, biasExpand});
    }
    rw.replaceOp(op, result);
    return success();
  }
};

static Value insertBroadcast(Value inp, ArrayRef<int64_t> outShape,
                             Location loc, OpBuilder &b) {
  ArrayRef<int64_t> inpShape = cast<ShapedType>(inp.getType()).getShape();
  bool broadcastDone = false;
  rock::BottomUpTMBuilder broadcastDims(b, inpShape, loc);
  for (unsigned int i = 0; i < outShape.size(); i++) {
    if (inpShape[i] == 1 && outShape[i] != 1) {
      broadcastDims.broadcast({i}, {outShape[i]});
      broadcastDone = true;
    } else {
      broadcastDims.passThrough({i}, {i});
    }
  }
  if (!broadcastDone) {
    return inp;
  }
  return rock::TransformOp::create(b, loc, inp, broadcastDims.get());
}

class MatMulConverter final : public OpConversionPattern<tosa::MatMulOp> {
public:
  using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;

  UnitAttr getTranspose(tosa::MatMulOp op, StringRef name) const {
    if (auto attr = op->getAttrOfType<BoolAttr>(name)) {
      if (attr.getValue())
        return UnitAttr::get(op->getContext());
    }
    return nullptr;
  }

  std::tuple<int64_t, int64_t> getLastDims(UnitAttr transposed,
                                           RankedTensorType type) const {
    ArrayRef<int64_t> shape = type.getShape();
    int64_t rank = type.getRank();
    if (transposed) {
      return {shape[rank - 1], shape[rank - 2]};
    }
    return {shape[rank - 2], shape[rank - 1]};
  }

  void setLastDims(UnitAttr transposed, SmallVectorImpl<int64_t> &shape,
                   std::pair<int64_t, int64_t> lastDims) const {
    size_t rank = shape.size();
    if (transposed) {
      shape[rank - 1] = lastDims.first;
      shape[rank - 2] = lastDims.second;
    } else {
      shape[rank - 2] = lastDims.first;
      shape[rank - 1] = lastDims.second;
    }
  }

  LogicalResult matchAndRewrite(tosa::MatMulOp op,
                                tosa::MatMulOp::Adaptor adaptor,
                                ConversionPatternRewriter &rw) const final {
    Location loc = op->getLoc();
    auto outputType = cast<RankedTensorType>(op.getType());
    Value output =
        bufferization::AllocTensorOp::create(rw, loc, outputType, ValueRange{});

    rock::GemmFeatures features =
        getGemmFeaturesFromOp(op, op.getA().getType());

    if (failed(setSplitKAttrs(op, features, rw)))
      return failure();

    UnitAttr transposeA = getTranspose(op, "transpose_a"),
             transposeB = getTranspose(op, "transpose_b"),
             transposeC = getTranspose(op, "transpose_c");

    auto [mDim, nDim] = getLastDims(transposeC, outputType);

    int64_t kDimOfA;
    std::tie(std::ignore, kDimOfA) =
        getLastDims(transposeA, cast<RankedTensorType>(op.getA().getType()));
    int64_t kDimOfB;
    std::tie(kDimOfB, std::ignore) =
        getLastDims(transposeB, cast<RankedTensorType>(op.getB().getType()));
    int kDim = (kDimOfA > kDimOfB) ? kDimOfA : kDimOfB;

    SmallVector<int64_t, 3> aShape = llvm::to_vector<3>(
        cast<RankedTensorType>(op.getA().getType()).getShape());
    setLastDims(transposeA, aShape, {mDim, kDim});
    Value brA = insertBroadcast(adaptor.getA(), aShape, loc, rw);

    SmallVector<int64_t, 3> bShape = llvm::to_vector<3>(
        cast<RankedTensorType>(op.getB().getType()).getShape());
    setLastDims(transposeB, bShape, {kDim, nDim});
    Value brB = insertBroadcast(adaptor.getB(), bShape, loc, rw);

    auto rockGemm = rock::GemmOp::create(
        rw, loc, outputType, brA, brB, output, transposeA, transposeB,
        transposeC,
        /*features=*/nullptr,
        rw.getAttr<rock::StoreMethodAttr>(rock::StoreMethod::Set),
        /*blockSize=*/nullptr, /*gridSize=*/nullptr,
        /*params=*/nullptr);

    if (auto attr = op->getAttrOfType<StringAttr>("perf_config"))
      rockGemm->setAttr("perf_config", attr);

    rw.replaceOp(op, rockGemm.getResult());

    return success();
  }
};

static void permuteLayout(Operation *op, const char *attrKey,
                          const char *layoutDefault,
                          const ArrayRef<int32_t> permDims,
                          bool isInput = false) {
  StringRef currentLayout(layoutDefault);
  if (auto attr = op->getAttrOfType<StringAttr>(attrKey))
    currentLayout = attr.getValue();
  SmallString<4> layout(currentLayout);
  if (isInput) {
    for (int i = 0, e = permDims.size(); i < e; ++i)
      layout[permDims[i]] = currentLayout[i];
  } else {
    for (int i = 0, e = permDims.size(); i < e; ++i)
      layout[i] = currentLayout[permDims[i]];
  }
  op->setAttr(attrKey, StringAttr::get(op->getContext(), layout));
}

struct TransposeRewritePattern : public OpRewritePattern<tosa::TransposeOp> {
  using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;

  void setTranspose(Operation *op, StringRef name, bool isNonTrivial) const {
    bool currentValue = false;
    if (auto attr = op->getAttrOfType<BoolAttr>(name))
      currentValue = attr.getValue();
    bool newValue = currentValue ^ isNonTrivial;
    op->setAttr(name, BoolAttr::get(op->getContext(), newValue));
  }

  LogicalResult checkInputHasUses(PatternRewriter &rewriter,
                                  tosa::TransposeOp top, Value tInput) const {
    // if the input has uses (apart from this one), we can't do this
    if (!tInput.hasOneUse()) {
      return rewriter.notifyMatchFailure(
          top, "abandoning attempt to fuse transpose "
               "because the operation has other uses");
    }
    return success();
  }

  LogicalResult checkMatMulTransposeValid(tosa::MatMulOp matmulOp,
                                          const ArrayRef<int32_t> dims) const {
    // batch dimension is expected to be 3rd from the last.
    if (dims.size() >= 3 && dims[dims.size() - 3] != (int32_t)dims.size() - 3) {
      return matmulOp.emitWarning(
          "Transposing the batch dimension out of place lowers performance");
    }
    return success();
  }

  bool isMatMulNonTrivial(const ArrayRef<int32_t> dims) const {
    int32_t lastDim = dims.size() - 1;
    int32_t prevLastDim = dims.size() - 2;
    return (dims[prevLastDim] == lastDim && dims[lastDim] == prevLastDim);
  }

  // This function traverses the uses of tOutput and then modifies
  // the uses to indicate the input are transposed and replaces them
  // with tInput. If there are collapse shapes encountered, the collapse
  // is applied on the tInput.
  LogicalResult mergeTransposeWithGemmLikeOp(PatternRewriter &rewriter,
                                             Value tOutput,
                                             const ArrayRef<int32_t> dims,
                                             Value tInput) const {
    auto handleConv = [&](auto convOp) -> LogicalResult {
      if (convOp.getInput() == tOutput) {
        permuteLayout(convOp.getOperation(), "input_layout", "nhwc", dims,
                      true);
        convOp.getInputMutable().assign(tInput);
      } else if (convOp.getWeight() == tOutput) {
        permuteLayout(convOp.getOperation(), "filter_layout", "kyxc", dims,
                      true);
        convOp.getWeightMutable().assign(tInput);
      } else {
        return convOp.emitWarning("transpose found leading to a "
                                  "conv2D/transposeConv2D input other than "
                                  "data or weight");
      }
      return success();
    };

    for (auto &use : llvm::make_early_inc_range(tOutput.getUses())) {
      if (auto op = dyn_cast<tensor::CollapseShapeOp>(use.getOwner())) {
        SmallVector<ReassociationIndices, 4> reassocIndices =
            op.getReassociationIndices();
        // This is to capture new reassociations above the transpose
        llvm::SmallDenseMap<int32_t, ReassociationIndices> newReassocIdxMap;
        ArrayRef<int64_t> inShape = op.getSrcType().getShape();

        // This loops maps reassociated dims back to pre transposed dims.
        SmallVector<int32_t, 4> newDims;

        llvm::SmallDenseSet<int64_t> preTpUnitDims;
        for (ReassociationIndices indices : reassocIndices) {
          ReassociationIndices newReassocIdx;
          size_t numNonUnitDimsMerged = 0;
          for (size_t i = 0, e = indices.size(); i < e; ++i) {
            if (inShape[indices[i]] == 1) {
              preTpUnitDims.insert(dims[indices[i]]);
            } else {
              numNonUnitDimsMerged += 1;
            }
            newReassocIdx.push_back(dims[indices[i]]);
          }
          if (numNonUnitDimsMerged > 1) {
            // Per MIGraphX bug #2692, this transpsoe/collaspe swap logic
            // will be incorrect in cases like the following
            //   %0 = expand_shape [[0], [1, 2], [3]] %arg0 : tensor<7x6x5xT>
            //   to tensor<7x3x2x5xT> %1 = transpose %0, [0, 2, 1, 3] :
            //   tensor<7x2x3x5xT> %2 = collapse_shape [[0], [1, 2], [2]] %1 :
            //   tensor<7x2x3x5xT> to tensor<7x6x5xT>
            // by way of creating a trivial expand/collapse pair that isn't
            // correct.
            //
            // Therefore, as a sledgehammer fix, don't handle any cases where
            // non-trivial collapses are performed.
            return rewriter.notifyMatchFailure(
                op, "abandoning attempt to interchange transpose and "
                    "non-trivial collapse");
          }
          if (newReassocIdx.size() > 1) {
            llvm::sort(newReassocIdx);
            // Remove unit dims from larger end of reassociation indices
            // but we need at least one for the reassociation
            while (newReassocIdx.size() > 1 &&
                   preTpUnitDims.contains(newReassocIdx.back())) {
              newReassocIdx.pop_back();
            }
            for (size_t i = 1; i < newReassocIdx.size(); i++) {
              if (newReassocIdx[i] - newReassocIdx[i - 1] != 1) {
                return rewriter.notifyMatchFailure(
                    op, "CollapseShape op following transpose collapses "
                        "non-contigous pre-transpose dims.");
              }
            }
          }
          newDims.push_back(newReassocIdx[0]);
          // minIdx is the representative of a group that is
          // being collapsed. For e.g. for a collapse of [3,4,5] is assigned
          // with 3 as the representative. I also note that we only allow
          // collapsing of contigous pre-transpose dims.
          newReassocIdxMap[newReassocIdx[0]] = newReassocIdx;
        }

        // Assign the ordering index of reassociated dims as the dim index
        SmallVector<int32_t, 4> newDimsSorted = newDims;
        llvm::sort(newDimsSorted);
        SmallVector<ReassociationIndices, 4> newReassocIndicesSorted;
        DenseMap<int32_t, int32_t> dimMap;
        // The vector of newDims (may) contain a discontinous
        // a range of representative minIdxs. Here we make
        // it contigous by assigning order idx.
        for (size_t i = 0; i < newDimsSorted.size(); i++) {
          dimMap[newDimsSorted[i]] = i;
          newReassocIndicesSorted.push_back(newReassocIdxMap[newDimsSorted[i]]);
        }
        // HOTFIX: glue trailing unit dimensions onto collapses that need
        // them. This is because a case like
        // %t = transpose %aRaw [0, 1, 3, 2] : tensor<1x1xKxM> ->
        // tensor<1x1xMxK> %a = collapse_shape [[0, 1], [2], [3]]
        //    : tensor<1x1xMxK> -> tensor<1xMxK>
        // will, with the above unit-dimension-removal logic, lead to the
        // invalid reassociation [[0], [2], [3]], causing a crash.
        // See MIGraphX bug #2365.
        // The entire logic here should be reviewed, or at least made less
        // complex if possible, but ... release-critical bug, what can we do?
        for (size_t i = 0, e = newReassocIndicesSorted.size() - 1; i < e; ++i) {
          ReassociationIndices &theseIndices = newReassocIndicesSorted[i];
          const ReassociationIndices &nextIndices =
              newReassocIndicesSorted[i + 1];
          while (theseIndices.back() + 1 < nextIndices[0]) {
            theseIndices.push_back(theseIndices.back() + 1);
          }
        }
        // do the same for the last set of indices too
        // where it does not match upto the rank of the input.
        ReassociationIndices &lastIndices = newReassocIndicesSorted.back();
        while (lastIndices.back() + 1 < (int64_t)inShape.size()) {
          lastIndices.push_back(lastIndices.back() + 1);
        }

        for (size_t i = 0; i < newDims.size(); i++) {
          newDims[i] = dimMap[newDims[i]];
        }

        tensor::CollapseShapeOp newCollapseShapeOp =
            tensor::CollapseShapeOp::create(rewriter, op.getLoc(), tInput,
                                            newReassocIndicesSorted);

        if (mergeTransposeWithGemmLikeOp(rewriter, op.getResult(), newDims,
                                         newCollapseShapeOp.getResult())
                .failed()) {
          rewriter.eraseOp(newCollapseShapeOp);
          return failure();
        }
        if (op->use_empty())
          rewriter.eraseOp(op);
      } else if (auto op = dyn_cast<tensor::ExpandShapeOp>(use.getOwner())) {
        return rewriter.notifyMatchFailure(
            op, "We dont support expand shapes yet.");
      } else if (auto transposeConv2D =
                     dyn_cast<tosa::TransposeConv2DOp>(use.getOwner())) {
        return handleConv(transposeConv2D);
      } else if (auto conv2D = dyn_cast<tosa::Conv2DOp>(use.getOwner())) {
        return handleConv(conv2D);
      } else if (auto matMulOp = dyn_cast<tosa::MatMulOp>(use.getOwner())) {
        if (checkMatMulTransposeValid(matMulOp, dims).failed()) {
          return failure();
        }
        bool mmNonTrivial = isMatMulNonTrivial(dims);
        if (matMulOp.getA() == tOutput) {
          setTranspose(matMulOp, "transpose_a", mmNonTrivial);
          matMulOp.getAMutable().assign(tInput);
        } else if (matMulOp.getB() == tOutput) {
          setTranspose(matMulOp, "transpose_b", mmNonTrivial);
          matMulOp.getBMutable().assign(tInput);
        } else {
          return matMulOp.emitWarning(
              "transpose found leading to a matmul input other than A or B");
        }
      } else {
        return failure();
      }
    }
    return success();
  }

  // Fold transpose ops and convert convolution into changed layout.
  // case #0 : fold TP(NCHW2NHWC)+tosa.conv.NHWC+TP(NHWC2NCHW) back to
  //           rock.conv.NCHW
  // Pattern match start from the output transpose
  LogicalResult matchAndRewrite(tosa::TransposeOp top,
                                PatternRewriter &b) const final {
    const auto dims = top.getPerms();

    Value tInput = top.getInput1();
    Value tOutput = top.getResult();
    auto definingOp = tInput.getDefiningOp();
    if (definingOp && (isa<tosa::Conv2DOp>(definingOp) ||
                       isa<tosa::TransposeConv2DOp>(definingOp))) {
      auto transposeConv2D = dyn_cast<tosa::TransposeConv2DOp>(definingOp);
      auto conv2D = dyn_cast<tosa::Conv2DOp>(definingOp);
      auto convOp = (transposeConv2D ? transposeConv2D : conv2D);
      if (checkInputHasUses(b, top, tInput).failed()) {
        return failure();
      }
      // conv output is transpose
      permuteLayout(convOp, "output_layout", "nhwk", dims);
      convOp->getResult(0).setType(tOutput.getType());
      top->replaceAllUsesWith(convOp);
    } else if (tosa::MatMulOp matMulOp =
                   tInput.getDefiningOp<tosa::MatMulOp>()) {

      if (checkInputHasUses(b, top, tInput).failed()) {
        return failure();
      }
      if (checkMatMulTransposeValid(matMulOp, dims).failed()) {
        return failure();
      }
      setTranspose(matMulOp, "transpose_c", isMatMulNonTrivial(dims));
      matMulOp->getResult(0).setType(tOutput.getType());
      top->replaceAllUsesWith(matMulOp);
    } else {
      if (mergeTransposeWithGemmLikeOp(b, tOutput, dims, tInput).failed()) {
        return failure();
      }
    }

    if (top.use_empty())
      b.eraseOp(top);
    return success();
  }
};

// In Tosa canonicalize, a transpose of NCHW to NHWC where H==W==1 will
// convert to a reshape because it does not change memory layout. Then in
// TosaToTensor conversion, the reshape is replaced by this pattern:
//     %0 = collapse(filters[KCHW]) -> [KC]
//     %1 = expand(%0[KC]) -> [KHWC]
// If this feeds into a conv as filter, we will drop the collapse/expand and
// update the filter_layout attribute.
struct CollapseExpandRewritePattern
    : public OpRewritePattern<tensor::ExpandShapeOp> {
  using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;

  bool checkExpand(tensor::ExpandShapeOp expOp) const {
    auto srcSh = cast<ShapedType>(expOp.getOperand(0).getType()).getShape();
    auto resSh = cast<ShapedType>(expOp.getResultType()).getShape();
    // [[0, 1, 2], [3]]
    // NC -> NHWC
    if (srcSh.size() == 2 && resSh.size() == 4 && srcSh[0] == resSh[0] &&
        srcSh[1] == resSh[3] && resSh[1] == 1 && resSh[2] == 1) {
      return true;
    }
    return false;
  }

  bool checkCollapse(tensor::CollapseShapeOp colOp) const {
    auto srcSh = cast<ShapedType>(colOp.getOperand().getType()).getShape();
    auto resSh = cast<ShapedType>(colOp.getResultType()).getShape();
    // [[0], [1, 2, 3]]
    // NCHW -> NC
    if (srcSh.size() == 4 && resSh.size() == 2 && srcSh[0] == resSh[0] &&
        srcSh[1] == resSh[1] && srcSh[2] == 1 && srcSh[3] == 1) {
      return true;
    }
    return false;
  }

  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expOp,
                                PatternRewriter &b) const final {
    LogicalResult lres = failure();
    Value expInp = expOp.getOperand(0);
    Value expOut = expOp.getResult();

    if (!checkExpand(expOp))
      return failure();

    auto colOp = expInp.getDefiningOp<tensor::CollapseShapeOp>();
    if (colOp && checkCollapse(colOp)) {
      auto colInp = colOp.getOperand();

      for (Operation *usr : expOut.getUsers()) {
        if (isa<tosa::TransposeConv2DOp>(usr) || isa<tosa::Conv2DOp>(usr)) {
          if (usr->getOperand(1) == expOut) {
            // update filter_layout
            SmallVector<int32_t> dims{0, 2, 3, 1};
            permuteLayout(usr, "filter_layout", "kyxc", dims, true);
            // replace filter input with collapse source
            usr->replaceUsesOfWith(expOut, colInp);

            lres = success();
          }
        }
      }
    }

    return lres;
  }
};

struct ConvElementwiseGemmRewritePattern
    : public OpRewritePattern<tosa::MatMulOp> {
  using OpRewritePattern::OpRewritePattern;

  FailureOr<ElementwiseRegionFinder<tosa::Conv2DOp>>
  match(tosa::MatMulOp op) const {
    ElementwiseRegionFinder<tosa::Conv2DOp> elementwiseRegionFinder;
    elementwiseRegionFinder.visit(op.getA());
    FailureOr<tosa::Conv2DOp> maybeConv =
        elementwiseRegionFinder.getFirstGemmBasedOp();

    if (succeeded(maybeConv))
      LLVM_DEBUG(llvm::dbgs() << "conv = " << maybeConv.value() << "\n");
    else {
      LLVM_DEBUG(llvm::dbgs() << "conv not found\n");
      return failure();
    }

    tosa::Conv2DOp firstConv = maybeConv.value();
    // bias not supported
    if (!isConstantZero(firstConv.getBias())) {
      op.emitOpError("bias not supported yet");
      return failure();
    }
    return elementwiseRegionFinder;
  }

  void rewrite(
      tosa::MatMulOp op,
      const ElementwiseRegionFinder<tosa::Conv2DOp> &elementwiseRegionFinder,
      PatternRewriter &rewriter) const {
    Location loc = op.getLoc();
    auto outputType = cast<RankedTensorType>(op.getType());
    Value output = bufferization::AllocTensorOp::create(
        rewriter, loc, outputType, ValueRange{});

    // This is guaranteed by the matcher
    tosa::Conv2DOp firstConv =
        elementwiseRegionFinder.getFirstGemmBasedOp().value();

    SmallVector<Value> elementwiseOtherArgs =
        elementwiseRegionFinder.getElementwiseArgs();

    int64_t group = 1;
    if (auto attr = op->template getAttrOfType<IntegerAttr>("group"))
      group = attr.getInt(); // Use op.getGroup() when all OpT have it.
    ConvFields convFields =
        commonConv(rewriter, op, firstConv.getInput(), firstConv.getWeight(),
                   output, firstConv.getPadAttr(), firstConv.getStrideAttr(),
                   firstConv.getDilationAttr(), group);
    auto firstGemmBlockIndex = elementwiseRegionFinder.getFirstGemmBlockIndex();

    rock::GemmFeatures featuresA =
        getGemmFeaturesFromOp(op, convFields.filterExp.getType());
    rock::GemmFeatures featuresC =
        getGemmFeaturesFromOp(op, op.getB().getType());
    rock::GemmFeatures features = intersectGemmFeatures(featuresA, featuresC);

    if (failed(setSplitKAttrs(op, features, rewriter)))
      return;

    auto convElentwiseGemmOp = rock::ConvElementwiseGemmOp::create(
        rewriter, loc, outputType, convFields.filterExp, convFields.inputExp,
        op.getB(), elementwiseOtherArgs, output,
        /*cTransposed=*/nullptr,
        /*oTransposed=*/nullptr, /*features=*/nullptr,
        rewriter.getAttr<rock::StoreMethodAttr>(rock::StoreMethod::Set),
        convFields.pad, convFields.stride, convFields.dilation,
        /*params0=*/nullptr, /*params1=*/nullptr,
        /*firstGemmIndices=*/
        rewriter.getDenseI64ArrayAttr(firstGemmBlockIndex));

    addConvAttributes(rewriter, convElentwiseGemmOp, convFields);

    Block *preSecondGemmElemwiseBlock =
        &convElentwiseGemmOp.getPreSecondGemmBody().emplaceBlock();
    {
      PatternRewriter::InsertionGuard guard(rewriter);
      rewriter.setInsertionPointToStart(preSecondGemmElemwiseBlock);
      elementwiseRegionFinder.rewrite(op.getA(), rewriter,
                                      preSecondGemmElemwiseBlock, loc);
    }
    if (auto attr = op->getAttrOfType<StringAttr>("perf_config"))
      convElentwiseGemmOp->setAttr("perf_config", attr);

    rewriter.replaceOp(op, convElentwiseGemmOp.getResult());
  }

  LogicalResult matchAndRewrite(tosa::MatMulOp op,
                                PatternRewriter &rewriter) const override {
    FailureOr<ElementwiseRegionFinder<tosa::Conv2DOp>> elemwiseFinder =
        match(op);
    if (succeeded(elemwiseFinder)) {
      rewrite(op, elemwiseFinder.value(), rewriter);
    }
    return elemwiseFinder;
  }
};

struct GemmElementwiseGemmRewritePattern
    : public OpRewritePattern<tosa::MatMulOp> {
  using OpRewritePattern::OpRewritePattern;

  FailureOr<ElementwiseRegionFinder<tosa::MatMulOp>>
  match(tosa::MatMulOp op) const {
    ElementwiseRegionFinder<tosa::MatMulOp> elemwiseRegionFinder;
    elemwiseRegionFinder.visit(op.getA());
    FailureOr<tosa::MatMulOp> maybeFirstMatMul =
        elemwiseRegionFinder.getFirstGemmBasedOp();
    if (succeeded(maybeFirstMatMul))
      LLVM_DEBUG(llvm::dbgs()
                 << "first matmul = " << maybeFirstMatMul.value() << "\n");
    else {
      LLVM_DEBUG(llvm::dbgs() << "first matmul not found\n");
      return failure();
    }
    return elemwiseRegionFinder;
  }

  void rewrite(tosa::MatMulOp op,
               const ElementwiseRegionFinder<tosa::MatMulOp> &elemwiseFinder,
               PatternRewriter &rewriter) const {
    Location loc = op.getLoc();

    auto outputType = cast<RankedTensorType>(op.getType());
    Value output = bufferization::AllocTensorOp::create(
        rewriter, loc, outputType, ValueRange{});
    SmallVector<Value> elementwiseOtherArgs =
        elemwiseFinder.getElementwiseArgs();
    // This is guranteed by the matcher
    tosa::MatMulOp firstMatMulOp = elemwiseFinder.getFirstGemmBasedOp().value();
    int64_t firstGemmBlockIndex = elemwiseFinder.getFirstGemmBlockIndex();

    rock::GemmFeatures featuresA =
        getGemmFeaturesFromOp(op, firstMatMulOp.getA().getType());
    rock::GemmFeatures featuresC =
        getGemmFeaturesFromOp(op, op.getB().getType());
    rock::GemmFeatures features = intersectGemmFeatures(featuresA, featuresC);

    if (failed(setSplitKAttrs(op, features, rewriter)))
      return;

    rock::GemmElementwiseGemmOp gemmElentwiseGemmOp =
        rock::GemmElementwiseGemmOp::create(
            rewriter, loc, outputType, firstMatMulOp.getA(),
            firstMatMulOp.getB(), op.getB(), elementwiseOtherArgs, output,
            /*qTransposed=*/nullptr,
            /*kTransposed=*/nullptr,
            /*vTransposed=*/nullptr,
            /*oTransposed=*/nullptr,
            /*features=*/nullptr,
            rewriter.getAttr<rock::StoreMethodAttr>(rock::StoreMethod::Set),
            /*params0=*/nullptr, /*params1=*/nullptr,
            /*firstGemmIndices=*/
            rewriter.getDenseI64ArrayAttr(firstGemmBlockIndex));
    Block *preSecondGemmElemwiseBlock =
        &gemmElentwiseGemmOp.getPreSecondGemmBody().emplaceBlock();
    {
      PatternRewriter::InsertionGuard guard(rewriter);
      rewriter.setInsertionPointToStart(preSecondGemmElemwiseBlock);
      elemwiseFinder.rewrite(op.getA(), rewriter, preSecondGemmElemwiseBlock,
                             loc);
    }
    if (auto attr = op->getAttrOfType<StringAttr>("perf_config"))
      gemmElentwiseGemmOp->setAttr("perf_config", attr);

    rewriter.replaceOp(op, gemmElentwiseGemmOp.getResult());
  }

  LogicalResult matchAndRewrite(tosa::MatMulOp op,
                                PatternRewriter &rewriter) const override {
    FailureOr<ElementwiseRegionFinder<tosa::MatMulOp>> elemwiseFinder =
        match(op);
    if (succeeded(elemwiseFinder)) {
      rewrite(op, elemwiseFinder.value(), rewriter);
    }
    return elemwiseFinder;
  }
};

struct SoftmaxMatcherValues {
  Value softmaxInput;
  Operation *subOp;
  Value exp;
  Operation *reduceMaxOp;
  Operation *reduceSumOp;
  bool hasReduceOp;
};

struct AttentionMatcherValues {
  SoftmaxMatcherValues softmaxValues;
  Value lse;
  Value causalMaskInput;
  Value currentSeqLen;
  bool isCausal;
  Type softmaxType;
  ElementwiseRegionFinder<tosa::MatMulOp> preSoftmaxElementwiseFinder;
};

struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
  using OpRewritePattern::OpRewritePattern;

  FailureOr<Value>
  getValueNonReshapeOpNonBroadcastNonTranspose(Value val) const {
    while (val.getDefiningOp() &&
           (val.getDefiningOp<tensor::CollapseShapeOp>() ||
            val.getDefiningOp<tensor::ExpandShapeOp>() ||
            val.getDefiningOp<tosa::TransposeOp>() ||
            val.getDefiningOp<tosa::AddOp>())) {
      if (val.getDefiningOp<tosa::AddOp>()) {
        auto maybeBroadcast = addBroadcast(val);
        if (failed(maybeBroadcast))
          return failure();
        val = maybeBroadcast.value();
      } else
        val = val.getDefiningOp()->getOperand(0);
    }
    return val;
  }

  Value getValueNonReshapeOp(Value val) const {
    while (val.getDefiningOp() &&
           (val.getDefiningOp<tensor::CollapseShapeOp>() ||
            val.getDefiningOp<tensor::ExpandShapeOp>())) {
      val = val.getDefiningOp()->getOperand(0);
    }
    return val;
  }

  template <typename TosaOp>
  TosaOp getDefiningNonReshapeOpNonBroadcast(Value val) const {
    while (val.getDefiningOp<tensor::CollapseShapeOp>() ||
           val.getDefiningOp<tensor::ExpandShapeOp>() ||
           val.getDefiningOp<tosa::AddOp>()) {
      if (val.getDefiningOp<tosa::AddOp>()) {
        auto maybeBroadcast = addBroadcast(val);
        if (failed(maybeBroadcast))
          return nullptr;
        val = maybeBroadcast.value();
      } else
        val = val.getDefiningOp()->getOperand(0);
    }
    return val.getDefiningOp<TosaOp>();
  }

  template <typename TosaOp>
  TosaOp getDefiningNonReshapeOp(Value val) const {
    while (val.getDefiningOp<tensor::CollapseShapeOp>() ||
           val.getDefiningOp<tensor::ExpandShapeOp>()) {
      val = val.getDefiningOp()->getOperand(0);
    }
    return val.getDefiningOp<TosaOp>();
  }

  template <typename TosaOp>
  TosaOp getDefiningNonReshapeOpNonCastOp(Value val) const {
    while (val.getDefiningOp<tensor::CollapseShapeOp>() ||
           val.getDefiningOp<tensor::ExpandShapeOp>() ||
           val.getDefiningOp<tosa::CastOp>()) {
      val = val.getDefiningOp()->getOperand(0);
    }
    return val.getDefiningOp<TosaOp>();
  }

  FailureOr<Value> addBroadcast(Value val) const {
    if (auto add = getDefiningNonReshapeOp<tosa::AddOp>(val)) {
      // this is a broadcast add, one of the arguments comes is the actual
      // value, the other is a 0 constant
      Value nonZero;
      if (auto constOp =
              getDefiningNonReshapeOp<tosa::ConstOp>(add.getInput1())) {
        if (isConstantZero(constOp.getResult()))
          nonZero = add.getInput2();
      } else if (auto constOp = getDefiningNonReshapeOp<arith::ConstantOp>(
                     add.getInput1())) {
        if (isConstantZero(constOp.getResult()))
          nonZero = add.getInput2();
      }

      if (auto constOp =
              getDefiningNonReshapeOp<tosa::ConstOp>(add.getInput2())) {
        if (isConstantZero(constOp.getResult()))
          nonZero = add.getInput1();
      } else if (auto constOp = getDefiningNonReshapeOp<arith::ConstantOp>(
                     add.getInput2())) {
        if (isConstantZero(constOp.getResult()))
          nonZero = add.getInput1();
      }
      if (nonZero)
        return nonZero;
    }
    return failure();
  }

  LogicalResult getConstComparison(TypedValue<TensorType> input,
                                   size_t nonOneDimFromEnd) const {
    // input is a constant with a range from 0 to maxSeqLen
    FailureOr<Value> maybeNonZero = addBroadcast(input);
    if (failed(maybeNonZero))
      return failure();

    // check that maybeNonZero is a const with range 0..maxSeqLen
    bool isRange = false;
    Value rangeResult;
    if (auto constRange =
            getDefiningNonReshapeOp<arith::ConstantOp>(maybeNonZero.value())) {
      rangeResult = constRange.getResult();
      isRange = isConstRange(rangeResult);
    } else if (auto constRange = getDefiningNonReshapeOp<tosa::ConstOp>(
                   maybeNonZero.value())) {
      rangeResult = constRange.getResult();
      isRange = isConstRange(rangeResult);
    }

    if (!isRange)
      return failure();

    auto shapedType = dyn_cast<ShapedType>(rangeResult.getType());
    if (!shapedType)
      return failure();

    auto shape = shapedType.getShape();
    assert(nonOneDimFromEnd < shape.size());
    size_t couldBeDiffOne = shape.size() - nonOneDimFromEnd - 1;
    for (auto [i, dim] : llvm::enumerate(shape)) {
      if (i != couldBeDiffOne && dim != 1) {
        return failure();
      }
    }
    return success();
  }

  FailureOr<Value> getCausal(Value input) const {
    auto select = getDefiningNonReshapeOpNonCastOp<tosa::SelectOp>(input);
    if (select) {
      // Check onTrue is -inf
      auto onTrue = select.getInput2();
      bool isConsNegInf = false;
      if (auto constOp = getDefiningNonReshapeOp<arith::ConstantOp>(onTrue))
        isConsNegInf = isConstIsNegInf(constOp.getResult());
      else if (auto constOp = getDefiningNonReshapeOp<tosa::ConstOp>(onTrue))
        isConsNegInf = isConstIsNegInf(constOp.getResult());

      if (!isConsNegInf)
        return failure();

      auto pred = select.getInput1();
      if (auto greater =
              getDefiningNonReshapeOpNonCastOp<tosa::GreaterOp>(pred)) {
        // input1 is a constant with a range from 0 to maxSeqLen (KV)
        if (failed(getConstComparison(greater.getInput1(), 0)))
          return failure();

        // input2 is a constant with a range from 0 to seqLenQ
        if (failed(getConstComparison(greater.getInput2(), 1)))
          return failure();

        Value result = select.getInput3();

        return result;
      }
    }
    return failure();
  }

  /*
  LSE pattern for seqLen1 would be simplified from
  log(sum(exp(sub(x, x)))) + max(x)
  = log(exp(sub(x, x))) + x
  = sub(x, x) + x
  */
  Value getLSESeqLen1(tosa::SubOp subOp) const {
    if (subOp.getInput1() != subOp.getInput2()) {
      // this is a sub of two different values, we cannot match LSE
      return nullptr;
    }
    Value subInput = subOp.getInput1();
    for (Operation *user : subOp->getUsers()) {
      if (tosa::AddOp addOp = dyn_cast<tosa::AddOp>(user)) {
        Value addOpInput1 = addOp.getInput1();
        Value addOpInput2 = addOp.getInput2();
        if (tosa::SubOp addOperandSubOp =
                addOpInput1.getDefiningOp<tosa::SubOp>()) {
          if (addOperandSubOp == subOp && addOpInput2 == subInput)
            return addOp.getOutput();
        } else if (tosa::SubOp addOperandSubOp =
                       addOpInput2.getDefiningOp<tosa::SubOp>()) {
          if (addOperandSubOp == subOp && addOpInput1 == subInput) {
            return addOp.getOutput();
          }
        }
      }
    }
    return nullptr;
  }
  /**
   * Attempts to match and extract a Log-Sum-Exp (LSE) pattern from TOSA
   * operations.
   *
   * This function traverses the users of a reduce sum operation to identify a
   * complete LSE computation pattern, which typically consists of:
   * 1. A reduce_max operation to find the maximum values (in some cases this
   * might not exist)
   * 2. Subtraction of the max from original values (implicit in the pattern)
   * 3. Exponential and sum operations (represented by reduceSum)
   * 4. A logarithm operation on the result
   * 5. Addition of the original max values back
   *
   * Note that reduceSum and reduceMax are given.
   *
   * The LSE pattern: log(sum(exp(x - max(x)))) + max(x)
   */
  Value getLSE(Operation *reduceSum, Operation *reduceMax,
               tosa::LogOp logOp = nullptr) const {
    for (auto *user : reduceSum->getUsers()) {
      if (auto op = dyn_cast<tosa::CastOp>(user)) {
        // we already found a log
        if (logOp != nullptr)
          return nullptr;
        Value val = getLSE(op, reduceMax);
        if (val)
          return val;
      } else if (auto op = dyn_cast<tosa::LogOp>(user)) {
        // we already found a log
        if (logOp != nullptr)
          return nullptr;
        Value val = getLSE(op, reduceMax, op);
        if (val)
          return val;
      } else if (auto addOp = dyn_cast<tosa::AddOp>(user)) {
        if (!logOp)
          continue;
        auto logOpFromAdd =
            getDefiningNonReshapeOp<tosa::LogOp>(addOp.getInput1());
        if (!logOpFromAdd)
          logOpFromAdd =
              getDefiningNonReshapeOp<tosa::LogOp>(addOp.getInput2());

        // must match the logOp
        if (logOp != logOpFromAdd)
          return nullptr;

        // ReduceMax could be gone if there's only one dim, then, we don't
        // know the previous op, because it could be anything we want to fuse
        auto *reduceMaxOpFromAdd =
            getDefiningNonReshapeOp<Operation *>(addOp.getInput1());
        if (!reduceMaxOpFromAdd || isa<tosa::LogOp>(reduceMaxOpFromAdd))
          reduceMaxOpFromAdd =
              getDefiningNonReshapeOp<Operation *>(addOp.getInput2());

        if (auto castOp = dyn_cast<tosa::CastOp>(reduceMaxOpFromAdd)) {
          // if the reduceMax is a cast, we need to get the input of the cast
          reduceMaxOpFromAdd =
              getDefiningNonReshapeOp<Operation *>(castOp.getInput());
        }

        // must match the reduceMax
        if (!reduceMax || reduceMax != reduceMaxOpFromAdd)
          return nullptr;

        return addOp.getOutput();
      } else if (isa<tensor::CollapseShapeOp>(user) ||
                 isa<tensor::ExpandShapeOp>(user)) {
        Value val = getLSE(user, reduceMax, logOp);
        if (val)
          return val;
      }
    }
    return nullptr;
  }

  FailureOr<std::pair<Value, Value>> getKVCache(Value softmaxInput) const {
    auto select =
        getDefiningNonReshapeOpNonCastOp<tosa::SelectOp>(softmaxInput);
    if (select) {
      // Check onTrue is -inf
      auto onTrue = select.getInput2();
      bool isConsNegInf = false;
      if (auto constOp = getDefiningNonReshapeOp<arith::ConstantOp>(onTrue))
        isConsNegInf = isConstIsNegInf(constOp.getResult());
      else if (auto constOp = getDefiningNonReshapeOp<tosa::ConstOp>(onTrue))
        isConsNegInf = isConstIsNegInf(constOp.getResult());

      if (!isConsNegInf)
        return failure();

      auto pred = select.getInput1();
      if (auto greater =
              getDefiningNonReshapeOpNonCastOp<tosa::GreaterOp>(pred)) {
        // input1 is a constant with a range from 0 to maxSeqLen
        if (failed(getConstComparison(greater.getInput1(), 0)))
          return failure();

        // input2 comes from argument: currentSeqLen
        auto input2 = greater.getInput2();
        FailureOr<Value> maybeNonZero2 = addBroadcast(input2);
        if (failed(maybeNonZero2))
          return failure();

        // check that the right dimensions are broadcasted
        auto beforeBroadcastShape =
            dyn_cast<ShapedType>(maybeNonZero2->getType());
        if (beforeBroadcastShape) {
          auto shape = beforeBroadcastShape.getShape();
          if (beforeBroadcastShape.getRank() > 2 &&
              !llvm::all_of(shape.slice(2), [](int32_t v) { return v == 1; }))
            return failure();
        } else {
          return failure();
        }

        Value currentSeqLen = getValueNonReshapeOp(maybeNonZero2.value());
        Value result = select.getInput3();

        // currentSeqLen must be of i32 type
        auto currentSeqLenShape = dyn_cast<ShapedType>(currentSeqLen.getType());
        if (!currentSeqLenShape ||
            !currentSeqLenShape.getElementType().isInteger(32))
          return failure();

        // we'll check now if currentSeqLen comes from a block argument
        FailureOr<Value> mustBeBlockArg =
            getValueNonReshapeOpNonBroadcastNonTranspose(currentSeqLen);

        if (failed(mustBeBlockArg) ||
            !isa<BlockArgument>(mustBeBlockArg.value()))
          return failure();

        return std::make_pair(result, currentSeqLen);
      }
    }
    return failure();
  }

  /*
  return true if there is path from `fromVal` to `toVal`
  */
  bool areConnected(Value fromVal, Value toVal,
                    llvm::SmallDenseMap<Value, bool> &pathMap) const {
    if (fromVal == toVal) {
      pathMap[fromVal] = true;
    }
    if (pathMap.contains(fromVal))
      return pathMap[fromVal];
    pathMap[fromVal] = false;
    for (Operation *user : fromVal.getUsers()) {
      for (Value userResultVal : user->getResults()) {
        if (areConnected(userResultVal, toVal, pathMap))
          pathMap[fromVal] = true;
      }
    }
    return pathMap[fromVal];
  }

  /*
  if softmax happens in a different datatype/precision compared to the first
  gemm output, then first gemm output type would have a cast operation that
  converts input to softmax data type. This function traces from first gemm
  output to cast operation and then traces path from cast to softmax input.
  Later during `match()` types of the casts on both softmax input and outputs
  are compared to ensure that cast op is indeed to change type of the softmax
  and it is not part of the fusion.
  */
  FailureOr<Type> getSoftmaxType(Value firstGemmOutput,
                                 Value softmaxInput) const {
    llvm::SmallDenseSet<Operation *> visited;
    llvm::SmallVector<Operation *> worklist = {
        firstGemmOutput.getUsers().begin(), firstGemmOutput.getUsers().end()};
    Type softmaxInputType =
        cast<ShapedType>(softmaxInput.getType()).getElementType();
    Type lastCastOutputType = nullptr;
    while (!worklist.empty()) {
      Operation *user = worklist.pop_back_val();
      if (visited.contains(user))
        continue;
      visited.insert(user);
      if (isa<tosa::CastOp>(user)) {
        // trace cast op to softmax input
        llvm::SmallDenseMap<Value, bool> pathMap;
        Value castOutput = user->getResult(0);
        Type castOutputType =
            cast<ShapedType>(castOutput.getType()).getElementType();
        if (areConnected(castOutput, softmaxInput, pathMap) &&
            castOutputType == softmaxInputType) {
          lastCastOutputType = castOutputType;
        }
      }
      worklist.insert(worklist.end(), user->getUsers().begin(),
                      user->getUsers().end());
    }
    if (lastCastOutputType == nullptr)
      return failure();
    return lastCastOutputType;
  }

  FailureOr<SoftmaxMatcherValues> maybeSoftmaxNumerator(Value val,
                                                        Operation *rsum) const {
    tosa::ExpOp exp = getDefiningNonReshapeOp<tosa::ExpOp>(val);
    if (!exp)
      return failure();

    tosa::SubOp sub = getDefiningNonReshapeOp<tosa::SubOp>(exp.getInput1());
    if (!sub)
      return failure();

    bool hasTosaReduce = false;
    Value result;
    tosa::ReduceMaxOp rmax =
        getDefiningNonReshapeOpNonBroadcast<tosa::ReduceMaxOp>(sub.getInput2());
    if (rmax) {
      if (rmax.getInput() != sub.getInput1())
        return failure();

      hasTosaReduce = true;
      result = rmax.getInput();
    } else {
      // this case happens when we have seq_len=1. in that case reduction size
      // would be one and both reduceMax and reduceSum would have been
      // const-folded
      if (sub.getInput1() != sub.getInput2())
        return failure();

      hasTosaReduce = false;
      result = sub.getInput1();
    }
    return SoftmaxMatcherValues{result, sub, exp, rmax, rsum, hasTosaReduce};
  }

  FailureOr<SoftmaxMatcherValues> maybeSoftmaxDenominator(Value val) const {
    FailureOr<SoftmaxMatcherValues> result;
    auto rsum = getDefiningNonReshapeOpNonBroadcast<tosa::ReduceSumOp>(val);
    if (rsum) {
      result = maybeSoftmaxNumerator(rsum.getInput(), rsum);
      if (succeeded(result) && !result.value().hasReduceOp) {
        // if we see tosa::Reduce Op in the denominator then we expect to see
        // tosa::Reduce Op in the numerator as well
        return failure();
      }
    } else {
      result = maybeSoftmaxNumerator(val, val.getDefiningOp());
      if (succeeded(result) && result.value().hasReduceOp) {
        // if we don't see tosa::Reduce Op in the denominator then we expect
        // to not see any tosa::Reduce Op in the numerator as well
        return failure();
      }
    }
    return result;
  }

  FailureOr<SoftmaxMatcherValues> maybeSoftmax(Value val) const {
    auto mul = getDefiningNonReshapeOp<tosa::MulOp>(val);
    if (!mul)
      return failure();
    if (auto rec = getDefiningNonReshapeOpNonBroadcast<tosa::ReciprocalOp>(
            mul.getInput1())) {
      return maybeSoftmaxDenominator(rec.getInput1());
    }
    if (auto rec = getDefiningNonReshapeOpNonBroadcast<tosa::ReciprocalOp>(
            mul.getInput2())) {
      return maybeSoftmaxDenominator(rec.getInput1());
    }
    return failure();
  }

  Value normalizeInputTensor(PatternRewriter &rewriter, Location loc,
                             TypedValue<TensorType> inputTensor) const {
    if (!inputTensor) {
      return inputTensor;
    }
    ArrayRef<int64_t> shape = inputTensor.getType().getShape();
    SmallVector<int64_t, 4> reverseInputShape =
        llvm::to_vector<4>(llvm::reverse(shape));
    SmallVector<int64_t, 4> normalizedShape;
    int collapsedBatchLen = 1;
    for (int64_t dimLen : ArrayRef<int64_t>{reverseInputShape}.slice(2)) {
      collapsedBatchLen *= dimLen;
    }
    normalizedShape.push_back(collapsedBatchLen);
    normalizedShape.push_back(reverseInputShape[1]);
    normalizedShape.push_back(reverseInputShape[0]);
    auto normalizedType = RankedTensorType::get(
        normalizedShape, inputTensor.getType().getElementType());
    auto normalizedShapeValue =
        tosa::getTosaConstShape(rewriter, loc, normalizedShape);
    auto reshapeOp = tosa::ReshapeOp::create(rewriter, loc, normalizedType,
                                             inputTensor, normalizedShapeValue);
    return reshapeOp;
  }

  void moveUsersAfterExpandShape(PatternRewriter &rewriter, Location loc,
                                 Operation *expandedOutLse,
                                 tosa::AddOp addOp) const {
    llvm::SmallVector<Operation *> toMove;
    llvm::SmallDenseSet<Operation *> visited;
    llvm::SmallVector<Operation *> worklist;

    // Seed the worklist with direct users
    for (Operation *user : addOp->getUsers()) {
      if (!isa<func::ReturnOp>(user))
        worklist.push_back(user);
    }

    // Collect all transitive users (BFS)
    while (!worklist.empty()) {
      Operation *op = worklist.pop_back_val();
      if (!visited.insert(op).second)
        continue;
      toMove.push_back(op);
      for (Operation *user : op->getUsers()) {
        if (!isa<func::ReturnOp>(user))
          worklist.push_back(user);
      }
    }
    // Sort by IR order
    llvm::sort(toMove, [](Operation *a, Operation *b) {
      return a->isBeforeInBlock(b);
    });

    // Move in reverse order to preserve dependencies
    for (Operation *op : llvm::reverse(toMove))
      op->moveAfter(expandedOutLse);
  }

  FailureOr<AttentionMatcherValues> match(tosa::MatMulOp op) const {
    Value softmaxOutput = op.getA();

    // check if the softmax is done in different precision compared to GEMMs
    Type softmaxType =
        cast<ShapedType>(softmaxOutput.getType()).getElementType();
    auto softmaxOutputCastOp =
        getDefiningNonReshapeOp<tosa::CastOp>(softmaxOutput);
    if (softmaxOutputCastOp) {
      softmaxOutput = softmaxOutputCastOp.getInput();
      if (getDefiningNonReshapeOp<tosa::CastOp>(softmaxOutput)) {
        LLVM_DEBUG(llvm::dbgs()
                   << "softmax output has multiple casts. rocMLIR only allows "
                      "one cast between softmax and gemm2\n");
        return failure();
      }
      softmaxType = cast<ShapedType>(softmaxOutput.getType()).getElementType();
    }

    // pattern match for softmax operation
    FailureOr<SoftmaxMatcherValues> softmaxMatcherResults =
        maybeSoftmax(softmaxOutput);

    if (failed(softmaxMatcherResults))
      return failure();
    SoftmaxMatcherValues softmaxMatcherValues = softmaxMatcherResults.value();

    Value softmaxInput = softmaxMatcherValues.softmaxInput;
    bool hasReduceOp = softmaxMatcherValues.hasReduceOp;
    Operation *sub = softmaxMatcherValues.subOp;
    Operation *rmax = softmaxMatcherValues.reduceMaxOp;
    Operation *rsum = softmaxMatcherValues.reduceSumOp;
    Value lse;
    if (hasReduceOp) {
      lse = getLSE(rsum, rmax);
    } else {
      // if there is no reduce op, then we have seq_len=1 and lse is just
      // sub(x, x) + x
      lse = getLSESeqLen1(cast<tosa::SubOp>(sub));
    }
    // lse has three or four dimensions
    if (lse) {
      auto type = cast<ShapedType>(lse.getType());
      if (type.getRank() != 4 && type.getRank() != 3)
        return failure();
      // last dimension must be 1: {B, NUM_HEADS, SEQ_LEN_Q, 1}
      if (type.getRank() == 4 && type.getDimSize(type.getRank() - 1) != 1)
        return failure();
    }

    // Note that non KV-Cache fusions might have tosa.select
    // so, if the checks for kv-cache fail, we just keep going
    Value kvCacheInput, currentSeqLen;
    auto maybeKVCache = getKVCache(softmaxInput);
    if (succeeded(maybeKVCache))
      std::tie(kvCacheInput, currentSeqLen) = maybeKVCache.value();
    else
      kvCacheInput = softmaxInput;

    // currentSeqLen needs one or two dimensions
    if (currentSeqLen &&
        cast<ShapedType>(currentSeqLen.getType()).getRank() > 2)
      return failure();

    auto causal = getCausal(kvCacheInput);
    bool isCausal = succeeded(causal);
    Value causalMaskInput = isCausal ? causal.value() : kvCacheInput;

    OpBuilder b{op};
    ElementwiseRegionFinder<tosa::MatMulOp> preSoftmaxElementwiseFinder;
    preSoftmaxElementwiseFinder.visit(causalMaskInput);
    FailureOr<tosa::MatMulOp> maybeFirstMatMul =
        preSoftmaxElementwiseFinder.getFirstGemmBasedOp();
    if (failed(maybeFirstMatMul)) {
      LLVM_DEBUG(llvm::dbgs() << "first matmul not found\n");
      return failure();
    }

    TypedValue<TensorType> matC = maybeFirstMatMul.value().getOutput();
    ArrayRef<int64_t> shapeC = matC.getType().getShape();
    bool isDotProduct = *(std::prev(shapeC.end(), 1)) == 1;
    isDotProduct &= *(std::prev(shapeC.end(), 2)) == 1;

    LLVM_DEBUG(llvm::dbgs()
               << "first matmul = " << maybeFirstMatMul.value() << "\n");
    LLVM_DEBUG(llvm::dbgs() << "hasReduceOp = " << hasReduceOp << "\n");
    LLVM_DEBUG(llvm::dbgs() << "isCausal = " << isCausal << "\n");
    if (isDotProduct && hasReduceOp)
      return failure();
    if (!isDotProduct && !hasReduceOp)
      return failure();

    // if softmax is done in different precision than GEMMs then there must be
    // cast operation on one of the uses of first GEMM
    if (softmaxOutputCastOp) {
      FailureOr<Type> softmaxInputCast = getSoftmaxType(matC, softmaxInput);
      if (failed(softmaxInputCast)) {
        LLVM_DEBUG(llvm::dbgs() << "softmax input cast not found\n");
        return failure();
      }
      if (softmaxInputCast.value() != softmaxType) {
        LLVM_DEBUG(
            llvm::dbgs()
            << "softmax type on input cast and output cast does not match\n");
        return failure();
      }
    }

    // populate struct to aggregate attention matcher values and pass it to
    // rewriter
    AttentionMatcherValues attentionMatcherValues;
    attentionMatcherValues.isCausal = isCausal;
    attentionMatcherValues.softmaxType = softmaxType;
    attentionMatcherValues.softmaxValues = softmaxMatcherValues;
    attentionMatcherValues.lse = lse;
    attentionMatcherValues.causalMaskInput = causalMaskInput;
    attentionMatcherValues.currentSeqLen = currentSeqLen;
    attentionMatcherValues.preSoftmaxElementwiseFinder =
        preSoftmaxElementwiseFinder;
    return attentionMatcherValues;
  }

  void rewrite(tosa::MatMulOp op,
               const AttentionMatcherValues &attentionMatcherValues,
               PatternRewriter &rewriter) const {
    Location loc = op.getLoc();
    auto outputType = cast<RankedTensorType>(op.getType());
    Value output = bufferization::AllocTensorOp::create(
        rewriter, loc, outputType, ValueRange{});
    RankedTensorType lseType;
    Value lse = attentionMatcherValues.lse;
    Value lseOut, lseOrig;
    SmallVector<ReassociationIndices> reassocIndicesLSE = {{0, 1}, {2, 3}};
    if (lse) {
      // {{0, 1}, {2, 3}} for 4D tensor, {{0}, {1, 2}} for 3D tensor
      if (cast<ShapedType>(lse.getType()).getRank() == 3)
        reassocIndicesLSE = {{0}, {1, 2}};

      // rock.attention expects lse to have the shape = {B, SEQ_LEN_Q}
      lseOrig = lse;
      lse = tensor::CollapseShapeOp::create(rewriter, op.getLoc(), lse,
                                            reassocIndicesLSE);

      lseType = cast<RankedTensorType>(lse.getType());
      lseOut = bufferization::AllocTensorOp::create(rewriter, loc, lseType,
                                                    ValueRange{});
    }
    ElementwiseRegionFinder<tosa::MatMulOp> preSoftmaxElementwiseFinder =
        attentionMatcherValues.preSoftmaxElementwiseFinder;
    SmallVector<Value> elementwiseOtherArgs =
        preSoftmaxElementwiseFinder.getElementwiseArgs();
    // causalMaskInput would be equal to kvCacheInput if there is no causal
    // mask and kvCacheInput would be same as softmaxInput if there is no
    // kv-cache. see match() for details
    Value causalMaskInput = attentionMatcherValues.causalMaskInput;
    tosa::MatMulOp firstMatMulOp =
        preSoftmaxElementwiseFinder.getFirstGemmBasedOp().value();
    Value currentSeqLen = attentionMatcherValues.currentSeqLen;
    bool isCausal = attentionMatcherValues.isCausal;
    TypeAttr softmaxTypeAttr =
        TypeAttr::get(attentionMatcherValues.softmaxType);

    // Reshape currentSeqLen {batch, numHeads} -> {batch * numHeads}
    if (currentSeqLen &&
        cast<ShapedType>(currentSeqLen.getType()).getRank() == 2) {
      SmallVector<ReassociationIndices> reassocIndices = {{0, 1}};
      currentSeqLen = tensor::CollapseShapeOp::create(
          rewriter, op.getLoc(), currentSeqLen, reassocIndices);
    }
    UnitAttr causalAttr = isCausal ? rewriter.getUnitAttr() : nullptr;
    ElementwiseRegionFinder<tosa::MatMulOp> elemwiseRegion =
        attentionMatcherValues.preSoftmaxElementwiseFinder;
    int64_t firstGemmBlockIndex = elemwiseRegion.getFirstGemmBlockIndex();

    // TODO: numHeadsQ and numHeadsKV migraphx integration
    rock::AttentionOp attnOp = rock::AttentionOp::create(
        rewriter, loc, outputType, lseType, firstMatMulOp.getA(),
        firstMatMulOp.getB(), op.getB(), elementwiseOtherArgs, currentSeqLen,
        output, lseOut,
        /*numHeadsQ=*/rewriter.getI32IntegerAttr(1),
        /*numHeadsKV=*/rewriter.getI32IntegerAttr(1),
        /*qTransposed=*/nullptr,
        /*kTransposed=*/nullptr,
        /*vTransposed=*/nullptr,
        /*oTransposed=*/nullptr, causalAttr,
        /*splitKV=*/rewriter.getI32IntegerAttr(1),
        /*features=*/nullptr,
        rewriter.getAttr<rock::StoreMethodAttr>(rock::StoreMethod::Set),
        softmaxTypeAttr,
        /*params0=*/nullptr, /*params1=*/nullptr,
        /*firstGemmIndices=*/
        rewriter.getDenseI64ArrayAttr(firstGemmBlockIndex));
    Block *preSoftmaxElemwiseBlock = &attnOp.getPreSoftmaxBody().emplaceBlock();
    {
      PatternRewriter::InsertionGuard guard(rewriter);
      rewriter.setInsertionPointToStart(preSoftmaxElemwiseBlock);
      elemwiseRegion.rewrite(causalMaskInput, rewriter, preSoftmaxElemwiseBlock,
                             loc);
    }
    tosa::AddOp addOp;
    Value expandedOutLse;
    if (lse) {
      // Reverse the collapse operation
      expandedOutLse = tensor::ExpandShapeOp::create(
          rewriter, op.getLoc(), lseOrig.getType(), attnOp->getResult(1),
          reassocIndicesLSE);

      // collecting AddOp before the first replace
      addOp = lseOrig.getDefiningOp<tosa::AddOp>();

      // all users have to be moved after the expand shape
      moveUsersAfterExpandShape(rewriter, op.getLoc(),
                                expandedOutLse.getDefiningOp(), addOp);
    }
    if (auto attr = op->getAttrOfType<StringAttr>("perf_config"))
      attnOp->setAttr("perf_config", attr);

    rewriter.replaceOp(op, attnOp->getResult(0));
    if (lse) {
      rewriter.replaceOp(addOp, expandedOutLse);
    }
  }

  LogicalResult matchAndRewrite(tosa::MatMulOp op,
                                PatternRewriter &rewriter) const override {
    FailureOr<AttentionMatcherValues> attentionMatcherResult = match(op);
    if (failed(attentionMatcherResult)) {
      return failure();
    }
    const AttentionMatcherValues &attentionMatcherValues =
        attentionMatcherResult.value();
    rewrite(op, attentionMatcherValues, rewriter);
    return success();
  }
};

template <typename TosaReduceOp>
typename std::enable_if_t<
    std::is_same<TosaReduceOp, tosa::ReduceSumOp>::value ||
        std::is_same<TosaReduceOp, tosa::ReduceMaxOp>::value,
    LogicalResult> static matchAndRewriteReductions(TosaReduceOp op,
                                                    rock::ReduceMethod rMethod,
                                                    Attribute outputInitVal,
                                                    ConversionPatternRewriter
                                                        &rw) {
  Location loc = op->getLoc();
  auto outputType = cast<RankedTensorType>(op.getType());
  Value output =
      bufferization::AllocTensorOp::create(rw, loc, outputType, ValueRange{});

  int32_t blockSize = 256;
  auto elementCount =
      cast<ShapedType>(op.getInput().getType()).getNumElements();
  int32_t gridSize = (elementCount + blockSize - 1) / blockSize;
  auto numCU = rock::getNumCU(op);
  if (succeeded(numCU)) {
    gridSize = std::min((int32_t)(20 * numCU.value()), gridSize);
  }

  auto rockReduce = rock::ReduceOp::create(
      rw, loc, outputType, op.getInput(), output,
      rw.getAttr<rock::ReduceMethodAttr>(rMethod),
      rw.getIndexAttr(op.getAxis()), rw.getI32IntegerAttr(blockSize),
      rw.getI32IntegerAttr(gridSize),
      /*useLDS=*/nullptr,
      /*useDPP=*/nullptr);

  func::FuncOp func = op->template getParentOfType<func::FuncOp>();
  SetVector<int64_t> resIndices = traceToRes(op.getOutput(), func);
  if (resIndices.empty())
    return op.emitOpError(
        "can't trace the reduction output to a kernel result");

  for (int64_t resNumber : resIndices) {
    func.setResultAttr(resNumber, rock::PrefillAttr::getMnemonic(),
                       outputInitVal);
    func.setResultAttr(resNumber, "mhal.read_access", rw.getUnitAttr());
    // The original function also need the read access attr for the output.
    if (func->hasAttr("original_func")) {
      if (ModuleOp rootMod =
              func->getParentOfType<ModuleOp>()->getParentOfType<ModuleOp>()) {
        SymbolTable symTable(rootMod);
        SymbolRefAttr originalFuncAttr =
            func->getAttrOfType<SymbolRefAttr>("original_func");
        if (func::FuncOp originalFunc = dyn_cast<func::FuncOp>(
                symTable.lookupSymbolIn(rootMod, originalFuncAttr))) {
          originalFunc.setResultAttr(resNumber, "mhal.read_access",
                                     rw.getUnitAttr());
        }
      }
    }
  }
  rw.replaceOp(op, rockReduce.getResult());
  return success();
}

class ReduceSumConverter final : public OpConversionPattern<tosa::ReduceSumOp> {
public:
  using OpConversionPattern<tosa::ReduceSumOp>::OpConversionPattern;

  LogicalResult matchAndRewrite(tosa::ReduceSumOp op,
                                tosa::ReduceSumOp::Adaptor adaptor,
                                ConversionPatternRewriter &rw) const final {
    Type elementType =
        cast<ShapedType>(op.getInput().getType()).getElementType();
    if (!isa<Float32Type, Float16Type, BFloat16Type>(elementType)) {
      return rw.notifyMatchFailure(
          op, "We only support F32, F16 and BF16 reductions, yet.");
    }
    Attribute outputInitVal = rw.getFloatAttr(elementType, 0.0000);
    return matchAndRewriteReductions(op, rock::ReduceMethod::Sum, outputInitVal,
                                     rw);
  }
};

class ReduceMaxConverter final : public OpConversionPattern<tosa::ReduceMaxOp> {
public:
  using OpConversionPattern<tosa::ReduceMaxOp>::OpConversionPattern;

  LogicalResult matchAndRewrite(tosa::ReduceMaxOp op,
                                tosa::ReduceMaxOp::Adaptor adaptor,
                                ConversionPatternRewriter &rw) const final {
    Type elementType =
        cast<ShapedType>(op.getInput().getType()).getElementType();
    Attribute outputInitVal;
    if (elementType.isF32()) {
      outputInitVal = rw.getFloatAttr(
          elementType, APFloat::getInf(APFloat::IEEEsingle(), true));
    } else {
      return rw.notifyMatchFailure(op, "We only support F32 reductions, yet.");
    }
    return matchAndRewriteReductions(op, rock::ReduceMethod::Max, outputInitVal,
                                     rw);
  }
};

// We identify the pattern dummy add with implicit broadcasting
// and rewrite it to be rock.transform broadcast
class AddSplatZeroRewritePattern final : public OpRewritePattern<tosa::AddOp> {
public:
  using OpRewritePattern<tosa::AddOp>::OpRewritePattern;
  LogicalResult matchAndRewrite(tosa::AddOp op,
                                PatternRewriter &rw) const final {
    Location loc = op.getLoc();
    TypedValue<TensorType> inp1 = op.getInput1();
    TypedValue<TensorType> inp2 = op.getInput2();
    TypedValue<TensorType> out = op.getOutput();

    TypedValue<TensorType> bcastInput;
    if (isConstantZero(inp1))
      bcastInput = inp2;
    if (isConstantZero(inp2)) {
      if (bcastInput) {
        return rw.notifyMatchFailure(op, "both inputs are splat zeros");
      }
      bcastInput = inp1;
    }
    if (bcastInput) {
      Value bcast =
          insertBroadcast(bcastInput, out.getType().getShape(), loc, rw);
      rw.replaceOp(op, bcast);
      return success();
    }
    return rw.notifyMatchFailure(op, "none of the inputs are splat zeros");
  }
};

} // namespace

void tosa::populateTosaToRockConversionPatterns(MLIRContext *context,
                                                RewritePatternSet &patterns) {
  patterns.add<ConvConverter<tosa::Conv2DOp>, ConvConverter<tosa::Conv3DOp>,
               ConvConverter<tosa::TransposeConv2DOp>, MatMulConverter,
               ReduceSumConverter, ReduceMaxConverter>(context);
}

void tosa::populateTosaToRockAttentionConversionPatterns(
    MLIRContext *context, RewritePatternSet &patterns) {
  patterns.add<AttentionRewritePattern>(context);
}

void tosa::populateTosaToRockGemmGemmConversionPatterns(
    MLIRContext *context, RewritePatternSet &patterns) {
  patterns.add<GemmElementwiseGemmRewritePattern>(context);
}

void tosa::populateTosaToRockConvGemmConversionPatterns(
    MLIRContext *context, RewritePatternSet &patterns) {
  patterns.add<ConvElementwiseGemmRewritePattern>(context);
}

void tosa::populateTosaToRockTensorConversionPatterns(
    MLIRContext *context, RewritePatternSet &patterns) {
  patterns.add<TransposeRewritePattern, CollapseExpandRewritePattern,
               AddSplatZeroRewritePattern>(context);
}
