diff options
author | wren romano <2998727+wrengr@users.noreply.github.com> | 2022-10-08 00:21:02 +0300 |
---|---|---|
committer | wren romano <2998727+wrengr@users.noreply.github.com> | 2022-10-12 00:03:37 +0300 |
commit | 90fd13b0a1120a7fa50a21d4d46af61285b7a964 (patch) | |
tree | d5fb328d1b37f08b4c1ce04620983b3ee0050624 /mlir/test | |
parent | 1079662d2fff7ae799503d910b299c5108d105fd (diff) |
[mlir][sparse] Converting SparseTensorCOO to use standard C++-style iterators.
This differential comprises three related changes: (1) it gives SparseTensorCOO standard C++-style iterators; (2) it removes the old iterator stuff from SparseTensorCOO; and (3) it introduces SparseTensorIterator which behaves like the old SparseTensorCOO iterator stuff used to.
The SparseTensorIterator class is needed because the MLIR codegen cannot easily use the C++-style iterators (hence why SparseTensorCOO had the old iterator stuff). Distinguishing SparseTensorIterator from SparseTensorCOO also helps improve API hygiene since these two classes are used for distinct purposes. And having SparseTensorIterator as its own class enables changing the underlying implementation in the future, without needing to worry about updating all the codegen tests etc.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D135485
Diffstat (limited to 'mlir/test')
-rw-r--r-- | mlir/test/Dialect/SparseTensor/sparse_concat.mlir | 10 | ||||
-rw-r--r-- | mlir/test/Dialect/SparseTensor/sparse_reshape.mlir | 8 |
2 files changed, 9 insertions, 9 deletions
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir index b51c72e2d6e6..f2fbe292b6dc 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir @@ -58,7 +58,7 @@ // CHECK: memref.store %[[TMP_15]], %[[TMP_0]][%[[TMP_13]], %[[TMP_14]]] : memref<5x4xf64> // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> () // CHECK: %[[TMP_11:.*]] = bufferization.to_tensor %[[TMP_0]] : memref<5x4xf64> // CHECK: return %[[TMP_11]] : tensor<5x4xf64> // CHECK: } @@ -141,7 +141,7 @@ func.func @concat_mix_dense(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spar // CHECK: %[[TMP_25:.*]] = func.call @addEltF64(%[[TMP_7]], %[[TMP_20]], %[[TMP_10]], %[[TMP_5]]) : (!llvm.ptr<i8>, memref<f64>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8> // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_17]]) : (!llvm.ptr<i8>) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr<i8>) -> () // CHECK: %[[TMP_21:.*]] = call @newSparseTensor(%[[TMP_1]], %[[TMP_3]], %[[TMP_5]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c2_i32]], %[[TMP_7]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8> // CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> () // CHECK: return %[[TMP_21]] : !llvm.ptr<i8> @@ -225,7 +225,7 @@ func.func @concat_mix_sparse(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spa // CHECK: %[[TMP_25:.*]] = func.call @addEltF64(%[[TMP_7]], %[[TMP_20]], %[[TMP_10]], %[[TMP_5]]) : (!llvm.ptr<i8>, memref<f64>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8> // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_17]]) : (!llvm.ptr<i8>) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr<i8>) -> () // CHECK: %[[TMP_21:.*]] = call @newSparseTensor(%[[TMP_1]], %[[TMP_3]], %[[TMP_5]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c2_i32]], %[[TMP_7]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8> // CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> () // CHECK: return %[[TMP_21]] : !llvm.ptr<i8> @@ -287,7 +287,7 @@ func.func @concat_mix_sparse_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3 // CHECK: memref.store %[[TMP_15]], %[[TMP_0]][%[[TMP_12]], %[[TMP_14]]] : memref<4x5xf64> // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> () // CHECK: %[[TMP_11:.*]] = bufferization.to_tensor %[[TMP_0]] : memref<4x5xf64> // CHECK: return %[[TMP_11]] : tensor<4x5xf64> // CHECK: } @@ -348,7 +348,7 @@ func.func @concat_mix_dense_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3x // CHECK: memref.store %[[TMP_16]], %[[TMP_0]][%[[TMP_13]], %[[TMP_15]]] : memref<3x5xf64> // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_8]]) : (!llvm.ptr<i8>) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_8]]) : (!llvm.ptr<i8>) -> () // CHECK: %[[TMP_12:.*]] = bufferization.to_tensor %[[TMP_1]] : memref<?x?xf64> // CHECK: return %[[TMP_12]] : tensor<?x?xf64> // CHECK: } diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir index 420d732ce62a..83e022146f48 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -35,7 +35,7 @@ // CHECK-CONV: } // CHECK-CONV: %[[N:.*]] = call @newSparseTensor // CHECK-CONV: call @delSparseTensorCOOF64 -// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorIteratorF64 // CHECK-CONV: return %[[N]] : !llvm.ptr<i8> // // rewrite for codegen: @@ -97,7 +97,7 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x // CHECK-CONV: } // CHECK-CONV: %[[N:.*]] = call @newSparseTensor // CHECK-CONV: call @delSparseTensorCOOF64 -// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorIteratorF64 // CHECK-CONV: return %[[N]] : !llvm.ptr<i8> // // rewrite for codegen: @@ -172,7 +172,7 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10 // CHECK-CONV: } // CHECK-CONV: %[[N:.*]] = call @newSparseTensor // CHECK-CONV: call @delSparseTensorCOOF64 -// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorIteratorF64 // CHECK-CONV: return %[[N]] : !llvm.ptr<i8> // // rewrite for codegen: @@ -244,7 +244,7 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor< // CHECK-CONV: } // CHECK-CONV: %[[N:.*]] = call @newSparseTensor // CHECK-CONV: call @delSparseTensorCOOF64 -// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorIteratorF64 // CHECK-CONV: return %[[N]] : !llvm.ptr<i8> // // rewrite for codegen: |