Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-07-27 21:41:24 +0300
committerSoumith Chintala <soumith@gmail.com>2015-07-27 21:41:24 +0300
commit194522f1ba96432ab19c176e23a0b9b981174770 (patch)
treeb1c51e8b39f0f448932265142965ae8914c95c2b
parent1115a7d492875b87c88aa9673bd73230b1b4598f (diff)
parent6bcdc6f74873a214d90ac1418da86dc0f75048fc (diff)
Merge pull request #329 from jonathantompson/cuda_class_nll_criterion_weights
Added weight support to ClassNLLCriterion cuda
-rw-r--r--ClassNLLCriterion.lua22
1 files changed, 20 insertions, 2 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua
index 409ac59..bc2a2e9 100644
--- a/ClassNLLCriterion.lua
+++ b/ClassNLLCriterion.lua
@@ -21,7 +21,23 @@ end
function ClassNLLCriterion:updateOutput(input, target)
- if input:type() == 'torch.CudaTensor' and not self.weights then
+ if input:type() == 'torch.CudaTensor' then
+ if self.weights == nil then
+ -- The CUDA implementation requires self.weights be non-nil
+ self.weights = torch.CudaTensor()
+ end
+ assert(self.weights:dim() == 0 or self.weights:dim() == 1,
+ 'weights must be 1D or empty')
+ -- The cuda code wont check weight size, so we must do it here.
+ if self.weights:dim() == 1 then
+ if input:dim() == 1 then
+ assert(self.weights:size(1) == input:size(1),
+ 'Wrong number of weights')
+ else
+ assert(self.weights:size(1) == input:size(2),
+ 'Wrong number of weights')
+ end
+ end
if input:dim() == 1 then
self._target = self._target or input.new(1)
if type(target) == 'number' then
@@ -66,7 +82,9 @@ function ClassNLLCriterion:updateGradInput(input, target)
self.gradInput:resizeAs(input)
self.gradInput:zero()
- if input:type() == 'torch.CudaTensor' and not self.weights then
+ if input:type() == 'torch.CudaTensor' then
+ -- Note: we'll assume that updateOutput() has been called and self.weights
+ -- is non-nil.
if input:dim() == 1 then
self._target = self._target or input.new(1)
if type(target) == 'number' then