diff options
author | Protonu Basu <protonu@fb.com> | 2019-09-04 06:42:16 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-09-04 06:43:36 +0300 |
commit | 21782ffd9ede194cdf2395854adc10ba11d0d896 (patch) | |
tree | ef74f0e35c6e9d44547e780ae17842c61b5d6cc8 | |
parent | 3ace43b21a95160b9cbcbf55573c3daa1e92ecb2 (diff) |
Modifying reference conv2d/3d, im2col2d.3d to support dilated convolutions
Summary: Modifying reference conv2d/3d, im2col2d.3d to support dilated convolutions
Reviewed By: dskhudia
Differential Revision: D17169707
fbshipit-source-id: f6862f79d9cf10f0b72df1b6feafc3d35ba7e5d5
-rw-r--r-- | src/RefImplementations.cc | 30 |
1 files changed, 20 insertions, 10 deletions
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index e3c0eac..da58eba 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -300,9 +300,11 @@ void im2col_ref( for (int h = 0; h < OUT_DIM[0]; ++h) { for (int w = 0; w < OUT_DIM[1]; ++w) { for (int r = 0; r < K[0]; ++r) { - int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r; + int h_in = + -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0]; for (int s = 0; s < K[1]; ++s) { - int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s; + int w_in = + -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1]; if (h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 || w_in >= IN_DIM[1]) { for (int g = 0; g < G; ++g) { @@ -363,11 +365,14 @@ void im2col_ref( for (int h = 0; h < OUT_DIM[1]; ++h) { for (int w = 0; w < OUT_DIM[2]; ++w) { for (int q = 0; q < K[0]; ++q) { - int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q; + int t_in = + -conv_p.pad[0] + t * conv_p.stride[0] + q * conv_p.dilation[0]; for (int r = 0; r < K[1]; ++r) { - int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r; + int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + + r * conv_p.dilation[1]; for (int s = 0; s < K[2]; ++s) { - int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s; + int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + + s * conv_p.dilation[2]; if (t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 || h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]) { for (int g = 0; g < G; ++g) { @@ -447,9 +452,11 @@ void conv_ref( for (int m = 0; m < OC / G; ++m) { int sum = 0; for (int r = 0; r < K[0]; ++r) { - int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r; + int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + + r * conv_p.dilation[0]; for (int s = 0; s < K[1]; ++s) { - int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s; + int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + + s * conv_p.dilation[1]; for (int c = 0; c < IC / G; ++c) { int a = h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 || w_in >= IN_DIM[1] @@ -499,11 +506,14 @@ void conv_ref( for (int m = 0; m < OC / G; ++m) { int sum = 0; for (int q = 0; q < K[0]; ++q) { - int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q; + int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + + q * conv_p.dilation[0]; for (int r = 0; r < K[1]; ++r) { - int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r; + int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + + r * conv_p.dilation[1]; for (int s = 0; s < K[2]; ++s) { - int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s; + int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + + s * conv_p.dilation[2]; for (int c = 0; c < IC / G; ++c) { int a = t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 || h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2] |