diff options
author | soumith <soumith@fb.com> | 2016-08-04 02:04:43 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-08-04 06:00:02 +0300 |
commit | e82c40d454392e40dc8ad6d94e65e3b696429aa7 (patch) | |
tree | 515a10f8c907fc9d9708985f81d7a5d40550adaf /VolumetricDilatedConvolution.lua | |
parent | 5bb6470df6c7c89866272ce9d40deaf1b2044773 (diff) |
volumetric dilated convolution
Diffstat (limited to 'VolumetricDilatedConvolution.lua')
-rw-r--r-- | VolumetricDilatedConvolution.lua | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/VolumetricDilatedConvolution.lua b/VolumetricDilatedConvolution.lua new file mode 100644 index 0000000..fc7f037 --- /dev/null +++ b/VolumetricDilatedConvolution.lua @@ -0,0 +1,103 @@ +local THNN = require 'nn.THNN' +local VolumetricDilatedConvolution, parent = torch.class('nn.VolumetricDilatedConvolution', 'nn.VolumetricConvolution') + +function VolumetricDilatedConvolution:__init(nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH, dilationT, dilationW, dilationH) + parent.__init(self, nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH) + + self.dilationT = dilationT or 1 + self.dilationW = dilationW or 1 + self.dilationH = dilationH or 1 +end + +local function makeContiguous(self, input, gradOutput) + if not input:isContiguous() then + self._input = self._input or input.new() + self._input:resizeAs(input):copy(input) + input = self._input + end + if gradOutput then + if not gradOutput:isContiguous() then + self._gradOutput = self._gradOutput or gradOutput.new() + self._gradOutput:resizeAs(gradOutput):copy(gradOutput) + gradOutput = self._gradOutput + end + end + return input, gradOutput +end + +function VolumetricDilatedConvolution:updateOutput(input) + self.finput = self.finput or self.weight.new() + self.fgradInput = self.fgradInput or self.weight.new() + input = makeContiguous(self, input) + input.THNN.VolumetricDilatedConvolution_updateOutput( + input:cdata(), + self.output:cdata(), + self.weight:cdata(), + THNN.optionalTensor(self.bias), + self.finput:cdata(), + self.fgradInput:cdata(), + self.kT, self.kW, self.kH, + self.dT, self.dW, self.dH, + self.padT, self.padW, self.padH, + self.dilationT, self.dilationW, self.dilationH + ) + return self.output +end + +function VolumetricDilatedConvolution:updateGradInput(input, gradOutput) + if self.gradInput then + input, gradOutput = makeContiguous(self, input, gradOutput) + self.fgradInput = self.fgradInput or self.weight.new() + input.THNN.VolumetricDilatedConvolution_updateGradInput( + input:cdata(), + gradOutput:cdata(), + self.gradInput:cdata(), + self.weight:cdata(), + self.finput:cdata(), + self.kT, self.kW, self.kH, + self.dT, self.dW, self.dH, + self.padT, self.padW, self.padH, + self.dilationT, self.dilationW, self.dilationH + ) + return self.gradInput + end +end + +function VolumetricDilatedConvolution:accGradParameters(input, gradOutput, scale) + scale = scale or 1 + input, gradOutput = makeContiguous(self, input, gradOutput) + self.fgradInput = self.fgradInput or self.weight.new() + input.THNN.VolumetricDilatedConvolution_accGradParameters( + input:cdata(), + gradOutput:cdata(), + self.gradWeight:cdata(), + THNN.optionalTensor(self.gradBias), + self.finput:cdata(), + self.fgradInput:cdata(), + self.kT, self.kW, self.kH, + self.dT, self.dW, self.dH, + self.padT, self.padW, self.padH, + self.dilationT, self.dilationW, self.dilationH, + scale + ) +end + +function VolumetricDilatedConvolution:__tostring__() + local s = string.format('%s(%d -> %d, %dx%dx%d', torch.type(self), + self.nInputPlane, self.nOutputPlane, self.kT, self.kW, self.kH) + if self.dT ~= 1 or self.dW ~= 1 or self.dH ~= 1 + or self.padT ~= 0 or self.padW ~= 0 or self.padH ~= 0 then + s = s .. string.format(', %d,%d,%d', self.dT, self.dW, self.dH) + end + if (self.padT or self.padW or self.padH) + and (self.padT ~= 0 or self.padW ~= 0 or self.padH ~= 0) then + s = s .. ', ' .. self.padT .. ',' .. self.padW .. ',' .. self.padH + end + s = s .. ', ' .. self.dilationT .. ',' + .. self.dilationW .. ',' .. self.dilationH + if self.bias then + return s .. ')' + else + return s .. ') without bias' + end +end |