diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-09-18 23:56:40 +0400 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-09-18 23:56:40 +0400 |
commit | f87985eb0b2596dbb82e32cb9eaf0a102708d93c (patch) | |
tree | 6438bf906a4a4907a0aa6dd75dd8c02b9a683114 /PushTable.lua | |
parent | 76abd3b0cf8f4a71b5432736bbf482614d9f78bb (diff) |
initial commit for PushTable/PullTable:backward
Diffstat (limited to 'PushTable.lua')
-rw-r--r-- | PushTable.lua | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/PushTable.lua b/PushTable.lua index fbfff76..3ee08c8 100644 --- a/PushTable.lua +++ b/PushTable.lua @@ -4,6 +4,10 @@ function PushTable:__init(index) self._index = index self._pulls = {} self.output = {} + self._gradInput = torch.Tensor() + self.gradInput = {} + self._nForward = 0 + self._nBackward = 0 end function PushTable:pull(index) @@ -23,14 +27,39 @@ function PushTable:updateOutput(inputTable) local input = inputTable[self._index] for i,pull in ipairs(self._pulls) do - pull:push(input) + pull:_updateOutput(input) end + + self._nBackward = 0 return self.output end +function PushTable:_updateGradInput(gradOutput) + if self._nBackward == 0 then + self._gradInput:copy(gradOutput) + else + self._gradInput:add(gradOutput) + end + self._nBackward = self._nBackward + 1 +end + function PushTable:updateGradInput(inputTable, gradOutputTable) + if self._nBackward ~= self._nForward then + error("n Inputs forwarded (pushed) ~= n gradOutputs backwarded".. + " (pulled) : "..self._nForward.." ~= "..self._nBackward) + end + self._nForward = 0 - + for i, gradOutput in ipairs(gradOutputTable) do + if i < self._index then + self.gradInput[i] = gradOutput + elseif i > self._index then + self.gradInput[i+1] = gradOutput + end + end + self.gradInput[self._index] = self._gradInput + assert(#inputTable == #self.gradInput, "tables size mismatch") + return self.gradInput end |