diff options
author | John Langford <jl@hunch.net> | 2014-12-29 04:33:43 +0300 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-12-29 04:33:43 +0300 |
commit | 6ded0936d05668220e5e857fc96e08f2ce1939c4 (patch) | |
tree | 8cc418230d0ec4cdb7e6d6b49f3c2705052d5fd4 | |
parent | 82a148fd5266a3b932528e53f9e053edf72fecfe (diff) |
various simplifications
-rw-r--r-- | vowpalwabbit/autolink.cc | 11 | ||||
-rw-r--r-- | vowpalwabbit/binary.cc | 29 | ||||
-rw-r--r-- | vowpalwabbit/global_data.h | 1 | ||||
-rw-r--r-- | vowpalwabbit/kernel_svm.cc | 2 | ||||
-rw-r--r-- | vowpalwabbit/log_multi.cc | 2 | ||||
-rw-r--r-- | vowpalwabbit/loss_functions.cc | 20 | ||||
-rw-r--r-- | vowpalwabbit/loss_functions.h | 5 | ||||
-rw-r--r-- | vowpalwabbit/nn.cc | 2 | ||||
-rw-r--r-- | vowpalwabbit/parse_args.cc | 2 | ||||
-rw-r--r-- | vowpalwabbit/scorer.cc | 20 | ||||
-rw-r--r-- | vowpalwabbit/simple_label.cc | 3 |
11 files changed, 41 insertions, 56 deletions
diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc index 6d4f0106..7b3bebd8 100644 --- a/vowpalwabbit/autolink.cc +++ b/vowpalwabbit/autolink.cc @@ -1,8 +1,6 @@ #include "reductions.h" #include "simple_label.h" -using namespace LEARNER; - namespace ALINK { const int autoconstant = 524267083; @@ -11,8 +9,8 @@ namespace ALINK { uint32_t stride_shift; }; - template <bool is_learn> - void predict_or_learn(autolink& b, base_learner& base, example& ec) + template <bool is_learn> + void predict_or_learn(autolink& b, LEARNER::base_learner& base, example& ec) { base.predict(ec); float base_pred = ec.pred.scalar; @@ -30,7 +28,6 @@ namespace ALINK { } ec.total_sum_feat_sq += sum_sq; - // apply predict or learn if (is_learn) base.learn(ec); else @@ -41,7 +38,7 @@ namespace ALINK { ec.total_sum_feat_sq -= sum_sq; } - base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all, po::variables_map& vm) { autolink& data = calloc_or_die<autolink>(); data.d = (uint32_t)vm["autolink"].as<size_t>(); @@ -49,7 +46,7 @@ namespace ALINK { *all.file_options << " --autolink " << data.d; - learner<autolink>& ret = init_learner(&data, all.l, predict_or_learn<true>, + LEARNER::learner<autolink>& ret = init_learner(&data, all.l, predict_or_learn<true>, predict_or_learn<false>); return make_base(ret); } diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc index 0916dd6f..04810e26 100644 --- a/vowpalwabbit/binary.cc +++ b/vowpalwabbit/binary.cc @@ -1,13 +1,11 @@ +#include <float.h> #include "reductions.h" #include "multiclass.h" #include "simple_label.h" -using namespace LEARNER; - namespace BINARY { - template <bool is_learn> - void predict_or_learn(char&, base_learner& base, example& ec) { + void predict_or_learn(char&, LEARNER::base_learner& base, example& ec) { if (is_learn) base.learn(ec); else @@ -18,17 +16,22 @@ namespace BINARY { else ec.pred.scalar = -1; - if (ec.l.simple.label == ec.pred.scalar) - ec.loss = 0.; - else - ec.loss = ec.l.simple.weight; + if (ec.l.simple.label != FLT_MAX) + { + if (fabs(ec.l.simple.label) != 1.f) + cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl; + else + if (ec.l.simple.label == ec.pred.scalar) + ec.loss = 0.; + else + ec.loss = ec.l.simple.weight; + } } - base_learner* setup(vw& all, po::variables_map& vm) - {//parse and set arguments - all.sd->binary_label = true; - //Create new learner - learner<char>& ret = init_learner<char>(NULL, all.l, predict_or_learn<true>, predict_or_learn<false>); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm) + { + LEARNER::learner<char>& ret = + LEARNER::init_learner<char>(NULL, all.l, predict_or_learn<true>, predict_or_learn<false>); return make_base(ret); } } diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h index fffe8b55..bdf76b46 100644 --- a/vowpalwabbit/global_data.h +++ b/vowpalwabbit/global_data.h @@ -154,7 +154,6 @@ struct shared_data { double holdout_sum_loss_since_last_pass; size_t holdout_best_pass; - bool binary_label; uint32_t k; }; diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc index 607d90b6..9a88f17e 100644 --- a/vowpalwabbit/kernel_svm.cc +++ b/vowpalwabbit/kernel_svm.cc @@ -810,7 +810,7 @@ namespace KSVM string loss_function = "hinge"; float loss_parameter = 0.0; delete all.loss; - all.loss = getLossFunction(&all, loss_function, (float)loss_parameter); + all.loss = getLossFunction(all, loss_function, (float)loss_parameter); svm_params& params = calloc_or_die<svm_params>(); params.model = &calloc_or_die<svm_model>(); diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index 19250a65..26f6b454 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -527,7 +527,7 @@ namespace LOG_MULTI string loss_function = "quantile";
float loss_parameter = 0.5;
delete(all.loss);
- all.loss = getLossFunction(&all, loss_function, loss_parameter);
+ all.loss = getLossFunction(all, loss_function, loss_parameter);
data.max_predictors = data.k - 1;
diff --git a/vowpalwabbit/loss_functions.cc b/vowpalwabbit/loss_functions.cc index 4d67ae5e..6badc633 100644 --- a/vowpalwabbit/loss_functions.cc +++ b/vowpalwabbit/loss_functions.cc @@ -297,21 +297,18 @@ public: float tau; }; -loss_function* getLossFunction(void* a, string funcName, float function_parameter) { - vw* all=(vw*)a; - if(funcName.compare("squared") == 0 || funcName.compare("Huber") == 0) { +loss_function* getLossFunction(vw& all, string funcName, float function_parameter) { + if(funcName.compare("squared") == 0 || funcName.compare("Huber") == 0) return new squaredloss(); - } else if(funcName.compare("classic") == 0){ + else if(funcName.compare("classic") == 0) return new classic_squaredloss(); - } else if(funcName.compare("hinge") == 0) { - all->sd->binary_label = true; + else if(funcName.compare("hinge") == 0) return new hingeloss(); - } else if(funcName.compare("logistic") == 0) { - if (all->set_minmax != noop_mm) + else if(funcName.compare("logistic") == 0) { + if (all.set_minmax != noop_mm) { - all->sd->min_label = -50; - all->sd->max_label = 50; - all->sd->binary_label = true; + all.sd->min_label = -50; + all.sd->max_label = 50; } return new logloss(); } else if(funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0) { @@ -320,5 +317,4 @@ loss_function* getLossFunction(void* a, string funcName, float function_paramete cout << "Invalid loss function name: \'" << funcName << "\' Bailing!" << endl; throw exception(); } - cout << "end getLossFunction" << endl; } diff --git a/vowpalwabbit/loss_functions.h b/vowpalwabbit/loss_functions.h index 421b3bfe..35b6f24b 100644 --- a/vowpalwabbit/loss_functions.h +++ b/vowpalwabbit/loss_functions.h @@ -8,8 +8,7 @@ license as described in the file LICENSE. #include "parse_primitives.h" struct shared_data; - -using namespace std; +struct vw; class loss_function { @@ -34,4 +33,4 @@ public : virtual ~loss_function() {}; }; -loss_function* getLossFunction(void*, string funcName, float function_parameter = 0); +loss_function* getLossFunction(vw&, std::string funcName, float function_parameter = 0); diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc index bf9bc020..bfab6009 100644 --- a/vowpalwabbit/nn.cc +++ b/vowpalwabbit/nn.cc @@ -356,7 +356,7 @@ CONVERSE: // That's right, I'm using goto. So sue me. << std::endl; n.finished_setup = false; - n.squared_loss = getLossFunction (0, "squared", 0); + n.squared_loss = getLossFunction (all, "squared", 0); n.xsubi = 0; diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index 9f5d071a..ef65345c 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -600,7 +600,7 @@ void parse_example_tweaks(vw& all, po::variables_map& vm) if(vm.count("quantile_tau")) loss_parameter = vm["quantile_tau"].as<float>(); - all.loss = getLossFunction(&all, loss_function, (float)loss_parameter); + all.loss = getLossFunction(all, loss_function, (float)loss_parameter); if (all.l1_lambda < 0.) { cerr << "l1_lambda should be nonnegative: resetting from " << all.l1_lambda << " to 0" << endl; diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc index e93eb84f..50645ed8 100644 --- a/vowpalwabbit/scorer.cc +++ b/vowpalwabbit/scorer.cc @@ -1,14 +1,11 @@ #include <float.h> - #include "reductions.h" -using namespace LEARNER; - namespace Scorer { struct scorer{ vw* all; }; template <bool is_learn, float (*link)(float in)> - void predict_or_learn(scorer& s, base_learner& base, example& ec) + void predict_or_learn(scorer& s, LEARNER::base_learner& base, example& ec) { s.all->set_minmax(s.all->sd, ec.l.simple.label); @@ -24,20 +21,17 @@ namespace Scorer { } // y = f(x) -> [0, 1] - float logistic(float in) - { return 1.f / (1.f + exp(- in)); } + float logistic(float in) { return 1.f / (1.f + exp(- in)); } // http://en.wikipedia.org/wiki/Generalized_logistic_curve // where the lower & upper asymptotes are -1 & 1 respectively // 'glf1' stands for 'Generalized Logistic Function with [-1,1] range' // y = f(x) -> [-1, 1] - float glf1(float in) - { return 2.f / (1.f + exp(- in)) - 1.f; } + float glf1(float in) { return 2.f / (1.f + exp(- in)) - 1.f; } - float noop(float in) - { return in; } + float id(float in) { return in; } - base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all, po::variables_map& vm) { scorer& s = calloc_or_die<scorer>(); s.all = &all; @@ -49,11 +43,11 @@ namespace Scorer { vm = add_options(all, link_opts); - learner<scorer>* l; + LEARNER::learner<scorer>* l; string link = vm["link"].as<string>(); if (!vm.count("link") || link.compare("identity") == 0) - l = &init_learner(&s, all.l, predict_or_learn<true, noop>, predict_or_learn<false, noop>); + l = &init_learner(&s, all.l, predict_or_learn<true, id>, predict_or_learn<false, id>); else if (link.compare("logistic") == 0) { *all.file_options << " --link=logistic "; diff --git a/vowpalwabbit/simple_label.cc b/vowpalwabbit/simple_label.cc index 3bf748de..150f95e9 100644 --- a/vowpalwabbit/simple_label.cc +++ b/vowpalwabbit/simple_label.cc @@ -96,9 +96,6 @@ void parse_simple_label(parser* p, shared_data* sd, void* v, v_array<substring>& cerr << "malformed example!\n"; cerr << "words.size() = " << words.size() << endl; } - if (words.size() > 0 && sd->binary_label && fabs(ld->label) != 1.f) - cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl; - count_label(ld->label); } |