Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2014-01-23 19:43:09 +0400
committerJohn Langford <jl@hunch.net>2014-01-23 19:43:09 +0400
commitf9a6b9bd2f40abf378481ae6ca92966dfdae2d5e (patch)
treebec1fcf318c70d61fed8955a4482bf43aadbf177
parent238040fb9f5cfd408ea97c4349e8df005481241a (diff)
parentcb8d90352a0007c8b133f535ad0f9111c7a58cee (diff)
reconcile
-rw-r--r--vowpalwabbit/bs.cc6
-rw-r--r--vowpalwabbit/csoaa.cc6
-rw-r--r--vowpalwabbit/global_data.cc30
-rw-r--r--vowpalwabbit/global_data.h1
-rw-r--r--vowpalwabbit/io_buf.cc54
-rw-r--r--vowpalwabbit/io_buf.h29
-rw-r--r--vowpalwabbit/lda_core.cc12
-rw-r--r--vowpalwabbit/parse_args.cc9
-rw-r--r--vowpalwabbit/parser.cc12
-rw-r--r--vowpalwabbit/simple_label.cc4
10 files changed, 94 insertions, 69 deletions
diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc
index aabe326a..035a5e28 100644
--- a/vowpalwabbit/bs.cc
+++ b/vowpalwabbit/bs.cc
@@ -121,11 +121,7 @@ namespace BS {
ss << temp;
ss << '\n';
ssize_t len = ss.str().size();
-#ifdef _WIN32
- ssize_t t = _write(f, ss.str().c_str(), (unsigned int)len);
-#else
- ssize_t t = write(f, ss.str().c_str(), (unsigned int)len);
-#endif
+ ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
if (t != len)
cerr << "write error" << endl;
}
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index 72db9c5d..0e97d7ac 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -1191,11 +1191,7 @@ namespace LabelDict {
for (size_t i=0; i<all.final_prediction_sink.size(); i++) {
int f = all.final_prediction_sink[i];
ssize_t t;
-#ifdef _WIN32
- t = _write(f, temp, 1);
-#else
- t = write(f, temp, 1);
-#endif
+ t = io_buf::write_file_or_socket(f, temp, 1);
if (t != 1)
std::cerr << "write error" << std::endl;
}
diff --git a/vowpalwabbit/global_data.cc b/vowpalwabbit/global_data.cc
index f500f478..bf109bbb 100644
--- a/vowpalwabbit/global_data.cc
+++ b/vowpalwabbit/global_data.cc
@@ -31,7 +31,7 @@ size_t really_read(int sock, void* in, size_t count)
{
if ((r =
#ifdef _WIN32
- _read(sock,buf,(unsigned int)(count-done))
+ recv(sock,buf,(unsigned int)(count-done),0)
#else
read(sock,buf,(unsigned int)(count-done))
#endif
@@ -65,7 +65,7 @@ void send_prediction(int sock, global_prediction p)
{
if (
#ifdef _WIN32
- _write(sock, &p, sizeof(p))
+ send(sock, reinterpret_cast<const char*>(&p), sizeof(p), 0)
#else
write(sock, &p, sizeof(p))
#endif
@@ -106,11 +106,7 @@ void print_result(int f, float res, float weight, v_array<char> tag)
print_tag(ss, tag);
ss << '\n';
ssize_t len = ss.str().size();
-#ifdef _WIN32
- ssize_t t = _write(f, ss.str().c_str(), (unsigned int)len);
-#else
- ssize_t t = write(f, ss.str().c_str(), (unsigned int)len);
-#endif
+ ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
if (t != len)
{
cerr << "write error" << endl;
@@ -128,11 +124,7 @@ void print_raw_text(int f, string s, v_array<char> tag)
print_tag (ss, tag);
ss << '\n';
ssize_t len = ss.str().size();
-#ifdef _WIN32
- ssize_t t = _write(f, ss.str().c_str(), (unsigned int)len);
-#else
- ssize_t t = write(f, ss.str().c_str(), (unsigned int)len);
-#endif
+ ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
if (t != len)
{
cerr << "write error" << endl;
@@ -156,11 +148,7 @@ void active_print_result(int f, float res, float weight, v_array<char> tag)
}
ss << '\n';
ssize_t len = ss.str().size();
-#ifdef _WIN32
- ssize_t t = _write(f, ss.str().c_str(), (unsigned int)len);
-#else
- ssize_t t = write(f, ss.str().c_str(), (unsigned int)len);
-#endif
+ ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
if (t != len)
cerr << "write error" << endl;
}
@@ -180,11 +168,8 @@ void print_lda_result(vw& all, int f, float* res, float weight, v_array<char> ta
print_tag(ss, tag);
ss << '\n';
ssize_t len = ss.str().size();
-#ifdef _WIN32
- ssize_t t = _write(f, ss.str().c_str(), (unsigned int)len);
-#else
- ssize_t t = write(f, ss.str().c_str(), (unsigned int)len);
-#endif
+ ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
+
if (t != len)
cerr << "write error" << endl;
}
@@ -253,6 +238,7 @@ vw::vw()
lda_alpha = 0.1f;
lda_rho = 0.1f;
lda_D = 10000.;
+ lda_epsilon = 0.001;
minibatch = 1;
span_server = "";
m = 15;
diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h
index 0b78ef42..40e12718 100644
--- a/vowpalwabbit/global_data.h
+++ b/vowpalwabbit/global_data.h
@@ -220,6 +220,7 @@ struct vw {
float lda_alpha;
float lda_rho;
float lda_D;
+ float lda_epsilon;
std::string text_regressor_name;
std::string inv_hash_regressor_name;
diff --git a/vowpalwabbit/io_buf.cc b/vowpalwabbit/io_buf.cc
index f633ce1f..4dbf3cf1 100644
--- a/vowpalwabbit/io_buf.cc
+++ b/vowpalwabbit/io_buf.cc
@@ -7,6 +7,10 @@ license as described in the file LICENSE.
#include "io_buf.h"
+#ifdef WIN32
+#include <winsock2.h>
+#endif
+
size_t buf_read(io_buf &i, char* &pointer, size_t n)
{//return a pointer to the next n bytes. n must be smaller than the maximum size.
if (i.space.end + n <= i.endloaded)
@@ -52,7 +56,7 @@ bool isbinary(io_buf &i) {
size_t readto(io_buf &i, char* &pointer, char terminal)
{//Return a pointer to the bytes before the terminal. Must be less than the buffer size.
pointer = i.space.end;
- while (pointer != i.endloaded && *pointer != terminal)
+ while (pointer < i.endloaded && *pointer != terminal)
pointer++;
if (pointer != i.endloaded)
{
@@ -104,3 +108,51 @@ void buf_write(io_buf &o, char* &pointer, size_t n)
buf_write (o, pointer,n);
}
}
+
+bool io_buf::is_socket(int f)
+{
+ // this appears to work in practice, but could probably be done in a cleaner fashion
+ const int _nhandle = 32;
+ return f >= _nhandle;
+}
+
+ssize_t io_buf::read_file_or_socket(int f, void* buf, size_t nbytes) {
+#ifdef _WIN32
+ if (is_socket(f)) {
+ return recv(f, reinterpret_cast<char*>(buf), static_cast<int>(nbytes), 0);
+ }
+ else {
+ return _read(f, buf, (unsigned int)nbytes);
+ }
+#else
+ return read(f, buf, (unsigned int)nbytes);
+#endif
+}
+
+ssize_t io_buf::write_file_or_socket(int f, const void* buf, size_t nbytes)
+{
+#ifdef _WIN32
+ if (is_socket(f)) {
+ return send(f, reinterpret_cast<const char*>(buf), static_cast<int>(nbytes), 0);
+ }
+ else {
+ return _write(f, buf, (unsigned int)nbytes);
+ }
+#else
+ return write(f, buf, (unsigned int)nbytes);
+#endif
+}
+
+void io_buf::close_file_or_socket(int f)
+{
+#ifdef _WIN32
+ if (io_buf::is_socket(f)) {
+ closesocket(f);
+ }
+ else {
+ _close(f);
+ }
+#else
+ close(f);
+#endif
+}
diff --git a/vowpalwabbit/io_buf.h b/vowpalwabbit/io_buf.h
index e8b8157c..8b62ca76 100644
--- a/vowpalwabbit/io_buf.h
+++ b/vowpalwabbit/io_buf.h
@@ -120,13 +120,11 @@ class io_buf {
void set(char *p){space.end = p;}
virtual ssize_t read_file(int f, void* buf, size_t nbytes){
-#ifdef _WIN32
- return _read(f, buf, (unsigned int)nbytes);
-#else
- return read(f, buf, (unsigned int)nbytes);
-#endif
+ return read_file_or_socket(f, buf, nbytes);
}
+ static ssize_t read_file_or_socket(int f, void* buf, size_t nbytes);
+
size_t fill(int f) {
if (space.end_array - endloaded == 0)
{
@@ -144,15 +142,12 @@ class io_buf {
return 0;
}
- virtual ssize_t write_file(int f, const void* buf, size_t nbytes)
- {
-#ifdef _WIN32
- return _write(f, buf, (unsigned int)nbytes);
-#else
- return write(f, buf, (unsigned int)nbytes);
-#endif
+ virtual ssize_t write_file(int f, const void* buf, size_t nbytes) {
+ return write_file_or_socket(f, buf, nbytes);
}
+ static ssize_t write_file_or_socket(int f, const void* buf, size_t nbytes);
+
virtual void flush() {
if (write_file(files[0], space.begin, space.size()) != (int) space.size())
std::cerr << "error, failed to write example\n";
@@ -160,19 +155,19 @@ class io_buf {
virtual bool close_file(){
if(files.size()>0){
-#ifdef _WIN32
- _close(files.pop());
-#else
- close(files.pop());
-#endif
+ close_file_or_socket(files.pop());
return true;
}
return false;
}
+ static void close_file_or_socket(int f);
+
void close_files(){
while(close_file());
}
+
+ static bool is_socket(int f);
};
void buf_write(io_buf &o, char* &pointer, size_t n);
diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc
index c70c6551..fcf79afc 100644
--- a/vowpalwabbit/lda_core.cc
+++ b/vowpalwabbit/lda_core.cc
@@ -496,7 +496,7 @@ v_array<float> old_gamma;
for (size_t k =0; k<all.lda; k++)
new_gamma[k] = new_gamma[k]*v[k]+all.lda_alpha;
}
- while (average_diff(all, old_gamma.begin, new_gamma.begin) > 0.001);
+ while (average_diff(all, old_gamma.begin, new_gamma.begin) > all.lda_epsilon);
ec->topic_predictions.erase();
ec->topic_predictions.resize(all.lda);
@@ -580,6 +580,15 @@ void save_load(lda& l, io_buf& model_file, bool read, bool text)
void learn_batch(lda& l)
{
+ if (l.sorted_features.empty()) {
+ // This can happen when the socket connection is dropped by the client.
+ // If l.sorted_features is empty, then l.sorted_features[0] does not
+ // exist, so we should not try to take its address in the beginning of
+ // the for loops down there. Since it seems that there's not much to
+ // do in this case, we just return.
+ return;
+ }
+
float eta = -1;
float minuseta = -1;
@@ -754,6 +763,7 @@ learner* setup(vw&all, vector<string>&opts, po::variables_map& vm)
("lda_alpha", po::value<float>(&all.lda_alpha), "Prior on sparsity of per-document topic weights")
("lda_rho", po::value<float>(&all.lda_rho), "Prior on sparsity of topic distributions")
("lda_D", po::value<float>(&all.lda_D), "Number of documents")
+ ("lda_epsilon", po::value<float>(&all.lda_epsilon), "Loop convergence threshold")
("minibatch", po::value<size_t>(&all.minibatch), "Minibatch size, for LDA");
po::parsed_options parsed = po::command_line_parser(opts).
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index fdf0681d..ae6f7b4d 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -415,7 +415,8 @@ vw* parse_args(int argc, char *argv[])
if( all->normalized_updates ) all->reg.stride *= 2;
if(!vm.count("learning_rate") && !vm.count("l") && !(all->adaptive && all->normalized_updates))
- all->eta = 10; //default learning rate to 10 for non default update rule
+ if (all->lda == 0)
+ all->eta = 10; //default learning rate to 10 for non default update rule
//if not using normalized or adaptive, default initial_t to 1 instead of 0
if(!all->adaptive && !all->normalized_updates && !vm.count("initial_t")) {
@@ -1094,11 +1095,7 @@ namespace VW {
free(all.options_from_file_argv);
for (size_t i = 0; i < all.final_prediction_sink.size(); i++)
if (all.final_prediction_sink[i] != 1)
-#ifdef _WIN32
- _close(all.final_prediction_sink[i]);
-#else
- close(all.final_prediction_sink[i]);
-#endif
+ io_buf::close_file_or_socket(all.final_prediction_sink[i]);
all.final_prediction_sink.delete_v();
delete all.loss;
delete &all;
diff --git a/vowpalwabbit/parser.cc b/vowpalwabbit/parser.cc
index 6bee8cce..2c55f0f9 100644
--- a/vowpalwabbit/parser.cc
+++ b/vowpalwabbit/parser.cc
@@ -239,11 +239,7 @@ void reset_source(vw& all, size_t numbits)
{
int fd = input->files.pop();
if (!member(all.final_prediction_sink, (size_t) fd))
-#ifdef _WIN32
- _close(fd);
-#else
- close(fd);
-#endif
+ io_buf::close_file_or_socket(fd);
}
input->open_file(all.p->output->finalname.begin, all.stdin_off, io_buf::READ); //pushing is merged into open_file
all.p->reader = read_cached_features;
@@ -259,11 +255,7 @@ void reset_source(vw& all, size_t numbits)
mutex_unlock(&all.p->output_lock);
// close socket, erase final prediction sink and socket
-#ifdef _WIN32
- _close(all.p->input->files[0]);
-#else
- close(all.p->input->files[0]);
-#endif
+ io_buf::close_file_or_socket(all.p->input->files[0]);
all.final_prediction_sink.erase();
all.p->input->files.erase();
diff --git a/vowpalwabbit/simple_label.cc b/vowpalwabbit/simple_label.cc
index fa25b050..d3646ec9 100644
--- a/vowpalwabbit/simple_label.cc
+++ b/vowpalwabbit/simple_label.cc
@@ -216,8 +216,8 @@ void output_and_account_example(vw& all, example& ec)
for (size_t i = 0; i<all.final_prediction_sink.size(); i++)
{
int f = (int)all.final_prediction_sink[i];
- if(all.active)
- active_print_result(f, ec.final_prediction, ai, ec.tag);
+ if(all.active && all.lda == 0)
+ active_print_result(f, ec->final_prediction, ai, ec->tag);
else if (all.lda > 0)
print_lda_result(all, f,ec.topic_predictions.begin,0.,ec.tag);
else