diff options
Diffstat (limited to 'vowpalwabbit/binary.cc')
-rw-r--r-- | vowpalwabbit/binary.cc | 28 |
1 files changed, 15 insertions, 13 deletions
diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc index 7eac0ba8..9951ddd3 100644 --- a/vowpalwabbit/binary.cc +++ b/vowpalwabbit/binary.cc @@ -1,13 +1,11 @@ +#include <float.h> #include "reductions.h" #include "multiclass.h" #include "simple_label.h" -using namespace LEARNER; - namespace BINARY { - template <bool is_learn> - void predict_or_learn(char&, base_learner& base, example& ec) { + void predict_or_learn(char&, LEARNER::base_learner& base, example& ec) { if (is_learn) base.learn(ec); else @@ -18,13 +16,19 @@ namespace BINARY { else ec.pred.scalar = -1; - if (ec.l.simple.label == ec.pred.scalar) - ec.loss = 0.; - else - ec.loss = ec.l.simple.weight; + if (ec.l.simple.label != FLT_MAX) + { + if (fabs(ec.l.simple.label) != 1.f) + cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl; + else + if (ec.l.simple.label == ec.pred.scalar) + ec.loss = 0.; + else + ec.loss = ec.l.simple.weight; + } } - base_learner* setup(vw& all, po::variables_map& vm) +LEARNER::base_learner* setup(vw& all, po::variables_map& vm) {//parse and set arguments po::options_description opts("Binary options"); opts.add_options() @@ -33,11 +37,9 @@ namespace BINARY { if(!vm.count("binary")) return NULL; - all.sd->binary_label = true; //Create new learner - learner<char>& ret = init_learner<char>(NULL, all.l); - ret.set_learn(predict_or_learn<true>); - ret.set_predict(predict_or_learn<false>); + LEARNER::learner<char>& ret = + LEARNER::init_learner<char>(NULL, all.l, predict_or_learn<true>, predict_or_learn<false>); return make_base(ret); } } |