Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THCUNN/generic/ClassNLLCriterion.cu')
-rw-r--r--lib/THCUNN/generic/ClassNLLCriterion.cu20
1 files changed, 14 insertions, 6 deletions
diff --git a/lib/THCUNN/generic/ClassNLLCriterion.cu b/lib/THCUNN/generic/ClassNLLCriterion.cu
index 039372b..a41c555 100644
--- a/lib/THCUNN/generic/ClassNLLCriterion.cu
+++ b/lib/THCUNN/generic/ClassNLLCriterion.cu
@@ -9,9 +9,11 @@ void THNN_(ClassNLLCriterion_updateOutput)(
THCTensor *output,
bool sizeAverage,
THCTensor *weights,
- THCTensor *total_weight) {
+ THCTensor *total_weight,
+ long ignore_index) {
THCUNN_check_dim_size(state, output, 1, 0, 1);
THCUNN_check_dim_size(state, total_weight, 1, 0, 1);
+ ignore_index -= TH_INDEX_BASE;
if (THCIndexTensor_(nDimension)(state, target) > 1) {
THError("multi-target not supported");
@@ -63,7 +65,8 @@ void THNN_(ClassNLLCriterion_updateOutput)(
target_data,
weights_data,
sizeAverage,
- n_classes
+ n_classes,
+ ignore_index
);
} else if (THCTensor_(nDimension)(state, input) == 2) {
@@ -77,7 +80,8 @@ void THNN_(ClassNLLCriterion_updateOutput)(
sizeAverage,
THCTensor_(size)(state, input, 0),
THCTensor_(size)(state, input, 1),
- n_classes
+ n_classes,
+ ignore_index
);
}
THCudaCheck(cudaGetLastError());
@@ -96,10 +100,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
THCTensor *gradInput,
bool sizeAverage,
THCTensor *weights,
- THCTensor *total_weight) {
+ THCTensor *total_weight,
+ long ignore_index) {
if (THCIndexTensor_(nDimension)(state, target) > 1) {
THError("multi-target not supported");
}
+ ignore_index -= TH_INDEX_BASE;
int n_dims = THCTensor_(nDimension)(state, input);
int n_classes = THCTensor_(size)(state, input, n_dims - 1);
@@ -145,7 +151,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
target_data,
total_weight_data,
sizeAverage,
- n_classes
+ n_classes,
+ ignore_index
);
} else {
cunn_ClassNLLCriterion_updateGradInput_kernel<real>
@@ -157,7 +164,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
sizeAverage,
THCTensor_(size)(state, input, 0),
THCTensor_(size)(state, input, 1),
- n_classes
+ n_classes,
+ ignore_index
);
}
THCudaCheck(cudaGetLastError());