Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWill Frey <will.frey@digitalreasoning.com>2017-01-27 21:30:25 +0300
committerSoumith Chintala <soumith@gmail.com>2017-01-27 21:30:25 +0300
commit37db7b80735e0a6d74c0f7a9b6fc72b1df5ccd38 (patch)
treeb6b2be2dd6761e3b67582921c5066dcaf8b3b655
parent5fa193a84ca8fe112bf6e75487ac96eb1b1239d2 (diff)
Added cunn support for TemporalRowConvolutionMM (#415)
* Added cunn TemporalRowConvolutionMM support
-rw-r--r--lib/THCUNN/TemporalRowConvolution.cu10
-rw-r--r--lib/THCUNN/generic/THCUNN.h40
-rw-r--r--lib/THCUNN/generic/TemporalRowConvolution.cu423
-rw-r--r--lib/THCUNN/row2col.h90
-rw-r--r--test.lua269
5 files changed, 832 insertions, 0 deletions
diff --git a/lib/THCUNN/TemporalRowConvolution.cu b/lib/THCUNN/TemporalRowConvolution.cu
new file mode 100644
index 0000000..dc3b18c
--- /dev/null
+++ b/lib/THCUNN/TemporalRowConvolution.cu
@@ -0,0 +1,10 @@
+#include "THCUNN.h"
+#include "common.h"
+#include "row2col.h"
+
+#include "THCHalf.h"
+#include "THCHalfAutoNumerics.cuh"
+
+#include "generic/TemporalRowConvolution.cu"
+
+#include "THCGenerateFloatTypes.h"
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index bf903b9..ec3d287 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -933,6 +933,46 @@ TH_API void THNN_(TemporalMaxPooling_updateGradInput)(
THCIndexTensor *indices,
int kW, int dW);
+TH_API void THNN_(TemporalRowConvolution_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ THCTensor *weight,
+ THCTensor *bias, // [OPTIONAL]
+ THCTensor *finput,
+ THCTensor *fgradInput,
+ int kW,
+ int dW,
+ int padW,
+ bool featFirst);
+
+TH_API void THNN_(TemporalRowConvolution_updateGradInput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ THCTensor *weight,
+ THCTensor *finput,
+ THCTensor *fgradInput,
+ int kW,
+ int dW,
+ int padW,
+ bool featFirst);
+
+TH_API void THNN_(TemporalRowConvolution_accGradParameters)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradWeight,
+ THCTensor *gradBias,
+ THCTensor *finput,
+ THCTensor *fgradInput,
+ int kW,
+ int dW,
+ int padW,
+ bool featFirst,
+ real scale);
+
TH_API void THNN_(Threshold_updateOutput)(
THCState *state,
THCTensor *input,
diff --git a/lib/THCUNN/generic/TemporalRowConvolution.cu b/lib/THCUNN/generic/TemporalRowConvolution.cu
new file mode 100644
index 0000000..365599d
--- /dev/null
+++ b/lib/THCUNN/generic/TemporalRowConvolution.cu
@@ -0,0 +1,423 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/TemporalRowConvolution.cu"
+#else
+
+static inline void THNN_(TemporalRowConvolution_shapeCheck)(
+ THCState *state, THCTensor *input, THCTensor *gradOutput, THCTensor *weight,
+ THCTensor *bias, int kW, int dW, int padW) {
+
+ THArgCheck(kW > 0, 5,
+ "kernel size should be greater than zero, but got kW: %d", kW);
+ THArgCheck(dW > 0, 6, "stride should be greater than zero, but got dW: %d",
+ dW);
+ THCUNN_argCheck(state, weight->nDimension == 2 || weight->nDimension == 3, 3,
+ weight, "2D or 3D weight tensor expected, but got: %s");
+
+ if (bias != NULL) {
+ THCUNN_check_dim_size(state, bias, 1, 0, weight->size[0]);
+ }
+
+ int ndim = input->nDimension;
+ int dimF = 0; // feature dimension
+ int dimS = 1; // sequence dimension
+
+ if (ndim == 3) {
+ ++dimF;
+ ++dimS;
+ }
+
+ THCUNN_argCheck(state, ndim == 2 || ndim == 3, 1, input,
+ "2D or 3D (batch mode) input tensor expected, but got :%s");
+
+ long inputFrameSize = weight->size[0];
+ long nInputFrame = input->size[dimS];
+ long nOutputFrame = (nInputFrame + 2 * padW - kW) / dW + 1;
+
+ if (nOutputFrame < 1) {
+ THError("Given input size: (%d x %d). "
+ "Calculated output size: (%d x %d). Output size is too small",
+ inputFrameSize, nInputFrame, inputFrameSize, nOutputFrame);
+ }
+
+ THCUNN_check_dim_size(state, input, ndim, dimF, inputFrameSize);
+
+ if (gradOutput != NULL) {
+ THCUNN_check_dim_size(state, gradOutput, ndim, dimF, inputFrameSize);
+ THCUNN_check_dim_size(state, gradOutput, ndim, dimS, nOutputFrame);
+ }
+}
+
+void THNN_(TemporalRowConvolution_updateOutput)(
+ THCState *state, THCTensor *input, THCTensor *output, THCTensor *weight,
+ THCTensor *bias, THCTensor *finput, THCTensor *fgradInput, int kW, int dW,
+ int padW, bool featFirst) {
+
+ // aliases
+ THCTensor *columns = finput;
+ THCTensor *ones = fgradInput;
+
+ // assert same GPU
+ THCUNN_assertSameGPU(state, 5, input, output, weight, columns, ones);
+ if (bias != NULL) {
+ THCUNN_assertSameGPU(state, 2, weight, bias);
+ }
+
+ // reshape weight if necessary
+ int ndim = input->nDimension;
+
+ THCTensor *tinput;
+
+ if (!featFirst) {
+ tinput = THCTensor_(newTranspose)(state, input, ndim - 1, ndim - 2);
+ input = THCTensor_(newContiguous)(state, tinput);
+ } else {
+ input = THCTensor_(newContiguous)(state, input);
+ }
+
+ THNN_(TemporalRowConvolution_shapeCheck)
+ (state, input, NULL, weight, bias, kW, dW, padW);
+
+ int batch = 1;
+ if (ndim == 2) {
+ // Force batch
+ batch = 0;
+ THCTensor_(resize3d)(state, input, 1, input->size[0], input->size[1]);
+ }
+
+ // Params:
+ long inputFrameSize = weight->size[0];
+ long nInputFrame = input->size[2];
+ long nOutputFrame = (nInputFrame + 2 * padW - kW) / dW + 1;
+
+ // Batch size
+ long batchSize = input->size[0];
+
+ // Resize output
+ THCTensor_(resize3d)(state, output, batchSize, inputFrameSize, nOutputFrame);
+
+ // Augment the input
+ THCTensor_(resize3d)(state, columns, inputFrameSize, kW, nOutputFrame);
+
+ // Define a buffer of ones, for bias accumulation
+ // Note: this buffer can be shared with other modules, it only ever
+ // gets increased and always contains ones.
+ if (ones->nDimension != 2 || ones->size[0] * ones->size[1] < nOutputFrame) {
+ // Resize plane and fill with ones...
+ THCTensor_(resize2d)(state, ones, 1, nOutputFrame);
+ THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
+ }
+
+ // Helpers
+ THCTensor *input_n = THCTensor_(new)(state);
+ THCTensor *output_n = THCTensor_(new)(state);
+
+ // For each elt in batch, do:
+ for (int elt = 0; elt < batchSize; ++elt) {
+ // Matrix multiply per output:
+ THCTensor_(select)(state, input_n, input, 0, elt);
+ THCTensor_(select)(state, output_n, output, 0, elt);
+
+ // Do bias first:
+ // m_, n_, k_ are dims of matrix A and B
+ // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
+ long m_ = inputFrameSize;
+ long n_ = nOutputFrame;
+ long k_ = 1;
+
+ // Do GEMM (note: this is a bit confusing because gemm asummes
+ // column-major matrices)
+ if (bias != NULL) {
+#ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemm(
+#elif defined(THC_REAL_IS_HALF)
+ THCudaBlas_Hgemm(
+#elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemm(
+#endif
+ state, 't', 'n', n_, m_, k_, ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, ones), k_, THCTensor_(data)(state, bias), k_,
+ ScalarConvert<int, real>::to(0), THCTensor_(data)(state, output_n),
+ n_);
+ } else {
+ THCTensor_(zero)(state, output_n);
+ }
+
+ // Extract columns:
+ row2col(THCState_getCurrentStream(state), THCTensor_(data)(state, input_n),
+ inputFrameSize, nInputFrame, kW, padW, dW, 1,
+ THCTensor_(data)(state, columns));
+
+ THCTensor *output3d = THCTensor_(newWithStorage3d)(
+ state, output_n->storage, output_n->storageOffset, inputFrameSize, -1,
+ 1, -1, nOutputFrame, -1);
+
+ // weight: inputFrameSize x 1 x kW
+ // columns: inputFrameSize x kW x nOutputFrame
+ THCTensor_(baddbmm)(state, output3d, ScalarConvert<int, real>::to(1),
+ output3d, ScalarConvert<int, real>::to(1), weight,
+ columns);
+ // output3d: inputFrameSize x 1 x nOutputFrame
+
+ THCTensor_(free)(state, output3d);
+ }
+
+ // Free
+ THCTensor_(free)(state, input_n);
+ THCTensor_(free)(state, output_n);
+
+ // Resize output
+ if (batch == 0) {
+ THCTensor_(resize2d)(state, output, inputFrameSize, nOutputFrame);
+ THCTensor_(resize2d)(state, input, inputFrameSize, nInputFrame);
+ }
+
+ if (!featFirst) {
+ THCTensor_(transpose)(state, output, output, ndim - 1, ndim - 2);
+ THCTensor_(free)(state, tinput);
+ }
+
+ THCTensor_(free)(state, input);
+}
+
+void THNN_(TemporalRowConvolution_updateGradInput)(
+ THCState *state, THCTensor *input, THCTensor *gradOutput,
+ THCTensor *gradInput, THCTensor *weight, THCTensor *finput,
+ THCTensor *fgradInput, int kW, int dW, int padW, bool featFirst) {
+
+ // aliases
+ THCTensor *gradColumns = finput;
+
+ THCUNN_assertSameGPU(state, 5, input, gradOutput, weight, gradColumns,
+ gradInput);
+
+ int ndim = input->nDimension;
+
+ THCTensor *tinput, *tgradOutput;
+
+ if (!featFirst) {
+ tinput = THCTensor_(newTranspose)(state, input, ndim - 1, ndim - 2);
+ tgradOutput =
+ THCTensor_(newTranspose)(state, gradOutput, ndim - 1, ndim - 2);
+ input = THCTensor_(newContiguous)(state, tinput);
+ gradOutput = THCTensor_(newContiguous)(state, tgradOutput);
+
+ } else {
+ input = THCTensor_(newContiguous)(state, input);
+ gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ }
+
+ THNN_(TemporalRowConvolution_shapeCheck)
+ (state, input, gradOutput, weight, NULL, kW, dW, padW);
+
+ int batch = 1;
+ if (ndim == 2) {
+ // Force batch
+ batch = 0;
+ THCTensor_(resize3d)(state, input, 1, input->size[0], input->size[1]);
+ THCTensor_(resize3d)(state, gradOutput, 1, gradOutput->size[0],
+ gradOutput->size[1]);
+ }
+
+ // Params:
+ long inputFrameSize = weight->size[0];
+ long nInputFrame = input->size[2];
+ long nOutputFrame = gradOutput->size[2];
+
+ // Batch size
+ long batchSize = input->size[0];
+
+ // Resize output
+ THCTensor_(resize3d)(state, gradInput, batchSize, inputFrameSize,
+ nInputFrame);
+
+ // Resize temporary columns
+ THCTensor_(resize3d)(state, gradColumns, inputFrameSize, kW, nOutputFrame);
+
+ // Helpers
+ THCTensor *gradInput_n = THCTensor_(new)(state);
+ THCTensor *gradOutput_n = THCTensor_(new)(state);
+
+ THCTensor_(transpose)(state, weight, weight, 1, 2);
+
+ for (int elt = 0; elt < batchSize; ++elt) {
+ // Matrix multiply per sample:
+ THCTensor_(select)(state, gradInput_n, gradInput, 0, elt);
+ THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
+
+ THCTensor *gradOutput3d = THCTensor_(newWithStorage3d)(
+ state, gradOutput_n->storage, gradOutput_n->storageOffset,
+ inputFrameSize, -1, 1, -1, nOutputFrame, -1);
+
+ // weight: inputFrameSize x kW x 1
+ // gradOutput3d: inputFrameSize x 1 x nOutputFrame
+ THCTensor_(baddbmm)(state, gradColumns, ScalarConvert<int, real>::to(0),
+ gradColumns, ScalarConvert<int, real>::to(1), weight,
+ gradOutput3d);
+ // gradColumns: inputFrameSize x kW x nOutputFrame
+
+ // Unpack columns back into input:
+ col2row<real, accreal>(THCState_getCurrentStream(state),
+ THCTensor_(data)(state, gradColumns), inputFrameSize,
+ nInputFrame, kW, padW, dW, 1,
+ THCTensor_(data)(state, gradInput_n));
+
+ THCTensor_(free)(state, gradOutput3d);
+ }
+
+ // Free
+ THCTensor_(free)(state, gradInput_n);
+ THCTensor_(free)(state, gradOutput_n);
+
+ // Resize output
+ if (batch == 0) {
+ THCTensor_(resize2d)(state, gradOutput, inputFrameSize, nOutputFrame);
+ THCTensor_(resize2d)(state, input, inputFrameSize, nInputFrame);
+ THCTensor_(resize2d)(state, gradInput, inputFrameSize, nInputFrame);
+ }
+
+ THCTensor_(transpose)(state, weight, weight, 1, 2);
+
+ if (!featFirst) {
+ THCTensor_(transpose)(state, gradInput, gradInput, ndim - 1, ndim - 2);
+ THCTensor_(free)(state, tinput);
+ THCTensor_(free)(state, tgradOutput);
+ }
+
+ THCTensor_(free)(state, input);
+ THCTensor_(free)(state, gradOutput);
+}
+
+void THNN_(TemporalRowConvolution_accGradParameters)(
+ THCState *state, THCTensor *input, THCTensor *gradOutput,
+ THCTensor *gradWeight, THCTensor *gradBias, THCTensor *finput,
+ THCTensor *fgradInput, int kW, int dW, int padW, bool featFirst,
+ real scale) {
+
+ // Aliases
+ THCTensor *columns = finput;
+ THCTensor *ones = fgradInput;
+
+ THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
+ if (gradBias != NULL) {
+ THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
+ }
+
+ int ndim = input->nDimension;
+
+ THCTensor *tinput, *tgradOutput;
+
+ if (!featFirst) {
+ tinput = THCTensor_(newTranspose)(state, input, ndim - 1, ndim - 2);
+ tgradOutput =
+ THCTensor_(newTranspose)(state, gradOutput, ndim - 1, ndim - 2);
+ input = THCTensor_(newContiguous)(state, tinput);
+ gradOutput = THCTensor_(newContiguous)(state, tgradOutput);
+ } else {
+ input = THCTensor_(newContiguous)(state, input);
+ gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ }
+
+ THNN_(TemporalRowConvolution_shapeCheck)
+ (state, input, gradOutput, gradWeight, gradBias, kW, dW, padW);
+
+ int batch = 1;
+ if (ndim == 2) {
+ // Force batch
+ batch = 0;
+ THCTensor_(resize3d)(state, input, 1, input->size[0], input->size[1]);
+ THCTensor_(resize3d)(state, gradOutput, 1, gradOutput->size[0],
+ gradOutput->size[1]);
+ }
+
+ // Params:
+ long inputFrameSize = gradWeight->size[0];
+ long nInputFrame = input->size[2];
+ long nOutputFrame = gradOutput->size[2];
+
+ // Batch size
+ long batchSize = input->size[0];
+
+ // Define a buffer of ones, for bias accumulation
+ if (ones->nDimension != 2 || ones->size[0] * ones->size[1] < nOutputFrame) {
+ // Resize plane and fill with ones...
+ THCTensor_(resize2d)(state, ones, 1, nOutputFrame);
+ THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
+ }
+
+ // // Resize temporary columns
+ THCTensor_(resize3d)(state, columns, inputFrameSize, kW, nOutputFrame);
+
+ // Helpers
+ THCTensor *input_n = THCTensor_(new)(state);
+ THCTensor *gradOutput_n = THCTensor_(new)(state);
+
+ // For each elt in batch, do:
+ for (int elt = 0; elt < batchSize; ++elt) {
+ // Matrix multiply per output
+ THCTensor_(select)(state, input_n, input, 0, elt);
+ THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
+
+ THCTensor *gradOutput3d = THCTensor_(newWithStorage3d)(
+ state, gradOutput_n->storage, gradOutput_n->storageOffset,
+ inputFrameSize, -1, 1, -1, nOutputFrame, -1);
+
+ // Extract columns
+ row2col(THCState_getCurrentStream(state), THCTensor_(data)(state, input_n),
+ inputFrameSize, nInputFrame, kW, padW, dW, 1,
+ THCTensor_(data)(state, columns));
+
+ THCTensor_(transpose)(state, columns, columns, 1, 2);
+
+ // gradOutput3d: inputFrameSize x 1 x nOutputFrame
+ // columns: inputFrameSize x nOutputFrame x kW
+ THCTensor_(baddbmm)(state, gradWeight, ScalarConvert<int, real>::to(1),
+ gradWeight, scale, gradOutput3d, columns);
+ // gradWeight: inputFrameSize x 1 x kW
+
+ THCTensor_(transpose)(state, columns, columns, 1, 2);
+
+ THCTensor_(free)(state, gradOutput3d);
+
+ if (gradBias != NULL) {
+ long m_ = inputFrameSize;
+ long k_ = nOutputFrame;
+#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
+#ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemv(
+#elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemv(
+#endif
+ state, 't', k_, m_, scale, THCTensor_(data)(state, gradOutput_n), k_,
+ THCTensor_(data)(state, ones), 1, ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, gradBias), 1);
+#endif
+#ifdef THC_REAL_IS_HALF // half not supported due to baddbmm
+ THCudaBlas_Hgemm(state, 't', 'n', m_, 1, k_, scale,
+ THCTensor_(data)(state, gradOutput_n), k_,
+ THCTensor_(data)(state, ones), k_,
+ ScalarConvert<int, real>::to(1),
+ THCTensor_(data)(state, gradBias), m_);
+#endif
+ }
+ }
+
+ // Free
+ THCTensor_(free)(state, input_n);
+ THCTensor_(free)(state, gradOutput_n);
+
+ // Resize
+ if (batch == 0) {
+ THCTensor_(resize2d)(state, gradOutput, inputFrameSize, nOutputFrame);
+ THCTensor_(resize2d)(state, input, inputFrameSize, nInputFrame);
+ }
+
+ if (!featFirst) {
+ THCTensor_(free)(state, tinput);
+ THCTensor_(free)(state, tgradOutput);
+ }
+
+ THCTensor_(free)(state, input);
+ THCTensor_(free)(state, gradOutput);
+}
+
+#endif
diff --git a/lib/THCUNN/row2col.h b/lib/THCUNN/row2col.h
new file mode 100644
index 0000000..04765dd
--- /dev/null
+++ b/lib/THCUNN/row2col.h
@@ -0,0 +1,90 @@
+#ifndef THCUNN_ROW2COL_H
+#define THCUNN_ROW2COL_H
+
+#include "THCNumerics.cuh"
+#include "common.h"
+
+// Kernel for fast unfold+copy on rows
+template <typename Dtype>
+__global__ void
+row2col_kernel(const int n, const Dtype *data_row, const int width,
+ const int ksize_w, const int pad_w, const int stride_w,
+ const int dilation_w, const int width_col, Dtype *data_col) {
+ CUDA_KERNEL_LOOP(index, n) {
+ int w_out = index % width_col;
+ index /= width_col;
+ int channel_in = index;
+ int channel_out = channel_in * ksize_w;
+ int w_in = w_out * stride_w - pad_w;
+ data_col += (channel_out)*width_col + w_out;
+ data_row += (channel_in)*width + w_in;
+ for (int j = 0; j < ksize_w; ++j) {
+ int w = w_in + j * dilation_w;
+ *data_col = (w >= 0 && w < width) ? data_row[j * dilation_w]
+ : ScalarConvert<int, Dtype>::to(0);
+ data_col += width_col;
+ }
+ }
+}
+
+template <typename Dtype>
+void row2col(cudaStream_t stream, const Dtype *data_row, const int channels,
+ const int width, const int ksize_w, const int pad_w,
+ const int stride_w, const int dilation_w, Dtype *data_col) {
+ // We are going to launch channels * width_col kernels, each
+ // kernel responsible for copying a single-channel grid.
+ int width_col =
+ (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * width_col;
+ // Launch
+ row2col_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>(
+ num_kernels, data_row, width, ksize_w, pad_w, stride_w, 1, width_col,
+ data_col);
+ THCudaCheck(cudaGetLastError());
+}
+
+template <typename Dtype, typename Acctype>
+__global__ void col2row_kernel(const int n, const Dtype *data_col,
+ const int width, const int channels,
+ const int kernel_w, const int pad_w,
+ const int stride_w, const int dilation_w,
+ const int width_col, Dtype *data_row) {
+ CUDA_KERNEL_LOOP(index, n) {
+ Acctype val = Acctype(0);
+ const int w_row = index % width + pad_w;
+ const int c_row = index / width;
+ int kernel_extent_w = (kernel_w - 1) * dilation_w + 1;
+ // compute the start and end of the output
+ const int w_col_start = (w_row < kernel_extent_w)
+ ? 0
+ : (w_row - kernel_extent_w) / stride_w + 1;
+ const int w_col_end = min(w_row / stride_w + 1, width_col);
+ for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) {
+ int w_k = (w_row - w_col * stride_w);
+ if (w_k % dilation_w == 0) {
+ w_k /= dilation_w;
+ int data_col_index = (c_row * kernel_w + w_k) * width_col + w_col;
+ val += data_col[data_col_index];
+ }
+ }
+ data_row[index] = ScalarConvert<Acctype, Dtype>::to(val);
+ }
+ }
+
+template <typename Dtype, typename Acctype>
+void col2row(cudaStream_t stream, const Dtype *data_col, const int channels,
+ const int width, const int patch_w, const int pad_w,
+ const int stride_w, const int dilation_w, Dtype *data_row) {
+ int width_col =
+ (width + 2 * pad_w - (dilation_w * (patch_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * width;
+ // To avoid involving atomic operations, we will launch one kernel per
+ // bottom dimension, and then in the kernel add up the top dimensions.
+ col2row_kernel<
+ Dtype, Acctype><<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>(
+ num_kernels, data_col, width, channels, patch_w, pad_w, stride_w,
+ dilation_w, width_col, data_row);
+
+ THCudaCheck(cudaGetLastError());
+}
+#endif
diff --git a/test.lua b/test.lua
index c3ed9bb..14d072d 100644
--- a/test.lua
+++ b/test.lua
@@ -3661,6 +3661,275 @@ function cunntest.TemporalConvolution_backward_batch()
end
end
+
+function cunntest.TemporalRowConvolution_forward_single()
+ local from = math.random(1,64) -- nFeature
+ local to = from
+ local ki = math.random(3,15) -- kW
+ local si = math.random(1,2) -- dW
+ local outi = math.random(1,256) -- nOutputFrame
+ local ini = (outi-1)*si+ki -- nInputFrame
+
+ local function jacTest(noBias, featFirst)
+ noBias = noBias or false
+ featFirst = featFirst or false
+
+ for k, typename in ipairs(typenames) do
+ if typename ~= "torch.CudaHalfTensor" then
+
+ local input
+ if featFirst then
+ input = torch.randn(from, ini):type(typename)
+ else
+ input = torch.randn(ini, from):type(typename)
+ end
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ local mod = nn.TemporalRowConvolution(from,ki,si):type(ctype)
+ if featFirst then
+ mod.featFirst = true
+ end
+ if noBias then
+ mod:noBias()
+ end
+ local groundtruth = mod:forward(input)
+
+ input = makeNonContiguous(input:type(typename))
+ local cmod = nn.TemporalRowConvolution(from,ki,si):type(typename)
+
+ if featFirst then
+ cmod.featFirst = true
+ end
+ if noBias then
+ cmod:noBias()
+ end
+ cmod.weight = mod.weight:type(typename)
+ if mod.bias then cmod.bias = mod.bias:type(typename) end
+ local rescuda = cmod:forward(input)
+
+ local error = rescuda:double() - groundtruth:double()
+ mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
+ string.format('error on state (forward) with %s', typename))
+ end
+ end
+ end
+ jacTest(false,false)
+ jacTest(false,true)
+ jacTest(true,false)
+ jacTest(true,true)
+end
+
+function cunntest.TemporalRowConvolution_forward_batch()
+ local bs = math.random(4,16)
+ local from = math.random(1,64)
+ local to = from
+ local ki = math.random(3,15)
+ local si = math.random(1,2)
+ local outi = math.random(1,256)
+ local ini = (outi-1)*si+ki
+
+ local function jacTest(noBias,featFirst)
+ noBias = noBias or false
+ featFirst = featFirst or false
+ for k, typename in ipairs(typenames) do
+ if typename ~= "torch.CudaHalfTensor" then
+
+ local input
+ if featFirst then
+ input = torch.randn(bs, from, ini):type(typename)
+ else
+ input = torch.randn(bs, ini, from):type(typename)
+ end
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ local mod = nn.TemporalRowConvolution(from,ki,si):type(ctype)
+ if featFirst then
+ mod.featFirst = true
+ end
+ if noBias then
+ mod:noBias()
+ end
+ local groundtruth = mod:forward(input)
+
+ input = makeNonContiguous(input:type(typename))
+ local cmod = nn.TemporalRowConvolution(from,ki,si):type(typename)
+ if featFirst then
+ cmod.featFirst = true
+ end
+ if noBias then
+ cmod:noBias()
+ end
+ cmod.weight = mod.weight:type(typename)
+ if mod.bias then
+ cmod.bias = mod.bias:type(typename)
+ end
+ local rescuda = cmod:forward(input)
+
+ local error = rescuda:double() - groundtruth:double()
+ mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
+ string.format('error on state (forward) with %s', typename))
+ end
+ end
+ end
+ jacTest(false,false)
+ jacTest(false,true)
+ jacTest(true,false)
+ jacTest(true,true)
+end
+
+function cunntest.TemporalRowConvolution_backward_single()
+ local from = math.random(1,64) -- nFeature
+ local to = from
+ local ki = math.random(3,15) -- kW
+ local si = math.random(1,2) -- dW
+ local outi = math.random(1,256) -- nOutputFrame
+ local ini = (outi-1)*si+ki -- nInputFrame
+
+ local function jacTest(noBias,featFirst)
+ noBias = noBias or false
+ featFirst = featFirst or false
+ for k, typename in ipairs(typenames) do
+ if typename ~= "torch.CudaHalfTensor" then
+
+ local input, gradOutput
+ if featFirst then
+ input = torch.randn(from, ini):type(typename)
+ gradOutput = torch.randn(to, outi):type(typename)
+ else
+ input = torch.randn(ini, from):type(typename)
+ gradOutput = torch.rand(outi, to):type(typename)
+ end
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ gradOutput = makeNonContiguous(gradOutput:type(ctype))
+ local mod = nn.TemporalRowConvolution(from,ki,si):type(ctype)
+ if featFirst then mod.featFirst = true end
+ if noBias then mod:noBias() end
+ mod:forward(input)
+ mod:zeroGradParameters()
+ local groundgrad = mod:backward(input, gradOutput)
+ local groundweight = mod.gradWeight
+ local groundbias = mod.gradBias
+
+ input = makeNonContiguous(input:type(typename))
+ gradOutput = makeNonContiguous(gradOutput:type(typename))
+ local cmod = nn.TemporalRowConvolution(from,ki,si):type(typename)
+ if featFirst then cmod.featFirst = true end
+ if noBias then cmod:noBias() end
+ cmod.weight = mod.weight:type(typename)
+ if cmod.bias then cmod.bias = mod.bias:type(typename) end
+ cmod:forward(input)
+ cmod:zeroGradParameters()
+ local rescuda = cmod:backward(input, gradOutput)
+ local weightcuda = cmod.gradWeight
+
+ local error = rescuda:double() - groundgrad:double()
+ local werror = weightcuda:double() - groundweight:double()
+
+ mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
+ string.format('error on state (backward) with %s', typename))
+ mytester:assertlt(werror:abs():max(),
+ precision_backward_conv_weightbias(precision_backward, typename, weightcuda:abs():max()),
+ string.format('error on weight (backward) with %s', typename))
+
+ if cmod.bias then
+ local berror = cmod.gradBias:double() - groundbias:double()
+ mytester:assertlt(berror:abs():max(),
+ precision_backward_conv_weightbias(precision_backward, typename, cmod.gradBias:abs():max()),
+ string.format('error on bias (backward) with %s', typename))
+ end
+ end
+ end
+ end
+ jacTest(false,false)
+ jacTest(false,true)
+ jacTest(true,false)
+ jacTest(true,true)
+end
+
+function cunntest.TemporalRowConvolution_backward_batch()
+ local bs = math.random(4,16)
+ local from = math.random(1,64) -- nFeature
+ local to = from
+ local ki = math.random(3,15) -- kW
+ local si = math.random(1,2) -- dW
+ local outi = math.random(1,256) -- nOutputFrame
+ local ini = (outi-1)*si+ki -- nInputFrame
+
+ local function jacTest(noBias,featFirst)
+ for k, typename in ipairs(typenames) do
+ if typename ~= "torch.CudaHalfTensor" then
+
+ local input, gradOutput
+ if featFirst then
+ input = torch.randn(bs, from, ini):type(typename)
+ gradOutput = torch.randn(bs, to, outi):type(typename)
+ else
+ input = torch.randn(bs, ini, from):type(typename)
+ gradOutput = torch.rand(bs, outi, to):type(typename)
+ end
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ gradOutput = makeNonContiguous(gradOutput:type(ctype))
+ local mod = nn.TemporalRowConvolution(from,ki,si):type(ctype)
+ if featFirst then
+ mod.featFirst = true
+ end
+ if noBias then
+ mod:noBias()
+ end
+ mod:forward(input)
+ mod:zeroGradParameters()
+ local groundgrad = mod:backward(input, gradOutput)
+ local groundweight = mod.gradWeight
+ local groundbias = mod.gradBias
+
+ input = makeNonContiguous(input:type(typename))
+ gradOutput = makeNonContiguous(gradOutput:type(typename))
+ local cmod = nn.TemporalRowConvolution(from,ki,si):type(typename)
+ if featFirst then
+ cmod.featFirst = true
+ end
+ if noBias then
+ cmod:noBias()
+ end
+ cmod.weight = mod.weight:type(typename)
+ if cmod.bias then
+ cmod.bias = mod.bias:type(typename)
+ end
+ cmod:forward(input)
+ cmod:zeroGradParameters()
+ local rescuda = cmod:backward(input, gradOutput)
+ local weightcuda = cmod.gradWeight
+
+ local error = rescuda:double() - groundgrad:double()
+ local werror = weightcuda:double() - groundweight:double()
+
+ mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
+ string.format('error on state (backward) [batch] with %s', typename))
+ mytester:assertlt(werror:abs():max(),
+ precision_backward_conv_weightbias(precision_backward, typename, weightcuda:abs():max()),
+ string.format('error on weight (backward) [batch] with %s', typename))
+
+ if cmod.bias then
+ local berror = cmod.gradBias:double() - groundbias:double()
+ mytester:assertlt(berror:abs():max(),
+ precision_backward_conv_weightbias(precision_backward, typename, cmod.gradBias:abs():max()),
+ string.format('error on bias (backward) [batch] with %s', typename))
+ end
+ end
+ end
+ end
+ jacTest(false,false)
+ jacTest(false,true)
+ jacTest(true,false)
+ jacTest(true,true)
+end
+
function cunntest.Dropout()
local p = 0.2 --prob of droping out a neuron
local input = makeNonContiguous(torch.CudaTensor(1000):fill((1-p)))