Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-05-28 06:52:34 +0400
committernicholas-leonard <nick@nikopia.org>2014-05-28 06:52:34 +0400
commit5ac4161d9fd1294e6101fe4f4a73881d335ab334 (patch)
treefa3815fb68d4438c1faee8080506d4d5a21456b0 /SoftMaxTree.lua
parentcbbc338950d79fe65ed3f9b5d986bcb8cb1b8407 (diff)
added CudaTensor support to SoftMaxTree:type()
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r--SoftMaxTree.lua12
1 files changed, 8 insertions, 4 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index 8edc94b..415204f 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -194,16 +194,20 @@ function SoftMaxTree:zeroGradParameters(partial)
end
function SoftMaxTree:type(type)
- if type and (type == 'torch.FloatTensor' or type == 'torch.DoubleTensor') then
+ if type and (type == 'torch.FloatTensor' or type == 'torch.DoubleTensor' or type == 'torch.CudaTensor') then
self.weight = self.weight:type(type)
self.bias = self.bias:type(type)
self.gradWeight = self.gradWeight:type(type)
self.gradBias = self.gradBias:type(type)
- self._linearOutput = self._linearOutput:type(type)
- self._linearGradOutput = self._linearGradOutput:type(type)
- self._logSoftMaxOutput = self._logSoftMaxOutput:type(type)
+ self._nodeBuffer = self._nodeBuffer:type(type)
+ self._multiBuffer = self._multiBuffer:type(type)
self.output = self.output:type(type)
self.gradInput = self.gradInput:type(type)
+ if (type == 'torch.CudaTensor') then
+ -- we need these to be both on GPU and CPU
+ self.parentChildren_d = self.parentChildren:type(type)
+ self.childParent_d = self.childParent:type(type)
+ end
end
return self
end