diff options
author | John Langford <jl@hunch.net> | 2011-10-12 05:28:10 +0400 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2011-10-12 05:28:10 +0400 |
commit | 7c3d84cb36a63962fe69cd0c2bce994da9435e81 (patch) | |
tree | 7e1e8953ca04a3f6e7932b31b1a67af33f06c54b | |
parent | 686e1e23738cf020e242a5f71006114ea6c305cd (diff) |
safe labels for logistic and hinge
-rw-r--r-- | global_data.h | 1 | ||||
-rw-r--r-- | loss_functions.cc | 2 | ||||
-rw-r--r-- | parse_args.cc | 1 | ||||
-rw-r--r-- | parser.cc | 37 | ||||
-rw-r--r-- | simple_label.cc | 5 |
5 files changed, 25 insertions, 21 deletions
diff --git a/global_data.h b/global_data.h index 3da68b49..c3fb6a0e 100644 --- a/global_data.h +++ b/global_data.h @@ -97,6 +97,7 @@ struct global_data { bool random_weights; bool add_constant; bool nonormalize; + bool binary_label; size_t lda; float lda_alpha; diff --git a/loss_functions.cc b/loss_functions.cc index 62850ced..7d30ece5 100644 --- a/loss_functions.cc +++ b/loss_functions.cc @@ -272,12 +272,14 @@ loss_function* getLossFunction(string funcName, double function_parameter) { } else if(funcName.compare("classic") == 0){ return new classic_squaredloss(); } else if(funcName.compare("hinge") == 0) { + global.binary_label = true; return new hingeloss(); } else if(funcName.compare("logistic") == 0) { if (set_minmax != noop_mm) { global.sd->min_label = -100; global.sd->max_label = 100; + global.binary_label = true; } return new logloss(); } else if(funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0) { diff --git a/parse_args.cc b/parse_args.cc index bc67c23c..004eb9f7 100644 --- a/parse_args.cc +++ b/parse_args.cc @@ -166,6 +166,7 @@ po::variables_map parse_args(int argc, char *argv[], boost::program_options::opt global.per_feature_regularizer_text = ""; global.ring_size = 1 << 8; global.nonormalize = false; + global.binary_label = false; global.adaptive = false; global.add_constant = true; @@ -730,28 +730,23 @@ void setup_example(parser* p, example* ae) } - if (global.rank == 0) - for (vector<string>::iterator i = global.pairs.begin(); i != global.pairs.end();i++) - { - ae->num_features - += (ae->atomics[(int)(*i)[0]].end - ae->atomics[(int)(*i)[0]].begin) - *(ae->atomics[(int)(*i)[1]].end - ae->atomics[(int)(*i)[1]].begin); - - ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]]*ae->sum_feat_sq[(int)(*i)[1]]; - - } - else - for (vector<string>::iterator i = global.pairs.begin(); i != global.pairs.end();i++) - { - ae->num_features - += (ae->atomics[(int)(*i)[0]].end - ae->atomics[(int)(*i)[0]].begin) - *global.rank; - ae->num_features - += (ae->atomics[(int)(*i)[1]].end - ae->atomics[(int)(*i)[1]].begin) - *global.rank; + if (global.rank == 0) + for (vector<string>::iterator i = global.pairs.begin(); i != global.pairs.end();i++) + { + ae->num_features + += (ae->atomics[(int)(*i)[0]].end - ae->atomics[(int)(*i)[0]].begin) + *(ae->atomics[(int)(*i)[1]].end - ae->atomics[(int)(*i)[1]].begin); + ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]]*ae->sum_feat_sq[(int)(*i)[1]]; + } + else + for (vector<string>::iterator i = global.pairs.begin(); i != global.pairs.end();i++) + { + ae->num_features + += (ae->atomics[(int)(*i)[0]].end - ae->atomics[(int)(*i)[0]].begin) * global.rank; + ae->num_features + += (ae->atomics[(int)(*i)[1]].end - ae->atomics[(int)(*i)[1]].begin) + *global.rank; } - - } void *main_parse_loop(void *in) diff --git a/simple_label.cc b/simple_label.cc index 50e4a772..ca93fd80 100644 --- a/simple_label.cc +++ b/simple_label.cc @@ -1,4 +1,5 @@ #include <float.h> +#include <math.h> #include "simple_label.h" #include "cache.h" @@ -9,6 +10,8 @@ char* bufread_simple_label(label_data* ld, char* c) { ld->label = *(float *)c; c += sizeof(ld->label); + if (global.binary_label && fabs(ld->label) != 1.f) + cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl; ld->weight = *(float *)c; c += sizeof(ld->weight); ld->initial = *(float *)c; @@ -94,5 +97,7 @@ void parse_simple_label(void* v, v_array<substring>& words) cerr << "malformed example!\n"; cerr << "words.index() = " << words.index() << endl; } + if (global.binary_label && fabs(ld->label) != 1.f) + cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl; } |