diff options
author | nicholas-leonard <nick@nikopia.org> | 2015-05-27 19:25:51 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2015-05-27 19:25:51 +0300 |
commit | cece52f783fcf87b8fb6fb371d6f47fc19607964 (patch) | |
tree | 2f4bc537f0fba0a28712263b45762e0e6cb17cff | |
parent | 821674533a3077fd79061d50e73aca3055582edd (diff) |
SoftMaxTree:type() is compatible with dpnn
-rw-r--r-- | SoftMaxTree.lua | 78 |
1 files changed, 49 insertions, 29 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index 6633359..27bc1f7 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -272,36 +272,56 @@ function SoftMaxTree:zeroGradParameters() end function SoftMaxTree:type(type) - if type and (type == 'torch.FloatTensor' or type == 'torch.DoubleTensor' or type == 'torch.CudaTensor') then - self.weight = self.weight:type(type) - self.bias = self.bias:type(type) - if not self.accUpdate then - self.gradWeight = self.gradWeight:type(type) - self.gradBias = self.gradBias:type(type) - end - self._nodeBuffer = self._nodeBuffer:type(type) - self._multiBuffer = self._multiBuffer:type(type) - self.output = self.output:type(type) - self._gradInput = self._gradInput:type(type) - if (type == 'torch.CudaTensor') then - -- cunnx needs this for filling self.updates - self._nodeUpdateHost = torch.IntTensor() - self._nodeUpdateCuda = torch.CudaTensor() - self._paramUpdateHost = torch.IntTensor() - self._paramUpdateCuda = torch.CudaTensor() - self.parentChildrenCuda = self.parentChildren:type(type) - self.childParentCuda = self.childParent:type(type) - self._gradTarget = self._gradTarget:type(type) - elseif self.nodeUpdateHost then - self._nodeUpdateHost = nil - self._nodeUpdateCuda = nil - self.parentChildren = self.parentChildren:type('torch.IntTensor') - self.childParent = self.childParent:type('torch.IntTensor') - self._gradTarget = self._gradTarget:type('torch.IntTensor') - end - self.gradInput = {self._gradInput, self._gradTarget} - self.batchSize = 0 --so that buffers are resized + if type == torch.type(self.weight) then + return self + end + + local hierarchy = self.hierarchy + self.hierarchy = nil + self._nodeUpdateHost = nil + self._nodeUpdateCuda = nil + self._paramUpdateHost = nil + self._paramUpdateCuda = nil + local parentChildren = self.parentChildren + self.parentChildren = nil + self.parentChildrenCuda = nil + local childParent = self.childParent + self.childParent = nil + self.childParentCuda = nil + local _gradTarget = self._gradTarget + self._gradTarget = nil + local childIds = self.childIds + self.childIds = nil + local parentIds = self.parentIds + self.parentIds = nil + + parent.type(self, type) + + self.hierarchy = hierarchy + self.parentChildren = parentChildren + self.childParent = childParent + self._gradTarget = _gradTarget + self.childIds = childIds + self.parentIds = parentIds + + if (type == 'torch.CudaTensor') then + -- cunnx needs this for filling self.updates + self._nodeUpdateHost = torch.IntTensor() + self._nodeUpdateCuda = torch.CudaTensor() + self._paramUpdateHost = torch.IntTensor() + self._paramUpdateCuda = torch.CudaTensor() + self.parentChildrenCuda = self.parentChildren:type(type) + self.childParentCuda = self.childParent:type(type) + self._gradTarget = self._gradTarget:type(type) + elseif self._nodeUpdateHost then + self._nodeUpdateHost = nil + self._nodeUpdateCuda = nil + self.parentChildren = self.parentChildren:type('torch.IntTensor') + self.childParent = self.childParent:type('torch.IntTensor') + self._gradTarget = self._gradTarget:type('torch.IntTensor') end + self.gradInput = {self._gradInput, self._gradTarget} + self.batchSize = 0 --so that buffers are resized return self end |