diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2016-03-02 23:56:26 +0300 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2016-03-05 15:56:03 +0300 |
commit | b1cf092d84bb6bfdbb6442d13cc0900e3aea7109 (patch) | |
tree | 2a27da19f92bd6262d01abc3fd816c12f39fd7b5 /ConcatTable.lua | |
parent | 3a2a1b42e6e6c61addbf82f1efaa7b35c2a3144f (diff) |
Improve error handling
When an error occurs in some module, all containers up to the
topmost one will be printed now.
Also, removed zeroGradParameters from ConcatTable, because it was
no different from its parent's implementation.
Diffstat (limited to 'ConcatTable.lua')
-rw-r--r-- | ConcatTable.lua | 16 |
1 files changed, 5 insertions, 11 deletions
diff --git a/ConcatTable.lua b/ConcatTable.lua index eddf6cc..cb08de0 100644 --- a/ConcatTable.lua +++ b/ConcatTable.lua @@ -8,7 +8,7 @@ end function ConcatTable:updateOutput(input) for i=1,#self.modules do - self.output[i] = self.modules[i]:updateOutput(input) + self.output[i] = self:rethrowErrors(self.modules[i], i, 'updateOutput', input) end return self.output end @@ -32,7 +32,7 @@ local function backward(self, method, input, gradOutput, scale) local wasTable = torch.type(self.gradInput) == 'table' if isTable then for i,module in ipairs(self.modules) do - local currentGradInput = module[method](module, input, gradOutput[i], scale) + local currentGradInput = self:rethrowErrors(module, i, method, input, gradOutput[i], scale) if torch.type(currentGradInput) ~= 'table' then error"currentGradInput is not a table!" end @@ -63,7 +63,7 @@ local function backward(self, method, input, gradOutput, scale) else self.gradInput = (not wasTable) and self.gradInput or input:clone() for i,module in ipairs(self.modules) do - local currentGradInput = module[method](module, input, gradOutput[i], scale) + local currentGradInput = self:rethrowErrors(module, i, method, input, gradOutput[i], scale) if i == 1 then self.gradInput:resizeAs(currentGradInput):copy(currentGradInput) else @@ -85,19 +85,13 @@ end function ConcatTable:accGradParameters(input, gradOutput, scale) scale = scale or 1 for i,module in ipairs(self.modules) do - module:accGradParameters(input, gradOutput[i], scale) + self:rethrowErrors(module, i, 'accGradParameters', input, gradOutput[i], scale) end end function ConcatTable:accUpdateGradParameters(input, gradOutput, lr) for i,module in ipairs(self.modules) do - module:accUpdateGradParameters(input, gradOutput[i], lr) - end -end - -function ConcatTable:zeroGradParameters() - for _,module in ipairs(self.modules) do - module:zeroGradParameters() + self:rethrowErrors(module, i, 'accUpdateGradParameters', input, gradOutput[i], lr) end end |