diff options
author | Hal Daume III <me@hal3.name> | 2014-08-02 18:19:17 +0400 |
---|---|---|
committer | Hal Daume III <me@hal3.name> | 2014-08-02 18:19:17 +0400 |
commit | dba35f69c913639c417d7db18c15365c002f9516 (patch) | |
tree | 5f539c48c991d36163f27df3aaba1b002426ca76 /python | |
parent | b9ad37c089a818c256d313f13b54191d7e653ae4 (diff) |
re-fixed snapshot bug, changed type of structured_predict to remove copying
Diffstat (limited to 'python')
-rw-r--r-- | python/Makefile | 11 | ||||
-rw-r--r-- | python/pylibvw.cc | 40 | ||||
-rw-r--r-- | python/pyvw.py | 8 | ||||
-rw-r--r-- | python/test_search.py | 9 |
4 files changed, 61 insertions, 7 deletions
diff --git a/python/Makefile b/python/Makefile index fbb7f7fc..c06bd2ed 100644 --- a/python/Makefile +++ b/python/Makefile @@ -4,9 +4,9 @@ # # If we don't know where to look for boost - it's a no go. # -ifeq ($(BOOST_LIBRARY),) - $(error Please run 'make' at the top level only) -endif +#ifeq ($(BOOST_LIBRARY),) +# $(error Please run 'make' at the top level only) +#endif PYTHON_VERSION = 2.7 PYTHON_INCLUDE = /usr/include/python$(PYTHON_VERSION) @@ -15,7 +15,10 @@ PYTHON_LIBS = VWLIBS = -L ../vowpalwabbit -l vw -l allreduce STDLIBS = $(BOOST_LIBRARY) $(LIBS) $(PYTHONLIBS) -all: pylibvw.so +all: + cd ..; $(MAKE) python + +things: pylibvw.so pylibvw.so: pylibvw.o ../vowpalwabbit/libvw.a $(CXX) -shared -Wl,--export-dynamic pylibvw.o $(BOOST_LIBRARY) -lboost_python -L/usr/lib/python$(PYTHON_VERSION)/config -lpython$(PYTHON_VERSION) $(VWLIBS) $(STDLIBS) -o pylibvw.so diff --git a/python/pylibvw.cc b/python/pylibvw.cc index 031ea32f..41d60b4e 100644 --- a/python/pylibvw.cc +++ b/python/pylibvw.cc @@ -222,6 +222,41 @@ void set_structured_predict_hook(searn_ptr srn, py::object run_object) { void my_set_test_only(example_ptr ec, bool val) { ec->test_only = val; } +bool po_exists(searn_ptr srn, string arg) { + PythonTask::task_data* d = srn->get_task_data<PythonTask::task_data>(); + return d->var_map.count(arg) > 0; +} + +string po_get_string(searn_ptr srn, string arg) { + PythonTask::task_data* d = srn->get_task_data<PythonTask::task_data>(); + return d->var_map[arg].as<string>(); +} + +int32_t po_get_int(searn_ptr srn, string arg) { + PythonTask::task_data* d = srn->get_task_data<PythonTask::task_data>(); + try { return d->var_map[arg].as<int>(); } catch (...) {} + try { return d->var_map[arg].as<size_t>(); } catch (...) {} + try { return d->var_map[arg].as<uint32_t>(); } catch (...) {} + try { return d->var_map[arg].as<uint64_t>(); } catch (...) {} + try { return d->var_map[arg].as<uint16_t>(); } catch (...) {} + try { return d->var_map[arg].as<int32_t>(); } catch (...) {} + try { return d->var_map[arg].as<int64_t>(); } catch (...) {} + try { return d->var_map[arg].as<int16_t>(); } catch (...) {} + // we know this'll fail but do it anyway to get the exception + return d->var_map[arg].as<int>(); +} + +PyObject* po_get(searn_ptr srn, string arg) { + try { + return py::incref(py::object(po_get_string(srn, arg)).ptr()); + } catch (...) {} + try { + return py::incref(py::object(po_get_int(srn, arg)).ptr()); + } catch (...) {} + // return None + return py::incref(py::object().ptr()); +} + //BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(searn_predict_overloads, Searn::searn::predict, 2, 3); //BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(searn_predictLDF_overloads, Searn::searn::predictLDF, 3, 4); @@ -330,6 +365,11 @@ BOOST_PYTHON_MODULE(pylibvw) { .def("get_num_actions", &searn_get_num_actions, "TODO") .def("set_structured_predict_hook", &set_structured_predict_hook, "TODO") + .def("po_exists", &po_exists, "TODO") + .def("po_get", &po_get, "TODO") + .def("po_get_str", &po_get_string, "TODO") + .def("po_get_int", &po_get_int, "TODO") + .def_readonly("AUTO_HISTORY", Searn::AUTO_HISTORY, "TODO") .def_readonly("AUTO_HAMMING_LOSS", Searn::AUTO_HAMMING_LOSS, "TODO") .def_readonly("EXAMPLES_DONT_CHANGE", Searn::EXAMPLES_DONT_CHANGE, "TODO") diff --git a/python/pyvw.py b/python/pyvw.py index 81d6ecef..e9e7207d 100644 --- a/python/pyvw.py +++ b/python/pyvw.py @@ -28,11 +28,13 @@ class SearchTask(): self._call_vw(lambda: self._run(my_example), isTest=False) def predict(self, my_example): - def f(): self._output = self._run(my_example) self._output = None + def f(): self._output = self._run(my_example) self._call_vw(f, isTest=True) - if self._output is None: - raise Exception('structured predict hook failed to return anything') + #if self._output is None: + # raise Exception('structured predict hook failed to return anything') + # don't raise this exception because your _run code legitimately + # _could_ return None! return self._output class vw(pylibvw.vw): diff --git a/python/test_search.py b/python/test_search.py index 10bca70f..f01c6e8c 100644 --- a/python/test_search.py +++ b/python/test_search.py @@ -4,7 +4,16 @@ import pyvw class SequenceLabeler(pyvw.SearchTask): def __init__(self, vw, srn, num_actions): # you must must must initialize the parent class + # this will automatically store self.srn <- srn, self.vw <- vw pyvw.SearchTask.__init__(self, vw, srn, num_actions) + + # you can test program options with srn.po_exists + # and get their values with srn.po_get -> string and + # srn.po_get_int -> int + if srn.po_exists('search'): + print 'found --search' + print '--search value =', srn.po_get('search'), ', type =', type(srn.po_get('search')) + # set whatever options you want srn.set_options( srn.AUTO_HAMMING_LOSS | srn.AUTO_HISTORY ) |