Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgchanan <gregchanan@gmail.com>2016-12-29 22:23:45 +0300
committerSoumith Chintala <soumith@gmail.com>2016-12-29 22:23:45 +0300
commit1ac06689dba1a4a672ed1fb3c3117000a46d7af5 (patch)
tree5edb7b33c8ef506cbe488fe5d83de6349e42aaeb /init.lua
parent6b763fd55f9919ec2f1ccf58c962213e6fb755ea (diff)
Add THHalfTensor support to cutorch (#655)
* Add THHalfTensor support to cutorch.
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua18
1 files changed, 16 insertions, 2 deletions
diff --git a/init.lua b/init.lua
index 09f00e3..fdb7b08 100644
--- a/init.lua
+++ b/init.lua
@@ -16,8 +16,8 @@ torch.CudaTensor.__tostring__ = torch.FloatTensor.__tostring__
torch.CudaDoubleStorage.__tostring__ = torch.DoubleStorage.__tostring__
torch.CudaDoubleTensor.__tostring__ = torch.DoubleTensor.__tostring__
if cutorch.hasHalf then
- torch.CudaHalfStorage.__tostring__ = torch.FloatStorage.__tostring__
- torch.CudaHalfTensor.__tostring__ = torch.FloatTensor.__tostring__
+ torch.CudaHalfStorage.__tostring__ = torch.HalfStorage.__tostring__
+ torch.CudaHalfTensor.__tostring__ = torch.HalfTensor.__tostring__
end
require('cutorch.Tensor')
@@ -57,6 +57,20 @@ function cutorch.createCudaHostTensor(...)
return torch.FloatTensor(storage, 1, size:storage())
end
+function cutorch.createCudaHostDoubleTensor(...)
+ local size = longTensorSize(...)
+ local storage = torch.DoubleStorage(cutorch.CudaHostAllocator, size:prod())
+ return torch.DoubleTensor(storage, 1, size:storage())
+end
+
+if cutorch.hasHalf then
+ function cutorch.createCudaHostHalfTensor(...)
+ local size = longTensorSize(...)
+ local storage = torch.HalfStorage(cutorch.CudaHostAllocator, size:prod())
+ return torch.HalfTensor(storage, 1, size:storage())
+ end
+ end
+
-- Creates a CudaTensor using the CudaUVAAllocator.
-- Accepts either a LongStorage or a sequence of numbers.
local function _createUVATensor(...)