Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas LĂ©onard <nick@nikopia.org>2017-04-29 01:33:08 +0300
committerSoumith Chintala <soumith@gmail.com>2017-04-29 01:33:08 +0300
commit870a99f06dd236b607347b77ed8a077c5ca2275e (patch)
tree3f83d7ce8877bddba2dbdaf0d002eca8de0a6b5e /FFInterface.lua
parente8c556d35c97029531c0d1449fcf418728f835e3 (diff)
fix cdiv/div unit tests; fix Mac OS X require ffi bug (#1016)
Diffstat (limited to 'FFInterface.lua')
-rw-r--r--FFInterface.lua222
1 files changed, 222 insertions, 0 deletions
diff --git a/FFInterface.lua b/FFInterface.lua
new file mode 100644
index 0000000..cb8bd33
--- /dev/null
+++ b/FFInterface.lua
@@ -0,0 +1,222 @@
+-- if this causes issues, you may need to:
+-- luarocks remove --force ffi
+-- and follow instructions to install
+-- https://github.com/facebook/luaffifb
+local ok, ffi = pcall(require, 'ffi')
+
+local function checkArgument(condition, fn, ud, msg, level)
+ local level = level or 3
+ if not condition then
+ error("bad argument #" .. ud .. " to '" .. fn .. "' (" .. msg .. ")", level)
+ end
+end
+
+local function checkArgumentType(expected, actual, fn, ud, level)
+ local level = level or 3
+ if expected ~= actual then
+ checkArgument(false, fn, ud, expected .. " expected, got " .. actual, level + 1)
+ end
+end
+
+if ok then
+
+ local Real2real = {
+ Byte='unsigned char',
+ Char='char',
+ Short='short',
+ Int='int',
+ Long='long',
+ Float='float',
+ Double='double',
+ Half='THHalf'
+ }
+
+ -- Allocator
+ ffi.cdef[[
+typedef struct THAllocator {
+ void* (*malloc)(void*, ptrdiff_t);
+ void* (*realloc)(void*, void*, ptrdiff_t);
+ void (*free)(void*, void*);
+} THAllocator;
+]]
+
+ -- Half
+ ffi.cdef[[
+typedef struct {
+ unsigned short x;
+} __THHalf;
+typedef __THHalf THHalf;
+]]
+
+ -- Storage
+ for Real, real in pairs(Real2real) do
+
+ local cdefs = [[
+typedef struct THRealStorage
+{
+ real *data;
+ ptrdiff_t size;
+ int refcount;
+ char flag;
+ THAllocator *allocator;
+ void *allocatorContext;
+} THRealStorage;
+]]
+ cdefs = cdefs:gsub('Real', Real):gsub('real', real)
+ ffi.cdef(cdefs)
+
+ local Storage = torch.getmetatable(string.format('torch.%sStorage', Real))
+ local Storage_tt = ffi.typeof('TH' .. Real .. 'Storage**')
+
+ rawset(Storage,
+ "cdata",
+ function(self)
+ return Storage_tt(self)[0]
+ end)
+
+ rawset(Storage,
+ "data",
+ function(self)
+ return Storage_tt(self)[0].data
+ end)
+ end
+
+ -- Tensor
+ for Real, real in pairs(Real2real) do
+
+ local cdefs = [[
+typedef struct THRealTensor
+{
+ long *size;
+ long *stride;
+ int nDimension;
+
+ THRealStorage *storage;
+ ptrdiff_t storageOffset;
+ int refcount;
+
+ char flag;
+
+} THRealTensor;
+]]
+ cdefs = cdefs:gsub('Real', Real):gsub('real', real)
+ ffi.cdef(cdefs)
+
+ local Tensor_type = string.format('torch.%sTensor', Real)
+ local Tensor = torch.getmetatable(Tensor_type)
+ local Tensor_tt = ffi.typeof('TH' .. Real .. 'Tensor**')
+
+ rawset(Tensor,
+ "cdata",
+ function(self)
+ if not self then return nil; end
+ return Tensor_tt(self)[0]
+ end)
+
+ rawset(Tensor,
+ "data",
+ function(self)
+ if not self then return nil; end
+ self = Tensor_tt(self)[0]
+ return self.storage ~= nil and self.storage.data + self.storageOffset or nil
+ end)
+
+ -- faster apply (contiguous case)
+ if Tensor_type ~= 'torch.HalfTensor' then
+ local apply = Tensor.apply
+ rawset(Tensor,
+ "apply",
+ function(self, func)
+ if self:isContiguous() and self.data then
+ local self_d = self:data()
+ for i=0,self:nElement()-1 do
+ local res = func(tonumber(self_d[i])) -- tonumber() required for long...
+ if res then
+ self_d[i] = res
+ end
+ end
+ return self
+ else
+ return apply(self, func)
+ end
+ end)
+
+ -- faster map (contiguous case)
+ local map = Tensor.map
+ rawset(Tensor,
+ "map",
+ function(self, src, func)
+ checkArgument(torch.isTensor(src), "map", 1, "tensor expected")
+ checkArgumentType(self:type(), src:type(), "map", 1)
+
+ if self:isContiguous() and src:isContiguous() and self.data and src.data then
+ local self_d = self:data()
+ local src_d = src:data()
+ assert(src:nElement() == self:nElement(), 'size mismatch')
+ for i=0,self:nElement()-1 do
+ local res = func(tonumber(self_d[i]), tonumber(src_d[i])) -- tonumber() required for long...
+ if res then
+ self_d[i] = res
+ end
+ end
+ return self
+ else
+ return map(self, src, func)
+ end
+ end)
+
+ -- faster map2 (contiguous case)
+ local map2 = Tensor.map2
+ rawset(Tensor,
+ "map2",
+ function(self, src1, src2, func)
+ checkArgument(torch.isTensor(src1), "map", 1, "tensor expected")
+ checkArgument(torch.isTensor(src2), "map", 2, "tensor expected")
+ checkArgumentType(self:type(), src1:type(), "map", 1)
+ checkArgumentType(self:type(), src2:type(), "map", 2)
+
+ if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then
+ local self_d = self:data()
+ local src1_d = src1:data()
+ local src2_d = src2:data()
+ assert(src1:nElement() == self:nElement(), 'size mismatch')
+ assert(src2:nElement() == self:nElement(), 'size mismatch')
+ for i=0,self:nElement()-1 do
+ local res = func(tonumber(self_d[i]), tonumber(src1_d[i]), tonumber(src2_d[i])) -- tonumber() required for long...
+ if res then
+ self_d[i] = res
+ end
+ end
+ return self
+ else
+ return map2(self, src1, src2, func)
+ end
+ end)
+ end
+ end
+
+ -- torch.data
+ -- will fail if :data() is not defined
+ function torch.data(self, asnumber)
+ if not self then return nil; end
+ local data = self:data()
+ if asnumber then
+ return ffi.cast('intptr_t', data)
+ else
+ return data
+ end
+ end
+
+ -- torch.cdata
+ -- will fail if :cdata() is not defined
+ function torch.cdata(self, asnumber)
+ if not self then return nil; end
+ local cdata = self:cdata()
+ if asnumber then
+ return ffi.cast('intptr_t', cdata)
+ else
+ return cdata
+ end
+ end
+
+end