// vw_explore.cpp : Defines the entry point for the console application. // #include "MWTExplorer.h" #include #include #include using namespace std; using namespace std::chrono; using namespace MultiWorldTesting; class MyContext { }; class MyPolicy : public IPolicy { public: u32 Choose_Action(MyContext& context) { return (u32)1; } }; class MySimplePolicy : public IPolicy { public: u32 Choose_Action(SimpleContext& context) { return (u32)1; } }; class MyScorer : public IScorer { public: MyScorer(u32 num_actions) : m_num_actions(num_actions) { } vector Score_Actions(MyContext& context) { vector scores; for (size_t i = 0; i < m_num_actions; i++) { scores.push_back(.1f); } return scores; } private: u32 m_num_actions; }; template struct MyInteraction { Ctx Context; u32 Action; float Probability; string Unique_Key; }; class MyRecorder : public IRecorder { public: virtual void Record(MyContext& context, u32 action, float probability, string unique_key) { m_interactions.push_back({ context, action, probability, unique_key }); } private: vector> m_interactions; }; int main(int argc, char* argv[]) { if (argc < 2) { cerr << "arguments: {greedy,tau-first,bagging,softmax,generic}" << endl; exit(1); } //arguments for individual explorers if (strcmp(argv[1], "greedy") == 0) { //Initialize Epsilon-Greedy explore algorithm using MyPolicy vector features; features.push_back({ 0.5f, 1 }); features.push_back({ 1.3f, 11 }); features.push_back({ -.95f, 413 }); SimpleContext context(features); StringRecorder recorder; MySimplePolicy default_policy; MwtExplorer mwt("appid", recorder); u32 num_actions = 10; float epsilon = .2f; EpsilonGreedyExplorer explorer(default_policy, epsilon, num_actions); string unique_key = "eventid"; u32 action = mwt.Choose_Action(explorer, unique_key, context); cout << "action = " << action << " recorder = " << recorder.Get_Recording(); } else if (strcmp(argv[1], "tau-first") == 0) { //Initialize Tau-First explore algorithm using MyPolicy MyRecorder recorder; MwtExplorer mwt("appid", recorder); int num_actions = 10; u32 tau = 5; MyPolicy default_policy; TauFirstExplorer explorer(default_policy, tau, num_actions); MyContext ctx; string unique_key = "eventid"; u32 action = mwt.Choose_Action(explorer, unique_key, ctx); cout << "action = " << action << endl; } else if (strcmp(argv[1], "bagging") == 0) { //Initialize Bagging explore algorithm using MyPolicy MyRecorder recorder; MwtExplorer mwt("appid", recorder); u32 num_bags = 2; vector>> policy_functions; for (size_t i = 0; i < num_bags; i++) { policy_functions.push_back(unique_ptr>(new MyPolicy())); } int num_actions = 10; BaggingExplorer explorer(policy_functions, num_actions); MyContext ctx; string unique_key = "eventid"; u32 action = mwt.Choose_Action(explorer, unique_key, ctx); cout << "action = " << action << endl; } else if (strcmp(argv[1], "softmax") == 0) { //Initialize Softmax explore algorithm using MyScorer MyRecorder recorder; MwtExplorer mwt("salt", recorder); u32 num_actions = 10; MyScorer scorer(num_actions); float lambda = 0.5f; SoftmaxExplorer explorer(scorer, lambda, num_actions); MyContext ctx; string unique_key = "eventid"; u32 action = mwt.Choose_Action(explorer, unique_key, ctx); cout << "action = " << action << endl; } else if (strcmp(argv[1], "generic") == 0) { //Initialize Generic explore algorithm using MyScorer MyRecorder recorder; MwtExplorer mwt("appid", recorder); int num_actions = 10; MyScorer scorer(num_actions); GenericExplorer explorer(scorer, num_actions); MyContext ctx; string unique_key = "eventid"; u32 action = mwt.Choose_Action(explorer, unique_key, ctx); cout << "action = " << action << endl; } else { cerr << "unknown exploration type: " << argv[1] << endl; exit(1); } return 0; }