diff options
author | Stephane Ross <stephaneross@cmu.edu> | 2012-07-23 14:26:02 +0400 |
---|---|---|
committer | Stephane Ross <stephaneross@cmu.edu> | 2012-07-23 14:26:02 +0400 |
commit | 3687d0e8cb61bfff8f9b153abf8d99d8415cc400 (patch) | |
tree | 48e017a7f86aee21ab770891bcfacce2f72a2540 /vowpalwabbit/cb.h | |
parent | 2d4af86e5bc3b76af71df242fd10129d8dc7e54e (diff) |
Addition of contextual bandit module, but buggy right now
Diffstat (limited to 'vowpalwabbit/cb.h')
-rw-r--r-- | vowpalwabbit/cb.h | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/vowpalwabbit/cb.h b/vowpalwabbit/cb.h new file mode 100644 index 00000000..50c10ac7 --- /dev/null +++ b/vowpalwabbit/cb.h @@ -0,0 +1,48 @@ +#ifndef CB_H +#define CB_H + +#define CB_TYPE_DR 0 +#define CB_TYPE_DM 1 +#define CB_TYPE_IPS 2 + +#include "global_data.h" +#include "parser.h" + +//Contextual Bandit module to deal with incomplete cost-sensitive data +//Currently implemented as a reduction to cost-sensitive learning, using the methods discussed in the paper 'Doubly Robust Policy Evaluation and Learning'. + +//CB is currently made to work with CSOAA or WAP as base cs learner +//TODO: extend to handle CSOAA_LDF and WAP_LDF + +namespace CB { + + struct cb_class { // names are for compatibility with 'features' + float x; // the cost of this class + uint32_t weight_index; // the index of this class + float partial_prediction; // a partial prediction: new! + float prob_action; //new for bandit setting, specifies the probability our policy chose this class for importance weighting + bool operator==(cb_class j){return weight_index == j.weight_index;} + }; + + struct label { + v_array<cb_class> costs; + }; + + void parse_flags(vw& all, std::vector<std::string>&, po::variables_map& vm, size_t s); + + void output_example(vw& all, example* ec); + size_t read_cached_label(shared_data* sd, void* v, io_buf& cache); + void cache_label(void* v, io_buf& cache); + void default_label(void* v); + void parse_label(shared_data* sd, void* v, v_array<substring>& words); + void delete_label(void* v); + float weight(void* v); + float initial(void* v); + const label_parser cb_label_parser = {default_label, parse_label, + cache_label, read_cached_label, + delete_label, weight, initial, + sizeof(label)}; + +} + +#endif |