Welcome to mirror list, hosted at ThFree Co, Russian Federation.

init.lua - github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 5867a9980cf9637a6c31ba46582ee60e1f689097 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
require 'cutorch'
require 'nn'
cudnn = {}
include 'ffi.lua'
local C = cudnn.C
local ffi = require 'ffi'

local errcheck = function(f, ...)
   local status = C[f](...)
   if status ~= 'CUDNN_STATUS_SUCCESS' then
      local str = ffi.string(C.cudnnGetErrorString(status))
      error('Error in CuDNN: ' .. str)
   end
end
cudnn.errcheck = errcheck

local numDevices = cutorch.getDeviceCount()
local currentDevice = cutorch.getDevice()
cudnn.handle = ffi.new('struct cudnnContext*[?]', numDevices)
-- create handle
for i=1,numDevices do
   cutorch.setDevice(i)
   errcheck('cudnnCreate', cudnn.handle+i-1)
end
cutorch.setDevice(currentDevice)

local function destroy(handle)
   local currentDevice = cutorch.getDevice()
   for i=1,numDevices do
      cutorch.setDevice(i)
      errcheck('cudnnDestroy', handle[i-1]);
   end
   cutorch.setDevice(currentDevice)
end
ffi.gc(cudnn.handle, destroy)

function cudnn.toDescriptor(t)
   assert(torch.typename(t) == 'torch.CudaTensor')
   local descriptor = ffi.new('struct cudnnTensorStruct*[1]')
   -- create descriptor
   errcheck('cudnnCreateTensorDescriptor', descriptor)
   -- set gc hook
   local function destroy(d)
      errcheck('cudnnDestroyTensorDescriptor', d[0]);
   end
   ffi.gc(descriptor, destroy)
   -- set descriptor
   local size = torch.LongTensor(t:size()):int()
   local stride = torch.LongTensor(t:stride()):int()
   errcheck('cudnnSetTensorNdDescriptor', descriptor[0], 'CUDNN_DATA_FLOAT',
            t:dim(), size:data(), stride:data())
   return descriptor
end

include 'SpatialConvolution.lua'
include 'VolumetricConvolution.lua'
include 'Pooling.lua'
include 'SpatialMaxPooling.lua'
include 'SpatialAveragePooling.lua'
include 'Pointwise.lua'
include 'ReLU.lua'
include 'Tanh.lua'
include 'Sigmoid.lua'
include 'SpatialSoftMax.lua'
include 'SoftMax.lua'