diff options
author | John Langford <jl@hunch.net> | 2014-12-29 05:14:23 +0300 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-12-29 05:14:23 +0300 |
commit | 19b21b995948621aa0df7500f3f397d3bb869bb7 (patch) | |
tree | b02e336a168aa05744eeaf3da9f4b34e45148df3 | |
parent | 6ded0936d05668220e5e857fc96e08f2ce1939c4 (diff) |
many simplifications
-rw-r--r-- | vowpalwabbit/accumulate.cc | 23 | ||||
-rw-r--r-- | vowpalwabbit/active.cc | 2 | ||||
-rw-r--r-- | vowpalwabbit/active_interactor.cc | 6 | ||||
-rw-r--r-- | vowpalwabbit/autolink.cc | 1 | ||||
-rw-r--r-- | vowpalwabbit/best_constant.cc | 2 | ||||
-rw-r--r-- | vowpalwabbit/cache.cc | 6 | ||||
-rw-r--r-- | vowpalwabbit/io_buf.cc | 24 | ||||
-rw-r--r-- | vowpalwabbit/learner.cc | 3 | ||||
-rw-r--r-- | vowpalwabbit/main.cc | 9 | ||||
-rw-r--r-- | vowpalwabbit/memory.cc | 3 | ||||
-rw-r--r-- | vowpalwabbit/multiclass.cc | 42 | ||||
-rw-r--r-- | vowpalwabbit/network.cc | 1 | ||||
-rw-r--r-- | vowpalwabbit/noop.cc | 8 | ||||
-rw-r--r-- | vowpalwabbit/parse_args.cc | 8 | ||||
-rw-r--r-- | vowpalwabbit/parse_example.cc | 2 | ||||
-rw-r--r-- | vowpalwabbit/parse_primitives.cc | 12 | ||||
-rw-r--r-- | vowpalwabbit/print.cc | 13 | ||||
-rw-r--r-- | vowpalwabbit/rand48.cc | 13 | ||||
-rw-r--r-- | vowpalwabbit/search.cc | 4 | ||||
-rw-r--r-- | vowpalwabbit/search_sequencetask.cc | 2 | ||||
-rw-r--r-- | vowpalwabbit/sender.cc | 13 | ||||
-rw-r--r-- | vowpalwabbit/topk.cc | 35 | ||||
-rw-r--r-- | vowpalwabbit/unique_sort.cc | 4 |
23 files changed, 75 insertions, 161 deletions
diff --git a/vowpalwabbit/accumulate.cc b/vowpalwabbit/accumulate.cc index d6c5e71f..49d2e969 100644 --- a/vowpalwabbit/accumulate.cc +++ b/vowpalwabbit/accumulate.cc @@ -17,9 +17,7 @@ Alekh Agarwal and John Langford, with help Olivier Chapelle. using namespace std;
-void add_float(float& c1, const float& c2) {
- c1 += c2;
-}
+void add_float(float& c1, const float& c2) { c1 += c2; }
void accumulate(vw& all, string master_location, regressor& reg, size_t o) {
uint32_t length = 1 << all.num_bits; //This is size of gradient
@@ -27,15 +25,11 @@ void accumulate(vw& all, string master_location, regressor& reg, size_t o) { float* local_grad = new float[length];
weight* weights = reg.weight_vector;
for(uint32_t i = 0;i < length;i++)
- {
- local_grad[i] = weights[stride*i+o];
- }
+ local_grad[i] = weights[stride*i+o];
all_reduce<float, add_float>(local_grad, length, master_location, all.unique_id, all.total, all.node, all.socks);
for(uint32_t i = 0;i < length;i++)
- {
- weights[stride*i+o] = local_grad[i];
- }
+ weights[stride*i+o] = local_grad[i];
delete[] local_grad;
}
@@ -53,11 +47,11 @@ void accumulate_avg(vw& all, string master_location, regressor& reg, size_t o) { float numnodes = (float)all.total;
for(uint32_t i = 0;i < length;i++)
- local_grad[i] = weights[stride*i+o];
+ local_grad[i] = weights[stride*i+o];
all_reduce<float, add_float>(local_grad, length, master_location, all.unique_id, all.total, all.node, all.socks);
for(uint32_t i = 0;i < length;i++)
- weights[stride*i+o] = local_grad[i]/numnodes;
+ weights[stride*i+o] = local_grad[i]/numnodes;
delete[] local_grad;
}
@@ -83,17 +77,14 @@ void accumulate_weighted_avg(vw& all, string master_location, regressor& reg) { uint32_t length = 1 << all.num_bits; //This is the number of parameters
size_t stride = 1 << all.reg.stride_shift;
weight* weights = reg.weight_vector;
-
-
float* local_weights = new float[length];
for(uint32_t i = 0;i < length;i++)
local_weights[i] = weights[stride*i+1];
-
//First compute weights for averaging
all_reduce<float, add_float>(local_weights, length, master_location, all.unique_id, all.total, all.node, all.socks);
-
+
for(uint32_t i = 0;i < length;i++) //Compute weighted versions
if(local_weights[i] > 0) {
float ratio = weights[stride*i+1]/local_weights[i];
@@ -107,7 +98,7 @@ void accumulate_weighted_avg(vw& all, string master_location, regressor& reg) { local_weights[i] = 0;
weights[stride*i] = 0;
}
-
+
all_reduce<float, add_float>(weights, length*stride, master_location, all.unique_id, all.total, all.node, all.socks);
delete[] local_weights;
diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc index 627928d6..a1070be3 100644 --- a/vowpalwabbit/active.cc +++ b/vowpalwabbit/active.cc @@ -160,9 +160,7 @@ namespace ACTIVE { ("simulation", "active learning simulation mode") ("mellowness", po::value<float>(&(data.active_c0)), "active learning mellowness parameter c_0. Default 8") ; - vm = add_options(all, active_opts); - data.all=&all; //Create new learner diff --git a/vowpalwabbit/active_interactor.cc b/vowpalwabbit/active_interactor.cc index 1f67fa69..ad4f5869 100644 --- a/vowpalwabbit/active_interactor.cc +++ b/vowpalwabbit/active_interactor.cc @@ -19,12 +19,6 @@ license as described in the file LICENSE. #include <netdb.h> #endif -using std::cin; -using std::endl; -using std::cout; -using std::cerr; -using std::string; - using namespace std; int open_socket(const char* host, unsigned short port) diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc index 7b3bebd8..7cdaecef 100644 --- a/vowpalwabbit/autolink.cc +++ b/vowpalwabbit/autolink.cc @@ -14,7 +14,6 @@ namespace ALINK { { base.predict(ec); float base_pred = ec.pred.scalar; - // add features of label ec.indices.push_back(autolink_namespace); float sum_sq = 0; diff --git a/vowpalwabbit/best_constant.cc b/vowpalwabbit/best_constant.cc index 3642a014..c56cb336 100644 --- a/vowpalwabbit/best_constant.cc +++ b/vowpalwabbit/best_constant.cc @@ -76,13 +76,11 @@ bool get_best_constant(vw& all, float& best_constant, float& best_constant_loss) } else return false; - if (!is_more_than_two_labels_observed) best_constant_loss = ( all.loss->getLoss(all.sd, best_constant, label1) * label1_cnt + all.loss->getLoss(all.sd, best_constant, label2) * label2_cnt ) / (label1_cnt + label2_cnt); else best_constant_loss = FLT_MIN; - return true; } diff --git a/vowpalwabbit/cache.cc b/vowpalwabbit/cache.cc index f91eba3c..9881e7f8 100644 --- a/vowpalwabbit/cache.cc +++ b/vowpalwabbit/cache.cc @@ -7,8 +7,6 @@ license as described in the file LICENSE. #include "unique_sort.h" #include "global_data.h" -using namespace std; - const size_t neg_1 = 1; const size_t general = 2; @@ -40,9 +38,7 @@ size_t read_cached_tag(io_buf& cache, example* ae) return tag_size+sizeof(tag_size); } -struct one_float { - float f; -} +struct one_float { float f; } #ifndef _WIN32 __attribute__((packed)) #endif diff --git a/vowpalwabbit/io_buf.cc b/vowpalwabbit/io_buf.cc index 4dbf3cf1..ba220762 100644 --- a/vowpalwabbit/io_buf.cc +++ b/vowpalwabbit/io_buf.cc @@ -3,10 +3,7 @@ Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD (revised) license as described in the file LICENSE. */ -#include <string.h> - #include "io_buf.h" - #ifdef WIN32 #include <winsock2.h> #endif @@ -110,20 +107,17 @@ void buf_write(io_buf &o, char* &pointer, size_t n) } bool io_buf::is_socket(int f) -{ - // this appears to work in practice, but could probably be done in a cleaner fashion +{ // this appears to work in practice, but could probably be done in a cleaner fashion const int _nhandle = 32; return f >= _nhandle; } ssize_t io_buf::read_file_or_socket(int f, void* buf, size_t nbytes) { #ifdef _WIN32 - if (is_socket(f)) { + if (is_socket(f)) return recv(f, reinterpret_cast<char*>(buf), static_cast<int>(nbytes), 0); - } - else { + else return _read(f, buf, (unsigned int)nbytes); - } #else return read(f, buf, (unsigned int)nbytes); #endif @@ -132,12 +126,10 @@ ssize_t io_buf::read_file_or_socket(int f, void* buf, size_t nbytes) { ssize_t io_buf::write_file_or_socket(int f, const void* buf, size_t nbytes) { #ifdef _WIN32 - if (is_socket(f)) { + if (is_socket(f)) return send(f, reinterpret_cast<const char*>(buf), static_cast<int>(nbytes), 0); - } - else { + else return _write(f, buf, (unsigned int)nbytes); - } #else return write(f, buf, (unsigned int)nbytes); #endif @@ -146,12 +138,10 @@ ssize_t io_buf::write_file_or_socket(int f, const void* buf, size_t nbytes) void io_buf::close_file_or_socket(int f) { #ifdef _WIN32 - if (io_buf::is_socket(f)) { + if (io_buf::is_socket(f)) closesocket(f); - } - else { + else _close(f); - } #else close(f); #endif diff --git a/vowpalwabbit/learner.cc b/vowpalwabbit/learner.cc index 40d37ba3..0376a00c 100644 --- a/vowpalwabbit/learner.cc +++ b/vowpalwabbit/learner.cc @@ -2,8 +2,7 @@ #include "parser.h" #include "learner.h" #include "vw.h" - -void save_predictor(vw& all, string reg_name, size_t current_pass); +#include "parse_regressor.h" void dispatch_example(vw& all, example& ec) { diff --git a/vowpalwabbit/main.cc b/vowpalwabbit/main.cc index 6773d3be..c7f40326 100644 --- a/vowpalwabbit/main.cc +++ b/vowpalwabbit/main.cc @@ -3,12 +3,6 @@ Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ - -#include <math.h> -#include <iostream> -#include <fstream> -#include <float.h> -#include <time.h> #ifdef _WIN32 #include <WinSock2.h> #else @@ -17,7 +11,6 @@ license as described in the file LICENSE. #endif #include <sys/timeb.h> #include "global_data.h" -#include "parse_example.h" #include "parse_args.h" #include "accumulate.h" #include "best_constant.h" @@ -44,9 +37,7 @@ int main(int argc, char *argv[]) } VW::start_parser(*all); - LEARNER::generic_driver(*all); - VW::end_parser(*all); ftime(&t_end); diff --git a/vowpalwabbit/memory.cc b/vowpalwabbit/memory.cc index ea23597c..7e40bf71 100644 --- a/vowpalwabbit/memory.cc +++ b/vowpalwabbit/memory.cc @@ -1,7 +1,6 @@ #include <stdlib.h> -#include <iostream> -void free_it(void*ptr) +void free_it(void* ptr) { if (ptr != NULL) free(ptr); diff --git a/vowpalwabbit/multiclass.cc b/vowpalwabbit/multiclass.cc index 8736bc6e..7e1f7a8e 100644 --- a/vowpalwabbit/multiclass.cc +++ b/vowpalwabbit/multiclass.cc @@ -56,9 +56,7 @@ namespace MULTICLASS { ld->weight = 1.; } - void delete_label(void* v) - { - } + void delete_label(void* v) {} void parse_label(parser* p, shared_data*, void* v, v_array<substring>& words) { @@ -145,32 +143,32 @@ namespace MULTICLASS { void output_example(vw& all, example& ec) { multiclass ld = ec.l.multi; - + size_t loss = 1; if (ld.label == (uint32_t)ec.pred.multiclass) loss = 0; - + if(ec.test_only) - { - all.sd->weighted_holdout_examples += ld.weight;//test weight seen - all.sd->weighted_holdout_examples_since_last_dump += ld.weight; - all.sd->weighted_holdout_examples_since_last_pass += ld.weight; - all.sd->holdout_sum_loss += loss; - all.sd->holdout_sum_loss_since_last_dump += loss; - all.sd->holdout_sum_loss_since_last_pass += loss;//since last pass - } + { + all.sd->weighted_holdout_examples += ld.weight;//test weight seen + all.sd->weighted_holdout_examples_since_last_dump += ld.weight; + all.sd->weighted_holdout_examples_since_last_pass += ld.weight; + all.sd->holdout_sum_loss += loss; + all.sd->holdout_sum_loss_since_last_dump += loss; + all.sd->holdout_sum_loss_since_last_pass += loss;//since last pass + } else - { - all.sd->weighted_examples += ld.weight; - all.sd->total_features += ec.num_features; - all.sd->sum_loss += loss; - all.sd->sum_loss_since_last_dump += loss; - all.sd->example_number++; - } - + { + all.sd->weighted_examples += ld.weight; + all.sd->total_features += ec.num_features; + all.sd->sum_loss += loss; + all.sd->sum_loss_since_last_dump += loss; + all.sd->example_number++; + } + for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++) all.print(*sink, (float)ec.pred.multiclass, 0, ec.tag); - + MULTICLASS::print_update(all, ec); } } diff --git a/vowpalwabbit/network.cc b/vowpalwabbit/network.cc index 7e39e879..b2922063 100644 --- a/vowpalwabbit/network.cc +++ b/vowpalwabbit/network.cc @@ -18,7 +18,6 @@ license as described in the file LICENSE. #include <netdb.h> #include <strings.h> #endif -#include <stdlib.h> #include <string.h> #include <string> diff --git a/vowpalwabbit/noop.cc b/vowpalwabbit/noop.cc index 797bfc3e..0c883a8c 100644 --- a/vowpalwabbit/noop.cc +++ b/vowpalwabbit/noop.cc @@ -7,11 +7,9 @@ license as described in the file LICENSE. #include "reductions.h" -using namespace LEARNER; - namespace NOOP { - void learn(char&, base_learner&, example&) {} + void learn(char&, LEARNER::base_learner&, example&) {} - base_learner* setup(vw& all) - { return &init_learner<char>(NULL, learn, 1); } + LEARNER::base_learner* setup(vw& all) + { return &LEARNER::init_learner<char>(NULL, learn, 1); } } diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index ef65345c..a760261f 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -99,13 +99,13 @@ void parse_dictionary_argument(vw&all, string str) { ifstream infile(s); size_t def = (size_t)' '; for (string line; getline(infile, line);) { - char*c = (char*)line.c_str(); // we're throwing away const, which is dangerous... + char* c = (char*)line.c_str(); // we're throwing away const, which is dangerous... while (*c == ' ' || *c == '\t') ++c; // skip initial whitespace - char*d = c; + char* d = c; while (*d != ' ' && *d != '\t' && *d != '\n' && *d != '\0') ++d; // gobble up initial word if (d == c) continue; // no word if (*d != ' ' && *d != '\t') continue; // reached end of line - char*word = (char*)calloc(d-c, sizeof(char)); + char* word = calloc_or_die<char>(d-c); memcpy(word, c, d-c); substring ss = { word, word + (d - c) }; uint32_t hash = uniform_hash( ss.begin, ss.end-ss.begin, quadratic_constant); @@ -132,7 +132,7 @@ void parse_dictionary_argument(vw&all, string str) { cerr << "dictionary " << s << " contains " << map->size() << " item" << (map->size() == 1 ? "\n" : "s\n"); all.namespace_dictionaries[(size_t)ns].push_back(map); - dictionary_info info = { (char*)calloc(strlen(s)+1, sizeof(char)), map }; + dictionary_info info = { calloc_or_die<char>(strlen(s)+1), map }; strcpy(info.name, s); all.read_dictionaries.push_back(info); } diff --git a/vowpalwabbit/parse_example.cc b/vowpalwabbit/parse_example.cc index 941f0366..63bcb227 100644 --- a/vowpalwabbit/parse_example.cc +++ b/vowpalwabbit/parse_example.cc @@ -182,7 +182,7 @@ public: for (feature*f = feats->begin; f != feats->end; ++f) { uint32_t id = f->weight_index; size_t len = 2 + (feature_name.end-feature_name.begin) + 1 + (size_t)ceil(log10(id)) + 1; - char* str = (char*)calloc(len, sizeof(char)); + char* str = calloc_or_die<char>(len); str[0] = index; str[1] = '_'; char *c = str+2; diff --git a/vowpalwabbit/parse_primitives.cc b/vowpalwabbit/parse_primitives.cc index b08f05fb..4ed67313 100644 --- a/vowpalwabbit/parse_primitives.cc +++ b/vowpalwabbit/parse_primitives.cc @@ -11,8 +11,6 @@ license as described in the file LICENSE. #include "parse_primitives.h" #include "hash.h" -using namespace std; - void tokenize(char delim, substring s, v_array<substring>& ret, bool allow_empty) { ret.erase(); @@ -53,17 +51,15 @@ size_t hashstring (substring s, uint32_t h) } size_t hashall (substring s, uint32_t h) -{ - return uniform_hash((unsigned char *)s.begin, s.end - s.begin, h); -} +{ return uniform_hash((unsigned char *)s.begin, s.end - s.begin, h); } -hash_func_t getHasher(const string& s){ +hash_func_t getHasher(const std::string& s){ if (s=="strings") return hashstring; else if(s=="all") return hashall; else{ - cerr << "Unknown hash function: " << s.c_str() << ". Exiting " << endl; - throw exception(); + std::cerr << "Unknown hash function: " << s.c_str() << ". Exiting " << std::endl; + throw std::exception(); } } diff --git a/vowpalwabbit/print.cc b/vowpalwabbit/print.cc index 8cfe5dd7..d0dc2765 100644 --- a/vowpalwabbit/print.cc +++ b/vowpalwabbit/print.cc @@ -4,14 +4,9 @@ #include "float.h" #include "reductions.h" -using namespace LEARNER; - namespace PRINT { - struct print{ - vw* all; - - }; + struct print{ vw* all; }; void print_feature(vw& all, float value, float& weight) { @@ -23,7 +18,7 @@ namespace PRINT cout << " "; } - void learn(print& p, base_learner& base, example& ec) + void learn(print& p, LEARNER::base_learner& base, example& ec) { label_data& ld = ec.l.simple; if (ld.label != FLT_MAX) @@ -46,7 +41,7 @@ namespace PRINT cout << endl; } - base_learner* setup(vw& all) + LEARNER::base_learner* setup(vw& all) { print& p = calloc_or_die<print>(); p.all = &all; @@ -54,7 +49,7 @@ namespace PRINT all.reg.weight_mask = (length << all.reg.stride_shift) - 1; all.reg.stride_shift = 0; - learner<print>& ret = init_learner(&p, learn, 1); + LEARNER::learner<print>& ret = init_learner(&p, learn, 1); return make_base(ret); } } diff --git a/vowpalwabbit/rand48.cc b/vowpalwabbit/rand48.cc index 4ea4e75e..4288e64d 100644 --- a/vowpalwabbit/rand48.cc +++ b/vowpalwabbit/rand48.cc @@ -1,8 +1,5 @@ //A quick implementation similar to drand48 for cross-platform compatibility #include <stdint.h> -#include <iostream> -using namespace std; - // // NB: the 'ULL' suffix is not part of the constant it is there to // prevent truncation of constant to (32-bit long) when compiling @@ -25,15 +22,9 @@ float merand48(uint64_t& initial) uint64_t v = c; -void msrand48(uint64_t initial) -{ - v = initial; -} +void msrand48(uint64_t initial) { v = initial; } -float frand48() -{ - return merand48(v); -} +float frand48() { return merand48(v); } float frand48_noadvance() { diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc index 63ba16d8..762ef1e9 100644 --- a/vowpalwabbit/search.cc +++ b/vowpalwabbit/search.cc @@ -855,7 +855,7 @@ namespace Search { size_t sz = sizeof(size_t) + sizeof(ptag) + sizeof(int) + sizeof(size_t) + sizeof(size_t) + condition_on_cnt * (sizeof(ptag) + sizeof(action) + sizeof(char)); if (sz % 4 != 0) sz = 4 * (sz / 4 + 1); // make sure sz aligns to 4 so that uniform_hash does the right thing - unsigned char* item = (unsigned char*)calloc(sz, 1); + unsigned char* item = &calloc_or_die<unsigned char>(); unsigned char* here = item; *here = (unsigned char)sz; here += sizeof(size_t); *here = mytag; here += sizeof(ptag); @@ -2117,7 +2117,7 @@ namespace Search { void predictor::set_input_length(size_t input_length) { is_ldf = true; if (ec_alloced) ec = (example*)realloc(ec, input_length * sizeof(example)); - else ec = (example*)calloc(input_length, sizeof(example)); + else ec = calloc_or_die<example>(input_length); ec_cnt = input_length; ec_alloced = true; } diff --git a/vowpalwabbit/search_sequencetask.cc b/vowpalwabbit/search_sequencetask.cc index 24f97ad5..d92013b4 100644 --- a/vowpalwabbit/search_sequencetask.cc +++ b/vowpalwabbit/search_sequencetask.cc @@ -264,7 +264,7 @@ namespace SequenceTask_DemoLDF { // this is just to debug/show off how to do LD lab.costs.push_back(default_wclass); } - task_data* data = (task_data*)calloc(1, sizeof(task_data)); + task_data* data = &calloc_or_die<task_data>(); data->ldf_examples = ldf_examples; data->num_actions = num_actions; diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc index 3b14131f..a9ded7e4 100644 --- a/vowpalwabbit/sender.cc +++ b/vowpalwabbit/sender.cc @@ -21,9 +21,6 @@ #include "network.h" #include "reductions.h" -using namespace std; -using namespace LEARNER; - namespace SENDER { struct sender { io_buf* buf; @@ -69,7 +66,7 @@ void receive_result(sender& s) return_simple_example(*(s.all), NULL, *ec); } - void learn(sender& s, base_learner& base, example& ec) + void learn(sender& s, LEARNER::base_learner& base, example& ec) { if (s.received_index + s.all->p->ring_size / 2 - 1 == s.sent_index) receive_result(s); @@ -81,8 +78,7 @@ void receive_result(sender& s) s.delay_ring[s.sent_index++ % s.all->p->ring_size] = &ec; } - void finish_example(vw& all, sender&, example& ec) -{} + void finish_example(vw& all, sender&, example& ec){} void end_examples(sender& s) { @@ -100,7 +96,7 @@ void end_examples(sender& s) delete s.buf; } - base_learner* setup(vw& all, po::variables_map& vm, vector<string> pairs) + LEARNER::base_learner* setup(vw& all, po::variables_map& vm, vector<string> pairs) { sender& s = calloc_or_die<sender>(); s.sd = -1; @@ -113,11 +109,10 @@ void end_examples(sender& s) s.all = &all; s.delay_ring = calloc_or_die<example*>(all.p->ring_size); - learner<sender>& l = init_learner(&s, learn, 1); + LEARNER::learner<sender>& l = init_learner(&s, learn, 1); l.set_finish(finish); l.set_finish_example(finish_example); l.set_end_examples(end_examples); return make_base(l); } - } diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc index 8dd245f7..445bdb23 100644 --- a/vowpalwabbit/topk.cc +++ b/vowpalwabbit/topk.cc @@ -4,31 +4,23 @@ individual contributors. All rights reserved. Released under a BSD (revised) license as described in the file LICENSE. */ #include <float.h> -#include <math.h> -#include <stdio.h> #include <sstream> -#include <numeric> -#include <vector> #include <queue> #include "reductions.h" #include "vw.h" -using namespace std; -using namespace LEARNER; - -typedef pair<float, v_array<char> > scored_example; - -struct compare_scored_examples -{ +namespace TOPK { + typedef pair<float, v_array<char> > scored_example; + + struct compare_scored_examples + { bool operator()(scored_example const& a, scored_example const& b) const { - return a.first > b.first; + return a.first > b.first; } -}; - -namespace TOPK { - + }; + struct topk{ uint32_t B; //rec number priority_queue<scored_example, vector<scored_example>, compare_scored_examples > pr_queue; @@ -85,7 +77,7 @@ namespace TOPK { } template <bool is_learn> - void predict_or_learn(topk& d, base_learner& base, example& ec) + void predict_or_learn(topk& d, LEARNER::base_learner& base, example& ec) { if (example_is_newline(ec)) return;//do not predict newline @@ -102,7 +94,6 @@ namespace TOPK { d.pr_queue.pop(); d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); } - } void finish_example(vw& all, topk& d, example& ec) @@ -111,16 +102,14 @@ namespace TOPK { VW::finish_example(all, &ec); } - base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all, po::variables_map& vm) { topk& data = calloc_or_die<topk>(); - data.B = (uint32_t)vm["top"].as<size_t>(); - data.all = &all; - learner<topk>& l = init_learner(&data, all.l, predict_or_learn<true>, - predict_or_learn<false>); + LEARNER::learner<topk>& l = init_learner(&data, all.l, predict_or_learn<true>, + predict_or_learn<false>); l.set_finish_example(finish_example); return make_base(l); diff --git a/vowpalwabbit/unique_sort.cc b/vowpalwabbit/unique_sort.cc index c682cf63..1a323d2d 100644 --- a/vowpalwabbit/unique_sort.cc +++ b/vowpalwabbit/unique_sort.cc @@ -6,9 +6,7 @@ license as described in the file LICENSE. #include "global_data.h" int order_features(const void* first, const void* second) -{ - return ((feature*)first)->weight_index - ((feature*)second)->weight_index; -} +{ return ((feature*)first)->weight_index - ((feature*)second)->weight_index;} int order_audit_features(const void* first, const void* second) { |