diff options
author | Nicholas Leonard <nleonard@twitter.com> | 2017-05-16 05:27:00 +0300 |
---|---|---|
committer | Nicholas Leonard <nleonard@twitter.com> | 2017-05-16 05:27:00 +0300 |
commit | 53f7b2532da4216bba76a9feafcfb7b273b2cc8d (patch) | |
tree | 96421bf61aba74abed654cb04276fced033360cb | |
parent | 501b31c4763ce236aef46235bdc21cb499fb6e3b (diff) |
ClassNLLCriterion ignoreIndex
-rw-r--r-- | lib/THCUNN/ClassNLLCriterion.cu | 48 | ||||
-rw-r--r-- | lib/THCUNN/generic/ClassNLLCriterion.cu | 20 | ||||
-rw-r--r-- | lib/THCUNN/generic/THCUNN.h | 6 | ||||
-rw-r--r-- | test.lua | 23 |
4 files changed, 71 insertions, 26 deletions
diff --git a/lib/THCUNN/ClassNLLCriterion.cu b/lib/THCUNN/ClassNLLCriterion.cu index 58684f4..194d64c 100644 --- a/lib/THCUNN/ClassNLLCriterion.cu +++ b/lib/THCUNN/ClassNLLCriterion.cu @@ -15,19 +15,22 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output, THCIndex_t *target, Dtype *weights, int size_average, - int n_classes) { + int n_classes, + long ignore_index) { assert(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0); // TODO: T4951791 Reuse code between updateOutput_kernel1 and // updateOutput_kernel. int t = (int)*target - TH_INDEX_BASE; - assert(t >= 0 && t < n_classes); - Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); - *output = -cur_weight * input[t]; - *total_weight = cur_weight; - if (size_average && *total_weight > 0) { - *output /= *total_weight; + if (t != ignore_index) { + assert(t >= 0 && t < n_classes); + Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); + *output = -cur_weight * input[t]; + *total_weight = cur_weight; + if (size_average && *total_weight > 0) { + *output /= *total_weight; + } } } @@ -40,7 +43,8 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output, int size_average, int nframe, int ndim, - int n_classes) { + int n_classes, + long ignore_index) { __shared__ Acctype shInputs[NTHREADS], acc_weight[NTHREADS]; int i, t; Dtype cur_weight; @@ -49,10 +53,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output, acc_weight[threadIdx.x] = ScalarConvert<int, Acctype>::to(0); for (i = threadIdx.x; i < nframe; i += NTHREADS) { t = target[i] - TH_INDEX_BASE; - assert(t >= 0 && t < n_classes); - cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); - shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight; - acc_weight[threadIdx.x] += cur_weight; + if (t != ignore_index) { + assert(t >= 0 && t < n_classes); + cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1); + shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight; + acc_weight[threadIdx.x] += cur_weight; + } } __syncthreads(); @@ -84,15 +90,18 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1( THCIndex_t* target, Dtype* total_weight, int size_average, - int n_classes) + int n_classes, + long ignore_index) { if (*total_weight <= 0) { return; } Dtype norm = size_average ? (ScalarConvert<int, Dtype>::to(1) / *total_weight) : ScalarConvert<int, Dtype>::to(1); int t = (int)*target - TH_INDEX_BASE; - assert(t >= 0 && t < n_classes); - gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; + if (t != ignore_index) { + assert(t >= 0 && t < n_classes); + gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; + } } template <typename Dtype> @@ -104,7 +113,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel( int size_average, int nframe, int ndim, - int n_classes) + int n_classes, + long ignore_index) { if (*total_weight <= 0) { return; @@ -114,8 +124,10 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel( for (i = threadIdx.x; i < nframe; i += NTHREADS) { t = (int)target[i] - TH_INDEX_BASE; - assert(t >= 0 && t < n_classes); - gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; + if (t != ignore_index) { + assert(t >= 0 && t < n_classes); + gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm; + } } } diff --git a/lib/THCUNN/generic/ClassNLLCriterion.cu b/lib/THCUNN/generic/ClassNLLCriterion.cu index 039372b..a41c555 100644 --- a/lib/THCUNN/generic/ClassNLLCriterion.cu +++ b/lib/THCUNN/generic/ClassNLLCriterion.cu @@ -9,9 +9,11 @@ void THNN_(ClassNLLCriterion_updateOutput)( THCTensor *output, bool sizeAverage, THCTensor *weights, - THCTensor *total_weight) { + THCTensor *total_weight, + long ignore_index) { THCUNN_check_dim_size(state, output, 1, 0, 1); THCUNN_check_dim_size(state, total_weight, 1, 0, 1); + ignore_index -= TH_INDEX_BASE; if (THCIndexTensor_(nDimension)(state, target) > 1) { THError("multi-target not supported"); @@ -63,7 +65,8 @@ void THNN_(ClassNLLCriterion_updateOutput)( target_data, weights_data, sizeAverage, - n_classes + n_classes, + ignore_index ); } else if (THCTensor_(nDimension)(state, input) == 2) { @@ -77,7 +80,8 @@ void THNN_(ClassNLLCriterion_updateOutput)( sizeAverage, THCTensor_(size)(state, input, 0), THCTensor_(size)(state, input, 1), - n_classes + n_classes, + ignore_index ); } THCudaCheck(cudaGetLastError()); @@ -96,10 +100,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)( THCTensor *gradInput, bool sizeAverage, THCTensor *weights, - THCTensor *total_weight) { + THCTensor *total_weight, + long ignore_index) { if (THCIndexTensor_(nDimension)(state, target) > 1) { THError("multi-target not supported"); } + ignore_index -= TH_INDEX_BASE; int n_dims = THCTensor_(nDimension)(state, input); int n_classes = THCTensor_(size)(state, input, n_dims - 1); @@ -145,7 +151,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)( target_data, total_weight_data, sizeAverage, - n_classes + n_classes, + ignore_index ); } else { cunn_ClassNLLCriterion_updateGradInput_kernel<real> @@ -157,7 +164,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)( sizeAverage, THCTensor_(size)(state, input, 0), THCTensor_(size)(state, input, 1), - n_classes + n_classes, + ignore_index ); } THCudaCheck(cudaGetLastError()); diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index b44fff3..72ea749 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -80,7 +80,8 @@ TH_API void THNN_(ClassNLLCriterion_updateOutput)( THCTensor *output, bool sizeAverage, THCTensor *weights, // [OPTIONAL] - THCTensor *total_weight); + THCTensor *total_weight, + long ignore_index); TH_API void THNN_(ClassNLLCriterion_updateGradInput)( THCState *state, @@ -89,7 +90,8 @@ TH_API void THNN_(ClassNLLCriterion_updateGradInput)( THCTensor *gradInput, bool sizeAverage, THCTensor *weights, // [OPTIONAL] - THCTensor *total_weight); + THCTensor *total_weight, + long ignore_index); TH_API void THNN_(DistKLDivCriterion_updateOutput)( THCState *state, @@ -4626,6 +4626,29 @@ function cunntest.ClassNLLCriterionMultipleTargetWeights() end end +function cunntest.ClassNLLCriterion_ignoreIndex() + local numLabels = 10 + local batchsize = 4 + local ignoreIndex = -1 + local cri = nn.ClassNLLCriterion(nil, nil, ignoreIndex):cuda() + local input = torch.randn(numLabels):cuda() + local target = ignoreIndex + mytester:assert(cri:forward(input, target) == 0) + mytester:assert(cri:backward(input, target):abs():sum() == 0) + local input = torch.randn(batchsize, numLabels):cuda() + local target = torch.LongTensor(batchsize):random(1,numLabels) + target[1] = ignoreIndex + target = target:cudaLong() + local output = cri:forward(input, target) + local gradInput = cri:backward(input, target):clone() + mytester:assert(gradInput[1]:abs():sum() == 0) + local input, target = input:sub(2,batchsize), target:sub(2,batchsize) + local output2 = cri:forward(input, target) + mytester:assert(math.abs(output2 - output) < 0.0000001) + local gradInput2 = cri:backward(input, target) + mytester:assertTensorEq(gradInput2, gradInput:sub(2,batchsize), 0.0000001) +end + function cunntest.TemporalMaxPooling() local settings = {{2, 2}, {3, 3}, {4, 2}, {2, 4}, {3, 5}} |