Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorProtonu Basu <protonu@fb.com>2019-09-04 06:42:16 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-09-04 06:43:36 +0300
commit21782ffd9ede194cdf2395854adc10ba11d0d896 (patch)
treeef74f0e35c6e9d44547e780ae17842c61b5d6cc8
parent3ace43b21a95160b9cbcbf55573c3daa1e92ecb2 (diff)
Modifying reference conv2d/3d, im2col2d.3d to support dilated convolutions
Summary: Modifying reference conv2d/3d, im2col2d.3d to support dilated convolutions Reviewed By: dskhudia Differential Revision: D17169707 fbshipit-source-id: f6862f79d9cf10f0b72df1b6feafc3d35ba7e5d5
-rw-r--r--src/RefImplementations.cc30
1 files changed, 20 insertions, 10 deletions
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index e3c0eac..da58eba 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -300,9 +300,11 @@ void im2col_ref(
for (int h = 0; h < OUT_DIM[0]; ++h) {
for (int w = 0; w < OUT_DIM[1]; ++w) {
for (int r = 0; r < K[0]; ++r) {
- int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r;
+ int h_in =
+ -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0];
for (int s = 0; s < K[1]; ++s) {
- int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s;
+ int w_in =
+ -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1];
if (h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
w_in >= IN_DIM[1]) {
for (int g = 0; g < G; ++g) {
@@ -363,11 +365,14 @@ void im2col_ref(
for (int h = 0; h < OUT_DIM[1]; ++h) {
for (int w = 0; w < OUT_DIM[2]; ++w) {
for (int q = 0; q < K[0]; ++q) {
- int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q;
+ int t_in =
+ -conv_p.pad[0] + t * conv_p.stride[0] + q * conv_p.dilation[0];
for (int r = 0; r < K[1]; ++r) {
- int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r;
+ int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
+ r * conv_p.dilation[1];
for (int s = 0; s < K[2]; ++s) {
- int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s;
+ int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
+ s * conv_p.dilation[2];
if (t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]) {
for (int g = 0; g < G; ++g) {
@@ -447,9 +452,11 @@ void conv_ref(
for (int m = 0; m < OC / G; ++m) {
int sum = 0;
for (int r = 0; r < K[0]; ++r) {
- int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r;
+ int h_in = -conv_p.pad[0] + h * conv_p.stride[0] +
+ r * conv_p.dilation[0];
for (int s = 0; s < K[1]; ++s) {
- int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s;
+ int w_in = -conv_p.pad[1] + w * conv_p.stride[1] +
+ s * conv_p.dilation[1];
for (int c = 0; c < IC / G; ++c) {
int a = h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
w_in >= IN_DIM[1]
@@ -499,11 +506,14 @@ void conv_ref(
for (int m = 0; m < OC / G; ++m) {
int sum = 0;
for (int q = 0; q < K[0]; ++q) {
- int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q;
+ int t_in = -conv_p.pad[0] + t * conv_p.stride[0] +
+ q * conv_p.dilation[0];
for (int r = 0; r < K[1]; ++r) {
- int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r;
+ int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
+ r * conv_p.dilation[1];
for (int s = 0; s < K[2]; ++s) {
- int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s;
+ int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
+ s * conv_p.dilation[2];
for (int c = 0; c < IC / G; ++c) {
int a = t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]