diff options
author | Frederic Besse <fbesse@google.com> | 2016-02-29 18:02:09 +0300 |
---|---|---|
committer | Frederic Besse <fbesse@google.com> | 2016-02-29 21:26:58 +0300 |
commit | 138e30164340d781ea2a5730092bb7781eee9cfd (patch) | |
tree | d3ce5bf2f8ec622597545c911c3fd485acc3c0c9 /TemporalConvolution.lua | |
parent | b80bdbab6faf66b711bf7c8a159703f62508c5e2 (diff) |
Fixing backward pass of TemporalConvolution
Diffstat (limited to 'TemporalConvolution.lua')
-rw-r--r-- | TemporalConvolution.lua | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/TemporalConvolution.lua b/TemporalConvolution.lua index 72a87c3..014cd95 100644 --- a/TemporalConvolution.lua +++ b/TemporalConvolution.lua @@ -13,7 +13,7 @@ function TemporalConvolution:__init(inputFrameSize, outputFrameSize, local nInputPlane = 1 -- single channel local nOutputPlane = outputFrameSize self.inputFrameSize = inputFrameSize - self.outputFrameSize = outputFramesize + self.outputFrameSize = outputFrameSize cudnn.SpatialConvolution.__init(self, nInputPlane, nOutputPlane, kW, kH, 1, dH,0,padH) self.weight = self.weight:view(nOutputPlane,inputFrameSize*kH) self.gradWeight = self.gradWeight:view(outputFrameSize, inputFrameSize*kH) @@ -26,7 +26,7 @@ function TemporalConvolution:createIODescriptors(input) if not self.iDesc or not self.oDesc or input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then - sizeChanged = true + sizeChanged = true end cudnn.SpatialConvolution.createIODescriptors(self,input) if sizeChanged then @@ -87,6 +87,11 @@ function TemporalConvolution:updateGradInput(input, gradOutput) if not self.gradInput then return end local _gradOutput = transposeGradOutput(gradOutput,self.buffer) local _input = inputview(input) + if input:dim()==3 and self.gradInput:dim() == 3 then + self.gradInput = self.gradInput:view(self.gradInput:size(1), 1, self.gradInput:size(2),self.gradInput:size(3)) + elseif input:dim() == 2 and self.gradInput:dim() == 2 then + self.gradInput = self.gradInput:view(1, 1, self.gradInput:size(1),self.gradInput:size(2)) + end self.gradInput = cudnn.SpatialConvolution.updateGradInput(self,_input, _gradOutput) if input:dim()==3 then self.gradInput = self.gradInput:view(self.gradInput:size(1),self.gradInput:size(3),self.gradInput:size(4)) |