diff options
author | Ronan Collobert <ronan@collobert.com> | 2015-07-21 01:49:18 +0300 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2015-07-21 01:49:18 +0300 |
commit | 28dc2d74ac3a510d1c45967fa3afcb67e57acf90 (patch) | |
tree | 226a1bc1fc0b604834a79b4796454b2d97fdbed5 | |
parent | 7ff6d1de3fa45e8c50fd943b1d0ba58acb75f467 (diff) |
cleanup sharedserialize (now supports properly CudaTensor/CudaStorage, if cutorch is loaded)
-rw-r--r-- | sharedserialize.lua | 81 |
1 files changed, 29 insertions, 52 deletions
diff --git a/sharedserialize.lua b/sharedserialize.lua index 9d2cfe3..2fe98ad 100644 --- a/sharedserialize.lua +++ b/sharedserialize.lua @@ -1,41 +1,13 @@ require 'torch' -local _, tds = pcall(require, 'tds') - local serialize = {} local typenames = {} --- check if typenames exists -for _, typename in ipairs{ - 'torch.ByteTensor', - 'torch.CharTensor', - 'torch.ShortTensor', - 'torch.IntTensor', - 'torch.LongTensor', - 'torch.CudaTensor', - 'torch.FloatTensor', - 'torch.DoubleTensor', - 'torch.CudaTensor', - 'torch.ByteStorage', - 'torch.CharStorage', - 'torch.ShortStorage', - 'torch.IntStorage', - 'torch.LongStorage', - 'torch.CudaStorage', - 'torch.FloatStorage', - 'torch.DoubleStorage', - 'torch.CudaStorage', - 'tds_hash'} do - - if torch.getmetatable(typename) then - typenames[typename] = {} - end - -end - -if typenames.tds_hash then +-- tds support +local _, tds = pcall(require, 'tds') -- for the free/retain functions +if tds then local ffi = require 'ffi' - local mt = typenames.tds_hash + local mt = {} function mt.__factory(f) local self = f:readLong() @@ -51,15 +23,17 @@ if typenames.tds_hash then function mt.__read(self, f) end + + typenames['tds_hash'] = mt end +-- tensor support for _, typename in ipairs{ 'torch.ByteTensor', 'torch.CharTensor', 'torch.ShortTensor', 'torch.IntTensor', 'torch.LongTensor', - 'torch.CudaTensor', 'torch.FloatTensor', 'torch.DoubleTensor', 'torch.CudaTensor', @@ -68,44 +42,47 @@ for _, typename in ipairs{ 'torch.ShortStorage', 'torch.IntStorage', 'torch.LongStorage', - 'torch.CudaStorage', 'torch.FloatStorage', 'torch.DoubleStorage', 'torch.CudaStorage'} do - if typenames[typename] then - local mt = typenames[typename] + local mt = {} - function mt.__factory(f) - local self = f:readLong() - self = torch.pushudata(self, typename) - return self - end + function mt.__factory(f) + local self = f:readLong() + self = torch.pushudata(self, typename) + return self + end - function mt.write(self, f) - f:writeLong(torch.pointer(self)) - self:retain() - end + function mt.write(self, f) + f:writeLong(torch.pointer(self)) + self:retain() + end - function mt.read(self, f) - end + function mt.read(self, f) end + + typenames[typename] = mt end local function swapwrite() for typename, mt in pairs(typenames) do local mts = torch.getmetatable(typename) - mts.__write, mt.__write = mt.__write, mts.__write - mts.write, mt.write = mt.write, mts.write + if mts then + mts.__write, mt.__write = mt.__write, mts.__write + mts.write, mt.write = mt.write, mts.write + end end end local function swapread() for typename, mt in pairs(typenames) do local mts = torch.getmetatable(typename) - mts.__factory, mt.__factory = mt.__factory, mts.__factory - mts.__read, mt.__read = mt.__read, mts.__read - mts.read, mt.read = mt.read, mts.read + if mts then + mts.__factory, mt.__factory = mt.__factory, mts.__factory + mts.__read, mt.__read = mt.__read, mts.__read + mts.read, mt.read = mt.read, mts.read + end end end |