Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/threads-ffi.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-04-18 21:59:12 +0300
committerSoumith Chintala <soumith@gmail.com>2016-04-18 21:59:12 +0300
commit8700bf64687742ac43f8fec599d1ea71387b7936 (patch)
tree8e6db78ad4208203a221206122fbe7f09a1ba589
parent118320fb285d2fa235a4bb2890832c5393a53f89 (diff)
Revert "Fix pointer overflows in all cases in sharedserialize"revert-57-sharedserialize_fix2
-rw-r--r--sharedserialize.lua37
1 files changed, 15 insertions, 22 deletions
diff --git a/sharedserialize.lua b/sharedserialize.lua
index 2a966fd..9cd8ebf 100644
--- a/sharedserialize.lua
+++ b/sharedserialize.lua
@@ -8,32 +8,16 @@ local _, tds = pcall(require, 'tds') -- for the free/retain functions
if tds then
local ffi = require 'ffi'
- local function serializePointer(obj, f)
- if ffi.sizeof('long') == 4 then
- f:writeDouble(torch.pointer(obj))
- else
- f:writeLong(torch.pointer(obj))
- end
- end
-
- local function deserializePointer(f)
- if ffi.sizeof('long') == 4 then
- return f:readDouble()
- else
- return f:readLong()
- end
- end
-
-- hash
local mt = {}
function mt.__factory(f)
- local self = deserializePointer(f)
+ local self = f:readLong()
self = ffi.cast('tds_hash&', self)
ffi.gc(self, tds.C.tds_hash_free)
return self
end
function mt.__write(self, f)
- serializePointer(self, f)
+ f:writeLong(torch.pointer(self))
tds.C.tds_hash_retain(self)
end
function mt.__read(self, f)
@@ -43,13 +27,22 @@ if tds then
-- vec
local mt = {}
function mt.__factory(f)
- local self = deserializePointer(f)
+ local self
+ if ffi.sizeof('long') == 4 then
+ self = f:readDouble()
+ else
+ self = f:readLong()
+ end
self = ffi.cast('tds_vec&', self)
ffi.gc(self, tds.C.tds_vec_free)
return self
end
function mt.__write(self, f)
- serializePointer(self, f)
+ if ffi.sizeof('long') == 4 then
+ f:writeDouble(torch.pointer(self))
+ else
+ f:writeLong(torch.pointer(self))
+ end
tds.C.tds_vec_retain(self)
end
function mt.__read(self, f)
@@ -80,13 +73,13 @@ for _, typename in ipairs{
local mt = {}
function mt.__factory(f)
- local self = deserializePointer(f)
+ local self = f:readLong()
self = torch.pushudata(self, typename)
return self
end
function mt.write(self, f)
- serializePointer(self, f)
+ f:writeLong(torch.pointer(self))
if typename ~= 'torch.Allocator' then
self:retain()
end