Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-05-15 04:15:30 +0300
committerGitHub <noreply@github.com>2017-05-15 04:15:30 +0300
commita5ae72397c1c3f07483df66d87fc1c4814083100 (patch)
tree5fd5b7f39dc291be87f55350b6b5ba8dfe8382d2
parente97095d33184f97b5c82e606031071fd107b298e (diff)
Revert "Update to ignore zero targets"
-rw-r--r--lib/THCUNN/ClassNLLCriterion.cu36
1 files changed, 14 insertions, 22 deletions
diff --git a/lib/THCUNN/ClassNLLCriterion.cu b/lib/THCUNN/ClassNLLCriterion.cu
index 9a278d8..58684f4 100644
--- a/lib/THCUNN/ClassNLLCriterion.cu
+++ b/lib/THCUNN/ClassNLLCriterion.cu
@@ -22,14 +22,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output,
// updateOutput_kernel.
int t = (int)*target - TH_INDEX_BASE;
- assert(t >= -1 && t < n_classes);
- if (t >= 0) {
- Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
- *output = -cur_weight * input[t];
- *total_weight = cur_weight;
- if (size_average && *total_weight > 0) {
- *output /= *total_weight;
- }
+ assert(t >= 0 && t < n_classes);
+ Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
+ *output = -cur_weight * input[t];
+ *total_weight = cur_weight;
+ if (size_average && *total_weight > 0) {
+ *output /= *total_weight;
}
}
@@ -51,12 +49,10 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
acc_weight[threadIdx.x] = ScalarConvert<int, Acctype>::to(0);
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
t = target[i] - TH_INDEX_BASE;
- assert(t >= -1 && t < n_classes);
- if (t >= 0) {
- cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
- shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
- acc_weight[threadIdx.x] += cur_weight;
- }
+ assert(t >= 0 && t < n_classes);
+ cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
+ shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
+ acc_weight[threadIdx.x] += cur_weight;
}
__syncthreads();
@@ -95,10 +91,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1(
}
Dtype norm = size_average ? (ScalarConvert<int, Dtype>::to(1) / *total_weight) : ScalarConvert<int, Dtype>::to(1);
int t = (int)*target - TH_INDEX_BASE;
- assert(t >= -1 && t < n_classes);
- if (t >= 0) {
- gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
- }
+ assert(t >= 0 && t < n_classes);
+ gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
}
template <typename Dtype>
@@ -120,10 +114,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
t = (int)target[i] - TH_INDEX_BASE;
- assert(t >= -1 && t < n_classes);
- if (t >= 0) {
- gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
- }
+ assert(t >= 0 && t < n_classes);
+ gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
}
}