Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-05-09 22:26:36 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-10 17:38:51 +0300
commit17bbce696980189e27f0cd12d5f219c8cd8ffbc0 (patch)
treeb02b37bc0daf5e65b43f745d15f654591825000e /Decorator.lua
parent3752f2426b55bc32cbd0ef112649d47dc674baa8 (diff)
Decorator modules
Diffstat (limited to 'Decorator.lua')
-rw-r--r--Decorator.lua47
1 files changed, 47 insertions, 0 deletions
diff --git a/Decorator.lua b/Decorator.lua
new file mode 100644
index 0000000..05fb4db
--- /dev/null
+++ b/Decorator.lua
@@ -0,0 +1,47 @@
+local Decorator, parent = torch.class("nn.Decorator", "nn.Container")
+
+function Decorator:__init(module)
+ parent.__init(self)
+ -- so that it can be handled like a Container
+ self.modules[1] = module
+end
+
+function Decorator:updateOutput(input)
+ self.output = self.modules[1]:updateOutput(input)
+ return self.output
+end
+
+function Decorator:updateGradInput(input, gradOutput)
+ self.gradInput = self.modules[1]:updateGradInput(input, gradOutput)
+ return self.gradInput
+end
+
+function Decorator:accGradParameters(input, gradOutput, scale)
+ self.modules[1]:accGradParameters(input, gradOutput, scale)
+end
+
+function Decorator:accUpdateGradParameters(input, gradOutput, lr)
+ self.modules[1]:accUpdateGradParameters(input, gradOutput, lr)
+end
+
+function Decorator:sharedAccUpdateGradParameters(input, gradOutput, lr)
+ self.modules[1]:sharedAccUpdateGradParameters(input, gradOutput, lr)
+end
+
+function Decorator:__tostring__()
+ if self.modules[1].__tostring__ then
+ return torch.type(self) .. ' @ ' .. self.modules[1]:__tostring__()
+ else
+ return torch.type(self) .. ' @ ' .. torch.type(self.modules[1])
+ end
+end
+
+-- useful for multiple-inheritance
+function Decorator.decorate(class)
+ class.updateOutput = nn.Decorator.updateOutput
+ class.updateGradInput = nn.Decorator.updateGradInput
+ class.accGradParameters = nn.Decorator.accGradParameters
+ class.accUpdateGradParameters = nn.Decorator.accUpdateGradParameters
+ class.sharedAccUpdateGradParameters = nn.Decorator.sharedAccUpdateGradParameters
+ class.__tostring__ = nn.Decorator.__tostring__
+end