Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorngimel <ngimelshein@nvidia.com>2016-10-01 02:44:53 +0300
committerGitHub <noreply@github.com>2016-10-01 02:44:53 +0300
commit383c3dab11a1d46943a93684d12834e6a9dbcce4 (patch)
tree0ee6469c7e8fe222b71a992fb1de39a47c770305
parent278fb2c01bba4bd15c606a27509863cde5960ef3 (diff)
Sync rnn (#2)
optional syncs in RNN, do not reallocate dropoutStates, fix dropoutStates size
-rw-r--r--RNN.lua11
1 files changed, 10 insertions, 1 deletions
diff --git a/RNN.lua b/RNN.lua
index 1763110..aa917cb 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -28,6 +28,7 @@ function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, remem
self.seed = 0x01234567
self.batchFirst = batchFirst or false -- Set to true for batch x time x inputdim.
self.rememberStates = rememberStates or false
+ self.sync = false
self.gradInput = torch.CudaTensor()
self.output = torch.CudaTensor()
@@ -43,6 +44,10 @@ function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, remem
self:reset()
end
+function RNN:setSync(sync)
+ self.sync = sync
+end
+
function RNN:reset(stdv)
stdv = stdv or 1.0 / math.sqrt(self.hiddenSize)
@@ -114,7 +119,9 @@ function RNN:resetDropoutDescriptor()
errcheck('cudnnDropoutGetStatesSize',
cudnn.getHandle(),
self.dropoutStatesSize:data())
- self.dropoutStates = torch.CudaTensor(self.dropoutStatesSize[1])
+ self.dropoutStates = self.dropoutStates or torch.CudaTensor()
+ local nElem = ((self.dropoutStatesSize[1]-1)/self.dropoutStates:elementSize()+1)
+ self.dropoutStates:resize(nElem)
errcheck('cudnnSetDropoutDescriptor',
self.dropoutDesc[0],
@@ -384,6 +391,7 @@ function RNN:updateOutput(input)
self.cyDesc[0], cy:data(),
self.workspace:data(), self.workspace:size(1) * 4) -- sizeof(float)
end
+ if self.sync then cutorch.synchronize() end
if self.rememberStates then
self.hiddenInput = self.hiddenOutput:clone()
if self.cellOutput then
@@ -473,6 +481,7 @@ function RNN:updateGradInput(input, gradOutput)
self.cxDesc[0], dcx:data(),
self.workspace:data(), self.workspace:size(1) * 4, -- sizeof(float)
self.reserve:data(), self.reserve:size(1) * 4) -- sizeof(float)
+ if self.sync then cutorch.synchronize() end
if (self.batchFirst) then
self.gradInput = self.gradInput:transpose(1, 2)
self.output = self.output:transpose(1, 2)