Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-07-28 06:17:44 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-03 05:44:44 +0300
commitd87d7c7619a008a27ad0a1d03dfaf978ccfb1719 (patch)
treed52eb1a9e0f0fb24ee236307718098b46ee59dec
parent55374c200cb01fe2b9a549200a1e055e2438fe41 (diff)
add 2d and 3d dilated full Convolution
-rw-r--r--lib/THNN/generic/SpatialFullConvolution.c421
-rw-r--r--lib/THNN/generic/SpatialFullDilatedConvolution.c472
-rw-r--r--lib/THNN/generic/THNN.h81
-rw-r--r--lib/THNN/generic/VolumetricFullConvolution.c502
-rw-r--r--lib/THNN/generic/VolumetricFullDilatedConvolution.c548
-rw-r--r--lib/THNN/init.c6
6 files changed, 1125 insertions, 905 deletions
diff --git a/lib/THNN/generic/SpatialFullConvolution.c b/lib/THNN/generic/SpatialFullConvolution.c
index 2edc53b..b9cd9fe 100644
--- a/lib/THNN/generic/SpatialFullConvolution.c
+++ b/lib/THNN/generic/SpatialFullConvolution.c
@@ -2,115 +2,6 @@
#define TH_GENERIC_FILE "generic/SpatialFullConvolution.c"
#else
-static void THNN_(im2col)(const real* data_im, const int channels,
- const int height, const int width, const int kernel_h, const int kernel_w,
- const int pad_h, const int pad_w,
- const int stride_h, const int stride_w,
- const int dilation_h, const int dilation_w,
- real* data_col) {
- const int height_col = (height + 2 * pad_h -
- (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
- const int width_col = (width + 2 * pad_w -
- (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
- const int channels_col = channels * kernel_h * kernel_w;
- for (int c_col = 0; c_col < channels_col; ++c_col) {
- int w_offset = c_col % kernel_w;
- int h_offset = (c_col / kernel_w) % kernel_h;
- int c_im = c_col / kernel_h / kernel_w;
- for (int h_col = 0; h_col < height_col; ++h_col) {
- for (int w_col = 0; w_col < width_col; ++w_col) {
- int h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
- int w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
- data_col[(c_col * height_col + h_col) * width_col + w_col] =
- (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ?
- data_im[(c_im * height + h_im) * width + w_im] : 0;
- }
- }
- }
-}
-
-static void THNN_(col2im)(const real* data_col, const int channels,
- const int height, const int width, const int kernel_h, const int kernel_w,
- const int pad_h, const int pad_w,
- const int stride_h, const int stride_w,
- const int dilation_h, const int dilation_w,
- real* data_im) {
- memset(data_im, 0, sizeof(real) * height * width * channels);
- const int height_col = (height + 2 * pad_h -
- (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
- const int width_col = (width + 2 * pad_w -
- (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
- const int channels_col = channels * kernel_h * kernel_w;
- for (int c_col = 0; c_col < channels_col; ++c_col) {
- int w_offset = c_col % kernel_w;
- int h_offset = (c_col / kernel_w) % kernel_h;
- int c_im = c_col / kernel_h / kernel_w;
- for (int h_col = 0; h_col < height_col; ++h_col) {
- for (int w_col = 0; w_col < width_col; ++w_col) {
- int h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
- int w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
- if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width)
- data_im[(c_im * height + h_im) * width + w_im] +=
- data_col[(c_col * height_col + h_col) * width_col + w_col];
- }
- }
- }
-}
-
-static inline void THNN_(SpatialFullConvolution_shapeCheck)(
- THTensor *input, THTensor *gradOutput,
- THTensor *weight, THTensor *bias,
- int kH, int kW, int dH, int dW, int padH, int padW, int adjH, int adjW) {
-
- THArgCheck(kW > 0 && kH > 0, 9,
- "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
- THArgCheck(dW > 0 && dH > 0, 11,
- "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
- THArgCheck(adjW < dW && adjH < dH, 15,
- "output adjustment must be smaller than stride, but got adjH: %d adjW: %d dH: %d dW: %d",
- adjH, adjW, dH, dW);
- THNN_ARGCHECK(weight->nDimension == 2 || weight->nDimension == 4, 5, weight,
- "2D or 4D weight tensor expected, but got: %s");
-
- if (bias != NULL) {
- THNN_CHECK_DIM_SIZE(bias, 1, 0, weight->size[1]);
- }
-
- int ndim = input->nDimension;
- int dimf = 0;
- int dimh = 1;
- int dimw = 2;
-
- if (ndim == 4) {
- dimf++;
- dimh++;
- dimw++;
- }
-
- THNN_ARGCHECK(ndim == 3 || ndim == 4, 2, input,
- "3D or 4D input tensor expected but got: %s");
-
- long nInputPlane = weight->size[0];
- long inputHeight = input->size[dimh];
- long inputWidth = input->size[dimw];
- long nOutputPlane = weight->size[1];
- long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH;
- long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW;
-
- if (outputWidth < 1 || outputHeight < 1)
- THError("Given input size: (%d x %d x %d). "
- "Calculated output size: (%d x %d x %d). Output size is too small",
- nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth);
-
- THNN_CHECK_DIM_SIZE(input, ndim, dimf, nInputPlane);
-
- if (gradOutput != NULL) {
- THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimf, nOutputPlane);
- THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimh, outputHeight);
- THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimw, outputWidth);
- }
-}
-
void THNN_(SpatialFullConvolution_updateOutput)(
THNNState *state,
THTensor *input,
@@ -124,118 +15,11 @@ void THNN_(SpatialFullConvolution_updateOutput)(
int padW, int padH,
int adjW, int adjH)
{
- THNN_(SpatialFullConvolution_shapeCheck)
- (input, NULL, weight, bias, kH, kW, dH, dW, padH, padW, adjH, adjW);
-
- int nInputPlane = THTensor_(size)(weight,0);
- int nOutputPlane = THTensor_(size)(weight,1);
-
- input = THTensor_(newContiguous)(input);
- weight = THTensor_(newContiguous)(weight);
- bias = bias ? THTensor_(newContiguous)(bias) : bias;
- int batch = 1;
- if (input->nDimension == 3) {
- // Force batch
- batch = 0;
- THTensor_(resize4d)(input, 1, input->size[0], input->size[1], input->size[2]);
- }
-
- long inputHeight = input->size[2];
- long inputWidth = input->size[3];
- long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH;
- long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW;
-
- // Batch size + input planes
- long batchSize = input->size[0];
-
- // Resize output
- THTensor_(resize4d)(output, batchSize, nOutputPlane, outputHeight, outputWidth);
-
- // Resize temporary columns
- THTensor_(resize2d)(columns, nOutputPlane*kW*kH, inputHeight*inputWidth);
- THTensor_(zero)(columns);
-
- // Define a buffer of ones, for bias accumulation
- // Note: this buffer can be shared with other modules, it only ever gets increased,
- // and always contains ones.
- if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
- // Resize plane and fill with ones...
- THTensor_(resize2d)(ones, outputHeight, outputWidth);
- THTensor_(fill)(ones, 1);
- }
-
- // Helpers
- THTensor *input_n = THTensor_(new)();
- THTensor *output_n = THTensor_(new)();
-
- int elt;
- // For each elt in batch, do:
- for (elt = 0; elt < batchSize; elt ++) {
- // Matrix mulitply per output:
- THTensor_(select)(input_n, input, 0, elt);
- THTensor_(select)(output_n, output, 0, elt);
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long m = weight->size[1] * weight->size[2] * weight->size[3];
- long n = columns->size[1];
- long k = weight->size[0];
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- THBlas_(gemm)(
- 'n', 't',
- n, m, k,
- 1,
- THTensor_(data)(input_n), n,
- THTensor_(data)(weight), m,
- 0,
- THTensor_(data)(columns), n
- );
-
- // Unpack columns back into input:
- THNN_(col2im)(
- THTensor_(data)(columns),
- nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
- 1, 1,
- THTensor_(data)(output_n)
- );
-
- // Do Bias after:
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long m_ = nOutputPlane;
- long n_ = outputHeight * outputWidth;
- long k_ = 1;
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- if (bias) {
- THBlas_(gemm)(
- 't', 'n',
- n_, m_, k_,
- 1,
- THTensor_(data)(ones), k_,
- THTensor_(data)(bias), k_,
- 1,
- THTensor_(data)(output_n), n_
- );
- }
+ THNN_(SpatialFullDilatedConvolution_updateOutput)(
+ state, input, output, weight, bias, columns, ones,
+ kW, kH, dW, dH, padW, padH, 1, 1, adjW, adjH);
}
- // Free
- THTensor_(free)(input_n);
- THTensor_(free)(output_n);
-
- // Resize output
- if (batch == 0) {
- THTensor_(resize3d)(output, nOutputPlane, outputHeight, outputWidth);
- THTensor_(resize3d)(input, nInputPlane, inputHeight, inputWidth);
- }
-
- THTensor_(free)(input);
- THTensor_(free)(weight);
- if (bias) THTensor_(free)(bias);
-}
-
void THNN_(SpatialFullConvolution_updateGradInput)(
THNNState *state,
THTensor *input,
@@ -248,94 +32,11 @@ void THNN_(SpatialFullConvolution_updateGradInput)(
int padW, int padH,
int adjW, int adjH)
{
- THNN_(SpatialFullConvolution_shapeCheck)
- (input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW, adjH, adjW);
-
- int nInputPlane = THTensor_(size)(weight,0);
- int nOutputPlane = THTensor_(size)(weight,1);
-
- input = THTensor_(newContiguous)(input);
- gradOutput = THTensor_(newContiguous)(gradOutput);
- weight = THTensor_(newContiguous)(weight);
- int batch = 1;
- if (input->nDimension == 3) {
- // Force batch
- batch = 0;
- THTensor_(resize4d)(input, 1, input->size[0], input->size[1], input->size[2]);
- THTensor_(resize4d)(gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2]);
- }
-
- long inputWidth = input->size[3];
- long inputHeight = input->size[2];
- long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW;
- long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH;
-
- // Batch size + input planes
- long batchSize = input->size[0];
-
- // Resize output
- THTensor_(resize4d)(gradInput, batchSize, nInputPlane, inputHeight, inputWidth);
- THTensor_(zero)(gradInput);
-
- // Resize temporary columns
- THTensor_(resize2d)(gradColumns, nOutputPlane*kW*kH, inputHeight*inputWidth);
-
- // Helpers
- THTensor *gradInput_n = THTensor_(new)();
- THTensor *gradOutput_n = THTensor_(new)();
-
- int elt;
- // For each elt in batch, do:
- for (elt = 0; elt < batchSize; elt ++) {
- // Matrix mulitply per sample:
- THTensor_(select)(gradInput_n, gradInput, 0, elt);
- THTensor_(select)(gradOutput_n, gradOutput, 0, elt);
-
- // Extract columns:
- THNN_(im2col)(
- THTensor_(data)(gradOutput_n),
- nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
- 1, 1,
- THTensor_(data)(gradColumns)
- );
-
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long m = weight->size[0];
- long n = gradColumns->size[1];
- long k = weight->size[1] * weight->size[2] * weight->size[3];
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- THBlas_(gemm)(
- 'n', 'n',
- n, m, k,
- 1,
- THTensor_(data)(gradColumns), n,
- THTensor_(data)(weight), k,
- 0,
- THTensor_(data)(gradInput_n), n
- );
- }
-
-
- // Free
- THTensor_(free)(gradInput_n);
- THTensor_(free)(gradOutput_n);
-
- // Resize output
- if (batch == 0) {
- THTensor_(resize3d)(gradOutput, nOutputPlane, outputHeight, outputWidth);
- THTensor_(resize3d)(input, nInputPlane, inputHeight, inputWidth);
- THTensor_(resize3d)(gradInput, nInputPlane, inputHeight, inputWidth);
- }
-
- THTensor_(free)(input);
- THTensor_(free)(gradOutput);
- THTensor_(free)(weight);
+ THNN_(SpatialFullDilatedConvolution_updateGradInput)(
+ state, input, gradOutput, gradInput, weight, gradColumns,
+ kW, kH, dW, dH, padW, padH, 1, 1, adjW, adjH);
}
-
void THNN_(SpatialFullConvolution_accGradParameters)(
THNNState *state,
THTensor *input,
@@ -350,113 +51,9 @@ void THNN_(SpatialFullConvolution_accGradParameters)(
int adjW, int adjH,
accreal scale_)
{
- real scale = TH_CONVERT_ACCREAL_TO_REAL(scale_);
- THNN_(SpatialFullConvolution_shapeCheck)
- (input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW, adjH, adjW);
-
- int nInputPlane = THTensor_(size)(gradWeight,0);
- int nOutputPlane = THTensor_(size)(gradWeight,1);
-
- input = THTensor_(newContiguous)(input);
- gradOutput = THTensor_(newContiguous)(gradOutput);
- THArgCheck(THTensor_(isContiguous)(gradWeight), 4, "gradWeight needs to be contiguous");
- if (gradBias)
- THArgCheck(THTensor_(isContiguous)(gradBias), 5, "gradBias needs to be contiguous");
- int batch = 1;
- if (input->nDimension == 3) {
- // Force batch
- batch = 0;
- THTensor_(resize4d)(input, 1, input->size[0], input->size[1], input->size[2]);
- THTensor_(resize4d)(gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2]);
- }
-
- long inputWidth = input->size[3];
- long inputHeight = input->size[2];
- long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW;
- long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH;
-
- // Batch size + input planes
- long batchSize = input->size[0];
-
- // Define a buffer of ones, for bias accumulation
- if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
- // Resize plane and fill with ones...
- THTensor_(resize2d)(ones, outputHeight, outputWidth);
- THTensor_(fill)(ones, 1);
- }
-
- // Resize temporary columns
- THTensor_(resize2d)(columns, nOutputPlane*kW*kH, inputHeight*inputWidth);
-
- // Helpers
- THTensor *input_n = THTensor_(new)();
- THTensor *gradOutput_n = THTensor_(new)();
-
- int elt;
- // For each elt in batch, do:
- for (elt = 0; elt < batchSize; elt ++) {
- // Matrix mulitply per output:
- THTensor_(select)(input_n, input, 0, elt);
- THTensor_(select)(gradOutput_n, gradOutput, 0, elt);
-
- // Extract columns:
- THNN_(im2col)(
- THTensor_(data)(gradOutput_n),
- nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
- 1, 1,
- THTensor_(data)(columns)
- );
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long n = columns->size[0]; // nOutputPlane * kh * kw
- long m = input_n->size[0]; // nInputPlane
- long k = columns->size[1]; // inputHeight * inputWidth
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- THBlas_(gemm)(
- 't', 'n',
- n, m, k,
- scale,
- THTensor_(data)(columns), k,
- THTensor_(data)(input_n), k,
- 1,
- THTensor_(data)(gradWeight), n
- );
-
-
- // Do Bias:
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long m_ = nOutputPlane;
- long k_ = outputHeight * outputWidth;
-
- // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
- if (gradBias) {
- THBlas_(gemv)(
- 't',
- k_, m_,
- scale,
- THTensor_(data)(gradOutput_n), k_,
- THTensor_(data)(ones), 1,
- 1,
- THTensor_(data)(gradBias), 1
- );
- }
- }
-
- // Free
- THTensor_(free)(input_n);
- THTensor_(free)(gradOutput_n);
-
- // Resize
- if (batch == 0) {
- THTensor_(resize3d)(gradOutput, nOutputPlane, outputHeight, outputWidth);
- THTensor_(resize3d)(input, nInputPlane, inputHeight, inputWidth);
- }
-
- THTensor_(free)(input);
- THTensor_(free)(gradOutput);
+THNN_(SpatialFullDilatedConvolution_accGradParameters)(
+ state, input, gradOutput, gradWeight, gradBias, columns, ones,
+ kW, kH, dW, dH, padW, padH, 1, 1, adjW, adjH, scale_);
}
#endif
diff --git a/lib/THNN/generic/SpatialFullDilatedConvolution.c b/lib/THNN/generic/SpatialFullDilatedConvolution.c
new file mode 100644
index 0000000..4d5a3fc
--- /dev/null
+++ b/lib/THNN/generic/SpatialFullDilatedConvolution.c
@@ -0,0 +1,472 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/SpatialFullDilatedConvolution.c"
+#else
+
+static void THNN_(im2col)(const real* data_im, const int channels,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ real* data_col) {
+ const int height_col = (height + 2 * pad_h -
+ (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_col = (width + 2 * pad_w -
+ (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+ const int channels_col = channels * kernel_h * kernel_w;
+ for (int c_col = 0; c_col < channels_col; ++c_col) {
+ int w_offset = c_col % kernel_w;
+ int h_offset = (c_col / kernel_w) % kernel_h;
+ int c_im = c_col / kernel_h / kernel_w;
+ for (int h_col = 0; h_col < height_col; ++h_col) {
+ for (int w_col = 0; w_col < width_col; ++w_col) {
+ int h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
+ int w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
+ data_col[(c_col * height_col + h_col) * width_col + w_col] =
+ (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ?
+ data_im[(c_im * height + h_im) * width + w_im] : 0;
+ }
+ }
+ }
+}
+
+static void THNN_(col2im)(const real* data_col, const int channels,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ real* data_im) {
+ memset(data_im, 0, sizeof(real) * height * width * channels);
+ const int height_col = (height + 2 * pad_h -
+ (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_col = (width + 2 * pad_w -
+ (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+ const int channels_col = channels * kernel_h * kernel_w;
+ for (int c_col = 0; c_col < channels_col; ++c_col) {
+ int w_offset = c_col % kernel_w;
+ int h_offset = (c_col / kernel_w) % kernel_h;
+ int c_im = c_col / kernel_h / kernel_w;
+ for (int h_col = 0; h_col < height_col; ++h_col) {
+ for (int w_col = 0; w_col < width_col; ++w_col) {
+ int h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
+ int w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
+ if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width)
+ data_im[(c_im * height + h_im) * width + w_im] +=
+ data_col[(c_col * height_col + h_col) * width_col + w_col];
+ }
+ }
+ }
+}
+
+static inline void THNN_(SpatialFullDilatedConvolution_shapeCheck)(
+ THTensor *input, THTensor *gradOutput,
+ THTensor *weight, THTensor *bias,
+ int kH, int kW, int dH, int dW, int padH, int padW,
+ int dilationH, int dilationW, int adjH, int adjW) {
+
+ THArgCheck(kW > 0 && kH > 0, 9,
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
+ THArgCheck(dW > 0 && dH > 0, 11,
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+ THArgCheck(adjW < dW && adjH < dH, 15,
+ "output adjustment must be smaller than stride, but got adjH: %d adjW: %d dH: %d dW: %d",
+ adjH, adjW, dH, dW);
+ THArgCheck(dilationW > 0 && dilationH > 0, 15,
+ "dilation should be greater than zero, but got dilationH: %d, dilationW: %d",
+ dilationH, dilationW);
+ THNN_ARGCHECK(weight->nDimension == 2 || weight->nDimension == 4, 5, weight,
+ "2D or 4D weight tensor expected, but got: %s");
+
+ if (bias != NULL) {
+ THNN_CHECK_DIM_SIZE(bias, 1, 0, weight->size[1]);
+ }
+
+ int ndim = input->nDimension;
+ int dimf = 0;
+ int dimh = 1;
+ int dimw = 2;
+
+ if (ndim == 4) {
+ dimf++;
+ dimh++;
+ dimw++;
+ }
+
+ THNN_ARGCHECK(ndim == 3 || ndim == 4, 2, input,
+ "3D or 4D input tensor expected but got: %s");
+
+ long nInputPlane = weight->size[0];
+ long inputHeight = input->size[dimh];
+ long inputWidth = input->size[dimw];
+ long nOutputPlane = weight->size[1];
+ long outputHeight = (inputHeight - 1) * dH - 2*padH + (dilationH * (kH - 1) + 1) + adjH;
+ long outputWidth = (inputWidth - 1) * dW - 2*padW + (dilationW * (kW - 1) + 1) + adjW;
+
+ if (outputWidth < 1 || outputHeight < 1)
+ THError("Given input size: (%d x %d x %d). "
+ "Calculated output size: (%d x %d x %d). Output size is too small",
+ nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth);
+
+ THNN_CHECK_DIM_SIZE(input, ndim, dimf, nInputPlane);
+
+ if (gradOutput != NULL) {
+ THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimf, nOutputPlane);
+ THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimh, outputHeight);
+ THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimw, outputWidth);
+ }
+}
+
+void THNN_(SpatialFullDilatedConvolution_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ THTensor *weight,
+ THTensor *bias,
+ THTensor *columns,
+ THTensor *ones,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH,
+ int adjW, int adjH)
+{
+ THNN_(SpatialFullDilatedConvolution_shapeCheck)
+ (input, NULL, weight, bias, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, adjH, adjW);
+
+ int nInputPlane = THTensor_(size)(weight,0);
+ int nOutputPlane = THTensor_(size)(weight,1);
+
+ input = THTensor_(newContiguous)(input);
+ weight = THTensor_(newContiguous)(weight);
+ bias = bias ? THTensor_(newContiguous)(bias) : bias;
+ int batch = 1;
+ if (input->nDimension == 3) {
+ // Force batch
+ batch = 0;
+ THTensor_(resize4d)(input, 1, input->size[0], input->size[1], input->size[2]);
+ }
+
+ long inputHeight = input->size[2];
+ long inputWidth = input->size[3];
+ long outputHeight = (inputHeight - 1) * dH - 2*padH + (dilationH * (kH - 1) + 1) + adjH;
+ long outputWidth = (inputWidth - 1) * dW - 2*padW + (dilationW * (kW - 1) + 1) + adjW;
+
+ // Batch size + input planes
+ long batchSize = input->size[0];
+
+ // Resize output
+ THTensor_(resize4d)(output, batchSize, nOutputPlane, outputHeight, outputWidth);
+
+ // Resize temporary columns
+ THTensor_(resize2d)(columns, nOutputPlane*kW*kH, inputHeight*inputWidth);
+ THTensor_(zero)(columns);
+
+ // Define a buffer of ones, for bias accumulation
+ // Note: this buffer can be shared with other modules, it only ever gets increased,
+ // and always contains ones.
+ if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
+ // Resize plane and fill with ones...
+ THTensor_(resize2d)(ones, outputHeight, outputWidth);
+ THTensor_(fill)(ones, 1);
+ }
+
+ // Helpers
+ THTensor *input_n = THTensor_(new)();
+ THTensor *output_n = THTensor_(new)();
+
+ int elt;
+ // For each elt in batch, do:
+ for (elt = 0; elt < batchSize; elt ++) {
+ // Matrix mulitply per output:
+ THTensor_(select)(input_n, input, 0, elt);
+ THTensor_(select)(output_n, output, 0, elt);
+
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m = weight->size[1] * weight->size[2] * weight->size[3];
+ long n = columns->size[1];
+ long k = weight->size[0];
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ THBlas_(gemm)(
+ 'n', 't',
+ n, m, k,
+ 1,
+ THTensor_(data)(input_n), n,
+ THTensor_(data)(weight), m,
+ 0,
+ THTensor_(data)(columns), n
+ );
+
+ // Unpack columns back into input:
+ THNN_(col2im)(
+ THTensor_(data)(columns),
+ nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW,
+ THTensor_(data)(output_n)
+ );
+
+ // Do Bias after:
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m_ = nOutputPlane;
+ long n_ = outputHeight * outputWidth;
+ long k_ = 1;
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ if (bias) {
+ THBlas_(gemm)(
+ 't', 'n',
+ n_, m_, k_,
+ 1,
+ THTensor_(data)(ones), k_,
+ THTensor_(data)(bias), k_,
+ 1,
+ THTensor_(data)(output_n), n_
+ );
+ }
+ }
+
+ // Free
+ THTensor_(free)(input_n);
+ THTensor_(free)(output_n);
+
+ // Resize output
+ if (batch == 0) {
+ THTensor_(resize3d)(output, nOutputPlane, outputHeight, outputWidth);
+ THTensor_(resize3d)(input, nInputPlane, inputHeight, inputWidth);
+ }
+
+ THTensor_(free)(input);
+ THTensor_(free)(weight);
+ if (bias) THTensor_(free)(bias);
+}
+
+void THNN_(SpatialFullDilatedConvolution_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ THTensor *weight,
+ THTensor *gradColumns,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH,
+ int adjW, int adjH)
+{
+ THNN_(SpatialFullDilatedConvolution_shapeCheck)
+ (input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, adjH, adjW);
+
+ int nInputPlane = THTensor_(size)(weight,0);
+ int nOutputPlane = THTensor_(size)(weight,1);
+
+ input = THTensor_(newContiguous)(input);
+ gradOutput = THTensor_(newContiguous)(gradOutput);
+ weight = THTensor_(newContiguous)(weight);
+ int batch = 1;
+ if (input->nDimension == 3) {
+ // Force batch
+ batch = 0;
+ THTensor_(resize4d)(input, 1, input->size[0], input->size[1], input->size[2]);
+ THTensor_(resize4d)(gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2]);
+ }
+
+ long inputWidth = input->size[3];
+ long inputHeight = input->size[2];
+ long outputHeight = (inputHeight - 1) * dH - 2*padH + (dilationH * (kH - 1) + 1) + adjH;
+ long outputWidth = (inputWidth - 1) * dW - 2*padW + (dilationW * (kW - 1) + 1) + adjW;
+
+ // Batch size + input planes
+ long batchSize = input->size[0];
+
+ // Resize output
+ THTensor_(resize4d)(gradInput, batchSize, nInputPlane, inputHeight, inputWidth);
+ THTensor_(zero)(gradInput);
+
+ // Resize temporary columns
+ THTensor_(resize2d)(gradColumns, nOutputPlane*kW*kH, inputHeight*inputWidth);
+
+ // Helpers
+ THTensor *gradInput_n = THTensor_(new)();
+ THTensor *gradOutput_n = THTensor_(new)();
+
+ int elt;
+ // For each elt in batch, do:
+ for (elt = 0; elt < batchSize; elt ++) {
+ // Matrix mulitply per sample:
+ THTensor_(select)(gradInput_n, gradInput, 0, elt);
+ THTensor_(select)(gradOutput_n, gradOutput, 0, elt);
+
+ // Extract columns:
+ THNN_(im2col)(
+ THTensor_(data)(gradOutput_n),
+ nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW,
+ THTensor_(data)(gradColumns)
+ );
+
+
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m = weight->size[0];
+ long n = gradColumns->size[1];
+ long k = weight->size[1] * weight->size[2] * weight->size[3];
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ THBlas_(gemm)(
+ 'n', 'n',
+ n, m, k,
+ 1,
+ THTensor_(data)(gradColumns), n,
+ THTensor_(data)(weight), k,
+ 0,
+ THTensor_(data)(gradInput_n), n
+ );
+ }
+
+
+ // Free
+ THTensor_(free)(gradInput_n);
+ THTensor_(free)(gradOutput_n);
+
+ // Resize output
+ if (batch == 0) {
+ THTensor_(resize3d)(gradOutput, nOutputPlane, outputHeight, outputWidth);
+ THTensor_(resize3d)(input, nInputPlane, inputHeight, inputWidth);
+ THTensor_(resize3d)(gradInput, nInputPlane, inputHeight, inputWidth);
+ }
+
+ THTensor_(free)(input);
+ THTensor_(free)(gradOutput);
+ THTensor_(free)(weight);
+}
+
+
+void THNN_(SpatialFullDilatedConvolution_accGradParameters)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradWeight,
+ THTensor *gradBias,
+ THTensor *columns,
+ THTensor *ones,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH,
+ int adjW, int adjH,
+ accreal scale_)
+{
+ real scale = TH_CONVERT_ACCREAL_TO_REAL(scale_);
+ THNN_(SpatialFullDilatedConvolution_shapeCheck)
+ (input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, adjH, adjW);
+
+ int nInputPlane = THTensor_(size)(gradWeight,0);
+ int nOutputPlane = THTensor_(size)(gradWeight,1);
+
+ input = THTensor_(newContiguous)(input);
+ gradOutput = THTensor_(newContiguous)(gradOutput);
+ THArgCheck(THTensor_(isContiguous)(gradWeight), 4, "gradWeight needs to be contiguous");
+ if (gradBias)
+ THArgCheck(THTensor_(isContiguous)(gradBias), 5, "gradBias needs to be contiguous");
+ int batch = 1;
+ if (input->nDimension == 3) {
+ // Force batch
+ batch = 0;
+ THTensor_(resize4d)(input, 1, input->size[0], input->size[1], input->size[2]);
+ THTensor_(resize4d)(gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2]);
+ }
+
+ long inputWidth = input->size[3];
+ long inputHeight = input->size[2];
+ long outputHeight = (inputHeight - 1) * dH - 2*padH + (dilationH * (kH - 1) + 1) + adjH;
+ long outputWidth = (inputWidth - 1) * dW - 2*padW + (dilationW * (kW - 1) + 1) + adjW;
+
+ // Batch size + input planes
+ long batchSize = input->size[0];
+
+ // Define a buffer of ones, for bias accumulation
+ if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
+ // Resize plane and fill with ones...
+ THTensor_(resize2d)(ones, outputHeight, outputWidth);
+ THTensor_(fill)(ones, 1);
+ }
+
+ // Resize temporary columns
+ THTensor_(resize2d)(columns, nOutputPlane*kW*kH, inputHeight*inputWidth);
+
+ // Helpers
+ THTensor *input_n = THTensor_(new)();
+ THTensor *gradOutput_n = THTensor_(new)();
+
+ int elt;
+ // For each elt in batch, do:
+ for (elt = 0; elt < batchSize; elt ++) {
+ // Matrix mulitply per output:
+ THTensor_(select)(input_n, input, 0, elt);
+ THTensor_(select)(gradOutput_n, gradOutput, 0, elt);
+
+ // Extract columns:
+ THNN_(im2col)(
+ THTensor_(data)(gradOutput_n),
+ nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW,
+ THTensor_(data)(columns)
+ );
+
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long n = columns->size[0]; // nOutputPlane * kh * kw
+ long m = input_n->size[0]; // nInputPlane
+ long k = columns->size[1]; // inputHeight * inputWidth
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ THBlas_(gemm)(
+ 't', 'n',
+ n, m, k,
+ scale,
+ THTensor_(data)(columns), k,
+ THTensor_(data)(input_n), k,
+ 1,
+ THTensor_(data)(gradWeight), n
+ );
+
+
+ // Do Bias:
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m_ = nOutputPlane;
+ long k_ = outputHeight * outputWidth;
+
+ // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
+ if (gradBias) {
+ THBlas_(gemv)(
+ 't',
+ k_, m_,
+ scale,
+ THTensor_(data)(gradOutput_n), k_,
+ THTensor_(data)(ones), 1,
+ 1,
+ THTensor_(data)(gradBias), 1
+ );
+ }
+ }
+
+ // Free
+ THTensor_(free)(input_n);
+ THTensor_(free)(gradOutput_n);
+
+ // Resize
+ if (batch == 0) {
+ THTensor_(resize3d)(gradOutput, nOutputPlane, outputHeight, outputWidth);
+ THTensor_(resize3d)(input, nInputPlane, inputHeight, inputWidth);
+ }
+
+ THTensor_(free)(input);
+ THTensor_(free)(gradOutput);
+}
+
+#endif
diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h
index ad4ea51..37b094b 100644
--- a/lib/THNN/generic/THNN.h
+++ b/lib/THNN/generic/THNN.h
@@ -1060,6 +1060,48 @@ TH_API void THNN_(SpatialDilatedConvolution_accGradParameters)(
int dilationW, int dilationH,
accreal scale);
+TH_API void THNN_(SpatialFullDilatedConvolution_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ THTensor *weight,
+ THTensor *bias, // [OPTIONAL]
+ THTensor *columns,
+ THTensor *ones,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH,
+ int adjW, int adjH);
+
+TH_API void THNN_(SpatialFullDilatedConvolution_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ THTensor *weight,
+ THTensor *gradColumns,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH,
+ int adjW, int adjH);
+
+TH_API void THNN_(SpatialFullDilatedConvolution_accGradParameters)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradWeight,
+ THTensor *gradBias, // [OPTIONAL]
+ THTensor *columns,
+ THTensor *ones,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH,
+ int adjW, int adjH,
+ accreal scale);
+
TH_API void THNN_(SpatialMaxPooling_updateOutput)(
THNNState *state,
THTensor *input,
@@ -1371,6 +1413,45 @@ TH_API void THNN_(VolumetricDilatedConvolution_accGradParameters)(
int dilationT, int dilationW, int dilationH,
accreal scale);
+TH_API void THNN_(VolumetricFullDilatedConvolution_updateOutput)(
+ THNNState *state, // library state
+ THTensor *input, // 4D or 5D (batch) tensor
+ THTensor *output, // [OUT] volumetric convolution output
+ THTensor *weight, // weight tensor (nInputPlane x nOutputPlane x kT x kH x kW)
+ THTensor *bias, // [OPTIONAL] gradBias tensor (nOutputPlane)
+ THTensor *finput, // [OUT] internal columns buffer
+ THTensor *fgradInput, // [OUT] internal ones buffer
+ int dT, int dW, int dH, // stride of the convolution
+ int pT, int pW, int pH, // padding
+ int dilationT, int dilationW, int dilationH,
+ int aT, int aW, int aH); // extra output adjustment
+TH_API void THNN_(VolumetricFullDilatedConvolution_updateGradInput)(
+ THNNState *state, // library state
+ THTensor *input, // 4D or 5D (batch) tensor
+ THTensor *gradOutput, // gradient w.r.t. output
+ THTensor *gradInput, // [OUT] gradient w.r.t. input
+ THTensor *weight, // weight tensor (nInputPlane x nOutputPlane x kT x kH x kW)
+ THTensor *finput, // internal columns buffer
+ THTensor *fgradInput, // internal ones buffer
+ int dT, int dW, int dH, // stride
+ int pT, int pW, int pH, // padding
+ int dilationT, int dilationW, int dilationH,
+ int aT, int aW, int aH); // extra output adjustment
+
+TH_API void THNN_(VolumetricFullDilatedConvolution_accGradParameters)(
+ THNNState *state, // library state
+ THTensor *input, // 4D or 5D (batch) tensor
+ THTensor *gradOutput, // gradient w.r.t. output
+ THTensor *gradWeight, // gradWeight tensor (nInputPlane x nOutputPlane x kT x kH x kW)
+ THTensor *gradBias, // [OPTIONAL] gradBias tensor (nOutputPlane)
+ THTensor *finput, // internal columns buffer
+ THTensor *fgradInput, // internal ones buffer
+ int dT, int dW, int dH, // stride
+ int pT, int pW, int pH, // padding
+ int dilationT, int dilationW, int dilationH,
+ int aT, int aW, int aH, // extra output adjustment
+ accreal scale); // scaling factor
+
TH_API void THNN_(VolumetricMaxPooling_updateOutput)(
THNNState *state,
THTensor *input,
diff --git a/lib/THNN/generic/VolumetricFullConvolution.c b/lib/THNN/generic/VolumetricFullConvolution.c
index c974fab..cef3c7f 100644
--- a/lib/THNN/generic/VolumetricFullConvolution.c
+++ b/lib/THNN/generic/VolumetricFullConvolution.c
@@ -2,150 +2,6 @@
#define TH_GENERIC_FILE "generic/VolumetricFullConvolution.c"
#else
-static void THNN_(vol2col)(
- const real *data_vol, const int channels,
- const int depth, const int height, const int width,
- const int kT, const int kH, const int kW,
- const int pT, const int pH, const int pW,
- const int dT, const int dH, const int dW,
- const int dilationT, const int dilationH, const int dilationW,
- real *data_col)
-{
- int c, t, h, w;
- int depth_col = (depth + 2 * pT - (dilationT * (kT - 1) + 1)) / dT + 1;
- int height_col = (height + 2 * pH - (dilationH * (kH - 1) + 1)) / dH + 1;
- int width_col = (width + 2 * pW - (dilationW * (kW - 1) + 1)) / dW + 1;
- int channels_col = channels * kT * kH * kW;
- for (c = 0; c < channels_col; ++c)
- {
- int w_offset = c % kW;
- int h_offset = (c / kW) % kH;
- int t_offset = (c / kW / kH) % kT;
- int c_vol = c / kT / kH / kW;
- for (t = 0; t < depth_col; ++t)
- {
- for (h = 0; h < height_col; ++h)
- {
- for (w = 0; w < width_col; ++w)
- {
- int t_pad = t * dT - pT + t_offset * dilationT;
- int h_pad = h * dH - pH + h_offset * dilationH;
- int w_pad = w * dW - pW + w_offset * dilationW;
- if (t_pad >= 0 && t_pad < depth &&
- h_pad >= 0 && h_pad < height &&
- w_pad >= 0 && w_pad < width)
- data_col[((c * depth_col + t) * height_col + h) * width_col + w] =
- data_vol[((c_vol * depth + t_pad) * height + h_pad) * width + w_pad];
- else
- data_col[((c * depth_col + t) * height_col + h) * width_col + w] = 0;
- }
- }
- }
- }
-}
-
-static void THNN_(col2vol)(
- const real* data_col, const int channels,
- const int depth, const int height, const int width,
- const int kT, const int kH, const int kW,
- const int pT, const int pH, const int pW,
- const int dT, const int dH, const int dW,
- const int dilationT, const int dilationH, const int dilationW,
- real* data_vol)
-{
- int c, t, h, w;
- memset(data_vol, 0, sizeof(real) * depth * height * width * channels);
- int depth_col = (depth + 2 * pT - (dilationT * (kT - 1) + 1)) / dT + 1;
- int height_col = (height + 2 * pH - (dilationH * (kH - 1) + 1)) / dH + 1;
- int width_col = (width + 2 * pW - (dilationW * (kW - 1) + 1)) / dW + 1;
- int channels_col = channels * kT * kH * kW;
- for (c = 0; c < channels_col; ++c)
- {
- int w_offset = c % kW;
- int h_offset = (c / kW) % kH;
- int t_offset = (c / kW / kH) % kT;
- int c_vol = c / kT / kH / kW;
- for (t = 0; t < depth_col; ++t)
- {
- for (h = 0; h < height_col; ++h)
- {
- for (w = 0; w < width_col; ++w)
- {
- int t_pad = t * dT - pT + t_offset * dilationT;
- int h_pad = h * dH - pH + h_offset * dilationH;
- int w_pad = w * dW - pW + w_offset * dilationW;
- if (t_pad >= 0 && t_pad < depth &&
- h_pad >= 0 && h_pad < height &&
- w_pad >= 0 && w_pad < width)
- data_vol[((c_vol * depth + t_pad) * height + h_pad) * width + w_pad] +=
- data_col[((c * depth_col + t) * height_col + h) * width_col + w];
- }
- }
- }
- }
-}
-
-static inline void THNN_(VolumetricFullConvolution_shapeCheck)(
- THTensor *input, THTensor *gradOutput,
- THTensor *weight, THTensor *bias,
- int dT, int dW, int dH, int pT, int pW, int pH,
- int aT, int aW, int aH) {
- THNN_ARGCHECK(input->nDimension == 4 || input->nDimension == 5, 2, input,
- "4D or 5D (batch mode) tensor expected for input, but got: %s");
- // number of input & output planes and kernel size is indirectly defined by the weight tensor
- THNN_ARGCHECK(weight->nDimension == 5, 4, weight,
- "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor "
- "expected for weight, but got: %s");
- THArgCheck(dT > 0 && dW > 0 && dH > 0, 11,
- "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW);
- THArgCheck(aT < dT && aW < dW && aH < dH, 15,
- "output adjustment must be smaller than stride, but got "
- "adjT: %d adjH: %d adjW: %d dT: %d dH: %d dW: %d",
- aT, aH, aW, dT, dH, dW);
-
- int ndim = input->nDimension;
- const int nInputPlane = (int)weight->size[0];
- const int nOutputPlane = (int)weight->size[1];
- const int kT = (int)weight->size[2];
- const int kH = (int)weight->size[3];
- const int kW = (int)weight->size[4];
-
- if (bias != NULL) {
- THNN_CHECK_DIM_SIZE(bias, 1, 0, weight->size[1]);
- }
-
- int dimf = 0;
- int dimd = 1;
- int dimh = 2;
- int dimw = 3;
-
- if (ndim == 5) {
- dimf++;
- dimd++;
- dimh++;
- dimw++;
- }
-
- const long inputWidth = input->size[dimw];
- const long inputHeight = input->size[dimh];
- const long inputDepth = input->size[dimd];
- const long outputWidth = (inputWidth - 1) * dW - 2*pW + kW + aW;
- const long outputHeight = (inputHeight - 1) * dH - 2*pH + kH + aH;
- const long outputDepth = (inputDepth - 1) * dT - 2*pT + kT + aT;
-
- if (outputDepth < 1 || outputWidth < 1 || outputHeight < 1)
- THError("Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small",
- nInputPlane,inputDepth,inputHeight,inputWidth,nOutputPlane,outputDepth,outputHeight,outputWidth);
-
- THNN_CHECK_DIM_SIZE(input, ndim, dimf, nInputPlane);
- if (gradOutput != NULL) {
- THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimf, nOutputPlane);
- THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimd, outputDepth);
- THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimh, outputHeight);
- THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimw, outputWidth);
- }
-}
-
void THNN_(VolumetricFullConvolution_updateOutput)(
THNNState *state,
THTensor *input, // 4D or 5D (batch) tensor
@@ -158,132 +14,9 @@ void THNN_(VolumetricFullConvolution_updateOutput)(
int pT, int pW, int pH, // padding
int aT, int aW, int aH) // extra output adjustment
{
- THTensor *columns = finput;
- THTensor *ones = fgradInput;
-
- THNN_(VolumetricFullConvolution_shapeCheck)(
- input, NULL, weight, bias,
- dT, dW, dH, pT, pW, pH, aT, aW, aH);
-
- const int nInputPlane = (int)weight->size[0];
- const int nOutputPlane = (int)weight->size[1];
- const int kT = (int)weight->size[2];
- const int kH = (int)weight->size[3];
- const int kW = (int)weight->size[4];
-
- input = THTensor_(newContiguous)(input);
- weight = THTensor_(newContiguous)(weight);
- bias = bias ? THTensor_(newContiguous)(bias) : bias;
- int batch = 1;
- if (input->nDimension == 4)
- {
- // Force batch
- batch = 0;
- THTensor_(resize5d)(input, 1, input->size[0], input->size[1], input->size[2], input->size[3]);
- }
-
- const long inputWidth = input->size[4];
- const long inputHeight = input->size[3];
- const long inputDepth = input->size[2];
- const long outputWidth = (inputWidth - 1) * dW - 2*pW + kW + aW;
- const long outputHeight = (inputHeight - 1) * dH - 2*pH + kH + aH;
- const long outputDepth = (inputDepth - 1) * dT - 2*pT + kT + aT;
-
- // Batch size + input planes
- const long batchSize = input->size[0];
-
- // Resize output
- THTensor_(resize5d)(output, batchSize, nOutputPlane, outputDepth, outputHeight, outputWidth);
-
- // Resize temporary columns
- THTensor_(resize2d)(columns, nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth);
- THTensor_(zero)(columns);
-
- // Define a buffer of ones, for bias accumulation
- // Note: this buffer can be shared with other modules, it only ever gets increased,
- // and always contains ones.
- if (ones->nDimension != 3 || ones->size[0]*ones->size[1]*ones->size[2] < outputDepth*outputHeight*outputWidth)
- {
- // Resize plane and fill with ones...
- THTensor_(resize3d)(ones, outputDepth, outputHeight, outputWidth);
- THTensor_(fill)(ones, 1);
- }
-
- // Helpers
- THTensor *input_n = THTensor_(new)();
- THTensor *output_n = THTensor_(new)();
-
- int elt;
- // For each elt in batch, do:
- for (elt = 0; elt < batchSize; ++elt)
- {
- // Matrix mulitply per output:
- THTensor_(select)(input_n, input, 0, elt);
- THTensor_(select)(output_n, output, 0, elt);
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- const long m = weight->size[1] * weight->size[2] * weight->size[3] * weight->size[4];
- const long n = columns->size[1];
- const long k = weight->size[0];
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- THBlas_(gemm)(
- 'n', 't',
- n, m, k,
- 1,
- THTensor_(data)(input_n), n,
- THTensor_(data)(weight), m,
- 0,
- THTensor_(data)(columns), n
- );
-
- // Unpack columns back into input:
- THNN_(col2vol)(
- THTensor_(data)(columns),
- nOutputPlane, outputDepth, outputHeight, outputWidth,
- kT, kH, kW,
- pT, pH, pW,
- dT, dH, dW,
- 1, 1, 1,
- THTensor_(data)(output_n)
- );
-
- // Do Bias after:
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- const long m_ = nOutputPlane;
- const long n_ = outputDepth * outputHeight * outputWidth;
- const long k_ = 1;
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- if (bias) {
- THBlas_(gemm)(
- 't', 'n',
- n_, m_, k_,
- 1,
- THTensor_(data)(ones), k_,
- THTensor_(data)(bias), k_,
- 1,
- THTensor_(data)(output_n), n_
- );
- }
- }
-
- // Free
- THTensor_(free)(input_n);
- THTensor_(free)(output_n);
-
- // Resize output
- if (batch == 0)
- {
- THTensor_(resize4d)(output, nOutputPlane, outputDepth, outputHeight, outputWidth);
- THTensor_(resize4d)(input, nInputPlane, inputDepth, inputHeight, inputWidth);
- }
-
- THTensor_(free)(input);
- THTensor_(free)(weight);
- if (bias) THTensor_(free)(bias);
+ THNN_(VolumetricFullDilatedConvolution_updateOutput)(
+ state, input, output, weight, bias, finput, fgradInput,
+ dT, dW, dH, pT, pW, pH, 1, 1, 1, aT, aW, aH);
}
void THNN_(VolumetricFullConvolution_updateGradInput)(
@@ -298,105 +31,9 @@ void THNN_(VolumetricFullConvolution_updateGradInput)(
int pT, int pW, int pH, // padding
int aT, int aW, int aH) // extra output adjustment
{
- THTensor *gradColumns = finput;
-
- // number of input & output planes and kernel size is indirectly defined by the weight tensor
- THNN_(VolumetricFullConvolution_shapeCheck)(
- input, gradOutput, weight, NULL,
- dT, dW, dH, pT, pW, pH, aT, aW, aH);
-
- const int nInputPlane = (int)weight->size[0];
- const int nOutputPlane = (int)weight->size[1];
- const int kT = (int)weight->size[2];
- const int kH = (int)weight->size[3];
- const int kW = (int)weight->size[4];
-
- input = THTensor_(newContiguous)(input);
- weight = THTensor_(newContiguous)(weight);
- gradOutput = THTensor_(newContiguous)(gradOutput);
-
- int batch = 1;
- if (input->nDimension == 4)
- {
- // Force batch
- batch = 0;
- THTensor_(resize5d)(input, 1, input->size[0], input->size[1], input->size[2], input->size[3]);
- THTensor_(resize5d)(gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
- }
-
- const long inputWidth = input->size[4];
- const long inputHeight = input->size[3];
- const long inputDepth = input->size[2];
- const long outputWidth = (inputWidth - 1) * dW - 2*pW + kW + aW;
- const long outputHeight = (inputHeight - 1) * dH - 2*pH + kH + aH;
- const long outputDepth = (inputDepth - 1) * dT - 2*pT + kT + aT;
-
- // Batch size + input planes
- const long batchSize = input->size[0];
-
- // Resize output
- THTensor_(resize5d)(gradInput, batchSize, nInputPlane, inputDepth, inputHeight, inputWidth);
- THTensor_(zero)(gradInput);
-
- // Resize temporary columns
- THTensor_(resize2d)(gradColumns, nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth);
-
- // Helpers
- THTensor *gradInput_n = THTensor_(new)();
- THTensor *gradOutput_n = THTensor_(new)();
-
- int elt;
- // For each elt in batch, do:
- for (elt = 0; elt < batchSize; ++elt)
- {
- // Matrix mulitply per sample:
- THTensor_(select)(gradInput_n, gradInput, 0, elt);
- THTensor_(select)(gradOutput_n, gradOutput, 0, elt);
-
- // Extract columns:
- THNN_(vol2col)(
- THTensor_(data)(gradOutput_n),
- nOutputPlane, outputDepth, outputHeight, outputWidth,
- kT, kH, kW,
- pT, pH, pW,
- dT, dH, dW,
- 1, 1, 1,
- THTensor_(data)(gradColumns)
- );
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- const long m = weight->size[0];
- const long n = gradColumns->size[1];
- const long k = weight->size[1] * weight->size[2] * weight->size[3] * weight->size[4];
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- THBlas_(gemm)(
- 'n', 'n',
- n, m, k,
- 1,
- THTensor_(data)(gradColumns), n,
- THTensor_(data)(weight), k,
- 0,
- THTensor_(data)(gradInput_n), n
- );
- }
-
- // Free
- THTensor_(free)(gradInput_n);
- THTensor_(free)(gradOutput_n);
-
- // Resize output
- if (batch == 0)
- {
- THTensor_(resize4d)(gradOutput, nOutputPlane, outputDepth, outputHeight, outputWidth);
- THTensor_(resize4d)(input, nInputPlane, inputDepth, inputHeight, inputWidth);
- THTensor_(resize4d)(gradInput, nInputPlane, inputDepth, inputHeight, inputWidth);
- }
-
- THTensor_(free)(input);
- THTensor_(free)(gradOutput);
- THTensor_(free)(weight);
+ THNN_(VolumetricFullDilatedConvolution_updateGradInput)(
+ state, input, gradOutput, gradInput, weight, finput, fgradInput,
+ dT, dW, dH, pT, pW, pH, 1, 1, 1, aT, aW, aH);
}
void THNN_(VolumetricFullConvolution_accGradParameters)(
@@ -412,130 +49,9 @@ void THNN_(VolumetricFullConvolution_accGradParameters)(
int aT, int aW, int aH, // extra output adjustment
accreal scale_)
{
- real scale = TH_CONVERT_ACCREAL_TO_REAL(scale_);
- // number of input & output planes and kernel size is indirectly defined by the gradWeight tensor
- THNN_(VolumetricFullConvolution_shapeCheck)(
- input, gradOutput, gradWeight, gradBias,
- dT, dW, dH, pT, pW, pH, aT, aW, aH);
-
- int nInputPlane = (int)gradWeight->size[0];
- int nOutputPlane = (int)gradWeight->size[1];
- int kT = (int)gradWeight->size[2];
- int kH = (int)gradWeight->size[3];
- int kW = (int)gradWeight->size[4];
-
- THTensor *columns = finput;
- THTensor *ones = fgradInput;
-
- input = THTensor_(newContiguous)(input);
- gradOutput = THTensor_(newContiguous)(gradOutput);
- THArgCheck(THTensor_(isContiguous)(gradWeight), 4, "gradWeight needs to be contiguous");
- if (gradBias)
- THArgCheck(THTensor_(isContiguous)(gradBias), 5, "gradBias needs to be contiguous");
-
- int batch = 1;
- if (input->nDimension == 4)
- {
- // Force batch
- batch = 0;
- THTensor_(resize5d)(input, 1, input->size[0], input->size[1], input->size[2], input->size[3]);
- THTensor_(resize5d)(gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
- }
-
- const long inputWidth = input->size[4];
- const long inputHeight = input->size[3];
- const long inputDepth = input->size[2];
- const long outputWidth = (inputWidth - 1) * dW - 2*pW + kW + aW;
- const long outputHeight = (inputHeight - 1) * dH - 2*pH + kH + aH;
- const long outputDepth = (inputDepth - 1) * dT - 2*pT + kT + aT;
-
- // Batch size + input planes
- const long batchSize = input->size[0];
-
- // Define a buffer of ones, for bias accumulation
- if (ones->nDimension != 3 || ones->size[0]*ones->size[1]*ones->size[2] < outputDepth*outputHeight*outputWidth)
- {
- // Resize plane and fill with ones...
- THTensor_(resize3d)(ones, outputDepth, outputHeight, outputWidth);
- THTensor_(fill)(ones, 1);
- }
-
- // Resize temporary columns
- THTensor_(resize2d)(columns, nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth);
-
- // Helpers
- THTensor *input_n = THTensor_(new)();
- THTensor *gradOutput_n = THTensor_(new)();
-
- int elt;
- // For each elt in batch, do:
- for (elt = 0; elt < batchSize; ++elt)
- {
- // Matrix mulitply per output:
- THTensor_(select)(input_n, input, 0, elt);
- THTensor_(select)(gradOutput_n, gradOutput, 0, elt);
-
- // Extract columns:
- THNN_(vol2col)(
- THTensor_(data)(gradOutput_n), nOutputPlane,
- outputDepth, outputHeight, outputWidth,
- kT, kH, kW,
- pT, pH, pW,
- dT, dH, dW,
- 1, 1, 1,
- THTensor_(data)(columns)
- );
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- const long n = columns->size[0]; // nOutputPlane * kt * kh * kw
- const long m = input_n->size[0]; // nInputPlane
- const long k = columns->size[1]; // inputHeight * inputWidth
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- THBlas_(gemm)(
- 't', 'n',
- n, m, k,
- scale,
- THTensor_(data)(columns), k,
- THTensor_(data)(input_n), k,
- 1,
- THTensor_(data)(gradWeight), n
- );
-
- // Do Bias:
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- const long m_ = nOutputPlane;
- const long k_ = outputDepth * outputHeight * outputWidth;
-
- // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
- if (gradBias) {
- THBlas_(gemv)(
- 't',
- k_, m_,
- scale,
- THTensor_(data)(gradOutput_n), k_,
- THTensor_(data)(ones), 1,
- 1,
- THTensor_(data)(gradBias), 1
- );
- }
- }
-
- // Free
- THTensor_(free)(input_n);
- THTensor_(free)(gradOutput_n);
-
- // Resize
- if (batch == 0)
- {
- THTensor_(resize4d)(gradOutput, nOutputPlane, outputDepth, outputHeight, outputWidth);
- THTensor_(resize4d)(input, nInputPlane, inputDepth, inputHeight, inputWidth);
- }
-
- THTensor_(free)(input);
- THTensor_(free)(gradOutput);
+ THNN_(VolumetricFullDilatedConvolution_accGradParameters)(
+ state, input, gradOutput, gradWeight, gradBias, finput, fgradInput,
+ dT, dW, dH, pT, pW, pH, 1, 1, 1, aT, aW, aH, scale_);
}
#endif
diff --git a/lib/THNN/generic/VolumetricFullDilatedConvolution.c b/lib/THNN/generic/VolumetricFullDilatedConvolution.c
new file mode 100644
index 0000000..4e22d38
--- /dev/null
+++ b/lib/THNN/generic/VolumetricFullDilatedConvolution.c
@@ -0,0 +1,548 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/VolumetricFullDilatedConvolution.c"
+#else
+
+static void THNN_(vol2col)(
+ const real *data_vol, const int channels,
+ const int depth, const int height, const int width,
+ const int kT, const int kH, const int kW,
+ const int pT, const int pH, const int pW,
+ const int dT, const int dH, const int dW,
+ const int dilationT, const int dilationH, const int dilationW,
+ real *data_col)
+{
+ int c, t, h, w;
+ int depth_col = (depth + 2 * pT - (dilationT * (kT - 1) + 1)) / dT + 1;
+ int height_col = (height + 2 * pH - (dilationH * (kH - 1) + 1)) / dH + 1;
+ int width_col = (width + 2 * pW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ int channels_col = channels * kT * kH * kW;
+ for (c = 0; c < channels_col; ++c)
+ {
+ int w_offset = c % kW;
+ int h_offset = (c / kW) % kH;
+ int t_offset = (c / kW / kH) % kT;
+ int c_vol = c / kT / kH / kW;
+ for (t = 0; t < depth_col; ++t)
+ {
+ for (h = 0; h < height_col; ++h)
+ {
+ for (w = 0; w < width_col; ++w)
+ {
+ int t_pad = t * dT - pT + t_offset * dilationT;
+ int h_pad = h * dH - pH + h_offset * dilationH;
+ int w_pad = w * dW - pW + w_offset * dilationW;
+ if (t_pad >= 0 && t_pad < depth &&
+ h_pad >= 0 && h_pad < height &&
+ w_pad >= 0 && w_pad < width)
+ data_col[((c * depth_col + t) * height_col + h) * width_col + w] =
+ data_vol[((c_vol * depth + t_pad) * height + h_pad) * width + w_pad];
+ else
+ data_col[((c * depth_col + t) * height_col + h) * width_col + w] = 0;
+ }
+ }
+ }
+ }
+}
+
+static void THNN_(col2vol)(
+ const real* data_col, const int channels,
+ const int depth, const int height, const int width,
+ const int kT, const int kH, const int kW,
+ const int pT, const int pH, const int pW,
+ const int dT, const int dH, const int dW,
+ const int dilationT, const int dilationH, const int dilationW,
+ real* data_vol)
+{
+ int c, t, h, w;
+ memset(data_vol, 0, sizeof(real) * depth * height * width * channels);
+ int depth_col = (depth + 2 * pT - (dilationT * (kT - 1) + 1)) / dT + 1;
+ int height_col = (height + 2 * pH - (dilationH * (kH - 1) + 1)) / dH + 1;
+ int width_col = (width + 2 * pW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ int channels_col = channels * kT * kH * kW;
+ for (c = 0; c < channels_col; ++c)
+ {
+ int w_offset = c % kW;
+ int h_offset = (c / kW) % kH;
+ int t_offset = (c / kW / kH) % kT;
+ int c_vol = c / kT / kH / kW;
+ for (t = 0; t < depth_col; ++t)
+ {
+ for (h = 0; h < height_col; ++h)
+ {
+ for (w = 0; w < width_col; ++w)
+ {
+ int t_pad = t * dT - pT + t_offset * dilationT;
+ int h_pad = h * dH - pH + h_offset * dilationH;
+ int w_pad = w * dW - pW + w_offset * dilationW;
+ if (t_pad >= 0 && t_pad < depth &&
+ h_pad >= 0 && h_pad < height &&
+ w_pad >= 0 && w_pad < width)
+ data_vol[((c_vol * depth + t_pad) * height + h_pad) * width + w_pad] +=
+ data_col[((c * depth_col + t) * height_col + h) * width_col + w];
+ }
+ }
+ }
+ }
+}
+
+static inline void THNN_(VolumetricFullDilatedConvolution_shapeCheck)(
+ THTensor *input, THTensor *gradOutput,
+ THTensor *weight, THTensor *bias,
+ int dT, int dW, int dH, int pT, int pW, int pH,
+ int dilationT, int dilationW, int dilationH,
+ int aT, int aW, int aH) {
+ THNN_ARGCHECK(input->nDimension == 4 || input->nDimension == 5, 2, input,
+ "4D or 5D (batch mode) tensor expected for input, but got: %s");
+ // number of input & output planes and kernel size is indirectly defined by the weight tensor
+ THNN_ARGCHECK(weight->nDimension == 5, 4, weight,
+ "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor "
+ "expected for weight, but got: %s");
+ THArgCheck(dT > 0 && dW > 0 && dH > 0, 11,
+ "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW);
+ THArgCheck(aT < dT && aW < dW && aH < dH, 15,
+ "output adjustment must be smaller than stride, but got "
+ "adjT: %d adjH: %d adjW: %d dT: %d dH: %d dW: %d",
+ aT, aH, aW, dT, dH, dW);
+ THArgCheck(dilationT > 0 && dilationW > 0 && dilationH > 0, 15,
+ "dilation should be greater than zero, but got dilationT: %d, dilationH: %d, dilationW: %d",
+ dilationT, dilationH, dilationW);
+
+ int ndim = input->nDimension;
+ const int nInputPlane = (int)weight->size[0];
+ const int nOutputPlane = (int)weight->size[1];
+ const int kT = (int)weight->size[2];
+ const int kH = (int)weight->size[3];
+ const int kW = (int)weight->size[4];
+
+ if (bias != NULL) {
+ THNN_CHECK_DIM_SIZE(bias, 1, 0, weight->size[1]);
+ }
+
+ int dimf = 0;
+ int dimd = 1;
+ int dimh = 2;
+ int dimw = 3;
+
+ if (ndim == 5) {
+ dimf++;
+ dimd++;
+ dimh++;
+ dimw++;
+ }
+
+ const long inputWidth = input->size[dimw];
+ const long inputHeight = input->size[dimh];
+ const long inputDepth = input->size[dimd];
+ const long outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
+ const long outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
+ const long outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
+
+ if (outputDepth < 1 || outputWidth < 1 || outputHeight < 1)
+ THError("Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small",
+ nInputPlane,inputDepth,inputHeight,inputWidth,nOutputPlane,outputDepth,outputHeight,outputWidth);
+
+ THNN_CHECK_DIM_SIZE(input, ndim, dimf, nInputPlane);
+ if (gradOutput != NULL) {
+ THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimf, nOutputPlane);
+ THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimd, outputDepth);
+ THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimh, outputHeight);
+ THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimw, outputWidth);
+ }
+}
+
+void THNN_(VolumetricFullDilatedConvolution_updateOutput)(
+ THNNState *state,
+ THTensor *input, // 4D or 5D (batch) tensor
+ THTensor *output,
+ THTensor *weight, // weight tensor (nInputPlane x nOutputPlane x kT x kH x kW)
+ THTensor *bias,
+ THTensor *finput, // internal columns buffer
+ THTensor *fgradInput, // internal ones buffer
+ int dT, int dW, int dH, // stride of the convolution
+ int pT, int pW, int pH, // padding
+ int dilationT, int dilationW, int dilationH,
+ int aT, int aW, int aH) // extra output adjustment
+{
+ THTensor *columns = finput;
+ THTensor *ones = fgradInput;
+
+ THNN_(VolumetricFullDilatedConvolution_shapeCheck)(
+ input, NULL, weight, bias,
+ dT, dW, dH, pT, pW, pH, dilationT, dilationW, dilationH, aT, aW, aH);
+
+ const int nInputPlane = (int)weight->size[0];
+ const int nOutputPlane = (int)weight->size[1];
+ const int kT = (int)weight->size[2];
+ const int kH = (int)weight->size[3];
+ const int kW = (int)weight->size[4];
+
+ input = THTensor_(newContiguous)(input);
+ weight = THTensor_(newContiguous)(weight);
+ bias = bias ? THTensor_(newContiguous)(bias) : bias;
+ int batch = 1;
+ if (input->nDimension == 4)
+ {
+ // Force batch
+ batch = 0;
+ THTensor_(resize5d)(input, 1, input->size[0], input->size[1], input->size[2], input->size[3]);
+ }
+
+ const long inputWidth = input->size[4];
+ const long inputHeight = input->size[3];
+ const long inputDepth = input->size[2];
+ const long outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
+ const long outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
+ const long outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
+
+ // Batch size + input planes
+ const long batchSize = input->size[0];
+
+ // Resize output
+ THTensor_(resize5d)(output, batchSize, nOutputPlane, outputDepth, outputHeight, outputWidth);
+
+ // Resize temporary columns
+ THTensor_(resize2d)(columns, nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth);
+ THTensor_(zero)(columns);
+
+ // Define a buffer of ones, for bias accumulation
+ // Note: this buffer can be shared with other modules, it only ever gets increased,
+ // and always contains ones.
+ if (ones->nDimension != 3 || ones->size[0]*ones->size[1]*ones->size[2] < outputDepth*outputHeight*outputWidth)
+ {
+ // Resize plane and fill with ones...
+ THTensor_(resize3d)(ones, outputDepth, outputHeight, outputWidth);
+ THTensor_(fill)(ones, 1);
+ }
+
+ // Helpers
+ THTensor *input_n = THTensor_(new)();
+ THTensor *output_n = THTensor_(new)();
+
+ int elt;
+ // For each elt in batch, do:
+ for (elt = 0; elt < batchSize; ++elt)
+ {
+ // Matrix mulitply per output:
+ THTensor_(select)(input_n, input, 0, elt);
+ THTensor_(select)(output_n, output, 0, elt);
+
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ const long m = weight->size[1] * weight->size[2] * weight->size[3] * weight->size[4];
+ const long n = columns->size[1];
+ const long k = weight->size[0];
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ THBlas_(gemm)(
+ 'n', 't',
+ n, m, k,
+ 1,
+ THTensor_(data)(input_n), n,
+ THTensor_(data)(weight), m,
+ 0,
+ THTensor_(data)(columns), n
+ );
+
+ // Unpack columns back into input:
+ THNN_(col2vol)(
+ THTensor_(data)(columns),
+ nOutputPlane, outputDepth, outputHeight, outputWidth,
+ kT, kH, kW,
+ pT, pH, pW,
+ dT, dH, dW,
+ dilationT, dilationH, dilationW,
+ THTensor_(data)(output_n)
+ );
+
+ // Do Bias after:
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ const long m_ = nOutputPlane;
+ const long n_ = outputDepth * outputHeight * outputWidth;
+ const long k_ = 1;
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ if (bias) {
+ THBlas_(gemm)(
+ 't', 'n',
+ n_, m_, k_,
+ 1,
+ THTensor_(data)(ones), k_,
+ THTensor_(data)(bias), k_,
+ 1,
+ THTensor_(data)(output_n), n_
+ );
+ }
+ }
+
+ // Free
+ THTensor_(free)(input_n);
+ THTensor_(free)(output_n);
+
+ // Resize output
+ if (batch == 0)
+ {
+ THTensor_(resize4d)(output, nOutputPlane, outputDepth, outputHeight, outputWidth);
+ THTensor_(resize4d)(input, nInputPlane, inputDepth, inputHeight, inputWidth);
+ }
+
+ THTensor_(free)(input);
+ THTensor_(free)(weight);
+ if (bias) THTensor_(free)(bias);
+}
+
+void THNN_(VolumetricFullDilatedConvolution_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ THTensor *weight,
+ THTensor *finput,
+ THTensor *fgradInput, // only used by cuda impl
+ int dT, int dW, int dH, // stride
+ int pT, int pW, int pH, // padding
+ int dilationT, int dilationW, int dilationH,
+ int aT, int aW, int aH) // extra output adjustment
+{
+ THTensor *gradColumns = finput;
+
+ // number of input & output planes and kernel size is indirectly defined by the weight tensor
+ THNN_(VolumetricFullDilatedConvolution_shapeCheck)(
+ input, gradOutput, weight, NULL,
+ dT, dW, dH, pT, pW, pH, dilationT, dilationW, dilationH, aT, aW, aH);
+
+ const int nInputPlane = (int)weight->size[0];
+ const int nOutputPlane = (int)weight->size[1];
+ const int kT = (int)weight->size[2];
+ const int kH = (int)weight->size[3];
+ const int kW = (int)weight->size[4];
+
+ input = THTensor_(newContiguous)(input);
+ weight = THTensor_(newContiguous)(weight);
+ gradOutput = THTensor_(newContiguous)(gradOutput);
+
+ int batch = 1;
+ if (input->nDimension == 4)
+ {
+ // Force batch
+ batch = 0;
+ THTensor_(resize5d)(input, 1, input->size[0], input->size[1], input->size[2], input->size[3]);
+ THTensor_(resize5d)(gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
+ }
+
+ const long inputWidth = input->size[4];
+ const long inputHeight = input->size[3];
+ const long inputDepth = input->size[2];
+ const long outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
+ const long outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
+ const long outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
+
+ // Batch size + input planes
+ const long batchSize = input->size[0];
+
+ // Resize output
+ THTensor_(resize5d)(gradInput, batchSize, nInputPlane, inputDepth, inputHeight, inputWidth);
+ THTensor_(zero)(gradInput);
+
+ // Resize temporary columns
+ THTensor_(resize2d)(gradColumns, nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth);
+
+ // Helpers
+ THTensor *gradInput_n = THTensor_(new)();
+ THTensor *gradOutput_n = THTensor_(new)();
+
+ int elt;
+ // For each elt in batch, do:
+ for (elt = 0; elt < batchSize; ++elt)
+ {
+ // Matrix mulitply per sample:
+ THTensor_(select)(gradInput_n, gradInput, 0, elt);
+ THTensor_(select)(gradOutput_n, gradOutput, 0, elt);
+
+ // Extract columns:
+ THNN_(vol2col)(
+ THTensor_(data)(gradOutput_n),
+ nOutputPlane, outputDepth, outputHeight, outputWidth,
+ kT, kH, kW,
+ pT, pH, pW,
+ dT, dH, dW,
+ dilationT, dilationH, dilationW,
+ THTensor_(data)(gradColumns)
+ );
+
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ const long m = weight->size[0];
+ const long n = gradColumns->size[1];
+ const long k = weight->size[1] * weight->size[2] * weight->size[3] * weight->size[4];
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ THBlas_(gemm)(
+ 'n', 'n',
+ n, m, k,
+ 1,
+ THTensor_(data)(gradColumns), n,
+ THTensor_(data)(weight), k,
+ 0,
+ THTensor_(data)(gradInput_n), n
+ );
+ }
+
+ // Free
+ THTensor_(free)(gradInput_n);
+ THTensor_(free)(gradOutput_n);
+
+ // Resize output
+ if (batch == 0)
+ {
+ THTensor_(resize4d)(gradOutput, nOutputPlane, outputDepth, outputHeight, outputWidth);
+ THTensor_(resize4d)(input, nInputPlane, inputDepth, inputHeight, inputWidth);
+ THTensor_(resize4d)(gradInput, nInputPlane, inputDepth, inputHeight, inputWidth);
+ }
+
+ THTensor_(free)(input);
+ THTensor_(free)(gradOutput);
+ THTensor_(free)(weight);
+}
+
+void THNN_(VolumetricFullDilatedConvolution_accGradParameters)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradWeight,
+ THTensor *gradBias,
+ THTensor *finput,
+ THTensor *fgradInput,
+ int dT, int dW, int dH, // stride
+ int pT, int pW, int pH, // padding
+ int dilationT, int dilationW, int dilationH,
+ int aT, int aW, int aH, // extra output adjustment
+ accreal scale_)
+{
+ real scale = TH_CONVERT_ACCREAL_TO_REAL(scale_);
+ // number of input & output planes and kernel size is indirectly defined by the gradWeight tensor
+ THNN_(VolumetricFullDilatedConvolution_shapeCheck)(
+ input, gradOutput, gradWeight, gradBias,
+ dT, dW, dH, pT, pW, pH, dilationT, dilationW, dilationH, aT, aW, aH);
+
+ int nInputPlane = (int)gradWeight->size[0];
+ int nOutputPlane = (int)gradWeight->size[1];
+ int kT = (int)gradWeight->size[2];
+ int kH = (int)gradWeight->size[3];
+ int kW = (int)gradWeight->size[4];
+
+ THTensor *columns = finput;
+ THTensor *ones = fgradInput;
+
+ input = THTensor_(newContiguous)(input);
+ gradOutput = THTensor_(newContiguous)(gradOutput);
+ THArgCheck(THTensor_(isContiguous)(gradWeight), 4, "gradWeight needs to be contiguous");
+ if (gradBias)
+ THArgCheck(THTensor_(isContiguous)(gradBias), 5, "gradBias needs to be contiguous");
+
+ int batch = 1;
+ if (input->nDimension == 4)
+ {
+ // Force batch
+ batch = 0;
+ THTensor_(resize5d)(input, 1, input->size[0], input->size[1], input->size[2], input->size[3]);
+ THTensor_(resize5d)(gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
+ }
+
+ const long inputWidth = input->size[4];
+ const long inputHeight = input->size[3];
+ const long inputDepth = input->size[2];
+ const long outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
+ const long outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
+ const long outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
+
+ // Batch size + input planes
+ const long batchSize = input->size[0];
+
+ // Define a buffer of ones, for bias accumulation
+ if (ones->nDimension != 3 || ones->size[0]*ones->size[1]*ones->size[2] < outputDepth*outputHeight*outputWidth)
+ {
+ // Resize plane and fill with ones...
+ THTensor_(resize3d)(ones, outputDepth, outputHeight, outputWidth);
+ THTensor_(fill)(ones, 1);
+ }
+
+ // Resize temporary columns
+ THTensor_(resize2d)(columns, nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth);
+
+ // Helpers
+ THTensor *input_n = THTensor_(new)();
+ THTensor *gradOutput_n = THTensor_(new)();
+
+ int elt;
+ // For each elt in batch, do:
+ for (elt = 0; elt < batchSize; ++elt)
+ {
+ // Matrix mulitply per output:
+ THTensor_(select)(input_n, input, 0, elt);
+ THTensor_(select)(gradOutput_n, gradOutput, 0, elt);
+
+ // Extract columns:
+ THNN_(vol2col)(
+ THTensor_(data)(gradOutput_n), nOutputPlane,
+ outputDepth, outputHeight, outputWidth,
+ kT, kH, kW,
+ pT, pH, pW,
+ dT, dH, dW,
+ dilationT, dilationH, dilationW,
+ THTensor_(data)(columns)
+ );
+
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ const long n = columns->size[0]; // nOutputPlane * kt * kh * kw
+ const long m = input_n->size[0]; // nInputPlane
+ const long k = columns->size[1]; // inputHeight * inputWidth
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ THBlas_(gemm)(
+ 't', 'n',
+ n, m, k,
+ scale,
+ THTensor_(data)(columns), k,
+ THTensor_(data)(input_n), k,
+ 1,
+ THTensor_(data)(gradWeight), n
+ );
+
+ // Do Bias:
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ const long m_ = nOutputPlane;
+ const long k_ = outputDepth * outputHeight * outputWidth;
+
+ // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
+ if (gradBias) {
+ THBlas_(gemv)(
+ 't',
+ k_, m_,
+ scale,
+ THTensor_(data)(gradOutput_n), k_,
+ THTensor_(data)(ones), 1,
+ 1,
+ THTensor_(data)(gradBias), 1
+ );
+ }
+ }
+
+ // Free
+ THTensor_(free)(input_n);
+ THTensor_(free)(gradOutput_n);
+
+ // Resize
+ if (batch == 0)
+ {
+ THTensor_(resize4d)(gradOutput, nOutputPlane, outputDepth, outputHeight, outputWidth);
+ THTensor_(resize4d)(input, nInputPlane, inputDepth, inputHeight, inputWidth);
+ }
+
+ THTensor_(free)(input);
+ THTensor_(free)(gradOutput);
+}
+
+#endif
diff --git a/lib/THNN/init.c b/lib/THNN/init.c
index acb88c0..cd5ddb9 100644
--- a/lib/THNN/init.c
+++ b/lib/THNN/init.c
@@ -200,6 +200,9 @@
#include "generic/SpatialConvolutionLocal.c"
#include "THGenerateFloatTypes.h"
+#include "generic/SpatialFullDilatedConvolution.c"
+#include "THGenerateFloatTypes.h"
+
#include "generic/SpatialFullConvolution.c"
#include "THGenerateFloatTypes.h"
@@ -251,6 +254,9 @@
#include "generic/VolumetricConvolutionMM.c"
#include "THGenerateFloatTypes.h"
+#include "generic/VolumetricFullDilatedConvolution.c"
+#include "THGenerateFloatTypes.h"
+
#include "generic/VolumetricFullConvolution.c"
#include "THGenerateFloatTypes.h"