Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephane Ross <stephaneross@cmu.edu>2012-07-23 14:26:02 +0400
committerStephane Ross <stephaneross@cmu.edu>2012-07-23 14:26:02 +0400
commit3687d0e8cb61bfff8f9b153abf8d99d8415cc400 (patch)
tree48e017a7f86aee21ab770891bcfacce2f72a2540 /vowpalwabbit/cb.h
parent2d4af86e5bc3b76af71df242fd10129d8dc7e54e (diff)
Addition of contextual bandit module, but buggy right now
Diffstat (limited to 'vowpalwabbit/cb.h')
-rw-r--r--vowpalwabbit/cb.h48
1 files changed, 48 insertions, 0 deletions
diff --git a/vowpalwabbit/cb.h b/vowpalwabbit/cb.h
new file mode 100644
index 00000000..50c10ac7
--- /dev/null
+++ b/vowpalwabbit/cb.h
@@ -0,0 +1,48 @@
+#ifndef CB_H
+#define CB_H
+
+#define CB_TYPE_DR 0
+#define CB_TYPE_DM 1
+#define CB_TYPE_IPS 2
+
+#include "global_data.h"
+#include "parser.h"
+
+//Contextual Bandit module to deal with incomplete cost-sensitive data
+//Currently implemented as a reduction to cost-sensitive learning, using the methods discussed in the paper 'Doubly Robust Policy Evaluation and Learning'.
+
+//CB is currently made to work with CSOAA or WAP as base cs learner
+//TODO: extend to handle CSOAA_LDF and WAP_LDF
+
+namespace CB {
+
+ struct cb_class { // names are for compatibility with 'features'
+ float x; // the cost of this class
+ uint32_t weight_index; // the index of this class
+ float partial_prediction; // a partial prediction: new!
+ float prob_action; //new for bandit setting, specifies the probability our policy chose this class for importance weighting
+ bool operator==(cb_class j){return weight_index == j.weight_index;}
+ };
+
+ struct label {
+ v_array<cb_class> costs;
+ };
+
+ void parse_flags(vw& all, std::vector<std::string>&, po::variables_map& vm, size_t s);
+
+ void output_example(vw& all, example* ec);
+ size_t read_cached_label(shared_data* sd, void* v, io_buf& cache);
+ void cache_label(void* v, io_buf& cache);
+ void default_label(void* v);
+ void parse_label(shared_data* sd, void* v, v_array<substring>& words);
+ void delete_label(void* v);
+ float weight(void* v);
+ float initial(void* v);
+ const label_parser cb_label_parser = {default_label, parse_label,
+ cache_label, read_cached_label,
+ delete_label, weight, initial,
+ sizeof(label)};
+
+}
+
+#endif