Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-09-24 23:17:37 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-09-24 23:17:37 +0400
commit100e8c1cd0bdcefb471d64dff38830db54eab23b (patch)
tree5afcb11a6430f67a247f5dc66c13012cb55d008f /SpatialFovea.lua
parentfef650e3b5b4040231f1527af95e7173f1c06d79 (diff)
Added support for type in SpatialFovea.
Diffstat (limited to 'SpatialFovea.lua')
-rw-r--r--SpatialFovea.lua21
1 files changed, 20 insertions, 1 deletions
diff --git a/SpatialFovea.lua b/SpatialFovea.lua
index b9e3fe4..c4ebcb8 100644
--- a/SpatialFovea.lua
+++ b/SpatialFovea.lua
@@ -116,6 +116,11 @@ function SpatialFovea:configure(width,height)
else
self.upsamplers[idx] = nn.SpatialUpSampling(r, r)
end
+
+ -- set correct types
+ self.downsamplers[idx]:type(self.output:type())
+ self.padders[idx]:type(self.output:type())
+ self.upsamplers[idx]:type(self.output:type())
end
end
@@ -265,7 +270,7 @@ function SpatialFovea:backward(input, gradOutput)
-- (4) is fovea focused ?
if self.focused then
for idx = 1,nscales do
- self.gradPadded[idx] = self.gradPadded[idx] or torch.Tensor()
+ self.gradPadded[idx] = self.gradPadded[idx] or torch.Tensor():typeAs(self.output)
self.gradPadded[idx]:resizeAs(self.padded[idx]):zero()
local fov = self.fov
local ox = math.floor(math.floor((self.x-1) / self.ratios[idx]) / self.sub) * self.sub + 1
@@ -327,6 +332,20 @@ function SpatialFovea:updateParameters(learningRate)
end
end
+function SpatialFovea:type(type)
+ parent.type(self,type)
+ for idx = 1,#self.processors do
+ self.processors[idx]:type(type)
+ self.upsamplers[idx]:type(type)
+ self.downsamplers[idx]:type(type)
+ self.padders[idx]:type(type)
+ end
+ for idx = 1,#self.preProcessors do
+ self.preProcessors[idx]:type(type)
+ end
+ return self
+end
+
function SpatialFovea:write(file)
parent.write(self, file)
file:writeInt(self.nInputPlane)