diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-06-04 20:22:26 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-06-04 20:22:26 +0400 |
commit | 175ce432b77108672e33ae67caae1ffa50afd062 (patch) | |
tree | 96b2021f15992a0ceca1363fcd66055e6df7f7f0 /SoftMaxTree.lua | |
parent | 57d7316d80bc6611a7c6d4088283f3117dc0e569 (diff) |
fixed SoftMaxTree:type() bug
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r-- | SoftMaxTree.lua | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index fc2b55c..4fa2d9a 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -242,16 +242,18 @@ function SoftMaxTree:type(type) self._nodeBuffer = self._nodeBuffer:type(type) self._multiBuffer = self._multiBuffer:type(type) self.output = self.output:type(type) - self.gradInput = self.gradInput:type(type) - self.parentChildren = self.parentChildren:type(type) - self.childParent = self.childParent:type(type) + self.gradInput = self.gradInput:type(type) if (type == 'torch.CudaTensor') then -- cunnx needs this for filling self.updates self._nodeUpdateHost = torch.IntTensor() self._nodeUpdateCuda = torch.CudaTensor() + self.parentChildren = self.parentChildren:type(type) + self.childParent = self.childParent:type(type) elseif self.nodeUpdateHost then self._nodeUpdateHost = nil self._nodeUpdateCuda = nil + self.parentChildren = self.parentChildren:type('torch.IntTensor') + self.childParent = self.childParent:type('torch.IntTensor') end self.batchSize = 0 --so that buffers are resized end |