diff options
author | John Langford <jl@hunch.net> | 2013-01-21 19:47:03 +0400 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2013-01-21 19:47:03 +0400 |
commit | f8d453e6eec1067e411271f2c208f001ec32237e (patch) | |
tree | 757ff6bfe1d1b928d3c035c6e794a6577d47b047 | |
parent | 85a5725045ca861feae61ab65b7a4593f6f81615 (diff) |
first compiling version
31 files changed, 1851 insertions, 1818 deletions
diff --git a/Makefile.am b/Makefile.am index 7e405360..1750ebe3 100644 --- a/Makefile.am +++ b/Makefile.am @@ -17,7 +17,7 @@ noinst_HEADERS = vowpalwabbit/accumulate.h vowpalwabbit/oaa.h \ vowpalwabbit/v_array.h vowpalwabbit/lda_core.h \ vowpalwabbit/v_hashmap.h vowpalwabbit/loss_functions.h \ vowpalwabbit/network.h vowpalwabbit/wap.h vowpalwabbit/noop.h \ - vowpalwabbit/nn.h + vowpalwabbit/nn.h vowpalwabbit/learner.h ACLOCAL_AMFLAGS = -I acinclude.d diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc index d863d9e7..2d3febd1 100644 --- a/vowpalwabbit/bfgs.cc +++ b/vowpalwabbit/bfgs.cc @@ -61,42 +61,43 @@ namespace BFGS { //nonrentrant - -double wolfe1_bound = 0.01; - -struct timeb t_start, t_end; -double net_comm_time = 0.0; - -struct timeb t_start_global, t_end_global; -double net_time; - -v_array<float> predictions; -size_t example_number=0; -size_t current_pass = 0; - - // default transition behavior -bool first_hessian_on=true; -bool backstep_on=false; - - // set by initializer -int mem_stride; -bool output_regularizer; -float* mem; -double* rho; -double* alpha; - - weight* regularizers = NULL; - // the below needs to be included when resetting, in addition to preconditioner and derivative - int lastj, origin; - double loss_sum, previous_loss_sum; - float step_size; - double importance_weight_sum; - double curvature; - - // first pass specification - bool first_pass=true; -bool gradient_pass=true; -bool preconditioner_pass=true; + struct bfgs { + double wolfe1_bound; + + struct timeb t_start, t_end; + double net_comm_time; + + struct timeb t_start_global, t_end_global; + double net_time; + + v_array<float> predictions; + size_t example_number; + size_t current_pass; + + // default transition behavior + bool first_hessian_on; + bool backstep_on; + + // set by initializer + int mem_stride; + bool output_regularizer; + float* mem; + double* rho; + double* alpha; + + weight* regularizers; + // the below needs to be included when resetting, in addition to preconditioner and derivative + int lastj, origin; + double loss_sum, previous_loss_sum; + float step_size; + double importance_weight_sum; + double curvature; + + // first pass specification + bool first_pass; + bool gradient_pass; + bool preconditioner_pass; + }; const char* curv_message = "Zero or negative curvature detected.\n" "To increase curvature you can increase regularization or rescale features.\n" @@ -121,15 +122,15 @@ void zero_preconditioner(vw& all) weights[stride*i+W_COND] = 0; } -void reset_state(vw& all, bool zero) +void reset_state(vw& all, bfgs& b, bool zero) { - lastj = origin = 0; - loss_sum = previous_loss_sum = 0.; - importance_weight_sum = 0.; - curvature = 0.; - first_pass = true; - gradient_pass = true; - preconditioner_pass = true; + b.lastj = b.origin = 0; + b.loss_sum = b.previous_loss_sum = 0.; + b.importance_weight_sum = 0.; + b.curvature = 0.; + b.first_pass = true; + b.gradient_pass = true; + b.preconditioner_pass = true; if (zero) { zero_derivative(all); @@ -300,7 +301,7 @@ float dot_with_direction(vw& all, example* &ec) return ret; } -double regularizer_direction_magnitude(vw& all, float regularizer) +double regularizer_direction_magnitude(vw& all, bfgs& b, float regularizer) {//compute direction magnitude double ret = 0.; @@ -310,12 +311,12 @@ double regularizer_direction_magnitude(vw& all, float regularizer) uint32_t length = 1 << all.num_bits; size_t stride = all.stride; weight* weights = all.reg.weight_vector; - if (regularizers == NULL) + if (b.regularizers == NULL) for(uint32_t i = 0; i < length; i++) ret += regularizer*weights[stride*i+W_DIR]*weights[stride*i+W_DIR]; else for(uint32_t i = 0; i < length; i++) - ret += regularizers[2*i]*weights[stride*i+W_DIR]*weights[stride*i+W_DIR]; + ret += b.regularizers[2*i]*weights[stride*i+W_DIR]*weights[stride*i+W_DIR]; return ret; } @@ -332,7 +333,7 @@ float direction_magnitude(vw& all) return (float)ret; } - void bfgs_iter_start(vw& all, float* mem, int& lastj, double importance_weight_sum, int&origin) +void bfgs_iter_start(vw& all, bfgs& b, float* mem, int& lastj, double importance_weight_sum, int&origin) { uint32_t length = 1 << all.num_bits; size_t stride = all.stride; @@ -342,10 +343,10 @@ float direction_magnitude(vw& all) double g1_g1 = 0.; origin = 0; - for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) { + for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) { if (all.m>0) - mem[(MEM_XT+origin)%mem_stride] = w[W_XT]; - mem[(MEM_GT+origin)%mem_stride] = w[W_GT]; + mem[(MEM_XT+origin)%b.mem_stride] = w[W_XT]; + mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT]; g1_Hg1 += w[W_GT] * w[W_GT] * w[W_COND]; g1_g1 += w[W_GT] * w[W_GT]; w[W_DIR] = -w[W_COND]*w[W_GT]; @@ -358,7 +359,7 @@ float direction_magnitude(vw& all) g1_Hg1/importance_weight_sum, "", "", ""); } -void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& lastj, int &origin) +void bfgs_iter_middle(vw& all, bfgs& b, float* mem, double* rho, double* alpha, int& lastj, int &origin) { uint32_t length = 1 << all.num_bits; size_t stride = all.stride; @@ -373,10 +374,10 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last double g_Hg = 0.; double y = 0.; - for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) { - y = w[W_GT]-mem[(MEM_GT+origin)%mem_stride]; + for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) { + y = w[W_GT]-mem[(MEM_GT+origin)%b.mem_stride]; g_Hy += w[W_GT] * w[W_COND] * y; - g_Hg += mem[(MEM_GT+origin)%mem_stride] * w[W_COND] * mem[(MEM_GT+origin)%mem_stride]; + g_Hg += mem[(MEM_GT+origin)%b.mem_stride] * w[W_COND] * mem[(MEM_GT+origin)%b.mem_stride]; } float beta = (float) (g_Hy/g_Hg); @@ -386,8 +387,8 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last mem = mem0; w = w0; - for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) { - mem[(MEM_GT+origin)%mem_stride] = w[W_GT]; + for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) { + mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT]; w[W_DIR] *= beta; w[W_DIR] -= w[W_COND]*w[W_GT]; @@ -407,13 +408,13 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last double y_Hy = 0.; double s_q = 0.; - for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) { - mem[(MEM_YT+origin)%mem_stride] = w[W_GT] - mem[(MEM_GT+origin)%mem_stride]; - mem[(MEM_ST+origin)%mem_stride] = w[W_XT] - mem[(MEM_XT+origin)%mem_stride]; + for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) { + mem[(MEM_YT+origin)%b.mem_stride] = w[W_GT] - mem[(MEM_GT+origin)%b.mem_stride]; + mem[(MEM_ST+origin)%b.mem_stride] = w[W_XT] - mem[(MEM_XT+origin)%b.mem_stride]; w[W_DIR] = w[W_GT]; - y_s += mem[(MEM_YT+origin)%mem_stride]*mem[(MEM_ST+origin)%mem_stride]; - y_Hy += mem[(MEM_YT+origin)%mem_stride]*mem[(MEM_YT+origin)%mem_stride]*w[W_COND]; - s_q += mem[(MEM_ST+origin)%mem_stride]*w[W_GT]; + y_s += mem[(MEM_YT+origin)%b.mem_stride]*mem[(MEM_ST+origin)%b.mem_stride]; + y_Hy += mem[(MEM_YT+origin)%b.mem_stride]*mem[(MEM_YT+origin)%b.mem_stride]*w[W_COND]; + s_q += mem[(MEM_ST+origin)%b.mem_stride]*w[W_GT]; } if (y_s <= 0. || y_Hy <= 0.) @@ -427,9 +428,9 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last s_q = 0.; mem = mem0; w = w0; - for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) { - w[W_DIR] -= (float)alpha[j]*mem[(2*j+MEM_YT+origin)%mem_stride]; - s_q += mem[(2*j+2+MEM_ST+origin)%mem_stride]*w[W_DIR]; + for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) { + w[W_DIR] -= (float)alpha[j]*mem[(2*j+MEM_YT+origin)%b.mem_stride]; + s_q += mem[(2*j+2+MEM_ST+origin)%b.mem_stride]*w[W_DIR]; } } @@ -437,10 +438,10 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last double y_r = 0.; mem = mem0; w = w0; - for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) { - w[W_DIR] -= (float)alpha[lastj]*mem[(2*lastj+MEM_YT+origin)%mem_stride]; + for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) { + w[W_DIR] -= (float)alpha[lastj]*mem[(2*lastj+MEM_YT+origin)%b.mem_stride]; w[W_DIR] *= gamma*w[W_COND]; - y_r += mem[(2*lastj+MEM_YT+origin)%mem_stride]*w[W_DIR]; + y_r += mem[(2*lastj+MEM_YT+origin)%b.mem_stride]*w[W_DIR]; } double coef_j; @@ -450,17 +451,17 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last y_r = 0.; mem = mem0; w = w0; - for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) { - w[W_DIR] += (float)coef_j*mem[(2*j+MEM_ST+origin)%mem_stride]; - y_r += mem[(2*j-2+MEM_YT+origin)%mem_stride]*w[W_DIR]; + for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) { + w[W_DIR] += (float)coef_j*mem[(2*j+MEM_ST+origin)%b.mem_stride]; + y_r += mem[(2*j-2+MEM_YT+origin)%b.mem_stride]*w[W_DIR]; } } coef_j = alpha[0] - rho[0] * y_r; mem = mem0; w = w0; - for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) { - w[W_DIR] = -w[W_DIR]-(float)coef_j*mem[(MEM_ST+origin)%mem_stride]; + for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) { + w[W_DIR] = -w[W_DIR]-(float)coef_j*mem[(MEM_ST+origin)%b.mem_stride]; } /********************* @@ -470,17 +471,17 @@ void bfgs_iter_middle(vw& all, float* mem, double* rho, double* alpha, int& last mem = mem0; w = w0; lastj = (lastj<all.m-1) ? lastj+1 : all.m-1; - origin = (origin+mem_stride-2)%mem_stride; - for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) { - mem[(MEM_GT+origin)%mem_stride] = w[W_GT]; - mem[(MEM_XT+origin)%mem_stride] = w[W_XT]; + origin = (origin+b.mem_stride-2)%b.mem_stride; + for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) { + mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT]; + mem[(MEM_XT+origin)%b.mem_stride] = w[W_XT]; w[W_GT] = 0; } for (int j=lastj; j>0; j--) rho[j] = rho[j-1]; } -double wolfe_eval(vw& all, float* mem, double loss_sum, double previous_loss_sum, double step_size, double importance_weight_sum, int &origin, double& wolfe1) { +double wolfe_eval(vw& all, bfgs& b, float* mem, double loss_sum, double previous_loss_sum, double step_size, double importance_weight_sum, int &origin, double& wolfe1) { uint32_t length = 1 << all.num_bits; size_t stride = all.stride; weight* w = all.reg.weight_vector; @@ -490,8 +491,8 @@ double wolfe_eval(vw& all, float* mem, double loss_sum, double previous_loss_sum double g1_Hg1 = 0.; double g1_g1 = 0.; - for(uint32_t i = 0; i < length; i++, mem+=mem_stride, w+=stride) { - g0_d += mem[(MEM_GT+origin)%mem_stride] * w[W_DIR]; + for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) { + g0_d += mem[(MEM_GT+origin)%b.mem_stride] * w[W_DIR]; g1_d += w[W_GT] * w[W_DIR]; g1_Hg1 += w[W_GT] * w[W_GT] * w[W_COND]; g1_g1 += w[W_GT] * w[W_GT]; @@ -507,13 +508,13 @@ double wolfe_eval(vw& all, float* mem, double loss_sum, double previous_loss_sum } -double add_regularization(vw& all, float regularization) +double add_regularization(vw& all, bfgs& b, float regularization) {//compute the derivative difference double ret = 0.; uint32_t length = 1 << all.num_bits; size_t stride = all.stride; weight* weights = all.reg.weight_vector; - if (regularizers == NULL) + if (b.regularizers == NULL) { for(uint32_t i = 0; i < length; i++) { weights[stride*i+W_GT] += regularization*weights[stride*i]; @@ -523,21 +524,21 @@ double add_regularization(vw& all, float regularization) else { for(uint32_t i = 0; i < length; i++) { - weight delta_weight = weights[stride*i] - regularizers[2*i+1]; - weights[stride*i+W_GT] += regularizers[2*i]*delta_weight; - ret += 0.5*regularizers[2*i]*delta_weight*delta_weight; + weight delta_weight = weights[stride*i] - b.regularizers[2*i+1]; + weights[stride*i+W_GT] += b.regularizers[2*i]*delta_weight; + ret += 0.5*b.regularizers[2*i]*delta_weight*delta_weight; } } return ret; } -void finalize_preconditioner(vw& all, float regularization) +void finalize_preconditioner(vw& all, bfgs& b, float regularization) { uint32_t length = 1 << all.num_bits; size_t stride = all.stride; weight* weights = all.reg.weight_vector; - if (regularizers == NULL) + if (b.regularizers == NULL) for(uint32_t i = 0; i < length; i++) { weights[stride*i+W_COND] += regularization; if (weights[stride*i+W_COND] > 0) @@ -545,34 +546,34 @@ void finalize_preconditioner(vw& all, float regularization) } else for(uint32_t i = 0; i < length; i++) { - weights[stride*i+W_COND] += regularizers[2*i]; + weights[stride*i+W_COND] += b.regularizers[2*i]; if (weights[stride*i+W_COND] > 0) weights[stride*i+W_COND] = 1.f / weights[stride*i+W_COND]; } } -void preconditioner_to_regularizer(vw& all, float regularization) +void preconditioner_to_regularizer(vw& all, bfgs& b, float regularization) { uint32_t length = 1 << all.num_bits; size_t stride = all.stride; weight* weights = all.reg.weight_vector; - if (regularizers == NULL) + if (b.regularizers == NULL) { - regularizers = (weight *)calloc(2*length, sizeof(weight)); + b.regularizers = (weight *)calloc(2*length, sizeof(weight)); - if (regularizers == NULL) + if (b.regularizers == NULL) { cerr << all.program_name << ": Failed to allocate weight array: try decreasing -b <bits>" << endl; exit (1); } for(uint32_t i = 0; i < length; i++) - regularizers[2*i] = weights[stride*i+W_COND] + regularization; + b.regularizers[2*i] = weights[stride*i+W_COND] + regularization; } else for(uint32_t i = 0; i < length; i++) - regularizers[2*i] = weights[stride*i+W_COND] + regularizers[2*i]; + b.regularizers[2*i] = weights[stride*i+W_COND] + b.regularizers[2*i]; for(uint32_t i = 0; i < length; i++) - regularizers[2*i+1] = weights[stride*i]; + b.regularizers[2*i+1] = weights[stride*i]; } void zero_state(vw& all) @@ -588,15 +589,15 @@ void zero_state(vw& all) } } -double derivative_in_direction(vw& all, float* mem, int &origin) +double derivative_in_direction(vw& all, bfgs& b, float* mem, int &origin) { double ret = 0.; uint32_t length = 1 << all.num_bits; size_t stride = all.stride; weight* w = all.reg.weight_vector; - for(uint32_t i = 0; i < length; i++, w+=stride, mem+=mem_stride) - ret += mem[(MEM_GT+origin)%mem_stride]*w[W_DIR]; + for(uint32_t i = 0; i < length; i++, w+=stride, mem+=b.mem_stride) + ret += mem[(MEM_GT+origin)%b.mem_stride]*w[W_DIR]; return ret; } @@ -611,63 +612,66 @@ void update_weight(vw& all, string& reg_name, float step_size, size_t current_pa save_predictor(all, reg_name, current_pass); } -int process_pass(vw& all) { +int process_pass(vw& all, bfgs& b) { int status = LEARN_OK; /********************************************************************/ /* A) FIRST PASS FINISHED: INITIALIZE FIRST LINE SEARCH *************/ /********************************************************************/ - if (first_pass) { + if (b.first_pass) { if(all.span_server != "") { accumulate(all, all.span_server, all.reg, W_COND); //Accumulate preconditioner - importance_weight_sum = accumulate_scalar(all, all.span_server, (float)importance_weight_sum); + float temp = b.importance_weight_sum; + b.importance_weight_sum = accumulate_scalar(all, all.span_server, temp); } - finalize_preconditioner(all, all.l2_lambda); + finalize_preconditioner(all, b, all.l2_lambda); if(all.span_server != "") { - loss_sum = accumulate_scalar(all, all.span_server, (float)loss_sum); //Accumulate loss_sums + float temp = (float)b.loss_sum; + b.loss_sum = accumulate_scalar(all, all.span_server, temp); //Accumulate loss_sums accumulate(all, all.span_server, all.reg, 1); //Accumulate gradients from all nodes } if (all.l2_lambda > 0.) - loss_sum += add_regularization(all, all.l2_lambda); + b.loss_sum += add_regularization(all, b, all.l2_lambda); if (!all.quiet) - fprintf(stderr, "%2lu %-10.5f\t", (long unsigned int)current_pass+1, loss_sum / importance_weight_sum); + fprintf(stderr, "%2lu %-10.5f\t", (long unsigned int)b.current_pass+1, b.loss_sum / b.importance_weight_sum); - previous_loss_sum = loss_sum; - loss_sum = 0.; - example_number = 0; - curvature = 0; - bfgs_iter_start(all, mem, lastj, importance_weight_sum, origin); - if (first_hessian_on) { - gradient_pass = false;//now start computing curvature + b.previous_loss_sum = b.loss_sum; + b.loss_sum = 0.; + b.example_number = 0; + b.curvature = 0; + bfgs_iter_start(all, b, b.mem, b.lastj, b.importance_weight_sum, b.origin); + if (b.first_hessian_on) { + b.gradient_pass = false;//now start computing curvature } else { - step_size = 0.5; + b.step_size = 0.5; float d_mag = direction_magnitude(all); - ftime(&t_end_global); - net_time = (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm)); + ftime(&b.t_end_global); + b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm)); if (!all.quiet) - fprintf(stderr, "%-10s\t%-10.5f\t%-10.5f\n", "", d_mag, step_size); - predictions.erase(); - update_weight(all, all.final_regressor_name, step_size, current_pass); } + fprintf(stderr, "%-10s\t%-10.5f\t%-10.5f\n", "", d_mag, b.step_size); + b.predictions.erase(); + update_weight(all, all.final_regressor_name, b.step_size, b.current_pass); } } else /********************************************************************/ /* B) GRADIENT CALCULATED *******************************************/ /********************************************************************/ - if (gradient_pass) // We just finished computing all gradients + if (b.gradient_pass) // We just finished computing all gradients { if(all.span_server != "") { - loss_sum = accumulate_scalar(all, all.span_server, (float)loss_sum); //Accumulate loss_sums + float t = (float)b.loss_sum; + b.loss_sum = accumulate_scalar(all, all.span_server, t); //Accumulate loss_sums accumulate(all, all.span_server, all.reg, 1); //Accumulate gradients from all nodes } if (all.l2_lambda > 0.) - loss_sum += add_regularization(all, all.l2_lambda); + b.loss_sum += add_regularization(all, b, all.l2_lambda); if (!all.quiet) - fprintf(stderr, "%2lu %-10.5f\t", (long unsigned int)current_pass+1, loss_sum / importance_weight_sum); + fprintf(stderr, "%2lu %-10.5f\t", (long unsigned int)b.current_pass+1, b.loss_sum / b.importance_weight_sum); double wolfe1; - double new_step = wolfe_eval(all, mem, loss_sum, previous_loss_sum, step_size, importance_weight_sum, origin, wolfe1); + double new_step = wolfe_eval(all, b, b.mem, b.loss_sum, b.previous_loss_sum, b.step_size, b.importance_weight_sum, b.origin, wolfe1); /********************************************************************/ /* B0) DERIVATIVE ZERO: MINIMUM FOUND *******************************/ @@ -676,26 +680,26 @@ int process_pass(vw& all) { { fprintf(stderr, "\n"); fprintf(stdout, "Derivative 0 detected.\n"); - step_size=0.0; + b.step_size=0.0; status = LEARN_CONV; } /********************************************************************/ /* B1) LINE SEARCH FAILED *******************************************/ /********************************************************************/ - else if (backstep_on && (wolfe1<wolfe1_bound || loss_sum > previous_loss_sum)) + else if (b.backstep_on && (wolfe1<b.wolfe1_bound || b.loss_sum > b.previous_loss_sum)) {// curvature violated, or we stepped too far last time: step back - ftime(&t_end_global); - net_time = (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm)); - float ratio = (step_size==0.f) ? 0.f : (float)new_step/(float)step_size; + ftime(&b.t_end_global); + b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm)); + float ratio = (b.step_size==0.f) ? 0.f : (float)new_step/(float)b.step_size; if (!all.quiet) fprintf(stderr, "%-10s\t%-10s\t(revise x %.1f)\t%-10.5f\n", "","",ratio, new_step); - predictions.erase(); - update_weight(all, all.final_regressor_name, (float)(-step_size+new_step), current_pass); - step_size = (float)new_step; + b.predictions.erase(); + update_weight(all, all.final_regressor_name, (float)(-b.step_size+new_step), b.current_pass); + b.step_size = (float)new_step; zero_derivative(all); - loss_sum = 0.; + b.loss_sum = 0.; } /********************************************************************/ @@ -703,38 +707,38 @@ int process_pass(vw& all) { /* DETERMINE NEXT SEARCH DIRECTION ******************/ /********************************************************************/ else { - double rel_decrease = (previous_loss_sum-loss_sum)/previous_loss_sum; - if (!nanpattern((float)rel_decrease) && backstep_on && fabs(rel_decrease)<all.rel_threshold) { + double rel_decrease = (b.previous_loss_sum-b.loss_sum)/b.previous_loss_sum; + if (!nanpattern((float)rel_decrease) && b.backstep_on && fabs(rel_decrease)<all.rel_threshold) { fprintf(stdout, "\nTermination condition reached in pass %ld: decrease in loss less than %.3f%%.\n" - "If you want to optimize further, decrease termination threshold.\n", (long int)current_pass+1, all.rel_threshold*100.0); + "If you want to optimize further, decrease termination threshold.\n", (long int)b.current_pass+1, all.rel_threshold*100.0); status = LEARN_CONV; } - previous_loss_sum = loss_sum; - loss_sum = 0.; - example_number = 0; - curvature = 0; - step_size = 1.0; + b.previous_loss_sum = b.loss_sum; + b.loss_sum = 0.; + b.example_number = 0; + b.curvature = 0; + b.step_size = 1.0; try { - bfgs_iter_middle(all, mem, rho, alpha, lastj, origin); + bfgs_iter_middle(all, b, b.mem, b.rho, b.alpha, b.lastj, b.origin); } catch (curv_exception e) { fprintf(stdout, "In bfgs_iter_middle: %s", curv_message); - step_size=0.0; + b.step_size=0.0; status = LEARN_CURV; } if (all.hessian_on) { - gradient_pass = false;//now start computing curvature + b.gradient_pass = false;//now start computing curvature } else { float d_mag = direction_magnitude(all); - ftime(&t_end_global); - net_time = (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm)); + ftime(&b.t_end_global); + b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm)); if (!all.quiet) - fprintf(stderr, "%-10s\t%-10.5f\t%-10.5f\n", "", d_mag, step_size); - predictions.erase(); - update_weight(all, all.final_regressor_name, step_size, current_pass); + fprintf(stderr, "%-10s\t%-10.5f\t%-10.5f\n", "", d_mag, b.step_size); + b.predictions.erase(); + update_weight(all, all.final_regressor_name, b.step_size, b.current_pass); } } } @@ -745,57 +749,58 @@ int process_pass(vw& all) { else // just finished all second gradients { if(all.span_server != "") { - curvature = accumulate_scalar(all, all.span_server, (float)curvature); //Accumulate curvatures + float t = (float)b.curvature; + b.curvature = accumulate_scalar(all, all.span_server, t); //Accumulate curvatures } if (all.l2_lambda > 0.) - curvature += regularizer_direction_magnitude(all, all.l2_lambda); - float dd = (float)derivative_in_direction(all, mem, origin); - if (curvature == 0. && dd != 0.) + b.curvature += regularizer_direction_magnitude(all, b, all.l2_lambda); + float dd = (float)derivative_in_direction(all, b, b.mem, b.origin); + if (b.curvature == 0. && dd != 0.) { fprintf(stdout, "%s", curv_message); - step_size=0.0; + b.step_size=0.0; status = LEARN_CURV; } else if ( dd == 0.) { fprintf(stdout, "Derivative 0 detected.\n"); - step_size=0.0; + b.step_size=0.0; status = LEARN_CONV; } else - step_size = - dd/(float)curvature; + b.step_size = - dd/(float)b.curvature; float d_mag = direction_magnitude(all); - predictions.erase(); - update_weight(all, all.final_regressor_name , step_size, current_pass); - ftime(&t_end_global); - net_time = (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm)); + b.predictions.erase(); + update_weight(all, all.final_regressor_name , b.step_size, b.current_pass); + ftime(&b.t_end_global); + b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm)); if (!all.quiet) - fprintf(stderr, "%-10.5f\t%-10.5f\t%-10.5f\n", curvature / importance_weight_sum, d_mag, step_size); - gradient_pass = true; + fprintf(stderr, "%-10.5f\t%-10.5f\t%-10.5f\n", b.curvature / b.importance_weight_sum, d_mag, b.step_size); + b.gradient_pass = true; }//now start computing derivatives. - current_pass++; - first_pass = false; - preconditioner_pass = false; + b.current_pass++; + b.first_pass = false; + b.preconditioner_pass = false; return status; } -void process_example(vw& all, example *ec) +void process_example(vw& all, bfgs& b, example *ec) { label_data* ld = (label_data*)ec->ld; - if (first_pass) - importance_weight_sum += ld->weight; + if (b.first_pass) + b.importance_weight_sum += ld->weight; /********************************************************************/ /* I) GRADIENT CALCULATION ******************************************/ /********************************************************************/ - if (gradient_pass) + if (b.gradient_pass) { ec->final_prediction = predict_and_gradient(all, ec);//w[0] & w[1] ec->loss = all.loss->getLoss(all.sd, ec->final_prediction, ld->label) * ld->weight; - loss_sum += ec->loss; - predictions.push_back(ec->final_prediction); + b.loss_sum += ec->loss; + b.predictions.push_back(ec->final_prediction); } /********************************************************************/ /* II) CURVATURE CALCULATION ****************************************/ @@ -803,62 +808,64 @@ void process_example(vw& all, example *ec) else //computing curvature { float d_dot_x = dot_with_direction(all, ec);//w[2] - if (example_number >= predictions.size())//Make things safe in case example source is strange. - example_number = predictions.size()-1; - ec->final_prediction = predictions[example_number]; - ec->partial_prediction = predictions[example_number]; + if (b.example_number >= b.predictions.size())//Make things safe in case example source is strange. + b.example_number = b.predictions.size()-1; + ec->final_prediction = b.predictions[b.example_number]; + ec->partial_prediction = b.predictions[b.example_number]; ec->loss = all.loss->getLoss(all.sd, ec->final_prediction, ld->label) * ld->weight; - float sd = all.loss->second_derivative(all.sd, predictions[example_number++],ld->label); - curvature += d_dot_x*d_dot_x*sd*ld->weight; + float sd = all.loss->second_derivative(all.sd, b.predictions[b.example_number++],ld->label); + b.curvature += d_dot_x*d_dot_x*sd*ld->weight; } - if (preconditioner_pass) + if (b.preconditioner_pass) update_preconditioner(all, ec);//w[3] } -void learn(void* a, example* ec) +void learn(void* a, void* d, example* ec) { vw* all = (vw*)a; + bfgs* b = (bfgs*)d; assert(ec->in_use); - if (ec->pass != current_pass) { - int status = process_pass(*all); + if (ec->pass != b->current_pass) { + int status = process_pass(*all, *b); if (status != LEARN_OK) - reset_state(*all, true); - else if (output_regularizer && current_pass==all->numpasses-1) { + reset_state(*all, *b, true); + else if (b->output_regularizer && b->current_pass==all->numpasses-1) { zero_preconditioner(*all); - preconditioner_pass = true; + b->preconditioner_pass = true; } } if (test_example(ec)) ec->final_prediction = bfgs_predict(*all,ec);//w[0] else - process_example(*all, ec); + process_example(*all, *b, ec); } -void finish(void* a) +void finish(void* a, void* d) { vw* all = (vw*)a; - if (current_pass != 0 && !output_regularizer) - process_pass(*all); + bfgs* b = (bfgs*)d; + if (b->current_pass != 0 && !b->output_regularizer) + process_pass(*all, *b); if (!all->quiet) fprintf(stderr, "\n"); - if (output_regularizer)//need to accumulate and place the regularizer. + if (b->output_regularizer)//need to accumulate and place the regularizer. { if(all->span_server != "") accumulate(*all, all->span_server, all->reg, W_COND); //Accumulate preconditioner - preconditioner_to_regularizer(*all, all->l2_lambda); + preconditioner_to_regularizer(*all, *b, all->l2_lambda); } - ftime(&t_end_global); - net_time = (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm)); + ftime(&b->t_end_global); + b->net_time = (int) (1000.0 * (b->t_end_global.time - b->t_start_global.time) + (b->t_end_global.millitm - b->t_start_global.millitm)); - predictions.delete_v(); - free(mem); - free(rho); - free(alpha); + b->predictions.delete_v(); + free(b->mem); + free(b->rho); + free(b->alpha); } -void save_load_regularizer(vw& all, io_buf& model_file, bool read, bool text) +void save_load_regularizer(vw& all, bfgs& b, io_buf& model_file, bool read, bool text) { char buff[512]; int c = 0; @@ -877,14 +884,14 @@ void save_load_regularizer(vw& all, io_buf& model_file, bool read, bool text) if (brw > 0) { assert (i< length); - v = &(regularizers[i]); + v = &(b.regularizers[i]); if (brw > 0) brw += bin_read_fixed(model_file, (char*)v, sizeof(*v), ""); } } else // write binary or text { - v = &(regularizers[i]); + v = &(b.regularizers[i]); if (*v != 0.) { c++; @@ -906,9 +913,10 @@ void save_load_regularizer(vw& all, io_buf& model_file, bool read, bool text) } -void save_load(void* in, io_buf& model_file, bool read, bool text) +void save_load(void* in, void* d, io_buf& model_file, bool read, bool text) { vw* all = (vw*)in; + bfgs* b = (bfgs*)d; uint32_t length = 1 << all->num_bits; @@ -917,8 +925,8 @@ void save_load(void* in, io_buf& model_file, bool read, bool text) initialize_regressor(*all); if (all->per_feature_regularizer_input != "") { - regularizers = (weight *)calloc(2*length, sizeof(weight)); - if (regularizers == NULL) + b->regularizers = (weight *)calloc(2*length, sizeof(weight)); + if (b->regularizers == NULL) { cerr << all->program_name << ": Failed to allocate regularizers array: try decreasing -b <bits>" << endl; exit (1); @@ -926,18 +934,18 @@ void save_load(void* in, io_buf& model_file, bool read, bool text) } int m = all->m; - mem_stride = (m==0) ? CG_EXTRA : 2*m; - mem = (float*) malloc(sizeof(float)*all->length()*(mem_stride)); - rho = (double*) malloc(sizeof(double)*m); - alpha = (double*) malloc(sizeof(double)*m); + b->mem_stride = (m==0) ? CG_EXTRA : 2*m; + b->mem = (float*) malloc(sizeof(float)*all->length()*(b->mem_stride)); + b->rho = (double*) malloc(sizeof(double)*m); + b->alpha = (double*) malloc(sizeof(double)*m); if (!all->quiet) { - fprintf(stderr, "m = %d\nAllocated %luM for weights and mem\n", m, (long unsigned int)all->length()*(sizeof(float)*(mem_stride)+sizeof(weight)*all->stride) >> 20); + fprintf(stderr, "m = %d\nAllocated %luM for weights and mem\n", m, (long unsigned int)all->length()*(sizeof(float)*(b->mem_stride)+sizeof(weight)*all->stride) >> 20); } - net_time = 0.0; - ftime(&t_start_global); + b->net_time = 0.0; + ftime(&b->t_start_global); if (!all->quiet) { @@ -947,13 +955,13 @@ void save_load(void* in, io_buf& model_file, bool read, bool text) cerr.precision(5); } - if (regularizers != NULL) + if (b->regularizers != NULL) all->l2_lambda = 1; // To make sure we are adding the regularization - output_regularizer = (all->per_feature_regularizer_output != "" || all->per_feature_regularizer_text != ""); - reset_state(*all, false); + b->output_regularizer = (all->per_feature_regularizer_output != "" || all->per_feature_regularizer_text != ""); + reset_state(*all, *b, false); } - bool reg_vector = output_regularizer || all->per_feature_regularizer_input.length() > 0; + bool reg_vector = b->output_regularizer || all->per_feature_regularizer_input.length() > 0; if (model_file.files.size() > 0) { char buff[512]; @@ -963,21 +971,22 @@ void save_load(void* in, io_buf& model_file, bool read, bool text) buff, text_len, text); if (reg_vector) - save_load_regularizer(*all, model_file, read, text); + save_load_regularizer(*all, *b, model_file, read, text); else GD::save_load_regressor(*all, model_file, read, text); } } -void drive(void* in) +void drive(void* in, void* d) { vw* all = (vw*)in; + bfgs* b = (bfgs*)d; example* ec = NULL; size_t final_pass=all->numpasses-1; - first_hessian_on = true; - backstep_on = true; + b->first_hessian_on = true; + b->backstep_on = true; while ( true ) { @@ -986,22 +995,22 @@ void drive(void* in) assert(ec->in_use); if (ec->pass<=final_pass) { - if (ec->pass != current_pass) { - int status = process_pass(*all); - if (status != LEARN_OK && final_pass>current_pass) { - final_pass = current_pass; + if (ec->pass != b->current_pass) { + int status = process_pass(*all, *b); + if (status != LEARN_OK && final_pass>b->current_pass) { + final_pass = b->current_pass; } - if (output_regularizer && current_pass==final_pass) { + if (b->output_regularizer && b->current_pass==final_pass) { zero_preconditioner(*all); - preconditioner_pass = true; + b->preconditioner_pass = true; } } - process_example(*all, ec); + process_example(*all, *b, ec); } - + return_simple_example(*all, ec); } - else if (parser_done(all->p)) + else if (parser_done(all->p)) { // finish(all); return; @@ -1011,5 +1020,38 @@ void drive(void* in) } } -} +void parse_args(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) +{ + bfgs* b = (bfgs*)calloc(1,sizeof(bfgs)); + b->wolfe1_bound = 0.01; + b->first_hessian_on=true; + b->first_pass = true; + b->gradient_pass = true; + b->preconditioner_pass = true; + + learner t = {b,drive,learn,finish,save_load}; + all.l = t; + all.bfgs = true; + all.stride = 4; + + if (vm.count("hessian_on") || all.m==0) { + all.hessian_on = true; + } + if (!all.quiet) { + if (all.m>0) + cerr << "enabling BFGS based optimization "; + else + cerr << "enabling conjugate gradient optimization via BFGS "; + if (all.hessian_on) + cerr << "with curvature calculation" << endl; + else + cerr << "**without** curvature calculation" << endl; + } + if (all.numpasses < 2) + { + cout << "you must make at least 2 passes to use BFGS" << endl; + exit(1); + } +} +} diff --git a/vowpalwabbit/bfgs.h b/vowpalwabbit/bfgs.h index def78634..fe1356ac 100644 --- a/vowpalwabbit/bfgs.h +++ b/vowpalwabbit/bfgs.h @@ -8,11 +8,7 @@ license as described in the file LICENSE. #include "gd.h" namespace BFGS { - - void drive(void*); - void finish(void*); - void learn(void*, example* ec); - void save_load(void* in, io_buf& model_file, bool read, bool text); + void parse_args(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file); } #endif diff --git a/vowpalwabbit/cb.cc b/vowpalwabbit/cb.cc index e631f334..6fbd6c04 100644 --- a/vowpalwabbit/cb.cc +++ b/vowpalwabbit/cb.cc @@ -15,14 +15,18 @@ license as described in the file LICENSE. namespace CB { - uint32_t increment = 0; - size_t cb_type = 0; - CSOAA::label cb_cs_ld; - float avg_loss_regressors = 0.; - size_t nb_ex_regressors = 0; - float last_pred_reg = 0.; - float last_correct_cost = 0.; - bool first_print_call = true; + struct cb { + uint32_t increment; + size_t cb_type; + CSOAA::label cb_cs_ld; + float avg_loss_regressors; + size_t nb_ex_regressors; + float last_pred_reg; + float last_correct_cost; + bool first_print_call; + + learner base; + }; bool know_all_cost_example(CB::label* ld) { @@ -212,14 +216,8 @@ namespace CB return NULL; } - void (*base_learner)(void*, example*) = NULL; //base learning algorithm (gd,bfgs,etc...) for training regressors of cb - void (*base_learner_cs)(void*, example*) = NULL; //base learner for cost-sensitive data - void (*base_finish)(void*) = NULL; - - void gen_cs_example_ips(void* a, example* ec, CSOAA::label& cs_ld) - { - //this implements the inverse propensity score method, where cost are importance weighted by the probability of the chosen action - vw* all = (vw*)a; + void gen_cs_example_ips(vw& all, cb& c, example* ec, CSOAA::label& cs_ld) + {//this implements the inverse propensity score method, where cost are importance weighted by the probability of the chosen action CB::label* ld = (CB::label*)ec->ld; cb_class* cl_obs = get_observed_cost(ld); @@ -228,7 +226,7 @@ namespace CB cs_ld.costs.erase(); if( ld->costs.size() == 1) { //this is a typical example where we can perform all actions //in this case generate cost-sensitive example with all actions - for(uint32_t i = 1; i <= all->sd->k; i++) + for(uint32_t i = 1; i <= all.sd->k; i++) { CSOAA::wclass wc; wc.wap_value = 0.; @@ -242,10 +240,10 @@ namespace CB //ips can be thought as the doubly robust method with a fixed regressor that predicts 0 costs for everything //update the loss of this regressor - nb_ex_regressors++; - avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x)*(cl_obs->x) - avg_loss_regressors ); - last_pred_reg = 0; - last_correct_cost = cl_obs->x; + c.nb_ex_regressors++; + c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x)*(cl_obs->x) - c.avg_loss_regressors ); + c.last_pred_reg = 0; + c.last_correct_cost = cl_obs->x; } cs_ld.costs.push_back(wc ); @@ -267,10 +265,10 @@ namespace CB //ips can be thought as the doubly robust method with a fixed regressor that predicts 0 costs for everything //update the loss of this regressor - nb_ex_regressors++; - avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x)*(cl_obs->x) - avg_loss_regressors ); - last_pred_reg = 0; - last_correct_cost = cl_obs->x; + c.nb_ex_regressors++; + c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x)*(cl_obs->x) - c.avg_loss_regressors ); + c.last_pred_reg = 0; + c.last_correct_cost = cl_obs->x; } cs_ld.costs.push_back( wc ); @@ -279,9 +277,8 @@ namespace CB } - float get_cost_pred(void* a, example* ec, uint32_t index) + float get_cost_pred(vw& all, cb& c, example* ec, uint32_t index) { - vw* all = (vw*)a; CB::label* ld = (CB::label*)ec->ld; label_data simple_temp; @@ -292,11 +289,11 @@ namespace CB ec->ld = &simple_temp; ec->partial_prediction = 0.; - uint32_t desired_increment = increment * (2*index-1); + uint32_t desired_increment = c.increment * (2*index-1); - update_example_indicies(all->audit, ec, desired_increment); - base_learner(all, ec); - update_example_indicies(all->audit, ec, -desired_increment); + update_example_indicies(all.audit, ec, desired_increment); + all.scorer.learn(&all, all.scorer.data, ec); + update_example_indicies(all.audit, ec, -desired_increment); ec->ld = ld; @@ -307,10 +304,9 @@ namespace CB } //this function below was a test to see if we save time by carefully organizing the feature offset/regression calls, but seems to be same time as gen_cs_example_dm - void gen_cs_example_dm2(void* a, example* ec, CSOAA::label& cs_ld) + void gen_cs_example_dm2(vw& all, cb& c, example* ec, CSOAA::label& cs_ld) { //this implements the direct estimation method, where costs are directly specified by the learned regressor. - vw* all = (vw*)a; CB::label* ld = (CB::label*)ec->ld; cb_class* cl_obs = get_observed_cost(ld); @@ -329,16 +325,16 @@ namespace CB cs_ld.costs.erase(); if( ld->costs.size() == 1) { //this is a typical example where we can perform all actions //in this case generate cost-sensitive example with all actions - for( uint32_t i = 1; i <= all->sd->k; i++) + for( uint32_t i = 1; i <= all.sd->k; i++) { CSOAA::wclass wc; wc.wap_value = 0.; ec->partial_prediction = 0.; - desired_increment = increment * (2*i-1); - update_example_indicies(all->audit, ec, desired_increment-current_increment); + desired_increment = c.increment * (2*i-1); + update_example_indicies(all.audit, ec, desired_increment-current_increment); current_increment = desired_increment; - base_learner(all, ec); + all.scorer.learn(&all, all.scorer.data, ec); //get cost prediction for this action wc.x = ec->partial_prediction; @@ -347,10 +343,10 @@ namespace CB wc.wap_value = 0.; if( cl_obs != NULL && cl_obs->weight_index == i ) { - nb_ex_regressors++; - avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - avg_loss_regressors ); - last_pred_reg = wc.x; - last_correct_cost = cl_obs->x; + c.nb_ex_regressors++; + c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - c.avg_loss_regressors ); + c.last_pred_reg = wc.x; + c.last_correct_cost = cl_obs->x; } cs_ld.costs.push_back( wc ); @@ -364,10 +360,10 @@ namespace CB wc.wap_value = 0.; ec->partial_prediction = 0.; - desired_increment = increment * (2*cl->weight_index-1); - update_example_indicies(all->audit, ec, desired_increment-current_increment); + desired_increment = c.increment * (2*cl->weight_index-1); + update_example_indicies(all.audit, ec, desired_increment-current_increment); current_increment = desired_increment; - base_learner(all, ec); + all.scorer.learn(&all, all.scorer.data, ec); //get cost prediction for this action wc.x = ec->partial_prediction; @@ -376,10 +372,10 @@ namespace CB wc.wap_value = 0.; if( cl_obs != NULL && cl_obs->weight_index == cl->weight_index ) { - nb_ex_regressors++; - avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - avg_loss_regressors ); - last_pred_reg = wc.x; - last_correct_cost = cl_obs->x; + c.nb_ex_regressors++; + c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - c.avg_loss_regressors ); + c.last_pred_reg = wc.x; + c.last_correct_cost = cl_obs->x; } cs_ld.costs.push_back( wc ); @@ -388,13 +384,12 @@ namespace CB ec->ld = ld; ec->partial_prediction = 0.; - update_example_indicies(all->audit, ec, -current_increment); + update_example_indicies(all.audit, ec, -current_increment); } - void gen_cs_example_dm(void* a, example* ec, CSOAA::label& cs_ld) + void gen_cs_example_dm(vw& all, cb& c, example* ec, CSOAA::label& cs_ld) { //this implements the direct estimation method, where costs are directly specified by the learned regressor. - vw* all = (vw*)a; CB::label* ld = (CB::label*)ec->ld; cb_class* cl_obs = get_observed_cost(ld); @@ -403,22 +398,22 @@ namespace CB cs_ld.costs.erase(); if( ld->costs.size() == 1) { //this is a typical example where we can perform all actions //in this case generate cost-sensitive example with all actions - for(uint32_t i = 1; i <= all->sd->k; i++) + for(uint32_t i = 1; i <= all.sd->k; i++) { CSOAA::wclass wc; wc.wap_value = 0.; //get cost prediction for this action - wc.x = get_cost_pred(a,ec,i); + wc.x = get_cost_pred(all, c,ec,i); wc.weight_index = i; wc.partial_prediction = 0.; wc.wap_value = 0.; if( cl_obs != NULL && cl_obs->weight_index == i ) { - nb_ex_regressors++; - avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - avg_loss_regressors ); - last_pred_reg = wc.x; - last_correct_cost = cl_obs->x; + c.nb_ex_regressors++; + c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - c.avg_loss_regressors ); + c.last_pred_reg = wc.x; + c.last_correct_cost = cl_obs->x; } cs_ld.costs.push_back( wc ); @@ -432,16 +427,16 @@ namespace CB wc.wap_value = 0.; //get cost prediction for this action - wc.x = get_cost_pred(a,ec,cl->weight_index); + wc.x = get_cost_pred(all, c,ec,cl->weight_index); wc.weight_index = cl->weight_index; wc.partial_prediction = 0.; wc.wap_value = 0.; if( cl_obs != NULL && cl_obs->weight_index == cl->weight_index ) { - nb_ex_regressors++; - avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - avg_loss_regressors ); - last_pred_reg = wc.x; - last_correct_cost = cl_obs->x; + c.nb_ex_regressors++; + c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - c.avg_loss_regressors ); + c.last_pred_reg = wc.x; + c.last_correct_cost = cl_obs->x; } cs_ld.costs.push_back( wc ); @@ -449,35 +444,33 @@ namespace CB } } - void gen_cs_example_dr(void* a, example* ec, CSOAA::label& cs_ld) - { - //this implements the doubly robust method - vw* all = (vw*)a; + void gen_cs_example_dr(vw& all, cb& c, example* ec, CSOAA::label& cs_ld) + {//this implements the doubly robust method CB::label* ld = (CB::label*)ec->ld; - + cb_class* cl_obs = get_observed_cost(ld); - + //generate cost sensitive example cs_ld.costs.erase(); if( ld->costs.size() == 1) { //this is a typical example where we can perform all actions //in this case generate cost-sensitive example with all actions - for(uint32_t i = 1; i <= all->sd->k; i++) + for(uint32_t i = 1; i <= all.sd->k; i++) { CSOAA::wclass wc; wc.wap_value = 0.; //get cost prediction for this label - wc.x = get_cost_pred(a,ec,i); + wc.x = get_cost_pred(all, c,ec,i); wc.weight_index = i; wc.partial_prediction = 0.; wc.wap_value = 0.; //add correction if we observed cost for this action and regressor is wrong if( cl_obs != NULL && cl_obs->weight_index == i ) { - nb_ex_regressors++; - avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - avg_loss_regressors ); - last_pred_reg = wc.x; - last_correct_cost = cl_obs->x; + c.nb_ex_regressors++; + c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - c.avg_loss_regressors ); + c.last_pred_reg = wc.x; + c.last_correct_cost = cl_obs->x; wc.x += (cl_obs->x - wc.x) / cl_obs->prob_action; } @@ -492,17 +485,17 @@ namespace CB wc.wap_value = 0.; //get cost prediction for this label - wc.x = get_cost_pred(a,ec,cl->weight_index); + wc.x = get_cost_pred(all,c,ec,cl->weight_index); wc.weight_index = cl->weight_index; wc.partial_prediction = 0.; wc.wap_value = 0.; //add correction if we observed cost for this action and regressor is wrong if( cl_obs != NULL && cl_obs->weight_index == cl->weight_index ) { - nb_ex_regressors++; - avg_loss_regressors += (1.0f/nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - avg_loss_regressors ); - last_pred_reg = wc.x; - last_correct_cost = cl_obs->x; + c.nb_ex_regressors++; + c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (cl_obs->x - wc.x)*(cl_obs->x - wc.x) - c.avg_loss_regressors ); + c.last_pred_reg = wc.x; + c.last_correct_cost = cl_obs->x; wc.x += (cl_obs->x - wc.x) / cl_obs->prob_action; } @@ -511,7 +504,7 @@ namespace CB } } - void cb_test_to_cs_test_label(void* a, example* ec, CSOAA::label& cs_ld) + void cb_test_to_cs_test_label(vw& all, example* ec, CSOAA::label& cs_ld) { CB::label* ld = (CB::label*)ec->ld; @@ -534,8 +527,9 @@ namespace CB } } - void learn(void* a, example* ec) { + void learn(void* a, void* d, example* ec) { vw* all = (vw*)a; + cb* c = (cb*)d; CB::label* ld = (CB::label*)ec->ld; float prediction = 1; @@ -543,10 +537,10 @@ namespace CB if( CB::is_test_label(ld) ) { //if so just query base cost-sensitive learner - cb_test_to_cs_test_label(a,ec,cb_cs_ld); + cb_test_to_cs_test_label(*all,ec,c->cb_cs_ld); - ec->ld = &cb_cs_ld; - base_learner_cs(all,ec); + ec->ld = &c->cb_cs_ld; + c->base.learn(all,c->base.data,ec); ec->ld = ld; return; } @@ -554,33 +548,33 @@ namespace CB //now this is a training example //generate a cost-sensitive example to update classifiers - switch(cb_type) + switch(c->cb_type) { case CB_TYPE_IPS: - gen_cs_example_ips(a,ec,cb_cs_ld); + gen_cs_example_ips(*all,*c,ec,c->cb_cs_ld); break; case CB_TYPE_DM: - gen_cs_example_dm(a,ec,cb_cs_ld); + gen_cs_example_dm(*all,*c,ec,c->cb_cs_ld); break; case CB_TYPE_DR: - gen_cs_example_dr(a,ec,cb_cs_ld); + gen_cs_example_dr(*all,*c,ec,c->cb_cs_ld); break; default: - std::cerr << "Unknown cb_type specified for contextual bandit learning: " << cb_type << ". Exiting." << endl; + std::cerr << "Unknown cb_type specified for contextual bandit learning: " << c->cb_type << ". Exiting." << endl; exit(1); } //update classifiers with cost-sensitive exemple - ec->ld = &cb_cs_ld; + ec->ld = &c->cb_cs_ld; ec->partial_prediction = 0.; - base_learner_cs(all,ec); + c->base.learn(all,c->base.data,ec); ec->ld = ld; //store current class prediction prediction = ec->final_prediction; //update our regressors if we are training regressors - if( cb_type == CB_TYPE_DM || cb_type == CB_TYPE_DR ) + if( c->cb_type == CB_TYPE_DM || c->cb_type == CB_TYPE_DR ) { cb_class* cl_obs = get_observed_cost(ld); @@ -595,10 +589,10 @@ namespace CB ec->ld = &simple_temp; - uint32_t desired_increment = increment * (2*i-1); + uint32_t desired_increment = c->increment * (2*i-1); update_example_indicies(all->audit, ec, desired_increment); ec->partial_prediction = 0.; - base_learner(all, ec); + all->scorer.learn(all, all->scorer.data, ec); update_example_indicies(all->audit, ec, -desired_increment); ec->ld = ld; @@ -608,12 +602,12 @@ namespace CB ec->final_prediction = prediction; } - void print_update(vw& all, bool is_test, example *ec) + void print_update(vw& all, cb& c, bool is_test, example *ec) { - if( first_print_call ) + if( c.first_print_call ) { fprintf(stderr, "*estimate* *estimate* avglossreg last pred last correct\n"); - first_print_call = false; + c.first_print_call = false; } if (all.sd->weighted_examples > all.sd->dump_interval && !all.quiet && !all.bfgs) @@ -632,9 +626,9 @@ namespace CB label_buf, (long unsigned int)*(OAA::prediction_t*)&ec->final_prediction, (long unsigned int)ec->num_features, - avg_loss_regressors, - last_pred_reg, - last_correct_cost); + c.avg_loss_regressors, + c.last_pred_reg, + c.last_correct_cost); all.sd->sum_loss_since_last_dump = 0.0; all.sd->old_weighted_examples = all.sd->weighted_examples; @@ -642,7 +636,7 @@ namespace CB } } - void output_example(vw& all, example* ec) + void output_example(vw& all, cb& c, example* ec) { CB::label* ld = (CB::label*)ec->ld; all.sd->weighted_examples += 1.; @@ -666,7 +660,7 @@ namespace CB } else { //we do not know exact cost of each action, so evaluate on generated cost-sensitive example currently stored in cb_cs_ld - for (CSOAA::wclass *cl = cb_cs_ld.costs.begin; cl != cb_cs_ld.costs.end; cl ++) { + for (CSOAA::wclass *cl = c.cb_cs_ld.costs.begin; cl != c.cb_cs_ld.costs.end; cl ++) { if (cl->weight_index == pred) chosen_loss = cl->x; if (cl->x < min) @@ -692,26 +686,28 @@ namespace CB all.sd->example_number++; - print_update(all, CB::is_test_label((CB::label*)ec->ld), ec); + print_update(all, c, CB::is_test_label((CB::label*)ec->ld), ec); } - void finish(void* a) + void finish(void* a, void* d) { - vw* all = (vw*)a; - cb_cs_ld.costs.delete_v(); - base_finish(all); + cb* c=(cb*)d; + c->base.finish(a,c->base.data); + c->cb_cs_ld.costs.delete_v(); + free(c); } - void drive_cb(void* in) + void drive(void* in, void* d) { vw*all = (vw*)in; + cb* c = (cb*)d; example* ec = NULL; while ( true ) { if ((ec = get_example(all->p)) != NULL)//semiblocking operation. { - learn(all, ec); - output_example(*all, ec); + learn(all, d, ec); + output_example(*all, *c, ec); VW::finish_example(*all, ec); } else if (parser_done(all->p)) @@ -724,6 +720,8 @@ namespace CB void parse_flags(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) { + cb* c = (cb*)calloc(1, sizeof(cb)); + c->first_print_call = true; po::options_description desc("CB options"); desc.add_options() ("cb_type", po::value<string>(), "contextual bandit method to use in {ips,dm,dr}"); @@ -772,47 +770,37 @@ namespace CB } if (type_string.compare("dr") == 0) { - cb_type = CB_TYPE_DR; + c->cb_type = CB_TYPE_DR; all.base_learner_nb_w *= nb_actions * 2; } else if (type_string.compare("dm") == 0) { - cb_type = CB_TYPE_DM; + c->cb_type = CB_TYPE_DM; all.base_learner_nb_w *= nb_actions * 2; } else if (type_string.compare("ips") == 0) { - cb_type = CB_TYPE_IPS; + c->cb_type = CB_TYPE_IPS; all.base_learner_nb_w *= nb_actions; } else { std::cerr << "warning: cb_type must be in {'ips','dm','dr'}; resetting to dr." << std::endl; - cb_type = CB_TYPE_DR; + c->cb_type = CB_TYPE_DR; all.base_learner_nb_w *= nb_actions * 2; } } else { //by default use doubly robust - cb_type = CB_TYPE_DR; + c->cb_type = CB_TYPE_DR; all.base_learner_nb_w *= nb_actions * 2; all.options_from_file.append(" --cb_type dr"); } - increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride; + c->increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride; *(all.p->lp) = CB::cb_label_parser; all.sd->k = nb_actions; - all.driver = drive_cb; - //this parsing is done after the cost-sensitive parsing, so all.learn currently points to the base cs learner - //and all.base_learn points to gd/bfgs base learner - base_learner_cs = all.learn; - base_learner = all.base_learn; - - all.learn = learn; - base_finish = all.finish; - all.finish = finish; - + learner l = {c, drive, learn, finish, all.l.save_load}; + all.l = l; } - - } diff --git a/vowpalwabbit/cb.h b/vowpalwabbit/cb.h index c2e40098..e185eeba 100644 --- a/vowpalwabbit/cb.h +++ b/vowpalwabbit/cb.h @@ -34,7 +34,6 @@ namespace CB { void parse_flags(vw& all, std::vector<std::string>&, po::variables_map& vm, po::variables_map& vm_file); - void output_example(vw& all, example* ec); size_t read_cached_label(shared_data* sd, void* v, io_buf& cache); void cache_label(void* v, io_buf& cache); void default_label(void* v); diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc index e1af04ca..fb44b73d 100644 --- a/vowpalwabbit/csoaa.cc +++ b/vowpalwabbit/csoaa.cc @@ -17,6 +17,10 @@ license as described in the file LICENSE. using namespace std; namespace CSOAA { + struct csoaa{ + uint32_t csoaa_increment; + learner base; + }; void name_value(substring &s, v_array<substring>& name, float &v) { @@ -274,11 +278,9 @@ namespace CSOAA { print_update(all, is_test_label((label*)ec->ld), ec); } - void (*base_learner)(void*, example*) = NULL; - void (*base_finish)(void*) = NULL; - - void learn(void* a, example* ec) { + void learn(void* a, void* d, example* ec) { vw* all = (vw*)a; + csoaa* c = (csoaa*)d; label* ld = (label*)ec->ld; size_t prediction = 1; float score = FLT_MAX; @@ -303,14 +305,14 @@ namespace CSOAA { ec->ld = &simple_temp; - uint32_t desired_increment = all->csoaa_increment * (i-1); + uint32_t desired_increment = c->csoaa_increment * (i-1); if (desired_increment != current_increment) { update_example_indicies(all->audit, ec, desired_increment - current_increment); current_increment = desired_increment; } - base_learner(all, ec); + c->base.learn(all, c->base.data, ec); cl->partial_prediction = ec->partial_prediction; if (ec->partial_prediction < score || (ec->partial_prediction == score && i < prediction)) { score = ec->partial_prediction; @@ -324,13 +326,14 @@ namespace CSOAA { update_example_indicies(all->audit, ec, -current_increment); } - void finish(void* a) + void finish(void* a, void* d) { - vw* all = (vw*)a; - base_finish(all); + csoaa* c=(csoaa*)d; + c->base.finish(a,c->base.data); + free(c); } - void drive_csoaa(void* in) + void drive(void* in, void* d) { vw* all = (vw*)in; example* ec = NULL; @@ -338,23 +341,21 @@ namespace CSOAA { { if ((ec = get_example(all->p)) != NULL)//semiblocking operation. { - learn(all, ec); + learn(all, d, ec); output_example(*all, ec); if (ec->in_use) VW::finish_example(*all, ec); } else if (parser_done(all->p)) - { - // finish(all); - return; - } + return; else ; } - } + } void parse_flags(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) { + csoaa* c=(csoaa*)calloc(1,sizeof(csoaa)); //first parse for number of actions uint32_t nb_actions = 0; if( vm_file.count("csoaa") ) { //if loaded options from regressor @@ -372,18 +373,14 @@ namespace CSOAA { } *(all.p->lp) = cs_label_parser; - if (!all.is_noop) - all.driver = drive_csoaa; + c->base=all.l; + c->csoaa_increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride; all.sd->k = nb_actions; all.base_learner_nb_w *= nb_actions; - all.csoaa_increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride; - base_learner = all.learn; - all.base_learn = all.learn; - all.learn = learn; - base_finish = all.finish; - all.finish = finish; + learner l = {c, drive, learn, finish, all.l.save_load}; + all.l = l; } bool example_is_test(example* ec) @@ -398,15 +395,153 @@ namespace CSOAA { } namespace CSOAA_AND_WAP_LDF { - v_array<example*> ec_seq = v_array<example*>(); - size_t read_example_this_loop = 0; - bool need_to_clear = true; - bool is_singleline = true; - bool is_wap = false; - float csoaa_example_t = 0; - void (*base_learner)(void*, example*) = NULL; - void (*base_finish)(void*) = NULL; + struct ldf { + v_array<example*> ec_seq; + v_hashmap< size_t, v_array<feature> > label_features; + + size_t read_example_this_loop; + bool need_to_clear; + bool is_singleline; + bool is_wap; + float csoaa_example_t; + learner base; + }; + +namespace LabelDict { + bool size_t_eq(size_t a, size_t b) { return (a==b); } + + size_t hash_lab(size_t lab) { return 328051 + 94389193 * lab; } + + bool ec_is_label_definition(example*ec) // label defs look like "___:-1" + { + v_array<CSOAA::wclass> costs = ((CSOAA::label*)ec->ld)->costs; + for (size_t j=0; j<costs.size(); j++) + if (costs[j].x >= 0.) return false; + if (ec->indices.size() == 0) return false; + if (ec->indices.size() > 2) return false; + if (ec->indices[0] != 'l') return false; + return true; + } + + bool ec_is_example_header(example*ec) // example headers look like "0:-1" + { + v_array<CSOAA::wclass> costs = ((CSOAA::label*)ec->ld)->costs; + if (costs.size() != 1) return false; + if (costs[0].weight_index != 0) return false; + if (costs[0].x >= 0) return false; + return true; + } + + bool ec_seq_is_label_definition(ldf& l, v_array<example*>ec_seq) + { + if (l.ec_seq.size() == 0) return false; + bool is_lab = ec_is_label_definition(l.ec_seq[0]); + for (size_t i=1; i<l.ec_seq.size(); i++) { + if (is_lab != ec_is_label_definition(l.ec_seq[i])) { + if (!((i == l.ec_seq.size()-1) && (OAA::example_is_newline(l.ec_seq[i])))) { + cerr << "error: mixed label definition and examples in ldf data!" << endl; + exit(-1); + } + } + } + return is_lab; + } + + void del_example_namespace(example*ec, char ns, v_array<feature> features) { + size_t numf = features.size(); + ec->num_features -= numf; + + assert (ec->atomics[(size_t)ns].size() >= numf); + if (ec->atomics[(size_t)ns].size() == numf) { // did NOT have ns + assert(ec->indices.size() > 0); + assert(ec->indices[ec->indices.size()-1] == (size_t)ns); + ec->indices.pop(); + ec->total_sum_feat_sq -= ec->sum_feat_sq[(size_t)ns]; + ec->atomics[(size_t)ns].erase(); + ec->sum_feat_sq[(size_t)ns] = 0.; + } else { // DID have ns + for (feature*f=features.begin; f!=features.end; f++) { + ec->sum_feat_sq[(size_t)ns] -= f->x * f->x; + ec->atomics[(size_t)ns].pop(); + } + } + } + + void add_example_namespace(example*ec, char ns, v_array<feature> features) { + bool has_ns = false; + for (size_t i=0; i<ec->indices.size(); i++) { + if (ec->indices[i] == (size_t)ns) { + has_ns = true; + break; + } + } + if (has_ns) { + ec->total_sum_feat_sq -= ec->sum_feat_sq[(size_t)ns]; + } else { + ec->indices.push_back((size_t)ns); + ec->sum_feat_sq[(size_t)ns] = 0; + } + + for (feature*f=features.begin; f!=features.end; f++) { + ec->sum_feat_sq[(size_t)ns] += f->x * f->x; + ec->atomics[(size_t)ns].push_back(*f); + } + + ec->num_features += features.size(); + ec->total_sum_feat_sq += ec->sum_feat_sq[(size_t)ns]; + } + + + + void add_example_namespaces_from_example(example*target, example*source) { + for (unsigned char* idx=source->indices.begin; idx!=source->indices.end; idx++) { + if (*idx == constant_namespace) continue; + add_example_namespace(target, (char)*idx, source->atomics[*idx]); + } + } + + void del_example_namespaces_from_example(example*target, example*source) { + //for (size_t*idx=source->indices.begin; idx!=source->indices.end; idx++) { + unsigned char* idx = source->indices.end; + idx--; + for (; idx>=source->indices.begin; idx--) { + if (*idx == constant_namespace) continue; + del_example_namespace(target, (char)*idx, source->atomics[*idx]); + } + } + + void add_example_namespace_from_memory(ldf& l, example*ec, size_t lab) { + size_t lab_hash = hash_lab(lab); + v_array<feature> features = l.label_features.get(lab, lab_hash); + if (features.size() == 0) return; + add_example_namespace(ec, 'l', features); + } + + void del_example_namespace_from_memory(ldf& l, example* ec, size_t lab) { + size_t lab_hash = hash_lab(lab); + v_array<feature> features = l.label_features.get(lab, lab_hash); + if (features.size() == 0) return; + del_example_namespace(ec, 'l', features); + } + + void set_label_features(ldf& l, size_t lab, v_array<feature>features) { + size_t lab_hash = hash_lab(lab); + if (l.label_features.contains(lab, lab_hash)) { return; } + l.label_features.put_after_get(lab, lab_hash, features); + } + + void free_label_features(ldf& l) { + void* label_iter = l.label_features.iterator(); + while (label_iter != NULL) { + v_array<feature> features = l.label_features.iterator_get_value(label_iter); + features.erase(); + features.delete_v(); + + label_iter = l.label_features.iterator_next(label_iter); + } + } +} inline bool cmp_wclass_ptr(const CSOAA::wclass* a, const CSOAA::wclass* b) { return a->x < b->x; } @@ -479,7 +614,7 @@ namespace CSOAA_AND_WAP_LDF { ec->indices.decr(); } - void make_single_prediction(vw& all, example*ec, size_t*prediction, float*min_score) { + void make_single_prediction(vw& all, ldf& l, example*ec, size_t*prediction, float*min_score) { label *ld = (label*)ec->ld; v_array<CSOAA::wclass> costs = ld->costs; label_data simple_label; @@ -490,10 +625,10 @@ namespace CSOAA_AND_WAP_LDF { simple_label.weight = 0.; ec->partial_prediction = 0.; - LabelDict::add_example_namespace_from_memory(ec, costs[j].weight_index); + LabelDict::add_example_namespace_from_memory(l, ec, costs[j].weight_index); ec->ld = &simple_label; - base_learner(&all, ec); // make a prediction + l.base.learn(&all, l.base.data, ec); // make a prediction costs[j].partial_prediction = ec->partial_prediction; if (ec->partial_prediction < *min_score) { @@ -501,7 +636,7 @@ namespace CSOAA_AND_WAP_LDF { *prediction = costs[j].weight_index; } - LabelDict::del_example_namespace_from_memory(ec, costs[j].weight_index); + LabelDict::del_example_namespace_from_memory(l, ec, costs[j].weight_index); } ec->ld = ld; @@ -509,46 +644,46 @@ namespace CSOAA_AND_WAP_LDF { - void do_actual_learning_wap(vw& all, size_t start_K) + void do_actual_learning_wap(vw& all, ldf& l, size_t start_K) { - size_t K = ec_seq.size(); - bool isTest = CSOAA::example_is_test(ec_seq[start_K]); + size_t K = l.ec_seq.size(); + bool isTest = CSOAA::example_is_test(l.ec_seq[start_K]); size_t prediction = 0; float min_score = FLT_MAX; v_hashmap<size_t,float> hit_labels(8, 0., NULL); for (size_t k=start_K; k<K; k++) { - example *ec = ec_seq.begin[k]; + example *ec = l.ec_seq.begin[k]; if (CSOAA::example_is_test(ec) != isTest) { isTest = true; cerr << "warning: wap_ldf got mix of train/test data; assuming test" << endl; } - if (LabelDict::ec_is_example_header(ec_seq[k])) { + if (LabelDict::ec_is_example_header(l.ec_seq[k])) { cerr << "warning: example headers at position " << k << ": can only have in initial position!" << endl; exit(-1); } - make_single_prediction(all, ec, &prediction, &min_score); + make_single_prediction(all, l, ec, &prediction, &min_score); } // do actual learning vector<CSOAA::wclass*> all_costs; if (all.training && !isTest) { for (size_t k=start_K; k<K; k++) { - v_array<CSOAA::wclass> this_costs = ((label*)ec_seq.begin[k]->ld)->costs; + v_array<CSOAA::wclass> this_costs = ((label*)l.ec_seq.begin[k]->ld)->costs; for (size_t j=0; j<this_costs.size(); j++) all_costs.push_back(&this_costs[j]); } compute_wap_values(all_costs); - csoaa_example_t += 1.; + l.csoaa_example_t += 1.; } label_data simple_label; for (size_t k1=start_K; k1<K; k1++) { - example *ec1 = ec_seq.begin[k1]; + example *ec1 = l.ec_seq.begin[k1]; label *ld1 = (label*)ec1->ld; v_array<CSOAA::wclass> costs1 = ld1->costs; bool prediction_is_me = false; @@ -558,10 +693,10 @@ namespace CSOAA_AND_WAP_LDF { for (size_t j1=0; j1<costs1.size(); j1++) { if (costs1[j1].weight_index == (uint32_t)-1) continue; if (all.training && !isTest) { - LabelDict::add_example_namespace_from_memory(ec1, costs1[j1].weight_index); + LabelDict::add_example_namespace_from_memory(l, ec1, costs1[j1].weight_index); for (size_t k2=k1+1; k2<K; k2++) { - example *ec2 = ec_seq.begin[k2]; + example *ec2 = l.ec_seq.begin[k2]; label *ld2 = (label*)ec2->ld; v_array<CSOAA::wclass> costs2 = ld2->costs; @@ -572,22 +707,22 @@ namespace CSOAA_AND_WAP_LDF { if (value_diff < 1e-6) continue; - LabelDict::add_example_namespace_from_memory(ec2, costs2[j2].weight_index); + LabelDict::add_example_namespace_from_memory(l, ec2, costs2[j2].weight_index); // learn - ec1->example_t = csoaa_example_t; + ec1->example_t = l.csoaa_example_t; simple_label.initial = 0.; simple_label.label = (costs1[j1].x < costs2[j2].x) ? -1.0f : 1.0f; simple_label.weight = value_diff; ec1->partial_prediction = 0.; subtract_example(all, ec1, ec2); - base_learner(&all, ec1); + l.base.learn(&all, l.base.data, ec1); unsubtract_example(all, ec1); - LabelDict::del_example_namespace_from_memory(ec2, costs2[j2].weight_index); + LabelDict::del_example_namespace_from_memory(l, ec2, costs2[j2].weight_index); } } - LabelDict::del_example_namespace_from_memory(ec1, costs1[j1].weight_index); + LabelDict::del_example_namespace_from_memory(l, ec1, costs1[j1].weight_index); } if (prediction == costs1[j1].weight_index) prediction_is_me = true; @@ -598,31 +733,31 @@ namespace CSOAA_AND_WAP_LDF { } } - void do_actual_learning_oaa(vw& all, size_t start_K) + void do_actual_learning_oaa(vw& all, ldf& l, size_t start_K) { - size_t K = ec_seq.size(); + size_t K = l.ec_seq.size(); size_t prediction = 0; - bool isTest = CSOAA::example_is_test(ec_seq[start_K]); + bool isTest = CSOAA::example_is_test(l.ec_seq[start_K]); float min_score = FLT_MAX; for (size_t k=start_K; k<K; k++) { - example *ec = ec_seq.begin[k]; + example *ec = l.ec_seq.begin[k]; if (CSOAA::example_is_test(ec) != isTest) { isTest = true; cerr << "warning: ldf got mix of train/test data; assuming test" << endl; } - if (LabelDict::ec_is_example_header(ec_seq[k])) { + if (LabelDict::ec_is_example_header(l.ec_seq[k])) { cerr << "warning: example headers at position " << k << ": can only have in initial position!" << endl; exit(-1); } - make_single_prediction(all, ec, &prediction, &min_score); + make_single_prediction(all, l, ec, &prediction, &min_score); } // do actual learning if (all.training && !isTest) - csoaa_example_t += 1.; + l.csoaa_example_t += 1.; for (size_t k=start_K; k<K; k++) { - example *ec = ec_seq.begin[k]; + example *ec = l.ec_seq.begin[k]; label *ld = (label*)ec->ld; v_array<CSOAA::wclass> costs = ld->costs; @@ -632,15 +767,15 @@ namespace CSOAA_AND_WAP_LDF { for (size_t j=0; j<costs.size(); j++) { if (all.training && !isTest) { float example_t = ec->example_t; - ec->example_t = csoaa_example_t; + ec->example_t = l.csoaa_example_t; simple_label.initial = 0.; simple_label.label = costs[j].x; simple_label.weight = 1.; ec->ld = &simple_label; ec->partial_prediction = 0.; - LabelDict::add_example_namespace_from_memory(ec, costs[j].weight_index); - base_learner(&all, ec); - LabelDict::del_example_namespace_from_memory(ec, costs[j].weight_index); + LabelDict::add_example_namespace_from_memory(l, ec, costs[j].weight_index); + l.base.learn(&all, l.base.data, ec); + LabelDict::del_example_namespace_from_memory(l, ec, costs[j].weight_index); ec->example_t = example_t; } @@ -660,46 +795,46 @@ namespace CSOAA_AND_WAP_LDF { } - void do_actual_learning(vw& all) + void do_actual_learning(vw& all, ldf& l) { - if (ec_seq.size() <= 0) return; // nothing to do + if (l.ec_seq.size() <= 0) return; // nothing to do /////////////////////// handle label definitions - if (LabelDict::ec_seq_is_label_definition(ec_seq)) { - for (size_t i=0; i<ec_seq.size(); i++) { + if (LabelDict::ec_seq_is_label_definition(l, l.ec_seq)) { + for (size_t i=0; i<l.ec_seq.size(); i++) { v_array<feature> features; - for (feature*f=ec_seq[i]->atomics[ec_seq[i]->indices[0]].begin; f!=ec_seq[i]->atomics[ec_seq[i]->indices[0]].end; f++) { + for (feature*f=l.ec_seq[i]->atomics[l.ec_seq[i]->indices[0]].begin; f!=l.ec_seq[i]->atomics[l.ec_seq[i]->indices[0]].end; f++) { feature fnew = { f->x, f->weight_index }; features.push_back(fnew); } - v_array<CSOAA::wclass> costs = ((CSOAA::label*)ec_seq[i]->ld)->costs; + v_array<CSOAA::wclass> costs = ((CSOAA::label*)l.ec_seq[i]->ld)->costs; for (size_t j=0; j<costs.size(); j++) { size_t lab = costs[j].weight_index; - LabelDict::set_label_features(lab, features); + LabelDict::set_label_features(l, lab, features); } } return; } /////////////////////// check for headers - size_t K = ec_seq.size(); + size_t K = l.ec_seq.size(); size_t start_K = 0; - if (LabelDict::ec_is_example_header(ec_seq[0])) { + if (LabelDict::ec_is_example_header(l.ec_seq[0])) { start_K = 1; for (size_t k=1; k<K; k++) - LabelDict::add_example_namespaces_from_example(ec_seq[k], ec_seq[0]); + LabelDict::add_example_namespaces_from_example(l.ec_seq[k], l.ec_seq[0]); } /////////////////////// learn - if (is_wap) do_actual_learning_wap(all, start_K); - else do_actual_learning_oaa(all, start_K); - + if (l.is_wap) do_actual_learning_wap(all, l, start_K); + else do_actual_learning_oaa(all, l, start_K); + /////////////////////// remove header if (start_K > 0) for (size_t k=1; k<K; k++) - LabelDict::del_example_namespaces_from_example(ec_seq[k], ec_seq[0]); + LabelDict::del_example_namespaces_from_example(l.ec_seq[k], l.ec_seq[0]); } @@ -749,149 +884,155 @@ namespace CSOAA_AND_WAP_LDF { CSOAA::print_update(all, CSOAA::example_is_test(ec), ec); } - void output_example_seq(vw& all) + void output_example_seq(vw& all, ldf& l) { - if ((ec_seq.size() > 0) && !LabelDict::ec_seq_is_label_definition(ec_seq)) { + if ((l.ec_seq.size() > 0) && !LabelDict::ec_seq_is_label_definition(l, l.ec_seq)) { all.sd->weighted_examples += 1; all.sd->example_number++; bool hit_loss = false; - for (example** ecc=ec_seq.begin; ecc!=ec_seq.end; ecc++) + for (example** ecc=l.ec_seq.begin; ecc!=l.ec_seq.end; ecc++) output_example(all, *ecc, hit_loss); if (all.raw_prediction > 0) - all.print_text(all.raw_prediction, "", ec_seq[0]->tag); + all.print_text(all.raw_prediction, "", l.ec_seq[0]->tag); } } - void clear_seq(vw& all) + void clear_seq(vw& all, ldf& l) { - if (ec_seq.size() > 0) - for (example** ecc=ec_seq.begin; ecc!=ec_seq.end; ecc++) + if (l.ec_seq.size() > 0) + for (example** ecc=l.ec_seq.begin; ecc!=l.ec_seq.end; ecc++) if ((*ecc)->in_use) VW::finish_example(all, *ecc); - ec_seq.erase(); + l.ec_seq.erase(); } - void learn_singleline(vw*all, example*ec) { - if ((!all->training) || CSOAA::example_is_test(ec)) { + void learn_singleline(vw& all, ldf& l, example*ec) { + if ((!all.training) || CSOAA::example_is_test(ec)) { size_t prediction = 0; float min_score = FLT_MAX; - make_single_prediction(*all, ec, &prediction, &min_score); + make_single_prediction(all, l, ec, &prediction, &min_score); } else { - ec_seq.erase(); - ec_seq.push_back(ec); - do_actual_learning(*all); - ec_seq.erase(); + l.ec_seq.erase(); + l.ec_seq.push_back(ec); + do_actual_learning(all,l); + l.ec_seq.erase(); } } - void learn_multiline(vw*all, example *ec) { - if (ec_seq.size() >= all->p->ring_size - 2) { // give some wiggle room - if (ec_seq[0]->pass == 0) + void learn_multiline(vw& all, ldf& l, example *ec) { + if (l.ec_seq.size() >= all.p->ring_size - 2) { // give some wiggle room + if (l.ec_seq[0]->pass == 0) cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << endl; - do_actual_learning(*all); - need_to_clear = true; + do_actual_learning(all, l); + l.need_to_clear = true; } - if (need_to_clear) { - output_example_seq(*all); - clear_seq(*all); - need_to_clear = false; + if (l.need_to_clear) { + output_example_seq(all, l); + clear_seq(all, l); + l.need_to_clear = false; } if (OAA::example_is_newline(ec)) { - do_actual_learning(*all); - if (!LabelDict::ec_seq_is_label_definition(ec_seq)) - global_print_newline(*all); + do_actual_learning(all, l); + if (!LabelDict::ec_seq_is_label_definition(l, l.ec_seq)) + global_print_newline(all); if (ec->in_use) - VW::finish_example(*all, ec); - need_to_clear = true; + VW::finish_example(all, ec); + l.need_to_clear = true; } else if (LabelDict::ec_is_label_definition(ec)) { - if (ec_seq.size() > 0) + if (l.ec_seq.size() > 0) cerr << "warning: label definition encountered in data block -- ignoring data!" << endl; - learn_singleline(all, ec); + learn_singleline(all, l, ec); if (ec->in_use) - VW::finish_example(*all, ec); + VW::finish_example(all, ec); } else { - ec_seq.push_back(ec); + l.ec_seq.push_back(ec); } } - void learn(void*a, example*ec) { + void learn(void*a, void* d, example*ec) { vw* all = (vw*)a; - if (is_singleline) learn_singleline(all, ec); - else learn_multiline(all, ec); + ldf* l = (ldf*)d; + if (l->is_singleline) learn_singleline(*all,*l, ec); + else learn_multiline(*all,*l, ec); } - void finish(void* a) + void finish(void* a, void* d) { - vw* all = (vw*)a; - clear_seq(*all); - ec_seq.delete_v(); - base_finish(all); - LabelDict::free_label_features(); + ldf* l=(ldf*)d; + l->base.finish(a,l->base.data); + vw* all = (vw*)a; + clear_seq(*all, *l); + l->ec_seq.delete_v(); + LabelDict::free_label_features(*l); } - void drive_ldf_singleline(vw*all) { + void drive_ldf_singleline(vw& all, ldf& l) { example* ec = NULL; while (true) { - if ((ec = get_example(all->p)) != NULL) { //semiblocking operation. + if ((ec = get_example(all.p)) != NULL) { //semiblocking operation. if (LabelDict::ec_is_example_header(ec)) { cerr << "error: example headers not allowed in ldf singleline mode" << endl; exit(-1); } - learn_singleline(all, ec); + learn_singleline(all, l, ec); if (! LabelDict::ec_is_label_definition(ec)) { - all->sd->weighted_examples += 1; - all->sd->example_number++; + all.sd->weighted_examples += 1; + all.sd->example_number++; } bool hit_loss = false; - output_example(*all, ec, hit_loss); + output_example(all, ec, hit_loss); if (ec->in_use) - VW::finish_example(*all, ec); - } else if (parser_done(all->p)) { + VW::finish_example(all, ec); + } else if (parser_done(all.p)) { return; } } } - void drive_ldf_multiline(vw*all) { + void drive_ldf_multiline(vw& all, ldf& l) { example* ec = NULL; - read_example_this_loop = 0; - need_to_clear = false; + l.read_example_this_loop = 0; + l.need_to_clear = false; while (true) { - if ((ec = get_example(all->p)) != NULL) { // semiblocking operation - learn_multiline(all, ec); - if (need_to_clear) { - output_example_seq(*all); - clear_seq(*all); - need_to_clear = false; + if ((ec = get_example(all.p)) != NULL) { // semiblocking operation + learn_multiline(all, l, ec); + if (l.need_to_clear) { + output_example_seq(all, l); + clear_seq(all, l); + l.need_to_clear = false; } - } else if (parser_done(all->p)) { - do_actual_learning(*all); - output_example_seq(*all); - clear_seq(*all); - ec_seq.delete_v(); + } else if (parser_done(all.p)) { + do_actual_learning(all, l); + output_example_seq(all, l); + clear_seq(all, l); + l.ec_seq.delete_v(); return; } } } - void drive_ldf(void*in) + void drive(void*in, void* d) { vw* all =(vw*)in; - if (is_singleline) - drive_ldf_singleline(all); + ldf* l = (ldf*)d; + if (l->is_singleline) + drive_ldf_singleline(*all, *l); else - drive_ldf_multiline(all); + drive_ldf_multiline(*all,*l); } - - + void parse_flags(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) { + ldf* ld = (ldf*)calloc(1, sizeof(ldf)); + ld->need_to_clear = true; + ld->is_singleline = true; + string ldf_arg; if(vm_file.count("csoaa_ldf")) { ldf_arg = vm_file["csoaa_ldf"].as<string>(); @@ -908,7 +1049,7 @@ namespace CSOAA_AND_WAP_LDF { } else if( vm_file.count("wap_ldf") ) { ldf_arg = vm_file["wap_ldf"].as<string>(); - is_wap = true; + ld->is_wap = true; if(vm.count("wap_ldf") && ldf_arg.compare(vm["wap_ldf"].as<string>()) != 0) { ldf_arg = vm["csoaa_ldf"].as<string>(); @@ -917,7 +1058,7 @@ namespace CSOAA_AND_WAP_LDF { } else { ldf_arg = vm["wap_ldf"].as<string>(); - is_wap = true; + ld->is_wap = true; all.options_from_file.append(" --wap_ldf "); all.options_from_file.append(ldf_arg); } @@ -927,9 +1068,9 @@ namespace CSOAA_AND_WAP_LDF { all.sd->k = (uint32_t)-1; if (ldf_arg.compare("singleline") == 0 || ldf_arg.compare("s") == 0) - is_singleline = true; + ld->is_singleline = true; else if (ldf_arg.compare("multiline") == 0 || ldf_arg.compare("m") == 0) - is_singleline = false; + ld->is_singleline = false; else { cerr << "ldf requires either [s]ingleline or [m]ultiline argument" << endl; exit(-1); @@ -938,14 +1079,10 @@ namespace CSOAA_AND_WAP_LDF { if (all.add_constant) { all.add_constant = false; } - - if (!all.is_noop) - all.driver = drive_ldf; - base_learner = all.learn; - all.base_learn = all.learn; - all.learn = learn; - base_finish = all.finish; - all.finish = finish; + ld->label_features = v_hashmap<size_t, v_array<feature> >::v_hashmap(256, v_array<feature>(), LabelDict::size_t_eq); + + learner l = {ld, drive, learn, finish, all.l.save_load}; + all.l = l; } void global_print_newline(vw& all) @@ -959,142 +1096,6 @@ namespace CSOAA_AND_WAP_LDF { std::cerr << "write error" << std::endl; } } - } -namespace LabelDict { - bool size_t_eq(size_t a, size_t b) { return (a==b); } - v_hashmap< size_t, v_array<feature> > label_features(256, v_array<feature>(), size_t_eq); - - size_t hash_lab(size_t lab) { return 328051 + 94389193 * lab; } - - bool ec_is_label_definition(example*ec) // label defs look like "___:-1" - { - v_array<CSOAA::wclass> costs = ((CSOAA::label*)ec->ld)->costs; - for (size_t j=0; j<costs.size(); j++) - if (costs[j].x >= 0.) return false; - if (ec->indices.size() == 0) return false; - if (ec->indices.size() > 2) return false; - if (ec->indices[0] != 'l') return false; - return true; - } - - bool ec_is_example_header(example*ec) // example headers look like "0:-1" - { - v_array<CSOAA::wclass> costs = ((CSOAA::label*)ec->ld)->costs; - if (costs.size() != 1) return false; - if (costs[0].weight_index != 0) return false; - if (costs[0].x >= 0) return false; - return true; - } - - bool ec_seq_is_label_definition(v_array<example*>ec_seq) - { - if (ec_seq.size() == 0) return false; - bool is_lab = ec_is_label_definition(ec_seq[0]); - for (size_t i=1; i<ec_seq.size(); i++) { - if (is_lab != ec_is_label_definition(ec_seq[i])) { - if (!((i == ec_seq.size()-1) && (OAA::example_is_newline(ec_seq[i])))) { - cerr << "error: mixed label definition and examples in ldf data!" << endl; - exit(-1); - } - } - } - return is_lab; - } - - void del_example_namespace(example*ec, char ns, v_array<feature> features) { - size_t numf = features.size(); - ec->num_features -= numf; - - assert (ec->atomics[(size_t)ns].size() >= numf); - if (ec->atomics[(size_t)ns].size() == numf) { // did NOT have ns - assert(ec->indices.size() > 0); - assert(ec->indices[ec->indices.size()-1] == (size_t)ns); - ec->indices.pop(); - ec->total_sum_feat_sq -= ec->sum_feat_sq[(size_t)ns]; - ec->atomics[(size_t)ns].erase(); - ec->sum_feat_sq[(size_t)ns] = 0.; - } else { // DID have ns - for (feature*f=features.begin; f!=features.end; f++) { - ec->sum_feat_sq[(size_t)ns] -= f->x * f->x; - ec->atomics[(size_t)ns].pop(); - } - } - } - - void add_example_namespace(example*ec, char ns, v_array<feature> features) { - bool has_ns = false; - for (size_t i=0; i<ec->indices.size(); i++) { - if (ec->indices[i] == (size_t)ns) { - has_ns = true; - break; - } - } - if (has_ns) { - ec->total_sum_feat_sq -= ec->sum_feat_sq[(size_t)ns]; - } else { - ec->indices.push_back((size_t)ns); - ec->sum_feat_sq[(size_t)ns] = 0; - } - - for (feature*f=features.begin; f!=features.end; f++) { - ec->sum_feat_sq[(size_t)ns] += f->x * f->x; - ec->atomics[(size_t)ns].push_back(*f); - } - - ec->num_features += features.size(); - ec->total_sum_feat_sq += ec->sum_feat_sq[(size_t)ns]; - } - - - - void add_example_namespaces_from_example(example*target, example*source) { - for (unsigned char* idx=source->indices.begin; idx!=source->indices.end; idx++) { - if (*idx == constant_namespace) continue; - add_example_namespace(target, (char)*idx, source->atomics[*idx]); - } - } - - void del_example_namespaces_from_example(example*target, example*source) { - //for (size_t*idx=source->indices.begin; idx!=source->indices.end; idx++) { - unsigned char* idx = source->indices.end; - idx--; - for (; idx>=source->indices.begin; idx--) { - if (*idx == constant_namespace) continue; - del_example_namespace(target, (char)*idx, source->atomics[*idx]); - } - } - - void add_example_namespace_from_memory(example*ec, size_t lab) { - size_t lab_hash = hash_lab(lab); - v_array<feature> features = label_features.get(lab, lab_hash); - if (features.size() == 0) return; - add_example_namespace(ec, 'l', features); - } - - void del_example_namespace_from_memory(example* ec, size_t lab) { - size_t lab_hash = hash_lab(lab); - v_array<feature> features = label_features.get(lab, lab_hash); - if (features.size() == 0) return; - del_example_namespace(ec, 'l', features); - } - - void set_label_features(size_t lab, v_array<feature>features) { - size_t lab_hash = hash_lab(lab); - if (label_features.contains(lab, lab_hash)) { return; } - label_features.put_after_get(lab, lab_hash, features); - } - - void free_label_features() { - void* label_iter = LabelDict::label_features.iterator(); - while (label_iter != NULL) { - v_array<feature> features = LabelDict::label_features.iterator_get_value(label_iter); - features.erase(); - features.delete_v(); - - label_iter = LabelDict::label_features.iterator_next(label_iter); - } - } -} diff --git a/vowpalwabbit/csoaa.h b/vowpalwabbit/csoaa.h index fa1fcb34..d8722ea5 100644 --- a/vowpalwabbit/csoaa.h +++ b/vowpalwabbit/csoaa.h @@ -58,16 +58,4 @@ namespace CSOAA_AND_WAP_LDF { const label_parser cs_label_parser = CSOAA::cs_label_parser; } -namespace LabelDict { - bool ec_is_label_definition(example*ec); - bool ec_is_example_header(example*ec); - bool ec_seq_is_label_definition(v_array<example*>ec_seq); - void add_example_namespaces_from_example(example*target, example*source); - void del_example_namespaces_from_example(example*target, example*source); - void add_example_namespace_from_memory(example*ec, size_t lab); - void del_example_namespace_from_memory(example* ec, size_t lab); - void set_label_features(size_t lab, v_array<feature>features); - void free_label_features(); -} - #endif diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index d5a81563..4486db0d 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -23,11 +23,6 @@ using namespace std; namespace ECT { - - //nonreentrant - uint32_t k = 1; - uint32_t errors = 0; - struct direction { size_t id; //unique id for node size_t tournament; //unique id for node @@ -37,24 +32,30 @@ namespace ECT uint32_t right; //down traversal, right bool last; }; - - v_array<direction> directions;//The nodes of the tournament datastructure - - v_array<v_array<v_array<uint32_t > > > all_levels; - - v_array<uint32_t> final_nodes; //The final nodes of each tournament. - - v_array<size_t> up_directions; //On edge e, which node n is in the up direction? - v_array<size_t> down_directions;//On edge e, which node n is in the down direction? - - size_t tree_height = 0; //The height of the final tournament. - uint32_t last_pair = 0; - - uint32_t increment = 0; + struct ect{ + uint32_t k; + uint32_t errors; + v_array<direction> directions;//The nodes of the tournament datastructure + + v_array<v_array<v_array<uint32_t > > > all_levels; + + v_array<uint32_t> final_nodes; //The final nodes of each tournament. + + v_array<size_t> up_directions; //On edge e, which node n is in the up direction? + v_array<size_t> down_directions;//On edge e, which node n is in the down direction? + + size_t tree_height; //The height of the final tournament. + + uint32_t last_pair; + + uint32_t increment; + + v_array<bool> tournaments_won; + + learner base; + }; - v_array<bool> tournaments_won; - bool exists(v_array<size_t> db) { for (size_t i = 0; i< db.size();i++) @@ -94,19 +95,7 @@ namespace ECT cout << endl; } - void print_state() - { - cout << "all_levels = " << endl; - for (size_t l = 0; l < all_levels.size(); l++) - print_level(all_levels[l]); - - cout << "directions = " << endl; - for (size_t i = 0; i < directions.size(); i++) - cout << " | " << directions[i].id << " t" << directions[i].tournament << " " << directions[i].winner << " " << directions[i].loser << " " << directions[i].left << " " << directions[i].right << " " << directions[i].last; - cout << endl; - } - - void create_circuit(vw& all, uint32_t max_label, uint32_t eliminations) + void create_circuit(vw& all, ect& e, uint32_t max_label, uint32_t eliminations) { if (max_label == 1) return; @@ -119,7 +108,7 @@ namespace ECT { t.push_back(i); direction d = {i,0,0,0,0,0, false}; - directions.push_back(d); + e.directions.push_back(d); } tournaments.push_back(t); @@ -127,16 +116,16 @@ namespace ECT for (size_t i = 0; i < eliminations-1; i++) tournaments.push_back(v_array<uint32_t>()); - all_levels.push_back(tournaments); + e.all_levels.push_back(tournaments); size_t level = 0; - uint32_t node = (uint32_t)directions.size(); + uint32_t node = (uint32_t)e.directions.size(); - while (not_empty(all_levels[level])) + while (not_empty(e.all_levels[level])) { v_array<v_array<uint32_t > > new_tournaments; - tournaments = all_levels[level]; + tournaments = e.all_levels[level]; for (size_t t = 0; t < tournaments.size(); t++) { @@ -153,59 +142,56 @@ namespace ECT uint32_t right = tournaments[t][2*j+1]; direction d = {id,t,0,0,left,right, false}; - directions.push_back(d); - uint32_t direction_index = (uint32_t)directions.size()-1; - if (directions[left].tournament == t) - directions[left].winner = direction_index; + e.directions.push_back(d); + uint32_t direction_index = (uint32_t)e.directions.size()-1; + if (e.directions[left].tournament == t) + e.directions[left].winner = direction_index; else - directions[left].loser = direction_index; - if (directions[right].tournament == t) - directions[right].winner = direction_index; + e.directions[left].loser = direction_index; + if (e.directions[right].tournament == t) + e.directions[right].winner = direction_index; else - directions[right].loser = direction_index; - if (directions[left].last == true) - directions[left].winner = direction_index; + e.directions[right].loser = direction_index; + if (e.directions[left].last == true) + e.directions[left].winner = direction_index; if (tournaments[t].size() == 2 && (t == 0 || tournaments[t-1].size() == 0)) { - directions[direction_index].last = true; + e.directions[direction_index].last = true; if (t+1 < tournaments.size()) new_tournaments[t+1].push_back(id); else // winner eliminated. - directions[direction_index].winner = 0; - final_nodes.push_back((uint32_t)(directions.size()- 1)); + e.directions[direction_index].winner = 0; + e.final_nodes.push_back((uint32_t)(e.directions.size()- 1)); } else new_tournaments[t].push_back(id); if (t+1 < tournaments.size()) new_tournaments[t+1].push_back(id); else // loser eliminated. - directions[direction_index].loser = 0; + e.directions[direction_index].loser = 0; } if (tournaments[t].size() % 2 == 1) new_tournaments[t].push_back(tournaments[t].last()); } - all_levels.push_back(new_tournaments); + e.all_levels.push_back(new_tournaments); level++; } - last_pair = (max_label - 1)*(eliminations); + e.last_pair = (max_label - 1)*(eliminations); if ( max_label > 1) - tree_height = final_depth(eliminations); + e.tree_height = final_depth(eliminations); - if (last_pair > 0) { - all.base_learner_nb_w *= (last_pair + (eliminations-1)); - increment = (uint32_t) all.length() / all.base_learner_nb_w * all.stride; + if (e.last_pair > 0) { + all.base_learner_nb_w *= (e.last_pair + (eliminations-1)); + e.increment = (uint32_t) all.length() / all.base_learner_nb_w * all.stride; } } - void (*base_learner)(void*, example*) = NULL; - void (*base_finish)(void*) = NULL; - - size_t ect_predict(vw& all, example* ec) + size_t ect_predict(vw& all, ect& e, example* ec) { - if (k == (size_t)1) + if (e.k == (size_t)1) return 1; uint32_t finals_winner = 0; @@ -214,19 +200,19 @@ namespace ECT label_data simple_temp = {FLT_MAX, 0., 0.}; ec->ld = & simple_temp; - for (size_t i = tree_height-1; i != (size_t)0 -1; i--) + for (size_t i = e.tree_height-1; i != (size_t)0 -1; i--) { - if ((finals_winner | (((size_t)1) << i)) <= errors) + if ((finals_winner | (((size_t)1) << i)) <= e.errors) {// a real choice exists uint32_t offset = 0; - uint32_t problem_number = last_pair + (finals_winner | (((uint32_t)1) << i)) - 1; //This is unique. - offset = problem_number*increment; + uint32_t problem_number = e.last_pair + (finals_winner | (((uint32_t)1) << i)) - 1; //This is unique. + offset = problem_number*e.increment; update_example_indicies(all.audit, ec,offset); ec->partial_prediction = 0; - base_learner(&all, ec); + e.base.learn(&all,e.base.data, ec); update_example_indicies(all.audit, ec,-offset); @@ -236,21 +222,21 @@ namespace ECT } } - uint32_t id = final_nodes[finals_winner]; - while (id >= k) + uint32_t id = e.final_nodes[finals_winner]; + while (id >= e.k) { - uint32_t offset = (id-k)*increment; + uint32_t offset = (id-e.k)*e.increment; ec->partial_prediction = 0; update_example_indicies(all.audit, ec,offset); - base_learner(&all, ec); + e.base.learn(&all,e.base.data, ec); float pred = ec->final_prediction; update_example_indicies(all.audit, ec,-offset); if (pred > 0.) - id = directions[id].right; + id = e.directions[id].right; else - id = directions[id].left; + id = e.directions[id].left; } return id+1; } @@ -263,18 +249,18 @@ namespace ECT return false; } - void ect_train(vw& all, example* ec) + void ect_train(vw& all, ect& e, example* ec) { - if (k == 1)//nothing to do + if (e.k == 1)//nothing to do return; OAA::mc_label * mc = (OAA::mc_label*)ec->ld; label_data simple_temp = {1.,mc->weight,0.}; - tournaments_won.erase(); + e.tournaments_won.erase(); - uint32_t id = directions[mc->label-1].winner; - bool left = directions[id].left == mc->label - 1; + uint32_t id = e.directions[mc->label-1].winner; + bool left = e.directions[id].left == mc->label - 1; do { if (left) @@ -285,15 +271,15 @@ namespace ECT simple_temp.weight = mc->weight; ec->ld = &simple_temp; - uint32_t offset = (id-k)*increment; + uint32_t offset = (id-e.k)*e.increment; update_example_indicies(all.audit, ec,offset); ec->partial_prediction = 0; - base_learner(&all, ec); + e.base.learn(&all,e.base.data, ec); simple_temp.weight = 0.; ec->partial_prediction = 0; - base_learner(&all, ec);//inefficient, we should extract final prediction exactly. + e.base.learn(&all,e.base.data, ec);//inefficient, we should extract final prediction exactly. float pred = ec->final_prediction; update_example_indicies(all.audit, ec,-offset); @@ -301,39 +287,39 @@ namespace ECT if (won) { - if (!directions[id].last) - left = directions[directions[id].winner].left == id; + if (!e.directions[id].last) + left = e.directions[e.directions[id].winner].left == id; else - tournaments_won.push_back(true); - id = directions[id].winner; + e.tournaments_won.push_back(true); + id = e.directions[id].winner; } else { - if (!directions[id].last) + if (!e.directions[id].last) { - left = directions[directions[id].loser].left == id; - if (directions[id].loser == 0) - tournaments_won.push_back(false); + left = e.directions[e.directions[id].loser].left == id; + if (e.directions[id].loser == 0) + e.tournaments_won.push_back(false); } else - tournaments_won.push_back(false); - id = directions[id].loser; + e.tournaments_won.push_back(false); + id = e.directions[id].loser; } } while(id != 0); - if (tournaments_won.size() < 1) + if (e.tournaments_won.size() < 1) cout << "badness!" << endl; //tournaments_won is a bit vector determining which tournaments the label won. - for (size_t i = 0; i < tree_height; i++) + for (size_t i = 0; i < e.tree_height; i++) { - for (uint32_t j = 0; j < tournaments_won.size()/2; j++) + for (uint32_t j = 0; j < e.tournaments_won.size()/2; j++) { - bool left = tournaments_won[j*2]; - bool right = tournaments_won[j*2+1]; + bool left = e.tournaments_won[j*2]; + bool right = e.tournaments_won[j*2+1]; if (left == right)//no query to do - tournaments_won[j] = left; + e.tournaments_won[j] = left; else //query to do { float label; @@ -342,72 +328,73 @@ namespace ECT else label = 1; simple_temp.label = label; - simple_temp.weight = (float)(1 << (tree_height -i -1)); + simple_temp.weight = (float)(1 << (e.tree_height -i -1)); ec->ld = & simple_temp; - uint32_t problem_number = last_pair + j*(1 << (i+1)) + (1 << i) -1; + uint32_t problem_number = e.last_pair + j*(1 << (i+1)) + (1 << i) -1; - uint32_t offset = problem_number*increment; + uint32_t offset = problem_number*e.increment; update_example_indicies(all.audit, ec,offset); ec->partial_prediction = 0; - base_learner(&all, ec); + e.base.learn(&all,e.base.data, ec); update_example_indicies(all.audit, ec,-offset); float pred = ec->final_prediction; if (pred > 0.) - tournaments_won[j] = right; + e.tournaments_won[j] = right; else - tournaments_won[j] = left; + e.tournaments_won[j] = left; } - if (tournaments_won.size() %2 == 1) - tournaments_won[tournaments_won.size()/2] = tournaments_won[tournaments_won.size()-1]; - tournaments_won.end = tournaments_won.begin+(1+tournaments_won.size())/2; + if (e.tournaments_won.size() %2 == 1) + e.tournaments_won[e.tournaments_won.size()/2] = e.tournaments_won[e.tournaments_won.size()-1]; + e.tournaments_won.end = e.tournaments_won.begin+(1+e.tournaments_won.size())/2; } } } - void learn(void*a, example* ec) + void learn(void*a, void* d, example* ec) { vw* all = (vw*)a; + ect* e=(ect*)d; OAA::mc_label* mc = (OAA::mc_label*)ec->ld; - if (mc->label > k) + if (mc->label > e->k) cout << "label > maximum label! This won't work right." << endl; - size_t new_label = ect_predict(*all, ec); + size_t new_label = ect_predict(*all, *e, ec); ec->ld = mc; if (mc->label != (uint32_t)-1 && all->training) - ect_train(*all, ec); + ect_train(*all, *e, ec); ec->ld = mc; *(OAA::prediction_t*)&(ec->final_prediction) = new_label; } - void finish(void* all) + void finish(void* all, void* d) { - for (size_t l = 0; l < all_levels.size(); l++) + ect* e = (ect*)d; + e->base.finish(all, e->base.data); + for (size_t l = 0; l < e->all_levels.size(); l++) { - for (size_t t = 0; t < all_levels[l].size(); t++) - all_levels[l][t].delete_v(); - all_levels[l].delete_v(); + for (size_t t = 0; t < e->all_levels[l].size(); t++) + e->all_levels[l][t].delete_v(); + e->all_levels[l].delete_v(); } - final_nodes.delete_v(); + e->final_nodes.delete_v(); - up_directions.delete_v(); + e->up_directions.delete_v(); - directions.delete_v(); + e->directions.delete_v(); - down_directions.delete_v(); + e->down_directions.delete_v(); - tournaments_won.delete_v(); - - base_finish(all); + e->tournaments_won.delete_v(); } - void drive_ect(void* in) + void drive(void* in, void* d) { vw* all = (vw*)in; example* ec = NULL; @@ -415,7 +402,7 @@ namespace ECT { if ((ec = get_example(all->p)) != NULL)//semiblocking operation. { - learn(all, ec); + learn(all, d, ec); OAA::output_example(*all, ec); VW::finish_example(*all, ec); } @@ -430,6 +417,7 @@ namespace ECT void parse_flags(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) { + ect* data = (ect*)calloc(1, sizeof(ect)); po::options_description desc("ECT options"); desc.add_options() ("error", po::value<size_t>(), "error in ECT"); @@ -448,48 +436,43 @@ namespace ECT po::notify(vm_file); //first parse for number of actions - k = 0; + data->k = 0; if( vm_file.count("ect") ) { - k = (int)vm_file["ect"].as<size_t>(); - if( vm.count("ect") && vm["ect"].as<size_t>() != k ) - std::cerr << "warning: you specified a different number of actions through --ect than the one loaded from predictor. Pursuing with loaded value of: " << k << endl; + data->k = (int)vm_file["ect"].as<size_t>(); + if( vm.count("ect") && vm["ect"].as<size_t>() != data->k ) + std::cerr << "warning: you specified a different number of actions through --ect than the one loaded from predictor. Pursuing with loaded value of: " << data->k << endl; } else { - k = (int)vm["ect"].as<size_t>(); + data->k = (int)vm["ect"].as<size_t>(); //append ect with nb_actions to options_from_file so it is saved to regressor later std::stringstream ss; - ss << " --ect " << k; + ss << " --ect " << data->k; all.options_from_file.append(ss.str()); } if(vm_file.count("error")) { - errors = (uint32_t)vm_file["error"].as<size_t>(); - if (vm.count("error") && (uint32_t)vm["error"].as<size_t>() != errors) { - cerr << "warning: specified value for --error different than the one loaded from predictor file. Pursuing with loaded value of: " << errors << endl; + data->errors = (uint32_t)vm_file["error"].as<size_t>(); + if (vm.count("error") && (uint32_t)vm["error"].as<size_t>() != data->errors) { + cerr << "warning: specified value for --error different than the one loaded from predictor file. Pursuing with loaded value of: " << data->errors << endl; } } else if (vm.count("error")) { - errors = (uint32_t)vm["error"].as<size_t>(); + data->errors = (uint32_t)vm["error"].as<size_t>(); //append error flag to options_from_file so it is saved in regressor file later stringstream ss; - ss << " --error " << errors; + ss << " --error " << data->errors; all.options_from_file.append(ss.str()); } else { - errors = 0; + data->errors = 0; } *(all.p->lp) = OAA::mc_label_parser; - all.driver = drive_ect; - base_learner = all.learn; - all.base_learn = all.learn; - all.learn = learn; - - base_finish = all.finish; - all.finish = finish; - - create_circuit(all, k, errors+1); + data->base = all.l; + create_circuit(all, *data, data->k, data->errors+1); + + learner l = {data, drive, learn, finish, all.l.save_load}; + all.l = l; } - } diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc index 2bf4542e..a493eaa0 100644 --- a/vowpalwabbit/gd.cc +++ b/vowpalwabbit/gd.cc @@ -27,15 +27,13 @@ license as described in the file LICENSE. #include "simple_label.h" #include "allreduce.h" #include "accumulate.h" +#include "learner.h" using namespace std; namespace GD { -//nonreentrant -size_t gd_current_pass = 0; - void predict(vw& all, example* ex); void sync_weights(vw& all); @@ -120,11 +118,11 @@ inline void specialized_update(vw& all, float x, uint32_t fi, float avg_norm, fl w[0] += update * x * t; } -void learn_gd(void* a, example* ec) +void learn(void* a, void* d, example* ec) { vw* all = (vw*)a; assert(ec->in_use); - if (ec->pass != gd_current_pass) + if (ec->pass != all->current_pass) { if(all->span_server != "") { @@ -137,11 +135,11 @@ void learn_gd(void* a, example* ec) if (all->save_per_pass) { sync_weights(*all); - save_predictor(*all, all->final_regressor_name, gd_current_pass); + save_predictor(*all, all->final_regressor_name, all->current_pass); } all->eta *= all->eta_decay_rate; - gd_current_pass = ec->pass; + all->current_pass = ec->pass; } if (!command_example(*all, ec)) @@ -163,7 +161,7 @@ void learn_gd(void* a, example* ec) } } -void finish_gd(void* a) + void finish(void* a, void* d) { vw* all = (vw*)a; sync_weights(*all); @@ -173,6 +171,8 @@ void finish_gd(void* a) else accumulate_avg(*all, all->span_server, all->reg, 0); } + size_t* current_pass = (size_t*)d; + free(current_pass); } void sync_weights(vw& all) { @@ -748,7 +748,7 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text) while ((!read && i < length) || (read && brw >0)); } -void save_load(void* in, io_buf& model_file, bool read, bool text) +void save_load(void* in, void* data, io_buf& model_file, bool read, bool text) { vw* all=(vw*)in; if(read) @@ -783,7 +783,7 @@ void save_load(void* in, io_buf& model_file, bool read, bool text) } } -void drive_gd(void* in) +void driver(void* in, void* data) { vw* all = (vw*)in; example* ec = NULL; @@ -792,16 +792,20 @@ void drive_gd(void* in) { if ((ec = get_example(all->p)) != NULL)//semiblocking operation. { - learn_gd(all, ec); + learn(all, data, ec); return_simple_example(*all, ec); } else if (parser_done(all->p)) - { - finish_gd(all); - return; - } + return; else ;//busywait when we have predicted on all examples but not yet trained on all. } } + + learner get_learner() + { + size_t* current_pass = (size_t*)calloc(1, sizeof(size_t)); + learner ret = {current_pass,driver,learn,finish,save_load}; + return ret; + } } diff --git a/vowpalwabbit/gd.h b/vowpalwabbit/gd.h index b62a59b3..13612774 100644 --- a/vowpalwabbit/gd.h +++ b/vowpalwabbit/gd.h @@ -16,6 +16,7 @@ license as described in the file LICENSE. #include "parser.h" #include "allreduce.h" #include "sparse_dense.h" +#include "learner.h" namespace GD{ void print_result(int f, float res, v_array<char> tag); @@ -31,15 +32,11 @@ void train_offset_example(regressor& r, example* ex, size_t offset); void compute_update(example* ec); void offset_train(regressor ®, example* &ec, float update, size_t offset); void train_one_example_single_thread(regressor& r, example* ex); -void drive_gd(void*); -void finish_gd(void*); -void learn_gd(void*, example* ec); -void save_load(void* in, io_buf& model_file, bool read, bool text); + learner get_learner(); void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text); void output_and_account_example(example* ec); bool command_example(vw&, example* ec); - template <float (*T)(vw&,float,uint32_t)> float inline_predict(vw& all, example* &ec) { diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc index 3ecb2a81..d7034053 100644 --- a/vowpalwabbit/gd_mf.cc +++ b/vowpalwabbit/gd_mf.cc @@ -207,7 +207,7 @@ float mf_predict(vw& all, example* ex) return ex->final_prediction; } -void save_load(void* in, io_buf& model_file, bool read, bool text) + void save_load(void* in, void* d, io_buf& model_file, bool read, bool text) { vw* all = (vw*)in; uint32_t length = 1 << all->num_bits; @@ -261,27 +261,39 @@ void save_load(void* in, io_buf& model_file, bool read, bool text) } } -void drive(void* in) + void learn(void* in, void* d, example* ec) + { + vw* all = (vw*)in; + size_t* current_pass = (size_t*) d; + if (ec->pass != *current_pass) { + all->eta *= all->eta_decay_rate; + *current_pass = ec->pass; + } + if (!GD::command_example(*all, ec)) + { + mf_predict(*all,ec); + if (all->training && ((label_data*)(ec->ld))->label != FLT_MAX) + mf_inline_train(*all, ec, ec->eta_round); + } + } + + void finish(void* a, void* d) + { + size_t* current_pass = (size_t*)d; + free(current_pass); + } + + void drive(void* in, void* d) { vw* all = (vw*)in; + example* ec = NULL; - size_t current_pass = 0; while ( true ) { if ((ec = get_example(all->p)) != NULL)//blocking operation. { - if (ec->pass != current_pass) { - all->eta *= all->eta_decay_rate; - //save_predictor(*all, all->final_regressor_name, current_pass); - current_pass = ec->pass; - } - if (!GD::command_example(*all, ec)) - { - mf_predict(*all,ec); - if (all->training && ((label_data*)(ec->ld))->label != FLT_MAX) - mf_inline_train(*all, ec, ec->eta_round); - } + learn(in,d,ec); return_simple_example(*all, ec); } else if (parser_done(all->p)) @@ -291,4 +303,10 @@ void drive(void* in) } } + void parse_flags(vw& all) + { + size_t* current_pass = (size_t*)calloc(1, sizeof(size_t)); + learner t = {current_pass,drive,learn,finish,save_load}; + all.l = t; + } } diff --git a/vowpalwabbit/gd_mf.h b/vowpalwabbit/gd_mf.h index 8bd6c4c6..b88f12fd 100644 --- a/vowpalwabbit/gd_mf.h +++ b/vowpalwabbit/gd_mf.h @@ -13,7 +13,6 @@ license as described in the file LICENSE. #include "gd.h" namespace GDMF{ - void drive(void*); - void save_load(void* in, io_buf& model_file, bool read, bool text); + void parse_flags(vw& all); } #endif diff --git a/vowpalwabbit/global_data.cc b/vowpalwabbit/global_data.cc index ae91381b..c5031bfa 100644 --- a/vowpalwabbit/global_data.cc +++ b/vowpalwabbit/global_data.cc @@ -172,6 +172,12 @@ void set_mm(shared_data* sd, float label) void noop_mm(shared_data* sd, float label) {} +void vw::learn(void* a, example* ec) +{ + vw* all = (vw*)a; + all->l.learn(a,all->l.data,ec); +} + vw::vw() { sd = (shared_data *) calloc(1, sizeof(shared_data)); @@ -185,6 +191,8 @@ vw::vw() reg_mode = 0; + current_pass = 0; + bfgs = false; hessian_on = false; stride = 1; @@ -200,13 +208,10 @@ vw::vw() m = 15; save_resume = false; - driver = GD::drive_gd; - learn = GD::learn_gd; - finish = GD::finish_gd; - save_load = GD::save_load; - set_minmax = set_mm; + l = GD::get_learner(); + scorer = l; - base_learn = NULL; + set_minmax = set_mm; base_learner_nb_w = 1; diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h index b2ba8a39..c58c934a 100644 --- a/vowpalwabbit/global_data.h +++ b/vowpalwabbit/global_data.h @@ -14,6 +14,7 @@ license as described in the file LICENSE. #include "comp_io.h" #include "example.h" #include "config.h" +#include "learner.h" struct version_struct { int major; @@ -115,13 +116,15 @@ struct vw { parser* p; - void (*driver)(void *); - void (*learn)(void *, example*); - void (*base_learn)(void *, example*); - void (*finish)(void *); - void (*save_load)(void *, io_buf&, bool, bool); + learner l;//the top level leaner + learner scorer;//a scoring function + + void learn(void*, example*); + void (*set_minmax)(shared_data* sd, float label); + size_t current_pass; + uint32_t num_bits; // log_2 of the number of features. bool default_bits; @@ -188,8 +191,6 @@ struct vw { bool nonormalize; bool do_reset_source; - uint32_t csoaa_increment; - float normalized_sum_norm_x; size_t normalized_idx; //offset idx where the norm is stored (1 or 2 depending on whether adaptive is true) diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc index be83fc5a..d73df641 100644 --- a/vowpalwabbit/lda_core.cc +++ b/vowpalwabbit/lda_core.cc @@ -564,8 +564,7 @@ void save_load(void* in, io_buf& model_file, bool read, bool text) void parse_flags(vw&all, std::vector<std::string>&opts, po::variables_map& vm)
{
-
- po::options_description desc("Searn options");
+ po::options_description desc("LDA options");
desc.add_options()
("lda_alpha", po::value<float>(&all.lda_alpha), "Prior on sparsity of per-document topic weights")
("lda_rho", po::value<float>(&all.lda_rho), "Prior on sparsity of topic distributions")
@@ -750,4 +749,8 @@ void drive(void* in) }
}
+ void parse_args()
+ {
+ }
+
}
diff --git a/vowpalwabbit/lda_core.h b/vowpalwabbit/lda_core.h index e087da04..df0ca613 100644 --- a/vowpalwabbit/lda_core.h +++ b/vowpalwabbit/lda_core.h @@ -7,8 +7,6 @@ license as described in the file LICENSE. #define LDA_CORE_H namespace LDA{ - void drive(void*); - void save_load(void* in, io_buf& model_file, bool read, bool text); void parse_flags(vw&, std::vector<std::string>&, po::variables_map&); } diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h new file mode 100644 index 00000000..e5e2ba89 --- /dev/null +++ b/vowpalwabbit/learner.h @@ -0,0 +1,18 @@ +/* +Copyright (c) by respective owners including Yahoo!, Microsoft, and +individual contributors. All rights reserved. Released under a BSD +license as described in the file LICENSE. + */ +#ifndef LEARNER_H +#define LEARNER_H +// This is the interface for a learning algorithm + +struct learner { + void* data; + + void (*driver)(void* all, void* data); + void (*learn)(void* all, void* data, example*); + void (*finish)(void* all, void* data); + void (*save_load)(void* all, void* data, io_buf&, bool read, bool text); +}; +#endif diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc index a15d4757..286fc403 100644 --- a/vowpalwabbit/nn.cc +++ b/vowpalwabbit/nn.cc @@ -18,27 +18,23 @@ license as described in the file LICENSE. using namespace std; namespace NN { - //nonreentrant - uint32_t k=0; - uint32_t increment=0; - loss_function* squared_loss; - example output_layer; const float hidden_min_activation = -3; const float hidden_max_activation = 3; const int nn_constant = 533357803; - bool dropout = false; - uint64_t xsubi; - uint64_t save_xsubi; - size_t nn_current_pass = 0; - bool inpass = false; - - static void - free_stuff (void) - { - delete squared_loss; - free (output_layer.indices.begin); - free (output_layer.atomics[nn_output_namespace].begin); - } + + struct nn { + uint32_t k; + uint32_t increment; + loss_function* squared_loss; + example output_layer; + bool dropout; + uint64_t xsubi; + uint64_t save_xsubi; + size_t nn_current_pass; + bool inpass; + + learner base; + }; #define cast_uint32_t static_cast<uint32_t> @@ -68,48 +64,48 @@ namespace NN { void (*base_learner)(void*,example*) = NULL; - void learn_with_output(vw*all, example* ec, bool shouldOutput) + void learn_with_output(vw& all, nn& n, example* ec, bool shouldOutput) { - if (GD::command_example(*all, ec)) { + if (GD::command_example(all, ec)) { return; } - if (all->bfgs && ec->pass != nn_current_pass) { - xsubi = save_xsubi; - nn_current_pass = ec->pass; + if (all.bfgs && ec->pass != n.nn_current_pass) { + n.xsubi = n.save_xsubi; + n.nn_current_pass = ec->pass; } label_data* ld = (label_data*)ec->ld; float save_label = ld->label; - void (*save_set_minmax) (shared_data*, float) = all->set_minmax; + void (*save_set_minmax) (shared_data*, float) = all.set_minmax; float save_min_label; float save_max_label; - float dropscale = dropout ? 2.0f : 1.0f; - loss_function* save_loss = all->loss; + float dropscale = n.dropout ? 2.0f : 1.0f; + loss_function* save_loss = all.loss; - float* hidden_units = (float*) alloca (k * sizeof (float)); - bool* dropped_out = (bool*) alloca (k * sizeof (bool)); + float* hidden_units = (float*) alloca (n.k * sizeof (float)); + bool* dropped_out = (bool*) alloca (n.k * sizeof (bool)); string outputString; stringstream outputStringStream(outputString); - all->set_minmax = noop_mm; - all->loss = squared_loss; - save_min_label = all->sd->min_label; - all->sd->min_label = hidden_min_activation; - save_max_label = all->sd->max_label; - all->sd->max_label = hidden_max_activation; + all.set_minmax = noop_mm; + all.loss = n.squared_loss; + save_min_label = all.sd->min_label; + all.sd->min_label = hidden_min_activation; + save_max_label = all.sd->max_label; + all.sd->max_label = hidden_max_activation; ld->label = FLT_MAX; - for (unsigned int i = 0; i < k; ++i) + for (unsigned int i = 0; i < n.k; ++i) { if (i != 0) - update_example_indicies(all->audit, ec, increment); + update_example_indicies(all.audit, ec, n.increment); ec->partial_prediction = 0.; - base_learner(all,ec); - hidden_units[i] = GD::finalize_prediction (*all, ec->partial_prediction); + n.base.learn(&all,n.base.data,ec); + hidden_units[i] = GD::finalize_prediction (all, ec->partial_prediction); - dropped_out[i] = (dropout && merand48 (xsubi) < 0.5); + dropped_out[i] = (n.dropout && merand48 (n.xsubi) < 0.5); if (shouldOutput) { if (i > 0) outputStringStream << ' '; @@ -117,10 +113,10 @@ namespace NN { } } ld->label = save_label; - all->loss = save_loss; - all->set_minmax = save_set_minmax; - all->sd->min_label = save_min_label; - all->sd->max_label = save_max_label; + all.loss = save_loss; + all.set_minmax = save_set_minmax; + all.sd->min_label = save_min_label; + all.sd->max_label = save_max_label; bool converse = false; float save_partial_prediction = 0; @@ -129,115 +125,115 @@ namespace NN { CONVERSE: // That's right, I'm using goto. So sue me. - output_layer.total_sum_feat_sq = 1; - output_layer.sum_feat_sq[nn_output_namespace] = 1; + n.output_layer.total_sum_feat_sq = 1; + n.output_layer.sum_feat_sq[nn_output_namespace] = 1; - for (unsigned int i = 0; i < k; ++i) + for (unsigned int i = 0; i < n.k; ++i) { float sigmah = (dropped_out[i]) ? 0.0f : dropscale * fasttanh (hidden_units[i]); - output_layer.atomics[nn_output_namespace][i+1].x = sigmah; + n.output_layer.atomics[nn_output_namespace][i+1].x = sigmah; - output_layer.total_sum_feat_sq += sigmah * sigmah; - output_layer.sum_feat_sq[nn_output_namespace] += sigmah * sigmah; + n.output_layer.total_sum_feat_sq += sigmah * sigmah; + n.output_layer.sum_feat_sq[nn_output_namespace] += sigmah * sigmah; } - if (inpass) { + if (n.inpass) { // TODO: this is not correct if there is something in the // nn_output_namespace but at least it will not leak memory // in that case - update_example_indicies (all->audit, ec, increment); + update_example_indicies (all.audit, ec, n.increment); ec->indices.push_back (nn_output_namespace); v_array<feature> save_nn_output_namespace = ec->atomics[nn_output_namespace]; - ec->atomics[nn_output_namespace] = output_layer.atomics[nn_output_namespace]; - ec->sum_feat_sq[nn_output_namespace] = output_layer.sum_feat_sq[nn_output_namespace]; - ec->total_sum_feat_sq += output_layer.sum_feat_sq[nn_output_namespace]; + ec->atomics[nn_output_namespace] = n.output_layer.atomics[nn_output_namespace]; + ec->sum_feat_sq[nn_output_namespace] = n.output_layer.sum_feat_sq[nn_output_namespace]; + ec->total_sum_feat_sq += n.output_layer.sum_feat_sq[nn_output_namespace]; ec->partial_prediction = 0.; - base_learner(all,ec); - output_layer.partial_prediction = ec->partial_prediction; - output_layer.loss = ec->loss; - ec->total_sum_feat_sq -= output_layer.sum_feat_sq[nn_output_namespace]; + n.base.learn(&all, n.base.data, ec); + n.output_layer.partial_prediction = ec->partial_prediction; + n.output_layer.loss = ec->loss; + ec->total_sum_feat_sq -= n.output_layer.sum_feat_sq[nn_output_namespace]; ec->sum_feat_sq[nn_output_namespace] = 0; ec->atomics[nn_output_namespace] = save_nn_output_namespace; ec->indices.pop (); - update_example_indicies (all->audit, ec, -increment); + update_example_indicies (all.audit, ec, -n.increment); } else { - output_layer.ld = ec->ld; - output_layer.pass = ec->pass; - output_layer.partial_prediction = 0; - output_layer.eta_round = ec->eta_round; - output_layer.eta_global = ec->eta_global; - output_layer.global_weight = ec->global_weight; - output_layer.example_t = ec->example_t; - base_learner(all,&output_layer); - output_layer.ld = 0; + n.output_layer.ld = ec->ld; + n.output_layer.pass = ec->pass; + n.output_layer.partial_prediction = 0; + n.output_layer.eta_round = ec->eta_round; + n.output_layer.eta_global = ec->eta_global; + n.output_layer.global_weight = ec->global_weight; + n.output_layer.example_t = ec->example_t; + n.base.learn(&all,n.base.data,&n.output_layer); + n.output_layer.ld = 0; } - output_layer.final_prediction = GD::finalize_prediction (*all, output_layer.partial_prediction); + n.output_layer.final_prediction = GD::finalize_prediction (all, n.output_layer.partial_prediction); if (shouldOutput) { - outputStringStream << ' ' << output_layer.partial_prediction; - all->print_text(all->raw_prediction, outputStringStream.str(), ec->tag); + outputStringStream << ' ' << n.output_layer.partial_prediction; + all.print_text(all.raw_prediction, outputStringStream.str(), ec->tag); } - if (all->training && ld->label != FLT_MAX) { - float gradient = all->loss->first_derivative(all->sd, - output_layer.final_prediction, + if (all.training && ld->label != FLT_MAX) { + float gradient = all.loss->first_derivative(all.sd, + n.output_layer.final_prediction, ld->label); if (fabs (gradient) > 0) { - all->loss = squared_loss; - all->set_minmax = noop_mm; - save_min_label = all->sd->min_label; - all->sd->min_label = hidden_min_activation; - save_max_label = all->sd->max_label; - all->sd->max_label = hidden_max_activation; - - for (size_t i = k; i > 0; --i) { + all.loss = n.squared_loss; + all.set_minmax = noop_mm; + save_min_label = all.sd->min_label; + all.sd->min_label = hidden_min_activation; + save_max_label = all.sd->max_label; + all.sd->max_label = hidden_max_activation; + + for (size_t i = n.k; i > 0; --i) { if (! dropped_out[i-1]) { float sigmah = - output_layer.atomics[nn_output_namespace][i].x / dropscale; + n.output_layer.atomics[nn_output_namespace][i].x / dropscale; float sigmahprime = dropscale * (1.0f - sigmah * sigmah); - float nu = all->reg.weight_vector[output_layer.atomics[nn_output_namespace][i].weight_index & all->weight_mask]; + float nu = all.reg.weight_vector[n.output_layer.atomics[nn_output_namespace][i].weight_index & all.weight_mask]; float gradhw = 0.5f * nu * gradient * sigmahprime; - ld->label = GD::finalize_prediction (*all, hidden_units[i-1] - gradhw); + ld->label = GD::finalize_prediction (all, hidden_units[i-1] - gradhw); if (ld->label != hidden_units[i-1]) { ec->partial_prediction = 0.; - base_learner(all,ec); + n.base.learn(&all,n.base.data,ec); } } if (i != 1) { - update_example_indicies(all->audit, ec, -increment); + update_example_indicies(all.audit, ec, -n.increment); } } - all->loss = save_loss; - all->set_minmax = save_set_minmax; - all->sd->min_label = save_min_label; - all->sd->max_label = save_max_label; + all.loss = save_loss; + all.set_minmax = save_set_minmax; + all.sd->min_label = save_min_label; + all.sd->max_label = save_max_label; } else - update_example_indicies(all->audit, ec, -(k-1)*increment); + update_example_indicies(all.audit, ec, -(n.k-1)*n.increment); } else - update_example_indicies(all->audit, ec, -(k-1)*increment); + update_example_indicies(all.audit, ec, -(n.k-1)*n.increment); ld->label = save_label; if (! converse) { - save_partial_prediction = output_layer.partial_prediction; - save_final_prediction = output_layer.final_prediction; - save_ec_loss = output_layer.loss; + save_partial_prediction = n.output_layer.partial_prediction; + save_final_prediction = n.output_layer.final_prediction; + save_ec_loss = n.output_layer.loss; } - if (dropout && ! converse) + if (n.dropout && ! converse) { - update_example_indicies (all->audit, ec, (k-1)*increment); + update_example_indicies (all.audit, ec, (n.k-1)*n.increment); - for (unsigned int i = 0; i < k; ++i) + for (unsigned int i = 0; i < n.k; ++i) { dropped_out[i] = ! dropped_out[i]; } @@ -251,20 +247,22 @@ CONVERSE: // That's right, I'm using goto. So sue me. ec->loss = save_ec_loss; } - void learn(void*a, example* ec) { + void learn(void*a, void* d,example* ec) { vw* all = (vw*)a; - learn_with_output(all, ec, false); + nn* n = (nn*)d; + learn_with_output(*all, *n, ec, false); } - void drive_nn(void *in) + void drive_nn(void *in, void* d) { vw* all = (vw*)in; + nn* n = (nn*)d; example* ec = NULL; while ( true ) { if ((ec = get_example(all->p)) != NULL)//semiblocking operation. { - learn_with_output(all, ec, all->raw_prediction > 0); + learn_with_output(*all, *n, ec, all->raw_prediction > 0); int save_raw_prediction = all->raw_prediction; all->raw_prediction = -1; return_simple_example(*all, ec); @@ -277,8 +275,20 @@ CONVERSE: // That's right, I'm using goto. So sue me. } } + void finish(void* a, void* d) + { + nn* n =(nn*)d; + n->base.finish(a,n->base.data); + delete n->squared_loss; + free (n->output_layer.indices.begin); + free (n->output_layer.atomics[nn_output_namespace].begin); + free(n); + } + void parse_flags(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) { + nn* n = (nn*)calloc(1,sizeof(nn)); + po::options_description desc("NN options"); desc.add_options() ("inpass", "Train or test sigmoidal feedforward network with input passthrough.") @@ -299,28 +309,28 @@ CONVERSE: // That's right, I'm using goto. So sue me. po::notify(vm_file); //first parse for number of hidden units - k = 0; + n->k = 0; if( vm_file.count("nn") ) { - k = (uint32_t)vm_file["nn"].as<size_t>(); - if( vm.count("nn") && (uint32_t)vm["nn"].as<size_t>() != k ) - std::cerr << "warning: you specified a different number of hidden units through --nn than the one loaded from predictor. Pursuing with loaded value of: " << k << endl; + n->k = (uint32_t)vm_file["nn"].as<size_t>(); + if( vm.count("nn") && (uint32_t)vm["nn"].as<size_t>() != n->k ) + std::cerr << "warning: you specified a different number of hidden units through --nn than the one loaded from predictor. Pursuing with loaded value of: " << n->k << endl; } else { - k = (uint32_t)vm["nn"].as<size_t>(); + n->k = (uint32_t)vm["nn"].as<size_t>(); std::stringstream ss; - ss << " --nn " << k; + ss << " --nn " << n->k; all.options_from_file.append(ss.str()); } if( vm_file.count("dropout") ) { - dropout = all.training || vm.count("dropout"); + n->dropout = all.training || vm.count("dropout"); - if (! dropout && ! vm.count("meanfield") && ! all.quiet) + if (! n->dropout && ! vm.count("meanfield") && ! all.quiet) std::cerr << "using mean field for testing, specify --dropout explicitly to override" << std::endl; } else if ( vm.count("dropout") ) { - dropout = true; + n->dropout = true; std::stringstream ss; ss << " --dropout "; @@ -328,62 +338,60 @@ CONVERSE: // That's right, I'm using goto. So sue me. } if ( vm.count("meanfield") ) { - dropout = false; + n->dropout = false; if (! all.quiet) std::cerr << "using mean field for neural network " << (all.training ? "training" : "testing") << std::endl; } - if (dropout) + if (n->dropout) if (! all.quiet) std::cerr << "using dropout for neural network " << (all.training ? "training" : "testing") << std::endl; if( vm_file.count("inpass") ) { - inpass = true; + n->inpass = true; } else if (vm.count ("inpass")) { - inpass = true; + n->inpass = true; std::stringstream ss; ss << " --inpass"; all.options_from_file.append(ss.str()); } - if (inpass && ! all.quiet) + if (n->inpass && ! all.quiet) std::cerr << "using input passthrough for neural network " << (all.training ? "training" : "testing") << std::endl; - all.driver = drive_nn; - base_learner = all.learn; - all.base_learn = all.learn; - all.learn = learn; + learner t = {n,drive_nn,learn,finish,all.l.save_load}; + all.l = t; - all.base_learner_nb_w *= (inpass) ? k + 1 : k; - increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride; + all.base_learner_nb_w *= (n->inpass) ? n->k + 1 : n->k; + n->increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride; bool initialize = true; // TODO: output_layer audit - memset (&output_layer, 0, sizeof (output_layer)); - output_layer.indices.push_back(nn_output_namespace); + memset (&n->output_layer, 0, sizeof (n->output_layer)); + n->output_layer.indices.push_back(nn_output_namespace); feature output = {1., nn_constant*all.stride}; - output_layer.atomics[nn_output_namespace].push_back(output); - initialize &= (all.reg.weight_vector[output_layer.atomics[nn_output_namespace][0].weight_index & all.weight_mask] == 0); + n->output_layer.atomics[nn_output_namespace].push_back(output); + initialize &= (all.reg.weight_vector[n->output_layer.atomics[nn_output_namespace][0].weight_index & all.weight_mask] == 0); - for (unsigned int i = 0; i < k; ++i) + for (unsigned int i = 0; i < n->k; ++i) { output.weight_index += all.stride; - output_layer.atomics[nn_output_namespace].push_back(output); - initialize &= (all.reg.weight_vector[output_layer.atomics[nn_output_namespace][i+1].weight_index & all.weight_mask] == 0); + n->output_layer.atomics[nn_output_namespace].push_back(output); + initialize &= (all.reg.weight_vector[n->output_layer.atomics[nn_output_namespace][i+1].weight_index & all.weight_mask] == 0); } - output_layer.num_features = k + 1; - output_layer.in_use = true; + n->output_layer.num_features = n->k + 1; + n->output_layer.in_use = true; if (initialize) { if (! all.quiet) @@ -391,15 +399,15 @@ CONVERSE: // That's right, I'm using goto. So sue me. // output weights - float sqrtk = sqrt ((float)k); - for (unsigned int i = 0; i <= k; ++i) + float sqrtk = sqrt ((float)n->k); + for (unsigned int i = 0; i <= n->k; ++i) { - weight* w = &all.reg.weight_vector[output_layer.atomics[nn_output_namespace][i].weight_index & all.weight_mask]; + weight* w = &all.reg.weight_vector[n->output_layer.atomics[nn_output_namespace][i].weight_index & all.weight_mask]; w[0] = (float) (frand48 () - 0.5) / sqrtk; // prevent divide by zero error - if (dropout && all.normalized_updates) + if (n->dropout && all.normalized_updates) w[all.normalized_idx] = 1e-4f; } @@ -407,22 +415,20 @@ CONVERSE: // That's right, I'm using goto. So sue me. unsigned int weight_index = constant * all.stride; - for (unsigned int i = 0; i < k; ++i) + for (unsigned int i = 0; i < n->k; ++i) { all.reg.weight_vector[weight_index & all.weight_mask] = (float) (frand48 () - 0.5); - weight_index += increment; + weight_index += n->increment; } } - squared_loss = getLossFunction (0, "squared", 0); + n->squared_loss = getLossFunction (0, "squared", 0); - xsubi = 0; + n->xsubi = 0; if (vm.count("random_seed")) - xsubi = vm["random_seed"].as<size_t>(); - - save_xsubi = xsubi; + n->xsubi = vm["random_seed"].as<size_t>(); - atexit (free_stuff); + n->save_xsubi = n->xsubi; } } diff --git a/vowpalwabbit/noop.cc b/vowpalwabbit/noop.cc index 390fe53b..02ffdb27 100644 --- a/vowpalwabbit/noop.cc +++ b/vowpalwabbit/noop.cc @@ -11,20 +11,28 @@ license as described in the file LICENSE. #include "simple_label.h" namespace NOOP { -void learn(void*in, example*ec) {} + void learn(void*in, void* d, example*ec) {} + void finish(void*in, void* d) {} -void save_load(void* in, io_buf& model_file, bool read, bool text) {} - -void drive(void* in) -{ - vw* all = (vw*)in; - example* ec = NULL; + void save_load(void* in, void* d, io_buf& model_file, bool read, bool text) {} - while ( !parser_done(all->p)){ - ec = get_example(all->p); - if (ec != NULL) - return_simple_example(*all, ec); + void drive(void* in, void* d) + { + vw* all = (vw*)in; + example* ec = NULL; + + while ( !parser_done(all->p)){ + ec = get_example(all->p); + if (ec != NULL) + return_simple_example(*all, ec); + } + } + + void parse_flags(vw& all) + { + learner t = {NULL,drive,learn,finish,save_load}; + all.l = t; + all.is_noop = true; } -} } diff --git a/vowpalwabbit/noop.h b/vowpalwabbit/noop.h index aac50e08..d54f7d3a 100644 --- a/vowpalwabbit/noop.h +++ b/vowpalwabbit/noop.h @@ -4,7 +4,5 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ namespace NOOP { - void drive(void*); - void learn(void*, example*); - void save_load(void* in, io_buf& model_file, bool read, bool text); + void parse_flags(vw&); } diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index 6358ad8c..2011af1d 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -17,6 +17,13 @@ using namespace std; namespace OAA { + struct oaa{ + uint32_t k; + uint32_t increment; + uint32_t total_increment; + learner base; + }; + char* bufread_label(mc_label* ld, char* c) { ld->label = *(uint32_t *)c; @@ -98,11 +105,6 @@ namespace OAA { } } - //nonreentrant - uint32_t k=0; - uint32_t increment=0; - uint32_t total_increment=0; - void print_update(vw& all, example *ec) { if (all.sd->weighted_examples > all.sd->dump_interval && !all.quiet && !all.bfgs) @@ -148,21 +150,19 @@ namespace OAA { print_update(all, ec); } - void (*base_learner)(void*,example*) = NULL; - - void learn_with_output(vw*all, example* ec, bool shouldOutput) + void learn_with_output(vw*all,oaa* d, example* ec, bool shouldOutput) { mc_label* mc_label_data = (mc_label*)ec->ld; size_t prediction = 1; float score = INT_MIN; - if (mc_label_data->label > k && mc_label_data->label != (uint32_t)-1) - cerr << "warning: label " << mc_label_data->label << " is greater than " << k << endl; + if (mc_label_data->label > d->k && mc_label_data->label != (uint32_t)-1) + cerr << "warning: label " << mc_label_data->label << " is greater than " << d->k << endl; string outputString; stringstream outputStringStream(outputString); - for (size_t i = 1; i <= k; i++) + for (size_t i = 1; i <= d->k; i++) { label_data simple_temp; simple_temp.initial = 0.; @@ -173,8 +173,8 @@ namespace OAA { simple_temp.weight = mc_label_data->weight; ec->ld = &simple_temp; if (i != 1) - update_example_indicies(all->audit, ec, increment); - base_learner(all,ec); + update_example_indicies(all->audit, ec, d->increment); + d->base.learn((void*)all,d->base.data,ec); if (ec->partial_prediction > score) { score = ec->partial_prediction; @@ -190,7 +190,7 @@ namespace OAA { } ec->ld = mc_label_data; *(prediction_t*)&(ec->final_prediction) = prediction; - update_example_indicies(all->audit, ec, -total_increment); + update_example_indicies(all->audit, ec, -d->total_increment); if (shouldOutput) { outputStringStream << endl; @@ -198,12 +198,11 @@ namespace OAA { } } - void learn(void*a, example* ec) { - vw* all = (vw*)a; - learn_with_output(all, ec, false); + void learn(void*a, void* d, example* ec) { + learn_with_output((vw*)a, (oaa*)d, ec, false); } - void drive_oaa(void *in) + void drive(void *in, void* d) { vw* all = (vw*)in; example* ec = NULL; @@ -211,7 +210,7 @@ namespace OAA { { if ((ec = get_example(all->p)) != NULL)//semiblocking operation. { - learn_with_output(all, ec, all->raw_prediction > 0); + learn_with_output(all, (oaa*)d, ec, all->raw_prediction > 0); output_example(*all, ec); VW::finish_example(*all, ec); } @@ -222,32 +221,38 @@ namespace OAA { } } + void finish(void* all, void* data) + { + oaa* o=(oaa*)data; + o->base.finish(all,o->base.data); + free(o); + } + void parse_flags(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) { + oaa* data = (oaa*)calloc(1, sizeof(oaa)); //first parse for number of actions - k = 0; if( vm_file.count("oaa") ) { - k = (uint32_t)vm_file["oaa"].as<size_t>(); - if( vm.count("oaa") && (uint32_t)vm["oaa"].as<size_t>() != k ) - std::cerr << "warning: you specified a different number of actions through --oaa than the one loaded from predictor. Pursuing with loaded value of: " << k << endl; + data->k = (uint32_t)vm_file["oaa"].as<size_t>(); + if( vm.count("oaa") && (uint32_t)vm["oaa"].as<size_t>() != data->k ) + std::cerr << "warning: you specified a different number of actions through --oaa than the one loaded from predictor. Pursuing with loaded value of: " << data->k << endl; } else { - k = (uint32_t)vm["oaa"].as<size_t>(); + data->k = (uint32_t)vm["oaa"].as<size_t>(); //append oaa with nb_actions to options_from_file so it is saved to regressor later std::stringstream ss; - ss << " --oaa " << k; + ss << " --oaa " << data->k; all.options_from_file.append(ss.str()); } *(all.p->lp) = mc_label_parser; - all.driver = drive_oaa; - base_learner = all.learn; - all.base_learn = all.learn; - all.learn = learn; - - all.base_learner_nb_w *= k; - increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride; - total_increment = increment*(k-1); + data->increment = ((uint32_t)all.length()/all.base_learner_nb_w) * all.stride; + data->total_increment = data->increment*(data->k-1); + data->base = all.l; + learner l = {data, drive, learn, finish, all.l.save_load}; + all.l = l; + + all.base_learner_nb_w *= data->k; } } diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index 54c6ab6f..30bcc3f1 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -237,33 +237,8 @@ vw parse_args(int argc, char *argv[]) } } - if (vm.count("bfgs") || vm.count("conjugate_gradient")) { - all.driver = BFGS::drive; - all.learn = BFGS::learn; - all.finish = BFGS::finish; - all.save_load = BFGS::save_load; - all.bfgs = true; - all.stride = 4; - - if (vm.count("hessian_on") || all.m==0) { - all.hessian_on = true; - } - if (!all.quiet) { - if (all.m>0) - cerr << "enabling BFGS based optimization "; - else - cerr << "enabling conjugate gradient optimization via BFGS "; - if (all.hessian_on) - cerr << "with curvature calculation" << endl; - else - cerr << "**without** curvature calculation" << endl; - } - if (all.numpasses < 2) - { - cout << "you must make at least 2 passes to use BFGS" << endl; - exit(1); - } - } + if (vm.count("bfgs") || vm.count("conjugate_gradient")) + BFGS::parse_args(all, to_pass_further, vm, vm_file); if (vm.count("version") || argc == 1) { /* upon direct query for version -- spit it out to stdout */ @@ -467,18 +442,8 @@ vw parse_args(int argc, char *argv[]) //if (vm.count("nonormalize")) // all.nonormalize = true; - if (vm.count("lda")) { - //default initial_t to 1 instead of 0 - if(!vm.count("initial_t")) { - all.sd->t = 1.f; - all.sd->weighted_unlabeled_examples = 1.f; - all.initial_t = 1.f; - } - + if (vm.count("lda")) LDA::parse_flags(all, to_pass_further, vm); - all.driver = LDA::drive; - all.save_load = LDA::save_load; - } if (!vm.count("lda") && !all.adaptive && !all.normalized_updates) all.eta *= powf((float)(all.sd->t), all.power_t); @@ -509,19 +474,11 @@ vw parse_args(int argc, char *argv[]) loss_parameter = vm["quantile_tau"].as<float>(); all.is_noop = false; - if (vm.count("noop")) { - all.driver = NOOP::drive; - all.learn = NOOP::learn; - all.save_load = NOOP::save_load; - all.is_noop = true; - } + if (vm.count("noop")) + NOOP::parse_flags(all); - if (all.rank != 0) { - all.driver = GDMF::drive; - all.save_load = GDMF::save_load; - loss_function = "classic"; - cerr << "Forcing classic squared loss for matrix factorization" << endl; - } + if (all.rank != 0) + GDMF::parse_flags(all); all.loss = getLossFunction(&all, loss_function, (float)loss_parameter); @@ -585,14 +542,10 @@ vw parse_args(int argc, char *argv[]) all.audit = true; if (vm.count("sendto")) - { - all.driver = SENDER::drive_send; - all.save_load = SENDER::save_load; - SENDER::parse_send_args(vm, all.pairs); - } + SENDER::parse_send_args(all, vm, all.pairs); // load rest of regressor - all.save_load(&all, io_temp, true, false); + all.l.save_load(&all, all.l.data, io_temp, true, false); io_temp.close_file(); if (all.l1_lambda < 0.) { @@ -703,7 +656,7 @@ vw parse_args(int argc, char *argv[]) CSOAA::parse_flags(all, to_pass_further, vm, vm_file); // default to CSOAA unless others have been specified got_cs = true; } - all.searnstr = (ImperativeSearn::searn_struct*)calloc(1, sizeof(ImperativeSearn::searn_struct)); + all.searnstr = (ImperativeSearn::searn*)calloc(1, sizeof(ImperativeSearn::searn)); ImperativeSearn::parse_flags(all, to_pass_further, vm, vm_file); } @@ -821,7 +774,7 @@ namespace VW { void finish(vw& all) { - all.finish(&all); + all.l.finish(&all, all.l.data); if (all.searnstr != NULL) free(all.searnstr); free_parser(all); finalize_regressor(all, all.final_regressor_name); diff --git a/vowpalwabbit/parse_regressor.cc b/vowpalwabbit/parse_regressor.cc index 77afbb84..b5c4c8c6 100644 --- a/vowpalwabbit/parse_regressor.cc +++ b/vowpalwabbit/parse_regressor.cc @@ -196,7 +196,7 @@ void dump_regressor(vw& all, string reg_name, bool as_text) io_temp.open_file(start_name.c_str(), all.stdin_off, io_buf::WRITE); save_load_header(all, io_temp, false, as_text); - all.save_load(&all, io_temp, false, as_text); + all.l.save_load(&all, all.l.data, io_temp, false, as_text); io_temp.flush(); // close_file() should do this for me ... io_temp.close_file(); diff --git a/vowpalwabbit/searn.cc b/vowpalwabbit/searn.cc index b1ac2fdb..1bc650bb 100644 --- a/vowpalwabbit/searn.cc +++ b/vowpalwabbit/searn.cc @@ -279,32 +279,6 @@ namespace SearnUtil namespace Searn { - // task stuff - search_task task; - bool is_singleline; - bool is_ldf; - bool has_hash; - bool constrainted_actions; - size_t input_label_size; - - // options - size_t max_action = 1; - size_t max_rollout = INT_MAX; - size_t passes_per_policy = 1; //this should be set to the same value as --passes for dagger - float beta = 0.5; - float gamma = 1.; - bool do_recombination = false; - bool allow_current_policy = false; //this should be set to true for dagger - bool rollout_oracle = false; //if true then rollout are performed using oracle instead (optimal approximation discussed in searn's paper). this should be set to true for dagger - bool adaptive_beta = false; //used to implement dagger through searn. if true, beta = 1-(1-alpha)^n after n updates, and policy is mixed with oracle as \pi' = (1-beta)\pi^* + beta \pi - float alpha = 0.001f; //parameter used to adapt beta for dagger (see above comment), should be in (0,1) - bool rollout_all_actions = true; //by default we rollout all actions. This is set to false when searn is used with a contextual bandit base learner, where we rollout only one sampled action - - // debug stuff - bool PRINT_DEBUG_INFO = 0; - bool PRINT_UPDATE_EVERY_EXAMPLE = 0 | PRINT_DEBUG_INFO; - - // rollout struct rollout_item { state st; @@ -313,42 +287,69 @@ namespace Searn size_t hash; }; - // memory - rollout_item* rollout; - v_array<example*> ec_seq = v_array<example*>(); - example** global_example_set = NULL; - example* empty_example = NULL; - OAA::mc_label empty_label; - v_array<CSOAA::wclass>loss_vector = v_array<CSOAA::wclass>(); - v_array<CB::cb_class>loss_vector_cb = v_array<CB::cb_class>(); - v_array<void*>old_labels = v_array<void*>(); - v_array<OAA::mc_label>new_labels = v_array<OAA::mc_label>(); - CSOAA::label testall_labels = { v_array<CSOAA::wclass>() }; - CSOAA::label allowed_labels = { v_array<CSOAA::wclass>() }; - CB::label testall_labels_cb = { v_array<CB::cb_class>() }; - CB::label allowed_labels_cb = { v_array<CB::cb_class>() }; - - // we need a hashmap that maps from STATES to ACTIONS - v_hashmap<state,action> *past_states = NULL; - v_array<state> unfreed_states = v_array<state>(); - - // tracking of example - size_t read_example_this_loop = 0; - size_t read_example_last_id = 0; - size_t passes_since_new_policy = 0; - size_t read_example_last_pass = 0; - size_t total_examples_generated = 0; - size_t total_predictions_made = 0; - size_t searn_num_features = 0; - - // variables - uint32_t current_policy = 0; - uint32_t total_number_of_policies = 1; - uint32_t increment = 0; //for policy offset - - void (*base_learner)(void*, example*) = NULL; - void (*base_finish)(void*) = NULL; + // debug stuff + const bool PRINT_DEBUG_INFO =0; + const bool PRINT_UPDATE_EVERY_EXAMPLE =0; + + struct searn { + // task stuff + search_task task; + bool is_singleline; + bool is_ldf; + bool has_hash; + bool constrainted_actions; + size_t input_label_size; + + // options + size_t max_action; + size_t max_rollout; + size_t passes_per_policy; //this should be set to the same value as --passes for dagger + float beta; + float gamma; + bool do_recombination; + bool allow_current_policy; //this should be set to true for dagger + bool rollout_oracle; //if true then rollout are performed using oracle instead (optimal approximation discussed in searn's paper). this should be set to true for dagger + bool adaptive_beta; //used to implement dagger through searn. if true, beta = 1-(1-alpha)^n after n updates, and policy is mixed with oracle as \pi' = (1-beta)\pi^* + beta \pi + float alpha; //parameter used to adapt beta for dagger (see above comment), should be in (0,1) + bool rollout_all_actions; //by default we rollout all actions. This is set to false when searn is used with a contextual bandit base learner, where we rollout only one sampled action + + // memory + rollout_item* rollout; + v_array<example*> ec_seq; + example** global_example_set; + example* empty_example; + OAA::mc_label empty_label; + v_array<CSOAA::wclass>loss_vector; + v_array<CB::cb_class>loss_vector_cb; + v_array<void*>old_labels; + v_array<OAA::mc_label>new_labels; + CSOAA::label testall_labels; + CSOAA::label allowed_labels; + CB::label testall_labels_cb; + CB::label allowed_labels_cb; + // we need a hashmap that maps from STATES to ACTIONS + v_hashmap<state,action> *past_states; + v_array<state> unfreed_states; + + // tracking of example + size_t read_example_this_loop; + size_t read_example_last_id; + size_t passes_since_new_policy; + size_t read_example_last_pass; + size_t total_examples_generated; + size_t total_predictions_made; + size_t searn_num_features; + + // variables + uint32_t current_policy; + uint32_t total_number_of_policies; + uint32_t increment; //for policy offset + + learner base; + }; + void drive(void*in, void*d); + void simple_print_example_features(vw&all, example *ec) { for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++) @@ -400,35 +401,35 @@ namespace Searn out[max_len] = 0; } - bool will_global_print_label(vw& all) + bool will_global_print_label(vw& all, searn& s) { - if (!task.to_string) return false; + if (!s.task.to_string) return false; if (all.final_prediction_sink.size() == 0) return false; return true; } - void global_print_label(vw& all, example*ec, state s0, std::vector<action> last_action_sequence) + void global_print_label(vw& all, searn& s, example*ec, state s0, std::vector<action> last_action_sequence) { - if (!task.to_string) return; + if (!s.task.to_string) return; if (all.final_prediction_sink.size() == 0) return; - string str = task.to_string(s0, false, last_action_sequence); + string str = s.task.to_string(s0, false, last_action_sequence); for (size_t i=0; i<all.final_prediction_sink.size(); i++) { int f = all.final_prediction_sink[i]; all.print_text(f, str, ec->tag); } } - void print_update(vw& all, state s0, std::vector<action> last_action_sequence) + void print_update(vw& all, searn& s, state s0, std::vector<action> last_action_sequence) { if (!should_print_update(all)) return; char true_label[21]; char pred_label[21]; - if (task.to_string) { - to_short_string(task.to_string(s0, true , empty_action_vector ), 20, true_label); - to_short_string(task.to_string(s0, false, last_action_sequence), 20, pred_label); + if (s.task.to_string) { + to_short_string(s.task.to_string(s0, true , empty_action_vector ), 20, true_label); + to_short_string(s.task.to_string(s0, false, last_action_sequence), 20, pred_label); } else { to_short_string("", 20, true_label); to_short_string("", 20, pred_label); @@ -441,11 +442,11 @@ namespace Searn all.sd->weighted_examples, true_label, pred_label, - (long unsigned int)searn_num_features, - (int)read_example_last_pass, - (int)current_policy, - (long unsigned int)total_predictions_made, - (long unsigned int)total_examples_generated); + (long unsigned int)s.searn_num_features, + (int)s.read_example_last_pass, + (int)s.current_policy, + (long unsigned int)s.total_predictions_made, + (long unsigned int)s.total_examples_generated); all.sd->sum_loss_since_last_dump = 0.0; all.sd->old_weighted_examples = all.sd->weighted_examples; @@ -454,94 +455,118 @@ namespace Searn - void clear_seq(vw&all) + void clear_seq(vw&all, searn& s) { - if (ec_seq.size() > 0) - for (example** ecc=ec_seq.begin; ecc!=ec_seq.end; ecc++) { + if (s.ec_seq.size() > 0) + for (example** ecc=s.ec_seq.begin; ecc!=s.ec_seq.end; ecc++) { VW::finish_example(all, *ecc); } - ec_seq.erase(); + s.ec_seq.erase(); } - void free_unfreed_states() + void free_unfreed_states(searn& s) { - while (!unfreed_states.empty()) { - state s = unfreed_states.pop(); - task.finish(s); + while (!s.unfreed_states.empty()) { + state st = s.unfreed_states.pop(); + s.task.finish(st); } } - void initialize_memory() + void initialize_memory(searn& s) { // initialize searn's memory - rollout = (rollout_item*)SearnUtil::calloc_or_die(max_action, sizeof(rollout_item)); - global_example_set = (example**)SearnUtil::calloc_or_die(max_action, sizeof(example*)); + s.rollout = (rollout_item*)SearnUtil::calloc_or_die(s.max_action, sizeof(rollout_item)); + s.global_example_set = (example**)SearnUtil::calloc_or_die(s.max_action, sizeof(example*)); - for (uint32_t k=1; k<=max_action; k++) { + for (uint32_t k=1; k<=s.max_action; k++) { CSOAA::wclass cost = { FLT_MAX, k, 1., 0. }; - testall_labels.costs.push_back(cost); + s.testall_labels.costs.push_back(cost); CB::cb_class cost_cb = { FLT_MAX, k, 0. }; - testall_labels_cb.costs.push_back(cost_cb); + s.testall_labels_cb.costs.push_back(cost_cb); } - empty_example = alloc_example(sizeof(OAA::mc_label)); - OAA::default_label(empty_example->ld); + s.empty_example = alloc_example(sizeof(OAA::mc_label)); + OAA::default_label(s.empty_example->ld); // cerr << "create: empty_example->ld = " << empty_example->ld << endl; - empty_example->in_use = true; + s.empty_example->in_use = true; } - void free_memory(vw&all) + void free_memory(vw&all, searn& s) { - dealloc_example(NULL, *empty_example); - free(empty_example); + dealloc_example(NULL, *s.empty_example); + free(s.empty_example); - SearnUtil::free_it(rollout); + SearnUtil::free_it(s.rollout); - loss_vector.delete_v(); + s.loss_vector.delete_v(); - old_labels.delete_v(); + s.old_labels.delete_v(); - new_labels.delete_v(); + s.new_labels.delete_v(); - free_unfreed_states(); - unfreed_states.delete_v(); + free_unfreed_states(s); + s.unfreed_states.delete_v(); - clear_seq(all); - ec_seq.delete_v(); + clear_seq(all,s); + s.ec_seq.delete_v(); - SearnUtil::free_it(global_example_set); + SearnUtil::free_it(s.global_example_set); - testall_labels.costs.delete_v(); - testall_labels_cb.costs.delete_v(); - allowed_labels.costs.delete_v(); - allowed_labels_cb.costs.delete_v(); + s.testall_labels.costs.delete_v(); + s.testall_labels_cb.costs.delete_v(); + s.allowed_labels.costs.delete_v(); + s.allowed_labels_cb.costs.delete_v(); - if (do_recombination) { - delete past_states; - past_states = NULL; + if (s.do_recombination) { + delete s.past_states; + s.past_states = NULL; } } - void learn(void*in, example *ec) + void learn(void*in, void*d, example *ec) { //vw*all = (vw*)in; // TODO } - void finish(void*in) + void finish(void*in, void*d) { + searn* s = (searn*)d; + s->base.finish(in,s->base.data); vw*all = (vw*)in; // free everything - if (task.finalize != NULL) - task.finalize(); - free_memory(*all); - base_finish(all); + if (s->task.finalize != NULL) + s->task.finalize(); + free_memory(*all,*s); + free(s); } void parse_flags(vw&all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) { + searn* s = (searn*)calloc(1,sizeof(searn)); + + s->max_action = 1; + s->max_rollout = INT_MAX; + s->passes_per_policy = 1; //this should be set to the same value as --passes for dagger + s->beta = 0.5; + s->gamma = 1.; + s->alpha = 0.001f; //parameter used to adapt beta for dagger (see above comment), should be in (0,1) + s->rollout_all_actions = true; //by default we rollout all actions. This is set to false when searn is used with a contextual bandit base learner, where we rollout only one sampled action + s->ec_seq = v_array<example*>(); + s->loss_vector = v_array<CSOAA::wclass>(); + s->loss_vector_cb = v_array<CB::cb_class>(); + s->old_labels = v_array<void*>(); + s->new_labels = v_array<OAA::mc_label>(); + s->testall_labels.costs = v_array<CSOAA::wclass>(); + s->allowed_labels.costs = v_array<CSOAA::wclass>(); + s->testall_labels_cb.costs = v_array<CB::cb_class>(); + s->allowed_labels_cb.costs = v_array<CB::cb_class>(); + // we need a hashmap that maps from STATES to ACTIONS + s->unfreed_states = v_array<state>(); + s->total_number_of_policies = 1; + po::options_description desc("Searn options"); desc.add_options() ("searn_task", po::value<string>(), "the searn task") @@ -596,154 +621,154 @@ namespace Searn } if (task_string.compare("sequence") == 0) { - task.final = SequenceTask::final; - task.loss = SequenceTask::loss; - task.step = SequenceTask::step; - task.oracle = SequenceTask::oracle; - task.copy = SequenceTask::copy; - task.finish = SequenceTask::finish; - task.searn_label_parser = OAA::mc_label_parser; - task.is_test_example = SequenceTask::is_test_example; - input_label_size = sizeof(OAA::mc_label); - task.start_state = NULL; - task.start_state_multiline = SequenceTask::start_state_multiline; + s->task.final = SequenceTask::final; + s->task.loss = SequenceTask::loss; + s->task.step = SequenceTask::step; + s->task.oracle = SequenceTask::oracle; + s->task.copy = SequenceTask::copy; + s->task.finish = SequenceTask::finish; + s->task.searn_label_parser = OAA::mc_label_parser; + s->task.is_test_example = SequenceTask::is_test_example; + s->input_label_size = sizeof(OAA::mc_label); + s->task.start_state = NULL; + s->task.start_state_multiline = SequenceTask::start_state_multiline; if (1) { - task.cs_example = SequenceTask::cs_example; - task.cs_ldf_example = NULL; + s->task.cs_example = SequenceTask::cs_example; + s->task.cs_ldf_example = NULL; } else { - task.cs_example = NULL; - task.cs_ldf_example = SequenceTask::cs_ldf_example; + s->task.cs_example = NULL; + s->task.cs_ldf_example = SequenceTask::cs_ldf_example; } - task.initialize = SequenceTask::initialize; - task.finalize = NULL; - task.equivalent = SequenceTask::equivalent; - task.hash = SequenceTask::hash; - task.allowed = SequenceTask::allowed; - task.to_string = SequenceTask::to_string; + s->task.initialize = SequenceTask::initialize; + s->task.finalize = NULL; + s->task.equivalent = SequenceTask::equivalent; + s->task.hash = SequenceTask::hash; + s->task.allowed = SequenceTask::allowed; + s->task.to_string = SequenceTask::to_string; } else { std::cerr << "error: unknown search task '" << task_string << "'" << std::endl; exit(-1); } - *(all.p->lp)=task.searn_label_parser; + *(all.p->lp)=s->task.searn_label_parser; if(vm_file.count("searn")) { //we loaded searn flag from regressor file - max_action = vm_file["searn"].as<size_t>(); - if( vm.count("searn") && vm["searn"].as<size_t>() != max_action ) - std::cerr << "warning: you specified a different number of actions through --searn than the one loaded from predictor. Pursuing with loaded value of: " << max_action << endl; + s->max_action = vm_file["searn"].as<size_t>(); + if( vm.count("searn") && vm["searn"].as<size_t>() != s->max_action ) + std::cerr << "warning: you specified a different number of actions through --searn than the one loaded from predictor. Pursuing with loaded value of: " << s->max_action << endl; } else { - max_action = vm["searn"].as<size_t>(); + s->max_action = vm["searn"].as<size_t>(); //append searn with nb_actions to options_from_file so it is saved to regressor later std::stringstream ss; - ss << " --searn " << max_action; + ss << " --searn " << s->max_action; all.options_from_file.append(ss.str()); } if(vm_file.count("searn_beta")) { //we loaded searn_beta flag from regressor file - beta = vm_file["searn_beta"].as<float>(); - if (vm.count("searn_beta") && vm["searn_beta"].as<float>() != beta ) - std::cerr << "warning: you specified a different value through --searn_beta than the one loaded from predictor. Pursuing with loaded value of: " << beta << endl; + s->beta = vm_file["searn_beta"].as<float>(); + if (vm.count("searn_beta") && vm["searn_beta"].as<float>() != s->beta ) + std::cerr << "warning: you specified a different value through --searn_beta than the one loaded from predictor. Pursuing with loaded value of: " << s->beta << endl; } else { - if (vm.count("searn_beta")) beta = vm["searn_beta"].as<float>(); + if (vm.count("searn_beta")) s->beta = vm["searn_beta"].as<float>(); //append searn_beta to options_from_file so it is saved in the regressor file later std::stringstream ss; - ss << " --searn_beta " << beta; + ss << " --searn_beta " << s->beta; all.options_from_file.append(ss.str()); } - if (vm.count("searn_rollout")) max_rollout = vm["searn_rollout"].as<size_t>(); - if (vm.count("searn_passes_per_policy")) passes_per_policy = vm["searn_passes_per_policy"].as<size_t>(); + if (vm.count("searn_rollout")) s->max_rollout = vm["searn_rollout"].as<size_t>(); + if (vm.count("searn_passes_per_policy")) s->passes_per_policy = vm["searn_passes_per_policy"].as<size_t>(); - if (vm.count("searn_gamma")) gamma = vm["searn_gamma"].as<float>(); - if (vm.count("searn_norecombine")) do_recombination = false; - if (vm.count("searn_allow_current_policy")) allow_current_policy = true; - if (vm.count("searn_rollout_oracle")) rollout_oracle = true; + if (vm.count("searn_gamma")) s->gamma = vm["searn_gamma"].as<float>(); + if (vm.count("searn_norecombine")) s->do_recombination = false; + if (vm.count("searn_allow_current_policy")) s->allow_current_policy = true; + if (vm.count("searn_rollout_oracle")) s->rollout_oracle = true; //check if the base learner is contextual bandit, in which case, we dont rollout all actions. - if ( vm.count("cb") || vm_file.count("cb") ) rollout_all_actions = false; + if ( vm.count("cb") || vm_file.count("cb") ) s->rollout_all_actions = false; //if we loaded a regressor with -i option, --searn_trained_nb_policies contains the number of trained policies in the file // and --searn_total_nb_policies contains the total number of policies in the file if ( vm_file.count("searn_total_nb_policies") ) { - current_policy = (uint32_t)vm_file["searn_trained_nb_policies"].as<size_t>(); - total_number_of_policies = (uint32_t)vm_file["searn_total_nb_policies"].as<size_t>(); - if (vm.count("searn_total_nb_policies") && (uint32_t)vm["searn_total_nb_policies"].as<size_t>() != total_number_of_policies) - std::cerr << "warning: --searn_total_nb_policies doesn't match the total number of policies stored in initial predictor. Using loaded value of: " << total_number_of_policies << endl; + s->current_policy = (uint32_t)vm_file["searn_trained_nb_policies"].as<size_t>(); + s->total_number_of_policies = (uint32_t)vm_file["searn_total_nb_policies"].as<size_t>(); + if (vm.count("searn_total_nb_policies") && (uint32_t)vm["searn_total_nb_policies"].as<size_t>() != s->total_number_of_policies) + std::cerr << "warning: --searn_total_nb_policies doesn't match the total number of policies stored in initial predictor. Using loaded value of: " << s->total_number_of_policies << endl; } else if (vm.count("searn_total_nb_policies")) { - total_number_of_policies = (uint32_t)vm["searn_total_nb_policies"].as<size_t>(); + s->total_number_of_policies = (uint32_t)vm["searn_total_nb_policies"].as<size_t>(); } if (vm.count("searn_as_dagger")) { //overide previously loaded options to set searn as dagger - allow_current_policy = true; - passes_per_policy = all.numpasses; - //rollout_oracle = true; - if( current_policy > 1 ) - current_policy = 1; + s->allow_current_policy = true; + s->passes_per_policy = all.numpasses; + //s->rollout_oracle = true; + if( s->current_policy > 1 ) + s->current_policy = 1; //indicate to adapt beta for each update - adaptive_beta = true; - alpha = vm["searn_as_dagger"].as<float>(); + s->adaptive_beta = true; + s->alpha = vm["searn_as_dagger"].as<float>(); } - if (beta <= 0 || beta >= 1) { + if (s->beta <= 0 || s->beta >= 1) { std::cerr << "warning: searn_beta must be in (0,1); resetting to 0.5" << std::endl; - beta = 0.5; + s->beta = 0.5; } - if (gamma <= 0 || gamma > 1) { + if (s->gamma <= 0 || s->gamma > 1) { std::cerr << "warning: searn_gamma must be in (0,1); resetting to 1.0" << std::endl; - gamma = 1.0; + s->gamma = 1.0; } - if (alpha < 0 || alpha > 1) { + if (s->alpha < 0 || s->alpha > 1) { std::cerr << "warning: searn_adaptive_beta must be in (0,1); resetting to 0.001" << std::endl; - alpha = 0.001f; + s->alpha = 0.001f; } - if (task.initialize != NULL) - if (!task.initialize(all, opts, vm, vm_file)) { + if (s->task.initialize != NULL) + if (!s->task.initialize(all, opts, vm, vm_file)) { std::cerr << "error: task did not initialize properly" << std::endl; exit(-1); } // check to make sure task is valid and set up our variables - if (task.final == NULL || - task.loss == NULL || - task.step == NULL || - task.oracle == NULL || - task.copy == NULL || - task.finish == NULL || - ((task.start_state == NULL) == (task.start_state_multiline == NULL)) || - ((task.cs_example == NULL) == (task.cs_ldf_example == NULL))) { + if (s->task.final == NULL || + s->task.loss == NULL || + s->task.step == NULL || + s->task.oracle == NULL || + s->task.copy == NULL || + s->task.finish == NULL || + ((s->task.start_state == NULL) == (s->task.start_state_multiline == NULL)) || + ((s->task.cs_example == NULL) == (s->task.cs_ldf_example == NULL))) { std::cerr << "error: searn task malformed" << std::endl; exit(-1); } - is_singleline = (task.start_state != NULL); - is_ldf = (task.cs_example == NULL); - has_hash = (task.hash != NULL); - constrainted_actions = (task.allowed != NULL); + s->is_singleline = (s->task.start_state != NULL); + s->is_ldf = (s->task.cs_example == NULL); + s->has_hash = (s->task.hash != NULL); + s->constrainted_actions = (s->task.allowed != NULL); - if (do_recombination && (task.hash == NULL)) { + if (s->do_recombination && (s->task.hash == NULL)) { std::cerr << "warning: cannot do recombination when hashing is unavailable -- turning off recombination" << std::endl; - do_recombination = false; + s->do_recombination = false; } - if (do_recombination) { + if (s->do_recombination) { // 0 is an invalid action - past_states = new v_hashmap<state,action>(1023, 0, task.equivalent); + s->past_states = new v_hashmap<state,action>(1023, 0, s->task.equivalent); } - if (is_ldf && !constrainted_actions) { + if (s->is_ldf && !s->constrainted_actions) { std::cerr << "error: LDF requires allowed" << std::endl; exit(-1); } @@ -752,16 +777,16 @@ namespace Searn //compute total number of policies we will have at end of training // we add current_policy for cases where we start from an initial set of policies loaded through -i option - uint32_t tmp_number_of_policies = current_policy; + uint32_t tmp_number_of_policies = s->current_policy; if( all.training ) - tmp_number_of_policies += (int)ceil(((float)all.numpasses) / ((float)passes_per_policy)); + tmp_number_of_policies += (int)ceil(((float)all.numpasses) / ((float)s->passes_per_policy)); //the user might have specified the number of policies that will eventually be trained through multiple vw calls, //so only set total_number_of_policies to computed value if it is larger - if( tmp_number_of_policies > total_number_of_policies ) + if( tmp_number_of_policies > s->total_number_of_policies ) { - total_number_of_policies = tmp_number_of_policies; - if( current_policy > 0 ) //we loaded a file but total number of policies didn't match what is needed for training + s->total_number_of_policies = tmp_number_of_policies; + if( s->current_policy > 0 ) //we loaded a file but total number of policies didn't match what is needed for training { std::cerr << "warning: you're attempting to train more classifiers than was allocated initially. Likely to cause bad performance." << endl; } @@ -770,108 +795,105 @@ namespace Searn //current policy currently points to a new policy we would train //if we are not training and loaded a bunch of policies for testing, we need to subtract 1 from current policy //so that we only use those loaded when testing (as run_prediction is called with allow_current to true) - if( !all.training && current_policy > 0 ) - current_policy--; + if( !all.training && s->current_policy > 0 ) + s->current_policy--; - //std::cerr << "Current Policy: " << current_policy << endl; + //std::cerr << "Current Policy: " << s->current_policy << endl; //std::cerr << "Total Number of Policies: " << total_number_of_policies << endl; std::stringstream ss1; std::stringstream ss2; - ss1 << current_policy; + ss1 << s->current_policy; //use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searn_trained_nb_policies VW::cmd_string_replace_value(all.options_from_file,"--searn_trained_nb_policies", ss1.str()); - ss2 << total_number_of_policies; + ss2 << s->total_number_of_policies; //use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searn_total_nb_policies VW::cmd_string_replace_value(all.options_from_file,"--searn_total_nb_policies", ss2.str()); - all.base_learner_nb_w *= total_number_of_policies; - increment = ((uint32_t)all.length() / all.base_learner_nb_w) * all.stride; - //cerr << "searn increment = " << increment << endl; - - all.driver = drive; - base_learner = all.learn; - all.learn = learn; - base_finish = all.finish; - all.finish = finish; + all.base_learner_nb_w *= s->total_number_of_policies; + s->increment = ((uint32_t)all.length() / all.base_learner_nb_w) * all.stride; + //cerr << "searn increment = " << s->increment << endl; + + learner l = {s, drive, learn, finish, all.l.save_load}; + all.l = l; } - uint32_t searn_predict(vw&all, state s0, size_t step, bool allow_oracle, bool allow_current, v_array< pair<uint32_t,float> >* partial_predictions) // TODO: partial_predictions + uint32_t searn_predict(vw&all, searn& s, state s0, size_t step, bool allow_oracle, bool allow_current, v_array< pair<uint32_t,float> >* partial_predictions) // TODO: partial_predictions { - int policy = SearnUtil::random_policy(read_example_last_id * 2147483 + step * 2147483647 /* has_hash ? task.hash(s0) : step */, beta, allow_current, (int)current_policy, allow_oracle, rollout_all_actions); - if (PRINT_DEBUG_INFO) { cerr << "predicing with policy " << policy << " (allow_oracle=" << allow_oracle << ", allow_current=" << allow_current << "), current_policy=" << current_policy << endl; } + int policy = SearnUtil::random_policy(s.read_example_last_id * 2147483 + step * 2147483647 /* s.has_hash ? s.task.hash(s0) : step */, s.beta, allow_current, (int)s.current_policy, allow_oracle, s.rollout_all_actions); + if (PRINT_DEBUG_INFO) { cerr << "predicing with policy " << policy << " (allow_oracle=" << allow_oracle << ", allow_current=" << allow_current << "), current_policy=" << s.current_policy << endl; } if (policy == -1) { - return task.oracle(s0); + return s.task.oracle(s0); } example *ec; - if (!is_ldf) { - task.cs_example(all, s0, ec, true); - SearnUtil::add_policy_offset(all, ec, increment, policy); + if (!s.is_ldf) { + s.task.cs_example(all, s0, ec, true); + SearnUtil::add_policy_offset(all, ec, s.increment, policy); void* old_label = ec->ld; - if(rollout_all_actions) { //this means we have a cost-sensitive base learner - ec->ld = (void*)&testall_labels; - if (task.allowed != NULL) { // we need to check which actions are allowed - allowed_labels.costs.erase(); + if(s.rollout_all_actions) { //this means we have a cost-sensitive base learner + ec->ld = (void*)&s.testall_labels; + if (s.task.allowed != NULL) { // we need to check which actions are allowed + s.allowed_labels.costs.erase(); bool all_allowed = true; - for (uint32_t k=1; k<=max_action; k++) - if (task.allowed(s0, k)) { + for (uint32_t k=1; k<=s.max_action; k++) + if (s.task.allowed(s0, k)) { CSOAA::wclass cost = { FLT_MAX, k, 1., 0. }; - allowed_labels.costs.push_back(cost); + s.allowed_labels.costs.push_back(cost); } else all_allowed = false; if (!all_allowed) - ec->ld = (void*)&allowed_labels; + ec->ld = (void*)&s.allowed_labels; } } else { //if we have a contextual bandit base learner - ec->ld = (void*)&testall_labels_cb; - if (task.allowed != NULL) { // we need to check which actions are allowed - allowed_labels_cb.costs.erase(); + ec->ld = (void*)&s.testall_labels_cb; + if (s.task.allowed != NULL) { // we need to check which actions are allowed + s.allowed_labels_cb.costs.erase(); bool all_allowed = true; - for (uint32_t k=1; k<=max_action; k++) - if (task.allowed(s0, k)) { + for (uint32_t k=1; k<=s.max_action; k++) + if (s.task.allowed(s0, k)) { CB::cb_class cost = { FLT_MAX, k, 0. }; - allowed_labels_cb.costs.push_back(cost); + s.allowed_labels_cb.costs.push_back(cost); } else all_allowed = false; if (!all_allowed) - ec->ld = (void*)&allowed_labels_cb; + ec->ld = (void*)&s.allowed_labels_cb; } } //cerr << "searn>"; //simple_print_example_features(all,ec); - base_learner(&all,ec); - total_predictions_made++; - searn_num_features += ec->num_features; + s.base.learn(&all,s.base.data,ec); + s.total_predictions_made++; + s.searn_num_features += ec->num_features; uint32_t final_prediction = (uint32_t)(*(OAA::prediction_t*)&(ec->final_prediction)); ec->ld = old_label; - SearnUtil::remove_policy_offset(all, ec, increment, policy); - task.cs_example(all, s0, ec, false); + SearnUtil::remove_policy_offset(all, ec, s.increment, policy); + s.task.cs_example(all, s0, ec, false); return final_prediction; } else { // is_ldf //TODO: modify this to handle contextual bandit base learner with ldf float best_prediction = 0; uint32_t best_action = 0; - for (uint32_t action=1; action <= max_action; action++) { - if (!task.allowed(s0, action)) + for (uint32_t action=1; action <= s.max_action; action++) { + if (!s.task.allowed(s0, action)) break; // for LDF, there are no more actions - task.cs_ldf_example(all, s0, action, ec, true); + s.task.cs_ldf_example(all, s0, action, ec, true); //cerr << "created example: " << ec << ", label: " << ec->ld << endl; - SearnUtil::add_policy_offset(all, ec, increment, policy); - base_learner(&all,ec); total_predictions_made++; searn_num_features += ec->num_features; + SearnUtil::add_policy_offset(all, ec, s.increment, policy); + s.base.learn(&all,s.base.data,ec); s.total_predictions_made++; s.searn_num_features += ec->num_features; //cerr << "base_learned on example: " << ec << endl; - empty_example->in_use = true; - base_learner(&all,empty_example); - //cerr << "base_learned on empty example: " << empty_example << endl; - SearnUtil::remove_policy_offset(all, ec, increment, policy); + s.empty_example->in_use = true; + s.base.learn(&all,s.base.data,s.empty_example); + //cerr << "base_learned on empty example: " << s.empty_example << endl; + SearnUtil::remove_policy_offset(all, ec, s.increment, policy); if (action == 1 || ec->partial_prediction < best_prediction) { @@ -879,7 +901,7 @@ namespace Searn best_action = action; } //cerr << "releasing example: " << ec << ", label: " << ec->ld << endl; - task.cs_ldf_example(all, s0, action, ec, false); + s.task.cs_ldf_example(all, s0, action, ec, false); } if (best_action < 1) { @@ -890,111 +912,111 @@ namespace Searn } } - float single_rollout(vw&all, state s0, uint32_t action) + float single_rollout(vw&all, searn& s, state s0, uint32_t action) { //first check if action is valid for current state - if( action < 1 || action > max_action || (task.allowed && !task.allowed(s0,action)) ) + if( action < 1 || action > s.max_action || (s.task.allowed && !s.task.allowed(s0,action)) ) { std::cerr << "warning: asked to rollout an unallowed action: " << action << "; not performing rollout." << std::endl; return 0; } //copy state and step it with current action - rollout[action-1].alive = true; - rollout[action-1].st = task.copy(s0); - task.step(rollout[action-1].st, action); - rollout[action-1].is_finished = task.final(rollout[action-1].st); - if (do_recombination) rollout[action-1].hash = task.hash(rollout[action-1].st); + s.rollout[action-1].alive = true; + s.rollout[action-1].st = s.task.copy(s0); + s.task.step(s.rollout[action-1].st, action); + s.rollout[action-1].is_finished = s.task.final(s.rollout[action-1].st); + if (s.do_recombination) s.rollout[action-1].hash = s.task.hash(s.rollout[action-1].st); //if not finished complete rollout - if (!rollout[action-1].is_finished) { - for (size_t step=1; step<max_rollout; step++) { + if (!s.rollout[action-1].is_finished) { + for (size_t step=1; step<s.max_rollout; step++) { uint32_t act_tmp = 0; - if (do_recombination) - act_tmp = past_states->get(rollout[action-1].st, rollout[action-1].hash); + if (s.do_recombination) + act_tmp = s.past_states->get(s.rollout[action-1].st, s.rollout[action-1].hash); if (act_tmp == 0) { // this means we didn't find it or we're not recombining - if( !rollout_oracle ) - act_tmp = searn_predict(all, rollout[action-1].st, step, true, allow_current_policy, NULL); + if( !s.rollout_oracle ) + act_tmp = searn_predict(all, s, s.rollout[action-1].st, step, true, s.allow_current_policy, NULL); else - act_tmp = task.oracle(rollout[action-1].st); + act_tmp = s.task.oracle(s.rollout[action-1].st); - if (do_recombination) { + if (s.do_recombination) { // we need to make a copy of the state - state copy = task.copy(rollout[action-1].st); - past_states->put_after_get(copy, rollout[action-1].hash, act_tmp); - unfreed_states.push_back(copy); + state copy = s.task.copy(s.rollout[action-1].st); + s.past_states->put_after_get(copy, s.rollout[action-1].hash, act_tmp); + s.unfreed_states.push_back(copy); } } - task.step(rollout[action-1].st, act_tmp); - rollout[action-1].is_finished = task.final(rollout[action-1].st); - if (do_recombination) rollout[action-1].hash = task.hash(rollout[action-1].st); - if (rollout[action-1].is_finished) break; + s.task.step(s.rollout[action-1].st, act_tmp); + s.rollout[action-1].is_finished = s.task.final(s.rollout[action-1].st); + if (s.do_recombination) s.rollout[action-1].hash = s.task.hash(s.rollout[action-1].st); + if (s.rollout[action-1].is_finished) break; } } // finally, compute losses and free copies - float l = task.loss(rollout[action-1].st); - if ((l == FLT_MAX) && (!rollout[action-1].is_finished) && (max_rollout < INT_MAX)) { + float l = s.task.loss(s.rollout[action-1].st); + if ((l == FLT_MAX) && (!s.rollout[action-1].is_finished) && (s.max_rollout < INT_MAX)) { std::cerr << "error: you asked for short rollouts, but your task does not support pre-final losses" << std::endl; exit(-1); } - task.finish(rollout[action-1].st); + s.task.finish(s.rollout[action-1].st); return l; } - void parallel_rollout(vw&all, state s0) + void parallel_rollout(vw&all, searn& s, state s0) { // first, make K copies of s0 and step them bool all_finished = true; - for (size_t k=1; k<=max_action; k++) - rollout[k-1].alive = false; + for (size_t k=1; k<=s.max_action; k++) + s.rollout[k-1].alive = false; - for (uint32_t k=1; k<=max_action; k++) { + for (uint32_t k=1; k<=s.max_action; k++) { // in the case of LDF, we might run out of actions early - if (task.allowed && !task.allowed(s0, k)) { - if (is_ldf) break; + if (s.task.allowed && !s.task.allowed(s0, k)) { + if (s.is_ldf) break; else continue; } - rollout[k-1].alive = true; - rollout[k-1].st = task.copy(s0); - task.step(rollout[k-1].st, k); - rollout[k-1].is_finished = task.final(rollout[k-1].st); - if (do_recombination) rollout[k-1].hash = task.hash(rollout[k-1].st); - all_finished = all_finished && rollout[k-1].is_finished; + s.rollout[k-1].alive = true; + s.rollout[k-1].st = s.task.copy(s0); + s.task.step(s.rollout[k-1].st, k); + s.rollout[k-1].is_finished = s.task.final(s.rollout[k-1].st); + if (s.do_recombination) s.rollout[k-1].hash = s.task.hash(s.rollout[k-1].st); + all_finished = all_finished && s.rollout[k-1].is_finished; } // now, complete all rollouts if (!all_finished) { - for (size_t step=1; step<max_rollout; step++) { + for (size_t step=1; step<s.max_rollout; step++) { all_finished = true; - for (size_t k=1; k<=max_action; k++) { - if (rollout[k-1].is_finished) continue; + for (size_t k=1; k<=s.max_action; k++) { + if (s.rollout[k-1].is_finished) continue; uint32_t action = 0; - if (do_recombination) - action = past_states->get(rollout[k-1].st, rollout[k-1].hash); + if (s.do_recombination) + action = s.past_states->get(s.rollout[k-1].st, s.rollout[k-1].hash); if (action == 0) { // this means we didn't find it or we're not recombining - if( !rollout_oracle ) - action = searn_predict(all, rollout[k-1].st, step, true, allow_current_policy, NULL); + if( !s.rollout_oracle ) + action = searn_predict(all, s, s.rollout[k-1].st, step, true, s.allow_current_policy, NULL); else - action = task.oracle(rollout[k-1].st); + action = s.task.oracle(s.rollout[k-1].st); - if (do_recombination) { + if (s.do_recombination) { // we need to make a copy of the state - state copy = task.copy(rollout[k-1].st); - past_states->put_after_get(copy, rollout[k-1].hash, action); - unfreed_states.push_back(copy); + state copy = s.task.copy(s.rollout[k-1].st); + s.past_states->put_after_get(copy, s.rollout[k-1].hash, action); + s.unfreed_states.push_back(copy); } } - task.step(rollout[k-1].st, action); - rollout[k-1].is_finished = task.final(rollout[k-1].st); - if (do_recombination) rollout[k-1].hash = task.hash(rollout[k-1].st); - all_finished = all_finished && rollout[k-1].is_finished; + s.task.step(s.rollout[k-1].st, action); + s.rollout[k-1].is_finished = s.task.final(s.rollout[k-1].st); + if (s.do_recombination) s.rollout[k-1].hash = s.task.hash(s.rollout[k-1].st); + all_finished = all_finished && s.rollout[k-1].is_finished; } if (all_finished) break; } @@ -1002,39 +1024,39 @@ namespace Searn // finally, compute losses and free copies float min_loss = 0; - loss_vector.erase(); - for (uint32_t k=1; k<=max_action; k++) { - if (!rollout[k-1].alive) + s.loss_vector.erase(); + for (uint32_t k=1; k<=s.max_action; k++) { + if (!s.rollout[k-1].alive) break; - float l = task.loss(rollout[k-1].st); - if ((l == FLT_MAX) && (!rollout[k-1].is_finished) && (max_rollout < INT_MAX)) { + float l = s.task.loss(s.rollout[k-1].st); + if ((l == FLT_MAX) && (!s.rollout[k-1].is_finished) && (s.max_rollout < INT_MAX)) { std::cerr << "error: you asked for short rollouts, but your task does not support pre-final losses" << std::endl; exit(-1); } CSOAA::wclass temp = { l, k, 1., 0. }; - loss_vector.push_back(temp); + s.loss_vector.push_back(temp); if ((k == 1) || (l < min_loss)) { min_loss = l; } - task.finish(rollout[k-1].st); + s.task.finish(s.rollout[k-1].st); } // subtract the smallest loss - for (size_t k=1; k<=max_action; k++) - if (rollout[k-1].alive) - loss_vector[k-1].x -= min_loss; + for (size_t k=1; k<=s.max_action; k++) + if (s.rollout[k-1].alive) + s.loss_vector[k-1].x -= min_loss; } - uint32_t uniform_exploration(state s0, float& prob_sampled_action) + uint32_t uniform_exploration(searn& s, state s0, float& prob_sampled_action) { //find how many valid actions - size_t nb_allowed_actions = max_action; - if( task.allowed ) { - for (uint32_t k=1; k<=max_action; k++) { - if( !task.allowed(s0,k) ) { + size_t nb_allowed_actions = s.max_action; + if( s.task.allowed ) { + for (uint32_t k=1; k<=s.max_action; k++) { + if( !s.task.allowed(s0,k) ) { nb_allowed_actions--; - if (is_ldf) { + if (s.is_ldf) { nb_allowed_actions = k-1; break; } @@ -1043,25 +1065,25 @@ namespace Searn } uint32_t action = (size_t)(frand48() * nb_allowed_actions) + 1; - if( task.allowed && nb_allowed_actions < max_action && !is_ldf) { + if( s.task.allowed && nb_allowed_actions < s.max_action && !s.is_ldf) { //need to adjust action to the corresponding valid action for (uint32_t k=1; k<=action; k++) { - if( !task.allowed(s0,k) ) action++; + if( !s.task.allowed(s0,k) ) action++; } } prob_sampled_action = (float) (1.0/nb_allowed_actions); return action; } - void get_contextual_bandit_loss_vector(vw&all, state s0) + void get_contextual_bandit_loss_vector(vw&all, searn& s, state s0) { float prob_sampled = 1.; - uint32_t act = uniform_exploration(s0,prob_sampled); - float loss = single_rollout(all,s0,act); + uint32_t act = uniform_exploration(s, s0,prob_sampled); + float loss = single_rollout(all,s, s0,act); - loss_vector_cb.erase(); - for (uint32_t k=1; k<=max_action; k++) { - if( task.allowed && !task.allowed(s0,k)) + s.loss_vector_cb.erase(); + for (uint32_t k=1; k<=s.max_action; k++) { + if( s.task.allowed && !s.task.allowed(s0,k)) break; CB::cb_class temp; @@ -1072,150 +1094,150 @@ namespace Searn temp.x = loss; temp.prob_action = prob_sampled; } - loss_vector_cb.push_back(temp); + s.loss_vector_cb.push_back(temp); } } - void generate_state_example(vw&all, state s0) + void generate_state_example(vw&all, searn& s, state s0) { // start by doing rollouts so we can get costs - loss_vector.erase(); - loss_vector_cb.erase(); - if( rollout_all_actions ) { - parallel_rollout(all, s0); + s.loss_vector.erase(); + s.loss_vector_cb.erase(); + if( s.rollout_all_actions ) { + parallel_rollout(all, s, s0); } else { - get_contextual_bandit_loss_vector(all, s0); + get_contextual_bandit_loss_vector(all, s, s0); } - if (loss_vector.size() <= 1 && loss_vector_cb.size() == 0) { + if (s.loss_vector.size() <= 1 && s.loss_vector_cb.size() == 0) { // nothing interesting to do! return; } // now, generate training examples - if (!is_ldf) { - total_examples_generated++; + if (!s.is_ldf) { + s.total_examples_generated++; example* ec; - task.cs_example(all, s0, ec, true); + s.task.cs_example(all, s0, ec, true); void* old_label = ec->ld; - if(rollout_all_actions) { - CSOAA::label ld = { loss_vector }; + if(s.rollout_all_actions) { + CSOAA::label ld = { s.loss_vector }; ec->ld = (void*)&ld; } else { - CB::label ld = { loss_vector_cb }; + CB::label ld = { s.loss_vector_cb }; ec->ld = (void*)&ld; } - SearnUtil::add_policy_offset(all, ec, increment, current_policy); - base_learner(&all,ec); - SearnUtil::remove_policy_offset(all, ec, increment, current_policy); + SearnUtil::add_policy_offset(all, ec, s.increment, s.current_policy); + s.base.learn(&all,s.base.data,ec); + SearnUtil::remove_policy_offset(all, ec, s.increment, s.current_policy); ec->ld = old_label; - task.cs_example(all, s0, ec, false); + s.task.cs_example(all, s0, ec, false); } else { // is_ldf //TODO: support ldf with contextual bandit base learner - old_labels.erase(); - new_labels.erase(); + s.old_labels.erase(); + s.new_labels.erase(); - for (uint32_t k=1; k<=max_action; k++) { - if (rollout[k-1].alive) { - OAA::mc_label ld = { k, loss_vector[k-1].x }; - new_labels.push_back(ld); + for (uint32_t k=1; k<=s.max_action; k++) { + if (s.rollout[k-1].alive) { + OAA::mc_label ld = { k, s.loss_vector[k-1].x }; + s.new_labels.push_back(ld); } else { OAA::mc_label ld = { k, 0. }; - new_labels.push_back(ld); + s.new_labels.push_back(ld); } } // cerr << "vvvvvvvvvvvvvvvvvvvvvvvvvvvv" << endl; - for (uint32_t k=1; k<=max_action; k++) { - if (!rollout[k-1].alive) break; + for (uint32_t k=1; k<=s.max_action; k++) { + if (!s.rollout[k-1].alive) break; - total_examples_generated++; + s.total_examples_generated++; - task.cs_ldf_example(all, s0, k, global_example_set[k-1], true); - old_labels.push_back(global_example_set[k-1]->ld); - global_example_set[k-1]->ld = (void*)(&new_labels[k-1]); - SearnUtil::add_policy_offset(all, global_example_set[k-1], increment, current_policy); - if (PRINT_DEBUG_INFO) { cerr << "add_policy_offset, max_action=" << max_action << ", total_number_of_policies=" << total_number_of_policies << ", current_policy=" << current_policy << endl;} - base_learner(&all,global_example_set[k-1]); + s.task.cs_ldf_example(all, s0, k, s.global_example_set[k-1], true); + s.old_labels.push_back(s.global_example_set[k-1]->ld); + s.global_example_set[k-1]->ld = (void*)(&s.new_labels[k-1]); + SearnUtil::add_policy_offset(all, s.global_example_set[k-1], s.increment, s.current_policy); + if (PRINT_DEBUG_INFO) { cerr << "add_policy_offset, s.max_action=" << s.max_action << ", total_number_of_policies=" << s.total_number_of_policies << ", current_policy=" << s.current_policy << endl;} + s.base.learn(&all,s.base.data,s.global_example_set[k-1]); } - // cerr << "============================ (empty = " << empty_example << ")" << endl; - empty_example->in_use = true; - base_learner(&all,empty_example); + // cerr << "============================ (empty = " << s.empty_example << ")" << endl; + s.empty_example->in_use = true; + s.base.learn(&all,s.base.data,s.empty_example); - for (uint32_t k=1; k<=max_action; k++) { - if (!rollout[k-1].alive) break; - SearnUtil::remove_policy_offset(all, global_example_set[k-1], increment, current_policy); - global_example_set[k-1]->ld = old_labels[k-1]; - task.cs_ldf_example(all, s0, k, global_example_set[k-1], false); + for (uint32_t k=1; k<=s.max_action; k++) { + if (!s.rollout[k-1].alive) break; + SearnUtil::remove_policy_offset(all, s.global_example_set[k-1], s.increment, s.current_policy); + s.global_example_set[k-1]->ld = s.old_labels[k-1]; + s.task.cs_ldf_example(all, s0, k, s.global_example_set[k-1], false); } // cerr << "^^^^^^^^^^^^^^^^^^^^^^^^^^^^" << endl; } } - void run_prediction(vw&all, state s0, bool allow_oracle, bool allow_current, bool track_actions, std::vector<action>* action_sequence) + void run_prediction(vw&all, searn& s, state s0, bool allow_oracle, bool allow_current, bool track_actions, std::vector<action>* action_sequence) { int step = 1; - while (!task.final(s0)) { - uint32_t action = searn_predict(all, s0, step, allow_oracle, allow_current, NULL); + while (!s.task.final(s0)) { + uint32_t action = searn_predict(all, s, s0, step, allow_oracle, allow_current, NULL); if (track_actions) action_sequence->push_back(action); - task.step(s0, action); + s.task.step(s0, action); step++; } } - void do_actual_learning(vw&all) + void do_actual_learning(vw&all, searn& s) { // there are two cases: // * is_singleline --> look only at ec_seq[0] // * otherwise --> look at everything - if (ec_seq.size() == 0) + if (s.ec_seq.size() == 0) return; // generate the start state state s0; - if (is_singleline) - task.start_state(ec_seq[0], &s0); + if (s.is_singleline) + s.task.start_state(s.ec_seq[0], &s0); else - task.start_state_multiline(ec_seq.begin, ec_seq.size(), &s0); + s.task.start_state_multiline(s.ec_seq.begin, s.ec_seq.size(), &s0); state s0copy = NULL; - bool is_test = task.is_test_example(ec_seq.begin, ec_seq.size()); + bool is_test = s.task.is_test_example(s.ec_seq.begin, s.ec_seq.size()); if (!is_test) { - s0copy = task.copy(s0); + s0copy = s.task.copy(s0); all.sd->example_number++; - all.sd->total_features += searn_num_features; + all.sd->total_features += s.searn_num_features; all.sd->weighted_examples += 1.; } - bool will_print = is_test || should_print_update(all) || will_global_print_label(all); + bool will_print = is_test || should_print_update(all) || will_global_print_label(all, s); - searn_num_features = 0; + s.searn_num_features = 0; std::vector<action> action_sequence; // if we are using adaptive beta, update it to take into account the latest updates - if( adaptive_beta ) beta = 1.f - powf(1.f - alpha,(float)total_examples_generated); + if( s.adaptive_beta ) s.beta = 1.f - powf(1.f - s.alpha,(float)s.total_examples_generated); - run_prediction(all, s0, false, true, will_print, &action_sequence); - global_print_label(all, ec_seq[0], s0, action_sequence); + run_prediction(all, s, s0, false, true, will_print, &action_sequence); + global_print_label(all, s, s.ec_seq[0], s0, action_sequence); if (!is_test) { - float loss = task.loss(s0); + float loss = s.task.loss(s0); all.sd->sum_loss += loss; all.sd->sum_loss_since_last_dump += loss; } - print_update(all, s0, action_sequence); + print_update(all, s, s0, action_sequence); - task.finish(s0); + s.task.finish(s0); if (is_test || !all.training) return; @@ -1224,105 +1246,106 @@ namespace Searn // training examples only get here int step = 1; - while (!task.final(s0)) { + while (!s.task.final(s0)) { // if we are using adaptive beta, update it to take into account the latest updates - if( adaptive_beta ) beta = 1.f - powf(1.f - alpha,(float)total_examples_generated); + if( s.adaptive_beta ) s.beta = 1.f - powf(1.f - s.alpha,(float)s.total_examples_generated); // first, make a prediction (we don't want to bias ourselves if // we're using the current policy to predict) - uint32_t action = searn_predict(all, s0, step, true, allow_current_policy, NULL); + uint32_t action = searn_predict(all, s, s0, step, true, s.allow_current_policy, NULL); // generate training example for the current state - generate_state_example(all, s0); + generate_state_example(all, s, s0); // take the prescribed step - task.step(s0, action); + s.task.step(s0, action); step++; } - task.finish(s0); - if (do_recombination) { // we need to free a bunch of memory - // past_states->iter(&hm_free_state_copies); - free_unfreed_states(); - past_states->clear(); + s.task.finish(s0); + if (s.do_recombination) { // we need to free a bunch of memory + // s.past_states->iter(&hm_free_state_copies); + free_unfreed_states(s); + s.past_states->clear(); } } - void process_next_example(vw&all, example *ec) + void process_next_example(vw&all, searn& s, example *ec) { bool is_real_example = true; - if (is_singleline) { - if (ec_seq.size() == 0) - ec_seq.push_back(ec); + if (s.is_singleline) { + if (s.ec_seq.size() == 0) + s.ec_seq.push_back(ec); else - ec_seq[0] = ec; + s.ec_seq[0] = ec; - do_actual_learning(all); + do_actual_learning(all, s); } else { // is multiline - if (ec_seq.size() >= all.p->ring_size - 2) { // give some wiggle room + if (s.ec_seq.size() >= all.p->ring_size - 2) { // give some wiggle room std::cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << std::endl; - do_actual_learning(all); - clear_seq(all); + do_actual_learning(all, s); + clear_seq(all, s); } if (OAA::example_is_newline(ec)) { - do_actual_learning(all); - clear_seq(all); + do_actual_learning(all, s); + clear_seq(all, s); //CSOAA_LDF::global_print_newline(all); VW::finish_example(all, ec); is_real_example = false; } else { - ec_seq.push_back(ec); + s.ec_seq.push_back(ec); } } // for both single and multiline if (is_real_example) { - read_example_this_loop++; - read_example_last_id = ec->example_counter; - if (ec->pass != read_example_last_pass) { - read_example_last_pass = ec->pass; - passes_since_new_policy++; - if (passes_since_new_policy >= passes_per_policy) { - passes_since_new_policy = 0; + s.read_example_this_loop++; + s.read_example_last_id = ec->example_counter; + if (ec->pass != s.read_example_last_pass) { + s.read_example_last_pass = ec->pass; + s.passes_since_new_policy++; + if (s.passes_since_new_policy >= s.passes_per_policy) { + s.passes_since_new_policy = 0; if(all.training) - current_policy++; - if (current_policy > total_number_of_policies) { + s.current_policy++; + if (s.current_policy > s.total_number_of_policies) { std::cerr << "internal error (bug): too many policies; not advancing" << std::endl; - current_policy = total_number_of_policies; + s.current_policy = s.total_number_of_policies; } //reset searn_trained_nb_policies in options_from_file so it is saved to regressor file later std::stringstream ss; - ss << current_policy; + ss << s.current_policy; VW::cmd_string_replace_value(all.options_from_file,"--searn_trained_nb_policies", ss.str()); } } } } - void drive(void*in) + void drive(void*in, void*d) { vw*all = (vw*)in; // initialize everything - + searn* s = (searn*)d; + const char * header_fmt = "%-10s %-10s %8s %15s %24s %22s %8s %5s %5s %15s %15s\n"; fprintf(stderr, header_fmt, "average", "since", "sequence", "example", "current label", "current predicted", "current", "cur", "cur", "predic.", "examples"); fprintf(stderr, header_fmt, "loss", "last", "counter", "weight", "sequence prefix", "sequence prefix", "features", "pass", "pol", "made", "gener."); cerr.precision(5); - initialize_memory(); + initialize_memory(*s); example* ec = NULL; - read_example_this_loop = 0; + s->read_example_this_loop = 0; while (true) { if ((ec = get_example(all->p)) != NULL) { // semiblocking operation - process_next_example(*all, ec); + process_next_example(*all, *s, ec); } else if (parser_done(all->p)) { - if (!is_singleline) - do_actual_learning(*all); + if (!s->is_singleline) + do_actual_learning(*all, *s); break; } } @@ -1330,10 +1353,10 @@ namespace Searn if( all->training ) { std::stringstream ss1; std::stringstream ss2; - ss1 << (current_policy+1); + ss1 << (s->current_policy+1); //use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searn_trained_nb_policies VW::cmd_string_replace_value(all->options_from_file,"--searn_trained_nb_policies", ss1.str()); - ss2 << total_number_of_policies; + ss2 << s->total_number_of_policies; //use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searn_total_nb_policies VW::cmd_string_replace_value(all->options_from_file,"--searn_total_nb_policies", ss2.str()); } @@ -1348,15 +1371,15 @@ namespace ImperativeSearn { const char INIT_TRAIN = 1; const char LEARN = 2; - inline bool isLDF(searn_struct* srn) { return (srn->A == 0); } + inline bool isLDF(searn& srn) { return (srn.A == 0); } - uint32_t choose_policy(searn_struct* srn, bool allow_current, bool allow_optimal) + uint32_t choose_policy(searn& srn, bool allow_current, bool allow_optimal) { - uint32_t seed = 0; // TODO: srn->read_example_last_id * 2147483 + srn->t * 2147483647; - return SearnUtil::random_policy(seed, srn->beta, allow_current, srn->current_policy, allow_optimal, srn->rollout_all_actions); + uint32_t seed = 0; // TODO: srn.read_example_last_id * 2147483 + srn.t * 2147483647; + return SearnUtil::random_policy(seed, srn.beta, allow_current, srn.current_policy, allow_optimal, srn.rollout_all_actions); } - v_array<CSOAA::wclass> get_all_labels(searn_struct* srn, size_t num_ec, v_array<uint32_t> *yallowed) + v_array<CSOAA::wclass> get_all_labels(searn& srn, size_t num_ec, v_array<uint32_t> *yallowed) { if (isLDF(srn)) { v_array<CSOAA::wclass> ret; // TODO: cache these! @@ -1369,7 +1392,7 @@ namespace ImperativeSearn { // is not LDF if (yallowed == NULL) { v_array<CSOAA::wclass> ret; // TODO: cache this! - for (uint32_t i=1; i<=srn->A; i++) { + for (uint32_t i=1; i<=srn.A; i++) { CSOAA::wclass cost = { FLT_MAX, i, 1., 0. }; ret.push_back(cost); } @@ -1390,21 +1413,20 @@ namespace ImperativeSearn { return 0; } - uint32_t single_prediction_notLDF(vw& all, example* ec, v_array<CSOAA::wclass> valid_labels, uint32_t pol) + uint32_t single_prediction_notLDF(vw& all, searn& srn, example* ec, v_array<CSOAA::wclass> valid_labels, uint32_t pol) { - searn_struct *srn = (searn_struct*)all.searnstr; assert(pol > 0); void* old_label = ec->ld; ec->ld = (void*)&valid_labels; - SearnUtil::add_policy_offset(all, ec, srn->increment, pol); + SearnUtil::add_policy_offset(all, ec, srn.increment, pol); - srn->base_learner(&all, ec); - srn->total_predictions_made++; - srn->num_features += ec->num_features; + srn.base.learn(&all,srn.base.data, ec); + srn.total_predictions_made++; + srn.num_features += ec->num_features; uint32_t final_prediction = (uint32_t)(*(OAA::prediction_t*)&(ec->final_prediction)); - SearnUtil::remove_policy_offset(all, ec, srn->increment, pol); + SearnUtil::remove_policy_offset(all, ec, srn.increment, pol); ec->ld = old_label; return final_prediction; @@ -1416,7 +1438,7 @@ namespace ImperativeSearn { return opts[(size_t)(((float)opts.size()) * r)]; } - uint32_t single_action(vw& all, example** ecs, size_t num_ec, v_array<CSOAA::wclass> valid_labels, uint32_t pol, v_array<uint32_t> *ystar) { + uint32_t single_action(vw& all, searn& srn, example** ecs, size_t num_ec, v_array<CSOAA::wclass> valid_labels, uint32_t pol, v_array<uint32_t> *ystar) { //cerr << "pol=" << pol << " ystar.size()=" << ystar->size() << " ystar[0]=" << ((ystar->size() > 0) ? (*ystar)[0] : 0) << endl; if (pol == 0) { // optimal policy if ((ystar == NULL) || (ystar->size() == 0)) @@ -1424,19 +1446,18 @@ namespace ImperativeSearn { else return choose_random<uint32_t>(*ystar); } else { // learned policy - if (isLDF((searn_struct*)all.searnstr)) + if (isLDF(srn)) return single_prediction_LDF(all, ecs, num_ec, pol); else - return single_prediction_notLDF(all, *ecs, valid_labels, pol); + return single_prediction_notLDF(all, srn, *ecs, valid_labels, pol); } } - void clear_snapshot(vw& all) + void clear_snapshot(vw& all, searn& srn) { - searn_struct *srn = (searn_struct*)all.searnstr; - for (size_t i=0; i<srn->snapshot_data.size(); i++) - free(srn->snapshot_data[i].data_ptr); - srn->snapshot_data.erase(); + for (size_t i=0; i<srn.snapshot_data.size(); i++) + free(srn.snapshot_data[i].data_ptr); + srn.snapshot_data.erase(); } // if not LDF: @@ -1456,10 +1477,10 @@ namespace ImperativeSearn { // otherwise means the oracle could do any of the listed actions uint32_t searn_predict(vw& all, example** ecs, size_t num_ec, v_array<uint32_t> *yallowed, v_array<uint32_t> *ystar) // num_ec == 0 means normal example, >0 means ldf, yallowed==NULL means all allowed, ystar==NULL means don't know { - searn_struct *srn = (searn_struct*)all.searnstr; + searn* srn=(searn*)all.searnstr; // check ldf sanity - if (!isLDF(srn)) { + if (!isLDF(*srn)) { assert(num_ec == 0); // searntask is trying to define an ldf example in a non-ldf problem } else { // is LDF assert(num_ec != 0); // searntask is trying to define a non-ldf example in an ldf problem" << endl; @@ -1467,17 +1488,17 @@ namespace ImperativeSearn { } if (srn->state == INIT_TEST) { - uint32_t pol = choose_policy(srn, true, false); - v_array<CSOAA::wclass> valid_labels = get_all_labels(srn, num_ec, yallowed); - uint32_t a = single_action(all, ecs, num_ec, valid_labels, pol, ystar); + uint32_t pol = choose_policy(*srn, true, false); + v_array<CSOAA::wclass> valid_labels = get_all_labels(*srn, num_ec, yallowed); + uint32_t a = single_action(all, *srn, ecs, num_ec, valid_labels, pol, ystar); srn->t++; valid_labels.erase(); valid_labels.delete_v(); return a; } if (srn->state == INIT_TRAIN) { - uint32_t pol = choose_policy(srn, srn->allow_current_policy, true); - v_array<CSOAA::wclass> valid_labels = get_all_labels(srn, num_ec, yallowed); - uint32_t a = single_action(all, ecs, num_ec, valid_labels, pol, ystar); + uint32_t pol = choose_policy(*srn, srn->allow_current_policy, true); + v_array<CSOAA::wclass> valid_labels = get_all_labels(*srn, num_ec, yallowed); + uint32_t a = single_action(all, *srn, ecs, num_ec, valid_labels, pol, ystar); srn->train_action.push_back(a); srn->train_labels.push_back(valid_labels); srn->t++; @@ -1502,9 +1523,9 @@ namespace ImperativeSearn { srn->t++; return srn->learn_a; } else { - uint32_t pol = choose_policy(srn, srn->allow_current_policy, true); - v_array<CSOAA::wclass> valid_labels = get_all_labels(srn, num_ec, yallowed); - uint32_t a = single_action(all, ecs, num_ec, valid_labels, pol, ystar); + uint32_t pol = choose_policy(*srn, srn->allow_current_policy, true); + v_array<CSOAA::wclass> valid_labels = get_all_labels(*srn, num_ec, yallowed); + uint32_t a = single_action(all, *srn, ecs, num_ec, valid_labels, pol, ystar); srn->t++; valid_labels.erase(); valid_labels.delete_v(); return a; @@ -1517,8 +1538,7 @@ namespace ImperativeSearn { void searn_declare_loss(vw& all, size_t predictions_since_last, float incr_loss) { - searn_struct *srn = (searn_struct*)all.searnstr; - + searn* srn=(searn*)all.searnstr; if (srn->t != srn->loss_last_step + predictions_since_last) { cerr << "fail: searntask hasn't counted its predictions correctly. current time step=" << srn->t << ", last declaration at " << srn->loss_last_step << ", declared # of predictions since then is " << predictions_since_last << endl; exit(-1); @@ -1544,7 +1564,7 @@ namespace ImperativeSearn { void searn_snapshot(vw& all, size_t index, size_t tag, void* data_ptr, size_t sizeof_data) { - searn_struct *srn = (searn_struct*)all.searnstr; + searn* srn=(searn*)all.searnstr; if (! srn->do_snapshot) return; //cerr << "snapshot called with: { index=" << index << ", tag=" << tag << ", data_ptr=" << *(size_t*)data_ptr << ", t=" << srn->t << " }" << endl; @@ -1593,19 +1613,16 @@ namespace ImperativeSearn { srn->t = item.pred_step; } - v_array<size_t> get_training_timesteps(vw& all) + v_array<size_t> get_training_timesteps(vw& all, searn& srn) { - searn_struct *srn = (searn_struct*)all.searnstr; v_array<size_t> timesteps; - for (size_t t=0; t<srn->T; t++) + for (size_t t=0; t<srn.T; t++) timesteps.push_back(t); return timesteps; } - void generate_training_example(vw& all, example** ec, size_t len, v_array<CSOAA::wclass> labels, v_array<float> losses) + void generate_training_example(vw& all, searn& srn, example** ec, size_t len, v_array<CSOAA::wclass> labels, v_array<float> losses) { - searn_struct *srn = (searn_struct*)all.searnstr; - assert(labels.size() == losses.size()); for (size_t i=0; i<labels.size(); i++) labels[i].x = losses[i]; @@ -1614,130 +1631,127 @@ namespace ImperativeSearn { void* old_label = ec[0]->ld; CSOAA::label new_label = { labels }; ec[0]->ld = (void*)&new_label; - SearnUtil::add_policy_offset(all, ec[0], srn->increment, srn->current_policy); - srn->base_learner(&all, ec[0]); - SearnUtil::remove_policy_offset(all, ec[0], srn->increment, srn->current_policy); + SearnUtil::add_policy_offset(all, ec[0], srn.increment, srn.current_policy); + srn.base.learn(&all,srn.base.data, ec[0]); + SearnUtil::remove_policy_offset(all, ec[0], srn.increment, srn.current_policy); ec[0]->ld = old_label; - srn->total_examples_generated++; + srn.total_examples_generated++; } else { // isLDF //TODO } } - void train_single_example(vw& all, example**ec, size_t len) + void train_single_example(vw& all, searn& srn, example**ec, size_t len) { - searn_struct *srn = (searn_struct*)all.searnstr; - // do an initial test pass to compute output (and loss) //cerr << "======================================== INIT TEST ========================================" << endl; - srn->state = INIT_TEST; - srn->t = 0; - srn->T = 0; - srn->loss_last_step = 0; - srn->test_loss = 0.f; - srn->train_loss = 0.f; - srn->learn_loss = 0.f; - srn->learn_example_copy = NULL; - srn->learn_example_len = 0; - srn->train_action.erase(); - srn->num_features = 0; + srn.state = INIT_TEST; + srn.t = 0; + srn.T = 0; + srn.loss_last_step = 0; + srn.test_loss = 0.f; + srn.train_loss = 0.f; + srn.learn_loss = 0.f; + srn.learn_example_copy = NULL; + srn.learn_example_len = 0; + srn.train_action.erase(); + srn.num_features = 0; - srn->task.structured_predict(all, ec, len, srn->pred_string, srn->truth_string); + srn.task->structured_predict(all, srn, ec, len, srn.pred_string, srn.truth_string); - if (srn->t == 0) + if (srn.t == 0) return; // there was no data! // do a pass over the data allowing oracle and snapshotting //cerr << "======================================== INIT TRAIN ========================================" << endl; - srn->state = INIT_TRAIN; - srn->t = 0; - srn->loss_last_step = 0; - clear_snapshot(all); + srn.state = INIT_TRAIN; + srn.t = 0; + srn.loss_last_step = 0; + clear_snapshot(all, srn); - srn->task.structured_predict(all, ec, len, NULL, NULL); + srn.task->structured_predict(all, srn, ec, len, NULL, NULL); - if (srn->t == 0) { - clear_snapshot(all); + if (srn.t == 0) { + clear_snapshot(all, srn); return; // there was no data } - srn->T = srn->t; + srn.T = srn.t; // generate training examples on which to learn //cerr << "======================================== LEARN ========================================" << endl; - srn->state = LEARN; - v_array<size_t> tset = get_training_timesteps(all); + srn.state = LEARN; + v_array<size_t> tset = get_training_timesteps(all, srn); for (size_t t=0; t<tset.size(); t++) { - v_array<CSOAA::wclass> aset = srn->train_labels[t]; - srn->learn_t = t; - srn->learn_losses.erase(); + v_array<CSOAA::wclass> aset = srn.train_labels[t]; + srn.learn_t = t; + srn.learn_losses.erase(); for (size_t i=0; i<aset.size(); i++) { - if (aset[i].weight_index == srn->train_action[srn->learn_t]) - srn->learn_losses.push_back( srn->train_loss ); + if (aset[i].weight_index == srn.train_action[srn.learn_t]) + srn.learn_losses.push_back( srn.train_loss ); else { - srn->t = 0; - srn->learn_a = aset[i].weight_index; - srn->loss_last_step = 0; - srn->learn_loss = 0.f; + srn.t = 0; + srn.learn_a = aset[i].weight_index; + srn.loss_last_step = 0; + srn.learn_loss = 0.f; - //cerr << "learn_t = " << srn->learn_t << " || learn_a = " << srn->learn_a << endl; - srn->task.structured_predict(all, ec, len, NULL, NULL); + //cerr << "learn_t = " << srn.learn_t << " || learn_a = " << srn.learn_a << endl; + srn.task->structured_predict(all, srn, ec, len, NULL, NULL); - srn->learn_losses.push_back( srn->learn_loss ); - //cerr << "total loss: " << srn->learn_loss << endl; + srn.learn_losses.push_back( srn.learn_loss ); + //cerr << "total loss: " << srn.learn_loss << endl; } } - if (srn->learn_example_copy != NULL) { - generate_training_example(all, srn->learn_example_copy, srn->learn_example_len, aset, srn->learn_losses); + if (srn.learn_example_copy != NULL) { + generate_training_example(all, srn, srn.learn_example_copy, srn.learn_example_len, aset, srn.learn_losses); - for (size_t n=0; n<srn->learn_example_len; n++) { - dealloc_example(CSOAA::delete_label, *srn->learn_example_copy[n]); - free(srn->learn_example_copy[n]); + for (size_t n=0; n<srn.learn_example_len; n++) { + dealloc_example(CSOAA::delete_label, *srn.learn_example_copy[n]); + free(srn.learn_example_copy[n]); } - free(srn->learn_example_copy); - srn->learn_example_copy = NULL; - srn->learn_example_len = 0; + free(srn.learn_example_copy); + srn.learn_example_copy = NULL; + srn.learn_example_len = 0; } else { cerr << "warning: searn did not generate an example for a given time-step" << endl; } } tset.erase(); tset.delete_v(); - clear_snapshot(all); - srn->train_action.erase(); - srn->train_action.delete_v(); - for (size_t i=0; i<srn->train_labels.size(); i++) { - srn->train_labels[i].erase(); - srn->train_labels[i].delete_v(); + clear_snapshot(all, srn); + srn.train_action.erase(); + srn.train_action.delete_v(); + for (size_t i=0; i<srn.train_labels.size(); i++) { + srn.train_labels[i].erase(); + srn.train_labels[i].delete_v(); } - srn->train_labels.erase(); - srn->train_labels.delete_v(); + srn.train_labels.erase(); + srn.train_labels.delete_v(); } - void clear_seq(vw&all) + void clear_seq(vw&all, searn& srn) { - searn_struct *srn = (searn_struct*)all.searnstr; - if (srn->ec_seq.size() > 0) - for (example** ecc=srn->ec_seq.begin; ecc!=srn->ec_seq.end; ecc++) { + if (srn.ec_seq.size() > 0) + for (example** ecc=srn.ec_seq.begin; ecc!=srn.ec_seq.end; ecc++) { VW::finish_example(all, *ecc); } - srn->ec_seq.erase(); + srn.ec_seq.erase(); } float safediv(float a,float b) { if (b == 0.f) return 0.f; else return (a/b); } - void print_update(vw& all) + + void print_update(vw& all, searn& srn) { - searn_struct *srn = (searn_struct*)all.searnstr; if (!Searn::should_print_update(all)) return; char true_label[21]; char pred_label[21]; - Searn::to_short_string(srn->truth_string ? srn->truth_string->str() : "", 20, true_label); - Searn::to_short_string(srn->pred_string ? srn->pred_string->str() : "", 20, pred_label); + Searn::to_short_string(srn.truth_string ? srn.truth_string->str() : "", 20, true_label); + Searn::to_short_string(srn.pred_string ? srn.pred_string->str() : "", 20, pred_label); fprintf(stderr, "%-10.6f %-10.6f %8ld %15f [%s] [%s] %8lu %5d %5d %15lu %15lu\n", safediv((float)all.sd->sum_loss, (float)all.sd->weighted_examples), @@ -1746,11 +1760,11 @@ namespace ImperativeSearn { all.sd->weighted_examples, true_label, pred_label, - (long unsigned int)srn->num_features, - (int)srn->read_example_last_pass, - (int)srn->current_policy, - (long unsigned int)srn->total_predictions_made, - (long unsigned int)srn->total_examples_generated); + (long unsigned int)srn.num_features, + (int)srn.read_example_last_pass, + (int)srn.current_policy, + (long unsigned int)srn.total_predictions_made, + (long unsigned int)srn.total_examples_generated); all.sd->sum_loss_since_last_dump = 0.0; all.sd->old_weighted_examples = all.sd->weighted_examples; @@ -1758,54 +1772,53 @@ namespace ImperativeSearn { } - void do_actual_learning(vw&all) + void do_actual_learning(vw&all, searn& srn) { - searn_struct *srn = (searn_struct*)all.searnstr; - if (srn->ec_seq.size() == 0) + if (srn.ec_seq.size() == 0) return; // nothing to do :) - if (Searn::should_print_update(all) || Searn::will_global_print_label(all)) { - srn->truth_string = new stringstream(); - srn->pred_string = new stringstream(); + if (Searn::should_print_update(all)) { + srn.truth_string = new stringstream(); + srn.pred_string = new stringstream(); } - train_single_example(all, srn->ec_seq.begin, srn->ec_seq.size()); - if (srn->test_loss >= 0.f) { - all.sd->sum_loss += srn->test_loss; - all.sd->sum_loss_since_last_dump += srn->test_loss; + train_single_example(all, srn, srn.ec_seq.begin, srn.ec_seq.size()); + if (srn.test_loss >= 0.f) { + all.sd->sum_loss += srn.test_loss; + all.sd->sum_loss_since_last_dump += srn.test_loss; all.sd->example_number++; - all.sd->total_features += srn->num_features; + all.sd->total_features += srn.num_features; all.sd->weighted_examples += 1.f; } - print_update(all); + print_update(all, srn); - if (srn->truth_string != NULL) { - delete srn->truth_string; - srn->truth_string = NULL; + if (srn.truth_string != NULL) { + delete srn.truth_string; + srn.truth_string = NULL; } - if (srn->pred_string != NULL) { - delete srn->pred_string; - srn->pred_string = NULL; + if (srn.pred_string != NULL) { + delete srn.pred_string; + srn.pred_string = NULL; } } - void searn_learn(void*in, example*ec) { - vw all = *(vw*)in; - searn_struct *srn = (searn_struct*)all.searnstr; + void searn_learn(void*in, void*d, example*ec) { + vw* all = (vw*)in; + searn *srn = (searn*)d; - if (srn->ec_seq.size() >= all.p->ring_size - 2) { // give some wiggle room + if (srn->ec_seq.size() >= all->p->ring_size - 2) { // give some wiggle room std::cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << std::endl; - do_actual_learning(all); - clear_seq(all); + do_actual_learning(*all, *srn); + clear_seq(*all, *srn); } bool is_real_example = true; if (OAA::example_is_newline(ec)) { - do_actual_learning(all); - clear_seq(all); - VW::finish_example(all, ec); + do_actual_learning(*all, *srn); + clear_seq(*all, *srn); + VW::finish_example(*all, ec); is_real_example = false; } else { srn->ec_seq.push_back(ec); @@ -1819,7 +1832,7 @@ namespace ImperativeSearn { srn->passes_since_new_policy++; if (srn->passes_since_new_policy >= srn->passes_per_policy) { srn->passes_since_new_policy = 0; - if(all.training) + if(all->training) srn->current_policy++; if (srn->current_policy > srn->total_number_of_policies) { std::cerr << "internal error (bug): too many policies; not advancing" << std::endl; @@ -1828,15 +1841,15 @@ namespace ImperativeSearn { //reset searn_trained_nb_policies in options_from_file so it is saved to regressor file later std::stringstream ss; ss << srn->current_policy; - VW::cmd_string_replace_value(all.options_from_file,"--searn_trained_nb_policies", ss.str()); + VW::cmd_string_replace_value(all->options_from_file,"--searn_trained_nb_policies", ss.str()); } } } } - void searn_drive(void*in) { - vw all = *(vw*)in; - searn_struct *srn = (searn_struct*)all.searnstr; + void searn_drive(void*in, void *d) { + vw* all = (vw*)in; + searn *srn = (searn*)d; const char * header_fmt = "%-10s %-10s %8s %15s %24s %22s %8s %5s %5s %15s %15s\n"; @@ -1847,63 +1860,61 @@ namespace ImperativeSearn { example* ec = NULL; srn->read_example_this_loop = 0; while (true) { - if ((ec = get_example(all.p)) != NULL) { // semiblocking operation - searn_learn(in, ec); - } else if (parser_done(all.p)) { - do_actual_learning(all); + if ((ec = get_example(all->p)) != NULL) { // semiblocking operation + searn_learn(in,d, ec); + } else if (parser_done(all->p)) { + do_actual_learning(*all, *srn); break; } } - if( all.training ) { + if( all->training ) { std::stringstream ss1; std::stringstream ss2; ss1 << (srn->current_policy+1); //use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searnimp_trained_nb_policies - VW::cmd_string_replace_value(all.options_from_file,"--searnimp_trained_nb_policies", ss1.str()); + VW::cmd_string_replace_value(all->options_from_file,"--searnimp_trained_nb_policies", ss1.str()); ss2 << srn->total_number_of_policies; //use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searnimp_total_nb_policies - VW::cmd_string_replace_value(all.options_from_file,"--searnimp_total_nb_policies", ss2.str()); + VW::cmd_string_replace_value(all->options_from_file,"--searnimp_total_nb_policies", ss2.str()); } } - void searn_initialize(vw& all) + void searn_initialize(vw& all, searn& srn) { - searn_struct *srn = (searn_struct*)all.searnstr; - - srn->predict = searn_predict; - srn->declare_loss = searn_declare_loss; - srn->snapshot = searn_snapshot; - - srn->beta = 0.5; - srn->allow_current_policy = false; - srn->rollout_all_actions = true; - srn->num_features = 0; - srn->current_policy = 1; - srn->state = 0; - srn->do_snapshot = true; - - srn->passes_per_policy = 1; //this should be set to the same value as --passes for dagger - - srn->read_example_this_loop = 0; - srn->read_example_last_id = 0; - srn->passes_since_new_policy = 0; - srn->read_example_last_pass = 0; - srn->total_examples_generated = 0; - srn->total_predictions_made = 0; + srn.predict = searn_predict; + srn.declare_loss = searn_declare_loss; + srn.snapshot = searn_snapshot; + + srn.beta = 0.5; + srn.allow_current_policy = false; + srn.rollout_all_actions = true; + srn.num_features = 0; + srn.current_policy = 1; + srn.state = 0; + srn.do_snapshot = true; + + srn.passes_per_policy = 1; //this should be set to the same value as --passes for dagger + + srn.read_example_this_loop = 0; + srn.read_example_last_id = 0; + srn.passes_since_new_policy = 0; + srn.read_example_last_pass = 0; + srn.total_examples_generated = 0; + srn.total_predictions_made = 0; } - void searn_finish(void*in) + void searn_finish(void*in, void* d) { vw*all = (vw*)in; - searn_struct *srn = (searn_struct*)all->searnstr; + searn *srn = (searn*)d; //cerr << "searn_finish" << endl; - clear_seq(*all); + clear_seq(*all,*srn); srn->ec_seq.delete_v(); - clear_snapshot(*all); + clear_snapshot(*all, *srn); srn->snapshot_data.delete_v(); for (size_t i=0; i<srn->train_labels.size(); i++) { @@ -1914,10 +1925,10 @@ namespace ImperativeSearn { srn->train_action.erase(); srn->train_action.delete_v(); srn->learn_losses.erase(); srn->learn_losses.delete_v(); - if (srn->task.finish != NULL) - srn->task.finish(*all); - if (srn->task.finish != NULL) - srn->base_finish(all); + if (srn->task->finish != NULL) + srn->task->finish(*all); + if (srn->task->finish != NULL) + srn->base.finish(all, srn->base.data); } void ensure_param(float &v, float lo, float hi, float def, const char* string) { @@ -1952,9 +1963,9 @@ namespace ImperativeSearn { void parse_flags(vw&all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) { - searn_struct *srn = (searn_struct*)all.searnstr; + searn* srn = (searn*)calloc(1,sizeof(searn)); - searn_initialize(all); + searn_initialize(all, *srn); po::options_description desc("Imperative Searn options"); desc.add_options() @@ -2038,25 +2049,24 @@ namespace ImperativeSearn { srn->increment = ((uint32_t)all.length() / all.base_learner_nb_w) * all.stride; if (task_string.compare("sequence") == 0) { - searn_task mytask = { SequenceTask_Easy::initialize, - SequenceTask_Easy::finish, - SequenceTask_Easy::structured_predict_v1 - }; + searn_task* mytask = (searn_task*)calloc(1, sizeof(searn_task)); + mytask->initialize = SequenceTask_Easy::initialize; + mytask->finish = SequenceTask_Easy::finish; + mytask->structured_predict = SequenceTask_Easy::structured_predict_v1; + srn->task = mytask; } else { cerr << "fail: unknown task for --searn_task: " << task_string << endl; exit(-1); } - srn->task.initialize(all, srn->A); + srn->task->initialize(all, srn->A); - all.driver = searn_drive; - srn->base_learner = all.learn; - all.learn = searn_learn; - srn->base_finish = all.finish; - all.finish = searn_finish; + srn->base = all.l; + + learner l = {srn, searn_drive, searn_learn, searn_finish, all.l.save_load}; + all.l = l; } - } /* time ./vw --searn 45 --searn_task sequence -k -c -d ../test/train-sets/wsj_small.dat2.gz --passes 5 --searn_passes_per_policy 4 diff --git a/vowpalwabbit/searn.h b/vowpalwabbit/searn.h index b7c02528..71cff9ab 100644 --- a/vowpalwabbit/searn.h +++ b/vowpalwabbit/searn.h @@ -195,16 +195,9 @@ namespace Searn }; void parse_flags(vw&all, std::vector<std::string>&, po::variables_map& vm, po::variables_map& vm_file); - void drive(void*); } namespace ImperativeSearn { - struct searn_task { - void (*initialize)(vw&,uint32_t&); - void (*finish)(vw&); - void (*structured_predict)(vw&,example**,size_t,stringstream*,stringstream*); - }; - struct snapshot_item { size_t index; size_t tag; @@ -212,15 +205,17 @@ namespace ImperativeSearn { size_t data_size; // sizeof(data_ptr) size_t pred_step; // srn->t when snapshot is made }; + + struct searn_task; - struct searn_struct { + struct searn { // functions that you will call uint32_t (*predict)(vw&,example**,size_t,v_array<uint32_t>*,v_array<uint32_t>*); void (*declare_loss)(vw&,size_t,float); // <0 means it was a test example! void (*snapshot)(vw&,size_t,size_t,void*,size_t); // structure that you must set - searn_task task; + searn_task* task; // data that you should not look at. ever. uint32_t A; // total number of actions, [1..A]; 0 means ldf @@ -265,8 +260,13 @@ namespace ImperativeSearn { v_array<example*> ec_seq; - void (*base_finish)(void*); - void (*base_learner)(void*,example*); + learner base; + }; + + struct searn_task { + void (*initialize)(vw&,uint32_t&); + void (*finish)(vw&); + void (*structured_predict)(vw&, searn&, example**,size_t,stringstream*,stringstream*); }; void parse_flags(vw&, std::vector<std::string>&, po::variables_map&, po::variables_map&); diff --git a/vowpalwabbit/searn_sequencetask.cc b/vowpalwabbit/searn_sequencetask.cc index f2487e47..e1001328 100644 --- a/vowpalwabbit/searn_sequencetask.cc +++ b/vowpalwabbit/searn_sequencetask.cc @@ -351,8 +351,7 @@ namespace SequenceTask_Easy { out->push_back( lab->costs[l].weight_index ); } - void structured_predict_v1(vw& vw, example**ec, size_t len, stringstream*output_ss, stringstream*truth_ss) { - searn_struct srn = *(searn_struct*)vw.searnstr; + void structured_predict_v1(vw& vw, searn& srn, example**ec, size_t len, stringstream*output_ss, stringstream*truth_ss) { float total_loss = 0; size_t history_length = max(hinfo.features, hinfo.length); bool is_train = false; diff --git a/vowpalwabbit/searn_sequencetask.h b/vowpalwabbit/searn_sequencetask.h index 3c174ce9..89f5c30d 100644 --- a/vowpalwabbit/searn_sequencetask.h +++ b/vowpalwabbit/searn_sequencetask.h @@ -8,6 +8,7 @@ license as described in the file LICENSE. #include "oaa.h" #include "parse_primitives.h" +#include "searn.h" namespace SequenceTask { bool final(state); @@ -30,7 +31,7 @@ namespace SequenceTask { namespace SequenceTask_Easy { void initialize(vw&, uint32_t&); void finish(vw&); - void structured_predict_v1(vw&,example**,size_t,stringstream*,stringstream*); + void structured_predict_v1(vw&, ImperativeSearn::searn&, example**,size_t,stringstream*,stringstream*); } diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc index ea7243d3..95d20ff9 100644 --- a/vowpalwabbit/sender.cc +++ b/vowpalwabbit/sender.cc @@ -24,25 +24,17 @@ using namespace std; namespace SENDER { -//nonreentrant -io_buf* buf; + struct sender { + io_buf* buf; + + int sd; + }; -int sd = -1; - -void open_sockets(string host) -{ - sd = open_socket(host.c_str()); - buf = new io_buf(); - buf->files.push_back(sd); -} - -void parse_send_args(po::variables_map& vm, vector<string> pairs) + void open_sockets(sender& s, string host) { - if (vm.count("sendto")) - { - vector<string> hosts = vm["sendto"].as< vector<string> >(); - open_sockets(hosts[0]); - } + s.sd = open_socket(host.c_str()); + s.buf = new io_buf(); + s.buf->files.push_back(s.sd); } void send_features(io_buf *b, example* ec) @@ -58,11 +50,12 @@ void send_features(io_buf *b, example* ec) b->flush(); } -void save_load(void* in, io_buf& model_file, bool read, bool text) {} + void save_load(void* in, void* d, io_buf& model_file, bool read, bool text) {} -void drive_send(void* in) + void drive_send(void* in, void* d) { vw* all = (vw*)in; + sender* s = (sender*)d; example* ec = NULL; v_array<char> null_tag; null_tag.erase(); @@ -77,7 +70,7 @@ void drive_send(void* in) if (received_index + all->p->ring_size == sent_index || (parser_finished & (received_index != sent_index))) { float res, weight; - get_prediction(sd,res,weight); + get_prediction(s->sd,res,weight); ec=delay_ring[received_index++ % all->p->ring_size]; label_data* ld = (label_data*)ec->ld; @@ -92,9 +85,9 @@ void drive_send(void* in) { label_data* ld = (label_data*)ec->ld; all->set_minmax(all->sd, ld->label); - simple_label.cache_label(ld, *buf);//send label information. - cache_tag(*buf, ec->tag); - send_features(buf,ec); + simple_label.cache_label(ld, *s->buf);//send label information. + cache_tag(*s->buf, ec->tag); + send_features(s->buf,ec); delay_ring[sent_index++ % all->p->ring_size] = ec; } else if (parser_done(all->p)) @@ -102,9 +95,9 @@ void drive_send(void* in) parser_finished = true; if (received_index == sent_index) { - shutdown(buf->files[0],SHUT_WR); - buf->files.delete_v(); - buf->space.delete_v(); + shutdown(s->buf->files[0],SHUT_WR); + s->buf->files.delete_v(); + s->buf->space.delete_v(); free(delay_ring); return; } @@ -114,5 +107,21 @@ void drive_send(void* in) } return; } + void learn(void*in, void* d, example*ec) { cout << "sender can't be used under reduction" << endl; } + void finish(void*in, void* d) { cout << "sender can't be used under reduction" << endl; } + + void parse_send_args(vw& all, po::variables_map& vm, vector<string> pairs) +{ + sender* s = (sender*)calloc(1,sizeof(sender)); + s->sd = -1; + if (vm.count("sendto")) + { + vector<string> hosts = vm["sendto"].as< vector<string> >(); + open_sockets(*s, hosts[0]); + } + + learner ret = {s,drive_send,learn,finish,save_load}; + all.l = ret; +} } diff --git a/vowpalwabbit/sender.h b/vowpalwabbit/sender.h index 57b727cd..a43cfed9 100644 --- a/vowpalwabbit/sender.h +++ b/vowpalwabbit/sender.h @@ -4,7 +4,5 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ namespace SENDER{ -void parse_send_args(po::variables_map& vm, std::vector<std::string> pairs); -void drive_send(void*); - void save_load(void* in, io_buf& model_file, bool read, bool text); + void parse_send_args(vw& all, po::variables_map& vm, vector<string> pairs); } diff --git a/vowpalwabbit/vw.cc b/vowpalwabbit/vw.cc index 0753816c..372bf824 100644 --- a/vowpalwabbit/vw.cc +++ b/vowpalwabbit/vw.cc @@ -44,7 +44,7 @@ int main(int argc, char *argv[]) start_parser(all); - all.driver(&all); + all.l.driver(&all, &all.l.data); end_parser(all); diff --git a/vowpalwabbit/wap.cc b/vowpalwabbit/wap.cc index 90b8103b..e45e65e2 100644 --- a/vowpalwabbit/wap.cc +++ b/vowpalwabbit/wap.cc @@ -17,9 +17,11 @@ license as described in the file LICENSE. using namespace std; namespace WAP { - //nonreentrant - uint32_t increment=0; - + struct wap{ + uint32_t increment; + learner base; + }; + void mirror_features(vw& all, example* ec, uint32_t offset1, uint32_t offset2) { for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++) @@ -128,7 +130,7 @@ namespace WAP { void (*base_learner)(void*, example*) = NULL; - void train(vw& all, example* ec) + void train(vw& all, wap& w, example* ec) { CSOAA::label* ld = (CSOAA::label*)ec->ld; @@ -182,10 +184,10 @@ namespace WAP { uint32_t myi = (uint32_t)vs[i].ci.weight_index; uint32_t myj = (uint32_t)vs[j].ci.weight_index; - mirror_features(all, ec,(myi-1)*increment, (myj-1)*increment); + mirror_features(all, ec,(myi-1)*w.increment, (myj-1)*w.increment); base_learner(&all, ec); - unmirror_features(all, ec,(myi-1)*increment, (myj-1)*increment); + unmirror_features(all, ec,(myi-1)*w.increment, (myj-1)*w.increment); } } @@ -193,7 +195,7 @@ namespace WAP { ec->ld = ld; } - size_t test(vw& all, example* ec) + size_t test(vw& all, wap& w, example* ec) { size_t prediction = 1; float score = -FLT_MAX; @@ -208,12 +210,12 @@ namespace WAP { simple_temp.label = FLT_MAX; uint32_t myi = (uint32_t)cost_label->costs[i].weight_index; if (myi!= 1) - update_example_indicies(all.audit, ec, increment*(myi-1)); + update_example_indicies(all.audit, ec, w.increment*(myi-1)); ec->partial_prediction = 0.; ec->ld = &simple_temp; base_learner(&all, ec); if (myi != 1) - update_example_indicies(all.audit, ec, -increment*(myi-1)); + update_example_indicies(all.audit, ec, -w.increment*(myi-1)); if (ec->partial_prediction > score) { score = ec->partial_prediction; @@ -224,21 +226,29 @@ namespace WAP { return prediction; } - void learn(void* a, example* ec) + void learn(void* a, void* d, example* ec) { vw* all = (vw*)a; CSOAA::label* cost_label = (CSOAA::label*)ec->ld; + wap* w = (wap*)d; - size_t prediction = test(*all, ec); + size_t prediction = test(*all, *w, ec); ec->ld = cost_label; if (cost_label->costs.size() > 0) - train(*all, ec); + train(*all, *w, ec); *(OAA::prediction_t*)&(ec->final_prediction) = prediction; } - - void drive_wap(void* in) + + void finish(void* a, void* d) + { + wap* w=(wap*)d; + w->base.finish(a,w->base.data); + free(w); + } + + void drive(void* in, void* d) { vw* all = (vw*)in; example* ec = NULL; @@ -246,22 +256,20 @@ namespace WAP { { if ((ec = get_example(all->p)) != NULL)//semiblocking operation. { - learn(all,ec); + learn(all, d, ec); CSOAA::output_example(*all, ec); VW::finish_example(*all, ec); } else if (parser_done(all->p)) - { - all->finish(all); - return; - } + return; else ; } } - + void parse_flags(vw& all, std::vector<std::string>&, po::variables_map& vm, po::variables_map& vm_file) { + wap* w=(wap*)calloc(1,sizeof(wap)); uint32_t nb_actions = 0; if( vm_file.count("wap") ) { //if loaded options from regressor nb_actions = (uint32_t)vm_file["wap"].as<size_t>(); @@ -281,11 +289,9 @@ namespace WAP { all.sd->k = (uint32_t)nb_actions; all.base_learner_nb_w *= nb_actions; - increment = (uint32_t)((all.length()/ all.base_learner_nb_w) * all.stride); + w->increment = (uint32_t)((all.length()/ all.base_learner_nb_w) * all.stride); - all.driver = drive_wap; - base_learner = all.learn; - all.base_learn = all.learn; - all.learn = learn; + learner l = {w, drive, learn, finish, all.l.save_load}; + all.l = l; } } |