From 8700bf64687742ac43f8fec599d1ea71387b7936 Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Mon, 18 Apr 2016 14:59:12 -0400 Subject: Revert "Fix pointer overflows in all cases in sharedserialize" --- sharedserialize.lua | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/sharedserialize.lua b/sharedserialize.lua index 2a966fd..9cd8ebf 100644 --- a/sharedserialize.lua +++ b/sharedserialize.lua @@ -8,32 +8,16 @@ local _, tds = pcall(require, 'tds') -- for the free/retain functions if tds then local ffi = require 'ffi' - local function serializePointer(obj, f) - if ffi.sizeof('long') == 4 then - f:writeDouble(torch.pointer(obj)) - else - f:writeLong(torch.pointer(obj)) - end - end - - local function deserializePointer(f) - if ffi.sizeof('long') == 4 then - return f:readDouble() - else - return f:readLong() - end - end - -- hash local mt = {} function mt.__factory(f) - local self = deserializePointer(f) + local self = f:readLong() self = ffi.cast('tds_hash&', self) ffi.gc(self, tds.C.tds_hash_free) return self end function mt.__write(self, f) - serializePointer(self, f) + f:writeLong(torch.pointer(self)) tds.C.tds_hash_retain(self) end function mt.__read(self, f) @@ -43,13 +27,22 @@ if tds then -- vec local mt = {} function mt.__factory(f) - local self = deserializePointer(f) + local self + if ffi.sizeof('long') == 4 then + self = f:readDouble() + else + self = f:readLong() + end self = ffi.cast('tds_vec&', self) ffi.gc(self, tds.C.tds_vec_free) return self end function mt.__write(self, f) - serializePointer(self, f) + if ffi.sizeof('long') == 4 then + f:writeDouble(torch.pointer(self)) + else + f:writeLong(torch.pointer(self)) + end tds.C.tds_vec_retain(self) end function mt.__read(self, f) @@ -80,13 +73,13 @@ for _, typename in ipairs{ local mt = {} function mt.__factory(f) - local self = deserializePointer(f) + local self = f:readLong() self = torch.pushudata(self, typename) return self end function mt.write(self, f) - serializePointer(self, f) + f:writeLong(torch.pointer(self)) if typename ~= 'torch.Allocator' then self:retain() end -- cgit v1.2.3