diff options
author | Sam Gross <sgross@fb.com> | 2016-01-09 07:53:18 +0300 |
---|---|---|
committer | Sam Gross <sgross@fb.com> | 2016-01-09 07:53:18 +0300 |
commit | 051b3b0ffd59a340f27ec21765d4826c8139efe4 (patch) | |
tree | 2219b6a62b473551d188c2ee77332ea0b4ecf56b | |
parent | 2a5e2a666ef2b34d728db8e4f3aac8f1660c4bf0 (diff) |
sharedserialize: add support for torch.Allocator
-rw-r--r-- | sharedserialize.lua | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/sharedserialize.lua b/sharedserialize.lua index aba4fb3..6e09a90 100644 --- a/sharedserialize.lua +++ b/sharedserialize.lua @@ -58,7 +58,8 @@ for _, typename in ipairs{ 'torch.LongStorage', 'torch.FloatStorage', 'torch.DoubleStorage', - 'torch.CudaStorage'} do + 'torch.CudaStorage', + 'torch.Allocator'} do local mt = {} @@ -70,7 +71,9 @@ for _, typename in ipairs{ function mt.write(self, f) f:writeLong(torch.pointer(self)) - self:retain() + if typename ~= 'torch.Allocator' then + self:retain() + end end function mt.read(self, f) @@ -83,6 +86,7 @@ local function swapwrite() for typename, mt in pairs(typenames) do local mts = torch.getmetatable(typename) if mts then + mts.__factory, mt.__factory = mt.__factory, mts.__factory mts.__write, mt.__write = mt.__write, mts.__write mts.write, mt.write = mt.write, mts.write end |