Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2012-10-22 23:19:16 +0400
committerClement Farabet <clement.farabet@gmail.com>2012-10-22 23:19:16 +0400
commitb8addac4e7fe7171e42886df7984cff3fca35ef8 (patch)
treeef825c3202478bd6c00389efd632bacf77c56500
parent1703ce686332ecf4f98f28e94ef1820c5c2a5e63 (diff)
batch verstion of padding module.
-rw-r--r--SpatialZeroPadding.lua116
1 files changed, 80 insertions, 36 deletions
diff --git a/SpatialZeroPadding.lua b/SpatialZeroPadding.lua
index af03e71..8e3756d 100644
--- a/SpatialZeroPadding.lua
+++ b/SpatialZeroPadding.lua
@@ -9,45 +9,89 @@ function SpatialZeroPadding:__init(pad_l, pad_r, pad_t, pad_b)
end
function SpatialZeroPadding:updateOutput(input)
- if input:dim() ~= 3 then error('input must be 3-dimensional') end
- local h = input:size(2) + self.pad_t + self.pad_b
- local w = input:size(3) + self.pad_l + self.pad_r
- if w < 1 or h < 1 then error('input is too small') end
- self.output:resize(input:size(1), h, w)
- self.output:zero()
- -- crop input if necessary
- local c_input = input
- if self.pad_t < 0 then c_input = c_input:narrow(2, 1 - self.pad_t, c_input:size(2) + self.pad_t) end
- if self.pad_b < 0 then c_input = c_input:narrow(2, 1, c_input:size(2) + self.pad_b) end
- if self.pad_l < 0 then c_input = c_input:narrow(3, 1 - self.pad_l, c_input:size(3) + self.pad_l) end
- if self.pad_r < 0 then c_input = c_input:narrow(3, 1, c_input:size(3) + self.pad_r) end
- -- crop outout if necessary
- local c_output = self.output
- if self.pad_t > 0 then c_output = c_output:narrow(2, 1 + self.pad_t, c_output:size(2) - self.pad_t) end
- if self.pad_b > 0 then c_output = c_output:narrow(2, 1, c_output:size(2) - self.pad_b) end
- if self.pad_l > 0 then c_output = c_output:narrow(3, 1 + self.pad_l, c_output:size(3) - self.pad_l) end
- if self.pad_r > 0 then c_output = c_output:narrow(3, 1, c_output:size(3) - self.pad_r) end
- -- copy input to output
- c_output:copy(c_input)
+ if input:dim() == 3 then
+ -- sizes
+ local h = input:size(2) + self.pad_t + self.pad_b
+ local w = input:size(3) + self.pad_l + self.pad_r
+ if w < 1 or h < 1 then error('input is too small') end
+ self.output:resize(input:size(1), h, w)
+ self.output:zero()
+ -- crop input if necessary
+ local c_input = input
+ if self.pad_t < 0 then c_input = c_input:narrow(2, 1 - self.pad_t, c_input:size(2) + self.pad_t) end
+ if self.pad_b < 0 then c_input = c_input:narrow(2, 1, c_input:size(2) + self.pad_b) end
+ if self.pad_l < 0 then c_input = c_input:narrow(3, 1 - self.pad_l, c_input:size(3) + self.pad_l) end
+ if self.pad_r < 0 then c_input = c_input:narrow(3, 1, c_input:size(3) + self.pad_r) end
+ -- crop outout if necessary
+ local c_output = self.output
+ if self.pad_t > 0 then c_output = c_output:narrow(2, 1 + self.pad_t, c_output:size(2) - self.pad_t) end
+ if self.pad_b > 0 then c_output = c_output:narrow(2, 1, c_output:size(2) - self.pad_b) end
+ if self.pad_l > 0 then c_output = c_output:narrow(3, 1 + self.pad_l, c_output:size(3) - self.pad_l) end
+ if self.pad_r > 0 then c_output = c_output:narrow(3, 1, c_output:size(3) - self.pad_r) end
+ -- copy input to output
+ c_output:copy(c_input)
+ elseif input:dim() == 4 then
+ -- sizes
+ local h = input:size(3) + self.pad_t + self.pad_b
+ local w = input:size(4) + self.pad_l + self.pad_r
+ if w < 1 or h < 1 then error('input is too small') end
+ self.output:resize(input:size(1), input:size(2), h, w)
+ self.output:zero()
+ -- crop input if necessary
+ local c_input = input
+ if self.pad_t < 0 then c_input = c_input:narrow(3, 1 - self.pad_t, c_input:size(3) + self.pad_t) end
+ if self.pad_b < 0 then c_input = c_input:narrow(3, 1, c_input:size(3) + self.pad_b) end
+ if self.pad_l < 0 then c_input = c_input:narrow(4, 1 - self.pad_l, c_input:size(4) + self.pad_l) end
+ if self.pad_r < 0 then c_input = c_input:narrow(4, 1, c_input:size(4) + self.pad_r) end
+ -- crop outout if necessary
+ local c_output = self.output
+ if self.pad_t > 0 then c_output = c_output:narrow(3, 1 + self.pad_t, c_output:size(3) - self.pad_t) end
+ if self.pad_b > 0 then c_output = c_output:narrow(3, 1, c_output:size(3) - self.pad_b) end
+ if self.pad_l > 0 then c_output = c_output:narrow(4, 1 + self.pad_l, c_output:size(4) - self.pad_l) end
+ if self.pad_r > 0 then c_output = c_output:narrow(4, 1, c_output:size(4) - self.pad_r) end
+ -- copy input to output
+ c_output:copy(c_input)
+ else
+ error('input must be 3 or 4-dimensional')
+ end
return self.output
end
function SpatialZeroPadding:updateGradInput(input, gradOutput)
- if input:dim() ~= 3 then error('input must be 3-dimensional') end
- self.gradInput:resizeAs(input):zero()
- -- crop gradInput if necessary
- local cg_input = self.gradInput
- if self.pad_t < 0 then cg_input = cg_input:narrow(2, 1 - self.pad_t, cg_input:size(2) + self.pad_t) end
- if self.pad_b < 0 then cg_input = cg_input:narrow(2, 1, cg_input:size(2) + self.pad_b) end
- if self.pad_l < 0 then cg_input = cg_input:narrow(3, 1 - self.pad_l, cg_input:size(3) + self.pad_l) end
- if self.pad_r < 0 then cg_input = cg_input:narrow(3, 1, cg_input:size(3) + self.pad_r) end
- -- crop gradOutout if necessary
- local cg_output = gradOutput
- if self.pad_t > 0 then cg_output = cg_output:narrow(2, 1 + self.pad_t, cg_output:size(2) - self.pad_t) end
- if self.pad_b > 0 then cg_output = cg_output:narrow(2, 1, cg_output:size(2) - self.pad_b) end
- if self.pad_l > 0 then cg_output = cg_output:narrow(3, 1 + self.pad_l, cg_output:size(3) - self.pad_l) end
- if self.pad_r > 0 then cg_output = cg_output:narrow(3, 1, cg_output:size(3) - self.pad_r) end
- -- copy gradOuput to gradInput
- cg_input:copy(cg_output)
+ if input:dim() == 3 then
+ self.gradInput:resizeAs(input):zero()
+ -- crop gradInput if necessary
+ local cg_input = self.gradInput
+ if self.pad_t < 0 then cg_input = cg_input:narrow(2, 1 - self.pad_t, cg_input:size(2) + self.pad_t) end
+ if self.pad_b < 0 then cg_input = cg_input:narrow(2, 1, cg_input:size(2) + self.pad_b) end
+ if self.pad_l < 0 then cg_input = cg_input:narrow(3, 1 - self.pad_l, cg_input:size(3) + self.pad_l) end
+ if self.pad_r < 0 then cg_input = cg_input:narrow(3, 1, cg_input:size(3) + self.pad_r) end
+ -- crop gradOutout if necessary
+ local cg_output = gradOutput
+ if self.pad_t > 0 then cg_output = cg_output:narrow(2, 1 + self.pad_t, cg_output:size(2) - self.pad_t) end
+ if self.pad_b > 0 then cg_output = cg_output:narrow(2, 1, cg_output:size(2) - self.pad_b) end
+ if self.pad_l > 0 then cg_output = cg_output:narrow(3, 1 + self.pad_l, cg_output:size(3) - self.pad_l) end
+ if self.pad_r > 0 then cg_output = cg_output:narrow(3, 1, cg_output:size(3) - self.pad_r) end
+ -- copy gradOuput to gradInput
+ cg_input:copy(cg_output)
+ elseif input:dim() == 4 then
+ self.gradInput:resizeAs(input):zero()
+ -- crop gradInput if necessary
+ local cg_input = self.gradInput
+ if self.pad_t < 0 then cg_input = cg_input:narrow(3, 1 - self.pad_t, cg_input:size(3) + self.pad_t) end
+ if self.pad_b < 0 then cg_input = cg_input:narrow(3, 1, cg_input:size(3) + self.pad_b) end
+ if self.pad_l < 0 then cg_input = cg_input:narrow(4, 1 - self.pad_l, cg_input:size(4) + self.pad_l) end
+ if self.pad_r < 0 then cg_input = cg_input:narrow(4, 1, cg_input:size(4) + self.pad_r) end
+ -- crop gradOutout if necessary
+ local cg_output = gradOutput
+ if self.pad_t > 0 then cg_output = cg_output:narrow(3, 1 + self.pad_t, cg_output:size(3) - self.pad_t) end
+ if self.pad_b > 0 then cg_output = cg_output:narrow(3, 1, cg_output:size(3) - self.pad_b) end
+ if self.pad_l > 0 then cg_output = cg_output:narrow(4, 1 + self.pad_l, cg_output:size(4) - self.pad_l) end
+ if self.pad_r > 0 then cg_output = cg_output:narrow(4, 1, cg_output:size(4) - self.pad_r) end
+ -- copy gradOuput to gradInput
+ cg_input:copy(cg_output)
+ else
+ error('input must be 3 or 4-dimensional')
+ end
return self.gradInput
end