Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHal Daume III <me@hal3.name>2014-09-18 01:19:57 +0400
committerHal Daume III <me@hal3.name>2014-09-18 01:19:57 +0400
commitcb2f619a8a0dfdd4280c390e2aa97dd709c1d68e (patch)
tree19651c4bc5f68ba04f7a86687658ce529cd24d28 /library
parent169a4e54e89b319ef560afd4bc31c0b02f7b1832 (diff)
parentdbcf40d7f3acbd29ccebe9f903e2beaec4f9b854 (diff)
Merge branch 'master' of github.com:hal3/vowpal_wabbit
Diffstat (limited to 'library')
-rw-r--r--library/libsearn.h72
-rw-r--r--library/test_search.cc85
2 files changed, 157 insertions, 0 deletions
diff --git a/library/libsearn.h b/library/libsearn.h
new file mode 100644
index 00000000..1fbd7f67
--- /dev/null
+++ b/library/libsearn.h
@@ -0,0 +1,72 @@
+#/*
+Copyright (c) by respective owners including Yahoo!, Microsoft, and
+individual contributors. All rights reserved. Released under a BSD
+license as described in the file LICENSE.
+ */
+#ifndef LIBSEARN_HOOKTASK_H
+#define LIBSEARN_HOOKTASK_H
+
+#include "../vowpalwabbit/parser.h"
+#include "../vowpalwabbit/vw.h"
+#include "../vowpalwabbit/searn.h"
+#include "../vowpalwabbit/searn_hooktask.h"
+
+using namespace std;
+
+template<class INPUT, class OUTPUT> class SearchTask {
+ public:
+ SearchTask(vw& vw_obj) : vw_obj(vw_obj), srn(*(Searn::searn*)vw_obj.searnstr) {
+ bogus_example = VW::read_example(vw_obj, (char*)"1 | x");
+ blank_line = VW::read_example(vw_obj, (char*)"");
+ VW::finish_example(vw_obj, blank_line);
+ HookTask::task_data* d = srn.get_task_data<HookTask::task_data>();
+ d->run_f = _searn_run_fn;
+ d->run_object = this;
+ d->var_map = NULL; // TODO
+ //d->num_actions = num_actions; // TODO
+ d->extra_data = NULL;
+ d->extra_data2 = NULL;
+ }
+ ~SearchTask() { VW::finish_example(vw_obj, bogus_example); }
+
+ virtual void _run(Searn::searn&srn, INPUT& input_example, OUTPUT& output) {} // YOU MUST DEFINE THIS FUNCTION!
+
+ void learn(INPUT& input_example, OUTPUT& output) {
+ HookTask::task_data* d = srn.template get_task_data<HookTask::task_data> (); // ugly cvw_objing convention :(
+ bogus_example->test_only = false;
+ d->extra_data = (void*)&input_example;
+ d->extra_data2 = (void*)&output;
+ vw_obj.learn(bogus_example);
+ vw_obj.learn(blank_line); // this will cause our searn_run_fn hook to get cvw_objed
+ }
+
+ void predict(INPUT& input_example, OUTPUT& output) {
+ HookTask::task_data* d = srn.template get_task_data<HookTask::task_data> (); // ugly cvw_objing convention :(
+ bogus_example->test_only = true;
+ d->extra_data = (void*)&input_example;
+ d->extra_data2 = (void*)&output;
+ vw_obj.learn(bogus_example);
+ vw_obj.learn(blank_line); // this will cause our searn_run_fn hook to get cvw_objed
+ }
+
+
+ protected:
+ vw& vw_obj;
+ Searn::searn& srn;
+
+ private:
+ example* bogus_example, *blank_line;
+
+ static void _searn_run_fn(Searn::searn&srn) {
+ HookTask::task_data* d = srn.get_task_data<HookTask::task_data>();
+ if ((d->run_object == NULL) || (d->extra_data == NULL) || (d->extra_data2 == NULL)) {
+ cerr << "error: cvw_objing _searn_run_fn without setting run object" << endl;
+ throw exception();
+ }
+ ((SearchTask*)d->run_object)->_run(srn, *(INPUT*)d->extra_data, *(OUTPUT*)d->extra_data2);
+ }
+
+};
+
+
+#endif
diff --git a/library/test_search.cc b/library/test_search.cc
new file mode 100644
index 00000000..37553562
--- /dev/null
+++ b/library/test_search.cc
@@ -0,0 +1,85 @@
+#include <stdio.h>
+#include "../vowpalwabbit/vw.h"
+#include "../vowpalwabbit/ezexample.h"
+#include "libsearn.h"
+
+struct wt {
+ string word;
+ uint32_t tag;
+ wt(string w, uint32_t t) : word(w), tag(t) {}
+};
+
+class SequenceLabelerTask : public SearchTask< vector<wt>, vector<uint32_t> > {
+ public:
+ SequenceLabelerTask(vw& vw_obj)
+ : SearchTask< vector<wt>, vector<uint32_t> >(vw_obj) { // must run parent constructor!
+ srn.set_options( Searn::AUTO_HAMMING_LOSS | Searn::AUTO_HISTORY );
+ }
+
+ // using vanilla vw interface
+ void _run(Searn::searn& srn, vector<wt> & input_example, vector<uint32_t> & output) {
+ output.clear();
+ for (size_t i=0; i<input_example.size(); i++) {
+ example* ex = VW::read_example(vw_obj, "1 |w " + input_example[i].word);
+ uint32_t p = srn.predict(ex, input_example[i].tag);
+ VW::finish_example(vw_obj, ex);
+ output.push_back(p);
+ }
+ }
+
+ // using ezexample
+ void _run2(Searn::searn& srn, vector<wt> & input_example, vector<uint32_t> & output) {
+ output.clear();
+ for (size_t i=0; i<input_example.size(); i++) {
+ ezexample ex(&vw_obj);
+ ex(vw_namespace('w'))(input_example[i].word); // add the feature
+ uint32_t p = srn.predict(ex.get(), input_example[i].tag);
+ output.push_back(p);
+ }
+ }
+
+};
+
+int main(int argc, char *argv[]) {
+ // initialize VW as usual, but use 'hook' as the search_task
+ vw& vw_obj = *VW::initialize("--search 4 --quiet --search_task hook --search_no_snapshot --ring_size 1024");
+
+ {
+ // we put this in its own scope so that its destructor gets called
+ // *before* VW::finish gets called; otherwise we'll get a
+ // segfault :(. not sure what to do about this :(.
+ SequenceLabelerTask task(vw_obj);
+ vector<wt> data;
+ vector<uint32_t> output;
+ uint32_t DET = 1, NOUN = 2, VERB = 3, ADJ = 4;
+ data.push_back( wt("the", DET) );
+ data.push_back( wt("monster", NOUN) );
+ data.push_back( wt("ate", VERB) );
+ data.push_back( wt("a", DET) );
+ data.push_back( wt("big", ADJ) );
+ data.push_back( wt("sandwich", NOUN) );
+ task.learn(data, output);
+ task.learn(data, output);
+ task.learn(data, output);
+ task.predict(data, output);
+ cerr << "output = [";
+ for (size_t i=0; i<output.size(); i++) cerr << " " << output[i];
+ cerr << " ]" << endl;
+ }
+
+ VW::finish(vw_obj);
+}
+
+
+
+
+
+
+
+
+
+
+
+
+
+