Welcome to mirror list, hosted at ThFree Co, Russian Federation.

topk.cc « vowpalwabbit - github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: d3414bcd1269c2c5ca014e4c93955f0b9f735e1e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
/*
Copyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved.  Released under a BSD (revised)
license as described in the file LICENSE.
 */
#include <float.h>
#include <sstream>
#include <queue>

#include "reductions.h"
#include "vw.h"

namespace TOPK {
  typedef pair<float, v_array<char> > scored_example;
  
  struct compare_scored_examples
  {
    bool operator()(scored_example const& a, scored_example const& b) const
    { return a.first > b.first; }
  };
  
  struct topk{
    uint32_t B; //rec number
    priority_queue<scored_example, vector<scored_example>, compare_scored_examples > pr_queue;
    vw* all;
  };

  void print_result(int f, priority_queue<scored_example, vector<scored_example>, compare_scored_examples > &pr_queue)
  {
    if (f >= 0)
    {
      char temp[30];
      std::stringstream ss;
      scored_example tmp_example;
      while(!pr_queue.empty())
      {
        tmp_example = pr_queue.top(); 
        pr_queue.pop();       
        sprintf(temp, "%f", tmp_example.first);
        ss << temp;
        ss << ' ';
        print_tag(ss, tmp_example.second);
        ss << ' ';
        ss << '\n';      
      }
      ss << '\n';        
      ssize_t len = ss.str().size();
#ifdef _WIN32
	  ssize_t t = _write(f, ss.str().c_str(), (unsigned int)len);
#else
	  ssize_t t = write(f, ss.str().c_str(), (unsigned int)len);
#endif
      if (t != len)
        cerr << "write error" << endl;
    }    
  }

  void output_example(vw& all, topk& d, example& ec)
  {
    label_data& ld = ec.l.simple;
    
    if (ld.label != FLT_MAX)
      all.sd->weighted_labels += ld.label * ld.weight;
    all.sd->weighted_examples += ld.weight;
    all.sd->sum_loss += ec.loss;
    all.sd->sum_loss_since_last_dump += ec.loss;
    all.sd->total_features += ec.num_features;
    all.sd->example_number++;
 
    if (example_is_newline(ec))
      for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++)
        TOPK::print_result(*sink, d.pr_queue);
       
    print_update(all, ec);
  }

  template <bool is_learn>
  void predict_or_learn(topk& d, LEARNER::base_learner& base, example& ec)
  {
    if (example_is_newline(ec)) return;//do not predict newline

    if (is_learn)
      base.learn(ec);
    else
      base.predict(ec);

    if(d.pr_queue.size() < d.B)      
      d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));

    else if(d.pr_queue.top().first < ec.pred.scalar)
    {
      d.pr_queue.pop();
      d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
    }
  }

  void finish_example(vw& all, topk& d, example& ec)
  {
    TOPK::output_example(all, d, ec);
    VW::finish_example(all, &ec);
  }

  LEARNER::base_learner* setup(vw& all)
  {
    new_options(all, "TOP K options")
      ("top", po::value<size_t>(), "top k recommendation");
    if(missing_required(all)) return NULL;

    topk& data = calloc_or_die<topk>();
    data.B = (uint32_t)all.vm["top"].as<size_t>();
    data.all = &all;

    LEARNER::learner<topk>& l = init_learner(&data, setup_base(all), predict_or_learn<true>, 
					     predict_or_learn<false>);
    l.set_finish_example(finish_example);

    return make_base(l);
  }
}