diff options
author | soumith <soumith@fb.com> | 2014-11-18 05:59:09 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2014-11-18 05:59:09 +0300 |
commit | 56b6d5426509b4d0bef7d2648fad72ab4c122c84 (patch) | |
tree | db8a21f36fe03093c0b383a5cf6523ab4e97de13 /init.lua | |
parent | 7b21377ffe067a86917715f522eb544239c2ec6c (diff) |
adding non-batch mode
Diffstat (limited to 'init.lua')
-rw-r--r-- | init.lua | 9 |
1 files changed, 5 insertions, 4 deletions
@@ -27,25 +27,26 @@ local function destroy(handle) local currentDevice = cutorch.getDevice() for i=1,numDevices do cutorch.setDevice(i) - errcheck('cudnnDestroy', handle[i-1]); + errcheck('cudnnDestroy', handle[i-1]); end cutorch.setDevice(currentDevice) end ffi.gc(cudnn.handle, destroy) function cudnn.toDescriptor(t) + if t:dim() == 3 then t = t:view(1, t:size(1), t:size(2), t:size(3)) end assert(t:dim() == 4); assert(torch.typename(t) == 'torch.CudaTensor') local descriptor = ffi.new('struct cudnnTensor4dStruct*[1]') -- create descriptor errcheck('cudnnCreateTensor4dDescriptor', descriptor) -- set gc hook - local function destroy(d) - errcheck('cudnnDestroyTensor4dDescriptor', d[0]); + local function destroy(d) + errcheck('cudnnDestroyTensor4dDescriptor', d[0]); end ffi.gc(descriptor, destroy) -- set descriptor - errcheck('cudnnSetTensor4dDescriptorEx', descriptor[0], 'CUDNN_DATA_FLOAT', + errcheck('cudnnSetTensor4dDescriptorEx', descriptor[0], 'CUDNN_DATA_FLOAT', t:size(1), t:size(2), t:size(3), t:size(4), t:stride(1), t:stride(2), t:stride(3), t:stride(4)) return descriptor |