Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-15 21:38:16 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-15 21:38:16 +0400
commit91b9753dc1ad5254940ec7dce30d4962c5489ba9 (patch)
tree8ba6ec9054e765c0a67771bf035f217094ca6234 /SoftMaxTree.lua
parent4e0fbda0f0e8e0c691cf189a8a7550ed65bdf184 (diff)
SoftMaxTree works with ConcatTable
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r--SoftMaxTree.lua4
1 files changed, 3 insertions, 1 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index 6b783a0..17c0ca8 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -137,7 +137,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, verbose)
self.batchSize = 0
self._gradInput = torch.Tensor()
- self._gradTarget = torch.Tensor() -- dummy
+ self._gradTarget = torch.IntTensor() -- dummy
self.gradInput = {self._gradInput, self._gradTarget}
self:reset()
@@ -160,6 +160,8 @@ function SoftMaxTree:updateOutput(inputTable)
self._nodeBuffer:resize(self.maxFamily)
self._multiBuffer:resize(input:size(1)*self.maxFamilyPath)
self.batchSize = input:size(1)
+ -- so that it works within nn.ConcatTable :
+ self._gradTarget:resizeAs(target):zero()
if self._nodeUpdateHost then
self._nodeUpdateHost:resize(input:size(1),self.maxDept)
self._nodeUpdateCuda:resize(input:size(1),self.maxDept)