Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/llvm/llvm-project.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp165
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp12
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp67
-rwxr-xr-x[-rw-r--r--]mlir/test/Dialect/SparseTensor/rewriting.mlir16
-rw-r--r--mlir/test/Dialect/SparseTensor/sparse_reshape.mlir79
-rwxr-xr-x[-rw-r--r--]mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir30
6 files changed, 285 insertions, 84 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index dae378490d2f..282af7aed2df 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -238,7 +238,7 @@ static void newParams(OpBuilder &builder, SmallVector<Value, 8> &params,
/// the following and the insertion point after this routine is inside the
/// if-then branch behind the assignment to ind. This is to ensure that the
/// addEltX call generated after is inside the if-then branch.
-/// if (tensor[ivs]!=0) {
+/// if (tensor[ivs] != 0)
/// ind = ivs
static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
Value tensor, Value ind, ValueRange ivs) {
@@ -382,6 +382,133 @@ static bool canUseDirectConversion(
return true;
}
+/// Helper method to translate indices during a reshaping operation.
+/// TODO: provide as general utility to MLIR at large?
+static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
+ ArrayRef<ReassociationIndices> reassociation,
+ TensorType dstTp, TensorType srcTp, Value dstIdx,
+ Value srcIdx) {
+ unsigned dstRank = dstTp.getRank();
+ unsigned srcRank = srcTp.getRank();
+ unsigned start = 0;
+ unsigned i = 0;
+ bool isExpand = srcRank > dstRank;
+ ArrayRef<int64_t> shape = isExpand ? srcTp.getShape() : dstTp.getShape();
+ // Iterate over reassociation map.
+ for (const auto &map : llvm::enumerate(reassociation)) {
+ // Prepare strides information in dimension slice.
+ uint64_t linear = 1;
+ for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
+ assert(!ShapedType::isDynamic(shape[j]));
+ linear *= shape[j];
+ }
+ // Start collapse.
+ Value idx = constantIndex(rewriter, loc, i++);
+ Value val;
+ if (!isExpand)
+ val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx);
+ // Iterate over dimension slice.
+ for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
+ linear /= shape[j];
+ Value stride = constantIndex(rewriter, loc, linear);
+ Value jdx = constantIndex(rewriter, loc, j);
+ if (isExpand) {
+ Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx);
+ Value mul = linear == 1
+ ? old
+ : rewriter.create<arith::MulIOp>(loc, old, stride);
+ val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul;
+ } else {
+ Value old = val;
+ if (linear != 1)
+ val = rewriter.create<arith::DivUIOp>(loc, val, stride);
+ rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx);
+ if (linear != 1)
+ val = rewriter.create<arith::RemUIOp>(loc, old, stride);
+ }
+ }
+ // Finalize expansion.
+ if (isExpand)
+ rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx);
+ start += map.value().size();
+ }
+ // Sanity.
+ assert((isExpand && i == dstRank) || (!isExpand && i == srcRank));
+}
+
+/// Generate code for a general sparse to sparse reshaping operation.
+/// Note that unlike dense reshaping (which can be done with a "cheap"
+/// change of view), sparse reshaping is currently done with actual
+/// data shuffling.
+///
+/// TODO: proportional to nnz, but still a lot of data movement
+/// https://github.com/llvm/llvm-project/issues/56477
+///
+/// iter = src->toCOO();
+/// coo = newSparseCOO()
+/// while (elem = iter->getNext()) {
+/// coo->add(reshape(elem.indices), elem.value)
+/// }
+/// s = newSparseTensor(coo)
+static LogicalResult
+genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
+ ArrayRef<ReassociationIndices> reassociation, Value src,
+ RankedTensorType dstTp, RankedTensorType srcTp) {
+ Location loc = op->getLoc();
+ auto encDst = getSparseTensorEncoding(dstTp);
+ auto encSrc = getSparseTensorEncoding(srcTp);
+ assert(encDst && encSrc);
+ unsigned srcRank = srcTp.getRank();
+ unsigned dstRank = dstTp.getRank();
+ Type elemTp = srcTp.getElementType();
+ assert(elemTp == dstTp.getElementType() &&
+ "reshape should not change element type");
+ // Start an iterator over the source tensor (in original index order).
+ auto noPerm = SparseTensorEncodingAttr::get(
+ op->getContext(), encSrc.getDimLevelType(), AffineMap(),
+ encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
+ SmallVector<Value, 4> sizes;
+ SmallVector<Value, 8> params;
+ sizesFromPtr(rewriter, sizes, op, noPerm, srcTp, src);
+ newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, sizes,
+ src);
+ Value iter = genNewCall(rewriter, op, params);
+ // Start a new COO for the destination tensor.
+ sizes.clear();
+ params.clear();
+ sizesFromPtr(rewriter, sizes, op, encDst, dstTp, src);
+ newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, sizes);
+ Value coo = genNewCall(rewriter, op, params);
+ Value dstPerm = params[2];
+ // Construct a while loop over the iterator.
+ Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
+ Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
+ Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
+ SmallVector<Value> noArgs;
+ SmallVector<Type> noTypes;
+ auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
+ Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
+ rewriter.setInsertionPointToEnd(before);
+ Value cond = genGetNextCall(rewriter, op, iter, srcIdx, elemPtr);
+ rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
+ // Translate indices from source to target and insert. Note that we do
+ // not need to store the value in elemPtr, as the value is still there.
+ Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
+ rewriter.setInsertionPointToStart(after);
+ translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx);
+ genAddEltCall(rewriter, op, elemTp, coo, elemPtr, dstIdx, dstPerm);
+ rewriter.create<scf::YieldOp>(loc);
+ // Final call to construct sparse tensor storage and free temporary resources.
+ rewriter.setInsertionPointAfter(whileOp);
+ params[6] = constantAction(rewriter, loc, Action::kFromCOO);
+ params[7] = coo;
+ Value dst = genNewCall(rewriter, op, params);
+ genDelCOOCall(rewriter, op, elemTp, coo);
+ genDelCOOCall(rewriter, op, elemTp, iter);
+ rewriter.replaceOp(op, dst);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
@@ -423,6 +550,7 @@ public:
/// Sparse conversion rule for trivial tensor casts.
class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
+public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
@@ -437,8 +565,30 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
}
};
+/// Sparse conversion rule for a reshape operator.
+template <typename ReshapeOp>
+class SparseReshapeConverter : public OpConversionPattern<ReshapeOp> {
+public:
+ using OpAdaptor = typename OpConversionPattern<ReshapeOp>::OpAdaptor;
+ using OpConversionPattern<ReshapeOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type dstType = op.getResult().getType();
+ Type srcType = op.getSrc().getType();
+ auto encDst = getSparseTensorEncoding(dstType);
+ auto encSrc = getSparseTensorEncoding(srcType);
+ if (encDst && encSrc)
+ return genSparse2SparseReshape(
+ op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0],
+ dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>());
+ return failure(); // handled elsewhere
+ }
+};
+
/// Sparse conversion rule for the new operator.
class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
+public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(NewOp op, OpAdaptor adaptor,
@@ -463,6 +613,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
/// Sparse conversion rule for the alloc operator.
class SparseTensorAllocConverter
: public OpConversionPattern<bufferization::AllocTensorOp> {
+public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
@@ -494,9 +645,6 @@ class SparseTensorAllocConverter
/// Sparse conversion rule for the convert operator.
class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
- /// Options to control sparse code generation.
- SparseTensorConversionOptions options;
-
public:
using OpConversionPattern::OpConversionPattern;
SparseTensorConvertConverter(MLIRContext *context,
@@ -697,6 +845,10 @@ public:
rewriter.replaceOp(op, dst);
return success();
}
+
+private:
+ /// Options to control sparse code generation.
+ SparseTensorConversionOptions options;
};
/// Sparse conversion rule for the release operator.
@@ -799,6 +951,7 @@ public:
}
};
+/// Sparse conversion rule for the expand operator.
class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -841,6 +994,7 @@ public:
}
};
+/// Sparse conversion rule for the compress operator.
class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -873,6 +1027,7 @@ public:
}
};
+/// Sparse conversion rule for the output operator.
class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -926,6 +1081,8 @@ void mlir::populateSparseTensorConversionPatterns(
const SparseTensorConversionOptions &options) {
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
SparseCastConverter, SparseTensorNewConverter,
+ SparseReshapeConverter<tensor::ExpandShapeOp>,
+ SparseReshapeConverter<tensor::CollapseShapeOp>,
SparseTensorAllocConverter, SparseTensorReleaseConverter,
SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
SparseTensorToValuesConverter, SparseTensorLoadConverter,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index f85f47a29ec9..1f157eab3c57 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -127,13 +127,11 @@ struct SparseTensorConversionPass
});
// The following operations and dialects may be introduced by the
// rewriting rules, and are therefore marked as legal.
- target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
- arith::IndexCastOp, complex::ConstantOp,
- complex::NotEqualOp, linalg::FillOp, linalg::YieldOp,
- tensor::ExtractOp>();
- target
- .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
- memref::MemRefDialect, scf::SCFDialect>();
+ target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
+ linalg::YieldOp, tensor::ExtractOp>();
+ target.addLegalDialect<
+ arith::ArithmeticDialect, bufferization::BufferizationDialect,
+ LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
// Translate strategy flags to strategy options.
SparseTensorConversionOptions options(
sparseToSparseConversionStrategy(sparseToSparse));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0a5364ef6c60..d892aa67f8c8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1832,71 +1832,38 @@ private:
SparsificationOptions options;
};
-/// Sparse rewriting rule for expand shape operator.
-struct ExpandShapeRewriter : public OpRewritePattern<tensor::ExpandShapeOp> {
+/// Sparse rewriting rule for reshape operator.
+template <typename ReshapeOp>
+struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
public:
- using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
+ using OpRewritePattern<ReshapeOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(tensor::ExpandShapeOp op,
+ LogicalResult matchAndRewrite(ReshapeOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto encDst = getSparseTensorEncoding(op.getResult().getType());
auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
// Since a pure dense expansion is very cheap (change of view), for
- // sparse2dense or dense2sparse, we can simply unfuse a sparse
- // conversion from the actual expansion operation itself.
+ // a sparse2dense or dense2sparse, we can simply unfuse a sparse
+ // conversion from the reshape operation itself.
+ // All other cases are handled elsewhere.
if (encDst && encSrc) {
- return failure(); // TODO: implement sparse2sparse
- } else if (encSrc) {
- RankedTensorType rtp = op.getSrc().getType().cast<RankedTensorType>();
- auto denseTp =
- RankedTensorType::get(rtp.getShape(), rtp.getElementType());
- auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
- op->setOperand(0, convert);
- return success();
- } else if (encDst) {
- RankedTensorType rtp = op.getResult().getType().cast<RankedTensorType>();
- auto denseTp =
- RankedTensorType::get(rtp.getShape(), rtp.getElementType());
- auto reshape = rewriter.create<tensor::ExpandShapeOp>(
- loc, denseTp, op.getSrc(), op.getReassociation());
- Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
- rewriter.replaceOp(op, convert);
- return success();
- }
- return failure();
- }
-};
-
-/// Sparse rewriting rule for collapse shape operator.
-struct CollapseShapeRewriter
- : public OpRewritePattern<tensor::CollapseShapeOp> {
-public:
- using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::CollapseShapeOp op,
- PatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
- auto encDst = getSparseTensorEncoding(op.getResult().getType());
- auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
- // Since a pure dense collapse is very cheap (change of view), for
- // sparse2dense or dense2sparse, we can simply unfuse a sparse
- // conversion from the actual collapse operation itself.
- if (encDst && encSrc) {
- return failure(); // TODO: implement sparse2sparse
+ return failure();
} else if (encSrc) {
- RankedTensorType rtp = op.getSrc().getType().cast<RankedTensorType>();
+ RankedTensorType rtp =
+ op.getSrc().getType().template cast<RankedTensorType>();
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
op->setOperand(0, convert);
return success();
} else if (encDst) {
- RankedTensorType rtp = op.getResult().getType().cast<RankedTensorType>();
+ RankedTensorType rtp =
+ op.getResult().getType().template cast<RankedTensorType>();
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
- auto reshape = rewriter.create<tensor::CollapseShapeOp>(
- loc, denseTp, op.getSrc(), op.getReassociation());
+ auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
+ op.getReassociation());
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
rewriter.replaceOp(op, convert);
return success();
@@ -1912,6 +1879,6 @@ public:
void mlir::populateSparsificationPatterns(
RewritePatternSet &patterns, const SparsificationOptions &options) {
patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
- patterns.add<ExpandShapeRewriter, CollapseShapeRewriter>(
- patterns.getContext());
+ patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
+ ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/rewriting.mlir
index 3955310fce9b..000c3560f1e0 100644..100755
--- a/mlir/test/Dialect/SparseTensor/rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting.mlir
@@ -40,8 +40,14 @@ func.func @expand_to_sparse(%arg0: tensor<12xf64>) -> tensor<3x4xf64, #SparseMat
return %0 : tensor<3x4xf64, #SparseMatrix>
}
-// TODO: make this work
+//
+// Not rewritten, needs conversion.
+//
// CHECK-LABEL: func.func @expand_sparse2sparse(
+// CHECK-SAME: %[[A:.*]]: tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> {
+// CHECK: %[[E:.*]] = tensor.expand_shape %[[A]] {{.*}} : tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: return %[[E]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: }
func.func @expand_sparse2sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> {
%0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix>
return %0 : tensor<3x4xf64, #SparseMatrix>
@@ -79,8 +85,14 @@ func.func @collapse_to_sparse(%arg0: tensor<3x4xf64>) -> tensor<12xf64, #SparseV
return %0 : tensor<12xf64, #SparseVector>
}
-// TODO: make this work
+//
+// Not rewritten, needs conversion.
+//
// CHECK-LABEL: func.func @collapse_sparse2sparse(
+// CHECK-SAME: %[[A:.*]]: tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>> {
+// CHECK: %[[C:.*]] = tensor.collapse_shape %[[A]] {{.*}} : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: return %[[C]] : tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: }
func.func @collapse_sparse2sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64, #SparseVector>
return %0 : tensor<12xf64, #SparseVector>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index c791536e1519..65eb56b9bac3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -1,24 +1,81 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-
-// TODO: check lowering to an actual implementation
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
+// RUN: mlir-opt %s --sparse-tensor-conversion --cse | FileCheck %s --check-prefix=CHECK-CONV
#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
-// CHECK-LABEL: func.func @sparse_expand(
-// CHECK-SAME: %[[A:.*]]: tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: return %[[E]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//
+// roundtrip:
+//
+// CHECK-ROUND-LABEL: func.func @sparse_expand(
+// CHECK-ROUND-SAME: %[[A:.*]]: tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-ROUND: return %[[E]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//
+// conversion:
+//
+// CHECK-CONV-LABEL: func.func @sparse_expand(
+// CHECK-CONV-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-CONV-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-CONV-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-CONV-DAG: call @newSparseTensor
+// CHECK-CONV-DAG: call @newSparseTensor
+// CHECK-CONV: scf.while : () -> () {
+// CHECK-CONV: call @getNextF64
+// CHECK-CONV: scf.condition(%13)
+// CHECK-CONV: } do {
+// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xindex>
+// CHECK-CONV: %[[D:.*]] = arith.divui %[[X]], %[[C10]] : index
+// CHECK-CONV: memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<?xindex>
+// CHECK-CONV: %[[R:.*]] = arith.remui %[[X]], %[[C10]] : index
+// CHECK-CONV: memref.store %[[R]], %{{.*}}[%[[C1]]] : memref<?xindex>
+// CHECK-CONV: call @addEltF64
+// CHECK-CONV: scf.yield
+// CHECK-CONV: }
+// CHECK-CONV: %[[N:.*]] = call @newSparseTensor
+// CHECK-CONV: call @delSparseTensorCOOF64
+// CHECK-CONV: call @delSparseTensorCOOF64
+// CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
+//
func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
%0 = tensor.expand_shape %arg0 [[0, 1]] :
tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix>
return %0 : tensor<10x10xf64, #SparseMatrix>
}
-// CHECK-LABEL: func.func @sparse_collapse(
-// CHECK-SAME: %[[A:.*]]: tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: %[[C:.*]] = tensor.collapse_shape %[[A]] {{\[\[}}0, 1]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: return %[[C]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//
+// roundtrip:
+//
+// CHECK-ROUND-LABEL: func.func @sparse_collapse(
+// CHECK-ROUND-SAME: %[[A:.*]]: tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-ROUND: %[[C:.*]] = tensor.collapse_shape %[[A]] {{\[\[}}0, 1]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-ROUND: return %[[C]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//
+// conversion:
+//
+// CHECK-CONV-LABEL: func.func @sparse_collapse(
+// CHECK-CONV-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-CONV-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-CONV-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-CONV-DAG: call @newSparseTensor
+// CHECK-CONV-DAG: call @newSparseTensor
+// CHECK-CONV: scf.while : () -> () {
+// CHECK-CONV: call @getNextF64
+// CHECK-CONV: scf.condition(%13)
+// CHECK-CONV: } do {
+// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xindex>
+// CHECK-CONV: %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index
+// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xindex>
+// CHECK-CONV: %[[A:.*]] = arith.addi %[[M]], %[[Y]] : index
+// CHECK-CONV: memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<?xindex>
+// CHECK-CONV: call @addEltF64
+// CHECK-CONV: scf.yield
+// CHECK-CONV: }
+// CHECK-CONV: %[[N:.*]] = call @newSparseTensor
+// CHECK-CONV: call @delSparseTensorCOOF64
+// CHECK-CONV: call @delSparseTensorCOOF64
+// CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
+//
func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] :
tensor<10x10xf64, #SparseMatrix> into tensor<100xf64, #SparseVector>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
index eed880866884..57d2a931fad5 100644..100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
@@ -32,11 +32,10 @@ module {
return %0 : tensor<3x4xf64, #SparseMatrix>
}
-// TODO: make this work
-// func.func @expand_sparse2sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> {
-// %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix>
-// return %0 : tensor<3x4xf64, #SparseMatrix>
-// }
+ func.func @expand_sparse2sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> {
+ %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix>
+ return %0 : tensor<3x4xf64, #SparseMatrix>
+ }
func.func @collapse_dense(%arg0: tensor<3x4xf64>) -> tensor<12xf64> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64> into tensor<12xf64>
@@ -53,11 +52,10 @@ module {
return %0 : tensor<12xf64, #SparseVector>
}
-// TODO: make this work
-// func.func @collapse_sparse2sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> {
-// %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64, #SparseVector>
-// return %0 : tensor<12xf64, #SparseVector>
-// }
+ func.func @collapse_sparse2sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64, #SparseVector>
+ return %0 : tensor<12xf64, #SparseVector>
+ }
//
@@ -81,10 +79,12 @@ module {
%expand0 = call @expand_dense(%v) : (tensor<12xf64>) -> tensor<3x4xf64>
%expand1 = call @expand_from_sparse(%sv) : (tensor<12xf64, #SparseVector>) -> tensor<3x4xf64>
%expand2 = call @expand_to_sparse(%v) : (tensor<12xf64>) -> tensor<3x4xf64, #SparseMatrix>
+ %expand3 = call @expand_sparse2sparse(%sv) : (tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix>
%collapse0 = call @collapse_dense(%m) : (tensor<3x4xf64>) -> tensor<12xf64>
%collapse1 = call @collapse_from_sparse(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64>
%collapse2 = call @collapse_to_sparse(%m) : (tensor<3x4xf64>) -> tensor<12xf64, #SparseVector>
+ %collapse3 = call @collapse_sparse2sparse(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector>
//
// Verify result.
@@ -92,9 +92,11 @@ module {
// CHECK: ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) )
// CHECK-NEXT: ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) )
// CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1, -1, -1, -1 )
+ // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1, -1, -1, -1 )
// CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4 )
// CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4 )
// CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, -1, -1, -1, -1 )
+ // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, -1, -1, -1, -1 )
//
%m0 = vector.transfer_read %expand0[%c0, %c0], %df: tensor<3x4xf64>, vector<3x4xf64>
vector.print %m0 : vector<3x4xf64>
@@ -103,6 +105,9 @@ module {
%a2 = sparse_tensor.values %expand2 : tensor<3x4xf64, #SparseMatrix> to memref<?xf64>
%m2 = vector.transfer_read %a2[%c0], %df: memref<?xf64>, vector<16xf64>
vector.print %m2 : vector<16xf64>
+ %a3 = sparse_tensor.values %expand3 : tensor<3x4xf64, #SparseMatrix> to memref<?xf64>
+ %m3 = vector.transfer_read %a3[%c0], %df: memref<?xf64>, vector<16xf64>
+ vector.print %m3 : vector<16xf64>
%v0 = vector.transfer_read %collapse0[%c0], %df: tensor<12xf64>, vector<12xf64>
vector.print %v0 : vector<12xf64>
@@ -111,12 +116,17 @@ module {
%b2 = sparse_tensor.values %collapse2 : tensor<12xf64, #SparseVector> to memref<?xf64>
%v2 = vector.transfer_read %b2[%c0], %df: memref<?xf64>, vector<16xf64>
vector.print %v2 : vector<16xf64>
+ %b3 = sparse_tensor.values %collapse3 : tensor<12xf64, #SparseVector> to memref<?xf64>
+ %v3 = vector.transfer_read %b3[%c0], %df: memref<?xf64>, vector<16xf64>
+ vector.print %v3 : vector<16xf64>
// Release sparse resources.
sparse_tensor.release %sv : tensor<12xf64, #SparseVector>
sparse_tensor.release %sm : tensor<3x4xf64, #SparseMatrix>
sparse_tensor.release %expand2 : tensor<3x4xf64, #SparseMatrix>
+ sparse_tensor.release %expand3 : tensor<3x4xf64, #SparseMatrix>
sparse_tensor.release %collapse2 : tensor<12xf64, #SparseVector>
+ sparse_tensor.release %collapse3 : tensor<12xf64, #SparseVector>
// Release dense resources.
%meme1 = bufferization.to_memref %expand1 : memref<3x4xf64>