Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/llvm/llvm-project.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorbixia1 <bixia@google.com>2022-10-31 07:55:25 +0300
committerbixia1 <bixia@google.com>2022-11-02 01:57:34 +0300
commiteb877006a61733a87257c1c999f5a0b880ccf3cd (patch)
tree87f04f3116732edb56ce1c14a6ae3c2b6ee125dd /mlir
parentf71d32a0eea47b3d2bb43d6be15cf09d47ef6971 (diff)
[mlir][sparse] Add rewriting rule for the convert operator.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D136301
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h4
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp18
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp218
-rw-r--r--mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir62
-rw-r--r--mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir31
-rw-r--r--mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir21
-rw-r--r--mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir3
-rw-r--r--mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir3
-rw-r--r--mlir/test/Dialect/SparseTensor/sparse_reshape.mlir3
9 files changed, 357 insertions, 6 deletions
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 7c615b47aed6..52f9fef7041c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -33,6 +33,10 @@ namespace sparse_tensor {
/// Returns null-attribute for any type without an encoding.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
+/// Returns true iff the given type is a type for a COO tensor with the last
+/// dimension level type being unique.
+bool isUniqueCOOType(RankedTensorType tp);
+
//
// Dimension level types.
//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index f17080c24710..133879b12b19 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -262,6 +262,24 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
return nullptr;
}
+bool mlir::sparse_tensor::isUniqueCOOType(RankedTensorType tp) {
+ SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp);
+
+ if (!enc)
+ return false;
+
+ if (!isCompressedDim(tp, 0))
+ return false;
+
+ for (uint64_t i = 1, e = tp.getRank(); i < e; ++i)
+ if (!isSingletonDim(tp, i))
+ return false;
+
+ // This works for rank == 1 (unique the only compressed) and rank > 1 (unique
+ // on the last singleton).
+ return isUniqueDim(tp, tp.getRank() - 1);
+}
+
uint64_t mlir::sparse_tensor::toOrigDim(const SparseTensorEncodingAttr &enc,
uint64_t d) {
if (enc) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 4399caecadcc..a0519462dd8c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -155,6 +155,18 @@ static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) {
return RankedTensorType::get(src.getShape(), src.getElementType(), enc);
}
+/// Collects the dynamic dimension sizes for `tp` with the assumption that
+/// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
+/// sizes to dynSizes.
+static void getDynamicSizes(RankedTensorType tp,
+ const SmallVectorImpl<Value> &sizes,
+ SmallVectorImpl<Value> &dynSizes) {
+ for (const auto &d : enumerate(tp.getShape())) {
+ if (d.value() == ShapedType::kDynamicSize)
+ dynSizes.push_back(sizes[d.index()]);
+ }
+}
+
//===---------------------------------------------------------------------===//
// The actual sparse tensor rewriting rules.
//===---------------------------------------------------------------------===//
@@ -461,6 +473,204 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
}
};
+/// Sparse rewriting rule for the convert operator.
+struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ConvertOp op,
+ PatternRewriter &rewriter) const override {
+ auto encDst = getSparseTensorEncoding(op.getType());
+ auto encSrc = getSparseTensorEncoding(op.getSource().getType());
+ if (encDst && encSrc) {
+ // Trivial tensor conversion is handled in codegen.
+ if (encSrc == encDst)
+ return failure();
+ return sparse2SparseRewrite(op, rewriter);
+ }
+ if (encSrc && !encDst)
+ return sparse2DenseRewrite(op, rewriter);
+ if (!encSrc && encDst)
+ return dense2SparseRewrite(op, rewriter);
+
+ // Dense-to-dense convert is a nop and handled by canonicalization.
+ return failure();
+ }
+
+private:
+ // Handles sparse constant to sparse tensor or dense tensor to sparse tensor
+ // conversion as follows:
+ // t = new sparse COO tensor
+ // fill t using src
+ // dst = convert t
+ //
+ // To fill the COO tensor from a dense tensor:
+ // for i1 in dim1
+ // ..
+ // for ik in dimk
+ // val = a[i1,..,ik]
+ // if val != 0
+ // t->add(val, [i1,..,ik], [p1,..,pk])
+ //
+ // To fill the COO tensor from a sparse constant in COO format:
+ // for i in range(NNZ)
+ // val = values[i]
+ // [i1,..,ik] = indices[i]
+ // t->add(val, [i1,..,ik], [p1,..,pk])
+ LogicalResult dense2SparseRewrite(ConvertOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ Value src = op.getSource();
+ RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+ SmallVector<Value, 4> sizes;
+ sizesFromSrc(rewriter, sizes, loc, src);
+ SmallVector<Value, 4> dynSizes;
+ getDynamicSizes(dstTp, sizes, dynSizes);
+
+ RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
+ auto cooBuffer =
+ rewriter.create<AllocTensorOp>(loc, cooTp, dynSizes).getResult();
+ unsigned rank = dstTp.cast<ShapedType>().getRank();
+
+ genDenseTensorOrSparseConstantIterLoop(
+ rewriter, loc, src, rank,
+ [&](OpBuilder &builder, Location loc, Value val, ValueRange indices) {
+ builder.create<InsertOp>(loc, val, cooBuffer, indices);
+ });
+
+ rewriter.setInsertionPointAfter(op);
+ rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, cooBuffer);
+ rewriter.create<DeallocTensorOp>(loc, cooBuffer);
+
+ return success();
+ }
+
+ // Handles sparse tensor to dense tensor conversion as follows:
+ // dst = new dense tensor;
+ // foreach elemment in src
+ // dst[elemment.indices] = element.value
+ LogicalResult sparse2DenseRewrite(ConvertOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op->getLoc();
+ RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+ Value src = op.getSource();
+ RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
+
+ SmallVector<Value, 4> sizes;
+ sizesForTensor(rewriter, sizes, loc, srcTp, src);
+ Value dst = allocDenseTensor(rewriter, loc, dstTp, sizes);
+
+ rewriter.create<ForeachOp>(
+ loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+ builder.create<memref::StoreOp>(loc, args.back(), dst,
+ args.drop_back());
+ builder.create<sparse_tensor::YieldOp>(loc);
+ });
+
+ rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstTp, dst);
+ return success();
+ }
+
+ // Handles sparse tensor to sparse tensor conversion as follows:
+ // if src is not COO
+ // construct a COO to represent the src
+ // sort the src COO
+ // foreach elemment in the sorted src COO
+ // insert element to dst
+ LogicalResult sparse2SparseRewrite(ConvertOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op->getLoc();
+ Value src = op.getSource();
+ RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
+ RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+ SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
+ SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
+
+ SmallVector<Value, 4> srcSizes;
+ sizesForTensor(rewriter, srcSizes, loc, srcTp, src);
+ Value tmpCoo = Value();
+ if (!isUniqueCOOType(srcTp)) {
+ // Construct a COO tensor from the src tensor.
+ // TODO: there may be cases for which more efficiently without
+ // going through an intermediate COO, such as cases that only change
+ // the overhead types.
+ SmallVector<Value, 4> dynSrcSizes;
+ getDynamicSizes(srcTp, srcSizes, dynSrcSizes);
+ srcTp = getUnorderedCOOFromType(srcTp);
+ tmpCoo =
+ rewriter.create<AllocTensorOp>(loc, srcTp, dynSrcSizes).getResult();
+ rewriter.create<ForeachOp>(
+ loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+ SmallVector<Value, 4> indices;
+ for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
+ uint64_t dim = toStoredDim(encSrc, i);
+ indices.push_back(args[dim]);
+ }
+ builder.create<InsertOp>(loc, args.back(), tmpCoo, indices);
+ builder.create<sparse_tensor::YieldOp>(loc);
+ });
+ src = tmpCoo;
+ }
+
+ // Sort the COO tensor so that its elements are ordered via increasing
+ // indices for the storage ordering of the dst tensor.
+ auto dynShape = {ShapedType::kDynamicSize};
+ auto indTp =
+ MemRefType::get(dynShape, getIndexOverheadType(rewriter, encSrc));
+ uint64_t rank = dstTp.getRank();
+ // Gather the indices-arrays in the dst tensor storage order.
+ SmallVector<Value, 4> xs(rank, Value());
+ for (int64_t i = 0; i < rank; i++) {
+ uint64_t orgDim = toOrigDim(encSrc, i);
+ xs[toStoredDim(encDst, orgDim)] = rewriter.create<ToIndicesOp>(
+ loc, indTp, src, rewriter.getIndexAttr(orgDim));
+ }
+
+ // Retrieve NNZ.
+ auto ptrTp =
+ MemRefType::get(dynShape, getPointerOverheadType(rewriter, encSrc));
+ Value p0 =
+ rewriter.create<ToIndicesOp>(loc, ptrTp, src, rewriter.getIndexAttr(0));
+ Value c1 = constantIndex(rewriter, loc, 1);
+ Value nnz = rewriter.create<memref::LoadOp>(loc, p0, c1);
+ nnz =
+ rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), nnz);
+
+ // Retrieve the values-array.
+ auto valTp = MemRefType::get(dynShape, srcTp.getElementType());
+ Value y = rewriter.create<ToValuesOp>(loc, valTp, src);
+
+ // Sort the COO tensor.
+ rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y});
+
+ // For each element in the COO tensor, insert the element to the dst tensor.
+ SmallVector<Value, 4> dynDstSizes;
+ getDynamicSizes(dstTp, srcSizes, dynDstSizes);
+ Value dst =
+ rewriter.create<AllocTensorOp>(loc, dstTp, dynDstSizes).getResult();
+ rewriter.create<ForeachOp>(
+ loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+ SmallVector<Value, 4> indices;
+ for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
+ uint64_t dim = toStoredDim(encDst, i);
+ indices.push_back(args[dim]);
+ }
+ builder.create<InsertOp>(loc, args.back(), dst, indices);
+ builder.create<sparse_tensor::YieldOp>(loc);
+ });
+
+ // Release the temporary COO if it is created.
+ if (tmpCoo)
+ rewriter.create<DeallocTensorOp>(loc, tmpCoo);
+
+ // Directly replace op with dst results in bufferization error message
+ // "sparse tensor allocation should not escape function".
+ // As such, we insert a trivial tensor convert which will be removed by
+ // codegen.
+ rewriter.setInsertionPointAfter(op);
+ rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, dst);
+ return success();
+ }
+};
+
/// Sparse rewriting rule for the foreach operator.
struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
public:
@@ -685,17 +895,19 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
//===---------------------------------------------------------------------===//
void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
bool enableRT, bool enableForeach,
- bool /*enableConvert*/) {
+ bool enableConvert) {
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
ReshapeRewriter<tensor::ExpandShapeOp>,
ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());
-
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
- if (!enableRT)
+ if (!enableRT) {
patterns.add<ConcatenateRewriter, NewRewriter, OutRewriter,
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>>(
patterns.getContext());
+ if (enableConvert)
+ patterns.add<ConvertRewriter>(patterns.getContext());
+ }
}
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 13772e83df08..d67e11b92dd9 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -1,4 +1,6 @@
// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
+// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
#SparseVector = #sparse_tensor.encoding<{
dimLevelType = ["compressed"]
@@ -100,6 +102,37 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
// CHECK: call @delSparseTensorCOOF64(%[[C]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
+
+// CHECK-RWT-LABEL: func.func @sparse_convert_2d(
+// CHECK-RWT-SAME: %[[A:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> {
+// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-RWT-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-RWT-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-RWT-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-RWT: %[[COO:.*]] = bufferization.alloc_tensor()
+// CHECK-RWT: scf.for %[[FI:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK-RWT: scf.for %[[FJ:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK-RWT: %[[V:.*]] = tensor.extract %[[A]]{{\[}}%[[FI]], %[[FJ]]] : tensor<2x4xf64>
+// CHECK-RWT: %[[NZ:.*]] = arith.cmpf une, %[[V]], %[[F0]] : f64
+// CHECK-RWT: scf.if %[[NZ]] {
+// CHECK-RWT: %{{.*}} = sparse_tensor.insert %[[V]] into %[[COO]]{{\[}}%[[FI]], %[[FJ]]]
+// CHECK-RWT: }
+// CHECK-RWT: }
+// CHECK-RWT: }
+// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
+// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index}
+// CHECK-RWT: %[[NNZ:.*]] = memref.load %[[I0]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK-RWT: %[[V2:.*]] = sparse_tensor.values %[[COO]]
+// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V2]]
+// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor()
+// CHECK-RWT: sparse_tensor.foreach in %[[COO]]
+// CHECK-RWT: ^bb0(%[[FI0:.*]]: index, %[[FI1:.*]]: index, %[[FV:.*]]: f64):
+// CHECK-RWT: sparse_tensor.insert %[[FV]] into %[[DST]]{{\[}}%[[FI0]], %[[FI1]]]
+// CHECK-RWT: }
+// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[DST]]
+// CHECK-RWT: bufferization.dealloc_tensor %[[COO]]
+// CHECK-RWT: return %[[R]] : tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
%0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
return %0 : tensor<2x4xf64, #CSR>
@@ -132,6 +165,35 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
// CHECK: call @delSparseTensorCOOF32(%[[C]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
+
+// CHECK-RWT-LABEL: func.func @sparse_constant()
+// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-RWT-DAG: %[[SI:.*]] = arith.constant dense<{{\[\[}}0, 0], [1, 6]]> : tensor<2x2xi64>
+// CHECK-RWT-DAG: %[[SV:.*]] = arith.constant dense<[1.000000e+00, 5.000000e+00]> : tensor<2xf32>
+// CHECK-RWT-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-RWT: %[[COO:.*]] = bufferization.alloc_tensor()
+// CHECK-RWT: scf.for %[[FI:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK-RWT: %[[I0r:.*]] = tensor.extract %[[SI]]{{\[}}%[[FI]], %[[C0]]] : tensor<2x2xi64>
+// CHECK-RWT: %[[I0:.*]] = arith.index_cast %[[I0r]] : i64 to index
+// CHECK-RWT: %[[I1r:.*]] = tensor.extract %[[SI]]{{\[}}%[[FI]], %[[C1]]] : tensor<2x2xi64>
+// CHECK-RWT: %[[I1:.*]] = arith.index_cast %[[I1r]] : i64 to index
+// CHECK-RWT: %[[V:.*]] = tensor.extract %[[SV]]{{\[}}%[[FI]]] : tensor<2xf32>
+// CHECK-RWT: sparse_tensor.insert %[[V]] into %[[COO]]{{\[}}%[[I0]], %[[I1]]]
+// CHECK-RWT: }
+// CHECK-RWT: %[[TI0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
+// CHECK-RWT: %[[TI1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index}
+// CHECK-RWT: %[[NNZ:.*]] = memref.load %[[TI0]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK-RWT: %[[TV:.*]] = sparse_tensor.values %[[COO]]
+// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[TI0]], %[[TI1]] jointly %[[TV]]
+// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor()
+// CHECK-RWT: sparse_tensor.foreach in %[[COO]]
+// CHECK-RWT: ^bb0(%[[F2I0:.*]]: index, %[[F2I1:.*]]: index, %[[F2V:.*]]: f32):
+// CHECK-RWT: sparse_tensor.insert %[[F2V]] into %[[DST]]{{\[}}%[[F2I0]], %[[F2I1]]]
+// CHECK-RWT: }
+// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[DST]]
+// CHECK-RWT: bufferization.dealloc_tensor %[[COO]]
+// CHECK-RWT: return %[[R]] : tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>
func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
// Initialize a tensor.
%0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
index ee7499a1c120..8980c4276e53 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
@@ -1,5 +1,8 @@
// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
+// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
+
#SparseVector = #sparse_tensor.encoding<{
dimLevelType = ["compressed"]
}>
@@ -128,6 +131,18 @@ func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<
// CHECK: }
// CHECK: %[[T:.*]] = bufferization.to_tensor %[[M]] : memref<2x4xf64>
// CHECK: return %[[T]] : tensor<2x4xf64>
+
+// CHECK-RWT-LABEL: func.func @sparse_convert_2d(
+// CHECK-RWT-SAME: %[[A:.*]]: tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>) -> tensor<2x4xf64> {
+// CHECK-RWT: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-RWT: %[[B:.*]] = memref.alloc() : memref<2x4xf64>
+// CHECK-RWT: linalg.fill ins(%[[F0]] : f64) outs(%[[B]]
+// CHECK-RWT: sparse_tensor.foreach in %[[A]]
+// CHECK-RWT: ^bb0(%[[FI0:.*]]: index, %[[FI1:.*]]: index, %[[FV:.*]]: f64):
+// CHECK-RWT: memref.store %[[FV]], %[[B]]{{\[}}%[[FI0]], %[[FI1]]]
+// CHECK-RWT: }
+// CHECK-RWT: %[[T:.*]] = bufferization.to_tensor %[[B]]
+// CHECK-RWT: return %[[T]] : tensor<2x4xf64>
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x4xf64, #SparseMatrix> to tensor<2x4xf64>
return %0 : tensor<2x4xf64>
@@ -260,6 +275,22 @@ func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tens
// CHECK: }
// CHECK: %[[T:.*]] = bufferization.to_tensor %[[M]] : memref<?x?xf64>
// CHECK: return %[[T]] : tensor<?x?xf64>
+
+// CHECK-RWT-LABEL: func.func @sparse_convert_2d_dyn2(
+// CHECK-RWT-SAME: %[[A:.*]]: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>) -> tensor<?x?xf64> {
+// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-RWT-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-RWT: %[[D0:.*]] = tensor.dim %[[A]], %[[C0]]
+// CHECK-RWT: %[[D1:.*]] = tensor.dim %[[A]], %[[C1]]
+// CHECK-RWT: %[[B:.*]] = memref.alloc(%[[D0]], %[[D1]])
+// CHECK-RWT: linalg.fill ins(%[[F0]] : f64) outs(%[[B]]
+// CHECK-RWT: sparse_tensor.foreach in %[[A]]
+// CHECK-RWT: ^bb0(%[[FI0:.*]]: index, %[[FI1:.*]]: index, %[[FV:.*]]: f64):
+// CHECK-RWT: memref.store %[[FV]], %[[B]]{{\[}}%[[FI0]], %[[FI1]]]
+// CHECK-RWT: }
+// CHECK-RWT: %[[T:.*]] = bufferization.to_tensor %[[B]]
+// CHECK-RWT: return %[[T]] : tensor<?x?xf64>
func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<?x?xf64, #SparseMatrix> to tensor<?x?xf64>
return %0 : tensor<?x?xf64>
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index cd5575b8e10a..92f9e46b9093 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -6,6 +6,9 @@
// RUN: mlir-opt %s --sparse-tensor-conversion="s2s-strategy=0" \
// RUN: --canonicalize --cse | FileCheck %s -check-prefixes=CHECK-AUTO,CHECK
+// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
+// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
+
#SparseVector64 = #sparse_tensor.encoding<{
dimLevelType = ["compressed"],
pointerBitWidth = 64,
@@ -79,6 +82,24 @@ func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor
// CHECK-AUTO-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xindex> to memref<?xindex>
// CHECK-AUTO: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[SparseToSparse]], %[[A]])
// CHECK-AUTO: return %[[T]] : !llvm.ptr<i8>
+
+// CHECK-RWT-LABEL: func.func @sparse_convert(
+// CHECK-RWT-SAME: %[[A:.*]]: tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 64, indexBitWidth = 64 }>>)
+// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-RWT: %[[D:.*]] = tensor.dim %[[A]], %[[C0]]
+// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[A]] {dimension = 0 : index}
+// CHECK-RWT: %[[NNZr:.*]] = memref.load %[[I0]]{{\[}}%[[C1]]] : memref<?xi64>
+// CHECK-RWT: %[[NNZ:.*]] = arith.index_cast %[[NNZr]] : i64 to index
+// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[A]]
+// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]] jointly %[[V]]
+// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor(%[[D]])
+// CHECK-RWT: sparse_tensor.foreach in %[[A]]
+// CHECK-RWT: ^bb0(%[[FI2:.*]]: index, %[[FV2:.*]]: f32):
+// CHECK-RWT: sparse_tensor.insert %[[FV2]] into %[[DST]]{{\[}}%[[FI2]]]
+// CHECK-RWT: }
+// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[DST]]
+// CHECK-RWT: return %[[R]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 32, indexBitWidth = 32 }>>
func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
return %0 : tensor<?xf32, #SparseVector32>
diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
index 3d2a5e2f50c1..79b616dec830 100644
--- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -sparse-tensor-rewrite=enable-runtime-library=false | FileCheck %s
+// RUN: mlir-opt %s -sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" |\
+// RUN: FileCheck %s
#CSR = #sparse_tensor.encoding<{
dimLevelType = ["dense", "compressed"]
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
index 8ab23e18e2d1..7280c6f5e7ba 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s --sparse-tensor-rewrite=enable-runtime-library=false --sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: --sparsification | FileCheck %s
#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index adc1f6de22d1..c162bacffac9 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
// RUN: mlir-opt %s --sparse-tensor-conversion --cse --canonicalize | FileCheck %s --check-prefix=CHECK-CONV
-// RUN: mlir-opt %s --sparse-tensor-rewrite=enable-runtime-library=false --cse --canonicalize | FileCheck %s --check-prefix=CHECK-RWT
+// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: --cse --canonicalize | FileCheck %s --check-prefix=CHECK-RWT
#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>