Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/threads-ffi.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2016-04-18 21:38:55 +0300
committerAdam Paszke <adam.paszke@gmail.com>2016-04-18 21:43:39 +0300
commit056f57eafd3565fdfef300e7821d6f0280450dd3 (patch)
tree9defcbf2fd6a305c8406c5b80d38738c50f36e23
parentef58d3ffc53fd4d03b4196979a98ffe11648f4f6 (diff)
Fix pointer overflows in all cases in sharedserialize
-rw-r--r--sharedserialize.lua37
1 files changed, 22 insertions, 15 deletions
diff --git a/sharedserialize.lua b/sharedserialize.lua
index 9cd8ebf..2a966fd 100644
--- a/sharedserialize.lua
+++ b/sharedserialize.lua
@@ -8,16 +8,32 @@ local _, tds = pcall(require, 'tds') -- for the free/retain functions
if tds then
local ffi = require 'ffi'
+ local function serializePointer(obj, f)
+ if ffi.sizeof('long') == 4 then
+ f:writeDouble(torch.pointer(obj))
+ else
+ f:writeLong(torch.pointer(obj))
+ end
+ end
+
+ local function deserializePointer(f)
+ if ffi.sizeof('long') == 4 then
+ return f:readDouble()
+ else
+ return f:readLong()
+ end
+ end
+
-- hash
local mt = {}
function mt.__factory(f)
- local self = f:readLong()
+ local self = deserializePointer(f)
self = ffi.cast('tds_hash&', self)
ffi.gc(self, tds.C.tds_hash_free)
return self
end
function mt.__write(self, f)
- f:writeLong(torch.pointer(self))
+ serializePointer(self, f)
tds.C.tds_hash_retain(self)
end
function mt.__read(self, f)
@@ -27,22 +43,13 @@ if tds then
-- vec
local mt = {}
function mt.__factory(f)
- local self
- if ffi.sizeof('long') == 4 then
- self = f:readDouble()
- else
- self = f:readLong()
- end
+ local self = deserializePointer(f)
self = ffi.cast('tds_vec&', self)
ffi.gc(self, tds.C.tds_vec_free)
return self
end
function mt.__write(self, f)
- if ffi.sizeof('long') == 4 then
- f:writeDouble(torch.pointer(self))
- else
- f:writeLong(torch.pointer(self))
- end
+ serializePointer(self, f)
tds.C.tds_vec_retain(self)
end
function mt.__read(self, f)
@@ -73,13 +80,13 @@ for _, typename in ipairs{
local mt = {}
function mt.__factory(f)
- local self = f:readLong()
+ local self = deserializePointer(f)
self = torch.pushudata(self, typename)
return self
end
function mt.write(self, f)
- f:writeLong(torch.pointer(self))
+ serializePointer(self, f)
if typename ~= 'torch.Allocator' then
self:retain()
end