Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarco Scoffier <github@metm.org>2011-09-26 19:51:11 +0400
committerMarco Scoffier <github@metm.org>2011-09-26 19:51:11 +0400
commit8e6cd11f8efdcba57955031f6e76ad57b9bf5e6f (patch)
treede9d26a672d496e85ea1d7187c668b7cd244e65d
parentec68f17afbe3fb37b1aa025175ab138830ed19d8 (diff)
parent100e8c1cd0bdcefb471d64dff38830db54eab23b (diff)
Merge branch 'master' of github.com:clementfarabet/lua---nnx
-rw-r--r--SpatialFovea.lua21
-rw-r--r--Type.lua49
2 files changed, 69 insertions, 1 deletions
diff --git a/SpatialFovea.lua b/SpatialFovea.lua
index b9e3fe4..c4ebcb8 100644
--- a/SpatialFovea.lua
+++ b/SpatialFovea.lua
@@ -116,6 +116,11 @@ function SpatialFovea:configure(width,height)
else
self.upsamplers[idx] = nn.SpatialUpSampling(r, r)
end
+
+ -- set correct types
+ self.downsamplers[idx]:type(self.output:type())
+ self.padders[idx]:type(self.output:type())
+ self.upsamplers[idx]:type(self.output:type())
end
end
@@ -265,7 +270,7 @@ function SpatialFovea:backward(input, gradOutput)
-- (4) is fovea focused ?
if self.focused then
for idx = 1,nscales do
- self.gradPadded[idx] = self.gradPadded[idx] or torch.Tensor()
+ self.gradPadded[idx] = self.gradPadded[idx] or torch.Tensor():typeAs(self.output)
self.gradPadded[idx]:resizeAs(self.padded[idx]):zero()
local fov = self.fov
local ox = math.floor(math.floor((self.x-1) / self.ratios[idx]) / self.sub) * self.sub + 1
@@ -327,6 +332,20 @@ function SpatialFovea:updateParameters(learningRate)
end
end
+function SpatialFovea:type(type)
+ parent.type(self,type)
+ for idx = 1,#self.processors do
+ self.processors[idx]:type(type)
+ self.upsamplers[idx]:type(type)
+ self.downsamplers[idx]:type(type)
+ self.padders[idx]:type(type)
+ end
+ for idx = 1,#self.preProcessors do
+ self.preProcessors[idx]:type(type)
+ end
+ return self
+end
+
function SpatialFovea:write(file)
parent.write(self, file)
file:writeInt(self.nInputPlane)
diff --git a/Type.lua b/Type.lua
new file mode 100644
index 0000000..c65f735
--- /dev/null
+++ b/Type.lua
@@ -0,0 +1,49 @@
+local Type, parent = torch.class('nn.Type', 'nn.Sequential')
+
+function Type:__init(type)
+ parent.__init(self)
+ if not type:find('torch%..+Tensor') then
+ type = 'torch.' .. type .. 'Tensor'
+ end
+ self.type = type
+ self.defaulttype = torch.getdefaulttensortype()
+ self.convert_input = nn.Copy(self.defaulttype, self.type)
+ self.convert_gradOutput = nn.Copy(self.defaulttype, self.type)
+ self.convert_output = nn.Copy(self.type, self.defaulttype)
+ self.convert_gradInput = nn.Copy(self.type, self.defaulttype)
+end
+
+function Type:add(module)
+ parent.add(self, module:type(self.type))
+ return self
+end
+
+function Type:forward(input)
+ input = self.convert_input:forward(input)
+ local output = parent.forward(self, input)
+ self.output = self.convert_output:forward(output)
+ return self.output
+end
+
+function Type:backward(input, gradOutput)
+ input = self.convert_input:forward(input)
+ gradOutput = self.convert_gradOutput:forward(gradOutput)
+ local gradInput = parent.backward(self, input, gradOutput)
+ self.gradInput = self.convert_gradInput:forward(gradInput)
+ return self.gradInput
+end
+
+local Float, parent = torch.class('nn.Float', 'nn.Type')
+function Float:__init()
+ parent.__init(self, 'torch.FloatTensor')
+end
+
+local Double, parent = torch.class('nn.Double', 'nn.Type')
+function Double:__init()
+ parent.__init(self, 'torch.DoubleTensor')
+end
+
+local Cuda, parent = torch.class('nn.Cuda', 'nn.Type')
+function Cuda:__init()
+ parent.__init(self, 'torch.CudaTensor')
+end