From eea2fc997205899bb994c3389affcca7031933be Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 18 Apr 2016 20:57:23 +0200 Subject: Add general methods for pointer serialization in sharedserialize --- sharedserialize.lua | 43 +++++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/sharedserialize.lua b/sharedserialize.lua index 9cd8ebf..17d18b6 100644 --- a/sharedserialize.lua +++ b/sharedserialize.lua @@ -1,23 +1,43 @@ require 'torch' +local ffi = require 'ffi' local serialize = {} local typenames = {} +local function serializePointer(obj, f) + -- on 32-bit systems double can represent all possible + -- pointer values, but signed long can't + if ffi.sizeof('long') == 4 then + f:writeDouble(torch.pointer(obj)) + -- on 64-bit systems, long can represent a larger + -- range of integers than double, so it's safer to use this + else + f:writeLong(torch.pointer(obj)) + end +end + +local function deserializePointer(f) + if ffi.sizeof('long') == 4 then + return f:readDouble() + else + return f:readLong() + end +end + -- tds support local _, tds = pcall(require, 'tds') -- for the free/retain functions if tds then - local ffi = require 'ffi' -- hash local mt = {} function mt.__factory(f) - local self = f:readLong() + local self = deserializePointer(f) self = ffi.cast('tds_hash&', self) ffi.gc(self, tds.C.tds_hash_free) return self end function mt.__write(self, f) - f:writeLong(torch.pointer(self)) + serializePointer(self, f) tds.C.tds_hash_retain(self) end function mt.__read(self, f) @@ -27,22 +47,13 @@ if tds then -- vec local mt = {} function mt.__factory(f) - local self - if ffi.sizeof('long') == 4 then - self = f:readDouble() - else - self = f:readLong() - end + local self = deserializePointer(f) self = ffi.cast('tds_vec&', self) ffi.gc(self, tds.C.tds_vec_free) return self end function mt.__write(self, f) - if ffi.sizeof('long') == 4 then - f:writeDouble(torch.pointer(self)) - else - f:writeLong(torch.pointer(self)) - end + serializePointer(self, f) tds.C.tds_vec_retain(self) end function mt.__read(self, f) @@ -73,13 +84,13 @@ for _, typename in ipairs{ local mt = {} function mt.__factory(f) - local self = f:readLong() + local self = deserializePointer(f) self = torch.pushudata(self, typename) return self end function mt.write(self, f) - f:writeLong(torch.pointer(self)) + serializePointer(self, f) if typename ~= 'torch.Allocator' then self:retain() end -- cgit v1.2.3