Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJake Hofman <jhofman@gmail.com>2014-01-01 01:46:18 +0400
committerJake Hofman <jhofman@gmail.com>2014-01-01 01:46:18 +0400
commit2154710510fb7d678c386fab1dedf084ac494990 (patch)
treed0115f929396067182997f8040bc0e9388b07ef4 /vowpalwabbit/gd_mf.cc
parent93819c84faacdf5f0f187dd71542491a19d33568 (diff)
consolidated gd_mf predict functions
Diffstat (limited to 'vowpalwabbit/gd_mf.cc')
-rw-r--r--vowpalwabbit/gd_mf.cc55
1 files changed, 23 insertions, 32 deletions
diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc
index 5230b8e6..9a4b93a4 100644
--- a/vowpalwabbit/gd_mf.cc
+++ b/vowpalwabbit/gd_mf.cc
@@ -28,9 +28,9 @@ namespace GDMF {
vw* all;
};
-void mf_local_predict(example* ec, regressor& reg);
+void mf_print_audit_features(vw& all, example* ec, size_t offset);
-float mf_inline_predict(vw& all, example* &ec)
+float mf_predict(vw& all, example* ec)
{
float prediction = all.p->lp->get_initial(ec->ld);
@@ -79,13 +79,30 @@ float mf_inline_predict(vw& all, example* &ec)
cerr << "cannot use triples in matrix factorization" << endl;
throw exception();
}
-
+
// ec->topic_predictions has linear, x_dot_l_1, x_dot_r_1, x_dot_l_2, x_dot_r_2, ...
- return prediction;
+ ec->partial_prediction = prediction;
+
+ // finalize prediction and compute loss
+ label_data* ld = (label_data*)ec->ld;
+ all.set_minmax(all.sd, ld->label);
+
+ ec->final_prediction = GD::finalize_prediction(all, ec->partial_prediction);
+
+ if (ld->label != FLT_MAX)
+ {
+ ec->loss = all.loss->getLoss(all.sd, ec->final_prediction, ld->label) * ld->weight;
+ }
+
+ if (all.audit)
+ mf_print_audit_features(all, ec, 0);
+
+ return ec->final_prediction;
}
-void mf_inline_train(vw& all, example* &ec, float update)
+
+void mf_train(vw& all, example* &ec, float update)
{
weight* weights = all.reg.weight_vector;
size_t mask = all.reg.weight_mask;
@@ -190,32 +207,6 @@ void mf_print_audit_features(vw& all, example* ec, size_t offset)
mf_print_offset_features(all, ec, offset);
}
-void mf_local_predict(vw& all, example* ec)
-{
- label_data* ld = (label_data*)ec->ld;
- all.set_minmax(all.sd, ld->label);
-
- ec->final_prediction = GD::finalize_prediction(all, ec->partial_prediction);
-
- if (ld->label != FLT_MAX)
- {
- ec->loss = all.loss->getLoss(all.sd, ec->final_prediction, ld->label) * ld->weight;
- }
-
- if (all.audit)
- mf_print_audit_features(all, ec, 0);
-}
-
-float mf_predict(vw& all, example* ex)
-{
- float prediction = mf_inline_predict(all, ex);
-
- ex->partial_prediction = prediction;
- mf_local_predict(all, ex);
-
- return ex->final_prediction;
-}
-
void save_load(void* d, io_buf& model_file, bool read, bool text)
{
vw* all = ((gdmf*)d)->all;
@@ -287,7 +278,7 @@ void end_pass(void* d)
mf_predict(*all,ec);
if (all->training && ((label_data*)(ec->ld))->label != FLT_MAX)
- mf_inline_train(*all, ec, ec->eta_round);
+ mf_train(*all, ec, ec->eta_round);
}
// placeholder