Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-08 05:34:17 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-08 05:34:17 +0400
commit60a6a1342e7ee26fc885fd50ff224e023ec3c46d (patch)
tree043420c4edefd4187521d83f2ee4924b87283b21 /SpatialPadding.lua
parent0d3f8772d0d5872979af0685f40548ef919f684c (diff)
Added SpatialNormalization module. Still buggy with non-square kernels.
Diffstat (limited to 'SpatialPadding.lua')
-rw-r--r--SpatialPadding.lua12
1 files changed, 3 insertions, 9 deletions
diff --git a/SpatialPadding.lua b/SpatialPadding.lua
index 6d7bd11..9c87c2b 100644
--- a/SpatialPadding.lua
+++ b/SpatialPadding.lua
@@ -19,17 +19,11 @@ function SpatialPadding:__init(pad_l, pad_r, pad_t, pad_b)
self.pad_b = pad_b or self.pad_l
end
-function SpatialPadding:setPadding(pad_l, pad_r, pad_t, pad_b)
- self.pad_l = pad_l or 0
- self.pad_r = pad_r or self.pad_l
- self.pad_t = pad_t or self.pad_l
- self.pad_b = pad_b or self.pad_l
-end
-
function SpatialPadding:forward(input)
+ if input:dim() ~= 3 then error('input must be 3-dimensional') end
local h = input:size(2) + self.pad_t + self.pad_b
local w = input:size(3) + self.pad_l + self.pad_r
- if w < 1 or h < 1 then error("Input too small") end
+ if w < 1 or h < 1 then error('input is too small') end
self.output:resize(input:size(1), h, w)
self.output:zero()
-- crop input if necessary
@@ -50,6 +44,7 @@ function SpatialPadding:forward(input)
end
function SpatialPadding:backward(input, gradOutput)
+ if input:dim() ~= 3 then error('input must be 3-dimensional') end
self.gradInput:resizeAs(input):zero()
-- crop gradInput if necessary
local cg_input = self.gradInput
@@ -68,7 +63,6 @@ function SpatialPadding:backward(input, gradOutput)
return self.gradInput
end
-
function SpatialPadding:write(file)
parent.write(self, file)
file:writeInt(self.pad_l)