Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-09-18 23:56:40 +0400
committerNicholas Leonard <nick@nikopia.org>2014-09-18 23:56:40 +0400
commitf87985eb0b2596dbb82e32cb9eaf0a102708d93c (patch)
tree6438bf906a4a4907a0aa6dd75dd8c02b9a683114 /PushTable.lua
parent76abd3b0cf8f4a71b5432736bbf482614d9f78bb (diff)
initial commit for PushTable/PullTable:backward
Diffstat (limited to 'PushTable.lua')
-rw-r--r--PushTable.lua33
1 files changed, 31 insertions, 2 deletions
diff --git a/PushTable.lua b/PushTable.lua
index fbfff76..3ee08c8 100644
--- a/PushTable.lua
+++ b/PushTable.lua
@@ -4,6 +4,10 @@ function PushTable:__init(index)
self._index = index
self._pulls = {}
self.output = {}
+ self._gradInput = torch.Tensor()
+ self.gradInput = {}
+ self._nForward = 0
+ self._nBackward = 0
end
function PushTable:pull(index)
@@ -23,14 +27,39 @@ function PushTable:updateOutput(inputTable)
local input = inputTable[self._index]
for i,pull in ipairs(self._pulls) do
- pull:push(input)
+ pull:_updateOutput(input)
end
+
+ self._nBackward = 0
return self.output
end
+function PushTable:_updateGradInput(gradOutput)
+ if self._nBackward == 0 then
+ self._gradInput:copy(gradOutput)
+ else
+ self._gradInput:add(gradOutput)
+ end
+ self._nBackward = self._nBackward + 1
+end
+
function PushTable:updateGradInput(inputTable, gradOutputTable)
+ if self._nBackward ~= self._nForward then
+ error("n Inputs forwarded (pushed) ~= n gradOutputs backwarded"..
+ " (pulled) : "..self._nForward.." ~= "..self._nBackward)
+ end
+ self._nForward = 0
-
+ for i, gradOutput in ipairs(gradOutputTable) do
+ if i < self._index then
+ self.gradInput[i] = gradOutput
+ elseif i > self._index then
+ self.gradInput[i+1] = gradOutput
+ end
+ end
+ self.gradInput[self._index] = self._gradInput
+ assert(#inputTable == #self.gradInput, "tables size mismatch")
+ return self.gradInput
end