diff options
author | Natalia Gimelshein <ngimelshein@nvidia.com> | 2017-07-14 01:12:29 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-07-14 02:23:29 +0300 |
commit | 27d79db5ac8f9fa3995ac43f876e4eb146d99913 (patch) | |
tree | 3143f8102fbeac6543b9230195138ecc07a76ad2 | |
parent | c3c0d9b11125f452c82f3225908f4313f91bb849 (diff) |
add launch_bounds to greedy kernels
-rw-r--r-- | lib/THCUNN/VolumetricUpSamplingTrilinear.cu | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/lib/THCUNN/VolumetricUpSamplingTrilinear.cu b/lib/THCUNN/VolumetricUpSamplingTrilinear.cu index 0d861af..5d1493d 100644 --- a/lib/THCUNN/VolumetricUpSamplingTrilinear.cu +++ b/lib/THCUNN/VolumetricUpSamplingTrilinear.cu @@ -10,6 +10,7 @@ #include "THCAtomics.cuh" template<typename Dtype, typename Acctype> +__launch_bounds__(1024) __global__ void caffe_gpu_interp2_kernel(const int n, const Acctype rdepth, const Acctype rheight, const Acctype rwidth, const THCDeviceTensor<Dtype, 5> data1, THCDeviceTensor<Dtype, 5> data2) { @@ -77,6 +78,7 @@ __global__ void caffe_gpu_interp2_kernel(const int n, // Backward (adjoint) operation 1 <- 2 (accumulates) template <typename Dtype, typename Acctype> +__launch_bounds__(1024) __global__ void caffe_gpu_interp2_kernel_backward(const int n, const Acctype rdepth, const Acctype rheight, const Acctype rwidth, THCDeviceTensor<Dtype, 5> data1, const THCDeviceTensor<Dtype, 5> data2){ |