diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-05-15 04:15:30 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-15 04:15:30 +0300 |
commit | a5ae72397c1c3f07483df66d87fc1c4814083100 (patch) | |
tree | 5fd5b7f39dc291be87f55350b6b5ba8dfe8382d2 | |
parent | e97095d33184f97b5c82e606031071fd107b298e (diff) |
Revert "Update to ignore zero targets"
-rw-r--r-- | lib/THCUNN/ClassNLLCriterion.cu | 36 |
1 files changed, 14 insertions, 22 deletions
diff --git a/lib/THCUNN/ClassNLLCriterion.cu b/lib/THCUNN/ClassNLLCriterion.cu index 9a278d8..58684f4 100644 --- a/lib/THCUNN/ClassNLLCriterion.cu +++ b/lib/THCUNN/ClassNLLCriterion.cu @@ -22,14 +22,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output, // updateOutput_kernel. int t = (int)*target - TH_INDEX_BASE; - assert(t >= -1 && t < n_classes); - if (t >= 0) { - Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); - *output = -cur_weight * input[t]; - *total_weight = cur_weight; - if (size_average && *total_weight > 0) { - *output /= *total_weight; - } + assert(t >= 0 && t < n_classes); + Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); + *output = -cur_weight * input[t]; + *total_weight = cur_weight; + if (size_average && *total_weight > 0) { + *output /= *total_weight; } } @@ -51,12 +49,10 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output, acc_weight[threadIdx.x] = ScalarConvert<int, Acctype>::to(0); for (i = threadIdx.x; i < nframe; i += NTHREADS) { t = target[i] - TH_INDEX_BASE; - assert(t >= -1 && t < n_classes); - if (t >= 0) { - cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); - shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight; - acc_weight[threadIdx.x] += cur_weight; - } + assert(t >= 0 && t < n_classes); + cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); + shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight; + acc_weight[threadIdx.x] += cur_weight; } __syncthreads(); @@ -95,10 +91,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1( } Dtype norm = size_average ? (ScalarConvert<int, Dtype>::to(1) / *total_weight) : ScalarConvert<int, Dtype>::to(1); int t = (int)*target - TH_INDEX_BASE; - assert(t >= -1 && t < n_classes); - if (t >= 0) { - gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; - } + assert(t >= 0 && t < n_classes); + gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; } template <typename Dtype> @@ -120,10 +114,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel( for (i = threadIdx.x; i < nframe; i += NTHREADS) { t = (int)target[i] - TH_INDEX_BASE; - assert(t >= -1 && t < n_classes); - if (t >= 0) { - gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; - } + assert(t >= 0 && t < n_classes); + gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; } } |