diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-09-19 00:21:42 +0400 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-09-19 00:21:42 +0400 |
commit | e68d15b0414c059a8ec747ee825266404d167238 (patch) | |
tree | 2d3db182a8ffc02e594f52b6c69ca19279202e10 /PushTable.lua | |
parent | f87985eb0b2596dbb82e32cb9eaf0a102708d93c (diff) |
PushTable/PullTable unit tested
Diffstat (limited to 'PushTable.lua')
-rw-r--r-- | PushTable.lua | 19 |
1 files changed, 8 insertions, 11 deletions
diff --git a/PushTable.lua b/PushTable.lua index 3ee08c8..b7cfb64 100644 --- a/PushTable.lua +++ b/PushTable.lua @@ -6,8 +6,7 @@ function PushTable:__init(index) self.output = {} self._gradInput = torch.Tensor() self.gradInput = {} - self._nForward = 0 - self._nBackward = 0 + self._forward = false end function PushTable:pull(index) @@ -30,26 +29,24 @@ function PushTable:updateOutput(inputTable) pull:_updateOutput(input) end - self._nBackward = 0 + self._forward = true return self.output end function PushTable:_updateGradInput(gradOutput) - if self._nBackward == 0 then + if self._forward then + if torch.type(self.gradInput) ~= torch.type(gradOutput) then + self._gradInput = gradOutput.new() + end + self._gradInput:resizeAs(gradOutput) self._gradInput:copy(gradOutput) else self._gradInput:add(gradOutput) end - self._nBackward = self._nBackward + 1 + self._forward = false end function PushTable:updateGradInput(inputTable, gradOutputTable) - if self._nBackward ~= self._nForward then - error("n Inputs forwarded (pushed) ~= n gradOutputs backwarded".. - " (pulled) : "..self._nForward.." ~= "..self._nBackward) - end - self._nForward = 0 - for i, gradOutput in ipairs(gradOutputTable) do if i < self._index then self.gradInput[i] = gradOutput |