Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-05-27 19:25:51 +0300
committernicholas-leonard <nick@nikopia.org>2015-05-27 19:25:51 +0300
commitcece52f783fcf87b8fb6fb371d6f47fc19607964 (patch)
tree2f4bc537f0fba0a28712263b45762e0e6cb17cff
parent821674533a3077fd79061d50e73aca3055582edd (diff)
SoftMaxTree:type() is compatible with dpnn
-rw-r--r--SoftMaxTree.lua78
1 files changed, 49 insertions, 29 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index 6633359..27bc1f7 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -272,36 +272,56 @@ function SoftMaxTree:zeroGradParameters()
end
function SoftMaxTree:type(type)
- if type and (type == 'torch.FloatTensor' or type == 'torch.DoubleTensor' or type == 'torch.CudaTensor') then
- self.weight = self.weight:type(type)
- self.bias = self.bias:type(type)
- if not self.accUpdate then
- self.gradWeight = self.gradWeight:type(type)
- self.gradBias = self.gradBias:type(type)
- end
- self._nodeBuffer = self._nodeBuffer:type(type)
- self._multiBuffer = self._multiBuffer:type(type)
- self.output = self.output:type(type)
- self._gradInput = self._gradInput:type(type)
- if (type == 'torch.CudaTensor') then
- -- cunnx needs this for filling self.updates
- self._nodeUpdateHost = torch.IntTensor()
- self._nodeUpdateCuda = torch.CudaTensor()
- self._paramUpdateHost = torch.IntTensor()
- self._paramUpdateCuda = torch.CudaTensor()
- self.parentChildrenCuda = self.parentChildren:type(type)
- self.childParentCuda = self.childParent:type(type)
- self._gradTarget = self._gradTarget:type(type)
- elseif self.nodeUpdateHost then
- self._nodeUpdateHost = nil
- self._nodeUpdateCuda = nil
- self.parentChildren = self.parentChildren:type('torch.IntTensor')
- self.childParent = self.childParent:type('torch.IntTensor')
- self._gradTarget = self._gradTarget:type('torch.IntTensor')
- end
- self.gradInput = {self._gradInput, self._gradTarget}
- self.batchSize = 0 --so that buffers are resized
+ if type == torch.type(self.weight) then
+ return self
+ end
+
+ local hierarchy = self.hierarchy
+ self.hierarchy = nil
+ self._nodeUpdateHost = nil
+ self._nodeUpdateCuda = nil
+ self._paramUpdateHost = nil
+ self._paramUpdateCuda = nil
+ local parentChildren = self.parentChildren
+ self.parentChildren = nil
+ self.parentChildrenCuda = nil
+ local childParent = self.childParent
+ self.childParent = nil
+ self.childParentCuda = nil
+ local _gradTarget = self._gradTarget
+ self._gradTarget = nil
+ local childIds = self.childIds
+ self.childIds = nil
+ local parentIds = self.parentIds
+ self.parentIds = nil
+
+ parent.type(self, type)
+
+ self.hierarchy = hierarchy
+ self.parentChildren = parentChildren
+ self.childParent = childParent
+ self._gradTarget = _gradTarget
+ self.childIds = childIds
+ self.parentIds = parentIds
+
+ if (type == 'torch.CudaTensor') then
+ -- cunnx needs this for filling self.updates
+ self._nodeUpdateHost = torch.IntTensor()
+ self._nodeUpdateCuda = torch.CudaTensor()
+ self._paramUpdateHost = torch.IntTensor()
+ self._paramUpdateCuda = torch.CudaTensor()
+ self.parentChildrenCuda = self.parentChildren:type(type)
+ self.childParentCuda = self.childParent:type(type)
+ self._gradTarget = self._gradTarget:type(type)
+ elseif self._nodeUpdateHost then
+ self._nodeUpdateHost = nil
+ self._nodeUpdateCuda = nil
+ self.parentChildren = self.parentChildren:type('torch.IntTensor')
+ self.childParent = self.childParent:type('torch.IntTensor')
+ self._gradTarget = self._gradTarget:type('torch.IntTensor')
end
+ self.gradInput = {self._gradInput, self._gradTarget}
+ self.batchSize = 0 --so that buffers are resized
return self
end