diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-01-11 03:27:29 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-01-11 03:27:29 +0300 |
commit | 49d1c06b9daad86acce7380595b914230360d892 (patch) | |
tree | 2e53ea747cedf2a28dbe3bf886e295758fefef33 | |
parent | 349df42dfa550389f04ac90ea621f21b2838b00c (diff) | |
parent | 1f8292f4f8334ab9e2433a0960607c780b29f848 (diff) |
Merge pull request #414 from gchanan/thrustalloc
Re-route thrust memory allocation to THCudaMalloc / THCudaFree in cunn.
-rw-r--r-- | lib/THCUNN/LookupTable.cu | 1 | ||||
-rw-r--r-- | lib/THCUNN/MSECriterion.cu | 1 | ||||
-rw-r--r-- | lib/THCUNN/SmoothL1Criterion.cu | 1 | ||||
-rw-r--r-- | lib/THCUNN/generic/LookupTable.cu | 5 | ||||
-rw-r--r-- | lib/THCUNN/generic/MSECriterion.cu | 6 | ||||
-rw-r--r-- | lib/THCUNN/generic/SmoothL1Criterion.cu | 6 |
6 files changed, 14 insertions, 6 deletions
diff --git a/lib/THCUNN/LookupTable.cu b/lib/THCUNN/LookupTable.cu index 5251bd1..29cb3a1 100644 --- a/lib/THCUNN/LookupTable.cu +++ b/lib/THCUNN/LookupTable.cu @@ -1,6 +1,7 @@ #include "THCUNN.h" #include "common.h" +#include "THCThrustAllocator.cuh" #include <thrust/device_ptr.h> #include <thrust/execution_policy.h> #include <thrust/iterator/constant_iterator.h> diff --git a/lib/THCUNN/MSECriterion.cu b/lib/THCUNN/MSECriterion.cu index 26a35a5..f776152 100644 --- a/lib/THCUNN/MSECriterion.cu +++ b/lib/THCUNN/MSECriterion.cu @@ -2,6 +2,7 @@ #include "common.h" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" +#include "THCThrustAllocator.cuh" #include <thrust/fill.h> #include <thrust/functional.h> diff --git a/lib/THCUNN/SmoothL1Criterion.cu b/lib/THCUNN/SmoothL1Criterion.cu index 8e94fbc..9a0654f 100644 --- a/lib/THCUNN/SmoothL1Criterion.cu +++ b/lib/THCUNN/SmoothL1Criterion.cu @@ -2,6 +2,7 @@ #include "common.h" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" +#include "THCThrustAllocator.cuh" #include <thrust/fill.h> #include <thrust/functional.h> diff --git a/lib/THCUNN/generic/LookupTable.cu b/lib/THCUNN/generic/LookupTable.cu index d467c6a..bd59a04 100644 --- a/lib/THCUNN/generic/LookupTable.cu +++ b/lib/THCUNN/generic/LookupTable.cu @@ -64,6 +64,7 @@ void THNN_(LookupTable_accGradParameters)( THCIndexTensor_(resizeAs)(state, count, input); count_data = THCIndexTensor_(data)(state, count); + THCThrustAllocator thrustAlloc(state); thrust::device_ptr<THCIndex_t> sorted_ptr(sorted_data); thrust::device_ptr<THCIndex_t> count_ptr(count_data); @@ -72,7 +73,7 @@ void THNN_(LookupTable_accGradParameters)( // count: 1 1 2 3 1 2 1 1 2 thrust::inclusive_scan_by_key( #if CUDA_VERSION >= 7000 - thrust::cuda::par.on(THCState_getCurrentStream(state)), + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif sorted_ptr, sorted_ptr + numel, @@ -85,7 +86,7 @@ void THNN_(LookupTable_accGradParameters)( // count: 1 3 3 3 2 2 1 2 2 thrust::inclusive_scan_by_key( #if CUDA_VERSION >= 7000 - thrust::cuda::par.on(THCState_getCurrentStream(state)), + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif thrust::make_reverse_iterator(sorted_ptr + numel), thrust::make_reverse_iterator(sorted_ptr), diff --git a/lib/THCUNN/generic/MSECriterion.cu b/lib/THCUNN/generic/MSECriterion.cu index a290db1..84c0083 100644 --- a/lib/THCUNN/generic/MSECriterion.cu +++ b/lib/THCUNN/generic/MSECriterion.cu @@ -18,11 +18,12 @@ void THNN_(MSECriterion_updateOutput)( input = THCTensor_(newContiguous)(state, input); target = THCTensor_(newContiguous)(state, target); + THCThrustAllocator thrustAlloc(state); thrust::device_ptr<real> input_data(THCTensor_(data)(state, input)); thrust::device_ptr<real> target_data(THCTensor_(data)(state, target)); accreal sum = thrust::inner_product( #if CUDA_VERSION >= 7000 - thrust::cuda::par.on(THCState_getCurrentStream(state)), + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif input_data, input_data+size, target_data, (accreal) 0, thrust::plus<accreal>(), mse_functor<real, accreal>()); @@ -54,13 +55,14 @@ void THNN_(MSECriterion_updateGradInput)( THCTensor_(resizeAs)(state, gradInput, input); + THCThrustAllocator thrustAlloc(state); thrust::device_ptr<real> input_data(THCTensor_(data)(state, input)); thrust::device_ptr<real> target_data(THCTensor_(data)(state, target)); thrust::device_ptr<real> gradInput_data(THCTensor_(data)(state, gradInput)); thrust::transform( #if CUDA_VERSION >= 7000 - thrust::cuda::par.on(THCState_getCurrentStream(state)), + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif input_data, input_data+size, target_data, gradInput_data, mse_updateGradInput_functor<real, accreal>(norm)); diff --git a/lib/THCUNN/generic/SmoothL1Criterion.cu b/lib/THCUNN/generic/SmoothL1Criterion.cu index de29175..03c6630 100644 --- a/lib/THCUNN/generic/SmoothL1Criterion.cu +++ b/lib/THCUNN/generic/SmoothL1Criterion.cu @@ -22,11 +22,12 @@ void THNN_(SmoothL1Criterion_updateOutput)( input = THCTensor_(newContiguous)(state, input); target = THCTensor_(newContiguous)(state, target); + THCThrustAllocator thrustAlloc(state); thrust::device_ptr<real> input_data(THCTensor_(data)(state, input)); thrust::device_ptr<real> target_data(THCTensor_(data)(state, target)); accreal sum = thrust::inner_product( #if CUDA_VERSION >= 7000 - thrust::cuda::par.on(THCState_getCurrentStream(state)), + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif input_data, input_data+size, target_data, (accreal) 0, thrust::plus<accreal>(), smoothl1_functor<real, accreal>() @@ -63,13 +64,14 @@ void THNN_(SmoothL1Criterion_updateGradInput)( THCTensor_(resizeAs)(state, gradInput, input); + THCThrustAllocator thrustAlloc(state); thrust::device_ptr<real> input_data(THCTensor_(data)(state, input)); thrust::device_ptr<real> target_data(THCTensor_(data)(state, target)); thrust::device_ptr<real> gradInput_data(THCTensor_(data)(state, gradInput)); thrust::transform( #if CUDA_VERSION >= 7000 - thrust::cuda::par.on(THCState_getCurrentStream(state)), + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif input_data, input_data+size, target_data, gradInput_data, smoothl1_updateGradInput_functor<real>(norm) |