Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/llvm/llvm-project.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorAart Bik <ajcbik@google.com>2022-03-09 00:24:45 +0300
committerAart Bik <ajcbik@google.com>2022-03-09 04:25:36 +0300
commit53cc3a06378229f5b4713f0db39135e846609d0a (patch)
tree88a0e71b283c79eab5df60a392d5296ac72cc0b6 /mlir
parenta53967cd553cd59452a48aa8651014cd8ed0342e (diff)
[mlir][sparse] index support in sparse compiler codegen
This revision adds support for the linalg.index to the sparse compiler pipeline. In essence, this adds the ability to refer to indices in the tensor index expression, as illustrated below: Y[i, j, k, l, m] = T[i, j, k, l, m] * i * j Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D121251
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h5
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp29
-rw-r--r--mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp44
-rw-r--r--mlir/test/Dialect/SparseTensor/sparse_index.mlir128
-rw-r--r--mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir81
5 files changed, 274 insertions, 13 deletions
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 304ba93737b5..3ecd584c8fe8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -28,6 +28,7 @@ enum Kind {
// Leaf.
kTensor = 0,
kInvariant,
+ kIndex,
// Unary operations.
kAbsF,
kCeilF,
@@ -42,6 +43,7 @@ enum Kind {
kCastUF, // unsigned
kCastS, // signed
kCastU, // unsigned
+ kCastIdx,
kTruncI,
kBitCast,
// Binary operations.
@@ -79,6 +81,9 @@ struct TensorExp {
/// Expressions representing tensors simply have a tensor number.
unsigned tensor;
+ /// Indices hold the index number.
+ unsigned index;
+
/// Tensor operations hold the indices of their children.
Children children;
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 6945e54a73a0..e4faddd9b7e8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
@@ -870,6 +871,13 @@ static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
return rewriter.create<arith::AddIOp>(loc, mul, i);
}
+/// Generates an index value.
+static Value genIndexValue(Merger &merger, CodeGen &codegen, unsigned exp) {
+ assert(codegen.curVecLength == 1); // TODO: implement vectorization!
+ unsigned idx = merger.exp(exp).index;
+ return codegen.loops[idx];
+}
+
/// Recursively generates tensor expression.
static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
linalg::GenericOp op, unsigned exp) {
@@ -880,6 +888,8 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
return genTensorLoad(merger, codegen, rewriter, op, exp);
if (merger.exp(exp).kind == Kind::kInvariant)
return genInvariantValue(merger, codegen, rewriter, exp);
+ if (merger.exp(exp).kind == Kind::kIndex)
+ return genIndexValue(merger, codegen, exp);
Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
return merger.buildExp(rewriter, loc, exp, v0, v1);
@@ -947,7 +957,8 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
merger.exp(exp).val =
atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
}
- } else if (merger.exp(exp).kind != Kind::kInvariant) {
+ } else if (merger.exp(exp).kind != Kind::kInvariant &&
+ merger.exp(exp).kind != Kind::kIndex) {
// Traverse into the binary operations. Note that we only hoist
// tensor loads, since subsequent MLIR/LLVM passes know how to
// deal with all other kinds of derived loop invariants.
@@ -1039,7 +1050,12 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
/// Returns vectorization strategy. Any implicit inner loop in the Linalg
/// operation is a candidate. Whether it is actually converted to SIMD code
/// depends on the requested strategy.
-static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
+static bool isVectorFor(CodeGen &codegen, bool isInner, bool isReduction,
+ bool isSparse) {
+ // Reject vectorization of sparse output, unless innermost is reduction.
+ if (codegen.sparseOut && !isReduction)
+ return false;
+ // Inspect strategy.
switch (codegen.options.vectorizationStrategy) {
case SparseVectorizationStrategy::kNone:
return false;
@@ -1056,6 +1072,10 @@ static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
/// to a parallel operation depends on the requested strategy.
static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
bool isSparse, bool isVector) {
+ // Reject parallelization of sparse output.
+ if (codegen.sparseOut)
+ return false;
+ // Inspect strategy.
switch (codegen.options.parallelizationStrategy) {
case SparseParallelizationStrategy::kNone:
return false;
@@ -1107,11 +1127,9 @@ static Operation *genFor(Merger &merger, CodeGen &codegen,
auto iteratorTypes = op.iterator_types().getValue();
bool isReduction = isReductionIterator(iteratorTypes[idx]);
bool isSparse = merger.isDim(fb, Dim::kSparse);
- bool isVector = !codegen.sparseOut &&
- isVectorFor(codegen, isInner, isSparse) &&
+ bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) &&
denseUnitStrides(merger, op, idx);
bool isParallel =
- !codegen.sparseOut &&
isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
// Prepare vector length.
@@ -1626,6 +1644,7 @@ public:
LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
+
// Detects sparse annotations and translate the per-dimension sparsity
// information for all tensors to loop indices in the kernel.
assert(op.getNumOutputs() == 1);
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 37e077acf06a..005278ca70d2 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -29,6 +29,10 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
case kInvariant:
assert(x == -1u && y == -1u && v);
break;
+ case kIndex:
+ assert(x != -1u && y == -1u && !v);
+ index = x;
+ break;
case kAbsF:
case kCeilF:
case kFloorF:
@@ -46,6 +50,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
case kCastUF:
case kCastS:
case kCastU:
+ case kCastIdx:
case kTruncI:
case kBitCast:
assert(x != -1u && y == -1u && v);
@@ -230,6 +235,7 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kCastUF:
case kCastS:
case kCastU:
+ case kCastIdx:
case kTruncI:
case kBitCast:
return isSingleCondition(t, tensorExps[e].children.e0);
@@ -273,6 +279,8 @@ static const char *kindToOpSymbol(Kind kind) {
return "tensor";
case kInvariant:
return "invariant";
+ case kIndex:
+ return "index";
case kAbsF:
return "abs";
case kCeilF:
@@ -291,6 +299,7 @@ static const char *kindToOpSymbol(Kind kind) {
case kCastUF:
case kCastS:
case kCastU:
+ case kCastIdx:
case kTruncI:
case kBitCast:
return "cast";
@@ -340,6 +349,9 @@ void Merger::dumpExp(unsigned e) const {
case kInvariant:
llvm::dbgs() << "invariant";
break;
+ case kIndex:
+ llvm::dbgs() << "index_" << tensorExps[e].index;
+ break;
case kAbsF:
case kCeilF:
case kFloorF:
@@ -353,6 +365,7 @@ void Merger::dumpExp(unsigned e) const {
case kCastUF:
case kCastS:
case kCastU:
+ case kCastIdx:
case kTruncI:
case kBitCast:
llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
@@ -420,16 +433,20 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
Kind kind = tensorExps[e].kind;
switch (kind) {
case kTensor:
- case kInvariant: {
+ case kInvariant:
+ case kIndex: {
// Either the index is really used in the tensor expression, or it is
- // set to the undefined index in that dimension. An invariant expression
- // and a truly dynamic sparse output tensor are set to a synthetic tensor
- // with undefined indices only to ensure the iteration space is not
- // skipped as a result of their contents.
+ // set to the undefined index in that dimension. An invariant expression,
+ // a proper index value, and a truly dynamic sparse output tensor are set
+ // to a synthetic tensor with undefined indices only to ensure the
+ // iteration space is not skipped as a result of their contents.
unsigned s = addSet();
- unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor;
- if (hasSparseOut && t == outTensor)
- t = syntheticTensor;
+ unsigned t = syntheticTensor;
+ if (kind == kTensor) {
+ t = tensorExps[e].tensor;
+ if (hasSparseOut && t == outTensor)
+ t = syntheticTensor;
+ }
latSets[s].push_back(addLat(t, i, e));
return s;
}
@@ -446,6 +463,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
case kCastUF:
case kCastS:
case kCastU:
+ case kCastIdx:
case kTruncI:
case kBitCast:
// A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
@@ -569,6 +587,11 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
Operation *def = v.getDefiningOp();
if (def->getBlock() != &op.region().front())
return addExp(kInvariant, v);
+ // Construct index operations.
+ if (def->getNumOperands() == 0) {
+ if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
+ return addExp(kIndex, indexOp.dim());
+ }
// Construct unary operations if subexpression can be built.
if (def->getNumOperands() == 1) {
auto x = buildTensorExp(op, def->getOperand(0));
@@ -598,6 +621,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kCastS, e, v);
if (isa<arith::ExtUIOp>(def))
return addExp(kCastU, e, v);
+ if (isa<arith::IndexCastOp>(def))
+ return addExp(kCastIdx, e, v);
if (isa<arith::TruncIOp>(def))
return addExp(kTruncI, e, v);
if (isa<arith::BitcastOp>(def))
@@ -654,6 +679,7 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
switch (tensorExps[e].kind) {
case kTensor:
case kInvariant:
+ case kIndex:
llvm_unreachable("unexpected non-op");
// Unary ops.
case kAbsF:
@@ -686,6 +712,8 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
case kCastU:
return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
+ case kCastIdx:
+ return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
case kTruncI:
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
case kBitCast:
diff --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
new file mode 100644
index 000000000000..f41c765376bb
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
@@ -0,0 +1,128 @@
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#DenseMatrix = #sparse_tensor.encoding<{
+ dimLevelType = ["dense", "dense"]
+}>
+
+#SparseMatrix = #sparse_tensor.encoding<{
+ dimLevelType = ["compressed", "compressed"]
+}>
+
+#trait = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel"],
+ doc = "X(i,j) = A(i,j) * i * j"
+}
+
+// CHECK-LABEL: func @dense_index(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.init{{\[}}%[[VAL_3]], %[[VAL_4]]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_7:.*]] = tensor.dim %[[VAL_5]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_8:.*]] = tensor.dim %[[VAL_5]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_5]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_7]] step %[[VAL_2]] {
+// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_2]] {
+// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_8]], %[[VAL_10]] : index
+// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_11]] : index
+// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_8]], %[[VAL_10]] : index
+// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_11]] : index
+// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[VAL_11]] : index to i64
+// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_10]] : index to i64
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xi64>
+// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_18]] : i64
+// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_16]], %[[VAL_19]] : i64
+// CHECK: memref.store %[[VAL_20]], %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xi64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_21:.*]] = sparse_tensor.load %[[VAL_5]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK: return %[[VAL_21]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK: }
+func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
+ -> tensor<?x?xi64, #DenseMatrix> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 0 : index
+ %0 = tensor.dim %arga, %c0 : tensor<?x?xi64, #DenseMatrix>
+ %1 = tensor.dim %arga, %c1 : tensor<?x?xi64, #DenseMatrix>
+ %init = sparse_tensor.init [%0, %1] : tensor<?x?xi64, #DenseMatrix>
+ %r = linalg.generic #trait
+ ins(%arga: tensor<?x?xi64, #DenseMatrix>)
+ outs(%init: tensor<?x?xi64, #DenseMatrix>) {
+ ^bb(%a: i64, %x: i64):
+ %i = linalg.index 0 : index
+ %j = linalg.index 1 : index
+ %ii = arith.index_cast %i : index to i64
+ %jj = arith.index_cast %j : index to i64
+ %m1 = arith.muli %ii, %a : i64
+ %m2 = arith.muli %jj, %m1 : i64
+ linalg.yield %m2 : i64
+ } -> tensor<?x?xi64, #DenseMatrix>
+ return %r : tensor<?x?xi64, #DenseMatrix>
+}
+
+// CHECK-LABEL: func @sparse_index(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.init{{\[}}%[[VAL_4]], %[[VAL_5]]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK: %[[VAL_12:.*]] = memref.alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_2]] {
+// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK: memref.store %[[VAL_16]], %[[VAL_12]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_2]] : index
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_2]] {
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK: memref.store %[[VAL_21]], %[[VAL_12]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = arith.index_cast %[[VAL_21]] : index to i64
+// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_16]] : index to i64
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<?xi64>
+// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_23]], %[[VAL_24]] : i64
+// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_22]], %[[VAL_25]] : i64
+// CHECK: sparse_tensor.lex_insert %[[VAL_6]], %[[VAL_12]], %[[VAL_26]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_27:.*]] = sparse_tensor.load %[[VAL_6]] hasInserts : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK: return %[[VAL_27]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK: }
+func @sparse_index(%arga: tensor<?x?xi64, #SparseMatrix>)
+ -> tensor<?x?xi64, #SparseMatrix> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 0 : index
+ %0 = tensor.dim %arga, %c0 : tensor<?x?xi64, #SparseMatrix>
+ %1 = tensor.dim %arga, %c1 : tensor<?x?xi64, #SparseMatrix>
+ %init = sparse_tensor.init [%0, %1] : tensor<?x?xi64, #SparseMatrix>
+ %r = linalg.generic #trait
+ ins(%arga: tensor<?x?xi64, #SparseMatrix>)
+ outs(%init: tensor<?x?xi64, #SparseMatrix>) {
+ ^bb(%a: i64, %x: i64):
+ %i = linalg.index 0 : index
+ %j = linalg.index 1 : index
+ %ii = arith.index_cast %i : index to i64
+ %jj = arith.index_cast %j : index to i64
+ %m1 = arith.muli %ii, %a : i64
+ %m2 = arith.muli %jj, %m1 : i64
+ linalg.yield %m2 : i64
+ } -> tensor<?x?xi64, #SparseMatrix>
+ return %r : tensor<?x?xi64, #SparseMatrix>
+}
+
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir
new file mode 100644
index 000000000000..36a052155a59
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseMatrix = #sparse_tensor.encoding<{
+ dimLevelType = ["compressed", "compressed"]
+}>
+
+#trait = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel"],
+ doc = "X(i,j) = A(i,j) * i * j"
+}
+
+module {
+
+ //
+ // Kernel that uses indices in the index notation.
+ //
+ func @sparse_index(%arga: tensor<3x4xi64, #SparseMatrix>)
+ -> tensor<3x4xi64, #SparseMatrix> {
+ %d0 = arith.constant 3 : index
+ %d1 = arith.constant 4 : index
+ %init = sparse_tensor.init [%d0, %d1] : tensor<3x4xi64, #SparseMatrix>
+ %r = linalg.generic #trait
+ ins(%arga: tensor<3x4xi64, #SparseMatrix>)
+ outs(%init: tensor<3x4xi64, #SparseMatrix>) {
+ ^bb(%a: i64, %x: i64):
+ %i = linalg.index 0 : index
+ %j = linalg.index 1 : index
+ %ii = arith.index_cast %i : index to i64
+ %jj = arith.index_cast %j : index to i64
+ %m1 = arith.muli %ii, %a : i64
+ %m2 = arith.muli %jj, %m1 : i64
+ linalg.yield %m2 : i64
+ } -> tensor<3x4xi64, #SparseMatrix>
+ return %r : tensor<3x4xi64, #SparseMatrix>
+ }
+
+ //
+ // Main driver.
+ //
+ func @entry() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %du = arith.constant -1 : i64
+
+ // Setup input "sparse" matrix.
+ %d = arith.constant dense <[
+ [ 1, 1, 1, 1 ],
+ [ 1, 1, 1, 1 ],
+ [ 1, 1, 1, 1 ]
+ ]> : tensor<3x4xi64>
+ %a = sparse_tensor.convert %d : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix>
+
+ // Call the kernel.
+ %0 = call @sparse_index(%a) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64, #SparseMatrix>
+
+ //
+ // Verify result.
+ //
+ // CHECK: ( ( 0, 0, 0, 0 ), ( 0, 1, 2, 3 ), ( 0, 2, 4, 6 ) )
+ //
+ %x = sparse_tensor.convert %0 : tensor<3x4xi64, #SparseMatrix> to tensor<3x4xi64>
+ %m = bufferization.to_memref %x : memref<3x4xi64>
+ %v = vector.transfer_read %m[%c0, %c0], %du: memref<3x4xi64>, vector<3x4xi64>
+ vector.print %v : vector<3x4xi64>
+
+ // Release resources.
+ sparse_tensor.release %a : tensor<3x4xi64, #SparseMatrix>
+ sparse_tensor.release %0 : tensor<3x4xi64, #SparseMatrix>
+ memref.dealloc %m : memref<3x4xi64>
+
+ return
+ }
+}