Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJake Hofman <jhofman@gmail.com>2014-01-07 02:01:15 +0400
committerJake Hofman <jhofman@gmail.com>2014-01-07 02:01:15 +0400
commitad6ac6854cb482bc4ab27f1cdc4a5276b7bf27b8 (patch)
tree2b8267ee7c7a125c64f51038eac027cc3628354c /vowpalwabbit/mf.cc
parentef823fba0220b7693f21b17f71261db1bcccf431 (diff)
reductions now have predict functions, but tests break
Diffstat (limited to 'vowpalwabbit/mf.cc')
-rw-r--r--vowpalwabbit/mf.cc22
1 files changed, 8 insertions, 14 deletions
diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc
index a7972726..da87b812 100644
--- a/vowpalwabbit/mf.cc
+++ b/vowpalwabbit/mf.cc
@@ -50,17 +50,15 @@ struct mf {
vw* all;
};
-void inline_predict(mf* data, vw* all, learner& base, example* &ec) {
+void predict(void *d, learner& base, example* ec) {
+ mf* data = (mf*) d;
+ vw* all = data->all;
float prediction = 0;
data->sub_predictions.resize(2*all->rank+1, true);
- // set weight to 0 to indicate test example (predict only)
- float weight = ((label_data*) ec->ld)->weight;
- ((label_data*) ec->ld)->weight = 0;
-
// predict from linear terms
- base.learn(ec);
+ base.predict(ec);
// store linear prediction
data->sub_predictions[0] = ec->partial_prediction;
@@ -79,7 +77,7 @@ void inline_predict(mf* data, vw* all, learner& base, example* &ec) {
for (size_t k = 1; k <= all->rank; k++) {
// compute l^k * x_l using base learner
- base.learn(ec, k);
+ base.predict(ec, k);
data->sub_predictions[2*k-1] = ec->partial_prediction;
}
@@ -89,7 +87,7 @@ void inline_predict(mf* data, vw* all, learner& base, example* &ec) {
for (size_t k = 1; k <= all->rank; k++) {
// compute r^k * x_r using base learner
- base.learn(ec, k + all->rank);
+ base.predict(ec, k + all->rank);
data->sub_predictions[2*k] = ec->partial_prediction;
}
@@ -101,21 +99,17 @@ void inline_predict(mf* data, vw* all, learner& base, example* &ec) {
// restore namespace indices and label
copy_array(ec->indices, data->indices);
- ((label_data*) ec->ld)->weight = weight;
-
// finalize prediction
ec->partial_prediction = prediction;
ec->final_prediction = GD::finalize_prediction(*(data->all), ec->partial_prediction);
-
}
-
void learn(void* d, learner& base, example* ec) {
mf* data = (mf*) d;
vw* all = data->all;
// predict with current weights
- inline_predict(data, all, base, ec);
+ predict(d, base, ec);
// force base learner to use precomputed prediction
ec->precomputed_prediction = true;
@@ -208,7 +202,7 @@ learner* setup(vw& all, po::variables_map& vm) {
for (size_t j = 0; j < (all.reg.weight_mask + 1) / all.reg.stride; j++)
all.reg.weight_vector[j*all.reg.stride] = (float) (0.1 * frand48());
}
- learner* l = new learner(data, learn, all.l, 2*data->rank+1);
+ learner* l = new learner(data, learn, predict, all.l, 2*data->rank+1);
l->set_finish(finish);
return l;
}