diff options
Diffstat (limited to 'vowpalwabbit/mf.cc')
-rw-r--r-- | vowpalwabbit/mf.cc | 14 |
1 files changed, 6 insertions, 8 deletions
diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc index 2cce6272..4e00be8d 100644 --- a/vowpalwabbit/mf.cc +++ b/vowpalwabbit/mf.cc @@ -43,7 +43,7 @@ struct mf { }; template <bool cache_sub_predictions> -void predict(mf& data, learner& base, example& ec) { +void predict(mf& data, base_learner& base, example& ec) { float prediction = 0; if (cache_sub_predictions) data.sub_predictions.resize(2*data.rank+1, true); @@ -102,7 +102,7 @@ void predict(mf& data, learner& base, example& ec) { ec.pred.scalar = GD::finalize_prediction(data.all->sd, ec.partial_prediction); } -void learn(mf& data, learner& base, example& ec) { +void learn(mf& data, base_learner& base, example& ec) { // predict with current weights predict<true>(data, base, ec); float predicted = ec.pred.scalar; @@ -189,7 +189,7 @@ void finish(mf& o) { } -learner* setup(vw& all, po::variables_map& vm) { +base_learner* setup(vw& all, po::variables_map& vm) { mf* data = new mf; // copy global data locally @@ -203,10 +203,8 @@ learner* setup(vw& all, po::variables_map& vm) { all.random_positive_weights = true; - learner* l = new learner(data, all.l, 2*data->rank+1); - l->set_learn<mf, learn>(); - l->set_predict<mf, predict<false> >(); - l->set_finish<mf,finish>(); - return l; + learner<mf>& l = init_learner(data, all.l, learn, predict<false>, 2*data->rank+1); + l.set_finish(finish); + return make_base(l); } } |