Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/topk.cc')
-rw-r--r--vowpalwabbit/topk.cc36
1 files changed, 18 insertions, 18 deletions
diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc
index 445bdb23..52f03ad9 100644
--- a/vowpalwabbit/topk.cc
+++ b/vowpalwabbit/topk.cc
@@ -16,15 +16,12 @@ namespace TOPK {
struct compare_scored_examples
{
bool operator()(scored_example const& a, scored_example const& b) const
- {
- return a.first > b.first;
- }
+ { return a.first > b.first; }
};
struct topk{
uint32_t B; //rec number
priority_queue<scored_example, vector<scored_example>, compare_scored_examples > pr_queue;
- vw* all;
};
void print_result(int f, priority_queue<scored_example, vector<scored_example>, compare_scored_examples > &pr_queue)
@@ -72,43 +69,46 @@ namespace TOPK {
if (example_is_newline(ec))
for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++)
TOPK::print_result(*sink, d.pr_queue);
-
+
print_update(all, ec);
}
-
+
template <bool is_learn>
void predict_or_learn(topk& d, LEARNER::base_learner& base, example& ec)
{
if (example_is_newline(ec)) return;//do not predict newline
-
+
if (is_learn)
base.learn(ec);
else
base.predict(ec);
-
+
if(d.pr_queue.size() < d.B)
d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
-
+
else if(d.pr_queue.top().first < ec.pred.scalar)
- {
- d.pr_queue.pop();
- d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
- }
+ {
+ d.pr_queue.pop();
+ d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
+ }
}
-
+
void finish_example(vw& all, topk& d, example& ec)
{
TOPK::output_example(all, d, ec);
VW::finish_example(all, &ec);
}
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm)
+ LEARNER::base_learner* setup(vw& all)
{
+ new_options(all, "TOP K options")
+ ("top", po::value<size_t>(), "top k recommendation");
+ if(missing_required(all)) return NULL;
+
topk& data = calloc_or_die<topk>();
- data.B = (uint32_t)vm["top"].as<size_t>();
- data.all = &all;
+ data.B = (uint32_t)all.vm["top"].as<size_t>();
- LEARNER::learner<topk>& l = init_learner(&data, all.l, predict_or_learn<true>,
+ LEARNER::learner<topk>& l = init_learner(&data, setup_base(all), predict_or_learn<true>,
predict_or_learn<false>);
l.set_finish_example(finish_example);