Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJake Hofman <jhofman+github@gmail.com>2011-08-03 19:13:27 +0400
committerJohn <jl@hunch.net>2011-08-04 04:07:50 +0400
commitf7c683ee47ed080521fd0a4c90169debb3df76cd (patch)
treee84ad9daabe53f427ecf9633ce756669e3eff4e6 /parse_regressor.cc
parent6734a953e719f5ca6e0cc942c4848af066c2c35c (diff)
rank and lda now saved in regressor
Diffstat (limited to 'parse_regressor.cc')
-rw-r--r--parse_regressor.cc43
1 files changed, 43 insertions, 0 deletions
diff --git a/parse_regressor.cc b/parse_regressor.cc
index bbdbf04a..95e36746 100644
--- a/parse_regressor.cc
+++ b/parse_regressor.cc
@@ -137,6 +137,39 @@ void read_vector(const char* file, regressor& r, bool& initialized, bool reg_vec
string temp(pair, 2);
local_pairs.push_back(temp);
}
+
+
+ size_t local_rank;
+ source.read((char*)&local_rank, sizeof(local_rank));
+ size_t local_lda;
+ source.read((char*)&local_lda, sizeof(local_lda));
+ if (!initialized)
+ {
+ global.rank = local_rank;
+ global.lda = local_lda;
+ //initialized = true;
+ }
+ else
+ {
+ cout << "can't combine regressors" << endl;
+ exit(1);
+ }
+
+ if (global.rank > 0)
+ {
+ float temp = ceilf(logf((float)(global.rank*2+1)) / logf (2.f));
+ global.stride = 1 << (int) temp;
+ global.random_weights = true;
+ }
+
+ if (global.lda > 0)
+ {
+ // par->sort_features = true;
+ float temp = ceilf(logf((float)(global.lda*2+1)) / logf (2.f));
+ global.stride = 1 << (int) temp;
+ global.random_weights = true;
+ }
+
if (!initialized)
{
global.pairs = local_pairs;
@@ -170,6 +203,7 @@ void read_vector(const char* file, regressor& r, bool& initialized, bool reg_vec
cout << "can't combine sources with different ngram features!" << endl;
exit(1);
}
+
size_t stride = global.stride;
while (source.good())
{
@@ -280,6 +314,9 @@ void dump_regressor(string reg_name, regressor &r, bool as_text, bool reg_vector
for (vector<string>::iterator i = global.pairs.begin(); i != global.pairs.end();i++)
io_temp.write_file(f,i->c_str(),2);
+ io_temp.write_file(f,(char*)&global.rank, sizeof(global.rank));
+ io_temp.write_file(f,(char*)&global.lda, sizeof(global.lda));
+
io_temp.write_file(f,(char*)&global.ngram, sizeof(global.ngram));
io_temp.write_file(f,(char*)&global.skips, sizeof(global.skips));
}
@@ -303,6 +340,12 @@ void dump_regressor(string reg_name, regressor &r, bool as_text, bool reg_vector
}
len = sprintf(buff, "ngram:%d skips:%d\nindex:weight pairs:\n", (int)global.ngram, (int)global.skips);
io_temp.write_file(f, buff, len);
+
+ len = sprintf(buff, "rank:%d\n", (int)global.rank);
+ io_temp.write_file(f, buff, len);
+
+ len = sprintf(buff, "lda:%d\n", (int)global.lda);
+ io_temp.write_file(f, buff, len);
}
uint32_t length = 1 << global.num_bits;