Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2014-11-18 05:59:09 +0300
committersoumith <soumith@fb.com>2014-11-18 05:59:09 +0300
commit56b6d5426509b4d0bef7d2648fad72ab4c122c84 (patch)
treedb8a21f36fe03093c0b383a5cf6523ab4e97de13 /init.lua
parent7b21377ffe067a86917715f522eb544239c2ec6c (diff)
adding non-batch mode
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua9
1 files changed, 5 insertions, 4 deletions
diff --git a/init.lua b/init.lua
index c27aec5..66fb73d 100644
--- a/init.lua
+++ b/init.lua
@@ -27,25 +27,26 @@ local function destroy(handle)
local currentDevice = cutorch.getDevice()
for i=1,numDevices do
cutorch.setDevice(i)
- errcheck('cudnnDestroy', handle[i-1]);
+ errcheck('cudnnDestroy', handle[i-1]);
end
cutorch.setDevice(currentDevice)
end
ffi.gc(cudnn.handle, destroy)
function cudnn.toDescriptor(t)
+ if t:dim() == 3 then t = t:view(1, t:size(1), t:size(2), t:size(3)) end
assert(t:dim() == 4);
assert(torch.typename(t) == 'torch.CudaTensor')
local descriptor = ffi.new('struct cudnnTensor4dStruct*[1]')
-- create descriptor
errcheck('cudnnCreateTensor4dDescriptor', descriptor)
-- set gc hook
- local function destroy(d)
- errcheck('cudnnDestroyTensor4dDescriptor', d[0]);
+ local function destroy(d)
+ errcheck('cudnnDestroyTensor4dDescriptor', d[0]);
end
ffi.gc(descriptor, destroy)
-- set descriptor
- errcheck('cudnnSetTensor4dDescriptorEx', descriptor[0], 'CUDNN_DATA_FLOAT',
+ errcheck('cudnnSetTensor4dDescriptorEx', descriptor[0], 'CUDNN_DATA_FLOAT',
t:size(1), t:size(2), t:size(3), t:size(4),
t:stride(1), t:stride(2), t:stride(3), t:stride(4))
return descriptor