Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorElie Michel <elie.michel@ens.fr>2015-06-09 21:06:10 +0300
committerElie Michel <elie.michel@ens.fr>2015-06-09 21:06:10 +0300
commitbf8fa1714aabe1602c9112998cf626110895d00c (patch)
tree8d2f45d5f8c22bf1a7ae50cfbb1ab1519b77a176 /Sequential.lua
parent8929c5cf41b2633b4cc83d9a5218268351cf5c0f (diff)
Add nn.Sequential.remove([index])
Diffstat (limited to 'Sequential.lua')
-rw-r--r--Sequential.lua12
1 files changed, 11 insertions, 1 deletions
diff --git a/Sequential.lua b/Sequential.lua
index b08f7df..359a764 100644
--- a/Sequential.lua
+++ b/Sequential.lua
@@ -15,7 +15,7 @@ end
function Sequential:insert(module, index)
index = index or (#self.modules + 1)
- if index > (#self.modules + 1) then
+ if index > (#self.modules + 1) or index < 1 then
error"index should be contiguous to existing modules"
end
table.insert(self.modules, index, module)
@@ -23,6 +23,16 @@ function Sequential:insert(module, index)
self.gradInput = self.modules[1].gradInput
end
+function Sequential:remove(index)
+ index = index or #self.modules
+ if index > #self.modules or index < 1 then
+ error"index out of range"
+ end
+ table.remove(self.modules, index)
+ self.output = self.modules[#self.modules].output
+ self.gradInput = self.modules[1].gradInput
+end
+
function Sequential:updateOutput(input)
local currentOutput = input
for i=1,#self.modules do