Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-08-04 02:04:43 +0300
committersoumith <soumith@fb.com>2016-08-04 06:00:02 +0300
commite82c40d454392e40dc8ad6d94e65e3b696429aa7 (patch)
tree515a10f8c907fc9d9708985f81d7a5d40550adaf /VolumetricDilatedConvolution.lua
parent5bb6470df6c7c89866272ce9d40deaf1b2044773 (diff)
volumetric dilated convolution
Diffstat (limited to 'VolumetricDilatedConvolution.lua')
-rw-r--r--VolumetricDilatedConvolution.lua103
1 files changed, 103 insertions, 0 deletions
diff --git a/VolumetricDilatedConvolution.lua b/VolumetricDilatedConvolution.lua
new file mode 100644
index 0000000..fc7f037
--- /dev/null
+++ b/VolumetricDilatedConvolution.lua
@@ -0,0 +1,103 @@
+local THNN = require 'nn.THNN'
+local VolumetricDilatedConvolution, parent = torch.class('nn.VolumetricDilatedConvolution', 'nn.VolumetricConvolution')
+
+function VolumetricDilatedConvolution:__init(nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH, dilationT, dilationW, dilationH)
+ parent.__init(self, nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH)
+
+ self.dilationT = dilationT or 1
+ self.dilationW = dilationW or 1
+ self.dilationH = dilationH or 1
+end
+
+local function makeContiguous(self, input, gradOutput)
+ if not input:isContiguous() then
+ self._input = self._input or input.new()
+ self._input:resizeAs(input):copy(input)
+ input = self._input
+ end
+ if gradOutput then
+ if not gradOutput:isContiguous() then
+ self._gradOutput = self._gradOutput or gradOutput.new()
+ self._gradOutput:resizeAs(gradOutput):copy(gradOutput)
+ gradOutput = self._gradOutput
+ end
+ end
+ return input, gradOutput
+end
+
+function VolumetricDilatedConvolution:updateOutput(input)
+ self.finput = self.finput or self.weight.new()
+ self.fgradInput = self.fgradInput or self.weight.new()
+ input = makeContiguous(self, input)
+ input.THNN.VolumetricDilatedConvolution_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.weight:cdata(),
+ THNN.optionalTensor(self.bias),
+ self.finput:cdata(),
+ self.fgradInput:cdata(),
+ self.kT, self.kW, self.kH,
+ self.dT, self.dW, self.dH,
+ self.padT, self.padW, self.padH,
+ self.dilationT, self.dilationW, self.dilationH
+ )
+ return self.output
+end
+
+function VolumetricDilatedConvolution:updateGradInput(input, gradOutput)
+ if self.gradInput then
+ input, gradOutput = makeContiguous(self, input, gradOutput)
+ self.fgradInput = self.fgradInput or self.weight.new()
+ input.THNN.VolumetricDilatedConvolution_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.weight:cdata(),
+ self.finput:cdata(),
+ self.kT, self.kW, self.kH,
+ self.dT, self.dW, self.dH,
+ self.padT, self.padW, self.padH,
+ self.dilationT, self.dilationW, self.dilationH
+ )
+ return self.gradInput
+ end
+end
+
+function VolumetricDilatedConvolution:accGradParameters(input, gradOutput, scale)
+ scale = scale or 1
+ input, gradOutput = makeContiguous(self, input, gradOutput)
+ self.fgradInput = self.fgradInput or self.weight.new()
+ input.THNN.VolumetricDilatedConvolution_accGradParameters(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradWeight:cdata(),
+ THNN.optionalTensor(self.gradBias),
+ self.finput:cdata(),
+ self.fgradInput:cdata(),
+ self.kT, self.kW, self.kH,
+ self.dT, self.dW, self.dH,
+ self.padT, self.padW, self.padH,
+ self.dilationT, self.dilationW, self.dilationH,
+ scale
+ )
+end
+
+function VolumetricDilatedConvolution:__tostring__()
+ local s = string.format('%s(%d -> %d, %dx%dx%d', torch.type(self),
+ self.nInputPlane, self.nOutputPlane, self.kT, self.kW, self.kH)
+ if self.dT ~= 1 or self.dW ~= 1 or self.dH ~= 1
+ or self.padT ~= 0 or self.padW ~= 0 or self.padH ~= 0 then
+ s = s .. string.format(', %d,%d,%d', self.dT, self.dW, self.dH)
+ end
+ if (self.padT or self.padW or self.padH)
+ and (self.padT ~= 0 or self.padW ~= 0 or self.padH ~= 0) then
+ s = s .. ', ' .. self.padT .. ',' .. self.padW .. ',' .. self.padH
+ end
+ s = s .. ', ' .. self.dilationT .. ','
+ .. self.dilationW .. ',' .. self.dilationH
+ if self.bias then
+ return s .. ')'
+ else
+ return s .. ') without bias'
+ end
+end