diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-09-14 00:12:57 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-09-14 00:12:57 +0400 |
commit | d6da0e6b62f24d0581255751d3ee22e6f3765035 (patch) | |
tree | 8decd2628965dabc2a7d3b807d37b8575b105338 | |
parent | 367f3c4751a498f67b5c90d776e5198be270d9f2 (diff) |
Added a flag to optionally bypass the creation of a spatial output.
-rw-r--r-- | SpatialClassifier.lua | 63 |
1 files changed, 44 insertions, 19 deletions
diff --git a/SpatialClassifier.lua b/SpatialClassifier.lua index 818c084..9d39f80 100644 --- a/SpatialClassifier.lua +++ b/SpatialClassifier.lua @@ -1,11 +1,11 @@ -local Classifier, parent = torch.class('nn.SpatialClassifier', 'nn.Sequential') +local Classifier, parent = torch.class('nn.SpatialClassifier', 'nn.Module') function Classifier:__init(classifier) parent.__init(self) -- public: self.classifier = classifier or nn.Sequential() + self.spatialOutput = true -- private: - self.modules = self.classifier self.inputF = torch.Tensor() self.inputT = torch.Tensor() self.outputF = torch.Tensor() @@ -14,6 +14,8 @@ function Classifier:__init(classifier) self.gradOutputT = torch.Tensor() self.gradInputF = torch.Tensor() self.gradInput = torch.Tensor() + -- compat: + self.modules = {self.classifier} end function Classifier:add(module) @@ -29,19 +31,23 @@ function Classifier:forward(input) local H = input:size(2) local W = input:size(3) local HW = H*W + -- transpose input: self.inputF:set(input):resize(K, HW) self.inputT:resize(HW, K):copy(self.inputF:t()) + -- classify all locations: self.outputT = self.classifier:forward(self.inputT) - -- force batch - if self.outputT:nDimension() == 1 then - self.outputT:resize(1,self.outputT:size(1)) + + if self.spatialOutput then + -- transpose output: + local N = self.outputT:size(2) + self.outputF:resize(N, HW):copy(self.outputT:t()) + self.output:set(self.outputF):resize(N,H,W) + else + -- leave output flat: + self.output = self.outputT end - -- transpose output: - local N = self.outputT:size(2) - self.outputF:resize(N, HW):copy(self.outputT:t()) - self.output:set(self.outputF):resize(N,H,W) return self.output end @@ -52,18 +58,22 @@ function Classifier:backward(input, gradOutput) local W = input:size(3) local HW = H*W local N = gradOutput:size(1) + -- transpose input self.inputF:set(input):resize(K, HW) self.inputT:resize(HW, K):copy(self.inputF:t()) - -- transpose gradOutput - self.gradOutputF:set(gradOutput):resize(N, HW) - self.gradOutputT:resize(HW, N):copy(self.gradOutputF:t()) + + if self.spatialOutput then + -- transpose gradOutput + self.gradOutputF:set(gradOutput):resize(N, HW) + self.gradOutputT:resize(HW, N):copy(self.gradOutputF:t()) + else + self.gradOutputT = gradOutput + end + -- backward through classifier: self.gradInputT = self.classifier:backward(self.inputT, self.gradOutputT) - -- force batch - if self.gradInputT:nDimension() == 1 then - self.gradInputT:resize(1,self.outputT:size(1)) - end + -- transpose gradInput self.gradInputF:resize(K, HW):copy(self.gradInputT:t()) self.gradInput:set(self.gradInputF):resize(K,H,W) @@ -77,12 +87,27 @@ function Classifier:accGradParameters(input, gradOutput, scale) local W = input:size(3) local HW = H*W local N = gradOutput:size(1) + -- transpose input self.inputF:set(input):resize(K, HW) self.inputT:resize(HW, K):copy(self.inputF:t()) - -- transpose gradOutput - self.gradOutputF:set(gradOutput):resize(N, HW) - self.gradOutputT:resize(HW, N):copy(self.gradOutputF:t()) + + if self.spatialOutput then + -- transpose gradOutput + self.gradOutputF:set(gradOutput):resize(N, HW) + self.gradOutputT:resize(HW, N):copy(self.gradOutputF:t()) + else + self.gradOutputT = gradOutput + end + -- backward through classifier: self.classifier:accGradParameters(self.inputT, self.gradOutputT, scale) end + +function Classifier:zeroGradParameters() + self.classifier:zeroGradParameters() +end + +function Classifier:updateParameters(learningRate) + self.classifier:updateParameters(learningRate) +end |