diff options
-rw-r--r-- | lib/THCUNN/CMakeLists.txt | 1 | ||||
-rw-r--r-- | lib/THCUNN/HardTanh.cu | 2 | ||||
-rw-r--r-- | lib/THCUNN/generic/LogSoftMax.cu | 20 | ||||
-rw-r--r-- | lib/THCUNN/generic/SparseLinear.cu | 7 | ||||
-rw-r--r-- | lib/THCUNN/generic/SpatialConvolutionLocal.cu | 8 | ||||
-rw-r--r-- | lib/THCUNN/generic/TemporalConvolution.cu | 28 | ||||
-rw-r--r-- | lib/THCUNN/generic/TemporalRowConvolution.cu | 15 |
7 files changed, 44 insertions, 37 deletions
diff --git a/lib/THCUNN/CMakeLists.txt b/lib/THCUNN/CMakeLists.txt index 433d99d..d4777bf 100644 --- a/lib/THCUNN/CMakeLists.txt +++ b/lib/THCUNN/CMakeLists.txt @@ -33,6 +33,7 @@ ENDIF() IF(NOT COMMAND CUDA_SELECT_NVCC_ARCH_FLAGS OR MSVC) INCLUDE(${CMAKE_CURRENT_SOURCE_DIR}/cmake/select_compute_arch.cmake) ENDIF() +LIST(APPEND CUDA_NVCC_FLAGS $ENV{TORCH_NVCC_FLAGS}) CUDA_SELECT_NVCC_ARCH_FLAGS(NVCC_FLAGS_EXTRA $ENV{TORCH_CUDA_ARCH_LIST}) LIST(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA}) diff --git a/lib/THCUNN/HardTanh.cu b/lib/THCUNN/HardTanh.cu index 0543a4c..5921f7f 100644 --- a/lib/THCUNN/HardTanh.cu +++ b/lib/THCUNN/HardTanh.cu @@ -46,7 +46,7 @@ struct hardtanhupdateGradInput_functor __device__ void operator()(T *gradInput, const T *input, const T *gradOutput) const { - if (*input < min_val_ || *input > max_val_) + if (*input <= min_val_ || *input >= max_val_) *gradInput = ScalarConvert<int, T>::to(0); else *gradInput = *gradOutput; diff --git a/lib/THCUNN/generic/LogSoftMax.cu b/lib/THCUNN/generic/LogSoftMax.cu index b39798c..2f24697 100644 --- a/lib/THCUNN/generic/LogSoftMax.cu +++ b/lib/THCUNN/generic/LogSoftMax.cu @@ -40,13 +40,13 @@ void THNN_(LogSoftMax_updateOutput)( width = THCTensor_(size)(state, input, 2); // create contiguous tensor with cuda layout from tensor with torch layout + THCTensor *tinput = THCTensor_(new)(state); // C x H x W -> W x H x C - THCTensor_(transpose)(state, input, input, 0, 2); + THCTensor_(transpose)(state, tinput, input, 0, 2); // W x H x C -> H x W x C - THCTensor_(transpose)(state, input, input, 0, 1); - THCTensor *transposedInput = THCTensor_(newContiguous)(state, input); - THCTensor_(transpose)(state, input, input, 0, 1); - THCTensor_(transpose)(state, input, input, 0, 2); + THCTensor_(transpose)(state, tinput, tinput, 0, 1); + THCTensor *transposedInput = THCTensor_(newContiguous)(state, tinput); + THCTensor_(free)(state, tinput); input = transposedInput; } else if (ndims == 4) @@ -59,12 +59,12 @@ void THNN_(LogSoftMax_updateOutput)( // create contiguous tensor with cuda layout from tensor with torch layout // B x C x H x W -> B x W x H x C - THCTensor_(transpose)(state, input, input, 1, 3); + THCTensor *tinput = THCTensor_(new)(state); + THCTensor_(transpose)(state, tinput, input, 1, 3); // B x W x H x C -> B x H x W x C - THCTensor_(transpose)(state, input, input, 1, 2); - THCTensor *transposedInput = THCTensor_(newContiguous)(state, input); - THCTensor_(transpose)(state, input, input, 1, 2); - THCTensor_(transpose)(state, input, input, 1, 3); + THCTensor_(transpose)(state, tinput, tinput, 1, 2); + THCTensor *transposedInput = THCTensor_(newContiguous)(state, tinput); + THCTensor_(free)(state, tinput); input = transposedInput; } else diff --git a/lib/THCUNN/generic/SparseLinear.cu b/lib/THCUNN/generic/SparseLinear.cu index 6838cac..23a5c94 100644 --- a/lib/THCUNN/generic/SparseLinear.cu +++ b/lib/THCUNN/generic/SparseLinear.cu @@ -175,10 +175,11 @@ void THNN_(SparseLinear_accGradParameters)( THCudaIntTensor_data(state, colPtrs), CUSPARSE_INDEX_BASE_ONE); // FORTRAN expects contiguous col-major matricies - THCTensor_(transpose)(state, gradOutput, NULL, 0, 1); + THCTensor *tgradOutput = THCTensor_(new)(state); + THCTensor_(transpose)(state, tgradOutput, gradOutput, 0, 1); THCTensor_(resize2d)(state, buf, batchnum, outDim); - THCTensor_(copy)(state, buf, gradOutput); - THCTensor_(transpose)(state, gradOutput, NULL, 0, 1); // Restore gradOutput + THCTensor_(copy)(state, buf, tgradOutput); + THCTensor_(free)(state, tgradOutput); real one = ScalarConvert<int, real>::to(1); cusparseMatDescr_t descr = 0; diff --git a/lib/THCUNN/generic/SpatialConvolutionLocal.cu b/lib/THCUNN/generic/SpatialConvolutionLocal.cu index 0d4b9ad..9cbddd1 100644 --- a/lib/THCUNN/generic/SpatialConvolutionLocal.cu +++ b/lib/THCUNN/generic/SpatialConvolutionLocal.cu @@ -230,7 +230,8 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)( THCTensor *fgradInput_n = THCTensor_(new)(state); THCTensor *gradOutput_n = THCTensor_(new)(state); - THCTensor_(transpose)(state, weight, weight, 1, 2); + THCTensor *tweight = THCTensor_(new)(state); + THCTensor_(transpose)(state, tweight, weight, 1, 2); // For each elt in batch, do: for (int elt = 0; elt < batchSize; elt ++) { @@ -258,7 +259,7 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)( THCTensor_(baddbmm)(state, fgradInput3d, ScalarConvert<int, real>::to(0), fgradInput3d, ScalarConvert<int, real>::to(1), - weight, gradOutput3d); + tweight, gradOutput3d); // fgradInput3d: oH*oW x nInputPlane*kH*kW x 1 // Unpack columns back into input: @@ -288,8 +289,7 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)( THCTensor_(resize3d)(state, gradInput, nInputPlane, inputHeight, inputWidth); } - THCTensor_(transpose)(state, weight, weight, 1, 2); - + THCTensor_(free)(state, tweight); THCTensor_(free)(state, input); THCTensor_(free)(state, gradOutput); if (freeWeight) diff --git a/lib/THCUNN/generic/TemporalConvolution.cu b/lib/THCUNN/generic/TemporalConvolution.cu index 5658527..abe4b54 100644 --- a/lib/THCUNN/generic/TemporalConvolution.cu +++ b/lib/THCUNN/generic/TemporalConvolution.cu @@ -98,9 +98,10 @@ void THNN_(TemporalConvolution_updateOutput)( nFrame, outputFrameStride*output->size[1], output->size[1], 1); - THCTensor_(transpose)(state, weight, NULL, 0, 1); - THCTensor_(addmm)(state, outputWindow, ScalarConvert<int, real>::to(1), outputWindow, ScalarConvert<int, real>::to(1), inputWindow, weight); - THCTensor_(transpose)(state, weight, NULL, 0, 1); + THCTensor *tweight = THCTensor_(new)(state); + THCTensor_(transpose)(state, tweight, weight, 0, 1); + THCTensor_(addmm)(state, outputWindow, ScalarConvert<int, real>::to(1), outputWindow, ScalarConvert<int, real>::to(1), inputWindow, tweight); + THCTensor_(free)(state, tweight); } } else @@ -145,9 +146,10 @@ void THNN_(TemporalConvolution_updateOutput)( nFrame, outputFrameStride*outputSample->size[1], outputSample->size[1], 1); - THCTensor_(transpose)(state, weight, NULL, 0, 1); - THCTensor_(addmm)(state, outputWindow, ScalarConvert<int, real>::to(1), outputWindow, ScalarConvert<int, real>::to(1), inputWindow, weight); - THCTensor_(transpose)(state, weight, NULL, 0, 1); + THCTensor *tweight = THCTensor_(new)(state); + THCTensor_(transpose)(state, tweight, weight, 0, 1); + THCTensor_(addmm)(state, outputWindow, ScalarConvert<int, real>::to(1), outputWindow, ScalarConvert<int, real>::to(1), inputWindow, tweight); + THCTensor_(free)(state, tweight); } } THCTensor_(free)(state, outputSample); @@ -329,9 +331,10 @@ void THNN_(TemporalConvolution_accGradParameters)( nFrame, outputFrameStride*gradOutput->size[1], gradOutput->size[1], 1); - THCTensor_(transpose)(state, gradOutputWindow, NULL, 0, 1); - THCTensor_(addmm)(state, gradWeight, ScalarConvert<int, real>::to(1), gradWeight, scale, gradOutputWindow, inputWindow); - THCTensor_(transpose)(state, gradOutputWindow, NULL, 0, 1); + THCTensor *tgradOutputWindow = THCTensor_(new)(state); + THCTensor_(transpose)(state, tgradOutputWindow, gradOutputWindow, 0, 1); + THCTensor_(addmm)(state, gradWeight, ScalarConvert<int, real>::to(1), gradWeight, scale, tgradOutputWindow, inputWindow); + THCTensor_(free)(state, tgradOutputWindow); } } else @@ -371,9 +374,10 @@ void THNN_(TemporalConvolution_accGradParameters)( nFrame, outputFrameStride*gradOutputSample->size[1], gradOutputSample->size[1], 1); - THCTensor_(transpose)(state, gradOutputWindow, NULL, 0, 1); - THCTensor_(addmm)(state, gradWeight, ScalarConvert<int, real>::to(1), gradWeight, scale, gradOutputWindow, inputWindow); - THCTensor_(transpose)(state, gradOutputWindow, NULL, 0, 1); + THCTensor *tgradOutputWindow = THCTensor_(new)(state); + THCTensor_(transpose)(state, tgradOutputWindow, gradOutputWindow, 0, 1); + THCTensor_(addmm)(state, gradWeight, ScalarConvert<int, real>::to(1), gradWeight, scale, tgradOutputWindow, inputWindow); + THCTensor_(free)(state, tgradOutputWindow); } } THCTensor_(free)(state, gradOutputSample); diff --git a/lib/THCUNN/generic/TemporalRowConvolution.cu b/lib/THCUNN/generic/TemporalRowConvolution.cu index a0835a9..9959322 100644 --- a/lib/THCUNN/generic/TemporalRowConvolution.cu +++ b/lib/THCUNN/generic/TemporalRowConvolution.cu @@ -237,7 +237,8 @@ void THNN_(TemporalRowConvolution_updateGradInput)( THCTensor *gradInput_n = THCTensor_(new)(state); THCTensor *gradOutput_n = THCTensor_(new)(state); - THCTensor_(transpose)(state, weight, weight, 1, 2); + THCTensor *tweight = THCTensor_(new)(state); + THCTensor_(transpose)(state, tweight, weight, 1, 2); for (int elt = 0; elt < batchSize; ++elt) { // Matrix multiply per sample: @@ -251,7 +252,7 @@ void THNN_(TemporalRowConvolution_updateGradInput)( // weight: inputFrameSize x kW x 1 // gradOutput3d: inputFrameSize x 1 x nOutputFrame THCTensor_(baddbmm)(state, gradColumns, ScalarConvert<int, real>::to(0), - gradColumns, ScalarConvert<int, real>::to(1), weight, + gradColumns, ScalarConvert<int, real>::to(1), tweight, gradOutput3d); // gradColumns: inputFrameSize x kW x nOutputFrame @@ -275,7 +276,7 @@ void THNN_(TemporalRowConvolution_updateGradInput)( THCTensor_(resize2d)(state, gradInput, inputFrameSize, nInputFrame); } - THCTensor_(transpose)(state, weight, weight, 1, 2); + THCTensor_(free)(state, tweight); if (!featFirst) { THCTensor_(transpose)(state, gradInput, gradInput, ndim - 1, ndim - 2); @@ -367,16 +368,16 @@ void THNN_(TemporalRowConvolution_accGradParameters)( inputFrameSize, nInputFrame, kW, padW, dW, 1, THCTensor_(data)(state, columns)); - THCTensor_(transpose)(state, columns, columns, 1, 2); + THCTensor *tcolumns = THCTensor_(new)(state); + THCTensor_(transpose)(state, tcolumns, columns, 1, 2); // gradOutput3d: inputFrameSize x 1 x nOutputFrame // columns: inputFrameSize x nOutputFrame x kW THCTensor_(baddbmm)(state, gradWeight, ScalarConvert<int, real>::to(1), - gradWeight, scale, gradOutput3d, columns); + gradWeight, scale, gradOutput3d, tcolumns); // gradWeight: inputFrameSize x 1 x kW - THCTensor_(transpose)(state, columns, columns, 1, 2); - + THCTensor_(free)(state, tcolumns); THCTensor_(free)(state, gradOutput3d); if (gradBias != NULL) { |