diff options
author | kmul00 <coolkoustav@gmail.com> | 2016-08-28 16:27:28 +0300 |
---|---|---|
committer | kmul00 <coolkoustav@gmail.com> | 2016-08-28 17:08:52 +0300 |
commit | ddebc7227dd82bd518954da24bedfde2afb07965 (patch) | |
tree | d11dd01524ffdcd50fdde18ae30a64631f094b8b /VolumetricDilatedMaxPooling.lua | |
parent | c1232cd0daa6ea20289a6c63c6f4c8246c9af3a9 (diff) |
Add Volumetric Dilated Max Pooling
new file: VolumetricDilatedMaxPooling.lua
modified: init.lua
modified: lib/THNN/generic/THNN.h
copied: lib/THNN/generic/VolumetricMaxPooling.c -> lib/THNN/generic/VolumetricDilatedMaxPooling.c
modified: lib/THNN/generic/VolumetricMaxPooling.c
modified: lib/THNN/init.c
modified: test.lua
Update docs
modified: doc/convolution.md
Resolved bug in ceil mode
modified: lib/THNN/generic/VolumetricDilatedMaxPooling.c
Diffstat (limited to 'VolumetricDilatedMaxPooling.lua')
-rw-r--r-- | VolumetricDilatedMaxPooling.lua | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/VolumetricDilatedMaxPooling.lua b/VolumetricDilatedMaxPooling.lua new file mode 100644 index 0000000..050e2c9 --- /dev/null +++ b/VolumetricDilatedMaxPooling.lua @@ -0,0 +1,64 @@ +local THNN = require 'nn.THNN' +local VolumetricDilatedMaxPooling, parent = torch.class('nn.VolumetricDilatedMaxPooling', 'nn.VolumetricMaxPooling') + +function VolumetricDilatedMaxPooling:__init(kT, kW, kH, dT, dW, dH, padT, padW, padH, dilationT, dilationW, dilationH) + parent.__init(self, kT, kW, kH, dT, dW, dH, padT, padW, padH) + + self.dilationT = dilationT or 1 + self.dilationW = dilationW or 1 + self.dilationH = dilationH or 1 + +end + +function VolumetricDilatedMaxPooling:updateOutput(input) + local dims = input:dim() + self.itime = input:size(dims-2) + self.iheight = input:size(dims-1) + self.iwidth = input:size(dims) + + self.indices = self.indices or input.new() + input.THNN.VolumetricDilatedMaxPooling_updateOutput( + input:cdata(), + self.output:cdata(), + self.indices:cdata(), + self.kT, self.kW, self.kH, + self.dT, self.dW, self.dH, + self.padT, self.padW, self.padH, + self.dilationT, self.dilationW, self.dilationH, + self.ceil_mode + ) + return self.output +end + +function VolumetricDilatedMaxPooling:updateGradInput(input, gradOutput) + input.THNN.VolumetricDilatedMaxPooling_updateGradInput( + input:cdata(), + gradOutput:cdata(), + self.gradInput:cdata(), + self.indices:cdata(), + self.dT, self.dW, self.dH, + self.padT, self.padW, self.padH, + self.dilationT, self.dilationW, self.dilationH + ) + return self.gradInput +end + +function VolumetricDilatedMaxPooling:clearState() + if self.indices then + self.indices:set() + end + return parent.clearState(self) +end + +function VolumetricDilatedMaxPooling:__tostring__() + local s = string.format('%s(%dx%dx%d, %d,%d,%d', torch.type(self), + self.kT, self.kW, self.kH, self.dT, self.dW, self.dH) + if (self.padT or self.padW or self.padH) and + (self.padT ~= 0 or self.padW ~= 0 or self.padH ~= 0) then + s = s .. ', ' .. self.padT.. ',' .. self.padW .. ','.. self.padH + end + s = s .. ', ' .. self.dilationT .. ',' .. self.dilationW .. ',' .. self.dilationH + s = s .. ')' + + return s +end |