diff options
author | Will Frey <will.frey@digitalreasoning.com> | 2017-01-27 21:30:25 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-01-27 21:30:25 +0300 |
commit | 37db7b80735e0a6d74c0f7a9b6fc72b1df5ccd38 (patch) | |
tree | b6b2be2dd6761e3b67582921c5066dcaf8b3b655 | |
parent | 5fa193a84ca8fe112bf6e75487ac96eb1b1239d2 (diff) |
Added cunn support for TemporalRowConvolutionMM (#415)
* Added cunn TemporalRowConvolutionMM support
-rw-r--r-- | lib/THCUNN/TemporalRowConvolution.cu | 10 | ||||
-rw-r--r-- | lib/THCUNN/generic/THCUNN.h | 40 | ||||
-rw-r--r-- | lib/THCUNN/generic/TemporalRowConvolution.cu | 423 | ||||
-rw-r--r-- | lib/THCUNN/row2col.h | 90 | ||||
-rw-r--r-- | test.lua | 269 |
5 files changed, 832 insertions, 0 deletions
diff --git a/lib/THCUNN/TemporalRowConvolution.cu b/lib/THCUNN/TemporalRowConvolution.cu new file mode 100644 index 0000000..dc3b18c --- /dev/null +++ b/lib/THCUNN/TemporalRowConvolution.cu @@ -0,0 +1,10 @@ +#include "THCUNN.h" +#include "common.h" +#include "row2col.h" + +#include "THCHalf.h" +#include "THCHalfAutoNumerics.cuh" + +#include "generic/TemporalRowConvolution.cu" + +#include "THCGenerateFloatTypes.h" diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index bf903b9..ec3d287 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -933,6 +933,46 @@ TH_API void THNN_(TemporalMaxPooling_updateGradInput)( THCIndexTensor *indices, int kW, int dW); +TH_API void THNN_(TemporalRowConvolution_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + THCTensor *weight, + THCTensor *bias, // [OPTIONAL] + THCTensor *finput, + THCTensor *fgradInput, + int kW, + int dW, + int padW, + bool featFirst); + +TH_API void THNN_(TemporalRowConvolution_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput, + THCTensor *weight, + THCTensor *finput, + THCTensor *fgradInput, + int kW, + int dW, + int padW, + bool featFirst); + +TH_API void THNN_(TemporalRowConvolution_accGradParameters)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradWeight, + THCTensor *gradBias, + THCTensor *finput, + THCTensor *fgradInput, + int kW, + int dW, + int padW, + bool featFirst, + real scale); + TH_API void THNN_(Threshold_updateOutput)( THCState *state, THCTensor *input, diff --git a/lib/THCUNN/generic/TemporalRowConvolution.cu b/lib/THCUNN/generic/TemporalRowConvolution.cu new file mode 100644 index 0000000..365599d --- /dev/null +++ b/lib/THCUNN/generic/TemporalRowConvolution.cu @@ -0,0 +1,423 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/TemporalRowConvolution.cu" +#else + +static inline void THNN_(TemporalRowConvolution_shapeCheck)( + THCState *state, THCTensor *input, THCTensor *gradOutput, THCTensor *weight, + THCTensor *bias, int kW, int dW, int padW) { + + THArgCheck(kW > 0, 5, + "kernel size should be greater than zero, but got kW: %d", kW); + THArgCheck(dW > 0, 6, "stride should be greater than zero, but got dW: %d", + dW); + THCUNN_argCheck(state, weight->nDimension == 2 || weight->nDimension == 3, 3, + weight, "2D or 3D weight tensor expected, but got: %s"); + + if (bias != NULL) { + THCUNN_check_dim_size(state, bias, 1, 0, weight->size[0]); + } + + int ndim = input->nDimension; + int dimF = 0; // feature dimension + int dimS = 1; // sequence dimension + + if (ndim == 3) { + ++dimF; + ++dimS; + } + + THCUNN_argCheck(state, ndim == 2 || ndim == 3, 1, input, + "2D or 3D (batch mode) input tensor expected, but got :%s"); + + long inputFrameSize = weight->size[0]; + long nInputFrame = input->size[dimS]; + long nOutputFrame = (nInputFrame + 2 * padW - kW) / dW + 1; + + if (nOutputFrame < 1) { + THError("Given input size: (%d x %d). " + "Calculated output size: (%d x %d). Output size is too small", + inputFrameSize, nInputFrame, inputFrameSize, nOutputFrame); + } + + THCUNN_check_dim_size(state, input, ndim, dimF, inputFrameSize); + + if (gradOutput != NULL) { + THCUNN_check_dim_size(state, gradOutput, ndim, dimF, inputFrameSize); + THCUNN_check_dim_size(state, gradOutput, ndim, dimS, nOutputFrame); + } +} + +void THNN_(TemporalRowConvolution_updateOutput)( + THCState *state, THCTensor *input, THCTensor *output, THCTensor *weight, + THCTensor *bias, THCTensor *finput, THCTensor *fgradInput, int kW, int dW, + int padW, bool featFirst) { + + // aliases + THCTensor *columns = finput; + THCTensor *ones = fgradInput; + + // assert same GPU + THCUNN_assertSameGPU(state, 5, input, output, weight, columns, ones); + if (bias != NULL) { + THCUNN_assertSameGPU(state, 2, weight, bias); + } + + // reshape weight if necessary + int ndim = input->nDimension; + + THCTensor *tinput; + + if (!featFirst) { + tinput = THCTensor_(newTranspose)(state, input, ndim - 1, ndim - 2); + input = THCTensor_(newContiguous)(state, tinput); + } else { + input = THCTensor_(newContiguous)(state, input); + } + + THNN_(TemporalRowConvolution_shapeCheck) + (state, input, NULL, weight, bias, kW, dW, padW); + + int batch = 1; + if (ndim == 2) { + // Force batch + batch = 0; + THCTensor_(resize3d)(state, input, 1, input->size[0], input->size[1]); + } + + // Params: + long inputFrameSize = weight->size[0]; + long nInputFrame = input->size[2]; + long nOutputFrame = (nInputFrame + 2 * padW - kW) / dW + 1; + + // Batch size + long batchSize = input->size[0]; + + // Resize output + THCTensor_(resize3d)(state, output, batchSize, inputFrameSize, nOutputFrame); + + // Augment the input + THCTensor_(resize3d)(state, columns, inputFrameSize, kW, nOutputFrame); + + // Define a buffer of ones, for bias accumulation + // Note: this buffer can be shared with other modules, it only ever + // gets increased and always contains ones. + if (ones->nDimension != 2 || ones->size[0] * ones->size[1] < nOutputFrame) { + // Resize plane and fill with ones... + THCTensor_(resize2d)(state, ones, 1, nOutputFrame); + THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1)); + } + + // Helpers + THCTensor *input_n = THCTensor_(new)(state); + THCTensor *output_n = THCTensor_(new)(state); + + // For each elt in batch, do: + for (int elt = 0; elt < batchSize; ++elt) { + // Matrix multiply per output: + THCTensor_(select)(state, input_n, input, 0, elt); + THCTensor_(select)(state, output_n, output, 0, elt); + + // Do bias first: + // m_, n_, k_ are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + long m_ = inputFrameSize; + long n_ = nOutputFrame; + long k_ = 1; + + // Do GEMM (note: this is a bit confusing because gemm asummes + // column-major matrices) + if (bias != NULL) { +#ifdef THC_REAL_IS_FLOAT + THCudaBlas_Sgemm( +#elif defined(THC_REAL_IS_HALF) + THCudaBlas_Hgemm( +#elif defined(THC_REAL_IS_DOUBLE) + THCudaBlas_Dgemm( +#endif + state, 't', 'n', n_, m_, k_, ScalarConvert<int, real>::to(1), + THCTensor_(data)(state, ones), k_, THCTensor_(data)(state, bias), k_, + ScalarConvert<int, real>::to(0), THCTensor_(data)(state, output_n), + n_); + } else { + THCTensor_(zero)(state, output_n); + } + + // Extract columns: + row2col(THCState_getCurrentStream(state), THCTensor_(data)(state, input_n), + inputFrameSize, nInputFrame, kW, padW, dW, 1, + THCTensor_(data)(state, columns)); + + THCTensor *output3d = THCTensor_(newWithStorage3d)( + state, output_n->storage, output_n->storageOffset, inputFrameSize, -1, + 1, -1, nOutputFrame, -1); + + // weight: inputFrameSize x 1 x kW + // columns: inputFrameSize x kW x nOutputFrame + THCTensor_(baddbmm)(state, output3d, ScalarConvert<int, real>::to(1), + output3d, ScalarConvert<int, real>::to(1), weight, + columns); + // output3d: inputFrameSize x 1 x nOutputFrame + + THCTensor_(free)(state, output3d); + } + + // Free + THCTensor_(free)(state, input_n); + THCTensor_(free)(state, output_n); + + // Resize output + if (batch == 0) { + THCTensor_(resize2d)(state, output, inputFrameSize, nOutputFrame); + THCTensor_(resize2d)(state, input, inputFrameSize, nInputFrame); + } + + if (!featFirst) { + THCTensor_(transpose)(state, output, output, ndim - 1, ndim - 2); + THCTensor_(free)(state, tinput); + } + + THCTensor_(free)(state, input); +} + +void THNN_(TemporalRowConvolution_updateGradInput)( + THCState *state, THCTensor *input, THCTensor *gradOutput, + THCTensor *gradInput, THCTensor *weight, THCTensor *finput, + THCTensor *fgradInput, int kW, int dW, int padW, bool featFirst) { + + // aliases + THCTensor *gradColumns = finput; + + THCUNN_assertSameGPU(state, 5, input, gradOutput, weight, gradColumns, + gradInput); + + int ndim = input->nDimension; + + THCTensor *tinput, *tgradOutput; + + if (!featFirst) { + tinput = THCTensor_(newTranspose)(state, input, ndim - 1, ndim - 2); + tgradOutput = + THCTensor_(newTranspose)(state, gradOutput, ndim - 1, ndim - 2); + input = THCTensor_(newContiguous)(state, tinput); + gradOutput = THCTensor_(newContiguous)(state, tgradOutput); + + } else { + input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); + } + + THNN_(TemporalRowConvolution_shapeCheck) + (state, input, gradOutput, weight, NULL, kW, dW, padW); + + int batch = 1; + if (ndim == 2) { + // Force batch + batch = 0; + THCTensor_(resize3d)(state, input, 1, input->size[0], input->size[1]); + THCTensor_(resize3d)(state, gradOutput, 1, gradOutput->size[0], + gradOutput->size[1]); + } + + // Params: + long inputFrameSize = weight->size[0]; + long nInputFrame = input->size[2]; + long nOutputFrame = gradOutput->size[2]; + + // Batch size + long batchSize = input->size[0]; + + // Resize output + THCTensor_(resize3d)(state, gradInput, batchSize, inputFrameSize, + nInputFrame); + + // Resize temporary columns + THCTensor_(resize3d)(state, gradColumns, inputFrameSize, kW, nOutputFrame); + + // Helpers + THCTensor *gradInput_n = THCTensor_(new)(state); + THCTensor *gradOutput_n = THCTensor_(new)(state); + + THCTensor_(transpose)(state, weight, weight, 1, 2); + + for (int elt = 0; elt < batchSize; ++elt) { + // Matrix multiply per sample: + THCTensor_(select)(state, gradInput_n, gradInput, 0, elt); + THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt); + + THCTensor *gradOutput3d = THCTensor_(newWithStorage3d)( + state, gradOutput_n->storage, gradOutput_n->storageOffset, + inputFrameSize, -1, 1, -1, nOutputFrame, -1); + + // weight: inputFrameSize x kW x 1 + // gradOutput3d: inputFrameSize x 1 x nOutputFrame + THCTensor_(baddbmm)(state, gradColumns, ScalarConvert<int, real>::to(0), + gradColumns, ScalarConvert<int, real>::to(1), weight, + gradOutput3d); + // gradColumns: inputFrameSize x kW x nOutputFrame + + // Unpack columns back into input: + col2row<real, accreal>(THCState_getCurrentStream(state), + THCTensor_(data)(state, gradColumns), inputFrameSize, + nInputFrame, kW, padW, dW, 1, + THCTensor_(data)(state, gradInput_n)); + + THCTensor_(free)(state, gradOutput3d); + } + + // Free + THCTensor_(free)(state, gradInput_n); + THCTensor_(free)(state, gradOutput_n); + + // Resize output + if (batch == 0) { + THCTensor_(resize2d)(state, gradOutput, inputFrameSize, nOutputFrame); + THCTensor_(resize2d)(state, input, inputFrameSize, nInputFrame); + THCTensor_(resize2d)(state, gradInput, inputFrameSize, nInputFrame); + } + + THCTensor_(transpose)(state, weight, weight, 1, 2); + + if (!featFirst) { + THCTensor_(transpose)(state, gradInput, gradInput, ndim - 1, ndim - 2); + THCTensor_(free)(state, tinput); + THCTensor_(free)(state, tgradOutput); + } + + THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); +} + +void THNN_(TemporalRowConvolution_accGradParameters)( + THCState *state, THCTensor *input, THCTensor *gradOutput, + THCTensor *gradWeight, THCTensor *gradBias, THCTensor *finput, + THCTensor *fgradInput, int kW, int dW, int padW, bool featFirst, + real scale) { + + // Aliases + THCTensor *columns = finput; + THCTensor *ones = fgradInput; + + THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones); + if (gradBias != NULL) { + THCUNN_assertSameGPU(state, 2, gradWeight, gradBias); + } + + int ndim = input->nDimension; + + THCTensor *tinput, *tgradOutput; + + if (!featFirst) { + tinput = THCTensor_(newTranspose)(state, input, ndim - 1, ndim - 2); + tgradOutput = + THCTensor_(newTranspose)(state, gradOutput, ndim - 1, ndim - 2); + input = THCTensor_(newContiguous)(state, tinput); + gradOutput = THCTensor_(newContiguous)(state, tgradOutput); + } else { + input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); + } + + THNN_(TemporalRowConvolution_shapeCheck) + (state, input, gradOutput, gradWeight, gradBias, kW, dW, padW); + + int batch = 1; + if (ndim == 2) { + // Force batch + batch = 0; + THCTensor_(resize3d)(state, input, 1, input->size[0], input->size[1]); + THCTensor_(resize3d)(state, gradOutput, 1, gradOutput->size[0], + gradOutput->size[1]); + } + + // Params: + long inputFrameSize = gradWeight->size[0]; + long nInputFrame = input->size[2]; + long nOutputFrame = gradOutput->size[2]; + + // Batch size + long batchSize = input->size[0]; + + // Define a buffer of ones, for bias accumulation + if (ones->nDimension != 2 || ones->size[0] * ones->size[1] < nOutputFrame) { + // Resize plane and fill with ones... + THCTensor_(resize2d)(state, ones, 1, nOutputFrame); + THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1)); + } + + // // Resize temporary columns + THCTensor_(resize3d)(state, columns, inputFrameSize, kW, nOutputFrame); + + // Helpers + THCTensor *input_n = THCTensor_(new)(state); + THCTensor *gradOutput_n = THCTensor_(new)(state); + + // For each elt in batch, do: + for (int elt = 0; elt < batchSize; ++elt) { + // Matrix multiply per output + THCTensor_(select)(state, input_n, input, 0, elt); + THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt); + + THCTensor *gradOutput3d = THCTensor_(newWithStorage3d)( + state, gradOutput_n->storage, gradOutput_n->storageOffset, + inputFrameSize, -1, 1, -1, nOutputFrame, -1); + + // Extract columns + row2col(THCState_getCurrentStream(state), THCTensor_(data)(state, input_n), + inputFrameSize, nInputFrame, kW, padW, dW, 1, + THCTensor_(data)(state, columns)); + + THCTensor_(transpose)(state, columns, columns, 1, 2); + + // gradOutput3d: inputFrameSize x 1 x nOutputFrame + // columns: inputFrameSize x nOutputFrame x kW + THCTensor_(baddbmm)(state, gradWeight, ScalarConvert<int, real>::to(1), + gradWeight, scale, gradOutput3d, columns); + // gradWeight: inputFrameSize x 1 x kW + + THCTensor_(transpose)(state, columns, columns, 1, 2); + + THCTensor_(free)(state, gradOutput3d); + + if (gradBias != NULL) { + long m_ = inputFrameSize; + long k_ = nOutputFrame; +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) +#ifdef THC_REAL_IS_FLOAT + THCudaBlas_Sgemv( +#elif defined(THC_REAL_IS_DOUBLE) + THCudaBlas_Dgemv( +#endif + state, 't', k_, m_, scale, THCTensor_(data)(state, gradOutput_n), k_, + THCTensor_(data)(state, ones), 1, ScalarConvert<int, real>::to(1), + THCTensor_(data)(state, gradBias), 1); +#endif +#ifdef THC_REAL_IS_HALF // half not supported due to baddbmm + THCudaBlas_Hgemm(state, 't', 'n', m_, 1, k_, scale, + THCTensor_(data)(state, gradOutput_n), k_, + THCTensor_(data)(state, ones), k_, + ScalarConvert<int, real>::to(1), + THCTensor_(data)(state, gradBias), m_); +#endif + } + } + + // Free + THCTensor_(free)(state, input_n); + THCTensor_(free)(state, gradOutput_n); + + // Resize + if (batch == 0) { + THCTensor_(resize2d)(state, gradOutput, inputFrameSize, nOutputFrame); + THCTensor_(resize2d)(state, input, inputFrameSize, nInputFrame); + } + + if (!featFirst) { + THCTensor_(free)(state, tinput); + THCTensor_(free)(state, tgradOutput); + } + + THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); +} + +#endif diff --git a/lib/THCUNN/row2col.h b/lib/THCUNN/row2col.h new file mode 100644 index 0000000..04765dd --- /dev/null +++ b/lib/THCUNN/row2col.h @@ -0,0 +1,90 @@ +#ifndef THCUNN_ROW2COL_H +#define THCUNN_ROW2COL_H + +#include "THCNumerics.cuh" +#include "common.h" + +// Kernel for fast unfold+copy on rows +template <typename Dtype> +__global__ void +row2col_kernel(const int n, const Dtype *data_row, const int width, + const int ksize_w, const int pad_w, const int stride_w, + const int dilation_w, const int width_col, Dtype *data_col) { + CUDA_KERNEL_LOOP(index, n) { + int w_out = index % width_col; + index /= width_col; + int channel_in = index; + int channel_out = channel_in * ksize_w; + int w_in = w_out * stride_w - pad_w; + data_col += (channel_out)*width_col + w_out; + data_row += (channel_in)*width + w_in; + for (int j = 0; j < ksize_w; ++j) { + int w = w_in + j * dilation_w; + *data_col = (w >= 0 && w < width) ? data_row[j * dilation_w] + : ScalarConvert<int, Dtype>::to(0); + data_col += width_col; + } + } +} + +template <typename Dtype> +void row2col(cudaStream_t stream, const Dtype *data_row, const int channels, + const int width, const int ksize_w, const int pad_w, + const int stride_w, const int dilation_w, Dtype *data_col) { + // We are going to launch channels * width_col kernels, each + // kernel responsible for copying a single-channel grid. + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * width_col; + // Launch + row2col_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>( + num_kernels, data_row, width, ksize_w, pad_w, stride_w, 1, width_col, + data_col); + THCudaCheck(cudaGetLastError()); +} + +template <typename Dtype, typename Acctype> +__global__ void col2row_kernel(const int n, const Dtype *data_col, + const int width, const int channels, + const int kernel_w, const int pad_w, + const int stride_w, const int dilation_w, + const int width_col, Dtype *data_row) { + CUDA_KERNEL_LOOP(index, n) { + Acctype val = Acctype(0); + const int w_row = index % width + pad_w; + const int c_row = index / width; + int kernel_extent_w = (kernel_w - 1) * dilation_w + 1; + // compute the start and end of the output + const int w_col_start = (w_row < kernel_extent_w) + ? 0 + : (w_row - kernel_extent_w) / stride_w + 1; + const int w_col_end = min(w_row / stride_w + 1, width_col); + for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) { + int w_k = (w_row - w_col * stride_w); + if (w_k % dilation_w == 0) { + w_k /= dilation_w; + int data_col_index = (c_row * kernel_w + w_k) * width_col + w_col; + val += data_col[data_col_index]; + } + } + data_row[index] = ScalarConvert<Acctype, Dtype>::to(val); + } + } + +template <typename Dtype, typename Acctype> +void col2row(cudaStream_t stream, const Dtype *data_col, const int channels, + const int width, const int patch_w, const int pad_w, + const int stride_w, const int dilation_w, Dtype *data_row) { + int width_col = + (width + 2 * pad_w - (dilation_w * (patch_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * width; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + col2row_kernel< + Dtype, Acctype><<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>( + num_kernels, data_col, width, channels, patch_w, pad_w, stride_w, + dilation_w, width_col, data_row); + + THCudaCheck(cudaGetLastError()); +} +#endif @@ -3661,6 +3661,275 @@ function cunntest.TemporalConvolution_backward_batch() end end + +function cunntest.TemporalRowConvolution_forward_single() + local from = math.random(1,64) -- nFeature + local to = from + local ki = math.random(3,15) -- kW + local si = math.random(1,2) -- dW + local outi = math.random(1,256) -- nOutputFrame + local ini = (outi-1)*si+ki -- nInputFrame + + local function jacTest(noBias, featFirst) + noBias = noBias or false + featFirst = featFirst or false + + for k, typename in ipairs(typenames) do + if typename ~= "torch.CudaHalfTensor" then + + local input + if featFirst then + input = torch.randn(from, ini):type(typename) + else + input = torch.randn(ini, from):type(typename) + end + + local ctype = t2cpu[typename] + input = makeNonContiguous(input:type(ctype)) + local mod = nn.TemporalRowConvolution(from,ki,si):type(ctype) + if featFirst then + mod.featFirst = true + end + if noBias then + mod:noBias() + end + local groundtruth = mod:forward(input) + + input = makeNonContiguous(input:type(typename)) + local cmod = nn.TemporalRowConvolution(from,ki,si):type(typename) + + if featFirst then + cmod.featFirst = true + end + if noBias then + cmod:noBias() + end + cmod.weight = mod.weight:type(typename) + if mod.bias then cmod.bias = mod.bias:type(typename) end + local rescuda = cmod:forward(input) + + local error = rescuda:double() - groundtruth:double() + mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename), + string.format('error on state (forward) with %s', typename)) + end + end + end + jacTest(false,false) + jacTest(false,true) + jacTest(true,false) + jacTest(true,true) +end + +function cunntest.TemporalRowConvolution_forward_batch() + local bs = math.random(4,16) + local from = math.random(1,64) + local to = from + local ki = math.random(3,15) + local si = math.random(1,2) + local outi = math.random(1,256) + local ini = (outi-1)*si+ki + + local function jacTest(noBias,featFirst) + noBias = noBias or false + featFirst = featFirst or false + for k, typename in ipairs(typenames) do + if typename ~= "torch.CudaHalfTensor" then + + local input + if featFirst then + input = torch.randn(bs, from, ini):type(typename) + else + input = torch.randn(bs, ini, from):type(typename) + end + + local ctype = t2cpu[typename] + input = makeNonContiguous(input:type(ctype)) + local mod = nn.TemporalRowConvolution(from,ki,si):type(ctype) + if featFirst then + mod.featFirst = true + end + if noBias then + mod:noBias() + end + local groundtruth = mod:forward(input) + + input = makeNonContiguous(input:type(typename)) + local cmod = nn.TemporalRowConvolution(from,ki,si):type(typename) + if featFirst then + cmod.featFirst = true + end + if noBias then + cmod:noBias() + end + cmod.weight = mod.weight:type(typename) + if mod.bias then + cmod.bias = mod.bias:type(typename) + end + local rescuda = cmod:forward(input) + + local error = rescuda:double() - groundtruth:double() + mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename), + string.format('error on state (forward) with %s', typename)) + end + end + end + jacTest(false,false) + jacTest(false,true) + jacTest(true,false) + jacTest(true,true) +end + +function cunntest.TemporalRowConvolution_backward_single() + local from = math.random(1,64) -- nFeature + local to = from + local ki = math.random(3,15) -- kW + local si = math.random(1,2) -- dW + local outi = math.random(1,256) -- nOutputFrame + local ini = (outi-1)*si+ki -- nInputFrame + + local function jacTest(noBias,featFirst) + noBias = noBias or false + featFirst = featFirst or false + for k, typename in ipairs(typenames) do + if typename ~= "torch.CudaHalfTensor" then + + local input, gradOutput + if featFirst then + input = torch.randn(from, ini):type(typename) + gradOutput = torch.randn(to, outi):type(typename) + else + input = torch.randn(ini, from):type(typename) + gradOutput = torch.rand(outi, to):type(typename) + end + + local ctype = t2cpu[typename] + input = makeNonContiguous(input:type(ctype)) + gradOutput = makeNonContiguous(gradOutput:type(ctype)) + local mod = nn.TemporalRowConvolution(from,ki,si):type(ctype) + if featFirst then mod.featFirst = true end + if noBias then mod:noBias() end + mod:forward(input) + mod:zeroGradParameters() + local groundgrad = mod:backward(input, gradOutput) + local groundweight = mod.gradWeight + local groundbias = mod.gradBias + + input = makeNonContiguous(input:type(typename)) + gradOutput = makeNonContiguous(gradOutput:type(typename)) + local cmod = nn.TemporalRowConvolution(from,ki,si):type(typename) + if featFirst then cmod.featFirst = true end + if noBias then cmod:noBias() end + cmod.weight = mod.weight:type(typename) + if cmod.bias then cmod.bias = mod.bias:type(typename) end + cmod:forward(input) + cmod:zeroGradParameters() + local rescuda = cmod:backward(input, gradOutput) + local weightcuda = cmod.gradWeight + + local error = rescuda:double() - groundgrad:double() + local werror = weightcuda:double() - groundweight:double() + + mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename), + string.format('error on state (backward) with %s', typename)) + mytester:assertlt(werror:abs():max(), + precision_backward_conv_weightbias(precision_backward, typename, weightcuda:abs():max()), + string.format('error on weight (backward) with %s', typename)) + + if cmod.bias then + local berror = cmod.gradBias:double() - groundbias:double() + mytester:assertlt(berror:abs():max(), + precision_backward_conv_weightbias(precision_backward, typename, cmod.gradBias:abs():max()), + string.format('error on bias (backward) with %s', typename)) + end + end + end + end + jacTest(false,false) + jacTest(false,true) + jacTest(true,false) + jacTest(true,true) +end + +function cunntest.TemporalRowConvolution_backward_batch() + local bs = math.random(4,16) + local from = math.random(1,64) -- nFeature + local to = from + local ki = math.random(3,15) -- kW + local si = math.random(1,2) -- dW + local outi = math.random(1,256) -- nOutputFrame + local ini = (outi-1)*si+ki -- nInputFrame + + local function jacTest(noBias,featFirst) + for k, typename in ipairs(typenames) do + if typename ~= "torch.CudaHalfTensor" then + + local input, gradOutput + if featFirst then + input = torch.randn(bs, from, ini):type(typename) + gradOutput = torch.randn(bs, to, outi):type(typename) + else + input = torch.randn(bs, ini, from):type(typename) + gradOutput = torch.rand(bs, outi, to):type(typename) + end + + local ctype = t2cpu[typename] + input = makeNonContiguous(input:type(ctype)) + gradOutput = makeNonContiguous(gradOutput:type(ctype)) + local mod = nn.TemporalRowConvolution(from,ki,si):type(ctype) + if featFirst then + mod.featFirst = true + end + if noBias then + mod:noBias() + end + mod:forward(input) + mod:zeroGradParameters() + local groundgrad = mod:backward(input, gradOutput) + local groundweight = mod.gradWeight + local groundbias = mod.gradBias + + input = makeNonContiguous(input:type(typename)) + gradOutput = makeNonContiguous(gradOutput:type(typename)) + local cmod = nn.TemporalRowConvolution(from,ki,si):type(typename) + if featFirst then + cmod.featFirst = true + end + if noBias then + cmod:noBias() + end + cmod.weight = mod.weight:type(typename) + if cmod.bias then + cmod.bias = mod.bias:type(typename) + end + cmod:forward(input) + cmod:zeroGradParameters() + local rescuda = cmod:backward(input, gradOutput) + local weightcuda = cmod.gradWeight + + local error = rescuda:double() - groundgrad:double() + local werror = weightcuda:double() - groundweight:double() + + mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename), + string.format('error on state (backward) [batch] with %s', typename)) + mytester:assertlt(werror:abs():max(), + precision_backward_conv_weightbias(precision_backward, typename, weightcuda:abs():max()), + string.format('error on weight (backward) [batch] with %s', typename)) + + if cmod.bias then + local berror = cmod.gradBias:double() - groundbias:double() + mytester:assertlt(berror:abs():max(), + precision_backward_conv_weightbias(precision_backward, typename, cmod.gradBias:abs():max()), + string.format('error on bias (backward) [batch] with %s', typename)) + end + end + end + end + jacTest(false,false) + jacTest(false,true) + jacTest(true,false) + jacTest(true,true) +end + function cunntest.Dropout() local p = 0.2 --prob of droping out a neuron local input = makeNonContiguous(torch.CudaTensor(1000):fill((1-p))) |