diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2016-08-11 22:29:57 +0300 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2016-08-11 22:31:31 +0300 |
commit | 8ae08a2bc7e9813f1660c25274164884c27fe641 (patch) | |
tree | 6f9e7cadfd611293f91d36a851cd19d0a69f9b75 | |
parent | 1b7667145d311ecb3dfe9715ae6569a958f0e8e9 (diff) |
Use TH_INDEX_BASE in THCUNN
-rw-r--r-- | lib/THCUNN/ClassNLLCriterion.cu | 8 | ||||
-rw-r--r-- | lib/THCUNN/LookupTable.cu | 10 | ||||
-rw-r--r-- | lib/THCUNN/MultiLabelMarginCriterion.cu | 22 | ||||
-rw-r--r-- | lib/THCUNN/MultiMarginCriterion.cu | 4 | ||||
-rw-r--r-- | lib/THCUNN/SpatialAdaptiveMaxPooling.cu | 12 | ||||
-rw-r--r-- | lib/THCUNN/SpatialClassNLLCriterion.cu | 4 | ||||
-rw-r--r-- | lib/THCUNN/SpatialFractionalMaxPooling.cu | 4 | ||||
-rw-r--r-- | lib/THCUNN/SpatialMaxPooling.cu | 4 | ||||
-rw-r--r-- | lib/THCUNN/SpatialMaxUnpooling.cu | 4 |
9 files changed, 36 insertions, 36 deletions
diff --git a/lib/THCUNN/ClassNLLCriterion.cu b/lib/THCUNN/ClassNLLCriterion.cu index b2f54cb..7f72c23 100644 --- a/lib/THCUNN/ClassNLLCriterion.cu +++ b/lib/THCUNN/ClassNLLCriterion.cu @@ -18,7 +18,7 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(float *output, // TODO: T4951791 Reuse code between updateOutput_kernel1 and // updateOutput_kernel. - int t = (int)*target - 1; + int t = (int)*target - TH_INDEX_BASE; assert(t >= 0 && t < n_classes); float cur_weight = weights ? weights[t] : 1.0f; *output = -cur_weight * input[t]; @@ -44,7 +44,7 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(float *output, shInputs[threadIdx.x] = 0.0f; acc_weight[threadIdx.x] = 0.0f; for (i = threadIdx.x; i < nframe; i += NTHREADS) { - t = target[i] - 1; + t = target[i] - TH_INDEX_BASE; assert(t >= 0 && t < n_classes); cur_weight = weights ? weights[t] : 1.0f; shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight; @@ -79,7 +79,7 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1( return; } float norm = size_average ? (1.0f / *total_weight) : 1.0f; - int t = (int)*target - 1; + int t = (int)*target - TH_INDEX_BASE; assert(t >= 0 && t < n_classes); gradInput[t] = -(weights ? weights[t] : 1.0f) * norm; } @@ -101,7 +101,7 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel( float norm = size_average ? (1.0f / *total_weight) : 1.0f; for (i = threadIdx.x; i < nframe; i += NTHREADS) { - t = (int)target[i] - 1; + t = (int)target[i] - TH_INDEX_BASE; assert(t >= 0 && t < n_classes); gradInput[i * ndim + t] = -(weights ? weights[t] : 1.0f) * norm; } diff --git a/lib/THCUNN/LookupTable.cu b/lib/THCUNN/LookupTable.cu index 2d1fee4..749ce15 100644 --- a/lib/THCUNN/LookupTable.cu +++ b/lib/THCUNN/LookupTable.cu @@ -73,8 +73,8 @@ __global__ void cunn_LookupTable_accGradParametersKernelByFeature( // warp-wide collision detector `warpHasCollision`. const int laneId = threadIdx.x % 32; for (int i = laneId; i < numel; i += WARP_SIZE) { - const int weightIndex = (int) (input[i] - 1); - if (weightIndex == paddingValue - 1) { + const int weightIndex = (int) (input[i] - TH_INDEX_BASE); + if (weightIndex == paddingValue - TH_INDEX_BASE) { continue; } @@ -120,8 +120,8 @@ __global__ void cunn_LookupTable_accGradParametersKernel( && input[idx] != paddingValue) { do { const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ; - const int weightRow = ((int) input[idx] - 1) * stride; - const int gradOutputRow = ((int) indices[idx] - 1) * stride; + const int weightRow = ((int) input[idx] - TH_INDEX_BASE) * stride; + const int gradOutputRow = ((int) indices[idx] - TH_INDEX_BASE) * stride; const float scale = count ? defaultScale / count[idx] : defaultScale; float gradient[SZ]; @@ -326,7 +326,7 @@ void THNN_CudaLookupTable_renorm( // numel << stride, since idx usually contains sparse row indices for (long i = 0; i < numel; i++) { - long k = idx_ptr[i] - 1; + long k = idx_ptr[i] - TH_INDEX_BASE; thrust::device_ptr<float> row_ptr = weight_ptr + k * stride; float norm = thrust::transform_reduce(row_ptr, row_ptr + stride, unary_pow, 0, binary_plus); diff --git a/lib/THCUNN/MultiLabelMarginCriterion.cu b/lib/THCUNN/MultiLabelMarginCriterion.cu index 97769ab..903e064 100644 --- a/lib/THCUNN/MultiLabelMarginCriterion.cu +++ b/lib/THCUNN/MultiLabelMarginCriterion.cu @@ -33,9 +33,9 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(float *output // mark targets in istarget if (threadIdx.x == 0) { for (int dt = 0; dt < dim; dt++) { - int target_idx = (int)target_k[dt]; - if (target_idx == 0) break; - istarget_k[target_idx - 1] = 1; + int target_idx = (int)target_k[dt] - TH_INDEX_BASE; + if (target_idx < 0) break; + istarget_k[target_idx] = 1; } } __syncthreads(); @@ -44,11 +44,11 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(float *output float sum = 0; for (int dt = 0; dt < dim; dt++) { // next target: - int target_idx = (int)target_k[dt]; - if (target_idx == 0) break; + int target_idx = (int)target_k[dt] - TH_INDEX_BASE; + if (target_idx < 0) break; // current value for target - float input_target_k = input_k[target_idx-1]; + float input_target_k = input_k[target_idx]; // compare to all inputs (multithreaded): for (int d = threadIdx.x; d < dim; d += blockDim.x) { @@ -102,11 +102,11 @@ __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(float *gra // iterate over targets for (int dt = 0; dt < dim; dt++) { // next target: - int target_idx = (int)target_k[dt]; - if (target_idx == 0) break; + int target_idx = (int)target_k[dt] - TH_INDEX_BASE; + if (target_idx < 0) break; // current value for target - float input_target_k = input_k[target_idx-1]; + float input_target_k = input_k[target_idx]; // compare to all inputs (multithreaded): float sum = 0; @@ -125,7 +125,7 @@ __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(float *gra // reduce sum float totalSum = reduceBlock(sums, blockDim.x, sum, thrust::plus<float>(), 0.0f); if (threadIdx.x == 0) { - gradInput_k[target_idx-1] += totalSum; + gradInput_k[target_idx] += totalSum; } __syncthreads(); } @@ -237,4 +237,4 @@ void THNN_CudaMultiLabelMarginCriterion_updateGradInput( THCudaTensor_free(state, istarget); } -#undef MULTILABELMARGIN_THREADS
\ No newline at end of file +#undef MULTILABELMARGIN_THREADS diff --git a/lib/THCUNN/MultiMarginCriterion.cu b/lib/THCUNN/MultiMarginCriterion.cu index a3c3e22..31caa75 100644 --- a/lib/THCUNN/MultiMarginCriterion.cu +++ b/lib/THCUNN/MultiMarginCriterion.cu @@ -10,7 +10,7 @@ __global__ void cunn_MultiMarginCriterion_updateOutput_kernel(float *output, flo int k = blockIdx.x; float *input_k = input + k*dim; float *output_k = output + k; - int target_k = ((int)target[k])-1; + int target_k = ((int)target[k]) - TH_INDEX_BASE; float input_target_k = input_k[target_k]; int i_start = threadIdx.x; @@ -53,7 +53,7 @@ __global__ void cunn_MultiMarginCriterion_updateGradInput_kernel(float *gradInpu int k = blockIdx.x; float *input_k = input + k*dim; float *gradInput_k = gradInput + k*dim; - int target_k = ((int)target[k])-1; + int target_k = ((int)target[k]) - TH_INDEX_BASE; float input_target_k = input_k[target_k]; float g = (sizeAverage ? 1./((float)(nframe*dim)) : 1./((float)dim)); diff --git a/lib/THCUNN/SpatialAdaptiveMaxPooling.cu b/lib/THCUNN/SpatialAdaptiveMaxPooling.cu index c615064..5dd8659 100644 --- a/lib/THCUNN/SpatialAdaptiveMaxPooling.cu +++ b/lib/THCUNN/SpatialAdaptiveMaxPooling.cu @@ -71,8 +71,8 @@ __global__ void adaptivemaxpool(float *input, float *output, float *indices_x, f } // Update output and argmax *ptr_output = max; - *ptr_ind_x = argmax_x + 1; - *ptr_ind_y = argmax_y + 1; + *ptr_ind_x = argmax_x + TH_INDEX_BASE; + *ptr_ind_y = argmax_y + TH_INDEX_BASE; } } } @@ -122,8 +122,8 @@ __global__ void adaptivemaxgradinput(float *gradInput, float *gradOutput, float float *ptr_ind_y = indices_y + yy*output_w + xx; float z = *ptr_gradOutput; - int argmax_x = (*ptr_ind_x)-1; - int argmax_y = (*ptr_ind_y)-1; + int argmax_x = (*ptr_ind_x) - TH_INDEX_BASE; + int argmax_y = (*ptr_ind_y) - TH_INDEX_BASE; ptr_gradInput[argmax_x + argmax_y*input_w] += z; } @@ -176,8 +176,8 @@ __global__ void atomicadaptivemaxgradinput( float *ptr_ind_y = indices_y + yy*output_w + xx; float z = *ptr_gradOutput; - int argmax_x = (*ptr_ind_x)-1; - int argmax_y = (*ptr_ind_y)-1; + int argmax_x = (*ptr_ind_x) - TH_INDEX_BASE; + int argmax_y = (*ptr_ind_y) - TH_INDEX_BASE; // atomic add since different threads could update same variable atomicAdd(&(ptr_gradInput[argmax_x + argmax_y*input_w]), z); diff --git a/lib/THCUNN/SpatialClassNLLCriterion.cu b/lib/THCUNN/SpatialClassNLLCriterion.cu index c718772..7f9e21f 100644 --- a/lib/THCUNN/SpatialClassNLLCriterion.cu +++ b/lib/THCUNN/SpatialClassNLLCriterion.cu @@ -32,7 +32,7 @@ __global__ void cunn_SpatialClassNLLCriterion_updateOutput_kernel( for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x; i < map_nelem; i += step) { - t = target[toffset + i] - 1; + t = target[toffset + i] - TH_INDEX_BASE; assert(t >= 0 && t < n_classes); cur_weight = weights ? weights[t] : 1.0f; input_sum -= input[ioffset + i + map_nelem * t] * cur_weight; @@ -77,7 +77,7 @@ __global__ void cunn_SpatialClassNLLCriterion_updateGradInput_kernel( for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x; i < map_nelem; i += step) { - t = (int)target[toffset + i] - 1; + t = (int)target[toffset + i] - TH_INDEX_BASE; assert(t >= 0 && t < n_classes); gradInput[ioffset + i + map_nelem * t] = -(weights ? weights[t] : 1.0f) * norm; } diff --git a/lib/THCUNN/SpatialFractionalMaxPooling.cu b/lib/THCUNN/SpatialFractionalMaxPooling.cu index 81d12b1..289b1d6 100644 --- a/lib/THCUNN/SpatialFractionalMaxPooling.cu +++ b/lib/THCUNN/SpatialFractionalMaxPooling.cu @@ -68,7 +68,7 @@ __global__ void SpatialFractionalMaxPooling_updateOutput( assert(maxIndex != -1); // +1 for Lua index - indices[batch][plane][outputH][outputW] = maxIndex + 1; + indices[batch][plane][outputH][outputW] = maxIndex + TH_INDEX_BASE; output[batch][plane][outputH][outputW] = maxVal; } } @@ -177,7 +177,7 @@ __global__ void SpatialFractionalMaxPooling_updateGradInput( int outputW = ourOutputPoint % gradOutput.getSize(3); int outputH = ourOutputPoint / gradOutput.getSize(3); - int index = indices[batch][plane][outputH][outputW] - 1; + int index = indices[batch][plane][outputH][outputW] - TH_INDEX_BASE; assert(index >= 0); int inputW = index % gradInput.getSize(3); int inputH = index / gradInput.getSize(3); diff --git a/lib/THCUNN/SpatialMaxPooling.cu b/lib/THCUNN/SpatialMaxPooling.cu index 824823f..fe27f50 100644 --- a/lib/THCUNN/SpatialMaxPooling.cu +++ b/lib/THCUNN/SpatialMaxPooling.cu @@ -32,7 +32,7 @@ __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data, } } top_data[index] = maxval; - top_mask[index] = maxidx + 1; + top_mask[index] = maxidx + TH_INDEX_BASE; } } @@ -63,7 +63,7 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff, top_mask += offset; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { - if (top_mask[ph * pooled_width + pw] - 1 == h * width + w) { + if (top_mask[ph * pooled_width + pw] - TH_INDEX_BASE == h * width + w) { gradient += top_diff[ph * pooled_width + pw]; } } diff --git a/lib/THCUNN/SpatialMaxUnpooling.cu b/lib/THCUNN/SpatialMaxUnpooling.cu index bd2c3af..b56bd56 100644 --- a/lib/THCUNN/SpatialMaxUnpooling.cu +++ b/lib/THCUNN/SpatialMaxUnpooling.cu @@ -8,7 +8,7 @@ __global__ void MaxUnpoolForward(const int nthreads, const Dtype* bottom_data, c int c = (index / iwidth / iheight) % channels; int n = index / iwidth / iheight / channels; top_data += (n*channels + c)*oheight*owidth; - int maxind = bottom_mask[index]-1; + int maxind = bottom_mask[index] - TH_INDEX_BASE; top_data[maxind] = bottom_data[index]; } @@ -21,7 +21,7 @@ __global__ void MaxUnpoolBackward(const int nthreads, const Dtype* top_diff, con int c = (index / iwidth / iheight) % channels; int n = index / iwidth / iheight / channels; top_diff += (n*channels + c)*oheight*owidth; - int maxind = bottom_mask[index]-1; + int maxind = bottom_mask[index] - TH_INDEX_BASE; bottom_diff[index] = top_diff[maxind]; } |