Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2014-03-27 02:09:22 +0400
committerJohn Langford <jl@hunch.net>2014-03-27 02:09:22 +0400
commitd639c86415e89d5df6f9df409e8bec4bb7570662 (patch)
treec6dace0c2cc47bbc9c1ae8a5f79c70adfc41d5e5 /vowpalwabbit/cb_algs.h
parentc7da88a24d53e2189e807c65ed562034c21f7f4e (diff)
version control cb_algs
Diffstat (limited to 'vowpalwabbit/cb_algs.h')
-rw-r--r--vowpalwabbit/cb_algs.h44
1 files changed, 44 insertions, 0 deletions
diff --git a/vowpalwabbit/cb_algs.h b/vowpalwabbit/cb_algs.h
new file mode 100644
index 00000000..37096585
--- /dev/null
+++ b/vowpalwabbit/cb_algs.h
@@ -0,0 +1,44 @@
+/*
+Copyright (c) by respective owners including Yahoo!, Microsoft, and
+individual contributors. All rights reserved. Released under a BSD
+license as described in the file LICENSE.
+ */
+#ifndef CB_ALGS_H
+#define CB_ALGS_H
+
+//TODO: extend to handle CSOAA_LDF and WAP_LDF
+namespace CB_ALGS {
+
+ LEARNER::learner* setup(vw& all, std::vector<std::string>&, po::variables_map& vm, po::variables_map& vm_file);
+
+ template <bool is_learn>
+ float get_cost_pred(vw& all, CB::cb_class* known_cost, example& ec, uint32_t index, uint32_t base)
+ {
+ CB::label* ld = (CB::label*)ec.ld;
+
+ label_data simple_temp;
+ simple_temp.initial = 0.;
+ if (known_cost != NULL && index == known_cost->action)
+ {
+ simple_temp.label = known_cost->cost;
+ simple_temp.weight = 1.;
+ }
+ else
+ {
+ simple_temp.label = FLT_MAX;
+ simple_temp.weight = 0.;
+ }
+
+ ec.ld = &simple_temp;
+
+ if (is_learn)
+ all.scorer->learn(ec, index-1+base);
+ else
+ all.scorer->predict(ec, index-1+base);
+ ec.ld = ld;
+
+ return ec.final_prediction;
+ }
+}
+
+#endif