Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-09-14 00:12:57 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-09-14 00:12:57 +0400
commitd6da0e6b62f24d0581255751d3ee22e6f3765035 (patch)
tree8decd2628965dabc2a7d3b807d37b8575b105338
parent367f3c4751a498f67b5c90d776e5198be270d9f2 (diff)
Added a flag to optionally bypass the creation of a spatial output.
-rw-r--r--SpatialClassifier.lua63
1 files changed, 44 insertions, 19 deletions
diff --git a/SpatialClassifier.lua b/SpatialClassifier.lua
index 818c084..9d39f80 100644
--- a/SpatialClassifier.lua
+++ b/SpatialClassifier.lua
@@ -1,11 +1,11 @@
-local Classifier, parent = torch.class('nn.SpatialClassifier', 'nn.Sequential')
+local Classifier, parent = torch.class('nn.SpatialClassifier', 'nn.Module')
function Classifier:__init(classifier)
parent.__init(self)
-- public:
self.classifier = classifier or nn.Sequential()
+ self.spatialOutput = true
-- private:
- self.modules = self.classifier
self.inputF = torch.Tensor()
self.inputT = torch.Tensor()
self.outputF = torch.Tensor()
@@ -14,6 +14,8 @@ function Classifier:__init(classifier)
self.gradOutputT = torch.Tensor()
self.gradInputF = torch.Tensor()
self.gradInput = torch.Tensor()
+ -- compat:
+ self.modules = {self.classifier}
end
function Classifier:add(module)
@@ -29,19 +31,23 @@ function Classifier:forward(input)
local H = input:size(2)
local W = input:size(3)
local HW = H*W
+
-- transpose input:
self.inputF:set(input):resize(K, HW)
self.inputT:resize(HW, K):copy(self.inputF:t())
+
-- classify all locations:
self.outputT = self.classifier:forward(self.inputT)
- -- force batch
- if self.outputT:nDimension() == 1 then
- self.outputT:resize(1,self.outputT:size(1))
+
+ if self.spatialOutput then
+ -- transpose output:
+ local N = self.outputT:size(2)
+ self.outputF:resize(N, HW):copy(self.outputT:t())
+ self.output:set(self.outputF):resize(N,H,W)
+ else
+ -- leave output flat:
+ self.output = self.outputT
end
- -- transpose output:
- local N = self.outputT:size(2)
- self.outputF:resize(N, HW):copy(self.outputT:t())
- self.output:set(self.outputF):resize(N,H,W)
return self.output
end
@@ -52,18 +58,22 @@ function Classifier:backward(input, gradOutput)
local W = input:size(3)
local HW = H*W
local N = gradOutput:size(1)
+
-- transpose input
self.inputF:set(input):resize(K, HW)
self.inputT:resize(HW, K):copy(self.inputF:t())
- -- transpose gradOutput
- self.gradOutputF:set(gradOutput):resize(N, HW)
- self.gradOutputT:resize(HW, N):copy(self.gradOutputF:t())
+
+ if self.spatialOutput then
+ -- transpose gradOutput
+ self.gradOutputF:set(gradOutput):resize(N, HW)
+ self.gradOutputT:resize(HW, N):copy(self.gradOutputF:t())
+ else
+ self.gradOutputT = gradOutput
+ end
+
-- backward through classifier:
self.gradInputT = self.classifier:backward(self.inputT, self.gradOutputT)
- -- force batch
- if self.gradInputT:nDimension() == 1 then
- self.gradInputT:resize(1,self.outputT:size(1))
- end
+
-- transpose gradInput
self.gradInputF:resize(K, HW):copy(self.gradInputT:t())
self.gradInput:set(self.gradInputF):resize(K,H,W)
@@ -77,12 +87,27 @@ function Classifier:accGradParameters(input, gradOutput, scale)
local W = input:size(3)
local HW = H*W
local N = gradOutput:size(1)
+
-- transpose input
self.inputF:set(input):resize(K, HW)
self.inputT:resize(HW, K):copy(self.inputF:t())
- -- transpose gradOutput
- self.gradOutputF:set(gradOutput):resize(N, HW)
- self.gradOutputT:resize(HW, N):copy(self.gradOutputF:t())
+
+ if self.spatialOutput then
+ -- transpose gradOutput
+ self.gradOutputF:set(gradOutput):resize(N, HW)
+ self.gradOutputT:resize(HW, N):copy(self.gradOutputF:t())
+ else
+ self.gradOutputT = gradOutput
+ end
+
-- backward through classifier:
self.classifier:accGradParameters(self.inputT, self.gradOutputT, scale)
end
+
+function Classifier:zeroGradParameters()
+ self.classifier:zeroGradParameters()
+end
+
+function Classifier:updateParameters(learningRate)
+ self.classifier:updateParameters(learningRate)
+end