diff options
author | Marco Scoffier <github@metm.org> | 2011-09-26 19:51:11 +0400 |
---|---|---|
committer | Marco Scoffier <github@metm.org> | 2011-09-26 19:51:11 +0400 |
commit | 8e6cd11f8efdcba57955031f6e76ad57b9bf5e6f (patch) | |
tree | de9d26a672d496e85ea1d7187c668b7cd244e65d | |
parent | ec68f17afbe3fb37b1aa025175ab138830ed19d8 (diff) | |
parent | 100e8c1cd0bdcefb471d64dff38830db54eab23b (diff) |
Merge branch 'master' of github.com:clementfarabet/lua---nnx
-rw-r--r-- | SpatialFovea.lua | 21 | ||||
-rw-r--r-- | Type.lua | 49 |
2 files changed, 69 insertions, 1 deletions
diff --git a/SpatialFovea.lua b/SpatialFovea.lua index b9e3fe4..c4ebcb8 100644 --- a/SpatialFovea.lua +++ b/SpatialFovea.lua @@ -116,6 +116,11 @@ function SpatialFovea:configure(width,height) else self.upsamplers[idx] = nn.SpatialUpSampling(r, r) end + + -- set correct types + self.downsamplers[idx]:type(self.output:type()) + self.padders[idx]:type(self.output:type()) + self.upsamplers[idx]:type(self.output:type()) end end @@ -265,7 +270,7 @@ function SpatialFovea:backward(input, gradOutput) -- (4) is fovea focused ? if self.focused then for idx = 1,nscales do - self.gradPadded[idx] = self.gradPadded[idx] or torch.Tensor() + self.gradPadded[idx] = self.gradPadded[idx] or torch.Tensor():typeAs(self.output) self.gradPadded[idx]:resizeAs(self.padded[idx]):zero() local fov = self.fov local ox = math.floor(math.floor((self.x-1) / self.ratios[idx]) / self.sub) * self.sub + 1 @@ -327,6 +332,20 @@ function SpatialFovea:updateParameters(learningRate) end end +function SpatialFovea:type(type) + parent.type(self,type) + for idx = 1,#self.processors do + self.processors[idx]:type(type) + self.upsamplers[idx]:type(type) + self.downsamplers[idx]:type(type) + self.padders[idx]:type(type) + end + for idx = 1,#self.preProcessors do + self.preProcessors[idx]:type(type) + end + return self +end + function SpatialFovea:write(file) parent.write(self, file) file:writeInt(self.nInputPlane) diff --git a/Type.lua b/Type.lua new file mode 100644 index 0000000..c65f735 --- /dev/null +++ b/Type.lua @@ -0,0 +1,49 @@ +local Type, parent = torch.class('nn.Type', 'nn.Sequential') + +function Type:__init(type) + parent.__init(self) + if not type:find('torch%..+Tensor') then + type = 'torch.' .. type .. 'Tensor' + end + self.type = type + self.defaulttype = torch.getdefaulttensortype() + self.convert_input = nn.Copy(self.defaulttype, self.type) + self.convert_gradOutput = nn.Copy(self.defaulttype, self.type) + self.convert_output = nn.Copy(self.type, self.defaulttype) + self.convert_gradInput = nn.Copy(self.type, self.defaulttype) +end + +function Type:add(module) + parent.add(self, module:type(self.type)) + return self +end + +function Type:forward(input) + input = self.convert_input:forward(input) + local output = parent.forward(self, input) + self.output = self.convert_output:forward(output) + return self.output +end + +function Type:backward(input, gradOutput) + input = self.convert_input:forward(input) + gradOutput = self.convert_gradOutput:forward(gradOutput) + local gradInput = parent.backward(self, input, gradOutput) + self.gradInput = self.convert_gradInput:forward(gradInput) + return self.gradInput +end + +local Float, parent = torch.class('nn.Float', 'nn.Type') +function Float:__init() + parent.__init(self, 'torch.FloatTensor') +end + +local Double, parent = torch.class('nn.Double', 'nn.Type') +function Double:__init() + parent.__init(self, 'torch.DoubleTensor') +end + +local Cuda, parent = torch.class('nn.Cuda', 'nn.Type') +function Cuda:__init() + parent.__init(self, 'torch.CudaTensor') +end |