Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-07-28 06:17:44 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-03 05:44:59 +0300
commitbbebfdc88c6cd0e533d10a08fc48565c7452e2e6 (patch)
tree1b03b79b47186ff2976217c5be6c2fab92bbb355
parent8d9e9562beb792e17c6614ec2b515094f9663776 (diff)
add 2d and 3d dilated full Convolution
-rw-r--r--lib/THCUNN/SpatialFullDilatedConvolution.cu8
-rw-r--r--lib/THCUNN/VolumetricFullConvolution.cu1
-rw-r--r--lib/THCUNN/VolumetricFullDilatedConvolution.cu8
-rw-r--r--lib/THCUNN/generic/SpatialFullConvolution.cu421
-rw-r--r--lib/THCUNN/generic/SpatialFullDilatedConvolution.cu469
-rw-r--r--lib/THCUNN/generic/THCUNN.h82
-rw-r--r--lib/THCUNN/generic/VolumetricFullConvolution.cu465
-rw-r--r--lib/THCUNN/generic/VolumetricFullDilatedConvolution.cu511
8 files changed, 1097 insertions, 868 deletions
diff --git a/lib/THCUNN/SpatialFullDilatedConvolution.cu b/lib/THCUNN/SpatialFullDilatedConvolution.cu
new file mode 100644
index 0000000..77d9811
--- /dev/null
+++ b/lib/THCUNN/SpatialFullDilatedConvolution.cu
@@ -0,0 +1,8 @@
+#include "THCUNN.h"
+#include "im2col.h"
+
+#include "THCHalf.h"
+#include "THCHalfAutoNumerics.cuh"
+
+#include "generic/SpatialFullDilatedConvolution.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/lib/THCUNN/VolumetricFullConvolution.cu b/lib/THCUNN/VolumetricFullConvolution.cu
index 93c4c0f..556b5bc 100644
--- a/lib/THCUNN/VolumetricFullConvolution.cu
+++ b/lib/THCUNN/VolumetricFullConvolution.cu
@@ -1,6 +1,5 @@
#include "THCUNN.h"
#include "common.h"
-#include "vol2col.h"
#include "THCHalf.h"
#include "THCHalfAutoNumerics.cuh"
diff --git a/lib/THCUNN/VolumetricFullDilatedConvolution.cu b/lib/THCUNN/VolumetricFullDilatedConvolution.cu
new file mode 100644
index 0000000..47173f2
--- /dev/null
+++ b/lib/THCUNN/VolumetricFullDilatedConvolution.cu
@@ -0,0 +1,8 @@
+#include "THCUNN.h"
+#include "common.h"
+#include "vol2col.h"
+#include "THCHalf.h"
+#include "THCHalfAutoNumerics.cuh"
+
+#include "generic/VolumetricFullDilatedConvolution.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/lib/THCUNN/generic/SpatialFullConvolution.cu b/lib/THCUNN/generic/SpatialFullConvolution.cu
index 76abb90..af9a473 100644
--- a/lib/THCUNN/generic/SpatialFullConvolution.cu
+++ b/lib/THCUNN/generic/SpatialFullConvolution.cu
@@ -2,65 +2,6 @@
#define THC_GENERIC_FILE "generic/SpatialFullConvolution.cu"
#else
-static inline void THNN_(SpatialFullConvolution_shapeCheck)(
- THCState *state,
- THCTensor *input, THCTensor *gradOutput,
- THCTensor *weight, THCTensor *bias,
- int kH, int kW, int dH, int dW, int padH, int padW,
- int adjH, int adjW) {
- THArgCheck(kW > 0 && kH > 0, 9,
- "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
- THArgCheck(dW > 0 && dH > 0, 11,
- "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
- THArgCheck(adjW < dW && adjH < dH, 15,
- "output adjustment must be smaller than stride, but got adjH: %d adjW: %d dH: %d dW: %d",
- adjH, adjW, dH, dW);
- THArgCheck(THCTensor_(isContiguous)(state, weight), 4,
- "weight tensor has to be contiguous");
- THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5,
- "bias tensor has to be contiguous");
- THCUNN_argCheck(state, weight->nDimension == 2 || weight->nDimension == 4, 5, weight,
- "2D or 4D weight tensor expected, but got: %s");
-
- if (bias != NULL) {
- THCUNN_check_dim_size(state, bias, 1, 0, weight->size[1]);
- }
-
- int ndim = input->nDimension;
- int dimf = 0;
- int dimh = 1;
- int dimw = 2;
-
- if (ndim == 4) {
- dimf++;
- dimh++;
- dimw++;
- }
-
- THCUNN_argCheck(state, ndim == 3 || ndim == 4, 2, input,
- "3D or 4D input tensor expected but got: %s");
-
- long nInputPlane = weight->size[0];
- long inputHeight = input->size[dimh];
- long inputWidth = input->size[dimw];
- long nOutputPlane = weight->size[1];
- long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH;
- long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW;
-
- if (outputWidth < 1 || outputHeight < 1)
- THError("Given input size: (%d x %d x %d). "
- "Calculated output size: (%d x %d x %d). Output size is too small",
- nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth);
-
- THCUNN_check_dim_size(state, input, ndim, dimf, nInputPlane);
-
- if (gradOutput != NULL) {
- THCUNN_check_dim_size(state, gradOutput, ndim, dimf, nOutputPlane);
- THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight);
- THCUNN_check_dim_size(state, gradOutput, ndim, dimw, outputWidth);
- }
-}
-
void THNN_(SpatialFullConvolution_updateOutput)(
THCState *state,
THCTensor *input,
@@ -74,133 +15,9 @@ void THNN_(SpatialFullConvolution_updateOutput)(
int padW, int padH,
int adjW, int adjH)
{
-
- int nInputPlane = THCTensor_(size)(state, weight, 0);
- int nOutputPlane = THCTensor_(size)(state, weight, 1);
-
- THCUNN_assertSameGPU(state, 6, input, output, weight,
- bias, columns, ones);
- THNN_(SpatialFullConvolution_shapeCheck)
- (state, input, NULL, weight, bias, kH, kW, dH, dW, padH, padW, adjH, adjW);
-
- input = THCTensor_(newContiguous)(state, input);
- weight = THCTensor_(newContiguous)(state, weight);
- bias = bias ? THCTensor_(newContiguous)(state, bias) : bias;
-
- int batch = 1;
- if (input->nDimension == 3) {
- // Force batch
- batch = 0;
- THCTensor_(resize4d)(state, input, 1, input->size[0], input->size[1], input->size[2]);
- }
-
- long inputWidth = input->size[3];
- long inputHeight = input->size[2];
- long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW;
- long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH;
-
- // Batch size + input planes
- long batchSize = input->size[0];
-
- // Resize output
- THCTensor_(resize4d)(state, output, batchSize, nOutputPlane, outputHeight, outputWidth);
-
- // Resize temporary columns
- THCTensor_(resize2d)(state, columns, nOutputPlane*kW*kH, inputHeight*inputWidth);
-
- // Define a buffer of ones, for bias accumulation
- // Note: this buffer can be shared with other modules, it only ever gets increased,
- // and always contains ones.
- if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
- // Resize plane and fill with ones...
- THCTensor_(resize2d)(state, ones, outputHeight, outputWidth);
- THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
- }
-
- // Helpers
- THCTensor *input_n = THCTensor_(new)(state);
- THCTensor *output_n = THCTensor_(new)(state);
-
- // For each elt in batch, do:
- for (int elt = 0; elt < batchSize; elt ++) {
- // Matrix mulitply per output:
- THCTensor_(select)(state, input_n, input, 0, elt);
- THCTensor_(select)(state, output_n, output, 0, elt);
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long m = weight->size[1] * weight->size[2] * weight->size[3];
- long n = columns->size[1];
- long k = weight->size[0];
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemm(
- #elif defined(THC_REAL_IS_HALF)
- THCudaBlas_Hgemm(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemm(
- #endif
- state,
- 'n', 't',
- n, m, k,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, input_n), n,
- THCTensor_(data)(state, weight), m,
- ScalarConvert<int, real>::to(0),
- THCTensor_(data)(state, columns), n
- );
-
- // Unpack columns back into input:
- col2im<real, accreal>(
- THCState_getCurrentStream(state),
- THCTensor_(data)(state, columns),
- nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
- 1, 1, THCTensor_(data)(state, output_n)
- );
-
- // Do Bias after:
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long m_ = nOutputPlane;
- long n_ = outputHeight * outputWidth;
- long k_ = 1;
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- if (bias) {
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemm(
- #elif defined(THC_REAL_IS_HALF)
- THCudaBlas_Hgemm(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemm(
- #endif
- state,
- 't', 'n',
- n_, m_, k_,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, ones), k_,
- THCTensor_(data)(state, bias), k_,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, output_n), n_
- );
- }
- }
-
- // Free
- THCTensor_(free)(state, input_n);
- THCTensor_(free)(state, output_n);
-
- // Resize output
- if (batch == 0) {
- THCTensor_(resize3d)(state, output, nOutputPlane, outputHeight, outputWidth);
- THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth);
- }
-
- THCTensor_(free)(state, input);
- THCTensor_(free)(state, weight);
- if (bias) THCTensor_(free)(state, bias);
-
+ THNN_(SpatialFullDilatedConvolution_updateOutput)(
+ state, input, output, weight, bias, columns, ones,
+ kW, kH, dW, dH, padW, padH, 1, 1, adjW, adjH);
}
void THNN_(SpatialFullConvolution_updateGradInput)(
@@ -215,98 +32,9 @@ void THNN_(SpatialFullConvolution_updateGradInput)(
int padW, int padH,
int adjW, int adjH)
{
- int nInputPlane = THCTensor_(size)(state, weight, 0);
- int nOutputPlane = THCTensor_(size)(state, weight, 1);
-
- THCUNN_assertSameGPU(state, 5, input, gradOutput, weight,
- gradColumns, gradInput);
- THNN_(SpatialFullConvolution_shapeCheck)
- (state, input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW, adjH, adjW);
-
- input = THCTensor_(newContiguous)(state, input);
- gradOutput = THCTensor_(newContiguous)(state, gradOutput);
- weight = THCTensor_(newContiguous)(state, weight);
- int batch = 1;
- if (input->nDimension == 3) {
- // Force batch
- batch = 0;
- THCTensor_(resize4d)(state, input, 1, input->size[0], input->size[1], input->size[2]);
- THCTensor_(resize4d)(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2]);
- }
-
- long inputWidth = input->size[3];
- long inputHeight = input->size[2];
- long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW;
- long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH;
-
- // Batch size + input planes
- long batchSize = input->size[0];
-
- // Resize output
- THCTensor_(resize4d)(state, gradInput, batchSize, nInputPlane, inputHeight, inputWidth);
-
- // Resize temporary columns
- THCTensor_(resize2d)(state, gradColumns, nOutputPlane*kW*kH, inputHeight*inputWidth);
-
- // Helpers
- THCTensor *gradInput_n = THCTensor_(new)(state);
- THCTensor *gradOutput_n = THCTensor_(new)(state);
-
- // For each elt in batch, do:
- for (int elt = 0; elt < batchSize; elt ++) {
- // Matrix mulitply per sample:
- THCTensor_(select)(state, gradInput_n, gradInput, 0, elt);
- THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
-
- // Extract columns:
- im2col(
- THCState_getCurrentStream(state),
- THCTensor_(data)(state, gradOutput_n),
- nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
- 1, 1, THCTensor_(data)(state, gradColumns)
- );
-
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long m = weight->size[0];
- long n = gradColumns->size[1];
- long k = weight->size[1] * weight->size[2] * weight->size[3];
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemm(
- #elif defined(THC_REAL_IS_HALF)
- THCudaBlas_Hgemm(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemm(
- #endif
- state,
- 'n', 'n',
- n, m, k,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, gradColumns), n,
- THCTensor_(data)(state, weight), k,
- ScalarConvert<int, real>::to(0),
- THCTensor_(data)(state, gradInput_n), n
- );
- }
-
-
- // Free
- THCTensor_(free)(state, gradInput_n);
- THCTensor_(free)(state, gradOutput_n);
-
- // Resize output
- if (batch == 0) {
- THCTensor_(resize3d)(state, gradOutput, nOutputPlane, outputHeight, outputWidth);
- THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth);
- THCTensor_(resize3d)(state, gradInput, nInputPlane, inputHeight, inputWidth);
- }
-
- THCTensor_(free)(state, input);
- THCTensor_(free)(state, gradOutput);
- THCTensor_(free)(state, weight);
+ THNN_(SpatialFullDilatedConvolution_updateGradInput)(
+ state, input, gradOutput, gradInput, weight, gradColumns,
+ kW, kH, dW, dH, padW, padH, 1, 1, adjW, adjH);
}
@@ -324,139 +52,10 @@ void THNN_(SpatialFullConvolution_accGradParameters)(
int adjW, int adjH,
accreal scale_)
{
- real scale = ScalarConvert<accreal, real>::to(scale_);
- int nInputPlane = THCTensor_(size)(state, gradWeight, 0);
- int nOutputPlane = THCTensor_(size)(state, gradWeight, 1);
-
- THCUNN_assertSameGPU(state, 6, input, gradOutput, gradWeight,
- gradBias, columns, ones);
- THNN_(SpatialFullConvolution_shapeCheck)
- (state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW, adjH, adjW);
-
- THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous");
- if (gradBias)
- THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous");
- input = THCTensor_(newContiguous)(state, input);
- gradOutput = THCTensor_(newContiguous)(state, gradOutput);
- int batch = 1;
- if (input->nDimension == 3) {
- // Force batch
- batch = 0;
- THCTensor_(resize4d)(state, input, 1, input->size[0], input->size[1], input->size[2]);
- THCTensor_(resize4d)(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2]);
- }
-
- long inputWidth = input->size[3];
- long inputHeight = input->size[2];
- long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW;
- long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH;
-
- // Batch size + input planes
- long batchSize = input->size[0];
-
- // Define a buffer of ones, for bias accumulation
- if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
- // Resize plane and fill with ones...
- THCTensor_(resize2d)(state, ones, outputHeight, outputWidth);
- THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
- }
-
- // Resize temporary columns
- THCTensor_(resize2d)(state, columns, nOutputPlane*kW*kH, inputHeight*inputWidth);
-
- // Helpers
- THCTensor *input_n = THCTensor_(new)(state);
- THCTensor *gradOutput_n = THCTensor_(new)(state);
-
- // For each elt in batch, do:
- for (int elt = 0; elt < batchSize; elt ++) {
- // Matrix mulitply per output:
- THCTensor_(select)(state, input_n, input, 0, elt);
- THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
-
- // Extract columns:
- im2col(
- THCState_getCurrentStream(state),
- THCTensor_(data)(state, gradOutput_n),
- nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
- 1, 1, THCTensor_(data)(state, columns)
- );
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long n = columns->size[0]; // nOutputPlane * kh * kw
- long m = input_n->size[0]; // nInputPlane
- long k = columns->size[1]; // inputHeight * inputWidth
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemm(
- #elif defined(THC_REAL_IS_HALF)
- THCudaBlas_Hgemm(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemm(
- #endif
- state,
- 't', 'n',
- n, m, k,
- scale,
- THCTensor_(data)(state, columns), k,
- THCTensor_(data)(state, input_n), k,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, gradWeight), n
- );
-
- // Do Bias:
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long m_ = nOutputPlane;
- long k_ = outputHeight * outputWidth;
-
- // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
- if (gradBias) {
- #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemv(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemv(
- #endif
- state,
- 't',
- k_, m_,
- scale,
- THCTensor_(data)(state, gradOutput_n), k_,
- THCTensor_(data)(state, ones), 1,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, gradBias), 1
- );
- #endif
- #ifdef THC_REAL_IS_HALF
- THCudaBlas_Hgemm(
- state,
- 't', 'n',
- m_, 1, k_,
- scale,
- THCTensor_(data)(state, gradOutput_n), k_,
- THCTensor_(data)(state, ones), k_,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, gradBias), m_
- );
- #endif
- }
- }
-
- // Free
- THCTensor_(free)(state, input_n);
- THCTensor_(free)(state, gradOutput_n);
-
- // Resize
- if (batch == 0) {
- THCTensor_(resize3d)(state, gradOutput, nOutputPlane, outputHeight, outputWidth);
- THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth);
- }
-
- THCTensor_(free)(state, input);
- THCTensor_(free)(state, gradOutput);
+ THNN_(SpatialFullDilatedConvolution_accGradParameters)(
+ state, input, gradOutput, gradWeight, gradBias,
+ columns, ones,
+ kW, kH, dW, dH, padW, padH, 1, 1, adjW, adjH, scale_);
}
#endif
diff --git a/lib/THCUNN/generic/SpatialFullDilatedConvolution.cu b/lib/THCUNN/generic/SpatialFullDilatedConvolution.cu
new file mode 100644
index 0000000..322a213
--- /dev/null
+++ b/lib/THCUNN/generic/SpatialFullDilatedConvolution.cu
@@ -0,0 +1,469 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/SpatialFullDilatedConvolution.cu"
+#else
+
+static inline void THNN_(SpatialFullDilatedConvolution_shapeCheck)(
+ THCState *state,
+ THCTensor *input, THCTensor *gradOutput,
+ THCTensor *weight, THCTensor *bias,
+ int kH, int kW, int dH, int dW, int padH, int padW,
+ int dilationH, int dilationW,
+ int adjH, int adjW) {
+ THArgCheck(kW > 0 && kH > 0, 9,
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
+ THArgCheck(dW > 0 && dH > 0, 11,
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+ THArgCheck(adjW < dW && adjH < dH, 15,
+ "output adjustment must be smaller than stride, but got adjH: %d adjW: %d dH: %d dW: %d",
+ adjH, adjW, dH, dW);
+ THArgCheck(dilationW > 0 && dilationH > 0, 15,
+ "dilation should be greater than zero, but got dilationH: %d, dilationW: %d",
+ dilationH, dilationW);
+ THArgCheck(THCTensor_(isContiguous)(state, weight), 4,
+ "weight tensor has to be contiguous");
+ THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5,
+ "bias tensor has to be contiguous");
+ THCUNN_argCheck(state, weight->nDimension == 2 || weight->nDimension == 4, 5, weight,
+ "2D or 4D weight tensor expected, but got: %s");
+
+ if (bias != NULL) {
+ THCUNN_check_dim_size(state, bias, 1, 0, weight->size[1]);
+ }
+
+ int ndim = input->nDimension;
+ int dimf = 0;
+ int dimh = 1;
+ int dimw = 2;
+
+ if (ndim == 4) {
+ dimf++;
+ dimh++;
+ dimw++;
+ }
+
+ THCUNN_argCheck(state, ndim == 3 || ndim == 4, 2, input,
+ "3D or 4D input tensor expected but got: %s");
+
+ long nInputPlane = weight->size[0];
+ long inputHeight = input->size[dimh];
+ long inputWidth = input->size[dimw];
+ long nOutputPlane = weight->size[1];
+ long outputHeight = (inputHeight - 1) * dH - 2*padH + (dilationH * (kH - 1) + 1) + adjH;
+ long outputWidth = (inputWidth - 1) * dW - 2*padW + (dilationW * (kW - 1) + 1) + adjW;
+
+ if (outputWidth < 1 || outputHeight < 1)
+ THError("Given input size: (%d x %d x %d). "
+ "Calculated output size: (%d x %d x %d). Output size is too small",
+ nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth);
+
+ THCUNN_check_dim_size(state, input, ndim, dimf, nInputPlane);
+
+ if (gradOutput != NULL) {
+ THCUNN_check_dim_size(state, gradOutput, ndim, dimf, nOutputPlane);
+ THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight);
+ THCUNN_check_dim_size(state, gradOutput, ndim, dimw, outputWidth);
+ }
+}
+
+void THNN_(SpatialFullDilatedConvolution_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ THCTensor *weight,
+ THCTensor *bias,
+ THCTensor *columns,
+ THCTensor *ones,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH,
+ int adjW, int adjH)
+{
+
+ int nInputPlane = THCTensor_(size)(state, weight, 0);
+ int nOutputPlane = THCTensor_(size)(state, weight, 1);
+
+ THCUNN_assertSameGPU(state, 6, input, output, weight,
+ bias, columns, ones);
+ THNN_(SpatialFullDilatedConvolution_shapeCheck)
+ (state, input, NULL, weight, bias, kH, kW, dH, dW, padH, padW, dilationH, dilationW, adjH, adjW);
+
+ input = THCTensor_(newContiguous)(state, input);
+ weight = THCTensor_(newContiguous)(state, weight);
+ bias = bias ? THCTensor_(newContiguous)(state, bias) : bias;
+
+ int batch = 1;
+ if (input->nDimension == 3) {
+ // Force batch
+ batch = 0;
+ THCTensor_(resize4d)(state, input, 1, input->size[0], input->size[1], input->size[2]);
+ }
+
+ long inputWidth = input->size[3];
+ long inputHeight = input->size[2];
+ long outputHeight = (inputHeight - 1) * dH - 2*padH + (dilationH * (kH - 1) + 1) + adjH;
+ long outputWidth = (inputWidth - 1) * dW - 2*padW + (dilationW * (kW - 1) + 1) + adjW;
+
+ // Batch size + input planes
+ long batchSize = input->size[0];
+
+ // Resize output
+ THCTensor_(resize4d)(state, output, batchSize, nOutputPlane, outputHeight, outputWidth);
+
+ // Resize temporary columns
+ THCTensor_(resize2d)(state, columns, nOutputPlane*kW*kH, inputHeight*inputWidth);
+
+ // Define a buffer of ones, for bias accumulation
+ // Note: this buffer can be shared with other modules, it only ever gets increased,
+ // and always contains ones.
+ if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
+ // Resize plane and fill with ones...
+ THCTensor_(resize2d)(state, ones, outputHeight, outputWidth);
+ THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
+ }
+
+ // Helpers
+ THCTensor *input_n = THCTensor_(new)(state);
+ THCTensor *output_n = THCTensor_(new)(state);
+
+ // For each elt in batch, do:
+ for (int elt = 0; elt < batchSize; elt ++) {
+ // Matrix mulitply per output:
+ THCTensor_(select)(state, input_n, input, 0, elt);
+ THCTensor_(select)(state, output_n, output, 0, elt);
+
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m = weight->size[1] * weight->size[2] * weight->size[3];
+ long n = columns->size[1];
+ long k = weight->size[0];
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ #ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemm(
+ #elif defined(THC_REAL_IS_HALF)
+ THCudaBlas_Hgemm(
+ #elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemm(
+ #endif
+ state,
+ 'n', 't',
+ n, m, k,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, input_n), n,
+ THCTensor_(data)(state, weight), m,
+ ScalarConvert<int, real>::to(0),
+ THCTensor_(data)(state, columns), n
+ );
+
+ // Unpack columns back into input:
+ col2im<real, accreal>(
+ THCState_getCurrentStream(state),
+ THCTensor_(data)(state, columns),
+ nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW, THCTensor_(data)(state, output_n)
+ );
+
+ // Do Bias after:
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m_ = nOutputPlane;
+ long n_ = outputHeight * outputWidth;
+ long k_ = 1;
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ if (bias) {
+ #ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemm(
+ #elif defined(THC_REAL_IS_HALF)
+ THCudaBlas_Hgemm(
+ #elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemm(
+ #endif
+ state,
+ 't', 'n',
+ n_, m_, k_,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, ones), k_,
+ THCTensor_(data)(state, bias), k_,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, output_n), n_
+ );
+ }
+ }
+
+ // Free
+ THCTensor_(free)(state, input_n);
+ THCTensor_(free)(state, output_n);
+
+ // Resize output
+ if (batch == 0) {
+ THCTensor_(resize3d)(state, output, nOutputPlane, outputHeight, outputWidth);
+ THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth);
+ }
+
+ THCTensor_(free)(state, input);
+ THCTensor_(free)(state, weight);
+ if (bias) THCTensor_(free)(state, bias);
+
+}
+
+void THNN_(SpatialFullDilatedConvolution_updateGradInput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ THCTensor *weight,
+ THCTensor *gradColumns,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH,
+ int adjW, int adjH)
+{
+ int nInputPlane = THCTensor_(size)(state, weight, 0);
+ int nOutputPlane = THCTensor_(size)(state, weight, 1);
+
+ THCUNN_assertSameGPU(state, 5, input, gradOutput, weight,
+ gradColumns, gradInput);
+ THNN_(SpatialFullDilatedConvolution_shapeCheck)
+ (state, input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW, dilationH, dilationW, adjH, adjW);
+
+ input = THCTensor_(newContiguous)(state, input);
+ gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ weight = THCTensor_(newContiguous)(state, weight);
+ int batch = 1;
+ if (input->nDimension == 3) {
+ // Force batch
+ batch = 0;
+ THCTensor_(resize4d)(state, input, 1, input->size[0], input->size[1], input->size[2]);
+ THCTensor_(resize4d)(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2]);
+ }
+
+ long inputWidth = input->size[3];
+ long inputHeight = input->size[2];
+ long outputHeight = (inputHeight - 1) * dH - 2*padH + (dilationH * (kH - 1) + 1) + adjH;
+ long outputWidth = (inputWidth - 1) * dW - 2*padW + (dilationW * (kW - 1) + 1) + adjW;
+
+ // Batch size + input planes
+ long batchSize = input->size[0];
+
+ // Resize output
+ THCTensor_(resize4d)(state, gradInput, batchSize, nInputPlane, inputHeight, inputWidth);
+
+ // Resize temporary columns
+ THCTensor_(resize2d)(state, gradColumns, nOutputPlane*kW*kH, inputHeight*inputWidth);
+
+ // Helpers
+ THCTensor *gradInput_n = THCTensor_(new)(state);
+ THCTensor *gradOutput_n = THCTensor_(new)(state);
+
+ // For each elt in batch, do:
+ for (int elt = 0; elt < batchSize; elt ++) {
+ // Matrix mulitply per sample:
+ THCTensor_(select)(state, gradInput_n, gradInput, 0, elt);
+ THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
+
+ // Extract columns:
+ im2col(
+ THCState_getCurrentStream(state),
+ THCTensor_(data)(state, gradOutput_n),
+ nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW, THCTensor_(data)(state, gradColumns)
+ );
+
+
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m = weight->size[0];
+ long n = gradColumns->size[1];
+ long k = weight->size[1] * weight->size[2] * weight->size[3];
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ #ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemm(
+ #elif defined(THC_REAL_IS_HALF)
+ THCudaBlas_Hgemm(
+ #elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemm(
+ #endif
+ state,
+ 'n', 'n',
+ n, m, k,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, gradColumns), n,
+ THCTensor_(data)(state, weight), k,
+ ScalarConvert<int, real>::to(0),
+ THCTensor_(data)(state, gradInput_n), n
+ );
+ }
+
+
+ // Free
+ THCTensor_(free)(state, gradInput_n);
+ THCTensor_(free)(state, gradOutput_n);
+
+ // Resize output
+ if (batch == 0) {
+ THCTensor_(resize3d)(state, gradOutput, nOutputPlane, outputHeight, outputWidth);
+ THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth);
+ THCTensor_(resize3d)(state, gradInput, nInputPlane, inputHeight, inputWidth);
+ }
+
+ THCTensor_(free)(state, input);
+ THCTensor_(free)(state, gradOutput);
+ THCTensor_(free)(state, weight);
+}
+
+
+void THNN_(SpatialFullDilatedConvolution_accGradParameters)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradWeight,
+ THCTensor *gradBias,
+ THCTensor *columns,
+ THCTensor *ones,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH,
+ int adjW, int adjH,
+ accreal scale_)
+{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
+ int nInputPlane = THCTensor_(size)(state, gradWeight, 0);
+ int nOutputPlane = THCTensor_(size)(state, gradWeight, 1);
+
+ THCUNN_assertSameGPU(state, 6, input, gradOutput, gradWeight,
+ gradBias, columns, ones);
+ THNN_(SpatialFullDilatedConvolution_shapeCheck)
+ (state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW, dilationH, dilationW, adjH, adjW);
+
+ THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous");
+ if (gradBias)
+ THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous");
+ input = THCTensor_(newContiguous)(state, input);
+ gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ int batch = 1;
+ if (input->nDimension == 3) {
+ // Force batch
+ batch = 0;
+ THCTensor_(resize4d)(state, input, 1, input->size[0], input->size[1], input->size[2]);
+ THCTensor_(resize4d)(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2]);
+ }
+
+ long inputWidth = input->size[3];
+ long inputHeight = input->size[2];
+ long outputHeight = (inputHeight - 1) * dH - 2*padH + (dilationH * (kH - 1) + 1) + adjH;
+ long outputWidth = (inputWidth - 1) * dW - 2*padW + (dilationW * (kW - 1) + 1) + adjW;
+
+ // Batch size + input planes
+ long batchSize = input->size[0];
+
+ // Define a buffer of ones, for bias accumulation
+ if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
+ // Resize plane and fill with ones...
+ THCTensor_(resize2d)(state, ones, outputHeight, outputWidth);
+ THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
+ }
+
+ // Resize temporary columns
+ THCTensor_(resize2d)(state, columns, nOutputPlane*kW*kH, inputHeight*inputWidth);
+
+ // Helpers
+ THCTensor *input_n = THCTensor_(new)(state);
+ THCTensor *gradOutput_n = THCTensor_(new)(state);
+
+ // For each elt in batch, do:
+ for (int elt = 0; elt < batchSize; elt ++) {
+ // Matrix mulitply per output:
+ THCTensor_(select)(state, input_n, input, 0, elt);
+ THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
+
+ // Extract columns:
+ im2col(
+ THCState_getCurrentStream(state),
+ THCTensor_(data)(state, gradOutput_n),
+ nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW, THCTensor_(data)(state, columns)
+ );
+
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long n = columns->size[0]; // nOutputPlane * kh * kw
+ long m = input_n->size[0]; // nInputPlane
+ long k = columns->size[1]; // inputHeight * inputWidth
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ #ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemm(
+ #elif defined(THC_REAL_IS_HALF)
+ THCudaBlas_Hgemm(
+ #elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemm(
+ #endif
+ state,
+ 't', 'n',
+ n, m, k,
+ scale,
+ THCTensor_(data)(state, columns), k,
+ THCTensor_(data)(state, input_n), k,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, gradWeight), n
+ );
+
+ // Do Bias:
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m_ = nOutputPlane;
+ long k_ = outputHeight * outputWidth;
+
+ // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
+ if (gradBias) {
+ #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
+ #ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemv(
+ #elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemv(
+ #endif
+ state,
+ 't',
+ k_, m_,
+ scale,
+ THCTensor_(data)(state, gradOutput_n), k_,
+ THCTensor_(data)(state, ones), 1,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, gradBias), 1
+ );
+ #endif
+ #ifdef THC_REAL_IS_HALF
+ THCudaBlas_Hgemm(
+ state,
+ 't', 'n',
+ m_, 1, k_,
+ scale,
+ THCTensor_(data)(state, gradOutput_n), k_,
+ THCTensor_(data)(state, ones), k_,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, gradBias), m_
+ );
+ #endif
+ }
+ }
+
+ // Free
+ THCTensor_(free)(state, input_n);
+ THCTensor_(free)(state, gradOutput_n);
+
+ // Resize
+ if (batch == 0) {
+ THCTensor_(resize3d)(state, gradOutput, nOutputPlane, outputHeight, outputWidth);
+ THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth);
+ }
+
+ THCTensor_(free)(state, input);
+ THCTensor_(free)(state, gradOutput);
+}
+
+#endif
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index 9692094..1a4464f 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -763,6 +763,48 @@ TH_API void THNN_(SpatialDilatedConvolution_accGradParameters)(
int dilationW, int dilationH,
accreal scale);
+TH_API void THNN_(SpatialFullDilatedConvolution_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ THCTensor *weight,
+ THCTensor *bias, // [OPTIONAL]
+ THCTensor *columns,
+ THCTensor *ones,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH,
+ int adjW, int adjH);
+
+TH_API void THNN_(SpatialFullDilatedConvolution_updateGradInput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ THCTensor *weight,
+ THCTensor *gradColumns,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH,
+ int adjW, int adjH);
+
+TH_API void THNN_(SpatialFullDilatedConvolution_accGradParameters)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradWeight,
+ THCTensor *gradBias, // [OPTIONAL]
+ THCTensor *columns,
+ THCTensor *ones,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH,
+ int adjW, int adjH,
+ accreal scale);
+
TH_API void THNN_(SpatialDilatedMaxPooling_updateOutput)(
THCState *state,
THCTensor *input,
@@ -1279,6 +1321,46 @@ TH_API void THNN_(VolumetricDilatedConvolution_accGradParameters)(
int dilationT, int dilationW, int dilationH,
accreal scale);
+TH_API void THNN_(VolumetricFullDilatedConvolution_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ THCTensor *weight,
+ THCTensor *bias, // [OPTIONAL]
+ THCTensor *finput,
+ THCTensor *fgradInput,
+ int dT, int dW, int dH,
+ int padT, int padW, int padH,
+ int dilationT, int dilationW, int dilationH,
+ int adjT, int adjW, int adjH);
+
+TH_API void THNN_(VolumetricFullDilatedConvolution_updateGradInput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ THCTensor *weight,
+ THCTensor *finput,
+ THCTensor *fgradInput,
+ int dT, int dW, int dH,
+ int padT, int padW, int padH,
+ int dilationT, int dilationW, int dilationH,
+ int adjT, int adjW, int adjH);
+
+TH_API void THNN_(VolumetricFullDilatedConvolution_accGradParameters)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradWeight,
+ THCTensor *gradBias, // [OPTIONAL]
+ THCTensor *finput,
+ THCTensor *fgradInput,
+ int dT, int dW, int dH,
+ int padT, int padW, int padH,
+ int dilationT, int dilationW, int dilationH,
+ int adjT, int adjW, int adjH,
+ accreal scale);
+
TH_API void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
THCState *state,
THCTensor *input,
diff --git a/lib/THCUNN/generic/VolumetricFullConvolution.cu b/lib/THCUNN/generic/VolumetricFullConvolution.cu
index 9dd266c..9837a2d 100644
--- a/lib/THCUNN/generic/VolumetricFullConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricFullConvolution.cu
@@ -2,75 +2,6 @@
#define THC_GENERIC_FILE "generic/VolumetricFullConvolution.cu"
#else
-static inline void THNN_(VolumetricFullConvolution_shapeCheck)(
- THCState *state,
- THCTensor *input,
- THCTensor *gradOutput,
- THCTensor *weight,
- THCTensor *bias,
- int dT, int dW, int dH,
- int padT, int padW, int padH,
- int adjT, int adjW, int adjH) {
- THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input,
- "4D or 5D (batch mode) tensor expected for input, but got: %s");
- // number of input & output planes and kernel size is indirectly defined by the weight tensor
- THCUNN_argCheck(state, weight->nDimension == 5, 4, weight,
- "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor "
- "expected for weight, but got: %s");
- THArgCheck(THCTensor_(isContiguous)(state, weight), 4,
- "weight tensor has to be contiguous");
- THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5,
- "bias tensor has to be contiguous");
- THArgCheck(dT > 0 && dW > 0 && dH > 0, 8,
- "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW);
- THArgCheck(adjT < dT && adjW < dW && adjH < dH, 14,
- "output adjustment must be smaller than stride, but got "
- "adjT: %d adjH: %d adjW: %d dT: %d dH: %d dW: %d",
- adjT, adjH, adjW, dT, dH, dW);
-
- int ndim = input->nDimension;
- int nInputPlane = THCTensor_(size)(state, weight, 0);
- int nOutputPlane = THCTensor_(size)(state, weight, 1);
- const int kT = (int)weight->size[2];
- const int kH = (int)weight->size[3];
- const int kW = (int)weight->size[4];
-
- if (bias != NULL) {
- THCUNN_check_dim_size(state, bias, 1, 0, weight->size[1]);
- }
-
- int dimf = 0;
- int dimd = 1;
- int dimh = 2;
- int dimw = 3;
-
- if (ndim == 5) {
- dimf++;
- dimd++;
- dimh++;
- dimw++;
- }
-
- long inputWidth = input->size[dimw];
- long inputHeight = input->size[dimh];
- long inputDepth = input->size[dimd];
- long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW;
- long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH;
- long outputDepth = (inputDepth - 1) * dT - 2*padT + kT + adjT;
-
- if (outputDepth < 1 || outputWidth < 1 || outputHeight < 1)
- THError("Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small",
- nInputPlane,inputDepth,inputHeight,inputWidth,nOutputPlane,outputDepth,outputHeight,outputWidth);
-
- THCUNN_check_dim_size(state, input, ndim, dimf, nInputPlane);
- if (gradOutput != NULL) {
- THCUNN_check_dim_size(state, gradOutput, ndim, dimf, nOutputPlane);
- THCUNN_check_dim_size(state, gradOutput, ndim, dimd, outputDepth);
- THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight);
- THCUNN_check_dim_size(state, gradOutput, ndim, dimw, outputWidth);
- }
-}
-
void THNN_(VolumetricFullConvolution_updateOutput)(
THCState *state,
THCTensor *input,
@@ -83,144 +14,9 @@ void THNN_(VolumetricFullConvolution_updateOutput)(
int padT, int padW, int padH,
int adjT, int adjW, int adjH)
{
-
- THCTensor *columns = finput;
- THCTensor *ones = fgradInput;
-
- int nInputPlane = THCTensor_(size)(state, weight, 0);
- int nOutputPlane = THCTensor_(size)(state, weight, 1);
- const int kT = (int)weight->size[2];
- const int kH = (int)weight->size[3];
- const int kW = (int)weight->size[4];
-
- THCUNN_assertSameGPU(state, 6, input, output, weight,
- bias, columns, ones);
- THNN_(VolumetricFullConvolution_shapeCheck)(
- state, input, NULL, weight, bias,
- dT, dW, dH, padT, padW, padH,
- adjT, adjW, adjH);
-
- input = THCTensor_(newContiguous)(state, input);
- weight = THCTensor_(newContiguous)(state, weight);
- bias = bias ? THCTensor_(newContiguous)(state, bias) : bias;
-
- int batch = 1;
- if (input->nDimension == 4) {
- // Force batch
- batch = 0;
- THCTensor_(resize5d)(state, input, 1, input->size[0], input->size[1], input->size[2], input->size[3]);
- }
-
- long inputWidth = input->size[4];
- long inputHeight = input->size[3];
- long inputDepth = input->size[2];
- long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW;
- long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH;
- long outputDepth = (inputDepth - 1) * dT - 2*padT + kT + adjT;
-
- // Batch size + input planes
- long batchSize = input->size[0];
-
- // Resize output
- THCTensor_(resize5d)(state, output, batchSize, nOutputPlane, outputDepth, outputHeight, outputWidth);
-
- // Resize temporary columns
- THCTensor_(resize2d)(state, columns, nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth);
-
- // Define a buffer of ones, for bias accumulation
- // Note: this buffer can be shared with other modules, it only ever gets increased,
- // and always contains ones.
- if (ones->nDimension != 3 || ones->size[0]*ones->size[1]*ones->size[2] < outputDepth*outputHeight*outputWidth) {
- // Resize plane and fill with ones...
- THCTensor_(resize3d)(state, ones, outputDepth, outputHeight, outputWidth);
- THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
- }
-
- // Helpers
- THCTensor *input_n = THCTensor_(new)(state);
- THCTensor *output_n = THCTensor_(new)(state);
-
- // For each elt in batch, do:
- for (int elt = 0; elt < batchSize; elt ++) {
- // Matrix mulitply per output:
- THCTensor_(select)(state, input_n, input, 0, elt);
- THCTensor_(select)(state, output_n, output, 0, elt);
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long m = weight->size[1] * weight->size[2] * weight->size[3] * weight->size[4];
- long n = columns->size[1];
- long k = weight->size[0];
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemm(
- #elif defined(THC_REAL_IS_HALF)
- THCudaBlas_Hgemm(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemm(
- #endif
- state,
- 'n', 't',
- n, m, k,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, input_n), n,
- THCTensor_(data)(state, weight), m,
- ScalarConvert<int, real>::to(0),
- THCTensor_(data)(state, columns), n
- );
-
- // Unpack columns back into input:
- col2vol<real, accreal>(
- THCState_getCurrentStream(state),
- THCTensor_(data)(state, columns),
- nOutputPlane, outputDepth, outputHeight, outputWidth, kT, kH, kW, padT, padH, padW, dT, dH, dW,
- 1,1,1,
- THCTensor_(data)(state, output_n)
- );
-
- // Do Bias after:
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long m_ = nOutputPlane;
- long n_ = outputDepth * outputHeight * outputWidth;
- long k_ = 1;
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- if (bias) {
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemm(
- #elif defined(THC_REAL_IS_HALF)
- THCudaBlas_Hgemm(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemm(
- #endif
- state,
- 't', 'n',
- n_, m_, k_,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, ones), k_,
- THCTensor_(data)(state, bias), k_,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, output_n), n_
- );
- }
- }
-
- // Free
- THCTensor_(free)(state, input_n);
- THCTensor_(free)(state, output_n);
-
- // Resize output
- if (batch == 0) {
- THCTensor_(resize4d)(state, output, nOutputPlane, outputDepth, outputHeight, outputWidth);
- THCTensor_(resize4d)(state, input, nInputPlane, inputDepth, inputHeight, inputWidth);
- }
-
- THCTensor_(free)(state, input);
- THCTensor_(free)(state, weight);
- if (bias) THCTensor_(free)(state, bias);
-
+ THNN_(VolumetricFullDilatedConvolution_updateOutput)(
+ state, input, output, weight, bias, finput, fgradInput,
+ dT, dW, dH, padT, padW, padH, 1, 1, 1, adjT, adjW, adjH);
}
void THNN_(VolumetricFullConvolution_updateGradInput)(
@@ -235,109 +31,9 @@ void THNN_(VolumetricFullConvolution_updateGradInput)(
int padT, int padW, int padH,
int adjT, int adjW, int adjH)
{
- THCTensor *gradColumns = finput;
-
- int nInputPlane = THCTensor_(size)(state, weight, 0);
- int nOutputPlane = THCTensor_(size)(state, weight, 1);
- const int kT = (int)weight->size[2];
- const int kH = (int)weight->size[3];
- const int kW = (int)weight->size[4];
-
- THCUNN_assertSameGPU(state, 5, input, gradOutput, weight,
- gradColumns, gradInput);
- THNN_(VolumetricFullConvolution_shapeCheck)(
- state, input, gradOutput, weight, NULL,
- dT, dW, dH, padT, padW, padH,
- adjT, adjW, adjH);
-
- input = THCTensor_(newContiguous)(state, input);
- gradOutput = THCTensor_(newContiguous)(state, gradOutput);
- weight = THCTensor_(newContiguous)(state, weight);
-
- int batch = 1;
- if (input->nDimension == 4) {
- // Force batch
- batch = 0;
- THCTensor_(resize5d)(state, input, 1, input->size[0], input->size[1], input->size[2], input->size[3]);
- THCTensor_(resize5d)(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
- }
-
- long inputWidth = input->size[4];
- long inputHeight = input->size[3];
- long inputDepth = input->size[2];
- long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW;
- long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH;
- long outputDepth = (inputDepth - 1) * dT - 2*padT + kT + adjT;
-
- // Batch size + input planes
- long batchSize = input->size[0];
-
- // Resize output
- THCTensor_(resize5d)(state, gradInput, batchSize, nInputPlane, inputDepth, inputHeight, inputWidth);
-
- // Resize temporary columns
- THCTensor_(resize2d)(state, gradColumns, nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth);
-
- // Helpers
- THCTensor *gradInput_n = THCTensor_(new)(state);
- THCTensor *gradOutput_n = THCTensor_(new)(state);
-
- // For each elt in batch, do:
- for (int elt = 0; elt < batchSize; elt ++) {
- // Matrix mulitply per sample:
- THCTensor_(select)(state, gradInput_n, gradInput, 0, elt);
- THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
-
- // Extract columns:
- vol2col(
- THCState_getCurrentStream(state),
- THCTensor_(data)(state, gradOutput_n),
- nOutputPlane, outputDepth, outputHeight, outputWidth, kT, kH, kW, padT, padH, padW, dT, dH, dW,
- 1,1,1,
- THCTensor_(data)(state, gradColumns)
- );
-
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long m = weight->size[0];
- long n = gradColumns->size[1];
- long k = weight->size[1] * weight->size[2] * weight->size[3] * weight->size[4];
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemm(
- #elif defined(THC_REAL_IS_HALF)
- THCudaBlas_Hgemm(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemm(
- #endif
- state,
- 'n', 'n',
- n, m, k,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, gradColumns), n,
- THCTensor_(data)(state, weight), k,
- ScalarConvert<int, real>::to(0),
- THCTensor_(data)(state, gradInput_n), n
- );
- }
-
-
- // Free
- THCTensor_(free)(state, gradInput_n);
- THCTensor_(free)(state, gradOutput_n);
-
- // Resize output
- if (batch == 0) {
- THCTensor_(resize4d)(state, gradOutput, nOutputPlane, outputDepth, outputHeight, outputWidth);
- THCTensor_(resize4d)(state, input, nInputPlane, inputDepth, inputHeight, inputWidth);
- THCTensor_(resize4d)(state, gradInput, nInputPlane, inputDepth, inputHeight, inputWidth);
- }
-
- THCTensor_(free)(state, input);
- THCTensor_(free)(state, gradOutput);
- THCTensor_(free)(state, weight);
+ THNN_(VolumetricFullDilatedConvolution_updateGradInput)(
+ state, input, gradOutput, gradInput, weight, finput, fgradInput,
+ dT, dW, dH, padT, padW, padH, 1, 1, 1, adjT, adjW, adjH);
}
@@ -354,152 +50,9 @@ void THNN_(VolumetricFullConvolution_accGradParameters)(
int adjT, int adjW, int adjH,
accreal scale_)
{
- real scale = ScalarConvert<accreal, real>::to(scale_);
- THCTensor *columns = finput;
- THCTensor *ones = fgradInput;
-
- int nInputPlane = THCTensor_(size)(state, gradWeight, 0);
- int nOutputPlane = THCTensor_(size)(state, gradWeight, 1);
- const int kT = (int)gradWeight->size[2];
- const int kH = (int)gradWeight->size[3];
- const int kW = (int)gradWeight->size[4];
-
- THCUNN_assertSameGPU(state, 6, input, gradOutput, gradWeight,
- gradBias, columns, ones);
- THNN_(VolumetricFullConvolution_shapeCheck)(
- state, input, gradOutput, gradWeight,
- gradBias, dT, dW, dH, padT, padW, padH,
- adjT, adjW, adjH);
-
- THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous");
- if (gradBias)
- THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous");
-
- input = THCTensor_(newContiguous)(state, input);
- gradOutput = THCTensor_(newContiguous)(state, gradOutput);
-
- int batch = 1;
- if (input->nDimension == 4) {
- // Force batch
- batch = 0;
- THCTensor_(resize5d)(state, input, 1, input->size[0], input->size[1], input->size[2], input->size[3]);
- THCTensor_(resize5d)(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
- }
-
- long inputWidth = input->size[4];
- long inputHeight = input->size[3];
- long inputDepth = input->size[2];
- long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW;
- long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH;
- long outputDepth = (inputDepth - 1) * dT - 2*padT + kT + adjT;
-
- // Batch size + input planes
- long batchSize = input->size[0];
-
- // Define a buffer of ones, for bias accumulation
- if (ones->nDimension != 3 || ones->size[0]*ones->size[1]*ones->size[2] < outputDepth*outputHeight*outputWidth) {
- // Resize plane and fill with ones...
- THCTensor_(resize3d)(state, ones, outputDepth, outputHeight, outputWidth);
- THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
- }
-
- // Resize temporary columns
- THCTensor_(resize2d)(state, columns, nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth);
-
- // Helpers
- THCTensor *input_n = THCTensor_(new)(state);
- THCTensor *gradOutput_n = THCTensor_(new)(state);
-
- // For each elt in batch, do:
- for (int elt = 0; elt < batchSize; elt ++) {
- // Matrix mulitply per output:
- THCTensor_(select)(state, input_n, input, 0, elt);
- THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
-
- // Extract columns:
- vol2col(
- THCState_getCurrentStream(state),
- THCTensor_(data)(state, gradOutput_n),
- nOutputPlane, outputDepth, outputHeight, outputWidth, kT, kH, kW, padT, padH, padW, dT, dH, dW,
- 1,1,1,
- THCTensor_(data)(state, columns)
- );
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long n = columns->size[0]; // nOutputPlane * kt * kh * kw
- long m = input_n->size[0]; // nInputPlane
- long k = columns->size[1]; // inputHeight * inputWidth
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemm(
- #elif defined(THC_REAL_IS_HALF)
- THCudaBlas_Hgemm(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemm(
- #endif
- state,
- 't', 'n',
- n, m, k,
- scale,
- THCTensor_(data)(state, columns), k,
- THCTensor_(data)(state, input_n), k,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, gradWeight), n
- );
-
- // Do Bias:
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- long m_ = nOutputPlane;
- long k_ = outputDepth * outputHeight * outputWidth;
-
- // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
- if (gradBias) {
- #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemv(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemv(
- #endif
- state,
- 't',
- k_, m_,
- scale,
- THCTensor_(data)(state, gradOutput_n), k_,
- THCTensor_(data)(state, ones), 1,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, gradBias), 1
- );
- #endif
- #ifdef THC_REAL_IS_HALF
- THCudaBlas_Hgemm(
- state,
- 't', 'n',
- m_, 1, k_,
- scale,
- THCTensor_(data)(state, gradOutput_n), k_,
- THCTensor_(data)(state, ones), k_,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, gradBias), m_
- );
- #endif
- }
- }
-
- // Free
- THCTensor_(free)(state, input_n);
- THCTensor_(free)(state, gradOutput_n);
-
- // Resize
- if (batch == 0) {
- THCTensor_(resize4d)(state, gradOutput, nOutputPlane, outputDepth, outputHeight, outputWidth);
- THCTensor_(resize4d)(state, input, nInputPlane, inputDepth, inputHeight, inputWidth);
- }
-
- THCTensor_(free)(state, input);
- THCTensor_(free)(state, gradOutput);
+ THNN_(VolumetricFullDilatedConvolution_accGradParameters)(
+ state, input, gradOutput, gradWeight, gradBias, finput, fgradInput,
+ dT, dW, dH, padT, padW, padH, 1, 1, 1, adjT, adjW, adjH, scale_);
}
#endif
diff --git a/lib/THCUNN/generic/VolumetricFullDilatedConvolution.cu b/lib/THCUNN/generic/VolumetricFullDilatedConvolution.cu
new file mode 100644
index 0000000..bda0b59
--- /dev/null
+++ b/lib/THCUNN/generic/VolumetricFullDilatedConvolution.cu
@@ -0,0 +1,511 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/VolumetricFullDilatedConvolution.cu"
+#else
+
+static inline void THNN_(VolumetricFullDilatedConvolution_shapeCheck)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *weight,
+ THCTensor *bias,
+ int dT, int dW, int dH,
+ int padT, int padW, int padH,
+ int dilationT, int dilationW, int dilationH,
+ int adjT, int adjW, int adjH) {
+ THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input,
+ "4D or 5D (batch mode) tensor expected for input, but got: %s");
+ // number of input & output planes and kernel size is indirectly defined by the weight tensor
+ THCUNN_argCheck(state, weight->nDimension == 5, 4, weight,
+ "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor "
+ "expected for weight, but got: %s");
+ THArgCheck(THCTensor_(isContiguous)(state, weight), 4,
+ "weight tensor has to be contiguous");
+ THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5,
+ "bias tensor has to be contiguous");
+ THArgCheck(dT > 0 && dW > 0 && dH > 0, 8,
+ "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW);
+ THArgCheck(adjT < dT && adjW < dW && adjH < dH, 14,
+ "output adjustment must be smaller than stride, but got "
+ "adjT: %d adjH: %d adjW: %d dT: %d dH: %d dW: %d",
+ adjT, adjH, adjW, dT, dH, dW);
+ THArgCheck(dilationT > 0 && dilationW > 0 && dilationH > 0, 15,
+ "dilation should be greater than zero, but got dilationT: %d, dilationH: %d, dilationW: %d",
+ dilationT, dilationH, dilationW);
+
+ int ndim = input->nDimension;
+ int nInputPlane = THCTensor_(size)(state, weight, 0);
+ int nOutputPlane = THCTensor_(size)(state, weight, 1);
+ const int kT = (int)weight->size[2];
+ const int kH = (int)weight->size[3];
+ const int kW = (int)weight->size[4];
+
+ if (bias != NULL) {
+ THCUNN_check_dim_size(state, bias, 1, 0, weight->size[1]);
+ }
+
+ int dimf = 0;
+ int dimd = 1;
+ int dimh = 2;
+ int dimw = 3;
+
+ if (ndim == 5) {
+ dimf++;
+ dimd++;
+ dimh++;
+ dimw++;
+ }
+
+ long inputWidth = input->size[dimw];
+ long inputHeight = input->size[dimh];
+ long inputDepth = input->size[dimd];
+ long outputDepth = (inputDepth - 1) * dT - 2*padT + (dilationT * (kT - 1) + 1) + adjT;
+ long outputHeight = (inputHeight - 1) * dH - 2*padH + (dilationH * (kH - 1) + 1) + adjH;
+ long outputWidth = (inputWidth - 1) * dW - 2*padW + (dilationW * (kW - 1) + 1) + adjW;
+ if (outputDepth < 1 || outputWidth < 1 || outputHeight < 1)
+ THError("Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small",
+ nInputPlane,inputDepth,inputHeight,inputWidth,nOutputPlane,outputDepth,outputHeight,outputWidth);
+
+ THCUNN_check_dim_size(state, input, ndim, dimf, nInputPlane);
+ if (gradOutput != NULL) {
+ THCUNN_check_dim_size(state, gradOutput, ndim, dimf, nOutputPlane);
+ THCUNN_check_dim_size(state, gradOutput, ndim, dimd, outputDepth);
+ THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight);
+ THCUNN_check_dim_size(state, gradOutput, ndim, dimw, outputWidth);
+ }
+}
+
+void THNN_(VolumetricFullDilatedConvolution_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ THCTensor *weight,
+ THCTensor *bias,
+ THCTensor *finput,
+ THCTensor *fgradInput,
+ int dT, int dW, int dH,
+ int padT, int padW, int padH,
+ int dilationT, int dilationW, int dilationH,
+ int adjT, int adjW, int adjH)
+{
+
+ THCTensor *columns = finput;
+ THCTensor *ones = fgradInput;
+
+ int nInputPlane = THCTensor_(size)(state, weight, 0);
+ int nOutputPlane = THCTensor_(size)(state, weight, 1);
+ const int kT = (int)weight->size[2];
+ const int kH = (int)weight->size[3];
+ const int kW = (int)weight->size[4];
+
+ THCUNN_assertSameGPU(state, 6, input, output, weight,
+ bias, columns, ones);
+ THNN_(VolumetricFullDilatedConvolution_shapeCheck)(
+ state, input, NULL, weight, bias,
+ dT, dW, dH, padT, padW, padH, dilationT, dilationW, dilationH,
+ adjT, adjW, adjH);
+
+ input = THCTensor_(newContiguous)(state, input);
+ weight = THCTensor_(newContiguous)(state, weight);
+ bias = bias ? THCTensor_(newContiguous)(state, bias) : bias;
+
+ int batch = 1;
+ if (input->nDimension == 4) {
+ // Force batch
+ batch = 0;
+ THCTensor_(resize5d)(state, input, 1, input->size[0], input->size[1], input->size[2], input->size[3]);
+ }
+
+ long inputWidth = input->size[4];
+ long inputHeight = input->size[3];
+ long inputDepth = input->size[2];
+ long outputDepth = (inputDepth - 1) * dT - 2*padT + (dilationT * (kT - 1) + 1) + adjT;
+ long outputHeight = (inputHeight - 1) * dH - 2*padH + (dilationH * (kH - 1) + 1) + adjH;
+ long outputWidth = (inputWidth - 1) * dW - 2*padW + (dilationW * (kW - 1) + 1) + adjW;
+
+ // Batch size + input planes
+ long batchSize = input->size[0];
+
+ // Resize output
+ THCTensor_(resize5d)(state, output, batchSize, nOutputPlane, outputDepth, outputHeight, outputWidth);
+
+ // Resize temporary columns
+ THCTensor_(resize2d)(state, columns, nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth);
+
+ // Define a buffer of ones, for bias accumulation
+ // Note: this buffer can be shared with other modules, it only ever gets increased,
+ // and always contains ones.
+ if (ones->nDimension != 3 || ones->size[0]*ones->size[1]*ones->size[2] < outputDepth*outputHeight*outputWidth) {
+ // Resize plane and fill with ones...
+ THCTensor_(resize3d)(state, ones, outputDepth, outputHeight, outputWidth);
+ THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
+ }
+
+ // Helpers
+ THCTensor *input_n = THCTensor_(new)(state);
+ THCTensor *output_n = THCTensor_(new)(state);
+
+ // For each elt in batch, do:
+ for (int elt = 0; elt < batchSize; elt ++) {
+ // Matrix mulitply per output:
+ THCTensor_(select)(state, input_n, input, 0, elt);
+ THCTensor_(select)(state, output_n, output, 0, elt);
+
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m = weight->size[1] * weight->size[2] * weight->size[3] * weight->size[4];
+ long n = columns->size[1];
+ long k = weight->size[0];
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ #ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemm(
+ #elif defined(THC_REAL_IS_HALF)
+ THCudaBlas_Hgemm(
+ #elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemm(
+ #endif
+ state,
+ 'n', 't',
+ n, m, k,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, input_n), n,
+ THCTensor_(data)(state, weight), m,
+ ScalarConvert<int, real>::to(0),
+ THCTensor_(data)(state, columns), n
+ );
+
+ // Unpack columns back into input:
+ col2vol<real, accreal>(
+ THCState_getCurrentStream(state),
+ THCTensor_(data)(state, columns),
+ nOutputPlane, outputDepth, outputHeight, outputWidth, kT, kH, kW, padT, padH, padW, dT, dH, dW,
+ dilationT, dilationH, dilationW,
+ THCTensor_(data)(state, output_n)
+ );
+
+ // Do Bias after:
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m_ = nOutputPlane;
+ long n_ = outputDepth * outputHeight * outputWidth;
+ long k_ = 1;
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ if (bias) {
+ #ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemm(
+ #elif defined(THC_REAL_IS_HALF)
+ THCudaBlas_Hgemm(
+ #elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemm(
+ #endif
+ state,
+ 't', 'n',
+ n_, m_, k_,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, ones), k_,
+ THCTensor_(data)(state, bias), k_,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, output_n), n_
+ );
+ }
+ }
+
+ // Free
+ THCTensor_(free)(state, input_n);
+ THCTensor_(free)(state, output_n);
+
+ // Resize output
+ if (batch == 0) {
+ THCTensor_(resize4d)(state, output, nOutputPlane, outputDepth, outputHeight, outputWidth);
+ THCTensor_(resize4d)(state, input, nInputPlane, inputDepth, inputHeight, inputWidth);
+ }
+
+ THCTensor_(free)(state, input);
+ THCTensor_(free)(state, weight);
+ if (bias) THCTensor_(free)(state, bias);
+
+}
+
+void THNN_(VolumetricFullDilatedConvolution_updateGradInput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ THCTensor *weight,
+ THCTensor *finput,
+ THCTensor *fgradInput,
+ int dT, int dW, int dH,
+ int padT, int padW, int padH,
+ int dilationT, int dilationW, int dilationH,
+ int adjT, int adjW, int adjH)
+{
+ THCTensor *gradColumns = finput;
+
+ int nInputPlane = THCTensor_(size)(state, weight, 0);
+ int nOutputPlane = THCTensor_(size)(state, weight, 1);
+ const int kT = (int)weight->size[2];
+ const int kH = (int)weight->size[3];
+ const int kW = (int)weight->size[4];
+
+ THCUNN_assertSameGPU(state, 5, input, gradOutput, weight,
+ gradColumns, gradInput);
+ THNN_(VolumetricFullDilatedConvolution_shapeCheck)(
+ state, input, gradOutput, weight, NULL,
+ dT, dW, dH, padT, padW, padH, dilationT, dilationW, dilationH,
+ adjT, adjW, adjH);
+
+ input = THCTensor_(newContiguous)(state, input);
+ gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ weight = THCTensor_(newContiguous)(state, weight);
+
+ int batch = 1;
+ if (input->nDimension == 4) {
+ // Force batch
+ batch = 0;
+ THCTensor_(resize5d)(state, input, 1, input->size[0], input->size[1], input->size[2], input->size[3]);
+ THCTensor_(resize5d)(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
+ }
+
+ long inputWidth = input->size[4];
+ long inputHeight = input->size[3];
+ long inputDepth = input->size[2];
+ long outputDepth = (inputDepth - 1) * dT - 2*padT + (dilationT * (kT - 1) + 1) + adjT;
+ long outputHeight = (inputHeight - 1) * dH - 2*padH + (dilationH * (kH - 1) + 1) + adjH;
+ long outputWidth = (inputWidth - 1) * dW - 2*padW + (dilationW * (kW - 1) + 1) + adjW;
+
+ // Batch size + input planes
+ long batchSize = input->size[0];
+
+ // Resize output
+ THCTensor_(resize5d)(state, gradInput, batchSize, nInputPlane, inputDepth, inputHeight, inputWidth);
+
+ // Resize temporary columns
+ THCTensor_(resize2d)(state, gradColumns, nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth);
+
+ // Helpers
+ THCTensor *gradInput_n = THCTensor_(new)(state);
+ THCTensor *gradOutput_n = THCTensor_(new)(state);
+
+ // For each elt in batch, do:
+ for (int elt = 0; elt < batchSize; elt ++) {
+ // Matrix mulitply per sample:
+ THCTensor_(select)(state, gradInput_n, gradInput, 0, elt);
+ THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
+
+ // Extract columns:
+ vol2col(
+ THCState_getCurrentStream(state),
+ THCTensor_(data)(state, gradOutput_n),
+ nOutputPlane, outputDepth, outputHeight, outputWidth, kT, kH, kW, padT, padH, padW, dT, dH, dW,
+ dilationT, dilationH, dilationW,
+ THCTensor_(data)(state, gradColumns)
+ );
+
+
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m = weight->size[0];
+ long n = gradColumns->size[1];
+ long k = weight->size[1] * weight->size[2] * weight->size[3] * weight->size[4];
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ #ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemm(
+ #elif defined(THC_REAL_IS_HALF)
+ THCudaBlas_Hgemm(
+ #elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemm(
+ #endif
+ state,
+ 'n', 'n',
+ n, m, k,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, gradColumns), n,
+ THCTensor_(data)(state, weight), k,
+ ScalarConvert<int, real>::to(0),
+ THCTensor_(data)(state, gradInput_n), n
+ );
+ }
+
+
+ // Free
+ THCTensor_(free)(state, gradInput_n);
+ THCTensor_(free)(state, gradOutput_n);
+
+ // Resize output
+ if (batch == 0) {
+ THCTensor_(resize4d)(state, gradOutput, nOutputPlane, outputDepth, outputHeight, outputWidth);
+ THCTensor_(resize4d)(state, input, nInputPlane, inputDepth, inputHeight, inputWidth);
+ THCTensor_(resize4d)(state, gradInput, nInputPlane, inputDepth, inputHeight, inputWidth);
+ }
+
+ THCTensor_(free)(state, input);
+ THCTensor_(free)(state, gradOutput);
+ THCTensor_(free)(state, weight);
+}
+
+
+void THNN_(VolumetricFullDilatedConvolution_accGradParameters)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradWeight,
+ THCTensor *gradBias,
+ THCTensor *finput,
+ THCTensor *fgradInput,
+ int dT, int dW, int dH,
+ int padT, int padW, int padH,
+ int dilationT, int dilationW, int dilationH,
+ int adjT, int adjW, int adjH,
+ accreal scale_)
+{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
+ THCTensor *columns = finput;
+ THCTensor *ones = fgradInput;
+
+ int nInputPlane = THCTensor_(size)(state, gradWeight, 0);
+ int nOutputPlane = THCTensor_(size)(state, gradWeight, 1);
+ const int kT = (int)gradWeight->size[2];
+ const int kH = (int)gradWeight->size[3];
+ const int kW = (int)gradWeight->size[4];
+
+ THCUNN_assertSameGPU(state, 6, input, gradOutput, gradWeight,
+ gradBias, columns, ones);
+ THNN_(VolumetricFullDilatedConvolution_shapeCheck)(
+ state, input, gradOutput, gradWeight,
+ gradBias, dT, dW, dH, padT, padW, padH, dilationT, dilationW, dilationH,
+ adjT, adjW, adjH);
+
+ THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous");
+ if (gradBias)
+ THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous");
+
+ input = THCTensor_(newContiguous)(state, input);
+ gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+
+ int batch = 1;
+ if (input->nDimension == 4) {
+ // Force batch
+ batch = 0;
+ THCTensor_(resize5d)(state, input, 1, input->size[0], input->size[1], input->size[2], input->size[3]);
+ THCTensor_(resize5d)(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
+ }
+
+ long inputWidth = input->size[4];
+ long inputHeight = input->size[3];
+ long inputDepth = input->size[2];
+ long outputDepth = (inputDepth - 1) * dT - 2*padT + (dilationT * (kT - 1) + 1) + adjT;
+ long outputHeight = (inputHeight - 1) * dH - 2*padH + (dilationH * (kH - 1) + 1) + adjH;
+ long outputWidth = (inputWidth - 1) * dW - 2*padW + (dilationW * (kW - 1) + 1) + adjW;
+
+ // Batch size + input planes
+ long batchSize = input->size[0];
+
+ // Define a buffer of ones, for bias accumulation
+ if (ones->nDimension != 3 || ones->size[0]*ones->size[1]*ones->size[2] < outputDepth*outputHeight*outputWidth) {
+ // Resize plane and fill with ones...
+ THCTensor_(resize3d)(state, ones, outputDepth, outputHeight, outputWidth);
+ THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
+ }
+
+ // Resize temporary columns
+ THCTensor_(resize2d)(state, columns, nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth);
+
+ // Helpers
+ THCTensor *input_n = THCTensor_(new)(state);
+ THCTensor *gradOutput_n = THCTensor_(new)(state);
+
+ // For each elt in batch, do:
+ for (int elt = 0; elt < batchSize; elt ++) {
+ // Matrix mulitply per output:
+ THCTensor_(select)(state, input_n, input, 0, elt);
+ THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
+
+ // Extract columns:
+ vol2col(
+ THCState_getCurrentStream(state),
+ THCTensor_(data)(state, gradOutput_n),
+ nOutputPlane, outputDepth, outputHeight, outputWidth, kT, kH, kW, padT, padH, padW, dT, dH, dW,
+ dilationT, dilationH, dilationW,
+ THCTensor_(data)(state, columns)
+ );
+
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long n = columns->size[0]; // nOutputPlane * kt * kh * kw
+ long m = input_n->size[0]; // nInputPlane
+ long k = columns->size[1]; // inputHeight * inputWidth
+
+ // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
+ #ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemm(
+ #elif defined(THC_REAL_IS_HALF)
+ THCudaBlas_Hgemm(
+ #elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemm(
+ #endif
+ state,
+ 't', 'n',
+ n, m, k,
+ scale,
+ THCTensor_(data)(state, columns), k,
+ THCTensor_(data)(state, input_n), k,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, gradWeight), n
+ );
+
+ // Do Bias:
+ // M,N,K are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m_ = nOutputPlane;
+ long k_ = outputDepth * outputHeight * outputWidth;
+
+ // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
+ if (gradBias) {
+ #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
+ #ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemv(
+ #elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemv(
+ #endif
+ state,
+ 't',
+ k_, m_,
+ scale,
+ THCTensor_(data)(state, gradOutput_n), k_,
+ THCTensor_(data)(state, ones), 1,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, gradBias), 1
+ );
+ #endif
+ #ifdef THC_REAL_IS_HALF
+ THCudaBlas_Hgemm(
+ state,
+ 't', 'n',
+ m_, 1, k_,
+ scale,
+ THCTensor_(data)(state, gradOutput_n), k_,
+ THCTensor_(data)(state, ones), k_,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, gradBias), m_
+ );
+ #endif
+ }
+ }
+
+ // Free
+ THCTensor_(free)(state, input_n);
+ THCTensor_(free)(state, gradOutput_n);
+
+ // Resize
+ if (batch == 0) {
+ THCTensor_(resize4d)(state, gradOutput, nOutputPlane, outputDepth, outputHeight, outputWidth);
+ THCTensor_(resize4d)(state, input, nInputPlane, inputDepth, inputHeight, inputWidth);
+ }
+
+ THCTensor_(free)(state, input);
+ THCTensor_(free)(state, gradOutput);
+}
+
+#endif