Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsidsen <sid.sen1@gmail.com>2014-11-11 19:45:18 +0300
committersidsen <sid.sen1@gmail.com>2014-11-11 19:45:18 +0300
commitd4c38ae4284222fbbe51f1152bab0f92df2e2d4e (patch)
tree8a6d8ec52a939c2b9f5f8fd9c4c4e14eb3a78f22 /explore
parent31a33a0dc8be432816337b9926da3f681a2ea0d8 (diff)
Get_Recording method takes bool flush argument
Diffstat (limited to 'explore')
-rw-r--r--explore/clr/explore_clr_wrapper.h66
-rw-r--r--explore/explore_sample.cpp10
-rw-r--r--explore/static/MWTExplorer.h8
-rw-r--r--explore/tests/MWTExploreTests.cpp1272
4 files changed, 680 insertions, 676 deletions
diff --git a/explore/clr/explore_clr_wrapper.h b/explore/clr/explore_clr_wrapper.h
index b1ef80a8..427b8be2 100644
--- a/explore/clr/explore_clr_wrapper.h
+++ b/explore/clr/explore_clr_wrapper.h
@@ -177,53 +177,53 @@ namespace MultiWorldTesting {
this->recorder = recorder;
}
- UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context)
- {
- String^ salt = this->appId;
+ UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context)
+ {
+ String^ salt = this->appId;
NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder());
GCHandle selfHandle = GCHandle::Alloc(this);
- IntPtr selfPtr = (IntPtr)selfHandle;
-
- GCHandle contextHandle = GCHandle::Alloc(context);
- IntPtr contextPtr = (IntPtr)contextHandle;
-
- GCHandle explorerHandle = GCHandle::Alloc(explorer);
- IntPtr explorerPtr = (IntPtr)explorerHandle;
-
- NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer());
- u32 action = 0;
+ IntPtr selfPtr = (IntPtr)selfHandle;
+
+ GCHandle contextHandle = GCHandle::Alloc(context);
+ IntPtr contextPtr = (IntPtr)contextHandle;
+
+ GCHandle explorerHandle = GCHandle::Alloc(explorer);
+ IntPtr explorerPtr = (IntPtr)explorerHandle;
+
+ NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer());
+ u32 action = 0;
if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid)
{
EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
+ action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid)
{
TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
+ action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid)
{
SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
+ action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
else if (explorer->GetType() == GenericExplorer<Ctx>::typeid)
{
GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
+ action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
else if (explorer->GetType() == BaggingExplorer<Ctx>::typeid)
{
BaggingExplorer<Ctx>^ baggingExplorer = (BaggingExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*baggingExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
-
- explorerHandle.Free();
- contextHandle.Free();
- selfHandle.Free();
-
- return action;
+ action = mwt.Choose_Action(*baggingExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+
+ explorerHandle.Free();
+ contextHandle.Free();
+ selfHandle.Free();
+
+ return action;
}
internal:
@@ -275,7 +275,7 @@ namespace MultiWorldTesting {
/// </returns>
String^ FlushRecording()
{
- return gcnew String(m_string_recorder->Flush_Recording().c_str());
+ return gcnew String(m_string_recorder->Get_Recording().c_str());
}
private:
@@ -291,9 +291,9 @@ namespace MultiWorldTesting {
// TODO: add another constructor overload for native SimpleContext to avoid copying feature values
m_features = new vector<NativeMultiWorldTesting::Feature>();
- for (int i = 0; i < features->Length; i++)
- {
- m_features->push_back({ features[i].Value, features[i].Id });
+ for (int i = 0; i < features->Length; i++)
+ {
+ m_features->push_back({ features[i].Value, features[i].Id });
}
m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features);
diff --git a/explore/explore_sample.cpp b/explore/explore_sample.cpp
index 43da6deb..aff6d1ef 100644
--- a/explore/explore_sample.cpp
+++ b/explore/explore_sample.cpp
@@ -103,7 +103,7 @@ int main(int argc, char* argv[])
u32 action = mwt.Choose_Action(explorer, unique_key, context);
cout << "Chosen action = " << action << endl;
- cout << "Exploration record = " << recorder.Flush_Recording();
+ cout << "Exploration record = " << recorder.Get_Recording();
}
else if (strcmp(argv[1], "tau-first") == 0)
{
@@ -119,7 +119,7 @@ int main(int argc, char* argv[])
string unique_key = "eventid";
u32 action = mwt.Choose_Action(explorer, unique_key, ctx);
- cout << "action = " << action << endl;
+ cout << "Chosen action = " << action << endl;
}
else if (strcmp(argv[1], "bagging") == 0)
{
@@ -139,7 +139,7 @@ int main(int argc, char* argv[])
string unique_key = "eventid";
u32 action = mwt.Choose_Action(explorer, unique_key, ctx);
- cout << "action = " << action << endl;
+ cout << "Chosen action = " << action << endl;
}
else if (strcmp(argv[1], "softmax") == 0)
{
@@ -156,7 +156,7 @@ int main(int argc, char* argv[])
string unique_key = "eventid";
u32 action = mwt.Choose_Action(explorer, unique_key, ctx);
- cout << "action = " << action << endl;
+ cout << "Chosen action = " << action << endl;
}
else if (strcmp(argv[1], "generic") == 0)
{
@@ -171,7 +171,7 @@ int main(int argc, char* argv[])
string unique_key = "eventid";
u32 action = mwt.Choose_Action(explorer, unique_key, ctx);
- cout << "action = " << action << endl;
+ cout << "Chosen action = " << action << endl;
}
else
{
diff --git a/explore/static/MWTExplorer.h b/explore/static/MWTExplorer.h
index 6ec8bab6..575a44a2 100644
--- a/explore/static/MWTExplorer.h
+++ b/explore/static/MWTExplorer.h
@@ -107,9 +107,13 @@ struct StringRecorder : public IRecorder<Ctx>
m_recording.append("\n");
}
- // Gets the content of recording so far as a string and clears internal content.
- string Flush_Recording()
+ // Gets the content of the recording so far as a string and optionally clears internal content.
+ string Get_Recording(bool flush = true)
{
+ if (!flush)
+ {
+ return m_recording;
+ }
string recording = m_recording;
m_recording.clear();
return recording;
diff --git a/explore/tests/MWTExploreTests.cpp b/explore/tests/MWTExploreTests.cpp
index 6ea35b7a..b3b1ef74 100644
--- a/explore/tests/MWTExploreTests.cpp
+++ b/explore/tests/MWTExploreTests.cpp
@@ -1,636 +1,636 @@
-#include "CppUnitTest.h"
-#include "MWTExploreTests.h"
-
-using namespace Microsoft::VisualStudio::CppUnitTestFramework;
-
-#define COUNT_INVALID(block) try { block } catch (std::invalid_argument) { num_ex++; }
-#define COUNT_BAD_CALL(block) try { block } catch (std::invalid_argument) { num_ex++; }
-
-namespace vw_explore_tests
-{
- TEST_CLASS(VWExploreUnitTests)
- {
- public:
- TEST_METHOD(Epsilon_Greedy)
- {
- int num_actions = 10;
- float epsilon = 0.f; // No randomization
- string unique_key = "1001";
- int params = 101;
-
- TestPolicy my_policy(params, num_actions);
- TestContext my_context;
- TestRecorder my_recorder;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- EpsilonGreedyExplorer<TestContext> explorer(my_policy, epsilon, num_actions);
-
- u32 expected_action = my_policy.Choose_Action(my_context);
-
- u32 chosen_action = mwt.Choose_Action(explorer, unique_key, my_context);
- Assert::AreEqual(expected_action, chosen_action);
-
- chosen_action = mwt.Choose_Action(explorer, unique_key, my_context);
- Assert::AreEqual(expected_action, chosen_action);
-
- float expected_probs[2] = { 1.f, 1.f };
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- this->Test_Interactions(interactions, 2, expected_probs);
- }
-
- TEST_METHOD(Epsilon_Greedy_Random)
- {
- int num_actions = 10;
- float epsilon = 0.5f; // Verify that about half the time the default policy is chosen
- int params = 101;
-
- TestPolicy my_policy(params, num_actions);
- TestContext my_context;
- TestRecorder my_recorder;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- EpsilonGreedyExplorer<TestContext> explorer(my_policy, epsilon, num_actions);
-
- u32 policy_action = my_policy.Choose_Action(my_context);
-
- int times_choose = 10000;
- int times_policy_action_chosen = 0;
- for (int i = 0; i < times_choose; i++)
- {
- u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i), my_context);
- if (chosen_action == policy_action)
- {
- times_policy_action_chosen++;
- }
- }
-
- Assert::IsTrue(abs((double)times_policy_action_chosen / times_choose - 0.5) < 0.1);
- }
-
- TEST_METHOD(Tau_First)
- {
- int num_actions = 10;
- u32 tau = 0;
- int params = 101;
-
- TestPolicy my_policy(params, num_actions);
- TestRecorder my_recorder;
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- TauFirstExplorer<TestContext> explorer(my_policy, tau, num_actions);
-
- u32 expected_action = my_policy.Choose_Action(my_context);
-
- u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
- Assert::AreEqual(expected_action, chosen_action);
-
- // tau = 0 means no randomization and no logging
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- this->Test_Interactions(interactions, 0, nullptr);
- }
-
- TEST_METHOD(Tau_First_Random)
- {
- int num_actions = 10;
- u32 tau = 2;
- TestPolicy my_policy(99, num_actions);
- TestRecorder my_recorder;
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- TauFirstExplorer<TestContext> explorer(my_policy, tau, num_actions);
-
- u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
- chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
-
- // Tau expired, did not explore
- chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
- Assert::AreEqual((u32)10, chosen_action);
-
- // Only 2 interactions logged, 3rd one should not be stored
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- float expected_probs[2] = { .1f, .1f };
- this->Test_Interactions(interactions, 2, expected_probs);
- }
-
- TEST_METHOD(Bagging)
- {
- int num_actions = 10;
- int params = 101;
- TestRecorder my_recorder;
-
- vector<unique_ptr<IPolicy<TestContext>>> policies;
- policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params, num_actions)));
- policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params + 1, num_actions)));
-
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("c++-test", my_recorder);
- BaggingExplorer<TestContext> explorer(policies, num_actions);
-
- u32 expected_action1 = policies[0]->Choose_Action(my_context);
- u32 expected_action2 = policies[1]->Choose_Action(my_context);
-
- u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
- Assert::AreEqual(expected_action2, chosen_action);
-
- chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
- Assert::AreEqual(expected_action1, chosen_action);
-
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- float expected_probs[2] = { .5f, .5f };
- this->Test_Interactions(interactions, 2, expected_probs);
- }
-
- TEST_METHOD(Bagging_Random)
- {
- int num_actions = 10;
- int params = 101;
- TestRecorder my_recorder;
-
- vector<unique_ptr<IPolicy<TestContext>>> policies;
- policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params, num_actions)));
- policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params + 1, num_actions)));
-
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("c++-test", my_recorder);
- BaggingExplorer<TestContext> explorer(policies, num_actions);
-
- u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
- chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
-
- // Two bags choosing different actions so prob of each is 1/2
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- float expected_probs[2] = { .5f, .5f };
- this->Test_Interactions(interactions, 2, expected_probs);
- }
-
- TEST_METHOD(Softmax)
- {
- int num_actions = 10;
- float lambda = 0.f;
- int scorer_arg = 7;
- u32 NUM_ACTIONS_COVER = 100;
- float C = 5.0f;
-
- TestScorer my_scorer(scorer_arg, num_actions);
- TestRecorder my_recorder;
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- SoftmaxExplorer<TestContext> explorer(my_scorer, lambda, num_actions);
-
- // Scale C up since we have fewer interactions
- u32 num_decisions = num_actions * log(num_actions * 1.0) + log(NUM_ACTIONS_COVER * 1.0 / num_actions) * C * num_actions;
- // The () following the array should ensure zero-initialization
- u32* actions = new u32[num_actions]();
- u32 i;
- for (i = 0; i < num_decisions; i++)
- {
- u32 action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i + 1), my_context);
- // Action IDs are 1-based
- actions[action - 1]++;
- }
- // Ensure all actions are covered
- for (i = 0; i < num_actions; i++)
- {
- Assert::IsTrue(actions[i] > 0);
- }
- float* expected_probs = new float[num_decisions];
- for (i = 0; i < num_decisions; i++)
- {
- // Our default scorer currently assigns equal weight to each action
- expected_probs[i] = 1.0 / num_actions;
- }
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- this->Test_Interactions(interactions, num_decisions, expected_probs);
-
- delete actions;
- delete expected_probs;
- }
-
- TEST_METHOD(Softmax_Scores)
- {
- int num_actions = 10;
- float lambda = 0.5f;
- int scorer_arg = 7;
- TestScorer my_scorer(scorer_arg, num_actions, /* uniform = */ false);
- TestRecorder my_recorder;
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- SoftmaxExplorer<TestContext> explorer(my_scorer, lambda, num_actions);
-
- u32 action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
- action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
- action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
-
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- size_t num_interactions = interactions.size();
-
- Assert::AreEqual(3, (int)num_interactions);
- for (int i = 0; i < num_interactions; i++)
- {
- Assert::AreNotEqual(1.f / num_actions, interactions[i].Probability);
- }
- }
-
- TEST_METHOD(Generic)
- {
- int num_actions = 10;
- int scorer_arg = 7;
- TestScorer my_scorer(scorer_arg, num_actions);
- TestRecorder my_recorder;
- TestContext my_context;
-
- MwtExplorer<TestContext> mwt("salt", my_recorder);
- GenericExplorer<TestContext> explorer(my_scorer, num_actions);
-
- u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
- chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
- chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
-
- vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
- float expected_probs[3] = { .1f, .1f, .1f };
- this->Test_Interactions(interactions, 3, expected_probs);
- }
-
- TEST_METHOD(End_To_End_Epsilon_Greedy)
- {
- int num_actions = 10;
- float epsilon = 0.5f;
- int params = 101;
-
- TestSimplePolicy my_policy(params, num_actions);
- StringRecorder<SimpleContext> my_recorder;
-
- MwtExplorer<SimpleContext> mwt("salt", my_recorder);
- EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions);
-
- this->End_To_End(mwt, explorer, my_recorder);
- }
-
- TEST_METHOD(End_To_End_Tau_First)
- {
- int num_actions = 10;
- u32 tau = 5;
- int params = 101;
-
- TestSimplePolicy my_policy(params, num_actions);
- StringRecorder<SimpleContext> my_recorder;
-
- MwtExplorer<SimpleContext> mwt("salt", my_recorder);
- TauFirstExplorer<SimpleContext> explorer(my_policy, tau, num_actions);
-
- this->End_To_End(mwt, explorer, my_recorder);
- }
-
- TEST_METHOD(End_To_End_Bagging)
- {
- int num_actions = 10;
- u32 bags = 2;
- int params = 101;
- StringRecorder<SimpleContext> my_recorder;
-
- vector<unique_ptr<IPolicy<SimpleContext>>> policies;
- policies.push_back(unique_ptr<IPolicy<SimpleContext>>(new TestSimplePolicy(params, num_actions)));
- policies.push_back(unique_ptr<IPolicy<SimpleContext>>(new TestSimplePolicy(params, num_actions)));
-
- MwtExplorer<SimpleContext> mwt("salt", my_recorder);
- BaggingExplorer<SimpleContext> explorer(policies, num_actions);
-
- this->End_To_End(mwt, explorer, my_recorder);
- }
-
- TEST_METHOD(End_To_End_Softmax)
- {
- int num_actions = 10;
- float lambda = 0.5f;
- int scorer_arg = 7;
- TestSimpleScorer my_scorer(scorer_arg, num_actions);
- StringRecorder<SimpleContext> my_recorder;
-
- MwtExplorer<SimpleContext> mwt("salt", my_recorder);
- SoftmaxExplorer<SimpleContext> explorer(my_scorer, lambda, num_actions);
-
- this->End_To_End(mwt, explorer, my_recorder);
- }
-
- TEST_METHOD(End_To_End_Generic)
- {
- int num_actions = 10;
- int scorer_arg = 7;
- TestSimpleScorer my_scorer(scorer_arg, num_actions);
- StringRecorder<SimpleContext> my_recorder;
-
- MwtExplorer<SimpleContext> mwt("salt", my_recorder);
- GenericExplorer<SimpleContext> explorer(my_scorer, num_actions);
-
- this->End_To_End(mwt, explorer, my_recorder);
- }
-
- TEST_METHOD(PRG_Coverage)
- {
- const u32 NUM_ACTIONS_COVER = 100;
- float C = 5.0f;
-
- // We could use many fewer bits (e.g. u8) per bin since we're throwing uniformly at
- // random, but this is safer in case we change things
- u32 bins[NUM_ACTIONS_COVER] = { 0 };
- u32 num_balls = NUM_ACTIONS_COVER * log(NUM_ACTIONS_COVER) + C * NUM_ACTIONS_COVER;
- PRG::prg rg;
- u32 i;
- for (i = 0; i < num_balls; i++)
- {
- bins[rg.Uniform_Int(0, NUM_ACTIONS_COVER - 1)]++;
- }
- // Ensure all actions are covered
- for (i = 0; i < NUM_ACTIONS_COVER; i++)
- {
- Assert::IsTrue(bins[i] > 0);
- }
- }
-
- TEST_METHOD(Float_To_String)
- {
- PRG::prg rand;
-
- for (int i = 0; i < 10000; i++)
- {
- float f = (rand.Uniform_Unit_Interval() - 0.5f) * rand.Uniform_Int(0, 100000);
-
- ostringstream expected_stream;
- expected_stream << std::fixed << std::setprecision(10) << f;
- string expected_str = expected_stream.str();
-
- char actual_chars[15] = { 0 };
- NumberUtils::Float_To_String(f, actual_chars);
- string actual_str(actual_chars);
-
- size_t length = actual_str.length() - 1;
- int compare_result = expected_str.compare(0, length, actual_str, 0, length);
- Assert::AreEqual(0, compare_result);
- }
- }
-
- TEST_METHOD(Serialized_String)
- {
- int num_actions = 10;
- float epsilon = 0.5f;
- int params = 101;
-
- TestSimplePolicy my_policy(params, num_actions);
-
- StringRecorder<SimpleContext> my_recorder;
- MwtExplorer<SimpleContext> mwt("c++-test", my_recorder);
- EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions);
-
- vector<Feature> features1;
- features1.push_back({ 0.5f, 1 });
- SimpleContext context1(features1);
-
- u32 expected_action = my_policy.Choose_Action(context1);
-
- string unique_key1 = "key1";
- u32 chosen_action1 = mwt.Choose_Action(explorer, unique_key1, context1);
-
- vector<Feature> features2;
- features2.push_back({ -99999.5f, 123456789 });
- features2.push_back({ 1.5f, 39 });
-
- SimpleContext context2(features2);
-
- string unique_key2 = "key2";
- u32 chosen_action2 = mwt.Choose_Action(explorer, unique_key2, context2);
-
- string actual_log = my_recorder.Flush_Recording();
-
- // Use hard-coded string to be independent of sprintf
- char* expected_log = "2 key1 0.55000 | 1:.5\n2 key2 0.55000 | 123456789:-99999.5 39:1.5\n";
-
- Assert::AreEqual(expected_log, actual_log.c_str());
- }
-
- TEST_METHOD(Serialized_String_Random)
- {
- PRG::prg rand;
-
- int num_actions = 10;
- int params = 101;
-
- TestSimplePolicy my_policy(params, num_actions);
-
- char expected_log[100] = { 0 };
-
- for (int i = 0; i < 10000; i++)
- {
- StringRecorder<SimpleContext> my_recorder;
- MwtExplorer<SimpleContext> mwt("c++-test", my_recorder);
- EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, 0.f, num_actions);
-
- Feature feature;
- feature.Value = (rand.Uniform_Unit_Interval() - 0.5f) * rand.Uniform_Int(0, 100000);
- feature.Id = i;
- vector<Feature> features;
- features.push_back(feature);
- SimpleContext my_context(features);
-
- u32 action = mwt.Choose_Action(explorer, "", my_context);
- string actual_log = my_recorder.Flush_Recording();
-
- ostringstream expected_stream;
- expected_stream << std::fixed << std::setprecision(10) << feature.Value;
-
- string expected_str = expected_stream.str();
- if (expected_str[0] == '0')
- {
- expected_str = expected_str.erase(0, 1);
- }
-
- sprintf(expected_log, "%d %s %.5f | %d:%s",
- action, "", 1.f, i, expected_str.c_str());
-
- size_t length = actual_log.length() - 1;
- int compare_result = string(expected_log).compare(0, length, actual_log, 0, length);
- Assert::AreEqual(0, compare_result);
- }
- }
-
- TEST_METHOD(Usage_Bad_Arguments)
- {
- int num_ex = 0;
- int params = 101;
- TestPolicy my_policy(params, 0);
- TestScorer my_scorer(params, 0);
- vector<unique_ptr<IPolicy<TestContext>>> policies;
-
- COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, .5f, 0);) // Invalid # actions, must be > 0
- COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, 1.5f, 10);) // Invalid epsilon, must be in [0,1]
- COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, -.5f, 10);) // Invalid epsilon, must be in [0,1]
-
- COUNT_INVALID(BaggingExplorer<TestContext> explorer(policies, 0);) // Invalid # actions, must be > 0
- COUNT_INVALID(BaggingExplorer<TestContext> explorer(policies, 1);) // Invalid # bags, must be > 0
-
- COUNT_INVALID(TauFirstExplorer<TestContext> explorer(my_policy, 1, 0);) // Invalid # actions, must be > 0
- COUNT_INVALID(SoftmaxExplorer<TestContext> explorer(my_scorer, .5f, 0);) // Invalid # actions, must be > 0
- COUNT_INVALID(GenericExplorer<TestContext> explorer(my_scorer, 0);) // Invalid # actions, must be > 0
-
-
- Assert::AreEqual(8, num_ex);
- }
-
- TEST_METHOD(Usage_Bad_Policy)
- {
- int num_ex = 0;
-
- // Default policy returns action outside valid range
- COUNT_BAD_CALL
- (
- MwtExplorer<TestContext> mwt("salt", TestRecorder());
- EpsilonGreedyExplorer<TestContext> explorer(TestBadPolicy(), 0.f, (u32)1);
-
- u32 expected_action = mwt.Choose_Action(explorer, "1001", TestContext());
- )
- COUNT_BAD_CALL
- (
- MwtExplorer<TestContext> mwt("salt", TestRecorder());
- TauFirstExplorer<TestContext> explorer(TestBadPolicy(), (u32)0, (u32)1);
- mwt.Choose_Action(explorer, "test", TestContext());
- )
- COUNT_BAD_CALL
- (
- vector<unique_ptr<IPolicy<TestContext>>> policies;
- policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestBadPolicy()));
- MwtExplorer<TestContext> mwt("salt", TestRecorder());
- BaggingExplorer<TestContext> explorer(policies, (u32)1);
- mwt.Choose_Action(explorer, "test", TestContext());
- )
- Assert::AreEqual(3, num_ex);
- }
-
- TEST_METHOD(Usage_Bad_Scorer)
- {
- int num_ex = 0;
-
- // Default policy returns action outside valid range
- COUNT_BAD_CALL
- (
- u32 num_actions = 1;
- FixedScorer scorer(num_actions, -1);
- MwtExplorer<TestContext> mwt("salt", TestRecorder());
- GenericExplorer<TestContext> explorer(scorer, num_actions);
- mwt.Choose_Action(explorer, "test", TestContext());
- )
- COUNT_BAD_CALL
- (
- u32 num_actions = 1;
- FixedScorer scorer(num_actions, 0);
- MwtExplorer<TestContext> mwt("salt", TestRecorder());
- GenericExplorer<TestContext> explorer(scorer, num_actions);
- mwt.Choose_Action(explorer, "test", TestContext());
- )
- Assert::AreEqual(2, num_ex);
- }
-
- TEST_METHOD(Custom_Context)
- {
- int num_actions = 10;
- float epsilon = 0.f; // No randomization
- string unique_key = "1001";
-
- TestSimplePolicy my_policy(0, num_actions);
-
- TestSimpleRecorder my_recorder;
- MwtExplorer<SimpleContext> mwt("salt", my_recorder);
-
- vector<Feature> features;
- features.push_back({ 0.5f, 1 });
- features.push_back({ 1.5f, 6 });
- features.push_back({ -5.3f, 13 });
- SimpleContext custom_context(features);
-
- EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions);
-
- u32 chosen_action = mwt.Choose_Action(explorer, unique_key, custom_context);
- Assert::AreEqual((u32)1, chosen_action);
-
- float expected_probs[1] = { 1.f };
-
- vector<TestInteraction<SimpleContext>> interactions = my_recorder.Get_All_Interactions();
- Assert::AreEqual(1, (int)interactions.size());
-
- SimpleContext* returned_context = &interactions[0].Context;
-
- size_t onf = features.size();
- Feature* of = &features[0];
-
- vector<Feature>& returned_features = returned_context->Get_Features();
- size_t rnf = returned_features.size();
- Feature* rf = &returned_features[0];
-
- Assert::AreEqual(rnf, onf);
- for (size_t i = 0; i < rnf; i++)
- {
- Assert::AreEqual(of[i].Id, rf[i].Id);
- Assert::AreEqual(of[i].Value, rf[i].Value);
- }
- }
-
- TEST_METHOD_INITIALIZE(Test_Initialize)
- {
- }
-
- TEST_METHOD_CLEANUP(Test_Cleanup)
- {
- }
-
- private:
- // Test end-to-end using StringRecorder with no crash
- template <class Exp>
- void End_To_End(MwtExplorer<SimpleContext>& mwt, Exp& explorer, StringRecorder<SimpleContext>& recorder)
- {
- PRG::prg rand;
-
- float rewards[10];
- for (int i = 0; i < 10; i++)
- {
- vector<Feature> features;
- for (int j = 0; j < 1000; j++)
- {
- features.push_back({ rand.Uniform_Unit_Interval(), j + 1 });
- }
- SimpleContext c(features);
-
- mwt.Choose_Action(explorer, to_string(i), c);
-
- rewards[i] = rand.Uniform_Unit_Interval();
- }
-
- recorder.Flush_Recording();
- }
-
- template <class Ctx>
- inline void Test_Interactions(vector<TestInteraction<Ctx>> interactions, int num_interactions_expected, float* probs_expected)
- {
- size_t num_interactions = interactions.size();
-
- Assert::AreEqual(num_interactions_expected, (int)num_interactions);
- for (int i = 0; i < num_interactions; i++)
- {
- Assert::AreEqual(probs_expected[i], interactions[i].Probability);
- }
- }
-
- string Get_Unique_Key(u32 seed)
- {
- PRG::prg rg(seed);
-
- std::ostringstream unique_key_container;
- unique_key_container << rg.Uniform_Unit_Interval();
-
- return unique_key_container.str();
- }
- };
-}
+#include "CppUnitTest.h"
+#include "MWTExploreTests.h"
+
+using namespace Microsoft::VisualStudio::CppUnitTestFramework;
+
+#define COUNT_INVALID(block) try { block } catch (std::invalid_argument) { num_ex++; }
+#define COUNT_BAD_CALL(block) try { block } catch (std::invalid_argument) { num_ex++; }
+
+namespace vw_explore_tests
+{
+ TEST_CLASS(VWExploreUnitTests)
+ {
+ public:
+ TEST_METHOD(Epsilon_Greedy)
+ {
+ int num_actions = 10;
+ float epsilon = 0.f; // No randomization
+ string unique_key = "1001";
+ int params = 101;
+
+ TestPolicy my_policy(params, num_actions);
+ TestContext my_context;
+ TestRecorder my_recorder;
+
+ MwtExplorer<TestContext> mwt("salt", my_recorder);
+ EpsilonGreedyExplorer<TestContext> explorer(my_policy, epsilon, num_actions);
+
+ u32 expected_action = my_policy.Choose_Action(my_context);
+
+ u32 chosen_action = mwt.Choose_Action(explorer, unique_key, my_context);
+ Assert::AreEqual(expected_action, chosen_action);
+
+ chosen_action = mwt.Choose_Action(explorer, unique_key, my_context);
+ Assert::AreEqual(expected_action, chosen_action);
+
+ float expected_probs[2] = { 1.f, 1.f };
+ vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
+ this->Test_Interactions(interactions, 2, expected_probs);
+ }
+
+ TEST_METHOD(Epsilon_Greedy_Random)
+ {
+ int num_actions = 10;
+ float epsilon = 0.5f; // Verify that about half the time the default policy is chosen
+ int params = 101;
+
+ TestPolicy my_policy(params, num_actions);
+ TestContext my_context;
+ TestRecorder my_recorder;
+
+ MwtExplorer<TestContext> mwt("salt", my_recorder);
+ EpsilonGreedyExplorer<TestContext> explorer(my_policy, epsilon, num_actions);
+
+ u32 policy_action = my_policy.Choose_Action(my_context);
+
+ int times_choose = 10000;
+ int times_policy_action_chosen = 0;
+ for (int i = 0; i < times_choose; i++)
+ {
+ u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i), my_context);
+ if (chosen_action == policy_action)
+ {
+ times_policy_action_chosen++;
+ }
+ }
+
+ Assert::IsTrue(abs((double)times_policy_action_chosen / times_choose - 0.5) < 0.1);
+ }
+
+ TEST_METHOD(Tau_First)
+ {
+ int num_actions = 10;
+ u32 tau = 0;
+ int params = 101;
+
+ TestPolicy my_policy(params, num_actions);
+ TestRecorder my_recorder;
+ TestContext my_context;
+
+ MwtExplorer<TestContext> mwt("salt", my_recorder);
+ TauFirstExplorer<TestContext> explorer(my_policy, tau, num_actions);
+
+ u32 expected_action = my_policy.Choose_Action(my_context);
+
+ u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
+ Assert::AreEqual(expected_action, chosen_action);
+
+ // tau = 0 means no randomization and no logging
+ vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
+ this->Test_Interactions(interactions, 0, nullptr);
+ }
+
+ TEST_METHOD(Tau_First_Random)
+ {
+ int num_actions = 10;
+ u32 tau = 2;
+ TestPolicy my_policy(99, num_actions);
+ TestRecorder my_recorder;
+ TestContext my_context;
+
+ MwtExplorer<TestContext> mwt("salt", my_recorder);
+ TauFirstExplorer<TestContext> explorer(my_policy, tau, num_actions);
+
+ u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
+ chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
+
+ // Tau expired, did not explore
+ chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
+ Assert::AreEqual((u32)10, chosen_action);
+
+ // Only 2 interactions logged, 3rd one should not be stored
+ vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
+ float expected_probs[2] = { .1f, .1f };
+ this->Test_Interactions(interactions, 2, expected_probs);
+ }
+
+ TEST_METHOD(Bagging)
+ {
+ int num_actions = 10;
+ int params = 101;
+ TestRecorder my_recorder;
+
+ vector<unique_ptr<IPolicy<TestContext>>> policies;
+ policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params, num_actions)));
+ policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params + 1, num_actions)));
+
+ TestContext my_context;
+
+ MwtExplorer<TestContext> mwt("c++-test", my_recorder);
+ BaggingExplorer<TestContext> explorer(policies, num_actions);
+
+ u32 expected_action1 = policies[0]->Choose_Action(my_context);
+ u32 expected_action2 = policies[1]->Choose_Action(my_context);
+
+ u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
+ Assert::AreEqual(expected_action2, chosen_action);
+
+ chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
+ Assert::AreEqual(expected_action1, chosen_action);
+
+ vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
+ float expected_probs[2] = { .5f, .5f };
+ this->Test_Interactions(interactions, 2, expected_probs);
+ }
+
+ TEST_METHOD(Bagging_Random)
+ {
+ int num_actions = 10;
+ int params = 101;
+ TestRecorder my_recorder;
+
+ vector<unique_ptr<IPolicy<TestContext>>> policies;
+ policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params, num_actions)));
+ policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy(params + 1, num_actions)));
+
+ TestContext my_context;
+
+ MwtExplorer<TestContext> mwt("c++-test", my_recorder);
+ BaggingExplorer<TestContext> explorer(policies, num_actions);
+
+ u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
+ chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
+
+ // Two bags choosing different actions so prob of each is 1/2
+ vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
+ float expected_probs[2] = { .5f, .5f };
+ this->Test_Interactions(interactions, 2, expected_probs);
+ }
+
+ TEST_METHOD(Softmax)
+ {
+ int num_actions = 10;
+ float lambda = 0.f;
+ int scorer_arg = 7;
+ u32 NUM_ACTIONS_COVER = 100;
+ float C = 5.0f;
+
+ TestScorer my_scorer(scorer_arg, num_actions);
+ TestRecorder my_recorder;
+ TestContext my_context;
+
+ MwtExplorer<TestContext> mwt("salt", my_recorder);
+ SoftmaxExplorer<TestContext> explorer(my_scorer, lambda, num_actions);
+
+ // Scale C up since we have fewer interactions
+ u32 num_decisions = num_actions * log(num_actions * 1.0) + log(NUM_ACTIONS_COVER * 1.0 / num_actions) * C * num_actions;
+ // The () following the array should ensure zero-initialization
+ u32* actions = new u32[num_actions]();
+ u32 i;
+ for (i = 0; i < num_decisions; i++)
+ {
+ u32 action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i + 1), my_context);
+ // Action IDs are 1-based
+ actions[action - 1]++;
+ }
+ // Ensure all actions are covered
+ for (i = 0; i < num_actions; i++)
+ {
+ Assert::IsTrue(actions[i] > 0);
+ }
+ float* expected_probs = new float[num_decisions];
+ for (i = 0; i < num_decisions; i++)
+ {
+ // Our default scorer currently assigns equal weight to each action
+ expected_probs[i] = 1.0 / num_actions;
+ }
+ vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
+ this->Test_Interactions(interactions, num_decisions, expected_probs);
+
+ delete actions;
+ delete expected_probs;
+ }
+
+ TEST_METHOD(Softmax_Scores)
+ {
+ int num_actions = 10;
+ float lambda = 0.5f;
+ int scorer_arg = 7;
+ TestScorer my_scorer(scorer_arg, num_actions, /* uniform = */ false);
+ TestRecorder my_recorder;
+ TestContext my_context;
+
+ MwtExplorer<TestContext> mwt("salt", my_recorder);
+ SoftmaxExplorer<TestContext> explorer(my_scorer, lambda, num_actions);
+
+ u32 action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
+ action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
+ action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
+
+ vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
+ size_t num_interactions = interactions.size();
+
+ Assert::AreEqual(3, (int)num_interactions);
+ for (int i = 0; i < num_interactions; i++)
+ {
+ Assert::AreNotEqual(1.f / num_actions, interactions[i].Probability);
+ }
+ }
+
+ TEST_METHOD(Generic)
+ {
+ int num_actions = 10;
+ int scorer_arg = 7;
+ TestScorer my_scorer(scorer_arg, num_actions);
+ TestRecorder my_recorder;
+ TestContext my_context;
+
+ MwtExplorer<TestContext> mwt("salt", my_recorder);
+ GenericExplorer<TestContext> explorer(my_scorer, num_actions);
+
+ u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
+ chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
+ chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
+
+ vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
+ float expected_probs[3] = { .1f, .1f, .1f };
+ this->Test_Interactions(interactions, 3, expected_probs);
+ }
+
+ TEST_METHOD(End_To_End_Epsilon_Greedy)
+ {
+ int num_actions = 10;
+ float epsilon = 0.5f;
+ int params = 101;
+
+ TestSimplePolicy my_policy(params, num_actions);
+ StringRecorder<SimpleContext> my_recorder;
+
+ MwtExplorer<SimpleContext> mwt("salt", my_recorder);
+ EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions);
+
+ this->End_To_End(mwt, explorer, my_recorder);
+ }
+
+ TEST_METHOD(End_To_End_Tau_First)
+ {
+ int num_actions = 10;
+ u32 tau = 5;
+ int params = 101;
+
+ TestSimplePolicy my_policy(params, num_actions);
+ StringRecorder<SimpleContext> my_recorder;
+
+ MwtExplorer<SimpleContext> mwt("salt", my_recorder);
+ TauFirstExplorer<SimpleContext> explorer(my_policy, tau, num_actions);
+
+ this->End_To_End(mwt, explorer, my_recorder);
+ }
+
+ TEST_METHOD(End_To_End_Bagging)
+ {
+ int num_actions = 10;
+ u32 bags = 2;
+ int params = 101;
+ StringRecorder<SimpleContext> my_recorder;
+
+ vector<unique_ptr<IPolicy<SimpleContext>>> policies;
+ policies.push_back(unique_ptr<IPolicy<SimpleContext>>(new TestSimplePolicy(params, num_actions)));
+ policies.push_back(unique_ptr<IPolicy<SimpleContext>>(new TestSimplePolicy(params, num_actions)));
+
+ MwtExplorer<SimpleContext> mwt("salt", my_recorder);
+ BaggingExplorer<SimpleContext> explorer(policies, num_actions);
+
+ this->End_To_End(mwt, explorer, my_recorder);
+ }
+
+ TEST_METHOD(End_To_End_Softmax)
+ {
+ int num_actions = 10;
+ float lambda = 0.5f;
+ int scorer_arg = 7;
+ TestSimpleScorer my_scorer(scorer_arg, num_actions);
+ StringRecorder<SimpleContext> my_recorder;
+
+ MwtExplorer<SimpleContext> mwt("salt", my_recorder);
+ SoftmaxExplorer<SimpleContext> explorer(my_scorer, lambda, num_actions);
+
+ this->End_To_End(mwt, explorer, my_recorder);
+ }
+
+ TEST_METHOD(End_To_End_Generic)
+ {
+ int num_actions = 10;
+ int scorer_arg = 7;
+ TestSimpleScorer my_scorer(scorer_arg, num_actions);
+ StringRecorder<SimpleContext> my_recorder;
+
+ MwtExplorer<SimpleContext> mwt("salt", my_recorder);
+ GenericExplorer<SimpleContext> explorer(my_scorer, num_actions);
+
+ this->End_To_End(mwt, explorer, my_recorder);
+ }
+
+ TEST_METHOD(PRG_Coverage)
+ {
+ const u32 NUM_ACTIONS_COVER = 100;
+ float C = 5.0f;
+
+ // We could use many fewer bits (e.g. u8) per bin since we're throwing uniformly at
+ // random, but this is safer in case we change things
+ u32 bins[NUM_ACTIONS_COVER] = { 0 };
+ u32 num_balls = NUM_ACTIONS_COVER * log(NUM_ACTIONS_COVER) + C * NUM_ACTIONS_COVER;
+ PRG::prg rg;
+ u32 i;
+ for (i = 0; i < num_balls; i++)
+ {
+ bins[rg.Uniform_Int(0, NUM_ACTIONS_COVER - 1)]++;
+ }
+ // Ensure all actions are covered
+ for (i = 0; i < NUM_ACTIONS_COVER; i++)
+ {
+ Assert::IsTrue(bins[i] > 0);
+ }
+ }
+
+ TEST_METHOD(Float_To_String)
+ {
+ PRG::prg rand;
+
+ for (int i = 0; i < 10000; i++)
+ {
+ float f = (rand.Uniform_Unit_Interval() - 0.5f) * rand.Uniform_Int(0, 100000);
+
+ ostringstream expected_stream;
+ expected_stream << std::fixed << std::setprecision(10) << f;
+ string expected_str = expected_stream.str();
+
+ char actual_chars[15] = { 0 };
+ NumberUtils::Float_To_String(f, actual_chars);
+ string actual_str(actual_chars);
+
+ size_t length = actual_str.length() - 1;
+ int compare_result = expected_str.compare(0, length, actual_str, 0, length);
+ Assert::AreEqual(0, compare_result);
+ }
+ }
+
+ TEST_METHOD(Serialized_String)
+ {
+ int num_actions = 10;
+ float epsilon = 0.5f;
+ int params = 101;
+
+ TestSimplePolicy my_policy(params, num_actions);
+
+ StringRecorder<SimpleContext> my_recorder;
+ MwtExplorer<SimpleContext> mwt("c++-test", my_recorder);
+ EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions);
+
+ vector<Feature> features1;
+ features1.push_back({ 0.5f, 1 });
+ SimpleContext context1(features1);
+
+ u32 expected_action = my_policy.Choose_Action(context1);
+
+ string unique_key1 = "key1";
+ u32 chosen_action1 = mwt.Choose_Action(explorer, unique_key1, context1);
+
+ vector<Feature> features2;
+ features2.push_back({ -99999.5f, 123456789 });
+ features2.push_back({ 1.5f, 39 });
+
+ SimpleContext context2(features2);
+
+ string unique_key2 = "key2";
+ u32 chosen_action2 = mwt.Choose_Action(explorer, unique_key2, context2);
+
+ string actual_log = my_recorder.Get_Recording();
+
+ // Use hard-coded string to be independent of sprintf
+ char* expected_log = "2 key1 0.55000 | 1:.5\n2 key2 0.55000 | 123456789:-99999.5 39:1.5\n";
+
+ Assert::AreEqual(expected_log, actual_log.c_str());
+ }
+
+ TEST_METHOD(Serialized_String_Random)
+ {
+ PRG::prg rand;
+
+ int num_actions = 10;
+ int params = 101;
+
+ TestSimplePolicy my_policy(params, num_actions);
+
+ char expected_log[100] = { 0 };
+
+ for (int i = 0; i < 10000; i++)
+ {
+ StringRecorder<SimpleContext> my_recorder;
+ MwtExplorer<SimpleContext> mwt("c++-test", my_recorder);
+ EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, 0.f, num_actions);
+
+ Feature feature;
+ feature.Value = (rand.Uniform_Unit_Interval() - 0.5f) * rand.Uniform_Int(0, 100000);
+ feature.Id = i;
+ vector<Feature> features;
+ features.push_back(feature);
+ SimpleContext my_context(features);
+
+ u32 action = mwt.Choose_Action(explorer, "", my_context);
+ string actual_log = my_recorder.Get_Recording();
+
+ ostringstream expected_stream;
+ expected_stream << std::fixed << std::setprecision(10) << feature.Value;
+
+ string expected_str = expected_stream.str();
+ if (expected_str[0] == '0')
+ {
+ expected_str = expected_str.erase(0, 1);
+ }
+
+ sprintf(expected_log, "%d %s %.5f | %d:%s",
+ action, "", 1.f, i, expected_str.c_str());
+
+ size_t length = actual_log.length() - 1;
+ int compare_result = string(expected_log).compare(0, length, actual_log, 0, length);
+ Assert::AreEqual(0, compare_result);
+ }
+ }
+
+ TEST_METHOD(Usage_Bad_Arguments)
+ {
+ int num_ex = 0;
+ int params = 101;
+ TestPolicy my_policy(params, 0);
+ TestScorer my_scorer(params, 0);
+ vector<unique_ptr<IPolicy<TestContext>>> policies;
+
+ COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, .5f, 0);) // Invalid # actions, must be > 0
+ COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, 1.5f, 10);) // Invalid epsilon, must be in [0,1]
+ COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, -.5f, 10);) // Invalid epsilon, must be in [0,1]
+
+ COUNT_INVALID(BaggingExplorer<TestContext> explorer(policies, 0);) // Invalid # actions, must be > 0
+ COUNT_INVALID(BaggingExplorer<TestContext> explorer(policies, 1);) // Invalid # bags, must be > 0
+
+ COUNT_INVALID(TauFirstExplorer<TestContext> explorer(my_policy, 1, 0);) // Invalid # actions, must be > 0
+ COUNT_INVALID(SoftmaxExplorer<TestContext> explorer(my_scorer, .5f, 0);) // Invalid # actions, must be > 0
+ COUNT_INVALID(GenericExplorer<TestContext> explorer(my_scorer, 0);) // Invalid # actions, must be > 0
+
+
+ Assert::AreEqual(8, num_ex);
+ }
+
+ TEST_METHOD(Usage_Bad_Policy)
+ {
+ int num_ex = 0;
+
+ // Default policy returns action outside valid range
+ COUNT_BAD_CALL
+ (
+ MwtExplorer<TestContext> mwt("salt", TestRecorder());
+ EpsilonGreedyExplorer<TestContext> explorer(TestBadPolicy(), 0.f, (u32)1);
+
+ u32 expected_action = mwt.Choose_Action(explorer, "1001", TestContext());
+ )
+ COUNT_BAD_CALL
+ (
+ MwtExplorer<TestContext> mwt("salt", TestRecorder());
+ TauFirstExplorer<TestContext> explorer(TestBadPolicy(), (u32)0, (u32)1);
+ mwt.Choose_Action(explorer, "test", TestContext());
+ )
+ COUNT_BAD_CALL
+ (
+ vector<unique_ptr<IPolicy<TestContext>>> policies;
+ policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestBadPolicy()));
+ MwtExplorer<TestContext> mwt("salt", TestRecorder());
+ BaggingExplorer<TestContext> explorer(policies, (u32)1);
+ mwt.Choose_Action(explorer, "test", TestContext());
+ )
+ Assert::AreEqual(3, num_ex);
+ }
+
+ TEST_METHOD(Usage_Bad_Scorer)
+ {
+ int num_ex = 0;
+
+ // Default policy returns action outside valid range
+ COUNT_BAD_CALL
+ (
+ u32 num_actions = 1;
+ FixedScorer scorer(num_actions, -1);
+ MwtExplorer<TestContext> mwt("salt", TestRecorder());
+ GenericExplorer<TestContext> explorer(scorer, num_actions);
+ mwt.Choose_Action(explorer, "test", TestContext());
+ )
+ COUNT_BAD_CALL
+ (
+ u32 num_actions = 1;
+ FixedScorer scorer(num_actions, 0);
+ MwtExplorer<TestContext> mwt("salt", TestRecorder());
+ GenericExplorer<TestContext> explorer(scorer, num_actions);
+ mwt.Choose_Action(explorer, "test", TestContext());
+ )
+ Assert::AreEqual(2, num_ex);
+ }
+
+ TEST_METHOD(Custom_Context)
+ {
+ int num_actions = 10;
+ float epsilon = 0.f; // No randomization
+ string unique_key = "1001";
+
+ TestSimplePolicy my_policy(0, num_actions);
+
+ TestSimpleRecorder my_recorder;
+ MwtExplorer<SimpleContext> mwt("salt", my_recorder);
+
+ vector<Feature> features;
+ features.push_back({ 0.5f, 1 });
+ features.push_back({ 1.5f, 6 });
+ features.push_back({ -5.3f, 13 });
+ SimpleContext custom_context(features);
+
+ EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions);
+
+ u32 chosen_action = mwt.Choose_Action(explorer, unique_key, custom_context);
+ Assert::AreEqual((u32)1, chosen_action);
+
+ float expected_probs[1] = { 1.f };
+
+ vector<TestInteraction<SimpleContext>> interactions = my_recorder.Get_All_Interactions();
+ Assert::AreEqual(1, (int)interactions.size());
+
+ SimpleContext* returned_context = &interactions[0].Context;
+
+ size_t onf = features.size();
+ Feature* of = &features[0];
+
+ vector<Feature>& returned_features = returned_context->Get_Features();
+ size_t rnf = returned_features.size();
+ Feature* rf = &returned_features[0];
+
+ Assert::AreEqual(rnf, onf);
+ for (size_t i = 0; i < rnf; i++)
+ {
+ Assert::AreEqual(of[i].Id, rf[i].Id);
+ Assert::AreEqual(of[i].Value, rf[i].Value);
+ }
+ }
+
+ TEST_METHOD_INITIALIZE(Test_Initialize)
+ {
+ }
+
+ TEST_METHOD_CLEANUP(Test_Cleanup)
+ {
+ }
+
+ private:
+ // Test end-to-end using StringRecorder with no crash
+ template <class Exp>
+ void End_To_End(MwtExplorer<SimpleContext>& mwt, Exp& explorer, StringRecorder<SimpleContext>& recorder)
+ {
+ PRG::prg rand;
+
+ float rewards[10];
+ for (int i = 0; i < 10; i++)
+ {
+ vector<Feature> features;
+ for (int j = 0; j < 1000; j++)
+ {
+ features.push_back({ rand.Uniform_Unit_Interval(), j + 1 });
+ }
+ SimpleContext c(features);
+
+ mwt.Choose_Action(explorer, to_string(i), c);
+
+ rewards[i] = rand.Uniform_Unit_Interval();
+ }
+
+ recorder.Get_Recording();
+ }
+
+ template <class Ctx>
+ inline void Test_Interactions(vector<TestInteraction<Ctx>> interactions, int num_interactions_expected, float* probs_expected)
+ {
+ size_t num_interactions = interactions.size();
+
+ Assert::AreEqual(num_interactions_expected, (int)num_interactions);
+ for (int i = 0; i < num_interactions; i++)
+ {
+ Assert::AreEqual(probs_expected[i], interactions[i].Probability);
+ }
+ }
+
+ string Get_Unique_Key(u32 seed)
+ {
+ PRG::prg rg(seed);
+
+ std::ostringstream unique_key_container;
+ unique_key_container << rg.Uniform_Unit_Interval();
+
+ return unique_key_container.str();
+ }
+ };
+}