diff options
author | Hal Daume III <me@hal3.name> | 2014-09-18 01:19:57 +0400 |
---|---|---|
committer | Hal Daume III <me@hal3.name> | 2014-09-18 01:19:57 +0400 |
commit | cb2f619a8a0dfdd4280c390e2aa97dd709c1d68e (patch) | |
tree | 19651c4bc5f68ba04f7a86687658ce529cd24d28 /library | |
parent | 169a4e54e89b319ef560afd4bc31c0b02f7b1832 (diff) | |
parent | dbcf40d7f3acbd29ccebe9f903e2beaec4f9b854 (diff) |
Merge branch 'master' of github.com:hal3/vowpal_wabbit
Diffstat (limited to 'library')
-rw-r--r-- | library/libsearn.h | 72 | ||||
-rw-r--r-- | library/test_search.cc | 85 |
2 files changed, 157 insertions, 0 deletions
diff --git a/library/libsearn.h b/library/libsearn.h new file mode 100644 index 00000000..1fbd7f67 --- /dev/null +++ b/library/libsearn.h @@ -0,0 +1,72 @@ +#/* +Copyright (c) by respective owners including Yahoo!, Microsoft, and +individual contributors. All rights reserved. Released under a BSD +license as described in the file LICENSE. + */ +#ifndef LIBSEARN_HOOKTASK_H +#define LIBSEARN_HOOKTASK_H + +#include "../vowpalwabbit/parser.h" +#include "../vowpalwabbit/vw.h" +#include "../vowpalwabbit/searn.h" +#include "../vowpalwabbit/searn_hooktask.h" + +using namespace std; + +template<class INPUT, class OUTPUT> class SearchTask { + public: + SearchTask(vw& vw_obj) : vw_obj(vw_obj), srn(*(Searn::searn*)vw_obj.searnstr) { + bogus_example = VW::read_example(vw_obj, (char*)"1 | x"); + blank_line = VW::read_example(vw_obj, (char*)""); + VW::finish_example(vw_obj, blank_line); + HookTask::task_data* d = srn.get_task_data<HookTask::task_data>(); + d->run_f = _searn_run_fn; + d->run_object = this; + d->var_map = NULL; // TODO + //d->num_actions = num_actions; // TODO + d->extra_data = NULL; + d->extra_data2 = NULL; + } + ~SearchTask() { VW::finish_example(vw_obj, bogus_example); } + + virtual void _run(Searn::searn&srn, INPUT& input_example, OUTPUT& output) {} // YOU MUST DEFINE THIS FUNCTION! + + void learn(INPUT& input_example, OUTPUT& output) { + HookTask::task_data* d = srn.template get_task_data<HookTask::task_data> (); // ugly cvw_objing convention :( + bogus_example->test_only = false; + d->extra_data = (void*)&input_example; + d->extra_data2 = (void*)&output; + vw_obj.learn(bogus_example); + vw_obj.learn(blank_line); // this will cause our searn_run_fn hook to get cvw_objed + } + + void predict(INPUT& input_example, OUTPUT& output) { + HookTask::task_data* d = srn.template get_task_data<HookTask::task_data> (); // ugly cvw_objing convention :( + bogus_example->test_only = true; + d->extra_data = (void*)&input_example; + d->extra_data2 = (void*)&output; + vw_obj.learn(bogus_example); + vw_obj.learn(blank_line); // this will cause our searn_run_fn hook to get cvw_objed + } + + + protected: + vw& vw_obj; + Searn::searn& srn; + + private: + example* bogus_example, *blank_line; + + static void _searn_run_fn(Searn::searn&srn) { + HookTask::task_data* d = srn.get_task_data<HookTask::task_data>(); + if ((d->run_object == NULL) || (d->extra_data == NULL) || (d->extra_data2 == NULL)) { + cerr << "error: cvw_objing _searn_run_fn without setting run object" << endl; + throw exception(); + } + ((SearchTask*)d->run_object)->_run(srn, *(INPUT*)d->extra_data, *(OUTPUT*)d->extra_data2); + } + +}; + + +#endif diff --git a/library/test_search.cc b/library/test_search.cc new file mode 100644 index 00000000..37553562 --- /dev/null +++ b/library/test_search.cc @@ -0,0 +1,85 @@ +#include <stdio.h> +#include "../vowpalwabbit/vw.h" +#include "../vowpalwabbit/ezexample.h" +#include "libsearn.h" + +struct wt { + string word; + uint32_t tag; + wt(string w, uint32_t t) : word(w), tag(t) {} +}; + +class SequenceLabelerTask : public SearchTask< vector<wt>, vector<uint32_t> > { + public: + SequenceLabelerTask(vw& vw_obj) + : SearchTask< vector<wt>, vector<uint32_t> >(vw_obj) { // must run parent constructor! + srn.set_options( Searn::AUTO_HAMMING_LOSS | Searn::AUTO_HISTORY ); + } + + // using vanilla vw interface + void _run(Searn::searn& srn, vector<wt> & input_example, vector<uint32_t> & output) { + output.clear(); + for (size_t i=0; i<input_example.size(); i++) { + example* ex = VW::read_example(vw_obj, "1 |w " + input_example[i].word); + uint32_t p = srn.predict(ex, input_example[i].tag); + VW::finish_example(vw_obj, ex); + output.push_back(p); + } + } + + // using ezexample + void _run2(Searn::searn& srn, vector<wt> & input_example, vector<uint32_t> & output) { + output.clear(); + for (size_t i=0; i<input_example.size(); i++) { + ezexample ex(&vw_obj); + ex(vw_namespace('w'))(input_example[i].word); // add the feature + uint32_t p = srn.predict(ex.get(), input_example[i].tag); + output.push_back(p); + } + } + +}; + +int main(int argc, char *argv[]) { + // initialize VW as usual, but use 'hook' as the search_task + vw& vw_obj = *VW::initialize("--search 4 --quiet --search_task hook --search_no_snapshot --ring_size 1024"); + + { + // we put this in its own scope so that its destructor gets called + // *before* VW::finish gets called; otherwise we'll get a + // segfault :(. not sure what to do about this :(. + SequenceLabelerTask task(vw_obj); + vector<wt> data; + vector<uint32_t> output; + uint32_t DET = 1, NOUN = 2, VERB = 3, ADJ = 4; + data.push_back( wt("the", DET) ); + data.push_back( wt("monster", NOUN) ); + data.push_back( wt("ate", VERB) ); + data.push_back( wt("a", DET) ); + data.push_back( wt("big", ADJ) ); + data.push_back( wt("sandwich", NOUN) ); + task.learn(data, output); + task.learn(data, output); + task.learn(data, output); + task.predict(data, output); + cerr << "output = ["; + for (size_t i=0; i<output.size(); i++) cerr << " " << output[i]; + cerr << " ]" << endl; + } + + VW::finish(vw_obj); +} + + + + + + + + + + + + + + |