Welcome to mirror list, hosted at ThFree Co, Russian Federation.

FFI.lua - github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: b2777a2b6128af46b3d40a24beb16be3c1546ba7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
local ok, ffi = pcall(require, 'ffi')
if ok then
   local unpack = unpack or table.unpack
   local cdefs = [[
typedef struct CUstream_st *cudaStream_t;

struct cublasContext;
typedef struct cublasContext *cublasHandle_t;
typedef struct CUhandle_st *cublasHandle_t;

typedef struct _THCStream {
   cudaStream_t stream;
   int device;
   int refcount;
} THCStream;


typedef struct _THCCudaResourcesPerDevice {
  THCStream** streams;
  cublasHandle_t* blasHandles;
  size_t scratchSpacePerStream;
  void** devScratchSpacePerStream;
} THCCudaResourcesPerDevice;


typedef struct THCState
{
  struct THCRNGState* rngState;
  struct cudaDeviceProp* deviceProperties;
  THCCudaResourcesPerDevice* resourcesPerDevice;
  int numDevices;
  int numUserStreams;
  int numUserBlasHandles;
  struct THAllocator* cudaHostAllocator;
} THCState;

cudaStream_t THCState_getCurrentStream(THCState *state);

]]

   local CudaTypes = {
      {'float', ''},
      {'unsigned char', 'Byte'},
      {'char', 'Char'},
      {'short', 'Short'},
      {'int', 'Int'},
      {'long','Long'},
      {'double','Double'},
  }
  if cutorch.hasHalf then
      table.insert(CudaTypes, {'half','Half'})
  end

   for _, typedata in ipairs(CudaTypes) do
      local real, Real = unpack(typedata)
      local ctype_def = [[
typedef struct THCStorage
{
    real *data;
    ptrdiff_t size;
    int refcount;
    char flag;
    THAllocator *allocator;
    void *allocatorContext;
    struct THCStorage *view;
    int device;
} THCStorage;

typedef struct THCTensor
{
    long *size;
    long *stride;
    int nDimension;

    THCStorage *storage;
    ptrdiff_t storageOffset;
    int refcount;

    char flag;

} THCTensor;
]]

      ctype_def = ctype_def:gsub('real',real):gsub('THCStorage','THCuda'..Real..'Storage'):gsub('THCTensor','THCuda'..Real..'Tensor')
      cdefs = cdefs .. ctype_def
   end
   if cutorch.hasHalf then
      ffi.cdef([[
typedef struct {
    unsigned short x;
} __half;
typedef __half half;
      ]])
   end
   ffi.cdef(cdefs)

   for _, typedata in ipairs(CudaTypes) do
      local real, Real = unpack(typedata)
      local Storage = torch.getmetatable('torch.Cuda' .. Real .. 'Storage')
      local Storage_tt = ffi.typeof('THCuda' .. Real .. 'Storage**')

      rawset(Storage, "cdata", function(self) return Storage_tt(self)[0] end)
      rawset(Storage, "data", function(self) return Storage_tt(self)[0].data end)
      -- Tensor
      local Tensor = torch.getmetatable('torch.Cuda' .. Real .. 'Tensor')
      local Tensor_tt = ffi.typeof('THCuda' .. Real .. 'Tensor**')

      rawset(Tensor, "cdata", function(self) return Tensor_tt(self)[0] end)

      rawset(Tensor, "data",
             function(self)
                self = Tensor_tt(self)[0]
                return self.storage ~= nil and self.storage.data + self.storageOffset or nil
             end
      )
   end

end