Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2016-08-11 22:29:57 +0300
committerAdam Paszke <adam.paszke@gmail.com>2016-08-11 22:31:31 +0300
commit8ae08a2bc7e9813f1660c25274164884c27fe641 (patch)
tree6f9e7cadfd611293f91d36a851cd19d0a69f9b75
parent1b7667145d311ecb3dfe9715ae6569a958f0e8e9 (diff)
Use TH_INDEX_BASE in THCUNN
-rw-r--r--lib/THCUNN/ClassNLLCriterion.cu8
-rw-r--r--lib/THCUNN/LookupTable.cu10
-rw-r--r--lib/THCUNN/MultiLabelMarginCriterion.cu22
-rw-r--r--lib/THCUNN/MultiMarginCriterion.cu4
-rw-r--r--lib/THCUNN/SpatialAdaptiveMaxPooling.cu12
-rw-r--r--lib/THCUNN/SpatialClassNLLCriterion.cu4
-rw-r--r--lib/THCUNN/SpatialFractionalMaxPooling.cu4
-rw-r--r--lib/THCUNN/SpatialMaxPooling.cu4
-rw-r--r--lib/THCUNN/SpatialMaxUnpooling.cu4
9 files changed, 36 insertions, 36 deletions
diff --git a/lib/THCUNN/ClassNLLCriterion.cu b/lib/THCUNN/ClassNLLCriterion.cu
index b2f54cb..7f72c23 100644
--- a/lib/THCUNN/ClassNLLCriterion.cu
+++ b/lib/THCUNN/ClassNLLCriterion.cu
@@ -18,7 +18,7 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(float *output,
// TODO: T4951791 Reuse code between updateOutput_kernel1 and
// updateOutput_kernel.
- int t = (int)*target - 1;
+ int t = (int)*target - TH_INDEX_BASE;
assert(t >= 0 && t < n_classes);
float cur_weight = weights ? weights[t] : 1.0f;
*output = -cur_weight * input[t];
@@ -44,7 +44,7 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(float *output,
shInputs[threadIdx.x] = 0.0f;
acc_weight[threadIdx.x] = 0.0f;
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
- t = target[i] - 1;
+ t = target[i] - TH_INDEX_BASE;
assert(t >= 0 && t < n_classes);
cur_weight = weights ? weights[t] : 1.0f;
shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
@@ -79,7 +79,7 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1(
return;
}
float norm = size_average ? (1.0f / *total_weight) : 1.0f;
- int t = (int)*target - 1;
+ int t = (int)*target - TH_INDEX_BASE;
assert(t >= 0 && t < n_classes);
gradInput[t] = -(weights ? weights[t] : 1.0f) * norm;
}
@@ -101,7 +101,7 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
float norm = size_average ? (1.0f / *total_weight) : 1.0f;
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
- t = (int)target[i] - 1;
+ t = (int)target[i] - TH_INDEX_BASE;
assert(t >= 0 && t < n_classes);
gradInput[i * ndim + t] = -(weights ? weights[t] : 1.0f) * norm;
}
diff --git a/lib/THCUNN/LookupTable.cu b/lib/THCUNN/LookupTable.cu
index 2d1fee4..749ce15 100644
--- a/lib/THCUNN/LookupTable.cu
+++ b/lib/THCUNN/LookupTable.cu
@@ -73,8 +73,8 @@ __global__ void cunn_LookupTable_accGradParametersKernelByFeature(
// warp-wide collision detector `warpHasCollision`.
const int laneId = threadIdx.x % 32;
for (int i = laneId; i < numel; i += WARP_SIZE) {
- const int weightIndex = (int) (input[i] - 1);
- if (weightIndex == paddingValue - 1) {
+ const int weightIndex = (int) (input[i] - TH_INDEX_BASE);
+ if (weightIndex == paddingValue - TH_INDEX_BASE) {
continue;
}
@@ -120,8 +120,8 @@ __global__ void cunn_LookupTable_accGradParametersKernel(
&& input[idx] != paddingValue) {
do {
const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
- const int weightRow = ((int) input[idx] - 1) * stride;
- const int gradOutputRow = ((int) indices[idx] - 1) * stride;
+ const int weightRow = ((int) input[idx] - TH_INDEX_BASE) * stride;
+ const int gradOutputRow = ((int) indices[idx] - TH_INDEX_BASE) * stride;
const float scale = count ? defaultScale / count[idx] : defaultScale;
float gradient[SZ];
@@ -326,7 +326,7 @@ void THNN_CudaLookupTable_renorm(
// numel << stride, since idx usually contains sparse row indices
for (long i = 0; i < numel; i++)
{
- long k = idx_ptr[i] - 1;
+ long k = idx_ptr[i] - TH_INDEX_BASE;
thrust::device_ptr<float> row_ptr = weight_ptr + k * stride;
float norm = thrust::transform_reduce(row_ptr, row_ptr + stride,
unary_pow, 0, binary_plus);
diff --git a/lib/THCUNN/MultiLabelMarginCriterion.cu b/lib/THCUNN/MultiLabelMarginCriterion.cu
index 97769ab..903e064 100644
--- a/lib/THCUNN/MultiLabelMarginCriterion.cu
+++ b/lib/THCUNN/MultiLabelMarginCriterion.cu
@@ -33,9 +33,9 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(float *output
// mark targets in istarget
if (threadIdx.x == 0) {
for (int dt = 0; dt < dim; dt++) {
- int target_idx = (int)target_k[dt];
- if (target_idx == 0) break;
- istarget_k[target_idx - 1] = 1;
+ int target_idx = (int)target_k[dt] - TH_INDEX_BASE;
+ if (target_idx < 0) break;
+ istarget_k[target_idx] = 1;
}
}
__syncthreads();
@@ -44,11 +44,11 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(float *output
float sum = 0;
for (int dt = 0; dt < dim; dt++) {
// next target:
- int target_idx = (int)target_k[dt];
- if (target_idx == 0) break;
+ int target_idx = (int)target_k[dt] - TH_INDEX_BASE;
+ if (target_idx < 0) break;
// current value for target
- float input_target_k = input_k[target_idx-1];
+ float input_target_k = input_k[target_idx];
// compare to all inputs (multithreaded):
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
@@ -102,11 +102,11 @@ __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(float *gra
// iterate over targets
for (int dt = 0; dt < dim; dt++) {
// next target:
- int target_idx = (int)target_k[dt];
- if (target_idx == 0) break;
+ int target_idx = (int)target_k[dt] - TH_INDEX_BASE;
+ if (target_idx < 0) break;
// current value for target
- float input_target_k = input_k[target_idx-1];
+ float input_target_k = input_k[target_idx];
// compare to all inputs (multithreaded):
float sum = 0;
@@ -125,7 +125,7 @@ __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(float *gra
// reduce sum
float totalSum = reduceBlock(sums, blockDim.x, sum, thrust::plus<float>(), 0.0f);
if (threadIdx.x == 0) {
- gradInput_k[target_idx-1] += totalSum;
+ gradInput_k[target_idx] += totalSum;
}
__syncthreads();
}
@@ -237,4 +237,4 @@ void THNN_CudaMultiLabelMarginCriterion_updateGradInput(
THCudaTensor_free(state, istarget);
}
-#undef MULTILABELMARGIN_THREADS \ No newline at end of file
+#undef MULTILABELMARGIN_THREADS
diff --git a/lib/THCUNN/MultiMarginCriterion.cu b/lib/THCUNN/MultiMarginCriterion.cu
index a3c3e22..31caa75 100644
--- a/lib/THCUNN/MultiMarginCriterion.cu
+++ b/lib/THCUNN/MultiMarginCriterion.cu
@@ -10,7 +10,7 @@ __global__ void cunn_MultiMarginCriterion_updateOutput_kernel(float *output, flo
int k = blockIdx.x;
float *input_k = input + k*dim;
float *output_k = output + k;
- int target_k = ((int)target[k])-1;
+ int target_k = ((int)target[k]) - TH_INDEX_BASE;
float input_target_k = input_k[target_k];
int i_start = threadIdx.x;
@@ -53,7 +53,7 @@ __global__ void cunn_MultiMarginCriterion_updateGradInput_kernel(float *gradInpu
int k = blockIdx.x;
float *input_k = input + k*dim;
float *gradInput_k = gradInput + k*dim;
- int target_k = ((int)target[k])-1;
+ int target_k = ((int)target[k]) - TH_INDEX_BASE;
float input_target_k = input_k[target_k];
float g = (sizeAverage ? 1./((float)(nframe*dim)) : 1./((float)dim));
diff --git a/lib/THCUNN/SpatialAdaptiveMaxPooling.cu b/lib/THCUNN/SpatialAdaptiveMaxPooling.cu
index c615064..5dd8659 100644
--- a/lib/THCUNN/SpatialAdaptiveMaxPooling.cu
+++ b/lib/THCUNN/SpatialAdaptiveMaxPooling.cu
@@ -71,8 +71,8 @@ __global__ void adaptivemaxpool(float *input, float *output, float *indices_x, f
}
// Update output and argmax
*ptr_output = max;
- *ptr_ind_x = argmax_x + 1;
- *ptr_ind_y = argmax_y + 1;
+ *ptr_ind_x = argmax_x + TH_INDEX_BASE;
+ *ptr_ind_y = argmax_y + TH_INDEX_BASE;
}
}
}
@@ -122,8 +122,8 @@ __global__ void adaptivemaxgradinput(float *gradInput, float *gradOutput, float
float *ptr_ind_y = indices_y + yy*output_w + xx;
float z = *ptr_gradOutput;
- int argmax_x = (*ptr_ind_x)-1;
- int argmax_y = (*ptr_ind_y)-1;
+ int argmax_x = (*ptr_ind_x) - TH_INDEX_BASE;
+ int argmax_y = (*ptr_ind_y) - TH_INDEX_BASE;
ptr_gradInput[argmax_x + argmax_y*input_w] += z;
}
@@ -176,8 +176,8 @@ __global__ void atomicadaptivemaxgradinput(
float *ptr_ind_y = indices_y + yy*output_w + xx;
float z = *ptr_gradOutput;
- int argmax_x = (*ptr_ind_x)-1;
- int argmax_y = (*ptr_ind_y)-1;
+ int argmax_x = (*ptr_ind_x) - TH_INDEX_BASE;
+ int argmax_y = (*ptr_ind_y) - TH_INDEX_BASE;
// atomic add since different threads could update same variable
atomicAdd(&(ptr_gradInput[argmax_x + argmax_y*input_w]), z);
diff --git a/lib/THCUNN/SpatialClassNLLCriterion.cu b/lib/THCUNN/SpatialClassNLLCriterion.cu
index c718772..7f9e21f 100644
--- a/lib/THCUNN/SpatialClassNLLCriterion.cu
+++ b/lib/THCUNN/SpatialClassNLLCriterion.cu
@@ -32,7 +32,7 @@ __global__ void cunn_SpatialClassNLLCriterion_updateOutput_kernel(
for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
i < map_nelem;
i += step) {
- t = target[toffset + i] - 1;
+ t = target[toffset + i] - TH_INDEX_BASE;
assert(t >= 0 && t < n_classes);
cur_weight = weights ? weights[t] : 1.0f;
input_sum -= input[ioffset + i + map_nelem * t] * cur_weight;
@@ -77,7 +77,7 @@ __global__ void cunn_SpatialClassNLLCriterion_updateGradInput_kernel(
for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
i < map_nelem;
i += step) {
- t = (int)target[toffset + i] - 1;
+ t = (int)target[toffset + i] - TH_INDEX_BASE;
assert(t >= 0 && t < n_classes);
gradInput[ioffset + i + map_nelem * t] = -(weights ? weights[t] : 1.0f) * norm;
}
diff --git a/lib/THCUNN/SpatialFractionalMaxPooling.cu b/lib/THCUNN/SpatialFractionalMaxPooling.cu
index 81d12b1..289b1d6 100644
--- a/lib/THCUNN/SpatialFractionalMaxPooling.cu
+++ b/lib/THCUNN/SpatialFractionalMaxPooling.cu
@@ -68,7 +68,7 @@ __global__ void SpatialFractionalMaxPooling_updateOutput(
assert(maxIndex != -1);
// +1 for Lua index
- indices[batch][plane][outputH][outputW] = maxIndex + 1;
+ indices[batch][plane][outputH][outputW] = maxIndex + TH_INDEX_BASE;
output[batch][plane][outputH][outputW] = maxVal;
}
}
@@ -177,7 +177,7 @@ __global__ void SpatialFractionalMaxPooling_updateGradInput(
int outputW = ourOutputPoint % gradOutput.getSize(3);
int outputH = ourOutputPoint / gradOutput.getSize(3);
- int index = indices[batch][plane][outputH][outputW] - 1;
+ int index = indices[batch][plane][outputH][outputW] - TH_INDEX_BASE;
assert(index >= 0);
int inputW = index % gradInput.getSize(3);
int inputH = index / gradInput.getSize(3);
diff --git a/lib/THCUNN/SpatialMaxPooling.cu b/lib/THCUNN/SpatialMaxPooling.cu
index 824823f..fe27f50 100644
--- a/lib/THCUNN/SpatialMaxPooling.cu
+++ b/lib/THCUNN/SpatialMaxPooling.cu
@@ -32,7 +32,7 @@ __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
}
}
top_data[index] = maxval;
- top_mask[index] = maxidx + 1;
+ top_mask[index] = maxidx + TH_INDEX_BASE;
}
}
@@ -63,7 +63,7 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
top_mask += offset;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
- if (top_mask[ph * pooled_width + pw] - 1 == h * width + w) {
+ if (top_mask[ph * pooled_width + pw] - TH_INDEX_BASE == h * width + w) {
gradient += top_diff[ph * pooled_width + pw];
}
}
diff --git a/lib/THCUNN/SpatialMaxUnpooling.cu b/lib/THCUNN/SpatialMaxUnpooling.cu
index bd2c3af..b56bd56 100644
--- a/lib/THCUNN/SpatialMaxUnpooling.cu
+++ b/lib/THCUNN/SpatialMaxUnpooling.cu
@@ -8,7 +8,7 @@ __global__ void MaxUnpoolForward(const int nthreads, const Dtype* bottom_data, c
int c = (index / iwidth / iheight) % channels;
int n = index / iwidth / iheight / channels;
top_data += (n*channels + c)*oheight*owidth;
- int maxind = bottom_mask[index]-1;
+ int maxind = bottom_mask[index] - TH_INDEX_BASE;
top_data[maxind] = bottom_data[index];
}
@@ -21,7 +21,7 @@ __global__ void MaxUnpoolBackward(const int nthreads, const Dtype* top_diff, con
int c = (index / iwidth / iheight) % channels;
int n = index / iwidth / iheight / channels;
top_diff += (n*channels + c)*oheight*owidth;
- int maxind = bottom_mask[index]-1;
+ int maxind = bottom_mask[index] - TH_INDEX_BASE;
bottom_diff[index] = top_diff[maxind];
}