diff options
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp | 165 | ||||
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp | 12 | ||||
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp | 67 | ||||
-rwxr-xr-x[-rw-r--r--] | mlir/test/Dialect/SparseTensor/rewriting.mlir | 16 | ||||
-rw-r--r-- | mlir/test/Dialect/SparseTensor/sparse_reshape.mlir | 79 | ||||
-rwxr-xr-x[-rw-r--r--] | mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir | 30 |
6 files changed, 285 insertions, 84 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index dae378490d2f..282af7aed2df 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -238,7 +238,7 @@ static void newParams(OpBuilder &builder, SmallVector<Value, 8> ¶ms, /// the following and the insertion point after this routine is inside the /// if-then branch behind the assignment to ind. This is to ensure that the /// addEltX call generated after is inside the if-then branch. -/// if (tensor[ivs]!=0) { +/// if (tensor[ivs] != 0) /// ind = ivs static Value genIndexAndValueForDense(OpBuilder &builder, Location loc, Value tensor, Value ind, ValueRange ivs) { @@ -382,6 +382,133 @@ static bool canUseDirectConversion( return true; } +/// Helper method to translate indices during a reshaping operation. +/// TODO: provide as general utility to MLIR at large? +static void translateIndices(Location loc, ConversionPatternRewriter &rewriter, + ArrayRef<ReassociationIndices> reassociation, + TensorType dstTp, TensorType srcTp, Value dstIdx, + Value srcIdx) { + unsigned dstRank = dstTp.getRank(); + unsigned srcRank = srcTp.getRank(); + unsigned start = 0; + unsigned i = 0; + bool isExpand = srcRank > dstRank; + ArrayRef<int64_t> shape = isExpand ? srcTp.getShape() : dstTp.getShape(); + // Iterate over reassociation map. + for (const auto &map : llvm::enumerate(reassociation)) { + // Prepare strides information in dimension slice. + uint64_t linear = 1; + for (unsigned j = start, end = start + map.value().size(); j < end; j++) { + assert(!ShapedType::isDynamic(shape[j])); + linear *= shape[j]; + } + // Start collapse. + Value idx = constantIndex(rewriter, loc, i++); + Value val; + if (!isExpand) + val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx); + // Iterate over dimension slice. + for (unsigned j = start, end = start + map.value().size(); j < end; j++) { + linear /= shape[j]; + Value stride = constantIndex(rewriter, loc, linear); + Value jdx = constantIndex(rewriter, loc, j); + if (isExpand) { + Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx); + Value mul = linear == 1 + ? old + : rewriter.create<arith::MulIOp>(loc, old, stride); + val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul; + } else { + Value old = val; + if (linear != 1) + val = rewriter.create<arith::DivUIOp>(loc, val, stride); + rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx); + if (linear != 1) + val = rewriter.create<arith::RemUIOp>(loc, old, stride); + } + } + // Finalize expansion. + if (isExpand) + rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx); + start += map.value().size(); + } + // Sanity. + assert((isExpand && i == dstRank) || (!isExpand && i == srcRank)); +} + +/// Generate code for a general sparse to sparse reshaping operation. +/// Note that unlike dense reshaping (which can be done with a "cheap" +/// change of view), sparse reshaping is currently done with actual +/// data shuffling. +/// +/// TODO: proportional to nnz, but still a lot of data movement +/// https://github.com/llvm/llvm-project/issues/56477 +/// +/// iter = src->toCOO(); +/// coo = newSparseCOO() +/// while (elem = iter->getNext()) { +/// coo->add(reshape(elem.indices), elem.value) +/// } +/// s = newSparseTensor(coo) +static LogicalResult +genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter, + ArrayRef<ReassociationIndices> reassociation, Value src, + RankedTensorType dstTp, RankedTensorType srcTp) { + Location loc = op->getLoc(); + auto encDst = getSparseTensorEncoding(dstTp); + auto encSrc = getSparseTensorEncoding(srcTp); + assert(encDst && encSrc); + unsigned srcRank = srcTp.getRank(); + unsigned dstRank = dstTp.getRank(); + Type elemTp = srcTp.getElementType(); + assert(elemTp == dstTp.getElementType() && + "reshape should not change element type"); + // Start an iterator over the source tensor (in original index order). + auto noPerm = SparseTensorEncodingAttr::get( + op->getContext(), encSrc.getDimLevelType(), AffineMap(), + encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); + SmallVector<Value, 4> sizes; + SmallVector<Value, 8> params; + sizesFromPtr(rewriter, sizes, op, noPerm, srcTp, src); + newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, sizes, + src); + Value iter = genNewCall(rewriter, op, params); + // Start a new COO for the destination tensor. + sizes.clear(); + params.clear(); + sizesFromPtr(rewriter, sizes, op, encDst, dstTp, src); + newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, sizes); + Value coo = genNewCall(rewriter, op, params); + Value dstPerm = params[2]; + // Construct a while loop over the iterator. + Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType()); + Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType()); + Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); + SmallVector<Value> noArgs; + SmallVector<Type> noTypes; + auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); + Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); + rewriter.setInsertionPointToEnd(before); + Value cond = genGetNextCall(rewriter, op, iter, srcIdx, elemPtr); + rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); + // Translate indices from source to target and insert. Note that we do + // not need to store the value in elemPtr, as the value is still there. + Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); + rewriter.setInsertionPointToStart(after); + translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx); + genAddEltCall(rewriter, op, elemTp, coo, elemPtr, dstIdx, dstPerm); + rewriter.create<scf::YieldOp>(loc); + // Final call to construct sparse tensor storage and free temporary resources. + rewriter.setInsertionPointAfter(whileOp); + params[6] = constantAction(rewriter, loc, Action::kFromCOO); + params[7] = coo; + Value dst = genNewCall(rewriter, op, params); + genDelCOOCall(rewriter, op, elemTp, coo); + genDelCOOCall(rewriter, op, elemTp, iter); + rewriter.replaceOp(op, dst); + return success(); +} + //===----------------------------------------------------------------------===// // Conversion rules. //===----------------------------------------------------------------------===// @@ -423,6 +550,7 @@ public: /// Sparse conversion rule for trivial tensor casts. class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { +public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, @@ -437,8 +565,30 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { } }; +/// Sparse conversion rule for a reshape operator. +template <typename ReshapeOp> +class SparseReshapeConverter : public OpConversionPattern<ReshapeOp> { +public: + using OpAdaptor = typename OpConversionPattern<ReshapeOp>::OpAdaptor; + using OpConversionPattern<ReshapeOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstType = op.getResult().getType(); + Type srcType = op.getSrc().getType(); + auto encDst = getSparseTensorEncoding(dstType); + auto encSrc = getSparseTensorEncoding(srcType); + if (encDst && encSrc) + return genSparse2SparseReshape( + op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0], + dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>()); + return failure(); // handled elsewhere + } +}; + /// Sparse conversion rule for the new operator. class SparseTensorNewConverter : public OpConversionPattern<NewOp> { +public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(NewOp op, OpAdaptor adaptor, @@ -463,6 +613,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> { /// Sparse conversion rule for the alloc operator. class SparseTensorAllocConverter : public OpConversionPattern<bufferization::AllocTensorOp> { +public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, @@ -494,9 +645,6 @@ class SparseTensorAllocConverter /// Sparse conversion rule for the convert operator. class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { - /// Options to control sparse code generation. - SparseTensorConversionOptions options; - public: using OpConversionPattern::OpConversionPattern; SparseTensorConvertConverter(MLIRContext *context, @@ -697,6 +845,10 @@ public: rewriter.replaceOp(op, dst); return success(); } + +private: + /// Options to control sparse code generation. + SparseTensorConversionOptions options; }; /// Sparse conversion rule for the release operator. @@ -799,6 +951,7 @@ public: } }; +/// Sparse conversion rule for the expand operator. class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { public: using OpConversionPattern::OpConversionPattern; @@ -841,6 +994,7 @@ public: } }; +/// Sparse conversion rule for the compress operator. class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { public: using OpConversionPattern::OpConversionPattern; @@ -873,6 +1027,7 @@ public: } }; +/// Sparse conversion rule for the output operator. class SparseTensorOutConverter : public OpConversionPattern<OutOp> { public: using OpConversionPattern::OpConversionPattern; @@ -926,6 +1081,8 @@ void mlir::populateSparseTensorConversionPatterns( const SparseTensorConversionOptions &options) { patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, SparseCastConverter, SparseTensorNewConverter, + SparseReshapeConverter<tensor::ExpandShapeOp>, + SparseReshapeConverter<tensor::CollapseShapeOp>, SparseTensorAllocConverter, SparseTensorReleaseConverter, SparseTensorToPointersConverter, SparseTensorToIndicesConverter, SparseTensorToValuesConverter, SparseTensorLoadConverter, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index f85f47a29ec9..1f157eab3c57 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -127,13 +127,11 @@ struct SparseTensorConversionPass }); // The following operations and dialects may be introduced by the // rewriting rules, and are therefore marked as legal. - target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp, - arith::IndexCastOp, complex::ConstantOp, - complex::NotEqualOp, linalg::FillOp, linalg::YieldOp, - tensor::ExtractOp>(); - target - .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect, - memref::MemRefDialect, scf::SCFDialect>(); + target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp, + linalg::YieldOp, tensor::ExtractOp>(); + target.addLegalDialect< + arith::ArithmeticDialect, bufferization::BufferizationDialect, + LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>(); // Translate strategy flags to strategy options. SparseTensorConversionOptions options( sparseToSparseConversionStrategy(sparseToSparse)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 0a5364ef6c60..d892aa67f8c8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1832,71 +1832,38 @@ private: SparsificationOptions options; }; -/// Sparse rewriting rule for expand shape operator. -struct ExpandShapeRewriter : public OpRewritePattern<tensor::ExpandShapeOp> { +/// Sparse rewriting rule for reshape operator. +template <typename ReshapeOp> +struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> { public: - using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern; + using OpRewritePattern<ReshapeOp>::OpRewritePattern; - LogicalResult matchAndRewrite(tensor::ExpandShapeOp op, + LogicalResult matchAndRewrite(ReshapeOp op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto encDst = getSparseTensorEncoding(op.getResult().getType()); auto encSrc = getSparseTensorEncoding(op.getSrc().getType()); // Since a pure dense expansion is very cheap (change of view), for - // sparse2dense or dense2sparse, we can simply unfuse a sparse - // conversion from the actual expansion operation itself. + // a sparse2dense or dense2sparse, we can simply unfuse a sparse + // conversion from the reshape operation itself. + // All other cases are handled elsewhere. if (encDst && encSrc) { - return failure(); // TODO: implement sparse2sparse - } else if (encSrc) { - RankedTensorType rtp = op.getSrc().getType().cast<RankedTensorType>(); - auto denseTp = - RankedTensorType::get(rtp.getShape(), rtp.getElementType()); - auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc()); - op->setOperand(0, convert); - return success(); - } else if (encDst) { - RankedTensorType rtp = op.getResult().getType().cast<RankedTensorType>(); - auto denseTp = - RankedTensorType::get(rtp.getShape(), rtp.getElementType()); - auto reshape = rewriter.create<tensor::ExpandShapeOp>( - loc, denseTp, op.getSrc(), op.getReassociation()); - Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape); - rewriter.replaceOp(op, convert); - return success(); - } - return failure(); - } -}; - -/// Sparse rewriting rule for collapse shape operator. -struct CollapseShapeRewriter - : public OpRewritePattern<tensor::CollapseShapeOp> { -public: - using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::CollapseShapeOp op, - PatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - auto encDst = getSparseTensorEncoding(op.getResult().getType()); - auto encSrc = getSparseTensorEncoding(op.getSrc().getType()); - // Since a pure dense collapse is very cheap (change of view), for - // sparse2dense or dense2sparse, we can simply unfuse a sparse - // conversion from the actual collapse operation itself. - if (encDst && encSrc) { - return failure(); // TODO: implement sparse2sparse + return failure(); } else if (encSrc) { - RankedTensorType rtp = op.getSrc().getType().cast<RankedTensorType>(); + RankedTensorType rtp = + op.getSrc().getType().template cast<RankedTensorType>(); auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc()); op->setOperand(0, convert); return success(); } else if (encDst) { - RankedTensorType rtp = op.getResult().getType().cast<RankedTensorType>(); + RankedTensorType rtp = + op.getResult().getType().template cast<RankedTensorType>(); auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); - auto reshape = rewriter.create<tensor::CollapseShapeOp>( - loc, denseTp, op.getSrc(), op.getReassociation()); + auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(), + op.getReassociation()); Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape); rewriter.replaceOp(op, convert); return success(); @@ -1912,6 +1879,6 @@ public: void mlir::populateSparsificationPatterns( RewritePatternSet &patterns, const SparsificationOptions &options) { patterns.add<GenericOpSparsifier>(patterns.getContext(), options); - patterns.add<ExpandShapeRewriter, CollapseShapeRewriter>( - patterns.getContext()); + patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>, + ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/rewriting.mlir index 3955310fce9b..000c3560f1e0 100644..100755 --- a/mlir/test/Dialect/SparseTensor/rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/rewriting.mlir @@ -40,8 +40,14 @@ func.func @expand_to_sparse(%arg0: tensor<12xf64>) -> tensor<3x4xf64, #SparseMat return %0 : tensor<3x4xf64, #SparseMatrix> } -// TODO: make this work +// +// Not rewritten, needs conversion. +// // CHECK-LABEL: func.func @expand_sparse2sparse( +// CHECK-SAME: %[[A:.*]]: tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> { +// CHECK: %[[E:.*]] = tensor.expand_shape %[[A]] {{.*}} : tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[E]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: } func.func @expand_sparse2sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> { %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix> return %0 : tensor<3x4xf64, #SparseMatrix> @@ -79,8 +85,14 @@ func.func @collapse_to_sparse(%arg0: tensor<3x4xf64>) -> tensor<12xf64, #SparseV return %0 : tensor<12xf64, #SparseVector> } -// TODO: make this work +// +// Not rewritten, needs conversion. +// // CHECK-LABEL: func.func @collapse_sparse2sparse( +// CHECK-SAME: %[[A:.*]]: tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>> { +// CHECK: %[[C:.*]] = tensor.collapse_shape %[[A]] {{.*}} : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[C]] : tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: } func.func @collapse_sparse2sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> { %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64, #SparseVector> return %0 : tensor<12xf64, #SparseVector> diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir index c791536e1519..65eb56b9bac3 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -1,24 +1,81 @@ -// RUN: mlir-opt %s | mlir-opt | FileCheck %s - -// TODO: check lowering to an actual implementation +// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND +// RUN: mlir-opt %s --sparse-tensor-conversion --cse | FileCheck %s --check-prefix=CHECK-CONV #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> #SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> -// CHECK-LABEL: func.func @sparse_expand( -// CHECK-SAME: %[[A:.*]]: tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: return %[[E]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> +// +// roundtrip: +// +// CHECK-ROUND-LABEL: func.func @sparse_expand( +// CHECK-ROUND-SAME: %[[A:.*]]: tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK-ROUND: return %[[E]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> +// +// conversion: +// +// CHECK-CONV-LABEL: func.func @sparse_expand( +// CHECK-CONV-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-CONV-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-CONV-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-CONV-DAG: call @newSparseTensor +// CHECK-CONV-DAG: call @newSparseTensor +// CHECK-CONV: scf.while : () -> () { +// CHECK-CONV: call @getNextF64 +// CHECK-CONV: scf.condition(%13) +// CHECK-CONV: } do { +// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xindex> +// CHECK-CONV: %[[D:.*]] = arith.divui %[[X]], %[[C10]] : index +// CHECK-CONV: memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<?xindex> +// CHECK-CONV: %[[R:.*]] = arith.remui %[[X]], %[[C10]] : index +// CHECK-CONV: memref.store %[[R]], %{{.*}}[%[[C1]]] : memref<?xindex> +// CHECK-CONV: call @addEltF64 +// CHECK-CONV: scf.yield +// CHECK-CONV: } +// CHECK-CONV: %[[N:.*]] = call @newSparseTensor +// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: return %[[N]] : !llvm.ptr<i8> +// func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> { %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix> return %0 : tensor<10x10xf64, #SparseMatrix> } -// CHECK-LABEL: func.func @sparse_collapse( -// CHECK-SAME: %[[A:.*]]: tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[C:.*]] = tensor.collapse_shape %[[A]] {{\[\[}}0, 1]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: return %[[C]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> +// +// roundtrip: +// +// CHECK-ROUND-LABEL: func.func @sparse_collapse( +// CHECK-ROUND-SAME: %[[A:.*]]: tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK-ROUND: %[[C:.*]] = tensor.collapse_shape %[[A]] {{\[\[}}0, 1]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK-ROUND: return %[[C]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> +// +// conversion: +// +// CHECK-CONV-LABEL: func.func @sparse_collapse( +// CHECK-CONV-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-CONV-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-CONV-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-CONV-DAG: call @newSparseTensor +// CHECK-CONV-DAG: call @newSparseTensor +// CHECK-CONV: scf.while : () -> () { +// CHECK-CONV: call @getNextF64 +// CHECK-CONV: scf.condition(%13) +// CHECK-CONV: } do { +// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xindex> +// CHECK-CONV: %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index +// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xindex> +// CHECK-CONV: %[[A:.*]] = arith.addi %[[M]], %[[Y]] : index +// CHECK-CONV: memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<?xindex> +// CHECK-CONV: call @addEltF64 +// CHECK-CONV: scf.yield +// CHECK-CONV: } +// CHECK-CONV: %[[N:.*]] = call @newSparseTensor +// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: return %[[N]] : !llvm.ptr<i8> +// func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> { %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<10x10xf64, #SparseMatrix> into tensor<100xf64, #SparseVector> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir index eed880866884..57d2a931fad5 100644..100755 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir @@ -32,11 +32,10 @@ module { return %0 : tensor<3x4xf64, #SparseMatrix> } -// TODO: make this work -// func.func @expand_sparse2sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> { -// %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix> -// return %0 : tensor<3x4xf64, #SparseMatrix> -// } + func.func @expand_sparse2sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> { + %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix> + return %0 : tensor<3x4xf64, #SparseMatrix> + } func.func @collapse_dense(%arg0: tensor<3x4xf64>) -> tensor<12xf64> { %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64> into tensor<12xf64> @@ -53,11 +52,10 @@ module { return %0 : tensor<12xf64, #SparseVector> } -// TODO: make this work -// func.func @collapse_sparse2sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> { -// %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64, #SparseVector> -// return %0 : tensor<12xf64, #SparseVector> -// } + func.func @collapse_sparse2sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64, #SparseVector> + return %0 : tensor<12xf64, #SparseVector> + } // @@ -81,10 +79,12 @@ module { %expand0 = call @expand_dense(%v) : (tensor<12xf64>) -> tensor<3x4xf64> %expand1 = call @expand_from_sparse(%sv) : (tensor<12xf64, #SparseVector>) -> tensor<3x4xf64> %expand2 = call @expand_to_sparse(%v) : (tensor<12xf64>) -> tensor<3x4xf64, #SparseMatrix> + %expand3 = call @expand_sparse2sparse(%sv) : (tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> %collapse0 = call @collapse_dense(%m) : (tensor<3x4xf64>) -> tensor<12xf64> %collapse1 = call @collapse_from_sparse(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64> %collapse2 = call @collapse_to_sparse(%m) : (tensor<3x4xf64>) -> tensor<12xf64, #SparseVector> + %collapse3 = call @collapse_sparse2sparse(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> // // Verify result. @@ -92,9 +92,11 @@ module { // CHECK: ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) ) // CHECK-NEXT: ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) ) // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1, -1, -1, -1 ) // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4 ) // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4 ) // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, -1, -1, -1, -1 ) // %m0 = vector.transfer_read %expand0[%c0, %c0], %df: tensor<3x4xf64>, vector<3x4xf64> vector.print %m0 : vector<3x4xf64> @@ -103,6 +105,9 @@ module { %a2 = sparse_tensor.values %expand2 : tensor<3x4xf64, #SparseMatrix> to memref<?xf64> %m2 = vector.transfer_read %a2[%c0], %df: memref<?xf64>, vector<16xf64> vector.print %m2 : vector<16xf64> + %a3 = sparse_tensor.values %expand3 : tensor<3x4xf64, #SparseMatrix> to memref<?xf64> + %m3 = vector.transfer_read %a3[%c0], %df: memref<?xf64>, vector<16xf64> + vector.print %m3 : vector<16xf64> %v0 = vector.transfer_read %collapse0[%c0], %df: tensor<12xf64>, vector<12xf64> vector.print %v0 : vector<12xf64> @@ -111,12 +116,17 @@ module { %b2 = sparse_tensor.values %collapse2 : tensor<12xf64, #SparseVector> to memref<?xf64> %v2 = vector.transfer_read %b2[%c0], %df: memref<?xf64>, vector<16xf64> vector.print %v2 : vector<16xf64> + %b3 = sparse_tensor.values %collapse3 : tensor<12xf64, #SparseVector> to memref<?xf64> + %v3 = vector.transfer_read %b3[%c0], %df: memref<?xf64>, vector<16xf64> + vector.print %v3 : vector<16xf64> // Release sparse resources. sparse_tensor.release %sv : tensor<12xf64, #SparseVector> sparse_tensor.release %sm : tensor<3x4xf64, #SparseMatrix> sparse_tensor.release %expand2 : tensor<3x4xf64, #SparseMatrix> + sparse_tensor.release %expand3 : tensor<3x4xf64, #SparseMatrix> sparse_tensor.release %collapse2 : tensor<12xf64, #SparseVector> + sparse_tensor.release %collapse3 : tensor<12xf64, #SparseVector> // Release dense resources. %meme1 = bufferization.to_memref %expand1 : memref<3x4xf64> |