Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorariel faigon <github.2009@yendor.com>2014-07-24 00:26:10 +0400
committerariel faigon <github.2009@yendor.com>2014-07-24 00:26:10 +0400
commit295d532b4e7c6b648c02c7b35b8ebde4a0b6c937 (patch)
treeb4d4e10273485c8aee0d4a9ca286a00d8856e64b
parent7e138ac19bb3e4be88201d521249d87f52e378f3 (diff)
parent3c18847b1935f83608824e095385d042d1d77bb9 (diff)
Merge branch 'master' of git://github.com/JohnLangford/vowpal_wabbit
-rw-r--r--library/Makefile10
-rw-r--r--library/ezexample_predict.cc18
-rw-r--r--library/ezexample_predict_threaded.cc18
-rw-r--r--library/ezexample_train.cc2
-rwxr-xr-xlibrary/train.sh4
-rw-r--r--library/train.wbin74 -> 0 bytes
-rwxr-xr-xtest/RunTests14
-rw-r--r--test/train-sets/library_train6
-rw-r--r--test/train-sets/ref/argmax_data.stderr12
-rw-r--r--test/train-sets/ref/ezexample_predict.stderr13
-rw-r--r--test/train-sets/ref/ezexample_predict.stdout0
-rw-r--r--test/train-sets/ref/library_train.stderr29
-rw-r--r--test/train-sets/ref/library_train.stdout0
-rw-r--r--vowpalwabbit/csoaa.cc9
-rw-r--r--vowpalwabbit/ezexample.h (renamed from library/ezexample.h)96
-rw-r--r--vowpalwabbit/parser.cc7
-rw-r--r--vowpalwabbit/searn.cc139
-rw-r--r--vowpalwabbit/searn_sequencetask.cc41
-rw-r--r--vowpalwabbit/searn_sequencetask.h7
19 files changed, 315 insertions, 110 deletions
diff --git a/library/Makefile b/library/Makefile
index 499a0a69..6b7fb917 100644
--- a/library/Makefile
+++ b/library/Makefile
@@ -11,21 +11,21 @@ endif
VWLIBS = -L ../vowpalwabbit -l vw -l allreduce
STDLIBS = $(BOOST_LIBRARY) $(LIBS)
-all: ezexample_predict ezexample_train library_example recommend gd_mf_weights
+all: ezexample_predict ezexample_train library_example recommend gd_mf_weights # ezexample_predict_threaded
-ezexample_predict: ezexample_predict.cc ../vowpalwabbit/libvw.a ezexample.h
+ezexample_predict: ezexample_predict.cc ../vowpalwabbit/libvw.a
$(CXX) -g $(FLAGS) -o $@ $< $(VWLIBS) $(STDLIBS)
-ezexample_predict_threaded: ezexample_predict_threaded.cc ../vowpalwabbit/libvw.a ezexample.h
+ezexample_predict_threaded: ezexample_predict_threaded.cc ../vowpalwabbit/libvw.a
$(CXX) -g $(FLAGS) -o $@ $< $(VWLIBS) $(BOOST_PROGRAM_OPTIONS) -l z -l boost_thread
-ezexample_train: ezexample_train.cc ../vowpalwabbit/libvw.a ezexample.h
+ezexample_train: ezexample_train.cc ../vowpalwabbit/libvw.a
$(CXX) -g $(FLAGS) -o $@ $< $(VWLIBS) $(STDLIBS)
library_example: library_example.cc ../vowpalwabbit/libvw.a
$(CXX) -g $(FLAGS) -o $@ $< $(VWLIBS) $(STDLIBS)
-recommend: recommend.cc ../vowpalwabbit/libvw.a ezexample.h
+recommend: recommend.cc ../vowpalwabbit/libvw.a
$(CXX) -g $(FLAGS) -o $@ $< $(VWLIBS) $(STDLIBS)
gd_mf_weights: gd_mf_weights.cc ../vowpalwabbit/libvw.a
diff --git a/library/ezexample_predict.cc b/library/ezexample_predict.cc
index 0aa95941..db061f61 100644
--- a/library/ezexample_predict.cc
+++ b/library/ezexample_predict.cc
@@ -1,14 +1,22 @@
#include <stdio.h>
#include "../vowpalwabbit/parser.h"
#include "../vowpalwabbit/vw.h"
-#include "ezexample.h"
+#include "../vowpalwabbit/ezexample.h"
using namespace std;
int main(int argc, char *argv[])
{
+ string init_string = "-t -q st --hash all --noconstant --ldf_override s -i ";
+ if (argc > 1)
+ init_string += argv[1];
+ else
+ init_string += "train.w";
+
+ cerr << "initializing with: '" << init_string << "'" << endl;
+
// INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
- vw* vw = VW::initialize("-t -i train.w -q st --hash all --noconstant --csoaa_ldf s --quiet");
+ vw* vw = VW::initialize(init_string); // "-t -q st --hash all --noconstant --ldf_override s -i train.w");
{
// HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
@@ -22,7 +30,7 @@ int main(int argc, char *argv[])
("w^le")
("w^homme");
ex.set_label("1");
- cerr << ex.predict() << endl;
+ cerr << ex.predict_partial() << endl;
// ex.clear_features();
@@ -32,14 +40,14 @@ int main(int argc, char *argv[])
("w^un")
("w^homme");
ex.set_label("2");
- cerr << ex.predict() << endl;
+ cerr << ex.predict_partial() << endl;
--ex; // remove the most recent namespace, and add features with explicit ns
ex('t', "p^un_homme")
('t', "w^un")
('t', "w^homme");
ex.set_label("2");
- cerr << ex.predict() << endl;
+ cerr << ex.predict_partial() << endl;
}
// AND FINISH UP
diff --git a/library/ezexample_predict_threaded.cc b/library/ezexample_predict_threaded.cc
index a45d76ca..0fa5b1e6 100644
--- a/library/ezexample_predict_threaded.cc
+++ b/library/ezexample_predict_threaded.cc
@@ -1,6 +1,6 @@
#include <stdio.h>
#include "../vowpalwabbit/vw.h"
-#include "ezexample.h"
+#include "../vowpalwabbit/ezexample.h"
#include <boost/thread/thread.hpp>
@@ -87,8 +87,8 @@ int main(int argc, char *argv[])
int threadcount = atoi(argv[1]);
runcount = atoi(argv[2]);
// INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
- string vw_init_string_all = "-t --csoaa_ldf s --quiet -q st --noconstant --hash all -i train.w";
- string vw_init_string_parser = "-t --csoaa_ldf s --quiet -q st --noconstant --hash all --noop"; // this needs to have enough arguments to get the parser right
+ string vw_init_string_all = "-t --ldf_override s --quiet -q st --noconstant --hash all -i train.w";
+ string vw_init_string_parser = "-t --ldf_override s --quiet -q st --noconstant --hash all --noop"; // this needs to have enough arguments to get the parser right
vw*vw = VW::initialize(vw_init_string_all);
vector<double> results;
@@ -104,8 +104,8 @@ int main(int argc, char *argv[])
("w^le")
("w^homme");
ex.set_label("1");
- results.push_back(ex.predict());
- cerr << "should be near zero = " << ex.predict() << endl;
+ results.push_back(ex.predict_partial());
+ cerr << "should be near zero = " << ex.predict_partial() << endl;
--ex; // remove the most recent namespace
ex(vw_namespace('t'))
@@ -113,8 +113,8 @@ int main(int argc, char *argv[])
("w^un")
("w^homme");
ex.set_label("1");
- results.push_back(ex.predict());
- cerr << "should be near one = " << ex.predict() << endl;
+ results.push_back(ex.predict_partial());
+ cerr << "should be near one = " << ex.predict_partial() << endl;
--ex; // remove the most recent namespace
// add features with explicit ns
@@ -122,8 +122,8 @@ int main(int argc, char *argv[])
('t', "w^un")
('t', "w^homme");
ex.set_label("1");
- results.push_back(ex.predict());
- cerr << "should be near one = " << ex.predict() << endl;
+ results.push_back(ex.predict_partial());
+ cerr << "should be near one = " << ex.predict_partial() << endl;
}
if (threadcount == 0)
diff --git a/library/ezexample_train.cc b/library/ezexample_train.cc
index 7df2dc50..a0f66a99 100644
--- a/library/ezexample_train.cc
+++ b/library/ezexample_train.cc
@@ -1,7 +1,7 @@
#include <stdio.h>
#include "../vowpalwabbit/parser.h"
#include "../vowpalwabbit/vw.h"
-#include "ezexample.h"
+#include "../vowpalwabbit/ezexample.h"
using namespace std;
diff --git a/library/train.sh b/library/train.sh
index ce118f04..e0d5f121 100755
--- a/library/train.sh
+++ b/library/train.sh
@@ -1,5 +1,5 @@
#!/bin/bash
rm -f train.cache train.w
-../vowpalwabbit/vw -c -d train -f train.w -q st --passes 100 --hash all --noconstant --csoaa_ldf m
-../vowpalwabbit/vw -t -d train -i train.w -p train.pred --noconstant --csoaa_ldf m
+../vowpalwabbit/vw -c -d train -f train.w -q st --passes 100 --hash all --noconstant --csoaa_ldf m --holdout_off
+../vowpalwabbit/vw -t -d train -i train.w -p train.pred --noconstant
diff --git a/library/train.w b/library/train.w
deleted file mode 100644
index be1820ff..00000000
--- a/library/train.w
+++ /dev/null
Binary files differ
diff --git a/test/RunTests b/test/RunTests
index f10e05b3..5b64bf84 100755
--- a/test/RunTests
+++ b/test/RunTests
@@ -288,6 +288,10 @@ sub next_test() {
}
next;
}
+ if ($line =~ /library\/ezexample_/) {
+ $cmd = trim_spaces($line);
+ next;
+ }
if ($line =~ m/\.stdout\b/) {
$out_ref = ref_file(trim_spaces($line));
next;
@@ -1015,3 +1019,13 @@ __DATA__
{VW} --stage_poly --sched_exponent 1.0 --batch_sz 1000 -d train-sets/rcv1_small.dat -p stage_poly.s100.doubling.predict --quiet
train-sets/ref/stage_poly.s100.doubling.stderr
train-sets/ref/stage_poly.s100.doubling.predict
+
+# Test 65: library test, train the initial model
+{VW} -c -k -d train-sets/library_train -f models/library_train.w -q st --passes 100 --hash all --noconstant --csoaa_ldf m --holdout_off
+ train-sets/ref/library_train.stdout
+ train-sets/ref/library_train.stderr
+
+# Test 66: library test, run ezexample_predict
+../library/ezexample_predict models/library_train.w
+ train-sets/ref/ezexample_predict.stdout
+ train-sets/ref/ezexample_predict.stderr
diff --git a/test/train-sets/library_train b/test/train-sets/library_train
new file mode 100644
index 00000000..71330c98
--- /dev/null
+++ b/test/train-sets/library_train
@@ -0,0 +1,6 @@
+1:1 |s p^the_man w^the w^man |t p^un_homme w^un w^homme
+2:0 |s p^the_man w^the w^man |t p^le_homme w^le w^homme
+
+1:0 |s p^a_man w^a w^man |t p^un_homme w^un w^homme
+2:1 |s p^a_man w^a w^man |t p^le_homme w^le w^homme
+
diff --git a/test/train-sets/ref/argmax_data.stderr b/test/train-sets/ref/argmax_data.stderr
index 6357dbac..48d88bb6 100644
--- a/test/train-sets/ref/argmax_data.stderr
+++ b/test/train-sets/ref/argmax_data.stderr
@@ -12,17 +12,17 @@ average since sequence example current label cur
loss last counter weight sequence prefix sequence prefix features pass pol made gener.
10.000000 10.000000 1 1.000000 [2 ] [1 ] 15 0 0 5 5
5.500000 1.000000 2 2.000000 [1 ] [2 ] 12 0 0 9 9
-5.250000 5.000000 4 4.000000 [2 ] [1 ] 9 0 0 15 15
-2.875000 0.500000 8 8.000000 [2 ] [2 ] 9 1 0 30 30
-1.687500 0.500000 16 16.000000 [2 ] [2 ] 9 3 0 60 60
-1.093750 0.500000 32 32.000000 [2 ] [2 ] 9 7 0 120 120
-0.796875 0.500000 64 64.000000 [2 ] [2 ] 9 15 0 240 240
+3.000000 0.500000 4 4.000000 [2 ] [2 ] 9 0 0 15 15
+1.750000 0.500000 8 8.000000 [2 ] [2 ] 9 1 0 30 30
+1.125000 0.500000 16 16.000000 [2 ] [2 ] 9 3 0 60 60
+0.812500 0.500000 32 32.000000 [2 ] [2 ] 9 7 0 120 120
+0.656250 0.500000 64 64.000000 [2 ] [2 ] 9 15 0 240 240
finished run
number of examples per pass = 4
passes used = 20
weighted example sum = 80
weighted label sum = 0
-average loss = 0.7375
+average loss = 0.625
best constant = 0
total feature number = 900
diff --git a/test/train-sets/ref/ezexample_predict.stderr b/test/train-sets/ref/ezexample_predict.stderr
new file mode 100644
index 00000000..ebbd34f7
--- /dev/null
+++ b/test/train-sets/ref/ezexample_predict.stderr
@@ -0,0 +1,13 @@
+initializing with: '-t -q st --hash all --noconstant --ldf_override s -i models/library_train.w'
+creating quadratic features for pairs: st
+only testing
+Num weight bits = 18
+learning rate = 10
+initial_t = 1
+power_t = 0.5
+using no cache
+Reading datafile =
+num sources = 0
+2.23517e-08
+1
+1
diff --git a/test/train-sets/ref/ezexample_predict.stdout b/test/train-sets/ref/ezexample_predict.stdout
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/train-sets/ref/ezexample_predict.stdout
diff --git a/test/train-sets/ref/library_train.stderr b/test/train-sets/ref/library_train.stderr
new file mode 100644
index 00000000..8cc3a531
--- /dev/null
+++ b/test/train-sets/ref/library_train.stderr
@@ -0,0 +1,29 @@
+creating quadratic features for pairs: st
+final_regressor = models/library_train.w
+Num weight bits = 18
+learning rate = 0.5
+initial_t = 0
+power_t = 0.5
+decay_learning_rate = 1
+creating cache_file = train-sets/library_train.cache
+Reading datafile = train-sets/library_train
+num sources = 1
+average since example example current current current
+loss last counter weight label predict features
+1.000000 1.000000 1 1.0 known 1 15
+0.500000 0.000000 2 2.0 known 0 15
+0.500000 0.500000 4 4.0 known 1 15
+0.250000 0.000000 8 8.0 known 1 15
+0.125000 0.000000 16 16.0 known 1 15
+0.062500 0.000000 32 32.0 known 1 15
+0.031250 0.000000 64 64.0 known 1 15
+0.015625 0.000000 128 128.0 known 1 15
+
+finished run
+number of examples per pass = 2
+passes used = 100
+weighted example sum = 200
+weighted label sum = 0
+average loss = 0.01
+best constant = 0
+total feature number = 6000
diff --git a/test/train-sets/ref/library_train.stdout b/test/train-sets/ref/library_train.stdout
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/train-sets/ref/library_train.stdout
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index 162d8608..cf04ab61 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -725,6 +725,13 @@ namespace LabelDict {
learner* setup(vw& all, po::variables_map& vm)
{
+ po::options_description ldf_opts("LDF Options");
+ ldf_opts.add_options()
+ ("ldf_override", po::value<string>(), "Override singleline or multiline from csoaa_ldf or wap_ldf, eg if stored in file")
+ ;
+
+ vm = add_options(all, ldf_opts);
+
ldf* ld = (ldf*)calloc_or_die(1, sizeof(ldf));
ld->all = &all;
@@ -744,6 +751,8 @@ namespace LabelDict {
all.file_options.append(" --wap_ldf ");
all.file_options.append(ldf_arg);
}
+ if ( vm.count("ldf_override") )
+ ldf_arg = vm["ldf_override"].as<string>();
all.p->lp = COST_SENSITIVE::cs_label;
diff --git a/library/ezexample.h b/vowpalwabbit/ezexample.h
index 69797738..3457c690 100644
--- a/library/ezexample.h
+++ b/vowpalwabbit/ezexample.h
@@ -22,6 +22,7 @@ class ezexample {
char str[2];
example*ec;
+ bool we_create_ec;
vector<fid> past_seeds;
fid current_seed;
size_t quadratic_features_num;
@@ -41,10 +42,7 @@ class ezexample {
return new_ec;
}
- public:
-
- // REAL FUNCTIONALITY
- ezexample(vw*this_vw, bool multiline=false, vw*this_vw_parser=NULL) {
+ void setup_new_ezexample(vw*this_vw, bool multiline, vw*this_vw_parser) {
vw_ref = this_vw;
vw_par_ref = (this_vw_parser == NULL) ? this_vw : this_vw_parser;
is_multiline = multiline;
@@ -53,20 +51,59 @@ class ezexample {
current_seed = 0;
current_ns = 0;
- ec = get_new_example();
-
quadratic_features_num = 0;
quadratic_features_sqr = 0.;
for (size_t i=0; i<256; i++) ns_exists[i] = false;
+ example_changed_since_prediction = true;
+ }
+
+
+ void setup_for_predict() {
+ static example* empty_example = is_multiline ? VW::read_example(*vw_par_ref, (char*)"") : NULL;
+ if (example_changed_since_prediction) {
+ mini_setup_example();
+ vw_ref->learn(ec);
+ if (is_multiline) vw_ref->learn(empty_example);
+ example_changed_since_prediction = false;
+ }
+ }
+
+ public:
+
+ // REAL FUNCTIONALITY
+ // create a new ezexample by asking the vw parser for an example
+ ezexample(vw*this_vw, bool multiline=false, vw*this_vw_parser=NULL) {
+ setup_new_ezexample(this_vw, multiline, this_vw_parser);
+
+ ec = get_new_example();
+ we_create_ec = true;
+
if (vw_ref->add_constant)
VW::add_constant_feature(*vw_ref, ec);
-
- example_changed_since_prediction = true;
}
- ~ezexample() {
+ // create a new ezexample by wrapping around an already existing example
+ // we do NOT copy your data, therefore, WARNING:
+ // do NOT touch the underlying example unless you really know what you're done)
+ ezexample(vw*this_vw, example*this_ec, bool multiline=false, vw*this_vw_parser=NULL) {
+ setup_new_ezexample(this_vw, multiline, this_vw_parser);
+
+ ec = this_ec;
+ we_create_ec = false;
+
+ for (unsigned char*i=ec->indices.begin; i != ec->indices.end; ++i) {
+ current_ns = *i;
+ ns_exists[(int)current_ns] = true;
+ }
+ if (current_ns != 0) {
+ str[0] = current_ns;
+ current_seed = VW::hash_space(*vw_ref, str);
+ }
+ }
+
+ ~ezexample() { // calls finish_example *only* if we created our own example!
if (ec->in_use)
VW::finish_example(*vw_par_ref, ec);
for (example**ecc=example_copies.begin; ecc!=example_copies.end; ecc++)
@@ -132,6 +169,24 @@ class ezexample {
inline fid addf(fid fint, float v) { return addf(current_ns, fint, v); }
+ // copy an entire namespace from this other example, you can even give it a new namespace name if you want!
+ void add_other_example_ns(example& other, char other_ns, char to_ns) {
+ if (ensure_ns_exists(to_ns)) return;
+ for (feature*f = other.atomics[(int)other_ns].begin; f != other.atomics[(int)other_ns].end; ++f) {
+ ec->atomics[(int)to_ns].push_back(*f);
+ ec->sum_feat_sq[(int)to_ns] += f->x * f->x;
+ ec->total_sum_feat_sq += f->x * f->x;
+ ec->num_features++;
+ }
+ example_changed_since_prediction = true;
+ }
+ void add_other_example_ns(example& other, char ns) { // default to_ns to other_ns
+ add_other_example_ns(other, ns, ns);
+ }
+
+ void add_other_example_ns(ezexample& other, char other_ns, char to_ns) { add_other_example_ns(*other.ec, other_ns, to_ns); }
+ void add_other_example_ns(ezexample& other, char ns ) { add_other_example_ns(*other.ec, ns); }
+
inline ezexample& set_label(string label) {
VW::parse_example_label(*vw_par_ref, *ec, label);
example_changed_since_prediction = true;
@@ -160,18 +215,17 @@ class ezexample {
ec->num_features += quadratic_features_num;
ec->total_sum_feat_sq += quadratic_features_sqr;
}
-
+
float predict() {
- static example* empty_example = is_multiline ? VW::read_example(*vw_par_ref, (char*)"") : NULL;
- if (example_changed_since_prediction) {
- mini_setup_example();
- vw_ref->learn(ec);
- if (is_multiline) vw_ref->learn(empty_example);
- example_changed_since_prediction = false;
- }
+ setup_for_predict();
return ((label_data*) ec->ld)->prediction;
}
+ float predict_partial() {
+ setup_for_predict();
+ return ec->partial_prediction;
+ }
+
void train() { // if multiline, add to stack; otherwise, actually train
if (example_changed_since_prediction) {
mini_setup_example();
@@ -225,13 +279,14 @@ class ezexample {
inline fid addf(char ns, string fstr, float val) { return addf(ns, hash(ns, fstr), val); }
inline fid addf(char ns, string fstr ) { return addf(ns, hash(ns, fstr), 1.0); }
+ inline ezexample& operator()(const vw_namespace&n) { addns(n.namespace_letter); return *this; }
+
inline ezexample& operator()(fid fint ) { addf(fint, 1.0); return *this; }
inline ezexample& operator()(string fstr ) { addf(fstr, 1.0); return *this; }
inline ezexample& operator()(const char* fstr ) { addf(fstr, 1.0); return *this; }
inline ezexample& operator()(fid fint, float val) { addf(fint, val); return *this; }
inline ezexample& operator()(string fstr, float val) { addf(fstr, val); return *this; }
inline ezexample& operator()(const char* fstr, float val) { addf(fstr, val); return *this; }
- inline ezexample& operator()(const vw_namespace&n) { addns(n.namespace_letter); return *this; }
inline ezexample& operator()(char ns, fid fint ) { addf(ns, fint, 1.0); return *this; }
inline ezexample& operator()(char ns, string fstr ) { addf(ns, fstr, 1.0); return *this; }
@@ -240,6 +295,11 @@ class ezexample {
inline ezexample& operator()(char ns, string fstr, float val) { addf(ns, fstr, val); return *this; }
inline ezexample& operator()(char ns, const char* fstr, float val) { addf(ns, fstr, val); return *this; }
+ inline ezexample& operator()( example&other, char other_ns, char to_ns) { add_other_example_ns(other, other_ns, to_ns); return *this; }
+ inline ezexample& operator()( example&other, char ns ) { add_other_example_ns(other, ns); return *this; }
+ inline ezexample& operator()(ezexample&other, char other_ns, char to_ns) { add_other_example_ns(other, other_ns, to_ns); return *this; }
+ inline ezexample& operator()(ezexample&other, char ns ) { add_other_example_ns(other, ns); return *this; }
+
inline ezexample& operator--() { remns(); return *this; }
diff --git a/vowpalwabbit/parser.cc b/vowpalwabbit/parser.cc
index 5a58ef09..26a65432 100644
--- a/vowpalwabbit/parser.cc
+++ b/vowpalwabbit/parser.cc
@@ -140,11 +140,11 @@ void handle_sigterm (int)
got_sigterm = true;
}
-bool is_test_only(uint32_t counter, uint32_t period, uint32_t after, bool holdout_off)
+bool is_test_only(uint32_t counter, uint32_t period, uint32_t after, bool holdout_off, uint32_t target_modulus) // target should be 0 in the normal case, or period-1 in the case that emptylines separate examples
{
if(holdout_off) return false;
if (after == 0) // hold out by period
- return (counter % period == 0);
+ return (counter % period == target_modulus);
else // hold out by position
return (counter+1 >= after);
}
@@ -757,10 +757,11 @@ void setup_example(vw& all, example* ae)
if ((!all.p->emptylines_separate_examples) || example_is_newline(*ae))
all.p->in_pass_counter++;
- ae->test_only = is_test_only(all.p->in_pass_counter, all.holdout_period, all.holdout_after, all.holdout_set_off);
+ ae->test_only = is_test_only(all.p->in_pass_counter, all.holdout_period, all.holdout_after, all.holdout_set_off, all.p->emptylines_separate_examples ? (all.holdout_period-1) : 0);
all.sd->t += all.p->lp.get_weight(ae->ld);
ae->example_t = (float)all.sd->t;
+
if (all.ignore_some)
{
if (all.audit || all.hash_inv)
diff --git a/vowpalwabbit/searn.cc b/vowpalwabbit/searn.cc
index 9f78e676..4d283172 100644
--- a/vowpalwabbit/searn.cc
+++ b/vowpalwabbit/searn.cc
@@ -47,6 +47,7 @@ namespace Searn
&ArgmaxTask::task,
&SequenceTask_DemoLDF::task,
&SequenceSpanTask::task,
+ &SequenceDoubleTask::task,
&EntityRelationTask::task,
NULL }; // must NULL terminate!
@@ -658,17 +659,23 @@ namespace Searn
}
void set_most_recent_snapshot_action(searn_private* priv, uint32_t action, float loss) {
+ if ((priv->state == GET_TRUTH_STRING) ||
+ (priv->state == INIT_TEST))
+ return;
if (priv->most_recent_snapshot_end == (size_t)-1) return;
bool on_training_path = priv->state == INIT_TRAIN || priv->state == BEAM_INIT;
- cdbg << "set: " << priv->most_recent_snapshot_begin << "\t" << priv->most_recent_snapshot_end << "\t" << priv->most_recent_snapshot_hash << "\totp=" << on_training_path << endl;
+ cdbg << "set: " << priv->most_recent_snapshot_begin << "\t" << priv->most_recent_snapshot_end << "\t" << priv->most_recent_snapshot_hash << "\totp=" << on_training_path << "\tloss=" << loss << "\taction=" << action << endl;
snapshot_item_ptr sip = { priv->most_recent_snapshot_begin,
priv->most_recent_snapshot_end,
priv->most_recent_snapshot_hash };
+ cdbg << "sip: start=" << sip.start << " end=" << sip.end << " hash=" << sip.hash_value << endl;
snapshot_item_result &res = priv->snapshot_map->get(sip, sip.hash_value);
if (res.loss < 0.f) { // not found!
snapshot_item_result me = { action, loss, on_training_path };
priv->snapshot_map->put_after_get(sip, sip.hash_value, me);
+ cdbg << "adding snapshot, new size = " << priv->snapshot_data.size() << endl;
} else {
+ clog << "found: action=" << res.action << " loss=" << res.loss << " otp=" << res.on_training_path << endl;
assert(false);
}
}
@@ -1037,42 +1044,6 @@ namespace Searn
}
- uint32_t searn_predict(searn_private* priv, example* ecs, size_t num_ec, v_array<uint32_t> *yallowed, v_array<uint32_t> *ystar, bool ystar_is_uint32t) // num_ec == 0 means normal example, >0 means ldf, yallowed==NULL means all allowed, ystar==NULL means don't know; ystar_is_uint32t means that the ystar ref is really just a uint32_t
- {
- vw* all = priv->all;
- learner* base = priv->base_learner;
- searn* srn=(searn*)all->searnstr;
- uint32_t a;
-
- //bool found_ss = get_most_recent_snapshot_action(priv, a);
-
- //if (!found_ss) a = (uint32_t)-1;
- if (srn->priv->rollout_all_actions)
- a = searn_predict_without_loss<COST_SENSITIVE::label>(*all, *base, ecs, num_ec, yallowed, ystar, ystar_is_uint32t);
- else
- a = searn_predict_without_loss<CB::label >(*all, *base, ecs, num_ec, yallowed, ystar, ystar_is_uint32t);
-
- //set_most_recent_snapshot_action(priv, a);
- priv->snapshotted_since_predict = false;
-
- if (priv->auto_hamming_loss) {
- float this_loss = 0.;
- if (ystar) {
- if (ystar_is_uint32t && // single allowed ystar
- (*((uint32_t*)ystar) != (uint32_t)-1) && // not a test example
- (*((uint32_t*)ystar) != a)) // wrong prediction
- this_loss = 1.;
- if ((!ystar_is_uint32t) && // many allowed ystar
- (!v_array_contains(*ystar, a)))
- this_loss = 1.;
- }
- searn_declare_loss(priv, 1, this_loss);
- }
-
- return a;
- }
-
-
bool snapshot_binary_search_lt(v_array<snapshot_item> a, size_t desired_t, size_t tag, size_t &pos, size_t last_found_pos) {
size_t hi = a.size();
if (hi == 0) return false;
@@ -1138,10 +1109,13 @@ namespace Searn
return false;
}
+
void searn_snapshot_data(searn_private* priv, size_t index, size_t tag, void* data_ptr, size_t sizeof_data, bool used_for_prediction) {
if ((priv->state == NONE) || (priv->state == INIT_TEST) || (priv->state == GET_TRUTH_STRING) || (priv->state == BEAM_PLAYOUT))
return;
+ priv->snapshotted_since_predict = true;
+
size_t i;
if ((priv->state == LEARN) && (priv->t <= priv->learn_t) &&
snapshot_binary_search_lt(priv->snapshot_data, priv->learn_t, tag, i, priv->snapshot_last_found_pos)) {
@@ -1158,7 +1132,6 @@ namespace Searn
}
if ((priv->state == INIT_TRAIN) || (priv->state == LEARN)) {
- priv->snapshotted_since_predict = true;
priv->most_recent_snapshot_end = priv->snapshot_data.size();
cdbg << "end = " << priv->most_recent_snapshot_end << endl;
@@ -1181,7 +1154,7 @@ namespace Searn
}
snapshot_item item = { index, tag, new_data, sizeof_data, priv->t };
priv->snapshot_data.push_back(item);
- //cerr << "priv->snapshot_data.push_back(item);" << endl;
+ cdbg << "priv->snapshot_data.push_back(item);" << endl;
return;
}
@@ -1263,32 +1236,80 @@ namespace Searn
}
- void searn_snapshot(searn_private* priv, size_t index, size_t tag, void* data_ptr, size_t sizeof_data, bool used_for_prediction) {
- if (! priv->do_snapshot) return;
- if (tag == 1) {
- priv->most_recent_snapshot_hash = 38429103;
- priv->most_recent_snapshot_begin = priv->snapshot_data.size();
- priv->most_recent_snapshot_end = -1;
- cdbg << "end = -1 ***" << endl;
- priv->most_recent_snapshot_loss = 0.f;
- if (priv->loss_declared)
- switch (priv->state) {
- case INIT_TEST : priv->most_recent_snapshot_loss = priv->test_loss; break;
- case INIT_TRAIN: priv->most_recent_snapshot_loss = priv->train_loss; break;
- case LEARN : priv->most_recent_snapshot_loss = priv->learn_loss; break;
- default : break;
- }
+ void searn_snapshot_initialize(searn_private* priv, size_t index) {
+ priv->most_recent_snapshot_hash = 38429103;
+ priv->most_recent_snapshot_begin = priv->snapshot_data.size();
+ priv->most_recent_snapshot_end = -1;
+ cdbg << "end = -1 ***" << endl;
+ priv->most_recent_snapshot_loss = 0.f;
+ if (priv->loss_declared)
+ switch (priv->state) {
+ case INIT_TEST : priv->most_recent_snapshot_loss = priv->test_loss; break;
+ case INIT_TRAIN: priv->most_recent_snapshot_loss = priv->train_loss; break;
+ case LEARN : priv->most_recent_snapshot_loss = priv->learn_loss; break;
+ default : break;
+ }
- if (priv->state == INIT_TRAIN)
- priv->final_snapshot_begin = priv->most_recent_snapshot_begin;
+ if (priv->state == INIT_TRAIN)
+ priv->final_snapshot_begin = priv->most_recent_snapshot_begin;
- if (priv->auto_history) {
- size_t history_size = sizeof(uint32_t) * priv->hinfo.length;
- searn_snapshot_data(priv, index, 0, priv->rollout_action.begin + priv->t, history_size, true);
+ if (priv->auto_history) {
+ size_t history_size = sizeof(uint32_t) * priv->hinfo.length;
+ searn_snapshot_data(priv, index, 0, priv->rollout_action.begin + priv->t, history_size, true);
+ }
+ }
+
+
+ uint32_t searn_predict(searn_private* priv, example* ecs, size_t num_ec, v_array<uint32_t> *yallowed, v_array<uint32_t> *ystar, bool ystar_is_uint32t) // num_ec == 0 means normal example, >0 means ldf, yallowed==NULL means all allowed, ystar==NULL means don't know; ystar_is_uint32t means that the ystar ref is really just a uint32_t
+ {
+ vw* all = priv->all;
+ learner* base = priv->base_learner;
+ searn* srn=(searn*)all->searnstr;
+ uint32_t a;
+
+ // handle the case where you want auto-history but you're not snapshotting yourself
+ if ((!priv->snapshotted_since_predict) && priv->auto_history) {
+ searn_snapshot_initialize(priv, priv->t);
+ }
+
+ //bool found_ss = get_most_recent_snapshot_action(priv, a);
+
+ //if (!found_ss) a = (uint32_t)-1;
+ if (srn->priv->rollout_all_actions)
+ a = searn_predict_without_loss<COST_SENSITIVE::label>(*all, *base, ecs, num_ec, yallowed, ystar, ystar_is_uint32t);
+ else
+ a = searn_predict_without_loss<CB::label >(*all, *base, ecs, num_ec, yallowed, ystar, ystar_is_uint32t);
+
+ //set_most_recent_snapshot_action(priv, a);
+ priv->snapshotted_since_predict = false;
+
+ if (priv->auto_hamming_loss) {
+ float this_loss = 0.;
+ if (ystar) {
+ if (ystar_is_uint32t && // single allowed ystar
+ (*((uint32_t*)ystar) != (uint32_t)-1) && // not a test example
+ (*((uint32_t*)ystar) != a)) // wrong prediction
+ this_loss = 1.;
+ if ((!ystar_is_uint32t) && // many allowed ystar
+ (!v_array_contains(*ystar, a)))
+ this_loss = 1.;
}
+ searn_declare_loss(priv, 1, this_loss);
}
+ return a;
+ }
+
+ void searn_snapshot(searn_private* priv, size_t index, size_t tag, void* data_ptr, size_t sizeof_data, bool used_for_prediction) {
+ if (! priv->do_snapshot) return;
+ if ((priv->state == GET_TRUTH_STRING) ||
+ (priv->state == INIT_TEST))
+ return;
+
+ if (tag == 1)
+ searn_snapshot_initialize(priv, index);
+
searn_snapshot_data(priv, index, tag, data_ptr, sizeof_data, used_for_prediction);
if (priv->state == INIT_TRAIN)
priv->final_snapshot_end = priv->most_recent_snapshot_end;
diff --git a/vowpalwabbit/searn_sequencetask.cc b/vowpalwabbit/searn_sequencetask.cc
index 97ac4ab8..06b8d9db 100644
--- a/vowpalwabbit/searn_sequencetask.cc
+++ b/vowpalwabbit/searn_sequencetask.cc
@@ -8,11 +8,13 @@ license as described in the file LICENSE.
#include "memory.h"
#include "example.h"
#include "gd.h"
+#include "ezexample.h"
namespace SequenceTask { Searn::searn_task task = { "sequence", initialize, finish, structured_predict }; }
-namespace ArgmaxTask { Searn::searn_task task = { "argmax", initialize, finish, structured_predict }; }
-namespace SequenceTask_DemoLDF { Searn::searn_task task = { "sequence_demoldf", initialize, finish, structured_predict }; }
+namespace ArgmaxTask { Searn::searn_task task = { "argmax", initialize, finish, structured_predict }; }
+namespace SequenceDoubleTask { Searn::searn_task task = { "sequencedouble", initialize, finish, structured_predict }; }
namespace SequenceSpanTask { Searn::searn_task task = { "sequencespan", initialize, finish, structured_predict }; }
+namespace SequenceTask_DemoLDF { Searn::searn_task task = { "sequence_demoldf", initialize, finish, structured_predict }; }
namespace SequenceTask {
@@ -110,6 +112,41 @@ namespace ArgmaxTask {
}
}
+
+namespace SequenceDoubleTask {
+ using namespace Searn;
+
+ void initialize(searn& srn, size_t& num_actions, po::variables_map& vm) {
+ srn.set_options( AUTO_HISTORY | // automatically add history features to our examples, please
+ EXAMPLES_DONT_CHANGE ); // we don't do any internal example munging
+ }
+
+ void finish(searn& srn) { } // if we had task data, we'd want to free it here
+
+ void structured_predict(searn& srn, vector<example*> ec) {
+ size_t N = ec.size();
+ for (size_t j=0; j<N*2; j++) {
+ srn.snapshot(j, 1, &j, sizeof(j), true);
+ size_t i =
+ (j == 0) ? 0 :
+ (j == 2*N-1) ? (N-1) :
+ (j%2 == 0) ? (j/2 - 1) :
+ ((j+1)/2);
+
+ size_t prediction = srn.predict(ec[i], MULTICLASS::get_example_label(ec[i]));
+
+ if ((j >= 2) && (j%2==0)) {
+ srn.loss( prediction != MULTICLASS::get_example_label(ec[i]) );
+ if (srn.output().good())
+ srn.output() << prediction << ' ';
+ } else
+ srn.loss(0.);
+ }
+ }
+}
+
+
+
namespace SequenceSpanTask {
enum EncodingType { BIO, BILOU };
// the format for the BIO encoding is:
diff --git a/vowpalwabbit/searn_sequencetask.h b/vowpalwabbit/searn_sequencetask.h
index 92621e68..7d99ac94 100644
--- a/vowpalwabbit/searn_sequencetask.h
+++ b/vowpalwabbit/searn_sequencetask.h
@@ -29,6 +29,13 @@ namespace SequenceSpanTask {
extern Searn::searn_task task;
}
+namespace SequenceDoubleTask {
+ void initialize(Searn::searn&, size_t&, po::variables_map&);
+ void finish(Searn::searn&);
+ void structured_predict(Searn::searn&, vector<example*>);
+ extern Searn::searn_task task;
+}
+
namespace SequenceTask_DemoLDF {
void initialize(Searn::searn&, size_t&, po::variables_map&);
void finish(Searn::searn&);