Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/llvm/llvm-project.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorHanhan Wang <hanchung@google.com>2022-11-02 21:02:48 +0300
committerHanhan Wang <hanchung@google.com>2022-11-02 21:03:14 +0300
commitc050dd4717ec4317bd45adfca8243cb9ea7b6370 (patch)
tree0a66bcc6f9ff519d36934656539918a43bbccbcc /mlir
parentd1fbdf5bf79219549bc1fde255186d02f646a46f (diff)
[mlir][linalg] Add support for vectorizing convs that have different types.
Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D137208
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp2
-rw-r--r--mlir/test/Dialect/Linalg/vectorize-convolution.mlir64
2 files changed, 65 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d565efb30241..cedec72b9cb3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1465,7 +1465,7 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
return;
for (Value operand : mulOp->getOperands()) {
if (Operation *def = operand.getDefiningOp()) {
- if (!isa<arith::ExtFOp>(def))
+ if (!isa<CastOpInterface>(def))
return;
operand = def->getOperand(0);
}
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index e7495765b3ec..1374c996128a 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -61,6 +61,70 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
// -----
+// The i8i8i32 case is similar to f32 case, so checking one case is enough for
+// test coverage.
+func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: memref<1x3x8xi8>, %output: memref<4x2x8xi32>) {
+ linalg.conv_1d_nwc_wcf
+ {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+ ins(%input, %filter : memref<4x6x3xi8>, memref<1x3x8xi8>)
+ outs(%output : memref<4x2x8xi32>)
+ return
+}
+
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK: func @conv1d_nwc_4x2x8_i8i8i32_memref
+// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xi8>, %[[FILTER:.+]]: memref<1x3x8xi8>, %[[OUTPUT:.+]]: memref<4x2x8xi32>)
+
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
+
+/// Read the whole data in one shot.
+// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
+// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
+// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I32]]
+
+// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
+// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
+
+// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xi8>
+
+// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xi32> to vector<4x1x8xi32>
+// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xi32> to vector<4x1x8xi32>
+
+/// w == 0, kw == 0
+// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
+// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
+// CHECK-SAME: : vector<4x1x3xi8>, vector<3x8xi8> into vector<4x1x8xi32>
+
+/// w == 1, kw == 0
+// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
+// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
+// CHECK-SAME: : vector<4x1x3xi8>, vector<3x8xi8> into vector<4x1x8xi32>
+
+/// w == 0, kw == 0
+// CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xi32> into vector<4x2x8xi32>
+/// w == 1, kw == 0
+// CHECK: %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_1]], %[[RES_0]]
+// CHECK-SAME: {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xi32> into vector<4x2x8xi32>
+
+// Write the result back in one shot.
+// CHECK: vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+// -----
+
func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) {
linalg.conv_1d_nwc_wcf
{dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}