Welcome to mirror list, hosted at ThFree Co, Russian Federation.

scorer.cc « vowpalwabbit - github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 51be45f2b0221d93bd5d7b0ed196026dfe6bdf8b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <float.h>

#include "reductions.h"

using namespace LEARNER;

namespace Scorer {
  struct scorer{
    vw* all;
  };

  template <bool is_learn, float (*link)(float in)>
  void predict_or_learn(scorer& s, base_learner& base, example& ec)
  {
    s.all->set_minmax(s.all->sd, ec.l.simple.label);
    
    if (is_learn && ec.l.simple.label != FLT_MAX && ec.l.simple.weight > 0)
      base.learn(ec);
    else
      base.predict(ec);

    if(ec.l.simple.weight > 0 && ec.l.simple.label != FLT_MAX)
      ec.loss = s.all->loss->getLoss(s.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.l.simple.weight;

    ec.pred.scalar = link(ec.pred.scalar);
  }

  // y = f(x) -> [0, 1]
  float logistic(float in)
  {
    return 1.f / (1.f + exp(- in));
  }

  // http://en.wikipedia.org/wiki/Generalized_logistic_curve
  // where the lower & upper asymptotes are -1 & 1 respectively
  // 'glf1' stands for 'Generalized Logistic Function with [-1,1] range'
  //    y = f(x) -> [-1, 1]
  float glf1(float in)
  {
    return 2.f / (1.f + exp(- in)) - 1.f;
  }

  float noop(float in)
  {
    return in;
  }

  base_learner* setup(vw& all, po::variables_map& vm)
  {
    po::options_description opts("Link options");
    opts.add_options()
      ("link", po::value<string>()->default_value("identity"), "Specify the link function: identity, logistic or glf1");
    vm = add_options(all, opts);
    string link = vm["link"].as<string>();

    scorer& s = calloc_or_die<scorer>();
    s.all = &all;

    learner<scorer>& l = init_learner(&s, all.l);
    if (!vm.count("link") || link.compare("identity") == 0)
      {
	l.set_learn(predict_or_learn<true, noop> );
	l.set_predict(predict_or_learn<false, noop> );
      }
    else if (link.compare("logistic") == 0)
      {
	*all.file_options << " --link=logistic ";
	l.set_learn(predict_or_learn<true, logistic> );
	l.set_predict(predict_or_learn<false, logistic>);
      }
    else if (link.compare("glf1") == 0)
      {
	*all.file_options << " --link=glf1 ";
	l.set_learn(predict_or_learn<true, glf1>);
	l.set_predict(predict_or_learn<false, glf1>);
      }
    else
      {
	cerr << "Unknown link function: " << link << endl;
	throw exception();
      }

    return make_base(l);
  }
}