diff options
author | John Langford <jl@hunch.net> | 2015-01-03 02:10:15 +0300 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2015-01-03 02:10:15 +0300 |
commit | d82ee88ac358fd71d3f0fb09b14ae441e8c4d867 (patch) | |
tree | a84340810ff9ae1c55ec9478abe2e39d37b550d1 | |
parent | 21d8e591d754487485cd28a66cf1034123b5302f (diff) |
topk simplifications
-rw-r--r-- | vowpalwabbit/topk.cc | 23 |
1 files changed, 10 insertions, 13 deletions
diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc index d3414bcd..234c72f2 100644 --- a/vowpalwabbit/topk.cc +++ b/vowpalwabbit/topk.cc @@ -3,7 +3,6 @@ Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD (revised) license as described in the file LICENSE. */ -#include <float.h> #include <sstream> #include <queue> @@ -22,7 +21,6 @@ namespace TOPK { struct topk{ uint32_t B; //rec number priority_queue<scored_example, vector<scored_example>, compare_scored_examples > pr_queue; - vw* all; }; void print_result(int f, priority_queue<scored_example, vector<scored_example>, compare_scored_examples > &pr_queue) @@ -70,30 +68,30 @@ namespace TOPK { if (example_is_newline(ec)) for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++) TOPK::print_result(*sink, d.pr_queue); - + print_update(all, ec); } - + template <bool is_learn> void predict_or_learn(topk& d, LEARNER::base_learner& base, example& ec) { if (example_is_newline(ec)) return;//do not predict newline - + if (is_learn) base.learn(ec); else base.predict(ec); - + if(d.pr_queue.size() < d.B) d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); - + else if(d.pr_queue.top().first < ec.pred.scalar) - { - d.pr_queue.pop(); - d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); - } + { + d.pr_queue.pop(); + d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); + } } - + void finish_example(vw& all, topk& d, example& ec) { TOPK::output_example(all, d, ec); @@ -108,7 +106,6 @@ namespace TOPK { topk& data = calloc_or_die<topk>(); data.B = (uint32_t)all.vm["top"].as<size_t>(); - data.all = &all; LEARNER::learner<topk>& l = init_learner(&data, setup_base(all), predict_or_learn<true>, predict_or_learn<false>); |