Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/cg.cc
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2011-07-12 04:39:07 +0400
committerJohn Langford <jl@hunch.net>2011-07-12 04:39:07 +0400
commit6690f79874927f5047461f82d74378ca65f87400 (patch)
tree5bc0d5bba1d4d4e8af14063e501968992df63220 /cg.cc
parent12196971e50b96c7b4d738e368ef1e15fdd857f3 (diff)
initial per-feature regularization
Diffstat (limited to 'cg.cc')
-rw-r--r--cg.cc71
1 files changed, 63 insertions, 8 deletions
diff --git a/cg.cc b/cg.cc
index b7c7bf33..88454969 100644
--- a/cg.cc
+++ b/cg.cc
@@ -168,6 +168,15 @@ void zero_derivative(regressor& reg)
weights[stride*i+1] = 0;
}
+void zero_preconditioner(regressor& reg)
+{//set derivative to 0.
+ uint32_t length = 1 << global.num_bits;
+ size_t stride = global.stride;
+ weight* weights = reg.weight_vectors[0];
+ for(uint32_t i = 0; i < length; i++)
+ weights[stride*i+3] = 0;
+}
+
double direction_magnitude(regressor& reg)
{//compute direction magnitude
double ret = 0.;
@@ -195,15 +204,25 @@ double derivative_diff_mag(regressor& reg, float* old_first_derivative)
}
double add_regularization(regressor& reg,float regularization)
-{//compute the derivative difference
+{
double ret = 0.;
uint32_t length = 1 << global.num_bits;
size_t stride = global.stride;
weight* weights = reg.weight_vectors[0];
- for(uint32_t i = 0; i < length; i++) {
- weights[stride*i+1] += regularization*weights[stride*i];
- ret += weights[stride*i]*weights[stride*i];
- }
+ if (reg.regularizers == NULL)
+ {
+ for(uint32_t i = 0; i < length; i++) {
+ weights[stride*i+1] += regularization*weights[stride*i];
+ ret += weights[stride*i]*weights[stride*i];
+ }
+ }
+ else
+ {
+ for(uint32_t i = 0; i < length; i++) {
+ weights[stride*i+1] += regularization*reg.regularizers[0][i]*weights[stride*i];
+ ret += reg.regularizers[0][i]*weights[stride*i]*weights[stride*i];
+ }
+ }
ret *= 0.5*regularization;
return ret;
}
@@ -213,10 +232,28 @@ void finalize_preconditioner(regressor& reg,float regularization)
uint32_t length = 1 << global.num_bits;
size_t stride = global.stride;
weight* weights = reg.weight_vectors[0];
+
+ if (reg.regularizers == NULL)
+ for(uint32_t i = 0; i < length; i++) {
+ weights[stride*i+3] += regularization;
+ if (weights[stride*i+3] > 0)
+ weights[stride*i+3] = 1. / weights[stride*i+3];
+ }
+ else
+ for(uint32_t i = 0; i < length; i++) {
+ weights[stride*i+3] += regularization + reg.regularizers[0][i];
+ if (weights[stride*i+3] > 0)
+ weights[stride*i+3] = 1. / weights[stride*i+3];
+ }
+}
+
+void preconditioner_to_regularizer(regressor& reg, float regularization)
+{
+ uint32_t length = 1 << global.num_bits;
+ size_t stride = global.stride;
+ weight* weights = reg.weight_vectors[0];
for(uint32_t i = 0; i < length; i++) {
- weights[stride*i+3] += regularization;
- if (weights[stride*i+3] > 0)
- weights[stride*i+3] = 1. / weights[stride*i+3];
+ reg.regularizers[0][i] = weights[stride*i+3] + regularization;
}
}
@@ -332,6 +369,10 @@ void setup_cg(gd_thread_params t)
cerr.precision(5);
}
+ bool output_regularizer = false;
+ if (global.per_feature_regularizer_output != "" || global.per_feature_regularizer_text != "")
+ output_regularizer = true;
+
while ( true )
{
if ((ec = get_example(thread_num)) != NULL)//semiblocking operation.
@@ -432,6 +473,8 @@ void setup_cg(gd_thread_params t)
}
else
current_pass++;
+ if (output_regularizer && current_pass == global.numpasses - 1)
+ zero_preconditioner(reg);
}
if (gradient_pass)
{
@@ -456,6 +499,12 @@ void setup_cg(gd_thread_params t)
float sd = reg.loss->second_derivative(predictions[example_number++],ld->label);
curvature += d_dot_x*d_dot_x*sd*ld->weight;
}
+ if (output_regularizer && current_pass == global.numpasses -1)
+ {
+ label_data* ld = (label_data*)ec->ld;
+ importance_weight_sum += ld->weight;
+ update_preconditioner(reg,ec);//w[3]
+ }
finish_example(ec);
}
else if (thread_done(thread_num))
@@ -481,6 +530,12 @@ void setup_cg(gd_thread_params t)
}
update_weight(reg,step_size);
}
+ if (output_regularizer)//need to accumulate and place the regularizer.
+ {
+ if(global.master_location != "")
+ accumulate(socks, reg, 3); //Accumulate preconditioner
+ preconditioner_to_regularizer(reg,global.regularization);
+ }
ftime(&t_end_global);
net_time += (int) (1000.0 * (t_end_global.time - t_start_global.time) + (t_end_global.millitm - t_start_global.millitm));
if (!global.quiet)