Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Hynes <nhynes@mit.edu>2016-06-29 19:17:30 +0300
committerNick Hynes <nhynes@mit.edu>2016-06-29 19:53:15 +0300
commit378d02f9a2da13dc3614d95a21ec4b555524c561 (patch)
tree5760f7aa2678f9d549760727f91e08ee0e76348b
parente7b1a0c5e36b2983986c58d28d45d0a9649c84ee (diff)
push/pop descs before/after write
-rw-r--r--RNN.lua25
1 files changed, 15 insertions, 10 deletions
diff --git a/RNN.lua b/RNN.lua
index 23b1d46..2d415b1 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -2,6 +2,8 @@ local RNN, parent = torch.class('cudnn.RNN', 'nn.Module')
local ffi = require 'ffi'
local errcheck = cudnn.errcheck
+local DESCS = {'rnnDesc', 'dropoutDesc', 'wDesc', 'xDescs', 'yDescs', 'hxDesc', 'hyDesc', 'cxDesc', 'cyDesc'}
+
function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst)
parent.__init(self)
@@ -512,25 +514,28 @@ function RNN:accGradParameters(input, gradOutput, scale)
end
function RNN:clearDesc()
- self.dropoutDesc = nil
- self.rnnDesc = nil
- self.dropoutDesc = nil
- self.wDesc = nil
- self.xDescs = nil
- self.yDescs = nil
- self.hxDesc = nil
- self.hyDesc = nil
- self.cxDesc = nil
- self.cyDesc = nil
+ for _, desc in pairs(DESCS) do
+ self[desc] = nil
+ end
end
function RNN:write(f)
+ local pushDescs = {}
+ for _, desc in pairs(DESCS) do
+ pushDescs[desc] = self[desc]
+ end
+
self:clearDesc()
+
local var = {}
for k,v in pairs(self) do
var[k] = v
end
f:writeObject(var)
+
+ for desc, v in pairs(pushDescs) do
+ self[desc] = v
+ end
end
function RNN:clearState()