Welcome to mirror list, hosted at ThFree Co, Russian Federation.

cb_algs.h « vowpalwabbit - github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0593fb6c4fa4f958f26768779752adc2e89f94cf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
/*
Copyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved.  Released under a BSD
license as described in the file LICENSE.
 */
#pragma once
//TODO: extend to handle CSOAA_LDF and WAP_LDF
namespace CB_ALGS {
  LEARNER::base_learner* setup(vw& all);

  template <bool is_learn>
    float get_cost_pred(vw& all, CB::cb_class* known_cost, example& ec, uint32_t index, uint32_t base)
  {
    CB::label ld = ec.l.cb;

    label_data simple_temp;
    simple_temp.initial = 0.;
    if (known_cost != NULL && index == known_cost->action)
      {
	simple_temp.label = known_cost->cost;
	simple_temp.weight = 1.;
      }
    else 
      {
	simple_temp.label = FLT_MAX;
	simple_temp.weight = 0.;
      }
    
    ec.l.simple = simple_temp;

    if (is_learn && simple_temp.label != FLT_MAX)
      all.scorer->learn(ec, index-1+base);
    else
      all.scorer->predict(ec, index-1+base);
    
    float pred = ec.pred.scalar;
    
    ec.l.cb = ld;

    return pred;
  }
}