Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaya Khudia <dskhudia@fb.com>2019-09-14 00:36:54 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-09-14 00:38:49 +0300
commit2f1477dfee9465c1e2dbdf21722970b3fa1baf86 (patch)
tree3b71d218676f0b49da8bda85b64961fe1a7fb93d
parentc8cac64995d8d8af871e461affbf505ac7fce4d8 (diff)
Minor changes in initialization of dilation (#126)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/126 Default value for dilation is in function definition itself. Reviewed By: protonu Differential Revision: D17371791 fbshipit-source-id: c3430dfa3faccf549dc066aa8dcd422b910dbcaa
-rw-r--r--include/fbgemm/ConvUtils.h39
1 files changed, 21 insertions, 18 deletions
diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h
index 4a354b7..5431958 100644
--- a/include/fbgemm/ConvUtils.h
+++ b/include/fbgemm/ConvUtils.h
@@ -8,9 +8,24 @@
#include <array>
#include <string>
+#include <type_traits>
namespace fbgemm {
+template <int N, int... Vals>
+constexpr
+ typename std::enable_if<N == sizeof...(Vals), std::array<int, N>>::type
+ array_of_ones() {
+ return std::array<int, N>{{Vals...}};
+}
+
+template <int N, int... Vals>
+constexpr
+ typename std::enable_if<N != sizeof...(Vals), std::array<int, N>>::type
+ array_of_ones() {
+ return array_of_ones<N, Vals..., 1>();
+}
+
/**
* @brief A struct to conveniently store all convolution parameters.
*/
@@ -34,7 +49,6 @@ struct conv_param_t {
/**
* @brief Constructor for initializing the convolution parameters.
- * TODO: Dilation is not handled correctly.
*/
conv_param_t(
int mb,
@@ -45,7 +59,7 @@ struct conv_param_t {
std::array<int, SPATIAL_DIM> k,
std::array<int, SPATIAL_DIM> strd,
std::array<int, SPATIAL_DIM * 2> pd,
- std::array<int, SPATIAL_DIM> dilations = {})
+ std::array<int, SPATIAL_DIM> dilations = array_of_ones<SPATIAL_DIM>())
: MB(mb),
IC(ic),
OC(oc),
@@ -66,17 +80,6 @@ struct conv_param_t {
" does not divide number of output channels = " + std::to_string(oc));
}
- bool dilation_unset = true;
- for (int d = 0; d < SPATIAL_DIM; ++d) {
- if (dilation[d] != 0) {
- dilation_unset = false;
- break;
- }
- }
- if (dilation_unset) {
- dilation.fill(1);
- }
-
for (int d = 0; d < SPATIAL_DIM; ++d) {
IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d];
OUT_DIM[d] = (IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1;
@@ -115,14 +118,14 @@ struct conv_param_t {
}
for (int d = 0; d < SPATIAL_DIM * 2; ++d) {
out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" +
- std::to_string(pad[d]);
- if (d < SPATIAL_DIM * 2 - 1) {
- out += ", ";
- }
+ std::to_string(pad[d]) + ", ";
}
for (int d = 0; d < SPATIAL_DIM; ++d) {
out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
- std::to_string(dilation[d]) + ", ";
+ std::to_string(dilation[d]);
+ if (d < SPATIAL_DIM - 1) {
+ out += ", ";
+ }
}
} else {
for (int d = 0; d < SPATIAL_DIM; ++d) {