diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-10-22 23:19:16 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-10-22 23:19:16 +0400 |
commit | b8addac4e7fe7171e42886df7984cff3fca35ef8 (patch) | |
tree | ef825c3202478bd6c00389efd632bacf77c56500 | |
parent | 1703ce686332ecf4f98f28e94ef1820c5c2a5e63 (diff) |
batch verstion of padding module.
-rw-r--r-- | SpatialZeroPadding.lua | 116 |
1 files changed, 80 insertions, 36 deletions
diff --git a/SpatialZeroPadding.lua b/SpatialZeroPadding.lua index af03e71..8e3756d 100644 --- a/SpatialZeroPadding.lua +++ b/SpatialZeroPadding.lua @@ -9,45 +9,89 @@ function SpatialZeroPadding:__init(pad_l, pad_r, pad_t, pad_b) end function SpatialZeroPadding:updateOutput(input) - if input:dim() ~= 3 then error('input must be 3-dimensional') end - local h = input:size(2) + self.pad_t + self.pad_b - local w = input:size(3) + self.pad_l + self.pad_r - if w < 1 or h < 1 then error('input is too small') end - self.output:resize(input:size(1), h, w) - self.output:zero() - -- crop input if necessary - local c_input = input - if self.pad_t < 0 then c_input = c_input:narrow(2, 1 - self.pad_t, c_input:size(2) + self.pad_t) end - if self.pad_b < 0 then c_input = c_input:narrow(2, 1, c_input:size(2) + self.pad_b) end - if self.pad_l < 0 then c_input = c_input:narrow(3, 1 - self.pad_l, c_input:size(3) + self.pad_l) end - if self.pad_r < 0 then c_input = c_input:narrow(3, 1, c_input:size(3) + self.pad_r) end - -- crop outout if necessary - local c_output = self.output - if self.pad_t > 0 then c_output = c_output:narrow(2, 1 + self.pad_t, c_output:size(2) - self.pad_t) end - if self.pad_b > 0 then c_output = c_output:narrow(2, 1, c_output:size(2) - self.pad_b) end - if self.pad_l > 0 then c_output = c_output:narrow(3, 1 + self.pad_l, c_output:size(3) - self.pad_l) end - if self.pad_r > 0 then c_output = c_output:narrow(3, 1, c_output:size(3) - self.pad_r) end - -- copy input to output - c_output:copy(c_input) + if input:dim() == 3 then + -- sizes + local h = input:size(2) + self.pad_t + self.pad_b + local w = input:size(3) + self.pad_l + self.pad_r + if w < 1 or h < 1 then error('input is too small') end + self.output:resize(input:size(1), h, w) + self.output:zero() + -- crop input if necessary + local c_input = input + if self.pad_t < 0 then c_input = c_input:narrow(2, 1 - self.pad_t, c_input:size(2) + self.pad_t) end + if self.pad_b < 0 then c_input = c_input:narrow(2, 1, c_input:size(2) + self.pad_b) end + if self.pad_l < 0 then c_input = c_input:narrow(3, 1 - self.pad_l, c_input:size(3) + self.pad_l) end + if self.pad_r < 0 then c_input = c_input:narrow(3, 1, c_input:size(3) + self.pad_r) end + -- crop outout if necessary + local c_output = self.output + if self.pad_t > 0 then c_output = c_output:narrow(2, 1 + self.pad_t, c_output:size(2) - self.pad_t) end + if self.pad_b > 0 then c_output = c_output:narrow(2, 1, c_output:size(2) - self.pad_b) end + if self.pad_l > 0 then c_output = c_output:narrow(3, 1 + self.pad_l, c_output:size(3) - self.pad_l) end + if self.pad_r > 0 then c_output = c_output:narrow(3, 1, c_output:size(3) - self.pad_r) end + -- copy input to output + c_output:copy(c_input) + elseif input:dim() == 4 then + -- sizes + local h = input:size(3) + self.pad_t + self.pad_b + local w = input:size(4) + self.pad_l + self.pad_r + if w < 1 or h < 1 then error('input is too small') end + self.output:resize(input:size(1), input:size(2), h, w) + self.output:zero() + -- crop input if necessary + local c_input = input + if self.pad_t < 0 then c_input = c_input:narrow(3, 1 - self.pad_t, c_input:size(3) + self.pad_t) end + if self.pad_b < 0 then c_input = c_input:narrow(3, 1, c_input:size(3) + self.pad_b) end + if self.pad_l < 0 then c_input = c_input:narrow(4, 1 - self.pad_l, c_input:size(4) + self.pad_l) end + if self.pad_r < 0 then c_input = c_input:narrow(4, 1, c_input:size(4) + self.pad_r) end + -- crop outout if necessary + local c_output = self.output + if self.pad_t > 0 then c_output = c_output:narrow(3, 1 + self.pad_t, c_output:size(3) - self.pad_t) end + if self.pad_b > 0 then c_output = c_output:narrow(3, 1, c_output:size(3) - self.pad_b) end + if self.pad_l > 0 then c_output = c_output:narrow(4, 1 + self.pad_l, c_output:size(4) - self.pad_l) end + if self.pad_r > 0 then c_output = c_output:narrow(4, 1, c_output:size(4) - self.pad_r) end + -- copy input to output + c_output:copy(c_input) + else + error('input must be 3 or 4-dimensional') + end return self.output end function SpatialZeroPadding:updateGradInput(input, gradOutput) - if input:dim() ~= 3 then error('input must be 3-dimensional') end - self.gradInput:resizeAs(input):zero() - -- crop gradInput if necessary - local cg_input = self.gradInput - if self.pad_t < 0 then cg_input = cg_input:narrow(2, 1 - self.pad_t, cg_input:size(2) + self.pad_t) end - if self.pad_b < 0 then cg_input = cg_input:narrow(2, 1, cg_input:size(2) + self.pad_b) end - if self.pad_l < 0 then cg_input = cg_input:narrow(3, 1 - self.pad_l, cg_input:size(3) + self.pad_l) end - if self.pad_r < 0 then cg_input = cg_input:narrow(3, 1, cg_input:size(3) + self.pad_r) end - -- crop gradOutout if necessary - local cg_output = gradOutput - if self.pad_t > 0 then cg_output = cg_output:narrow(2, 1 + self.pad_t, cg_output:size(2) - self.pad_t) end - if self.pad_b > 0 then cg_output = cg_output:narrow(2, 1, cg_output:size(2) - self.pad_b) end - if self.pad_l > 0 then cg_output = cg_output:narrow(3, 1 + self.pad_l, cg_output:size(3) - self.pad_l) end - if self.pad_r > 0 then cg_output = cg_output:narrow(3, 1, cg_output:size(3) - self.pad_r) end - -- copy gradOuput to gradInput - cg_input:copy(cg_output) + if input:dim() == 3 then + self.gradInput:resizeAs(input):zero() + -- crop gradInput if necessary + local cg_input = self.gradInput + if self.pad_t < 0 then cg_input = cg_input:narrow(2, 1 - self.pad_t, cg_input:size(2) + self.pad_t) end + if self.pad_b < 0 then cg_input = cg_input:narrow(2, 1, cg_input:size(2) + self.pad_b) end + if self.pad_l < 0 then cg_input = cg_input:narrow(3, 1 - self.pad_l, cg_input:size(3) + self.pad_l) end + if self.pad_r < 0 then cg_input = cg_input:narrow(3, 1, cg_input:size(3) + self.pad_r) end + -- crop gradOutout if necessary + local cg_output = gradOutput + if self.pad_t > 0 then cg_output = cg_output:narrow(2, 1 + self.pad_t, cg_output:size(2) - self.pad_t) end + if self.pad_b > 0 then cg_output = cg_output:narrow(2, 1, cg_output:size(2) - self.pad_b) end + if self.pad_l > 0 then cg_output = cg_output:narrow(3, 1 + self.pad_l, cg_output:size(3) - self.pad_l) end + if self.pad_r > 0 then cg_output = cg_output:narrow(3, 1, cg_output:size(3) - self.pad_r) end + -- copy gradOuput to gradInput + cg_input:copy(cg_output) + elseif input:dim() == 4 then + self.gradInput:resizeAs(input):zero() + -- crop gradInput if necessary + local cg_input = self.gradInput + if self.pad_t < 0 then cg_input = cg_input:narrow(3, 1 - self.pad_t, cg_input:size(3) + self.pad_t) end + if self.pad_b < 0 then cg_input = cg_input:narrow(3, 1, cg_input:size(3) + self.pad_b) end + if self.pad_l < 0 then cg_input = cg_input:narrow(4, 1 - self.pad_l, cg_input:size(4) + self.pad_l) end + if self.pad_r < 0 then cg_input = cg_input:narrow(4, 1, cg_input:size(4) + self.pad_r) end + -- crop gradOutout if necessary + local cg_output = gradOutput + if self.pad_t > 0 then cg_output = cg_output:narrow(3, 1 + self.pad_t, cg_output:size(3) - self.pad_t) end + if self.pad_b > 0 then cg_output = cg_output:narrow(3, 1, cg_output:size(3) - self.pad_b) end + if self.pad_l > 0 then cg_output = cg_output:narrow(4, 1 + self.pad_l, cg_output:size(4) - self.pad_l) end + if self.pad_r > 0 then cg_output = cg_output:narrow(4, 1, cg_output:size(4) - self.pad_r) end + -- copy gradOuput to gradInput + cg_input:copy(cg_output) + else + error('input must be 3 or 4-dimensional') + end return self.gradInput end |