Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-03-04 19:16:20 +0300
committerSoumith Chintala <soumith@gmail.com>2016-03-04 19:16:20 +0300
commit1cf98511b39a486f74aa4ecce02f8de422101a2d (patch)
tree3a700def960c2a93195abf7d0e1bcac09ed41b38 /File.lua
parent3b14c94d36035232f4a0efd4d5023cf3a5f0ec0b (diff)
parentfd7f176c2ff32e8612bc1152504e5cf408eb214c (diff)
Merge pull request #567 from nkoumchatzky/master
Add a hook to File:writeObject
Diffstat (limited to 'File.lua')
-rw-r--r--File.lua25
1 files changed, 17 insertions, 8 deletions
diff --git a/File.lua b/File.lua
index 74d9a8f..7811f0b 100644
--- a/File.lua
+++ b/File.lua
@@ -104,7 +104,9 @@ local function formatStack(objectNameStack)
return table.concat(parts, '.')
end
-function File:writeObject(object, debugname)
+function File:writeObject(object, debugname, hook)
+ -- define a default hook function if not provided
+ hook = hook or function(object) return object end
-- we use an environment to keep a record of written objects
if not torch.getenv(self).writeObjects then
torch.setenv(self, {
@@ -114,7 +116,14 @@ function File:writeObject(object, debugname)
upvalueRefToId={}, upvalueIdToClosure={},
})
end
-
+ -- That guy is used for references' book-keeping
+ local sobject = object
+ -- That guy is the object that is actually persisted
+ -- hook(object) can be used to modify the object before writing it to the file.
+ -- Useful for serializing objects under a config
+ -- that we want to deserialize safely under another config.
+ -- (e.g. Cuda to Float tensors, cudnn to nn, ...)
+ object = hook(object)
local force = torch.getenv(self).force
-- if nil object, only write the type and return
@@ -145,7 +154,7 @@ function File:writeObject(object, debugname)
-- check it exists already (we look at the pointer!)
local objects = torch.getenv(self).writeObjects
local objectsRef = torch.getenv(self).writeObjectsRef
- local index = objects[torch.pointer(object)]
+ local index = objects[torch.pointer(sobject)]
if index and (not force) then
-- if already exists, write only its index
@@ -155,7 +164,7 @@ function File:writeObject(object, debugname)
index = objects.nWriteObject or 0
index = index + 1
if not force then
- objects[torch.pointer(object)] = index
+ objects[torch.pointer(sobject)] = index
objectsRef[object] = index -- we make sure the object is not going to disappear
end
self:writeInt(index)
@@ -188,7 +197,7 @@ function File:writeObject(object, debugname)
local stringStorage = torch.CharStorage():string(dumped)
self:writeInt(#stringStorage)
self:writeChar(stringStorage)
- self:writeObject(upvalues, UPVALUES_TOKEN)
+ self:writeObject(upvalues, UPVALUES_TOKEN, hook)
elseif typeidx == TYPE_TORCH then
local version = torch.CharStorage():string('V ' .. torch.version(object))
local className = torch.CharStorage():string(torch.typename(object))
@@ -208,7 +217,7 @@ function File:writeObject(object, debugname)
print(string.format('$ Warning: cannot write object field <%s> of <%s> %s', k, torch.typename(object), formatStack(objectNameStack)))
end
end
- self:writeObject(var, torch.typename(object))
+ self:writeObject(var, torch.typename(object), hook)
else
error(string.format('<%s> is a non-serializable Torch object %s', torch.typename(object), formatStack(objectNameStack)))
end
@@ -216,14 +225,14 @@ function File:writeObject(object, debugname)
local size = 0; for k,v in pairs(object) do size = size + 1 end
self:writeInt(size)
for k,v in pairs(object) do
- self:writeObject(k)
+ self:writeObject(k, nil, hook)
local name = (type(k) == 'string' or type(k) == 'number') and tostring(k) or nil
-- special case name for upvalues
if objectNameStack[#objectNameStack-1] == UPVALUES_TOKEN and
name == 'value' and type(object.name) == 'string' then
name = object.name
end
- self:writeObject(v, name)
+ self:writeObject(v, name, hook)
end
end
end