From c5531bf697dc9235dc5e2fc3afb0a32c384b92e4 Mon Sep 17 00:00:00 2001 From: Stephane Ross Date: Tue, 28 Aug 2012 09:16:02 -0400 Subject: Update to gd to combine all update rules into single functions and allow specifying any combination of invariant, adaptive, normalized for the update. Fixed bugs in loss functions. --- vowpalwabbit/loss_functions.cc | 39 ++++++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 7 deletions(-) (limited to 'vowpalwabbit/loss_functions.cc') diff --git a/vowpalwabbit/loss_functions.cc b/vowpalwabbit/loss_functions.cc index 2528f904..67acf903 100644 --- a/vowpalwabbit/loss_functions.cc +++ b/vowpalwabbit/loss_functions.cc @@ -44,9 +44,13 @@ public: * with its first order Taylor expansion around 0 * to avoid catastrophic cancellation. */ - return (label - prediction)*eta_t/norm; + return 2.f*(label - prediction)*eta_t/norm; } - return (label - prediction)*(1-exp(-eta_t))/norm; + return (label - prediction)*(1.f-exp(-2.f*eta_t))/norm; + } + + float getUnsafeUpdate(float prediction, float label,float eta_t,float norm) { + return 2.f*(label - prediction)*eta_t/norm; } float getRevertingWeight(shared_data* sd, float prediction, float eta_t){ @@ -56,7 +60,7 @@ public: } float getSquareGrad(float prediction, float label) { - return (prediction - label) * (prediction - label); + return 4.f*(prediction - label) * (prediction - label); } float first_derivative(shared_data* sd, float prediction, float label) { @@ -87,7 +91,11 @@ public: } float getUpdate(float prediction, float label,float eta_t, float norm) { - return eta_t*(label - prediction)/norm; + return 2.f*eta_t*(label - prediction)/norm; + } + + float getUnsafeUpdate(float prediction, float label,float eta_t,float norm) { + return 2.f*(label - prediction)*eta_t/norm; } float getRevertingWeight(shared_data* sd, float prediction, float eta_t){ @@ -97,7 +105,7 @@ public: } float getSquareGrad(float prediction, float label) { - return (prediction - label) * (prediction - label); + return 4.f * (prediction - label) * (prediction - label); } float first_derivative(shared_data*, float prediction, float label) { @@ -128,6 +136,11 @@ public: return label * (normal < err ? normal : err)/norm; } + float getUnsafeUpdate(float prediction, float label,float eta_t, float norm) { + if(label*prediction >= label*label) return 0; + return label * eta_t/norm; + } + float getRevertingWeight(shared_data*, float prediction, float eta_t){ return fabs(prediction)/eta_t; } @@ -171,6 +184,11 @@ public: w = wexpmx(x); return -(label*w+prediction)/norm; } + + float getUnsafeUpdate(float prediction, float label, float eta_t, float norm) { + float d = exp(label * prediction); + return label*eta_t/((1+d)*norm); + } inline float wexpmx(float x){ /* This piece of code is approximating W(exp(x))-x. @@ -230,12 +248,19 @@ public: float normal = eta_t;//base update size if(err > 0) { normal = tau*normal; - return tau*(normal < err ? normal : err) / norm; + return (normal < err ? normal : err) / norm; } else { normal = -(1-tau) * normal; - return ( normal < - err ? normal : err) / norm; + return ( normal > err ? normal : err) / norm; } } + + float getUnsafeUpdate(float prediction, float label, float eta_t, float norm) { + float err = label - prediction; + if(err == 0) return 0; + if(err > 0) return tau*eta_t/norm; + return -(1-tau)*eta_t/norm; + } float getRevertingWeight(shared_data* sd, float prediction, float eta_t){ float v,t; -- cgit v1.2.3