diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-05-21 20:48:19 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-21 20:48:19 +0300 |
commit | 78aac1a015ebba0655a7fdad8a4a09419b68da67 (patch) | |
tree | e534fc3f1f1192102fd4b5b25974abe6d4d7f9f2 /lib | |
parent | 482537275df7fde77cc4dcc1d93de33cbfafde9f (diff) |
Revert "Revert "ClassNLLCriterion supports missing targets""revert-1217-revert-1215-ClassNLLCriterion-missing-target
Diffstat (limited to 'lib')
-rw-r--r-- | lib/THNN/generic/ClassNLLCriterion.c | 48 | ||||
-rw-r--r-- | lib/THNN/generic/THNN.h | 6 |
2 files changed, 34 insertions, 20 deletions
diff --git a/lib/THNN/generic/ClassNLLCriterion.c b/lib/THNN/generic/ClassNLLCriterion.c index 0db3a8a..4cf37ae 100644 --- a/lib/THNN/generic/ClassNLLCriterion.c +++ b/lib/THNN/generic/ClassNLLCriterion.c @@ -9,12 +9,14 @@ void THNN_(ClassNLLCriterion_updateOutput)( THTensor *output, bool sizeAverage, THTensor *weights, - THTensor *total_weight) + THTensor *total_weight, + long ignore_index) { THNN_CHECK_DIM_SIZE(output, 1, 0, 1); THNN_CHECK_DIM_SIZE(total_weight, 1, 0, 1); int n_dims = THTensor_(nDimension)(input); int n_classes = THTensor_(size)(input, n_dims - 1); + ignore_index -= TH_INDEX_BASE; if (THIndexTensor_(nDimension)(target) > 1) { THError("multi-target not supported"); @@ -42,9 +44,11 @@ void THNN_(ClassNLLCriterion_updateOutput)( if (THTensor_(nDimension)(input) == 1) { int cur_target = target_data[0] - TH_INDEX_BASE; - THAssert(cur_target >= 0 && cur_target < n_classes); - total_weight_data[0] = weights ? weights_data[cur_target] : 1.0f; - output_data[0] = -input_data[cur_target] * total_weight_data[0]; + if (cur_target != ignore_index) { + THAssert(cur_target >= 0 && cur_target < n_classes); + total_weight_data[0] = weights ? weights_data[cur_target] : 1.0f; + output_data[0] = -input_data[cur_target] * total_weight_data[0]; + } } else if (THTensor_(nDimension)(input) == 2) { int batch_size = THTensor_(size)(input, 0); THAssert(THIndexTensor_(size)(target, 0) == batch_size); @@ -54,11 +58,13 @@ void THNN_(ClassNLLCriterion_updateOutput)( int i; for (i = 0; i < batch_size; i++) { int cur_target = target_data[i] - TH_INDEX_BASE; - THAssert(cur_target >= 0 && cur_target < n_classes); + if (cur_target != ignore_index) { + THAssert(cur_target >= 0 && cur_target < n_classes); - real cur_weight = weights ? weights_data[cur_target] : 1.0f; - total_weight_data[0] += cur_weight; - output_data[0] -= input_data[i * n_target + cur_target] * cur_weight; + real cur_weight = weights ? weights_data[cur_target] : 1.0f; + total_weight_data[0] += cur_weight; + output_data[0] -= input_data[i * n_target + cur_target] * cur_weight; + } } } @@ -80,10 +86,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)( THTensor *gradInput, bool sizeAverage, THTensor *weights, - THTensor *total_weight) + THTensor *total_weight, + long ignore_index) { int n_dims = THTensor_(nDimension)(input); int n_classes = THTensor_(size)(input, n_dims - 1); + ignore_index -= TH_INDEX_BASE; if (!THTensor_(isContiguous)(gradInput)) { THError("gradInput must be contiguous"); @@ -102,7 +110,7 @@ void THNN_(ClassNLLCriterion_updateGradInput)( if (THTensor_(nDimension)(input) > 2) { THError("input tensor should be 1D or 2D"); } - + if (weights && THTensor_(nElement)(weights) != n_classes) { THError("weight tensor should be defined either for all or no classes"); } @@ -116,10 +124,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)( if (THTensor_(nDimension)(input) == 1) { int cur_target = target_data[0] - TH_INDEX_BASE; - THAssert(cur_target >= 0 && cur_target < n_classes); + if (cur_target != ignore_index) { + THAssert(cur_target >= 0 && cur_target < n_classes); - gradInput_data[cur_target] = - (!sizeAverage && weights) ? -weights_data[cur_target] : -1; + gradInput_data[cur_target] = + (!sizeAverage && weights) ? -weights_data[cur_target] : -1; + } } else if (THTensor_(nDimension)(input) == 2) { int batch_size = THTensor_(size)(input, 0); @@ -131,13 +141,15 @@ void THNN_(ClassNLLCriterion_updateGradInput)( for (i = 0; i < batch_size; i++){ int cur_target = target_data[i] - TH_INDEX_BASE; - THAssert(cur_target >= 0 && cur_target < n_classes); + if (cur_target != ignore_index) { + THAssert(cur_target >= 0 && cur_target < n_classes); - gradInput_data[i * n_target + cur_target] = - -(weights ? weights_data[cur_target] : 1.0f); + gradInput_data[i * n_target + cur_target] = + -(weights ? weights_data[cur_target] : 1.0f); - if (sizeAverage && *total_weight_data) { - gradInput_data[i * n_target + cur_target] /= *total_weight_data; + if (sizeAverage && *total_weight_data) { + gradInput_data[i * n_target + cur_target] /= *total_weight_data; + } } } } diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h index fcc7f51..b9fd709 100644 --- a/lib/THNN/generic/THNN.h +++ b/lib/THNN/generic/THNN.h @@ -47,7 +47,8 @@ TH_API void THNN_(ClassNLLCriterion_updateOutput)( THTensor *output, // [OUT] a one-element tensor with loss bool sizeAverage, // if true, the loss will be normalized by batch size and class weights THTensor *weights, // [OPTIONAL] class weights - THTensor *total_weight); // [BUFFER] + THTensor *total_weight, // [BUFFER] + long ignore_index); // target index to ignore (loss = 0, gradInput = 0) TH_API void THNN_(ClassNLLCriterion_updateGradInput)( THNNState *state, // library's state THTensor *input, // input tensor (1D/2D) @@ -55,7 +56,8 @@ TH_API void THNN_(ClassNLLCriterion_updateGradInput)( THTensor *gradInput, // [OUT] gradient w.r.t. input bool sizeAverage, // if true, the loss will be normalized by batch size and class weights THTensor *weights, // [OPTIONAL] class weights - THTensor *total_weight); // [BUFFER] + THTensor *total_weight, // [BUFFER] + long ignore_index); // target index to ignore (loss = 0, gradInput = 0) TH_API void THNN_(SpatialClassNLLCriterion_updateOutput)( THNNState *state, // library's state |