diff options
author | Greg Heinrich <gheinrich@nvidia.com> | 2015-11-16 22:53:46 +0300 |
---|---|---|
committer | Greg Heinrich <gheinrich@nvidia.com> | 2015-11-16 22:53:46 +0300 |
commit | e714f9f6a85e8058cdb1e3a9e8c991fabf5ebddf (patch) | |
tree | 71a6bc7abe0b4c619c5c407ecc61bfec711219bd /DataParallelTable.lua | |
parent | 8089e0213e8751b04400e51b9e8067a1126085f7 (diff) |
Allow move to CUDA from top-level module
When defining a network in Torch it is customary to move it
to the currently selected GPU through a single call to the
top-level module's cuda() method.
This patch allows this by merely asserting in the type()
method that the target type is 'torch.CudaTensor', rather
that dropping an error unconditionally.
In this patch we are also conveniently moving any added
module to CUDA, since it only makes sense to add CUDA
modules to a DataParallelTable container.
Diffstat (limited to 'DataParallelTable.lua')
-rw-r--r-- | DataParallelTable.lua | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/DataParallelTable.lua b/DataParallelTable.lua index 1e25ac8..edb84c8 100644 --- a/DataParallelTable.lua +++ b/DataParallelTable.lua @@ -210,7 +210,7 @@ function DataParallelTable:add(module, gpuid) assert(gpuid <= cutorch.getDeviceCount() and gpuid >= 1) assert(#self.modules == #self.gpuAssignments) - self.modules[#self.modules + 1] = module + self.modules[#self.modules + 1] = module:cuda() self.gpuAssignments[#self.gpuAssignments + 1] = gpuid return self @@ -482,7 +482,7 @@ function DataParallelTable:name() end function DataParallelTable:type(typeStr) - error("type() not supported for DataParallelTable.") + assert(typeStr == 'torch.CudaTensor', "DataParallelTable supports only torch.CudaTensor type.") end function DataParallelTable:_calculateSliceRange(tensor, id, total) |