Welcome to mirror list, hosted at ThFree Co, Russian Federation.

FFI.lua - github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0afa3db1321c3fa7a539981724fd062a05377945 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
local ok, ffi = pcall(require, 'ffi')
if ok then

   local cdefs = [[
typedef struct CUstream_st *cudaStream_t;

struct cublasContext;
typedef struct cublasContext *cublasHandle_t;
typedef struct CUhandle_st *cublasHandle_t;

typedef struct _THCCudaResourcesPerDevice {
  cudaStream_t* streams;
  cublasHandle_t* blasHandles;
  size_t scratchSpacePerStream;
  void** devScratchSpacePerStream;
} THCCudaResourcesPerDevice;


typedef struct THCState
{
  struct THCRNGState* rngState;
  struct cudaDeviceProp* deviceProperties;
  cudaStream_t currentStream;
  cublasHandle_t currentBlasHandle;
  THCCudaResourcesPerDevice* resourcesPerDevice;
  int numDevices;
  int numUserStreams;
  int numUserBlasHandles;
  int currentPerDeviceStream;
  int currentPerDeviceBlasHandle;
  struct THAllocator* cudaHostAllocator;
} THCState;

cudaStream_t THCState_getCurrentStream(THCState *state);

typedef struct THCudaStorage
{
    float *data;
    long size;
    int refcount;
    char flag;
    THAllocator *allocator;
    void *allocatorContext;
    struct THCudaStorage *view;
} THCudaStorage;

typedef struct THCudaTensor
{
    long *size;
    long *stride;
    int nDimension;

    THCudaStorage *storage;
    long storageOffset;
    int refcount;

    char flag;

} THCudaTensor;
]]
   ffi.cdef(cdefs)

   local Storage = torch.getmetatable('torch.CudaStorage')
   local Storage_tt = ffi.typeof('THCudaStorage**')

   rawset(Storage, "cdata", function(self) return Storage_tt(self)[0] end)
   rawset(Storage, "data", function(self) return Storage_tt(self)[0].data end)
   -- Tensor
   local Tensor = torch.getmetatable('torch.CudaTensor')
   local Tensor_tt = ffi.typeof('THCudaTensor**')

   rawset(Tensor, "cdata", function(self) return Tensor_tt(self)[0] end)

   rawset(Tensor, "data",
          function(self)
             self = Tensor_tt(self)[0]
             return self.storage ~= nil and self.storage.data + self.storageOffset or nil
          end
   )

end