Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/loss_functions.cc')
-rw-r--r--vowpalwabbit/loss_functions.cc39
1 files changed, 32 insertions, 7 deletions
diff --git a/vowpalwabbit/loss_functions.cc b/vowpalwabbit/loss_functions.cc
index 2528f904..67acf903 100644
--- a/vowpalwabbit/loss_functions.cc
+++ b/vowpalwabbit/loss_functions.cc
@@ -44,9 +44,13 @@ public:
* with its first order Taylor expansion around 0
* to avoid catastrophic cancellation.
*/
- return (label - prediction)*eta_t/norm;
+ return 2.f*(label - prediction)*eta_t/norm;
}
- return (label - prediction)*(1-exp(-eta_t))/norm;
+ return (label - prediction)*(1.f-exp(-2.f*eta_t))/norm;
+ }
+
+ float getUnsafeUpdate(float prediction, float label,float eta_t,float norm) {
+ return 2.f*(label - prediction)*eta_t/norm;
}
float getRevertingWeight(shared_data* sd, float prediction, float eta_t){
@@ -56,7 +60,7 @@ public:
}
float getSquareGrad(float prediction, float label) {
- return (prediction - label) * (prediction - label);
+ return 4.f*(prediction - label) * (prediction - label);
}
float first_derivative(shared_data* sd, float prediction, float label)
{
@@ -87,7 +91,11 @@ public:
}
float getUpdate(float prediction, float label,float eta_t, float norm) {
- return eta_t*(label - prediction)/norm;
+ return 2.f*eta_t*(label - prediction)/norm;
+ }
+
+ float getUnsafeUpdate(float prediction, float label,float eta_t,float norm) {
+ return 2.f*(label - prediction)*eta_t/norm;
}
float getRevertingWeight(shared_data* sd, float prediction, float eta_t){
@@ -97,7 +105,7 @@ public:
}
float getSquareGrad(float prediction, float label) {
- return (prediction - label) * (prediction - label);
+ return 4.f * (prediction - label) * (prediction - label);
}
float first_derivative(shared_data*, float prediction, float label)
{
@@ -128,6 +136,11 @@ public:
return label * (normal < err ? normal : err)/norm;
}
+ float getUnsafeUpdate(float prediction, float label,float eta_t, float norm) {
+ if(label*prediction >= label*label) return 0;
+ return label * eta_t/norm;
+ }
+
float getRevertingWeight(shared_data*, float prediction, float eta_t){
return fabs(prediction)/eta_t;
}
@@ -171,6 +184,11 @@ public:
w = wexpmx(x);
return -(label*w+prediction)/norm;
}
+
+ float getUnsafeUpdate(float prediction, float label, float eta_t, float norm) {
+ float d = exp(label * prediction);
+ return label*eta_t/((1+d)*norm);
+ }
inline float wexpmx(float x){
/* This piece of code is approximating W(exp(x))-x.
@@ -230,12 +248,19 @@ public:
float normal = eta_t;//base update size
if(err > 0) {
normal = tau*normal;
- return tau*(normal < err ? normal : err) / norm;
+ return (normal < err ? normal : err) / norm;
} else {
normal = -(1-tau) * normal;
- return ( normal < - err ? normal : err) / norm;
+ return ( normal > err ? normal : err) / norm;
}
}
+
+ float getUnsafeUpdate(float prediction, float label, float eta_t, float norm) {
+ float err = label - prediction;
+ if(err == 0) return 0;
+ if(err > 0) return tau*eta_t/norm;
+ return -(1-tau)*eta_t/norm;
+ }
float getRevertingWeight(shared_data* sd, float prediction, float eta_t){
float v,t;