diff options
Diffstat (limited to 'vowpalwabbit/nn.cc')
-rw-r--r-- | vowpalwabbit/nn.cc | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc index bfab6009..bcd42092 100644 --- a/vowpalwabbit/nn.cc +++ b/vowpalwabbit/nn.cc @@ -308,19 +308,20 @@ CONVERSE: // That's right, I'm using goto. So sue me. free (n.output_layer.atomics[nn_output_namespace].begin); } - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) { - nn& n = calloc_or_die<nn>(); - n.all = &all; - - po::options_description nn_opts("NN options"); - nn_opts.add_options() + new_options(all, "Neural Network options") + ("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units"); + if(missing_required(all)) return NULL; + new_options(all) ("inpass", "Train or test sigmoidal feedforward network with input passthrough.") ("dropout", "Train or test sigmoidal feedforward network using dropout.") ("meanfield", "Train or test sigmoidal feedforward network using mean field."); + add_options(all); - vm = add_options(all, nn_opts); - + po::variables_map& vm = all.vm; + nn& n = calloc_or_die<nn>(); + n.all = &all; //first parse for number of hidden units n.k = (uint32_t)vm["nn"].as<size_t>(); *all.file_options << " --nn " << n.k; @@ -364,8 +365,10 @@ CONVERSE: // That's right, I'm using goto. So sue me. n.xsubi = vm["random_seed"].as<size_t>(); n.save_xsubi = n.xsubi; - n.increment = all.l->increment;//Indexing of output layer is odd. - learner<nn>& l = init_learner(&n, all.l, predict_or_learn<true>, + + base_learner* base = setup_base(all); + n.increment = base->increment;//Indexing of output layer is odd. + learner<nn>& l = init_learner(&n, base, predict_or_learn<true>, predict_or_learn<false>, n.k+1); l.set_finish(finish); l.set_finish_example(finish_example); |