Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/binary.cc')
-rw-r--r--vowpalwabbit/binary.cc28
1 files changed, 15 insertions, 13 deletions
diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc
index 7eac0ba8..9951ddd3 100644
--- a/vowpalwabbit/binary.cc
+++ b/vowpalwabbit/binary.cc
@@ -1,13 +1,11 @@
+#include <float.h>
#include "reductions.h"
#include "multiclass.h"
#include "simple_label.h"
-using namespace LEARNER;
-
namespace BINARY {
-
template <bool is_learn>
- void predict_or_learn(char&, base_learner& base, example& ec) {
+ void predict_or_learn(char&, LEARNER::base_learner& base, example& ec) {
if (is_learn)
base.learn(ec);
else
@@ -18,13 +16,19 @@ namespace BINARY {
else
ec.pred.scalar = -1;
- if (ec.l.simple.label == ec.pred.scalar)
- ec.loss = 0.;
- else
- ec.loss = ec.l.simple.weight;
+ if (ec.l.simple.label != FLT_MAX)
+ {
+ if (fabs(ec.l.simple.label) != 1.f)
+ cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl;
+ else
+ if (ec.l.simple.label == ec.pred.scalar)
+ ec.loss = 0.;
+ else
+ ec.loss = ec.l.simple.weight;
+ }
}
- base_learner* setup(vw& all, po::variables_map& vm)
+LEARNER::base_learner* setup(vw& all, po::variables_map& vm)
{//parse and set arguments
po::options_description opts("Binary options");
opts.add_options()
@@ -33,11 +37,9 @@ namespace BINARY {
if(!vm.count("binary"))
return NULL;
- all.sd->binary_label = true;
//Create new learner
- learner<char>& ret = init_learner<char>(NULL, all.l);
- ret.set_learn(predict_or_learn<true>);
- ret.set_predict(predict_or_learn<false>);
+ LEARNER::learner<char>& ret =
+ LEARNER::init_learner<char>(NULL, all.l, predict_or_learn<true>, predict_or_learn<false>);
return make_base(ret);
}
}