Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-09 02:33:19 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-09 02:33:19 +0400
commit037c8cdbac310be57ffb921492daf286806020bc (patch)
tree4f6c36bcbae256f3e97114fb65843be47f62cf47 /SoftMaxTree.lua
parent11e4e58f22700b8fb4ed033243cb74b0bc14dd53 (diff)
SoftMaxTree.gradInput is a table (like input)
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r--SoftMaxTree.lua7
1 files changed, 6 insertions, 1 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index 5db6a10..d074063 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -136,6 +136,10 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, verbose)
self.batchSize = 0
+ self._gradInput = torch.Tensor()
+ self._gradTarget = torch.Tensor() -- dummy
+ self.gradInput = {self._gradInput, self._gradTarget}
+
self:reset()
end
@@ -253,7 +257,8 @@ function SoftMaxTree:type(type)
self._nodeBuffer = self._nodeBuffer:type(type)
self._multiBuffer = self._multiBuffer:type(type)
self.output = self.output:type(type)
- self.gradInput = self.gradInput:type(type)
+ self._gradInput = self._gradInput:type(type)
+ self.gradInput = {self._gradInput, self._gradTarget}
if (type == 'torch.CudaTensor') then
-- cunnx needs this for filling self.updates
self._nodeUpdateHost = torch.IntTensor()