Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGreg Heinrich <gheinrich@nvidia.com>2015-11-16 22:53:46 +0300
committerGreg Heinrich <gheinrich@nvidia.com>2015-11-16 22:53:46 +0300
commite714f9f6a85e8058cdb1e3a9e8c991fabf5ebddf (patch)
tree71a6bc7abe0b4c619c5c407ecc61bfec711219bd /DataParallelTable.lua
parent8089e0213e8751b04400e51b9e8067a1126085f7 (diff)
Allow move to CUDA from top-level module
When defining a network in Torch it is customary to move it to the currently selected GPU through a single call to the top-level module's cuda() method. This patch allows this by merely asserting in the type() method that the target type is 'torch.CudaTensor', rather that dropping an error unconditionally. In this patch we are also conveniently moving any added module to CUDA, since it only makes sense to add CUDA modules to a DataParallelTable container.
Diffstat (limited to 'DataParallelTable.lua')
-rw-r--r--DataParallelTable.lua4
1 files changed, 2 insertions, 2 deletions
diff --git a/DataParallelTable.lua b/DataParallelTable.lua
index 1e25ac8..edb84c8 100644
--- a/DataParallelTable.lua
+++ b/DataParallelTable.lua
@@ -210,7 +210,7 @@ function DataParallelTable:add(module, gpuid)
assert(gpuid <= cutorch.getDeviceCount() and gpuid >= 1)
assert(#self.modules == #self.gpuAssignments)
- self.modules[#self.modules + 1] = module
+ self.modules[#self.modules + 1] = module:cuda()
self.gpuAssignments[#self.gpuAssignments + 1] = gpuid
return self
@@ -482,7 +482,7 @@ function DataParallelTable:name()
end
function DataParallelTable:type(typeStr)
- error("type() not supported for DataParallelTable.")
+ assert(typeStr == 'torch.CudaTensor', "DataParallelTable supports only torch.CudaTensor type.")
end
function DataParallelTable:_calculateSliceRange(tensor, id, total)