diff options
author | bixia1 <bixia@google.com> | 2022-09-27 21:45:02 +0300 |
---|---|---|
committer | bixia1 <bixia@google.com> | 2022-09-27 22:00:08 +0300 |
commit | 4329ca61e821bb594ffaa280481b7f92d8ad31b1 (patch) | |
tree | 629115866cb1cc7850edeefd30cc5f55f3e1d82a | |
parent | 30cc712eb6f23a5c7beaae669bf2ab6beede7f20 (diff) |
[mlir][sparse] Add sparse_tensor.sort operator.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D134482
-rw-r--r-- | mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td | 45 | ||||
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 36 | ||||
-rw-r--r-- | mlir/test/Dialect/SparseTensor/invalid.mlir | 26 | ||||
-rw-r--r-- | mlir/test/Dialect/SparseTensor/roundtrip.mlir | 40 | ||||
-rw-r--r-- | utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 |
6 files changed, 149 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index ef1edf1fdf4f..2d2a0499a7b9 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -392,6 +392,51 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>, let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)"; } +def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>, + // TODO: May want to extend tablegen with + // class NonemptyVariadic<Type type> : Variadic<type> { let minSize = 1; } + // and then use NonemptyVariadic<...>:$xs here. + Arguments<(ins Index:$n, + Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs, + Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys)> { + string summary = "Sorts the arrays in xs and ys lexicographically on the " + "integral values found in the xs list"; + string description = [{ + Lexicographically sort the first `n` values in `xs` along with the values in + `ys`. Conceptually, the values being sorted are tuples produced by + `zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted + along with values in `xs`, but values in `ys` don't affect the + lexicographical order. The order in which arrays appear in `xs` affects the + sorting result. The operator updates `xs` and `ys` in place with the result + of the sorting. + + For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of + "sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the + output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5]. + + Buffers in `xs` needs to have the same integral element type while buffers + in `ys` can have different numeric element types. All buffers in `xs` and + `ys` should have a dimension not less than `n`. The behavior of the operator + is undefined if this condition is not met. The operator requires at least + one buffer in `xs` while `ys` can be empty. + + Note that this operation is "impure" in the sense that its behavior is + solely defined by side-effects and not SSA values. The semantics may be + refined over time as our sparse abstractions evolve. + + Example: + + ```mlir + sparse_tensor.sort %n, %x1, %x2 jointly y1, %y2 + : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32> + ``` + }]; + let assemblyFormat = "$n `,` $xs (`jointly` $ys^)? attr-dict" + "`:` type($xs) (`jointly` type($ys)^)?"; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Sparse Tensor Syntax Operations. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt index 5172031909b4..9dba53018015 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRSparseTensorDialect MLIRSparseTensorOpsIncGen LINK_LIBS PUBLIC + MLIRArithmeticDialect MLIRDialect MLIRIR MLIRInferTypeOpInterface diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 2e98eaa7561c..ba344e372519 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" @@ -505,6 +506,41 @@ LogicalResult SelectOp::verify() { return success(); } +LogicalResult SortOp::verify() { + if (getXs().empty()) + return emitError("need at least one xs buffer."); + + auto n = getN().getDefiningOp<arith::ConstantIndexOp>(); + + Type xtp = getXs().front().getType().cast<MemRefType>().getElementType(); + auto checkTypes = [&](ValueRange operands, + bool checkEleType = true) -> LogicalResult { + for (Value opnd : operands) { + MemRefType mtp = opnd.getType().cast<MemRefType>(); + uint64_t dim = mtp.getShape()[0]; + // We can't check the size of dynamic dimension at compile-time, but all + // xs and ys should have a dimension not less than n at runtime. + if (n && dim != ShapedType::kDynamicSize && dim < n.value()) + return emitError(llvm::formatv("xs and ys need to have a dimension >= n" + ": {0} < {1}", + dim, n.value())); + + if (checkEleType && xtp != mtp.getElementType()) + return emitError("mismatch xs element types"); + } + return success(); + }; + + LogicalResult result = checkTypes(getXs()); + if (failed(result)) + return result; + + if (n) + return checkTypes(getYs(), false); + + return success(); +} + LogicalResult YieldOp::verify() { // Check for compatible parent. auto *parentOp = (*this)->getParentOp(); diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index af913204fabb..7425d19efb3b 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -501,3 +501,29 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () { } return } + +// ----- + +// TODO: a test case with empty xs doesn't work due to some parser issues. + +func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) { + // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}} + sparse_tensor.sort %arg0, %arg1: memref<?xf32> +} + +// ----- + +func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) { + %i20 = arith.constant 20 : index + // expected-error@+1 {{xs and ys need to have a dimension >= n: 10 < 20}} + sparse_tensor.sort %i20, %arg0 : memref<10xindex> + return +} + +// ----- + +func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) { + // expected-error@+1 {{mismatch xs element types}} + sparse_tensor.sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8> + return +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir index fd4b508ad485..7059b1590543 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -362,3 +362,43 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () { } return } + +// ---- + +// CHECK-LABEL: func @sparse_sort_1d0v( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref<?xindex>) +// CHECK: sparse_tensor.sort %[[A]], %[[B]] : memref<?xindex> +// CHECK: return %[[B]] +func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) { + sparse_tensor.sort %arg0, %arg1 : memref<?xindex> + return %arg1 : memref<?xindex> +} + +// ----- + +// CHECK-LABEL: func @sparse_sort_1d2v( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref<20xindex>, +// CHECK-SAME: %[[C:.*]]: memref<10xindex>, +// CHECK-SAME: %[[D:.*]]: memref<?xf32>) +// CHECK: sparse_tensor.sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref<?xf32> +// CHECK: return %[[B]], %[[C]], %[[D]] +func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref<?xf32>) -> (memref<20xindex>, memref<10xindex>, memref<?xf32>) { + sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref<?xf32> + return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref<?xf32> +} + +// ----- + +// CHECK-LABEL: func @sparse_sort_2d1v( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref<10xi8>, +// CHECK-SAME: %[[C:.*]]: memref<20xi8>, +// CHECK-SAME: %[[D:.*]]: memref<10xf64>) +// CHECK: sparse_tensor.sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64> +// CHECK: return %[[B]], %[[C]], %[[D]] +func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) { + sparse_tensor.sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> + return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 135e54ce2973..ccc21652b383 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2030,6 +2030,7 @@ cc_library( hdrs = ["include/mlir/Dialect/SparseTensor/IR/SparseTensor.h"], includes = ["include"], deps = [ + ":ArithmeticDialect", ":IR", ":InferTypeOpInterface", ":SparseTensorAttrDefsIncGen", |