Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederic Besse <fbesse@google.com>2016-02-29 18:02:09 +0300
committerFrederic Besse <fbesse@google.com>2016-02-29 21:26:58 +0300
commit138e30164340d781ea2a5730092bb7781eee9cfd (patch)
treed3ce5bf2f8ec622597545c911c3fd485acc3c0c9 /TemporalConvolution.lua
parentb80bdbab6faf66b711bf7c8a159703f62508c5e2 (diff)
Fixing backward pass of TemporalConvolution
Diffstat (limited to 'TemporalConvolution.lua')
-rw-r--r--TemporalConvolution.lua9
1 files changed, 7 insertions, 2 deletions
diff --git a/TemporalConvolution.lua b/TemporalConvolution.lua
index 72a87c3..014cd95 100644
--- a/TemporalConvolution.lua
+++ b/TemporalConvolution.lua
@@ -13,7 +13,7 @@ function TemporalConvolution:__init(inputFrameSize, outputFrameSize,
local nInputPlane = 1 -- single channel
local nOutputPlane = outputFrameSize
self.inputFrameSize = inputFrameSize
- self.outputFrameSize = outputFramesize
+ self.outputFrameSize = outputFrameSize
cudnn.SpatialConvolution.__init(self, nInputPlane, nOutputPlane, kW, kH, 1, dH,0,padH)
self.weight = self.weight:view(nOutputPlane,inputFrameSize*kH)
self.gradWeight = self.gradWeight:view(outputFrameSize, inputFrameSize*kH)
@@ -26,7 +26,7 @@ function TemporalConvolution:createIODescriptors(input)
if not self.iDesc or not self.oDesc or
input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then
- sizeChanged = true
+ sizeChanged = true
end
cudnn.SpatialConvolution.createIODescriptors(self,input)
if sizeChanged then
@@ -87,6 +87,11 @@ function TemporalConvolution:updateGradInput(input, gradOutput)
if not self.gradInput then return end
local _gradOutput = transposeGradOutput(gradOutput,self.buffer)
local _input = inputview(input)
+ if input:dim()==3 and self.gradInput:dim() == 3 then
+ self.gradInput = self.gradInput:view(self.gradInput:size(1), 1, self.gradInput:size(2),self.gradInput:size(3))
+ elseif input:dim() == 2 and self.gradInput:dim() == 2 then
+ self.gradInput = self.gradInput:view(1, 1, self.gradInput:size(1),self.gradInput:size(2))
+ end
self.gradInput = cudnn.SpatialConvolution.updateGradInput(self,_input, _gradOutput)
if input:dim()==3 then
self.gradInput = self.gradInput:view(self.gradInput:size(1),self.gradInput:size(3),self.gradInput:size(4))