diff options
author | sidsen <sid.sen1@gmail.com> | 2014-11-11 19:45:18 +0300 |
---|---|---|
committer | sidsen <sid.sen1@gmail.com> | 2014-11-11 19:45:18 +0300 |
commit | d4c38ae4284222fbbe51f1152bab0f92df2e2d4e (patch) | |
tree | 8a6d8ec52a939c2b9f5f8fd9c4c4e14eb3a78f22 /explore | |
parent | 31a33a0dc8be432816337b9926da3f681a2ea0d8 (diff) |
Get_Recording method takes bool flush argument
Diffstat (limited to 'explore')
-rw-r--r-- | explore/clr/explore_clr_wrapper.h | 66 | ||||
-rw-r--r-- | explore/explore_sample.cpp | 10 | ||||
-rw-r--r-- | explore/static/MWTExplorer.h | 8 | ||||
-rw-r--r-- | explore/tests/MWTExploreTests.cpp | 1272 |
4 files changed, 680 insertions, 676 deletions
diff --git a/explore/clr/explore_clr_wrapper.h b/explore/clr/explore_clr_wrapper.h index b1ef80a8..427b8be2 100644 --- a/explore/clr/explore_clr_wrapper.h +++ b/explore/clr/explore_clr_wrapper.h @@ -177,53 +177,53 @@ namespace MultiWorldTesting { this->recorder = recorder; } - UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context)
- {
- String^ salt = this->appId;
+ UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context) + { + String^ salt = this->appId; NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder()); GCHandle selfHandle = GCHandle::Alloc(this); - IntPtr selfPtr = (IntPtr)selfHandle;
-
- GCHandle contextHandle = GCHandle::Alloc(context);
- IntPtr contextPtr = (IntPtr)contextHandle;
-
- GCHandle explorerHandle = GCHandle::Alloc(explorer);
- IntPtr explorerPtr = (IntPtr)explorerHandle;
-
- NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer());
- u32 action = 0;
+ IntPtr selfPtr = (IntPtr)selfHandle; + + GCHandle contextHandle = GCHandle::Alloc(context); + IntPtr contextPtr = (IntPtr)contextHandle; + + GCHandle explorerHandle = GCHandle::Alloc(explorer); + IntPtr explorerPtr = (IntPtr)explorerHandle; + + NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer()); + u32 action = 0; if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid) { EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer; - action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
+ action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid) { TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer; - action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
+ action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid) { SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer; - action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
+ action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } else if (explorer->GetType() == GenericExplorer<Ctx>::typeid) { GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer; - action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
+ action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } else if (explorer->GetType() == BaggingExplorer<Ctx>::typeid) { BaggingExplorer<Ctx>^ baggingExplorer = (BaggingExplorer<Ctx>^)explorer; - action = mwt.Choose_Action(*baggingExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
-
- explorerHandle.Free();
- contextHandle.Free();
- selfHandle.Free();
-
- return action;
+ action = mwt.Choose_Action(*baggingExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + + explorerHandle.Free(); + contextHandle.Free(); + selfHandle.Free(); + + return action; } internal: @@ -275,7 +275,7 @@ namespace MultiWorldTesting { /// </returns> String^ FlushRecording() { - return gcnew String(m_string_recorder->Flush_Recording().c_str()); + return gcnew String(m_string_recorder->Get_Recording().c_str()); } private: @@ -291,9 +291,9 @@ namespace MultiWorldTesting { // TODO: add another constructor overload for native SimpleContext to avoid copying feature values m_features = new vector<NativeMultiWorldTesting::Feature>(); - for (int i = 0; i < features->Length; i++)
- {
- m_features->push_back({ features[i].Value, features[i].Id });
+ for (int i = 0; i < features->Length; i++) + { + m_features->push_back({ features[i].Value, features[i].Id }); } m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features); diff --git a/explore/explore_sample.cpp b/explore/explore_sample.cpp index 43da6deb..aff6d1ef 100644 --- a/explore/explore_sample.cpp +++ b/explore/explore_sample.cpp @@ -103,7 +103,7 @@ int main(int argc, char* argv[]) u32 action = mwt.Choose_Action(explorer, unique_key, context); cout << "Chosen action = " << action << endl; - cout << "Exploration record = " << recorder.Flush_Recording(); + cout << "Exploration record = " << recorder.Get_Recording(); } else if (strcmp(argv[1], "tau-first") == 0) { @@ -119,7 +119,7 @@ int main(int argc, char* argv[]) string unique_key = "eventid"; u32 action = mwt.Choose_Action(explorer, unique_key, ctx); - cout << "action = " << action << endl; + cout << "Chosen action = " << action << endl; } else if (strcmp(argv[1], "bagging") == 0) { @@ -139,7 +139,7 @@ int main(int argc, char* argv[]) string unique_key = "eventid"; u32 action = mwt.Choose_Action(explorer, unique_key, ctx); - cout << "action = " << action << endl; + cout << "Chosen action = " << action << endl; } else if (strcmp(argv[1], "softmax") == 0) { @@ -156,7 +156,7 @@ int main(int argc, char* argv[]) string unique_key = "eventid"; u32 action = mwt.Choose_Action(explorer, unique_key, ctx); - cout << "action = " << action << endl; + cout << "Chosen action = " << action << endl; } else if (strcmp(argv[1], "generic") == 0) { @@ -171,7 +171,7 @@ int main(int argc, char* argv[]) string unique_key = "eventid"; u32 action = mwt.Choose_Action(explorer, unique_key, ctx); - cout << "action = " << action << endl; + cout << "Chosen action = " << action << endl; } else { diff --git a/explore/static/MWTExplorer.h b/explore/static/MWTExplorer.h index 6ec8bab6..575a44a2 100644 --- a/explore/static/MWTExplorer.h +++ b/explore/static/MWTExplorer.h @@ -107,9 +107,13 @@ struct StringRecorder : public IRecorder<Ctx> m_recording.append("\n"); } - // Gets the content of recording so far as a string and clears internal content. - string Flush_Recording() + // Gets the content of the recording so far as a string and optionally clears internal content. + string Get_Recording(bool flush = true) { + if (!flush) + { + return m_recording; + } string recording = m_recording; m_recording.clear(); return recording; diff --git a/explore/tests/MWTExploreTests.cpp b/explore/tests/MWTExploreTests.cpp index 6ea35b7a..b3b1ef74 100644 --- a/explore/tests/MWTExploreTests.cpp +++ b/explore/tests/MWTExploreTests.cpp @@ -1,636 +1,636 @@ -#include "CppUnitTest.h"
-#include "MWTExploreTests.h"
-
-using namespace Microsoft::VisualStudio::CppUnitTestFramework;
-
-#define COUNT_INVALID(block) try { block } catch (std::invalid_argument) { num_ex++; }
-#define COUNT_BAD_CALL(block) try { block } catch (std::invalid_argument) { num_ex++; }
-
-namespace vw_explore_tests
-{
- TEST_CLASS(VWExploreUnitTests)
- {
- public:
- TEST_METHOD(Epsilon_Greedy)
- {
- int num_actions = 10;
- float epsilon = 0.f; // No randomization
- string unique_key = "1001";
- int params = 101;
-
- TestPolicy my_policy(params, num_actions);
- TestContext my_context;
- TestRecorder my_recorder;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- EpsilonGreedyExplorer<TestContext> explorer(my_policy, epsilon, num_actions);
-
- u32 expected_action = my_policy.Choose_Action(my_context);
-
- u32 chosen_action = mwt.Choose_Action(explorer, unique_key, my_context);
- Assert::AreEqual(expected_action, chosen_action);
-
- chosen_action = mwt.Choose_Action(explorer, unique_key, my_context);
- Assert::AreEqual(expected_action, chosen_action);
-
- float expected_probs[2] = { 1.f, 1.f };
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- this->Test_Interactions(interactions, 2, expected_probs);
- }
-
- TEST_METHOD(Epsilon_Greedy_Random)
- {
- int num_actions = 10;
- float epsilon = 0.5f; // Verify that about half the time the default policy is chosen
- int params = 101;
-
- TestPolicy my_policy(params, num_actions);
- TestContext my_context;
- TestRecorder my_recorder;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- EpsilonGreedyExplorer<TestContext> explorer(my_policy, epsilon, num_actions);
-
- u32 policy_action = my_policy.Choose_Action(my_context);
-
- int times_choose = 10000;
- int times_policy_action_chosen = 0;
- for (int i = 0; i < times_choose; i++)
- {
- u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i), my_context);
- if (chosen_action == policy_action)
- {
- times_policy_action_chosen++;
- }
- }
-
- Assert::IsTrue(abs((double)times_policy_action_chosen / times_choose - 0.5) < 0.1);
- }
-
- TEST_METHOD(Tau_First)
- {
- int num_actions = 10;
- u32 tau = 0;
- int params = 101;
-
- TestPolicy my_policy(params, num_actions);
- TestRecorder my_recorder;
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- TauFirstExplorer<TestContext> explorer(my_policy, tau, num_actions);
-
- u32 expected_action = my_policy.Choose_Action(my_context);
-
- u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
- Assert::AreEqual(expected_action, chosen_action);
-
- // tau = 0 means no randomization and no logging
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- this->Test_Interactions(interactions, 0, nullptr);
- }
-
- TEST_METHOD(Tau_First_Random)
- {
- int num_actions = 10;
- u32 tau = 2;
- TestPolicy my_policy(99, num_actions);
- TestRecorder my_recorder;
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- TauFirstExplorer<TestContext> explorer(my_policy, tau, num_actions);
-
- u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
- chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
-
- // Tau expired, did not explore
- chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
- Assert::AreEqual((u32)10, chosen_action);
-
- // Only 2 interactions logged, 3rd one should not be stored
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- float expected_probs[2] = { .1f, .1f };
- this->Test_Interactions(interactions, 2, expected_probs);
- }
-
- TEST_METHOD(Bagging)
- {
- int num_actions = 10;
- int params = 101;
- TestRecorder my_recorder;
-
- vector<unique_ptr<IPolicy<TestContext>>> policies;
- policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params, num_actions)));
- policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params + 1, num_actions)));
-
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("c++-test", my_recorder);
- BaggingExplorer<TestContext> explorer(policies, num_actions);
-
- u32 expected_action1 = policies[0]->Choose_Action(my_context);
- u32 expected_action2 = policies[1]->Choose_Action(my_context);
-
- u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
- Assert::AreEqual(expected_action2, chosen_action);
-
- chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
- Assert::AreEqual(expected_action1, chosen_action);
-
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- float expected_probs[2] = { .5f, .5f };
- this->Test_Interactions(interactions, 2, expected_probs);
- }
-
- TEST_METHOD(Bagging_Random)
- {
- int num_actions = 10;
- int params = 101;
- TestRecorder my_recorder;
-
- vector<unique_ptr<IPolicy<TestContext>>> policies;
- policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params, num_actions)));
- policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params + 1, num_actions)));
-
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("c++-test", my_recorder);
- BaggingExplorer<TestContext> explorer(policies, num_actions);
-
- u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
- chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
-
- // Two bags choosing different actions so prob of each is 1/2
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- float expected_probs[2] = { .5f, .5f };
- this->Test_Interactions(interactions, 2, expected_probs);
- }
-
- TEST_METHOD(Softmax)
- {
- int num_actions = 10;
- float lambda = 0.f;
- int scorer_arg = 7;
- u32 NUM_ACTIONS_COVER = 100;
- float C = 5.0f;
-
- TestScorer my_scorer(scorer_arg, num_actions);
- TestRecorder my_recorder;
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- SoftmaxExplorer<TestContext> explorer(my_scorer, lambda, num_actions);
-
- // Scale C up since we have fewer interactions
- u32 num_decisions = num_actions * log(num_actions * 1.0) + log(NUM_ACTIONS_COVER * 1.0 / num_actions) * C * num_actions;
- // The () following the array should ensure zero-initialization
- u32* actions = new u32[num_actions]();
- u32 i;
- for (i = 0; i < num_decisions; i++)
- {
- u32 action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i + 1), my_context);
- // Action IDs are 1-based
- actions[action - 1]++;
- }
- // Ensure all actions are covered
- for (i = 0; i < num_actions; i++)
- {
- Assert::IsTrue(actions[i] > 0);
- }
- float* expected_probs = new float[num_decisions];
- for (i = 0; i < num_decisions; i++)
- {
- // Our default scorer currently assigns equal weight to each action
- expected_probs[i] = 1.0 / num_actions;
- }
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- this->Test_Interactions(interactions, num_decisions, expected_probs);
-
- delete actions;
- delete expected_probs;
- }
-
- TEST_METHOD(Softmax_Scores)
- {
- int num_actions = 10;
- float lambda = 0.5f;
- int scorer_arg = 7;
- TestScorer my_scorer(scorer_arg, num_actions, /* uniform = */ false);
- TestRecorder my_recorder;
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- SoftmaxExplorer<TestContext> explorer(my_scorer, lambda, num_actions);
-
- u32 action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
- action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
- action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
-
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- size_t num_interactions = interactions.size();
-
- Assert::AreEqual(3, (int)num_interactions);
- for (int i = 0; i < num_interactions; i++)
- {
- Assert::AreNotEqual(1.f / num_actions, interactions[i].Probability);
- }
- }
-
- TEST_METHOD(Generic)
- {
- int num_actions = 10;
- int scorer_arg = 7;
- TestScorer my_scorer(scorer_arg, num_actions);
- TestRecorder my_recorder;
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- GenericExplorer<TestContext> explorer(my_scorer, num_actions);
-
- u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
- chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
- chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
-
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- float expected_probs[3] = { .1f, .1f, .1f };
- this->Test_Interactions(interactions, 3, expected_probs);
- }
-
- TEST_METHOD(End_To_End_Epsilon_Greedy)
- {
- int num_actions = 10;
- float epsilon = 0.5f;
- int params = 101;
-
- TestSimplePolicy my_policy(params, num_actions);
- StringRecorder<SimpleContext> my_recorder;
-
- MwtExplorer<SimpleContext> mwt("salt", my_recorder);
- EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions);
-
- this->End_To_End(mwt, explorer, my_recorder);
- }
-
- TEST_METHOD(End_To_End_Tau_First)
- {
- int num_actions = 10;
- u32 tau = 5;
- int params = 101;
-
- TestSimplePolicy my_policy(params, num_actions);
- StringRecorder<SimpleContext> my_recorder;
-
- MwtExplorer<SimpleContext> mwt("salt", my_recorder);
- TauFirstExplorer<SimpleContext> explorer(my_policy, tau, num_actions);
-
- this->End_To_End(mwt, explorer, my_recorder);
- }
-
- TEST_METHOD(End_To_End_Bagging)
- {
- int num_actions = 10;
- u32 bags = 2;
- int params = 101;
- StringRecorder<SimpleContext> my_recorder;
-
- vector<unique_ptr<IPolicy<SimpleContext>>> policies;
- policies.push_back(unique_ptr<IPolicy<SimpleContext>>(new TestSimplePolicy(params, num_actions)));
- policies.push_back(unique_ptr<IPolicy<SimpleContext>>(new TestSimplePolicy(params, num_actions)));
-
- MwtExplorer<SimpleContext> mwt("salt", my_recorder);
- BaggingExplorer<SimpleContext> explorer(policies, num_actions);
-
- this->End_To_End(mwt, explorer, my_recorder);
- }
-
- TEST_METHOD(End_To_End_Softmax)
- {
- int num_actions = 10;
- float lambda = 0.5f;
- int scorer_arg = 7;
- TestSimpleScorer my_scorer(scorer_arg, num_actions);
- StringRecorder<SimpleContext> my_recorder;
-
- MwtExplorer<SimpleContext> mwt("salt", my_recorder);
- SoftmaxExplorer<SimpleContext> explorer(my_scorer, lambda, num_actions);
-
- this->End_To_End(mwt, explorer, my_recorder);
- }
-
- TEST_METHOD(End_To_End_Generic)
- {
- int num_actions = 10;
- int scorer_arg = 7;
- TestSimpleScorer my_scorer(scorer_arg, num_actions);
- StringRecorder<SimpleContext> my_recorder;
-
- MwtExplorer<SimpleContext> mwt("salt", my_recorder);
- GenericExplorer<SimpleContext> explorer(my_scorer, num_actions);
-
- this->End_To_End(mwt, explorer, my_recorder);
- }
-
- TEST_METHOD(PRG_Coverage)
- {
- const u32 NUM_ACTIONS_COVER = 100;
- float C = 5.0f;
-
- // We could use many fewer bits (e.g. u8) per bin since we're throwing uniformly at
- // random, but this is safer in case we change things
- u32 bins[NUM_ACTIONS_COVER] = { 0 };
- u32 num_balls = NUM_ACTIONS_COVER * log(NUM_ACTIONS_COVER) + C * NUM_ACTIONS_COVER;
- PRG::prg rg;
- u32 i;
- for (i = 0; i < num_balls; i++)
- {
- bins[rg.Uniform_Int(0, NUM_ACTIONS_COVER - 1)]++;
- }
- // Ensure all actions are covered
- for (i = 0; i < NUM_ACTIONS_COVER; i++)
- {
- Assert::IsTrue(bins[i] > 0);
- }
- }
-
- TEST_METHOD(Float_To_String)
- {
- PRG::prg rand;
-
- for (int i = 0; i < 10000; i++)
- {
- float f = (rand.Uniform_Unit_Interval() - 0.5f) * rand.Uniform_Int(0, 100000);
-
- ostringstream expected_stream;
- expected_stream << std::fixed << std::setprecision(10) << f;
- string expected_str = expected_stream.str();
-
- char actual_chars[15] = { 0 };
- NumberUtils::Float_To_String(f, actual_chars);
- string actual_str(actual_chars);
-
- size_t length = actual_str.length() - 1;
- int compare_result = expected_str.compare(0, length, actual_str, 0, length);
- Assert::AreEqual(0, compare_result);
- }
- }
-
- TEST_METHOD(Serialized_String)
- {
- int num_actions = 10;
- float epsilon = 0.5f;
- int params = 101;
-
- TestSimplePolicy my_policy(params, num_actions);
-
- StringRecorder<SimpleContext> my_recorder;
- MwtExplorer<SimpleContext> mwt("c++-test", my_recorder);
- EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions);
-
- vector<Feature> features1;
- features1.push_back({ 0.5f, 1 });
- SimpleContext context1(features1);
-
- u32 expected_action = my_policy.Choose_Action(context1);
-
- string unique_key1 = "key1";
- u32 chosen_action1 = mwt.Choose_Action(explorer, unique_key1, context1);
-
- vector<Feature> features2;
- features2.push_back({ -99999.5f, 123456789 });
- features2.push_back({ 1.5f, 39 });
-
- SimpleContext context2(features2);
-
- string unique_key2 = "key2";
- u32 chosen_action2 = mwt.Choose_Action(explorer, unique_key2, context2);
-
- string actual_log = my_recorder.Flush_Recording();
-
- // Use hard-coded string to be independent of sprintf
- char* expected_log = "2 key1 0.55000 | 1:.5\n2 key2 0.55000 | 123456789:-99999.5 39:1.5\n";
-
- Assert::AreEqual(expected_log, actual_log.c_str());
- }
-
- TEST_METHOD(Serialized_String_Random)
- {
- PRG::prg rand;
-
- int num_actions = 10;
- int params = 101;
-
- TestSimplePolicy my_policy(params, num_actions);
-
- char expected_log[100] = { 0 };
-
- for (int i = 0; i < 10000; i++)
- {
- StringRecorder<SimpleContext> my_recorder;
- MwtExplorer<SimpleContext> mwt("c++-test", my_recorder);
- EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, 0.f, num_actions);
-
- Feature feature;
- feature.Value = (rand.Uniform_Unit_Interval() - 0.5f) * rand.Uniform_Int(0, 100000);
- feature.Id = i;
- vector<Feature> features;
- features.push_back(feature);
- SimpleContext my_context(features);
-
- u32 action = mwt.Choose_Action(explorer, "", my_context);
- string actual_log = my_recorder.Flush_Recording();
-
- ostringstream expected_stream;
- expected_stream << std::fixed << std::setprecision(10) << feature.Value;
-
- string expected_str = expected_stream.str();
- if (expected_str[0] == '0')
- {
- expected_str = expected_str.erase(0, 1);
- }
-
- sprintf(expected_log, "%d %s %.5f | %d:%s",
- action, "", 1.f, i, expected_str.c_str());
-
- size_t length = actual_log.length() - 1;
- int compare_result = string(expected_log).compare(0, length, actual_log, 0, length);
- Assert::AreEqual(0, compare_result);
- }
- }
-
- TEST_METHOD(Usage_Bad_Arguments)
- {
- int num_ex = 0;
- int params = 101;
- TestPolicy my_policy(params, 0);
- TestScorer my_scorer(params, 0);
- vector<unique_ptr<IPolicy<TestContext>>> policies;
-
- COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, .5f, 0);) // Invalid # actions, must be > 0
- COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, 1.5f, 10);) // Invalid epsilon, must be in [0,1]
- COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, -.5f, 10);) // Invalid epsilon, must be in [0,1]
-
- COUNT_INVALID(BaggingExplorer<TestContext> explorer(policies, 0);) // Invalid # actions, must be > 0
- COUNT_INVALID(BaggingExplorer<TestContext> explorer(policies, 1);) // Invalid # bags, must be > 0
-
- COUNT_INVALID(TauFirstExplorer<TestContext> explorer(my_policy, 1, 0);) // Invalid # actions, must be > 0
- COUNT_INVALID(SoftmaxExplorer<TestContext> explorer(my_scorer, .5f, 0);) // Invalid # actions, must be > 0
- COUNT_INVALID(GenericExplorer<TestContext> explorer(my_scorer, 0);) // Invalid # actions, must be > 0
-
-
- Assert::AreEqual(8, num_ex);
- }
-
- TEST_METHOD(Usage_Bad_Policy)
- {
- int num_ex = 0;
-
- // Default policy returns action outside valid range
- COUNT_BAD_CALL
- (
- MwtExplorer<TestContext> mwt("salt", TestRecorder());
- EpsilonGreedyExplorer<TestContext> explorer(TestBadPolicy(), 0.f, (u32)1);
-
- u32 expected_action = mwt.Choose_Action(explorer, "1001", TestContext());
- )
- COUNT_BAD_CALL
- (
- MwtExplorer<TestContext> mwt("salt", TestRecorder());
- TauFirstExplorer<TestContext> explorer(TestBadPolicy(), (u32)0, (u32)1);
- mwt.Choose_Action(explorer, "test", TestContext());
- )
- COUNT_BAD_CALL
- (
- vector<unique_ptr<IPolicy<TestContext>>> policies;
- policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestBadPolicy()));
- MwtExplorer<TestContext> mwt("salt", TestRecorder());
- BaggingExplorer<TestContext> explorer(policies, (u32)1);
- mwt.Choose_Action(explorer, "test", TestContext());
- )
- Assert::AreEqual(3, num_ex);
- }
-
- TEST_METHOD(Usage_Bad_Scorer)
- {
- int num_ex = 0;
-
- // Default policy returns action outside valid range
- COUNT_BAD_CALL
- (
- u32 num_actions = 1;
- FixedScorer scorer(num_actions, -1);
- MwtExplorer<TestContext> mwt("salt", TestRecorder());
- GenericExplorer<TestContext> explorer(scorer, num_actions);
- mwt.Choose_Action(explorer, "test", TestContext());
- )
- COUNT_BAD_CALL
- (
- u32 num_actions = 1;
- FixedScorer scorer(num_actions, 0);
- MwtExplorer<TestContext> mwt("salt", TestRecorder());
- GenericExplorer<TestContext> explorer(scorer, num_actions);
- mwt.Choose_Action(explorer, "test", TestContext());
- )
- Assert::AreEqual(2, num_ex);
- }
-
- TEST_METHOD(Custom_Context)
- {
- int num_actions = 10;
- float epsilon = 0.f; // No randomization
- string unique_key = "1001";
-
- TestSimplePolicy my_policy(0, num_actions);
-
- TestSimpleRecorder my_recorder;
- MwtExplorer<SimpleContext> mwt("salt", my_recorder);
-
- vector<Feature> features;
- features.push_back({ 0.5f, 1 });
- features.push_back({ 1.5f, 6 });
- features.push_back({ -5.3f, 13 });
- SimpleContext custom_context(features);
-
- EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions);
-
- u32 chosen_action = mwt.Choose_Action(explorer, unique_key, custom_context);
- Assert::AreEqual((u32)1, chosen_action);
-
- float expected_probs[1] = { 1.f };
-
- vector<TestInteraction<SimpleContext>> interactions = my_recorder.Get_All_Interactions();
- Assert::AreEqual(1, (int)interactions.size());
-
- SimpleContext* returned_context = &interactions[0].Context;
-
- size_t onf = features.size();
- Feature* of = &features[0];
-
- vector<Feature>& returned_features = returned_context->Get_Features();
- size_t rnf = returned_features.size();
- Feature* rf = &returned_features[0];
-
- Assert::AreEqual(rnf, onf);
- for (size_t i = 0; i < rnf; i++)
- {
- Assert::AreEqual(of[i].Id, rf[i].Id);
- Assert::AreEqual(of[i].Value, rf[i].Value);
- }
- }
-
- TEST_METHOD_INITIALIZE(Test_Initialize)
- {
- }
-
- TEST_METHOD_CLEANUP(Test_Cleanup)
- {
- }
-
- private:
- // Test end-to-end using StringRecorder with no crash
- template <class Exp>
- void End_To_End(MwtExplorer<SimpleContext>& mwt, Exp& explorer, StringRecorder<SimpleContext>& recorder)
- {
- PRG::prg rand;
-
- float rewards[10];
- for (int i = 0; i < 10; i++)
- {
- vector<Feature> features;
- for (int j = 0; j < 1000; j++)
- {
- features.push_back({ rand.Uniform_Unit_Interval(), j + 1 });
- }
- SimpleContext c(features);
-
- mwt.Choose_Action(explorer, to_string(i), c);
-
- rewards[i] = rand.Uniform_Unit_Interval();
- }
-
- recorder.Flush_Recording();
- }
-
- template <class Ctx>
- inline void Test_Interactions(vector<TestInteraction<Ctx>> interactions, int num_interactions_expected, float* probs_expected)
- {
- size_t num_interactions = interactions.size();
-
- Assert::AreEqual(num_interactions_expected, (int)num_interactions);
- for (int i = 0; i < num_interactions; i++)
- {
- Assert::AreEqual(probs_expected[i], interactions[i].Probability);
- }
- }
-
- string Get_Unique_Key(u32 seed)
- {
- PRG::prg rg(seed);
-
- std::ostringstream unique_key_container;
- unique_key_container << rg.Uniform_Unit_Interval();
-
- return unique_key_container.str();
- }
- };
-}
+#include "CppUnitTest.h" +#include "MWTExploreTests.h" + +using namespace Microsoft::VisualStudio::CppUnitTestFramework; + +#define COUNT_INVALID(block) try { block } catch (std::invalid_argument) { num_ex++; } +#define COUNT_BAD_CALL(block) try { block } catch (std::invalid_argument) { num_ex++; } + +namespace vw_explore_tests +{ + TEST_CLASS(VWExploreUnitTests) + { + public: + TEST_METHOD(Epsilon_Greedy) + { + int num_actions = 10; + float epsilon = 0.f; // No randomization + string unique_key = "1001"; + int params = 101; + + TestPolicy my_policy(params, num_actions); + TestContext my_context; + TestRecorder my_recorder; + + MwtExplorer<TestContext> mwt("salt", my_recorder); + EpsilonGreedyExplorer<TestContext> explorer(my_policy, epsilon, num_actions); + + u32 expected_action = my_policy.Choose_Action(my_context); + + u32 chosen_action = mwt.Choose_Action(explorer, unique_key, my_context); + Assert::AreEqual(expected_action, chosen_action); + + chosen_action = mwt.Choose_Action(explorer, unique_key, my_context); + Assert::AreEqual(expected_action, chosen_action); + + float expected_probs[2] = { 1.f, 1.f }; + vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions(); + this->Test_Interactions(interactions, 2, expected_probs); + } + + TEST_METHOD(Epsilon_Greedy_Random) + { + int num_actions = 10; + float epsilon = 0.5f; // Verify that about half the time the default policy is chosen + int params = 101; + + TestPolicy my_policy(params, num_actions); + TestContext my_context; + TestRecorder my_recorder; + + MwtExplorer<TestContext> mwt("salt", my_recorder); + EpsilonGreedyExplorer<TestContext> explorer(my_policy, epsilon, num_actions); + + u32 policy_action = my_policy.Choose_Action(my_context); + + int times_choose = 10000; + int times_policy_action_chosen = 0; + for (int i = 0; i < times_choose; i++) + { + u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i), my_context); + if (chosen_action == policy_action) + { + times_policy_action_chosen++; + } + } + + Assert::IsTrue(abs((double)times_policy_action_chosen / times_choose - 0.5) < 0.1); + } + + TEST_METHOD(Tau_First) + { + int num_actions = 10; + u32 tau = 0; + int params = 101; + + TestPolicy my_policy(params, num_actions); + TestRecorder my_recorder; + TestContext my_context; + + MwtExplorer<TestContext> mwt("salt", my_recorder); + TauFirstExplorer<TestContext> explorer(my_policy, tau, num_actions); + + u32 expected_action = my_policy.Choose_Action(my_context); + + u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context); + Assert::AreEqual(expected_action, chosen_action); + + // tau = 0 means no randomization and no logging + vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions(); + this->Test_Interactions(interactions, 0, nullptr); + } + + TEST_METHOD(Tau_First_Random) + { + int num_actions = 10; + u32 tau = 2; + TestPolicy my_policy(99, num_actions); + TestRecorder my_recorder; + TestContext my_context; + + MwtExplorer<TestContext> mwt("salt", my_recorder); + TauFirstExplorer<TestContext> explorer(my_policy, tau, num_actions); + + u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context); + chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context); + + // Tau expired, did not explore + chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context); + Assert::AreEqual((u32)10, chosen_action); + + // Only 2 interactions logged, 3rd one should not be stored + vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions(); + float expected_probs[2] = { .1f, .1f }; + this->Test_Interactions(interactions, 2, expected_probs); + } + + TEST_METHOD(Bagging) + { + int num_actions = 10; + int params = 101; + TestRecorder my_recorder; + + vector<unique_ptr<IPolicy<TestContext>>> policies; + policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params, num_actions))); + policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params + 1, num_actions))); + + TestContext my_context; + + MwtExplorer<TestContext> mwt("c++-test", my_recorder); + BaggingExplorer<TestContext> explorer(policies, num_actions); + + u32 expected_action1 = policies[0]->Choose_Action(my_context); + u32 expected_action2 = policies[1]->Choose_Action(my_context); + + u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context); + Assert::AreEqual(expected_action2, chosen_action); + + chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context); + Assert::AreEqual(expected_action1, chosen_action); + + vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions(); + float expected_probs[2] = { .5f, .5f }; + this->Test_Interactions(interactions, 2, expected_probs); + } + + TEST_METHOD(Bagging_Random) + { + int num_actions = 10; + int params = 101; + TestRecorder my_recorder; + + vector<unique_ptr<IPolicy<TestContext>>> policies; + policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params, num_actions))); + policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params + 1, num_actions))); + + TestContext my_context; + + MwtExplorer<TestContext> mwt("c++-test", my_recorder); + BaggingExplorer<TestContext> explorer(policies, num_actions); + + u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context); + chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context); + + // Two bags choosing different actions so prob of each is 1/2 + vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions(); + float expected_probs[2] = { .5f, .5f }; + this->Test_Interactions(interactions, 2, expected_probs); + } + + TEST_METHOD(Softmax) + { + int num_actions = 10; + float lambda = 0.f; + int scorer_arg = 7; + u32 NUM_ACTIONS_COVER = 100; + float C = 5.0f; + + TestScorer my_scorer(scorer_arg, num_actions); + TestRecorder my_recorder; + TestContext my_context; + + MwtExplorer<TestContext> mwt("salt", my_recorder); + SoftmaxExplorer<TestContext> explorer(my_scorer, lambda, num_actions); + + // Scale C up since we have fewer interactions + u32 num_decisions = num_actions * log(num_actions * 1.0) + log(NUM_ACTIONS_COVER * 1.0 / num_actions) * C * num_actions; + // The () following the array should ensure zero-initialization + u32* actions = new u32[num_actions](); + u32 i; + for (i = 0; i < num_decisions; i++) + { + u32 action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i + 1), my_context); + // Action IDs are 1-based + actions[action - 1]++; + } + // Ensure all actions are covered + for (i = 0; i < num_actions; i++) + { + Assert::IsTrue(actions[i] > 0); + } + float* expected_probs = new float[num_decisions]; + for (i = 0; i < num_decisions; i++) + { + // Our default scorer currently assigns equal weight to each action + expected_probs[i] = 1.0 / num_actions; + } + vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions(); + this->Test_Interactions(interactions, num_decisions, expected_probs); + + delete actions; + delete expected_probs; + } + + TEST_METHOD(Softmax_Scores) + { + int num_actions = 10; + float lambda = 0.5f; + int scorer_arg = 7; + TestScorer my_scorer(scorer_arg, num_actions, /* uniform = */ false); + TestRecorder my_recorder; + TestContext my_context; + + MwtExplorer<TestContext> mwt("salt", my_recorder); + SoftmaxExplorer<TestContext> explorer(my_scorer, lambda, num_actions); + + u32 action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context); + action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context); + action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context); + + vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions(); + size_t num_interactions = interactions.size(); + + Assert::AreEqual(3, (int)num_interactions); + for (int i = 0; i < num_interactions; i++) + { + Assert::AreNotEqual(1.f / num_actions, interactions[i].Probability); + } + } + + TEST_METHOD(Generic) + { + int num_actions = 10; + int scorer_arg = 7; + TestScorer my_scorer(scorer_arg, num_actions); + TestRecorder my_recorder; + TestContext my_context; + + MwtExplorer<TestContext> mwt("salt", my_recorder); + GenericExplorer<TestContext> explorer(my_scorer, num_actions); + + u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context); + chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context); + chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context); + + vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions(); + float expected_probs[3] = { .1f, .1f, .1f }; + this->Test_Interactions(interactions, 3, expected_probs); + } + + TEST_METHOD(End_To_End_Epsilon_Greedy) + { + int num_actions = 10; + float epsilon = 0.5f; + int params = 101; + + TestSimplePolicy my_policy(params, num_actions); + StringRecorder<SimpleContext> my_recorder; + + MwtExplorer<SimpleContext> mwt("salt", my_recorder); + EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions); + + this->End_To_End(mwt, explorer, my_recorder); + } + + TEST_METHOD(End_To_End_Tau_First) + { + int num_actions = 10; + u32 tau = 5; + int params = 101; + + TestSimplePolicy my_policy(params, num_actions); + StringRecorder<SimpleContext> my_recorder; + + MwtExplorer<SimpleContext> mwt("salt", my_recorder); + TauFirstExplorer<SimpleContext> explorer(my_policy, tau, num_actions); + + this->End_To_End(mwt, explorer, my_recorder); + } + + TEST_METHOD(End_To_End_Bagging) + { + int num_actions = 10; + u32 bags = 2; + int params = 101; + StringRecorder<SimpleContext> my_recorder; + + vector<unique_ptr<IPolicy<SimpleContext>>> policies; + policies.push_back(unique_ptr<IPolicy<SimpleContext>>(new TestSimplePolicy(params, num_actions))); + policies.push_back(unique_ptr<IPolicy<SimpleContext>>(new TestSimplePolicy(params, num_actions))); + + MwtExplorer<SimpleContext> mwt("salt", my_recorder); + BaggingExplorer<SimpleContext> explorer(policies, num_actions); + + this->End_To_End(mwt, explorer, my_recorder); + } + + TEST_METHOD(End_To_End_Softmax) + { + int num_actions = 10; + float lambda = 0.5f; + int scorer_arg = 7; + TestSimpleScorer my_scorer(scorer_arg, num_actions); + StringRecorder<SimpleContext> my_recorder; + + MwtExplorer<SimpleContext> mwt("salt", my_recorder); + SoftmaxExplorer<SimpleContext> explorer(my_scorer, lambda, num_actions); + + this->End_To_End(mwt, explorer, my_recorder); + } + + TEST_METHOD(End_To_End_Generic) + { + int num_actions = 10; + int scorer_arg = 7; + TestSimpleScorer my_scorer(scorer_arg, num_actions); + StringRecorder<SimpleContext> my_recorder; + + MwtExplorer<SimpleContext> mwt("salt", my_recorder); + GenericExplorer<SimpleContext> explorer(my_scorer, num_actions); + + this->End_To_End(mwt, explorer, my_recorder); + } + + TEST_METHOD(PRG_Coverage) + { + const u32 NUM_ACTIONS_COVER = 100; + float C = 5.0f; + + // We could use many fewer bits (e.g. u8) per bin since we're throwing uniformly at + // random, but this is safer in case we change things + u32 bins[NUM_ACTIONS_COVER] = { 0 }; + u32 num_balls = NUM_ACTIONS_COVER * log(NUM_ACTIONS_COVER) + C * NUM_ACTIONS_COVER; + PRG::prg rg; + u32 i; + for (i = 0; i < num_balls; i++) + { + bins[rg.Uniform_Int(0, NUM_ACTIONS_COVER - 1)]++; + } + // Ensure all actions are covered + for (i = 0; i < NUM_ACTIONS_COVER; i++) + { + Assert::IsTrue(bins[i] > 0); + } + } + + TEST_METHOD(Float_To_String) + { + PRG::prg rand; + + for (int i = 0; i < 10000; i++) + { + float f = (rand.Uniform_Unit_Interval() - 0.5f) * rand.Uniform_Int(0, 100000); + + ostringstream expected_stream; + expected_stream << std::fixed << std::setprecision(10) << f; + string expected_str = expected_stream.str(); + + char actual_chars[15] = { 0 }; + NumberUtils::Float_To_String(f, actual_chars); + string actual_str(actual_chars); + + size_t length = actual_str.length() - 1; + int compare_result = expected_str.compare(0, length, actual_str, 0, length); + Assert::AreEqual(0, compare_result); + } + } + + TEST_METHOD(Serialized_String) + { + int num_actions = 10; + float epsilon = 0.5f; + int params = 101; + + TestSimplePolicy my_policy(params, num_actions); + + StringRecorder<SimpleContext> my_recorder; + MwtExplorer<SimpleContext> mwt("c++-test", my_recorder); + EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions); + + vector<Feature> features1; + features1.push_back({ 0.5f, 1 }); + SimpleContext context1(features1); + + u32 expected_action = my_policy.Choose_Action(context1); + + string unique_key1 = "key1"; + u32 chosen_action1 = mwt.Choose_Action(explorer, unique_key1, context1); + + vector<Feature> features2; + features2.push_back({ -99999.5f, 123456789 }); + features2.push_back({ 1.5f, 39 }); + + SimpleContext context2(features2); + + string unique_key2 = "key2"; + u32 chosen_action2 = mwt.Choose_Action(explorer, unique_key2, context2); + + string actual_log = my_recorder.Get_Recording(); + + // Use hard-coded string to be independent of sprintf + char* expected_log = "2 key1 0.55000 | 1:.5\n2 key2 0.55000 | 123456789:-99999.5 39:1.5\n"; + + Assert::AreEqual(expected_log, actual_log.c_str()); + } + + TEST_METHOD(Serialized_String_Random) + { + PRG::prg rand; + + int num_actions = 10; + int params = 101; + + TestSimplePolicy my_policy(params, num_actions); + + char expected_log[100] = { 0 }; + + for (int i = 0; i < 10000; i++) + { + StringRecorder<SimpleContext> my_recorder; + MwtExplorer<SimpleContext> mwt("c++-test", my_recorder); + EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, 0.f, num_actions); + + Feature feature; + feature.Value = (rand.Uniform_Unit_Interval() - 0.5f) * rand.Uniform_Int(0, 100000); + feature.Id = i; + vector<Feature> features; + features.push_back(feature); + SimpleContext my_context(features); + + u32 action = mwt.Choose_Action(explorer, "", my_context); + string actual_log = my_recorder.Get_Recording(); + + ostringstream expected_stream; + expected_stream << std::fixed << std::setprecision(10) << feature.Value; + + string expected_str = expected_stream.str(); + if (expected_str[0] == '0') + { + expected_str = expected_str.erase(0, 1); + } + + sprintf(expected_log, "%d %s %.5f | %d:%s", + action, "", 1.f, i, expected_str.c_str()); + + size_t length = actual_log.length() - 1; + int compare_result = string(expected_log).compare(0, length, actual_log, 0, length); + Assert::AreEqual(0, compare_result); + } + } + + TEST_METHOD(Usage_Bad_Arguments) + { + int num_ex = 0; + int params = 101; + TestPolicy my_policy(params, 0); + TestScorer my_scorer(params, 0); + vector<unique_ptr<IPolicy<TestContext>>> policies; + + COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, .5f, 0);) // Invalid # actions, must be > 0 + COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, 1.5f, 10);) // Invalid epsilon, must be in [0,1] + COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, -.5f, 10);) // Invalid epsilon, must be in [0,1] + + COUNT_INVALID(BaggingExplorer<TestContext> explorer(policies, 0);) // Invalid # actions, must be > 0 + COUNT_INVALID(BaggingExplorer<TestContext> explorer(policies, 1);) // Invalid # bags, must be > 0 + + COUNT_INVALID(TauFirstExplorer<TestContext> explorer(my_policy, 1, 0);) // Invalid # actions, must be > 0 + COUNT_INVALID(SoftmaxExplorer<TestContext> explorer(my_scorer, .5f, 0);) // Invalid # actions, must be > 0 + COUNT_INVALID(GenericExplorer<TestContext> explorer(my_scorer, 0);) // Invalid # actions, must be > 0 + + + Assert::AreEqual(8, num_ex); + } + + TEST_METHOD(Usage_Bad_Policy) + { + int num_ex = 0; + + // Default policy returns action outside valid range + COUNT_BAD_CALL + ( + MwtExplorer<TestContext> mwt("salt", TestRecorder()); + EpsilonGreedyExplorer<TestContext> explorer(TestBadPolicy(), 0.f, (u32)1); + + u32 expected_action = mwt.Choose_Action(explorer, "1001", TestContext()); + ) + COUNT_BAD_CALL + ( + MwtExplorer<TestContext> mwt("salt", TestRecorder()); + TauFirstExplorer<TestContext> explorer(TestBadPolicy(), (u32)0, (u32)1); + mwt.Choose_Action(explorer, "test", TestContext()); + ) + COUNT_BAD_CALL + ( + vector<unique_ptr<IPolicy<TestContext>>> policies; + policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestBadPolicy())); + MwtExplorer<TestContext> mwt("salt", TestRecorder()); + BaggingExplorer<TestContext> explorer(policies, (u32)1); + mwt.Choose_Action(explorer, "test", TestContext()); + ) + Assert::AreEqual(3, num_ex); + } + + TEST_METHOD(Usage_Bad_Scorer) + { + int num_ex = 0; + + // Default policy returns action outside valid range + COUNT_BAD_CALL + ( + u32 num_actions = 1; + FixedScorer scorer(num_actions, -1); + MwtExplorer<TestContext> mwt("salt", TestRecorder()); + GenericExplorer<TestContext> explorer(scorer, num_actions); + mwt.Choose_Action(explorer, "test", TestContext()); + ) + COUNT_BAD_CALL + ( + u32 num_actions = 1; + FixedScorer scorer(num_actions, 0); + MwtExplorer<TestContext> mwt("salt", TestRecorder()); + GenericExplorer<TestContext> explorer(scorer, num_actions); + mwt.Choose_Action(explorer, "test", TestContext()); + ) + Assert::AreEqual(2, num_ex); + } + + TEST_METHOD(Custom_Context) + { + int num_actions = 10; + float epsilon = 0.f; // No randomization + string unique_key = "1001"; + + TestSimplePolicy my_policy(0, num_actions); + + TestSimpleRecorder my_recorder; + MwtExplorer<SimpleContext> mwt("salt", my_recorder); + + vector<Feature> features; + features.push_back({ 0.5f, 1 }); + features.push_back({ 1.5f, 6 }); + features.push_back({ -5.3f, 13 }); + SimpleContext custom_context(features); + + EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions); + + u32 chosen_action = mwt.Choose_Action(explorer, unique_key, custom_context); + Assert::AreEqual((u32)1, chosen_action); + + float expected_probs[1] = { 1.f }; + + vector<TestInteraction<SimpleContext>> interactions = my_recorder.Get_All_Interactions(); + Assert::AreEqual(1, (int)interactions.size()); + + SimpleContext* returned_context = &interactions[0].Context; + + size_t onf = features.size(); + Feature* of = &features[0]; + + vector<Feature>& returned_features = returned_context->Get_Features(); + size_t rnf = returned_features.size(); + Feature* rf = &returned_features[0]; + + Assert::AreEqual(rnf, onf); + for (size_t i = 0; i < rnf; i++) + { + Assert::AreEqual(of[i].Id, rf[i].Id); + Assert::AreEqual(of[i].Value, rf[i].Value); + } + } + + TEST_METHOD_INITIALIZE(Test_Initialize) + { + } + + TEST_METHOD_CLEANUP(Test_Cleanup) + { + } + + private: + // Test end-to-end using StringRecorder with no crash + template <class Exp> + void End_To_End(MwtExplorer<SimpleContext>& mwt, Exp& explorer, StringRecorder<SimpleContext>& recorder) + { + PRG::prg rand; + + float rewards[10]; + for (int i = 0; i < 10; i++) + { + vector<Feature> features; + for (int j = 0; j < 1000; j++) + { + features.push_back({ rand.Uniform_Unit_Interval(), j + 1 }); + } + SimpleContext c(features); + + mwt.Choose_Action(explorer, to_string(i), c); + + rewards[i] = rand.Uniform_Unit_Interval(); + } + + recorder.Get_Recording(); + } + + template <class Ctx> + inline void Test_Interactions(vector<TestInteraction<Ctx>> interactions, int num_interactions_expected, float* probs_expected) + { + size_t num_interactions = interactions.size(); + + Assert::AreEqual(num_interactions_expected, (int)num_interactions); + for (int i = 0; i < num_interactions; i++) + { + Assert::AreEqual(probs_expected[i], interactions[i].Probability); + } + } + + string Get_Unique_Key(u32 seed) + { + PRG::prg rg(seed); + + std::ostringstream unique_key_container; + unique_key_container << rg.Uniform_Unit_Interval(); + + return unique_key_container.str(); + } + }; +} |