diff options
author | Nicholas Leonard <nleonard@twitter.com> | 2017-05-25 06:04:38 +0300 |
---|---|---|
committer | Nicholas Leonard <nleonard@twitter.com> | 2017-05-25 06:04:38 +0300 |
commit | 44ba0d3f08755214429eac1eff2c435ab2f14b57 (patch) | |
tree | 50349ec4197fe557296d6507df9c80c3843619d5 /CAddTensorTable.lua | |
parent | df1af9500a45f4deecd0f3f1f5020fe4789248ca (diff) |
nn.CAddTensorTable
Diffstat (limited to 'CAddTensorTable.lua')
-rw-r--r-- | CAddTensorTable.lua | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/CAddTensorTable.lua b/CAddTensorTable.lua new file mode 100644 index 0000000..16efe44 --- /dev/null +++ b/CAddTensorTable.lua @@ -0,0 +1,43 @@ + +local CAddTensorTable, parent = torch.class('nn.CAddTensorTable', 'nn.Module') + +function CAddTensorTable:__init() + parent.__init(self) + self.gradInput = {} +end + +-- input is a table with 2 entries. input[1] is the vector to be added. +-- input[2] is the table to which we add the vector +function CAddTensorTable:updateOutput(input) + local currentOutput = {} + for i=1,#input[2] do + currentOutput[i] = currentOutput[i] or input[1].new() + currentOutput[i]:resizeAs(input[1]) + currentOutput[i]:copy(input[2][i]) + currentOutput[i]:add(input[1]) + end + for i = #input[2]+1, #currentOutput do + currentOutput[i] = nil + end + self.output = currentOutput + return self.output +end + +function CAddTensorTable:updateGradInput(input, gradOutput) + self.gradInput[1] = self.gradInput[1] or input[1].new() + self.gradInput[1]:resizeAs(input[1]) + self.gradInput[1]:copy(gradOutput[1]) + for i=2, #input[2] do + self.gradInput[1]:add(gradOutput[i]) + end + self.gradInput[2] = self.gradInput[2] or {} + for i=1,#input[2] do + self.gradInput[2][i] = self.gradInput[2][i] or input[1].new() + self.gradInput[2][i]:resizeAs(input[1]) + self.gradInput[2][i]:copy(gradOutput[i]) + end + for i=#input[2]+1, #self.gradInput[2] do + self.gradInput[2][i] = nil + end + return self.gradInput +end
\ No newline at end of file |