diff options
author | Luke <insperatum@gmail.com> | 2015-11-20 05:00:03 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-11-26 11:26:17 +0300 |
commit | b7a64bf1f8062b394ee28679f4d6bf736b2672c6 (patch) | |
tree | 727c6ad0a7cde6471e6083c91e96c761d34318cf /Identity.lua | |
parent | 41d2b1d1ee8aa12b5195fbd3bd165bc932e0764f (diff) |
Retain nn.Identity state pointers if they are tensors
Diffstat (limited to 'Identity.lua')
-rw-r--r-- | Identity.lua | 24 |
1 files changed, 21 insertions, 3 deletions
diff --git a/Identity.lua b/Identity.lua index 088cc34..3d73f81 100644 --- a/Identity.lua +++ b/Identity.lua @@ -1,12 +1,30 @@ -local Identity, _ = torch.class('nn.Identity', 'nn.Module') +local Identity, parent = torch.class('nn.Identity', 'nn.Module') + +function Identity:__init() + parent.__init(self) + self.tensorOutput = torch.Tensor{} + self.output = self.tensorOutput + self.tensorGradInput = torch.Tensor{} + self.gradInput = self.tensorGradInput +end function Identity:updateOutput(input) - self.output = input + if torch.isTensor(input) then + self.tensorOutput:set(input) + self.output = self.tensorOutput + else + self.output = input + end return self.output end function Identity:updateGradInput(input, gradOutput) - self.gradInput = gradOutput + if torch.isTensor(gradOutput) then + self.tensorGradInput:set(gradOutput) + self.gradInput = self.tensorGradInput + else + self.gradInput = gradOutput + end return self.gradInput end |