Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/cg.cc
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2011-06-24 06:08:14 +0400
committerJohn Langford <jl@hunch.net>2011-06-24 06:08:14 +0400
commit2f623c4fe2aa664bcd3145ca6e0ec8e0bd1f05a3 (patch)
treebccf35ad6b1adf3fd4d71e14cf18fcd104cc5dad /cg.cc
parent0454c2131c54278010a0bef49d86f730dea3255a (diff)
initiall cluster parallel CG code
Diffstat (limited to 'cg.cc')
-rw-r--r--cg.cc27
1 files changed, 15 insertions, 12 deletions
diff --git a/cg.cc b/cg.cc
index 0ca2bceb..b7c7bf33 100644
--- a/cg.cc
+++ b/cg.cc
@@ -1,10 +1,10 @@
/*
-Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights
+Copyright (c) 2009-2011 Yahoo! Inc. All rights reserved. The copyrights
embodied in the content of this file are licensed under the BSD
(revised) open source license
-The algorithm here is generally based on Jonathan Shewchuck's tutorial.
-
+The algorithm here is generally based on Jonathan Shewchuck's
+tutorial.
*/
#include <fstream>
#include <float.h>
@@ -342,7 +342,10 @@ void setup_cg(gd_thread_params t)
if (current_pass == 0)
{
if(global.master_location != "")
- accumulate(socks, reg, 3); //Accumulate preconditioner
+ {
+ accumulate(socks, reg, 3); //Accumulate preconditioner
+ importance_weight_sum = accumulate_scalar(socks, importance_weight_sum);
+ }
finalize_preconditioner(reg,global.regularization);
}
if (gradient_pass) // We just finished computing all gradients
@@ -354,7 +357,7 @@ void setup_cg(gd_thread_params t)
if (global.regularization > 0.)
loss_sum += add_regularization(reg,global.regularization);
if (!global.quiet)
- fprintf(stderr, "%-f\t", loss_sum );
+ fprintf(stderr, "%-f\t", loss_sum / importance_weight_sum);
if (current_pass > 0 && loss_sum > previous_loss_sum)
{// we stepped to far last time, step back
@@ -380,7 +383,7 @@ void setup_cg(gd_thread_params t)
float new_d_mag = derivative_magnitude(reg, old_first_derivative);
previous_d_mag = new_d_mag;
if (!global.quiet)
- fprintf(stderr, "%f\t%f\t", mix_frac, new_d_mag);
+ fprintf(stderr, "%f\t%f\t", mix_frac, new_d_mag / importance_weight_sum);
update_direction(reg, mix_frac, old_first_derivative);
gradient_pass = false;//now start computing curvature
@@ -402,7 +405,7 @@ void setup_cg(gd_thread_params t)
}
step_size = - dd/curvature;
if (!global.quiet) {
- fprintf(stderr, "%-e\t%-e\t%-e\t%-f\n", curvature, d_mag, step_size, 0.5*step_size*step_size*curvature);
+ fprintf(stderr, "%-e\t%-e\t%-e\t%-f\n", curvature/importance_weight_sum, d_mag/importance_weight_sum, step_size, 0.5*step_size*step_size*curvature /importance_weight_sum);
//fprintf(stdout, "Net comm. time is %f\n",net_comm_time - prev_comm_time);
}
prev_comm_time = net_comm_time;
@@ -480,8 +483,11 @@ void setup_cg(gd_thread_params t)
}
ftime(&t_end_global);
net_time += (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm));
- cerr<<"Net time spent in communication = "<<(float)net_comm_time/(float)1000<<"seconds\n";
- cerr<<"Net time spent = "<<(float)net_time/(float)1000<<"seconds\n";
+ if (!global.quiet)
+ {
+ cerr<<"Net time spent in communication = "<<(float)net_comm_time/(float)1000<<" seconds\n";
+ cerr<<"Net time spent = "<<(float)net_time/(float)1000<<" seconds\n";
+ }
if (global.local_prediction > 0)
shutdown(global.local_prediction, SHUT_WR);
if(global.master_location != "")
@@ -495,13 +501,10 @@ void setup_cg(gd_thread_params t)
;//busywait when we have predicted on all examples but not yet trained on all.
}
- cerr<<"Done CG\n";
- fflush(stderr);
if(global.master_location != "")
all_reduce_close(socks);
free(predictions.begin);
free(old_first_derivative);
- cerr<<"Really Done CG\n";
ftime(&t_end_global);
net_time += (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm));
cerr<<"Net time spent in communication = "<<(float)net_comm_time/(float)1000<<"seconds\n";