diff options
author | Nicholas Leonard <nleonard@twitter.com> | 2017-05-09 22:26:36 +0300 |
---|---|---|
committer | Nicholas Leonard <nleonard@twitter.com> | 2017-05-10 17:38:51 +0300 |
commit | 17bbce696980189e27f0cd12d5f219c8cd8ffbc0 (patch) | |
tree | b02b37bc0daf5e65b43f745d15f654591825000e /Decorator.lua | |
parent | 3752f2426b55bc32cbd0ef112649d47dc674baa8 (diff) |
Decorator modules
Diffstat (limited to 'Decorator.lua')
-rw-r--r-- | Decorator.lua | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/Decorator.lua b/Decorator.lua new file mode 100644 index 0000000..05fb4db --- /dev/null +++ b/Decorator.lua @@ -0,0 +1,47 @@ +local Decorator, parent = torch.class("nn.Decorator", "nn.Container") + +function Decorator:__init(module) + parent.__init(self) + -- so that it can be handled like a Container + self.modules[1] = module +end + +function Decorator:updateOutput(input) + self.output = self.modules[1]:updateOutput(input) + return self.output +end + +function Decorator:updateGradInput(input, gradOutput) + self.gradInput = self.modules[1]:updateGradInput(input, gradOutput) + return self.gradInput +end + +function Decorator:accGradParameters(input, gradOutput, scale) + self.modules[1]:accGradParameters(input, gradOutput, scale) +end + +function Decorator:accUpdateGradParameters(input, gradOutput, lr) + self.modules[1]:accUpdateGradParameters(input, gradOutput, lr) +end + +function Decorator:sharedAccUpdateGradParameters(input, gradOutput, lr) + self.modules[1]:sharedAccUpdateGradParameters(input, gradOutput, lr) +end + +function Decorator:__tostring__() + if self.modules[1].__tostring__ then + return torch.type(self) .. ' @ ' .. self.modules[1]:__tostring__() + else + return torch.type(self) .. ' @ ' .. torch.type(self.modules[1]) + end +end + +-- useful for multiple-inheritance +function Decorator.decorate(class) + class.updateOutput = nn.Decorator.updateOutput + class.updateGradInput = nn.Decorator.updateGradInput + class.accGradParameters = nn.Decorator.accGradParameters + class.accUpdateGradParameters = nn.Decorator.accUpdateGradParameters + class.sharedAccUpdateGradParameters = nn.Decorator.sharedAccUpdateGradParameters + class.__tostring__ = nn.Decorator.__tostring__ +end |