Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikos Karampatziakis <nikos@research-mm5.corp.sp1.yahoo.com>2010-08-25 00:10:12 +0400
committerNikos Karampatziakis <nikos@research-mm5.corp.sp1.yahoo.com>2010-08-25 00:10:12 +0400
commitd75cdda5c7d939151acb14e001cfbacf6d9fba32 (patch)
tree6aa6f854cc3a9188d15500c1d9050836a3496dd3 /loss_functions.cc
parentdcbdf822ed9ce8af1f3acadbe0cf208c064cc5b0 (diff)
Patch for proper treatment of importance weights
Diffstat (limited to 'loss_functions.cc')
-rw-r--r--loss_functions.cc50
1 files changed, 37 insertions, 13 deletions
diff --git a/loss_functions.cc b/loss_functions.cc
index ad80d231..26ca7348 100644
--- a/loss_functions.cc
+++ b/loss_functions.cc
@@ -21,8 +21,8 @@ public:
return example_loss;
}
- double getUpdate(double prediction, double label) {
- return (label - prediction);
+ double getUpdate(double prediction, double label,double eta_t, double norm, float h) {
+ return (label - prediction)*(1-exp(-h*eta_t*norm))/norm;
}
};
@@ -37,10 +37,11 @@ public:
return (e > 0) ? e : 0;
}
- double getUpdate(double prediction, double label) {
- if(prediction == label) return 0;
- return label;
-
+ double getUpdate(double prediction, double label,double eta_t, double norm, float h) {
+ if(label*prediction >= label*label) return 0;
+ double s1=(label*label-label*prediction)/(label*label*norm);
+ double s2=eta_t*h;
+ return label*(s1<s2 ? s1 : s2);
}
};
@@ -54,9 +55,28 @@ public:
return log(1 + exp(-label * prediction));
}
- double getUpdate(double prediction, double label) {
- double d = exp(-label * prediction);
- return label * d / (1 + d);
+ double getUpdate(double prediction, double label, double eta_t, double norm, float h) {
+ /* There's a simpler solution for this which involves approximating W(exp(x))-x */
+ double s,b,q;
+ double d = exp(label * prediction);
+ double c = eta_t * norm * h + label * prediction + d;
+ /* In general we want s = h - (W(exp(c))-d)/(eta_t*norm) where W is
+ * the Lambert W function. The following is a good approximation:
+ */
+ if (c <= 1){
+ q = -0.915756*c+0.763451;
+ /* Safe-guard a large exponent */
+ b = q > 500 ? 0.45865 : 0.45865+1/(1+exp(q));
+ /* Safe-guard a large exponent */
+ q = b - c > 500 ? exp(c-b) : 1/(1+exp(b-c));
+ s = h-((1+exp(b-1))*q-d)/(eta_t*norm);
+ }
+ else{
+ b = log(c);
+ q = 0.997415 + 0.571902*b/(0.842524 + c);
+ s = (c/(c+q)*b - label * prediction)/(eta_t*norm);
+ }
+ return eta_t * label * s;
}
};
@@ -75,13 +95,17 @@ public:
}
- double getUpdate(double prediction, double label) {
- double e = label - prediction;
+ double getUpdate(double prediction, double label, double eta_t, double norm, float h) {
+ double s2;
+ double e = label - prediction;
if(e == 0) return 0;
+ double s1=eta_t*h;
if(e > 0) {
- return tau;
+ s2=e/(norm*tau);
+ return tau*(s1<s2?s1:s2);
} else {
- return -(1 - tau);
+ s2=-e/(norm*(1-tau));
+ return -(1 - tau)*(s1<s2?s1:s2);
}
}