Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2014-12-26 21:21:24 +0300
committerJohn Langford <jl@hunch.net>2014-12-26 21:21:24 +0300
commit834bb36afe50c82b7457f4920644f0b37dc1f7f4 (patch)
treecbe89d07be193c10a2d6e46fb51d2d1f4c9a67cb /vowpalwabbit
parent6d65f8b5667bf8828a4d6d6fa9a22277728c97d6 (diff)
parentcb0dc4cfde89a7f9198ccc9b0d9dedad52826f39 (diff)
fix conflicts
Diffstat (limited to 'vowpalwabbit')
-rw-r--r--vowpalwabbit/Makefile1
-rw-r--r--vowpalwabbit/Makefile.am2
-rw-r--r--vowpalwabbit/active.cc2
-rw-r--r--vowpalwabbit/autolink.cc3
-rw-r--r--vowpalwabbit/bfgs.cc6
-rw-r--r--vowpalwabbit/bs.cc2
-rw-r--r--vowpalwabbit/cb_algs.cc4
-rw-r--r--vowpalwabbit/cbify.cc10
-rw-r--r--vowpalwabbit/csoaa.cc5
-rw-r--r--vowpalwabbit/ect.cc4
-rw-r--r--vowpalwabbit/example.cc10
-rw-r--r--vowpalwabbit/ftrl_proximal.cc229
-rw-r--r--vowpalwabbit/ftrl_proximal.h12
-rw-r--r--vowpalwabbit/gd.cc31
-rw-r--r--vowpalwabbit/gd.h1
-rw-r--r--vowpalwabbit/gd_mf.cc2
-rw-r--r--vowpalwabbit/global_data.cc2
-rw-r--r--vowpalwabbit/global_data.h1
-rw-r--r--vowpalwabbit/kernel_svm.cc40
-rw-r--r--vowpalwabbit/lda_core.cc3
-rw-r--r--vowpalwabbit/learner.cc39
-rw-r--r--vowpalwabbit/learner.h4
-rw-r--r--vowpalwabbit/log_multi.cc1
-rw-r--r--vowpalwabbit/loss_functions.cc6
-rw-r--r--vowpalwabbit/lrq.cc4
-rw-r--r--vowpalwabbit/main.cc2
-rw-r--r--vowpalwabbit/memory.cc13
-rw-r--r--vowpalwabbit/memory.h18
-rw-r--r--vowpalwabbit/nn.cc4
-rw-r--r--vowpalwabbit/oaa.cc2
-rw-r--r--vowpalwabbit/parse_args.cc11
-rw-r--r--vowpalwabbit/parse_example.cc4
-rw-r--r--vowpalwabbit/parse_regressor.cc2
-rw-r--r--vowpalwabbit/parser.cc4
-rw-r--r--vowpalwabbit/print.cc2
-rw-r--r--vowpalwabbit/scorer.cc2
-rw-r--r--vowpalwabbit/search.cc10
-rw-r--r--vowpalwabbit/search_dep_parser.cc1
-rw-r--r--vowpalwabbit/sender.cc4
-rw-r--r--vowpalwabbit/simple_label.cc2
-rw-r--r--vowpalwabbit/stagewise_poly.cc6
-rw-r--r--vowpalwabbit/topk.cc2
42 files changed, 389 insertions, 124 deletions
diff --git a/vowpalwabbit/Makefile b/vowpalwabbit/Makefile
index 7cf003ea..79893902 100644
--- a/vowpalwabbit/Makefile
+++ b/vowpalwabbit/Makefile
@@ -10,6 +10,7 @@ all:
test:
cd ..; $(MAKE) test
+
things: config.h $(BINARIES)
%.1: %
diff --git a/vowpalwabbit/Makefile.am b/vowpalwabbit/Makefile.am
index dd8a9787..676d6f4e 100644
--- a/vowpalwabbit/Makefile.am
+++ b/vowpalwabbit/Makefile.am
@@ -4,7 +4,7 @@ liballreduce_la_SOURCES = allreduce.cc
bin_PROGRAMS = vw active_interactor
-libvw_la_SOURCES = hash.cc memory.cc global_data.cc io_buf.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc rand48.cc simple_label.cc multiclass.cc oaa.cc ect.cc autolink.cc binary.cc lrq.cc cost_sensitive.cc csoaa.cc cb.cc cb_algs.cc search.cc search_sequencetask.cc search_dep_parser.cc search_hooktask.cc search_multiclasstask.cc search_entityrelationtask.cc parse_example.cc scorer.cc network.cc parse_args.cc accumulate.cc gd.cc learner.cc lda_core.cc gd_mf.cc mf.cc bfgs.cc noop.cc print.cc example.cc parser.cc loss_functions.cc sender.cc nn.cc bs.cc cbify.cc topk.cc stagewise_poly.cc log_multi.cc active.cc kernel_svm.cc best_constant.cc
+libvw_la_SOURCES = hash.cc memory.cc global_data.cc io_buf.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc rand48.cc simple_label.cc multiclass.cc oaa.cc ect.cc autolink.cc binary.cc lrq.cc cost_sensitive.cc csoaa.cc cb.cc cb_algs.cc search.cc search_sequencetask.cc search_dep_parser.cc search_hooktask.cc search_multiclasstask.cc search_entityrelationtask.cc parse_example.cc scorer.cc network.cc parse_args.cc accumulate.cc gd.cc learner.cc lda_core.cc gd_mf.cc mf.cc bfgs.cc noop.cc print.cc example.cc parser.cc loss_functions.cc sender.cc nn.cc bs.cc cbify.cc topk.cc stagewise_poly.cc log_multi.cc active.cc kernel_svm.cc best_constant.cc ftrl_proximal.cc
libvw_c_wrapper_la_SOURCES = vwdll.cpp
diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc
index 772c5450..3dc6af02 100644
--- a/vowpalwabbit/active.cc
+++ b/vowpalwabbit/active.cc
@@ -162,7 +162,7 @@ namespace ACTIVE {
if(!vm.count("active"))
return NULL;
- active* data = (active*)calloc_or_die(1, sizeof(active));
+ active* data = calloc_or_die<active>();
data->active_c0 = 8;
data->all=&all;
diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc
index d43c7492..3fc67920 100644
--- a/vowpalwabbit/autolink.cc
+++ b/vowpalwabbit/autolink.cc
@@ -50,7 +50,8 @@ namespace ALINK {
if(!vm.count("autolink"))
return NULL;
- autolink* data = (autolink*)calloc_or_die(1,sizeof(autolink));
+ autolink* data = calloc_or_die<autolink>();
+
data->d = (uint32_t)vm["autolink"].as<size_t>();
data->stride_shift = all.reg.stride_shift;
diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc
index bcfeb315..3ac88939 100644
--- a/vowpalwabbit/bfgs.cc
+++ b/vowpalwabbit/bfgs.cc
@@ -479,7 +479,7 @@ void preconditioner_to_regularizer(vw& all, bfgs& b, float regularization)
weight* weights = all.reg.weight_vector;
if (b.regularizers == NULL)
{
- b.regularizers = (weight *)calloc_or_die(2*length, sizeof(weight));
+ b.regularizers = calloc_or_die<weight>(2*length);
if (b.regularizers == NULL)
{
@@ -909,7 +909,7 @@ void save_load(bfgs& b, io_buf& model_file, bool read, bool text)
initialize_regressor(*all);
if (all->per_feature_regularizer_input != "")
{
- b.regularizers = (weight *)calloc_or_die(2*length, sizeof(weight));
+ b.regularizers = calloc_or_die<weight>(2*length);
if (b.regularizers == NULL)
{
cerr << all->program_name << ": Failed to allocate regularizers array: try decreasing -b <bits>" << endl;
@@ -981,7 +981,7 @@ learner* setup(vw& all, po::variables_map& vm)
if(!vm.count("bfgs") && !vm.count("conjugate_gradient"))
return NULL;
- bfgs* b = (bfgs*)calloc_or_die(1,sizeof(bfgs));
+ bfgs* b = calloc_or_die<bfgs>();
b->all = &all;
b->m = vm["mem"].as<uint32_t>();
b->rel_threshold = vm["termination"].as<float>();
diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc
index 8c2ca4b9..5ae4bf1d 100644
--- a/vowpalwabbit/bs.cc
+++ b/vowpalwabbit/bs.cc
@@ -241,7 +241,7 @@ namespace BS {
learner* setup(vw& all, po::variables_map& vm)
{
- bs* data = (bs*)calloc_or_die(1, sizeof(bs));
+ bs* data = calloc_or_die<bs>();
data->ub = FLT_MAX;
data->lb = -FLT_MAX;
diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc
index daaa2826..45542f04 100644
--- a/vowpalwabbit/cb_algs.cc
+++ b/vowpalwabbit/cb_algs.cc
@@ -428,6 +428,7 @@ namespace CB_ALGS
float loss = 0.;
if (!is_test_label(ld))
{//need to compute exact loss
+ c.known_cost = get_observed_cost(ld);
float chosen_loss = FLT_MAX;
if( know_all_cost_example(ld) ) {
for (cb_class *cl = ld.costs.begin; cl != ld.costs.end; cl ++) {
@@ -507,7 +508,8 @@ namespace CB_ALGS
if (!vm.count("cb"))
return NULL;
- cb* c = (cb*)calloc_or_die(1, sizeof(cb));
+ cb* c = calloc_or_die<cb>();
+
c->all = &all;
c->min_cost = 0.;
c->max_cost = 1.;
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc
index 3ccf339e..8d8f62d0 100644
--- a/vowpalwabbit/cbify.cc
+++ b/vowpalwabbit/cbify.cc
@@ -170,7 +170,7 @@ namespace CBIFY {
vwc.l = &base;
vwc.e = &ec;
- uint32_t action = data.mwt_explorer->Choose_Action(*data.tau_explorer.get(), to_string((u64)ec.example_counter), vwc);
+ uint32_t action = data.mwt_explorer->Choose_Action(*data.tau_explorer.get(), to_string((unsigned long long)ec.example_counter), vwc);
ec.loss = loss(ld.label, action);
if (vwc.recorded && is_learn)
@@ -196,7 +196,7 @@ namespace CBIFY {
vw_context vwc;
vwc.l = &base;
vwc.e = &ec;
- data.mwt_explorer->Choose_Action(*data.greedy_explorer.get(), to_string((u64)ec.example_counter), vwc);
+ data.mwt_explorer->Choose_Action(*data.greedy_explorer.get(), to_string((unsigned long long)ec.example_counter), vwc);
u32 action = data.recorder->Get_Action();
float prob = data.recorder->Get_Prob();
@@ -224,7 +224,7 @@ namespace CBIFY {
vw_context context;
context.l = &base;
context.e = &ec;
- uint32_t action = data.mwt_explorer->Choose_Action(*data.bootstrap_explorer.get(), to_string((u64)ec.example_counter), context);
+ uint32_t action = data.mwt_explorer->Choose_Action(*data.bootstrap_explorer.get(), to_string((unsigned long long)ec.example_counter), context);
assert(action != 0);
if (is_learn)
@@ -317,7 +317,7 @@ namespace CBIFY {
vw_context cp;
cp.data = &data;
cp.e = &ec;
- uint32_t action = data.mwt_explorer->Choose_Action(*data.generic_explorer.get(), to_string((u64)ec.example_counter), cp);
+ uint32_t action = data.mwt_explorer->Choose_Action(*data.generic_explorer.get(), to_string((unsigned long long)ec.example_counter), cp);
if (is_learn)
{
@@ -390,7 +390,7 @@ namespace CBIFY {
if (!vm.count("cbify"))
return NULL;
- cbify* data = (cbify*)calloc_or_die(1, sizeof(cbify));
+ cbify* data = calloc_or_die<cbify>();
data->all = &all;
data->k = (uint32_t)vm["cbify"].as<size_t>();
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index b4945f0f..f5e5d91f 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -76,8 +76,7 @@ namespace CSOAA {
vm = add_options(all, opts);
if(!vm.count("csoaa"))
return NULL;
-
- csoaa* c=(csoaa*)calloc_or_die(1,sizeof(csoaa));
+ csoaa* c = calloc_or_die<csoaa>();
c->all = &all;
//first parse for number of actions
uint32_t nb_actions = 0;
@@ -668,7 +667,7 @@ namespace LabelDict {
if(!vm.count("csoaa_ldf") && !vm.count("wap_ldf"))
return NULL;
- ldf* ld = (ldf*)calloc_or_die(1, sizeof(ldf));
+ ldf* ld = calloc_or_die<ldf>();
ld->all = &all;
ld->need_to_clear = true;
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc
index b2eac7cd..5acbd525 100644
--- a/vowpalwabbit/ect.cc
+++ b/vowpalwabbit/ect.cc
@@ -375,8 +375,8 @@ namespace ECT
vm = add_options(all, opts);
if (!vm.count("ect"))
return NULL;
-
- ect* data = (ect*)calloc_or_die(1, sizeof(ect));
+
+ ect* data = calloc_or_die<ect>();
//first parse for number of actions
data->k = (int)vm["ect"].as<size_t>();
diff --git a/vowpalwabbit/example.cc b/vowpalwabbit/example.cc
index 2db36e46..f38bd041 100644
--- a/vowpalwabbit/example.cc
+++ b/vowpalwabbit/example.cc
@@ -40,9 +40,9 @@ float collision_cleanup(feature* feature_map, size_t& len) {
audit_data copy_audit_data(audit_data &src) {
audit_data dst;
- dst.space = (char*)calloc_or_die(strlen(src.space)+1, sizeof(char));
+ dst.space = calloc_or_die<char>(strlen(src.space)+1);
strcpy(dst.space, src.space);
- dst.feature = (char*)calloc_or_die(strlen(src.feature)+1, sizeof(char));
+ dst.feature = calloc_or_die<char>(strlen(src.feature)+1);
strcpy(dst.feature, src.feature);
dst.weight_index = src.weight_index;
dst.x = src.x;
@@ -135,13 +135,13 @@ void return_features(feature* f)
flat_example* flatten_example(vw& all, example *ec)
{
- flat_example* fec = (flat_example*) calloc_or_die(1,sizeof(flat_example));
+ flat_example* fec = calloc_or_die<flat_example>();
fec->l = ec->l;
fec->tag_len = ec->tag.size();
if (fec->tag_len >0)
{
- fec->tag = (char*)calloc_or_die(fec->tag_len+1, sizeof(char));
+ fec->tag = calloc_or_die<char>(fec->tag_len+1);
memcpy(fec->tag,ec->tag.begin, fec->tag_len);
}
@@ -176,7 +176,7 @@ void free_flatten_example(flat_example* fec)
example *alloc_examples(size_t label_size, size_t count=1)
{
- example* ec = (example*)calloc_or_die(count, sizeof(example));
+ example* ec = calloc_or_die<example>(count);
if (ec == NULL) return NULL;
for (size_t i=0; i<count; i++) {
ec[i].in_use = true;
diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc
new file mode 100644
index 00000000..12363994
--- /dev/null
+++ b/vowpalwabbit/ftrl_proximal.cc
@@ -0,0 +1,229 @@
+/*
+ Copyright (c) by respective owners including Yahoo!, Microsoft, and
+ individual contributors. All rights reserved. Released under a BSD (revised)
+ license as described in the file LICENSE.
+ */
+#include <fstream>
+#include <float.h>
+#ifndef _WIN32
+#include <netdb.h>
+#endif
+#include <string.h>
+#include <stdio.h>
+#include <assert.h>
+#include <sys/timeb.h>
+#include "parse_example.h"
+#include "constant.h"
+#include "cache.h"
+#include "simple_label.h"
+#include "vw.h"
+#include "gd.h"
+#include "accumulate.h"
+#include "memory.h"
+#include <exception>
+
+using namespace std;
+using namespace LEARNER;
+
+
+#define W_XT 0 // current parameter w(XT)
+#define W_GT 1 // current gradient g(GT)
+#define W_ZT 2 // accumulated z(t) = z(t-1) + g(t) + sigma*w(t)
+#define W_G2 3 // accumulated gradient squre n(t) = n(t-1) + g(t)*g(t)
+
+/********************************************************************/
+/* mem & w definition ***********************************************/
+/********************************************************************/
+// w[0] = current weight
+// w[1] = current first derivative
+// w[2] = accumulated zt
+// w[3] = accumulated g2
+
+namespace FTRL {
+
+ //nonrentrant
+ struct ftrl {
+
+ vw* all;
+ // set by initializer
+ float ftrl_alpha;
+ float ftrl_beta;
+
+ // evaluation file pointer
+ FILE* fo;
+ bool progressive_validation;
+ };
+
+ void update_accumulated_state(weight* w, float ftrl_alpha) {
+ double ng2 = w[W_G2] + w[W_GT]*w[W_GT];
+ double sigma = (sqrt(ng2) - sqrt(w[W_G2]))/ ftrl_alpha;
+ w[W_ZT] += w[W_GT] - sigma * w[W_XT];
+ w[W_G2] = ng2;
+ }
+
+ struct update_data {
+ float update;
+ float ftrl_alpha;
+ float ftrl_beta;
+ float l1_lambda;
+ float l2_lambda;
+ };
+
+ //void update_grad(weight* weights, size_t mask, float loss_grad)
+ void update_grad(update_data& d, float x, float& wref) {
+ float* w = &wref;
+ w[W_GT] = d.update * x;
+ update_accumulated_state(w, d.ftrl_alpha);
+ }
+
+ float ftrl_predict(vw& all, example& ec) {
+ ec.partial_prediction = GD::inline_predict(all, ec);
+ return GD::finalize_prediction(all.sd, ec.partial_prediction);
+ }
+
+ float predict_and_gradient(vw& all, ftrl &b, example& ec) {
+ float fp = ftrl_predict(all, ec);
+ ec.updated_prediction = fp;
+
+ label_data& ld = ec.l.simple;
+ all.set_minmax(all.sd, ld.label);
+
+ struct update_data data;
+
+ data.update = all.loss->first_derivative(all.sd, fp, ld.label) * ld.weight;
+ data.ftrl_alpha = b.ftrl_alpha;
+
+ GD::foreach_feature<update_data,update_grad>(all, ec, data);
+
+ return fp;
+ }
+
+ inline float sign(float w){ if (w < 0.) return -1.; else return 1.;}
+
+ void update_w(update_data& d, float x, float& wref) {
+ float* w = &wref;
+ float flag = sign(w[W_ZT]);
+ float fabs_zt = w[W_ZT] * flag;
+ if (fabs_zt <= d.l1_lambda) {
+ w[W_XT] = 0.;
+ } else {
+ double step = 1/(d.l2_lambda + (d.ftrl_beta + sqrt(w[W_G2]))/d.ftrl_alpha);
+ w[W_XT] = step * flag * (d.l1_lambda - fabs_zt);
+ }
+ }
+
+ void update_weight(vw& all, ftrl &b, example& ec) {
+
+ struct update_data data;
+
+ data.ftrl_alpha = b.ftrl_alpha;
+ data.ftrl_beta = b.ftrl_beta;
+ data.l1_lambda = all.l1_lambda;
+ data.l2_lambda = all.l2_lambda;
+
+ GD::foreach_feature<update_data, update_w>(all, ec, data);
+
+ }
+
+ void evaluate_example(vw& all, ftrl& b , example& ec) {
+ label_data& ld = ec.l.simple;
+ ec.loss = all.loss->getLoss(all.sd, ec.updated_prediction, ld.label) * ld.weight;
+ if (b.progressive_validation) {
+ float v = 1./(1 + exp(-ec.updated_prediction));
+ fprintf(b.fo, "%.6f\t%d\n", v, (int)(ld.label * ld.weight));
+ }
+ }
+
+ //void learn(void* a, void* d, example* ec) {
+ void learn(ftrl& a, learner& base, example& ec) {
+ vw* all = a.all;
+ assert(ec.in_use);
+
+ // predict w*x, compute gradient, update accumulate state
+ ec.pred.scalar = predict_and_gradient(*all, a, ec);
+ // evaluate, statistic
+ evaluate_example(*all, a, ec);
+ // update weight
+ update_weight(*all, a, ec);
+ }
+
+ void save_load(ftrl& b, io_buf& model_file, bool read, bool text) {
+ vw* all = b.all;
+ if (read) {
+ initialize_regressor(*all);
+ }
+
+ if (model_file.files.size() > 0) {
+ bool resume = all->save_resume;
+ char buff[512];
+ uint32_t text_len = sprintf(buff, ":%d\n", resume);
+ bin_text_read_write_fixed(model_file,(char *)&resume, sizeof (resume), "", read, buff, text_len, text);
+
+ if (resume) {
+ GD::save_load_online_state(*all, model_file, read, text);
+ //save_load_online_state(*all, model_file, read, text);
+ } else {
+ GD::save_load_regressor(*all, model_file, read, text);
+ }
+ }
+
+ }
+
+ // placeholder
+ void predict(ftrl& b, learner& base, example& ec)
+ {
+ vw* all = b.all;
+ //ec.l.simple.prediction = ftrl_predict(*all,ec);
+ ec.pred.scalar = ftrl_predict(*all,ec);
+ }
+
+ learner* setup(vw& all, po::variables_map& vm) {
+
+ ftrl* b = calloc_or_die<ftrl>();
+ b->all = &all;
+ b->ftrl_beta = 0.0;
+ b->ftrl_alpha = 0.1;
+
+ po::options_description ftrl_opts("FTRL options");
+
+ ftrl_opts.add_options()
+ ("ftrl_alpha", po::value<float>(&(b->ftrl_alpha)), "Learning rate for FTRL-proximal optimization")
+ ("ftrl_beta", po::value<float>(&(b->ftrl_beta)), "FTRL beta")
+ ("progressive_validation", po::value<string>()->default_value("ftrl.evl"), "File to record progressive validation for ftrl-proximal");
+
+ vm = add_options(all, ftrl_opts);
+
+ if (vm.count("ftrl_alpha")) {
+ b->ftrl_alpha = vm["ftrl_alpha"].as<float>();
+ }
+
+ if (vm.count("ftrl_beta")) {
+ b->ftrl_beta = vm["ftrl_beta"].as<float>();
+ }
+
+ all.reg.stride_shift = 2; // NOTE: for more parameter storage
+
+ b->progressive_validation = false;
+ if (vm.count("progressive_validation")) {
+ std::string filename = vm["progressive_validation"].as<string>();
+ b->fo = fopen(filename.c_str(), "w");
+ assert(b->fo != NULL);
+ b->progressive_validation = true;
+ }
+
+ if (!all.quiet) {
+ cerr << "Enabling FTRL-Proximal based optimization" << endl;
+ cerr << "ftrl_alpha = " << b->ftrl_alpha << endl;
+ cerr << "ftrl_beta = " << b->ftrl_beta << endl;
+ }
+
+ learner* l = new learner(b, 1 << all.reg.stride_shift);
+ l->set_learn<ftrl, learn>();
+ l->set_predict<ftrl, predict>();
+ l->set_save_load<ftrl,save_load>();
+
+ return l;
+ }
+
+
+} // end namespace
diff --git a/vowpalwabbit/ftrl_proximal.h b/vowpalwabbit/ftrl_proximal.h
new file mode 100644
index 00000000..934d91c8
--- /dev/null
+++ b/vowpalwabbit/ftrl_proximal.h
@@ -0,0 +1,12 @@
+/*
+Copyright (c) by respective owners including Yahoo!, Microsoft, and
+individual contributors. All rights reserved. Released under a BSD
+license as described in the file LICENSE.
+ */
+#ifndef FTRL_PROXIMAL_H
+#define FTRL_PROXIMAL_H
+
+namespace FTRL {
+ LEARNER::learner* setup(vw& all, po::variables_map& vm);
+}
+#endif
diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc
index f0736c90..46a71873 100644
--- a/vowpalwabbit/gd.cc
+++ b/vowpalwabbit/gd.cc
@@ -32,7 +32,7 @@ using namespace LEARNER;
namespace GD
{
struct gd{
- double normalized_sum_norm_x;
+ //double normalized_sum_norm_x;
double total_weight;
size_t no_win_counter;
size_t early_stop_thres;
@@ -75,14 +75,14 @@ namespace GD
if (normalized) {
if (sqrt_rate)
{
- float avg_norm = (float) g.total_weight / (float) g.normalized_sum_norm_x;
+ float avg_norm = (float) g.total_weight / (float) g.all->normalized_sum_norm_x;
if (adaptive)
return sqrt(avg_norm);
else
return avg_norm;
}
else
- return powf( (float) g.normalized_sum_norm_x / (float) g.total_weight, g.neg_norm_power);
+ return powf( (float) g.all->normalized_sum_norm_x / (float) g.total_weight, g.neg_norm_power);
}
return 1.f;
}
@@ -451,7 +451,7 @@ template<bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normaliz
foreach_feature<norm_data,pred_per_update_feature<sqrt_rate, feature_mask_off, adaptive, normalized, spare> >(all, ec, nd);
if(normalized) {
- g.normalized_sum_norm_x += ld.weight * nd.norm_x;
+ g.all->normalized_sum_norm_x += ld.weight * nd.norm_x;
g.total_weight += ld.weight;
g.update_multiplier = average_update<sqrt_rate, adaptive, normalized>(g, nd.pred_per_update);
@@ -609,9 +609,10 @@ void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text)
while ((!read && i < length) || (read && brw >0));
}
-void save_load_online_state(gd& g, io_buf& model_file, bool read, bool text)
+//void save_load_online_state(gd& g, io_buf& model_file, bool read, bool text)
+void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text)
{
- vw& all = *g.all;
+ //vw& all = *g.all;
char buff[512];
@@ -620,10 +621,10 @@ void save_load_online_state(gd& g, io_buf& model_file, bool read, bool text)
"", read,
buff, text_len, text);
- text_len = sprintf(buff, "norm normalizer %f\n", g.normalized_sum_norm_x);
- bin_text_read_write_fixed(model_file,(char*)&g.normalized_sum_norm_x, sizeof(g.normalized_sum_norm_x),
- "", read,
- buff, text_len, text);
+ text_len = sprintf(buff, "norm normalizer %f\n", all.normalized_sum_norm_x);
+ bin_text_read_write_fixed(model_file,(char*)&all.normalized_sum_norm_x, sizeof(all.normalized_sum_norm_x),
+ "", read,
+ buff, text_len, text);
text_len = sprintf(buff, "t %f\n", all.sd->t);
bin_text_read_write_fixed(model_file,(char*)&all.sd->t, sizeof(all.sd->t),
@@ -780,7 +781,8 @@ void save_load(gd& g, io_buf& model_file, bool read, bool text)
"", read,
buff, text_len, text);
if (resume)
- save_load_online_state(g, model_file, read, text);
+ //save_load_online_state(g, model_file, read, text);
+ save_load_online_state(all, model_file, read, text);
else
save_load_regressor(all, model_file, read, text);
}
@@ -850,9 +852,10 @@ learner* setup(vw& all, po::variables_map& vm)
("normalized", "use per feature normalized updates")
("exact_adaptive_norm", "use current default invariant normalized adaptive update rule");
vm = add_options(all, opts);
- gd* g = (gd*)calloc_or_die(1, sizeof(gd));
+ gd* g = calloc_or_die<gd>();
+
g->all = &all;
- g->normalized_sum_norm_x = 0;
+ g->all->normalized_sum_norm_x = 0;
g->no_win_counter = 0;
g->total_weight = 0.;
g->early_stop_thres = 3;
@@ -861,7 +864,7 @@ learner* setup(vw& all, po::variables_map& vm)
if(all.initial_t > 0)//for the normalized update: if initial_t is bigger than 1 we interpret this as if we had seen (all.initial_t) previous fake datapoints all with norm 1
{
- g->normalized_sum_norm_x = all.initial_t;
+ g->all->normalized_sum_norm_x = all.initial_t;
g->total_weight = all.initial_t;
}
diff --git a/vowpalwabbit/gd.h b/vowpalwabbit/gd.h
index 05bd5b5d..be52e62f 100644
--- a/vowpalwabbit/gd.h
+++ b/vowpalwabbit/gd.h
@@ -26,6 +26,7 @@ namespace GD{
void train_one_example_single_thread(regressor& r, example* ex);
LEARNER::learner* setup(vw& all, po::variables_map& vm);
void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text);
+ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text);
void output_and_account_example(example* ec);
// iterate through one namespace (or its part), callback function T(some_data_R, feature_value_x, feature_weight)
diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc
index db8cb8e5..36aa8f90 100644
--- a/vowpalwabbit/gd_mf.cc
+++ b/vowpalwabbit/gd_mf.cc
@@ -299,7 +299,7 @@ void mf_train(vw& all, example& ec)
else
all.rank = vm["gdmf"].as<uint32_t>();
- gdmf* data = (gdmf*)calloc_or_die(1,sizeof(gdmf));
+ gdmf* data = calloc_or_die<gdmf>();
data->all = &all;
// store linear + 2*rank weights per index, round up to power of two
diff --git a/vowpalwabbit/global_data.cc b/vowpalwabbit/global_data.cc
index fbffb83a..cdd870e5 100644
--- a/vowpalwabbit/global_data.cc
+++ b/vowpalwabbit/global_data.cc
@@ -235,7 +235,7 @@ po::variables_map add_options(vw& all, po::options_description& opts)
vw::vw()
{
- sd = (shared_data *) calloc_or_die(1, sizeof(shared_data));
+ sd = calloc_or_die<shared_data>();
sd->dump_interval = 1.; // next update progress dump
sd->contraction = 1.;
sd->max_label = 1.;
diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h
index a07792e4..9026d46b 100644
--- a/vowpalwabbit/global_data.h
+++ b/vowpalwabbit/global_data.h
@@ -196,6 +196,7 @@ struct vw {
bool hessian_on;
bool save_resume;
+ double normalized_sum_norm_x;
po::options_description opts;
std::string file_options;
diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc
index 17c86056..b28b2ad0 100644
--- a/vowpalwabbit/kernel_svm.cc
+++ b/vowpalwabbit/kernel_svm.cc
@@ -106,7 +106,7 @@ namespace KSVM
{
krow.delete_v();
// free flatten example contents
- flat_example *fec = (flat_example*)calloc_or_die(1, sizeof(flat_example));
+ flat_example *fec = calloc_or_die<flat_example>();
*fec = ex;
free_flatten_example(fec); // free contents of flat example and frees fec.
}
@@ -222,17 +222,17 @@ namespace KSVM
int save_load_flat_example(io_buf& model_file, bool read, flat_example*& fec) {
size_t brw = 1;
if(read) {
- fec = (flat_example*) calloc_or_die(1, sizeof(flat_example));
+ fec = calloc_or_die<flat_example>();
brw = bin_read_fixed(model_file, (char*) fec, sizeof(flat_example), "");
if(brw > 0) {
if(fec->tag_len > 0) {
- fec->tag = (char*) calloc_or_die(fec->tag_len, sizeof(char));
+ fec->tag = calloc_or_die<char>(fec->tag_len);
brw = bin_read_fixed(model_file, (char*) fec->tag, fec->tag_len*sizeof(char), "");
if(!brw) return 2;
}
if(fec->feature_map_len > 0) {
- fec->feature_map = (feature*) calloc_or_die(fec->feature_map_len, sizeof(feature));
+ fec->feature_map = calloc_or_die<feature>(fec->feature_map_len);
brw = bin_read_fixed(model_file, (char*) fec->feature_map, fec->feature_map_len*sizeof(feature), ""); if(!brw) return 3;
}
}
@@ -277,7 +277,7 @@ namespace KSVM
for(uint32_t i = 0;i < model->num_support;i++) {
if(read) {
save_load_flat_example(model_file, read, fec);
- svm_example* tmp= (svm_example*)calloc_or_die(1,sizeof(svm_example));
+ svm_example* tmp= calloc_or_die<svm_example>();
tmp->init_svm_example(fec);
model->support_vec.push_back(tmp);
}
@@ -398,7 +398,7 @@ namespace KSVM
void predict(svm_params& params, learner &base, example& ec) {
flat_example* fec = flatten_sort_example(*(params.all),&ec);
if(fec) {
- svm_example* sec = (svm_example*)calloc_or_die(1, sizeof(svm_example));
+ svm_example* sec = calloc_or_die<svm_example>();
sec->init_svm_example(fec);
float score;
predict(params, &sec, &score, 1);
@@ -556,7 +556,7 @@ namespace KSVM
}
- size_t* sizes = (size_t*) calloc_or_die(all.total, sizeof(size_t));
+ size_t* sizes = calloc_or_die<size_t>(all.total);
sizes[all.node] = b->space.end - b->space.begin;
//cerr<<"Sizes = "<<sizes[all.node]<<" ";
all_reduce<size_t, add_size_t>(sizes, all.total, all.span_server, all.unique_id, all.total, all.node, all.socks);
@@ -570,7 +570,7 @@ namespace KSVM
//cerr<<total_sum<<" "<<prev_sum<<endl;
if(total_sum > 0) {
- queries = (char*) calloc_or_die(total_sum, sizeof(char));
+ queries = calloc_or_die<char>(total_sum);
memcpy(queries + prev_sum, b->space.begin, b->space.end - b->space.begin);
b->space.delete_v();
all_reduce<char, copy_char>(queries, total_sum, all.span_server, all.unique_id, all.total, all.node, all.socks);
@@ -584,7 +584,7 @@ namespace KSVM
for(size_t i = 0;i < params.pool_size; i++) {
if(!save_load_flat_example(*b, true, fec)) {
- params.pool[i] = (svm_example*)calloc_or_die(1,sizeof(svm_example));
+ params.pool[i] = calloc_or_die<svm_example>();
params.pool[i]->init_svm_example(fec);
train_pool[i] = true;
params.pool_pos++;
@@ -617,11 +617,11 @@ namespace KSVM
//cerr<<"In train "<<params.all->training<<endl;
- bool* train_pool = (bool*)calloc_or_die(params.pool_size, sizeof(bool));
+ bool* train_pool = calloc_or_die<bool>(params.pool_size);
for(size_t i = 0;i < params.pool_size;i++)
train_pool[i] = false;
- float* scores = (float*)calloc_or_die(params.pool_pos, sizeof(float));
+ float* scores = calloc_or_die<float>(params.pool_pos);
predict(params, params.pool, scores, params.pool_pos);
//cout<<scores[0]<<endl;
@@ -690,7 +690,7 @@ namespace KSVM
bool overshoot = update(params, model_pos);
//cerr<<model_pos<<":alpha = "<<model->alpha[model_pos]<<endl;
- double* subopt = (double*)calloc_or_die(model->num_support,sizeof(double));
+ double* subopt = calloc_or_die<double>(model->num_support);
for(size_t j = 0;j < params.reprocess;j++) {
if(model->num_support == 0) break;
//cerr<<"reprocess: ";
@@ -739,7 +739,7 @@ namespace KSVM
// cout<<i<<":"<<fec->feature_map[i].x<<" "<<fec->feature_map[i].weight_index<<" ";
// cout<<endl;
if(fec) {
- svm_example* sec = (svm_example*)calloc_or_die(1, sizeof(svm_example));
+ svm_example* sec = calloc_or_die<svm_example>();
sec->init_svm_example(fec);
float score = 0;
predict(params, &sec, &score, 1);
@@ -814,8 +814,8 @@ namespace KSVM
delete all.loss;
all.loss = getLossFunction(&all, loss_function, (float)loss_parameter);
- svm_params* params = (svm_params*) calloc_or_die(1,sizeof(svm_params));
- params->model = (svm_model*) calloc_or_die(1,sizeof(svm_model));
+ svm_params* params = calloc_or_die<svm_params>();
+ params->model = calloc_or_die<svm_model>();
params->model->num_support = 0;
//params->curcache = 0;
params->maxcache = 1024*1024*1024;
@@ -845,7 +845,7 @@ namespace KSVM
else
params->pool_size = 1;
- params->pool = (svm_example**)calloc_or_die(params->pool_size, sizeof(svm_example*));
+ params->pool = calloc_or_die<svm_example*>(params->pool_size);
params->pool_pos = 0;
if(vm.count("subsample"))
@@ -885,7 +885,7 @@ namespace KSVM
all.file_options.append(ss.str());
}
cerr<<"bandwidth = "<<bandwidth<<endl;
- params->kernel_params = calloc_or_die(1,sizeof(double*));
+ params->kernel_params = calloc_or_die<double>();
*((float*)params->kernel_params) = bandwidth;
}
else if(kernel_type.compare("poly") == 0) {
@@ -898,14 +898,14 @@ namespace KSVM
all.file_options.append(ss.str());
}
cerr<<"degree = "<<degree<<endl;
- params->kernel_params = calloc_or_die(1,sizeof(int*));
+ params->kernel_params = calloc_or_die<int>();
*((int*)params->kernel_params) = degree;
}
else
params->kernel_type = SVM_KER_LIN;
- all.reg.weight_mask = (uint32_t)FLT_MAX;
- params->all->reg.weight_mask = (uint32_t)FLT_MAX;
+ params->all->reg.weight_mask = (uint32_t)LONG_MAX;
+ params->all->reg.stride_shift = 0;
learner* l = new learner(params, 1);
l->set_learn<svm_params, learn>();
diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc
index 1fac43cc..c7d9ede2 100644
--- a/vowpalwabbit/lda_core.cc
+++ b/vowpalwabbit/lda_core.cc
@@ -753,6 +753,7 @@ void end_examples(lda& l)
ld.v.delete_v();
}
+
learner* setup(vw&all, po::variables_map& vm)
{
po::options_description opts("Lda options");
@@ -769,7 +770,7 @@ void end_examples(lda& l)
else
all.lda = vm["lda"].as<uint32_t>();
- lda* ld = (lda*)calloc_or_die(1,sizeof(lda));
+ lda* ld = calloc_or_die<lda>();
ld->lda = all.lda;
ld->lda_alpha = vm["lda_alpha"].as<float>();
diff --git a/vowpalwabbit/learner.cc b/vowpalwabbit/learner.cc
index e718d309..40d37ba3 100644
--- a/vowpalwabbit/learner.cc
+++ b/vowpalwabbit/learner.cc
@@ -16,44 +16,55 @@ void dispatch_example(vw& all, example& ec)
namespace LEARNER
{
- void generic_driver(vw* all)
+ void generic_driver(vw& all)
{
example* ec = NULL;
- all->l->init_driver();
- while ( true )
+ all.l->init_driver();
+ while ( all.early_terminate == false )
{
- if ((ec = VW::get_example(all->p)) != NULL)//semiblocking operation.
+ if ((ec = VW::get_example(all.p)) != NULL)//semiblocking operation.
{
if (ec->indices.size() > 1) // 1+ nonconstant feature. (most common case first)
- dispatch_example(*all, *ec);
+ dispatch_example(all, *ec);
else if (ec->end_pass)
{
- all->l->end_pass();
- VW::finish_example(*all,ec);
+ all.l->end_pass();
+ VW::finish_example(all, ec);
}
else if (ec->tag.size() >= 4 && !strncmp((const char*) ec->tag.begin, "save", 4))
{// save state command
- string final_regressor_name = all->final_regressor_name;
+ string final_regressor_name = all.final_regressor_name;
if ((ec->tag).size() >= 6 && (ec->tag)[4] == '_')
final_regressor_name = string(ec->tag.begin+5, (ec->tag).size()-5);
- if (!all->quiet)
+ if (!all.quiet)
cerr << "saving regressor to " << final_regressor_name << endl;
- save_predictor(*all, final_regressor_name, 0);
+ save_predictor(all, final_regressor_name, 0);
- VW::finish_example(*all,ec);
+ VW::finish_example(all,ec);
}
else // empty example
- dispatch_example(*all, *ec);
+ dispatch_example(all, *ec);
}
- else if (parser_done(all->p))
+ else if (parser_done(all.p))
{
- all->l->end_examples();
+ all.l->end_examples();
return;
}
}
+ if (all.early_terminate) //drain any extra examples from parser and call end_examples
+ while ( all.early_terminate == false )
+ {
+ if ((ec = VW::get_example(all.p)) != NULL)//semiblocking operation.
+ VW::finish_example(all, ec);
+ else if (parser_done(all.p))
+ {
+ all.l->end_examples();
+ return;
+ }
+ }
}
}
diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h
index 7deeb4d2..71571c40 100644
--- a/vowpalwabbit/learner.h
+++ b/vowpalwabbit/learner.h
@@ -50,7 +50,7 @@ namespace LEARNER
void (*finish_example_f)(vw&, void* data, example&);
};
- void generic_driver(vw* all);
+ void generic_driver(vw& all);
inline void generic_sl(void*, io_buf&, bool, bool) {}
inline void generic_learner(void* data, learner& base, example&) {}
@@ -184,8 +184,6 @@ public:
{finish_example_fd.data = learn_fd.data;
finish_example_fd.finish_example_f = tend_example<T,f>;}
- void driver(vw* all) {LEARNER::generic_driver(all);}
-
inline learner()
{
weights = 1;
diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc
index a2ad4952..7bb66e9a 100644
--- a/vowpalwabbit/log_multi.cc
+++ b/vowpalwabbit/log_multi.cc
@@ -110,6 +110,7 @@ namespace LOG_MULTI
node.parent = 0;
node.min_count = 0;
+ node.preds = v_init<node_pred>();
init_leaf(node);
return node;
diff --git a/vowpalwabbit/loss_functions.cc b/vowpalwabbit/loss_functions.cc
index 0cdc5fa7..4d67ae5e 100644
--- a/vowpalwabbit/loss_functions.cc
+++ b/vowpalwabbit/loss_functions.cc
@@ -127,7 +127,8 @@ public:
}
float getLoss(shared_data*, float prediction, float label) {
- assert(label == -1.f || label == 1.f);
+ if (label != -1.f && label != 1.f)
+ cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl;
float e = 1 - label*prediction;
return (e > 0) ? e : 0;
}
@@ -170,7 +171,8 @@ public:
}
float getLoss(shared_data*, float prediction, float label) {
- assert(label == -1.f || label == 1.f || label == FLT_MAX);
+ if (label != -1.f && label != 1.f)
+ cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl;
return log(1 + exp(-label * prediction));
}
diff --git a/vowpalwabbit/lrq.cc b/vowpalwabbit/lrq.cc
index 2c4ab028..3caf81ae 100644
--- a/vowpalwabbit/lrq.cc
+++ b/vowpalwabbit/lrq.cc
@@ -135,11 +135,11 @@ namespace LRQ {
if (iter == 0 && (all.audit || all.hash_inv))
{
- char* new_space = (char*)calloc_or_die(4, sizeof(char));
+ char* new_space = calloc_or_die<char>(4);
strcpy(new_space, "lrq");
size_t n_len = strlen(i->c_str () + 4);
size_t len = strlen(ra->feature) + n_len + 2;
- char* new_feature = (char*)calloc_or_die(len, sizeof(char));
+ char* new_feature = calloc_or_die<char>(len);
new_feature[0] = right;
new_feature[1] = '^';
strcat(new_feature, ra->feature);
diff --git a/vowpalwabbit/main.cc b/vowpalwabbit/main.cc
index 96e08aa1..6773d3be 100644
--- a/vowpalwabbit/main.cc
+++ b/vowpalwabbit/main.cc
@@ -45,7 +45,7 @@ int main(int argc, char *argv[])
VW::start_parser(*all);
- all->l->driver(all);
+ LEARNER::generic_driver(*all);
VW::end_parser(*all);
diff --git a/vowpalwabbit/memory.cc b/vowpalwabbit/memory.cc
index aa3772d3..ea23597c 100644
--- a/vowpalwabbit/memory.cc
+++ b/vowpalwabbit/memory.cc
@@ -1,19 +1,6 @@
#include <stdlib.h>
#include <iostream>
-void* calloc_or_die(size_t nmemb, size_t size)
-{
- if (nmemb == 0 || size == 0)
- return NULL;
-
- void* data = calloc(nmemb, size);
- if (data == NULL) {
- std::cerr << "internal error: memory allocation failed; dying!" << std::endl;
- throw std::exception();
- }
- return data;
-}
-
void free_it(void*ptr)
{
if (ptr != NULL)
diff --git a/vowpalwabbit/memory.h b/vowpalwabbit/memory.h
index 266290cd..63441f46 100644
--- a/vowpalwabbit/memory.h
+++ b/vowpalwabbit/memory.h
@@ -1,4 +1,20 @@
#pragma once
-void* calloc_or_die(size_t nmemb, size_t size);
+
+#include <stdlib.h>
+#include <iostream>
+
+template<class T>
+T* calloc_or_die(size_t nmemb = 1)
+{
+ if (nmemb == 0)
+ return NULL;
+
+ void* data = calloc(nmemb, sizeof(T));
+ if (data == NULL) {
+ std::cerr << "internal error: memory allocation failed; dying!" << std::endl;
+ throw std::exception();
+ }
+ return (T*)data;
+}
void free_it(void* ptr);
diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc
index e3fd0544..2ae38489 100644
--- a/vowpalwabbit/nn.cc
+++ b/vowpalwabbit/nn.cc
@@ -319,8 +319,8 @@ CONVERSE: // That's right, I'm using goto. So sue me.
vm = add_options(all, opts);
if(!vm.count("nn"))
return NULL;
-
- nn* n = (nn*)calloc_or_die(1,sizeof(nn));
+
+ nn* n = calloc_or_die<nn>();
n->all = &all;
//first parse for number of hidden units
diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc
index 7d6183b1..2fd3c9d3 100644
--- a/vowpalwabbit/oaa.cc
+++ b/vowpalwabbit/oaa.cc
@@ -85,7 +85,7 @@ namespace OAA {
if(!vm.count("oaa"))
return NULL;
- oaa* data = (oaa*)calloc_or_die(1, sizeof(oaa));
+ oaa* data = calloc_or_die<oaa>();
//first parse for number of actions
data->k = (uint32_t)vm["oaa"].as<size_t>();
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index 51c6708c..3136115e 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -36,6 +36,7 @@ license as described in the file LICENSE.
#include "gd_mf.h"
#include "mf.h"
#include "vw.h"
+#include "ftrl_proximal.h"
#include "rand48.h"
#include "parse_args.h"
#include "binary.h"
@@ -139,7 +140,7 @@ void parse_dictionary_argument(vw&all, string str) {
void parse_affix_argument(vw&all, string str) {
if (str.length() == 0) return;
- char* cstr = (char*)calloc_or_die(str.length()+1, sizeof(char));
+ char* cstr = calloc_or_die<char>(str.length()+1);
strcpy(cstr, str.c_str());
char*p = strtok(cstr, ",");
@@ -720,6 +721,8 @@ void parse_base_algorithm(vw& all, po::variables_map& vm)
{
// all.l = GD::setup(all, vm);
all.scorer = all.l;
+ if (vm.count("ftrl"))
+ all.l = FTRL::setup(all, vm);
}
void load_input_model(vw& all, po::variables_map& vm, io_buf& io_temp)
@@ -997,7 +1000,7 @@ namespace VW {
char** get_argv_from_string(string s, int& argc)
{
- char* c = (char*)calloc_or_die(s.length()+3, sizeof(char));
+ char* c = calloc_or_die<char>(s.length()+3);
c[0] = 'b';
c[1] = ' ';
strcpy(c+2, s.c_str());
@@ -1005,11 +1008,11 @@ namespace VW {
v_array<substring> foo = v_init<substring>();
tokenize(' ', ss, foo);
- char** argv = (char**)calloc_or_die(foo.size(), sizeof(char*));
+ char** argv = calloc_or_die<char*>(foo.size());
for (size_t i = 0; i < foo.size(); i++)
{
*(foo[i].end) = '\0';
- argv[i] = (char*)calloc_or_die(foo[i].end-foo[i].begin+1, sizeof(char));
+ argv[i] = calloc_or_die<char>(foo[i].end-foo[i].begin+1);
sprintf(argv[i],"%s",foo[i].begin);
}
diff --git a/vowpalwabbit/parse_example.cc b/vowpalwabbit/parse_example.cc
index 31ee3e1f..941f0366 100644
--- a/vowpalwabbit/parse_example.cc
+++ b/vowpalwabbit/parse_example.cc
@@ -20,7 +20,7 @@ char* copy(char* base)
{
size_t len = 0;
while (base[len++] != '\0');
- char* ret = (char *)calloc_or_die(len,sizeof(char));
+ char* ret = calloc_or_die<char>(len);
memcpy(ret,base,len);
return ret;
}
@@ -270,7 +270,7 @@ public:
{
if (base != NULL)
free(base);
- base = (char *) calloc_or_die(2,sizeof(char));
+ base = calloc_or_die<char>(2);
base[0] = ' ';
base[1] = '\0';
}
diff --git a/vowpalwabbit/parse_regressor.cc b/vowpalwabbit/parse_regressor.cc
index 6cd8c459..504db6a7 100644
--- a/vowpalwabbit/parse_regressor.cc
+++ b/vowpalwabbit/parse_regressor.cc
@@ -35,7 +35,7 @@ void initialize_regressor(vw& all)
size_t length = ((size_t)1) << all.num_bits;
all.reg.weight_mask = (length << all.reg.stride_shift) - 1;
- all.reg.weight_vector = (weight *)calloc_or_die(length << all.reg.stride_shift, sizeof(weight));
+ all.reg.weight_vector = calloc_or_die<weight>(length << all.reg.stride_shift);
if (all.reg.weight_vector == NULL)
{
cerr << all.program_name << ": Failed to allocate weight array with " << all.num_bits << " bits: try decreasing -b <bits>" << endl;
diff --git a/vowpalwabbit/parser.cc b/vowpalwabbit/parser.cc
index 4a230d99..f47d5a10 100644
--- a/vowpalwabbit/parser.cc
+++ b/vowpalwabbit/parser.cc
@@ -153,7 +153,7 @@ bool is_test_only(uint32_t counter, uint32_t period, uint32_t after, bool holdou
parser* new_parser()
{
- parser* ret = (parser*) calloc_or_die(1,sizeof(parser));
+ parser* ret = calloc_or_die<parser>();
ret->input = new io_buf;
ret->output = new io_buf;
ret->local_example_number = 0;
@@ -1170,7 +1170,7 @@ void initialize_examples(vw& all)
all.p->end_parsed_examples = 0;
all.p->done = false;
- all.p->examples = (example*)calloc_or_die(all.p->ring_size, sizeof(example));
+ all.p->examples = calloc_or_die<example>(all.p->ring_size);
for (size_t i = 0; i < all.p->ring_size; i++)
{
diff --git a/vowpalwabbit/print.cc b/vowpalwabbit/print.cc
index d5af84f3..c5c6566e 100644
--- a/vowpalwabbit/print.cc
+++ b/vowpalwabbit/print.cc
@@ -54,7 +54,7 @@ namespace PRINT
if(!vm.count("print"))
return NULL;
- print* p = (print*)calloc_or_die(1, sizeof(print));
+ print* p = calloc_or_die<print>();
p->all = &all;
size_t length = ((size_t)1) << all.num_bits;
all.reg.weight_mask = (length << all.reg.stride_shift) - 1;
diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc
index 02643b26..a76e11b1 100644
--- a/vowpalwabbit/scorer.cc
+++ b/vowpalwabbit/scorer.cc
@@ -51,7 +51,7 @@ namespace Scorer {
opts.add_options()
("link", po::value<string>()->default_value("identity"), "Specify the link function: identity, logistic or glf1");
vm = add_options(all, opts);
- scorer* s = (scorer*)calloc_or_die(1, sizeof(scorer));
+ scorer* s = calloc_or_die<scorer>();
s->all = &all;
learner* l = new learner(s, all.l);
diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc
index 79c82ffc..6f03f84e 100644
--- a/vowpalwabbit/search.cc
+++ b/vowpalwabbit/search.cc
@@ -383,8 +383,8 @@ namespace Search {
priv.dat_new_feature_ec->sum_feat_sq[priv.dat_new_feature_namespace] += f.x * f.x;
if (priv.all->audit) {
audit_data a = { NULL, NULL, f.weight_index, f.x, true };
- a.space = (char*)calloc_or_die(priv.dat_new_feature_feature_space->length()+1, sizeof(char));
- a.feature = (char*)calloc_or_die(priv.dat_new_feature_audit_ss.str().length() + 32, sizeof(char));
+ a.space = calloc_or_die<char>(priv.dat_new_feature_feature_space->length()+1);
+ a.feature = calloc_or_die<char>(priv.dat_new_feature_audit_ss.str().length() + 32);
strcpy(a.space, priv.dat_new_feature_feature_space->c_str());
int num = sprintf(a.feature, "fid=%lu_", (idx & mask) >> ss);
strcpy(a.feature+num, priv.dat_new_feature_audit_ss.str().c_str());
@@ -1809,7 +1809,7 @@ namespace Search {
if (all.args[i] == "--search_task" && all.args[i+1] != "hook")
all.args.erase(all.args.begin() + i, all.args.begin() + i + 2);
- search* sch = (search*)calloc_or_die(1,sizeof(search));
+ search* sch = calloc_or_die<search>();
sch->priv = new search_private();
search_initialize(&all, *sch);
search_private& priv = *sch->priv;
@@ -1896,7 +1896,7 @@ namespace Search {
"warning: you specified a different history length through --search_history_length than the one loaded from predictor. using loaded value of: ", "");
//check if the base learner is contextual bandit, in which case, we dont rollout all actions.
- priv.allowed_actions_cache = (polylabel*)calloc_or_die(1,sizeof(polylabel));
+ priv.allowed_actions_cache = calloc_or_die<polylabel>();
if (vm.count("cb")) {
priv.cb_learner = true;
CB::cb_label.default_label(priv.allowed_actions_cache);
@@ -2140,7 +2140,7 @@ namespace Search {
void predictor::make_new_pointer(v_array<action>& A, size_t new_size) {
size_t old_size = A.size();
action* old_pointer = A.begin;
- A.begin = (action*)calloc_or_die(new_size, sizeof(action));
+ A.begin = calloc_or_die<action>(new_size);
A.end = A.begin + new_size;
A.end_array = A.end;
memcpy(A.begin, old_pointer, old_size * sizeof(action));
diff --git a/vowpalwabbit/search_dep_parser.cc b/vowpalwabbit/search_dep_parser.cc
index c29c754d..9629c30a 100644
--- a/vowpalwabbit/search_dep_parser.cc
+++ b/vowpalwabbit/search_dep_parser.cc
@@ -45,7 +45,6 @@ namespace DepParserTask {
void initialize(Search::search& srn, size_t& num_actions, po::variables_map& vm) {
task_data *data = new task_data();
data->my_init_flag = false;
- //data->ex = (example*)calloc_or_die(1, sizeof(example));
data->ec_buf.resize(12, true);
data->children = new v_array<uint32_t>[6];
diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc
index 9b811e96..e7407f9d 100644
--- a/vowpalwabbit/sender.cc
+++ b/vowpalwabbit/sender.cc
@@ -109,7 +109,7 @@ learner* setup(vw& all, po::variables_map& vm)
if(!vm.count("sendto"))
return NULL;
- sender* s = (sender*)calloc_or_die(1,sizeof(sender));
+ sender* s = calloc_or_die<sender>();
s->sd = -1;
if (vm.count("sendto"))
{
@@ -118,7 +118,7 @@ learner* setup(vw& all, po::variables_map& vm)
}
s->all = &all;
- s->delay_ring = (example**) calloc_or_die(all.p->ring_size, sizeof(example*));
+ s->delay_ring = calloc_or_die<example*>(all.p->ring_size);
learner* l = new learner(s, 1);
l->set_learn<sender, learn>();
diff --git a/vowpalwabbit/simple_label.cc b/vowpalwabbit/simple_label.cc
index 9e220cde..3bf748de 100644
--- a/vowpalwabbit/simple_label.cc
+++ b/vowpalwabbit/simple_label.cc
@@ -15,8 +15,6 @@ char* bufread_simple_label(shared_data* sd, label_data* ld, char* c)
{
ld->label = *(float *)c;
c += sizeof(ld->label);
- if (sd->binary_label && fabs(ld->label) != 1.f && ld->label != FLT_MAX)
- cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl;
ld->weight = *(float *)c;
c += sizeof(ld->weight);
ld->initial = *(float *)c;
diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc
index 68077829..a12ce55d 100644
--- a/vowpalwabbit/stagewise_poly.cc
+++ b/vowpalwabbit/stagewise_poly.cc
@@ -129,7 +129,7 @@ namespace StagewisePoly
void depthsbits_create(stagewise_poly &poly)
{
- poly.depthsbits = (uint8_t *) calloc_or_die(1, depthsbits_sizeof(poly));
+ poly.depthsbits = calloc_or_die<uint8_t>(2 * poly.all->length());
for (uint32_t i = 0; i < poly.all->length() * 2; i += 2) {
poly.depthsbits[i] = default_depth;
poly.depthsbits[i+1] = indicator_bit;
@@ -247,7 +247,7 @@ namespace StagewisePoly
cout << ", new size " << poly.sd_len << endl;
#endif //DEBUG
free(poly.sd); //okay for null.
- poly.sd = (sort_data *) calloc_or_die(poly.sd_len, sizeof(sort_data));
+ poly.sd = calloc_or_die<sort_data>(poly.sd_len);
}
assert(len <= poly.sd_len);
}
@@ -672,7 +672,7 @@ namespace StagewisePoly
if (vm.count("stage_poly"))
return NULL;
- stagewise_poly *poly = (stagewise_poly *) calloc_or_die(1, sizeof(stagewise_poly));
+ stagewise_poly *poly = calloc_or_die<stagewise_poly>();
poly->all = &all;
depthsbits_create(*poly);
diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc
index e5098c7c..1aeb4127 100644
--- a/vowpalwabbit/topk.cc
+++ b/vowpalwabbit/topk.cc
@@ -120,7 +120,7 @@ namespace TOPK {
if(!vm.count("top"))
return NULL;
- topk* data = (topk*)calloc_or_die(1, sizeof(topk));
+ topk* data = calloc_or_die<topk>();
data->B = (uint32_t)vm["top"].as<size_t>();