diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2014-07-07 19:13:28 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2014-07-07 19:13:28 +0400 |
commit | 4f01eb0e7359e59ed435fe4bbd45a19bf5df9b17 (patch) | |
tree | 9a312aac358aeabfd7c37396f3e4a60d6618e760 | |
parent | 5b640df933ac92eafb7a9cfa72a39506f261ee74 (diff) | |
parent | e217fc040843fb77f885822d8ef1a45f84cfacae (diff) |
Merge pull request #27 from nicholas-leonard/master
Module:evaluate/training
-rw-r--r-- | Concat.lua | 12 | ||||
-rw-r--r-- | ConcatTable.lua | 12 | ||||
-rw-r--r-- | Module.lua | 8 | ||||
-rw-r--r-- | Parallel.lua | 12 | ||||
-rw-r--r-- | ParallelTable.lua | 12 | ||||
-rw-r--r-- | Sequential.lua | 12 | ||||
-rw-r--r-- | SplitTable.lua | 1 |
7 files changed, 68 insertions, 1 deletions
@@ -92,6 +92,18 @@ function Concat:updateParameters(learningRate) end end +function Concat:training() + for i=1,#self.modules do + self.modules[i]:training() + end +end + +function Concat:evaluate() + for i=1,#self.modules do + self.modules[i]:evaluate() + end +end + function Concat:share(mlp,...) for i=1,#self.modules do self.modules[i]:share(mlp.modules[i],...); diff --git a/ConcatTable.lua b/ConcatTable.lua index 414e873..c42776d 100644 --- a/ConcatTable.lua +++ b/ConcatTable.lua @@ -63,6 +63,18 @@ function ConcatTable:updateParameters(learningRate) end end +function ConcatTable:training() + for i=1,#self.modules do + self.modules[i]:training() + end +end + +function ConcatTable:evaluate() + for i=1,#self.modules do + self.modules[i]:evaluate() + end +end + function ConcatTable:share(mlp,...) for i=1,#self.modules do self.modules[i]:share(mlp.modules[i],...); @@ -81,6 +81,14 @@ function Module:updateParameters(learningRate) end end +function Module:training() + self.train = true +end + +function Module:evaluate() + self.train = false +end + function Module:share(mlp, ...) local arg = {...} for i,v in ipairs(arg) do diff --git a/Parallel.lua b/Parallel.lua index 3c625bc..547f444 100644 --- a/Parallel.lua +++ b/Parallel.lua @@ -108,6 +108,18 @@ function Parallel:updateParameters(learningRate) end end +function Parallel:training() + for i=1,#self.modules do + self.modules[i]:training() + end +end + +function Parallel:evaluate() + for i=1,#self.modules do + self.modules[i]:evaluate() + end +end + function Parallel:share(mlp,...) for i=1,#self.modules do self.modules[i]:share(mlp.modules[i],...); diff --git a/ParallelTable.lua b/ParallelTable.lua index 3d8f79d..255a7bd 100644 --- a/ParallelTable.lua +++ b/ParallelTable.lua @@ -61,6 +61,18 @@ function ParallelTable:updateParameters(learningRate) end end +function ParallelTable:training() + for i=1,#self.modules do + self.modules[i]:training() + end +end + +function ParallelTable:evaluate() + for i=1,#self.modules do + self.modules[i]:evaluate() + end +end + function ParallelTable:share(mlp,...) for i=1,#self.modules do self.modules[i]:share(mlp.modules[i],...); diff --git a/Sequential.lua b/Sequential.lua index a368964..b43bd99 100644 --- a/Sequential.lua +++ b/Sequential.lua @@ -83,6 +83,18 @@ function Sequential:updateParameters(learningRate) end end +function Sequential:training() + for i=1,#self.modules do + self.modules[i]:training() + end +end + +function Sequential:evaluate() + for i=1,#self.modules do + self.modules[i]:evaluate() + end +end + function Sequential:share(mlp,...) for i=1,#self.modules do self.modules[i]:share(mlp.modules[i],...); diff --git a/SplitTable.lua b/SplitTable.lua index 0148c4e..70b45f6 100644 --- a/SplitTable.lua +++ b/SplitTable.lua @@ -2,7 +2,6 @@ local SplitTable, parent = torch.class('nn.SplitTable', 'nn.Module') function SplitTable:__init(dimension, nInputDims) parent.__init(self) - self.modules = {} self.dimension = dimension self.nInputDims = nInputDims end |