Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarco Scoffier <github@metm.org>2011-11-03 07:44:33 +0400
committerMarco Scoffier <github@metm.org>2011-11-03 07:44:33 +0400
commit372d8bf80c88040a4b18a9b6a72eb3e8914b88a6 (patch)
treef9c659e0a6a875b19651c65d0bfd5162588f2220
parent947a61dcea4d0ca854953fa3168f0a9b71f90e6f (diff)
added test() function to ASGD
-rw-r--r--ASGDOptimization.lua22
1 files changed, 21 insertions, 1 deletions
diff --git a/ASGDOptimization.lua b/ASGDOptimization.lua
index 892d740..03a2058 100644
--- a/ASGDOptimization.lua
+++ b/ASGDOptimization.lua
@@ -64,4 +64,24 @@ function ASGD:optimize()
-- (4c) update mu_t
-- mu_t = 1/max(1,t-t0)
self.mu_t = 1 / math.max(1,self.t - self.t0)
-end \ No newline at end of file
+end
+
+-- in ASGD we keep a copy of the parameters which is an averaged
+-- version of the current parameters. This function is to test with
+-- those averaged parameters. Best to run on batches because we have
+-- to copy the full parameter vector
+
+function ASGD:test(_inputs, _targets) -- function test
+ -- (0) make a backup of the online parameters
+ self.backup = self.backup or
+ self.parameters.new():resizeAs(self.parameters)
+ self.backup:copy(self.parameters)
+ -- (1) copy average parameters into the model
+ self.parameters:copy(self.a)
+ -- (2) do the test with the average parameters
+ self.output = self.module:forward(_inputs)
+ self.error = self.criterion:forward(self.output, _targets)
+ -- (3) copy back the online parameters to continue training
+ self.parameters:copy(self.backup)
+ return self.error
+end