Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-07-25 17:00:56 +0300
committerGitHub <noreply@github.com>2016-07-25 17:00:56 +0300
commit9634a071f2edbe37c7539e9c5e2565b26a8ffbb4 (patch)
tree7ec093f65c72b1a3f118bc8ce26641b9524361df
parent2cd59e19df203c9bd45e2a4b5aa7e6f5f043dcda (diff)
parent3f14ce2485983a4804a7ac0525a4cd01e134b571 (diff)
Merge pull request #308 from mys007/classnllbounds
NLL Criteria: weight bound checking
-rw-r--r--lib/THCUNN/ClassNLLCriterion.cu6
-rw-r--r--lib/THCUNN/SpatialClassNLLCriterion.cu6
2 files changed, 12 insertions, 0 deletions
diff --git a/lib/THCUNN/ClassNLLCriterion.cu b/lib/THCUNN/ClassNLLCriterion.cu
index 0949f3c..b2f54cb 100644
--- a/lib/THCUNN/ClassNLLCriterion.cu
+++ b/lib/THCUNN/ClassNLLCriterion.cu
@@ -128,6 +128,9 @@ void THNN_CudaClassNLLCriterion_updateOutput(THCState *state, THCudaTensor *inpu
if (THCudaTensor_nDimension(state, input) > 2) {
THArgCheck(0, 2, "vector or matrix expected");
}
+ if (weights && THCudaTensor_nElement(state, weights) != n_classes) {
+ THError("weight tensor should be defined either for all or no classes");
+ }
input = THCudaTensor_newContiguous(state, input);
weights = weights ? THCudaTensor_newContiguous(state, weights) : NULL;
@@ -198,6 +201,9 @@ void THNN_CudaClassNLLCriterion_updateGradInput(THCState *state, THCudaTensor *i
if (THCudaTensor_nDimension(state, input) > 2) {
THArgCheck(0, 2, "vector or matrix expected");
}
+ if (weights && THCudaTensor_nElement(state, weights) != n_classes) {
+ THError("weight tensor should be defined either for all or no classes");
+ }
weights = weights ? THCudaTensor_newContiguous(state, weights) : NULL;
target = THCudaTensor_newContiguous(state, target);
diff --git a/lib/THCUNN/SpatialClassNLLCriterion.cu b/lib/THCUNN/SpatialClassNLLCriterion.cu
index 56b6a2d..c718772 100644
--- a/lib/THCUNN/SpatialClassNLLCriterion.cu
+++ b/lib/THCUNN/SpatialClassNLLCriterion.cu
@@ -96,6 +96,9 @@ void THNN_CudaSpatialClassNLLCriterion_updateOutput(
"only batches of spatial targets supported (3D tensors)");
THArgCheck(THCudaTensor_nDimension(state, input) == 4, 2,
"only batches of spatial inputs supported (4D tensors)");
+ if (weights && THCudaTensor_nElement(state, weights) != THCudaTensor_size(state, input, 1)) {
+ THError("weight tensor should be defined either for all or no classes");
+ }
if (weights)
THCUNN_assertSameGPU(state, 5, input, target, weights, output, total_weight);
@@ -157,6 +160,9 @@ void THNN_CudaSpatialClassNLLCriterion_updateGradInput(
"only batches of spatial inputs supported (4D tensors)");
THArgCheck(THCudaTensor_isContiguous(state, gradInput), 4,
"gradInput must be contiguous");
+ if (weights && THCudaTensor_nElement(state, weights) != THCudaTensor_size(state, input, 1)) {
+ THError("weight tensor should be defined either for all or no classes");
+ }
if (weights)
THCUNN_assertSameGPU(state, 5, weights, input, target, gradInput, total_weight);