Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/llvm/llvm-project.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorbixia1 <bixia@google.com>2022-09-27 21:45:02 +0300
committerbixia1 <bixia@google.com>2022-09-27 22:00:08 +0300
commit4329ca61e821bb594ffaa280481b7f92d8ad31b1 (patch)
tree629115866cb1cc7850edeefd30cc5f55f3e1d82a
parent30cc712eb6f23a5c7beaae669bf2ab6beede7f20 (diff)
[mlir][sparse] Add sparse_tensor.sort operator.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D134482
-rw-r--r--mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td45
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp36
-rw-r--r--mlir/test/Dialect/SparseTensor/invalid.mlir26
-rw-r--r--mlir/test/Dialect/SparseTensor/roundtrip.mlir40
-rw-r--r--utils/bazel/llvm-project-overlay/mlir/BUILD.bazel1
6 files changed, 149 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index ef1edf1fdf4f..2d2a0499a7b9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -392,6 +392,51 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)";
}
+def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
+ // TODO: May want to extend tablegen with
+ // class NonemptyVariadic<Type type> : Variadic<type> { let minSize = 1; }
+ // and then use NonemptyVariadic<...>:$xs here.
+ Arguments<(ins Index:$n,
+ Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
+ Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys)> {
+ string summary = "Sorts the arrays in xs and ys lexicographically on the "
+ "integral values found in the xs list";
+ string description = [{
+ Lexicographically sort the first `n` values in `xs` along with the values in
+ `ys`. Conceptually, the values being sorted are tuples produced by
+ `zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted
+ along with values in `xs`, but values in `ys` don't affect the
+ lexicographical order. The order in which arrays appear in `xs` affects the
+ sorting result. The operator updates `xs` and `ys` in place with the result
+ of the sorting.
+
+ For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of
+ "sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the
+ output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5].
+
+ Buffers in `xs` needs to have the same integral element type while buffers
+ in `ys` can have different numeric element types. All buffers in `xs` and
+ `ys` should have a dimension not less than `n`. The behavior of the operator
+ is undefined if this condition is not met. The operator requires at least
+ one buffer in `xs` while `ys` can be empty.
+
+ Note that this operation is "impure" in the sense that its behavior is
+ solely defined by side-effects and not SSA values. The semantics may be
+ refined over time as our sparse abstractions evolve.
+
+ Example:
+
+ ```mlir
+ sparse_tensor.sort %n, %x1, %x2 jointly y1, %y2
+ : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
+ ```
+ }];
+ let assemblyFormat = "$n `,` $xs (`jointly` $ys^)? attr-dict"
+ "`:` type($xs) (`jointly` type($ys)^)?";
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Syntax Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
index 5172031909b4..9dba53018015 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRSparseTensorDialect
MLIRSparseTensorOpsIncGen
LINK_LIBS PUBLIC
+ MLIRArithmeticDialect
MLIRDialect
MLIRIR
MLIRInferTypeOpInterface
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 2e98eaa7561c..ba344e372519 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
@@ -505,6 +506,41 @@ LogicalResult SelectOp::verify() {
return success();
}
+LogicalResult SortOp::verify() {
+ if (getXs().empty())
+ return emitError("need at least one xs buffer.");
+
+ auto n = getN().getDefiningOp<arith::ConstantIndexOp>();
+
+ Type xtp = getXs().front().getType().cast<MemRefType>().getElementType();
+ auto checkTypes = [&](ValueRange operands,
+ bool checkEleType = true) -> LogicalResult {
+ for (Value opnd : operands) {
+ MemRefType mtp = opnd.getType().cast<MemRefType>();
+ uint64_t dim = mtp.getShape()[0];
+ // We can't check the size of dynamic dimension at compile-time, but all
+ // xs and ys should have a dimension not less than n at runtime.
+ if (n && dim != ShapedType::kDynamicSize && dim < n.value())
+ return emitError(llvm::formatv("xs and ys need to have a dimension >= n"
+ ": {0} < {1}",
+ dim, n.value()));
+
+ if (checkEleType && xtp != mtp.getElementType())
+ return emitError("mismatch xs element types");
+ }
+ return success();
+ };
+
+ LogicalResult result = checkTypes(getXs());
+ if (failed(result))
+ return result;
+
+ if (n)
+ return checkTypes(getYs(), false);
+
+ return success();
+}
+
LogicalResult YieldOp::verify() {
// Check for compatible parent.
auto *parentOp = (*this)->getParentOp();
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index af913204fabb..7425d19efb3b 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -501,3 +501,29 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
}
return
}
+
+// -----
+
+// TODO: a test case with empty xs doesn't work due to some parser issues.
+
+func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) {
+ // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}}
+ sparse_tensor.sort %arg0, %arg1: memref<?xf32>
+}
+
+// -----
+
+func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) {
+ %i20 = arith.constant 20 : index
+ // expected-error@+1 {{xs and ys need to have a dimension >= n: 10 < 20}}
+ sparse_tensor.sort %i20, %arg0 : memref<10xindex>
+ return
+}
+
+// -----
+
+func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) {
+ // expected-error@+1 {{mismatch xs element types}}
+ sparse_tensor.sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8>
+ return
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index fd4b508ad485..7059b1590543 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -362,3 +362,43 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
}
return
}
+
+// ----
+
+// CHECK-LABEL: func @sparse_sort_1d0v(
+// CHECK-SAME: %[[A:.*]]: index,
+// CHECK-SAME: %[[B:.*]]: memref<?xindex>)
+// CHECK: sparse_tensor.sort %[[A]], %[[B]] : memref<?xindex>
+// CHECK: return %[[B]]
+func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
+ sparse_tensor.sort %arg0, %arg1 : memref<?xindex>
+ return %arg1 : memref<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func @sparse_sort_1d2v(
+// CHECK-SAME: %[[A:.*]]: index,
+// CHECK-SAME: %[[B:.*]]: memref<20xindex>,
+// CHECK-SAME: %[[C:.*]]: memref<10xindex>,
+// CHECK-SAME: %[[D:.*]]: memref<?xf32>)
+// CHECK: sparse_tensor.sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
+// CHECK: return %[[B]], %[[C]], %[[D]]
+func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref<?xf32>) -> (memref<20xindex>, memref<10xindex>, memref<?xf32>) {
+ sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
+ return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @sparse_sort_2d1v(
+// CHECK-SAME: %[[A:.*]]: index,
+// CHECK-SAME: %[[B:.*]]: memref<10xi8>,
+// CHECK-SAME: %[[C:.*]]: memref<20xi8>,
+// CHECK-SAME: %[[D:.*]]: memref<10xf64>)
+// CHECK: sparse_tensor.sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+// CHECK: return %[[B]], %[[C]], %[[D]]
+func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
+ sparse_tensor.sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+ return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 135e54ce2973..ccc21652b383 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2030,6 +2030,7 @@ cc_library(
hdrs = ["include/mlir/Dialect/SparseTensor/IR/SparseTensor.h"],
includes = ["include"],
deps = [
+ ":ArithmeticDialect",
":IR",
":InferTypeOpInterface",
":SparseTensorAttrDefsIncGen",