diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-09 02:33:19 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-09 02:33:19 +0400 |
commit | 037c8cdbac310be57ffb921492daf286806020bc (patch) | |
tree | 4f6c36bcbae256f3e97114fb65843be47f62cf47 /SoftMaxTree.lua | |
parent | 11e4e58f22700b8fb4ed033243cb74b0bc14dd53 (diff) |
SoftMaxTree.gradInput is a table (like input)
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r-- | SoftMaxTree.lua | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index 5db6a10..d074063 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -136,6 +136,10 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, verbose) self.batchSize = 0 + self._gradInput = torch.Tensor() + self._gradTarget = torch.Tensor() -- dummy + self.gradInput = {self._gradInput, self._gradTarget} + self:reset() end @@ -253,7 +257,8 @@ function SoftMaxTree:type(type) self._nodeBuffer = self._nodeBuffer:type(type) self._multiBuffer = self._multiBuffer:type(type) self.output = self.output:type(type) - self.gradInput = self.gradInput:type(type) + self._gradInput = self._gradInput:type(type) + self.gradInput = {self._gradInput, self._gradTarget} if (type == 'torch.CudaTensor') then -- cunnx needs this for filling self.updates self._nodeUpdateHost = torch.IntTensor() |