Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2011-10-12 05:28:10 +0400
committerJohn Langford <jl@hunch.net>2011-10-12 05:28:10 +0400
commit7c3d84cb36a63962fe69cd0c2bce994da9435e81 (patch)
tree7e1e8953ca04a3f6e7932b31b1a67af33f06c54b
parent686e1e23738cf020e242a5f71006114ea6c305cd (diff)
safe labels for logistic and hinge
-rw-r--r--global_data.h1
-rw-r--r--loss_functions.cc2
-rw-r--r--parse_args.cc1
-rw-r--r--parser.cc37
-rw-r--r--simple_label.cc5
5 files changed, 25 insertions, 21 deletions
diff --git a/global_data.h b/global_data.h
index 3da68b49..c3fb6a0e 100644
--- a/global_data.h
+++ b/global_data.h
@@ -97,6 +97,7 @@ struct global_data {
bool random_weights;
bool add_constant;
bool nonormalize;
+ bool binary_label;
size_t lda;
float lda_alpha;
diff --git a/loss_functions.cc b/loss_functions.cc
index 62850ced..7d30ece5 100644
--- a/loss_functions.cc
+++ b/loss_functions.cc
@@ -272,12 +272,14 @@ loss_function* getLossFunction(string funcName, double function_parameter) {
} else if(funcName.compare("classic") == 0){
return new classic_squaredloss();
} else if(funcName.compare("hinge") == 0) {
+ global.binary_label = true;
return new hingeloss();
} else if(funcName.compare("logistic") == 0) {
if (set_minmax != noop_mm)
{
global.sd->min_label = -100;
global.sd->max_label = 100;
+ global.binary_label = true;
}
return new logloss();
} else if(funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0) {
diff --git a/parse_args.cc b/parse_args.cc
index bc67c23c..004eb9f7 100644
--- a/parse_args.cc
+++ b/parse_args.cc
@@ -166,6 +166,7 @@ po::variables_map parse_args(int argc, char *argv[], boost::program_options::opt
global.per_feature_regularizer_text = "";
global.ring_size = 1 << 8;
global.nonormalize = false;
+ global.binary_label = false;
global.adaptive = false;
global.add_constant = true;
diff --git a/parser.cc b/parser.cc
index 2f04cff1..51c50a5c 100644
--- a/parser.cc
+++ b/parser.cc
@@ -730,28 +730,23 @@ void setup_example(parser* p, example* ae)
}
- if (global.rank == 0)
- for (vector<string>::iterator i = global.pairs.begin(); i != global.pairs.end();i++)
- {
- ae->num_features
- += (ae->atomics[(int)(*i)[0]].end - ae->atomics[(int)(*i)[0]].begin)
- *(ae->atomics[(int)(*i)[1]].end - ae->atomics[(int)(*i)[1]].begin);
-
- ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]]*ae->sum_feat_sq[(int)(*i)[1]];
-
- }
- else
- for (vector<string>::iterator i = global.pairs.begin(); i != global.pairs.end();i++)
- {
- ae->num_features
- += (ae->atomics[(int)(*i)[0]].end - ae->atomics[(int)(*i)[0]].begin)
- *global.rank;
- ae->num_features
- += (ae->atomics[(int)(*i)[1]].end - ae->atomics[(int)(*i)[1]].begin)
- *global.rank;
+ if (global.rank == 0)
+ for (vector<string>::iterator i = global.pairs.begin(); i != global.pairs.end();i++)
+ {
+ ae->num_features
+ += (ae->atomics[(int)(*i)[0]].end - ae->atomics[(int)(*i)[0]].begin)
+ *(ae->atomics[(int)(*i)[1]].end - ae->atomics[(int)(*i)[1]].begin);
+ ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]]*ae->sum_feat_sq[(int)(*i)[1]];
+ }
+ else
+ for (vector<string>::iterator i = global.pairs.begin(); i != global.pairs.end();i++)
+ {
+ ae->num_features
+ += (ae->atomics[(int)(*i)[0]].end - ae->atomics[(int)(*i)[0]].begin) * global.rank;
+ ae->num_features
+ += (ae->atomics[(int)(*i)[1]].end - ae->atomics[(int)(*i)[1]].begin)
+ *global.rank;
}
-
-
}
void *main_parse_loop(void *in)
diff --git a/simple_label.cc b/simple_label.cc
index 50e4a772..ca93fd80 100644
--- a/simple_label.cc
+++ b/simple_label.cc
@@ -1,4 +1,5 @@
#include <float.h>
+#include <math.h>
#include "simple_label.h"
#include "cache.h"
@@ -9,6 +10,8 @@ char* bufread_simple_label(label_data* ld, char* c)
{
ld->label = *(float *)c;
c += sizeof(ld->label);
+ if (global.binary_label && fabs(ld->label) != 1.f)
+ cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl;
ld->weight = *(float *)c;
c += sizeof(ld->weight);
ld->initial = *(float *)c;
@@ -94,5 +97,7 @@ void parse_simple_label(void* v, v_array<substring>& words)
cerr << "malformed example!\n";
cerr << "words.index() = " << words.index() << endl;
}
+ if (global.binary_label && fabs(ld->label) != 1.f)
+ cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl;
}