Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikos Karampatziakis <nikos@research-mm5.corp.sp1.yahoo.com>2010-09-04 20:48:03 +0400
committerNikos Karampatziakis <nikos@research-mm5.corp.sp1.yahoo.com>2010-09-04 20:48:03 +0400
commit8164117b728a46d4468689e720efba1ab8fd0bc1 (patch)
treeccecd35ab9f625ba42611a62de4373eeb94f494f /loss_functions.cc
parentd75cdda5c7d939151acb14e001cfbacf6d9fba32 (diff)
Norm invariant loss functions
Diffstat (limited to 'loss_functions.cc')
-rw-r--r--loss_functions.cc83
1 files changed, 58 insertions, 25 deletions
diff --git a/loss_functions.cc b/loss_functions.cc
index 26ca7348..6a7decfa 100644
--- a/loss_functions.cc
+++ b/loss_functions.cc
@@ -9,6 +9,7 @@ embodied in the content of this file are licensed under the BSD
using namespace std;
#include "loss_functions.h"
+#include "global_data.h"
class squaredloss : public loss_function {
public:
@@ -22,10 +23,40 @@ public:
}
double getUpdate(double prediction, double label,double eta_t, double norm, float h) {
- return (label - prediction)*(1-exp(-h*eta_t*norm))/norm;
+ eta_t*=h;
+ if (eta_t<1e-12){
+ /* When exp(-eta_t)~= 1 we replace 1-exp(-eta_t)
+ * with its first order Taylor expansion around 0
+ * to avoid catastrophic cancellation.
+ */
+ return (label - prediction)*eta_t/norm;
+ }
+ return (label - prediction)*(1-exp(-eta_t))/norm;
}
+
+ //Second order update
+ //double getUpdate(double prediction, double label,double eta_t, double norm, float h) {
+ // return h*eta_t*(label - prediction)/(1+h*eta_t*norm);
+ //}
};
+class classic_squaredloss : public loss_function {
+public:
+ classic_squaredloss() {
+
+ }
+
+ double getLoss(double prediction, double label) {
+ double example_loss = (prediction - label) * (prediction - label);
+ return example_loss;
+ }
+
+ double getUpdate(double prediction, double label,double eta_t, double norm, float h) {
+ return h*eta_t*(label - prediction)/norm;
+ }
+};
+
+
class hingeloss : public loss_function {
public:
hingeloss() {
@@ -39,9 +70,9 @@ public:
double getUpdate(double prediction, double label,double eta_t, double norm, float h) {
if(label*prediction >= label*label) return 0;
- double s1=(label*label-label*prediction)/(label*label*norm);
+ double s1=(label*label-label*prediction)/(label*label);
double s2=eta_t*h;
- return label*(s1<s2 ? s1 : s2);
+ return label * (s1<s2 ? s1 : s2)/norm;
}
};
@@ -56,27 +87,25 @@ public:
}
double getUpdate(double prediction, double label, double eta_t, double norm, float h) {
- /* There's a simpler solution for this which involves approximating W(exp(x))-x */
- double s,b,q;
+ double b,l,q;
double d = exp(label * prediction);
- double c = eta_t * norm * h + label * prediction + d;
- /* In general we want s = h - (W(exp(c))-d)/(eta_t*norm) where W is
- * the Lambert W function. The following is a good approximation:
+ double x = eta_t*h + label*prediction + d;
+ /* This piece of code is approximating W(exp(x))-x.
+ * W is the Lambert W function.
+ * Faster/better approximations can be substituted here
*/
- if (c <= 1){
- q = -0.915756*c+0.763451;
- /* Safe-guard a large exponent */
- b = q > 500 ? 0.45865 : 0.45865+1/(1+exp(q));
- /* Safe-guard a large exponent */
- q = b - c > 500 ? exp(c-b) : 1/(1+exp(b-c));
- s = h-((1+exp(b-1))*q-d)/(eta_t*norm);
+ if (x >= 1){
+ l=log(x);
+ q=(2.16612+1.89678*x)/(2.16276+1.90021*x-l);
+ b=-x*l/(q+x);
+ }
+ else if(x<-7.010881832645721){
+ b=-x;
}
else{
- b = log(c);
- q = 0.997415 + 0.571902*b/(0.842524 + c);
- s = (c/(c+q)*b - label * prediction)/(eta_t*norm);
+ b=0.566841-x*(0.637815-x*(0.0752909-x*(0.00122244+x*(0.00284082+x*(0.000413765+0.0000193232*x)))));
}
- return eta_t * label * s;
+ return -(label*b+prediction)/norm;
}
};
@@ -101,11 +130,11 @@ public:
if(e == 0) return 0;
double s1=eta_t*h;
if(e > 0) {
- s2=e/(norm*tau);
- return tau*(s1<s2?s1:s2);
+ s2=e/tau;
+ return tau*(s1<s2?s1:s2)/norm;
} else {
- s2=-e/(norm*(1-tau));
- return -(1 - tau)*(s1<s2?s1:s2);
+ s2=-e/(1-tau);
+ return -(1 - tau)*(s1<s2?s1:s2)/norm;
}
}
@@ -115,14 +144,18 @@ public:
loss_function* getLossFunction(string funcName, double function_parameter) {
if(funcName.compare("squared") == 0) {
return new squaredloss();
- } else if(funcName.compare("hinge") == 0) {
+ } else if(funcName.compare("classic") == 0){
+ return new classic_squaredloss();
+ } else if(funcName.compare("hinge") == 0) {
return new hingeloss();
} else if(funcName.compare("logistic") == 0) {
+ global.min_label = -100;
+ global.max_label = 100;
return new logloss();
} else if(funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0) {
return new quantileloss(function_parameter);
} else {
- cout << "Invalid loss function name: " << funcName << " Bailing!" << endl;
+ cout << "Invalid loss function name: \'" << funcName << "\' Bailing!" << endl;
exit(1);
}
cout << "end getLossFunction" << endl;