From 99f9974e4002e0aaef22741f656e5bc9baf9dc40 Mon Sep 17 00:00:00 2001 From: Hal Daume III Date: Sat, 24 May 2014 16:56:42 -0400 Subject: merged john's changes --- vowpalwabbit/bfgs.cc | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) (limited to 'vowpalwabbit/bfgs.cc') diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc index 6f9f04c8..8b1778a5 100644 --- a/vowpalwabbit/bfgs.cc +++ b/vowpalwabbit/bfgs.cc @@ -154,7 +154,7 @@ bool test_example(example& ec) float bfgs_predict(vw& all, example& ec) { - ec.partial_prediction = GD::inline_predict(all,ec); + ec.partial_prediction = GD::inline_predict(all,ec); return GD::finalize_prediction(all, ec.partial_prediction); } @@ -198,7 +198,7 @@ void update_preconditioner(vw& all, example& ec) float dot_with_direction(vw& all, example& ec) { ec.ft_offset+= W_DIR; - float ret = GD::inline_predict(all, ec); + float ret = GD::inline_predict(all, ec); ec.ft_offset-= W_DIR; return ret; @@ -802,8 +802,10 @@ void end_pass(bfgs& b) set_done(*all); cerr<<"Early termination reached w.r.t. holdout set error"; } - - } + } if (b.final_pass == b.current_pass) { + finalize_regressor(*all, all->final_regressor_name); + set_done(*all); + } }else{//reaching convergence in the previous pass if(b.output_regularizer) @@ -968,7 +970,7 @@ void save_load(bfgs& b, io_buf& model_file, bool read, bool text) b.backstep_on = true; } -learner* setup(vw& all, std::vector&opts, po::variables_map& vm) +learner* setup(vw& all, po::variables_map& vm) { bfgs* b = (bfgs*)calloc_or_die(1,sizeof(bfgs)); b->all = &all; @@ -982,6 +984,16 @@ learner* setup(vw& all, std::vector&opts, po::variables_map& vm) b->no_win_counter = 0; b->early_stop_thres = 3; + po::options_description bfgs_opts("LBFGS options"); + + bfgs_opts.add_options() + ("hessian_on", "use second derivative in line search") + ("mem", po::value(&(all.m)), "memory in bfgs") + ("conjugate_gradient", "use conjugate gradient based optimization") + ("termination", po::value(&(all.rel_threshold)),"Termination threshold"); + + vm = add_options(all, bfgs_opts); + if(!all.holdout_set_off) { all.sd->holdout_best_loss = FLT_MAX; -- cgit v1.2.3