Welcome to mirror list, hosted at ThFree Co, Russian Federation.

oaa.cc « vowpalwabbit - github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 4c905e13bd2f5f6156466f7d2c904095ad26c973 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
/*
Copyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved.  Released under a BSD (revised)
license as described in the file LICENSE.
 */
#include <sstream>
#include "multiclass.h"
#include "simple_label.h"
#include "reductions.h"
#include "vw.h"

using namespace std;
using namespace LEARNER;

namespace OAA {
  struct oaa{
    size_t k;
    bool shouldOutput;
    vw* all;
  };

  template <bool is_learn>
  void predict_or_learn(oaa& o, base_learner& base, example& ec) {
    MULTICLASS::multiclass mc_label_data = ec.l.multi;
    if (mc_label_data.label == 0 || (mc_label_data.label > o.k && mc_label_data.label != (uint32_t)-1))
      cout << "label " << mc_label_data.label << " is not in {1,"<< o.k << "} This won't work right." << endl;
    
    ec.l.simple = {0.f, mc_label_data.weight, 0.f};

    stringstream outputStringStream;

    uint32_t prediction = 1;
    float score = INT_MIN;
    for (uint32_t i = 1; i <= o.k; i++)
      {
	if (is_learn)
	  {
	    if (mc_label_data.label == i)
	      ec.l.simple.label = 1;
	    else
	      ec.l.simple.label = -1;
	    
	    base.learn(ec, i-1);
	  }
	else
	  base.predict(ec, i-1);

        if (ec.partial_prediction > score)
          {
            score = ec.partial_prediction;
            prediction = i;
          }
	
        if (o.shouldOutput) {
          if (i > 1) outputStringStream << ' ';
          outputStringStream << i << ':' << ec.partial_prediction;
        }
      }	
    ec.pred.multiclass = prediction;
    ec.l.multi = mc_label_data;
    
    if (o.shouldOutput) 
      o.all->print_text(o.all->raw_prediction, outputStringStream.str(), ec.tag);
  }
  
  void finish_example(vw& all, oaa&, example& ec)
  {
    MULTICLASS::output_example(all, ec);
    VW::finish_example(all, &ec);
  }

  base_learner* setup(vw& all, po::variables_map& vm)
  {
    oaa& data = calloc_or_die<oaa>();
    data.k = vm["oaa"].as<size_t>();
    data.shouldOutput = all.raw_prediction > 0;
    data.all = &all;

    *all.file_options << " --oaa " << data.k;
    all.p->lp = MULTICLASS::mc_label;

    learner<oaa>& l = init_learner(&data, all.l, predict_or_learn<true>, 
				   predict_or_learn<false>, data.k);
    l.set_finish_example(finish_example);

    return make_base(l);
  }
}