diff options
author | Luong Hoang <lhoang@live.com> | 2014-11-13 20:01:48 +0300 |
---|---|---|
committer | Luong Hoang <lhoang@live.com> | 2014-11-13 20:01:48 +0300 |
commit | 0dc62a0e6a610077b15552057986aeba867fabf0 (patch) | |
tree | 5566bdf706bc1bea625eadfd74b43dcd9ae461d9 /explore | |
parent | edbbbcf1e1be35a3364418f53b4cbc66965605d8 (diff) |
fixes #30
Diffstat (limited to 'explore')
-rw-r--r-- | explore/clr/explore_clr_wrapper.h | 750 | ||||
-rw-r--r-- | explore/clr/explore_interface.h | 42 | ||||
-rw-r--r-- | explore/clr/explore_interop.h | 708 | ||||
-rw-r--r-- | explore/explore.cpp | 128 | ||||
-rw-r--r-- | explore/static/MWTExplorer.h | 1332 | ||||
-rw-r--r-- | explore/tests/MWTExploreTests.h | 6 |
6 files changed, 1483 insertions, 1483 deletions
diff --git a/explore/clr/explore_clr_wrapper.h b/explore/clr/explore_clr_wrapper.h index f663eb9e..a6cede4b 100644 --- a/explore/clr/explore_clr_wrapper.h +++ b/explore/clr/explore_clr_wrapper.h @@ -1,12 +1,12 @@ -#pragma once -#include "explore_interop.h" - +#pragma once
+#include "explore_interop.h"
+
/*!
* \addtogroup MultiWorldTestingCsharp
* @{
-*/ -namespace MultiWorldTesting { - +*/
+namespace MultiWorldTesting {
+
/// <summary>
/// The epsilon greedy exploration class.
/// </summary>
@@ -14,425 +14,425 @@ namespace MultiWorldTesting { /// This is a good choice if you have no idea which actions should be preferred.
/// Epsilon greedy is also computationally cheap.
/// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam> - generic <class Ctx> - public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx> - { - public: + /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
+ {
+ public:
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
- /// <param name="defaultPolicy">A default function which outputs an action given a context.</param> - /// <param name="epsilon">The probability of a random exploration.</param> - /// <param name="numActions">The number of actions to randomize over.</param> - EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions) - { - this->defaultPolicy = defaultPolicy; - m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions); - } - - ~EpsilonGreedyExplorer() - { - delete m_explorer; - } - - internal: - virtual UInt32 InvokePolicyCallback(Ctx context, int index) override - { - return defaultPolicy->ChooseAction(context); - } - - NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get() - { - return m_explorer; - } - - private: - IPolicy<Ctx>^ defaultPolicy; - NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer; - }; - + /// <param name="defaultPolicy">A default function which outputs an action given a context.</param>
+ /// <param name="epsilon">The probability of a random exploration.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions)
+ {
+ this->defaultPolicy = defaultPolicy;
+ m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions);
+ }
+
+ ~EpsilonGreedyExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
+ {
+ return defaultPolicy->ChooseAction(context);
+ }
+
+ NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ IPolicy<Ctx>^ defaultPolicy;
+ NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer;
+ };
+
/// <summary>
/// The tau-first exploration class.
/// </summary>
/// <remarks>
- /// The tau-first explorer collects precisely tau uniform random + /// The tau-first explorer collects precisely tau uniform random
/// exploration events, and then uses the default policy.
/// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam> - generic <class Ctx> - public ref class TauFirstExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx> - { - public: + /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class TauFirstExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
+ {
+ public:
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
- /// <param name="defaultPolicy">A default policy after randomization finishes.</param> - /// <param name="tau">The number of events to be uniform over.</param> - /// <param name="numActions">The number of actions to randomize over.</param> - TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions) - { - this->defaultPolicy = defaultPolicy; - m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions); - } - - ~TauFirstExplorer() - { - delete m_explorer; - } - - internal: - virtual UInt32 InvokePolicyCallback(Ctx context, int index) override - { - return defaultPolicy->ChooseAction(context); - } - - NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get() - { - return m_explorer; - } - - private: - IPolicy<Ctx>^ defaultPolicy; - NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer; - }; - + /// <param name="defaultPolicy">A default policy after randomization finishes.</param>
+ /// <param name="tau">The number of events to be uniform over.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions)
+ {
+ this->defaultPolicy = defaultPolicy;
+ m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions);
+ }
+
+ ~TauFirstExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
+ {
+ return defaultPolicy->ChooseAction(context);
+ }
+
+ NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ IPolicy<Ctx>^ defaultPolicy;
+ NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer;
+ };
+
/// <summary>
/// The epsilon greedy exploration class.
/// </summary>
/// <remarks>
/// In some cases, different actions have a different scores, and you
- /// would prefer to choose actions with large scores. Softmax allows + /// would prefer to choose actions with large scores. Softmax allows
/// you to do that.
/// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam> - generic <class Ctx> - public ref class SoftmaxExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx> - { - public: + /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class SoftmaxExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
+ {
+ public:
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
- /// <param name="defaultScorer">A function which outputs a score for each action.</param> - /// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param> - /// <param name="numActions">The number of actions to randomize over.</param> - SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions) - { - this->defaultScorer = defaultScorer; - m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions); - } - - ~SoftmaxExplorer() - { - delete m_explorer; - } - - internal: - virtual List<float>^ InvokeScorerCallback(Ctx context) override - { - return defaultScorer->ScoreActions(context); - } - - NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get() - { - return m_explorer; - } - - private: - IScorer<Ctx>^ defaultScorer; - NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer; - }; - + /// <param name="defaultScorer">A function which outputs a score for each action.</param>
+ /// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions)
+ {
+ this->defaultScorer = defaultScorer;
+ m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions);
+ }
+
+ ~SoftmaxExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual List<float>^ InvokeScorerCallback(Ctx context) override
+ {
+ return defaultScorer->ScoreActions(context);
+ }
+
+ NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ IScorer<Ctx>^ defaultScorer;
+ NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer;
+ };
+
/// <summary>
/// The generic exploration class.
/// </summary>
/// <remarks>
- /// GenericExplorer provides complete flexibility. You can create any + /// GenericExplorer provides complete flexibility. You can create any
/// distribution over actions desired, and it will draw from that.
/// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam> - generic <class Ctx> - public ref class GenericExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx> - { - public: + /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class GenericExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
+ {
+ public:
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
- /// <param name="defaultScorer">A function which outputs the probability of each action.</param> - /// <param name="numActions">The number of actions to randomize over.</param> - GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions) - { - this->defaultScorer = defaultScorer; - m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions); - } - - ~GenericExplorer() - { - delete m_explorer; - } - - internal: - virtual List<float>^ InvokeScorerCallback(Ctx context) override - { - return defaultScorer->ScoreActions(context); - } - - NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get() - { - return m_explorer; - } - - private: - IScorer<Ctx>^ defaultScorer; - NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer; - }; - + /// <param name="defaultScorer">A function which outputs the probability of each action.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions)
+ {
+ this->defaultScorer = defaultScorer;
+ m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions);
+ }
+
+ ~GenericExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual List<float>^ InvokeScorerCallback(Ctx context) override
+ {
+ return defaultScorer->ScoreActions(context);
+ }
+
+ NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ IScorer<Ctx>^ defaultScorer;
+ NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer;
+ };
+
/// <summary>
/// The bootstrap exploration class.
/// </summary>
/// <remarks>
- /// The Bootstrap explorer randomizes over the actions chosen by a set of - /// default policies. This performs well statistically but can be + /// The Bootstrap explorer randomizes over the actions chosen by a set of
+ /// default policies. This performs well statistically but can be
/// computationally expensive.
/// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam> - generic <class Ctx> - public ref class BootstrapExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx> - { - public: + /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class BootstrapExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
+ {
+ public:
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
- /// <param name="defaultPolicies">A set of default policies to be uniform random over.</param> - /// <param name="numActions">The number of actions to randomize over.</param> - BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions) - { - this->defaultPolicies = defaultPolicies; - if (this->defaultPolicies == nullptr) - { - throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null."); - } - - m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions); - } - - ~BootstrapExplorer() - { - delete m_explorer; - } - - internal: - virtual UInt32 InvokePolicyCallback(Ctx context, int index) override - { - if (index < 0 || index >= defaultPolicies->Length) - { - throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range."); - } - return defaultPolicies[index]->ChooseAction(context); - } - - NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get() - { - return m_explorer; - } - - private: - cli::array<IPolicy<Ctx>^>^ defaultPolicies; - NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer; - }; - + /// <param name="defaultPolicies">A set of default policies to be uniform random over.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions)
+ {
+ this->defaultPolicies = defaultPolicies;
+ if (this->defaultPolicies == nullptr)
+ {
+ throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null.");
+ }
+
+ m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions);
+ }
+
+ ~BootstrapExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
+ {
+ if (index < 0 || index >= defaultPolicies->Length)
+ {
+ throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range.");
+ }
+ return defaultPolicies[index]->ChooseAction(context);
+ }
+
+ NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ cli::array<IPolicy<Ctx>^>^ defaultPolicies;
+ NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer;
+ };
+
/// <summary>
- /// The top level MwtExplorer class. Using this makes sure that the + /// The top level MwtExplorer class. Using this makes sure that the
/// right bits are recorded and good random actions are chosen.
/// </summary>
- /// <typeparam name="Ctx">The Context type.</typeparam> - generic <class Ctx> - public ref class MwtExplorer : public RecorderCallback<Ctx> - { - public: + /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class MwtExplorer : public RecorderCallback<Ctx>
+ {
+ public:
/// <summary>
/// Constructor.
/// </summary>
- /// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param> - /// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param> - MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder) - { - this->appId = appId; - this->recorder = recorder; - } - + /// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param>
+ /// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param>
+ MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder)
+ {
+ this->appId = appId;
+ this->recorder = recorder;
+ }
+
/// <summary>
/// Choose_Action should be drop-in replacement for any existing policy function.
/// </summary>
- /// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param> - /// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param> - /// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param> - /// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns> - UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context) - { - String^ salt = this->appId; - NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder()); - - GCHandle selfHandle = GCHandle::Alloc(this); - IntPtr selfPtr = (IntPtr)selfHandle; - - GCHandle contextHandle = GCHandle::Alloc(context); - IntPtr contextPtr = (IntPtr)contextHandle; - - GCHandle explorerHandle = GCHandle::Alloc(explorer); - IntPtr explorerPtr = (IntPtr)explorerHandle; - - NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer()); - u32 action = 0; - if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid) - { - EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer; - action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context); - } - else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid) - { - TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer; - action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context); - } - else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid) - { - SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer; - action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context); - } - else if (explorer->GetType() == GenericExplorer<Ctx>::typeid) - { - GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer; - action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context); - } - else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid) - { - BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer; - action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context); - } - - explorerHandle.Free(); - contextHandle.Free(); - selfHandle.Free(); - - return action; - } - - internal: - virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override - { - recorder->Record(context, action, probability, unique_key); - } - - private: - IRecorder<Ctx>^ recorder; - String^ appId; - }; - - /// <summary> - /// Represents a feature in a sparse array. - /// </summary> - [StructLayout(LayoutKind::Sequential)] - public value struct Feature - { - float Value; - UInt32 Id; - }; - + /// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param>
+ /// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param>
+ /// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param>
+ /// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns>
+ UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context)
+ {
+ String^ salt = this->appId;
+ NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder());
+
+ GCHandle selfHandle = GCHandle::Alloc(this);
+ IntPtr selfPtr = (IntPtr)selfHandle;
+
+ GCHandle contextHandle = GCHandle::Alloc(context);
+ IntPtr contextPtr = (IntPtr)contextHandle;
+
+ GCHandle explorerHandle = GCHandle::Alloc(explorer);
+ IntPtr explorerPtr = (IntPtr)explorerHandle;
+
+ NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer());
+ u32 action = 0;
+ if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid)
+ {
+ EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+ else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid)
+ {
+ TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+ else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid)
+ {
+ SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+ else if (explorer->GetType() == GenericExplorer<Ctx>::typeid)
+ {
+ GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+ else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid)
+ {
+ BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+
+ explorerHandle.Free();
+ contextHandle.Free();
+ selfHandle.Free();
+
+ return action;
+ }
+
+ internal:
+ virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override
+ {
+ recorder->Record(context, action, probability, unique_key);
+ }
+
+ private:
+ IRecorder<Ctx>^ recorder;
+ String^ appId;
+ };
+
+ /// <summary>
+ /// Represents a feature in a sparse array.
+ /// </summary>
+ [StructLayout(LayoutKind::Sequential)]
+ public value struct Feature
+ {
+ float Value;
+ UInt32 Id;
+ };
+
/// <summary>
/// A sample recorder class that converts the exploration tuple into string format.
/// </summary>
- /// <typeparam name="Ctx">The Context type.</typeparam> - generic <class Ctx> where Ctx : IStringContext - public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx> - { - public: - StringRecorder() - { - m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>(); - } - - ~StringRecorder() - { - delete m_string_recorder; - } - - virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) - { - GCHandle contextHandle = GCHandle::Alloc(context); - IntPtr contextPtr = (IntPtr)contextHandle; - - NativeStringContext native_context(contextPtr.ToPointer(), GetCallback()); - m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey)); - } - - /// <summary> - /// Gets the content of the recording so far as a string and clears internal content. - /// </summary> - /// <returns> - /// A string with recording content. - /// </returns> - String^ GetRecording() - { - // Workaround for C++-CLI bug which does not allow default value for parameter - return GetRecording(true); - } - - /// <summary> - /// Gets the content of the recording so far as a string and optionally clears internal content. - /// </summary> - /// <param name="flush">A boolean value indicating whether to clear the internal content.</param> - /// <returns> - /// A string with recording content. - /// </returns> - String^ GetRecording(bool flush) - { - return gcnew String(m_string_recorder->Get_Recording(flush).c_str()); - } - - private: - NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder; - }; - - /// <summary> - /// A sample context class that stores a vector of Features. - /// </summary> - public ref class SimpleContext : public IStringContext - { - public: - SimpleContext(cli::array<Feature>^ features) - { - Features = features; - - // TODO: add another constructor overload for native SimpleContext to avoid copying feature values - m_features = new vector<NativeMultiWorldTesting::Feature>(); - for (int i = 0; i < features->Length; i++) - { - m_features->push_back({ features[i].Value, features[i].Id }); - } - - m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features); - } - - String^ ToString() override - { - return gcnew String(m_native_context->To_String().c_str()); - } - - ~SimpleContext() - { - delete m_native_context; - } - - public: - cli::array<Feature>^ GetFeatures() { return Features; } - - internal: - cli::array<Feature>^ Features; - - private: - vector<NativeMultiWorldTesting::Feature>* m_features; - NativeMultiWorldTesting::SimpleContext* m_native_context; - }; -} - + /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx> where Ctx : IStringContext
+ public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx>
+ {
+ public:
+ StringRecorder()
+ {
+ m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>();
+ }
+
+ ~StringRecorder()
+ {
+ delete m_string_recorder;
+ }
+
+ virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey)
+ {
+ GCHandle contextHandle = GCHandle::Alloc(context);
+ IntPtr contextPtr = (IntPtr)contextHandle;
+
+ NativeStringContext native_context(contextPtr.ToPointer(), GetCallback());
+ m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey));
+ }
+
+ /// <summary>
+ /// Gets the content of the recording so far as a string and clears internal content.
+ /// </summary>
+ /// <returns>
+ /// A string with recording content.
+ /// </returns>
+ String^ GetRecording()
+ {
+ // Workaround for C++-CLI bug which does not allow default value for parameter
+ return GetRecording(true);
+ }
+
+ /// <summary>
+ /// Gets the content of the recording so far as a string and optionally clears internal content.
+ /// </summary>
+ /// <param name="flush">A boolean value indicating whether to clear the internal content.</param>
+ /// <returns>
+ /// A string with recording content.
+ /// </returns>
+ String^ GetRecording(bool flush)
+ {
+ return gcnew String(m_string_recorder->Get_Recording(flush).c_str());
+ }
+
+ private:
+ NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder;
+ };
+
+ /// <summary>
+ /// A sample context class that stores a vector of Features.
+ /// </summary>
+ public ref class SimpleContext : public IStringContext
+ {
+ public:
+ SimpleContext(cli::array<Feature>^ features)
+ {
+ Features = features;
+
+ // TODO: add another constructor overload for native SimpleContext to avoid copying feature values
+ m_features = new vector<NativeMultiWorldTesting::Feature>();
+ for (int i = 0; i < features->Length; i++)
+ {
+ m_features->push_back({ features[i].Value, features[i].Id });
+ }
+
+ m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features);
+ }
+
+ String^ ToString() override
+ {
+ return gcnew String(m_native_context->To_String().c_str());
+ }
+
+ ~SimpleContext()
+ {
+ delete m_native_context;
+ }
+
+ public:
+ cli::array<Feature>^ GetFeatures() { return Features; }
+
+ internal:
+ cli::array<Feature>^ Features;
+
+ private:
+ vector<NativeMultiWorldTesting::Feature>* m_features;
+ NativeMultiWorldTesting::SimpleContext* m_native_context;
+ };
+}
+
/*! @} End of Doxygen Groups*/
\ No newline at end of file diff --git a/explore/clr/explore_interface.h b/explore/clr/explore_interface.h index 62b606dc..3212b4d8 100644 --- a/explore/clr/explore_interface.h +++ b/explore/clr/explore_interface.h @@ -1,30 +1,30 @@ -#pragma once - -using namespace System; -using namespace System::Collections::Generic; - +#pragma once
+
+using namespace System;
+using namespace System::Collections::Generic;
+
/** \defgroup MultiWorldTestingCsharp
\brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
-*/ - +*/
+
/*!
* \addtogroup MultiWorldTestingCsharp
* @{
-*/ - -//! Interface for C# version of Multiworld Testing library. -//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs -namespace MultiWorldTesting { - +*/
+
+//! Interface for C# version of Multiworld Testing library.
+//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
+namespace MultiWorldTesting {
+
/// <summary>
/// Represents a recorder that exposes a method to record exploration data based on generic contexts.
/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam> -/// <remarks> -/// Exploration data is specified as a set of tuples <context, action, probability, key> as described below. An -/// application passes an IRecorder object to the @MwtExplorer constructor. See -/// @StringRecorder for a sample IRecorder object. -/// </remarks> +/// <typeparam name="Ctx">The Context type.</typeparam>
+/// <remarks>
+/// Exploration data is specified as a set of tuples <context, action, probability, key> as described below. An
+/// application passes an IRecorder object to the @MwtExplorer constructor. See
+/// @StringRecorder for a sample IRecorder object.
+/// </remarks>
generic <class Ctx>
public interface class IRecorder
{
@@ -40,7 +40,7 @@ public: };
/// <summary>
-/// Exposes a method for choosing an action given a generic context. IPolicy objects are +/// Exposes a method for choosing an action given a generic context. IPolicy objects are
/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
/// </summary>
/// <typeparam name="Ctx">The Context type.</typeparam>
@@ -57,7 +57,7 @@ public: };
/// <summary>
-/// Exposes a method for specifying a score (weight) for each action given a generic context. +/// Exposes a method for specifying a score (weight) for each action given a generic context.
/// </summary>
/// <typeparam name="Ctx">The Context type.</typeparam>
generic <class Ctx>
diff --git a/explore/clr/explore_interop.h b/explore/clr/explore_interop.h index 85b2d231..2042476b 100644 --- a/explore/clr/explore_interop.h +++ b/explore/clr/explore_interop.h @@ -1,358 +1,358 @@ -#pragma once - -#define MANAGED_CODE - -#include "explore_interface.h" -#include "MWTExplorer.h" - -#include <msclr\marshal_cppstd.h> - -using namespace System; -using namespace System::Collections::Generic; -using namespace System::IO; -using namespace System::Runtime::InteropServices; -using namespace System::Xml::Serialization; -using namespace msclr::interop; - -namespace MultiWorldTesting { - -// Policy callback -private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index); -typedef u32 Native_Policy_Callback(void* explorer, void* context, int index); - -// Scorer callback -private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size); -typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size); - -// Recorder callback -private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey); -typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key); - -// ToString callback -private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue); -typedef void Native_To_String_Callback(void* explorer, void* string_value); - -// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context -// used for triggering callback for Policy, Scorer, Recorder -class NativeContext -{ -public: - NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context) - { - m_clr_mwt = clr_mwt; - m_clr_explorer = clr_explorer; - m_clr_context = clr_context; - } - - void* Get_Clr_Mwt() - { - return m_clr_mwt; - } - - void* Get_Clr_Context() - { - return m_clr_context; - } - - void* Get_Clr_Explorer() - { - return m_clr_explorer; - } - -private: - void* m_clr_mwt; - void* m_clr_context; - void* m_clr_explorer; -}; - -class NativeStringContext -{ -public: - NativeStringContext(void* clr_context, Native_To_String_Callback* func) - { - m_clr_context = clr_context; - m_func = func; - } - - string To_String() - { - string value; - m_func(m_clr_context, &value); - return value; - } -private: - void* m_clr_context; - Native_To_String_Callback* m_func; -}; - -// NativeRecorder listens to callback event and reroute it to the managed Recorder instance -class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext> -{ -public: - NativeRecorder(Native_Recorder_Callback* native_func) - { - m_func = native_func; - } - - void Record(NativeContext& context, u32 action, float probability, string unique_key) - { - GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str())); - IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle; - - m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer()); - - uniqueKeyHandle.Free(); - } -private: - Native_Recorder_Callback* m_func; -}; - -// NativePolicy listens to callback event and reroute it to the managed Policy instance -class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext> -{ -public: - NativePolicy(Native_Policy_Callback* func, int index = -1) - { - m_func = func; - m_index = index; - } - - u32 Choose_Action(NativeContext& context) - { - return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index); - } - -private: - Native_Policy_Callback* m_func; - int m_index; -}; - -class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext> -{ -public: - NativeScorer(Native_Scorer_Callback* func) - { - m_func = func; - } - - vector<float> Score_Actions(NativeContext& context) - { - float* scores = nullptr; - u32 num_scores = 0; - m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores); - - // It's ok if scores is null, vector will be empty - vector<float> scores_vector(scores, scores + num_scores); - delete[] scores; - - return scores_vector; - } -private: - Native_Scorer_Callback* m_func; -}; - -// Triggers callback to the Policy instance to choose an action -generic <class Ctx> -public ref class PolicyCallback abstract -{ -internal: - virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0; - - PolicyCallback() - { - policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke); - IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback); - m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer()); - m_native_policy = nullptr; - m_native_policies = nullptr; - } - - ~PolicyCallback() - { - delete m_native_policy; - delete m_native_policies; - } - - NativePolicy* GetNativePolicy() - { - if (m_native_policy == nullptr) - { - m_native_policy = new NativePolicy(m_callback); - } - return m_native_policy; - } - - vector<NativeMultiWorldTesting::PolicyPtr<NativeContext>>* GetNativePolicies(int count) - { - if (m_native_policies == nullptr) - { - m_native_policies = new vector<NativeMultiWorldTesting::PolicyPtr<NativeContext>>(); +#pragma once
+
+#define MANAGED_CODE
+
+#include "explore_interface.h"
+#include "MWTExplorer.h"
+
+#include <msclr\marshal_cppstd.h>
+
+using namespace System;
+using namespace System::Collections::Generic;
+using namespace System::IO;
+using namespace System::Runtime::InteropServices;
+using namespace System::Xml::Serialization;
+using namespace msclr::interop;
+
+namespace MultiWorldTesting {
+
+// Policy callback
+private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index);
+typedef u32 Native_Policy_Callback(void* explorer, void* context, int index);
+
+// Scorer callback
+private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size);
+typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size);
+
+// Recorder callback
+private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey);
+typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key);
+
+// ToString callback
+private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue);
+typedef void Native_To_String_Callback(void* explorer, void* string_value);
+
+// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context
+// used for triggering callback for Policy, Scorer, Recorder
+class NativeContext
+{
+public:
+ NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context)
+ {
+ m_clr_mwt = clr_mwt;
+ m_clr_explorer = clr_explorer;
+ m_clr_context = clr_context;
+ }
+
+ void* Get_Clr_Mwt()
+ {
+ return m_clr_mwt;
+ }
+
+ void* Get_Clr_Context()
+ {
+ return m_clr_context;
+ }
+
+ void* Get_Clr_Explorer()
+ {
+ return m_clr_explorer;
+ }
+
+private:
+ void* m_clr_mwt;
+ void* m_clr_context;
+ void* m_clr_explorer;
+};
+
+class NativeStringContext
+{
+public:
+ NativeStringContext(void* clr_context, Native_To_String_Callback* func)
+ {
+ m_clr_context = clr_context;
+ m_func = func;
+ }
+
+ string To_String()
+ {
+ string value;
+ m_func(m_clr_context, &value);
+ return value;
+ }
+private:
+ void* m_clr_context;
+ Native_To_String_Callback* m_func;
+};
+
+// NativeRecorder listens to callback event and reroute it to the managed Recorder instance
+class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext>
+{
+public:
+ NativeRecorder(Native_Recorder_Callback* native_func)
+ {
+ m_func = native_func;
+ }
+
+ void Record(NativeContext& context, u32 action, float probability, string unique_key)
+ {
+ GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str()));
+ IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle;
+
+ m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer());
+
+ uniqueKeyHandle.Free();
+ }
+private:
+ Native_Recorder_Callback* m_func;
+};
+
+// NativePolicy listens to callback event and reroute it to the managed Policy instance
+class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext>
+{
+public:
+ NativePolicy(Native_Policy_Callback* func, int index = -1)
+ {
+ m_func = func;
+ m_index = index;
+ }
+
+ u32 Choose_Action(NativeContext& context)
+ {
+ return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index);
+ }
+
+private:
+ Native_Policy_Callback* m_func;
+ int m_index;
+};
+
+class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext>
+{
+public:
+ NativeScorer(Native_Scorer_Callback* func)
+ {
+ m_func = func;
+ }
+
+ vector<float> Score_Actions(NativeContext& context)
+ {
+ float* scores = nullptr;
+ u32 num_scores = 0;
+ m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores);
+
+ // It's ok if scores is null, vector will be empty
+ vector<float> scores_vector(scores, scores + num_scores);
+ delete[] scores;
+
+ return scores_vector;
+ }
+private:
+ Native_Scorer_Callback* m_func;
+};
+
+// Triggers callback to the Policy instance to choose an action
+generic <class Ctx>
+public ref class PolicyCallback abstract
+{
+internal:
+ virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0;
+
+ PolicyCallback()
+ {
+ policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke);
+ IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback);
+ m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer());
+ m_native_policy = nullptr;
+ m_native_policies = nullptr;
+ }
+
+ ~PolicyCallback()
+ {
+ delete m_native_policy;
+ delete m_native_policies;
+ }
+
+ NativePolicy* GetNativePolicy()
+ {
+ if (m_native_policy == nullptr)
+ {
+ m_native_policy = new NativePolicy(m_callback);
+ }
+ return m_native_policy;
+ }
+
+ vector<NativeMultiWorldTesting::PolicyPtr<NativeContext>>* GetNativePolicies(int count)
+ {
+ if (m_native_policies == nullptr)
+ {
+ m_native_policies = new vector<NativeMultiWorldTesting::PolicyPtr<NativeContext>>();
for (int i = 0; i < count; i++)
{
m_native_policies->push_back(NativeMultiWorldTesting::PolicyPtr<NativeContext>(new NativePolicy(m_callback, i)));
- } - } - - return m_native_policies; - } - - static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index) - { - GCHandle callbackHandle = (GCHandle)callbackPtr; - PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target; - - GCHandle contextHandle = (GCHandle)contextPtr; - Ctx context = (Ctx)contextHandle.Target; - - return callback->InvokePolicyCallback(context, index); - } - -private: - ClrPolicyCallback^ policyCallback; - -private: - NativePolicy* m_native_policy; - vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies; - Native_Policy_Callback* m_callback; -}; - -// Triggers callback to the Recorder instance to record interaction data -generic <class Ctx> -public ref class RecorderCallback abstract -{ -internal: - virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0; - - RecorderCallback() - { - recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke); - IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback); - Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer()); - m_native_recorder = new NativeRecorder(callback); - } - - ~RecorderCallback() - { - delete m_native_recorder; - } - - NativeRecorder* GetNativeRecorder() - { - return m_native_recorder; - } - - static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr) - { - GCHandle mwtHandle = (GCHandle)mwtPtr; - RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target; - - GCHandle contextHandle = (GCHandle)contextPtr; - Ctx context = (Ctx)contextHandle.Target; - - GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr; - String^ uniqueKey = (String^)uniqueKeyHandle.Target; - - callback->InvokeRecorderCallback(context, action, probability, uniqueKey); - } - -private: - ClrRecorderCallback^ recorderCallback; - -private: - NativeRecorder* m_native_recorder; -}; - -// Triggers callback to the Recorder instance to record interaction data -generic <class Ctx> -public ref class ScorerCallback abstract -{ -internal: - virtual List<float>^ InvokeScorerCallback(Ctx context) = 0; - - ScorerCallback() - { - scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke); - IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback); - Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer()); - m_native_scorer = new NativeScorer(callback); - } - - ~ScorerCallback() - { - delete m_native_scorer; - } - - NativeScorer* GetNativeScorer() - { - return m_native_scorer; - } - - static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr) - { - GCHandle callbackHandle = (GCHandle)callbackPtr; - ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target; - - GCHandle contextHandle = (GCHandle)contextPtr; - Ctx context = (Ctx)contextHandle.Target; - - List<float>^ scoreList = callback->InvokeScorerCallback(context); - - if (scoreList == nullptr || scoreList->Count == 0) - { - return; - } - - u32* num_scores = (u32*)sizePtr.ToPointer(); - *num_scores = (u32)scoreList->Count; - - float* scores = new float[*num_scores]; - for (u32 i = 0; i < *num_scores; i++) - { - scores[i] = scoreList[i]; - } - - float** native_scores = (float**)scoresPtr.ToPointer(); - *native_scores = scores; - } - -private: - ClrScorerCallback^ scorerCallback; - -private: - NativeScorer* m_native_scorer; -}; - -// Triggers callback to the Context instance to perform ToString() operation -generic <class Ctx> where Ctx : IStringContext -public ref class ToStringCallback -{ -internal: - ToStringCallback() - { - toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke); - IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback); - m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer()); - } - - Native_To_String_Callback* GetCallback() - { - return m_callback; - } - - static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr) - { - GCHandle contextHandle = (GCHandle)contextPtr; - Ctx context = (Ctx)contextHandle.Target; - - string* out_string = (string*)stringPtr.ToPointer(); - *out_string = marshal_as<string>(context->ToString()); - } - -private: - ClrToStringCallback^ toStringCallback; - -private: - Native_To_String_Callback* m_callback; -}; - + }
+ }
+
+ return m_native_policies;
+ }
+
+ static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index)
+ {
+ GCHandle callbackHandle = (GCHandle)callbackPtr;
+ PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target;
+
+ GCHandle contextHandle = (GCHandle)contextPtr;
+ Ctx context = (Ctx)contextHandle.Target;
+
+ return callback->InvokePolicyCallback(context, index);
+ }
+
+private:
+ ClrPolicyCallback^ policyCallback;
+
+private:
+ NativePolicy* m_native_policy;
+ vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies;
+ Native_Policy_Callback* m_callback;
+};
+
+// Triggers callback to the Recorder instance to record interaction data
+generic <class Ctx>
+public ref class RecorderCallback abstract
+{
+internal:
+ virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0;
+
+ RecorderCallback()
+ {
+ recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke);
+ IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback);
+ Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer());
+ m_native_recorder = new NativeRecorder(callback);
+ }
+
+ ~RecorderCallback()
+ {
+ delete m_native_recorder;
+ }
+
+ NativeRecorder* GetNativeRecorder()
+ {
+ return m_native_recorder;
+ }
+
+ static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr)
+ {
+ GCHandle mwtHandle = (GCHandle)mwtPtr;
+ RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target;
+
+ GCHandle contextHandle = (GCHandle)contextPtr;
+ Ctx context = (Ctx)contextHandle.Target;
+
+ GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr;
+ String^ uniqueKey = (String^)uniqueKeyHandle.Target;
+
+ callback->InvokeRecorderCallback(context, action, probability, uniqueKey);
+ }
+
+private:
+ ClrRecorderCallback^ recorderCallback;
+
+private:
+ NativeRecorder* m_native_recorder;
+};
+
+// Triggers callback to the Recorder instance to record interaction data
+generic <class Ctx>
+public ref class ScorerCallback abstract
+{
+internal:
+ virtual List<float>^ InvokeScorerCallback(Ctx context) = 0;
+
+ ScorerCallback()
+ {
+ scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke);
+ IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback);
+ Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer());
+ m_native_scorer = new NativeScorer(callback);
+ }
+
+ ~ScorerCallback()
+ {
+ delete m_native_scorer;
+ }
+
+ NativeScorer* GetNativeScorer()
+ {
+ return m_native_scorer;
+ }
+
+ static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr)
+ {
+ GCHandle callbackHandle = (GCHandle)callbackPtr;
+ ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target;
+
+ GCHandle contextHandle = (GCHandle)contextPtr;
+ Ctx context = (Ctx)contextHandle.Target;
+
+ List<float>^ scoreList = callback->InvokeScorerCallback(context);
+
+ if (scoreList == nullptr || scoreList->Count == 0)
+ {
+ return;
+ }
+
+ u32* num_scores = (u32*)sizePtr.ToPointer();
+ *num_scores = (u32)scoreList->Count;
+
+ float* scores = new float[*num_scores];
+ for (u32 i = 0; i < *num_scores; i++)
+ {
+ scores[i] = scoreList[i];
+ }
+
+ float** native_scores = (float**)scoresPtr.ToPointer();
+ *native_scores = scores;
+ }
+
+private:
+ ClrScorerCallback^ scorerCallback;
+
+private:
+ NativeScorer* m_native_scorer;
+};
+
+// Triggers callback to the Context instance to perform ToString() operation
+generic <class Ctx> where Ctx : IStringContext
+public ref class ToStringCallback
+{
+internal:
+ ToStringCallback()
+ {
+ toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke);
+ IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback);
+ m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer());
+ }
+
+ Native_To_String_Callback* GetCallback()
+ {
+ return m_callback;
+ }
+
+ static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr)
+ {
+ GCHandle contextHandle = (GCHandle)contextPtr;
+ Ctx context = (Ctx)contextHandle.Target;
+
+ string* out_string = (string*)stringPtr.ToPointer();
+ *out_string = marshal_as<string>(context->ToString());
+ }
+
+private:
+ ClrToStringCallback^ toStringCallback;
+
+private:
+ Native_To_String_Callback* m_callback;
+};
+
}
\ No newline at end of file diff --git a/explore/explore.cpp b/explore/explore.cpp index 8b5ff465..ebc1c14e 100644 --- a/explore/explore.cpp +++ b/explore/explore.cpp @@ -1,15 +1,15 @@ -// explore.cpp : Timing code to measure performance of MWT Explorer library - -#include "MWTExplorer.h" -#include <chrono> -#include <tuple> -#include <iostream> - -using namespace std; -using namespace std::chrono; - -using namespace MultiWorldTesting; - +// explore.cpp : Timing code to measure performance of MWT Explorer library
+
+#include "MWTExplorer.h"
+#include <chrono>
+#include <tuple>
+#include <iostream>
+
+using namespace std;
+using namespace std::chrono;
+
+using namespace MultiWorldTesting;
+
class MySimplePolicy : public IPolicy<SimpleContext>
{
public:
@@ -18,56 +18,56 @@ public: return (u32)1;
}
};
- -const u32 num_actions = 10; - -void Clock_Explore() -{ - float epsilon = .2f; - string unique_key = "key"; - int num_features = 1000; - int num_iter = 10000; - int num_warmup = 100; - int num_interactions = 1; - - // pre-create features - vector<Feature> features; - for (int i = 0; i < num_features; i++) - { - Feature f = {0.5, i+1}; - features.push_back(f); - } - - long long time_init = 0, time_choose = 0; - for (int iter = 0; iter < num_iter + num_warmup; iter++) - { - high_resolution_clock::time_point t1 = high_resolution_clock::now(); - StringRecorder<SimpleContext> recorder; - MwtExplorer<SimpleContext> mwt("test", recorder); - MySimplePolicy default_policy; +
+const u32 num_actions = 10;
+
+void Clock_Explore()
+{
+ float epsilon = .2f;
+ string unique_key = "key";
+ int num_features = 1000;
+ int num_iter = 10000;
+ int num_warmup = 100;
+ int num_interactions = 1;
+
+ // pre-create features
+ vector<Feature> features;
+ for (int i = 0; i < num_features; i++)
+ {
+ Feature f = {0.5, i+1};
+ features.push_back(f);
+ }
+
+ long long time_init = 0, time_choose = 0;
+ for (int iter = 0; iter < num_iter + num_warmup; iter++)
+ {
+ high_resolution_clock::time_point t1 = high_resolution_clock::now();
+ StringRecorder<SimpleContext> recorder;
+ MwtExplorer<SimpleContext> mwt("test", recorder);
+ MySimplePolicy default_policy;
EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions);
- high_resolution_clock::time_point t2 = high_resolution_clock::now(); - time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count(); - - t1 = high_resolution_clock::now(); - SimpleContext appContext(features); - for (int i = 0; i < num_interactions; i++) - { - mwt.Choose_Action(explorer, unique_key, appContext); - } - t2 = high_resolution_clock::now(); - time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count(); - } - - cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl; - cout << "--- PER ITERATION ---" << endl; - cout << "Init: " << (double)time_init / num_iter << " micro" << endl; - cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl; - cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl; -} - -int main(int argc, char* argv[]) -{ - Clock_Explore(); - return 0; -} + high_resolution_clock::time_point t2 = high_resolution_clock::now();
+ time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
+
+ t1 = high_resolution_clock::now();
+ SimpleContext appContext(features);
+ for (int i = 0; i < num_interactions; i++)
+ {
+ mwt.Choose_Action(explorer, unique_key, appContext);
+ }
+ t2 = high_resolution_clock::now();
+ time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
+ }
+
+ cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl;
+ cout << "--- PER ITERATION ---" << endl;
+ cout << "Init: " << (double)time_init / num_iter << " micro" << endl;
+ cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl;
+ cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl;
+}
+
+int main(int argc, char* argv[])
+{
+ Clock_Explore();
+ return 0;
+}
diff --git a/explore/static/MWTExplorer.h b/explore/static/MWTExplorer.h index 1e12096b..0d0e1646 100644 --- a/explore/static/MWTExplorer.h +++ b/explore/static/MWTExplorer.h @@ -1,672 +1,672 @@ -// -// Main interface for clients of the Multiworld testing (MWT) service. -// - -#pragma once - -#include <stdexcept> -#include <float.h> -#include <math.h> -#include <stdio.h> -#include <string.h> -#include <vector> -#include <utility> -#include <memory> -#include <limits.h> -#include <tuple> - -#ifdef MANAGED_CODE -#define PORTING_INTERFACE public -#define MWT_NAMESPACE namespace NativeMultiWorldTesting -#else -#define PORTING_INTERFACE private -#define MWT_NAMESPACE namespace MultiWorldTesting -#endif - -using namespace std; - -#include "utility.h" - +//
+// Main interface for clients of the Multiworld testing (MWT) service.
+//
+
+#pragma once
+
+#include <stdexcept>
+#include <float.h>
+#include <math.h>
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include <utility>
+#include <memory>
+#include <limits.h>
+#include <tuple>
+
+#ifdef MANAGED_CODE
+#define PORTING_INTERFACE public
+#define MWT_NAMESPACE namespace NativeMultiWorldTesting
+#else
+#define PORTING_INTERFACE private
+#define MWT_NAMESPACE namespace MultiWorldTesting
+#endif
+
+using namespace std;
+
+#include "utility.h"
+
/** \defgroup MultiWorldTestingCpp
\brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
-*/ - +*/
+
/*!
* \addtogroup MultiWorldTestingCpp
* @{
-*/ - -//! Interface for C++ version of Multiworld Testing library. -//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp -MWT_NAMESPACE { - -// Forward declarations -template <class Ctx> -class IRecorder; -template <class Ctx> -class IExplorer; - -/// -/// The top-level MwtExplorer class. Using this enables principled and efficient exploration -/// over a set of possible actions, and ensures that the right bits are recorded. -/// -template <class Ctx> -class MwtExplorer -{ -public: - /// - /// Constructor - /// - /// @param appid This should be unique to your experiment or you risk nasty correlation bugs. - /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning. - /// - MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder) - { - m_app_id = HashUtils::Compute_Id_Hash(app_id); - } - - /// - /// Chooses an action by invoking an underlying exploration algorithm. This should be a - /// drop-in replacement for any existing policy function. - /// - /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback. - /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc.. - /// @param context The context upon which a decision is made. See SimpleContext below for an example. - /// - u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context) - { - u64 seed = HashUtils::Compute_Id_Hash(unique_key); - - std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context); - - u32 action = std::get<0>(action_probability_log_tuple); - float prob = std::get<1>(action_probability_log_tuple); - - if (std::get<2>(action_probability_log_tuple)) - { - m_recorder.Record(context, action, prob, unique_key); - } - - return action; - } - -private: - u64 m_app_id; - IRecorder<Ctx>& m_recorder; -}; - -/// -/// Exposes a method to record exploration data based on generic contexts. Exploration data -/// is specified as a set of tuples <context, action, probability, key> as described below. An -/// application passes an IRecorder object to the @MwtExplorer constructor. See -/// @StringRecorder for a sample IRecorder object. -/// -template <class Ctx> -class IRecorder -{ -public: - /// - /// Records the exploration data associated with a given decision. - /// - /// @param context A user-defined context for the decision - /// @param action The action chosen by an exploration algorithm given context - /// @param probability The probability the exploration algorithm chose said action - /// @param unique_key A user-defined unique identifer for the decision - /// - virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0; -}; - -/// -/// Exposes a method to choose an action given a generic context, and obtain the relevant -/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this -/// interface yourself: instead, use the various exploration algorithms below, which -/// implement it for you. -/// -template <class Ctx> -class IExplorer -{ -public: - /// - /// Determines the action to take and the probability with which it was chosen, for a - /// given context. - /// - /// @param salted_seed A PRG seed based on a unique id information provided by the user - /// @param context A user-defined context for the decision - /// @returns The action to take, the probability it was chosen, and a flag indicating - /// whether to record this decision - /// - virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0; -}; - -/// -/// Exposes a method to choose an action given a generic context. IPolicy objects are -/// passed to (and invoked by) exploration algorithms to specify the default policy behavior. -/// -template <class Ctx> -class IPolicy -{ -public: - /// - /// Determines the action to take for a given context. - /// - /// @param context A user-defined context for the decision - /// @returns The action to take (1-based index) - /// - virtual u32 Choose_Action(Ctx& context) = 0; -}; - -/// -/// Exposes a method for specifying a score (weight) for each action given a generic context. -/// -template <class Ctx> -class IScorer -{ -public: - /// - /// Determines the score of each action for a given context. - /// - /// @param context A user-defined context for the decision - /// @returns A vector of scores indexed by action (1-based) - /// - virtual vector<float> Score_Actions(Ctx& context) = 0; -}; - -/// -/// A sample recorder class that converts the exploration tuple into string format. -/// -template <class Ctx> -struct StringRecorder : public IRecorder<Ctx> -{ - void Record(Ctx& context, u32 action, float probability, string unique_key) - { - // Implicitly enforce To_String() API on the context - m_recording.append(to_string(action)); - m_recording.append(" ", 1); - m_recording.append(unique_key); - m_recording.append(" ", 1); - - char prob_str[10] = { 0 }; - NumberUtils::Float_To_String(probability, prob_str); - m_recording.append(prob_str); - - m_recording.append(" | ", 3); - m_recording.append(context.To_String()); - m_recording.append("\n"); - } - - // Gets the content of the recording so far as a string and optionally clears internal content. - string Get_Recording(bool flush = true) - { - if (!flush) - { - return m_recording; - } - string recording = m_recording; - m_recording.clear(); - return recording; - } - -private: - string m_recording; -}; - -/// -/// Represents a feature in a sparse array. -/// -struct Feature -{ - float Value; - u32 Id; - - bool operator==(Feature other_feature) - { - return Id == other_feature.Id; - } -}; - -/// -/// A sample context class that stores a vector of Features. -/// -class SimpleContext -{ -public: - SimpleContext(vector<Feature>& features) : - m_features(features) - { } - - vector<Feature>& Get_Features() - { - return m_features; - } - - string To_String() - { - string out_string; - char feature_str[35] = { 0 }; - for (size_t i = 0; i < m_features.size(); i++) - { - int chars; - if (i == 0) - { - chars = sprintf(feature_str, "%d:", m_features[i].Id); - } - else - { - chars = sprintf(feature_str, " %d:", m_features[i].Id); - } - NumberUtils::print_float(feature_str + chars, m_features[i].Value); - out_string.append(feature_str); - } - return out_string; - } - -private: - vector<Feature>& m_features; -}; - -/// -/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea -/// which actions should be preferred. Epsilon greedy is also computationally cheap. -/// -template <class Ctx> -class EpsilonGreedyExplorer : public IExplorer<Ctx> -{ -public: - /// - /// The constructor is the only public member, because this should be used with the MwtExplorer. - /// - /// @param default_policy A default function which outputs an action given a context. - /// @param epsilon The probability of a random exploration. - /// @param num_actions The number of actions to randomize over. - /// - EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) : - m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions) - { - if (m_num_actions < 1) - { - throw std::invalid_argument("Number of actions must be at least 1."); - } - - if (m_epsilon < 0 || m_epsilon > 1) - { - throw std::invalid_argument("Epsilon must be between 0 and 1."); - } - } - - ~EpsilonGreedyExplorer() - { - } - -private: - std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) - { - PRG::prg random_generator(salted_seed); - - // Invoke the default policy function to get the action - u32 chosen_action = m_default_policy.Choose_Action(context); - - if (chosen_action == 0 || chosen_action > m_num_actions) - { - throw std::invalid_argument("Action chosen by default policy is not within valid range."); - } - - float action_probability = 0.f; - float base_probability = m_epsilon / m_num_actions; // uniform probability - - // TODO: check this random generation - if (random_generator.Uniform_Unit_Interval() < 1.f - m_epsilon) - { - action_probability = 1.f - m_epsilon + base_probability; - } - else - { - // Get uniform random action ID - u32 actionId = random_generator.Uniform_Int(1, m_num_actions); - - if (actionId == chosen_action) - { - // IF it matches the one chosen by the default policy - // then increase the probability - action_probability = 1.f - m_epsilon + base_probability; - } - else - { - // Otherwise it's just the uniform probability - action_probability = base_probability; - } - chosen_action = actionId; - } - - return std::tuple<u32, float, bool>(chosen_action, action_probability, true); - } - -private: - IPolicy<Ctx>& m_default_policy; - float m_epsilon; - u32 m_num_actions; - -private: - friend class MwtExplorer<Ctx>; -}; - -/// -/// In some cases, different actions have a different scores, and you would prefer to -/// choose actions with large scores. Softmax allows you to do that. -/// -template <class Ctx> -class SoftmaxExplorer : public IExplorer<Ctx> -{ -public: - /// - /// The constructor is the only public member, because this should be used with the MwtExplorer. - /// - /// @param default_scorer A function which outputs a score for each action. - /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max. - /// @param num_actions The number of actions to randomize over. - /// - SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) : - m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions) - { - if (m_num_actions < 1) - { - throw std::invalid_argument("Number of actions must be at least 1."); - } - } - -private: - std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) - { - PRG::prg random_generator(salted_seed); - - // Invoke the default scorer function - vector<float> scores = m_default_scorer.Score_Actions(context); - u32 num_scores = (u32)scores.size(); - if (num_scores != m_num_actions) - { - throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions"); - } - - u32 i = 0; - - float max_score = -FLT_MAX; - for (i = 0; i < num_scores; i++) - { - if (max_score < scores[i]) - { - max_score = scores[i]; - } - } - - // Create a normalized exponential distribution based on the returned scores - for (i = 0; i < num_scores; i++) - { - scores[i] = exp(m_lambda * (scores[i] - max_score)); - } - - // Create a discrete_distribution based on the returned weights. This class handles the - // case where the sum of the weights is < or > 1, by normalizing agains the sum. - float total = 0.f; - for (size_t i = 0; i < num_scores; i++) - total += scores[i]; - - float draw = random_generator.Uniform_Unit_Interval(); - - float sum = 0.f; - float action_probability = 0.f; - u32 action_index = num_scores - 1; - for (u32 i = 0; i < num_scores; i++) - { - scores[i] = scores[i] / total; - sum += scores[i]; - if (sum > draw) - { - action_index = i; - action_probability = scores[i]; - break; - } - } - - // action id is one-based - return std::tuple<u32, float, bool>(action_index + 1, action_probability, true); - } - -private: - IScorer<Ctx>& m_default_scorer; - float m_lambda; - u32 m_num_actions; - -private: - friend class MwtExplorer<Ctx>; -}; - -/// -/// GenericExplorer provides complete flexibility. You can create any -/// distribution over actions desired, and it will draw from that. -/// -template <class Ctx> -class GenericExplorer : public IExplorer<Ctx> -{ -public: - /// - /// The constructor is the only public member, because this should be used with the MwtExplorer. - /// - /// @param default_scorer A function which outputs the probability of each action. - /// @param num_actions The number of actions to randomize over. - /// - GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) : - m_default_scorer(default_scorer), m_num_actions(num_actions) - { - if (m_num_actions < 1) - { - throw std::invalid_argument("Number of actions must be at least 1."); - } - } - -private: - std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) - { - PRG::prg random_generator(salted_seed); - - // Invoke the default scorer function - vector<float> weights = m_default_scorer.Score_Actions(context); - u32 num_weights = (u32)weights.size(); - if (num_weights != m_num_actions) - { - throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions"); - } - - // Create a discrete_distribution based on the returned weights. This class handles the - // case where the sum of the weights is < or > 1, by normalizing agains the sum. - float total = 0.f; - for (size_t i = 0; i < num_weights; i++) - { - if (weights[i] < 0) - { - throw std::invalid_argument("Scores must be non-negative."); - } - total += weights[i]; - } - if (total == 0) - { - throw std::invalid_argument("At least one score must be positive."); - } - - float draw = random_generator.Uniform_Unit_Interval(); - - float sum = 0.f; - float action_probability = 0.f; - u32 action_index = num_weights - 1; - for (u32 i = 0; i < num_weights; i++) - { - weights[i] = weights[i] / total; - sum += weights[i]; - if (sum > draw) - { - action_index = i; - action_probability = weights[i]; - break; - } - } - - // action id is one-based - return std::tuple<u32, float, bool>(action_index + 1, action_probability, true); - } - -private: - IScorer<Ctx>& m_default_scorer; - u32 m_num_actions; - -private: - friend class MwtExplorer<Ctx>; -}; - -/// -/// The tau-first explorer collects exactly tau uniform random exploration events, and then -/// uses the default policy thereafter. -/// -template <class Ctx> -class TauFirstExplorer : public IExplorer<Ctx> -{ -public: - - /// - /// The constructor is the only public member, because this should be used with the MwtExplorer. - /// - /// @param default_policy A default policy after randomization finishes. - /// @param tau The number of events to be uniform over. - /// @param num_actions The number of actions to randomize over. - /// - TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) : - m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions) - { - if (m_num_actions < 1) - { - throw std::invalid_argument("Number of actions must be at least 1."); - } - } - -private: - std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) - { - PRG::prg random_generator(salted_seed); - - u32 chosen_action = 0; - float action_probability = 0.f; - bool log_action; - if (m_tau) - { - m_tau--; - u32 actionId = random_generator.Uniform_Int(1, m_num_actions); - action_probability = 1.f / m_num_actions; - chosen_action = actionId; - log_action = true; - } - else - { - // Invoke the default policy function to get the action - chosen_action = m_default_policy.Choose_Action(context); - - if (chosen_action == 0 || chosen_action > m_num_actions) - { - throw std::invalid_argument("Action chosen by default policy is not within valid range."); - } - - action_probability = 1.f; - log_action = false; - } - - return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action); - } - -private: - IPolicy<Ctx>& m_default_policy; - u32 m_tau; - u32 m_num_actions; - -private: - friend class MwtExplorer<Ctx>; -}; - -/// Represents a smart pointer to IPolicy<Ctx> -template <class Ctx> -using PolicyPtr = unique_ptr<IPolicy<Ctx>>; - -/// -/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies. -/// This performs well statistically but can be computationally expensive. -/// -template <class Ctx> -class BootstrapExplorer : public IExplorer<Ctx> -{ -public: - /// - /// The constructor is the only public member, because this should be used with the MwtExplorer. - /// - /// @param default_policy_functions A set of default policies to be uniform random over. - /// @param num_actions The number of actions to randomize over. - /// - BootstrapExplorer(vector<PolicyPtr<Ctx>>& default_policy_functions, u32 num_actions) : - m_default_policy_functions(default_policy_functions), - m_num_actions(num_actions) - { - m_bags = (u32)default_policy_functions.size(); - if (m_num_actions < 1) - { - throw std::invalid_argument("Number of actions must be at least 1."); - } - - if (m_bags < 1) - { - throw std::invalid_argument("Number of bags must be at least 1."); - } - } - -private: - std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) - { - PRG::prg random_generator(salted_seed); - - // Select bag - u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1); - - // Invoke the default policy function to get the action - u32 chosen_action = 0; - u32 action_from_bag = 0; - vector<u32> actions_selected; - for (size_t i = 0; i < m_num_actions; i++) - { - actions_selected.push_back(0); - } - - // Invoke the default policy function to get the action - for (u32 current_bag = 0; current_bag < m_bags; current_bag++) - { - action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context); - - if (action_from_bag == 0 || action_from_bag > m_num_actions) - { - throw std::invalid_argument("Action chosen by default policy is not within valid range."); - } - - if (current_bag == chosen_bag) - { - chosen_action = action_from_bag; - } - //this won't work if actions aren't 0 to Count - actions_selected[action_from_bag - 1]++; // action id is one-based - } - float action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based - - return std::tuple<u32, float, bool>(chosen_action, action_probability, true); - } - -private: - vector<PolicyPtr<Ctx>>& m_default_policy_functions; - u32 m_bags; - u32 m_num_actions; - -private: - friend class MwtExplorer<Ctx>; -}; -} // End namespace MultiWorldTestingCpp +*/
+
+//! Interface for C++ version of Multiworld Testing library.
+//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
+MWT_NAMESPACE {
+
+// Forward declarations
+template <class Ctx>
+class IRecorder;
+template <class Ctx>
+class IExplorer;
+
+///
+/// The top-level MwtExplorer class. Using this enables principled and efficient exploration
+/// over a set of possible actions, and ensures that the right bits are recorded.
+///
+template <class Ctx>
+class MwtExplorer
+{
+public:
+ ///
+ /// Constructor
+ ///
+ /// @param appid This should be unique to your experiment or you risk nasty correlation bugs.
+ /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning.
+ ///
+ MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder)
+ {
+ m_app_id = HashUtils::Compute_Id_Hash(app_id);
+ }
+
+ ///
+ /// Chooses an action by invoking an underlying exploration algorithm. This should be a
+ /// drop-in replacement for any existing policy function.
+ ///
+ /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback.
+ /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc..
+ /// @param context The context upon which a decision is made. See SimpleContext below for an example.
+ ///
+ u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context)
+ {
+ u64 seed = HashUtils::Compute_Id_Hash(unique_key);
+
+ std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context);
+
+ u32 action = std::get<0>(action_probability_log_tuple);
+ float prob = std::get<1>(action_probability_log_tuple);
+
+ if (std::get<2>(action_probability_log_tuple))
+ {
+ m_recorder.Record(context, action, prob, unique_key);
+ }
+
+ return action;
+ }
+
+private:
+ u64 m_app_id;
+ IRecorder<Ctx>& m_recorder;
+};
+
+///
+/// Exposes a method to record exploration data based on generic contexts. Exploration data
+/// is specified as a set of tuples <context, action, probability, key> as described below. An
+/// application passes an IRecorder object to the @MwtExplorer constructor. See
+/// @StringRecorder for a sample IRecorder object.
+///
+template <class Ctx>
+class IRecorder
+{
+public:
+ ///
+ /// Records the exploration data associated with a given decision.
+ ///
+ /// @param context A user-defined context for the decision
+ /// @param action The action chosen by an exploration algorithm given context
+ /// @param probability The probability the exploration algorithm chose said action
+ /// @param unique_key A user-defined unique identifer for the decision
+ ///
+ virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0;
+};
+
+///
+/// Exposes a method to choose an action given a generic context, and obtain the relevant
+/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this
+/// interface yourself: instead, use the various exploration algorithms below, which
+/// implement it for you.
+///
+template <class Ctx>
+class IExplorer
+{
+public:
+ ///
+ /// Determines the action to take and the probability with which it was chosen, for a
+ /// given context.
+ ///
+ /// @param salted_seed A PRG seed based on a unique id information provided by the user
+ /// @param context A user-defined context for the decision
+ /// @returns The action to take, the probability it was chosen, and a flag indicating
+ /// whether to record this decision
+ ///
+ virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0;
+};
+
+///
+/// Exposes a method to choose an action given a generic context. IPolicy objects are
+/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
+///
+template <class Ctx>
+class IPolicy
+{
+public:
+ ///
+ /// Determines the action to take for a given context.
+ ///
+ /// @param context A user-defined context for the decision
+ /// @returns The action to take (1-based index)
+ ///
+ virtual u32 Choose_Action(Ctx& context) = 0;
+};
+
+///
+/// Exposes a method for specifying a score (weight) for each action given a generic context.
+///
+template <class Ctx>
+class IScorer
+{
+public:
+ ///
+ /// Determines the score of each action for a given context.
+ ///
+ /// @param context A user-defined context for the decision
+ /// @returns A vector of scores indexed by action (1-based)
+ ///
+ virtual vector<float> Score_Actions(Ctx& context) = 0;
+};
+
+///
+/// A sample recorder class that converts the exploration tuple into string format.
+///
+template <class Ctx>
+struct StringRecorder : public IRecorder<Ctx>
+{
+ void Record(Ctx& context, u32 action, float probability, string unique_key)
+ {
+ // Implicitly enforce To_String() API on the context
+ m_recording.append(to_string(action));
+ m_recording.append(" ", 1);
+ m_recording.append(unique_key);
+ m_recording.append(" ", 1);
+
+ char prob_str[10] = { 0 };
+ NumberUtils::Float_To_String(probability, prob_str);
+ m_recording.append(prob_str);
+
+ m_recording.append(" | ", 3);
+ m_recording.append(context.To_String());
+ m_recording.append("\n");
+ }
+
+ // Gets the content of the recording so far as a string and optionally clears internal content.
+ string Get_Recording(bool flush = true)
+ {
+ if (!flush)
+ {
+ return m_recording;
+ }
+ string recording = m_recording;
+ m_recording.clear();
+ return recording;
+ }
+
+private:
+ string m_recording;
+};
+
+///
+/// Represents a feature in a sparse array.
+///
+struct Feature
+{
+ float Value;
+ u32 Id;
+
+ bool operator==(Feature other_feature)
+ {
+ return Id == other_feature.Id;
+ }
+};
+
+///
+/// A sample context class that stores a vector of Features.
+///
+class SimpleContext
+{
+public:
+ SimpleContext(vector<Feature>& features) :
+ m_features(features)
+ { }
+
+ vector<Feature>& Get_Features()
+ {
+ return m_features;
+ }
+
+ string To_String()
+ {
+ string out_string;
+ char feature_str[35] = { 0 };
+ for (size_t i = 0; i < m_features.size(); i++)
+ {
+ int chars;
+ if (i == 0)
+ {
+ chars = sprintf(feature_str, "%d:", m_features[i].Id);
+ }
+ else
+ {
+ chars = sprintf(feature_str, " %d:", m_features[i].Id);
+ }
+ NumberUtils::print_float(feature_str + chars, m_features[i].Value);
+ out_string.append(feature_str);
+ }
+ return out_string;
+ }
+
+private:
+ vector<Feature>& m_features;
+};
+
+///
+/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea
+/// which actions should be preferred. Epsilon greedy is also computationally cheap.
+///
+template <class Ctx>
+class EpsilonGreedyExplorer : public IExplorer<Ctx>
+{
+public:
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_policy A default function which outputs an action given a context.
+ /// @param epsilon The probability of a random exploration.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) :
+ m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions)
+ {
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+
+ if (m_epsilon < 0 || m_epsilon > 1)
+ {
+ throw std::invalid_argument("Epsilon must be between 0 and 1.");
+ }
+ }
+
+ ~EpsilonGreedyExplorer()
+ {
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ // Invoke the default policy function to get the action
+ u32 chosen_action = m_default_policy.Choose_Action(context);
+
+ if (chosen_action == 0 || chosen_action > m_num_actions)
+ {
+ throw std::invalid_argument("Action chosen by default policy is not within valid range.");
+ }
+
+ float action_probability = 0.f;
+ float base_probability = m_epsilon / m_num_actions; // uniform probability
+
+ // TODO: check this random generation
+ if (random_generator.Uniform_Unit_Interval() < 1.f - m_epsilon)
+ {
+ action_probability = 1.f - m_epsilon + base_probability;
+ }
+ else
+ {
+ // Get uniform random action ID
+ u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
+
+ if (actionId == chosen_action)
+ {
+ // IF it matches the one chosen by the default policy
+ // then increase the probability
+ action_probability = 1.f - m_epsilon + base_probability;
+ }
+ else
+ {
+ // Otherwise it's just the uniform probability
+ action_probability = base_probability;
+ }
+ chosen_action = actionId;
+ }
+
+ return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
+ }
+
+private:
+ IPolicy<Ctx>& m_default_policy;
+ float m_epsilon;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+
+///
+/// In some cases, different actions have a different scores, and you would prefer to
+/// choose actions with large scores. Softmax allows you to do that.
+///
+template <class Ctx>
+class SoftmaxExplorer : public IExplorer<Ctx>
+{
+public:
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_scorer A function which outputs a score for each action.
+ /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) :
+ m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions)
+ {
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ // Invoke the default scorer function
+ vector<float> scores = m_default_scorer.Score_Actions(context);
+ u32 num_scores = (u32)scores.size();
+ if (num_scores != m_num_actions)
+ {
+ throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions");
+ }
+
+ u32 i = 0;
+
+ float max_score = -FLT_MAX;
+ for (i = 0; i < num_scores; i++)
+ {
+ if (max_score < scores[i])
+ {
+ max_score = scores[i];
+ }
+ }
+
+ // Create a normalized exponential distribution based on the returned scores
+ for (i = 0; i < num_scores; i++)
+ {
+ scores[i] = exp(m_lambda * (scores[i] - max_score));
+ }
+
+ // Create a discrete_distribution based on the returned weights. This class handles the
+ // case where the sum of the weights is < or > 1, by normalizing agains the sum.
+ float total = 0.f;
+ for (size_t i = 0; i < num_scores; i++)
+ total += scores[i];
+
+ float draw = random_generator.Uniform_Unit_Interval();
+
+ float sum = 0.f;
+ float action_probability = 0.f;
+ u32 action_index = num_scores - 1;
+ for (u32 i = 0; i < num_scores; i++)
+ {
+ scores[i] = scores[i] / total;
+ sum += scores[i];
+ if (sum > draw)
+ {
+ action_index = i;
+ action_probability = scores[i];
+ break;
+ }
+ }
+
+ // action id is one-based
+ return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
+ }
+
+private:
+ IScorer<Ctx>& m_default_scorer;
+ float m_lambda;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+
+///
+/// GenericExplorer provides complete flexibility. You can create any
+/// distribution over actions desired, and it will draw from that.
+///
+template <class Ctx>
+class GenericExplorer : public IExplorer<Ctx>
+{
+public:
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_scorer A function which outputs the probability of each action.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) :
+ m_default_scorer(default_scorer), m_num_actions(num_actions)
+ {
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ // Invoke the default scorer function
+ vector<float> weights = m_default_scorer.Score_Actions(context);
+ u32 num_weights = (u32)weights.size();
+ if (num_weights != m_num_actions)
+ {
+ throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions");
+ }
+
+ // Create a discrete_distribution based on the returned weights. This class handles the
+ // case where the sum of the weights is < or > 1, by normalizing agains the sum.
+ float total = 0.f;
+ for (size_t i = 0; i < num_weights; i++)
+ {
+ if (weights[i] < 0)
+ {
+ throw std::invalid_argument("Scores must be non-negative.");
+ }
+ total += weights[i];
+ }
+ if (total == 0)
+ {
+ throw std::invalid_argument("At least one score must be positive.");
+ }
+
+ float draw = random_generator.Uniform_Unit_Interval();
+
+ float sum = 0.f;
+ float action_probability = 0.f;
+ u32 action_index = num_weights - 1;
+ for (u32 i = 0; i < num_weights; i++)
+ {
+ weights[i] = weights[i] / total;
+ sum += weights[i];
+ if (sum > draw)
+ {
+ action_index = i;
+ action_probability = weights[i];
+ break;
+ }
+ }
+
+ // action id is one-based
+ return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
+ }
+
+private:
+ IScorer<Ctx>& m_default_scorer;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+
+///
+/// The tau-first explorer collects exactly tau uniform random exploration events, and then
+/// uses the default policy thereafter.
+///
+template <class Ctx>
+class TauFirstExplorer : public IExplorer<Ctx>
+{
+public:
+
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_policy A default policy after randomization finishes.
+ /// @param tau The number of events to be uniform over.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) :
+ m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions)
+ {
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ u32 chosen_action = 0;
+ float action_probability = 0.f;
+ bool log_action;
+ if (m_tau)
+ {
+ m_tau--;
+ u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
+ action_probability = 1.f / m_num_actions;
+ chosen_action = actionId;
+ log_action = true;
+ }
+ else
+ {
+ // Invoke the default policy function to get the action
+ chosen_action = m_default_policy.Choose_Action(context);
+
+ if (chosen_action == 0 || chosen_action > m_num_actions)
+ {
+ throw std::invalid_argument("Action chosen by default policy is not within valid range.");
+ }
+
+ action_probability = 1.f;
+ log_action = false;
+ }
+
+ return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action);
+ }
+
+private:
+ IPolicy<Ctx>& m_default_policy;
+ u32 m_tau;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+
+/// Represents a smart pointer to IPolicy<Ctx>
+template <class Ctx>
+using PolicyPtr = unique_ptr<IPolicy<Ctx>>;
+
+///
+/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies.
+/// This performs well statistically but can be computationally expensive.
+///
+template <class Ctx>
+class BootstrapExplorer : public IExplorer<Ctx>
+{
+public:
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_policy_functions A set of default policies to be uniform random over.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ BootstrapExplorer(vector<PolicyPtr<Ctx>>& default_policy_functions, u32 num_actions) :
+ m_default_policy_functions(default_policy_functions),
+ m_num_actions(num_actions)
+ {
+ m_bags = (u32)default_policy_functions.size();
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+
+ if (m_bags < 1)
+ {
+ throw std::invalid_argument("Number of bags must be at least 1.");
+ }
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ // Select bag
+ u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1);
+
+ // Invoke the default policy function to get the action
+ u32 chosen_action = 0;
+ u32 action_from_bag = 0;
+ vector<u32> actions_selected;
+ for (size_t i = 0; i < m_num_actions; i++)
+ {
+ actions_selected.push_back(0);
+ }
+
+ // Invoke the default policy function to get the action
+ for (u32 current_bag = 0; current_bag < m_bags; current_bag++)
+ {
+ action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context);
+
+ if (action_from_bag == 0 || action_from_bag > m_num_actions)
+ {
+ throw std::invalid_argument("Action chosen by default policy is not within valid range.");
+ }
+
+ if (current_bag == chosen_bag)
+ {
+ chosen_action = action_from_bag;
+ }
+ //this won't work if actions aren't 0 to Count
+ actions_selected[action_from_bag - 1]++; // action id is one-based
+ }
+ float action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based
+
+ return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
+ }
+
+private:
+ vector<PolicyPtr<Ctx>>& m_default_policy_functions;
+ u32 m_bags;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+} // End namespace MultiWorldTestingCpp
/*! @} End of Doxygen Groups*/
\ No newline at end of file diff --git a/explore/tests/MWTExploreTests.h b/explore/tests/MWTExploreTests.h index f79d3fd5..447a368f 100644 --- a/explore/tests/MWTExploreTests.h +++ b/explore/tests/MWTExploreTests.h @@ -2,9 +2,9 @@ #include "MWTExplorer.h"
#include "utility.h"
-#include <iomanip> -#include <iostream> -#include <sstream> +#include <iomanip>
+#include <iostream>
+#include <sstream>
using namespace MultiWorldTesting;
|