diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-07-27 21:41:24 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-07-27 21:41:24 +0300 |
commit | 194522f1ba96432ab19c176e23a0b9b981174770 (patch) | |
tree | b1c51e8b39f0f448932265142965ae8914c95c2b | |
parent | 1115a7d492875b87c88aa9673bd73230b1b4598f (diff) | |
parent | 6bcdc6f74873a214d90ac1418da86dc0f75048fc (diff) |
Merge pull request #329 from jonathantompson/cuda_class_nll_criterion_weights
Added weight support to ClassNLLCriterion cuda
-rw-r--r-- | ClassNLLCriterion.lua | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua index 409ac59..bc2a2e9 100644 --- a/ClassNLLCriterion.lua +++ b/ClassNLLCriterion.lua @@ -21,7 +21,23 @@ end function ClassNLLCriterion:updateOutput(input, target) - if input:type() == 'torch.CudaTensor' and not self.weights then + if input:type() == 'torch.CudaTensor' then + if self.weights == nil then + -- The CUDA implementation requires self.weights be non-nil + self.weights = torch.CudaTensor() + end + assert(self.weights:dim() == 0 or self.weights:dim() == 1, + 'weights must be 1D or empty') + -- The cuda code wont check weight size, so we must do it here. + if self.weights:dim() == 1 then + if input:dim() == 1 then + assert(self.weights:size(1) == input:size(1), + 'Wrong number of weights') + else + assert(self.weights:size(1) == input:size(2), + 'Wrong number of weights') + end + end if input:dim() == 1 then self._target = self._target or input.new(1) if type(target) == 'number' then @@ -66,7 +82,9 @@ function ClassNLLCriterion:updateGradInput(input, target) self.gradInput:resizeAs(input) self.gradInput:zero() - if input:type() == 'torch.CudaTensor' and not self.weights then + if input:type() == 'torch.CudaTensor' then + -- Note: we'll assume that updateOutput() has been called and self.weights + -- is non-nil. if input:dim() == 1 then self._target = self._target or input.new(1) if type(target) == 'number' then |