Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-06-04 20:22:26 +0400
committernicholas-leonard <nick@nikopia.org>2014-06-04 20:22:26 +0400
commit175ce432b77108672e33ae67caae1ffa50afd062 (patch)
tree96b2021f15992a0ceca1363fcd66055e6df7f7f0 /SoftMaxTree.lua
parent57d7316d80bc6611a7c6d4088283f3117dc0e569 (diff)
fixed SoftMaxTree:type() bug
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r--SoftMaxTree.lua8
1 files changed, 5 insertions, 3 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index fc2b55c..4fa2d9a 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -242,16 +242,18 @@ function SoftMaxTree:type(type)
self._nodeBuffer = self._nodeBuffer:type(type)
self._multiBuffer = self._multiBuffer:type(type)
self.output = self.output:type(type)
- self.gradInput = self.gradInput:type(type)
- self.parentChildren = self.parentChildren:type(type)
- self.childParent = self.childParent:type(type)
+ self.gradInput = self.gradInput:type(type)
if (type == 'torch.CudaTensor') then
-- cunnx needs this for filling self.updates
self._nodeUpdateHost = torch.IntTensor()
self._nodeUpdateCuda = torch.CudaTensor()
+ self.parentChildren = self.parentChildren:type(type)
+ self.childParent = self.childParent:type(type)
elseif self.nodeUpdateHost then
self._nodeUpdateHost = nil
self._nodeUpdateCuda = nil
+ self.parentChildren = self.parentChildren:type('torch.IntTensor')
+ self.childParent = self.childParent:type('torch.IntTensor')
end
self.batchSize = 0 --so that buffers are resized
end