From 90fd13b0a1120a7fa50a21d4d46af61285b7a964 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Fri, 7 Oct 2022 14:21:02 -0700 Subject: [mlir][sparse] Converting SparseTensorCOO to use standard C++-style iterators. This differential comprises three related changes: (1) it gives SparseTensorCOO standard C++-style iterators; (2) it removes the old iterator stuff from SparseTensorCOO; and (3) it introduces SparseTensorIterator which behaves like the old SparseTensorCOO iterator stuff used to. The SparseTensorIterator class is needed because the MLIR codegen cannot easily use the C++-style iterators (hence why SparseTensorCOO had the old iterator stuff). Distinguishing SparseTensorIterator from SparseTensorCOO also helps improve API hygiene since these two classes are used for distinct purposes. And having SparseTensorIterator as its own class enables changing the underlying implementation in the future, without needing to worry about updating all the codegen tests etc. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D135485 --- mlir/test/Dialect/SparseTensor/sparse_concat.mlir | 10 +++++----- mlir/test/Dialect/SparseTensor/sparse_reshape.mlir | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) (limited to 'mlir/test') diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir index b51c72e2d6e6..f2fbe292b6dc 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir @@ -58,7 +58,7 @@ // CHECK: memref.store %[[TMP_15]], %[[TMP_0]][%[[TMP_13]], %[[TMP_14]]] : memref<5x4xf64> // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_7]]) : (!llvm.ptr) -> () // CHECK: %[[TMP_11:.*]] = bufferization.to_tensor %[[TMP_0]] : memref<5x4xf64> // CHECK: return %[[TMP_11]] : tensor<5x4xf64> // CHECK: } @@ -141,7 +141,7 @@ func.func @concat_mix_dense(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spar // CHECK: %[[TMP_25:.*]] = func.call @addEltF64(%[[TMP_7]], %[[TMP_20]], %[[TMP_10]], %[[TMP_5]]) : (!llvm.ptr, memref, memref, memref) -> !llvm.ptr // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_17]]) : (!llvm.ptr) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr) -> () // CHECK: %[[TMP_21:.*]] = call @newSparseTensor(%[[TMP_1]], %[[TMP_3]], %[[TMP_5]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c2_i32]], %[[TMP_7]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr // CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr) -> () // CHECK: return %[[TMP_21]] : !llvm.ptr @@ -225,7 +225,7 @@ func.func @concat_mix_sparse(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spa // CHECK: %[[TMP_25:.*]] = func.call @addEltF64(%[[TMP_7]], %[[TMP_20]], %[[TMP_10]], %[[TMP_5]]) : (!llvm.ptr, memref, memref, memref) -> !llvm.ptr // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_17]]) : (!llvm.ptr) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr) -> () // CHECK: %[[TMP_21:.*]] = call @newSparseTensor(%[[TMP_1]], %[[TMP_3]], %[[TMP_5]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c2_i32]], %[[TMP_7]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr // CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr) -> () // CHECK: return %[[TMP_21]] : !llvm.ptr @@ -287,7 +287,7 @@ func.func @concat_mix_sparse_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3 // CHECK: memref.store %[[TMP_15]], %[[TMP_0]][%[[TMP_12]], %[[TMP_14]]] : memref<4x5xf64> // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_7]]) : (!llvm.ptr) -> () // CHECK: %[[TMP_11:.*]] = bufferization.to_tensor %[[TMP_0]] : memref<4x5xf64> // CHECK: return %[[TMP_11]] : tensor<4x5xf64> // CHECK: } @@ -348,7 +348,7 @@ func.func @concat_mix_dense_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3x // CHECK: memref.store %[[TMP_16]], %[[TMP_0]][%[[TMP_13]], %[[TMP_15]]] : memref<3x5xf64> // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_8]]) : (!llvm.ptr) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_8]]) : (!llvm.ptr) -> () // CHECK: %[[TMP_12:.*]] = bufferization.to_tensor %[[TMP_1]] : memref // CHECK: return %[[TMP_12]] : tensor // CHECK: } diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir index 420d732ce62a..83e022146f48 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -35,7 +35,7 @@ // CHECK-CONV: } // CHECK-CONV: %[[N:.*]] = call @newSparseTensor // CHECK-CONV: call @delSparseTensorCOOF64 -// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorIteratorF64 // CHECK-CONV: return %[[N]] : !llvm.ptr // // rewrite for codegen: @@ -97,7 +97,7 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x // CHECK-CONV: } // CHECK-CONV: %[[N:.*]] = call @newSparseTensor // CHECK-CONV: call @delSparseTensorCOOF64 -// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorIteratorF64 // CHECK-CONV: return %[[N]] : !llvm.ptr // // rewrite for codegen: @@ -172,7 +172,7 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10 // CHECK-CONV: } // CHECK-CONV: %[[N:.*]] = call @newSparseTensor // CHECK-CONV: call @delSparseTensorCOOF64 -// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorIteratorF64 // CHECK-CONV: return %[[N]] : !llvm.ptr // // rewrite for codegen: @@ -244,7 +244,7 @@ func.func @dynamic_sparse_expand(%arg0: tensor) -> tensor< // CHECK-CONV: } // CHECK-CONV: %[[N:.*]] = call @newSparseTensor // CHECK-CONV: call @delSparseTensorCOOF64 -// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorIteratorF64 // CHECK-CONV: return %[[N]] : !llvm.ptr // // rewrite for codegen: -- cgit v1.2.3