Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/mf.cc')
-rw-r--r--vowpalwabbit/mf.cc14
1 files changed, 6 insertions, 8 deletions
diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc
index 2cce6272..4e00be8d 100644
--- a/vowpalwabbit/mf.cc
+++ b/vowpalwabbit/mf.cc
@@ -43,7 +43,7 @@ struct mf {
};
template <bool cache_sub_predictions>
-void predict(mf& data, learner& base, example& ec) {
+void predict(mf& data, base_learner& base, example& ec) {
float prediction = 0;
if (cache_sub_predictions)
data.sub_predictions.resize(2*data.rank+1, true);
@@ -102,7 +102,7 @@ void predict(mf& data, learner& base, example& ec) {
ec.pred.scalar = GD::finalize_prediction(data.all->sd, ec.partial_prediction);
}
-void learn(mf& data, learner& base, example& ec) {
+void learn(mf& data, base_learner& base, example& ec) {
// predict with current weights
predict<true>(data, base, ec);
float predicted = ec.pred.scalar;
@@ -189,7 +189,7 @@ void finish(mf& o) {
}
-learner* setup(vw& all, po::variables_map& vm) {
+base_learner* setup(vw& all, po::variables_map& vm) {
mf* data = new mf;
// copy global data locally
@@ -203,10 +203,8 @@ learner* setup(vw& all, po::variables_map& vm) {
all.random_positive_weights = true;
- learner* l = new learner(data, all.l, 2*data->rank+1);
- l->set_learn<mf, learn>();
- l->set_predict<mf, predict<false> >();
- l->set_finish<mf,finish>();
- return l;
+ learner<mf>& l = init_learner(data, all.l, learn, predict<false>, 2*data->rank+1);
+ l.set_finish(finish);
+ return make_base(l);
}
}