diff options
author | Jonathan Tompson <jonathantompson@gmail.com> | 2015-07-24 23:57:19 +0300 |
---|---|---|
committer | Jonathan Tompson <jonathantompson@gmail.com> | 2015-07-24 23:57:19 +0300 |
commit | 6bcdc6f74873a214d90ac1418da86dc0f75048fc (patch) | |
tree | 19f94fa780b8fb9478d512316e86c50686479c81 | |
parent | 3c13b0d41281990e3f1d4dd8a1ae45b8d5ca0f40 (diff) |
Aded weight support to ClassNLLCriterion cuda.
-rw-r--r-- | ClassNLLCriterion.lua | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua index 409ac59..bc2a2e9 100644 --- a/ClassNLLCriterion.lua +++ b/ClassNLLCriterion.lua @@ -21,7 +21,23 @@ end function ClassNLLCriterion:updateOutput(input, target) - if input:type() == 'torch.CudaTensor' and not self.weights then + if input:type() == 'torch.CudaTensor' then + if self.weights == nil then + -- The CUDA implementation requires self.weights be non-nil + self.weights = torch.CudaTensor() + end + assert(self.weights:dim() == 0 or self.weights:dim() == 1, + 'weights must be 1D or empty') + -- The cuda code wont check weight size, so we must do it here. + if self.weights:dim() == 1 then + if input:dim() == 1 then + assert(self.weights:size(1) == input:size(1), + 'Wrong number of weights') + else + assert(self.weights:size(1) == input:size(2), + 'Wrong number of weights') + end + end if input:dim() == 1 then self._target = self._target or input.new(1) if type(target) == 'number' then @@ -66,7 +82,9 @@ function ClassNLLCriterion:updateGradInput(input, target) self.gradInput:resizeAs(input) self.gradInput:zero() - if input:type() == 'torch.CudaTensor' and not self.weights then + if input:type() == 'torch.CudaTensor' then + -- Note: we'll assume that updateOutput() has been called and self.weights + -- is non-nil. if input:dim() == 1 then self._target = self._target or input.new(1) if type(target) == 'number' then |