Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Tompson <jonathantompson@gmail.com>2015-07-24 23:57:19 +0300
committerJonathan Tompson <jonathantompson@gmail.com>2015-07-24 23:57:19 +0300
commit6bcdc6f74873a214d90ac1418da86dc0f75048fc (patch)
tree19f94fa780b8fb9478d512316e86c50686479c81
parent3c13b0d41281990e3f1d4dd8a1ae45b8d5ca0f40 (diff)
Aded weight support to ClassNLLCriterion cuda.
-rw-r--r--ClassNLLCriterion.lua22
1 files changed, 20 insertions, 2 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua
index 409ac59..bc2a2e9 100644
--- a/ClassNLLCriterion.lua
+++ b/ClassNLLCriterion.lua
@@ -21,7 +21,23 @@ end
function ClassNLLCriterion:updateOutput(input, target)
- if input:type() == 'torch.CudaTensor' and not self.weights then
+ if input:type() == 'torch.CudaTensor' then
+ if self.weights == nil then
+ -- The CUDA implementation requires self.weights be non-nil
+ self.weights = torch.CudaTensor()
+ end
+ assert(self.weights:dim() == 0 or self.weights:dim() == 1,
+ 'weights must be 1D or empty')
+ -- The cuda code wont check weight size, so we must do it here.
+ if self.weights:dim() == 1 then
+ if input:dim() == 1 then
+ assert(self.weights:size(1) == input:size(1),
+ 'Wrong number of weights')
+ else
+ assert(self.weights:size(1) == input:size(2),
+ 'Wrong number of weights')
+ end
+ end
if input:dim() == 1 then
self._target = self._target or input.new(1)
if type(target) == 'number' then
@@ -66,7 +82,9 @@ function ClassNLLCriterion:updateGradInput(input, target)
self.gradInput:resizeAs(input)
self.gradInput:zero()
- if input:type() == 'torch.CudaTensor' and not self.weights then
+ if input:type() == 'torch.CudaTensor' then
+ -- Note: we'll assume that updateOutput() has been called and self.weights
+ -- is non-nil.
if input:dim() == 1 then
self._target = self._target or input.new(1)
if type(target) == 'number' then