diff options
author | Hal Daume III <me@hal3.name> | 2014-11-04 18:55:38 +0300 |
---|---|---|
committer | Hal Daume III <me@hal3.name> | 2014-11-04 18:55:38 +0300 |
commit | f2afaef7554c0a7351cca7c5bb7b8c6be25f74e6 (patch) | |
tree | 7a18a244a7eeb79af9f28de06973250052db772a /python/pyvw.py | |
parent | 8fce7d0ab9218a8f33678534b01d21d227a61505 (diff) |
allow lambdas for examples for lazy computation
Diffstat (limited to 'python/pyvw.py')
-rw-r--r-- | python/pyvw.py | 47 |
1 files changed, 37 insertions, 10 deletions
diff --git a/python/pyvw.py b/python/pyvw.py index d1a8fa90..3f8caa9c 100644 --- a/python/pyvw.py +++ b/python/pyvw.py @@ -90,8 +90,13 @@ class vw(pylibvw.vw): """The basic (via-reduction) prediction mechanism. Several variants are supported through this overloaded function: - 'examples' can be a single example (interpreted as non-LDF - mode) or a list of examples (interpreted as LDF mode). + 'examples' can be a single example (interpreted as + non-LDF mode) or a list of examples (interpreted as + LDF mode). it can also be a lambda function that + returns a single example or list of examples, and in + that list, each element can also be a lambda function + that returns an example. this is done for lazy + example construction (aka speed). 'my_tag' should be an integer id, specifying this prediction @@ -114,17 +119,39 @@ class vw(pylibvw.vw): 'learner_id' specifies the underlying learner id Returns a single prediction. + """ - if (isinstance(examples, list) and all([isinstance(ex, example) or isinstance(ex, pylibvw.example) for ex in examples])) or \ - isinstance(examples, example) or isinstance(examples, pylibvw.example): - P = sch.get_predictor(my_tag) - if isinstance(examples, list): # LDF - P.set_input_length(len(examples)) + + P = sch.get_predictor(my_tag) + if sch.is_ldf(): + # we need to know how many actions there are, even if we don't know their identities + while hasattr(examples, '__call__'): examples = examples() + if not isinstance(examples, list): raise TypeError('expected example _list_ in LDF mode for SearchTask.predict()') + P.set_input_length(len(examples)) + if sch.predict_needs_example(): for n in range(len(examples)): - P.set_input_at(n, examples[n]) - else: # non-LDF + ec = examples[n] + while hasattr(ec, '__call__'): ec = ec() # unfold the lambdas + if not isinstance(ec, example) and not isinstance(ec, pylibvw.example): raise TypeError('non-example in LDF example list in SearchTask.predict()') + P.set_input_at(n, ec) + else: + pass # TODO: do we need to set the examples even though they're not used? + else: + if sch.predict_needs_example(): + while hasattr(examples, '__call__'): examples = examples() P.set_input(examples) - + else: + pass # TODO: do we need to set the examples even though they're not used? + + # if (isinstance(examples, list) and all([isinstance(ex, example) or isinstance(ex, pylibvw.example) for ex in examples])) or \ + # isinstance(examples, example) or isinstance(examples, pylibvw.example): + # if isinstance(examples, list): # LDF + # P.set_input_length(len(examples)) + # for n in range(len(examples)): + # P.set_input_at(n, examples[n]) + # else: # non-LDF + # P.set_input(examples) + if True: # TODO: get rid of this if oracle is None: pass elif isinstance(oracle, list): if len(oracle) > 0: P.set_oracles(oracle) |