Welcome to mirror list, hosted at ThFree Co, Russian Federation.

best_constant.cc « vowpalwabbit - github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 3642a0140fc5da190c2529401a9260d7502d56e9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include "best_constant.h"

bool  is_more_than_two_labels_observed = false;
float first_observed_label = FLT_MAX;
float second_observed_label = FLT_MAX;

bool get_best_constant(vw& all, float& best_constant, float& best_constant_loss)
{
    if (    first_observed_label == FLT_MAX || // no non-test labels observed or function was never called
            (all.loss == NULL) || (all.sd == NULL)) return false;

    float label1 = first_observed_label; // observed labels might be inside [sd->Min_label, sd->Max_label], so can't use Min/Max
    float label2 = (second_observed_label == FLT_MAX)?0:second_observed_label; // if only one label observed, second might be 0
    if (label1 > label2) {float tmp = label1; label1 = label2; label2 = tmp;} // as don't use min/max - make sure label1 < label2

    float label1_cnt;
    float label2_cnt;

    if (label1 != label2)
    {
        float weighted_labeled_examples = (float)(all.sd->weighted_examples - all.sd->weighted_unlabeled_examples + all.initial_t);
        label1_cnt = (float) (all.sd->weighted_labels - label2*weighted_labeled_examples)/(label1 - label2);
        label2_cnt = weighted_labeled_examples - label1_cnt;
    } else
        return false;

    if ( (label1_cnt + label2_cnt) <= 0.) return false;

    po::parsed_options pos = po::command_line_parser(all.args).
            style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
            options(all.opts).allow_unregistered().run();

    po::variables_map vm = po::variables_map();

    po::store(pos, vm);
    po::notify(vm);

    string funcName;
    if(vm.count("loss_function"))
        funcName = vm["loss_function"].as<string>();
    else
        funcName = "squared";

    if(funcName.compare("squared") == 0 || funcName.compare("Huber") == 0 || funcName.compare("classic") == 0)
    {
        best_constant = (float) all.sd->weighted_labels / (float) (all.sd->weighted_examples - all.sd->weighted_unlabeled_examples + all.initial_t); //GENERIC. WAS: (label1*label1_cnt + label2*label2_cnt) / (label1_cnt + label2_cnt);

    } else if (is_more_than_two_labels_observed) {
        //loss functions below don't have generic formuas for constant yet.
        return false;

    } else if(funcName.compare("hinge") == 0) {

        best_constant = label2_cnt <= label1_cnt ? -1.f: 1.f;

    } else if(funcName.compare("logistic") == 0) {

        label1 = -1.; //override {-50, 50} to get proper loss
        label2 =  1.;

        if (label1_cnt <= 0) best_constant = 1.;
        else
            if (label2_cnt <= 0) best_constant = -1.;
            else
                best_constant = log(label2_cnt/label1_cnt);

    } else if(funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0) {

        float tau = 0.5;
        if(vm.count("quantile_tau"))
            tau = vm["quantile_tau"].as<float>();

        float q = tau*(label1_cnt + label2_cnt);
        if (q < label2_cnt) best_constant = label2;
        else best_constant = label1;
    } else
        return false;


    if (!is_more_than_two_labels_observed)
    best_constant_loss = ( all.loss->getLoss(all.sd, best_constant, label1) * label1_cnt +
                           all.loss->getLoss(all.sd, best_constant, label2) * label2_cnt )
            / (label1_cnt + label2_cnt);
    else best_constant_loss = FLT_MIN;


    return true;
}