Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/bfgs.cc')
-rw-r--r--vowpalwabbit/bfgs.cc22
1 files changed, 17 insertions, 5 deletions
diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc
index 6f9f04c8..8b1778a5 100644
--- a/vowpalwabbit/bfgs.cc
+++ b/vowpalwabbit/bfgs.cc
@@ -154,7 +154,7 @@ bool test_example(example& ec)
float bfgs_predict(vw& all, example& ec)
{
- ec.partial_prediction = GD::inline_predict<vec_add>(all,ec);
+ ec.partial_prediction = GD::inline_predict(all,ec);
return GD::finalize_prediction(all, ec.partial_prediction);
}
@@ -198,7 +198,7 @@ void update_preconditioner(vw& all, example& ec)
float dot_with_direction(vw& all, example& ec)
{
ec.ft_offset+= W_DIR;
- float ret = GD::inline_predict<vec_add>(all, ec);
+ float ret = GD::inline_predict(all, ec);
ec.ft_offset-= W_DIR;
return ret;
@@ -802,8 +802,10 @@ void end_pass(bfgs& b)
set_done(*all);
cerr<<"Early termination reached w.r.t. holdout set error";
}
-
- }
+ } if (b.final_pass == b.current_pass) {
+ finalize_regressor(*all, all->final_regressor_name);
+ set_done(*all);
+ }
}else{//reaching convergence in the previous pass
if(b.output_regularizer)
@@ -968,7 +970,7 @@ void save_load(bfgs& b, io_buf& model_file, bool read, bool text)
b.backstep_on = true;
}
-learner* setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm)
+learner* setup(vw& all, po::variables_map& vm)
{
bfgs* b = (bfgs*)calloc_or_die(1,sizeof(bfgs));
b->all = &all;
@@ -982,6 +984,16 @@ learner* setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm)
b->no_win_counter = 0;
b->early_stop_thres = 3;
+ po::options_description bfgs_opts("LBFGS options");
+
+ bfgs_opts.add_options()
+ ("hessian_on", "use second derivative in line search")
+ ("mem", po::value<int>(&(all.m)), "memory in bfgs")
+ ("conjugate_gradient", "use conjugate gradient based optimization")
+ ("termination", po::value<float>(&(all.rel_threshold)),"Termination threshold");
+
+ vm = add_options(all, bfgs_opts);
+
if(!all.holdout_set_off)
{
all.sd->holdout_best_loss = FLT_MAX;