Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-05-25 06:04:38 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-25 06:04:38 +0300
commit44ba0d3f08755214429eac1eff2c435ab2f14b57 (patch)
tree50349ec4197fe557296d6507df9c80c3843619d5 /CAddTensorTable.lua
parentdf1af9500a45f4deecd0f3f1f5020fe4789248ca (diff)
nn.CAddTensorTable
Diffstat (limited to 'CAddTensorTable.lua')
-rw-r--r--CAddTensorTable.lua43
1 files changed, 43 insertions, 0 deletions
diff --git a/CAddTensorTable.lua b/CAddTensorTable.lua
new file mode 100644
index 0000000..16efe44
--- /dev/null
+++ b/CAddTensorTable.lua
@@ -0,0 +1,43 @@
+
+local CAddTensorTable, parent = torch.class('nn.CAddTensorTable', 'nn.Module')
+
+function CAddTensorTable:__init()
+ parent.__init(self)
+ self.gradInput = {}
+end
+
+-- input is a table with 2 entries. input[1] is the vector to be added.
+-- input[2] is the table to which we add the vector
+function CAddTensorTable:updateOutput(input)
+ local currentOutput = {}
+ for i=1,#input[2] do
+ currentOutput[i] = currentOutput[i] or input[1].new()
+ currentOutput[i]:resizeAs(input[1])
+ currentOutput[i]:copy(input[2][i])
+ currentOutput[i]:add(input[1])
+ end
+ for i = #input[2]+1, #currentOutput do
+ currentOutput[i] = nil
+ end
+ self.output = currentOutput
+ return self.output
+end
+
+function CAddTensorTable:updateGradInput(input, gradOutput)
+ self.gradInput[1] = self.gradInput[1] or input[1].new()
+ self.gradInput[1]:resizeAs(input[1])
+ self.gradInput[1]:copy(gradOutput[1])
+ for i=2, #input[2] do
+ self.gradInput[1]:add(gradOutput[i])
+ end
+ self.gradInput[2] = self.gradInput[2] or {}
+ for i=1,#input[2] do
+ self.gradInput[2][i] = self.gradInput[2][i] or input[1].new()
+ self.gradInput[2][i]:resizeAs(input[1])
+ self.gradInput[2][i]:copy(gradOutput[i])
+ end
+ for i=#input[2]+1, #self.gradInput[2] do
+ self.gradInput[2][i] = nil
+ end
+ return self.gradInput
+end \ No newline at end of file