diff options
author | Nicholas LĂ©onard <nick@nikopia.org> | 2017-05-12 17:38:18 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-12 17:38:18 +0300 |
commit | 3845f851d4d1ffc8b50ebb01c92c7f6c53ab21b4 (patch) | |
tree | 5bdaf8e240abe2b755c1c8cb1ff561dac338bfb0 | |
parent | 8252eb7450c6e632be7a4f37f217995059834837 (diff) | |
parent | 5ea11052bebef7a33b024ae1972d8a51be847938 (diff) |
Merge pull request #458 from jnhwkim/master
Update to ignore zero targets
-rw-r--r-- | lib/THCUNN/ClassNLLCriterion.cu | 36 |
1 files changed, 22 insertions, 14 deletions
diff --git a/lib/THCUNN/ClassNLLCriterion.cu b/lib/THCUNN/ClassNLLCriterion.cu index 58684f4..9a278d8 100644 --- a/lib/THCUNN/ClassNLLCriterion.cu +++ b/lib/THCUNN/ClassNLLCriterion.cu @@ -22,12 +22,14 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output, // updateOutput_kernel. int t = (int)*target - TH_INDEX_BASE; - assert(t >= 0 && t < n_classes); - Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); - *output = -cur_weight * input[t]; - *total_weight = cur_weight; - if (size_average && *total_weight > 0) { - *output /= *total_weight; + assert(t >= -1 && t < n_classes); + if (t >= 0) { + Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); + *output = -cur_weight * input[t]; + *total_weight = cur_weight; + if (size_average && *total_weight > 0) { + *output /= *total_weight; + } } } @@ -49,10 +51,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output, acc_weight[threadIdx.x] = ScalarConvert<int, Acctype>::to(0); for (i = threadIdx.x; i < nframe; i += NTHREADS) { t = target[i] - TH_INDEX_BASE; - assert(t >= 0 && t < n_classes); - cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); - shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight; - acc_weight[threadIdx.x] += cur_weight; + assert(t >= -1 && t < n_classes); + if (t >= 0) { + cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); + shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight; + acc_weight[threadIdx.x] += cur_weight; + } } __syncthreads(); @@ -91,8 +95,10 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1( } Dtype norm = size_average ? (ScalarConvert<int, Dtype>::to(1) / *total_weight) : ScalarConvert<int, Dtype>::to(1); int t = (int)*target - TH_INDEX_BASE; - assert(t >= 0 && t < n_classes); - gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; + assert(t >= -1 && t < n_classes); + if (t >= 0) { + gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; + } } template <typename Dtype> @@ -114,8 +120,10 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel( for (i = threadIdx.x; i < nframe; i += NTHREADS) { t = (int)target[i] - TH_INDEX_BASE; - assert(t >= 0 && t < n_classes); - gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; + assert(t >= -1 && t < n_classes); + if (t >= 0) { + gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; + } } } |