diff options
Diffstat (limited to 'vowpalwabbit/topk.cc')
-rw-r--r-- | vowpalwabbit/topk.cc | 36 |
1 files changed, 18 insertions, 18 deletions
diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc index 445bdb23..52f03ad9 100644 --- a/vowpalwabbit/topk.cc +++ b/vowpalwabbit/topk.cc @@ -16,15 +16,12 @@ namespace TOPK { struct compare_scored_examples { bool operator()(scored_example const& a, scored_example const& b) const - { - return a.first > b.first; - } + { return a.first > b.first; } }; struct topk{ uint32_t B; //rec number priority_queue<scored_example, vector<scored_example>, compare_scored_examples > pr_queue; - vw* all; }; void print_result(int f, priority_queue<scored_example, vector<scored_example>, compare_scored_examples > &pr_queue) @@ -72,43 +69,46 @@ namespace TOPK { if (example_is_newline(ec)) for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++) TOPK::print_result(*sink, d.pr_queue); - + print_update(all, ec); } - + template <bool is_learn> void predict_or_learn(topk& d, LEARNER::base_learner& base, example& ec) { if (example_is_newline(ec)) return;//do not predict newline - + if (is_learn) base.learn(ec); else base.predict(ec); - + if(d.pr_queue.size() < d.B) d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); - + else if(d.pr_queue.top().first < ec.pred.scalar) - { - d.pr_queue.pop(); - d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); - } + { + d.pr_queue.pop(); + d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); + } } - + void finish_example(vw& all, topk& d, example& ec) { TOPK::output_example(all, d, ec); VW::finish_example(all, &ec); } - LEARNER::base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all) { + new_options(all, "TOP K options") + ("top", po::value<size_t>(), "top k recommendation"); + if(missing_required(all)) return NULL; + topk& data = calloc_or_die<topk>(); - data.B = (uint32_t)vm["top"].as<size_t>(); - data.all = &all; + data.B = (uint32_t)all.vm["top"].as<size_t>(); - LEARNER::learner<topk>& l = init_learner(&data, all.l, predict_or_learn<true>, + LEARNER::learner<topk>& l = init_learner(&data, setup_base(all), predict_or_learn<true>, predict_or_learn<false>); l.set_finish_example(finish_example); |