Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THCUNN/ClassNLLCriterion.cu')
-rw-r--r--lib/THCUNN/ClassNLLCriterion.cu48
1 files changed, 30 insertions, 18 deletions
diff --git a/lib/THCUNN/ClassNLLCriterion.cu b/lib/THCUNN/ClassNLLCriterion.cu
index 58684f4..194d64c 100644
--- a/lib/THCUNN/ClassNLLCriterion.cu
+++ b/lib/THCUNN/ClassNLLCriterion.cu
@@ -15,19 +15,22 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output,
THCIndex_t *target,
Dtype *weights,
int size_average,
- int n_classes) {
+ int n_classes,
+ long ignore_index) {
assert(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0);
// TODO: T4951791 Reuse code between updateOutput_kernel1 and
// updateOutput_kernel.
int t = (int)*target - TH_INDEX_BASE;
- assert(t >= 0 && t < n_classes);
- Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
- *output = -cur_weight * input[t];
- *total_weight = cur_weight;
- if (size_average && *total_weight > 0) {
- *output /= *total_weight;
+ if (t != ignore_index) {
+ assert(t >= 0 && t < n_classes);
+ Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
+ *output = -cur_weight * input[t];
+ *total_weight = cur_weight;
+ if (size_average && *total_weight > 0) {
+ *output /= *total_weight;
+ }
}
}
@@ -40,7 +43,8 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
int size_average,
int nframe,
int ndim,
- int n_classes) {
+ int n_classes,
+ long ignore_index) {
__shared__ Acctype shInputs[NTHREADS], acc_weight[NTHREADS];
int i, t;
Dtype cur_weight;
@@ -49,10 +53,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
acc_weight[threadIdx.x] = ScalarConvert<int, Acctype>::to(0);
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
t = target[i] - TH_INDEX_BASE;
- assert(t >= 0 && t < n_classes);
- cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
- shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
- acc_weight[threadIdx.x] += cur_weight;
+ if (t != ignore_index) {
+ assert(t >= 0 && t < n_classes);
+ cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
+ shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
+ acc_weight[threadIdx.x] += cur_weight;
+ }
}
__syncthreads();
@@ -84,15 +90,18 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1(
THCIndex_t* target,
Dtype* total_weight,
int size_average,
- int n_classes)
+ int n_classes,
+ long ignore_index)
{
if (*total_weight <= 0) {
return;
}
Dtype norm = size_average ? (ScalarConvert<int, Dtype>::to(1) / *total_weight) : ScalarConvert<int, Dtype>::to(1);
int t = (int)*target - TH_INDEX_BASE;
- assert(t >= 0 && t < n_classes);
- gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
+ if (t != ignore_index) {
+ assert(t >= 0 && t < n_classes);
+ gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
+ }
}
template <typename Dtype>
@@ -104,7 +113,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
int size_average,
int nframe,
int ndim,
- int n_classes)
+ int n_classes,
+ long ignore_index)
{
if (*total_weight <= 0) {
return;
@@ -114,8 +124,10 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
t = (int)target[i] - TH_INDEX_BASE;
- assert(t >= 0 && t < n_classes);
- gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
+ if (t != ignore_index) {
+ assert(t >= 0 && t < n_classes);
+ gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
+ }
}
}