Welcome to mirror list, hosted at ThFree Co, Russian Federation.

search.h « vowpalwabbit - github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 8c89a37306821caf885902c8a08d566942120c9e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
/*
Copyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved.  Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
#include "global_data.h"

#define cdbg clog
#undef cdbg
#define cdbg if (1) {} else clog
// comment the previous two lines if you want loads of debug output :)

typedef uint32_t    action;
typedef uint32_t    ptag;

namespace Search {
  struct search_private;
  struct search_task;

  extern uint32_t AUTO_CONDITION_FEATURES, AUTO_HAMMING_LOSS, EXAMPLES_DONT_CHANGE, IS_LDF;

  struct search {
    // INTERFACE
    // for managing task-specific data that you want on the heap:
    template<class T> void  set_task_data(T*data)           { task_data = data; }
    template<class T> T*    get_task_data()                 { return (T*)task_data; }

    // for setting programmatic options during initialization
    // this should be an or ("|") of AUTO_CONDITION_FEATURES, etc.
    void set_options(uint32_t opts);

    // change the default label parser, but you _must_ tell me how
    // to detect test examples!
    void set_label_parser(label_parser&lp, bool (*is_test)(void*));

    // for adding command-line options
    void add_program_options(po::variables_map& vw, po::options_description& opts);
    
    // for explicitly declaring a loss incrementally
    void loss(float incr_loss);

    // make a prediction on an example. returns the predicted action.
    // arguments:
    //   ec                    the example (features) on which to make a prediction
    //   my_tag                a tag for this prediction, so that you can explicitly
    //                           state, for future predictions, which ones depend
    //                           explicitely or implicitly on this prediction
    //   oracle_actions        an array of actions that the oracle would take
    //                           NULL => the oracle doesn't know (is random!)
    //   oracle_actions_cnt    the length of the previous array, or 0 if it's NULL
    //   condition_on          an array of previous (or future) predictions on which
    //                           this prediction depends. the semantics of conditioning
    //                           is that IF the predictions for all the tags in
    //                           condition_on were the same, then the prediction for
    //                           _this_ example will also be the same. i.e., same
    //                           features, etc. (also assuming same policy). if
    //                           AUTO_CONDITION_FEATURES is on, then we will automatically
    //                           add features to ec based on what you're conditioning on.
    //                           NULL => independent prediction
    //   condition_on_names    a string containing the list of names of features you're
    //                           conditioning on. used explicitly for auditing, implicitly
    //                           for keeping tags separated. also, strlen(condition_on_names)
    //                           tells us how long condition_on is
    //   allowed_actions       an array of actions that are allowed at this step, or
    //                           NULL if everything is allowed
    //   allowed_actions_cnt   the length of allowed_actions
    //   learner_id            the id for the underlying learner to use (via set_num_learners)
    action predict(        example& ec
                   ,       ptag     my_tag
                   , const action*  oracle_actions
                   ,       size_t   oracle_actions_cnt   = 1
                   , const ptag*    condition_on         = NULL
                   , const char*    condition_on_names   = NULL   // strlen(condition_on_names) should == |condition_on|
                   , const action*  allowed_actions      = NULL
                   ,       size_t   allowed_actions_cnt  = 0
                   ,       size_t   learner_id           = 0
                   );

    // make an LDF prediction on a list of examples. arguments are identical to predict(...)
    // with the following exceptions:
    //   * ecs/ec_cnt replace ec. ecs is the list of examples the make up a single
    //     LDF example, and ec_cnt is its length
    //   * there are no more "allowed_actions" because that is implicit in the LDF
    //     example structure
    action predictLDF(        example* ecs
                      ,       size_t   ec_cnt
                      ,       ptag     my_tag
                      , const action*  oracle_actions
                      ,       size_t   oracle_actions_cnt   = 1
                      , const ptag*    condition_on         = NULL
                      , const char*    condition_on_names   = NULL
                      ,       size_t   learner_id           = 0
                      );

    // some times during training, a call to "predict" doesn't
    // actually use the example you pass (*), and for efficiency you
    // might want to forgo the construction of examples in those
    // cases. if a call to predictNeedsExample() returns true, then
    // then any subsequent call to predict should be sure to include
    // correctly processed examples. if it returns false, you can pass
    // anything to the next call to predict.
    //
    // (*) the slight exception is for predictLDF. in this case, we
    // always need to provide some examples so that we know which
    // actions are possible. in LDF mode, if predictNeedsExample()
    // returns false, then it's okay to just provide the labels in
    // your subsequent call to predictLDF(), and skip the feature
    // values.
    bool   predictNeedsExample();
    
    // get the value specified by --search_history_length
    uint32_t get_history_length();

    // check if the user declared ldf mode
    bool is_ldf();
    
    // where you should write output
    std::stringstream& output();

    // set the number of learners
    void set_num_learners(size_t num_learners);

    // get the action sequence from the test run (only run if test_only or -t or...)
    void get_test_action_sequence(vector<action>&);

    // get feature index mask
    size_t get_mask();

    // get stride_shift
    size_t get_stride_shift();

    // internal data that you don't get to see!
    search_private* priv;
    void*           task_data;  // your task data!
    const char*     task_name;
  };

  // for defining new tasks, you must fill out a search_task
  struct search_task {
    // required
    const char* task_name;
    void (*run)(search&, std::vector<example*>&);

    // optional
    void (*initialize)(search&,size_t&, po::variables_map&);
    void (*finish)(search&);
    void (*run_setup)(search&, std::vector<example*>&);
    void (*run_takedown)(search&, std::vector<example*>&);
  };

  // to make calls to "predict" (and "predictLDF") cleaner when you
  // want to use crazy combinations of arguments
  class predictor {
    public:
    predictor(search& sch, ptag my_tag);
    ~predictor();

    // tell the predictor what to use as input. a single example input
    // means non-LDF mode; an array of inputs means LDF mode
    predictor& set_input(example& input_example);
    predictor& set_input(example* input_example, size_t input_length);    // if you're lucky and have an array of examples

    // the following is mostly to make life manageable for the Python interface
    void set_input_length(size_t input_length);  // declare that we have an input_length-long LDF example
    void set_input_at(size_t posn, example&input_example); // set the corresponding input (*after* set_input_length)

    // different ways of adding to the list of oracle actions. you can
    // either add_ or set_; setting erases previous actions. these
    // functions attempt to allocate as little memory as possible, so if
    // you pass a v_array or an action*, unless you later add something
    // else, we'll just store a pointer to your memory. this means that
    // you probably shouldn't change the data there, or free that pointer,
    // between calling add/set_oracle and calling predict()
    predictor& erase_oracles();

    predictor& add_oracle(action a);
    predictor& add_oracle(action*a, size_t action_count);
    predictor& add_oracle(v_array<action>& a);

    predictor& set_oracle(action a);
    predictor& set_oracle(action*a, size_t action_count);
    predictor& set_oracle(v_array<action>& a);
    
    // same as add/set_oracle but for allowed actions
    predictor& erase_alloweds();

    predictor& add_allowed(action a);
    predictor& add_allowed(action*a, size_t action_count);
    predictor& add_allowed(v_array<action>& a);
    
    predictor& set_allowed(action a);
    predictor& set_allowed(action*a, size_t action_count);
    predictor& set_allowed(v_array<action>& a);

    // add a tag to condition on with a name, or set the conditioning
    // variables (i.e., erase previous ones)
    predictor& add_condition(ptag tag, char name);
    predictor& set_condition(ptag tag, char name);
    predictor& add_condition_range(ptag hi, ptag count, char name0); // add (hi,name0), (hi-1,name0+1), ..., (h-count,name0+count)
    predictor& set_condition_range(ptag hi, ptag count, char name0); // set (hi,name0), (hi-1,name0+1), ..., (h-count,name0+count)

    // set learner id
    predictor& set_learner_id(size_t id);

    // change the current tag
    predictor& set_tag(ptag tag);

    // make a prediction
    action predict();
    
    private:
    bool is_ldf;
    ptag my_tag;
    example* ec;
    size_t ec_cnt;
    bool ec_alloced;
    v_array<action> oracle_actions;    bool oracle_is_pointer;   // if we're pointing to your memory TRUE; if it's our own memory FALSE
    v_array<ptag> condition_on_tags;
    v_array<char> condition_on_names;
    v_array<action> allowed_actions;   bool allowed_is_pointer;  // if we're pointing to your memory TRUE; if it's our own memory FALSE
    size_t learner_id;
    search&sch;

    void make_new_pointer(v_array<action>& A, size_t new_size);
    predictor& add_to(v_array<action>& A, bool& A_is_ptr, action a, bool clear_first);
    predictor& add_to(v_array<action>&A, bool& A_is_ptr, action*a, size_t action_count, bool clear_first);
    void free_ec();
    
    // prevent the user from doing something stupid :) ... ugh needed to turn this off for python :(
    //predictor(const predictor&P);
    //predictor&operator=(const predictor&P);
  };
  
  // some helper functions you might find helpful
  template<class T> void check_option(T& ret, vw&all, po::variables_map& vm, const char* opt_name, bool default_to_cmdline, bool(*equal)(T,T), const char* mismatch_error_string, const char* required_error_string);
  void check_option(bool& ret, vw&all, po::variables_map& vm, const char* opt_name, bool default_to_cmdline, const char* mismatch_error_string);
  bool string_equal(string a, string b);
  bool float_equal(float a, float b);
  bool uint32_equal(uint32_t a, uint32_t b);
  bool size_equal(size_t a, size_t b);
  
  // our interface within VW
  LEARNER::learner* setup(vw&, po::variables_map&);
  void search_finish(void*);
  void search_drive(void*);
  void search_learn(void*,example*);  
}