Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2013-01-21 19:47:03 +0400
committerJohn Langford <jl@hunch.net>2013-01-21 19:47:03 +0400
commitf8d453e6eec1067e411271f2c208f001ec32237e (patch)
tree757ff6bfe1d1b928d3c035c6e794a6577d47b047
parent85a5725045ca861feae61ab65b7a4593f6f81615 (diff)
first compiling version
-rw-r--r--Makefile.am2
-rw-r--r--vowpalwabbit/bfgs.cc496
-rw-r--r--vowpalwabbit/bfgs.h6
-rw-r--r--vowpalwabbit/cb.cc244
-rw-r--r--vowpalwabbit/cb.h1
-rw-r--r--vowpalwabbit/csoaa.cc593
-rw-r--r--vowpalwabbit/csoaa.h12
-rw-r--r--vowpalwabbit/ect.cc279
-rw-r--r--vowpalwabbit/gd.cc34
-rw-r--r--vowpalwabbit/gd.h7
-rw-r--r--vowpalwabbit/gd_mf.cc46
-rw-r--r--vowpalwabbit/gd_mf.h3
-rw-r--r--vowpalwabbit/global_data.cc17
-rw-r--r--vowpalwabbit/global_data.h15
-rw-r--r--vowpalwabbit/lda_core.cc7
-rw-r--r--vowpalwabbit/lda_core.h2
-rw-r--r--vowpalwabbit/learner.h18
-rw-r--r--vowpalwabbit/nn.cc294
-rw-r--r--vowpalwabbit/noop.cc32
-rw-r--r--vowpalwabbit/noop.h4
-rw-r--r--vowpalwabbit/oaa.cc71
-rw-r--r--vowpalwabbit/parse_args.cc69
-rw-r--r--vowpalwabbit/parse_regressor.cc2
-rw-r--r--vowpalwabbit/searn.cc1264
-rw-r--r--vowpalwabbit/searn.h22
-rw-r--r--vowpalwabbit/searn_sequencetask.cc3
-rw-r--r--vowpalwabbit/searn_sequencetask.h3
-rw-r--r--vowpalwabbit/sender.cc61
-rw-r--r--vowpalwabbit/sender.h4
-rw-r--r--vowpalwabbit/vw.cc2
-rw-r--r--vowpalwabbit/wap.cc56
31 files changed, 1851 insertions, 1818 deletions
diff --git a/Makefile.am b/Makefile.am
index 7e405360..1750ebe3 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -17,7 +17,7 @@ noinst_HEADERS = vowpalwabbit/accumulate.h vowpalwabbit/oaa.h \
vowpalwabbit/v_array.h vowpalwabbit/lda_core.h \
vowpalwabbit/v_hashmap.h vowpalwabbit/loss_functions.h \
vowpalwabbit/network.h vowpalwabbit/wap.h vowpalwabbit/noop.h \
- vowpalwabbit/nn.h
+ vowpalwabbit/nn.h vowpalwabbit/learner.h
ACLOCAL_AMFLAGS = -I acinclude.d
diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc
index d863d9e7..2d3febd1 100644
--- a/vowpalwabbit/bfgs.cc
+++ b/vowpalwabbit/bfgs.cc
@@ -61,42 +61,43 @@ namespace BFGS
{
//nonrentrant
-
-double wolfe1_bound = 0.01;
-
-struct timeb t_start, t_end;
-double net_comm_time = 0.0;
-
-struct timeb t_start_global, t_end_global;
-double net_time;
-
-v_array<float> predictions;
-size_t example_number=0;
-size_t current_pass = 0;
-
- // default transition behavior
-bool first_hessian_on=true;
-bool backstep_on=false;
-
- // set by initializer
-int mem_stride;
-bool output_regularizer;
-float* mem;
-double* rho;
-double* alpha;
-
- weight* regularizers = NULL;
- // the below needs to be included when resetting, in addition to preconditioner and derivative
- int lastj, origin;
- double loss_sum, previous_loss_sum;
- float step_size;
- double importance_weight_sum;
- double curvature;
-
- // first pass specification
- bool first_pass=true;
-bool gradient_pass=true;
-bool preconditioner_pass=true;
+ struct bfgs {
+ double wolfe1_bound;
+
+ struct timeb t_start, t_end;
+ double net_comm_time;
+
+ struct timeb t_start_global, t_end_global;
+ double net_time;
+
+ v_array<float> predictions;
+ size_t example_number;
+ size_t current_pass;
+
+ // default transition behavior
+ bool first_hessian_on;
+ bool backstep_on;
+
+ // set by initializer
+ int mem_stride;
+ bool output_regularizer;
+ float* mem;
+ double* rho;
+ double* alpha;
+
+ weight* regularizers;
+ // the below needs to be included when resetting, in addition to preconditioner and derivative
+ int lastj, origin;
+ double loss_sum, previous_loss_sum;
+ float step_size;
+ double importance_weight_sum;
+ double curvature;
+
+ // first pass specification
+ bool first_pass;
+ bool gradient_pass;
+ bool preconditioner_pass;
+ };
const char* curv_message = "Zero or negative curvature detected.\n"
"To increase curvature you can increase regularization or rescale features.\n"
@@ -121,15 +122,15 @@ void zero_preconditioner(vw& all)
weights[stride*i+W_COND] = 0;
}
-void reset_state(vw& all, bool zero)
+void reset_state(vw& all, bfgs& b, bool zero)
{
- lastj = origin = 0;
- loss_sum = previous_loss_sum = 0.;
- importance_weight_sum = 0.;
- curvature = 0.;
- first_pass = true;
- gradient_pass = true;
- preconditioner_pass = true;
+ b.lastj = b.origin = 0;
+ b.loss_sum = b.previous_loss_sum = 0.;
+ b.importance_weight_sum = 0.;
+ b.curvature = 0.;
+ b.first_pass = true;
+ b.gradient_pass = true;
+ b.preconditioner_pass = true;
if (zero)
{
zero_derivative(all);
@@ -300,7 +301,7 @@ float dot_with_direction(vw& all, example* &ec)
return ret;
}
-double regularizer_direction_magnitude(vw& all, float regularizer)
+double regularizer_direction_magnitude(vw& all, bfgs& b, float regularizer)
{//compute direction magnitude
double ret = 0.;
@@ -310,12 +311,12 @@ double regularizer_direction_magnitude(vw& all, float regularizer)
uint32_t length = 1 << all.num_bits;
size_t stride = all.stride;
weight* weights = all.reg.weight_vector;
- if (regularizers == NULL)
+ if (b.regularizers == NULL)
for(uint32_t i = 0; i < length; i++)
ret += regularizer*weights[stride*i+W_DIR]*weights[stride*i+W_DIR];
else
for(uint32_t i = 0; i < length; i++)
- ret += regularizers[2*i]*weights[stride*i+W_DIR]*weights[stride*i+W_DIR];
+ ret += b.regularizers[2*i]*weights[stride*i+W_DIR]*weights[stride*i+W_DIR];
return ret;
}
@@ -332,7 +333,7 @@ float direction_magnitude(vw& all)
return (float)ret;
}
- void bfgs_iter_start(vw& all, float* mem, int& lastj, double importance_weight_sum, int&origin)
+void bfgs_iter_start(vw& all, bfgs& b, float* mem, int& lastj, double importance_weight_sum, int&origin)
{
uint32_t length = 1 << all.num_bits;
size_t stride = all.stride;
@@ -342,10 +343,10 @@ float direction_magnitude(vw& all)
double g1_g1 = 0.;
origin = 0;
- for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) {
+ for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
if (all.m>0)
- mem[(MEM_XT+origin)%mem_stride] = w[W_XT];
- mem[(MEM_GT+origin)%mem_stride] = w[W_GT];
+ mem[(MEM_XT+origin)%b.mem_stride] = w[W_XT];
+ mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT];
g1_Hg1 += w[W_GT] * w[W_GT] * w[W_COND];
g1_g1 += w[W_GT] * w[W_GT];
w[W_DIR] = -w[W_COND]*w[W_GT];
@@ -358,7 +359,7 @@ float direction_magnitude(vw& all)
g1_Hg1/importance_weight_sum, "", "", "");
}
-void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& lastj, int &origin)
+void bfgs_iter_middle(vw& all, bfgs& b, float* mem, double* rho, double* alpha, int& lastj, int &origin)
{
uint32_t length = 1 << all.num_bits;
size_t stride = all.stride;
@@ -373,10 +374,10 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last
double g_Hg = 0.;
double y = 0.;
- for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) {
- y = w[W_GT]-mem[(MEM_GT+origin)%mem_stride];
+ for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
+ y = w[W_GT]-mem[(MEM_GT+origin)%b.mem_stride];
g_Hy += w[W_GT] * w[W_COND] * y;
- g_Hg += mem[(MEM_GT+origin)%mem_stride] * w[W_COND] * mem[(MEM_GT+origin)%mem_stride];
+ g_Hg += mem[(MEM_GT+origin)%b.mem_stride] * w[W_COND] * mem[(MEM_GT+origin)%b.mem_stride];
}
float beta = (float) (g_Hy/g_Hg);
@@ -386,8 +387,8 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last
mem = mem0;
w = w0;
- for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) {
- mem[(MEM_GT+origin)%mem_stride] = w[W_GT];
+ for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
+ mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT];
w[W_DIR] *= beta;
w[W_DIR] -= w[W_COND]*w[W_GT];
@@ -407,13 +408,13 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last
double y_Hy = 0.;
double s_q = 0.;
- for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) {
- mem[(MEM_YT+origin)%mem_stride] = w[W_GT] - mem[(MEM_GT+origin)%mem_stride];
- mem[(MEM_ST+origin)%mem_stride] = w[W_XT] - mem[(MEM_XT+origin)%mem_stride];
+ for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
+ mem[(MEM_YT+origin)%b.mem_stride] = w[W_GT] - mem[(MEM_GT+origin)%b.mem_stride];
+ mem[(MEM_ST+origin)%b.mem_stride] = w[W_XT] - mem[(MEM_XT+origin)%b.mem_stride];
w[W_DIR] = w[W_GT];
- y_s += mem[(MEM_YT+origin)%mem_stride]*mem[(MEM_ST+origin)%mem_stride];
- y_Hy += mem[(MEM_YT+origin)%mem_stride]*mem[(MEM_YT+origin)%mem_stride]*w[W_COND];
- s_q += mem[(MEM_ST+origin)%mem_stride]*w[W_GT];
+ y_s += mem[(MEM_YT+origin)%b.mem_stride]*mem[(MEM_ST+origin)%b.mem_stride];
+ y_Hy += mem[(MEM_YT+origin)%b.mem_stride]*mem[(MEM_YT+origin)%b.mem_stride]*w[W_COND];
+ s_q += mem[(MEM_ST+origin)%b.mem_stride]*w[W_GT];
}
if (y_s <= 0. || y_Hy <= 0.)
@@ -427,9 +428,9 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last
s_q = 0.;
mem = mem0;
w = w0;
- for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) {
- w[W_DIR] -= (float)alpha[j]*mem[(2*j+MEM_YT+origin)%mem_stride];
- s_q += mem[(2*j+2+MEM_ST+origin)%mem_stride]*w[W_DIR];
+ for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
+ w[W_DIR] -= (float)alpha[j]*mem[(2*j+MEM_YT+origin)%b.mem_stride];
+ s_q += mem[(2*j+2+MEM_ST+origin)%b.mem_stride]*w[W_DIR];
}
}
@@ -437,10 +438,10 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last
double y_r = 0.;
mem = mem0;
w = w0;
- for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) {
- w[W_DIR] -= (float)alpha[lastj]*mem[(2*lastj+MEM_YT+origin)%mem_stride];
+ for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
+ w[W_DIR] -= (float)alpha[lastj]*mem[(2*lastj+MEM_YT+origin)%b.mem_stride];
w[W_DIR] *= gamma*w[W_COND];
- y_r += mem[(2*lastj+MEM_YT+origin)%mem_stride]*w[W_DIR];
+ y_r += mem[(2*lastj+MEM_YT+origin)%b.mem_stride]*w[W_DIR];
}
double coef_j;
@@ -450,17 +451,17 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last
y_r = 0.;
mem = mem0;
w = w0;
- for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) {
- w[W_DIR] += (float)coef_j*mem[(2*j+MEM_ST+origin)%mem_stride];
- y_r += mem[(2*j-2+MEM_YT+origin)%mem_stride]*w[W_DIR];
+ for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
+ w[W_DIR] += (float)coef_j*mem[(2*j+MEM_ST+origin)%b.mem_stride];
+ y_r += mem[(2*j-2+MEM_YT+origin)%b.mem_stride]*w[W_DIR];
}
}
coef_j = alpha[0] - rho[0] * y_r;
mem = mem0;
w = w0;
- for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) {
- w[W_DIR] = -w[W_DIR]-(float)coef_j*mem[(MEM_ST+origin)%mem_stride];
+ for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
+ w[W_DIR] = -w[W_DIR]-(float)coef_j*mem[(MEM_ST+origin)%b.mem_stride];
}
/*********************
@@ -470,17 +471,17 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last
mem = mem0;
w = w0;
lastj = (lastj<all.m-1) ? lastj+1 : all.m-1;
- origin = (origin+mem_stride-2)%mem_stride;
- for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) {
- mem[(MEM_GT+origin)%mem_stride] = w[W_GT];
- mem[(MEM_XT+origin)%mem_stride] = w[W_XT];
+ origin = (origin+b.mem_stride-2)%b.mem_stride;
+ for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
+ mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT];
+ mem[(MEM_XT+origin)%b.mem_stride] = w[W_XT];
w[W_GT] = 0;
}
for (int j=lastj; j>0; j--)
rho[j] = rho[j-1];
}
-double wolfe_eval(vw& all, float* mem, double loss_sum, double previous_loss_sum, double step_size, double importance_weight_sum, int &origin, double& wolfe1) {
+double wolfe_eval(vw& all, bfgs& b, float* mem, double loss_sum, double previous_loss_sum, double step_size, double importance_weight_sum, int &origin, double& wolfe1) {
uint32_t length = 1 << all.num_bits;
size_t stride = all.stride;
weight* w = all.reg.weight_vector;
@@ -490,8 +491,8 @@ double wolfe_eval(vw& all, float* mem, double loss_sum, double previous_loss_sum
double g1_Hg1 = 0.;
double g1_g1 = 0.;
- for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) {
- g0_d += mem[(MEM_GT+origin)%mem_stride] * w[W_DIR];
+ for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
+ g0_d += mem[(MEM_GT+origin)%b.mem_stride] * w[W_DIR];
g1_d += w[W_GT] * w[W_DIR];
g1_Hg1 += w[W_GT] * w[W_GT] * w[W_COND];
g1_g1 += w[W_GT] * w[W_GT];
@@ -507,13 +508,13 @@ double wolfe_eval(vw& all, float* mem, double loss_sum, double previous_loss_sum
}
-double add_regularization(vw& all, float regularization)
+double add_regularization(vw& all, bfgs& b, float regularization)
{//compute the derivative difference
double ret = 0.;
uint32_t length = 1 << all.num_bits;
size_t stride = all.stride;
weight* weights = all.reg.weight_vector;
- if (regularizers == NULL)
+ if (b.regularizers == NULL)
{
for(uint32_t i = 0; i < length; i++) {
weights[stride*i+W_GT] += regularization*weights[stride*i];
@@ -523,21 +524,21 @@ double add_regularization(vw& all, float regularization)
else
{
for(uint32_t i = 0; i < length; i++) {
- weight delta_weight = weights[stride*i] - regularizers[2*i+1];
- weights[stride*i+W_GT] += regularizers[2*i]*delta_weight;
- ret += 0.5*regularizers[2*i]*delta_weight*delta_weight;
+ weight delta_weight = weights[stride*i] - b.regularizers[2*i+1];
+ weights[stride*i+W_GT] += b.regularizers[2*i]*delta_weight;
+ ret += 0.5*b.regularizers[2*i]*delta_weight*delta_weight;
}
}
return ret;
}
-void finalize_preconditioner(vw& all, float regularization)
+void finalize_preconditioner(vw& all, bfgs& b, float regularization)
{
uint32_t length = 1 << all.num_bits;
size_t stride = all.stride;
weight* weights = all.reg.weight_vector;
- if (regularizers == NULL)
+ if (b.regularizers == NULL)
for(uint32_t i = 0; i < length; i++) {
weights[stride*i+W_COND] += regularization;
if (weights[stride*i+W_COND] > 0)
@@ -545,34 +546,34 @@ void finalize_preconditioner(vw& all, float regularization)
}
else
for(uint32_t i = 0; i < length; i++) {
- weights[stride*i+W_COND] += regularizers[2*i];
+ weights[stride*i+W_COND] += b.regularizers[2*i];
if (weights[stride*i+W_COND] > 0)
weights[stride*i+W_COND] = 1.f / weights[stride*i+W_COND];
}
}
-void preconditioner_to_regularizer(vw& all, float regularization)
+void preconditioner_to_regularizer(vw& all, bfgs& b, float regularization)
{
uint32_t length = 1 << all.num_bits;
size_t stride = all.stride;
weight* weights = all.reg.weight_vector;
- if (regularizers == NULL)
+ if (b.regularizers == NULL)
{
- regularizers = (weight *)calloc(2*length, sizeof(weight));
+ b.regularizers = (weight *)calloc(2*length, sizeof(weight));
- if (regularizers == NULL)
+ if (b.regularizers == NULL)
{
cerr << all.program_name << ": Failed to allocate weight array: try decreasing -b <bits>" << endl;
exit (1);
}
for(uint32_t i = 0; i < length; i++)
- regularizers[2*i] = weights[stride*i+W_COND] + regularization;
+ b.regularizers[2*i] = weights[stride*i+W_COND] + regularization;
}
else
for(uint32_t i = 0; i < length; i++)
- regularizers[2*i] = weights[stride*i+W_COND] + regularizers[2*i];
+ b.regularizers[2*i] = weights[stride*i+W_COND] + b.regularizers[2*i];
for(uint32_t i = 0; i < length; i++)
- regularizers[2*i+1] = weights[stride*i];
+ b.regularizers[2*i+1] = weights[stride*i];
}
void zero_state(vw& all)
@@ -588,15 +589,15 @@ void zero_state(vw& all)
}
}
-double derivative_in_direction(vw& all, float* mem, int &origin)
+double derivative_in_direction(vw& all, bfgs& b, float* mem, int &origin)
{
double ret = 0.;
uint32_t length = 1 << all.num_bits;
size_t stride = all.stride;
weight* w = all.reg.weight_vector;
- for(uint32_t i = 0; i < length; i++, w+=stride, mem+=mem_stride)
- ret += mem[(MEM_GT+origin)%mem_stride]*w[W_DIR];
+ for(uint32_t i = 0; i < length; i++, w+=stride, mem+=b.mem_stride)
+ ret += mem[(MEM_GT+origin)%b.mem_stride]*w[W_DIR];
return ret;
}
@@ -611,63 +612,66 @@ void update_weight(vw& all, string& reg_name, float step_size, size_t current_pa
save_predictor(all, reg_name, current_pass);
}
-int process_pass(vw& all) {
+int process_pass(vw& all, bfgs& b) {
int status = LEARN_OK;
/********************************************************************/
/* A) FIRST PASS FINISHED: INITIALIZE FIRST LINE SEARCH *************/
/********************************************************************/
- if (first_pass) {
+ if (b.first_pass) {
if(all.span_server != "")
{
accumulate(all, all.span_server, all.reg, W_COND); //Accumulate preconditioner
- importance_weight_sum = accumulate_scalar(all, all.span_server, (float)importance_weight_sum);
+ float temp = b.importance_weight_sum;
+ b.importance_weight_sum = accumulate_scalar(all, all.span_server, temp);
}
- finalize_preconditioner(all, all.l2_lambda);
+ finalize_preconditioner(all, b, all.l2_lambda);
if(all.span_server != "") {
- loss_sum = accumulate_scalar(all, all.span_server, (float)loss_sum); //Accumulate loss_sums
+ float temp = (float)b.loss_sum;
+ b.loss_sum = accumulate_scalar(all, all.span_server, temp); //Accumulate loss_sums
accumulate(all, all.span_server, all.reg, 1); //Accumulate gradients from all nodes
}
if (all.l2_lambda > 0.)
- loss_sum += add_regularization(all, all.l2_lambda);
+ b.loss_sum += add_regularization(all, b, all.l2_lambda);
if (!all.quiet)
- fprintf(stderr, "%2lu %-10.5f\t", (long unsigned int)current_pass+1, loss_sum / importance_weight_sum);
+ fprintf(stderr, "%2lu %-10.5f\t", (long unsigned int)b.current_pass+1, b.loss_sum / b.importance_weight_sum);
- previous_loss_sum = loss_sum;
- loss_sum = 0.;
- example_number = 0;
- curvature = 0;
- bfgs_iter_start(all, mem, lastj, importance_weight_sum, origin);
- if (first_hessian_on) {
- gradient_pass = false;//now start computing curvature
+ b.previous_loss_sum = b.loss_sum;
+ b.loss_sum = 0.;
+ b.example_number = 0;
+ b.curvature = 0;
+ bfgs_iter_start(all, b, b.mem, b.lastj, b.importance_weight_sum, b.origin);
+ if (b.first_hessian_on) {
+ b.gradient_pass = false;//now start computing curvature
}
else {
- step_size = 0.5;
+ b.step_size = 0.5;
float d_mag = direction_magnitude(all);
- ftime(&t_end_global);
- net_time = (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm));
+ ftime(&b.t_end_global);
+ b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
if (!all.quiet)
- fprintf(stderr, "%-10s\t%-10.5f\t%-10.5f\n", "", d_mag, step_size);
- predictions.erase();
- update_weight(all, all.final_regressor_name, step_size, current_pass); }
+ fprintf(stderr, "%-10s\t%-10.5f\t%-10.5f\n", "", d_mag, b.step_size);
+ b.predictions.erase();
+ update_weight(all, all.final_regressor_name, b.step_size, b.current_pass); }
}
else
/********************************************************************/
/* B) GRADIENT CALCULATED *******************************************/
/********************************************************************/
- if (gradient_pass) // We just finished computing all gradients
+ if (b.gradient_pass) // We just finished computing all gradients
{
if(all.span_server != "") {
- loss_sum = accumulate_scalar(all, all.span_server, (float)loss_sum); //Accumulate loss_sums
+ float t = (float)b.loss_sum;
+ b.loss_sum = accumulate_scalar(all, all.span_server, t); //Accumulate loss_sums
accumulate(all, all.span_server, all.reg, 1); //Accumulate gradients from all nodes
}
if (all.l2_lambda > 0.)
- loss_sum += add_regularization(all, all.l2_lambda);
+ b.loss_sum += add_regularization(all, b, all.l2_lambda);
if (!all.quiet)
- fprintf(stderr, "%2lu %-10.5f\t", (long unsigned int)current_pass+1, loss_sum / importance_weight_sum);
+ fprintf(stderr, "%2lu %-10.5f\t", (long unsigned int)b.current_pass+1, b.loss_sum / b.importance_weight_sum);
double wolfe1;
- double new_step = wolfe_eval(all, mem, loss_sum, previous_loss_sum, step_size, importance_weight_sum, origin, wolfe1);
+ double new_step = wolfe_eval(all, b, b.mem, b.loss_sum, b.previous_loss_sum, b.step_size, b.importance_weight_sum, b.origin, wolfe1);
/********************************************************************/
/* B0) DERIVATIVE ZERO: MINIMUM FOUND *******************************/
@@ -676,26 +680,26 @@ int process_pass(vw& all) {
{
fprintf(stderr, "\n");
fprintf(stdout, "Derivative 0 detected.\n");
- step_size=0.0;
+ b.step_size=0.0;
status = LEARN_CONV;
}
/********************************************************************/
/* B1) LINE SEARCH FAILED *******************************************/
/********************************************************************/
- else if (backstep_on && (wolfe1<wolfe1_bound || loss_sum > previous_loss_sum))
+ else if (b.backstep_on && (wolfe1<b.wolfe1_bound || b.loss_sum > b.previous_loss_sum))
{// curvature violated, or we stepped too far last time: step back
- ftime(&t_end_global);
- net_time = (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm));
- float ratio = (step_size==0.f) ? 0.f : (float)new_step/(float)step_size;
+ ftime(&b.t_end_global);
+ b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
+ float ratio = (b.step_size==0.f) ? 0.f : (float)new_step/(float)b.step_size;
if (!all.quiet)
fprintf(stderr, "%-10s\t%-10s\t(revise x %.1f)\t%-10.5f\n",
"","",ratio,
new_step);
- predictions.erase();
- update_weight(all, all.final_regressor_name, (float)(-step_size+new_step), current_pass);
- step_size = (float)new_step;
+ b.predictions.erase();
+ update_weight(all, all.final_regressor_name, (float)(-b.step_size+new_step), b.current_pass);
+ b.step_size = (float)new_step;
zero_derivative(all);
- loss_sum = 0.;
+ b.loss_sum = 0.;
}
/********************************************************************/
@@ -703,38 +707,38 @@ int process_pass(vw& all) {
/* DETERMINE NEXT SEARCH DIRECTION ******************/
/********************************************************************/
else {
- double rel_decrease = (previous_loss_sum-loss_sum)/previous_loss_sum;
- if (!nanpattern((float)rel_decrease) && backstep_on && fabs(rel_decrease)<all.rel_threshold) {
+ double rel_decrease = (b.previous_loss_sum-b.loss_sum)/b.previous_loss_sum;
+ if (!nanpattern((float)rel_decrease) && b.backstep_on && fabs(rel_decrease)<all.rel_threshold) {
fprintf(stdout, "\nTermination condition reached in pass %ld: decrease in loss less than %.3f%%.\n"
- "If you want to optimize further, decrease termination threshold.\n", (long int)current_pass+1, all.rel_threshold*100.0);
+ "If you want to optimize further, decrease termination threshold.\n", (long int)b.current_pass+1, all.rel_threshold*100.0);
status = LEARN_CONV;
}
- previous_loss_sum = loss_sum;
- loss_sum = 0.;
- example_number = 0;
- curvature = 0;
- step_size = 1.0;
+ b.previous_loss_sum = b.loss_sum;
+ b.loss_sum = 0.;
+ b.example_number = 0;
+ b.curvature = 0;
+ b.step_size = 1.0;
try {
- bfgs_iter_middle(all, mem, rho, alpha, lastj, origin);
+ bfgs_iter_middle(all, b, b.mem, b.rho, b.alpha, b.lastj, b.origin);
}
catch (curv_exception e) {
fprintf(stdout, "In bfgs_iter_middle: %s", curv_message);
- step_size=0.0;
+ b.step_size=0.0;
status = LEARN_CURV;
}
if (all.hessian_on) {
- gradient_pass = false;//now start computing curvature
+ b.gradient_pass = false;//now start computing curvature
}
else {
float d_mag = direction_magnitude(all);
- ftime(&t_end_global);
- net_time = (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm));
+ ftime(&b.t_end_global);
+ b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
if (!all.quiet)
- fprintf(stderr, "%-10s\t%-10.5f\t%-10.5f\n", "", d_mag, step_size);
- predictions.erase();
- update_weight(all, all.final_regressor_name, step_size, current_pass);
+ fprintf(stderr, "%-10s\t%-10.5f\t%-10.5f\n", "", d_mag, b.step_size);
+ b.predictions.erase();
+ update_weight(all, all.final_regressor_name, b.step_size, b.current_pass);
}
}
}
@@ -745,57 +749,58 @@ int process_pass(vw& all) {
else // just finished all second gradients
{
if(all.span_server != "") {
- curvature = accumulate_scalar(all, all.span_server, (float)curvature); //Accumulate curvatures
+ float t = (float)b.curvature;
+ b.curvature = accumulate_scalar(all, all.span_server, t); //Accumulate curvatures
}
if (all.l2_lambda > 0.)
- curvature += regularizer_direction_magnitude(all, all.l2_lambda);
- float dd = (float)derivative_in_direction(all, mem, origin);
- if (curvature == 0. && dd != 0.)
+ b.curvature += regularizer_direction_magnitude(all, b, all.l2_lambda);
+ float dd = (float)derivative_in_direction(all, b, b.mem, b.origin);
+ if (b.curvature == 0. && dd != 0.)
{
fprintf(stdout, "%s", curv_message);
- step_size=0.0;
+ b.step_size=0.0;
status = LEARN_CURV;
}
else if ( dd == 0.)
{
fprintf(stdout, "Derivative 0 detected.\n");
- step_size=0.0;
+ b.step_size=0.0;
status = LEARN_CONV;
}
else
- step_size = - dd/(float)curvature;
+ b.step_size = - dd/(float)b.curvature;
float d_mag = direction_magnitude(all);
- predictions.erase();
- update_weight(all, all.final_regressor_name , step_size, current_pass);
- ftime(&t_end_global);
- net_time = (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm));
+ b.predictions.erase();
+ update_weight(all, all.final_regressor_name , b.step_size, b.current_pass);
+ ftime(&b.t_end_global);
+ b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
if (!all.quiet)
- fprintf(stderr, "%-10.5f\t%-10.5f\t%-10.5f\n", curvature / importance_weight_sum, d_mag, step_size);
- gradient_pass = true;
+ fprintf(stderr, "%-10.5f\t%-10.5f\t%-10.5f\n", b.curvature / b.importance_weight_sum, d_mag, b.step_size);
+ b.gradient_pass = true;
}//now start computing derivatives.
- current_pass++;
- first_pass = false;
- preconditioner_pass = false;
+ b.current_pass++;
+ b.first_pass = false;
+ b.preconditioner_pass = false;
return status;
}
-void process_example(vw& all, example *ec)
+void process_example(vw& all, bfgs& b, example *ec)
{
label_data* ld = (label_data*)ec->ld;
- if (first_pass)
- importance_weight_sum += ld->weight;
+ if (b.first_pass)
+ b.importance_weight_sum += ld->weight;
/********************************************************************/
/* I) GRADIENT CALCULATION ******************************************/
/********************************************************************/
- if (gradient_pass)
+ if (b.gradient_pass)
{
ec->final_prediction = predict_and_gradient(all, ec);//w[0] & w[1]
ec->loss = all.loss->getLoss(all.sd, ec->final_prediction, ld->label) * ld->weight;
- loss_sum += ec->loss;
- predictions.push_back(ec->final_prediction);
+ b.loss_sum += ec->loss;
+ b.predictions.push_back(ec->final_prediction);
}
/********************************************************************/
/* II) CURVATURE CALCULATION ****************************************/
@@ -803,62 +808,64 @@ void process_example(vw& all, example *ec)
else //computing curvature
{
float d_dot_x = dot_with_direction(all, ec);//w[2]
- if (example_number >= predictions.size())//Make things safe in case example source is strange.
- example_number = predictions.size()-1;
- ec->final_prediction = predictions[example_number];
- ec->partial_prediction = predictions[example_number];
+ if (b.example_number >= b.predictions.size())//Make things safe in case example source is strange.
+ b.example_number = b.predictions.size()-1;
+ ec->final_prediction = b.predictions[b.example_number];
+ ec->partial_prediction = b.predictions[b.example_number];
ec->loss = all.loss->getLoss(all.sd, ec->final_prediction, ld->label) * ld->weight;
- float sd = all.loss->second_derivative(all.sd, predictions[example_number++],ld->label);
- curvature += d_dot_x*d_dot_x*sd*ld->weight;
+ float sd = all.loss->second_derivative(all.sd, b.predictions[b.example_number++],ld->label);
+ b.curvature += d_dot_x*d_dot_x*sd*ld->weight;
}
- if (preconditioner_pass)
+ if (b.preconditioner_pass)
update_preconditioner(all, ec);//w[3]
}
-void learn(void* a, example* ec)
+void learn(void* a, void* d, example* ec)
{
vw* all = (vw*)a;
+ bfgs* b = (bfgs*)d;
assert(ec->in_use);
- if (ec->pass != current_pass) {
- int status = process_pass(*all);
+ if (ec->pass != b->current_pass) {
+ int status = process_pass(*all, *b);
if (status != LEARN_OK)
- reset_state(*all, true);
- else if (output_regularizer && current_pass==all->numpasses-1) {
+ reset_state(*all, *b, true);
+ else if (b->output_regularizer && b->current_pass==all->numpasses-1) {
zero_preconditioner(*all);
- preconditioner_pass = true;
+ b->preconditioner_pass = true;
}
}
if (test_example(ec))
ec->final_prediction = bfgs_predict(*all,ec);//w[0]
else
- process_example(*all, ec);
+ process_example(*all, *b, ec);
}
-void finish(void* a)
+void finish(void* a, void* d)
{
vw* all = (vw*)a;
- if (current_pass != 0 && !output_regularizer)
- process_pass(*all);
+ bfgs* b = (bfgs*)d;
+ if (b->current_pass != 0 && !b->output_regularizer)
+ process_pass(*all, *b);
if (!all->quiet)
fprintf(stderr, "\n");
- if (output_regularizer)//need to accumulate and place the regularizer.
+ if (b->output_regularizer)//need to accumulate and place the regularizer.
{
if(all->span_server != "")
accumulate(*all, all->span_server, all->reg, W_COND); //Accumulate preconditioner
- preconditioner_to_regularizer(*all, all->l2_lambda);
+ preconditioner_to_regularizer(*all, *b, all->l2_lambda);
}
- ftime(&t_end_global);
- net_time = (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm));
+ ftime(&b->t_end_global);
+ b->net_time = (int) (1000.0 * (b->t_end_global.time - b->t_start_global.time) + (b->t_end_global.millitm - b->t_start_global.millitm));
- predictions.delete_v();
- free(mem);
- free(rho);
- free(alpha);
+ b->predictions.delete_v();
+ free(b->mem);
+ free(b->rho);
+ free(b->alpha);
}
-void save_load_regularizer(vw& all, io_buf& model_file, bool read, bool text)
+void save_load_regularizer(vw& all, bfgs& b, io_buf& model_file, bool read, bool text)
{
char buff[512];
int c = 0;
@@ -877,14 +884,14 @@ void save_load_regularizer(vw& all, io_buf& model_file, bool read, bool text)
if (brw > 0)
{
assert (i< length);
- v = &(regularizers[i]);
+ v = &(b.regularizers[i]);
if (brw > 0)
brw += bin_read_fixed(model_file, (char*)v, sizeof(*v), "");
}
}
else // write binary or text
{
- v = &(regularizers[i]);
+ v = &(b.regularizers[i]);
if (*v != 0.)
{
c++;
@@ -906,9 +913,10 @@ void save_load_regularizer(vw& all, io_buf& model_file, bool read, bool text)
}
-void save_load(void* in, io_buf& model_file, bool read, bool text)
+void save_load(void* in, void* d, io_buf& model_file, bool read, bool text)
{
vw* all = (vw*)in;
+ bfgs* b = (bfgs*)d;
uint32_t length = 1 << all->num_bits;
@@ -917,8 +925,8 @@ void save_load(void* in, io_buf& model_file, bool read, bool text)
initialize_regressor(*all);
if (all->per_feature_regularizer_input != "")
{
- regularizers = (weight *)calloc(2*length, sizeof(weight));
- if (regularizers == NULL)
+ b->regularizers = (weight *)calloc(2*length, sizeof(weight));
+ if (b->regularizers == NULL)
{
cerr << all->program_name << ": Failed to allocate regularizers array: try decreasing -b <bits>" << endl;
exit (1);
@@ -926,18 +934,18 @@ void save_load(void* in, io_buf& model_file, bool read, bool text)
}
int m = all->m;
- mem_stride = (m==0) ? CG_EXTRA : 2*m;
- mem = (float*) malloc(sizeof(float)*all->length()*(mem_stride));
- rho = (double*) malloc(sizeof(double)*m);
- alpha = (double*) malloc(sizeof(double)*m);
+ b->mem_stride = (m==0) ? CG_EXTRA : 2*m;
+ b->mem = (float*) malloc(sizeof(float)*all->length()*(b->mem_stride));
+ b->rho = (double*) malloc(sizeof(double)*m);
+ b->alpha = (double*) malloc(sizeof(double)*m);
if (!all->quiet)
{
- fprintf(stderr, "m = %d\nAllocated %luM for weights and mem\n", m, (long unsigned int)all->length()*(sizeof(float)*(mem_stride)+sizeof(weight)*all->stride) >> 20);
+ fprintf(stderr, "m = %d\nAllocated %luM for weights and mem\n", m, (long unsigned int)all->length()*(sizeof(float)*(b->mem_stride)+sizeof(weight)*all->stride) >> 20);
}
- net_time = 0.0;
- ftime(&t_start_global);
+ b->net_time = 0.0;
+ ftime(&b->t_start_global);
if (!all->quiet)
{
@@ -947,13 +955,13 @@ void save_load(void* in, io_buf& model_file, bool read, bool text)
cerr.precision(5);
}
- if (regularizers != NULL)
+ if (b->regularizers != NULL)
all->l2_lambda = 1; // To make sure we are adding the regularization
- output_regularizer = (all->per_feature_regularizer_output != "" || all->per_feature_regularizer_text != "");
- reset_state(*all, false);
+ b->output_regularizer = (all->per_feature_regularizer_output != "" || all->per_feature_regularizer_text != "");
+ reset_state(*all, *b, false);
}
- bool reg_vector = output_regularizer || all->per_feature_regularizer_input.length() > 0;
+ bool reg_vector = b->output_regularizer || all->per_feature_regularizer_input.length() > 0;
if (model_file.files.size() > 0)
{
char buff[512];
@@ -963,21 +971,22 @@ void save_load(void* in, io_buf& model_file, bool read, bool text)
buff, text_len, text);
if (reg_vector)
- save_load_regularizer(*all, model_file, read, text);
+ save_load_regularizer(*all, *b, model_file, read, text);
else
GD::save_load_regressor(*all, model_file, read, text);
}
}
-void drive(void* in)
+void drive(void* in, void* d)
{
vw* all = (vw*)in;
+ bfgs* b = (bfgs*)d;
example* ec = NULL;
size_t final_pass=all->numpasses-1;
- first_hessian_on = true;
- backstep_on = true;
+ b->first_hessian_on = true;
+ b->backstep_on = true;
while ( true )
{
@@ -986,22 +995,22 @@ void drive(void* in)
assert(ec->in_use);
if (ec->pass<=final_pass) {
- if (ec->pass != current_pass) {
- int status = process_pass(*all);
- if (status != LEARN_OK && final_pass>current_pass) {
- final_pass = current_pass;
+ if (ec->pass != b->current_pass) {
+ int status = process_pass(*all, *b);
+ if (status != LEARN_OK && final_pass>b->current_pass) {
+ final_pass = b->current_pass;
}
- if (output_regularizer && current_pass==final_pass) {
+ if (b->output_regularizer && b->current_pass==final_pass) {
zero_preconditioner(*all);
- preconditioner_pass = true;
+ b->preconditioner_pass = true;
}
}
- process_example(*all, ec);
+ process_example(*all, *b, ec);
}
-
+
return_simple_example(*all, ec);
}
- else if (parser_done(all->p))
+ else if (parser_done(all->p))
{
// finish(all);
return;
@@ -1011,5 +1020,38 @@ void drive(void* in)
}
}
-}
+void parse_args(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
+{
+ bfgs* b = (bfgs*)calloc(1,sizeof(bfgs));
+ b->wolfe1_bound = 0.01;
+ b->first_hessian_on=true;
+ b->first_pass = true;
+ b->gradient_pass = true;
+ b->preconditioner_pass = true;
+
+ learner t = {b,drive,learn,finish,save_load};
+ all.l = t;
+ all.bfgs = true;
+ all.stride = 4;
+
+ if (vm.count("hessian_on") || all.m==0) {
+ all.hessian_on = true;
+ }
+ if (!all.quiet) {
+ if (all.m>0)
+ cerr << "enabling BFGS based optimization ";
+ else
+ cerr << "enabling conjugate gradient optimization via BFGS ";
+ if (all.hessian_on)
+ cerr << "with curvature calculation" << endl;
+ else
+ cerr << "**without** curvature calculation" << endl;
+ }
+ if (all.numpasses < 2)
+ {
+ cout << "you must make at least 2 passes to use BFGS" << endl;
+ exit(1);
+ }
+}
+}
diff --git a/vowpalwabbit/bfgs.h b/vowpalwabbit/bfgs.h
index def78634..fe1356ac 100644
--- a/vowpalwabbit/bfgs.h
+++ b/vowpalwabbit/bfgs.h
@@ -8,11 +8,7 @@ license as described in the file LICENSE.
#include "gd.h"
namespace BFGS {
-
- void drive(void*);
- void finish(void*);
- void learn(void*, example* ec);
- void save_load(void* in, io_buf& model_file, bool read, bool text);
+ void parse_args(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file);
}
#endif
diff --git a/vowpalwabbit/cb.cc b/vowpalwabbit/cb.cc
index e631f334..6fbd6c04 100644
--- a/vowpalwabbit/cb.cc
+++ b/vowpalwabbit/cb.cc
@@ -15,14 +15,18 @@ license as described in the file LICENSE.
namespace CB
{
- uint32_t increment = 0;
- size_t cb_type = 0;
- CSOAA::label cb_cs_ld;
- float avg_loss_regressors = 0.;
- size_t nb_ex_regressors = 0;
- float last_pred_reg = 0.;
- float last_correct_cost = 0.;
- bool first_print_call = true;
+ struct cb {
+ uint32_t increment;
+ size_t cb_type;
+ CSOAA::label cb_cs_ld;
+ float avg_loss_regressors;
+ size_t nb_ex_regressors;
+ float last_pred_reg;
+ float last_correct_cost;
+ bool first_print_call;
+
+ learner base;
+ };
bool know_all_cost_example(CB::label* ld)
{
@@ -212,14 +216,8 @@ namespace CB
return NULL;
}
- void (*base_learner)(void*, example*) = NULL; //base learning algorithm (gd,bfgs,etc...) for training regressors of cb
- void (*base_learner_cs)(void*, example*) = NULL; //base learner for cost-sensitive data
- void (*base_finish)(void*) = NULL;
-
- void gen_cs_example_ips(void* a, example* ec, CSOAA::label& cs_ld)
- {
- //this implements the inverse propensity score method, where cost are importance weighted by the probability of the chosen action
- vw* all = (vw*)a;
+ void gen_cs_example_ips(vw& all, cb& c, example* ec, CSOAA::label& cs_ld)
+ {//this implements the inverse propensity score method, where cost are importance weighted by the probability of the chosen action
CB::label* ld = (CB::label*)ec->ld;
cb_class* cl_obs = get_observed_cost(ld);
@@ -228,7 +226,7 @@ namespace CB
cs_ld.costs.erase();
if( ld->costs.size() == 1) { //this is a typical example where we can perform all actions
//in this case generate cost-sensitive example with all actions
- for(uint32_t i = 1; i <= all->sd->k; i++)
+ for(uint32_t i = 1; i <= all.sd->k; i++)
{
CSOAA::wclass wc;
wc.wap_value = 0.;
@@ -242,10 +240,10 @@ namespace CB
//ips can be thought as the doubly robust method with a fixed regressor that predicts 0 costs for everything
//update the loss of this regressor
- nb_ex_regressors++;
- avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x)*(cl_obs->x) - avg_loss_regressors );
- last_pred_reg = 0;
- last_correct_cost = cl_obs->x;
+ c.nb_ex_regressors++;
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x)*(cl_obs->x) - c.avg_loss_regressors );
+ c.last_pred_reg = 0;
+ c.last_correct_cost = cl_obs->x;
}
cs_ld.costs.push_back(wc );
@@ -267,10 +265,10 @@ namespace CB
//ips can be thought as the doubly robust method with a fixed regressor that predicts 0 costs for everything
//update the loss of this regressor
- nb_ex_regressors++;
- avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x)*(cl_obs->x) - avg_loss_regressors );
- last_pred_reg = 0;
- last_correct_cost = cl_obs->x;
+ c.nb_ex_regressors++;
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x)*(cl_obs->x) - c.avg_loss_regressors );
+ c.last_pred_reg = 0;
+ c.last_correct_cost = cl_obs->x;
}
cs_ld.costs.push_back( wc );
@@ -279,9 +277,8 @@ namespace CB
}
- float get_cost_pred(void* a, example* ec, uint32_t index)
+ float get_cost_pred(vw& all, cb& c, example* ec, uint32_t index)
{
- vw* all = (vw*)a;
CB::label* ld = (CB::label*)ec->ld;
label_data simple_temp;
@@ -292,11 +289,11 @@ namespace CB
ec->ld = &simple_temp;
ec->partial_prediction = 0.;
- uint32_t desired_increment = increment * (2*index-1);
+ uint32_t desired_increment = c.increment * (2*index-1);
- update_example_indicies(all->audit, ec, desired_increment);
- base_learner(all, ec);
- update_example_indicies(all->audit, ec, -desired_increment);
+ update_example_indicies(all.audit, ec, desired_increment);
+ all.scorer.learn(&all, all.scorer.data, ec);
+ update_example_indicies(all.audit, ec, -desired_increment);
ec->ld = ld;
@@ -307,10 +304,9 @@ namespace CB
}
//this function below was a test to see if we save time by carefully organizing the feature offset/regression calls, but seems to be same time as gen_cs_example_dm
- void gen_cs_example_dm2(void* a, example* ec, CSOAA::label& cs_ld)
+ void gen_cs_example_dm2(vw& all, cb& c, example* ec, CSOAA::label& cs_ld)
{
//this implements the direct estimation method, where costs are directly specified by the learned regressor.
- vw* all = (vw*)a;
CB::label* ld = (CB::label*)ec->ld;
cb_class* cl_obs = get_observed_cost(ld);
@@ -329,16 +325,16 @@ namespace CB
cs_ld.costs.erase();
if( ld->costs.size() == 1) { //this is a typical example where we can perform all actions
//in this case generate cost-sensitive example with all actions
- for( uint32_t i = 1; i <= all->sd->k; i++)
+ for( uint32_t i = 1; i <= all.sd->k; i++)
{
CSOAA::wclass wc;
wc.wap_value = 0.;
ec->partial_prediction = 0.;
- desired_increment = increment * (2*i-1);
- update_example_indicies(all->audit, ec, desired_increment-current_increment);
+ desired_increment = c.increment * (2*i-1);
+ update_example_indicies(all.audit, ec, desired_increment-current_increment);
current_increment = desired_increment;
- base_learner(all, ec);
+ all.scorer.learn(&all, all.scorer.data, ec);
//get cost prediction for this action
wc.x = ec->partial_prediction;
@@ -347,10 +343,10 @@ namespace CB
wc.wap_value = 0.;
if( cl_obs != NULL && cl_obs->weight_index == i ) {
- nb_ex_regressors++;
- avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - avg_loss_regressors );
- last_pred_reg = wc.x;
- last_correct_cost = cl_obs->x;
+ c.nb_ex_regressors++;
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - c.avg_loss_regressors );
+ c.last_pred_reg = wc.x;
+ c.last_correct_cost = cl_obs->x;
}
cs_ld.costs.push_back( wc );
@@ -364,10 +360,10 @@ namespace CB
wc.wap_value = 0.;
ec->partial_prediction = 0.;
- desired_increment = increment * (2*cl->weight_index-1);
- update_example_indicies(all->audit, ec, desired_increment-current_increment);
+ desired_increment = c.increment * (2*cl->weight_index-1);
+ update_example_indicies(all.audit, ec, desired_increment-current_increment);
current_increment = desired_increment;
- base_learner(all, ec);
+ all.scorer.learn(&all, all.scorer.data, ec);
//get cost prediction for this action
wc.x = ec->partial_prediction;
@@ -376,10 +372,10 @@ namespace CB
wc.wap_value = 0.;
if( cl_obs != NULL && cl_obs->weight_index == cl->weight_index ) {
- nb_ex_regressors++;
- avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - avg_loss_regressors );
- last_pred_reg = wc.x;
- last_correct_cost = cl_obs->x;
+ c.nb_ex_regressors++;
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - c.avg_loss_regressors );
+ c.last_pred_reg = wc.x;
+ c.last_correct_cost = cl_obs->x;
}
cs_ld.costs.push_back( wc );
@@ -388,13 +384,12 @@ namespace CB
ec->ld = ld;
ec->partial_prediction = 0.;
- update_example_indicies(all->audit, ec, -current_increment);
+ update_example_indicies(all.audit, ec, -current_increment);
}
- void gen_cs_example_dm(void* a, example* ec, CSOAA::label& cs_ld)
+ void gen_cs_example_dm(vw& all, cb& c, example* ec, CSOAA::label& cs_ld)
{
//this implements the direct estimation method, where costs are directly specified by the learned regressor.
- vw* all = (vw*)a;
CB::label* ld = (CB::label*)ec->ld;
cb_class* cl_obs = get_observed_cost(ld);
@@ -403,22 +398,22 @@ namespace CB
cs_ld.costs.erase();
if( ld->costs.size() == 1) { //this is a typical example where we can perform all actions
//in this case generate cost-sensitive example with all actions
- for(uint32_t i = 1; i <= all->sd->k; i++)
+ for(uint32_t i = 1; i <= all.sd->k; i++)
{
CSOAA::wclass wc;
wc.wap_value = 0.;
//get cost prediction for this action
- wc.x = get_cost_pred(a,ec,i);
+ wc.x = get_cost_pred(all, c,ec,i);
wc.weight_index = i;
wc.partial_prediction = 0.;
wc.wap_value = 0.;
if( cl_obs != NULL && cl_obs->weight_index == i ) {
- nb_ex_regressors++;
- avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - avg_loss_regressors );
- last_pred_reg = wc.x;
- last_correct_cost = cl_obs->x;
+ c.nb_ex_regressors++;
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - c.avg_loss_regressors );
+ c.last_pred_reg = wc.x;
+ c.last_correct_cost = cl_obs->x;
}
cs_ld.costs.push_back( wc );
@@ -432,16 +427,16 @@ namespace CB
wc.wap_value = 0.;
//get cost prediction for this action
- wc.x = get_cost_pred(a,ec,cl->weight_index);
+ wc.x = get_cost_pred(all, c,ec,cl->weight_index);
wc.weight_index = cl->weight_index;
wc.partial_prediction = 0.;
wc.wap_value = 0.;
if( cl_obs != NULL && cl_obs->weight_index == cl->weight_index ) {
- nb_ex_regressors++;
- avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - avg_loss_regressors );
- last_pred_reg = wc.x;
- last_correct_cost = cl_obs->x;
+ c.nb_ex_regressors++;
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - c.avg_loss_regressors );
+ c.last_pred_reg = wc.x;
+ c.last_correct_cost = cl_obs->x;
}
cs_ld.costs.push_back( wc );
@@ -449,35 +444,33 @@ namespace CB
}
}
- void gen_cs_example_dr(void* a, example* ec, CSOAA::label& cs_ld)
- {
- //this implements the doubly robust method
- vw* all = (vw*)a;
+ void gen_cs_example_dr(vw& all, cb& c, example* ec, CSOAA::label& cs_ld)
+ {//this implements the doubly robust method
CB::label* ld = (CB::label*)ec->ld;
-
+
cb_class* cl_obs = get_observed_cost(ld);
-
+
//generate cost sensitive example
cs_ld.costs.erase();
if( ld->costs.size() == 1) { //this is a typical example where we can perform all actions
//in this case generate cost-sensitive example with all actions
- for(uint32_t i = 1; i <= all->sd->k; i++)
+ for(uint32_t i = 1; i <= all.sd->k; i++)
{
CSOAA::wclass wc;
wc.wap_value = 0.;
//get cost prediction for this label
- wc.x = get_cost_pred(a,ec,i);
+ wc.x = get_cost_pred(all, c,ec,i);
wc.weight_index = i;
wc.partial_prediction = 0.;
wc.wap_value = 0.;
//add correction if we observed cost for this action and regressor is wrong
if( cl_obs != NULL && cl_obs->weight_index == i ) {
- nb_ex_regressors++;
- avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - avg_loss_regressors );
- last_pred_reg = wc.x;
- last_correct_cost = cl_obs->x;
+ c.nb_ex_regressors++;
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - c.avg_loss_regressors );
+ c.last_pred_reg = wc.x;
+ c.last_correct_cost = cl_obs->x;
wc.x += (cl_obs->x - wc.x) / cl_obs->prob_action;
}
@@ -492,17 +485,17 @@ namespace CB
wc.wap_value = 0.;
//get cost prediction for this label
- wc.x = get_cost_pred(a,ec,cl->weight_index);
+ wc.x = get_cost_pred(all,c,ec,cl->weight_index);
wc.weight_index = cl->weight_index;
wc.partial_prediction = 0.;
wc.wap_value = 0.;
//add correction if we observed cost for this action and regressor is wrong
if( cl_obs != NULL && cl_obs->weight_index == cl->weight_index ) {
- nb_ex_regressors++;
- avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - avg_loss_regressors );
- last_pred_reg = wc.x;
- last_correct_cost = cl_obs->x;
+ c.nb_ex_regressors++;
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - c.avg_loss_regressors );
+ c.last_pred_reg = wc.x;
+ c.last_correct_cost = cl_obs->x;
wc.x += (cl_obs->x - wc.x) / cl_obs->prob_action;
}
@@ -511,7 +504,7 @@ namespace CB
}
}
- void cb_test_to_cs_test_label(void* a, example* ec, CSOAA::label& cs_ld)
+ void cb_test_to_cs_test_label(vw& all, example* ec, CSOAA::label& cs_ld)
{
CB::label* ld = (CB::label*)ec->ld;
@@ -534,8 +527,9 @@ namespace CB
}
}
- void learn(void* a, example* ec) {
+ void learn(void* a, void* d, example* ec) {
vw* all = (vw*)a;
+ cb* c = (cb*)d;
CB::label* ld = (CB::label*)ec->ld;
float prediction = 1;
@@ -543,10 +537,10 @@ namespace CB
if( CB::is_test_label(ld) )
{
//if so just query base cost-sensitive learner
- cb_test_to_cs_test_label(a,ec,cb_cs_ld);
+ cb_test_to_cs_test_label(*all,ec,c->cb_cs_ld);
- ec->ld = &cb_cs_ld;
- base_learner_cs(all,ec);
+ ec->ld = &c->cb_cs_ld;
+ c->base.learn(all,c->base.data,ec);
ec->ld = ld;
return;
}
@@ -554,33 +548,33 @@ namespace CB
//now this is a training example
//generate a cost-sensitive example to update classifiers
- switch(cb_type)
+ switch(c->cb_type)
{
case CB_TYPE_IPS:
- gen_cs_example_ips(a,ec,cb_cs_ld);
+ gen_cs_example_ips(*all,*c,ec,c->cb_cs_ld);
break;
case CB_TYPE_DM:
- gen_cs_example_dm(a,ec,cb_cs_ld);
+ gen_cs_example_dm(*all,*c,ec,c->cb_cs_ld);
break;
case CB_TYPE_DR:
- gen_cs_example_dr(a,ec,cb_cs_ld);
+ gen_cs_example_dr(*all,*c,ec,c->cb_cs_ld);
break;
default:
- std::cerr << "Unknown cb_type specified for contextual bandit learning: " << cb_type << ". Exiting." << endl;
+ std::cerr << "Unknown cb_type specified for contextual bandit learning: " << c->cb_type << ". Exiting." << endl;
exit(1);
}
//update classifiers with cost-sensitive exemple
- ec->ld = &cb_cs_ld;
+ ec->ld = &c->cb_cs_ld;
ec->partial_prediction = 0.;
- base_learner_cs(all,ec);
+ c->base.learn(all,c->base.data,ec);
ec->ld = ld;
//store current class prediction
prediction = ec->final_prediction;
//update our regressors if we are training regressors
- if( cb_type == CB_TYPE_DM || cb_type == CB_TYPE_DR )
+ if( c->cb_type == CB_TYPE_DM || c->cb_type == CB_TYPE_DR )
{
cb_class* cl_obs = get_observed_cost(ld);
@@ -595,10 +589,10 @@ namespace CB
ec->ld = &simple_temp;
- uint32_t desired_increment = increment * (2*i-1);
+ uint32_t desired_increment = c->increment * (2*i-1);
update_example_indicies(all->audit, ec, desired_increment);
ec->partial_prediction = 0.;
- base_learner(all, ec);
+ all->scorer.learn(all, all->scorer.data, ec);
update_example_indicies(all->audit, ec, -desired_increment);
ec->ld = ld;
@@ -608,12 +602,12 @@ namespace CB
ec->final_prediction = prediction;
}
- void print_update(vw& all, bool is_test, example *ec)
+ void print_update(vw& all, cb& c, bool is_test, example *ec)
{
- if( first_print_call )
+ if( c.first_print_call )
{
fprintf(stderr, "*estimate* *estimate* avglossreg last pred last correct\n");
- first_print_call = false;
+ c.first_print_call = false;
}
if (all.sd->weighted_examples > all.sd->dump_interval && !all.quiet && !all.bfgs)
@@ -632,9 +626,9 @@ namespace CB
label_buf,
(long unsigned int)*(OAA::prediction_t*)&ec->final_prediction,
(long unsigned int)ec->num_features,
- avg_loss_regressors,
- last_pred_reg,
- last_correct_cost);
+ c.avg_loss_regressors,
+ c.last_pred_reg,
+ c.last_correct_cost);
all.sd->sum_loss_since_last_dump = 0.0;
all.sd->old_weighted_examples = all.sd->weighted_examples;
@@ -642,7 +636,7 @@ namespace CB
}
}
- void output_example(vw& all, example* ec)
+ void output_example(vw& all, cb& c, example* ec)
{
CB::label* ld = (CB::label*)ec->ld;
all.sd->weighted_examples += 1.;
@@ -666,7 +660,7 @@ namespace CB
}
else {
//we do not know exact cost of each action, so evaluate on generated cost-sensitive example currently stored in cb_cs_ld
- for (CSOAA::wclass *cl = cb_cs_ld.costs.begin; cl != cb_cs_ld.costs.end; cl ++) {
+ for (CSOAA::wclass *cl = c.cb_cs_ld.costs.begin; cl != c.cb_cs_ld.costs.end; cl ++) {
if (cl->weight_index == pred)
chosen_loss = cl->x;
if (cl->x < min)
@@ -692,26 +686,28 @@ namespace CB
all.sd->example_number++;
- print_update(all, CB::is_test_label((CB::label*)ec->ld), ec);
+ print_update(all, c, CB::is_test_label((CB::label*)ec->ld), ec);
}
- void finish(void* a)
+ void finish(void* a, void* d)
{
- vw* all = (vw*)a;
- cb_cs_ld.costs.delete_v();
- base_finish(all);
+ cb* c=(cb*)d;
+ c->base.finish(a,c->base.data);
+ c->cb_cs_ld.costs.delete_v();
+ free(c);
}
- void drive_cb(void* in)
+ void drive(void* in, void* d)
{
vw*all = (vw*)in;
+ cb* c = (cb*)d;
example* ec = NULL;
while ( true )
{
if ((ec = get_example(all->p)) != NULL)//semiblocking operation.
{
- learn(all, ec);
- output_example(*all, ec);
+ learn(all, d, ec);
+ output_example(*all, *c, ec);
VW::finish_example(*all, ec);
}
else if (parser_done(all->p))
@@ -724,6 +720,8 @@ namespace CB
void parse_flags(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
{
+ cb* c = (cb*)calloc(1, sizeof(cb));
+ c->first_print_call = true;
po::options_description desc("CB options");
desc.add_options()
("cb_type", po::value<string>(), "contextual bandit method to use in {ips,dm,dr}");
@@ -772,47 +770,37 @@ namespace CB
}
if (type_string.compare("dr") == 0) {
- cb_type = CB_TYPE_DR;
+ c->cb_type = CB_TYPE_DR;
all.base_learner_nb_w *= nb_actions * 2;
}
else if (type_string.compare("dm") == 0) {
- cb_type = CB_TYPE_DM;
+ c->cb_type = CB_TYPE_DM;
all.base_learner_nb_w *= nb_actions * 2;
}
else if (type_string.compare("ips") == 0) {
- cb_type = CB_TYPE_IPS;
+ c->cb_type = CB_TYPE_IPS;
all.base_learner_nb_w *= nb_actions;
}
else {
std::cerr << "warning: cb_type must be in {'ips','dm','dr'}; resetting to dr." << std::endl;
- cb_type = CB_TYPE_DR;
+ c->cb_type = CB_TYPE_DR;
all.base_learner_nb_w *= nb_actions * 2;
}
}
else {
//by default use doubly robust
- cb_type = CB_TYPE_DR;
+ c->cb_type = CB_TYPE_DR;
all.base_learner_nb_w *= nb_actions * 2;
all.options_from_file.append(" --cb_type dr");
}
- increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride;
+ c->increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride;
*(all.p->lp) = CB::cb_label_parser;
all.sd->k = nb_actions;
- all.driver = drive_cb;
- //this parsing is done after the cost-sensitive parsing, so all.learn currently points to the base cs learner
- //and all.base_learn points to gd/bfgs base learner
- base_learner_cs = all.learn;
- base_learner = all.base_learn;
-
- all.learn = learn;
- base_finish = all.finish;
- all.finish = finish;
-
+ learner l = {c, drive, learn, finish, all.l.save_load};
+ all.l = l;
}
-
-
}
diff --git a/vowpalwabbit/cb.h b/vowpalwabbit/cb.h
index c2e40098..e185eeba 100644
--- a/vowpalwabbit/cb.h
+++ b/vowpalwabbit/cb.h
@@ -34,7 +34,6 @@ namespace CB {
void parse_flags(vw& all, std::vector<std::string>&, po::variables_map& vm, po::variables_map& vm_file);
- void output_example(vw& all, example* ec);
size_t read_cached_label(shared_data* sd, void* v, io_buf& cache);
void cache_label(void* v, io_buf& cache);
void default_label(void* v);
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index e1af04ca..fb44b73d 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -17,6 +17,10 @@ license as described in the file LICENSE.
using namespace std;
namespace CSOAA {
+ struct csoaa{
+ uint32_t csoaa_increment;
+ learner base;
+ };
void name_value(substring &s, v_array<substring>& name, float &v)
{
@@ -274,11 +278,9 @@ namespace CSOAA {
print_update(all, is_test_label((label*)ec->ld), ec);
}
- void (*base_learner)(void*, example*) = NULL;
- void (*base_finish)(void*) = NULL;
-
- void learn(void* a, example* ec) {
+ void learn(void* a, void* d, example* ec) {
vw* all = (vw*)a;
+ csoaa* c = (csoaa*)d;
label* ld = (label*)ec->ld;
size_t prediction = 1;
float score = FLT_MAX;
@@ -303,14 +305,14 @@ namespace CSOAA {
ec->ld = &simple_temp;
- uint32_t desired_increment = all->csoaa_increment * (i-1);
+ uint32_t desired_increment = c->csoaa_increment * (i-1);
if (desired_increment != current_increment) {
update_example_indicies(all->audit, ec, desired_increment - current_increment);
current_increment = desired_increment;
}
- base_learner(all, ec);
+ c->base.learn(all, c->base.data, ec);
cl->partial_prediction = ec->partial_prediction;
if (ec->partial_prediction < score || (ec->partial_prediction == score && i < prediction)) {
score = ec->partial_prediction;
@@ -324,13 +326,14 @@ namespace CSOAA {
update_example_indicies(all->audit, ec, -current_increment);
}
- void finish(void* a)
+ void finish(void* a, void* d)
{
- vw* all = (vw*)a;
- base_finish(all);
+ csoaa* c=(csoaa*)d;
+ c->base.finish(a,c->base.data);
+ free(c);
}
- void drive_csoaa(void* in)
+ void drive(void* in, void* d)
{
vw* all = (vw*)in;
example* ec = NULL;
@@ -338,23 +341,21 @@ namespace CSOAA {
{
if ((ec = get_example(all->p)) != NULL)//semiblocking operation.
{
- learn(all, ec);
+ learn(all, d, ec);
output_example(*all, ec);
if (ec->in_use)
VW::finish_example(*all, ec);
}
else if (parser_done(all->p))
- {
- // finish(all);
- return;
- }
+ return;
else
;
}
- }
+ }
void parse_flags(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
{
+ csoaa* c=(csoaa*)calloc(1,sizeof(csoaa));
//first parse for number of actions
uint32_t nb_actions = 0;
if( vm_file.count("csoaa") ) { //if loaded options from regressor
@@ -372,18 +373,14 @@ namespace CSOAA {
}
*(all.p->lp) = cs_label_parser;
- if (!all.is_noop)
- all.driver = drive_csoaa;
+ c->base=all.l;
+ c->csoaa_increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride;
all.sd->k = nb_actions;
all.base_learner_nb_w *= nb_actions;
- all.csoaa_increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride;
- base_learner = all.learn;
- all.base_learn = all.learn;
- all.learn = learn;
- base_finish = all.finish;
- all.finish = finish;
+ learner l = {c, drive, learn, finish, all.l.save_load};
+ all.l = l;
}
bool example_is_test(example* ec)
@@ -398,15 +395,153 @@ namespace CSOAA {
}
namespace CSOAA_AND_WAP_LDF {
- v_array<example*> ec_seq = v_array<example*>();
- size_t read_example_this_loop = 0;
- bool need_to_clear = true;
- bool is_singleline = true;
- bool is_wap = false;
- float csoaa_example_t = 0;
- void (*base_learner)(void*, example*) = NULL;
- void (*base_finish)(void*) = NULL;
+ struct ldf {
+ v_array<example*> ec_seq;
+ v_hashmap< size_t, v_array<feature> > label_features;
+
+ size_t read_example_this_loop;
+ bool need_to_clear;
+ bool is_singleline;
+ bool is_wap;
+ float csoaa_example_t;
+ learner base;
+ };
+
+namespace LabelDict {
+ bool size_t_eq(size_t a, size_t b) { return (a==b); }
+
+ size_t hash_lab(size_t lab) { return 328051 + 94389193 * lab; }
+
+ bool ec_is_label_definition(example*ec) // label defs look like "___:-1"
+ {
+ v_array<CSOAA::wclass> costs = ((CSOAA::label*)ec->ld)->costs;
+ for (size_t j=0; j<costs.size(); j++)
+ if (costs[j].x >= 0.) return false;
+ if (ec->indices.size() == 0) return false;
+ if (ec->indices.size() > 2) return false;
+ if (ec->indices[0] != 'l') return false;
+ return true;
+ }
+
+ bool ec_is_example_header(example*ec) // example headers look like "0:-1"
+ {
+ v_array<CSOAA::wclass> costs = ((CSOAA::label*)ec->ld)->costs;
+ if (costs.size() != 1) return false;
+ if (costs[0].weight_index != 0) return false;
+ if (costs[0].x >= 0) return false;
+ return true;
+ }
+
+ bool ec_seq_is_label_definition(ldf& l, v_array<example*>ec_seq)
+ {
+ if (l.ec_seq.size() == 0) return false;
+ bool is_lab = ec_is_label_definition(l.ec_seq[0]);
+ for (size_t i=1; i<l.ec_seq.size(); i++) {
+ if (is_lab != ec_is_label_definition(l.ec_seq[i])) {
+ if (!((i == l.ec_seq.size()-1) && (OAA::example_is_newline(l.ec_seq[i])))) {
+ cerr << "error: mixed label definition and examples in ldf data!" << endl;
+ exit(-1);
+ }
+ }
+ }
+ return is_lab;
+ }
+
+ void del_example_namespace(example*ec, char ns, v_array<feature> features) {
+ size_t numf = features.size();
+ ec->num_features -= numf;
+
+ assert (ec->atomics[(size_t)ns].size() >= numf);
+ if (ec->atomics[(size_t)ns].size() == numf) { // did NOT have ns
+ assert(ec->indices.size() > 0);
+ assert(ec->indices[ec->indices.size()-1] == (size_t)ns);
+ ec->indices.pop();
+ ec->total_sum_feat_sq -= ec->sum_feat_sq[(size_t)ns];
+ ec->atomics[(size_t)ns].erase();
+ ec->sum_feat_sq[(size_t)ns] = 0.;
+ } else { // DID have ns
+ for (feature*f=features.begin; f!=features.end; f++) {
+ ec->sum_feat_sq[(size_t)ns] -= f->x * f->x;
+ ec->atomics[(size_t)ns].pop();
+ }
+ }
+ }
+
+ void add_example_namespace(example*ec, char ns, v_array<feature> features) {
+ bool has_ns = false;
+ for (size_t i=0; i<ec->indices.size(); i++) {
+ if (ec->indices[i] == (size_t)ns) {
+ has_ns = true;
+ break;
+ }
+ }
+ if (has_ns) {
+ ec->total_sum_feat_sq -= ec->sum_feat_sq[(size_t)ns];
+ } else {
+ ec->indices.push_back((size_t)ns);
+ ec->sum_feat_sq[(size_t)ns] = 0;
+ }
+
+ for (feature*f=features.begin; f!=features.end; f++) {
+ ec->sum_feat_sq[(size_t)ns] += f->x * f->x;
+ ec->atomics[(size_t)ns].push_back(*f);
+ }
+
+ ec->num_features += features.size();
+ ec->total_sum_feat_sq += ec->sum_feat_sq[(size_t)ns];
+ }
+
+
+
+ void add_example_namespaces_from_example(example*target, example*source) {
+ for (unsigned char* idx=source->indices.begin; idx!=source->indices.end; idx++) {
+ if (*idx == constant_namespace) continue;
+ add_example_namespace(target, (char)*idx, source->atomics[*idx]);
+ }
+ }
+
+ void del_example_namespaces_from_example(example*target, example*source) {
+ //for (size_t*idx=source->indices.begin; idx!=source->indices.end; idx++) {
+ unsigned char* idx = source->indices.end;
+ idx--;
+ for (; idx>=source->indices.begin; idx--) {
+ if (*idx == constant_namespace) continue;
+ del_example_namespace(target, (char)*idx, source->atomics[*idx]);
+ }
+ }
+
+ void add_example_namespace_from_memory(ldf& l, example*ec, size_t lab) {
+ size_t lab_hash = hash_lab(lab);
+ v_array<feature> features = l.label_features.get(lab, lab_hash);
+ if (features.size() == 0) return;
+ add_example_namespace(ec, 'l', features);
+ }
+
+ void del_example_namespace_from_memory(ldf& l, example* ec, size_t lab) {
+ size_t lab_hash = hash_lab(lab);
+ v_array<feature> features = l.label_features.get(lab, lab_hash);
+ if (features.size() == 0) return;
+ del_example_namespace(ec, 'l', features);
+ }
+
+ void set_label_features(ldf& l, size_t lab, v_array<feature>features) {
+ size_t lab_hash = hash_lab(lab);
+ if (l.label_features.contains(lab, lab_hash)) { return; }
+ l.label_features.put_after_get(lab, lab_hash, features);
+ }
+
+ void free_label_features(ldf& l) {
+ void* label_iter = l.label_features.iterator();
+ while (label_iter != NULL) {
+ v_array<feature> features = l.label_features.iterator_get_value(label_iter);
+ features.erase();
+ features.delete_v();
+
+ label_iter = l.label_features.iterator_next(label_iter);
+ }
+ }
+}
inline bool cmp_wclass_ptr(const CSOAA::wclass* a, const CSOAA::wclass* b) { return a->x < b->x; }
@@ -479,7 +614,7 @@ namespace CSOAA_AND_WAP_LDF {
ec->indices.decr();
}
- void make_single_prediction(vw& all, example*ec, size_t*prediction, float*min_score) {
+ void make_single_prediction(vw& all, ldf& l, example*ec, size_t*prediction, float*min_score) {
label *ld = (label*)ec->ld;
v_array<CSOAA::wclass> costs = ld->costs;
label_data simple_label;
@@ -490,10 +625,10 @@ namespace CSOAA_AND_WAP_LDF {
simple_label.weight = 0.;
ec->partial_prediction = 0.;
- LabelDict::add_example_namespace_from_memory(ec, costs[j].weight_index);
+ LabelDict::add_example_namespace_from_memory(l, ec, costs[j].weight_index);
ec->ld = &simple_label;
- base_learner(&all, ec); // make a prediction
+ l.base.learn(&all, l.base.data, ec); // make a prediction
costs[j].partial_prediction = ec->partial_prediction;
if (ec->partial_prediction < *min_score) {
@@ -501,7 +636,7 @@ namespace CSOAA_AND_WAP_LDF {
*prediction = costs[j].weight_index;
}
- LabelDict::del_example_namespace_from_memory(ec, costs[j].weight_index);
+ LabelDict::del_example_namespace_from_memory(l, ec, costs[j].weight_index);
}
ec->ld = ld;
@@ -509,46 +644,46 @@ namespace CSOAA_AND_WAP_LDF {
- void do_actual_learning_wap(vw& all, size_t start_K)
+ void do_actual_learning_wap(vw& all, ldf& l, size_t start_K)
{
- size_t K = ec_seq.size();
- bool isTest = CSOAA::example_is_test(ec_seq[start_K]);
+ size_t K = l.ec_seq.size();
+ bool isTest = CSOAA::example_is_test(l.ec_seq[start_K]);
size_t prediction = 0;
float min_score = FLT_MAX;
v_hashmap<size_t,float> hit_labels(8, 0., NULL);
for (size_t k=start_K; k<K; k++) {
- example *ec = ec_seq.begin[k];
+ example *ec = l.ec_seq.begin[k];
if (CSOAA::example_is_test(ec) != isTest) {
isTest = true;
cerr << "warning: wap_ldf got mix of train/test data; assuming test" << endl;
}
- if (LabelDict::ec_is_example_header(ec_seq[k])) {
+ if (LabelDict::ec_is_example_header(l.ec_seq[k])) {
cerr << "warning: example headers at position " << k << ": can only have in initial position!" << endl;
exit(-1);
}
- make_single_prediction(all, ec, &prediction, &min_score);
+ make_single_prediction(all, l, ec, &prediction, &min_score);
}
// do actual learning
vector<CSOAA::wclass*> all_costs;
if (all.training && !isTest) {
for (size_t k=start_K; k<K; k++) {
- v_array<CSOAA::wclass> this_costs = ((label*)ec_seq.begin[k]->ld)->costs;
+ v_array<CSOAA::wclass> this_costs = ((label*)l.ec_seq.begin[k]->ld)->costs;
for (size_t j=0; j<this_costs.size(); j++)
all_costs.push_back(&this_costs[j]);
}
compute_wap_values(all_costs);
- csoaa_example_t += 1.;
+ l.csoaa_example_t += 1.;
}
label_data simple_label;
for (size_t k1=start_K; k1<K; k1++) {
- example *ec1 = ec_seq.begin[k1];
+ example *ec1 = l.ec_seq.begin[k1];
label *ld1 = (label*)ec1->ld;
v_array<CSOAA::wclass> costs1 = ld1->costs;
bool prediction_is_me = false;
@@ -558,10 +693,10 @@ namespace CSOAA_AND_WAP_LDF {
for (size_t j1=0; j1<costs1.size(); j1++) {
if (costs1[j1].weight_index == (uint32_t)-1) continue;
if (all.training && !isTest) {
- LabelDict::add_example_namespace_from_memory(ec1, costs1[j1].weight_index);
+ LabelDict::add_example_namespace_from_memory(l, ec1, costs1[j1].weight_index);
for (size_t k2=k1+1; k2<K; k2++) {
- example *ec2 = ec_seq.begin[k2];
+ example *ec2 = l.ec_seq.begin[k2];
label *ld2 = (label*)ec2->ld;
v_array<CSOAA::wclass> costs2 = ld2->costs;
@@ -572,22 +707,22 @@ namespace CSOAA_AND_WAP_LDF {
if (value_diff < 1e-6)
continue;
- LabelDict::add_example_namespace_from_memory(ec2, costs2[j2].weight_index);
+ LabelDict::add_example_namespace_from_memory(l, ec2, costs2[j2].weight_index);
// learn
- ec1->example_t = csoaa_example_t;
+ ec1->example_t = l.csoaa_example_t;
simple_label.initial = 0.;
simple_label.label = (costs1[j1].x < costs2[j2].x) ? -1.0f : 1.0f;
simple_label.weight = value_diff;
ec1->partial_prediction = 0.;
subtract_example(all, ec1, ec2);
- base_learner(&all, ec1);
+ l.base.learn(&all, l.base.data, ec1);
unsubtract_example(all, ec1);
- LabelDict::del_example_namespace_from_memory(ec2, costs2[j2].weight_index);
+ LabelDict::del_example_namespace_from_memory(l, ec2, costs2[j2].weight_index);
}
}
- LabelDict::del_example_namespace_from_memory(ec1, costs1[j1].weight_index);
+ LabelDict::del_example_namespace_from_memory(l, ec1, costs1[j1].weight_index);
}
if (prediction == costs1[j1].weight_index) prediction_is_me = true;
@@ -598,31 +733,31 @@ namespace CSOAA_AND_WAP_LDF {
}
}
- void do_actual_learning_oaa(vw& all, size_t start_K)
+ void do_actual_learning_oaa(vw& all, ldf& l, size_t start_K)
{
- size_t K = ec_seq.size();
+ size_t K = l.ec_seq.size();
size_t prediction = 0;
- bool isTest = CSOAA::example_is_test(ec_seq[start_K]);
+ bool isTest = CSOAA::example_is_test(l.ec_seq[start_K]);
float min_score = FLT_MAX;
for (size_t k=start_K; k<K; k++) {
- example *ec = ec_seq.begin[k];
+ example *ec = l.ec_seq.begin[k];
if (CSOAA::example_is_test(ec) != isTest) {
isTest = true;
cerr << "warning: ldf got mix of train/test data; assuming test" << endl;
}
- if (LabelDict::ec_is_example_header(ec_seq[k])) {
+ if (LabelDict::ec_is_example_header(l.ec_seq[k])) {
cerr << "warning: example headers at position " << k << ": can only have in initial position!" << endl;
exit(-1);
}
- make_single_prediction(all, ec, &prediction, &min_score);
+ make_single_prediction(all, l, ec, &prediction, &min_score);
}
// do actual learning
if (all.training && !isTest)
- csoaa_example_t += 1.;
+ l.csoaa_example_t += 1.;
for (size_t k=start_K; k<K; k++) {
- example *ec = ec_seq.begin[k];
+ example *ec = l.ec_seq.begin[k];
label *ld = (label*)ec->ld;
v_array<CSOAA::wclass> costs = ld->costs;
@@ -632,15 +767,15 @@ namespace CSOAA_AND_WAP_LDF {
for (size_t j=0; j<costs.size(); j++) {
if (all.training && !isTest) {
float example_t = ec->example_t;
- ec->example_t = csoaa_example_t;
+ ec->example_t = l.csoaa_example_t;
simple_label.initial = 0.;
simple_label.label = costs[j].x;
simple_label.weight = 1.;
ec->ld = &simple_label;
ec->partial_prediction = 0.;
- LabelDict::add_example_namespace_from_memory(ec, costs[j].weight_index);
- base_learner(&all, ec);
- LabelDict::del_example_namespace_from_memory(ec, costs[j].weight_index);
+ LabelDict::add_example_namespace_from_memory(l, ec, costs[j].weight_index);
+ l.base.learn(&all, l.base.data, ec);
+ LabelDict::del_example_namespace_from_memory(l, ec, costs[j].weight_index);
ec->example_t = example_t;
}
@@ -660,46 +795,46 @@ namespace CSOAA_AND_WAP_LDF {
}
- void do_actual_learning(vw& all)
+ void do_actual_learning(vw& all, ldf& l)
{
- if (ec_seq.size() <= 0) return; // nothing to do
+ if (l.ec_seq.size() <= 0) return; // nothing to do
/////////////////////// handle label definitions
- if (LabelDict::ec_seq_is_label_definition(ec_seq)) {
- for (size_t i=0; i<ec_seq.size(); i++) {
+ if (LabelDict::ec_seq_is_label_definition(l, l.ec_seq)) {
+ for (size_t i=0; i<l.ec_seq.size(); i++) {
v_array<feature> features;
- for (feature*f=ec_seq[i]->atomics[ec_seq[i]->indices[0]].begin; f!=ec_seq[i]->atomics[ec_seq[i]->indices[0]].end; f++) {
+ for (feature*f=l.ec_seq[i]->atomics[l.ec_seq[i]->indices[0]].begin; f!=l.ec_seq[i]->atomics[l.ec_seq[i]->indices[0]].end; f++) {
feature fnew = { f->x, f->weight_index };
features.push_back(fnew);
}
- v_array<CSOAA::wclass> costs = ((CSOAA::label*)ec_seq[i]->ld)->costs;
+ v_array<CSOAA::wclass> costs = ((CSOAA::label*)l.ec_seq[i]->ld)->costs;
for (size_t j=0; j<costs.size(); j++) {
size_t lab = costs[j].weight_index;
- LabelDict::set_label_features(lab, features);
+ LabelDict::set_label_features(l, lab, features);
}
}
return;
}
/////////////////////// check for headers
- size_t K = ec_seq.size();
+ size_t K = l.ec_seq.size();
size_t start_K = 0;
- if (LabelDict::ec_is_example_header(ec_seq[0])) {
+ if (LabelDict::ec_is_example_header(l.ec_seq[0])) {
start_K = 1;
for (size_t k=1; k<K; k++)
- LabelDict::add_example_namespaces_from_example(ec_seq[k], ec_seq[0]);
+ LabelDict::add_example_namespaces_from_example(l.ec_seq[k], l.ec_seq[0]);
}
/////////////////////// learn
- if (is_wap) do_actual_learning_wap(all, start_K);
- else do_actual_learning_oaa(all, start_K);
-
+ if (l.is_wap) do_actual_learning_wap(all, l, start_K);
+ else do_actual_learning_oaa(all, l, start_K);
+
/////////////////////// remove header
if (start_K > 0)
for (size_t k=1; k<K; k++)
- LabelDict::del_example_namespaces_from_example(ec_seq[k], ec_seq[0]);
+ LabelDict::del_example_namespaces_from_example(l.ec_seq[k], l.ec_seq[0]);
}
@@ -749,149 +884,155 @@ namespace CSOAA_AND_WAP_LDF {
CSOAA::print_update(all, CSOAA::example_is_test(ec), ec);
}
- void output_example_seq(vw& all)
+ void output_example_seq(vw& all, ldf& l)
{
- if ((ec_seq.size() > 0) && !LabelDict::ec_seq_is_label_definition(ec_seq)) {
+ if ((l.ec_seq.size() > 0) && !LabelDict::ec_seq_is_label_definition(l, l.ec_seq)) {
all.sd->weighted_examples += 1;
all.sd->example_number++;
bool hit_loss = false;
- for (example** ecc=ec_seq.begin; ecc!=ec_seq.end; ecc++)
+ for (example** ecc=l.ec_seq.begin; ecc!=l.ec_seq.end; ecc++)
output_example(all, *ecc, hit_loss);
if (all.raw_prediction > 0)
- all.print_text(all.raw_prediction, "", ec_seq[0]->tag);
+ all.print_text(all.raw_prediction, "", l.ec_seq[0]->tag);
}
}
- void clear_seq(vw& all)
+ void clear_seq(vw& all, ldf& l)
{
- if (ec_seq.size() > 0)
- for (example** ecc=ec_seq.begin; ecc!=ec_seq.end; ecc++)
+ if (l.ec_seq.size() > 0)
+ for (example** ecc=l.ec_seq.begin; ecc!=l.ec_seq.end; ecc++)
if ((*ecc)->in_use)
VW::finish_example(all, *ecc);
- ec_seq.erase();
+ l.ec_seq.erase();
}
- void learn_singleline(vw*all, example*ec) {
- if ((!all->training) || CSOAA::example_is_test(ec)) {
+ void learn_singleline(vw& all, ldf& l, example*ec) {
+ if ((!all.training) || CSOAA::example_is_test(ec)) {
size_t prediction = 0;
float min_score = FLT_MAX;
- make_single_prediction(*all, ec, &prediction, &min_score);
+ make_single_prediction(all, l, ec, &prediction, &min_score);
} else {
- ec_seq.erase();
- ec_seq.push_back(ec);
- do_actual_learning(*all);
- ec_seq.erase();
+ l.ec_seq.erase();
+ l.ec_seq.push_back(ec);
+ do_actual_learning(all,l);
+ l.ec_seq.erase();
}
}
- void learn_multiline(vw*all, example *ec) {
- if (ec_seq.size() >= all->p->ring_size - 2) { // give some wiggle room
- if (ec_seq[0]->pass == 0)
+ void learn_multiline(vw& all, ldf& l, example *ec) {
+ if (l.ec_seq.size() >= all.p->ring_size - 2) { // give some wiggle room
+ if (l.ec_seq[0]->pass == 0)
cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << endl;
- do_actual_learning(*all);
- need_to_clear = true;
+ do_actual_learning(all, l);
+ l.need_to_clear = true;
}
- if (need_to_clear) {
- output_example_seq(*all);
- clear_seq(*all);
- need_to_clear = false;
+ if (l.need_to_clear) {
+ output_example_seq(all, l);
+ clear_seq(all, l);
+ l.need_to_clear = false;
}
if (OAA::example_is_newline(ec)) {
- do_actual_learning(*all);
- if (!LabelDict::ec_seq_is_label_definition(ec_seq))
- global_print_newline(*all);
+ do_actual_learning(all, l);
+ if (!LabelDict::ec_seq_is_label_definition(l, l.ec_seq))
+ global_print_newline(all);
if (ec->in_use)
- VW::finish_example(*all, ec);
- need_to_clear = true;
+ VW::finish_example(all, ec);
+ l.need_to_clear = true;
} else if (LabelDict::ec_is_label_definition(ec)) {
- if (ec_seq.size() > 0)
+ if (l.ec_seq.size() > 0)
cerr << "warning: label definition encountered in data block -- ignoring data!" << endl;
- learn_singleline(all, ec);
+ learn_singleline(all, l, ec);
if (ec->in_use)
- VW::finish_example(*all, ec);
+ VW::finish_example(all, ec);
} else {
- ec_seq.push_back(ec);
+ l.ec_seq.push_back(ec);
}
}
- void learn(void*a, example*ec) {
+ void learn(void*a, void* d, example*ec) {
vw* all = (vw*)a;
- if (is_singleline) learn_singleline(all, ec);
- else learn_multiline(all, ec);
+ ldf* l = (ldf*)d;
+ if (l->is_singleline) learn_singleline(*all,*l, ec);
+ else learn_multiline(*all,*l, ec);
}
- void finish(void* a)
+ void finish(void* a, void* d)
{
- vw* all = (vw*)a;
- clear_seq(*all);
- ec_seq.delete_v();
- base_finish(all);
- LabelDict::free_label_features();
+ ldf* l=(ldf*)d;
+ l->base.finish(a,l->base.data);
+ vw* all = (vw*)a;
+ clear_seq(*all, *l);
+ l->ec_seq.delete_v();
+ LabelDict::free_label_features(*l);
}
- void drive_ldf_singleline(vw*all) {
+ void drive_ldf_singleline(vw& all, ldf& l) {
example* ec = NULL;
while (true) {
- if ((ec = get_example(all->p)) != NULL) { //semiblocking operation.
+ if ((ec = get_example(all.p)) != NULL) { //semiblocking operation.
if (LabelDict::ec_is_example_header(ec)) {
cerr << "error: example headers not allowed in ldf singleline mode" << endl;
exit(-1);
}
- learn_singleline(all, ec);
+ learn_singleline(all, l, ec);
if (! LabelDict::ec_is_label_definition(ec)) {
- all->sd->weighted_examples += 1;
- all->sd->example_number++;
+ all.sd->weighted_examples += 1;
+ all.sd->example_number++;
}
bool hit_loss = false;
- output_example(*all, ec, hit_loss);
+ output_example(all, ec, hit_loss);
if (ec->in_use)
- VW::finish_example(*all, ec);
- } else if (parser_done(all->p)) {
+ VW::finish_example(all, ec);
+ } else if (parser_done(all.p)) {
return;
}
}
}
- void drive_ldf_multiline(vw*all) {
+ void drive_ldf_multiline(vw& all, ldf& l) {
example* ec = NULL;
- read_example_this_loop = 0;
- need_to_clear = false;
+ l.read_example_this_loop = 0;
+ l.need_to_clear = false;
while (true) {
- if ((ec = get_example(all->p)) != NULL) { // semiblocking operation
- learn_multiline(all, ec);
- if (need_to_clear) {
- output_example_seq(*all);
- clear_seq(*all);
- need_to_clear = false;
+ if ((ec = get_example(all.p)) != NULL) { // semiblocking operation
+ learn_multiline(all, l, ec);
+ if (l.need_to_clear) {
+ output_example_seq(all, l);
+ clear_seq(all, l);
+ l.need_to_clear = false;
}
- } else if (parser_done(all->p)) {
- do_actual_learning(*all);
- output_example_seq(*all);
- clear_seq(*all);
- ec_seq.delete_v();
+ } else if (parser_done(all.p)) {
+ do_actual_learning(all, l);
+ output_example_seq(all, l);
+ clear_seq(all, l);
+ l.ec_seq.delete_v();
return;
}
}
}
- void drive_ldf(void*in)
+ void drive(void*in, void* d)
{
vw* all =(vw*)in;
- if (is_singleline)
- drive_ldf_singleline(all);
+ ldf* l = (ldf*)d;
+ if (l->is_singleline)
+ drive_ldf_singleline(*all, *l);
else
- drive_ldf_multiline(all);
+ drive_ldf_multiline(*all,*l);
}
-
-
+
void parse_flags(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
{
+ ldf* ld = (ldf*)calloc(1, sizeof(ldf));
+ ld->need_to_clear = true;
+ ld->is_singleline = true;
+
string ldf_arg;
if(vm_file.count("csoaa_ldf")) {
ldf_arg = vm_file["csoaa_ldf"].as<string>();
@@ -908,7 +1049,7 @@ namespace CSOAA_AND_WAP_LDF {
}
else if( vm_file.count("wap_ldf") ) {
ldf_arg = vm_file["wap_ldf"].as<string>();
- is_wap = true;
+ ld->is_wap = true;
if(vm.count("wap_ldf") && ldf_arg.compare(vm["wap_ldf"].as<string>()) != 0) {
ldf_arg = vm["csoaa_ldf"].as<string>();
@@ -917,7 +1058,7 @@ namespace CSOAA_AND_WAP_LDF {
}
else {
ldf_arg = vm["wap_ldf"].as<string>();
- is_wap = true;
+ ld->is_wap = true;
all.options_from_file.append(" --wap_ldf ");
all.options_from_file.append(ldf_arg);
}
@@ -927,9 +1068,9 @@ namespace CSOAA_AND_WAP_LDF {
all.sd->k = (uint32_t)-1;
if (ldf_arg.compare("singleline") == 0 || ldf_arg.compare("s") == 0)
- is_singleline = true;
+ ld->is_singleline = true;
else if (ldf_arg.compare("multiline") == 0 || ldf_arg.compare("m") == 0)
- is_singleline = false;
+ ld->is_singleline = false;
else {
cerr << "ldf requires either [s]ingleline or [m]ultiline argument" << endl;
exit(-1);
@@ -938,14 +1079,10 @@ namespace CSOAA_AND_WAP_LDF {
if (all.add_constant) {
all.add_constant = false;
}
-
- if (!all.is_noop)
- all.driver = drive_ldf;
- base_learner = all.learn;
- all.base_learn = all.learn;
- all.learn = learn;
- base_finish = all.finish;
- all.finish = finish;
+ ld->label_features = v_hashmap<size_t, v_array<feature> >::v_hashmap(256, v_array<feature>(), LabelDict::size_t_eq);
+
+ learner l = {ld, drive, learn, finish, all.l.save_load};
+ all.l = l;
}
void global_print_newline(vw& all)
@@ -959,142 +1096,6 @@ namespace CSOAA_AND_WAP_LDF {
std::cerr << "write error" << std::endl;
}
}
-
}
-namespace LabelDict {
- bool size_t_eq(size_t a, size_t b) { return (a==b); }
- v_hashmap< size_t, v_array<feature> > label_features(256, v_array<feature>(), size_t_eq);
-
- size_t hash_lab(size_t lab) { return 328051 + 94389193 * lab; }
-
- bool ec_is_label_definition(example*ec) // label defs look like "___:-1"
- {
- v_array<CSOAA::wclass> costs = ((CSOAA::label*)ec->ld)->costs;
- for (size_t j=0; j<costs.size(); j++)
- if (costs[j].x >= 0.) return false;
- if (ec->indices.size() == 0) return false;
- if (ec->indices.size() > 2) return false;
- if (ec->indices[0] != 'l') return false;
- return true;
- }
-
- bool ec_is_example_header(example*ec) // example headers look like "0:-1"
- {
- v_array<CSOAA::wclass> costs = ((CSOAA::label*)ec->ld)->costs;
- if (costs.size() != 1) return false;
- if (costs[0].weight_index != 0) return false;
- if (costs[0].x >= 0) return false;
- return true;
- }
-
- bool ec_seq_is_label_definition(v_array<example*>ec_seq)
- {
- if (ec_seq.size() == 0) return false;
- bool is_lab = ec_is_label_definition(ec_seq[0]);
- for (size_t i=1; i<ec_seq.size(); i++) {
- if (is_lab != ec_is_label_definition(ec_seq[i])) {
- if (!((i == ec_seq.size()-1) && (OAA::example_is_newline(ec_seq[i])))) {
- cerr << "error: mixed label definition and examples in ldf data!" << endl;
- exit(-1);
- }
- }
- }
- return is_lab;
- }
-
- void del_example_namespace(example*ec, char ns, v_array<feature> features) {
- size_t numf = features.size();
- ec->num_features -= numf;
-
- assert (ec->atomics[(size_t)ns].size() >= numf);
- if (ec->atomics[(size_t)ns].size() == numf) { // did NOT have ns
- assert(ec->indices.size() > 0);
- assert(ec->indices[ec->indices.size()-1] == (size_t)ns);
- ec->indices.pop();
- ec->total_sum_feat_sq -= ec->sum_feat_sq[(size_t)ns];
- ec->atomics[(size_t)ns].erase();
- ec->sum_feat_sq[(size_t)ns] = 0.;
- } else { // DID have ns
- for (feature*f=features.begin; f!=features.end; f++) {
- ec->sum_feat_sq[(size_t)ns] -= f->x * f->x;
- ec->atomics[(size_t)ns].pop();
- }
- }
- }
-
- void add_example_namespace(example*ec, char ns, v_array<feature> features) {
- bool has_ns = false;
- for (size_t i=0; i<ec->indices.size(); i++) {
- if (ec->indices[i] == (size_t)ns) {
- has_ns = true;
- break;
- }
- }
- if (has_ns) {
- ec->total_sum_feat_sq -= ec->sum_feat_sq[(size_t)ns];
- } else {
- ec->indices.push_back((size_t)ns);
- ec->sum_feat_sq[(size_t)ns] = 0;
- }
-
- for (feature*f=features.begin; f!=features.end; f++) {
- ec->sum_feat_sq[(size_t)ns] += f->x * f->x;
- ec->atomics[(size_t)ns].push_back(*f);
- }
-
- ec->num_features += features.size();
- ec->total_sum_feat_sq += ec->sum_feat_sq[(size_t)ns];
- }
-
-
-
- void add_example_namespaces_from_example(example*target, example*source) {
- for (unsigned char* idx=source->indices.begin; idx!=source->indices.end; idx++) {
- if (*idx == constant_namespace) continue;
- add_example_namespace(target, (char)*idx, source->atomics[*idx]);
- }
- }
-
- void del_example_namespaces_from_example(example*target, example*source) {
- //for (size_t*idx=source->indices.begin; idx!=source->indices.end; idx++) {
- unsigned char* idx = source->indices.end;
- idx--;
- for (; idx>=source->indices.begin; idx--) {
- if (*idx == constant_namespace) continue;
- del_example_namespace(target, (char)*idx, source->atomics[*idx]);
- }
- }
-
- void add_example_namespace_from_memory(example*ec, size_t lab) {
- size_t lab_hash = hash_lab(lab);
- v_array<feature> features = label_features.get(lab, lab_hash);
- if (features.size() == 0) return;
- add_example_namespace(ec, 'l', features);
- }
-
- void del_example_namespace_from_memory(example* ec, size_t lab) {
- size_t lab_hash = hash_lab(lab);
- v_array<feature> features = label_features.get(lab, lab_hash);
- if (features.size() == 0) return;
- del_example_namespace(ec, 'l', features);
- }
-
- void set_label_features(size_t lab, v_array<feature>features) {
- size_t lab_hash = hash_lab(lab);
- if (label_features.contains(lab, lab_hash)) { return; }
- label_features.put_after_get(lab, lab_hash, features);
- }
-
- void free_label_features() {
- void* label_iter = LabelDict::label_features.iterator();
- while (label_iter != NULL) {
- v_array<feature> features = LabelDict::label_features.iterator_get_value(label_iter);
- features.erase();
- features.delete_v();
-
- label_iter = LabelDict::label_features.iterator_next(label_iter);
- }
- }
-}
diff --git a/vowpalwabbit/csoaa.h b/vowpalwabbit/csoaa.h
index fa1fcb34..d8722ea5 100644
--- a/vowpalwabbit/csoaa.h
+++ b/vowpalwabbit/csoaa.h
@@ -58,16 +58,4 @@ namespace CSOAA_AND_WAP_LDF {
const label_parser cs_label_parser = CSOAA::cs_label_parser;
}
-namespace LabelDict {
- bool ec_is_label_definition(example*ec);
- bool ec_is_example_header(example*ec);
- bool ec_seq_is_label_definition(v_array<example*>ec_seq);
- void add_example_namespaces_from_example(example*target, example*source);
- void del_example_namespaces_from_example(example*target, example*source);
- void add_example_namespace_from_memory(example*ec, size_t lab);
- void del_example_namespace_from_memory(example* ec, size_t lab);
- void set_label_features(size_t lab, v_array<feature>features);
- void free_label_features();
-}
-
#endif
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc
index d5a81563..4486db0d 100644
--- a/vowpalwabbit/ect.cc
+++ b/vowpalwabbit/ect.cc
@@ -23,11 +23,6 @@ using namespace std;
namespace ECT
{
-
- //nonreentrant
- uint32_t k = 1;
- uint32_t errors = 0;
-
struct direction {
size_t id; //unique id for node
size_t tournament; //unique id for node
@@ -37,24 +32,30 @@ namespace ECT
uint32_t right; //down traversal, right
bool last;
};
-
- v_array<direction> directions;//The nodes of the tournament datastructure
-
- v_array<v_array<v_array<uint32_t > > > all_levels;
-
- v_array<uint32_t> final_nodes; //The final nodes of each tournament.
-
- v_array<size_t> up_directions; //On edge e, which node n is in the up direction?
- v_array<size_t> down_directions;//On edge e, which node n is in the down direction?
-
- size_t tree_height = 0; //The height of the final tournament.
- uint32_t last_pair = 0;
-
- uint32_t increment = 0;
+ struct ect{
+ uint32_t k;
+ uint32_t errors;
+ v_array<direction> directions;//The nodes of the tournament datastructure
+
+ v_array<v_array<v_array<uint32_t > > > all_levels;
+
+ v_array<uint32_t> final_nodes; //The final nodes of each tournament.
+
+ v_array<size_t> up_directions; //On edge e, which node n is in the up direction?
+ v_array<size_t> down_directions;//On edge e, which node n is in the down direction?
+
+ size_t tree_height; //The height of the final tournament.
+
+ uint32_t last_pair;
+
+ uint32_t increment;
+
+ v_array<bool> tournaments_won;
+
+ learner base;
+ };
- v_array<bool> tournaments_won;
-
bool exists(v_array<size_t> db)
{
for (size_t i = 0; i< db.size();i++)
@@ -94,19 +95,7 @@ namespace ECT
cout << endl;
}
- void print_state()
- {
- cout << "all_levels = " << endl;
- for (size_t l = 0; l < all_levels.size(); l++)
- print_level(all_levels[l]);
-
- cout << "directions = " << endl;
- for (size_t i = 0; i < directions.size(); i++)
- cout << " | " << directions[i].id << " t" << directions[i].tournament << " " << directions[i].winner << " " << directions[i].loser << " " << directions[i].left << " " << directions[i].right << " " << directions[i].last;
- cout << endl;
- }
-
- void create_circuit(vw& all, uint32_t max_label, uint32_t eliminations)
+ void create_circuit(vw& all, ect& e, uint32_t max_label, uint32_t eliminations)
{
if (max_label == 1)
return;
@@ -119,7 +108,7 @@ namespace ECT
{
t.push_back(i);
direction d = {i,0,0,0,0,0, false};
- directions.push_back(d);
+ e.directions.push_back(d);
}
tournaments.push_back(t);
@@ -127,16 +116,16 @@ namespace ECT
for (size_t i = 0; i < eliminations-1; i++)
tournaments.push_back(v_array<uint32_t>());
- all_levels.push_back(tournaments);
+ e.all_levels.push_back(tournaments);
size_t level = 0;
- uint32_t node = (uint32_t)directions.size();
+ uint32_t node = (uint32_t)e.directions.size();
- while (not_empty(all_levels[level]))
+ while (not_empty(e.all_levels[level]))
{
v_array<v_array<uint32_t > > new_tournaments;
- tournaments = all_levels[level];
+ tournaments = e.all_levels[level];
for (size_t t = 0; t < tournaments.size(); t++)
{
@@ -153,59 +142,56 @@ namespace ECT
uint32_t right = tournaments[t][2*j+1];
direction d = {id,t,0,0,left,right, false};
- directions.push_back(d);
- uint32_t direction_index = (uint32_t)directions.size()-1;
- if (directions[left].tournament == t)
- directions[left].winner = direction_index;
+ e.directions.push_back(d);
+ uint32_t direction_index = (uint32_t)e.directions.size()-1;
+ if (e.directions[left].tournament == t)
+ e.directions[left].winner = direction_index;
else
- directions[left].loser = direction_index;
- if (directions[right].tournament == t)
- directions[right].winner = direction_index;
+ e.directions[left].loser = direction_index;
+ if (e.directions[right].tournament == t)
+ e.directions[right].winner = direction_index;
else
- directions[right].loser = direction_index;
- if (directions[left].last == true)
- directions[left].winner = direction_index;
+ e.directions[right].loser = direction_index;
+ if (e.directions[left].last == true)
+ e.directions[left].winner = direction_index;
if (tournaments[t].size() == 2 && (t == 0 || tournaments[t-1].size() == 0))
{
- directions[direction_index].last = true;
+ e.directions[direction_index].last = true;
if (t+1 < tournaments.size())
new_tournaments[t+1].push_back(id);
else // winner eliminated.
- directions[direction_index].winner = 0;
- final_nodes.push_back((uint32_t)(directions.size()- 1));
+ e.directions[direction_index].winner = 0;
+ e.final_nodes.push_back((uint32_t)(e.directions.size()- 1));
}
else
new_tournaments[t].push_back(id);
if (t+1 < tournaments.size())
new_tournaments[t+1].push_back(id);
else // loser eliminated.
- directions[direction_index].loser = 0;
+ e.directions[direction_index].loser = 0;
}
if (tournaments[t].size() % 2 == 1)
new_tournaments[t].push_back(tournaments[t].last());
}
- all_levels.push_back(new_tournaments);
+ e.all_levels.push_back(new_tournaments);
level++;
}
- last_pair = (max_label - 1)*(eliminations);
+ e.last_pair = (max_label - 1)*(eliminations);
if ( max_label > 1)
- tree_height = final_depth(eliminations);
+ e.tree_height = final_depth(eliminations);
- if (last_pair > 0) {
- all.base_learner_nb_w *= (last_pair + (eliminations-1));
- increment = (uint32_t) all.length() / all.base_learner_nb_w * all.stride;
+ if (e.last_pair > 0) {
+ all.base_learner_nb_w *= (e.last_pair + (eliminations-1));
+ e.increment = (uint32_t) all.length() / all.base_learner_nb_w * all.stride;
}
}
- void (*base_learner)(void*, example*) = NULL;
- void (*base_finish)(void*) = NULL;
-
- size_t ect_predict(vw& all, example* ec)
+ size_t ect_predict(vw& all, ect& e, example* ec)
{
- if (k == (size_t)1)
+ if (e.k == (size_t)1)
return 1;
uint32_t finals_winner = 0;
@@ -214,19 +200,19 @@ namespace ECT
label_data simple_temp = {FLT_MAX, 0., 0.};
ec->ld = & simple_temp;
- for (size_t i = tree_height-1; i != (size_t)0 -1; i--)
+ for (size_t i = e.tree_height-1; i != (size_t)0 -1; i--)
{
- if ((finals_winner | (((size_t)1) << i)) <= errors)
+ if ((finals_winner | (((size_t)1) << i)) <= e.errors)
{// a real choice exists
uint32_t offset = 0;
- uint32_t problem_number = last_pair + (finals_winner | (((uint32_t)1) << i)) - 1; //This is unique.
- offset = problem_number*increment;
+ uint32_t problem_number = e.last_pair + (finals_winner | (((uint32_t)1) << i)) - 1; //This is unique.
+ offset = problem_number*e.increment;
update_example_indicies(all.audit, ec,offset);
ec->partial_prediction = 0;
- base_learner(&all, ec);
+ e.base.learn(&all,e.base.data, ec);
update_example_indicies(all.audit, ec,-offset);
@@ -236,21 +222,21 @@ namespace ECT
}
}
- uint32_t id = final_nodes[finals_winner];
- while (id >= k)
+ uint32_t id = e.final_nodes[finals_winner];
+ while (id >= e.k)
{
- uint32_t offset = (id-k)*increment;
+ uint32_t offset = (id-e.k)*e.increment;
ec->partial_prediction = 0;
update_example_indicies(all.audit, ec,offset);
- base_learner(&all, ec);
+ e.base.learn(&all,e.base.data, ec);
float pred = ec->final_prediction;
update_example_indicies(all.audit, ec,-offset);
if (pred > 0.)
- id = directions[id].right;
+ id = e.directions[id].right;
else
- id = directions[id].left;
+ id = e.directions[id].left;
}
return id+1;
}
@@ -263,18 +249,18 @@ namespace ECT
return false;
}
- void ect_train(vw& all, example* ec)
+ void ect_train(vw& all, ect& e, example* ec)
{
- if (k == 1)//nothing to do
+ if (e.k == 1)//nothing to do
return;
OAA::mc_label * mc = (OAA::mc_label*)ec->ld;
label_data simple_temp = {1.,mc->weight,0.};
- tournaments_won.erase();
+ e.tournaments_won.erase();
- uint32_t id = directions[mc->label-1].winner;
- bool left = directions[id].left == mc->label - 1;
+ uint32_t id = e.directions[mc->label-1].winner;
+ bool left = e.directions[id].left == mc->label - 1;
do
{
if (left)
@@ -285,15 +271,15 @@ namespace ECT
simple_temp.weight = mc->weight;
ec->ld = &simple_temp;
- uint32_t offset = (id-k)*increment;
+ uint32_t offset = (id-e.k)*e.increment;
update_example_indicies(all.audit, ec,offset);
ec->partial_prediction = 0;
- base_learner(&all, ec);
+ e.base.learn(&all,e.base.data, ec);
simple_temp.weight = 0.;
ec->partial_prediction = 0;
- base_learner(&all, ec);//inefficient, we should extract final prediction exactly.
+ e.base.learn(&all,e.base.data, ec);//inefficient, we should extract final prediction exactly.
float pred = ec->final_prediction;
update_example_indicies(all.audit, ec,-offset);
@@ -301,39 +287,39 @@ namespace ECT
if (won)
{
- if (!directions[id].last)
- left = directions[directions[id].winner].left == id;
+ if (!e.directions[id].last)
+ left = e.directions[e.directions[id].winner].left == id;
else
- tournaments_won.push_back(true);
- id = directions[id].winner;
+ e.tournaments_won.push_back(true);
+ id = e.directions[id].winner;
}
else
{
- if (!directions[id].last)
+ if (!e.directions[id].last)
{
- left = directions[directions[id].loser].left == id;
- if (directions[id].loser == 0)
- tournaments_won.push_back(false);
+ left = e.directions[e.directions[id].loser].left == id;
+ if (e.directions[id].loser == 0)
+ e.tournaments_won.push_back(false);
}
else
- tournaments_won.push_back(false);
- id = directions[id].loser;
+ e.tournaments_won.push_back(false);
+ id = e.directions[id].loser;
}
}
while(id != 0);
- if (tournaments_won.size() < 1)
+ if (e.tournaments_won.size() < 1)
cout << "badness!" << endl;
//tournaments_won is a bit vector determining which tournaments the label won.
- for (size_t i = 0; i < tree_height; i++)
+ for (size_t i = 0; i < e.tree_height; i++)
{
- for (uint32_t j = 0; j < tournaments_won.size()/2; j++)
+ for (uint32_t j = 0; j < e.tournaments_won.size()/2; j++)
{
- bool left = tournaments_won[j*2];
- bool right = tournaments_won[j*2+1];
+ bool left = e.tournaments_won[j*2];
+ bool right = e.tournaments_won[j*2+1];
if (left == right)//no query to do
- tournaments_won[j] = left;
+ e.tournaments_won[j] = left;
else //query to do
{
float label;
@@ -342,72 +328,73 @@ namespace ECT
else
label = 1;
simple_temp.label = label;
- simple_temp.weight = (float)(1 << (tree_height -i -1));
+ simple_temp.weight = (float)(1 << (e.tree_height -i -1));
ec->ld = & simple_temp;
- uint32_t problem_number = last_pair + j*(1 << (i+1)) + (1 << i) -1;
+ uint32_t problem_number = e.last_pair + j*(1 << (i+1)) + (1 << i) -1;
- uint32_t offset = problem_number*increment;
+ uint32_t offset = problem_number*e.increment;
update_example_indicies(all.audit, ec,offset);
ec->partial_prediction = 0;
- base_learner(&all, ec);
+ e.base.learn(&all,e.base.data, ec);
update_example_indicies(all.audit, ec,-offset);
float pred = ec->final_prediction;
if (pred > 0.)
- tournaments_won[j] = right;
+ e.tournaments_won[j] = right;
else
- tournaments_won[j] = left;
+ e.tournaments_won[j] = left;
}
- if (tournaments_won.size() %2 == 1)
- tournaments_won[tournaments_won.size()/2] = tournaments_won[tournaments_won.size()-1];
- tournaments_won.end = tournaments_won.begin+(1+tournaments_won.size())/2;
+ if (e.tournaments_won.size() %2 == 1)
+ e.tournaments_won[e.tournaments_won.size()/2] = e.tournaments_won[e.tournaments_won.size()-1];
+ e.tournaments_won.end = e.tournaments_won.begin+(1+e.tournaments_won.size())/2;
}
}
}
- void learn(void*a, example* ec)
+ void learn(void*a, void* d, example* ec)
{
vw* all = (vw*)a;
+ ect* e=(ect*)d;
OAA::mc_label* mc = (OAA::mc_label*)ec->ld;
- if (mc->label > k)
+ if (mc->label > e->k)
cout << "label > maximum label! This won't work right." << endl;
- size_t new_label = ect_predict(*all, ec);
+ size_t new_label = ect_predict(*all, *e, ec);
ec->ld = mc;
if (mc->label != (uint32_t)-1 && all->training)
- ect_train(*all, ec);
+ ect_train(*all, *e, ec);
ec->ld = mc;
*(OAA::prediction_t*)&(ec->final_prediction) = new_label;
}
- void finish(void* all)
+ void finish(void* all, void* d)
{
- for (size_t l = 0; l < all_levels.size(); l++)
+ ect* e = (ect*)d;
+ e->base.finish(all, e->base.data);
+ for (size_t l = 0; l < e->all_levels.size(); l++)
{
- for (size_t t = 0; t < all_levels[l].size(); t++)
- all_levels[l][t].delete_v();
- all_levels[l].delete_v();
+ for (size_t t = 0; t < e->all_levels[l].size(); t++)
+ e->all_levels[l][t].delete_v();
+ e->all_levels[l].delete_v();
}
- final_nodes.delete_v();
+ e->final_nodes.delete_v();
- up_directions.delete_v();
+ e->up_directions.delete_v();
- directions.delete_v();
+ e->directions.delete_v();
- down_directions.delete_v();
+ e->down_directions.delete_v();
- tournaments_won.delete_v();
-
- base_finish(all);
+ e->tournaments_won.delete_v();
}
- void drive_ect(void* in)
+ void drive(void* in, void* d)
{
vw* all = (vw*)in;
example* ec = NULL;
@@ -415,7 +402,7 @@ namespace ECT
{
if ((ec = get_example(all->p)) != NULL)//semiblocking operation.
{
- learn(all, ec);
+ learn(all, d, ec);
OAA::output_example(*all, ec);
VW::finish_example(*all, ec);
}
@@ -430,6 +417,7 @@ namespace ECT
void parse_flags(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
{
+ ect* data = (ect*)calloc(1, sizeof(ect));
po::options_description desc("ECT options");
desc.add_options()
("error", po::value<size_t>(), "error in ECT");
@@ -448,48 +436,43 @@ namespace ECT
po::notify(vm_file);
//first parse for number of actions
- k = 0;
+ data->k = 0;
if( vm_file.count("ect") ) {
- k = (int)vm_file["ect"].as<size_t>();
- if( vm.count("ect") && vm["ect"].as<size_t>() != k )
- std::cerr << "warning: you specified a different number of actions through --ect than the one loaded from predictor. Pursuing with loaded value of: " << k << endl;
+ data->k = (int)vm_file["ect"].as<size_t>();
+ if( vm.count("ect") && vm["ect"].as<size_t>() != data->k )
+ std::cerr << "warning: you specified a different number of actions through --ect than the one loaded from predictor. Pursuing with loaded value of: " << data->k << endl;
}
else {
- k = (int)vm["ect"].as<size_t>();
+ data->k = (int)vm["ect"].as<size_t>();
//append ect with nb_actions to options_from_file so it is saved to regressor later
std::stringstream ss;
- ss << " --ect " << k;
+ ss << " --ect " << data->k;
all.options_from_file.append(ss.str());
}
if(vm_file.count("error")) {
- errors = (uint32_t)vm_file["error"].as<size_t>();
- if (vm.count("error") && (uint32_t)vm["error"].as<size_t>() != errors) {
- cerr << "warning: specified value for --error different than the one loaded from predictor file. Pursuing with loaded value of: " << errors << endl;
+ data->errors = (uint32_t)vm_file["error"].as<size_t>();
+ if (vm.count("error") && (uint32_t)vm["error"].as<size_t>() != data->errors) {
+ cerr << "warning: specified value for --error different than the one loaded from predictor file. Pursuing with loaded value of: " << data->errors << endl;
}
}
else if (vm.count("error")) {
- errors = (uint32_t)vm["error"].as<size_t>();
+ data->errors = (uint32_t)vm["error"].as<size_t>();
//append error flag to options_from_file so it is saved in regressor file later
stringstream ss;
- ss << " --error " << errors;
+ ss << " --error " << data->errors;
all.options_from_file.append(ss.str());
} else {
- errors = 0;
+ data->errors = 0;
}
*(all.p->lp) = OAA::mc_label_parser;
- all.driver = drive_ect;
- base_learner = all.learn;
- all.base_learn = all.learn;
- all.learn = learn;
-
- base_finish = all.finish;
- all.finish = finish;
-
- create_circuit(all, k, errors+1);
+ data->base = all.l;
+ create_circuit(all, *data, data->k, data->errors+1);
+
+ learner l = {data, drive, learn, finish, all.l.save_load};
+ all.l = l;
}
-
}
diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc
index 2bf4542e..a493eaa0 100644
--- a/vowpalwabbit/gd.cc
+++ b/vowpalwabbit/gd.cc
@@ -27,15 +27,13 @@ license as described in the file LICENSE.
#include "simple_label.h"
#include "allreduce.h"
#include "accumulate.h"
+#include "learner.h"
using namespace std;
namespace GD
{
-//nonreentrant
-size_t gd_current_pass = 0;
-
void predict(vw& all, example* ex);
void sync_weights(vw& all);
@@ -120,11 +118,11 @@ inline void specialized_update(vw& all, float x, uint32_t fi, float avg_norm, fl
w[0] += update * x * t;
}
-void learn_gd(void* a, example* ec)
+void learn(void* a, void* d, example* ec)
{
vw* all = (vw*)a;
assert(ec->in_use);
- if (ec->pass != gd_current_pass)
+ if (ec->pass != all->current_pass)
{
if(all->span_server != "") {
@@ -137,11 +135,11 @@ void learn_gd(void* a, example* ec)
if (all->save_per_pass)
{
sync_weights(*all);
- save_predictor(*all, all->final_regressor_name, gd_current_pass);
+ save_predictor(*all, all->final_regressor_name, all->current_pass);
}
all->eta *= all->eta_decay_rate;
- gd_current_pass = ec->pass;
+ all->current_pass = ec->pass;
}
if (!command_example(*all, ec))
@@ -163,7 +161,7 @@ void learn_gd(void* a, example* ec)
}
}
-void finish_gd(void* a)
+ void finish(void* a, void* d)
{
vw* all = (vw*)a;
sync_weights(*all);
@@ -173,6 +171,8 @@ void finish_gd(void* a)
else
accumulate_avg(*all, all->span_server, all->reg, 0);
}
+ size_t* current_pass = (size_t*)d;
+ free(current_pass);
}
void sync_weights(vw& all) {
@@ -748,7 +748,7 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text)
while ((!read && i < length) || (read && brw >0));
}
-void save_load(void* in, io_buf& model_file, bool read, bool text)
+void save_load(void* in, void* data, io_buf& model_file, bool read, bool text)
{
vw* all=(vw*)in;
if(read)
@@ -783,7 +783,7 @@ void save_load(void* in, io_buf& model_file, bool read, bool text)
}
}
-void drive_gd(void* in)
+void driver(void* in, void* data)
{
vw* all = (vw*)in;
example* ec = NULL;
@@ -792,16 +792,20 @@ void drive_gd(void* in)
{
if ((ec = get_example(all->p)) != NULL)//semiblocking operation.
{
- learn_gd(all, ec);
+ learn(all, data, ec);
return_simple_example(*all, ec);
}
else if (parser_done(all->p))
- {
- finish_gd(all);
- return;
- }
+ return;
else
;//busywait when we have predicted on all examples but not yet trained on all.
}
}
+
+ learner get_learner()
+ {
+ size_t* current_pass = (size_t*)calloc(1, sizeof(size_t));
+ learner ret = {current_pass,driver,learn,finish,save_load};
+ return ret;
+ }
}
diff --git a/vowpalwabbit/gd.h b/vowpalwabbit/gd.h
index b62a59b3..13612774 100644
--- a/vowpalwabbit/gd.h
+++ b/vowpalwabbit/gd.h
@@ -16,6 +16,7 @@ license as described in the file LICENSE.
#include "parser.h"
#include "allreduce.h"
#include "sparse_dense.h"
+#include "learner.h"
namespace GD{
void print_result(int f, float res, v_array<char> tag);
@@ -31,15 +32,11 @@ void train_offset_example(regressor& r, example* ex, size_t offset);
void compute_update(example* ec);
void offset_train(regressor &reg, example* &ec, float update, size_t offset);
void train_one_example_single_thread(regressor& r, example* ex);
-void drive_gd(void*);
-void finish_gd(void*);
-void learn_gd(void*, example* ec);
-void save_load(void* in, io_buf& model_file, bool read, bool text);
+ learner get_learner();
void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text);
void output_and_account_example(example* ec);
bool command_example(vw&, example* ec);
-
template <float (*T)(vw&,float,uint32_t)>
float inline_predict(vw& all, example* &ec)
{
diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc
index 3ecb2a81..d7034053 100644
--- a/vowpalwabbit/gd_mf.cc
+++ b/vowpalwabbit/gd_mf.cc
@@ -207,7 +207,7 @@ float mf_predict(vw& all, example* ex)
return ex->final_prediction;
}
-void save_load(void* in, io_buf& model_file, bool read, bool text)
+ void save_load(void* in, void* d, io_buf& model_file, bool read, bool text)
{
vw* all = (vw*)in;
uint32_t length = 1 << all->num_bits;
@@ -261,27 +261,39 @@ void save_load(void* in, io_buf& model_file, bool read, bool text)
}
}
-void drive(void* in)
+ void learn(void* in, void* d, example* ec)
+ {
+ vw* all = (vw*)in;
+ size_t* current_pass = (size_t*) d;
+ if (ec->pass != *current_pass) {
+ all->eta *= all->eta_decay_rate;
+ *current_pass = ec->pass;
+ }
+ if (!GD::command_example(*all, ec))
+ {
+ mf_predict(*all,ec);
+ if (all->training && ((label_data*)(ec->ld))->label != FLT_MAX)
+ mf_inline_train(*all, ec, ec->eta_round);
+ }
+ }
+
+ void finish(void* a, void* d)
+ {
+ size_t* current_pass = (size_t*)d;
+ free(current_pass);
+ }
+
+ void drive(void* in, void* d)
{
vw* all = (vw*)in;
+
example* ec = NULL;
- size_t current_pass = 0;
while ( true )
{
if ((ec = get_example(all->p)) != NULL)//blocking operation.
{
- if (ec->pass != current_pass) {
- all->eta *= all->eta_decay_rate;
- //save_predictor(*all, all->final_regressor_name, current_pass);
- current_pass = ec->pass;
- }
- if (!GD::command_example(*all, ec))
- {
- mf_predict(*all,ec);
- if (all->training && ((label_data*)(ec->ld))->label != FLT_MAX)
- mf_inline_train(*all, ec, ec->eta_round);
- }
+ learn(in,d,ec);
return_simple_example(*all, ec);
}
else if (parser_done(all->p))
@@ -291,4 +303,10 @@ void drive(void* in)
}
}
+ void parse_flags(vw& all)
+ {
+ size_t* current_pass = (size_t*)calloc(1, sizeof(size_t));
+ learner t = {current_pass,drive,learn,finish,save_load};
+ all.l = t;
+ }
}
diff --git a/vowpalwabbit/gd_mf.h b/vowpalwabbit/gd_mf.h
index 8bd6c4c6..b88f12fd 100644
--- a/vowpalwabbit/gd_mf.h
+++ b/vowpalwabbit/gd_mf.h
@@ -13,7 +13,6 @@ license as described in the file LICENSE.
#include "gd.h"
namespace GDMF{
- void drive(void*);
- void save_load(void* in, io_buf& model_file, bool read, bool text);
+ void parse_flags(vw& all);
}
#endif
diff --git a/vowpalwabbit/global_data.cc b/vowpalwabbit/global_data.cc
index ae91381b..c5031bfa 100644
--- a/vowpalwabbit/global_data.cc
+++ b/vowpalwabbit/global_data.cc
@@ -172,6 +172,12 @@ void set_mm(shared_data* sd, float label)
void noop_mm(shared_data* sd, float label)
{}
+void vw::learn(void* a, example* ec)
+{
+ vw* all = (vw*)a;
+ all->l.learn(a,all->l.data,ec);
+}
+
vw::vw()
{
sd = (shared_data *) calloc(1, sizeof(shared_data));
@@ -185,6 +191,8 @@ vw::vw()
reg_mode = 0;
+ current_pass = 0;
+
bfgs = false;
hessian_on = false;
stride = 1;
@@ -200,13 +208,10 @@ vw::vw()
m = 15;
save_resume = false;
- driver = GD::drive_gd;
- learn = GD::learn_gd;
- finish = GD::finish_gd;
- save_load = GD::save_load;
- set_minmax = set_mm;
+ l = GD::get_learner();
+ scorer = l;
- base_learn = NULL;
+ set_minmax = set_mm;
base_learner_nb_w = 1;
diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h
index b2ba8a39..c58c934a 100644
--- a/vowpalwabbit/global_data.h
+++ b/vowpalwabbit/global_data.h
@@ -14,6 +14,7 @@ license as described in the file LICENSE.
#include "comp_io.h"
#include "example.h"
#include "config.h"
+#include "learner.h"
struct version_struct {
int major;
@@ -115,13 +116,15 @@ struct vw {
parser* p;
- void (*driver)(void *);
- void (*learn)(void *, example*);
- void (*base_learn)(void *, example*);
- void (*finish)(void *);
- void (*save_load)(void *, io_buf&, bool, bool);
+ learner l;//the top level leaner
+ learner scorer;//a scoring function
+
+ void learn(void*, example*);
+
void (*set_minmax)(shared_data* sd, float label);
+ size_t current_pass;
+
uint32_t num_bits; // log_2 of the number of features.
bool default_bits;
@@ -188,8 +191,6 @@ struct vw {
bool nonormalize;
bool do_reset_source;
- uint32_t csoaa_increment;
-
float normalized_sum_norm_x;
size_t normalized_idx; //offset idx where the norm is stored (1 or 2 depending on whether adaptive is true)
diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc
index be83fc5a..d73df641 100644
--- a/vowpalwabbit/lda_core.cc
+++ b/vowpalwabbit/lda_core.cc
@@ -564,8 +564,7 @@ void save_load(void* in, io_buf& model_file, bool read, bool text)
void parse_flags(vw&all, std::vector<std::string>&opts, po::variables_map& vm)
{
-
- po::options_description desc("Searn options");
+ po::options_description desc("LDA options");
desc.add_options()
("lda_alpha", po::value<float>(&all.lda_alpha), "Prior on sparsity of per-document topic weights")
("lda_rho", po::value<float>(&all.lda_rho), "Prior on sparsity of topic distributions")
@@ -750,4 +749,8 @@ void drive(void* in)
}
}
+ void parse_args()
+ {
+ }
+
}
diff --git a/vowpalwabbit/lda_core.h b/vowpalwabbit/lda_core.h
index e087da04..df0ca613 100644
--- a/vowpalwabbit/lda_core.h
+++ b/vowpalwabbit/lda_core.h
@@ -7,8 +7,6 @@ license as described in the file LICENSE.
#define LDA_CORE_H
namespace LDA{
- void drive(void*);
- void save_load(void* in, io_buf& model_file, bool read, bool text);
void parse_flags(vw&, std::vector<std::string>&, po::variables_map&);
}
diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h
new file mode 100644
index 00000000..e5e2ba89
--- /dev/null
+++ b/vowpalwabbit/learner.h
@@ -0,0 +1,18 @@
+/*
+Copyright (c) by respective owners including Yahoo!, Microsoft, and
+individual contributors. All rights reserved. Released under a BSD
+license as described in the file LICENSE.
+ */
+#ifndef LEARNER_H
+#define LEARNER_H
+// This is the interface for a learning algorithm
+
+struct learner {
+ void* data;
+
+ void (*driver)(void* all, void* data);
+ void (*learn)(void* all, void* data, example*);
+ void (*finish)(void* all, void* data);
+ void (*save_load)(void* all, void* data, io_buf&, bool read, bool text);
+};
+#endif
diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc
index a15d4757..286fc403 100644
--- a/vowpalwabbit/nn.cc
+++ b/vowpalwabbit/nn.cc
@@ -18,27 +18,23 @@ license as described in the file LICENSE.
using namespace std;
namespace NN {
- //nonreentrant
- uint32_t k=0;
- uint32_t increment=0;
- loss_function* squared_loss;
- example output_layer;
const float hidden_min_activation = -3;
const float hidden_max_activation = 3;
const int nn_constant = 533357803;
- bool dropout = false;
- uint64_t xsubi;
- uint64_t save_xsubi;
- size_t nn_current_pass = 0;
- bool inpass = false;
-
- static void
- free_stuff (void)
- {
- delete squared_loss;
- free (output_layer.indices.begin);
- free (output_layer.atomics[nn_output_namespace].begin);
- }
+
+ struct nn {
+ uint32_t k;
+ uint32_t increment;
+ loss_function* squared_loss;
+ example output_layer;
+ bool dropout;
+ uint64_t xsubi;
+ uint64_t save_xsubi;
+ size_t nn_current_pass;
+ bool inpass;
+
+ learner base;
+ };
#define cast_uint32_t static_cast<uint32_t>
@@ -68,48 +64,48 @@ namespace NN {
void (*base_learner)(void*,example*) = NULL;
- void learn_with_output(vw*all, example* ec, bool shouldOutput)
+ void learn_with_output(vw& all, nn& n, example* ec, bool shouldOutput)
{
- if (GD::command_example(*all, ec)) {
+ if (GD::command_example(all, ec)) {
return;
}
- if (all->bfgs && ec->pass != nn_current_pass) {
- xsubi = save_xsubi;
- nn_current_pass = ec->pass;
+ if (all.bfgs && ec->pass != n.nn_current_pass) {
+ n.xsubi = n.save_xsubi;
+ n.nn_current_pass = ec->pass;
}
label_data* ld = (label_data*)ec->ld;
float save_label = ld->label;
- void (*save_set_minmax) (shared_data*, float) = all->set_minmax;
+ void (*save_set_minmax) (shared_data*, float) = all.set_minmax;
float save_min_label;
float save_max_label;
- float dropscale = dropout ? 2.0f : 1.0f;
- loss_function* save_loss = all->loss;
+ float dropscale = n.dropout ? 2.0f : 1.0f;
+ loss_function* save_loss = all.loss;
- float* hidden_units = (float*) alloca (k * sizeof (float));
- bool* dropped_out = (bool*) alloca (k * sizeof (bool));
+ float* hidden_units = (float*) alloca (n.k * sizeof (float));
+ bool* dropped_out = (bool*) alloca (n.k * sizeof (bool));
string outputString;
stringstream outputStringStream(outputString);
- all->set_minmax = noop_mm;
- all->loss = squared_loss;
- save_min_label = all->sd->min_label;
- all->sd->min_label = hidden_min_activation;
- save_max_label = all->sd->max_label;
- all->sd->max_label = hidden_max_activation;
+ all.set_minmax = noop_mm;
+ all.loss = n.squared_loss;
+ save_min_label = all.sd->min_label;
+ all.sd->min_label = hidden_min_activation;
+ save_max_label = all.sd->max_label;
+ all.sd->max_label = hidden_max_activation;
ld->label = FLT_MAX;
- for (unsigned int i = 0; i < k; ++i)
+ for (unsigned int i = 0; i < n.k; ++i)
{
if (i != 0)
- update_example_indicies(all->audit, ec, increment);
+ update_example_indicies(all.audit, ec, n.increment);
ec->partial_prediction = 0.;
- base_learner(all,ec);
- hidden_units[i] = GD::finalize_prediction (*all, ec->partial_prediction);
+ n.base.learn(&all,n.base.data,ec);
+ hidden_units[i] = GD::finalize_prediction (all, ec->partial_prediction);
- dropped_out[i] = (dropout && merand48 (xsubi) < 0.5);
+ dropped_out[i] = (n.dropout && merand48 (n.xsubi) < 0.5);
if (shouldOutput) {
if (i > 0) outputStringStream << ' ';
@@ -117,10 +113,10 @@ namespace NN {
}
}
ld->label = save_label;
- all->loss = save_loss;
- all->set_minmax = save_set_minmax;
- all->sd->min_label = save_min_label;
- all->sd->max_label = save_max_label;
+ all.loss = save_loss;
+ all.set_minmax = save_set_minmax;
+ all.sd->min_label = save_min_label;
+ all.sd->max_label = save_max_label;
bool converse = false;
float save_partial_prediction = 0;
@@ -129,115 +125,115 @@ namespace NN {
CONVERSE: // That's right, I'm using goto. So sue me.
- output_layer.total_sum_feat_sq = 1;
- output_layer.sum_feat_sq[nn_output_namespace] = 1;
+ n.output_layer.total_sum_feat_sq = 1;
+ n.output_layer.sum_feat_sq[nn_output_namespace] = 1;
- for (unsigned int i = 0; i < k; ++i)
+ for (unsigned int i = 0; i < n.k; ++i)
{
float sigmah =
(dropped_out[i]) ? 0.0f : dropscale * fasttanh (hidden_units[i]);
- output_layer.atomics[nn_output_namespace][i+1].x = sigmah;
+ n.output_layer.atomics[nn_output_namespace][i+1].x = sigmah;
- output_layer.total_sum_feat_sq += sigmah * sigmah;
- output_layer.sum_feat_sq[nn_output_namespace] += sigmah * sigmah;
+ n.output_layer.total_sum_feat_sq += sigmah * sigmah;
+ n.output_layer.sum_feat_sq[nn_output_namespace] += sigmah * sigmah;
}
- if (inpass) {
+ if (n.inpass) {
// TODO: this is not correct if there is something in the
// nn_output_namespace but at least it will not leak memory
// in that case
- update_example_indicies (all->audit, ec, increment);
+ update_example_indicies (all.audit, ec, n.increment);
ec->indices.push_back (nn_output_namespace);
v_array<feature> save_nn_output_namespace = ec->atomics[nn_output_namespace];
- ec->atomics[nn_output_namespace] = output_layer.atomics[nn_output_namespace];
- ec->sum_feat_sq[nn_output_namespace] = output_layer.sum_feat_sq[nn_output_namespace];
- ec->total_sum_feat_sq += output_layer.sum_feat_sq[nn_output_namespace];
+ ec->atomics[nn_output_namespace] = n.output_layer.atomics[nn_output_namespace];
+ ec->sum_feat_sq[nn_output_namespace] = n.output_layer.sum_feat_sq[nn_output_namespace];
+ ec->total_sum_feat_sq += n.output_layer.sum_feat_sq[nn_output_namespace];
ec->partial_prediction = 0.;
- base_learner(all,ec);
- output_layer.partial_prediction = ec->partial_prediction;
- output_layer.loss = ec->loss;
- ec->total_sum_feat_sq -= output_layer.sum_feat_sq[nn_output_namespace];
+ n.base.learn(&all, n.base.data, ec);
+ n.output_layer.partial_prediction = ec->partial_prediction;
+ n.output_layer.loss = ec->loss;
+ ec->total_sum_feat_sq -= n.output_layer.sum_feat_sq[nn_output_namespace];
ec->sum_feat_sq[nn_output_namespace] = 0;
ec->atomics[nn_output_namespace] = save_nn_output_namespace;
ec->indices.pop ();
- update_example_indicies (all->audit, ec, -increment);
+ update_example_indicies (all.audit, ec, -n.increment);
}
else {
- output_layer.ld = ec->ld;
- output_layer.pass = ec->pass;
- output_layer.partial_prediction = 0;
- output_layer.eta_round = ec->eta_round;
- output_layer.eta_global = ec->eta_global;
- output_layer.global_weight = ec->global_weight;
- output_layer.example_t = ec->example_t;
- base_learner(all,&output_layer);
- output_layer.ld = 0;
+ n.output_layer.ld = ec->ld;
+ n.output_layer.pass = ec->pass;
+ n.output_layer.partial_prediction = 0;
+ n.output_layer.eta_round = ec->eta_round;
+ n.output_layer.eta_global = ec->eta_global;
+ n.output_layer.global_weight = ec->global_weight;
+ n.output_layer.example_t = ec->example_t;
+ n.base.learn(&all,n.base.data,&n.output_layer);
+ n.output_layer.ld = 0;
}
- output_layer.final_prediction = GD::finalize_prediction (*all, output_layer.partial_prediction);
+ n.output_layer.final_prediction = GD::finalize_prediction (all, n.output_layer.partial_prediction);
if (shouldOutput) {
- outputStringStream << ' ' << output_layer.partial_prediction;
- all->print_text(all->raw_prediction, outputStringStream.str(), ec->tag);
+ outputStringStream << ' ' << n.output_layer.partial_prediction;
+ all.print_text(all.raw_prediction, outputStringStream.str(), ec->tag);
}
- if (all->training && ld->label != FLT_MAX) {
- float gradient = all->loss->first_derivative(all->sd,
- output_layer.final_prediction,
+ if (all.training && ld->label != FLT_MAX) {
+ float gradient = all.loss->first_derivative(all.sd,
+ n.output_layer.final_prediction,
ld->label);
if (fabs (gradient) > 0) {
- all->loss = squared_loss;
- all->set_minmax = noop_mm;
- save_min_label = all->sd->min_label;
- all->sd->min_label = hidden_min_activation;
- save_max_label = all->sd->max_label;
- all->sd->max_label = hidden_max_activation;
-
- for (size_t i = k; i > 0; --i) {
+ all.loss = n.squared_loss;
+ all.set_minmax = noop_mm;
+ save_min_label = all.sd->min_label;
+ all.sd->min_label = hidden_min_activation;
+ save_max_label = all.sd->max_label;
+ all.sd->max_label = hidden_max_activation;
+
+ for (size_t i = n.k; i > 0; --i) {
if (! dropped_out[i-1]) {
float sigmah =
- output_layer.atomics[nn_output_namespace][i].x / dropscale;
+ n.output_layer.atomics[nn_output_namespace][i].x / dropscale;
float sigmahprime = dropscale * (1.0f - sigmah * sigmah);
- float nu = all->reg.weight_vector[output_layer.atomics[nn_output_namespace][i].weight_index & all->weight_mask];
+ float nu = all.reg.weight_vector[n.output_layer.atomics[nn_output_namespace][i].weight_index & all.weight_mask];
float gradhw = 0.5f * nu * gradient * sigmahprime;
- ld->label = GD::finalize_prediction (*all, hidden_units[i-1] - gradhw);
+ ld->label = GD::finalize_prediction (all, hidden_units[i-1] - gradhw);
if (ld->label != hidden_units[i-1]) {
ec->partial_prediction = 0.;
- base_learner(all,ec);
+ n.base.learn(&all,n.base.data,ec);
}
}
if (i != 1) {
- update_example_indicies(all->audit, ec, -increment);
+ update_example_indicies(all.audit, ec, -n.increment);
}
}
- all->loss = save_loss;
- all->set_minmax = save_set_minmax;
- all->sd->min_label = save_min_label;
- all->sd->max_label = save_max_label;
+ all.loss = save_loss;
+ all.set_minmax = save_set_minmax;
+ all.sd->min_label = save_min_label;
+ all.sd->max_label = save_max_label;
}
else
- update_example_indicies(all->audit, ec, -(k-1)*increment);
+ update_example_indicies(all.audit, ec, -(n.k-1)*n.increment);
}
else
- update_example_indicies(all->audit, ec, -(k-1)*increment);
+ update_example_indicies(all.audit, ec, -(n.k-1)*n.increment);
ld->label = save_label;
if (! converse) {
- save_partial_prediction = output_layer.partial_prediction;
- save_final_prediction = output_layer.final_prediction;
- save_ec_loss = output_layer.loss;
+ save_partial_prediction = n.output_layer.partial_prediction;
+ save_final_prediction = n.output_layer.final_prediction;
+ save_ec_loss = n.output_layer.loss;
}
- if (dropout && ! converse)
+ if (n.dropout && ! converse)
{
- update_example_indicies (all->audit, ec, (k-1)*increment);
+ update_example_indicies (all.audit, ec, (n.k-1)*n.increment);
- for (unsigned int i = 0; i < k; ++i)
+ for (unsigned int i = 0; i < n.k; ++i)
{
dropped_out[i] = ! dropped_out[i];
}
@@ -251,20 +247,22 @@ CONVERSE: // That's right, I'm using goto. So sue me.
ec->loss = save_ec_loss;
}
- void learn(void*a, example* ec) {
+ void learn(void*a, void* d,example* ec) {
vw* all = (vw*)a;
- learn_with_output(all, ec, false);
+ nn* n = (nn*)d;
+ learn_with_output(*all, *n, ec, false);
}
- void drive_nn(void *in)
+ void drive_nn(void *in, void* d)
{
vw* all = (vw*)in;
+ nn* n = (nn*)d;
example* ec = NULL;
while ( true )
{
if ((ec = get_example(all->p)) != NULL)//semiblocking operation.
{
- learn_with_output(all, ec, all->raw_prediction > 0);
+ learn_with_output(*all, *n, ec, all->raw_prediction > 0);
int save_raw_prediction = all->raw_prediction;
all->raw_prediction = -1;
return_simple_example(*all, ec);
@@ -277,8 +275,20 @@ CONVERSE: // That's right, I'm using goto. So sue me.
}
}
+ void finish(void* a, void* d)
+ {
+ nn* n =(nn*)d;
+ n->base.finish(a,n->base.data);
+ delete n->squared_loss;
+ free (n->output_layer.indices.begin);
+ free (n->output_layer.atomics[nn_output_namespace].begin);
+ free(n);
+ }
+
void parse_flags(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
{
+ nn* n = (nn*)calloc(1,sizeof(nn));
+
po::options_description desc("NN options");
desc.add_options()
("inpass", "Train or test sigmoidal feedforward network with input passthrough.")
@@ -299,28 +309,28 @@ CONVERSE: // That's right, I'm using goto. So sue me.
po::notify(vm_file);
//first parse for number of hidden units
- k = 0;
+ n->k = 0;
if( vm_file.count("nn") ) {
- k = (uint32_t)vm_file["nn"].as<size_t>();
- if( vm.count("nn") && (uint32_t)vm["nn"].as<size_t>() != k )
- std::cerr << "warning: you specified a different number of hidden units through --nn than the one loaded from predictor. Pursuing with loaded value of: " << k << endl;
+ n->k = (uint32_t)vm_file["nn"].as<size_t>();
+ if( vm.count("nn") && (uint32_t)vm["nn"].as<size_t>() != n->k )
+ std::cerr << "warning: you specified a different number of hidden units through --nn than the one loaded from predictor. Pursuing with loaded value of: " << n->k << endl;
}
else {
- k = (uint32_t)vm["nn"].as<size_t>();
+ n->k = (uint32_t)vm["nn"].as<size_t>();
std::stringstream ss;
- ss << " --nn " << k;
+ ss << " --nn " << n->k;
all.options_from_file.append(ss.str());
}
if( vm_file.count("dropout") ) {
- dropout = all.training || vm.count("dropout");
+ n->dropout = all.training || vm.count("dropout");
- if (! dropout && ! vm.count("meanfield") && ! all.quiet)
+ if (! n->dropout && ! vm.count("meanfield") && ! all.quiet)
std::cerr << "using mean field for testing, specify --dropout explicitly to override" << std::endl;
}
else if ( vm.count("dropout") ) {
- dropout = true;
+ n->dropout = true;
std::stringstream ss;
ss << " --dropout ";
@@ -328,62 +338,60 @@ CONVERSE: // That's right, I'm using goto. So sue me.
}
if ( vm.count("meanfield") ) {
- dropout = false;
+ n->dropout = false;
if (! all.quiet)
std::cerr << "using mean field for neural network "
<< (all.training ? "training" : "testing")
<< std::endl;
}
- if (dropout)
+ if (n->dropout)
if (! all.quiet)
std::cerr << "using dropout for neural network "
<< (all.training ? "training" : "testing")
<< std::endl;
if( vm_file.count("inpass") ) {
- inpass = true;
+ n->inpass = true;
}
else if (vm.count ("inpass")) {
- inpass = true;
+ n->inpass = true;
std::stringstream ss;
ss << " --inpass";
all.options_from_file.append(ss.str());
}
- if (inpass && ! all.quiet)
+ if (n->inpass && ! all.quiet)
std::cerr << "using input passthrough for neural network "
<< (all.training ? "training" : "testing")
<< std::endl;
- all.driver = drive_nn;
- base_learner = all.learn;
- all.base_learn = all.learn;
- all.learn = learn;
+ learner t = {n,drive_nn,learn,finish,all.l.save_load};
+ all.l = t;
- all.base_learner_nb_w *= (inpass) ? k + 1 : k;
- increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride;
+ all.base_learner_nb_w *= (n->inpass) ? n->k + 1 : n->k;
+ n->increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride;
bool initialize = true;
// TODO: output_layer audit
- memset (&output_layer, 0, sizeof (output_layer));
- output_layer.indices.push_back(nn_output_namespace);
+ memset (&n->output_layer, 0, sizeof (n->output_layer));
+ n->output_layer.indices.push_back(nn_output_namespace);
feature output = {1., nn_constant*all.stride};
- output_layer.atomics[nn_output_namespace].push_back(output);
- initialize &= (all.reg.weight_vector[output_layer.atomics[nn_output_namespace][0].weight_index & all.weight_mask] == 0);
+ n->output_layer.atomics[nn_output_namespace].push_back(output);
+ initialize &= (all.reg.weight_vector[n->output_layer.atomics[nn_output_namespace][0].weight_index & all.weight_mask] == 0);
- for (unsigned int i = 0; i < k; ++i)
+ for (unsigned int i = 0; i < n->k; ++i)
{
output.weight_index += all.stride;
- output_layer.atomics[nn_output_namespace].push_back(output);
- initialize &= (all.reg.weight_vector[output_layer.atomics[nn_output_namespace][i+1].weight_index & all.weight_mask] == 0);
+ n->output_layer.atomics[nn_output_namespace].push_back(output);
+ initialize &= (all.reg.weight_vector[n->output_layer.atomics[nn_output_namespace][i+1].weight_index & all.weight_mask] == 0);
}
- output_layer.num_features = k + 1;
- output_layer.in_use = true;
+ n->output_layer.num_features = n->k + 1;
+ n->output_layer.in_use = true;
if (initialize) {
if (! all.quiet)
@@ -391,15 +399,15 @@ CONVERSE: // That's right, I'm using goto. So sue me.
// output weights
- float sqrtk = sqrt ((float)k);
- for (unsigned int i = 0; i <= k; ++i)
+ float sqrtk = sqrt ((float)n->k);
+ for (unsigned int i = 0; i <= n->k; ++i)
{
- weight* w = &all.reg.weight_vector[output_layer.atomics[nn_output_namespace][i].weight_index & all.weight_mask];
+ weight* w = &all.reg.weight_vector[n->output_layer.atomics[nn_output_namespace][i].weight_index & all.weight_mask];
w[0] = (float) (frand48 () - 0.5) / sqrtk;
// prevent divide by zero error
- if (dropout && all.normalized_updates)
+ if (n->dropout && all.normalized_updates)
w[all.normalized_idx] = 1e-4f;
}
@@ -407,22 +415,20 @@ CONVERSE: // That's right, I'm using goto. So sue me.
unsigned int weight_index = constant * all.stride;
- for (unsigned int i = 0; i < k; ++i)
+ for (unsigned int i = 0; i < n->k; ++i)
{
all.reg.weight_vector[weight_index & all.weight_mask] = (float) (frand48 () - 0.5);
- weight_index += increment;
+ weight_index += n->increment;
}
}
- squared_loss = getLossFunction (0, "squared", 0);
+ n->squared_loss = getLossFunction (0, "squared", 0);
- xsubi = 0;
+ n->xsubi = 0;
if (vm.count("random_seed"))
- xsubi = vm["random_seed"].as<size_t>();
-
- save_xsubi = xsubi;
+ n->xsubi = vm["random_seed"].as<size_t>();
- atexit (free_stuff);
+ n->save_xsubi = n->xsubi;
}
}
diff --git a/vowpalwabbit/noop.cc b/vowpalwabbit/noop.cc
index 390fe53b..02ffdb27 100644
--- a/vowpalwabbit/noop.cc
+++ b/vowpalwabbit/noop.cc
@@ -11,20 +11,28 @@ license as described in the file LICENSE.
#include "simple_label.h"
namespace NOOP {
-void learn(void*in, example*ec) {}
+ void learn(void*in, void* d, example*ec) {}
+ void finish(void*in, void* d) {}
-void save_load(void* in, io_buf& model_file, bool read, bool text) {}
-
-void drive(void* in)
-{
- vw* all = (vw*)in;
- example* ec = NULL;
+ void save_load(void* in, void* d, io_buf& model_file, bool read, bool text) {}
- while ( !parser_done(all->p)){
- ec = get_example(all->p);
- if (ec != NULL)
- return_simple_example(*all, ec);
+ void drive(void* in, void* d)
+ {
+ vw* all = (vw*)in;
+ example* ec = NULL;
+
+ while ( !parser_done(all->p)){
+ ec = get_example(all->p);
+ if (ec != NULL)
+ return_simple_example(*all, ec);
+ }
+ }
+
+ void parse_flags(vw& all)
+ {
+ learner t = {NULL,drive,learn,finish,save_load};
+ all.l = t;
+ all.is_noop = true;
}
-}
}
diff --git a/vowpalwabbit/noop.h b/vowpalwabbit/noop.h
index aac50e08..d54f7d3a 100644
--- a/vowpalwabbit/noop.h
+++ b/vowpalwabbit/noop.h
@@ -4,7 +4,5 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
namespace NOOP {
- void drive(void*);
- void learn(void*, example*);
- void save_load(void* in, io_buf& model_file, bool read, bool text);
+ void parse_flags(vw&);
}
diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc
index 6358ad8c..2011af1d 100644
--- a/vowpalwabbit/oaa.cc
+++ b/vowpalwabbit/oaa.cc
@@ -17,6 +17,13 @@ using namespace std;
namespace OAA {
+ struct oaa{
+ uint32_t k;
+ uint32_t increment;
+ uint32_t total_increment;
+ learner base;
+ };
+
char* bufread_label(mc_label* ld, char* c)
{
ld->label = *(uint32_t *)c;
@@ -98,11 +105,6 @@ namespace OAA {
}
}
- //nonreentrant
- uint32_t k=0;
- uint32_t increment=0;
- uint32_t total_increment=0;
-
void print_update(vw& all, example *ec)
{
if (all.sd->weighted_examples > all.sd->dump_interval && !all.quiet && !all.bfgs)
@@ -148,21 +150,19 @@ namespace OAA {
print_update(all, ec);
}
- void (*base_learner)(void*,example*) = NULL;
-
- void learn_with_output(vw*all, example* ec, bool shouldOutput)
+ void learn_with_output(vw*all,oaa* d, example* ec, bool shouldOutput)
{
mc_label* mc_label_data = (mc_label*)ec->ld;
size_t prediction = 1;
float score = INT_MIN;
- if (mc_label_data->label > k && mc_label_data->label != (uint32_t)-1)
- cerr << "warning: label " << mc_label_data->label << " is greater than " << k << endl;
+ if (mc_label_data->label > d->k && mc_label_data->label != (uint32_t)-1)
+ cerr << "warning: label " << mc_label_data->label << " is greater than " << d->k << endl;
string outputString;
stringstream outputStringStream(outputString);
- for (size_t i = 1; i <= k; i++)
+ for (size_t i = 1; i <= d->k; i++)
{
label_data simple_temp;
simple_temp.initial = 0.;
@@ -173,8 +173,8 @@ namespace OAA {
simple_temp.weight = mc_label_data->weight;
ec->ld = &simple_temp;
if (i != 1)
- update_example_indicies(all->audit, ec, increment);
- base_learner(all,ec);
+ update_example_indicies(all->audit, ec, d->increment);
+ d->base.learn((void*)all,d->base.data,ec);
if (ec->partial_prediction > score)
{
score = ec->partial_prediction;
@@ -190,7 +190,7 @@ namespace OAA {
}
ec->ld = mc_label_data;
*(prediction_t*)&(ec->final_prediction) = prediction;
- update_example_indicies(all->audit, ec, -total_increment);
+ update_example_indicies(all->audit, ec, -d->total_increment);
if (shouldOutput) {
outputStringStream << endl;
@@ -198,12 +198,11 @@ namespace OAA {
}
}
- void learn(void*a, example* ec) {
- vw* all = (vw*)a;
- learn_with_output(all, ec, false);
+ void learn(void*a, void* d, example* ec) {
+ learn_with_output((vw*)a, (oaa*)d, ec, false);
}
- void drive_oaa(void *in)
+ void drive(void *in, void* d)
{
vw* all = (vw*)in;
example* ec = NULL;
@@ -211,7 +210,7 @@ namespace OAA {
{
if ((ec = get_example(all->p)) != NULL)//semiblocking operation.
{
- learn_with_output(all, ec, all->raw_prediction > 0);
+ learn_with_output(all, (oaa*)d, ec, all->raw_prediction > 0);
output_example(*all, ec);
VW::finish_example(*all, ec);
}
@@ -222,32 +221,38 @@ namespace OAA {
}
}
+ void finish(void* all, void* data)
+ {
+ oaa* o=(oaa*)data;
+ o->base.finish(all,o->base.data);
+ free(o);
+ }
+
void parse_flags(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
{
+ oaa* data = (oaa*)calloc(1, sizeof(oaa));
//first parse for number of actions
- k = 0;
if( vm_file.count("oaa") ) {
- k = (uint32_t)vm_file["oaa"].as<size_t>();
- if( vm.count("oaa") && (uint32_t)vm["oaa"].as<size_t>() != k )
- std::cerr << "warning: you specified a different number of actions through --oaa than the one loaded from predictor. Pursuing with loaded value of: " << k << endl;
+ data->k = (uint32_t)vm_file["oaa"].as<size_t>();
+ if( vm.count("oaa") && (uint32_t)vm["oaa"].as<size_t>() != data->k )
+ std::cerr << "warning: you specified a different number of actions through --oaa than the one loaded from predictor. Pursuing with loaded value of: " << data->k << endl;
}
else {
- k = (uint32_t)vm["oaa"].as<size_t>();
+ data->k = (uint32_t)vm["oaa"].as<size_t>();
//append oaa with nb_actions to options_from_file so it is saved to regressor later
std::stringstream ss;
- ss << " --oaa " << k;
+ ss << " --oaa " << data->k;
all.options_from_file.append(ss.str());
}
*(all.p->lp) = mc_label_parser;
- all.driver = drive_oaa;
- base_learner = all.learn;
- all.base_learn = all.learn;
- all.learn = learn;
-
- all.base_learner_nb_w *= k;
- increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride;
- total_increment = increment*(k-1);
+ data->increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride;
+ data->total_increment = data->increment*(data->k-1);
+ data->base = all.l;
+ learner l = {data, drive, learn, finish, all.l.save_load};
+ all.l = l;
+
+ all.base_learner_nb_w *= data->k;
}
}
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index 54c6ab6f..30bcc3f1 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -237,33 +237,8 @@ vw parse_args(int argc, char *argv[])
}
}
- if (vm.count("bfgs") || vm.count("conjugate_gradient")) {
- all.driver = BFGS::drive;
- all.learn = BFGS::learn;
- all.finish = BFGS::finish;
- all.save_load = BFGS::save_load;
- all.bfgs = true;
- all.stride = 4;
-
- if (vm.count("hessian_on") || all.m==0) {
- all.hessian_on = true;
- }
- if (!all.quiet) {
- if (all.m>0)
- cerr << "enabling BFGS based optimization ";
- else
- cerr << "enabling conjugate gradient optimization via BFGS ";
- if (all.hessian_on)
- cerr << "with curvature calculation" << endl;
- else
- cerr << "**without** curvature calculation" << endl;
- }
- if (all.numpasses < 2)
- {
- cout << "you must make at least 2 passes to use BFGS" << endl;
- exit(1);
- }
- }
+ if (vm.count("bfgs") || vm.count("conjugate_gradient"))
+ BFGS::parse_args(all, to_pass_further, vm, vm_file);
if (vm.count("version") || argc == 1) {
/* upon direct query for version -- spit it out to stdout */
@@ -467,18 +442,8 @@ vw parse_args(int argc, char *argv[])
//if (vm.count("nonormalize"))
// all.nonormalize = true;
- if (vm.count("lda")) {
- //default initial_t to 1 instead of 0
- if(!vm.count("initial_t")) {
- all.sd->t = 1.f;
- all.sd->weighted_unlabeled_examples = 1.f;
- all.initial_t = 1.f;
- }
-
+ if (vm.count("lda"))
LDA::parse_flags(all, to_pass_further, vm);
- all.driver = LDA::drive;
- all.save_load = LDA::save_load;
- }
if (!vm.count("lda") && !all.adaptive && !all.normalized_updates)
all.eta *= powf((float)(all.sd->t), all.power_t);
@@ -509,19 +474,11 @@ vw parse_args(int argc, char *argv[])
loss_parameter = vm["quantile_tau"].as<float>();
all.is_noop = false;
- if (vm.count("noop")) {
- all.driver = NOOP::drive;
- all.learn = NOOP::learn;
- all.save_load = NOOP::save_load;
- all.is_noop = true;
- }
+ if (vm.count("noop"))
+ NOOP::parse_flags(all);
- if (all.rank != 0) {
- all.driver = GDMF::drive;
- all.save_load = GDMF::save_load;
- loss_function = "classic";
- cerr << "Forcing classic squared loss for matrix factorization" << endl;
- }
+ if (all.rank != 0)
+ GDMF::parse_flags(all);
all.loss = getLossFunction(&all, loss_function, (float)loss_parameter);
@@ -585,14 +542,10 @@ vw parse_args(int argc, char *argv[])
all.audit = true;
if (vm.count("sendto"))
- {
- all.driver = SENDER::drive_send;
- all.save_load = SENDER::save_load;
- SENDER::parse_send_args(vm, all.pairs);
- }
+ SENDER::parse_send_args(all, vm, all.pairs);
// load rest of regressor
- all.save_load(&all, io_temp, true, false);
+ all.l.save_load(&all, all.l.data, io_temp, true, false);
io_temp.close_file();
if (all.l1_lambda < 0.) {
@@ -703,7 +656,7 @@ vw parse_args(int argc, char *argv[])
CSOAA::parse_flags(all, to_pass_further, vm, vm_file); // default to CSOAA unless others have been specified
got_cs = true;
}
- all.searnstr = (ImperativeSearn::searn_struct*)calloc(1, sizeof(ImperativeSearn::searn_struct));
+ all.searnstr = (ImperativeSearn::searn*)calloc(1, sizeof(ImperativeSearn::searn));
ImperativeSearn::parse_flags(all, to_pass_further, vm, vm_file);
}
@@ -821,7 +774,7 @@ namespace VW {
void finish(vw& all)
{
- all.finish(&all);
+ all.l.finish(&all, all.l.data);
if (all.searnstr != NULL) free(all.searnstr);
free_parser(all);
finalize_regressor(all, all.final_regressor_name);
diff --git a/vowpalwabbit/parse_regressor.cc b/vowpalwabbit/parse_regressor.cc
index 77afbb84..b5c4c8c6 100644
--- a/vowpalwabbit/parse_regressor.cc
+++ b/vowpalwabbit/parse_regressor.cc
@@ -196,7 +196,7 @@ void dump_regressor(vw& all, string reg_name, bool as_text)
io_temp.open_file(start_name.c_str(), all.stdin_off, io_buf::WRITE);
save_load_header(all, io_temp, false, as_text);
- all.save_load(&all, io_temp, false, as_text);
+ all.l.save_load(&all, all.l.data, io_temp, false, as_text);
io_temp.flush(); // close_file() should do this for me ...
io_temp.close_file();
diff --git a/vowpalwabbit/searn.cc b/vowpalwabbit/searn.cc
index b1ac2fdb..1bc650bb 100644
--- a/vowpalwabbit/searn.cc
+++ b/vowpalwabbit/searn.cc
@@ -279,32 +279,6 @@ namespace SearnUtil
namespace Searn
{
- // task stuff
- search_task task;
- bool is_singleline;
- bool is_ldf;
- bool has_hash;
- bool constrainted_actions;
- size_t input_label_size;
-
- // options
- size_t max_action = 1;
- size_t max_rollout = INT_MAX;
- size_t passes_per_policy = 1; //this should be set to the same value as --passes for dagger
- float beta = 0.5;
- float gamma = 1.;
- bool do_recombination = false;
- bool allow_current_policy = false; //this should be set to true for dagger
- bool rollout_oracle = false; //if true then rollout are performed using oracle instead (optimal approximation discussed in searn's paper). this should be set to true for dagger
- bool adaptive_beta = false; //used to implement dagger through searn. if true, beta = 1-(1-alpha)^n after n updates, and policy is mixed with oracle as \pi' = (1-beta)\pi^* + beta \pi
- float alpha = 0.001f; //parameter used to adapt beta for dagger (see above comment), should be in (0,1)
- bool rollout_all_actions = true; //by default we rollout all actions. This is set to false when searn is used with a contextual bandit base learner, where we rollout only one sampled action
-
- // debug stuff
- bool PRINT_DEBUG_INFO = 0;
- bool PRINT_UPDATE_EVERY_EXAMPLE = 0 | PRINT_DEBUG_INFO;
-
-
// rollout
struct rollout_item {
state st;
@@ -313,42 +287,69 @@ namespace Searn
size_t hash;
};
- // memory
- rollout_item* rollout;
- v_array<example*> ec_seq = v_array<example*>();
- example** global_example_set = NULL;
- example* empty_example = NULL;
- OAA::mc_label empty_label;
- v_array<CSOAA::wclass>loss_vector = v_array<CSOAA::wclass>();
- v_array<CB::cb_class>loss_vector_cb = v_array<CB::cb_class>();
- v_array<void*>old_labels = v_array<void*>();
- v_array<OAA::mc_label>new_labels = v_array<OAA::mc_label>();
- CSOAA::label testall_labels = { v_array<CSOAA::wclass>() };
- CSOAA::label allowed_labels = { v_array<CSOAA::wclass>() };
- CB::label testall_labels_cb = { v_array<CB::cb_class>() };
- CB::label allowed_labels_cb = { v_array<CB::cb_class>() };
-
- // we need a hashmap that maps from STATES to ACTIONS
- v_hashmap<state,action> *past_states = NULL;
- v_array<state> unfreed_states = v_array<state>();
-
- // tracking of example
- size_t read_example_this_loop = 0;
- size_t read_example_last_id = 0;
- size_t passes_since_new_policy = 0;
- size_t read_example_last_pass = 0;
- size_t total_examples_generated = 0;
- size_t total_predictions_made = 0;
- size_t searn_num_features = 0;
-
- // variables
- uint32_t current_policy = 0;
- uint32_t total_number_of_policies = 1;
- uint32_t increment = 0; //for policy offset
-
- void (*base_learner)(void*, example*) = NULL;
- void (*base_finish)(void*) = NULL;
+ // debug stuff
+ const bool PRINT_DEBUG_INFO =0;
+ const bool PRINT_UPDATE_EVERY_EXAMPLE =0;
+
+ struct searn {
+ // task stuff
+ search_task task;
+ bool is_singleline;
+ bool is_ldf;
+ bool has_hash;
+ bool constrainted_actions;
+ size_t input_label_size;
+
+ // options
+ size_t max_action;
+ size_t max_rollout;
+ size_t passes_per_policy; //this should be set to the same value as --passes for dagger
+ float beta;
+ float gamma;
+ bool do_recombination;
+ bool allow_current_policy; //this should be set to true for dagger
+ bool rollout_oracle; //if true then rollout are performed using oracle instead (optimal approximation discussed in searn's paper). this should be set to true for dagger
+ bool adaptive_beta; //used to implement dagger through searn. if true, beta = 1-(1-alpha)^n after n updates, and policy is mixed with oracle as \pi' = (1-beta)\pi^* + beta \pi
+ float alpha; //parameter used to adapt beta for dagger (see above comment), should be in (0,1)
+ bool rollout_all_actions; //by default we rollout all actions. This is set to false when searn is used with a contextual bandit base learner, where we rollout only one sampled action
+
+ // memory
+ rollout_item* rollout;
+ v_array<example*> ec_seq;
+ example** global_example_set;
+ example* empty_example;
+ OAA::mc_label empty_label;
+ v_array<CSOAA::wclass>loss_vector;
+ v_array<CB::cb_class>loss_vector_cb;
+ v_array<void*>old_labels;
+ v_array<OAA::mc_label>new_labels;
+ CSOAA::label testall_labels;
+ CSOAA::label allowed_labels;
+ CB::label testall_labels_cb;
+ CB::label allowed_labels_cb;
+ // we need a hashmap that maps from STATES to ACTIONS
+ v_hashmap<state,action> *past_states;
+ v_array<state> unfreed_states;
+
+ // tracking of example
+ size_t read_example_this_loop;
+ size_t read_example_last_id;
+ size_t passes_since_new_policy;
+ size_t read_example_last_pass;
+ size_t total_examples_generated;
+ size_t total_predictions_made;
+ size_t searn_num_features;
+
+ // variables
+ uint32_t current_policy;
+ uint32_t total_number_of_policies;
+ uint32_t increment; //for policy offset
+
+ learner base;
+ };
+ void drive(void*in, void*d);
+
void simple_print_example_features(vw&all, example *ec)
{
for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
@@ -400,35 +401,35 @@ namespace Searn
out[max_len] = 0;
}
- bool will_global_print_label(vw& all)
+ bool will_global_print_label(vw& all, searn& s)
{
- if (!task.to_string) return false;
+ if (!s.task.to_string) return false;
if (all.final_prediction_sink.size() == 0) return false;
return true;
}
- void global_print_label(vw& all, example*ec, state s0, std::vector<action> last_action_sequence)
+ void global_print_label(vw& all, searn& s, example*ec, state s0, std::vector<action> last_action_sequence)
{
- if (!task.to_string) return;
+ if (!s.task.to_string) return;
if (all.final_prediction_sink.size() == 0) return;
- string str = task.to_string(s0, false, last_action_sequence);
+ string str = s.task.to_string(s0, false, last_action_sequence);
for (size_t i=0; i<all.final_prediction_sink.size(); i++) {
int f = all.final_prediction_sink[i];
all.print_text(f, str, ec->tag);
}
}
- void print_update(vw& all, state s0, std::vector<action> last_action_sequence)
+ void print_update(vw& all, searn& s, state s0, std::vector<action> last_action_sequence)
{
if (!should_print_update(all))
return;
char true_label[21];
char pred_label[21];
- if (task.to_string) {
- to_short_string(task.to_string(s0, true , empty_action_vector ), 20, true_label);
- to_short_string(task.to_string(s0, false, last_action_sequence), 20, pred_label);
+ if (s.task.to_string) {
+ to_short_string(s.task.to_string(s0, true , empty_action_vector ), 20, true_label);
+ to_short_string(s.task.to_string(s0, false, last_action_sequence), 20, pred_label);
} else {
to_short_string("", 20, true_label);
to_short_string("", 20, pred_label);
@@ -441,11 +442,11 @@ namespace Searn
all.sd->weighted_examples,
true_label,
pred_label,
- (long unsigned int)searn_num_features,
- (int)read_example_last_pass,
- (int)current_policy,
- (long unsigned int)total_predictions_made,
- (long unsigned int)total_examples_generated);
+ (long unsigned int)s.searn_num_features,
+ (int)s.read_example_last_pass,
+ (int)s.current_policy,
+ (long unsigned int)s.total_predictions_made,
+ (long unsigned int)s.total_examples_generated);
all.sd->sum_loss_since_last_dump = 0.0;
all.sd->old_weighted_examples = all.sd->weighted_examples;
@@ -454,94 +455,118 @@ namespace Searn
- void clear_seq(vw&all)
+ void clear_seq(vw&all, searn& s)
{
- if (ec_seq.size() > 0)
- for (example** ecc=ec_seq.begin; ecc!=ec_seq.end; ecc++) {
+ if (s.ec_seq.size() > 0)
+ for (example** ecc=s.ec_seq.begin; ecc!=s.ec_seq.end; ecc++) {
VW::finish_example(all, *ecc);
}
- ec_seq.erase();
+ s.ec_seq.erase();
}
- void free_unfreed_states()
+ void free_unfreed_states(searn& s)
{
- while (!unfreed_states.empty()) {
- state s = unfreed_states.pop();
- task.finish(s);
+ while (!s.unfreed_states.empty()) {
+ state st = s.unfreed_states.pop();
+ s.task.finish(st);
}
}
- void initialize_memory()
+ void initialize_memory(searn& s)
{
// initialize searn's memory
- rollout = (rollout_item*)SearnUtil::calloc_or_die(max_action, sizeof(rollout_item));
- global_example_set = (example**)SearnUtil::calloc_or_die(max_action, sizeof(example*));
+ s.rollout = (rollout_item*)SearnUtil::calloc_or_die(s.max_action, sizeof(rollout_item));
+ s.global_example_set = (example**)SearnUtil::calloc_or_die(s.max_action, sizeof(example*));
- for (uint32_t k=1; k<=max_action; k++) {
+ for (uint32_t k=1; k<=s.max_action; k++) {
CSOAA::wclass cost = { FLT_MAX, k, 1., 0. };
- testall_labels.costs.push_back(cost);
+ s.testall_labels.costs.push_back(cost);
CB::cb_class cost_cb = { FLT_MAX, k, 0. };
- testall_labels_cb.costs.push_back(cost_cb);
+ s.testall_labels_cb.costs.push_back(cost_cb);
}
- empty_example = alloc_example(sizeof(OAA::mc_label));
- OAA::default_label(empty_example->ld);
+ s.empty_example = alloc_example(sizeof(OAA::mc_label));
+ OAA::default_label(s.empty_example->ld);
// cerr << "create: empty_example->ld = " << empty_example->ld << endl;
- empty_example->in_use = true;
+ s.empty_example->in_use = true;
}
- void free_memory(vw&all)
+ void free_memory(vw&all, searn& s)
{
- dealloc_example(NULL, *empty_example);
- free(empty_example);
+ dealloc_example(NULL, *s.empty_example);
+ free(s.empty_example);
- SearnUtil::free_it(rollout);
+ SearnUtil::free_it(s.rollout);
- loss_vector.delete_v();
+ s.loss_vector.delete_v();
- old_labels.delete_v();
+ s.old_labels.delete_v();
- new_labels.delete_v();
+ s.new_labels.delete_v();
- free_unfreed_states();
- unfreed_states.delete_v();
+ free_unfreed_states(s);
+ s.unfreed_states.delete_v();
- clear_seq(all);
- ec_seq.delete_v();
+ clear_seq(all,s);
+ s.ec_seq.delete_v();
- SearnUtil::free_it(global_example_set);
+ SearnUtil::free_it(s.global_example_set);
- testall_labels.costs.delete_v();
- testall_labels_cb.costs.delete_v();
- allowed_labels.costs.delete_v();
- allowed_labels_cb.costs.delete_v();
+ s.testall_labels.costs.delete_v();
+ s.testall_labels_cb.costs.delete_v();
+ s.allowed_labels.costs.delete_v();
+ s.allowed_labels_cb.costs.delete_v();
- if (do_recombination) {
- delete past_states;
- past_states = NULL;
+ if (s.do_recombination) {
+ delete s.past_states;
+ s.past_states = NULL;
}
}
- void learn(void*in, example *ec)
+ void learn(void*in, void*d, example *ec)
{
//vw*all = (vw*)in;
// TODO
}
- void finish(void*in)
+ void finish(void*in, void*d)
{
+ searn* s = (searn*)d;
+ s->base.finish(in,s->base.data);
vw*all = (vw*)in;
// free everything
- if (task.finalize != NULL)
- task.finalize();
- free_memory(*all);
- base_finish(all);
+ if (s->task.finalize != NULL)
+ s->task.finalize();
+ free_memory(*all,*s);
+ free(s);
}
void parse_flags(vw&all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
{
+ searn* s = (searn*)calloc(1,sizeof(searn));
+
+ s->max_action = 1;
+ s->max_rollout = INT_MAX;
+ s->passes_per_policy = 1; //this should be set to the same value as --passes for dagger
+ s->beta = 0.5;
+ s->gamma = 1.;
+ s->alpha = 0.001f; //parameter used to adapt beta for dagger (see above comment), should be in (0,1)
+ s->rollout_all_actions = true; //by default we rollout all actions. This is set to false when searn is used with a contextual bandit base learner, where we rollout only one sampled action
+ s->ec_seq = v_array<example*>();
+ s->loss_vector = v_array<CSOAA::wclass>();
+ s->loss_vector_cb = v_array<CB::cb_class>();
+ s->old_labels = v_array<void*>();
+ s->new_labels = v_array<OAA::mc_label>();
+ s->testall_labels.costs = v_array<CSOAA::wclass>();
+ s->allowed_labels.costs = v_array<CSOAA::wclass>();
+ s->testall_labels_cb.costs = v_array<CB::cb_class>();
+ s->allowed_labels_cb.costs = v_array<CB::cb_class>();
+ // we need a hashmap that maps from STATES to ACTIONS
+ s->unfreed_states = v_array<state>();
+ s->total_number_of_policies = 1;
+
po::options_description desc("Searn options");
desc.add_options()
("searn_task", po::value<string>(), "the searn task")
@@ -596,154 +621,154 @@ namespace Searn
}
if (task_string.compare("sequence") == 0) {
- task.final = SequenceTask::final;
- task.loss = SequenceTask::loss;
- task.step = SequenceTask::step;
- task.oracle = SequenceTask::oracle;
- task.copy = SequenceTask::copy;
- task.finish = SequenceTask::finish;
- task.searn_label_parser = OAA::mc_label_parser;
- task.is_test_example = SequenceTask::is_test_example;
- input_label_size = sizeof(OAA::mc_label);
- task.start_state = NULL;
- task.start_state_multiline = SequenceTask::start_state_multiline;
+ s->task.final = SequenceTask::final;
+ s->task.loss = SequenceTask::loss;
+ s->task.step = SequenceTask::step;
+ s->task.oracle = SequenceTask::oracle;
+ s->task.copy = SequenceTask::copy;
+ s->task.finish = SequenceTask::finish;
+ s->task.searn_label_parser = OAA::mc_label_parser;
+ s->task.is_test_example = SequenceTask::is_test_example;
+ s->input_label_size = sizeof(OAA::mc_label);
+ s->task.start_state = NULL;
+ s->task.start_state_multiline = SequenceTask::start_state_multiline;
if (1) {
- task.cs_example = SequenceTask::cs_example;
- task.cs_ldf_example = NULL;
+ s->task.cs_example = SequenceTask::cs_example;
+ s->task.cs_ldf_example = NULL;
} else {
- task.cs_example = NULL;
- task.cs_ldf_example = SequenceTask::cs_ldf_example;
+ s->task.cs_example = NULL;
+ s->task.cs_ldf_example = SequenceTask::cs_ldf_example;
}
- task.initialize = SequenceTask::initialize;
- task.finalize = NULL;
- task.equivalent = SequenceTask::equivalent;
- task.hash = SequenceTask::hash;
- task.allowed = SequenceTask::allowed;
- task.to_string = SequenceTask::to_string;
+ s->task.initialize = SequenceTask::initialize;
+ s->task.finalize = NULL;
+ s->task.equivalent = SequenceTask::equivalent;
+ s->task.hash = SequenceTask::hash;
+ s->task.allowed = SequenceTask::allowed;
+ s->task.to_string = SequenceTask::to_string;
} else {
std::cerr << "error: unknown search task '" << task_string << "'" << std::endl;
exit(-1);
}
- *(all.p->lp)=task.searn_label_parser;
+ *(all.p->lp)=s->task.searn_label_parser;
if(vm_file.count("searn")) { //we loaded searn flag from regressor file
- max_action = vm_file["searn"].as<size_t>();
- if( vm.count("searn") && vm["searn"].as<size_t>() != max_action )
- std::cerr << "warning: you specified a different number of actions through --searn than the one loaded from predictor. Pursuing with loaded value of: " << max_action << endl;
+ s->max_action = vm_file["searn"].as<size_t>();
+ if( vm.count("searn") && vm["searn"].as<size_t>() != s->max_action )
+ std::cerr << "warning: you specified a different number of actions through --searn than the one loaded from predictor. Pursuing with loaded value of: " << s->max_action << endl;
}
else {
- max_action = vm["searn"].as<size_t>();
+ s->max_action = vm["searn"].as<size_t>();
//append searn with nb_actions to options_from_file so it is saved to regressor later
std::stringstream ss;
- ss << " --searn " << max_action;
+ ss << " --searn " << s->max_action;
all.options_from_file.append(ss.str());
}
if(vm_file.count("searn_beta")) { //we loaded searn_beta flag from regressor file
- beta = vm_file["searn_beta"].as<float>();
- if (vm.count("searn_beta") && vm["searn_beta"].as<float>() != beta )
- std::cerr << "warning: you specified a different value through --searn_beta than the one loaded from predictor. Pursuing with loaded value of: " << beta << endl;
+ s->beta = vm_file["searn_beta"].as<float>();
+ if (vm.count("searn_beta") && vm["searn_beta"].as<float>() != s->beta )
+ std::cerr << "warning: you specified a different value through --searn_beta than the one loaded from predictor. Pursuing with loaded value of: " << s->beta << endl;
}
else {
- if (vm.count("searn_beta")) beta = vm["searn_beta"].as<float>();
+ if (vm.count("searn_beta")) s->beta = vm["searn_beta"].as<float>();
//append searn_beta to options_from_file so it is saved in the regressor file later
std::stringstream ss;
- ss << " --searn_beta " << beta;
+ ss << " --searn_beta " << s->beta;
all.options_from_file.append(ss.str());
}
- if (vm.count("searn_rollout")) max_rollout = vm["searn_rollout"].as<size_t>();
- if (vm.count("searn_passes_per_policy")) passes_per_policy = vm["searn_passes_per_policy"].as<size_t>();
+ if (vm.count("searn_rollout")) s->max_rollout = vm["searn_rollout"].as<size_t>();
+ if (vm.count("searn_passes_per_policy")) s->passes_per_policy = vm["searn_passes_per_policy"].as<size_t>();
- if (vm.count("searn_gamma")) gamma = vm["searn_gamma"].as<float>();
- if (vm.count("searn_norecombine")) do_recombination = false;
- if (vm.count("searn_allow_current_policy")) allow_current_policy = true;
- if (vm.count("searn_rollout_oracle")) rollout_oracle = true;
+ if (vm.count("searn_gamma")) s->gamma = vm["searn_gamma"].as<float>();
+ if (vm.count("searn_norecombine")) s->do_recombination = false;
+ if (vm.count("searn_allow_current_policy")) s->allow_current_policy = true;
+ if (vm.count("searn_rollout_oracle")) s->rollout_oracle = true;
//check if the base learner is contextual bandit, in which case, we dont rollout all actions.
- if ( vm.count("cb") || vm_file.count("cb") ) rollout_all_actions = false;
+ if ( vm.count("cb") || vm_file.count("cb") ) s->rollout_all_actions = false;
//if we loaded a regressor with -i option, --searn_trained_nb_policies contains the number of trained policies in the file
// and --searn_total_nb_policies contains the total number of policies in the file
if ( vm_file.count("searn_total_nb_policies") )
{
- current_policy = (uint32_t)vm_file["searn_trained_nb_policies"].as<size_t>();
- total_number_of_policies = (uint32_t)vm_file["searn_total_nb_policies"].as<size_t>();
- if (vm.count("searn_total_nb_policies") && (uint32_t)vm["searn_total_nb_policies"].as<size_t>() != total_number_of_policies)
- std::cerr << "warning: --searn_total_nb_policies doesn't match the total number of policies stored in initial predictor. Using loaded value of: " << total_number_of_policies << endl;
+ s->current_policy = (uint32_t)vm_file["searn_trained_nb_policies"].as<size_t>();
+ s->total_number_of_policies = (uint32_t)vm_file["searn_total_nb_policies"].as<size_t>();
+ if (vm.count("searn_total_nb_policies") && (uint32_t)vm["searn_total_nb_policies"].as<size_t>() != s->total_number_of_policies)
+ std::cerr << "warning: --searn_total_nb_policies doesn't match the total number of policies stored in initial predictor. Using loaded value of: " << s->total_number_of_policies << endl;
}
else if (vm.count("searn_total_nb_policies"))
{
- total_number_of_policies = (uint32_t)vm["searn_total_nb_policies"].as<size_t>();
+ s->total_number_of_policies = (uint32_t)vm["searn_total_nb_policies"].as<size_t>();
}
if (vm.count("searn_as_dagger"))
{
//overide previously loaded options to set searn as dagger
- allow_current_policy = true;
- passes_per_policy = all.numpasses;
- //rollout_oracle = true;
- if( current_policy > 1 )
- current_policy = 1;
+ s->allow_current_policy = true;
+ s->passes_per_policy = all.numpasses;
+ //s->rollout_oracle = true;
+ if( s->current_policy > 1 )
+ s->current_policy = 1;
//indicate to adapt beta for each update
- adaptive_beta = true;
- alpha = vm["searn_as_dagger"].as<float>();
+ s->adaptive_beta = true;
+ s->alpha = vm["searn_as_dagger"].as<float>();
}
- if (beta <= 0 || beta >= 1) {
+ if (s->beta <= 0 || s->beta >= 1) {
std::cerr << "warning: searn_beta must be in (0,1); resetting to 0.5" << std::endl;
- beta = 0.5;
+ s->beta = 0.5;
}
- if (gamma <= 0 || gamma > 1) {
+ if (s->gamma <= 0 || s->gamma > 1) {
std::cerr << "warning: searn_gamma must be in (0,1); resetting to 1.0" << std::endl;
- gamma = 1.0;
+ s->gamma = 1.0;
}
- if (alpha < 0 || alpha > 1) {
+ if (s->alpha < 0 || s->alpha > 1) {
std::cerr << "warning: searn_adaptive_beta must be in (0,1); resetting to 0.001" << std::endl;
- alpha = 0.001f;
+ s->alpha = 0.001f;
}
- if (task.initialize != NULL)
- if (!task.initialize(all, opts, vm, vm_file)) {
+ if (s->task.initialize != NULL)
+ if (!s->task.initialize(all, opts, vm, vm_file)) {
std::cerr << "error: task did not initialize properly" << std::endl;
exit(-1);
}
// check to make sure task is valid and set up our variables
- if (task.final == NULL ||
- task.loss == NULL ||
- task.step == NULL ||
- task.oracle == NULL ||
- task.copy == NULL ||
- task.finish == NULL ||
- ((task.start_state == NULL) == (task.start_state_multiline == NULL)) ||
- ((task.cs_example == NULL) == (task.cs_ldf_example == NULL))) {
+ if (s->task.final == NULL ||
+ s->task.loss == NULL ||
+ s->task.step == NULL ||
+ s->task.oracle == NULL ||
+ s->task.copy == NULL ||
+ s->task.finish == NULL ||
+ ((s->task.start_state == NULL) == (s->task.start_state_multiline == NULL)) ||
+ ((s->task.cs_example == NULL) == (s->task.cs_ldf_example == NULL))) {
std::cerr << "error: searn task malformed" << std::endl;
exit(-1);
}
- is_singleline = (task.start_state != NULL);
- is_ldf = (task.cs_example == NULL);
- has_hash = (task.hash != NULL);
- constrainted_actions = (task.allowed != NULL);
+ s->is_singleline = (s->task.start_state != NULL);
+ s->is_ldf = (s->task.cs_example == NULL);
+ s->has_hash = (s->task.hash != NULL);
+ s->constrainted_actions = (s->task.allowed != NULL);
- if (do_recombination && (task.hash == NULL)) {
+ if (s->do_recombination && (s->task.hash == NULL)) {
std::cerr << "warning: cannot do recombination when hashing is unavailable -- turning off recombination" << std::endl;
- do_recombination = false;
+ s->do_recombination = false;
}
- if (do_recombination) {
+ if (s->do_recombination) {
// 0 is an invalid action
- past_states = new v_hashmap<state,action>(1023, 0, task.equivalent);
+ s->past_states = new v_hashmap<state,action>(1023, 0, s->task.equivalent);
}
- if (is_ldf && !constrainted_actions) {
+ if (s->is_ldf && !s->constrainted_actions) {
std::cerr << "error: LDF requires allowed" << std::endl;
exit(-1);
}
@@ -752,16 +777,16 @@ namespace Searn
//compute total number of policies we will have at end of training
// we add current_policy for cases where we start from an initial set of policies loaded through -i option
- uint32_t tmp_number_of_policies = current_policy;
+ uint32_t tmp_number_of_policies = s->current_policy;
if( all.training )
- tmp_number_of_policies += (int)ceil(((float)all.numpasses) / ((float)passes_per_policy));
+ tmp_number_of_policies += (int)ceil(((float)all.numpasses) / ((float)s->passes_per_policy));
//the user might have specified the number of policies that will eventually be trained through multiple vw calls,
//so only set total_number_of_policies to computed value if it is larger
- if( tmp_number_of_policies > total_number_of_policies )
+ if( tmp_number_of_policies > s->total_number_of_policies )
{
- total_number_of_policies = tmp_number_of_policies;
- if( current_policy > 0 ) //we loaded a file but total number of policies didn't match what is needed for training
+ s->total_number_of_policies = tmp_number_of_policies;
+ if( s->current_policy > 0 ) //we loaded a file but total number of policies didn't match what is needed for training
{
std::cerr << "warning: you're attempting to train more classifiers than was allocated initially. Likely to cause bad performance." << endl;
}
@@ -770,108 +795,105 @@ namespace Searn
//current policy currently points to a new policy we would train
//if we are not training and loaded a bunch of policies for testing, we need to subtract 1 from current policy
//so that we only use those loaded when testing (as run_prediction is called with allow_current to true)
- if( !all.training && current_policy > 0 )
- current_policy--;
+ if( !all.training && s->current_policy > 0 )
+ s->current_policy--;
- //std::cerr << "Current Policy: " << current_policy << endl;
+ //std::cerr << "Current Policy: " << s->current_policy << endl;
//std::cerr << "Total Number of Policies: " << total_number_of_policies << endl;
std::stringstream ss1;
std::stringstream ss2;
- ss1 << current_policy;
+ ss1 << s->current_policy;
//use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searn_trained_nb_policies
VW::cmd_string_replace_value(all.options_from_file,"--searn_trained_nb_policies", ss1.str());
- ss2 << total_number_of_policies;
+ ss2 << s->total_number_of_policies;
//use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searn_total_nb_policies
VW::cmd_string_replace_value(all.options_from_file,"--searn_total_nb_policies", ss2.str());
- all.base_learner_nb_w *= total_number_of_policies;
- increment = ((uint32_t)all.length() / all.base_learner_nb_w) * all.stride;
- //cerr << "searn increment = " << increment << endl;
-
- all.driver = drive;
- base_learner = all.learn;
- all.learn = learn;
- base_finish = all.finish;
- all.finish = finish;
+ all.base_learner_nb_w *= s->total_number_of_policies;
+ s->increment = ((uint32_t)all.length() / all.base_learner_nb_w) * all.stride;
+ //cerr << "searn increment = " << s->increment << endl;
+
+ learner l = {s, drive, learn, finish, all.l.save_load};
+ all.l = l;
}
- uint32_t searn_predict(vw&all, state s0, size_t step, bool allow_oracle, bool allow_current, v_array< pair<uint32_t,float> >* partial_predictions) // TODO: partial_predictions
+ uint32_t searn_predict(vw&all, searn& s, state s0, size_t step, bool allow_oracle, bool allow_current, v_array< pair<uint32_t,float> >* partial_predictions) // TODO: partial_predictions
{
- int policy = SearnUtil::random_policy(read_example_last_id * 2147483 + step * 2147483647 /* has_hash ? task.hash(s0) : step */, beta, allow_current, (int)current_policy, allow_oracle, rollout_all_actions);
- if (PRINT_DEBUG_INFO) { cerr << "predicing with policy " << policy << " (allow_oracle=" << allow_oracle << ", allow_current=" << allow_current << "), current_policy=" << current_policy << endl; }
+ int policy = SearnUtil::random_policy(s.read_example_last_id * 2147483 + step * 2147483647 /* s.has_hash ? s.task.hash(s0) : step */, s.beta, allow_current, (int)s.current_policy, allow_oracle, s.rollout_all_actions);
+ if (PRINT_DEBUG_INFO) { cerr << "predicing with policy " << policy << " (allow_oracle=" << allow_oracle << ", allow_current=" << allow_current << "), current_policy=" << s.current_policy << endl; }
if (policy == -1) {
- return task.oracle(s0);
+ return s.task.oracle(s0);
}
example *ec;
- if (!is_ldf) {
- task.cs_example(all, s0, ec, true);
- SearnUtil::add_policy_offset(all, ec, increment, policy);
+ if (!s.is_ldf) {
+ s.task.cs_example(all, s0, ec, true);
+ SearnUtil::add_policy_offset(all, ec, s.increment, policy);
void* old_label = ec->ld;
- if(rollout_all_actions) { //this means we have a cost-sensitive base learner
- ec->ld = (void*)&testall_labels;
- if (task.allowed != NULL) { // we need to check which actions are allowed
- allowed_labels.costs.erase();
+ if(s.rollout_all_actions) { //this means we have a cost-sensitive base learner
+ ec->ld = (void*)&s.testall_labels;
+ if (s.task.allowed != NULL) { // we need to check which actions are allowed
+ s.allowed_labels.costs.erase();
bool all_allowed = true;
- for (uint32_t k=1; k<=max_action; k++)
- if (task.allowed(s0, k)) {
+ for (uint32_t k=1; k<=s.max_action; k++)
+ if (s.task.allowed(s0, k)) {
CSOAA::wclass cost = { FLT_MAX, k, 1., 0. };
- allowed_labels.costs.push_back(cost);
+ s.allowed_labels.costs.push_back(cost);
} else
all_allowed = false;
if (!all_allowed)
- ec->ld = (void*)&allowed_labels;
+ ec->ld = (void*)&s.allowed_labels;
}
}
else { //if we have a contextual bandit base learner
- ec->ld = (void*)&testall_labels_cb;
- if (task.allowed != NULL) { // we need to check which actions are allowed
- allowed_labels_cb.costs.erase();
+ ec->ld = (void*)&s.testall_labels_cb;
+ if (s.task.allowed != NULL) { // we need to check which actions are allowed
+ s.allowed_labels_cb.costs.erase();
bool all_allowed = true;
- for (uint32_t k=1; k<=max_action; k++)
- if (task.allowed(s0, k)) {
+ for (uint32_t k=1; k<=s.max_action; k++)
+ if (s.task.allowed(s0, k)) {
CB::cb_class cost = { FLT_MAX, k, 0. };
- allowed_labels_cb.costs.push_back(cost);
+ s.allowed_labels_cb.costs.push_back(cost);
} else
all_allowed = false;
if (!all_allowed)
- ec->ld = (void*)&allowed_labels_cb;
+ ec->ld = (void*)&s.allowed_labels_cb;
}
}
//cerr << "searn>";
//simple_print_example_features(all,ec);
- base_learner(&all,ec);
- total_predictions_made++;
- searn_num_features += ec->num_features;
+ s.base.learn(&all,s.base.data,ec);
+ s.total_predictions_made++;
+ s.searn_num_features += ec->num_features;
uint32_t final_prediction = (uint32_t)(*(OAA::prediction_t*)&(ec->final_prediction));
ec->ld = old_label;
- SearnUtil::remove_policy_offset(all, ec, increment, policy);
- task.cs_example(all, s0, ec, false);
+ SearnUtil::remove_policy_offset(all, ec, s.increment, policy);
+ s.task.cs_example(all, s0, ec, false);
return final_prediction;
} else { // is_ldf
//TODO: modify this to handle contextual bandit base learner with ldf
float best_prediction = 0;
uint32_t best_action = 0;
- for (uint32_t action=1; action <= max_action; action++) {
- if (!task.allowed(s0, action))
+ for (uint32_t action=1; action <= s.max_action; action++) {
+ if (!s.task.allowed(s0, action))
break; // for LDF, there are no more actions
- task.cs_ldf_example(all, s0, action, ec, true);
+ s.task.cs_ldf_example(all, s0, action, ec, true);
//cerr << "created example: " << ec << ", label: " << ec->ld << endl;
- SearnUtil::add_policy_offset(all, ec, increment, policy);
- base_learner(&all,ec); total_predictions_made++; searn_num_features += ec->num_features;
+ SearnUtil::add_policy_offset(all, ec, s.increment, policy);
+ s.base.learn(&all,s.base.data,ec); s.total_predictions_made++; s.searn_num_features += ec->num_features;
//cerr << "base_learned on example: " << ec << endl;
- empty_example->in_use = true;
- base_learner(&all,empty_example);
- //cerr << "base_learned on empty example: " << empty_example << endl;
- SearnUtil::remove_policy_offset(all, ec, increment, policy);
+ s.empty_example->in_use = true;
+ s.base.learn(&all,s.base.data,s.empty_example);
+ //cerr << "base_learned on empty example: " << s.empty_example << endl;
+ SearnUtil::remove_policy_offset(all, ec, s.increment, policy);
if (action == 1 ||
ec->partial_prediction < best_prediction) {
@@ -879,7 +901,7 @@ namespace Searn
best_action = action;
}
//cerr << "releasing example: " << ec << ", label: " << ec->ld << endl;
- task.cs_ldf_example(all, s0, action, ec, false);
+ s.task.cs_ldf_example(all, s0, action, ec, false);
}
if (best_action < 1) {
@@ -890,111 +912,111 @@ namespace Searn
}
}
- float single_rollout(vw&all, state s0, uint32_t action)
+ float single_rollout(vw&all, searn& s, state s0, uint32_t action)
{
//first check if action is valid for current state
- if( action < 1 || action > max_action || (task.allowed && !task.allowed(s0,action)) )
+ if( action < 1 || action > s.max_action || (s.task.allowed && !s.task.allowed(s0,action)) )
{
std::cerr << "warning: asked to rollout an unallowed action: " << action << "; not performing rollout." << std::endl;
return 0;
}
//copy state and step it with current action
- rollout[action-1].alive = true;
- rollout[action-1].st = task.copy(s0);
- task.step(rollout[action-1].st, action);
- rollout[action-1].is_finished = task.final(rollout[action-1].st);
- if (do_recombination) rollout[action-1].hash = task.hash(rollout[action-1].st);
+ s.rollout[action-1].alive = true;
+ s.rollout[action-1].st = s.task.copy(s0);
+ s.task.step(s.rollout[action-1].st, action);
+ s.rollout[action-1].is_finished = s.task.final(s.rollout[action-1].st);
+ if (s.do_recombination) s.rollout[action-1].hash = s.task.hash(s.rollout[action-1].st);
//if not finished complete rollout
- if (!rollout[action-1].is_finished) {
- for (size_t step=1; step<max_rollout; step++) {
+ if (!s.rollout[action-1].is_finished) {
+ for (size_t step=1; step<s.max_rollout; step++) {
uint32_t act_tmp = 0;
- if (do_recombination)
- act_tmp = past_states->get(rollout[action-1].st, rollout[action-1].hash);
+ if (s.do_recombination)
+ act_tmp = s.past_states->get(s.rollout[action-1].st, s.rollout[action-1].hash);
if (act_tmp == 0) { // this means we didn't find it or we're not recombining
- if( !rollout_oracle )
- act_tmp = searn_predict(all, rollout[action-1].st, step, true, allow_current_policy, NULL);
+ if( !s.rollout_oracle )
+ act_tmp = searn_predict(all, s, s.rollout[action-1].st, step, true, s.allow_current_policy, NULL);
else
- act_tmp = task.oracle(rollout[action-1].st);
+ act_tmp = s.task.oracle(s.rollout[action-1].st);
- if (do_recombination) {
+ if (s.do_recombination) {
// we need to make a copy of the state
- state copy = task.copy(rollout[action-1].st);
- past_states->put_after_get(copy, rollout[action-1].hash, act_tmp);
- unfreed_states.push_back(copy);
+ state copy = s.task.copy(s.rollout[action-1].st);
+ s.past_states->put_after_get(copy, s.rollout[action-1].hash, act_tmp);
+ s.unfreed_states.push_back(copy);
}
}
- task.step(rollout[action-1].st, act_tmp);
- rollout[action-1].is_finished = task.final(rollout[action-1].st);
- if (do_recombination) rollout[action-1].hash = task.hash(rollout[action-1].st);
- if (rollout[action-1].is_finished) break;
+ s.task.step(s.rollout[action-1].st, act_tmp);
+ s.rollout[action-1].is_finished = s.task.final(s.rollout[action-1].st);
+ if (s.do_recombination) s.rollout[action-1].hash = s.task.hash(s.rollout[action-1].st);
+ if (s.rollout[action-1].is_finished) break;
}
}
// finally, compute losses and free copies
- float l = task.loss(rollout[action-1].st);
- if ((l == FLT_MAX) && (!rollout[action-1].is_finished) && (max_rollout < INT_MAX)) {
+ float l = s.task.loss(s.rollout[action-1].st);
+ if ((l == FLT_MAX) && (!s.rollout[action-1].is_finished) && (s.max_rollout < INT_MAX)) {
std::cerr << "error: you asked for short rollouts, but your task does not support pre-final losses" << std::endl;
exit(-1);
}
- task.finish(rollout[action-1].st);
+ s.task.finish(s.rollout[action-1].st);
return l;
}
- void parallel_rollout(vw&all, state s0)
+ void parallel_rollout(vw&all, searn& s, state s0)
{
// first, make K copies of s0 and step them
bool all_finished = true;
- for (size_t k=1; k<=max_action; k++)
- rollout[k-1].alive = false;
+ for (size_t k=1; k<=s.max_action; k++)
+ s.rollout[k-1].alive = false;
- for (uint32_t k=1; k<=max_action; k++) {
+ for (uint32_t k=1; k<=s.max_action; k++) {
// in the case of LDF, we might run out of actions early
- if (task.allowed && !task.allowed(s0, k)) {
- if (is_ldf) break;
+ if (s.task.allowed && !s.task.allowed(s0, k)) {
+ if (s.is_ldf) break;
else continue;
}
- rollout[k-1].alive = true;
- rollout[k-1].st = task.copy(s0);
- task.step(rollout[k-1].st, k);
- rollout[k-1].is_finished = task.final(rollout[k-1].st);
- if (do_recombination) rollout[k-1].hash = task.hash(rollout[k-1].st);
- all_finished = all_finished && rollout[k-1].is_finished;
+ s.rollout[k-1].alive = true;
+ s.rollout[k-1].st = s.task.copy(s0);
+ s.task.step(s.rollout[k-1].st, k);
+ s.rollout[k-1].is_finished = s.task.final(s.rollout[k-1].st);
+ if (s.do_recombination) s.rollout[k-1].hash = s.task.hash(s.rollout[k-1].st);
+ all_finished = all_finished && s.rollout[k-1].is_finished;
}
// now, complete all rollouts
if (!all_finished) {
- for (size_t step=1; step<max_rollout; step++) {
+ for (size_t step=1; step<s.max_rollout; step++) {
all_finished = true;
- for (size_t k=1; k<=max_action; k++) {
- if (rollout[k-1].is_finished) continue;
+ for (size_t k=1; k<=s.max_action; k++) {
+ if (s.rollout[k-1].is_finished) continue;
uint32_t action = 0;
- if (do_recombination)
- action = past_states->get(rollout[k-1].st, rollout[k-1].hash);
+ if (s.do_recombination)
+ action = s.past_states->get(s.rollout[k-1].st, s.rollout[k-1].hash);
if (action == 0) { // this means we didn't find it or we're not recombining
- if( !rollout_oracle )
- action = searn_predict(all, rollout[k-1].st, step, true, allow_current_policy, NULL);
+ if( !s.rollout_oracle )
+ action = searn_predict(all, s, s.rollout[k-1].st, step, true, s.allow_current_policy, NULL);
else
- action = task.oracle(rollout[k-1].st);
+ action = s.task.oracle(s.rollout[k-1].st);
- if (do_recombination) {
+ if (s.do_recombination) {
// we need to make a copy of the state
- state copy = task.copy(rollout[k-1].st);
- past_states->put_after_get(copy, rollout[k-1].hash, action);
- unfreed_states.push_back(copy);
+ state copy = s.task.copy(s.rollout[k-1].st);
+ s.past_states->put_after_get(copy, s.rollout[k-1].hash, action);
+ s.unfreed_states.push_back(copy);
}
}
- task.step(rollout[k-1].st, action);
- rollout[k-1].is_finished = task.final(rollout[k-1].st);
- if (do_recombination) rollout[k-1].hash = task.hash(rollout[k-1].st);
- all_finished = all_finished && rollout[k-1].is_finished;
+ s.task.step(s.rollout[k-1].st, action);
+ s.rollout[k-1].is_finished = s.task.final(s.rollout[k-1].st);
+ if (s.do_recombination) s.rollout[k-1].hash = s.task.hash(s.rollout[k-1].st);
+ all_finished = all_finished && s.rollout[k-1].is_finished;
}
if (all_finished) break;
}
@@ -1002,39 +1024,39 @@ namespace Searn
// finally, compute losses and free copies
float min_loss = 0;
- loss_vector.erase();
- for (uint32_t k=1; k<=max_action; k++) {
- if (!rollout[k-1].alive)
+ s.loss_vector.erase();
+ for (uint32_t k=1; k<=s.max_action; k++) {
+ if (!s.rollout[k-1].alive)
break;
- float l = task.loss(rollout[k-1].st);
- if ((l == FLT_MAX) && (!rollout[k-1].is_finished) && (max_rollout < INT_MAX)) {
+ float l = s.task.loss(s.rollout[k-1].st);
+ if ((l == FLT_MAX) && (!s.rollout[k-1].is_finished) && (s.max_rollout < INT_MAX)) {
std::cerr << "error: you asked for short rollouts, but your task does not support pre-final losses" << std::endl;
exit(-1);
}
CSOAA::wclass temp = { l, k, 1., 0. };
- loss_vector.push_back(temp);
+ s.loss_vector.push_back(temp);
if ((k == 1) || (l < min_loss)) { min_loss = l; }
- task.finish(rollout[k-1].st);
+ s.task.finish(s.rollout[k-1].st);
}
// subtract the smallest loss
- for (size_t k=1; k<=max_action; k++)
- if (rollout[k-1].alive)
- loss_vector[k-1].x -= min_loss;
+ for (size_t k=1; k<=s.max_action; k++)
+ if (s.rollout[k-1].alive)
+ s.loss_vector[k-1].x -= min_loss;
}
- uint32_t uniform_exploration(state s0, float& prob_sampled_action)
+ uint32_t uniform_exploration(searn& s, state s0, float& prob_sampled_action)
{
//find how many valid actions
- size_t nb_allowed_actions = max_action;
- if( task.allowed ) {
- for (uint32_t k=1; k<=max_action; k++) {
- if( !task.allowed(s0,k) ) {
+ size_t nb_allowed_actions = s.max_action;
+ if( s.task.allowed ) {
+ for (uint32_t k=1; k<=s.max_action; k++) {
+ if( !s.task.allowed(s0,k) ) {
nb_allowed_actions--;
- if (is_ldf) {
+ if (s.is_ldf) {
nb_allowed_actions = k-1;
break;
}
@@ -1043,25 +1065,25 @@ namespace Searn
}
uint32_t action = (size_t)(frand48() * nb_allowed_actions) + 1;
- if( task.allowed && nb_allowed_actions < max_action && !is_ldf) {
+ if( s.task.allowed && nb_allowed_actions < s.max_action && !s.is_ldf) {
//need to adjust action to the corresponding valid action
for (uint32_t k=1; k<=action; k++) {
- if( !task.allowed(s0,k) ) action++;
+ if( !s.task.allowed(s0,k) ) action++;
}
}
prob_sampled_action = (float) (1.0/nb_allowed_actions);
return action;
}
- void get_contextual_bandit_loss_vector(vw&all, state s0)
+ void get_contextual_bandit_loss_vector(vw&all, searn& s, state s0)
{
float prob_sampled = 1.;
- uint32_t act = uniform_exploration(s0,prob_sampled);
- float loss = single_rollout(all,s0,act);
+ uint32_t act = uniform_exploration(s, s0,prob_sampled);
+ float loss = single_rollout(all,s, s0,act);
- loss_vector_cb.erase();
- for (uint32_t k=1; k<=max_action; k++) {
- if( task.allowed && !task.allowed(s0,k))
+ s.loss_vector_cb.erase();
+ for (uint32_t k=1; k<=s.max_action; k++) {
+ if( s.task.allowed && !s.task.allowed(s0,k))
break;
CB::cb_class temp;
@@ -1072,150 +1094,150 @@ namespace Searn
temp.x = loss;
temp.prob_action = prob_sampled;
}
- loss_vector_cb.push_back(temp);
+ s.loss_vector_cb.push_back(temp);
}
}
- void generate_state_example(vw&all, state s0)
+ void generate_state_example(vw&all, searn& s, state s0)
{
// start by doing rollouts so we can get costs
- loss_vector.erase();
- loss_vector_cb.erase();
- if( rollout_all_actions ) {
- parallel_rollout(all, s0);
+ s.loss_vector.erase();
+ s.loss_vector_cb.erase();
+ if( s.rollout_all_actions ) {
+ parallel_rollout(all, s, s0);
}
else {
- get_contextual_bandit_loss_vector(all, s0);
+ get_contextual_bandit_loss_vector(all, s, s0);
}
- if (loss_vector.size() <= 1 && loss_vector_cb.size() == 0) {
+ if (s.loss_vector.size() <= 1 && s.loss_vector_cb.size() == 0) {
// nothing interesting to do!
return;
}
// now, generate training examples
- if (!is_ldf) {
- total_examples_generated++;
+ if (!s.is_ldf) {
+ s.total_examples_generated++;
example* ec;
- task.cs_example(all, s0, ec, true);
+ s.task.cs_example(all, s0, ec, true);
void* old_label = ec->ld;
- if(rollout_all_actions) {
- CSOAA::label ld = { loss_vector };
+ if(s.rollout_all_actions) {
+ CSOAA::label ld = { s.loss_vector };
ec->ld = (void*)&ld;
}
else {
- CB::label ld = { loss_vector_cb };
+ CB::label ld = { s.loss_vector_cb };
ec->ld = (void*)&ld;
}
- SearnUtil::add_policy_offset(all, ec, increment, current_policy);
- base_learner(&all,ec);
- SearnUtil::remove_policy_offset(all, ec, increment, current_policy);
+ SearnUtil::add_policy_offset(all, ec, s.increment, s.current_policy);
+ s.base.learn(&all,s.base.data,ec);
+ SearnUtil::remove_policy_offset(all, ec, s.increment, s.current_policy);
ec->ld = old_label;
- task.cs_example(all, s0, ec, false);
+ s.task.cs_example(all, s0, ec, false);
} else { // is_ldf
//TODO: support ldf with contextual bandit base learner
- old_labels.erase();
- new_labels.erase();
+ s.old_labels.erase();
+ s.new_labels.erase();
- for (uint32_t k=1; k<=max_action; k++) {
- if (rollout[k-1].alive) {
- OAA::mc_label ld = { k, loss_vector[k-1].x };
- new_labels.push_back(ld);
+ for (uint32_t k=1; k<=s.max_action; k++) {
+ if (s.rollout[k-1].alive) {
+ OAA::mc_label ld = { k, s.loss_vector[k-1].x };
+ s.new_labels.push_back(ld);
} else {
OAA::mc_label ld = { k, 0. };
- new_labels.push_back(ld);
+ s.new_labels.push_back(ld);
}
}
// cerr << "vvvvvvvvvvvvvvvvvvvvvvvvvvvv" << endl;
- for (uint32_t k=1; k<=max_action; k++) {
- if (!rollout[k-1].alive) break;
+ for (uint32_t k=1; k<=s.max_action; k++) {
+ if (!s.rollout[k-1].alive) break;
- total_examples_generated++;
+ s.total_examples_generated++;
- task.cs_ldf_example(all, s0, k, global_example_set[k-1], true);
- old_labels.push_back(global_example_set[k-1]->ld);
- global_example_set[k-1]->ld = (void*)(&new_labels[k-1]);
- SearnUtil::add_policy_offset(all, global_example_set[k-1], increment, current_policy);
- if (PRINT_DEBUG_INFO) { cerr << "add_policy_offset, max_action=" << max_action << ", total_number_of_policies=" << total_number_of_policies << ", current_policy=" << current_policy << endl;}
- base_learner(&all,global_example_set[k-1]);
+ s.task.cs_ldf_example(all, s0, k, s.global_example_set[k-1], true);
+ s.old_labels.push_back(s.global_example_set[k-1]->ld);
+ s.global_example_set[k-1]->ld = (void*)(&s.new_labels[k-1]);
+ SearnUtil::add_policy_offset(all, s.global_example_set[k-1], s.increment, s.current_policy);
+ if (PRINT_DEBUG_INFO) { cerr << "add_policy_offset, s.max_action=" << s.max_action << ", total_number_of_policies=" << s.total_number_of_policies << ", current_policy=" << s.current_policy << endl;}
+ s.base.learn(&all,s.base.data,s.global_example_set[k-1]);
}
- // cerr << "============================ (empty = " << empty_example << ")" << endl;
- empty_example->in_use = true;
- base_learner(&all,empty_example);
+ // cerr << "============================ (empty = " << s.empty_example << ")" << endl;
+ s.empty_example->in_use = true;
+ s.base.learn(&all,s.base.data,s.empty_example);
- for (uint32_t k=1; k<=max_action; k++) {
- if (!rollout[k-1].alive) break;
- SearnUtil::remove_policy_offset(all, global_example_set[k-1], increment, current_policy);
- global_example_set[k-1]->ld = old_labels[k-1];
- task.cs_ldf_example(all, s0, k, global_example_set[k-1], false);
+ for (uint32_t k=1; k<=s.max_action; k++) {
+ if (!s.rollout[k-1].alive) break;
+ SearnUtil::remove_policy_offset(all, s.global_example_set[k-1], s.increment, s.current_policy);
+ s.global_example_set[k-1]->ld = s.old_labels[k-1];
+ s.task.cs_ldf_example(all, s0, k, s.global_example_set[k-1], false);
}
// cerr << "^^^^^^^^^^^^^^^^^^^^^^^^^^^^" << endl;
}
}
- void run_prediction(vw&all, state s0, bool allow_oracle, bool allow_current, bool track_actions, std::vector<action>* action_sequence)
+ void run_prediction(vw&all, searn& s, state s0, bool allow_oracle, bool allow_current, bool track_actions, std::vector<action>* action_sequence)
{
int step = 1;
- while (!task.final(s0)) {
- uint32_t action = searn_predict(all, s0, step, allow_oracle, allow_current, NULL);
+ while (!s.task.final(s0)) {
+ uint32_t action = searn_predict(all, s, s0, step, allow_oracle, allow_current, NULL);
if (track_actions)
action_sequence->push_back(action);
- task.step(s0, action);
+ s.task.step(s0, action);
step++;
}
}
- void do_actual_learning(vw&all)
+ void do_actual_learning(vw&all, searn& s)
{
// there are two cases:
// * is_singleline --> look only at ec_seq[0]
// * otherwise --> look at everything
- if (ec_seq.size() == 0)
+ if (s.ec_seq.size() == 0)
return;
// generate the start state
state s0;
- if (is_singleline)
- task.start_state(ec_seq[0], &s0);
+ if (s.is_singleline)
+ s.task.start_state(s.ec_seq[0], &s0);
else
- task.start_state_multiline(ec_seq.begin, ec_seq.size(), &s0);
+ s.task.start_state_multiline(s.ec_seq.begin, s.ec_seq.size(), &s0);
state s0copy = NULL;
- bool is_test = task.is_test_example(ec_seq.begin, ec_seq.size());
+ bool is_test = s.task.is_test_example(s.ec_seq.begin, s.ec_seq.size());
if (!is_test) {
- s0copy = task.copy(s0);
+ s0copy = s.task.copy(s0);
all.sd->example_number++;
- all.sd->total_features += searn_num_features;
+ all.sd->total_features += s.searn_num_features;
all.sd->weighted_examples += 1.;
}
- bool will_print = is_test || should_print_update(all) || will_global_print_label(all);
+ bool will_print = is_test || should_print_update(all) || will_global_print_label(all, s);
- searn_num_features = 0;
+ s.searn_num_features = 0;
std::vector<action> action_sequence;
// if we are using adaptive beta, update it to take into account the latest updates
- if( adaptive_beta ) beta = 1.f - powf(1.f - alpha,(float)total_examples_generated);
+ if( s.adaptive_beta ) s.beta = 1.f - powf(1.f - s.alpha,(float)s.total_examples_generated);
- run_prediction(all, s0, false, true, will_print, &action_sequence);
- global_print_label(all, ec_seq[0], s0, action_sequence);
+ run_prediction(all, s, s0, false, true, will_print, &action_sequence);
+ global_print_label(all, s, s.ec_seq[0], s0, action_sequence);
if (!is_test) {
- float loss = task.loss(s0);
+ float loss = s.task.loss(s0);
all.sd->sum_loss += loss;
all.sd->sum_loss_since_last_dump += loss;
}
- print_update(all, s0, action_sequence);
+ print_update(all, s, s0, action_sequence);
- task.finish(s0);
+ s.task.finish(s0);
if (is_test || !all.training)
return;
@@ -1224,105 +1246,106 @@ namespace Searn
// training examples only get here
int step = 1;
- while (!task.final(s0)) {
+ while (!s.task.final(s0)) {
// if we are using adaptive beta, update it to take into account the latest updates
- if( adaptive_beta ) beta = 1.f - powf(1.f - alpha,(float)total_examples_generated);
+ if( s.adaptive_beta ) s.beta = 1.f - powf(1.f - s.alpha,(float)s.total_examples_generated);
// first, make a prediction (we don't want to bias ourselves if
// we're using the current policy to predict)
- uint32_t action = searn_predict(all, s0, step, true, allow_current_policy, NULL);
+ uint32_t action = searn_predict(all, s, s0, step, true, s.allow_current_policy, NULL);
// generate training example for the current state
- generate_state_example(all, s0);
+ generate_state_example(all, s, s0);
// take the prescribed step
- task.step(s0, action);
+ s.task.step(s0, action);
step++;
}
- task.finish(s0);
- if (do_recombination) { // we need to free a bunch of memory
- // past_states->iter(&hm_free_state_copies);
- free_unfreed_states();
- past_states->clear();
+ s.task.finish(s0);
+ if (s.do_recombination) { // we need to free a bunch of memory
+ // s.past_states->iter(&hm_free_state_copies);
+ free_unfreed_states(s);
+ s.past_states->clear();
}
}
- void process_next_example(vw&all, example *ec)
+ void process_next_example(vw&all, searn& s, example *ec)
{
bool is_real_example = true;
- if (is_singleline) {
- if (ec_seq.size() == 0)
- ec_seq.push_back(ec);
+ if (s.is_singleline) {
+ if (s.ec_seq.size() == 0)
+ s.ec_seq.push_back(ec);
else
- ec_seq[0] = ec;
+ s.ec_seq[0] = ec;
- do_actual_learning(all);
+ do_actual_learning(all, s);
} else {
// is multiline
- if (ec_seq.size() >= all.p->ring_size - 2) { // give some wiggle room
+ if (s.ec_seq.size() >= all.p->ring_size - 2) { // give some wiggle room
std::cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << std::endl;
- do_actual_learning(all);
- clear_seq(all);
+ do_actual_learning(all, s);
+ clear_seq(all, s);
}
if (OAA::example_is_newline(ec)) {
- do_actual_learning(all);
- clear_seq(all);
+ do_actual_learning(all, s);
+ clear_seq(all, s);
//CSOAA_LDF::global_print_newline(all);
VW::finish_example(all, ec);
is_real_example = false;
} else {
- ec_seq.push_back(ec);
+ s.ec_seq.push_back(ec);
}
}
// for both single and multiline
if (is_real_example) {
- read_example_this_loop++;
- read_example_last_id = ec->example_counter;
- if (ec->pass != read_example_last_pass) {
- read_example_last_pass = ec->pass;
- passes_since_new_policy++;
- if (passes_since_new_policy >= passes_per_policy) {
- passes_since_new_policy = 0;
+ s.read_example_this_loop++;
+ s.read_example_last_id = ec->example_counter;
+ if (ec->pass != s.read_example_last_pass) {
+ s.read_example_last_pass = ec->pass;
+ s.passes_since_new_policy++;
+ if (s.passes_since_new_policy >= s.passes_per_policy) {
+ s.passes_since_new_policy = 0;
if(all.training)
- current_policy++;
- if (current_policy > total_number_of_policies) {
+ s.current_policy++;
+ if (s.current_policy > s.total_number_of_policies) {
std::cerr << "internal error (bug): too many policies; not advancing" << std::endl;
- current_policy = total_number_of_policies;
+ s.current_policy = s.total_number_of_policies;
}
//reset searn_trained_nb_policies in options_from_file so it is saved to regressor file later
std::stringstream ss;
- ss << current_policy;
+ ss << s.current_policy;
VW::cmd_string_replace_value(all.options_from_file,"--searn_trained_nb_policies", ss.str());
}
}
}
}
- void drive(void*in)
+ void drive(void*in, void*d)
{
vw*all = (vw*)in;
// initialize everything
-
+ searn* s = (searn*)d;
+
const char * header_fmt = "%-10s %-10s %8s %15s %24s %22s %8s %5s %5s %15s %15s\n";
fprintf(stderr, header_fmt, "average", "since", "sequence", "example", "current label", "current predicted", "current", "cur", "cur", "predic.", "examples");
fprintf(stderr, header_fmt, "loss", "last", "counter", "weight", "sequence prefix", "sequence prefix", "features", "pass", "pol", "made", "gener.");
cerr.precision(5);
- initialize_memory();
+ initialize_memory(*s);
example* ec = NULL;
- read_example_this_loop = 0;
+ s->read_example_this_loop = 0;
while (true) {
if ((ec = get_example(all->p)) != NULL) { // semiblocking operation
- process_next_example(*all, ec);
+ process_next_example(*all, *s, ec);
} else if (parser_done(all->p)) {
- if (!is_singleline)
- do_actual_learning(*all);
+ if (!s->is_singleline)
+ do_actual_learning(*all, *s);
break;
}
}
@@ -1330,10 +1353,10 @@ namespace Searn
if( all->training ) {
std::stringstream ss1;
std::stringstream ss2;
- ss1 << (current_policy+1);
+ ss1 << (s->current_policy+1);
//use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searn_trained_nb_policies
VW::cmd_string_replace_value(all->options_from_file,"--searn_trained_nb_policies", ss1.str());
- ss2 << total_number_of_policies;
+ ss2 << s->total_number_of_policies;
//use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searn_total_nb_policies
VW::cmd_string_replace_value(all->options_from_file,"--searn_total_nb_policies", ss2.str());
}
@@ -1348,15 +1371,15 @@ namespace ImperativeSearn {
const char INIT_TRAIN = 1;
const char LEARN = 2;
- inline bool isLDF(searn_struct* srn) { return (srn->A == 0); }
+ inline bool isLDF(searn& srn) { return (srn.A == 0); }
- uint32_t choose_policy(searn_struct* srn, bool allow_current, bool allow_optimal)
+ uint32_t choose_policy(searn& srn, bool allow_current, bool allow_optimal)
{
- uint32_t seed = 0; // TODO: srn->read_example_last_id * 2147483 + srn->t * 2147483647;
- return SearnUtil::random_policy(seed, srn->beta, allow_current, srn->current_policy, allow_optimal, srn->rollout_all_actions);
+ uint32_t seed = 0; // TODO: srn.read_example_last_id * 2147483 + srn.t * 2147483647;
+ return SearnUtil::random_policy(seed, srn.beta, allow_current, srn.current_policy, allow_optimal, srn.rollout_all_actions);
}
- v_array<CSOAA::wclass> get_all_labels(searn_struct* srn, size_t num_ec, v_array<uint32_t> *yallowed)
+ v_array<CSOAA::wclass> get_all_labels(searn& srn, size_t num_ec, v_array<uint32_t> *yallowed)
{
if (isLDF(srn)) {
v_array<CSOAA::wclass> ret; // TODO: cache these!
@@ -1369,7 +1392,7 @@ namespace ImperativeSearn {
// is not LDF
if (yallowed == NULL) {
v_array<CSOAA::wclass> ret; // TODO: cache this!
- for (uint32_t i=1; i<=srn->A; i++) {
+ for (uint32_t i=1; i<=srn.A; i++) {
CSOAA::wclass cost = { FLT_MAX, i, 1., 0. };
ret.push_back(cost);
}
@@ -1390,21 +1413,20 @@ namespace ImperativeSearn {
return 0;
}
- uint32_t single_prediction_notLDF(vw& all, example* ec, v_array<CSOAA::wclass> valid_labels, uint32_t pol)
+ uint32_t single_prediction_notLDF(vw& all, searn& srn, example* ec, v_array<CSOAA::wclass> valid_labels, uint32_t pol)
{
- searn_struct *srn = (searn_struct*)all.searnstr;
assert(pol > 0);
void* old_label = ec->ld;
ec->ld = (void*)&valid_labels;
- SearnUtil::add_policy_offset(all, ec, srn->increment, pol);
+ SearnUtil::add_policy_offset(all, ec, srn.increment, pol);
- srn->base_learner(&all, ec);
- srn->total_predictions_made++;
- srn->num_features += ec->num_features;
+ srn.base.learn(&all,srn.base.data, ec);
+ srn.total_predictions_made++;
+ srn.num_features += ec->num_features;
uint32_t final_prediction = (uint32_t)(*(OAA::prediction_t*)&(ec->final_prediction));
- SearnUtil::remove_policy_offset(all, ec, srn->increment, pol);
+ SearnUtil::remove_policy_offset(all, ec, srn.increment, pol);
ec->ld = old_label;
return final_prediction;
@@ -1416,7 +1438,7 @@ namespace ImperativeSearn {
return opts[(size_t)(((float)opts.size()) * r)];
}
- uint32_t single_action(vw& all, example** ecs, size_t num_ec, v_array<CSOAA::wclass> valid_labels, uint32_t pol, v_array<uint32_t> *ystar) {
+ uint32_t single_action(vw& all, searn& srn, example** ecs, size_t num_ec, v_array<CSOAA::wclass> valid_labels, uint32_t pol, v_array<uint32_t> *ystar) {
//cerr << "pol=" << pol << " ystar.size()=" << ystar->size() << " ystar[0]=" << ((ystar->size() > 0) ? (*ystar)[0] : 0) << endl;
if (pol == 0) { // optimal policy
if ((ystar == NULL) || (ystar->size() == 0))
@@ -1424,19 +1446,18 @@ namespace ImperativeSearn {
else
return choose_random<uint32_t>(*ystar);
} else { // learned policy
- if (isLDF((searn_struct*)all.searnstr))
+ if (isLDF(srn))
return single_prediction_LDF(all, ecs, num_ec, pol);
else
- return single_prediction_notLDF(all, *ecs, valid_labels, pol);
+ return single_prediction_notLDF(all, srn, *ecs, valid_labels, pol);
}
}
- void clear_snapshot(vw& all)
+ void clear_snapshot(vw& all, searn& srn)
{
- searn_struct *srn = (searn_struct*)all.searnstr;
- for (size_t i=0; i<srn->snapshot_data.size(); i++)
- free(srn->snapshot_data[i].data_ptr);
- srn->snapshot_data.erase();
+ for (size_t i=0; i<srn.snapshot_data.size(); i++)
+ free(srn.snapshot_data[i].data_ptr);
+ srn.snapshot_data.erase();
}
// if not LDF:
@@ -1456,10 +1477,10 @@ namespace ImperativeSearn {
// otherwise means the oracle could do any of the listed actions
uint32_t searn_predict(vw& all, example** ecs, size_t num_ec, v_array<uint32_t> *yallowed, v_array<uint32_t> *ystar) // num_ec == 0 means normal example, >0 means ldf, yallowed==NULL means all allowed, ystar==NULL means don't know
{
- searn_struct *srn = (searn_struct*)all.searnstr;
+ searn* srn=(searn*)all.searnstr;
// check ldf sanity
- if (!isLDF(srn)) {
+ if (!isLDF(*srn)) {
assert(num_ec == 0); // searntask is trying to define an ldf example in a non-ldf problem
} else { // is LDF
assert(num_ec != 0); // searntask is trying to define a non-ldf example in an ldf problem" << endl;
@@ -1467,17 +1488,17 @@ namespace ImperativeSearn {
}
if (srn->state == INIT_TEST) {
- uint32_t pol = choose_policy(srn, true, false);
- v_array<CSOAA::wclass> valid_labels = get_all_labels(srn, num_ec, yallowed);
- uint32_t a = single_action(all, ecs, num_ec, valid_labels, pol, ystar);
+ uint32_t pol = choose_policy(*srn, true, false);
+ v_array<CSOAA::wclass> valid_labels = get_all_labels(*srn, num_ec, yallowed);
+ uint32_t a = single_action(all, *srn, ecs, num_ec, valid_labels, pol, ystar);
srn->t++;
valid_labels.erase(); valid_labels.delete_v();
return a;
}
if (srn->state == INIT_TRAIN) {
- uint32_t pol = choose_policy(srn, srn->allow_current_policy, true);
- v_array<CSOAA::wclass> valid_labels = get_all_labels(srn, num_ec, yallowed);
- uint32_t a = single_action(all, ecs, num_ec, valid_labels, pol, ystar);
+ uint32_t pol = choose_policy(*srn, srn->allow_current_policy, true);
+ v_array<CSOAA::wclass> valid_labels = get_all_labels(*srn, num_ec, yallowed);
+ uint32_t a = single_action(all, *srn, ecs, num_ec, valid_labels, pol, ystar);
srn->train_action.push_back(a);
srn->train_labels.push_back(valid_labels);
srn->t++;
@@ -1502,9 +1523,9 @@ namespace ImperativeSearn {
srn->t++;
return srn->learn_a;
} else {
- uint32_t pol = choose_policy(srn, srn->allow_current_policy, true);
- v_array<CSOAA::wclass> valid_labels = get_all_labels(srn, num_ec, yallowed);
- uint32_t a = single_action(all, ecs, num_ec, valid_labels, pol, ystar);
+ uint32_t pol = choose_policy(*srn, srn->allow_current_policy, true);
+ v_array<CSOAA::wclass> valid_labels = get_all_labels(*srn, num_ec, yallowed);
+ uint32_t a = single_action(all, *srn, ecs, num_ec, valid_labels, pol, ystar);
srn->t++;
valid_labels.erase(); valid_labels.delete_v();
return a;
@@ -1517,8 +1538,7 @@ namespace ImperativeSearn {
void searn_declare_loss(vw& all, size_t predictions_since_last, float incr_loss)
{
- searn_struct *srn = (searn_struct*)all.searnstr;
-
+ searn* srn=(searn*)all.searnstr;
if (srn->t != srn->loss_last_step + predictions_since_last) {
cerr << "fail: searntask hasn't counted its predictions correctly. current time step=" << srn->t << ", last declaration at " << srn->loss_last_step << ", declared # of predictions since then is " << predictions_since_last << endl;
exit(-1);
@@ -1544,7 +1564,7 @@ namespace ImperativeSearn {
void searn_snapshot(vw& all, size_t index, size_t tag, void* data_ptr, size_t sizeof_data)
{
- searn_struct *srn = (searn_struct*)all.searnstr;
+ searn* srn=(searn*)all.searnstr;
if (! srn->do_snapshot) return;
//cerr << "snapshot called with: { index=" << index << ", tag=" << tag << ", data_ptr=" << *(size_t*)data_ptr << ", t=" << srn->t << " }" << endl;
@@ -1593,19 +1613,16 @@ namespace ImperativeSearn {
srn->t = item.pred_step;
}
- v_array<size_t> get_training_timesteps(vw& all)
+ v_array<size_t> get_training_timesteps(vw& all, searn& srn)
{
- searn_struct *srn = (searn_struct*)all.searnstr;
v_array<size_t> timesteps;
- for (size_t t=0; t<srn->T; t++)
+ for (size_t t=0; t<srn.T; t++)
timesteps.push_back(t);
return timesteps;
}
- void generate_training_example(vw& all, example** ec, size_t len, v_array<CSOAA::wclass> labels, v_array<float> losses)
+ void generate_training_example(vw& all, searn& srn, example** ec, size_t len, v_array<CSOAA::wclass> labels, v_array<float> losses)
{
- searn_struct *srn = (searn_struct*)all.searnstr;
-
assert(labels.size() == losses.size());
for (size_t i=0; i<labels.size(); i++)
labels[i].x = losses[i];
@@ -1614,130 +1631,127 @@ namespace ImperativeSearn {
void* old_label = ec[0]->ld;
CSOAA::label new_label = { labels };
ec[0]->ld = (void*)&new_label;
- SearnUtil::add_policy_offset(all, ec[0], srn->increment, srn->current_policy);
- srn->base_learner(&all, ec[0]);
- SearnUtil::remove_policy_offset(all, ec[0], srn->increment, srn->current_policy);
+ SearnUtil::add_policy_offset(all, ec[0], srn.increment, srn.current_policy);
+ srn.base.learn(&all,srn.base.data, ec[0]);
+ SearnUtil::remove_policy_offset(all, ec[0], srn.increment, srn.current_policy);
ec[0]->ld = old_label;
- srn->total_examples_generated++;
+ srn.total_examples_generated++;
} else { // isLDF
//TODO
}
}
- void train_single_example(vw& all, example**ec, size_t len)
+ void train_single_example(vw& all, searn& srn, example**ec, size_t len)
{
- searn_struct *srn = (searn_struct*)all.searnstr;
-
// do an initial test pass to compute output (and loss)
//cerr << "======================================== INIT TEST ========================================" << endl;
- srn->state = INIT_TEST;
- srn->t = 0;
- srn->T = 0;
- srn->loss_last_step = 0;
- srn->test_loss = 0.f;
- srn->train_loss = 0.f;
- srn->learn_loss = 0.f;
- srn->learn_example_copy = NULL;
- srn->learn_example_len = 0;
- srn->train_action.erase();
- srn->num_features = 0;
+ srn.state = INIT_TEST;
+ srn.t = 0;
+ srn.T = 0;
+ srn.loss_last_step = 0;
+ srn.test_loss = 0.f;
+ srn.train_loss = 0.f;
+ srn.learn_loss = 0.f;
+ srn.learn_example_copy = NULL;
+ srn.learn_example_len = 0;
+ srn.train_action.erase();
+ srn.num_features = 0;
- srn->task.structured_predict(all, ec, len, srn->pred_string, srn->truth_string);
+ srn.task->structured_predict(all, srn, ec, len, srn.pred_string, srn.truth_string);
- if (srn->t == 0)
+ if (srn.t == 0)
return; // there was no data!
// do a pass over the data allowing oracle and snapshotting
//cerr << "======================================== INIT TRAIN ========================================" << endl;
- srn->state = INIT_TRAIN;
- srn->t = 0;
- srn->loss_last_step = 0;
- clear_snapshot(all);
+ srn.state = INIT_TRAIN;
+ srn.t = 0;
+ srn.loss_last_step = 0;
+ clear_snapshot(all, srn);
- srn->task.structured_predict(all, ec, len, NULL, NULL);
+ srn.task->structured_predict(all, srn, ec, len, NULL, NULL);
- if (srn->t == 0) {
- clear_snapshot(all);
+ if (srn.t == 0) {
+ clear_snapshot(all, srn);
return; // there was no data
}
- srn->T = srn->t;
+ srn.T = srn.t;
// generate training examples on which to learn
//cerr << "======================================== LEARN ========================================" << endl;
- srn->state = LEARN;
- v_array<size_t> tset = get_training_timesteps(all);
+ srn.state = LEARN;
+ v_array<size_t> tset = get_training_timesteps(all, srn);
for (size_t t=0; t<tset.size(); t++) {
- v_array<CSOAA::wclass> aset = srn->train_labels[t];
- srn->learn_t = t;
- srn->learn_losses.erase();
+ v_array<CSOAA::wclass> aset = srn.train_labels[t];
+ srn.learn_t = t;
+ srn.learn_losses.erase();
for (size_t i=0; i<aset.size(); i++) {
- if (aset[i].weight_index == srn->train_action[srn->learn_t])
- srn->learn_losses.push_back( srn->train_loss );
+ if (aset[i].weight_index == srn.train_action[srn.learn_t])
+ srn.learn_losses.push_back( srn.train_loss );
else {
- srn->t = 0;
- srn->learn_a = aset[i].weight_index;
- srn->loss_last_step = 0;
- srn->learn_loss = 0.f;
+ srn.t = 0;
+ srn.learn_a = aset[i].weight_index;
+ srn.loss_last_step = 0;
+ srn.learn_loss = 0.f;
- //cerr << "learn_t = " << srn->learn_t << " || learn_a = " << srn->learn_a << endl;
- srn->task.structured_predict(all, ec, len, NULL, NULL);
+ //cerr << "learn_t = " << srn.learn_t << " || learn_a = " << srn.learn_a << endl;
+ srn.task->structured_predict(all, srn, ec, len, NULL, NULL);
- srn->learn_losses.push_back( srn->learn_loss );
- //cerr << "total loss: " << srn->learn_loss << endl;
+ srn.learn_losses.push_back( srn.learn_loss );
+ //cerr << "total loss: " << srn.learn_loss << endl;
}
}
- if (srn->learn_example_copy != NULL) {
- generate_training_example(all, srn->learn_example_copy, srn->learn_example_len, aset, srn->learn_losses);
+ if (srn.learn_example_copy != NULL) {
+ generate_training_example(all, srn, srn.learn_example_copy, srn.learn_example_len, aset, srn.learn_losses);
- for (size_t n=0; n<srn->learn_example_len; n++) {
- dealloc_example(CSOAA::delete_label, *srn->learn_example_copy[n]);
- free(srn->learn_example_copy[n]);
+ for (size_t n=0; n<srn.learn_example_len; n++) {
+ dealloc_example(CSOAA::delete_label, *srn.learn_example_copy[n]);
+ free(srn.learn_example_copy[n]);
}
- free(srn->learn_example_copy);
- srn->learn_example_copy = NULL;
- srn->learn_example_len = 0;
+ free(srn.learn_example_copy);
+ srn.learn_example_copy = NULL;
+ srn.learn_example_len = 0;
} else {
cerr << "warning: searn did not generate an example for a given time-step" << endl;
}
}
tset.erase(); tset.delete_v();
- clear_snapshot(all);
- srn->train_action.erase();
- srn->train_action.delete_v();
- for (size_t i=0; i<srn->train_labels.size(); i++) {
- srn->train_labels[i].erase();
- srn->train_labels[i].delete_v();
+ clear_snapshot(all, srn);
+ srn.train_action.erase();
+ srn.train_action.delete_v();
+ for (size_t i=0; i<srn.train_labels.size(); i++) {
+ srn.train_labels[i].erase();
+ srn.train_labels[i].delete_v();
}
- srn->train_labels.erase();
- srn->train_labels.delete_v();
+ srn.train_labels.erase();
+ srn.train_labels.delete_v();
}
- void clear_seq(vw&all)
+ void clear_seq(vw&all, searn& srn)
{
- searn_struct *srn = (searn_struct*)all.searnstr;
- if (srn->ec_seq.size() > 0)
- for (example** ecc=srn->ec_seq.begin; ecc!=srn->ec_seq.end; ecc++) {
+ if (srn.ec_seq.size() > 0)
+ for (example** ecc=srn.ec_seq.begin; ecc!=srn.ec_seq.end; ecc++) {
VW::finish_example(all, *ecc);
}
- srn->ec_seq.erase();
+ srn.ec_seq.erase();
}
float safediv(float a,float b) { if (b == 0.f) return 0.f; else return (a/b); }
- void print_update(vw& all)
+
+ void print_update(vw& all, searn& srn)
{
- searn_struct *srn = (searn_struct*)all.searnstr;
if (!Searn::should_print_update(all))
return;
char true_label[21];
char pred_label[21];
- Searn::to_short_string(srn->truth_string ? srn->truth_string->str() : "", 20, true_label);
- Searn::to_short_string(srn->pred_string ? srn->pred_string->str() : "", 20, pred_label);
+ Searn::to_short_string(srn.truth_string ? srn.truth_string->str() : "", 20, true_label);
+ Searn::to_short_string(srn.pred_string ? srn.pred_string->str() : "", 20, pred_label);
fprintf(stderr, "%-10.6f %-10.6f %8ld %15f [%s] [%s] %8lu %5d %5d %15lu %15lu\n",
safediv((float)all.sd->sum_loss, (float)all.sd->weighted_examples),
@@ -1746,11 +1760,11 @@ namespace ImperativeSearn {
all.sd->weighted_examples,
true_label,
pred_label,
- (long unsigned int)srn->num_features,
- (int)srn->read_example_last_pass,
- (int)srn->current_policy,
- (long unsigned int)srn->total_predictions_made,
- (long unsigned int)srn->total_examples_generated);
+ (long unsigned int)srn.num_features,
+ (int)srn.read_example_last_pass,
+ (int)srn.current_policy,
+ (long unsigned int)srn.total_predictions_made,
+ (long unsigned int)srn.total_examples_generated);
all.sd->sum_loss_since_last_dump = 0.0;
all.sd->old_weighted_examples = all.sd->weighted_examples;
@@ -1758,54 +1772,53 @@ namespace ImperativeSearn {
}
- void do_actual_learning(vw&all)
+ void do_actual_learning(vw&all, searn& srn)
{
- searn_struct *srn = (searn_struct*)all.searnstr;
- if (srn->ec_seq.size() == 0)
+ if (srn.ec_seq.size() == 0)
return; // nothing to do :)
- if (Searn::should_print_update(all) || Searn::will_global_print_label(all)) {
- srn->truth_string = new stringstream();
- srn->pred_string = new stringstream();
+ if (Searn::should_print_update(all)) {
+ srn.truth_string = new stringstream();
+ srn.pred_string = new stringstream();
}
- train_single_example(all, srn->ec_seq.begin, srn->ec_seq.size());
- if (srn->test_loss >= 0.f) {
- all.sd->sum_loss += srn->test_loss;
- all.sd->sum_loss_since_last_dump += srn->test_loss;
+ train_single_example(all, srn, srn.ec_seq.begin, srn.ec_seq.size());
+ if (srn.test_loss >= 0.f) {
+ all.sd->sum_loss += srn.test_loss;
+ all.sd->sum_loss_since_last_dump += srn.test_loss;
all.sd->example_number++;
- all.sd->total_features += srn->num_features;
+ all.sd->total_features += srn.num_features;
all.sd->weighted_examples += 1.f;
}
- print_update(all);
+ print_update(all, srn);
- if (srn->truth_string != NULL) {
- delete srn->truth_string;
- srn->truth_string = NULL;
+ if (srn.truth_string != NULL) {
+ delete srn.truth_string;
+ srn.truth_string = NULL;
}
- if (srn->pred_string != NULL) {
- delete srn->pred_string;
- srn->pred_string = NULL;
+ if (srn.pred_string != NULL) {
+ delete srn.pred_string;
+ srn.pred_string = NULL;
}
}
- void searn_learn(void*in, example*ec) {
- vw all = *(vw*)in;
- searn_struct *srn = (searn_struct*)all.searnstr;
+ void searn_learn(void*in, void*d, example*ec) {
+ vw* all = (vw*)in;
+ searn *srn = (searn*)d;
- if (srn->ec_seq.size() >= all.p->ring_size - 2) { // give some wiggle room
+ if (srn->ec_seq.size() >= all->p->ring_size - 2) { // give some wiggle room
std::cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << std::endl;
- do_actual_learning(all);
- clear_seq(all);
+ do_actual_learning(*all, *srn);
+ clear_seq(*all, *srn);
}
bool is_real_example = true;
if (OAA::example_is_newline(ec)) {
- do_actual_learning(all);
- clear_seq(all);
- VW::finish_example(all, ec);
+ do_actual_learning(*all, *srn);
+ clear_seq(*all, *srn);
+ VW::finish_example(*all, ec);
is_real_example = false;
} else {
srn->ec_seq.push_back(ec);
@@ -1819,7 +1832,7 @@ namespace ImperativeSearn {
srn->passes_since_new_policy++;
if (srn->passes_since_new_policy >= srn->passes_per_policy) {
srn->passes_since_new_policy = 0;
- if(all.training)
+ if(all->training)
srn->current_policy++;
if (srn->current_policy > srn->total_number_of_policies) {
std::cerr << "internal error (bug): too many policies; not advancing" << std::endl;
@@ -1828,15 +1841,15 @@ namespace ImperativeSearn {
//reset searn_trained_nb_policies in options_from_file so it is saved to regressor file later
std::stringstream ss;
ss << srn->current_policy;
- VW::cmd_string_replace_value(all.options_from_file,"--searn_trained_nb_policies", ss.str());
+ VW::cmd_string_replace_value(all->options_from_file,"--searn_trained_nb_policies", ss.str());
}
}
}
}
- void searn_drive(void*in) {
- vw all = *(vw*)in;
- searn_struct *srn = (searn_struct*)all.searnstr;
+ void searn_drive(void*in, void *d) {
+ vw* all = (vw*)in;
+ searn *srn = (searn*)d;
const char * header_fmt = "%-10s %-10s %8s %15s %24s %22s %8s %5s %5s %15s %15s\n";
@@ -1847,63 +1860,61 @@ namespace ImperativeSearn {
example* ec = NULL;
srn->read_example_this_loop = 0;
while (true) {
- if ((ec = get_example(all.p)) != NULL) { // semiblocking operation
- searn_learn(in, ec);
- } else if (parser_done(all.p)) {
- do_actual_learning(all);
+ if ((ec = get_example(all->p)) != NULL) { // semiblocking operation
+ searn_learn(in,d, ec);
+ } else if (parser_done(all->p)) {
+ do_actual_learning(*all, *srn);
break;
}
}
- if( all.training ) {
+ if( all->training ) {
std::stringstream ss1;
std::stringstream ss2;
ss1 << (srn->current_policy+1);
//use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searnimp_trained_nb_policies
- VW::cmd_string_replace_value(all.options_from_file,"--searnimp_trained_nb_policies", ss1.str());
+ VW::cmd_string_replace_value(all->options_from_file,"--searnimp_trained_nb_policies", ss1.str());
ss2 << srn->total_number_of_policies;
//use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searnimp_total_nb_policies
- VW::cmd_string_replace_value(all.options_from_file,"--searnimp_total_nb_policies", ss2.str());
+ VW::cmd_string_replace_value(all->options_from_file,"--searnimp_total_nb_policies", ss2.str());
}
}
- void searn_initialize(vw& all)
+ void searn_initialize(vw& all, searn& srn)
{
- searn_struct *srn = (searn_struct*)all.searnstr;
-
- srn->predict = searn_predict;
- srn->declare_loss = searn_declare_loss;
- srn->snapshot = searn_snapshot;
-
- srn->beta = 0.5;
- srn->allow_current_policy = false;
- srn->rollout_all_actions = true;
- srn->num_features = 0;
- srn->current_policy = 1;
- srn->state = 0;
- srn->do_snapshot = true;
-
- srn->passes_per_policy = 1; //this should be set to the same value as --passes for dagger
-
- srn->read_example_this_loop = 0;
- srn->read_example_last_id = 0;
- srn->passes_since_new_policy = 0;
- srn->read_example_last_pass = 0;
- srn->total_examples_generated = 0;
- srn->total_predictions_made = 0;
+ srn.predict = searn_predict;
+ srn.declare_loss = searn_declare_loss;
+ srn.snapshot = searn_snapshot;
+
+ srn.beta = 0.5;
+ srn.allow_current_policy = false;
+ srn.rollout_all_actions = true;
+ srn.num_features = 0;
+ srn.current_policy = 1;
+ srn.state = 0;
+ srn.do_snapshot = true;
+
+ srn.passes_per_policy = 1; //this should be set to the same value as --passes for dagger
+
+ srn.read_example_this_loop = 0;
+ srn.read_example_last_id = 0;
+ srn.passes_since_new_policy = 0;
+ srn.read_example_last_pass = 0;
+ srn.total_examples_generated = 0;
+ srn.total_predictions_made = 0;
}
- void searn_finish(void*in)
+ void searn_finish(void*in, void* d)
{
vw*all = (vw*)in;
- searn_struct *srn = (searn_struct*)all->searnstr;
+ searn *srn = (searn*)d;
//cerr << "searn_finish" << endl;
- clear_seq(*all);
+ clear_seq(*all,*srn);
srn->ec_seq.delete_v();
- clear_snapshot(*all);
+ clear_snapshot(*all, *srn);
srn->snapshot_data.delete_v();
for (size_t i=0; i<srn->train_labels.size(); i++) {
@@ -1914,10 +1925,10 @@ namespace ImperativeSearn {
srn->train_action.erase(); srn->train_action.delete_v();
srn->learn_losses.erase(); srn->learn_losses.delete_v();
- if (srn->task.finish != NULL)
- srn->task.finish(*all);
- if (srn->task.finish != NULL)
- srn->base_finish(all);
+ if (srn->task->finish != NULL)
+ srn->task->finish(*all);
+ if (srn->task->finish != NULL)
+ srn->base.finish(all, srn->base.data);
}
void ensure_param(float &v, float lo, float hi, float def, const char* string) {
@@ -1952,9 +1963,9 @@ namespace ImperativeSearn {
void parse_flags(vw&all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
{
- searn_struct *srn = (searn_struct*)all.searnstr;
+ searn* srn = (searn*)calloc(1,sizeof(searn));
- searn_initialize(all);
+ searn_initialize(all, *srn);
po::options_description desc("Imperative Searn options");
desc.add_options()
@@ -2038,25 +2049,24 @@ namespace ImperativeSearn {
srn->increment = ((uint32_t)all.length() / all.base_learner_nb_w) * all.stride;
if (task_string.compare("sequence") == 0) {
- searn_task mytask = { SequenceTask_Easy::initialize,
- SequenceTask_Easy::finish,
- SequenceTask_Easy::structured_predict_v1
- };
+ searn_task* mytask = (searn_task*)calloc(1, sizeof(searn_task));
+ mytask->initialize = SequenceTask_Easy::initialize;
+ mytask->finish = SequenceTask_Easy::finish;
+ mytask->structured_predict = SequenceTask_Easy::structured_predict_v1;
+
srn->task = mytask;
} else {
cerr << "fail: unknown task for --searn_task: " << task_string << endl;
exit(-1);
}
- srn->task.initialize(all, srn->A);
+ srn->task->initialize(all, srn->A);
- all.driver = searn_drive;
- srn->base_learner = all.learn;
- all.learn = searn_learn;
- srn->base_finish = all.finish;
- all.finish = searn_finish;
+ srn->base = all.l;
+
+ learner l = {srn, searn_drive, searn_learn, searn_finish, all.l.save_load};
+ all.l = l;
}
-
}
/*
time ./vw --searn 45 --searn_task sequence -k -c -d ../test/train-sets/wsj_small.dat2.gz --passes 5 --searn_passes_per_policy 4
diff --git a/vowpalwabbit/searn.h b/vowpalwabbit/searn.h
index b7c02528..71cff9ab 100644
--- a/vowpalwabbit/searn.h
+++ b/vowpalwabbit/searn.h
@@ -195,16 +195,9 @@ namespace Searn
};
void parse_flags(vw&all, std::vector<std::string>&, po::variables_map& vm, po::variables_map& vm_file);
- void drive(void*);
}
namespace ImperativeSearn {
- struct searn_task {
- void (*initialize)(vw&,uint32_t&);
- void (*finish)(vw&);
- void (*structured_predict)(vw&,example**,size_t,stringstream*,stringstream*);
- };
-
struct snapshot_item {
size_t index;
size_t tag;
@@ -212,15 +205,17 @@ namespace ImperativeSearn {
size_t data_size; // sizeof(data_ptr)
size_t pred_step; // srn->t when snapshot is made
};
+
+ struct searn_task;
- struct searn_struct {
+ struct searn {
// functions that you will call
uint32_t (*predict)(vw&,example**,size_t,v_array<uint32_t>*,v_array<uint32_t>*);
void (*declare_loss)(vw&,size_t,float); // <0 means it was a test example!
void (*snapshot)(vw&,size_t,size_t,void*,size_t);
// structure that you must set
- searn_task task;
+ searn_task* task;
// data that you should not look at. ever.
uint32_t A; // total number of actions, [1..A]; 0 means ldf
@@ -265,8 +260,13 @@ namespace ImperativeSearn {
v_array<example*> ec_seq;
- void (*base_finish)(void*);
- void (*base_learner)(void*,example*);
+ learner base;
+ };
+
+ struct searn_task {
+ void (*initialize)(vw&,uint32_t&);
+ void (*finish)(vw&);
+ void (*structured_predict)(vw&, searn&, example**,size_t,stringstream*,stringstream*);
};
void parse_flags(vw&, std::vector<std::string>&, po::variables_map&, po::variables_map&);
diff --git a/vowpalwabbit/searn_sequencetask.cc b/vowpalwabbit/searn_sequencetask.cc
index f2487e47..e1001328 100644
--- a/vowpalwabbit/searn_sequencetask.cc
+++ b/vowpalwabbit/searn_sequencetask.cc
@@ -351,8 +351,7 @@ namespace SequenceTask_Easy {
out->push_back( lab->costs[l].weight_index );
}
- void structured_predict_v1(vw& vw, example**ec, size_t len, stringstream*output_ss, stringstream*truth_ss) {
- searn_struct srn = *(searn_struct*)vw.searnstr;
+ void structured_predict_v1(vw& vw, searn& srn, example**ec, size_t len, stringstream*output_ss, stringstream*truth_ss) {
float total_loss = 0;
size_t history_length = max(hinfo.features, hinfo.length);
bool is_train = false;
diff --git a/vowpalwabbit/searn_sequencetask.h b/vowpalwabbit/searn_sequencetask.h
index 3c174ce9..89f5c30d 100644
--- a/vowpalwabbit/searn_sequencetask.h
+++ b/vowpalwabbit/searn_sequencetask.h
@@ -8,6 +8,7 @@ license as described in the file LICENSE.
#include "oaa.h"
#include "parse_primitives.h"
+#include "searn.h"
namespace SequenceTask {
bool final(state);
@@ -30,7 +31,7 @@ namespace SequenceTask {
namespace SequenceTask_Easy {
void initialize(vw&, uint32_t&);
void finish(vw&);
- void structured_predict_v1(vw&,example**,size_t,stringstream*,stringstream*);
+ void structured_predict_v1(vw&, ImperativeSearn::searn&, example**,size_t,stringstream*,stringstream*);
}
diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc
index ea7243d3..95d20ff9 100644
--- a/vowpalwabbit/sender.cc
+++ b/vowpalwabbit/sender.cc
@@ -24,25 +24,17 @@
using namespace std;
namespace SENDER {
-//nonreentrant
-io_buf* buf;
+ struct sender {
+ io_buf* buf;
+
+ int sd;
+ };
-int sd = -1;
-
-void open_sockets(string host)
-{
- sd = open_socket(host.c_str());
- buf = new io_buf();
- buf->files.push_back(sd);
-}
-
-void parse_send_args(po::variables_map& vm, vector<string> pairs)
+ void open_sockets(sender& s, string host)
{
- if (vm.count("sendto"))
- {
- vector<string> hosts = vm["sendto"].as< vector<string> >();
- open_sockets(hosts[0]);
- }
+ s.sd = open_socket(host.c_str());
+ s.buf = new io_buf();
+ s.buf->files.push_back(s.sd);
}
void send_features(io_buf *b, example* ec)
@@ -58,11 +50,12 @@ void send_features(io_buf *b, example* ec)
b->flush();
}
-void save_load(void* in, io_buf& model_file, bool read, bool text) {}
+ void save_load(void* in, void* d, io_buf& model_file, bool read, bool text) {}
-void drive_send(void* in)
+ void drive_send(void* in, void* d)
{
vw* all = (vw*)in;
+ sender* s = (sender*)d;
example* ec = NULL;
v_array<char> null_tag;
null_tag.erase();
@@ -77,7 +70,7 @@ void drive_send(void* in)
if (received_index + all->p->ring_size == sent_index || (parser_finished & (received_index != sent_index)))
{
float res, weight;
- get_prediction(sd,res,weight);
+ get_prediction(s->sd,res,weight);
ec=delay_ring[received_index++ % all->p->ring_size];
label_data* ld = (label_data*)ec->ld;
@@ -92,9 +85,9 @@ void drive_send(void* in)
{
label_data* ld = (label_data*)ec->ld;
all->set_minmax(all->sd, ld->label);
- simple_label.cache_label(ld, *buf);//send label information.
- cache_tag(*buf, ec->tag);
- send_features(buf,ec);
+ simple_label.cache_label(ld, *s->buf);//send label information.
+ cache_tag(*s->buf, ec->tag);
+ send_features(s->buf,ec);
delay_ring[sent_index++ % all->p->ring_size] = ec;
}
else if (parser_done(all->p))
@@ -102,9 +95,9 @@ void drive_send(void* in)
parser_finished = true;
if (received_index == sent_index)
{
- shutdown(buf->files[0],SHUT_WR);
- buf->files.delete_v();
- buf->space.delete_v();
+ shutdown(s->buf->files[0],SHUT_WR);
+ s->buf->files.delete_v();
+ s->buf->space.delete_v();
free(delay_ring);
return;
}
@@ -114,5 +107,21 @@ void drive_send(void* in)
}
return;
}
+ void learn(void*in, void* d, example*ec) { cout << "sender can't be used under reduction" << endl; }
+ void finish(void*in, void* d) { cout << "sender can't be used under reduction" << endl; }
+
+ void parse_send_args(vw& all, po::variables_map& vm, vector<string> pairs)
+{
+ sender* s = (sender*)calloc(1,sizeof(sender));
+ s->sd = -1;
+ if (vm.count("sendto"))
+ {
+ vector<string> hosts = vm["sendto"].as< vector<string> >();
+ open_sockets(*s, hosts[0]);
+ }
+
+ learner ret = {s,drive_send,learn,finish,save_load};
+ all.l = ret;
+}
}
diff --git a/vowpalwabbit/sender.h b/vowpalwabbit/sender.h
index 57b727cd..a43cfed9 100644
--- a/vowpalwabbit/sender.h
+++ b/vowpalwabbit/sender.h
@@ -4,7 +4,5 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
namespace SENDER{
-void parse_send_args(po::variables_map& vm, std::vector<std::string> pairs);
-void drive_send(void*);
- void save_load(void* in, io_buf& model_file, bool read, bool text);
+ void parse_send_args(vw& all, po::variables_map& vm, vector<string> pairs);
}
diff --git a/vowpalwabbit/vw.cc b/vowpalwabbit/vw.cc
index 0753816c..372bf824 100644
--- a/vowpalwabbit/vw.cc
+++ b/vowpalwabbit/vw.cc
@@ -44,7 +44,7 @@ int main(int argc, char *argv[])
start_parser(all);
- all.driver(&all);
+ all.l.driver(&all, &all.l.data);
end_parser(all);
diff --git a/vowpalwabbit/wap.cc b/vowpalwabbit/wap.cc
index 90b8103b..e45e65e2 100644
--- a/vowpalwabbit/wap.cc
+++ b/vowpalwabbit/wap.cc
@@ -17,9 +17,11 @@ license as described in the file LICENSE.
using namespace std;
namespace WAP {
- //nonreentrant
- uint32_t increment=0;
-
+ struct wap{
+ uint32_t increment;
+ learner base;
+ };
+
void mirror_features(vw& all, example* ec, uint32_t offset1, uint32_t offset2)
{
for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
@@ -128,7 +130,7 @@ namespace WAP {
void (*base_learner)(void*, example*) = NULL;
- void train(vw& all, example* ec)
+ void train(vw& all, wap& w, example* ec)
{
CSOAA::label* ld = (CSOAA::label*)ec->ld;
@@ -182,10 +184,10 @@ namespace WAP {
uint32_t myi = (uint32_t)vs[i].ci.weight_index;
uint32_t myj = (uint32_t)vs[j].ci.weight_index;
- mirror_features(all, ec,(myi-1)*increment, (myj-1)*increment);
+ mirror_features(all, ec,(myi-1)*w.increment, (myj-1)*w.increment);
base_learner(&all, ec);
- unmirror_features(all, ec,(myi-1)*increment, (myj-1)*increment);
+ unmirror_features(all, ec,(myi-1)*w.increment, (myj-1)*w.increment);
}
}
@@ -193,7 +195,7 @@ namespace WAP {
ec->ld = ld;
}
- size_t test(vw& all, example* ec)
+ size_t test(vw& all, wap& w, example* ec)
{
size_t prediction = 1;
float score = -FLT_MAX;
@@ -208,12 +210,12 @@ namespace WAP {
simple_temp.label = FLT_MAX;
uint32_t myi = (uint32_t)cost_label->costs[i].weight_index;
if (myi!= 1)
- update_example_indicies(all.audit, ec, increment*(myi-1));
+ update_example_indicies(all.audit, ec, w.increment*(myi-1));
ec->partial_prediction = 0.;
ec->ld = &simple_temp;
base_learner(&all, ec);
if (myi != 1)
- update_example_indicies(all.audit, ec, -increment*(myi-1));
+ update_example_indicies(all.audit, ec, -w.increment*(myi-1));
if (ec->partial_prediction > score)
{
score = ec->partial_prediction;
@@ -224,21 +226,29 @@ namespace WAP {
return prediction;
}
- void learn(void* a, example* ec)
+ void learn(void* a, void* d, example* ec)
{
vw* all = (vw*)a;
CSOAA::label* cost_label = (CSOAA::label*)ec->ld;
+ wap* w = (wap*)d;
- size_t prediction = test(*all, ec);
+ size_t prediction = test(*all, *w, ec);
ec->ld = cost_label;
if (cost_label->costs.size() > 0)
- train(*all, ec);
+ train(*all, *w, ec);
*(OAA::prediction_t*)&(ec->final_prediction) = prediction;
}
-
- void drive_wap(void* in)
+
+ void finish(void* a, void* d)
+ {
+ wap* w=(wap*)d;
+ w->base.finish(a,w->base.data);
+ free(w);
+ }
+
+ void drive(void* in, void* d)
{
vw* all = (vw*)in;
example* ec = NULL;
@@ -246,22 +256,20 @@ namespace WAP {
{
if ((ec = get_example(all->p)) != NULL)//semiblocking operation.
{
- learn(all,ec);
+ learn(all, d, ec);
CSOAA::output_example(*all, ec);
VW::finish_example(*all, ec);
}
else if (parser_done(all->p))
- {
- all->finish(all);
- return;
- }
+ return;
else
;
}
}
-
+
void parse_flags(vw& all, std::vector<std::string>&, po::variables_map& vm, po::variables_map& vm_file)
{
+ wap* w=(wap*)calloc(1,sizeof(wap));
uint32_t nb_actions = 0;
if( vm_file.count("wap") ) { //if loaded options from regressor
nb_actions = (uint32_t)vm_file["wap"].as<size_t>();
@@ -281,11 +289,9 @@ namespace WAP {
all.sd->k = (uint32_t)nb_actions;
all.base_learner_nb_w *= nb_actions;
- increment = (uint32_t)((all.length()/ all.base_learner_nb_w) * all.stride);
+ w->increment = (uint32_t)((all.length()/ all.base_learner_nb_w) * all.stride);
- all.driver = drive_wap;
- base_learner = all.learn;
- all.base_learn = all.learn;
- all.learn = learn;
+ learner l = {w, drive, learn, finish, all.l.save_load};
+ all.l = l;
}
}