1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
|
/*
Copyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved. Released under a BSD (revised)
license as described in the file LICENSE.
*/
#include <sstream>
#include "multiclass.h"
#include "simple_label.h"
#include "reductions.h"
namespace OAA {
struct oaa{
size_t k;
bool shouldOutput;
vw* all;
};
template <bool is_learn>
void predict_or_learn(oaa& o, LEARNER::base_learner& base, example& ec) {
MULTICLASS::multiclass mc_label_data = ec.l.multi;
if (mc_label_data.label == 0 || (mc_label_data.label > o.k && mc_label_data.label != (uint32_t)-1))
cout << "label " << mc_label_data.label << " is not in {1,"<< o.k << "} This won't work right." << endl;
ec.l.simple = {0.f, mc_label_data.weight, 0.f};
stringstream outputStringStream;
uint32_t prediction = 1;
float score = INT_MIN;
for (uint32_t i = 1; i <= o.k; i++)
{
if (is_learn)
{
if (mc_label_data.label == i)
ec.l.simple.label = 1;
else
ec.l.simple.label = -1;
base.learn(ec, i-1);
}
else
base.predict(ec, i-1);
if (ec.partial_prediction > score)
{
score = ec.partial_prediction;
prediction = i;
}
if (o.shouldOutput) {
if (i > 1) outputStringStream << ' ';
outputStringStream << i << ':' << ec.partial_prediction;
}
}
ec.pred.multiclass = prediction;
ec.l.multi = mc_label_data;
if (o.shouldOutput)
o.all->print_text(o.all->raw_prediction, outputStringStream.str(), ec.tag);
}
LEARNER::base_learner* setup(vw& all)
{
new_options(all, "One-against-all options")
("oaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> labels");
if(missing_required(all)) return NULL;
oaa& data = calloc_or_die<oaa>();
data.k = all.vm["oaa"].as<size_t>();
data.shouldOutput = all.raw_prediction > 0;
data.all = &all;
*all.file_options << " --oaa " << data.k;
LEARNER::learner<oaa>& l = init_learner(&data, setup_base(all), predict_or_learn<true>,
predict_or_learn<false>, data.k);
l.set_finish_example(MULTICLASS::finish_example<oaa>);
all.p->lp = MULTICLASS::mc_label;
return make_base(l);
}
}
|