Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-12-29 00:06:00 +0300
committerGregory Chanan <gchanan@fb.com>2016-12-29 00:06:00 +0300
commit4a21b71b58f30cc7ce7474f1db54e1d6d55a41df (patch)
treef49d31b31baff1f3d18ee29b15d1da394eb4fc16
parent4173b226cf3a0cccc5551a68e45c2afb29ffc9b3 (diff)
Add support for Half tensors to DataParallelTable.
-rw-r--r--DataParallelTable.lua4
-rw-r--r--test_DataParallelTable.lua2
2 files changed, 3 insertions, 3 deletions
diff --git a/DataParallelTable.lua b/DataParallelTable.lua
index e0194d4..9e07978 100644
--- a/DataParallelTable.lua
+++ b/DataParallelTable.lua
@@ -550,8 +550,8 @@ function DataParallelTable:_distributeTensorRecursive(dst, src, idx, n)
assert(torch.isTensor(src), 'input must be a tensor or table of tensors')
if self.typeStr == 'torch.CudaHalfTensor' then
- assert(false,
- 'Half Tensors not supported yet by DataParallelTable')
+ assert(src:type() == self.typeStr or src:type() == 'torch.HalfTensor',
+ 'input must be a CudaHalf or Half tensor')
elseif self.typeStr == 'torch.CudaDoubleTensor' then
assert(src:type() == self.typeStr or src:type() == 'torch.DoubleTensor',
'input must be a CudaDouble or Double tensor')
diff --git a/test_DataParallelTable.lua b/test_DataParallelTable.lua
index 2b25cf2..ec91b78 100644
--- a/test_DataParallelTable.lua
+++ b/test_DataParallelTable.lua
@@ -25,7 +25,7 @@ local t2cpu = {
local function checkHalf()
if cutorch.hasHalf then
table.insert(typenames, 'torch.CudaHalfTensor')
- t2cpu['torch.CudaHalfTensor'] = 'torch.FloatTensor'
+ t2cpu['torch.CudaHalfTensor'] = 'torch.HalfTensor'
end
end