Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-05-21 20:48:19 +0300
committerGitHub <noreply@github.com>2017-05-21 20:48:19 +0300
commit78aac1a015ebba0655a7fdad8a4a09419b68da67 (patch)
treee534fc3f1f1192102fd4b5b25974abe6d4d7f9f2 /lib
parent482537275df7fde77cc4dcc1d93de33cbfafde9f (diff)
Revert "Revert "ClassNLLCriterion supports missing targets""revert-1217-revert-1215-ClassNLLCriterion-missing-target
Diffstat (limited to 'lib')
-rw-r--r--lib/THNN/generic/ClassNLLCriterion.c48
-rw-r--r--lib/THNN/generic/THNN.h6
2 files changed, 34 insertions, 20 deletions
diff --git a/lib/THNN/generic/ClassNLLCriterion.c b/lib/THNN/generic/ClassNLLCriterion.c
index 0db3a8a..4cf37ae 100644
--- a/lib/THNN/generic/ClassNLLCriterion.c
+++ b/lib/THNN/generic/ClassNLLCriterion.c
@@ -9,12 +9,14 @@ void THNN_(ClassNLLCriterion_updateOutput)(
THTensor *output,
bool sizeAverage,
THTensor *weights,
- THTensor *total_weight)
+ THTensor *total_weight,
+ long ignore_index)
{
THNN_CHECK_DIM_SIZE(output, 1, 0, 1);
THNN_CHECK_DIM_SIZE(total_weight, 1, 0, 1);
int n_dims = THTensor_(nDimension)(input);
int n_classes = THTensor_(size)(input, n_dims - 1);
+ ignore_index -= TH_INDEX_BASE;
if (THIndexTensor_(nDimension)(target) > 1) {
THError("multi-target not supported");
@@ -42,9 +44,11 @@ void THNN_(ClassNLLCriterion_updateOutput)(
if (THTensor_(nDimension)(input) == 1) {
int cur_target = target_data[0] - TH_INDEX_BASE;
- THAssert(cur_target >= 0 && cur_target < n_classes);
- total_weight_data[0] = weights ? weights_data[cur_target] : 1.0f;
- output_data[0] = -input_data[cur_target] * total_weight_data[0];
+ if (cur_target != ignore_index) {
+ THAssert(cur_target >= 0 && cur_target < n_classes);
+ total_weight_data[0] = weights ? weights_data[cur_target] : 1.0f;
+ output_data[0] = -input_data[cur_target] * total_weight_data[0];
+ }
} else if (THTensor_(nDimension)(input) == 2) {
int batch_size = THTensor_(size)(input, 0);
THAssert(THIndexTensor_(size)(target, 0) == batch_size);
@@ -54,11 +58,13 @@ void THNN_(ClassNLLCriterion_updateOutput)(
int i;
for (i = 0; i < batch_size; i++) {
int cur_target = target_data[i] - TH_INDEX_BASE;
- THAssert(cur_target >= 0 && cur_target < n_classes);
+ if (cur_target != ignore_index) {
+ THAssert(cur_target >= 0 && cur_target < n_classes);
- real cur_weight = weights ? weights_data[cur_target] : 1.0f;
- total_weight_data[0] += cur_weight;
- output_data[0] -= input_data[i * n_target + cur_target] * cur_weight;
+ real cur_weight = weights ? weights_data[cur_target] : 1.0f;
+ total_weight_data[0] += cur_weight;
+ output_data[0] -= input_data[i * n_target + cur_target] * cur_weight;
+ }
}
}
@@ -80,10 +86,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
THTensor *gradInput,
bool sizeAverage,
THTensor *weights,
- THTensor *total_weight)
+ THTensor *total_weight,
+ long ignore_index)
{
int n_dims = THTensor_(nDimension)(input);
int n_classes = THTensor_(size)(input, n_dims - 1);
+ ignore_index -= TH_INDEX_BASE;
if (!THTensor_(isContiguous)(gradInput)) {
THError("gradInput must be contiguous");
@@ -102,7 +110,7 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
if (THTensor_(nDimension)(input) > 2) {
THError("input tensor should be 1D or 2D");
}
-
+
if (weights && THTensor_(nElement)(weights) != n_classes) {
THError("weight tensor should be defined either for all or no classes");
}
@@ -116,10 +124,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
if (THTensor_(nDimension)(input) == 1) {
int cur_target = target_data[0] - TH_INDEX_BASE;
- THAssert(cur_target >= 0 && cur_target < n_classes);
+ if (cur_target != ignore_index) {
+ THAssert(cur_target >= 0 && cur_target < n_classes);
- gradInput_data[cur_target] =
- (!sizeAverage && weights) ? -weights_data[cur_target] : -1;
+ gradInput_data[cur_target] =
+ (!sizeAverage && weights) ? -weights_data[cur_target] : -1;
+ }
} else if (THTensor_(nDimension)(input) == 2) {
int batch_size = THTensor_(size)(input, 0);
@@ -131,13 +141,15 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
for (i = 0; i < batch_size; i++){
int cur_target = target_data[i] - TH_INDEX_BASE;
- THAssert(cur_target >= 0 && cur_target < n_classes);
+ if (cur_target != ignore_index) {
+ THAssert(cur_target >= 0 && cur_target < n_classes);
- gradInput_data[i * n_target + cur_target] =
- -(weights ? weights_data[cur_target] : 1.0f);
+ gradInput_data[i * n_target + cur_target] =
+ -(weights ? weights_data[cur_target] : 1.0f);
- if (sizeAverage && *total_weight_data) {
- gradInput_data[i * n_target + cur_target] /= *total_weight_data;
+ if (sizeAverage && *total_weight_data) {
+ gradInput_data[i * n_target + cur_target] /= *total_weight_data;
+ }
}
}
}
diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h
index fcc7f51..b9fd709 100644
--- a/lib/THNN/generic/THNN.h
+++ b/lib/THNN/generic/THNN.h
@@ -47,7 +47,8 @@ TH_API void THNN_(ClassNLLCriterion_updateOutput)(
THTensor *output, // [OUT] a one-element tensor with loss
bool sizeAverage, // if true, the loss will be normalized by batch size and class weights
THTensor *weights, // [OPTIONAL] class weights
- THTensor *total_weight); // [BUFFER]
+ THTensor *total_weight, // [BUFFER]
+ long ignore_index); // target index to ignore (loss = 0, gradInput = 0)
TH_API void THNN_(ClassNLLCriterion_updateGradInput)(
THNNState *state, // library's state
THTensor *input, // input tensor (1D/2D)
@@ -55,7 +56,8 @@ TH_API void THNN_(ClassNLLCriterion_updateGradInput)(
THTensor *gradInput, // [OUT] gradient w.r.t. input
bool sizeAverage, // if true, the loss will be normalized by batch size and class weights
THTensor *weights, // [OPTIONAL] class weights
- THTensor *total_weight); // [BUFFER]
+ THTensor *total_weight, // [BUFFER]
+ long ignore_index); // target index to ignore (loss = 0, gradInput = 0)
TH_API void THNN_(SpatialClassNLLCriterion_updateOutput)(
THNNState *state, // library's state