diff options
author | soumith <soumith@fb.com> | 2016-02-25 03:59:30 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-02-25 03:59:30 +0300 |
commit | 47f58cff33cc86e52b72b02c468a72d810e4216d (patch) | |
tree | a47da6251611d931d9465166755aa925f117b2bd /SpatialReflectionPadding.lua | |
parent | bcdda102686713037538d770ee5c033879450ee3 (diff) |
Adding SpatialReflection and SpatialReplication padding
Diffstat (limited to 'SpatialReflectionPadding.lua')
-rw-r--r-- | SpatialReflectionPadding.lua | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/SpatialReflectionPadding.lua b/SpatialReflectionPadding.lua new file mode 100644 index 0000000..f626591 --- /dev/null +++ b/SpatialReflectionPadding.lua @@ -0,0 +1,51 @@ +local SpatialReflectionPadding, parent = + torch.class('nn.SpatialReflectionPadding', 'nn.Module') + +function SpatialReflectionPadding:__init(pad_l, pad_r, pad_t, pad_b) + parent.__init(self) + self.pad_l = pad_l + self.pad_r = pad_r or self.pad_l + self.pad_t = pad_t or self.pad_l + self.pad_b = pad_b or self.pad_l +end + +function SpatialReflectionPadding:updateOutput(input) + if input:dim() == 3 or input:dim() == 4 then + input.THNN.SpatialReflectionPadding_updateOutput( + input:cdata(), self.output:cdata(), + self.pad_l, self.pad_r, self.pad_t, self.pad_b) + else + error('input must be 3 or 4-dimensional') + end + return self.output +end + +function SpatialReflectionPadding:updateGradInput(input, gradOutput) + if input:dim() == 3 and gradOutput:dim() == 3 then + assert(input:size(1) == gradOutput:size(1) + and input:size(2) + self.pad_t + self.pad_b == gradOutput:size(2) + and input:size(3) + self.pad_l + self.pad_r == gradOutput:size(3), + 'input and gradOutput must be compatible in size') + elseif input:dim() == 4 and gradOutput:dim() == 4 then + assert(input:size(1) == gradOutput:size(1) + and input:size(2) == gradOutput:size(2) + and input:size(3) + self.pad_t + self.pad_b == gradOutput:size(3) + and input:size(4) + self.pad_l + self.pad_r == gradOutput:size(4), + 'input and gradOutput must be compatible in size') + else + error( + [[input and gradOutput must be 3 or 4-dimensional + and have equal number of dimensions]] + ) + end + input.THNN.SpatialReflectionPadding_updateGradInput( + input:cdata(), gradOutput:cdata(), self.gradInput:cdata(), + self.pad_l, self.pad_r, self.pad_t, self.pad_b) + return self.gradInput +end + +function SpatialReflectionPadding:__tostring__() + return torch.type(self) .. + string.format('(l=%d,r=%d,t=%d,b=%d)', self.pad_l, self.pad_r, + self.pad_t, self.pad_b) +end |