Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2013-12-24 23:46:23 +0400
committerJohn Langford <jl@hunch.net>2013-12-24 23:46:23 +0400
commitc95619bd2c56b346fe317b8fa134ee1c5b7c828e (patch)
treefbb31b1a16c318c272950188733acb6e02d91502
parentcc7db28db06560a45bf4c61c2b936a02b3697fa8 (diff)
initial cbify and a bugfix
-rw-r--r--cs_test/VowpalWabbitInterface.cs3
-rw-r--r--vowpalwabbit/Makefile.am2
-rw-r--r--vowpalwabbit/cb.cc110
-rw-r--r--vowpalwabbit/cb.h8
-rw-r--r--vowpalwabbit/oaa.h2
-rw-r--r--vowpalwabbit/parse_args.cc22
-rw-r--r--vowpalwabbit/searn.cc8
-rw-r--r--vowpalwabbit/vwdll.h1
8 files changed, 90 insertions, 66 deletions
diff --git a/cs_test/VowpalWabbitInterface.cs b/cs_test/VowpalWabbitInterface.cs
index f42aa576..3c9ae27c 100644
--- a/cs_test/VowpalWabbitInterface.cs
+++ b/cs_test/VowpalWabbitInterface.cs
@@ -115,6 +115,9 @@ namespace Microsoft.Research.MachineLearning
[DllImport("libvw.dll", EntryPoint = "VW_Get_Weight", CallingConvention = CallingConvention.StdCall)]
public static extern float Get_Weight(IntPtr vw, UInt32 index, UInt32 offset);
+ [DllImport("libvw.dll", EntryPoint = "VW_Set_Weight", CallingConvention = CallingConvention.StdCall)]
+ public static extern void Set_Weight(IntPtr vw, UInt32 index, UInt32 offset, float value);
+
[DllImport("libvw.dll", EntryPoint = "VW_Num_Weights", CallingConvention = CallingConvention.StdCall)]
public static extern UInt32 Num_Weights(IntPtr vw);
diff --git a/vowpalwabbit/Makefile.am b/vowpalwabbit/Makefile.am
index 05b6ce66..1a9dddc4 100644
--- a/vowpalwabbit/Makefile.am
+++ b/vowpalwabbit/Makefile.am
@@ -4,7 +4,7 @@ liballreduce_la_SOURCES = allreduce.cc
bin_PROGRAMS = vw active_interactor
-libvw_la_SOURCES = hash.cc global_data.cc io_buf.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc rand48.cc simple_label.cc oaa.cc ect.cc autolink.cc binary.cc csoaa.cc cb.cc wap.cc beam.cc searn.cc searn_sequencetask.cc parse_example.cc sparse_dense.cc network.cc parse_args.cc accumulate.cc gd.cc learner.cc lda_core.cc gd_mf.cc bfgs.cc noop.cc example.cc parser.cc loss_functions.cc sender.cc nn.cc bs.cc topk.cc
+libvw_la_SOURCES = hash.cc global_data.cc io_buf.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc rand48.cc simple_label.cc oaa.cc ect.cc autolink.cc binary.cc csoaa.cc cb.cc wap.cc beam.cc searn.cc searn_sequencetask.cc parse_example.cc sparse_dense.cc network.cc parse_args.cc accumulate.cc gd.cc learner.cc lda_core.cc gd_mf.cc bfgs.cc noop.cc example.cc parser.cc loss_functions.cc sender.cc nn.cc bs.cc cbify.cc topk.cc
# accumulate.cc uses all_reduce
libvw_la_LIBADD = liballreduce.la
diff --git a/vowpalwabbit/cb.cc b/vowpalwabbit/cb.cc
index 482edaf3..c1346a61 100644
--- a/vowpalwabbit/cb.cc
+++ b/vowpalwabbit/cb.cc
@@ -39,7 +39,7 @@ namespace CB
//if we specified more than 1 action for this example, i.e. either we have a limited set of possible actions, or all actions are specified
//than check if all actions have a specified cost
for (cb_class* cl = ld->costs.begin; cl != ld->costs.end; cl++)
- if (cl->x == FLT_MAX)
+ if (cl->cost == FLT_MAX)
return false;
return true;
@@ -50,7 +50,7 @@ namespace CB
if (ld->costs.size() == 0)
return true;
for (size_t i=0; i<ld->costs.size(); i++)
- if (FLT_MAX != ld->costs[i].x && ld->costs[i].prob_action > 0.)
+ if (FLT_MAX != ld->costs[i].cost && ld->costs[i].probability > 0.)
return false;
return true;
}
@@ -154,19 +154,19 @@ namespace CB
throw exception();
}
- f.weight_index = (uint32_t)hashstring(p->parse_name[0], 0);
- if (f.weight_index < 1 || f.weight_index > sd->k)
+ f.action = (uint32_t)hashstring(p->parse_name[0], 0);
+ if (f.action < 1 || f.action > sd->k)
{
- cerr << "invalid action: " << f.weight_index << endl;
+ cerr << "invalid action: " << f.action << endl;
cerr << "terminating." << endl;
throw exception();
}
- f.x = FLT_MAX;
+ f.cost = FLT_MAX;
if(p->parse_name.size() > 1)
- f.x = float_of_substring(p->parse_name[1]);
+ f.cost = float_of_substring(p->parse_name[1]);
- if ( nanpattern(f.x))
+ if ( nanpattern(f.cost))
{
cerr << "error NaN cost for action: ";
cerr.write(p->parse_name[0].begin, p->parse_name[0].end - p->parse_name[0].begin);
@@ -174,11 +174,11 @@ namespace CB
throw exception();
}
- f.prob_action = .0;
+ f.probability = .0;
if(p->parse_name.size() > 2)
- f.prob_action = float_of_substring(p->parse_name[2]);
+ f.probability = float_of_substring(p->parse_name[2]);
- if ( nanpattern(f.prob_action))
+ if ( nanpattern(f.probability))
{
cerr << "error NaN probability for action: ";
cerr.write(p->parse_name[0].begin, p->parse_name[0].end - p->parse_name[0].begin);
@@ -186,15 +186,15 @@ namespace CB
throw exception();
}
- if( f.prob_action > 1.0 )
+ if( f.probability > 1.0 )
{
cerr << "invalid probability > 1 specified for an action, resetting to 1." << endl;
- f.prob_action = 1.0;
+ f.probability = 1.0;
}
- if( f.prob_action < 0.0 )
+ if( f.probability < 0.0 )
{
cerr << "invalid probability < 0 specified for an action, resetting to 0." << endl;
- f.prob_action = .0;
+ f.probability = .0;
}
ld->costs.push_back(f);
@@ -204,7 +204,7 @@ namespace CB
inline bool observed_cost(cb_class* cl)
{
//cost observed for this action if it has non zero probability and cost != FLT_MAX
- return (cl != NULL && cl->x != FLT_MAX && cl->prob_action > .0);
+ return (cl != NULL && cl->cost != FLT_MAX && cl->probability > .0);
}
cb_class* get_observed_cost(CB::label* ld)
@@ -235,15 +235,15 @@ namespace CB
wc.weight_index = i;
wc.partial_prediction = 0.;
wc.wap_value = 0.;
- if( c.known_cost != NULL && i == c.known_cost->weight_index )
+ if( c.known_cost != NULL && i == c.known_cost->action )
{
- wc.x = c.known_cost->x / c.known_cost->prob_action; //use importance weighted cost for observed action, 0 otherwise
+ wc.x = c.known_cost->cost / c.known_cost->probability; //use importance weighted cost for observed action, 0 otherwise
//ips can be thought as the doubly robust method with a fixed regressor that predicts 0 costs for everything
//update the loss of this regressor
c.nb_ex_regressors++;
- c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (c.known_cost->x)*(c.known_cost->x) - c.avg_loss_regressors );
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (c.known_cost->cost)*(c.known_cost->cost) - c.avg_loss_regressors );
c.last_pred_reg = 0;
- c.last_correct_cost = c.known_cost->x;
+ c.last_correct_cost = c.known_cost->cost;
}
cs_ld.costs.push_back(wc );
@@ -256,19 +256,19 @@ namespace CB
CSOAA::wclass wc;
wc.wap_value = 0.;
wc.x = 0.;
- wc.weight_index = cl->weight_index;
+ wc.weight_index = cl->action;
wc.partial_prediction = 0.;
wc.wap_value = 0.;
- if( c.known_cost != NULL && cl->weight_index == c.known_cost->weight_index )
+ if( c.known_cost != NULL && cl->action == c.known_cost->action )
{
- wc.x = c.known_cost->x / c.known_cost->prob_action; //use importance weighted cost for observed action, 0 otherwise
+ wc.x = c.known_cost->cost / c.known_cost->probability; //use importance weighted cost for observed action, 0 otherwise
//ips can be thought as the doubly robust method with a fixed regressor that predicts 0 costs for everything
//update the loss of this regressor
c.nb_ex_regressors++;
- c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (c.known_cost->x)*(c.known_cost->x) - c.avg_loss_regressors );
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (c.known_cost->cost)*(c.known_cost->cost) - c.avg_loss_regressors );
c.last_pred_reg = 0;
- c.last_correct_cost = c.known_cost->x;
+ c.last_correct_cost = c.known_cost->cost;
}
cs_ld.costs.push_back( wc );
@@ -294,9 +294,9 @@ namespace CB
label_data simple_temp;
simple_temp.initial = 0.;
- if (c.known_cost != NULL && index == c.known_cost->weight_index)
+ if (c.known_cost != NULL && index == c.known_cost->action)
{
- simple_temp.label = c.known_cost->x;
+ simple_temp.label = c.known_cost->cost;
simple_temp.weight = 1.;
}
else
@@ -343,11 +343,11 @@ namespace CB
wc.partial_prediction = 0.;
wc.wap_value = 0.;
- if( c.known_cost != NULL && c.known_cost->weight_index == i ) {
+ if( c.known_cost != NULL && c.known_cost->action == i ) {
c.nb_ex_regressors++;
- c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (c.known_cost->x - wc.x)*(c.known_cost->x - wc.x) - c.avg_loss_regressors );
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (c.known_cost->cost - wc.x)*(c.known_cost->cost - wc.x) - c.avg_loss_regressors );
c.last_pred_reg = wc.x;
- c.last_correct_cost = c.known_cost->x;
+ c.last_correct_cost = c.known_cost->cost;
}
cs_ld.costs.push_back( wc );
@@ -361,22 +361,22 @@ namespace CB
wc.wap_value = 0.;
//get cost prediction for this action
- wc.x = get_cost_pred(all, c, ec, cl->weight_index - 1);
- if (wc.x < min || (wc.x == min && cl->weight_index < argmin))
+ wc.x = get_cost_pred(all, c, ec, cl->action - 1);
+ if (wc.x < min || (wc.x == min && cl->action < argmin))
{
min = wc.x;
- argmin = cl->weight_index;
+ argmin = cl->action;
}
- wc.weight_index = cl->weight_index;
+ wc.weight_index = cl->action;
wc.partial_prediction = 0.;
wc.wap_value = 0.;
- if( c.known_cost != NULL && c.known_cost->weight_index == cl->weight_index ) {
+ if( c.known_cost != NULL && c.known_cost->action == cl->action ) {
c.nb_ex_regressors++;
- c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (c.known_cost->x - wc.x)*(c.known_cost->x - wc.x) - c.avg_loss_regressors );
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (c.known_cost->cost - wc.x)*(c.known_cost->cost - wc.x) - c.avg_loss_regressors );
c.last_pred_reg = wc.x;
- c.last_correct_cost = c.known_cost->x;
+ c.last_correct_cost = c.known_cost->cost;
}
cs_ld.costs.push_back( wc );
@@ -406,12 +406,12 @@ namespace CB
wc.wap_value = 0.;
//add correction if we observed cost for this action and regressor is wrong
- if( c.known_cost != NULL && c.known_cost->weight_index == i ) {
+ if( c.known_cost != NULL && c.known_cost->action == i ) {
c.nb_ex_regressors++;
- c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (c.known_cost->x - wc.x)*(c.known_cost->x - wc.x) - c.avg_loss_regressors );
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (c.known_cost->cost - wc.x)*(c.known_cost->cost - wc.x) - c.avg_loss_regressors );
c.last_pred_reg = wc.x;
- c.last_correct_cost = c.known_cost->x;
- wc.x += (c.known_cost->x - wc.x) / c.known_cost->prob_action;
+ c.last_correct_cost = c.known_cost->cost;
+ wc.x += (c.known_cost->cost - wc.x) / c.known_cost->probability;
}
cs_ld.costs.push_back( wc );
@@ -425,18 +425,18 @@ namespace CB
wc.wap_value = 0.;
//get cost prediction for this label
- wc.x = get_cost_pred(all, c, ec, all.sd->k + cl->weight_index - 1);
- wc.weight_index = cl->weight_index;
+ wc.x = get_cost_pred(all, c, ec, all.sd->k + cl->action - 1);
+ wc.weight_index = cl->action;
wc.partial_prediction = 0.;
wc.wap_value = 0.;
//add correction if we observed cost for this action and regressor is wrong
- if( c.known_cost != NULL && c.known_cost->weight_index == cl->weight_index ) {
+ if( c.known_cost != NULL && c.known_cost->action == cl->action ) {
c.nb_ex_regressors++;
- c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (c.known_cost->x - wc.x)*(c.known_cost->x - wc.x) - c.avg_loss_regressors );
+ c.avg_loss_regressors += (1.0f/c.nb_ex_regressors)*( (c.known_cost->cost - wc.x)*(c.known_cost->cost - wc.x) - c.avg_loss_regressors );
c.last_pred_reg = wc.x;
- c.last_correct_cost = c.known_cost->x;
- wc.x += (c.known_cost->x - wc.x) / c.known_cost->prob_action;
+ c.last_correct_cost = c.known_cost->cost;
+ wc.x += (c.known_cost->cost - wc.x) / c.known_cost->probability;
}
cs_ld.costs.push_back( wc );
@@ -457,8 +457,8 @@ namespace CB
CSOAA::wclass wc;
wc.wap_value = 0.;
- wc.x = cl->x;
- wc.weight_index = cl->weight_index;
+ wc.x = cl->cost;
+ wc.weight_index = cl->action;
wc.partial_prediction = 0.;
wc.wap_value = 0.;
@@ -486,8 +486,8 @@ namespace CB
//now this is a training example
c->known_cost = get_observed_cost(ld);
- c->min_cost = min (c->min_cost, c->known_cost->x);
- c->max_cost = max (c->max_cost, c->known_cost->x);
+ c->min_cost = min (c->min_cost, c->known_cost->cost);
+ c->max_cost = max (c->max_cost, c->known_cost->cost);
//generate a cost-sensitive example to update classifiers
switch(c->cb_type)
@@ -585,8 +585,8 @@ namespace CB
float chosen_loss = FLT_MAX;
if( know_all_cost_example(ld) ) {
for (cb_class *cl = ld->costs.begin; cl != ld->costs.end; cl ++) {
- if (cl->weight_index == pred)
- chosen_loss = cl->x;
+ if (cl->action == pred)
+ chosen_loss = cl->cost;
}
}
else {
@@ -595,8 +595,8 @@ namespace CB
if (cl->weight_index == pred)
{
chosen_loss = cl->x;
- if (c.known_cost->weight_index == pred && c.cb_type == CB_TYPE_DM)
- chosen_loss += (c.known_cost->x - chosen_loss) / c.known_cost->prob_action;
+ if (c.known_cost->action == pred && c.cb_type == CB_TYPE_DM)
+ chosen_loss += (c.known_cost->cost - chosen_loss) / c.known_cost->probability;
}
}
}
diff --git a/vowpalwabbit/cb.h b/vowpalwabbit/cb.h
index 62870016..4fcb34e2 100644
--- a/vowpalwabbit/cb.h
+++ b/vowpalwabbit/cb.h
@@ -22,10 +22,10 @@ license as described in the file LICENSE.
namespace CB {
struct cb_class {
- float x; // the cost of this class
- uint32_t weight_index; // the index of this class
- float prob_action; //new for bandit setting, specifies the probability the data collection policy chose this class for importance weighting
- bool operator==(cb_class j){return weight_index == j.weight_index;}
+ float cost; // the cost of this class
+ uint32_t action; // the index of this class
+ float probability; //new for bandit setting, specifies the probability the data collection policy chose this class for importance weighting
+ bool operator==(cb_class j){return action == j.action;}
};
struct label {
diff --git a/vowpalwabbit/oaa.h b/vowpalwabbit/oaa.h
index 2ea20fda..c8b900af 100644
--- a/vowpalwabbit/oaa.h
+++ b/vowpalwabbit/oaa.h
@@ -61,5 +61,3 @@ namespace OAA
-
-
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index 3eee061d..4fd53803 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -15,6 +15,7 @@ license as described in the file LICENSE.
#include "network.h"
#include "global_data.h"
#include "nn.h"
+#include "cbify.h"
#include "oaa.h"
#include "bs.h"
#include "topk.h"
@@ -244,6 +245,7 @@ vw* parse_args(int argc, char *argv[])
("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs")
("lda", po::value<size_t>(&(all->lda)), "Run lda with <int> topics")
("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units")
+ ("cbify", po::value<size_t>(), "Convert multiclass on <k> classes into a contextual bandit problem and solve")
("searn", po::value<size_t>(), "use searn, argument=maximum action id or 0 for LDF")
;
@@ -862,6 +864,26 @@ vw* parse_args(int argc, char *argv[])
got_cb = true;
}
+ if (vm.count("cbify") || vm_file.count("cbify"))
+ {
+ if(!got_cs) {
+ if( vm_file.count("cbify") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm_file["cbify"]));
+ else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cbify"]));
+
+ all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified
+ got_cs = true;
+ }
+
+ if (!got_cb) {
+ if( vm_file.count("cbify") ) vm.insert(pair<string,po::variable_value>(string("cb"),vm_file["cbify"]));
+ else vm.insert(pair<string,po::variable_value>(string("cb"),vm["cbify"]));
+ all->l = CB::setup(*all, to_pass_further, vm, vm_file);
+ got_cb = true;
+ }
+
+ all->l = CBIFY::setup(*all, to_pass_further, vm, vm_file);
+ }
+
all->searnstr = NULL;
if (vm.count("searn") || vm_file.count("searn") ) {
if (!got_cs && !got_cb) {
diff --git a/vowpalwabbit/searn.cc b/vowpalwabbit/searn.cc
index 2d38144e..337f9fe0 100644
--- a/vowpalwabbit/searn.cc
+++ b/vowpalwabbit/searn.cc
@@ -405,7 +405,7 @@ namespace Searn {
if (srn.rollout_all_actions)
return choose_random<CSOAA::wclass>(((CSOAA::label*)valid_labels)->costs).weight_index;
else
- return choose_random<CB::cb_class >(((CB::label *)valid_labels)->costs).weight_index;
+ return choose_random<CB::cb_class >(((CB::label *)valid_labels)->costs).action;
} else if (ystar_is_uint32t)
return *((uint32_t*)ystar);
else
@@ -443,7 +443,7 @@ namespace Searn {
CB::label *ret = new CB::label();
v_array<CB::cb_class> costs = ((CB::label*)l)->costs;
for (size_t i=0; i<costs.size(); i++) {
- CB::cb_class c = { costs[i].x, costs[i].weight_index, costs[i].prob_action };
+ CB::cb_class c = { costs[i].cost, costs[i].action, costs[i].probability };
ret->costs.push_back(c);
}
return ret;
@@ -851,7 +851,7 @@ namespace Searn {
if (srn.rollout_all_actions)
return ((CSOAA::label*)l)->costs[i].weight_index;
else
- return ((CB::label*)l)->costs[i].weight_index;
+ return ((CB::label*)l)->costs[i].action;
}
bool should_print_update(vw& all, bool hit_new_pass=false)
@@ -885,7 +885,7 @@ namespace Searn {
if (srn.rollout_all_actions)
((CSOAA::label*)labels)->costs[i].x = losses[i] - min_loss;
else
- ((CB::label*)labels)->costs[i].x = losses[i] - min_loss;
+ ((CB::label*)labels)->costs[i].cost = losses[i] - min_loss;
if (!isLDF(srn)) {
void* old_label = ec[0]->ld;
diff --git a/vowpalwabbit/vwdll.h b/vowpalwabbit/vwdll.h
index 00cbbe44..85fa4f32 100644
--- a/vowpalwabbit/vwdll.h
+++ b/vowpalwabbit/vwdll.h
@@ -58,6 +58,7 @@ extern "C"
VW_DLL_MEMBER void VW_CALLING_CONV VW_AddLabel(VW_EXAMPLE e, float label, float weight, float base);
VW_DLL_MEMBER float VW_CALLING_CONV VW_Get_Weight(VW_HANDLE handle, size_t index, size_t offset);
+ VW_DLL_MEMBER void VW_CALLING_CONV VW_Set_Weight(VW_HANDLE handle, size_t index, size_t offset, float value);
VW_DLL_MEMBER size_t VW_CALLING_CONV VW_Num_Weights(VW_HANDLE handle);
VW_DLL_MEMBER size_t VW_CALLING_CONV VW_Get_Stride(VW_HANDLE handle);