Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/threads-ffi.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2015-07-21 01:49:18 +0300
committerRonan Collobert <ronan@collobert.com>2015-07-21 01:49:18 +0300
commit28dc2d74ac3a510d1c45967fa3afcb67e57acf90 (patch)
tree226a1bc1fc0b604834a79b4796454b2d97fdbed5
parent7ff6d1de3fa45e8c50fd943b1d0ba58acb75f467 (diff)
cleanup sharedserialize (now supports properly CudaTensor/CudaStorage, if cutorch is loaded)
-rw-r--r--sharedserialize.lua81
1 files changed, 29 insertions, 52 deletions
diff --git a/sharedserialize.lua b/sharedserialize.lua
index 9d2cfe3..2fe98ad 100644
--- a/sharedserialize.lua
+++ b/sharedserialize.lua
@@ -1,41 +1,13 @@
require 'torch'
-local _, tds = pcall(require, 'tds')
-
local serialize = {}
local typenames = {}
--- check if typenames exists
-for _, typename in ipairs{
- 'torch.ByteTensor',
- 'torch.CharTensor',
- 'torch.ShortTensor',
- 'torch.IntTensor',
- 'torch.LongTensor',
- 'torch.CudaTensor',
- 'torch.FloatTensor',
- 'torch.DoubleTensor',
- 'torch.CudaTensor',
- 'torch.ByteStorage',
- 'torch.CharStorage',
- 'torch.ShortStorage',
- 'torch.IntStorage',
- 'torch.LongStorage',
- 'torch.CudaStorage',
- 'torch.FloatStorage',
- 'torch.DoubleStorage',
- 'torch.CudaStorage',
- 'tds_hash'} do
-
- if torch.getmetatable(typename) then
- typenames[typename] = {}
- end
-
-end
-
-if typenames.tds_hash then
+-- tds support
+local _, tds = pcall(require, 'tds') -- for the free/retain functions
+if tds then
local ffi = require 'ffi'
- local mt = typenames.tds_hash
+ local mt = {}
function mt.__factory(f)
local self = f:readLong()
@@ -51,15 +23,17 @@ if typenames.tds_hash then
function mt.__read(self, f)
end
+
+ typenames['tds_hash'] = mt
end
+-- tensor support
for _, typename in ipairs{
'torch.ByteTensor',
'torch.CharTensor',
'torch.ShortTensor',
'torch.IntTensor',
'torch.LongTensor',
- 'torch.CudaTensor',
'torch.FloatTensor',
'torch.DoubleTensor',
'torch.CudaTensor',
@@ -68,44 +42,47 @@ for _, typename in ipairs{
'torch.ShortStorage',
'torch.IntStorage',
'torch.LongStorage',
- 'torch.CudaStorage',
'torch.FloatStorage',
'torch.DoubleStorage',
'torch.CudaStorage'} do
- if typenames[typename] then
- local mt = typenames[typename]
+ local mt = {}
- function mt.__factory(f)
- local self = f:readLong()
- self = torch.pushudata(self, typename)
- return self
- end
+ function mt.__factory(f)
+ local self = f:readLong()
+ self = torch.pushudata(self, typename)
+ return self
+ end
- function mt.write(self, f)
- f:writeLong(torch.pointer(self))
- self:retain()
- end
+ function mt.write(self, f)
+ f:writeLong(torch.pointer(self))
+ self:retain()
+ end
- function mt.read(self, f)
- end
+ function mt.read(self, f)
end
+
+ typenames[typename] = mt
end
local function swapwrite()
for typename, mt in pairs(typenames) do
local mts = torch.getmetatable(typename)
- mts.__write, mt.__write = mt.__write, mts.__write
- mts.write, mt.write = mt.write, mts.write
+ if mts then
+ mts.__write, mt.__write = mt.__write, mts.__write
+ mts.write, mt.write = mt.write, mts.write
+ end
end
end
local function swapread()
for typename, mt in pairs(typenames) do
local mts = torch.getmetatable(typename)
- mts.__factory, mt.__factory = mt.__factory, mts.__factory
- mts.__read, mt.__read = mt.__read, mts.__read
- mts.read, mt.read = mt.read, mts.read
+ if mts then
+ mts.__factory, mt.__factory = mt.__factory, mts.__factory
+ mts.__read, mt.__read = mt.__read, mts.__read
+ mts.read, mt.read = mt.read, mts.read
+ end
end
end