From 056f57eafd3565fdfef300e7821d6f0280450dd3 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 18 Apr 2016 20:38:55 +0200 Subject: Fix pointer overflows in all cases in sharedserialize --- sharedserialize.lua | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/sharedserialize.lua b/sharedserialize.lua index 9cd8ebf..2a966fd 100644 --- a/sharedserialize.lua +++ b/sharedserialize.lua @@ -8,16 +8,32 @@ local _, tds = pcall(require, 'tds') -- for the free/retain functions if tds then local ffi = require 'ffi' + local function serializePointer(obj, f) + if ffi.sizeof('long') == 4 then + f:writeDouble(torch.pointer(obj)) + else + f:writeLong(torch.pointer(obj)) + end + end + + local function deserializePointer(f) + if ffi.sizeof('long') == 4 then + return f:readDouble() + else + return f:readLong() + end + end + -- hash local mt = {} function mt.__factory(f) - local self = f:readLong() + local self = deserializePointer(f) self = ffi.cast('tds_hash&', self) ffi.gc(self, tds.C.tds_hash_free) return self end function mt.__write(self, f) - f:writeLong(torch.pointer(self)) + serializePointer(self, f) tds.C.tds_hash_retain(self) end function mt.__read(self, f) @@ -27,22 +43,13 @@ if tds then -- vec local mt = {} function mt.__factory(f) - local self - if ffi.sizeof('long') == 4 then - self = f:readDouble() - else - self = f:readLong() - end + local self = deserializePointer(f) self = ffi.cast('tds_vec&', self) ffi.gc(self, tds.C.tds_vec_free) return self end function mt.__write(self, f) - if ffi.sizeof('long') == 4 then - f:writeDouble(torch.pointer(self)) - else - f:writeLong(torch.pointer(self)) - end + serializePointer(self, f) tds.C.tds_vec_retain(self) end function mt.__read(self, f) @@ -73,13 +80,13 @@ for _, typename in ipairs{ local mt = {} function mt.__factory(f) - local self = f:readLong() + local self = deserializePointer(f) self = torch.pushudata(self, typename) return self end function mt.write(self, f) - f:writeLong(torch.pointer(self)) + serializePointer(self, f) if typename ~= 'torch.Allocator' then self:retain() end -- cgit v1.2.3