Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-05-16 05:27:00 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-16 05:27:00 +0300
commit53f7b2532da4216bba76a9feafcfb7b273b2cc8d (patch)
tree96421bf61aba74abed654cb04276fced033360cb
parent501b31c4763ce236aef46235bdc21cb499fb6e3b (diff)
ClassNLLCriterion ignoreIndex
-rw-r--r--lib/THCUNN/ClassNLLCriterion.cu48
-rw-r--r--lib/THCUNN/generic/ClassNLLCriterion.cu20
-rw-r--r--lib/THCUNN/generic/THCUNN.h6
-rw-r--r--test.lua23
4 files changed, 71 insertions, 26 deletions
diff --git a/lib/THCUNN/ClassNLLCriterion.cu b/lib/THCUNN/ClassNLLCriterion.cu
index 58684f4..194d64c 100644
--- a/lib/THCUNN/ClassNLLCriterion.cu
+++ b/lib/THCUNN/ClassNLLCriterion.cu
@@ -15,19 +15,22 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output,
THCIndex_t *target,
Dtype *weights,
int size_average,
- int n_classes) {
+ int n_classes,
+ long ignore_index) {
assert(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0);
// TODO: T4951791 Reuse code between updateOutput_kernel1 and
// updateOutput_kernel.
int t = (int)*target - TH_INDEX_BASE;
- assert(t >= 0 && t < n_classes);
- Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
- *output = -cur_weight * input[t];
- *total_weight = cur_weight;
- if (size_average && *total_weight > 0) {
- *output /= *total_weight;
+ if (t != ignore_index) {
+ assert(t >= 0 && t < n_classes);
+ Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
+ *output = -cur_weight * input[t];
+ *total_weight = cur_weight;
+ if (size_average && *total_weight > 0) {
+ *output /= *total_weight;
+ }
}
}
@@ -40,7 +43,8 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
int size_average,
int nframe,
int ndim,
- int n_classes) {
+ int n_classes,
+ long ignore_index) {
__shared__ Acctype shInputs[NTHREADS], acc_weight[NTHREADS];
int i, t;
Dtype cur_weight;
@@ -49,10 +53,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
acc_weight[threadIdx.x] = ScalarConvert<int, Acctype>::to(0);
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
t = target[i] - TH_INDEX_BASE;
- assert(t >= 0 && t < n_classes);
- cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
- shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
- acc_weight[threadIdx.x] += cur_weight;
+ if (t != ignore_index) {
+ assert(t >= 0 && t < n_classes);
+ cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
+ shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
+ acc_weight[threadIdx.x] += cur_weight;
+ }
}
__syncthreads();
@@ -84,15 +90,18 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1(
THCIndex_t* target,
Dtype* total_weight,
int size_average,
- int n_classes)
+ int n_classes,
+ long ignore_index)
{
if (*total_weight <= 0) {
return;
}
Dtype norm = size_average ? (ScalarConvert<int, Dtype>::to(1) / *total_weight) : ScalarConvert<int, Dtype>::to(1);
int t = (int)*target - TH_INDEX_BASE;
- assert(t >= 0 && t < n_classes);
- gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
+ if (t != ignore_index) {
+ assert(t >= 0 && t < n_classes);
+ gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
+ }
}
template <typename Dtype>
@@ -104,7 +113,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
int size_average,
int nframe,
int ndim,
- int n_classes)
+ int n_classes,
+ long ignore_index)
{
if (*total_weight <= 0) {
return;
@@ -114,8 +124,10 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
t = (int)target[i] - TH_INDEX_BASE;
- assert(t >= 0 && t < n_classes);
- gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
+ if (t != ignore_index) {
+ assert(t >= 0 && t < n_classes);
+ gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
+ }
}
}
diff --git a/lib/THCUNN/generic/ClassNLLCriterion.cu b/lib/THCUNN/generic/ClassNLLCriterion.cu
index 039372b..a41c555 100644
--- a/lib/THCUNN/generic/ClassNLLCriterion.cu
+++ b/lib/THCUNN/generic/ClassNLLCriterion.cu
@@ -9,9 +9,11 @@ void THNN_(ClassNLLCriterion_updateOutput)(
THCTensor *output,
bool sizeAverage,
THCTensor *weights,
- THCTensor *total_weight) {
+ THCTensor *total_weight,
+ long ignore_index) {
THCUNN_check_dim_size(state, output, 1, 0, 1);
THCUNN_check_dim_size(state, total_weight, 1, 0, 1);
+ ignore_index -= TH_INDEX_BASE;
if (THCIndexTensor_(nDimension)(state, target) > 1) {
THError("multi-target not supported");
@@ -63,7 +65,8 @@ void THNN_(ClassNLLCriterion_updateOutput)(
target_data,
weights_data,
sizeAverage,
- n_classes
+ n_classes,
+ ignore_index
);
} else if (THCTensor_(nDimension)(state, input) == 2) {
@@ -77,7 +80,8 @@ void THNN_(ClassNLLCriterion_updateOutput)(
sizeAverage,
THCTensor_(size)(state, input, 0),
THCTensor_(size)(state, input, 1),
- n_classes
+ n_classes,
+ ignore_index
);
}
THCudaCheck(cudaGetLastError());
@@ -96,10 +100,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
THCTensor *gradInput,
bool sizeAverage,
THCTensor *weights,
- THCTensor *total_weight) {
+ THCTensor *total_weight,
+ long ignore_index) {
if (THCIndexTensor_(nDimension)(state, target) > 1) {
THError("multi-target not supported");
}
+ ignore_index -= TH_INDEX_BASE;
int n_dims = THCTensor_(nDimension)(state, input);
int n_classes = THCTensor_(size)(state, input, n_dims - 1);
@@ -145,7 +151,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
target_data,
total_weight_data,
sizeAverage,
- n_classes
+ n_classes,
+ ignore_index
);
} else {
cunn_ClassNLLCriterion_updateGradInput_kernel<real>
@@ -157,7 +164,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
sizeAverage,
THCTensor_(size)(state, input, 0),
THCTensor_(size)(state, input, 1),
- n_classes
+ n_classes,
+ ignore_index
);
}
THCudaCheck(cudaGetLastError());
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index b44fff3..72ea749 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -80,7 +80,8 @@ TH_API void THNN_(ClassNLLCriterion_updateOutput)(
THCTensor *output,
bool sizeAverage,
THCTensor *weights, // [OPTIONAL]
- THCTensor *total_weight);
+ THCTensor *total_weight,
+ long ignore_index);
TH_API void THNN_(ClassNLLCriterion_updateGradInput)(
THCState *state,
@@ -89,7 +90,8 @@ TH_API void THNN_(ClassNLLCriterion_updateGradInput)(
THCTensor *gradInput,
bool sizeAverage,
THCTensor *weights, // [OPTIONAL]
- THCTensor *total_weight);
+ THCTensor *total_weight,
+ long ignore_index);
TH_API void THNN_(DistKLDivCriterion_updateOutput)(
THCState *state,
diff --git a/test.lua b/test.lua
index f7fd728..0cd64fb 100644
--- a/test.lua
+++ b/test.lua
@@ -4626,6 +4626,29 @@ function cunntest.ClassNLLCriterionMultipleTargetWeights()
end
end
+function cunntest.ClassNLLCriterion_ignoreIndex()
+ local numLabels = 10
+ local batchsize = 4
+ local ignoreIndex = -1
+ local cri = nn.ClassNLLCriterion(nil, nil, ignoreIndex):cuda()
+ local input = torch.randn(numLabels):cuda()
+ local target = ignoreIndex
+ mytester:assert(cri:forward(input, target) == 0)
+ mytester:assert(cri:backward(input, target):abs():sum() == 0)
+ local input = torch.randn(batchsize, numLabels):cuda()
+ local target = torch.LongTensor(batchsize):random(1,numLabels)
+ target[1] = ignoreIndex
+ target = target:cudaLong()
+ local output = cri:forward(input, target)
+ local gradInput = cri:backward(input, target):clone()
+ mytester:assert(gradInput[1]:abs():sum() == 0)
+ local input, target = input:sub(2,batchsize), target:sub(2,batchsize)
+ local output2 = cri:forward(input, target)
+ mytester:assert(math.abs(output2 - output) < 0.0000001)
+ local gradInput2 = cri:backward(input, target)
+ mytester:assertTensorEq(gradInput2, gradInput:sub(2,batchsize), 0.0000001)
+end
+
function cunntest.TemporalMaxPooling()
local settings = {{2, 2}, {3, 3}, {4, 2}, {2, 4}, {3, 5}}