diff options
Diffstat (limited to 'lib/THCUNN/ClassNLLCriterion.cu')
-rw-r--r-- | lib/THCUNN/ClassNLLCriterion.cu | 48 |
1 files changed, 30 insertions, 18 deletions
diff --git a/lib/THCUNN/ClassNLLCriterion.cu b/lib/THCUNN/ClassNLLCriterion.cu index 58684f4..194d64c 100644 --- a/lib/THCUNN/ClassNLLCriterion.cu +++ b/lib/THCUNN/ClassNLLCriterion.cu @@ -15,19 +15,22 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output, THCIndex_t *target, Dtype *weights, int size_average, - int n_classes) { + int n_classes, + long ignore_index) { assert(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0); // TODO: T4951791 Reuse code between updateOutput_kernel1 and // updateOutput_kernel. int t = (int)*target - TH_INDEX_BASE; - assert(t >= 0 && t < n_classes); - Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); - *output = -cur_weight * input[t]; - *total_weight = cur_weight; - if (size_average && *total_weight > 0) { - *output /= *total_weight; + if (t != ignore_index) { + assert(t >= 0 && t < n_classes); + Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); + *output = -cur_weight * input[t]; + *total_weight = cur_weight; + if (size_average && *total_weight > 0) { + *output /= *total_weight; + } } } @@ -40,7 +43,8 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output, int size_average, int nframe, int ndim, - int n_classes) { + int n_classes, + long ignore_index) { __shared__ Acctype shInputs[NTHREADS], acc_weight[NTHREADS]; int i, t; Dtype cur_weight; @@ -49,10 +53,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output, acc_weight[threadIdx.x] = ScalarConvert<int, Acctype>::to(0); for (i = threadIdx.x; i < nframe; i += NTHREADS) { t = target[i] - TH_INDEX_BASE; - assert(t >= 0 && t < n_classes); - cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); - shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight; - acc_weight[threadIdx.x] += cur_weight; + if (t != ignore_index) { + assert(t >= 0 && t < n_classes); + cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); + shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight; + acc_weight[threadIdx.x] += cur_weight; + } } __syncthreads(); @@ -84,15 +90,18 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1( THCIndex_t* target, Dtype* total_weight, int size_average, - int n_classes) + int n_classes, + long ignore_index) { if (*total_weight <= 0) { return; } Dtype norm = size_average ? (ScalarConvert<int, Dtype>::to(1) / *total_weight) : ScalarConvert<int, Dtype>::to(1); int t = (int)*target - TH_INDEX_BASE; - assert(t >= 0 && t < n_classes); - gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; + if (t != ignore_index) { + assert(t >= 0 && t < n_classes); + gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; + } } template <typename Dtype> @@ -104,7 +113,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel( int size_average, int nframe, int ndim, - int n_classes) + int n_classes, + long ignore_index) { if (*total_weight <= 0) { return; @@ -114,8 +124,10 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel( for (i = threadIdx.x; i < nframe; i += NTHREADS) { t = (int)target[i] - TH_INDEX_BASE; - assert(t >= 0 && t < n_classes); - gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; + if (t != ignore_index) { + assert(t >= 0 && t < n_classes); + gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; + } } } |