diff options
author | ngimel <ngimelshein@nvidia.com> | 2016-10-01 02:44:53 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-10-01 02:44:53 +0300 |
commit | 383c3dab11a1d46943a93684d12834e6a9dbcce4 (patch) | |
tree | 0ee6469c7e8fe222b71a992fb1de39a47c770305 | |
parent | 278fb2c01bba4bd15c606a27509863cde5960ef3 (diff) |
Sync rnn (#2)
optional syncs in RNN, do not reallocate dropoutStates, fix dropoutStates size
-rw-r--r-- | RNN.lua | 11 |
1 files changed, 10 insertions, 1 deletions
@@ -28,6 +28,7 @@ function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, remem self.seed = 0x01234567 self.batchFirst = batchFirst or false -- Set to true for batch x time x inputdim. self.rememberStates = rememberStates or false + self.sync = false self.gradInput = torch.CudaTensor() self.output = torch.CudaTensor() @@ -43,6 +44,10 @@ function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, remem self:reset() end +function RNN:setSync(sync) + self.sync = sync +end + function RNN:reset(stdv) stdv = stdv or 1.0 / math.sqrt(self.hiddenSize) @@ -114,7 +119,9 @@ function RNN:resetDropoutDescriptor() errcheck('cudnnDropoutGetStatesSize', cudnn.getHandle(), self.dropoutStatesSize:data()) - self.dropoutStates = torch.CudaTensor(self.dropoutStatesSize[1]) + self.dropoutStates = self.dropoutStates or torch.CudaTensor() + local nElem = ((self.dropoutStatesSize[1]-1)/self.dropoutStates:elementSize()+1) + self.dropoutStates:resize(nElem) errcheck('cudnnSetDropoutDescriptor', self.dropoutDesc[0], @@ -384,6 +391,7 @@ function RNN:updateOutput(input) self.cyDesc[0], cy:data(), self.workspace:data(), self.workspace:size(1) * 4) -- sizeof(float) end + if self.sync then cutorch.synchronize() end if self.rememberStates then self.hiddenInput = self.hiddenOutput:clone() if self.cellOutput then @@ -473,6 +481,7 @@ function RNN:updateGradInput(input, gradOutput) self.cxDesc[0], dcx:data(), self.workspace:data(), self.workspace:size(1) * 4, -- sizeof(float) self.reserve:data(), self.reserve:size(1) * 4) -- sizeof(float) + if self.sync then cutorch.synchronize() end if (self.batchFirst) then self.gradInput = self.gradInput:transpose(1, 2) self.output = self.output:transpose(1, 2) |