Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2016-03-02 23:56:26 +0300
committerAdam Paszke <adam.paszke@gmail.com>2016-03-05 15:56:03 +0300
commitb1cf092d84bb6bfdbb6442d13cc0900e3aea7109 (patch)
tree2a27da19f92bd6262d01abc3fd816c12f39fd7b5 /ConcatTable.lua
parent3a2a1b42e6e6c61addbf82f1efaa7b35c2a3144f (diff)
Improve error handling
When an error occurs in some module, all containers up to the topmost one will be printed now. Also, removed zeroGradParameters from ConcatTable, because it was no different from its parent's implementation.
Diffstat (limited to 'ConcatTable.lua')
-rw-r--r--ConcatTable.lua16
1 files changed, 5 insertions, 11 deletions
diff --git a/ConcatTable.lua b/ConcatTable.lua
index eddf6cc..cb08de0 100644
--- a/ConcatTable.lua
+++ b/ConcatTable.lua
@@ -8,7 +8,7 @@ end
function ConcatTable:updateOutput(input)
for i=1,#self.modules do
- self.output[i] = self.modules[i]:updateOutput(input)
+ self.output[i] = self:rethrowErrors(self.modules[i], i, 'updateOutput', input)
end
return self.output
end
@@ -32,7 +32,7 @@ local function backward(self, method, input, gradOutput, scale)
local wasTable = torch.type(self.gradInput) == 'table'
if isTable then
for i,module in ipairs(self.modules) do
- local currentGradInput = module[method](module, input, gradOutput[i], scale)
+ local currentGradInput = self:rethrowErrors(module, i, method, input, gradOutput[i], scale)
if torch.type(currentGradInput) ~= 'table' then
error"currentGradInput is not a table!"
end
@@ -63,7 +63,7 @@ local function backward(self, method, input, gradOutput, scale)
else
self.gradInput = (not wasTable) and self.gradInput or input:clone()
for i,module in ipairs(self.modules) do
- local currentGradInput = module[method](module, input, gradOutput[i], scale)
+ local currentGradInput = self:rethrowErrors(module, i, method, input, gradOutput[i], scale)
if i == 1 then
self.gradInput:resizeAs(currentGradInput):copy(currentGradInput)
else
@@ -85,19 +85,13 @@ end
function ConcatTable:accGradParameters(input, gradOutput, scale)
scale = scale or 1
for i,module in ipairs(self.modules) do
- module:accGradParameters(input, gradOutput[i], scale)
+ self:rethrowErrors(module, i, 'accGradParameters', input, gradOutput[i], scale)
end
end
function ConcatTable:accUpdateGradParameters(input, gradOutput, lr)
for i,module in ipairs(self.modules) do
- module:accUpdateGradParameters(input, gradOutput[i], lr)
- end
-end
-
-function ConcatTable:zeroGradParameters()
- for _,module in ipairs(self.modules) do
- module:zeroGradParameters()
+ self:rethrowErrors(module, i, 'accUpdateGradParameters', input, gradOutput[i], lr)
end
end