Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/cg.cc
diff options
context:
space:
mode:
authorJohn Langford <jl@humpty.(none)>2010-11-28 06:12:37 +0300
committerJohn Langford <jl@humpty.(none)>2010-11-28 06:12:37 +0300
commit769cf31138647f971d042426774e4f5af2413cc4 (patch)
treed21005293188d3cf70ac74b5d86475730105a917 /cg.cc
parent29d3095ca3d2f2e30c8f63c01d9d71b021a39391 (diff)
tweak importance weight
Diffstat (limited to 'cg.cc')
-rw-r--r--cg.cc5
1 files changed, 4 insertions, 1 deletions
diff --git a/cg.cc b/cg.cc
index 77ba8303..cc4e1540 100644
--- a/cg.cc
+++ b/cg.cc
@@ -199,6 +199,7 @@ void setup_cg(gd_thread_params t)
bool gradient_pass=true;
double loss_sum = 0;
float step_size = 0.;
+ double importance_weight_sum = 0.;
double previous_d_mag=0;
size_t current_pass = 0;
@@ -227,6 +228,7 @@ void setup_cg(gd_thread_params t)
add_regularization(reg,global.regularization*predictions.index());
example_number = 0;
curvature = 0;
+ importance_weight_sum = 0;
float mix_frac = 0;
if (current_pass != 0)
mix_frac = derivative_diff_mag(reg) / previous_d_mag;
@@ -241,7 +243,7 @@ void setup_cg(gd_thread_params t)
else // just finished all second gradients
{
if (global.regularization > 0.)
- curvature += global.regularization*direction_magnitude(reg)*predictions.index();
+ curvature += global.regularization*direction_magnitude(reg)*importance_weight_sum;
step_size = - derivative_in_direction(reg)/(max(curvature,1.));
predictions.erase();
update_weight(reg,step_size);
@@ -267,6 +269,7 @@ void setup_cg(gd_thread_params t)
ec->loss = reg.loss->getLoss(ec->final_prediction, ld->label) * ld->weight;
float sd = reg.loss->second_derivative(predictions[example_number++],ld->label);
curvature += d_dot_x*d_dot_x*sd*ld->weight;
+ importance_weight_sum += ld->weight;
}
finish_example(ec);
}