Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke <insperatum@gmail.com>2015-11-20 05:00:03 +0300
committerSoumith Chintala <soumith@gmail.com>2015-11-26 11:26:17 +0300
commitb7a64bf1f8062b394ee28679f4d6bf736b2672c6 (patch)
tree727c6ad0a7cde6471e6083c91e96c761d34318cf /Identity.lua
parent41d2b1d1ee8aa12b5195fbd3bd165bc932e0764f (diff)
Retain nn.Identity state pointers if they are tensors
Diffstat (limited to 'Identity.lua')
-rw-r--r--Identity.lua24
1 files changed, 21 insertions, 3 deletions
diff --git a/Identity.lua b/Identity.lua
index 088cc34..3d73f81 100644
--- a/Identity.lua
+++ b/Identity.lua
@@ -1,12 +1,30 @@
-local Identity, _ = torch.class('nn.Identity', 'nn.Module')
+local Identity, parent = torch.class('nn.Identity', 'nn.Module')
+
+function Identity:__init()
+ parent.__init(self)
+ self.tensorOutput = torch.Tensor{}
+ self.output = self.tensorOutput
+ self.tensorGradInput = torch.Tensor{}
+ self.gradInput = self.tensorGradInput
+end
function Identity:updateOutput(input)
- self.output = input
+ if torch.isTensor(input) then
+ self.tensorOutput:set(input)
+ self.output = self.tensorOutput
+ else
+ self.output = input
+ end
return self.output
end
function Identity:updateGradInput(input, gradOutput)
- self.gradInput = gradOutput
+ if torch.isTensor(gradOutput) then
+ self.tensorGradInput:set(gradOutput)
+ self.gradInput = self.tensorGradInput
+ else
+ self.gradInput = gradOutput
+ end
return self.gradInput
end