diff options
author | John Langford <jl@hunch.net> | 2014-01-23 19:43:09 +0400 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-01-23 19:43:09 +0400 |
commit | f9a6b9bd2f40abf378481ae6ca92966dfdae2d5e (patch) | |
tree | bec1fcf318c70d61fed8955a4482bf43aadbf177 | |
parent | 238040fb9f5cfd408ea97c4349e8df005481241a (diff) | |
parent | cb8d90352a0007c8b133f535ad0f9111c7a58cee (diff) |
reconcile
-rw-r--r-- | vowpalwabbit/bs.cc | 6 | ||||
-rw-r--r-- | vowpalwabbit/csoaa.cc | 6 | ||||
-rw-r--r-- | vowpalwabbit/global_data.cc | 30 | ||||
-rw-r--r-- | vowpalwabbit/global_data.h | 1 | ||||
-rw-r--r-- | vowpalwabbit/io_buf.cc | 54 | ||||
-rw-r--r-- | vowpalwabbit/io_buf.h | 29 | ||||
-rw-r--r-- | vowpalwabbit/lda_core.cc | 12 | ||||
-rw-r--r-- | vowpalwabbit/parse_args.cc | 9 | ||||
-rw-r--r-- | vowpalwabbit/parser.cc | 12 | ||||
-rw-r--r-- | vowpalwabbit/simple_label.cc | 4 |
10 files changed, 94 insertions, 69 deletions
diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc index aabe326a..035a5e28 100644 --- a/vowpalwabbit/bs.cc +++ b/vowpalwabbit/bs.cc @@ -121,11 +121,7 @@ namespace BS { ss << temp; ss << '\n'; ssize_t len = ss.str().size(); -#ifdef _WIN32 - ssize_t t = _write(f, ss.str().c_str(), (unsigned int)len); -#else - ssize_t t = write(f, ss.str().c_str(), (unsigned int)len); -#endif + ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len); if (t != len) cerr << "write error" << endl; } diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc index 72db9c5d..0e97d7ac 100644 --- a/vowpalwabbit/csoaa.cc +++ b/vowpalwabbit/csoaa.cc @@ -1191,11 +1191,7 @@ namespace LabelDict { for (size_t i=0; i<all.final_prediction_sink.size(); i++) { int f = all.final_prediction_sink[i]; ssize_t t; -#ifdef _WIN32 - t = _write(f, temp, 1); -#else - t = write(f, temp, 1); -#endif + t = io_buf::write_file_or_socket(f, temp, 1); if (t != 1) std::cerr << "write error" << std::endl; } diff --git a/vowpalwabbit/global_data.cc b/vowpalwabbit/global_data.cc index f500f478..bf109bbb 100644 --- a/vowpalwabbit/global_data.cc +++ b/vowpalwabbit/global_data.cc @@ -31,7 +31,7 @@ size_t really_read(int sock, void* in, size_t count) { if ((r = #ifdef _WIN32 - _read(sock,buf,(unsigned int)(count-done)) + recv(sock,buf,(unsigned int)(count-done),0) #else read(sock,buf,(unsigned int)(count-done)) #endif @@ -65,7 +65,7 @@ void send_prediction(int sock, global_prediction p) { if ( #ifdef _WIN32 - _write(sock, &p, sizeof(p)) + send(sock, reinterpret_cast<const char*>(&p), sizeof(p), 0) #else write(sock, &p, sizeof(p)) #endif @@ -106,11 +106,7 @@ void print_result(int f, float res, float weight, v_array<char> tag) print_tag(ss, tag); ss << '\n'; ssize_t len = ss.str().size(); -#ifdef _WIN32 - ssize_t t = _write(f, ss.str().c_str(), (unsigned int)len); -#else - ssize_t t = write(f, ss.str().c_str(), (unsigned int)len); -#endif + ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len); if (t != len) { cerr << "write error" << endl; @@ -128,11 +124,7 @@ void print_raw_text(int f, string s, v_array<char> tag) print_tag (ss, tag); ss << '\n'; ssize_t len = ss.str().size(); -#ifdef _WIN32 - ssize_t t = _write(f, ss.str().c_str(), (unsigned int)len); -#else - ssize_t t = write(f, ss.str().c_str(), (unsigned int)len); -#endif + ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len); if (t != len) { cerr << "write error" << endl; @@ -156,11 +148,7 @@ void active_print_result(int f, float res, float weight, v_array<char> tag) } ss << '\n'; ssize_t len = ss.str().size(); -#ifdef _WIN32 - ssize_t t = _write(f, ss.str().c_str(), (unsigned int)len); -#else - ssize_t t = write(f, ss.str().c_str(), (unsigned int)len); -#endif + ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len); if (t != len) cerr << "write error" << endl; } @@ -180,11 +168,8 @@ void print_lda_result(vw& all, int f, float* res, float weight, v_array<char> ta print_tag(ss, tag); ss << '\n'; ssize_t len = ss.str().size(); -#ifdef _WIN32 - ssize_t t = _write(f, ss.str().c_str(), (unsigned int)len); -#else - ssize_t t = write(f, ss.str().c_str(), (unsigned int)len); -#endif + ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len); + if (t != len) cerr << "write error" << endl; } @@ -253,6 +238,7 @@ vw::vw() lda_alpha = 0.1f; lda_rho = 0.1f; lda_D = 10000.; + lda_epsilon = 0.001; minibatch = 1; span_server = ""; m = 15; diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h index 0b78ef42..40e12718 100644 --- a/vowpalwabbit/global_data.h +++ b/vowpalwabbit/global_data.h @@ -220,6 +220,7 @@ struct vw { float lda_alpha; float lda_rho; float lda_D; + float lda_epsilon; std::string text_regressor_name; std::string inv_hash_regressor_name; diff --git a/vowpalwabbit/io_buf.cc b/vowpalwabbit/io_buf.cc index f633ce1f..4dbf3cf1 100644 --- a/vowpalwabbit/io_buf.cc +++ b/vowpalwabbit/io_buf.cc @@ -7,6 +7,10 @@ license as described in the file LICENSE. #include "io_buf.h" +#ifdef WIN32 +#include <winsock2.h> +#endif + size_t buf_read(io_buf &i, char* &pointer, size_t n) {//return a pointer to the next n bytes. n must be smaller than the maximum size. if (i.space.end + n <= i.endloaded) @@ -52,7 +56,7 @@ bool isbinary(io_buf &i) { size_t readto(io_buf &i, char* &pointer, char terminal) {//Return a pointer to the bytes before the terminal. Must be less than the buffer size. pointer = i.space.end; - while (pointer != i.endloaded && *pointer != terminal) + while (pointer < i.endloaded && *pointer != terminal) pointer++; if (pointer != i.endloaded) { @@ -104,3 +108,51 @@ void buf_write(io_buf &o, char* &pointer, size_t n) buf_write (o, pointer,n); } } + +bool io_buf::is_socket(int f) +{ + // this appears to work in practice, but could probably be done in a cleaner fashion + const int _nhandle = 32; + return f >= _nhandle; +} + +ssize_t io_buf::read_file_or_socket(int f, void* buf, size_t nbytes) { +#ifdef _WIN32 + if (is_socket(f)) { + return recv(f, reinterpret_cast<char*>(buf), static_cast<int>(nbytes), 0); + } + else { + return _read(f, buf, (unsigned int)nbytes); + } +#else + return read(f, buf, (unsigned int)nbytes); +#endif +} + +ssize_t io_buf::write_file_or_socket(int f, const void* buf, size_t nbytes) +{ +#ifdef _WIN32 + if (is_socket(f)) { + return send(f, reinterpret_cast<const char*>(buf), static_cast<int>(nbytes), 0); + } + else { + return _write(f, buf, (unsigned int)nbytes); + } +#else + return write(f, buf, (unsigned int)nbytes); +#endif +} + +void io_buf::close_file_or_socket(int f) +{ +#ifdef _WIN32 + if (io_buf::is_socket(f)) { + closesocket(f); + } + else { + _close(f); + } +#else + close(f); +#endif +} diff --git a/vowpalwabbit/io_buf.h b/vowpalwabbit/io_buf.h index e8b8157c..8b62ca76 100644 --- a/vowpalwabbit/io_buf.h +++ b/vowpalwabbit/io_buf.h @@ -120,13 +120,11 @@ class io_buf { void set(char *p){space.end = p;} virtual ssize_t read_file(int f, void* buf, size_t nbytes){ -#ifdef _WIN32 - return _read(f, buf, (unsigned int)nbytes); -#else - return read(f, buf, (unsigned int)nbytes); -#endif + return read_file_or_socket(f, buf, nbytes); } + static ssize_t read_file_or_socket(int f, void* buf, size_t nbytes); + size_t fill(int f) { if (space.end_array - endloaded == 0) { @@ -144,15 +142,12 @@ class io_buf { return 0; } - virtual ssize_t write_file(int f, const void* buf, size_t nbytes) - { -#ifdef _WIN32 - return _write(f, buf, (unsigned int)nbytes); -#else - return write(f, buf, (unsigned int)nbytes); -#endif + virtual ssize_t write_file(int f, const void* buf, size_t nbytes) { + return write_file_or_socket(f, buf, nbytes); } + static ssize_t write_file_or_socket(int f, const void* buf, size_t nbytes); + virtual void flush() { if (write_file(files[0], space.begin, space.size()) != (int) space.size()) std::cerr << "error, failed to write example\n"; @@ -160,19 +155,19 @@ class io_buf { virtual bool close_file(){ if(files.size()>0){ -#ifdef _WIN32 - _close(files.pop()); -#else - close(files.pop()); -#endif + close_file_or_socket(files.pop()); return true; } return false; } + static void close_file_or_socket(int f); + void close_files(){ while(close_file()); } + + static bool is_socket(int f); }; void buf_write(io_buf &o, char* &pointer, size_t n); diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc index c70c6551..fcf79afc 100644 --- a/vowpalwabbit/lda_core.cc +++ b/vowpalwabbit/lda_core.cc @@ -496,7 +496,7 @@ v_array<float> old_gamma; for (size_t k =0; k<all.lda; k++)
new_gamma[k] = new_gamma[k]*v[k]+all.lda_alpha;
}
- while (average_diff(all, old_gamma.begin, new_gamma.begin) > 0.001);
+ while (average_diff(all, old_gamma.begin, new_gamma.begin) > all.lda_epsilon);
ec->topic_predictions.erase();
ec->topic_predictions.resize(all.lda);
@@ -580,6 +580,15 @@ void save_load(lda& l, io_buf& model_file, bool read, bool text) void learn_batch(lda& l)
{
+ if (l.sorted_features.empty()) {
+ // This can happen when the socket connection is dropped by the client.
+ // If l.sorted_features is empty, then l.sorted_features[0] does not
+ // exist, so we should not try to take its address in the beginning of
+ // the for loops down there. Since it seems that there's not much to
+ // do in this case, we just return.
+ return;
+ }
+
float eta = -1;
float minuseta = -1;
@@ -754,6 +763,7 @@ learner* setup(vw&all, vector<string>&opts, po::variables_map& vm) ("lda_alpha", po::value<float>(&all.lda_alpha), "Prior on sparsity of per-document topic weights")
("lda_rho", po::value<float>(&all.lda_rho), "Prior on sparsity of topic distributions")
("lda_D", po::value<float>(&all.lda_D), "Number of documents")
+ ("lda_epsilon", po::value<float>(&all.lda_epsilon), "Loop convergence threshold")
("minibatch", po::value<size_t>(&all.minibatch), "Minibatch size, for LDA");
po::parsed_options parsed = po::command_line_parser(opts).
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index fdf0681d..ae6f7b4d 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -415,7 +415,8 @@ vw* parse_args(int argc, char *argv[]) if( all->normalized_updates ) all->reg.stride *= 2; if(!vm.count("learning_rate") && !vm.count("l") && !(all->adaptive && all->normalized_updates)) - all->eta = 10; //default learning rate to 10 for non default update rule + if (all->lda == 0) + all->eta = 10; //default learning rate to 10 for non default update rule //if not using normalized or adaptive, default initial_t to 1 instead of 0 if(!all->adaptive && !all->normalized_updates && !vm.count("initial_t")) { @@ -1094,11 +1095,7 @@ namespace VW { free(all.options_from_file_argv); for (size_t i = 0; i < all.final_prediction_sink.size(); i++) if (all.final_prediction_sink[i] != 1) -#ifdef _WIN32 - _close(all.final_prediction_sink[i]); -#else - close(all.final_prediction_sink[i]); -#endif + io_buf::close_file_or_socket(all.final_prediction_sink[i]); all.final_prediction_sink.delete_v(); delete all.loss; delete &all; diff --git a/vowpalwabbit/parser.cc b/vowpalwabbit/parser.cc index 6bee8cce..2c55f0f9 100644 --- a/vowpalwabbit/parser.cc +++ b/vowpalwabbit/parser.cc @@ -239,11 +239,7 @@ void reset_source(vw& all, size_t numbits) { int fd = input->files.pop(); if (!member(all.final_prediction_sink, (size_t) fd)) -#ifdef _WIN32 - _close(fd); -#else - close(fd); -#endif + io_buf::close_file_or_socket(fd); } input->open_file(all.p->output->finalname.begin, all.stdin_off, io_buf::READ); //pushing is merged into open_file all.p->reader = read_cached_features; @@ -259,11 +255,7 @@ void reset_source(vw& all, size_t numbits) mutex_unlock(&all.p->output_lock); // close socket, erase final prediction sink and socket -#ifdef _WIN32 - _close(all.p->input->files[0]); -#else - close(all.p->input->files[0]); -#endif + io_buf::close_file_or_socket(all.p->input->files[0]); all.final_prediction_sink.erase(); all.p->input->files.erase(); diff --git a/vowpalwabbit/simple_label.cc b/vowpalwabbit/simple_label.cc index fa25b050..d3646ec9 100644 --- a/vowpalwabbit/simple_label.cc +++ b/vowpalwabbit/simple_label.cc @@ -216,8 +216,8 @@ void output_and_account_example(vw& all, example& ec) for (size_t i = 0; i<all.final_prediction_sink.size(); i++) { int f = (int)all.final_prediction_sink[i]; - if(all.active) - active_print_result(f, ec.final_prediction, ai, ec.tag); + if(all.active && all.lda == 0) + active_print_result(f, ec->final_prediction, ai, ec->tag); else if (all.lda > 0) print_lda_result(all, f,ec.topic_predictions.begin,0.,ec.tag); else |