Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/cg.cc
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2011-07-18 10:57:40 +0400
committerJohn Langford <jl@hunch.net>2011-07-18 10:57:40 +0400
commit280b23d4a1ad20bfdc90eaab7af2a087b449e377 (patch)
tree7a3740f0091c4c768ceb3a333840e5bf78d3a0aa /cg.cc
parent6cb47949f15cfad879976775ae20c733909e9518 (diff)
improved per-feature regularization from Olivier
Diffstat (limited to 'cg.cc')
-rw-r--r--cg.cc32
1 files changed, 18 insertions, 14 deletions
diff --git a/cg.cc b/cg.cc
index 6b1897e8..35c879f1 100644
--- a/cg.cc
+++ b/cg.cc
@@ -189,12 +189,12 @@ double regularizer_direction_magnitude(regressor& reg, float regularizer)
weight* weights = reg.weight_vectors[0];
if (reg.regularizers == NULL)
for(uint32_t i = 0; i < length; i++)
- ret += weights[stride*i+2]*weights[stride*i+2];
+ ret += regularizer*weights[stride*i+2]*weights[stride*i+2];
else
- for(uint32_t i = 0; i < length; i++)
- ret += reg.regularizers[0][i]*weights[stride*i+2]*weights[stride*i+2];
+ for(uint32_t i = 0; i < length; i++)
+ ret += reg.regularizers[0][2*i]*weights[stride*i+2]*weights[stride*i+2];
- return ret*regularizer;
+ return ret;
}
double derivative_diff_mag(regressor& reg, float* old_first_derivative)
@@ -221,17 +221,17 @@ double add_regularization(regressor& reg,float regularization)
{
for(uint32_t i = 0; i < length; i++) {
weights[stride*i+1] += regularization*weights[stride*i];
- ret += weights[stride*i]*weights[stride*i];
+ ret += 0.5*regularization*weights[stride*i]*weights[stride*i];
}
}
else
{
for(uint32_t i = 0; i < length; i++) {
- weights[stride*i+1] += regularization*reg.regularizers[0][i]*weights[stride*i];
- ret += reg.regularizers[0][i]*weights[stride*i]*weights[stride*i];
+ weight delta_weight = weights[stride*i] - reg.regularizers[0][2*i+1];
+ weights[stride*i+1] += reg.regularizers[0][2*i]*delta_weight;
+ ret += 0.5*reg.regularizers[0][2*i]*delta_weight*delta_weight;
}
}
- ret *= 0.5*regularization;
return ret;
}
@@ -249,7 +249,7 @@ void finalize_preconditioner(regressor& reg,float regularization)
}
else
for(uint32_t i = 0; i < length; i++) {
- weights[stride*i+3] += regularization*reg.regularizers[0][i];
+ weights[stride*i+3] += reg.regularizers[0][2*i];
if (weights[stride*i+3] > 0)
weights[stride*i+3] = 1. / weights[stride*i+3];
}
@@ -267,7 +267,7 @@ void preconditioner_to_regularizer(regressor& reg, float regularization)
for (size_t i = 0; i < num_threads; i++)
{
if (reg.regularizers != NULL)
- reg.regularizers[i] = (weight *)calloc(length/num_threads, sizeof(weight));
+ reg.regularizers[i] = (weight *)calloc(2*length/num_threads, sizeof(weight));
if ((reg.regularizers != NULL && reg.regularizers[i] == NULL))
{
@@ -276,11 +276,13 @@ void preconditioner_to_regularizer(regressor& reg, float regularization)
}
}
for(uint32_t i = 0; i < length; i++)
- reg.regularizers[0][i] = weights[stride*i+3] + regularization;
+ reg.regularizers[0][2*i] = weights[stride*i+3] + regularization;
}
else
for(uint32_t i = 0; i < length; i++)
- reg.regularizers[0][i] = weights[stride*i+3] + regularization*reg.regularizers[0][i];
+ reg.regularizers[0][2*i] = weights[stride*i+3] + reg.regularizers[0][2*i];
+ for(uint32_t i = 0; i < length; i++)
+ reg.regularizers[0][2*i+1] = weights[stride*i];
}
void zero_state(regressor& reg, float* old_first_derivative)
@@ -396,8 +398,10 @@ void setup_cg(gd_thread_params& t)
}
bool output_regularizer = false;
- if (global.per_feature_regularizer_output != "" || global.per_feature_regularizer_text != "")
+ if (global.per_feature_regularizer_output != "" || global.per_feature_regularizer_text != "") {
+ global.regularization = 1; // To make sure we are adding the regularization
output_regularizer = true;
+ }
while ( true )
{
@@ -549,7 +553,7 @@ void setup_cg(gd_thread_params& t)
}
float step_size = - dd/(max(curvature,1.));
if (!global.quiet) {
- fprintf(stderr, "%-e\t%-e\t%-f\n", curvature, step_size, 0.5*step_size*step_size*curvature);
+ fprintf(stderr, "%-e\t%-e\t%-f\n", curvature, step_size, 0.5*step_size*step_size*curvature/importance_weight_sum);
//fprintf(stdout, "Net comm. time is %f\n",net_comm_time - prev_comm_time);
}
update_weight(reg,step_size);