diff options
author | Gregory Chanan <gchanan@fb.com> | 2016-11-04 21:57:08 +0300 |
---|---|---|
committer | Gregory Chanan <gchanan@fb.com> | 2016-11-09 00:47:01 +0300 |
commit | 604d6fffa9913cbcffb2cb32a1660a5cc4e893ab (patch) | |
tree | 733f55a21614edf4219d53f50d12b164c5aa5305 /lib/THCUNN/im2col.h | |
parent | b0444c9da3367e44a0de17d8250742906b47ce46 (diff) |
Allow wider test tolerances for:
1) Size of half numbers
2) Convolution weight/bias
3) BatchNormalization
Diffstat (limited to 'lib/THCUNN/im2col.h')
-rw-r--r-- | lib/THCUNN/im2col.h | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/lib/THCUNN/im2col.h b/lib/THCUNN/im2col.h index 9cb5afe..ba57263 100644 --- a/lib/THCUNN/im2col.h +++ b/lib/THCUNN/im2col.h @@ -60,7 +60,7 @@ void im2col(cudaStream_t stream, const Dtype* data_im, const int channels, THCudaCheck(cudaGetLastError()); } -template <typename Dtype> +template <typename Dtype, typename Acctype> __global__ void col2im_kernel(const int n, const Dtype* data_col, const int height, const int width, const int channels, const int kernel_h, const int kernel_w, @@ -70,7 +70,7 @@ __global__ void col2im_kernel(const int n, const Dtype* data_col, const int height_col, const int width_col, Dtype* data_im) { CUDA_KERNEL_LOOP(index, n) { - Dtype val = ScalarConvert<int, Dtype>::to(0); + Acctype val = Acctype(0); const int w_im = index % width + pad_w; const int h_im = (index / width) % height + pad_h; const int c_im = index / (width * height); @@ -97,11 +97,11 @@ __global__ void col2im_kernel(const int n, const Dtype* data_col, } } } - data_im[index] = val; + data_im[index] = ScalarConvert<Acctype, Dtype>::to(val); } } -template <typename Dtype> +template <typename Dtype, typename Acctype> void col2im(cudaStream_t stream, const Dtype* data_col, const int channels, const int height, const int width, const int patch_h, const int patch_w, const int pad_h, @@ -114,7 +114,7 @@ void col2im(cudaStream_t stream, const Dtype* data_col, const int channels, int num_kernels = channels * height * width; // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. - col2im_kernel <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>> ( + col2im_kernel<Dtype, Acctype> <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>> ( num_kernels, data_col, height, width, channels, patch_h, patch_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, |