diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-04-18 21:52:36 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-04-18 21:52:36 +0300 |
commit | 118320fb285d2fa235a4bb2890832c5393a53f89 (patch) | |
tree | 9defcbf2fd6a305c8406c5b80d38738c50f36e23 | |
parent | ef58d3ffc53fd4d03b4196979a98ffe11648f4f6 (diff) | |
parent | 056f57eafd3565fdfef300e7821d6f0280450dd3 (diff) |
Merge pull request #57 from apaszke/sharedserialize_fix2
Fix pointer overflows in all cases in sharedserialize
-rw-r--r-- | sharedserialize.lua | 37 |
1 files changed, 22 insertions, 15 deletions
diff --git a/sharedserialize.lua b/sharedserialize.lua index 9cd8ebf..2a966fd 100644 --- a/sharedserialize.lua +++ b/sharedserialize.lua @@ -8,16 +8,32 @@ local _, tds = pcall(require, 'tds') -- for the free/retain functions if tds then local ffi = require 'ffi' + local function serializePointer(obj, f) + if ffi.sizeof('long') == 4 then + f:writeDouble(torch.pointer(obj)) + else + f:writeLong(torch.pointer(obj)) + end + end + + local function deserializePointer(f) + if ffi.sizeof('long') == 4 then + return f:readDouble() + else + return f:readLong() + end + end + -- hash local mt = {} function mt.__factory(f) - local self = f:readLong() + local self = deserializePointer(f) self = ffi.cast('tds_hash&', self) ffi.gc(self, tds.C.tds_hash_free) return self end function mt.__write(self, f) - f:writeLong(torch.pointer(self)) + serializePointer(self, f) tds.C.tds_hash_retain(self) end function mt.__read(self, f) @@ -27,22 +43,13 @@ if tds then -- vec local mt = {} function mt.__factory(f) - local self - if ffi.sizeof('long') == 4 then - self = f:readDouble() - else - self = f:readLong() - end + local self = deserializePointer(f) self = ffi.cast('tds_vec&', self) ffi.gc(self, tds.C.tds_vec_free) return self end function mt.__write(self, f) - if ffi.sizeof('long') == 4 then - f:writeDouble(torch.pointer(self)) - else - f:writeLong(torch.pointer(self)) - end + serializePointer(self, f) tds.C.tds_vec_retain(self) end function mt.__read(self, f) @@ -73,13 +80,13 @@ for _, typename in ipairs{ local mt = {} function mt.__factory(f) - local self = f:readLong() + local self = deserializePointer(f) self = torch.pushudata(self, typename) return self end function mt.write(self, f) - f:writeLong(torch.pointer(self)) + serializePointer(self, f) if typename ~= 'torch.Allocator' then self:retain() end |