Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJake Hofman <jhofman@gmail>2014-01-08 05:26:07 +0400
committerJake Hofman <jhofman@gmail>2014-01-08 05:26:07 +0400
commitf4f64eb8fbc175ceab134d95b9291d9ae7a5dc7c (patch)
tree7413ca4108b635881c4b50eefe83bb51fad96848 /vowpalwabbit/mf.cc
parentde91f21a827ae6e53e239a26edcc53f632b49c4f (diff)
added --new_mf flag to use reductionized matrix factorization
Diffstat (limited to 'vowpalwabbit/mf.cc')
-rw-r--r--vowpalwabbit/mf.cc55
1 files changed, 27 insertions, 28 deletions
diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc
index da87b812..89c11447 100644
--- a/vowpalwabbit/mf.cc
+++ b/vowpalwabbit/mf.cc
@@ -50,18 +50,21 @@ struct mf {
vw* all;
};
+template <bool cache_sub_predictions>
void predict(void *d, learner& base, example* ec) {
mf* data = (mf*) d;
vw* all = data->all;
float prediction = 0;
- data->sub_predictions.resize(2*all->rank+1, true);
+ if (cache_sub_predictions)
+ data->sub_predictions.resize(2*all->rank+1, true);
// predict from linear terms
base.predict(ec);
// store linear prediction
- data->sub_predictions[0] = ec->partial_prediction;
+ if (cache_sub_predictions)
+ data->sub_predictions[0] = ec->partial_prediction;
prediction += ec->partial_prediction;
// store namespace indices
@@ -70,30 +73,31 @@ void predict(void *d, learner& base, example* ec) {
// add interaction terms to prediction
for (vector<string>::iterator i = data->pairs.begin(); i != data->pairs.end(); i++) {
if (ec->atomics[(int) (*i)[0]].size() > 0 && ec->atomics[(int) (*i)[1]].size() > 0) {
+ for (size_t k = 1; k <= all->rank; k++) {
- // set example to left namespace only
- ec->indices.erase();
- ec->indices.push_back((int) (*i)[0]);
+ // set example to left namespace only
+ ec->indices.erase();
+ ec->indices.push_back((int) (*i)[0]);
- for (size_t k = 1; k <= all->rank; k++) {
// compute l^k * x_l using base learner
base.predict(ec, k);
- data->sub_predictions[2*k-1] = ec->partial_prediction;
- }
+ float x_dot_l = ec->partial_prediction;
+ if (cache_sub_predictions)
+ data->sub_predictions[2*k-1] = x_dot_l;
- // set example to right namespace only
- ec->indices.erase();
- ec->indices.push_back((int) (*i)[1]);
+ // set example to right namespace only
+ ec->indices.erase();
+ ec->indices.push_back((int) (*i)[1]);
- for (size_t k = 1; k <= all->rank; k++) {
// compute r^k * x_r using base learner
base.predict(ec, k + all->rank);
- data->sub_predictions[2*k] = ec->partial_prediction;
- }
+ float x_dot_r = ec->partial_prediction;
+ if (cache_sub_predictions)
+ data->sub_predictions[2*k] = x_dot_r;
- // accumulate prediction
- for (size_t k = 1; k <= all->rank; k++)
- prediction += (data->sub_predictions[2*k-1] * data->sub_predictions[2*k]);
+ // accumulate prediction
+ prediction += (x_dot_l * x_dot_r);
+ }
}
}
// restore namespace indices and label
@@ -109,13 +113,10 @@ void learn(void* d, learner& base, example* ec) {
vw* all = data->all;
// predict with current weights
- predict(d, base, ec);
-
- // force base learner to use precomputed prediction
- ec->precomputed_prediction = true;
+ predict<true>(d, base, ec);
// update linear weights
- base.learn(ec);
+ base.update(ec);
// store namespace indices
copy_array(data->indices, ec->indices);
@@ -139,7 +140,7 @@ void learn(void* d, learner& base, example* ec) {
f->x *= data->sub_predictions[2*k];
// update l^k using base learner
- base.learn(ec, k);
+ base.update(ec, k);
// restore left namespace features (undoing multiply)
copy_array(ec->atomics[(int) (*i)[0]], data->temp_features);
@@ -160,17 +161,15 @@ void learn(void* d, learner& base, example* ec) {
f->x *= data->sub_predictions[2*k-1];
// update r^k using base learner
- base.learn(ec, k + all->rank);
+ base.update(ec, k + all->rank);
// restore right namespace features
copy_array(ec->atomics[(int) (*i)[1]], data->temp_features);
}
}
}
- // restore namespace indices and unset precomputed prediction
+ // restore namespace indices
copy_array(ec->indices, data->indices);
-
- ec->precomputed_prediction = false;
}
void finish(void* data) {
@@ -202,7 +201,7 @@ learner* setup(vw& all, po::variables_map& vm) {
for (size_t j = 0; j < (all.reg.weight_mask + 1) / all.reg.stride; j++)
all.reg.weight_vector[j*all.reg.stride] = (float) (0.1 * frand48());
}
- learner* l = new learner(data, learn, predict, all.l, 2*data->rank+1);
+ learner* l = new learner(data, learn, predict<false>, all.l, 2*data->rank+1);
l->set_finish(finish);
return l;
}