Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2014-07-07 19:13:28 +0400
committerClement Farabet <clement.farabet@gmail.com>2014-07-07 19:13:28 +0400
commit4f01eb0e7359e59ed435fe4bbd45a19bf5df9b17 (patch)
tree9a312aac358aeabfd7c37396f3e4a60d6618e760
parent5b640df933ac92eafb7a9cfa72a39506f261ee74 (diff)
parente217fc040843fb77f885822d8ef1a45f84cfacae (diff)
Merge pull request #27 from nicholas-leonard/master
Module:evaluate/training
-rw-r--r--Concat.lua12
-rw-r--r--ConcatTable.lua12
-rw-r--r--Module.lua8
-rw-r--r--Parallel.lua12
-rw-r--r--ParallelTable.lua12
-rw-r--r--Sequential.lua12
-rw-r--r--SplitTable.lua1
7 files changed, 68 insertions, 1 deletions
diff --git a/Concat.lua b/Concat.lua
index 6543bcc..0b11a9d 100644
--- a/Concat.lua
+++ b/Concat.lua
@@ -92,6 +92,18 @@ function Concat:updateParameters(learningRate)
end
end
+function Concat:training()
+ for i=1,#self.modules do
+ self.modules[i]:training()
+ end
+end
+
+function Concat:evaluate()
+ for i=1,#self.modules do
+ self.modules[i]:evaluate()
+ end
+end
+
function Concat:share(mlp,...)
for i=1,#self.modules do
self.modules[i]:share(mlp.modules[i],...);
diff --git a/ConcatTable.lua b/ConcatTable.lua
index 414e873..c42776d 100644
--- a/ConcatTable.lua
+++ b/ConcatTable.lua
@@ -63,6 +63,18 @@ function ConcatTable:updateParameters(learningRate)
end
end
+function ConcatTable:training()
+ for i=1,#self.modules do
+ self.modules[i]:training()
+ end
+end
+
+function ConcatTable:evaluate()
+ for i=1,#self.modules do
+ self.modules[i]:evaluate()
+ end
+end
+
function ConcatTable:share(mlp,...)
for i=1,#self.modules do
self.modules[i]:share(mlp.modules[i],...);
diff --git a/Module.lua b/Module.lua
index 7812ef6..4a08d12 100644
--- a/Module.lua
+++ b/Module.lua
@@ -81,6 +81,14 @@ function Module:updateParameters(learningRate)
end
end
+function Module:training()
+ self.train = true
+end
+
+function Module:evaluate()
+ self.train = false
+end
+
function Module:share(mlp, ...)
local arg = {...}
for i,v in ipairs(arg) do
diff --git a/Parallel.lua b/Parallel.lua
index 3c625bc..547f444 100644
--- a/Parallel.lua
+++ b/Parallel.lua
@@ -108,6 +108,18 @@ function Parallel:updateParameters(learningRate)
end
end
+function Parallel:training()
+ for i=1,#self.modules do
+ self.modules[i]:training()
+ end
+end
+
+function Parallel:evaluate()
+ for i=1,#self.modules do
+ self.modules[i]:evaluate()
+ end
+end
+
function Parallel:share(mlp,...)
for i=1,#self.modules do
self.modules[i]:share(mlp.modules[i],...);
diff --git a/ParallelTable.lua b/ParallelTable.lua
index 3d8f79d..255a7bd 100644
--- a/ParallelTable.lua
+++ b/ParallelTable.lua
@@ -61,6 +61,18 @@ function ParallelTable:updateParameters(learningRate)
end
end
+function ParallelTable:training()
+ for i=1,#self.modules do
+ self.modules[i]:training()
+ end
+end
+
+function ParallelTable:evaluate()
+ for i=1,#self.modules do
+ self.modules[i]:evaluate()
+ end
+end
+
function ParallelTable:share(mlp,...)
for i=1,#self.modules do
self.modules[i]:share(mlp.modules[i],...);
diff --git a/Sequential.lua b/Sequential.lua
index a368964..b43bd99 100644
--- a/Sequential.lua
+++ b/Sequential.lua
@@ -83,6 +83,18 @@ function Sequential:updateParameters(learningRate)
end
end
+function Sequential:training()
+ for i=1,#self.modules do
+ self.modules[i]:training()
+ end
+end
+
+function Sequential:evaluate()
+ for i=1,#self.modules do
+ self.modules[i]:evaluate()
+ end
+end
+
function Sequential:share(mlp,...)
for i=1,#self.modules do
self.modules[i]:share(mlp.modules[i],...);
diff --git a/SplitTable.lua b/SplitTable.lua
index 0148c4e..70b45f6 100644
--- a/SplitTable.lua
+++ b/SplitTable.lua
@@ -2,7 +2,6 @@ local SplitTable, parent = torch.class('nn.SplitTable', 'nn.Module')
function SplitTable:__init(dimension, nInputDims)
parent.__init(self)
- self.modules = {}
self.dimension = dimension
self.nInputDims = nInputDims
end