Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorHal Daume III <me@hal3.name>2014-08-02 18:19:17 +0400
committerHal Daume III <me@hal3.name>2014-08-02 18:19:17 +0400
commitdba35f69c913639c417d7db18c15365c002f9516 (patch)
tree5f539c48c991d36163f27df3aaba1b002426ca76 /python
parentb9ad37c089a818c256d313f13b54191d7e653ae4 (diff)
re-fixed snapshot bug, changed type of structured_predict to remove copying
Diffstat (limited to 'python')
-rw-r--r--python/Makefile11
-rw-r--r--python/pylibvw.cc40
-rw-r--r--python/pyvw.py8
-rw-r--r--python/test_search.py9
4 files changed, 61 insertions, 7 deletions
diff --git a/python/Makefile b/python/Makefile
index fbb7f7fc..c06bd2ed 100644
--- a/python/Makefile
+++ b/python/Makefile
@@ -4,9 +4,9 @@
#
# If we don't know where to look for boost - it's a no go.
#
-ifeq ($(BOOST_LIBRARY),)
- $(error Please run 'make' at the top level only)
-endif
+#ifeq ($(BOOST_LIBRARY),)
+# $(error Please run 'make' at the top level only)
+#endif
PYTHON_VERSION = 2.7
PYTHON_INCLUDE = /usr/include/python$(PYTHON_VERSION)
@@ -15,7 +15,10 @@ PYTHON_LIBS =
VWLIBS = -L ../vowpalwabbit -l vw -l allreduce
STDLIBS = $(BOOST_LIBRARY) $(LIBS) $(PYTHONLIBS)
-all: pylibvw.so
+all:
+ cd ..; $(MAKE) python
+
+things: pylibvw.so
pylibvw.so: pylibvw.o ../vowpalwabbit/libvw.a
$(CXX) -shared -Wl,--export-dynamic pylibvw.o $(BOOST_LIBRARY) -lboost_python -L/usr/lib/python$(PYTHON_VERSION)/config -lpython$(PYTHON_VERSION) $(VWLIBS) $(STDLIBS) -o pylibvw.so
diff --git a/python/pylibvw.cc b/python/pylibvw.cc
index 031ea32f..41d60b4e 100644
--- a/python/pylibvw.cc
+++ b/python/pylibvw.cc
@@ -222,6 +222,41 @@ void set_structured_predict_hook(searn_ptr srn, py::object run_object) {
void my_set_test_only(example_ptr ec, bool val) { ec->test_only = val; }
+bool po_exists(searn_ptr srn, string arg) {
+ PythonTask::task_data* d = srn->get_task_data<PythonTask::task_data>();
+ return d->var_map.count(arg) > 0;
+}
+
+string po_get_string(searn_ptr srn, string arg) {
+ PythonTask::task_data* d = srn->get_task_data<PythonTask::task_data>();
+ return d->var_map[arg].as<string>();
+}
+
+int32_t po_get_int(searn_ptr srn, string arg) {
+ PythonTask::task_data* d = srn->get_task_data<PythonTask::task_data>();
+ try { return d->var_map[arg].as<int>(); } catch (...) {}
+ try { return d->var_map[arg].as<size_t>(); } catch (...) {}
+ try { return d->var_map[arg].as<uint32_t>(); } catch (...) {}
+ try { return d->var_map[arg].as<uint64_t>(); } catch (...) {}
+ try { return d->var_map[arg].as<uint16_t>(); } catch (...) {}
+ try { return d->var_map[arg].as<int32_t>(); } catch (...) {}
+ try { return d->var_map[arg].as<int64_t>(); } catch (...) {}
+ try { return d->var_map[arg].as<int16_t>(); } catch (...) {}
+ // we know this'll fail but do it anyway to get the exception
+ return d->var_map[arg].as<int>();
+}
+
+PyObject* po_get(searn_ptr srn, string arg) {
+ try {
+ return py::incref(py::object(po_get_string(srn, arg)).ptr());
+ } catch (...) {}
+ try {
+ return py::incref(py::object(po_get_int(srn, arg)).ptr());
+ } catch (...) {}
+ // return None
+ return py::incref(py::object().ptr());
+}
+
//BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(searn_predict_overloads, Searn::searn::predict, 2, 3);
//BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(searn_predictLDF_overloads, Searn::searn::predictLDF, 3, 4);
@@ -330,6 +365,11 @@ BOOST_PYTHON_MODULE(pylibvw) {
.def("get_num_actions", &searn_get_num_actions, "TODO")
.def("set_structured_predict_hook", &set_structured_predict_hook, "TODO")
+ .def("po_exists", &po_exists, "TODO")
+ .def("po_get", &po_get, "TODO")
+ .def("po_get_str", &po_get_string, "TODO")
+ .def("po_get_int", &po_get_int, "TODO")
+
.def_readonly("AUTO_HISTORY", Searn::AUTO_HISTORY, "TODO")
.def_readonly("AUTO_HAMMING_LOSS", Searn::AUTO_HAMMING_LOSS, "TODO")
.def_readonly("EXAMPLES_DONT_CHANGE", Searn::EXAMPLES_DONT_CHANGE, "TODO")
diff --git a/python/pyvw.py b/python/pyvw.py
index 81d6ecef..e9e7207d 100644
--- a/python/pyvw.py
+++ b/python/pyvw.py
@@ -28,11 +28,13 @@ class SearchTask():
self._call_vw(lambda: self._run(my_example), isTest=False)
def predict(self, my_example):
- def f(): self._output = self._run(my_example)
self._output = None
+ def f(): self._output = self._run(my_example)
self._call_vw(f, isTest=True)
- if self._output is None:
- raise Exception('structured predict hook failed to return anything')
+ #if self._output is None:
+ # raise Exception('structured predict hook failed to return anything')
+ # don't raise this exception because your _run code legitimately
+ # _could_ return None!
return self._output
class vw(pylibvw.vw):
diff --git a/python/test_search.py b/python/test_search.py
index 10bca70f..f01c6e8c 100644
--- a/python/test_search.py
+++ b/python/test_search.py
@@ -4,7 +4,16 @@ import pyvw
class SequenceLabeler(pyvw.SearchTask):
def __init__(self, vw, srn, num_actions):
# you must must must initialize the parent class
+ # this will automatically store self.srn <- srn, self.vw <- vw
pyvw.SearchTask.__init__(self, vw, srn, num_actions)
+
+ # you can test program options with srn.po_exists
+ # and get their values with srn.po_get -> string and
+ # srn.po_get_int -> int
+ if srn.po_exists('search'):
+ print 'found --search'
+ print '--search value =', srn.po_get('search'), ', type =', type(srn.po_get('search'))
+
# set whatever options you want
srn.set_options( srn.AUTO_HAMMING_LOSS | srn.AUTO_HISTORY )