diff options
Diffstat (limited to 'test/UniConvPackingTest.cc')
-rw-r--r-- | test/UniConvPackingTest.cc | 148 |
1 files changed, 148 insertions, 0 deletions
diff --git a/test/UniConvPackingTest.cc b/test/UniConvPackingTest.cc new file mode 100644 index 0000000..77552af --- /dev/null +++ b/test/UniConvPackingTest.cc @@ -0,0 +1,148 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <algorithm> +#include <random> +#include <iostream> + + +#include <gtest/gtest.h> + +#include "QuantizationHelpers.h" +#include "TestUtils.h" +#include "bench/BenchUtils.h" +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" + +using namespace std; +using namespace fbgemm; + +namespace { + +// tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad +class convPackingTest + : public testing::TestWithParam< + tuple<int, int, int, int, int, int, int, int, int, int>> {}; + +}; // namespace + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + convPackingTest, + ::testing::Combine( + ::testing::ValuesIn({1, 2}), // MB + ::testing::ValuesIn({16, 32}), // IC + ::testing::ValuesIn({16, 32}), // OC + ::testing::ValuesIn({17}), // IT + ::testing::ValuesIn({10, 30, 55}), // IH + ::testing::ValuesIn({10, 30, 55}), // IW + ::testing::ValuesIn({1, 4, 16}), // G + ::testing::ValuesIn({3, 7}), // kernel + ::testing::ValuesIn({1, 2}), // stride + ::testing::ValuesIn({1, 2}))); // pad + +/** + * Test for conv packing + */ +TEST_P(convPackingTest, packingTest) { + int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad; + tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam(); + + conv_param_t<2> conv_p_2d( + MB, + IC, + OC, + {IH, IW}, + G, + {kernel, kernel}, + {stride, stride}, + {pad, pad, pad, pad}); + + int kernel_dim_2d = kernel * kernel; + aligned_vector<int8_t> Bint8_2d( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); + + switch (ConvFastPath<2, int32_t>(conv_p_2d)) { + case optimized_conv_t::depthwise: { + ASSERT_NE(packedB_2D.getPackedWFor2DDW(), nullptr) + << "2D depthwise packed matrix is null"; + ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) + << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + break; + } + case optimized_conv_t::groupwise: { + ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr) + << "2D depthwise packed matrix is null"; + ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) + << "3D depthwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr) + << "Groupwise packed matrix is null"; + break; + } + case optimized_conv_t::im2col: { + ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr) + << "2D depthwise packed matrix is null"; + ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) + << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix is null"; + break; + } + } + + conv_param_t<3> conv_p_3d( + MB, + IC, + OC, + {IT, IH, IW}, + G, + {kernel, kernel, kernel}, + {stride, stride, stride}, + {pad, pad, pad, pad, pad, pad}); + + int kernel_dim_3d = kernel * kernel * kernel; + aligned_vector<int8_t> Bint8_3d( + kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); + PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data()); + + switch (ConvFastPath<3, int32_t>(conv_p_3d)) { + case optimized_conv_t::depthwise: { + ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr) + << "2D depthwise packed matrix is null"; + ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_NE(packedB_3D.getPackedWFor3DDW(), nullptr) + << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + break; + } + case optimized_conv_t::groupwise: { + ASSERT_TRUE(false) << "groupwise are not supported for 3D"; + break; + } + case optimized_conv_t::im2col: { + ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr) + << "2D depthwise packed matrix is null"; + ASSERT_EQ(packedB_3D.getPackedWFor3DDW(), nullptr) + << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_NE(packedB_3D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix is null"; + break; + } + } +} |