Welcome to mirror list, hosted at ThFree Co, Russian Federation.

oaa.cc « vowpalwabbit - github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 6063ed478c7831435349a3bb9c414243a086aca1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
/*
Copyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved.  Released under a BSD (revised)
license as described in the file LICENSE.
 */
#include <sstream>
#include "multiclass.h"
#include "simple_label.h"
#include "reductions.h"
#include "vw.h"

namespace OAA {
  struct oaa{
    size_t k;
    bool shouldOutput;
    vw* all;
  };

  template <bool is_learn>
  void predict_or_learn(oaa& o, LEARNER::base_learner& base, example& ec) {
    MULTICLASS::multiclass mc_label_data = ec.l.multi;
    if (mc_label_data.label == 0 || (mc_label_data.label > o.k && mc_label_data.label != (uint32_t)-1))
      cout << "label " << mc_label_data.label << " is not in {1,"<< o.k << "} This won't work right." << endl;
    
    ec.l.simple = {0.f, mc_label_data.weight, 0.f};

    stringstream outputStringStream;

    uint32_t prediction = 1;
    float score = INT_MIN;
    for (uint32_t i = 1; i <= o.k; i++)
      {
	if (is_learn)
	  {
	    if (mc_label_data.label == i)
	      ec.l.simple.label = 1;
	    else
	      ec.l.simple.label = -1;
	    
	    base.learn(ec, i-1);
	  }
	else
	  base.predict(ec, i-1);

        if (ec.partial_prediction > score)
          {
            score = ec.partial_prediction;
            prediction = i;
          }
	
        if (o.shouldOutput) {
          if (i > 1) outputStringStream << ' ';
          outputStringStream << i << ':' << ec.partial_prediction;
        }
      }	
    ec.pred.multiclass = prediction;
    ec.l.multi = mc_label_data;
    
    if (o.shouldOutput) 
      o.all->print_text(o.all->raw_prediction, outputStringStream.str(), ec.tag);
  }
  
  void finish_example(vw& all, oaa&, example& ec) { MULTICLASS::finish_example(all, ec); }

  /*{
  new_options(all, "One-against-all options")
    ("oaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> labels");
  if (missing_required(all,vm)) return NULL;
  options(all)
    ...;
  add_options(all)
  }*/

  LEARNER::base_learner* setup(vw& all)
  {
    new_options(all, "One-against-all options")
      ("oaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> labels");
    if(missing_required(all)) return NULL;
    
    oaa& data = calloc_or_die<oaa>();
    data.k = all.vm["oaa"].as<size_t>();
    data.shouldOutput = all.raw_prediction > 0;
    data.all = &all;

    *all.file_options << " --oaa " << data.k;
    all.p->lp = MULTICLASS::mc_label;

    LEARNER::learner<oaa>& l = init_learner(&data, setup_base(all), predict_or_learn<true>, 
				   predict_or_learn<false>, data.k);
    l.set_finish_example(finish_example);
    return make_base(l);
  }
}