diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-15 21:38:16 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-15 21:38:16 +0400 |
commit | 91b9753dc1ad5254940ec7dce30d4962c5489ba9 (patch) | |
tree | 8ba6ec9054e765c0a67771bf035f217094ca6234 /SoftMaxTree.lua | |
parent | 4e0fbda0f0e8e0c691cf189a8a7550ed65bdf184 (diff) |
SoftMaxTree works with ConcatTable
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r-- | SoftMaxTree.lua | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index 6b783a0..17c0ca8 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -137,7 +137,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, verbose) self.batchSize = 0 self._gradInput = torch.Tensor() - self._gradTarget = torch.Tensor() -- dummy + self._gradTarget = torch.IntTensor() -- dummy self.gradInput = {self._gradInput, self._gradTarget} self:reset() @@ -160,6 +160,8 @@ function SoftMaxTree:updateOutput(inputTable) self._nodeBuffer:resize(self.maxFamily) self._multiBuffer:resize(input:size(1)*self.maxFamilyPath) self.batchSize = input:size(1) + -- so that it works within nn.ConcatTable : + self._gradTarget:resizeAs(target):zero() if self._nodeUpdateHost then self._nodeUpdateHost:resize(input:size(1),self.maxDept) self._nodeUpdateCuda:resize(input:size(1),self.maxDept) |