diff options
author | Protonu Basu <protonu@fb.com> | 2019-09-04 00:30:56 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-09-04 00:33:31 +0300 |
commit | 3ace43b21a95160b9cbcbf55573c3daa1e92ecb2 (patch) | |
tree | 91dad43496fe9d60ad0e7ec5f22a2ee34ce3104f | |
parent | e55a59653b70ef67e8f48373566756c477cda7b5 (diff) |
Adding Support for dilations in the conv_param_t constructor
Summary: (PART 1) Adding support for convolutions with dilation -- Modifications to the constructor
Reviewed By: jianyuh
Differential Revision: D17165387
fbshipit-source-id: e005c416683d9d40a4413f8aba1b5f21a7afc156
-rw-r--r-- | include/fbgemm/ConvUtils.h | 29 |
1 files changed, 25 insertions, 4 deletions
diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h index 11f3dcc..4a354b7 100644 --- a/include/fbgemm/ConvUtils.h +++ b/include/fbgemm/ConvUtils.h @@ -44,7 +44,8 @@ struct conv_param_t { int g, std::array<int, SPATIAL_DIM> k, std::array<int, SPATIAL_DIM> strd, - std::array<int, SPATIAL_DIM * 2> pd) + std::array<int, SPATIAL_DIM * 2> pd, + std::array<int, SPATIAL_DIM> dilations = {}) : MB(mb), IC(ic), OC(oc), @@ -52,7 +53,8 @@ struct conv_param_t { G(g), K(k), stride(strd), - pad(pd) { + pad(pd), + dilation(dilations) { if (ic % g != 0) { throw std::runtime_error( "groups = " + std::to_string(g) + @@ -63,10 +65,21 @@ struct conv_param_t { "groups = " + std::to_string(g) + " does not divide number of output channels = " + std::to_string(oc)); } + + bool dilation_unset = true; + for (int d = 0; d < SPATIAL_DIM; ++d) { + if (dilation[d] != 0) { + dilation_unset = false; + break; + } + } + if (dilation_unset) { + dilation.fill(1); + } + for (int d = 0; d < SPATIAL_DIM; ++d) { - dilation[d] = 1; IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d]; - OUT_DIM[d] = (IN_DIMP[d] - K[d]) / stride[d] + 1; + OUT_DIM[d] = (IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1; } } @@ -107,6 +120,10 @@ struct conv_param_t { out += ", "; } } + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" + + std::to_string(dilation[d]) + ", "; + } } else { for (int d = 0; d < SPATIAL_DIM; ++d) { out += "K" + std::to_string(d) + ":" + std::to_string(K[d]) + ", "; @@ -121,6 +138,10 @@ struct conv_param_t { out += ", "; } } + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "dilation_" + std::to_string(d) + ":" + + std::to_string(dilation[d]) + ", "; + } } return out; } |