diff options
-rw-r--r-- | TemporalConvolution.lua | 103 | ||||
-rw-r--r-- | init.lua | 2 | ||||
-rw-r--r-- | test/test.lua | 148 |
3 files changed, 252 insertions, 1 deletions
diff --git a/TemporalConvolution.lua b/TemporalConvolution.lua new file mode 100644 index 0000000..9ace2f2 --- /dev/null +++ b/TemporalConvolution.lua @@ -0,0 +1,103 @@ +local TemporalConvolution, parent = + torch.class('cudnn.TemporalConvolution', 'cudnn.SpatialConvolution') +--use cudnn to perform temporal convolutions +--note: if padH parameter is not passed, no padding will be performed, as in parent TemporalConvolution +--however, instead of separately padding data, as is required now for nn.TemporalConvolution, +--it is recommended to pass padding parameter to this routine and use cudnn implicit padding facilities. +--limitation is that padding will be equal on both sides. + +function TemporalConvolution:__init(inputFrameSize, outputFrameSize, + kH, dH, padH) + local delayedReset = self.reset + local kW = inputFrameSize + local nInputPlane = 1 -- single channel + local nOutputPlane = outputFrameSize + parent.__init(self, nInputPlane, nOutputPlane, kW, kH, 1, dH,0,padH) + self.weight = self.weight:view(nOutputPlane,inputFrameSize*kH) + self.gradWeight = self.gradWeight:view(outputFrameSize, inputFrameSize*kH) + self.inputFrameSize = inputFrameSize + self.outputFrameSize = outputFramesize +--self.dW and self.kW now have different meaning than in nn.TemporalConvolution, because +--W and H are switched in temporal and spatial +end + +function TemporalConvolution:createIODescriptors(input) + local sizeChanged = false + if not self.iDesc or not self.oDesc or + input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] + or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then + sizeChanged = true + end + parent.createIODescriptors(self,input) + if sizeChanged then + self.oSize = self.output:size() + end +end + +local function inputview(input) + local _input = input + if input:dim()==2 then + _input = input:view(1,input:size(1),input:size(2)) + end + return _input:view(_input:size(1),1,_input:size(2),_input:size(3)) +end + +function TemporalConvolution:updateOutput(input) + local _input = inputview(input) + assert(_input:size(4) == self.inputFrameSize,'invalid input frame size') + self.buffer = self.buffer or torch.CudaTensor() + self._output = self._output or torch.CudaTensor() + if self.output:storage() then self._output:set(self.output:storage()) else self._output = self.output end + if self.buffer:storage() then self.output:set(self.buffer:storage()) else self.output = self.buffer end + parent.updateOutput(self,_input) + self.buffer = self.output:view(self.oSize):transpose(2,3) + self.output = self._output:resize(self.buffer:size()):copy(self.buffer) + -- self.output here is always 4D, use input dimensions to properly view output + if input:dim()==3 then + self.output=self.output:view(self.oSize[1], self.oSize[3],self.oSize[2]) + else + self.output=self.output:view(self.oSize[3], self.oSize[2]) + end + return self.output +end + +local function transposeGradOutput(src,dst) + assert(src:dim() == 2 or src:dim() == 3, 'gradOutput has to be 2D or 3D'); + local srctransposed = src:transpose(src:dim(),src:dim()-1) + dst:resize(srctransposed:size()) + dst:copy(srctransposed) + if src:dim()==3 then + dst = dst:view(dst:size(1),dst:size(2),dst:size(3),1) + else + dst = dst:view(dst:size(1),dst:size(2),1) + end + return dst +end + +function TemporalConvolution:updateGradInput(input, gradOutput) + if not self.gradInput then return end + local _gradOutput = transposeGradOutput(gradOutput,self.buffer) + local _input = inputview(input) + self.gradInput = parent.updateGradInput(self,_input, _gradOutput) + if input:dim()==3 then + self.gradInput = self.gradInput:view(self.gradInput:size(1),self.gradInput:size(3),self.gradInput:size(4)) + else + self.gradInput = self.gradInput:view(self.gradInput:size(3),self.gradInput:size(4)) + end + return self.gradInput +end + +function TemporalConvolution:accGradParameters(input,gradOutput,scale) +--2d (4d) view of input + local _input = inputview(input) +-- transpose gradOutput (it will likely be transposed twice, hopefully, no big deal + local _gradOutput = transposeGradOutput(gradOutput,self.buffer) + parent.accGradParameters(self,_input,_gradOutput,scale) +end + +function TemporalConvolution:write(f) + self.buffer = nil + self._ouptut = nil + self.oSize = nil + parent.write(self,f) +end @@ -112,7 +112,7 @@ include 'LogSoftMax.lua' include 'SpatialCrossMapLRN.lua' include 'SpatialBatchNormalization.lua' include 'SpatialCrossEntropyCriterion.lua' - +include 'TemporalConvolution.lua' include 'functional.lua' diff --git a/test/test.lua b/test/test.lua index b5991c5..e11f048 100644 --- a/test/test.lua +++ b/test/test.lua @@ -163,6 +163,154 @@ function cudnntest.SpatialConvolution_backward_single() 'error on bias (backward) ') end +function cudnntest.TemporalConvolution_batch() + local bs = math.random(1,32) + local inputFrameSize = math.random(1,64) + local outputFrameSize = math.random(1,64) + local ki = math.random(1,15) + local si = math.random(1,ki) + local outi = math.random(1,15) + local ini = (outi-1)*si+ki + local scale = math.random() + + local input = torch.randn(bs,ini,inputFrameSize):cuda() + local gradOutput = torch.randn(bs,outi,outputFrameSize):cuda() + local sconv = nn.TemporalConvolution(inputFrameSize,outputFrameSize, ki, si):cuda() + local groundForward = sconv:forward(input) + sconv:zeroGradParameters() + local groundgrad = sconv:backward(input, gradOutput, scale) + cutorch.synchronize() + local groundweight = sconv.gradWeight + local groundbias = sconv.gradBias + + local gconv = cudnn.TemporalConvolution(inputFrameSize,outputFrameSize, ki, si):cuda():fastest() + gconv.weight:copy(sconv.weight:view(gconv.weight:size())) + gconv.bias:copy(sconv.bias) + gconv:forward(input) + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + local cudaForward = gconv:forward(input) + gconv:zeroGradParameters() + local rescuda = gconv:backward(input, gradOutput, scale) + cutorch.synchronize() + local weightcuda = gconv.gradWeight + local biascuda = gconv.gradBias + + local ferror = cudaForward:float() - groundForward:float() + local error = rescuda:float() - groundgrad:float() + local werror = weightcuda:float() - groundweight:float() + local berror = biascuda:float() - groundbias:float() + mytester:assertlt(ferror:abs():max(), precision_forward, 'error on forward ') + mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') + mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ') + mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ') +end + +function cudnntest.TemporalConvolution_padding_batch() + local bs = math.random(1,32) + local inputFrameSize = math.random(1,64) + local outputFrameSize = math.random(1,64) + local ki = math.random(2,15) + local pad_h = math.floor(ki/2) + local si = math.random(1,ki) + local outi = math.random(1,15) + local ini = (outi-1)*si+ki + local scale = math.random() + + local inputpadded = torch.randn(bs,ini,inputFrameSize):cuda() + for i=1,pad_h do + inputpadded:narrow(2,i,1):fill(0) + inputpadded:narrow(2,ini-i+1,1):fill(0) + end + local input = torch.Tensor(bs,ini - 2 * pad_h, inputFrameSize):cuda() + input:copy(inputpadded:narrow(2, pad_h+1, ini - 2 * pad_h)) + local gradOutput = torch.randn(bs,outi,outputFrameSize):cuda() + local sconv = nn.TemporalConvolution(inputFrameSize,outputFrameSize, ki, si):cuda() + local groundForward = sconv:forward(inputpadded) + sconv:zeroGradParameters() + local groundgrad = sconv:backward(inputpadded, gradOutput, scale) + cutorch.synchronize() + local groundweight = sconv.gradWeight + local groundbias = sconv.gradBias + + local gconv = cudnn.TemporalConvolution(inputFrameSize,outputFrameSize, ki, si,pad_h):cuda():fastest() + gconv.weight:copy(sconv.weight:view(gconv.weight:size())) + gconv.bias:copy(sconv.bias) + gconv:forward(input) + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + local cudaForward = gconv:forward(input) + gconv:zeroGradParameters() + local rescuda = gconv:backward(input, gradOutput, scale) + cutorch.synchronize() + local weightcuda = gconv.gradWeight + local biascuda = gconv.gradBias + + local ferror = cudaForward:float() - groundForward:float() + groundgrad = groundgrad:narrow(2, pad_h + 1, ini - 2 * pad_h) + local error = rescuda:float() - groundgrad:float() + local werror = weightcuda:float() - groundweight:float() + local berror = biascuda:float() - groundbias:float() + mytester:assertlt(ferror:abs():max(), precision_forward, 'error on forward ') + mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') + mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ') + mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ') +end + + +function cudnntest.TemporalConvolution_single() + local inputFrameSize = math.random(1,64) + local outputFrameSize = math.random(1,64) + local ki = math.random(1,15) + local si = math.random(1,ki) + local outi = math.random(1,15) + local ini = (outi-1)*si+ki + local scale = math.random() + + local input = torch.randn(ini,inputFrameSize):cuda() + local gradOutput = torch.randn(outi,outputFrameSize):cuda() + local sconv = nn.TemporalConvolution(inputFrameSize,outputFrameSize, ki, si):cuda() + local groundForward = sconv:forward(input) + sconv:zeroGradParameters() + local groundgrad = sconv:backward(input, gradOutput, scale) + cutorch.synchronize() + local groundweight = sconv.gradWeight + local groundbias = sconv.gradBias + + local gconv = cudnn.TemporalConvolution(inputFrameSize,outputFrameSize, ki, si):cuda():fastest() + gconv.weight:copy(sconv.weight:view(gconv.weight:size())) + gconv.bias:copy(sconv.bias) + gconv:forward(input) + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + local cudaForward = gconv:forward(input) + gconv:zeroGradParameters() + local rescuda = gconv:backward(input, gradOutput, scale) + cutorch.synchronize() + local weightcuda = gconv.gradWeight + local biascuda = gconv.gradBias + + local ferror = cudaForward:float() - groundForward:float() + local error = rescuda:float() - groundgrad:float() + local werror = weightcuda:float() - groundweight:float() + local berror = biascuda:float() - groundbias:float() + mytester:assertlt(ferror:abs():max(), precision_forward, 'error on forward ') + mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') + mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ') + mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ') +end + + + function cudnntest.VolumetricConvolution_forward_single() local from = math.random(1,16) local to = math.random(1,16) |