Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/threads-ffi.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2016-04-18 21:57:23 +0300
committerAdam Paszke <adam.paszke@gmail.com>2016-04-19 19:31:22 +0300
commiteea2fc997205899bb994c3389affcca7031933be (patch)
treec5f4fce7fe7c6d773f5d5ac5a9bc98fc39a368ae
parentf22f0f45dbf4a1af60889fe34c9f89c62bbd274e (diff)
Add general methods for pointer serialization in sharedserialize
-rw-r--r--sharedserialize.lua43
1 files changed, 27 insertions, 16 deletions
diff --git a/sharedserialize.lua b/sharedserialize.lua
index 9cd8ebf..17d18b6 100644
--- a/sharedserialize.lua
+++ b/sharedserialize.lua
@@ -1,23 +1,43 @@
require 'torch'
+local ffi = require 'ffi'
local serialize = {}
local typenames = {}
+local function serializePointer(obj, f)
+ -- on 32-bit systems double can represent all possible
+ -- pointer values, but signed long can't
+ if ffi.sizeof('long') == 4 then
+ f:writeDouble(torch.pointer(obj))
+ -- on 64-bit systems, long can represent a larger
+ -- range of integers than double, so it's safer to use this
+ else
+ f:writeLong(torch.pointer(obj))
+ end
+end
+
+local function deserializePointer(f)
+ if ffi.sizeof('long') == 4 then
+ return f:readDouble()
+ else
+ return f:readLong()
+ end
+end
+
-- tds support
local _, tds = pcall(require, 'tds') -- for the free/retain functions
if tds then
- local ffi = require 'ffi'
-- hash
local mt = {}
function mt.__factory(f)
- local self = f:readLong()
+ local self = deserializePointer(f)
self = ffi.cast('tds_hash&', self)
ffi.gc(self, tds.C.tds_hash_free)
return self
end
function mt.__write(self, f)
- f:writeLong(torch.pointer(self))
+ serializePointer(self, f)
tds.C.tds_hash_retain(self)
end
function mt.__read(self, f)
@@ -27,22 +47,13 @@ if tds then
-- vec
local mt = {}
function mt.__factory(f)
- local self
- if ffi.sizeof('long') == 4 then
- self = f:readDouble()
- else
- self = f:readLong()
- end
+ local self = deserializePointer(f)
self = ffi.cast('tds_vec&', self)
ffi.gc(self, tds.C.tds_vec_free)
return self
end
function mt.__write(self, f)
- if ffi.sizeof('long') == 4 then
- f:writeDouble(torch.pointer(self))
- else
- f:writeLong(torch.pointer(self))
- end
+ serializePointer(self, f)
tds.C.tds_vec_retain(self)
end
function mt.__read(self, f)
@@ -73,13 +84,13 @@ for _, typename in ipairs{
local mt = {}
function mt.__factory(f)
- local self = f:readLong()
+ local self = deserializePointer(f)
self = torch.pushudata(self, typename)
return self
end
function mt.write(self, f)
- f:writeLong(torch.pointer(self))
+ serializePointer(self, f)
if typename ~= 'torch.Allocator' then
self:retain()
end