diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-07-25 17:00:56 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-07-25 17:00:56 +0300 |
commit | 9634a071f2edbe37c7539e9c5e2565b26a8ffbb4 (patch) | |
tree | 7ec093f65c72b1a3f118bc8ce26641b9524361df | |
parent | 2cd59e19df203c9bd45e2a4b5aa7e6f5f043dcda (diff) | |
parent | 3f14ce2485983a4804a7ac0525a4cd01e134b571 (diff) |
Merge pull request #308 from mys007/classnllbounds
NLL Criteria: weight bound checking
-rw-r--r-- | lib/THCUNN/ClassNLLCriterion.cu | 6 | ||||
-rw-r--r-- | lib/THCUNN/SpatialClassNLLCriterion.cu | 6 |
2 files changed, 12 insertions, 0 deletions
diff --git a/lib/THCUNN/ClassNLLCriterion.cu b/lib/THCUNN/ClassNLLCriterion.cu index 0949f3c..b2f54cb 100644 --- a/lib/THCUNN/ClassNLLCriterion.cu +++ b/lib/THCUNN/ClassNLLCriterion.cu @@ -128,6 +128,9 @@ void THNN_CudaClassNLLCriterion_updateOutput(THCState *state, THCudaTensor *inpu if (THCudaTensor_nDimension(state, input) > 2) { THArgCheck(0, 2, "vector or matrix expected"); } + if (weights && THCudaTensor_nElement(state, weights) != n_classes) { + THError("weight tensor should be defined either for all or no classes"); + } input = THCudaTensor_newContiguous(state, input); weights = weights ? THCudaTensor_newContiguous(state, weights) : NULL; @@ -198,6 +201,9 @@ void THNN_CudaClassNLLCriterion_updateGradInput(THCState *state, THCudaTensor *i if (THCudaTensor_nDimension(state, input) > 2) { THArgCheck(0, 2, "vector or matrix expected"); } + if (weights && THCudaTensor_nElement(state, weights) != n_classes) { + THError("weight tensor should be defined either for all or no classes"); + } weights = weights ? THCudaTensor_newContiguous(state, weights) : NULL; target = THCudaTensor_newContiguous(state, target); diff --git a/lib/THCUNN/SpatialClassNLLCriterion.cu b/lib/THCUNN/SpatialClassNLLCriterion.cu index 56b6a2d..c718772 100644 --- a/lib/THCUNN/SpatialClassNLLCriterion.cu +++ b/lib/THCUNN/SpatialClassNLLCriterion.cu @@ -96,6 +96,9 @@ void THNN_CudaSpatialClassNLLCriterion_updateOutput( "only batches of spatial targets supported (3D tensors)"); THArgCheck(THCudaTensor_nDimension(state, input) == 4, 2, "only batches of spatial inputs supported (4D tensors)"); + if (weights && THCudaTensor_nElement(state, weights) != THCudaTensor_size(state, input, 1)) { + THError("weight tensor should be defined either for all or no classes"); + } if (weights) THCUNN_assertSameGPU(state, 5, input, target, weights, output, total_weight); @@ -157,6 +160,9 @@ void THNN_CudaSpatialClassNLLCriterion_updateGradInput( "only batches of spatial inputs supported (4D tensors)"); THArgCheck(THCudaTensor_isContiguous(state, gradInput), 4, "gradInput must be contiguous"); + if (weights && THCudaTensor_nElement(state, weights) != THCudaTensor_size(state, input, 1)) { + THError("weight tensor should be defined either for all or no classes"); + } if (weights) THCUNN_assertSameGPU(state, 5, weights, input, target, gradInput, total_weight); |