diff options
author | Gregory Chanan <gchanan@fb.com> | 2016-12-29 00:06:00 +0300 |
---|---|---|
committer | Gregory Chanan <gchanan@fb.com> | 2016-12-29 00:06:00 +0300 |
commit | 4a21b71b58f30cc7ce7474f1db54e1d6d55a41df (patch) | |
tree | f49d31b31baff1f3d18ee29b15d1da394eb4fc16 | |
parent | 4173b226cf3a0cccc5551a68e45c2afb29ffc9b3 (diff) |
Add support for Half tensors to DataParallelTable.
-rw-r--r-- | DataParallelTable.lua | 4 | ||||
-rw-r--r-- | test_DataParallelTable.lua | 2 |
2 files changed, 3 insertions, 3 deletions
diff --git a/DataParallelTable.lua b/DataParallelTable.lua index e0194d4..9e07978 100644 --- a/DataParallelTable.lua +++ b/DataParallelTable.lua @@ -550,8 +550,8 @@ function DataParallelTable:_distributeTensorRecursive(dst, src, idx, n) assert(torch.isTensor(src), 'input must be a tensor or table of tensors') if self.typeStr == 'torch.CudaHalfTensor' then - assert(false, - 'Half Tensors not supported yet by DataParallelTable') + assert(src:type() == self.typeStr or src:type() == 'torch.HalfTensor', + 'input must be a CudaHalf or Half tensor') elseif self.typeStr == 'torch.CudaDoubleTensor' then assert(src:type() == self.typeStr or src:type() == 'torch.DoubleTensor', 'input must be a CudaDouble or Double tensor') diff --git a/test_DataParallelTable.lua b/test_DataParallelTable.lua index 2b25cf2..ec91b78 100644 --- a/test_DataParallelTable.lua +++ b/test_DataParallelTable.lua @@ -25,7 +25,7 @@ local t2cpu = { local function checkHalf() if cutorch.hasHalf then table.insert(typenames, 'torch.CudaHalfTensor') - t2cpu['torch.CudaHalfTensor'] = 'torch.FloatTensor' + t2cpu['torch.CudaHalfTensor'] = 'torch.HalfTensor' end end |