diff options
author | Nicholas LĂ©onard <nick@nikopia.org> | 2017-04-29 01:33:08 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-04-29 01:33:08 +0300 |
commit | 870a99f06dd236b607347b77ed8a077c5ca2275e (patch) | |
tree | 3f83d7ce8877bddba2dbdaf0d002eca8de0a6b5e /FFInterface.lua | |
parent | e8c556d35c97029531c0d1449fcf418728f835e3 (diff) |
fix cdiv/div unit tests; fix Mac OS X require ffi bug (#1016)
Diffstat (limited to 'FFInterface.lua')
-rw-r--r-- | FFInterface.lua | 222 |
1 files changed, 222 insertions, 0 deletions
diff --git a/FFInterface.lua b/FFInterface.lua new file mode 100644 index 0000000..cb8bd33 --- /dev/null +++ b/FFInterface.lua @@ -0,0 +1,222 @@ +-- if this causes issues, you may need to: +-- luarocks remove --force ffi +-- and follow instructions to install +-- https://github.com/facebook/luaffifb +local ok, ffi = pcall(require, 'ffi') + +local function checkArgument(condition, fn, ud, msg, level) + local level = level or 3 + if not condition then + error("bad argument #" .. ud .. " to '" .. fn .. "' (" .. msg .. ")", level) + end +end + +local function checkArgumentType(expected, actual, fn, ud, level) + local level = level or 3 + if expected ~= actual then + checkArgument(false, fn, ud, expected .. " expected, got " .. actual, level + 1) + end +end + +if ok then + + local Real2real = { + Byte='unsigned char', + Char='char', + Short='short', + Int='int', + Long='long', + Float='float', + Double='double', + Half='THHalf' + } + + -- Allocator + ffi.cdef[[ +typedef struct THAllocator { + void* (*malloc)(void*, ptrdiff_t); + void* (*realloc)(void*, void*, ptrdiff_t); + void (*free)(void*, void*); +} THAllocator; +]] + + -- Half + ffi.cdef[[ +typedef struct { + unsigned short x; +} __THHalf; +typedef __THHalf THHalf; +]] + + -- Storage + for Real, real in pairs(Real2real) do + + local cdefs = [[ +typedef struct THRealStorage +{ + real *data; + ptrdiff_t size; + int refcount; + char flag; + THAllocator *allocator; + void *allocatorContext; +} THRealStorage; +]] + cdefs = cdefs:gsub('Real', Real):gsub('real', real) + ffi.cdef(cdefs) + + local Storage = torch.getmetatable(string.format('torch.%sStorage', Real)) + local Storage_tt = ffi.typeof('TH' .. Real .. 'Storage**') + + rawset(Storage, + "cdata", + function(self) + return Storage_tt(self)[0] + end) + + rawset(Storage, + "data", + function(self) + return Storage_tt(self)[0].data + end) + end + + -- Tensor + for Real, real in pairs(Real2real) do + + local cdefs = [[ +typedef struct THRealTensor +{ + long *size; + long *stride; + int nDimension; + + THRealStorage *storage; + ptrdiff_t storageOffset; + int refcount; + + char flag; + +} THRealTensor; +]] + cdefs = cdefs:gsub('Real', Real):gsub('real', real) + ffi.cdef(cdefs) + + local Tensor_type = string.format('torch.%sTensor', Real) + local Tensor = torch.getmetatable(Tensor_type) + local Tensor_tt = ffi.typeof('TH' .. Real .. 'Tensor**') + + rawset(Tensor, + "cdata", + function(self) + if not self then return nil; end + return Tensor_tt(self)[0] + end) + + rawset(Tensor, + "data", + function(self) + if not self then return nil; end + self = Tensor_tt(self)[0] + return self.storage ~= nil and self.storage.data + self.storageOffset or nil + end) + + -- faster apply (contiguous case) + if Tensor_type ~= 'torch.HalfTensor' then + local apply = Tensor.apply + rawset(Tensor, + "apply", + function(self, func) + if self:isContiguous() and self.data then + local self_d = self:data() + for i=0,self:nElement()-1 do + local res = func(tonumber(self_d[i])) -- tonumber() required for long... + if res then + self_d[i] = res + end + end + return self + else + return apply(self, func) + end + end) + + -- faster map (contiguous case) + local map = Tensor.map + rawset(Tensor, + "map", + function(self, src, func) + checkArgument(torch.isTensor(src), "map", 1, "tensor expected") + checkArgumentType(self:type(), src:type(), "map", 1) + + if self:isContiguous() and src:isContiguous() and self.data and src.data then + local self_d = self:data() + local src_d = src:data() + assert(src:nElement() == self:nElement(), 'size mismatch') + for i=0,self:nElement()-1 do + local res = func(tonumber(self_d[i]), tonumber(src_d[i])) -- tonumber() required for long... + if res then + self_d[i] = res + end + end + return self + else + return map(self, src, func) + end + end) + + -- faster map2 (contiguous case) + local map2 = Tensor.map2 + rawset(Tensor, + "map2", + function(self, src1, src2, func) + checkArgument(torch.isTensor(src1), "map", 1, "tensor expected") + checkArgument(torch.isTensor(src2), "map", 2, "tensor expected") + checkArgumentType(self:type(), src1:type(), "map", 1) + checkArgumentType(self:type(), src2:type(), "map", 2) + + if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then + local self_d = self:data() + local src1_d = src1:data() + local src2_d = src2:data() + assert(src1:nElement() == self:nElement(), 'size mismatch') + assert(src2:nElement() == self:nElement(), 'size mismatch') + for i=0,self:nElement()-1 do + local res = func(tonumber(self_d[i]), tonumber(src1_d[i]), tonumber(src2_d[i])) -- tonumber() required for long... + if res then + self_d[i] = res + end + end + return self + else + return map2(self, src1, src2, func) + end + end) + end + end + + -- torch.data + -- will fail if :data() is not defined + function torch.data(self, asnumber) + if not self then return nil; end + local data = self:data() + if asnumber then + return ffi.cast('intptr_t', data) + else + return data + end + end + + -- torch.cdata + -- will fail if :cdata() is not defined + function torch.cdata(self, asnumber) + if not self then return nil; end + local cdata = self:cdata() + if asnumber then + return ffi.cast('intptr_t', cdata) + else + return cdata + end + end + +end |