Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2014-12-29 04:33:43 +0300
committerJohn Langford <jl@hunch.net>2014-12-29 04:33:43 +0300
commit6ded0936d05668220e5e857fc96e08f2ce1939c4 (patch)
tree8cc418230d0ec4cdb7e6d6b49f3c2705052d5fd4
parent82a148fd5266a3b932528e53f9e053edf72fecfe (diff)
various simplifications
-rw-r--r--vowpalwabbit/autolink.cc11
-rw-r--r--vowpalwabbit/binary.cc29
-rw-r--r--vowpalwabbit/global_data.h1
-rw-r--r--vowpalwabbit/kernel_svm.cc2
-rw-r--r--vowpalwabbit/log_multi.cc2
-rw-r--r--vowpalwabbit/loss_functions.cc20
-rw-r--r--vowpalwabbit/loss_functions.h5
-rw-r--r--vowpalwabbit/nn.cc2
-rw-r--r--vowpalwabbit/parse_args.cc2
-rw-r--r--vowpalwabbit/scorer.cc20
-rw-r--r--vowpalwabbit/simple_label.cc3
11 files changed, 41 insertions, 56 deletions
diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc
index 6d4f0106..7b3bebd8 100644
--- a/vowpalwabbit/autolink.cc
+++ b/vowpalwabbit/autolink.cc
@@ -1,8 +1,6 @@
#include "reductions.h"
#include "simple_label.h"
-using namespace LEARNER;
-
namespace ALINK {
const int autoconstant = 524267083;
@@ -11,8 +9,8 @@ namespace ALINK {
uint32_t stride_shift;
};
- template <bool is_learn>
- void predict_or_learn(autolink& b, base_learner& base, example& ec)
+ template <bool is_learn>
+ void predict_or_learn(autolink& b, LEARNER::base_learner& base, example& ec)
{
base.predict(ec);
float base_pred = ec.pred.scalar;
@@ -30,7 +28,6 @@ namespace ALINK {
}
ec.total_sum_feat_sq += sum_sq;
- // apply predict or learn
if (is_learn)
base.learn(ec);
else
@@ -41,7 +38,7 @@ namespace ALINK {
ec.total_sum_feat_sq -= sum_sq;
}
- base_learner* setup(vw& all, po::variables_map& vm)
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm)
{
autolink& data = calloc_or_die<autolink>();
data.d = (uint32_t)vm["autolink"].as<size_t>();
@@ -49,7 +46,7 @@ namespace ALINK {
*all.file_options << " --autolink " << data.d;
- learner<autolink>& ret = init_learner(&data, all.l, predict_or_learn<true>,
+ LEARNER::learner<autolink>& ret = init_learner(&data, all.l, predict_or_learn<true>,
predict_or_learn<false>);
return make_base(ret);
}
diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc
index 0916dd6f..04810e26 100644
--- a/vowpalwabbit/binary.cc
+++ b/vowpalwabbit/binary.cc
@@ -1,13 +1,11 @@
+#include <float.h>
#include "reductions.h"
#include "multiclass.h"
#include "simple_label.h"
-using namespace LEARNER;
-
namespace BINARY {
-
template <bool is_learn>
- void predict_or_learn(char&, base_learner& base, example& ec) {
+ void predict_or_learn(char&, LEARNER::base_learner& base, example& ec) {
if (is_learn)
base.learn(ec);
else
@@ -18,17 +16,22 @@ namespace BINARY {
else
ec.pred.scalar = -1;
- if (ec.l.simple.label == ec.pred.scalar)
- ec.loss = 0.;
- else
- ec.loss = ec.l.simple.weight;
+ if (ec.l.simple.label != FLT_MAX)
+ {
+ if (fabs(ec.l.simple.label) != 1.f)
+ cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl;
+ else
+ if (ec.l.simple.label == ec.pred.scalar)
+ ec.loss = 0.;
+ else
+ ec.loss = ec.l.simple.weight;
+ }
}
- base_learner* setup(vw& all, po::variables_map& vm)
- {//parse and set arguments
- all.sd->binary_label = true;
- //Create new learner
- learner<char>& ret = init_learner<char>(NULL, all.l, predict_or_learn<true>, predict_or_learn<false>);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm)
+ {
+ LEARNER::learner<char>& ret =
+ LEARNER::init_learner<char>(NULL, all.l, predict_or_learn<true>, predict_or_learn<false>);
return make_base(ret);
}
}
diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h
index fffe8b55..bdf76b46 100644
--- a/vowpalwabbit/global_data.h
+++ b/vowpalwabbit/global_data.h
@@ -154,7 +154,6 @@ struct shared_data {
double holdout_sum_loss_since_last_pass;
size_t holdout_best_pass;
- bool binary_label;
uint32_t k;
};
diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc
index 607d90b6..9a88f17e 100644
--- a/vowpalwabbit/kernel_svm.cc
+++ b/vowpalwabbit/kernel_svm.cc
@@ -810,7 +810,7 @@ namespace KSVM
string loss_function = "hinge";
float loss_parameter = 0.0;
delete all.loss;
- all.loss = getLossFunction(&all, loss_function, (float)loss_parameter);
+ all.loss = getLossFunction(all, loss_function, (float)loss_parameter);
svm_params& params = calloc_or_die<svm_params>();
params.model = &calloc_or_die<svm_model>();
diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc
index 19250a65..26f6b454 100644
--- a/vowpalwabbit/log_multi.cc
+++ b/vowpalwabbit/log_multi.cc
@@ -527,7 +527,7 @@ namespace LOG_MULTI
string loss_function = "quantile";
float loss_parameter = 0.5;
delete(all.loss);
- all.loss = getLossFunction(&all, loss_function, loss_parameter);
+ all.loss = getLossFunction(all, loss_function, loss_parameter);
data.max_predictors = data.k - 1;
diff --git a/vowpalwabbit/loss_functions.cc b/vowpalwabbit/loss_functions.cc
index 4d67ae5e..6badc633 100644
--- a/vowpalwabbit/loss_functions.cc
+++ b/vowpalwabbit/loss_functions.cc
@@ -297,21 +297,18 @@ public:
float tau;
};
-loss_function* getLossFunction(void* a, string funcName, float function_parameter) {
- vw* all=(vw*)a;
- if(funcName.compare("squared") == 0 || funcName.compare("Huber") == 0) {
+loss_function* getLossFunction(vw& all, string funcName, float function_parameter) {
+ if(funcName.compare("squared") == 0 || funcName.compare("Huber") == 0)
return new squaredloss();
- } else if(funcName.compare("classic") == 0){
+ else if(funcName.compare("classic") == 0)
return new classic_squaredloss();
- } else if(funcName.compare("hinge") == 0) {
- all->sd->binary_label = true;
+ else if(funcName.compare("hinge") == 0)
return new hingeloss();
- } else if(funcName.compare("logistic") == 0) {
- if (all->set_minmax != noop_mm)
+ else if(funcName.compare("logistic") == 0) {
+ if (all.set_minmax != noop_mm)
{
- all->sd->min_label = -50;
- all->sd->max_label = 50;
- all->sd->binary_label = true;
+ all.sd->min_label = -50;
+ all.sd->max_label = 50;
}
return new logloss();
} else if(funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0) {
@@ -320,5 +317,4 @@ loss_function* getLossFunction(void* a, string funcName, float function_paramete
cout << "Invalid loss function name: \'" << funcName << "\' Bailing!" << endl;
throw exception();
}
- cout << "end getLossFunction" << endl;
}
diff --git a/vowpalwabbit/loss_functions.h b/vowpalwabbit/loss_functions.h
index 421b3bfe..35b6f24b 100644
--- a/vowpalwabbit/loss_functions.h
+++ b/vowpalwabbit/loss_functions.h
@@ -8,8 +8,7 @@ license as described in the file LICENSE.
#include "parse_primitives.h"
struct shared_data;
-
-using namespace std;
+struct vw;
class loss_function {
@@ -34,4 +33,4 @@ public :
virtual ~loss_function() {};
};
-loss_function* getLossFunction(void*, string funcName, float function_parameter = 0);
+loss_function* getLossFunction(vw&, std::string funcName, float function_parameter = 0);
diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc
index bf9bc020..bfab6009 100644
--- a/vowpalwabbit/nn.cc
+++ b/vowpalwabbit/nn.cc
@@ -356,7 +356,7 @@ CONVERSE: // That's right, I'm using goto. So sue me.
<< std::endl;
n.finished_setup = false;
- n.squared_loss = getLossFunction (0, "squared", 0);
+ n.squared_loss = getLossFunction (all, "squared", 0);
n.xsubi = 0;
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index 9f5d071a..ef65345c 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -600,7 +600,7 @@ void parse_example_tweaks(vw& all, po::variables_map& vm)
if(vm.count("quantile_tau"))
loss_parameter = vm["quantile_tau"].as<float>();
- all.loss = getLossFunction(&all, loss_function, (float)loss_parameter);
+ all.loss = getLossFunction(all, loss_function, (float)loss_parameter);
if (all.l1_lambda < 0.) {
cerr << "l1_lambda should be nonnegative: resetting from " << all.l1_lambda << " to 0" << endl;
diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc
index e93eb84f..50645ed8 100644
--- a/vowpalwabbit/scorer.cc
+++ b/vowpalwabbit/scorer.cc
@@ -1,14 +1,11 @@
#include <float.h>
-
#include "reductions.h"
-using namespace LEARNER;
-
namespace Scorer {
struct scorer{ vw* all; };
template <bool is_learn, float (*link)(float in)>
- void predict_or_learn(scorer& s, base_learner& base, example& ec)
+ void predict_or_learn(scorer& s, LEARNER::base_learner& base, example& ec)
{
s.all->set_minmax(s.all->sd, ec.l.simple.label);
@@ -24,20 +21,17 @@ namespace Scorer {
}
// y = f(x) -> [0, 1]
- float logistic(float in)
- { return 1.f / (1.f + exp(- in)); }
+ float logistic(float in) { return 1.f / (1.f + exp(- in)); }
// http://en.wikipedia.org/wiki/Generalized_logistic_curve
// where the lower & upper asymptotes are -1 & 1 respectively
// 'glf1' stands for 'Generalized Logistic Function with [-1,1] range'
// y = f(x) -> [-1, 1]
- float glf1(float in)
- { return 2.f / (1.f + exp(- in)) - 1.f; }
+ float glf1(float in) { return 2.f / (1.f + exp(- in)) - 1.f; }
- float noop(float in)
- { return in; }
+ float id(float in) { return in; }
- base_learner* setup(vw& all, po::variables_map& vm)
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm)
{
scorer& s = calloc_or_die<scorer>();
s.all = &all;
@@ -49,11 +43,11 @@ namespace Scorer {
vm = add_options(all, link_opts);
- learner<scorer>* l;
+ LEARNER::learner<scorer>* l;
string link = vm["link"].as<string>();
if (!vm.count("link") || link.compare("identity") == 0)
- l = &init_learner(&s, all.l, predict_or_learn<true, noop>, predict_or_learn<false, noop>);
+ l = &init_learner(&s, all.l, predict_or_learn<true, id>, predict_or_learn<false, id>);
else if (link.compare("logistic") == 0)
{
*all.file_options << " --link=logistic ";
diff --git a/vowpalwabbit/simple_label.cc b/vowpalwabbit/simple_label.cc
index 3bf748de..150f95e9 100644
--- a/vowpalwabbit/simple_label.cc
+++ b/vowpalwabbit/simple_label.cc
@@ -96,9 +96,6 @@ void parse_simple_label(parser* p, shared_data* sd, void* v, v_array<substring>&
cerr << "malformed example!\n";
cerr << "words.size() = " << words.size() << endl;
}
- if (words.size() > 0 && sd->binary_label && fabs(ld->label) != 1.f)
- cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl;
-
count_label(ld->label);
}