diff options
author | John Langford <jl@hunch.net> | 2014-03-27 02:09:22 +0400 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-03-27 02:09:22 +0400 |
commit | d639c86415e89d5df6f9df409e8bec4bb7570662 (patch) | |
tree | c6dace0c2cc47bbc9c1ae8a5f79c70adfc41d5e5 /vowpalwabbit/cb_algs.h | |
parent | c7da88a24d53e2189e807c65ed562034c21f7f4e (diff) |
version control cb_algs
Diffstat (limited to 'vowpalwabbit/cb_algs.h')
-rw-r--r-- | vowpalwabbit/cb_algs.h | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/vowpalwabbit/cb_algs.h b/vowpalwabbit/cb_algs.h new file mode 100644 index 00000000..37096585 --- /dev/null +++ b/vowpalwabbit/cb_algs.h @@ -0,0 +1,44 @@ +/* +Copyright (c) by respective owners including Yahoo!, Microsoft, and +individual contributors. All rights reserved. Released under a BSD +license as described in the file LICENSE. + */ +#ifndef CB_ALGS_H +#define CB_ALGS_H + +//TODO: extend to handle CSOAA_LDF and WAP_LDF +namespace CB_ALGS { + + LEARNER::learner* setup(vw& all, std::vector<std::string>&, po::variables_map& vm, po::variables_map& vm_file); + + template <bool is_learn> + float get_cost_pred(vw& all, CB::cb_class* known_cost, example& ec, uint32_t index, uint32_t base) + { + CB::label* ld = (CB::label*)ec.ld; + + label_data simple_temp; + simple_temp.initial = 0.; + if (known_cost != NULL && index == known_cost->action) + { + simple_temp.label = known_cost->cost; + simple_temp.weight = 1.; + } + else + { + simple_temp.label = FLT_MAX; + simple_temp.weight = 0.; + } + + ec.ld = &simple_temp; + + if (is_learn) + all.scorer->learn(ec, index-1+base); + else + all.scorer->predict(ec, index-1+base); + ec.ld = ld; + + return ec.final_prediction; + } +} + +#endif |