diff options
Diffstat (limited to 'vowpalwabbit/stagewise_poly.cc')
-rw-r--r-- | vowpalwabbit/stagewise_poly.cc | 64 |
1 files changed, 31 insertions, 33 deletions
diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc index c6352d26..b2e7e150 100644 --- a/vowpalwabbit/stagewise_poly.cc +++ b/vowpalwabbit/stagewise_poly.cc @@ -129,7 +129,7 @@ namespace StagewisePoly void depthsbits_create(stagewise_poly &poly) { - poly.depthsbits = (uint8_t *) calloc_or_die(1, depthsbits_sizeof(poly)); + poly.depthsbits = calloc_or_die<uint8_t>(2 * poly.all->length()); for (uint32_t i = 0; i < poly.all->length() * 2; i += 2) { poly.depthsbits[i] = default_depth; poly.depthsbits[i+1] = indicator_bit; @@ -247,7 +247,7 @@ namespace StagewisePoly cout << ", new size " << poly.sd_len << endl; #endif //DEBUG free(poly.sd); //okay for null. - poly.sd = (sort_data *) calloc_or_die(poly.sd_len, sizeof(sort_data)); + poly.sd = calloc_or_die<sort_data>(poly.sd_len); } assert(len <= poly.sd_len); } @@ -502,7 +502,7 @@ namespace StagewisePoly } } - void predict(stagewise_poly &poly, learner &base, example &ec) + void predict(stagewise_poly &poly, base_learner &base, example &ec) { poly.original_ec = &ec; synthetic_create(poly, ec, false); @@ -511,7 +511,7 @@ namespace StagewisePoly ec.updated_prediction = poly.synth_ec.updated_prediction; } - void learn(stagewise_poly &poly, learner &base, example &ec) + void learn(stagewise_poly &poly, base_learner &base, example &ec) { bool training = poly.all->training && ec.l.simple.label != FLT_MAX; poly.original_ec = &ec; @@ -657,13 +657,13 @@ namespace StagewisePoly } - learner *setup(vw &all, po::variables_map &vm) + base_learner *setup(vw &all, po::variables_map &vm) { - stagewise_poly *poly = (stagewise_poly *) calloc_or_die(1, sizeof(stagewise_poly)); - poly->all = &all; + stagewise_poly& poly = calloc_or_die<stagewise_poly>(); + poly.all = &all; - depthsbits_create(*poly); - sort_data_create(*poly); + depthsbits_create(poly); + sort_data_create(poly); po::options_description sp_opt("Stagewise poly options"); sp_opt.add_options() @@ -676,36 +676,34 @@ namespace StagewisePoly ; vm = add_options(all, sp_opt); - poly->sched_exponent = vm.count("sched_exponent") ? vm["sched_exponent"].as<float>() : 1.f; - poly->batch_sz = vm.count("batch_sz") ? vm["batch_sz"].as<uint32_t>() : 1000; - poly->batch_sz_double = vm.count("batch_sz_no_doubling") ? false : true; + poly.sched_exponent = vm.count("sched_exponent") ? vm["sched_exponent"].as<float>() : 1.f; + poly.batch_sz = vm.count("batch_sz") ? vm["batch_sz"].as<uint32_t>() : 1000; + poly.batch_sz_double = vm.count("batch_sz_no_doubling") ? false : true; #ifdef MAGIC_ARGUMENT - poly->magic_argument = vm.count("magic_argument") ? vm["magic_argument"].as<float>() : 0.; + poly.magic_argument = vm.count("magic_argument") ? vm["magic_argument"].as<float>() : 0.; #endif //MAGIC_ARGUMENT - poly->sum_sparsity = 0; - poly->sum_input_sparsity = 0; - poly->num_examples = 0; - poly->sum_sparsity_sync = 0; - poly->sum_input_sparsity_sync = 0; - poly->num_examples_sync = 0; - poly->last_example_counter = -1; - poly->numpasses = 1; - poly->update_support = false; - poly->original_ec = NULL; - poly->next_batch_sz = poly->batch_sz; + poly.sum_sparsity = 0; + poly.sum_input_sparsity = 0; + poly.num_examples = 0; + poly.sum_sparsity_sync = 0; + poly.sum_input_sparsity_sync = 0; + poly.num_examples_sync = 0; + poly.last_example_counter = -1; + poly.numpasses = 1; + poly.update_support = false; + poly.original_ec = NULL; + poly.next_batch_sz = poly.batch_sz; //following is so that saved models know to load us. - all.file_options.append(" --stage_poly"); + *all.file_options << " --stage_poly"; - learner *l = new learner(poly, all.l); - l->set_learn<stagewise_poly, learn>(); - l->set_predict<stagewise_poly, predict>(); - l->set_finish<stagewise_poly, finish>(); - l->set_save_load<stagewise_poly, save_load>(); - l->set_finish_example<stagewise_poly,finish_example>(); - l->set_end_pass<stagewise_poly, end_pass>(); + learner<stagewise_poly>& l = init_learner(&poly, all.l, learn, predict); + l.set_finish(finish); + l.set_save_load(save_load); + l.set_finish_example(finish_example); + l.set_end_pass(end_pass); - return l; + return make_base(l); } } |