diff options
author | Natalia Gimelshein <ngimelshein@nvidia.com> | 2016-07-20 02:55:06 +0300 |
---|---|---|
committer | Natalia Gimelshein <ngimelshein@nvidia.com> | 2016-07-29 02:18:02 +0300 |
commit | 8688e3eb439b09893cfb43901367c5cb4b3c23bd (patch) | |
tree | a4d7d6fdbe1447f1836c31ce0c1903624478317a | |
parent | c3250a987bae1cf9940b5162aa28ed0317ca5c01 (diff) |
saving states works
set tensors to nil when resetting states
don't silently nil hidden states, add rememberStates to GRU
-rw-r--r-- | GRU.lua | 4 | ||||
-rw-r--r-- | LSTM.lua | 4 | ||||
-rw-r--r-- | RNN.lua | 28 | ||||
-rw-r--r-- | RNNReLU.lua | 4 | ||||
-rw-r--r-- | RNNTanh.lua | 4 |
5 files changed, 32 insertions, 12 deletions
@@ -1,7 +1,7 @@ local GRU, parent = torch.class('cudnn.GRU', 'cudnn.RNN') -function GRU:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout) - parent.__init(self,inputSize, hiddenSize, numLayers, batchFirst, dropout) +function GRU:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, rememberStates) + parent.__init(self,inputSize, hiddenSize, numLayers, batchFirst, dropout, rememberStates) self.mode = 'CUDNN_GRU' self:reset() end @@ -1,7 +1,7 @@ local LSTM, parent = torch.class('cudnn.LSTM', 'cudnn.RNN') -function LSTM:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout) - parent.__init(self,inputSize, hiddenSize, numLayers, batchFirst, dropout) +function LSTM:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, rememberStates) + parent.__init(self,inputSize, hiddenSize, numLayers, batchFirst, dropout, rememberStates) self.mode = 'CUDNN_LSTM' self:reset() end @@ -11,7 +11,7 @@ RNN.linearLayers = { CUDNN_RNN_TANH = 2 } -function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout) +function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, rememberStates) parent.__init(self) self.datatype = 'CUDNN_DATA_FLOAT' @@ -27,6 +27,7 @@ function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout) self.dropout = dropout or 0 self.seed = 0x01234567 self.batchFirst = batchFirst or false -- Set to true for batch x time x inputdim. + self.rememberStates = rememberStates or false self.gradInput = torch.CudaTensor() self.output = torch.CudaTensor() @@ -240,6 +241,18 @@ function RNN:resizeHidden(tensor) return tensor:resize(self.numLayers * self.numDirections, self.miniBatch, self.hiddenSize) end +function RNN:resetStates() + if self.hiddenInput then + self.hiddenInput = nil + end + if self.cellInput then + self.cellInput = nil + end +end + + + + function RNN:updateOutput(input) if (self.batchFirst) then input = input:transpose(1, 2) @@ -266,6 +279,7 @@ function RNN:updateOutput(input) assert(input:size(3) == self.inputSize, 'Incorrect input size!') + -- Update descriptors/tensors if resetRNN then if not self.dropoutDesc then self:resetDropoutDescriptor() end @@ -364,9 +378,15 @@ function RNN:updateOutput(input) self.cyDesc[0], cy:data(), self.workspace:data(), self.workspace:size(1) * 4) -- sizeof(float) end - if (self.batchFirst) then - self.output = self.output:transpose(1, 2) - end + if self.rememberStates then + self.hiddenInput = self.hiddenOutput:clone() + if self.cellOutput then + self.cellInput = self.cellOutput:clone() + end + end + if (self.batchFirst) then + self.output = self.output:transpose(1, 2) + end return self.output end diff --git a/RNNReLU.lua b/RNNReLU.lua index fc262e2..377e3f2 100644 --- a/RNNReLU.lua +++ b/RNNReLU.lua @@ -1,7 +1,7 @@ local RNNReLU, parent = torch.class('cudnn.RNNReLU', 'cudnn.RNN') -function RNNReLU:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout) - parent.__init(self,inputSize, hiddenSize, numLayers, batchFirst, dropout) +function RNNReLU:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, rememberStates) + parent.__init(self,inputSize, hiddenSize, numLayers, batchFirst, dropout, rememberStates) self.mode = 'CUDNN_RNN_RELU' self:reset() end diff --git a/RNNTanh.lua b/RNNTanh.lua index 3382a52..5cf1ccd 100644 --- a/RNNTanh.lua +++ b/RNNTanh.lua @@ -1,7 +1,7 @@ local RNNTanh, parent = torch.class('cudnn.RNNTanh', 'cudnn.RNN') -function RNNTanh:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout) - parent.__init(self,inputSize, hiddenSize, numLayers, batchFirst, dropout) +function RNNTanh:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, rememberStates) + parent.__init(self,inputSize, hiddenSize, numLayers, batchFirst, dropout, rememberStates) self.mode = 'CUDNN_RNN_TANH' self:reset() end |