diff options
author | Nick Hynes <nhynes@mit.edu> | 2016-06-29 19:17:30 +0300 |
---|---|---|
committer | Nick Hynes <nhynes@mit.edu> | 2016-06-29 19:53:15 +0300 |
commit | 378d02f9a2da13dc3614d95a21ec4b555524c561 (patch) | |
tree | 5760f7aa2678f9d549760727f91e08ee0e76348b | |
parent | e7b1a0c5e36b2983986c58d28d45d0a9649c84ee (diff) |
push/pop descs before/after write
-rw-r--r-- | RNN.lua | 25 |
1 files changed, 15 insertions, 10 deletions
@@ -2,6 +2,8 @@ local RNN, parent = torch.class('cudnn.RNN', 'nn.Module') local ffi = require 'ffi' local errcheck = cudnn.errcheck +local DESCS = {'rnnDesc', 'dropoutDesc', 'wDesc', 'xDescs', 'yDescs', 'hxDesc', 'hyDesc', 'cxDesc', 'cyDesc'} + function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst) parent.__init(self) @@ -512,25 +514,28 @@ function RNN:accGradParameters(input, gradOutput, scale) end function RNN:clearDesc() - self.dropoutDesc = nil - self.rnnDesc = nil - self.dropoutDesc = nil - self.wDesc = nil - self.xDescs = nil - self.yDescs = nil - self.hxDesc = nil - self.hyDesc = nil - self.cxDesc = nil - self.cyDesc = nil + for _, desc in pairs(DESCS) do + self[desc] = nil + end end function RNN:write(f) + local pushDescs = {} + for _, desc in pairs(DESCS) do + pushDescs[desc] = self[desc] + end + self:clearDesc() + local var = {} for k,v in pairs(self) do var[k] = v end f:writeObject(var) + + for desc, v in pairs(pushDescs) do + self[desc] = v + end end function RNN:clearState() |