Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-09-19 00:21:42 +0400
committerNicholas Leonard <nick@nikopia.org>2014-09-19 00:21:42 +0400
commite68d15b0414c059a8ec747ee825266404d167238 (patch)
tree2d3db182a8ffc02e594f52b6c69ca19279202e10 /PushTable.lua
parentf87985eb0b2596dbb82e32cb9eaf0a102708d93c (diff)
PushTable/PullTable unit tested
Diffstat (limited to 'PushTable.lua')
-rw-r--r--PushTable.lua19
1 files changed, 8 insertions, 11 deletions
diff --git a/PushTable.lua b/PushTable.lua
index 3ee08c8..b7cfb64 100644
--- a/PushTable.lua
+++ b/PushTable.lua
@@ -6,8 +6,7 @@ function PushTable:__init(index)
self.output = {}
self._gradInput = torch.Tensor()
self.gradInput = {}
- self._nForward = 0
- self._nBackward = 0
+ self._forward = false
end
function PushTable:pull(index)
@@ -30,26 +29,24 @@ function PushTable:updateOutput(inputTable)
pull:_updateOutput(input)
end
- self._nBackward = 0
+ self._forward = true
return self.output
end
function PushTable:_updateGradInput(gradOutput)
- if self._nBackward == 0 then
+ if self._forward then
+ if torch.type(self.gradInput) ~= torch.type(gradOutput) then
+ self._gradInput = gradOutput.new()
+ end
+ self._gradInput:resizeAs(gradOutput)
self._gradInput:copy(gradOutput)
else
self._gradInput:add(gradOutput)
end
- self._nBackward = self._nBackward + 1
+ self._forward = false
end
function PushTable:updateGradInput(inputTable, gradOutputTable)
- if self._nBackward ~= self._nForward then
- error("n Inputs forwarded (pushed) ~= n gradOutputs backwarded"..
- " (pulled) : "..self._nForward.." ~= "..self._nBackward)
- end
- self._nForward = 0
-
for i, gradOutput in ipairs(gradOutputTable) do
if i < self._index then
self.gradInput[i] = gradOutput