Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominik Grewe <dominikg@google.com>2015-10-15 21:32:10 +0300
committerDominik Grewe <dominikg@google.com>2015-10-15 21:32:10 +0300
commitf30f26ae215620fb5e4bf71cd4e092b869e6fa8f (patch)
treeb7b06a5c273ef266de5faa5add6b3c002f60af12 /init.lua
parent17fb5a6c467100e7ed8746809a089ad9271cef6b (diff)
Allow passing no arguments to cutorch.createCudaHostTensor
Still creates a storage because of the non-default allocator.
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua14
1 files changed, 11 insertions, 3 deletions
diff --git a/init.lua b/init.lua
index c39fd7a..50af27e 100644
--- a/init.lua
+++ b/init.lua
@@ -24,9 +24,17 @@ end
-- Creates a FloatTensor using the CudaHostAllocator.
-- Accepts either a LongStorage or a sequence of numbers.
function cutorch.createCudaHostTensor(...)
- local size = torch.LongTensor(torch.isStorage(...) and ... or {...})
- local storage = torch.FloatStorage(cutorch.CudaHostAllocator, size:prod())
- return torch.FloatTensor(storage, 1, size:storage())
+ local size
+ if not ... then
+ size = torch.LongTensor{0}
+ elseif torch.isStorage(...) then
+ size = torch.LongTensor(...)
+ else
+ size = torch.LongTensor{...}
+ end
+
+ local storage = torch.FloatStorage(cutorch.CudaHostAllocator, size:prod())
+ return torch.FloatTensor(storage, 1, size:storage())
end
cutorch.setHeapTracking(true)