Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkmul00 <coolkoustav@gmail.com>2016-08-28 16:27:28 +0300
committerkmul00 <coolkoustav@gmail.com>2016-08-28 17:08:52 +0300
commitddebc7227dd82bd518954da24bedfde2afb07965 (patch)
treed11dd01524ffdcd50fdde18ae30a64631f094b8b /VolumetricDilatedMaxPooling.lua
parentc1232cd0daa6ea20289a6c63c6f4c8246c9af3a9 (diff)
Add Volumetric Dilated Max Pooling
new file: VolumetricDilatedMaxPooling.lua modified: init.lua modified: lib/THNN/generic/THNN.h copied: lib/THNN/generic/VolumetricMaxPooling.c -> lib/THNN/generic/VolumetricDilatedMaxPooling.c modified: lib/THNN/generic/VolumetricMaxPooling.c modified: lib/THNN/init.c modified: test.lua Update docs modified: doc/convolution.md Resolved bug in ceil mode modified: lib/THNN/generic/VolumetricDilatedMaxPooling.c
Diffstat (limited to 'VolumetricDilatedMaxPooling.lua')
-rw-r--r--VolumetricDilatedMaxPooling.lua64
1 files changed, 64 insertions, 0 deletions
diff --git a/VolumetricDilatedMaxPooling.lua b/VolumetricDilatedMaxPooling.lua
new file mode 100644
index 0000000..050e2c9
--- /dev/null
+++ b/VolumetricDilatedMaxPooling.lua
@@ -0,0 +1,64 @@
+local THNN = require 'nn.THNN'
+local VolumetricDilatedMaxPooling, parent = torch.class('nn.VolumetricDilatedMaxPooling', 'nn.VolumetricMaxPooling')
+
+function VolumetricDilatedMaxPooling:__init(kT, kW, kH, dT, dW, dH, padT, padW, padH, dilationT, dilationW, dilationH)
+ parent.__init(self, kT, kW, kH, dT, dW, dH, padT, padW, padH)
+
+ self.dilationT = dilationT or 1
+ self.dilationW = dilationW or 1
+ self.dilationH = dilationH or 1
+
+end
+
+function VolumetricDilatedMaxPooling:updateOutput(input)
+ local dims = input:dim()
+ self.itime = input:size(dims-2)
+ self.iheight = input:size(dims-1)
+ self.iwidth = input:size(dims)
+
+ self.indices = self.indices or input.new()
+ input.THNN.VolumetricDilatedMaxPooling_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.indices:cdata(),
+ self.kT, self.kW, self.kH,
+ self.dT, self.dW, self.dH,
+ self.padT, self.padW, self.padH,
+ self.dilationT, self.dilationW, self.dilationH,
+ self.ceil_mode
+ )
+ return self.output
+end
+
+function VolumetricDilatedMaxPooling:updateGradInput(input, gradOutput)
+ input.THNN.VolumetricDilatedMaxPooling_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.indices:cdata(),
+ self.dT, self.dW, self.dH,
+ self.padT, self.padW, self.padH,
+ self.dilationT, self.dilationW, self.dilationH
+ )
+ return self.gradInput
+end
+
+function VolumetricDilatedMaxPooling:clearState()
+ if self.indices then
+ self.indices:set()
+ end
+ return parent.clearState(self)
+end
+
+function VolumetricDilatedMaxPooling:__tostring__()
+ local s = string.format('%s(%dx%dx%d, %d,%d,%d', torch.type(self),
+ self.kT, self.kW, self.kH, self.dT, self.dW, self.dH)
+ if (self.padT or self.padW or self.padH) and
+ (self.padT ~= 0 or self.padW ~= 0 or self.padH ~= 0) then
+ s = s .. ', ' .. self.padT.. ',' .. self.padW .. ','.. self.padH
+ end
+ s = s .. ', ' .. self.dilationT .. ',' .. self.dilationW .. ',' .. self.dilationH
+ s = s .. ')'
+
+ return s
+end