Welcome to mirror list, hosted at ThFree Co, Russian Federation.

FFInterface.lua - github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: cb8bd3365bfe382bd5d1810086a3e139eab6ef73 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
-- if this causes issues, you may need to:
-- luarocks remove --force ffi
-- and follow instructions to install
-- https://github.com/facebook/luaffifb
local ok, ffi = pcall(require, 'ffi')

local function checkArgument(condition, fn, ud, msg, level)
   local level = level or 3
   if not condition then
      error("bad argument #" .. ud .. " to '" .. fn .. "' (" .. msg .. ")", level)
   end
end

local function checkArgumentType(expected, actual, fn, ud, level)
   local level = level or 3
   if expected ~= actual then
      checkArgument(false, fn, ud, expected .. " expected, got " .. actual, level + 1)
   end
end

if ok then

   local Real2real = {
      Byte='unsigned char',
      Char='char',
      Short='short',
      Int='int',
      Long='long',
      Float='float',
      Double='double',
      Half='THHalf'
   }

   -- Allocator
   ffi.cdef[[
typedef struct THAllocator {
  void* (*malloc)(void*, ptrdiff_t);
  void* (*realloc)(void*, void*, ptrdiff_t);
  void (*free)(void*, void*);
} THAllocator;
]]

   -- Half
   ffi.cdef[[
typedef struct {
  unsigned short x;
} __THHalf;
typedef __THHalf THHalf;
]]

   -- Storage
   for Real, real in pairs(Real2real) do

      local cdefs = [[
typedef struct THRealStorage
{
    real *data;
    ptrdiff_t size;
    int refcount;
    char flag;
    THAllocator *allocator;
    void *allocatorContext;
} THRealStorage;
]]
      cdefs = cdefs:gsub('Real', Real):gsub('real', real)
      ffi.cdef(cdefs)

      local Storage = torch.getmetatable(string.format('torch.%sStorage', Real))
      local Storage_tt = ffi.typeof('TH' .. Real .. 'Storage**')

      rawset(Storage,
             "cdata",
             function(self)
                return Storage_tt(self)[0]
             end)

      rawset(Storage,
             "data",
             function(self)
                return Storage_tt(self)[0].data
             end)
   end

   -- Tensor
   for Real, real in pairs(Real2real) do

      local cdefs = [[
typedef struct THRealTensor
{
    long *size;
    long *stride;
    int nDimension;

    THRealStorage *storage;
    ptrdiff_t storageOffset;
    int refcount;

    char flag;

} THRealTensor;
]]
      cdefs = cdefs:gsub('Real', Real):gsub('real', real)
      ffi.cdef(cdefs)

      local Tensor_type = string.format('torch.%sTensor', Real)
      local Tensor = torch.getmetatable(Tensor_type)
      local Tensor_tt = ffi.typeof('TH' .. Real .. 'Tensor**')

      rawset(Tensor,
             "cdata",
             function(self)
                if not self then return nil; end
                return Tensor_tt(self)[0]
             end)

      rawset(Tensor,
             "data",
             function(self)
                if not self then return nil; end
                self = Tensor_tt(self)[0]
                return self.storage ~= nil and self.storage.data + self.storageOffset or nil
             end)

      -- faster apply (contiguous case)
      if Tensor_type ~= 'torch.HalfTensor' then
         local apply = Tensor.apply
         rawset(Tensor,
                "apply",
                function(self, func)
                   if self:isContiguous() and self.data then
                      local self_d = self:data()
                      for i=0,self:nElement()-1 do
                         local res = func(tonumber(self_d[i])) -- tonumber() required for long...
                         if res then
                            self_d[i] = res
                         end
                      end
                      return self
                   else
                      return apply(self, func)
                   end
                end)

         -- faster map (contiguous case)
         local map = Tensor.map
         rawset(Tensor,
                "map",
                function(self, src, func)
                   checkArgument(torch.isTensor(src), "map", 1, "tensor expected")
                   checkArgumentType(self:type(), src:type(), "map", 1)

                   if self:isContiguous() and src:isContiguous() and self.data and src.data then
                      local self_d = self:data()
                      local src_d = src:data()
                      assert(src:nElement() == self:nElement(), 'size mismatch')
                      for i=0,self:nElement()-1 do
                         local res = func(tonumber(self_d[i]), tonumber(src_d[i])) -- tonumber() required for long...
                         if res then
                            self_d[i] = res
                         end
                      end
                      return self
                   else
                      return map(self, src, func)
                   end
                end)

         -- faster map2 (contiguous case)
         local map2 = Tensor.map2
         rawset(Tensor,
                "map2",
                function(self, src1, src2, func)
                   checkArgument(torch.isTensor(src1), "map", 1, "tensor expected")
                   checkArgument(torch.isTensor(src2), "map", 2, "tensor expected")
                   checkArgumentType(self:type(), src1:type(), "map", 1)
                   checkArgumentType(self:type(), src2:type(), "map", 2)

                   if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then
                      local self_d = self:data()
                     local src1_d = src1:data()
                      local src2_d = src2:data()
                      assert(src1:nElement() == self:nElement(), 'size mismatch')
                      assert(src2:nElement() == self:nElement(), 'size mismatch')
                      for i=0,self:nElement()-1 do
                         local res = func(tonumber(self_d[i]), tonumber(src1_d[i]), tonumber(src2_d[i])) -- tonumber() required for long...
                         if res then
                            self_d[i] = res
                         end
                      end
                      return self
                   else
                      return map2(self, src1, src2, func)
                   end
                end)
             end
   end

   -- torch.data
   -- will fail if :data() is not defined
   function torch.data(self, asnumber)
      if not self then return nil; end
      local data = self:data()
      if asnumber then
         return ffi.cast('intptr_t', data)
      else
         return data
      end
   end

   -- torch.cdata
   -- will fail if :cdata() is not defined
   function torch.cdata(self, asnumber)
      if not self then return nil; end
      local cdata = self:cdata()
      if asnumber then
         return ffi.cast('intptr_t', cdata)
      else
         return cdata
      end
   end

end