diff options
author | Daya Khudia <dskhudia@fb.com> | 2019-09-14 00:36:54 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-09-14 00:38:49 +0300 |
commit | 2f1477dfee9465c1e2dbdf21722970b3fa1baf86 (patch) | |
tree | 3b71d218676f0b49da8bda85b64961fe1a7fb93d | |
parent | c8cac64995d8d8af871e461affbf505ac7fce4d8 (diff) |
Minor changes in initialization of dilation (#126)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/126
Default value for dilation is in function definition itself.
Reviewed By: protonu
Differential Revision: D17371791
fbshipit-source-id: c3430dfa3faccf549dc066aa8dcd422b910dbcaa
-rw-r--r-- | include/fbgemm/ConvUtils.h | 39 |
1 files changed, 21 insertions, 18 deletions
diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h index 4a354b7..5431958 100644 --- a/include/fbgemm/ConvUtils.h +++ b/include/fbgemm/ConvUtils.h @@ -8,9 +8,24 @@ #include <array> #include <string> +#include <type_traits> namespace fbgemm { +template <int N, int... Vals> +constexpr + typename std::enable_if<N == sizeof...(Vals), std::array<int, N>>::type + array_of_ones() { + return std::array<int, N>{{Vals...}}; +} + +template <int N, int... Vals> +constexpr + typename std::enable_if<N != sizeof...(Vals), std::array<int, N>>::type + array_of_ones() { + return array_of_ones<N, Vals..., 1>(); +} + /** * @brief A struct to conveniently store all convolution parameters. */ @@ -34,7 +49,6 @@ struct conv_param_t { /** * @brief Constructor for initializing the convolution parameters. - * TODO: Dilation is not handled correctly. */ conv_param_t( int mb, @@ -45,7 +59,7 @@ struct conv_param_t { std::array<int, SPATIAL_DIM> k, std::array<int, SPATIAL_DIM> strd, std::array<int, SPATIAL_DIM * 2> pd, - std::array<int, SPATIAL_DIM> dilations = {}) + std::array<int, SPATIAL_DIM> dilations = array_of_ones<SPATIAL_DIM>()) : MB(mb), IC(ic), OC(oc), @@ -66,17 +80,6 @@ struct conv_param_t { " does not divide number of output channels = " + std::to_string(oc)); } - bool dilation_unset = true; - for (int d = 0; d < SPATIAL_DIM; ++d) { - if (dilation[d] != 0) { - dilation_unset = false; - break; - } - } - if (dilation_unset) { - dilation.fill(1); - } - for (int d = 0; d < SPATIAL_DIM; ++d) { IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d]; OUT_DIM[d] = (IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1; @@ -115,14 +118,14 @@ struct conv_param_t { } for (int d = 0; d < SPATIAL_DIM * 2; ++d) { out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" + - std::to_string(pad[d]); - if (d < SPATIAL_DIM * 2 - 1) { - out += ", "; - } + std::to_string(pad[d]) + ", "; } for (int d = 0; d < SPATIAL_DIM; ++d) { out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" + - std::to_string(dilation[d]) + ", "; + std::to_string(dilation[d]); + if (d < SPATIAL_DIM - 1) { + out += ", "; + } } } else { for (int d = 0; d < SPATIAL_DIM; ++d) { |