diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-04-18 21:59:12 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-04-18 21:59:12 +0300 |
commit | 8700bf64687742ac43f8fec599d1ea71387b7936 (patch) | |
tree | 8e6db78ad4208203a221206122fbe7f09a1ba589 | |
parent | 118320fb285d2fa235a4bb2890832c5393a53f89 (diff) |
Revert "Fix pointer overflows in all cases in sharedserialize"revert-57-sharedserialize_fix2
-rw-r--r-- | sharedserialize.lua | 37 |
1 files changed, 15 insertions, 22 deletions
diff --git a/sharedserialize.lua b/sharedserialize.lua index 2a966fd..9cd8ebf 100644 --- a/sharedserialize.lua +++ b/sharedserialize.lua @@ -8,32 +8,16 @@ local _, tds = pcall(require, 'tds') -- for the free/retain functions if tds then local ffi = require 'ffi' - local function serializePointer(obj, f) - if ffi.sizeof('long') == 4 then - f:writeDouble(torch.pointer(obj)) - else - f:writeLong(torch.pointer(obj)) - end - end - - local function deserializePointer(f) - if ffi.sizeof('long') == 4 then - return f:readDouble() - else - return f:readLong() - end - end - -- hash local mt = {} function mt.__factory(f) - local self = deserializePointer(f) + local self = f:readLong() self = ffi.cast('tds_hash&', self) ffi.gc(self, tds.C.tds_hash_free) return self end function mt.__write(self, f) - serializePointer(self, f) + f:writeLong(torch.pointer(self)) tds.C.tds_hash_retain(self) end function mt.__read(self, f) @@ -43,13 +27,22 @@ if tds then -- vec local mt = {} function mt.__factory(f) - local self = deserializePointer(f) + local self + if ffi.sizeof('long') == 4 then + self = f:readDouble() + else + self = f:readLong() + end self = ffi.cast('tds_vec&', self) ffi.gc(self, tds.C.tds_vec_free) return self end function mt.__write(self, f) - serializePointer(self, f) + if ffi.sizeof('long') == 4 then + f:writeDouble(torch.pointer(self)) + else + f:writeLong(torch.pointer(self)) + end tds.C.tds_vec_retain(self) end function mt.__read(self, f) @@ -80,13 +73,13 @@ for _, typename in ipairs{ local mt = {} function mt.__factory(f) - local self = deserializePointer(f) + local self = f:readLong() self = torch.pushudata(self, typename) return self end function mt.write(self, f) - serializePointer(self, f) + f:writeLong(torch.pointer(self)) if typename ~= 'torch.Allocator' then self:retain() end |