diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-05-28 06:52:34 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-05-28 06:52:34 +0400 |
commit | 5ac4161d9fd1294e6101fe4f4a73881d335ab334 (patch) | |
tree | fa3815fb68d4438c1faee8080506d4d5a21456b0 /SoftMaxTree.lua | |
parent | cbbc338950d79fe65ed3f9b5d986bcb8cb1b8407 (diff) |
added CudaTensor support to SoftMaxTree:type()
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r-- | SoftMaxTree.lua | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index 8edc94b..415204f 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -194,16 +194,20 @@ function SoftMaxTree:zeroGradParameters(partial) end function SoftMaxTree:type(type) - if type and (type == 'torch.FloatTensor' or type == 'torch.DoubleTensor') then + if type and (type == 'torch.FloatTensor' or type == 'torch.DoubleTensor' or type == 'torch.CudaTensor') then self.weight = self.weight:type(type) self.bias = self.bias:type(type) self.gradWeight = self.gradWeight:type(type) self.gradBias = self.gradBias:type(type) - self._linearOutput = self._linearOutput:type(type) - self._linearGradOutput = self._linearGradOutput:type(type) - self._logSoftMaxOutput = self._logSoftMaxOutput:type(type) + self._nodeBuffer = self._nodeBuffer:type(type) + self._multiBuffer = self._multiBuffer:type(type) self.output = self.output:type(type) self.gradInput = self.gradInput:type(type) + if (type == 'torch.CudaTensor') then + -- we need these to be both on GPU and CPU + self.parentChildren_d = self.parentChildren:type(type) + self.childParent_d = self.childParent:type(type) + end end return self end |