diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-03-04 19:16:20 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-03-04 19:16:20 +0300 |
commit | 1cf98511b39a486f74aa4ecce02f8de422101a2d (patch) | |
tree | 3a700def960c2a93195abf7d0e1bcac09ed41b38 /File.lua | |
parent | 3b14c94d36035232f4a0efd4d5023cf3a5f0ec0b (diff) | |
parent | fd7f176c2ff32e8612bc1152504e5cf408eb214c (diff) |
Merge pull request #567 from nkoumchatzky/master
Add a hook to File:writeObject
Diffstat (limited to 'File.lua')
-rw-r--r-- | File.lua | 25 |
1 files changed, 17 insertions, 8 deletions
@@ -104,7 +104,9 @@ local function formatStack(objectNameStack) return table.concat(parts, '.') end -function File:writeObject(object, debugname) +function File:writeObject(object, debugname, hook) + -- define a default hook function if not provided + hook = hook or function(object) return object end -- we use an environment to keep a record of written objects if not torch.getenv(self).writeObjects then torch.setenv(self, { @@ -114,7 +116,14 @@ function File:writeObject(object, debugname) upvalueRefToId={}, upvalueIdToClosure={}, }) end - + -- That guy is used for references' book-keeping + local sobject = object + -- That guy is the object that is actually persisted + -- hook(object) can be used to modify the object before writing it to the file. + -- Useful for serializing objects under a config + -- that we want to deserialize safely under another config. + -- (e.g. Cuda to Float tensors, cudnn to nn, ...) + object = hook(object) local force = torch.getenv(self).force -- if nil object, only write the type and return @@ -145,7 +154,7 @@ function File:writeObject(object, debugname) -- check it exists already (we look at the pointer!) local objects = torch.getenv(self).writeObjects local objectsRef = torch.getenv(self).writeObjectsRef - local index = objects[torch.pointer(object)] + local index = objects[torch.pointer(sobject)] if index and (not force) then -- if already exists, write only its index @@ -155,7 +164,7 @@ function File:writeObject(object, debugname) index = objects.nWriteObject or 0 index = index + 1 if not force then - objects[torch.pointer(object)] = index + objects[torch.pointer(sobject)] = index objectsRef[object] = index -- we make sure the object is not going to disappear end self:writeInt(index) @@ -188,7 +197,7 @@ function File:writeObject(object, debugname) local stringStorage = torch.CharStorage():string(dumped) self:writeInt(#stringStorage) self:writeChar(stringStorage) - self:writeObject(upvalues, UPVALUES_TOKEN) + self:writeObject(upvalues, UPVALUES_TOKEN, hook) elseif typeidx == TYPE_TORCH then local version = torch.CharStorage():string('V ' .. torch.version(object)) local className = torch.CharStorage():string(torch.typename(object)) @@ -208,7 +217,7 @@ function File:writeObject(object, debugname) print(string.format('$ Warning: cannot write object field <%s> of <%s> %s', k, torch.typename(object), formatStack(objectNameStack))) end end - self:writeObject(var, torch.typename(object)) + self:writeObject(var, torch.typename(object), hook) else error(string.format('<%s> is a non-serializable Torch object %s', torch.typename(object), formatStack(objectNameStack))) end @@ -216,14 +225,14 @@ function File:writeObject(object, debugname) local size = 0; for k,v in pairs(object) do size = size + 1 end self:writeInt(size) for k,v in pairs(object) do - self:writeObject(k) + self:writeObject(k, nil, hook) local name = (type(k) == 'string' or type(k) == 'number') and tostring(k) or nil -- special case name for upvalues if objectNameStack[#objectNameStack-1] == UPVALUES_TOKEN and name == 'value' and type(object.name) == 'string' then name = object.name end - self:writeObject(v, name) + self:writeObject(v, name, hook) end end end |