diff options
Diffstat (limited to 'lib/THCUNN/generic/ClassNLLCriterion.cu')
-rw-r--r-- | lib/THCUNN/generic/ClassNLLCriterion.cu | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/lib/THCUNN/generic/ClassNLLCriterion.cu b/lib/THCUNN/generic/ClassNLLCriterion.cu index 039372b..a41c555 100644 --- a/lib/THCUNN/generic/ClassNLLCriterion.cu +++ b/lib/THCUNN/generic/ClassNLLCriterion.cu @@ -9,9 +9,11 @@ void THNN_(ClassNLLCriterion_updateOutput)( THCTensor *output, bool sizeAverage, THCTensor *weights, - THCTensor *total_weight) { + THCTensor *total_weight, + long ignore_index) { THCUNN_check_dim_size(state, output, 1, 0, 1); THCUNN_check_dim_size(state, total_weight, 1, 0, 1); + ignore_index -= TH_INDEX_BASE; if (THCIndexTensor_(nDimension)(state, target) > 1) { THError("multi-target not supported"); @@ -63,7 +65,8 @@ void THNN_(ClassNLLCriterion_updateOutput)( target_data, weights_data, sizeAverage, - n_classes + n_classes, + ignore_index ); } else if (THCTensor_(nDimension)(state, input) == 2) { @@ -77,7 +80,8 @@ void THNN_(ClassNLLCriterion_updateOutput)( sizeAverage, THCTensor_(size)(state, input, 0), THCTensor_(size)(state, input, 1), - n_classes + n_classes, + ignore_index ); } THCudaCheck(cudaGetLastError()); @@ -96,10 +100,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)( THCTensor *gradInput, bool sizeAverage, THCTensor *weights, - THCTensor *total_weight) { + THCTensor *total_weight, + long ignore_index) { if (THCIndexTensor_(nDimension)(state, target) > 1) { THError("multi-target not supported"); } + ignore_index -= TH_INDEX_BASE; int n_dims = THCTensor_(nDimension)(state, input); int n_classes = THCTensor_(size)(state, input, n_dims - 1); @@ -145,7 +151,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)( target_data, total_weight_data, sizeAverage, - n_classes + n_classes, + ignore_index ); } else { cunn_ClassNLLCriterion_updateGradInput_kernel<real> @@ -157,7 +164,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)( sizeAverage, THCTensor_(size)(state, input, 0), THCTensor_(size)(state, input, 1), - n_classes + n_classes, + ignore_index ); } THCudaCheck(cudaGetLastError()); |