Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas LĂ©onard <nick@nikopia.org>2017-05-12 17:38:18 +0300
committerGitHub <noreply@github.com>2017-05-12 17:38:18 +0300
commit3845f851d4d1ffc8b50ebb01c92c7f6c53ab21b4 (patch)
tree5bdaf8e240abe2b755c1c8cb1ff561dac338bfb0
parent8252eb7450c6e632be7a4f37f217995059834837 (diff)
parent5ea11052bebef7a33b024ae1972d8a51be847938 (diff)
Merge pull request #458 from jnhwkim/master
Update to ignore zero targets
-rw-r--r--lib/THCUNN/ClassNLLCriterion.cu36
1 files changed, 22 insertions, 14 deletions
diff --git a/lib/THCUNN/ClassNLLCriterion.cu b/lib/THCUNN/ClassNLLCriterion.cu
index 58684f4..9a278d8 100644
--- a/lib/THCUNN/ClassNLLCriterion.cu
+++ b/lib/THCUNN/ClassNLLCriterion.cu
@@ -22,12 +22,14 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output,
// updateOutput_kernel.
int t = (int)*target - TH_INDEX_BASE;
- assert(t >= 0 && t < n_classes);
- Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
- *output = -cur_weight * input[t];
- *total_weight = cur_weight;
- if (size_average && *total_weight > 0) {
- *output /= *total_weight;
+ assert(t >= -1 && t < n_classes);
+ if (t >= 0) {
+ Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
+ *output = -cur_weight * input[t];
+ *total_weight = cur_weight;
+ if (size_average && *total_weight > 0) {
+ *output /= *total_weight;
+ }
}
}
@@ -49,10 +51,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
acc_weight[threadIdx.x] = ScalarConvert<int, Acctype>::to(0);
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
t = target[i] - TH_INDEX_BASE;
- assert(t >= 0 && t < n_classes);
- cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
- shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
- acc_weight[threadIdx.x] += cur_weight;
+ assert(t >= -1 && t < n_classes);
+ if (t >= 0) {
+ cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
+ shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
+ acc_weight[threadIdx.x] += cur_weight;
+ }
}
__syncthreads();
@@ -91,8 +95,10 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1(
}
Dtype norm = size_average ? (ScalarConvert<int, Dtype>::to(1) / *total_weight) : ScalarConvert<int, Dtype>::to(1);
int t = (int)*target - TH_INDEX_BASE;
- assert(t >= 0 && t < n_classes);
- gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
+ assert(t >= -1 && t < n_classes);
+ if (t >= 0) {
+ gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
+ }
}
template <typename Dtype>
@@ -114,8 +120,10 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
t = (int)target[i] - TH_INDEX_BASE;
- assert(t >= 0 && t < n_classes);
- gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
+ assert(t >= -1 && t < n_classes);
+ if (t >= 0) {
+ gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
+ }
}
}