Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorProtonu Basu <protonu@fb.com>2019-09-04 00:30:56 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-09-04 00:33:31 +0300
commit3ace43b21a95160b9cbcbf55573c3daa1e92ecb2 (patch)
tree91dad43496fe9d60ad0e7ec5f22a2ee34ce3104f
parente55a59653b70ef67e8f48373566756c477cda7b5 (diff)
Adding Support for dilations in the conv_param_t constructor
Summary: (PART 1) Adding support for convolutions with dilation -- Modifications to the constructor Reviewed By: jianyuh Differential Revision: D17165387 fbshipit-source-id: e005c416683d9d40a4413f8aba1b5f21a7afc156
-rw-r--r--include/fbgemm/ConvUtils.h29
1 files changed, 25 insertions, 4 deletions
diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h
index 11f3dcc..4a354b7 100644
--- a/include/fbgemm/ConvUtils.h
+++ b/include/fbgemm/ConvUtils.h
@@ -44,7 +44,8 @@ struct conv_param_t {
int g,
std::array<int, SPATIAL_DIM> k,
std::array<int, SPATIAL_DIM> strd,
- std::array<int, SPATIAL_DIM * 2> pd)
+ std::array<int, SPATIAL_DIM * 2> pd,
+ std::array<int, SPATIAL_DIM> dilations = {})
: MB(mb),
IC(ic),
OC(oc),
@@ -52,7 +53,8 @@ struct conv_param_t {
G(g),
K(k),
stride(strd),
- pad(pd) {
+ pad(pd),
+ dilation(dilations) {
if (ic % g != 0) {
throw std::runtime_error(
"groups = " + std::to_string(g) +
@@ -63,10 +65,21 @@ struct conv_param_t {
"groups = " + std::to_string(g) +
" does not divide number of output channels = " + std::to_string(oc));
}
+
+ bool dilation_unset = true;
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ if (dilation[d] != 0) {
+ dilation_unset = false;
+ break;
+ }
+ }
+ if (dilation_unset) {
+ dilation.fill(1);
+ }
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
- dilation[d] = 1;
IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d];
- OUT_DIM[d] = (IN_DIMP[d] - K[d]) / stride[d] + 1;
+ OUT_DIM[d] = (IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1;
}
}
@@ -107,6 +120,10 @@ struct conv_param_t {
out += ", ";
}
}
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
+ std::to_string(dilation[d]) + ", ";
+ }
} else {
for (int d = 0; d < SPATIAL_DIM; ++d) {
out += "K" + std::to_string(d) + ":" + std::to_string(K[d]) + ", ";
@@ -121,6 +138,10 @@ struct conv_param_t {
out += ", ";
}
}
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "dilation_" + std::to_string(d) + ":" +
+ std::to_string(dilation[d]) + ", ";
+ }
}
return out;
}