diff options
author | John Langford <jl@hunch.net> | 2015-01-04 04:45:55 +0300 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2015-01-04 04:45:55 +0300 |
commit | b913d178acb0052f4bb9622dd253a4ddc684a4ce (patch) | |
tree | 2a7c6348009dea6ee77caf67af897473fb2b77aa | |
parent | 00ce5188f90cc426493872e8d630b6c61f99df79 (diff) | |
parent | ed8c4d38aba3b49159f1b2574028b5cbae96a7f2 (diff) |
resolve conflicts
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | demo/dna/Makefile | 2 | ||||
-rw-r--r-- | demo/entityrelation/Makefile | 5 | ||||
-rw-r--r-- | demo/movielens/Makefile | 4 | ||||
-rw-r--r-- | demo/normalized/Makefile | 26 | ||||
-rw-r--r-- | explore/clr/explore_clr_wrapper.cpp | 36 | ||||
-rw-r--r-- | explore/clr/explore_clr_wrapper.h | 874 | ||||
-rw-r--r-- | explore/clr/explore_interface.h | 174 | ||||
-rw-r--r-- | explore/clr/explore_interop.h | 714 | ||||
-rw-r--r-- | explore/explore.cpp | 146 | ||||
-rw-r--r-- | explore/static/MWTExplorer.h | 1338 | ||||
-rw-r--r-- | explore/static/utility.h | 566 | ||||
-rw-r--r-- | explore/tests/MWTExploreTests.h | 328 | ||||
-rw-r--r-- | library/Makefile | 1 | ||||
-rw-r--r-- | library/ezexample_predict.cc | 110 | ||||
-rw-r--r-- | library/ezexample_predict_threaded.cc | 298 | ||||
-rw-r--r-- | library/ezexample_train.cc | 144 | ||||
-rw-r--r-- | library/library_example.cc | 124 | ||||
-rw-r--r-- | python/Makefile | 2 | ||||
-rw-r--r-- | vowpalwabbit/Makefile | 1 | ||||
-rw-r--r-- | vowpalwabbit/accumulate.h | 26 |
21 files changed, 2466 insertions, 2455 deletions
@@ -117,3 +117,5 @@ clean: ifneq ($(JAVA_HOME),) cd java && $(MAKE) clean endif + +.PHONY: all clean install diff --git a/demo/dna/Makefile b/demo/dna/Makefile index 93efac3b..477cfb98 100644 --- a/demo/dna/Makefile +++ b/demo/dna/Makefile @@ -27,3 +27,5 @@ quaddna2vw: quaddna2vw.cpp %.perf: %.test.predictions perf.check perl.check zsh.check ./do-perf $< + +.PHONY: all clean diff --git a/demo/entityrelation/Makefile b/demo/entityrelation/Makefile index c287d97c..c3db3a69 100644 --- a/demo/entityrelation/Makefile +++ b/demo/entityrelation/Makefile @@ -8,7 +8,7 @@ eval_script=./evaluationER.py all: @cat README.md clean: - rm -f *.model *.predictions ER_*.vw *.cache evaluationER.py er.zip *~ + rm -f *.model *.predictions ER_*.vw *.cache evaluationER.py er.zip *~ %.check: @test -x "$$(which $*)" || { \ @@ -33,5 +33,4 @@ er.test.predictions: ER_test.vw er.model er.perf: ER_test.vw er.test.predictions python.check @$(eval_script) $< er.test.predictions - - +.PHONY: all clean diff --git a/demo/movielens/Makefile b/demo/movielens/Makefile index 68be5310..46c92e24 100644 --- a/demo/movielens/Makefile +++ b/demo/movielens/Makefile @@ -39,7 +39,7 @@ ml-%.ratings.train.vw: ml-%/ratings.dat @printf "%s test MAE is %3.3f\n" $* $$(cat $*.results) #--------------------------------------------------------------------- -# linear model (no interaction terms) +# linear model (no interaction terms) #--------------------------------------------------------------------- linear.results: ml-1m.ratings.test.vw ml-1m.ratings.train.vw @@ -154,3 +154,5 @@ lrqdropouthogwild.results: ml-1m.ratings.test.vw ml-1m.ratings.train.vw do-lrq-h -d $(word 1,$+) -p \ >(perl -lane '$$s+=abs(($$F[0]-$$F[1])); } { \ 1; print $$s/$$.;' > $@) + +.PHONY: all clean shootout diff --git a/demo/normalized/Makefile b/demo/normalized/Makefile index 8268e82f..f1378b80 100644 --- a/demo/normalized/Makefile +++ b/demo/normalized/Makefile @@ -19,8 +19,8 @@ SHUFFLE='BEGIN { srand 69; }; \ $$b[$$i] = $$_; } { print grep { defined $$_ } @b;' #--------------------------------------------------------------------- -# bank marketing -# +# bank marketing +# # normalization really helps. The columns have units of euros, seconds, # days, and years; in addition there are categorical variables. #--------------------------------------------------------------------- @@ -57,10 +57,10 @@ bankbestmin=1e-2 bankbestmax=10 banknonormbestmin=1e-8 banknonormbestmax=1e-3 -banktimeestimate=1 +banktimeestimate=1 #--------------------------------------------------------------------- -# covertype +# covertype #--------------------------------------------------------------------- covtype.data.gz: @@ -88,11 +88,11 @@ covertypebestmin=1e-2 covertypebestmax=10 covertypenonormbestmin=1e-8 covertypenonormbestmax=1e-5 -covertypetimeestimate=5 +covertypetimeestimate=5 #--------------------------------------------------------------------- -# million song database -# +# million song database +# # normalization is helpful. #--------------------------------------------------------------------- @@ -142,7 +142,7 @@ MSDnonormbestmax=1e-5 MSDtimeestimate=15 #--------------------------------------------------------------------- -# census-income (KDD) +# census-income (KDD) #--------------------------------------------------------------------- census-income.data.gz: @@ -180,7 +180,7 @@ censusnonormbestmax=1e-4 censustimeestimate=5 #--------------------------------------------------------------------- -# Statlog (Shuttle) +# Statlog (Shuttle) #--------------------------------------------------------------------- shuttle.trn.Z: @@ -211,7 +211,7 @@ shuttlenonormbestmax=1e-3 shuttletimeestimate=1 #--------------------------------------------------------------------- -# CT slices +# CT slices # # normalization doesn't help much #--------------------------------------------------------------------- @@ -256,7 +256,7 @@ CTslicenonormbestmax=1 CTslicetimeestimate=15 #--------------------------------------------------------------------- -# common routines +# common routines #--------------------------------------------------------------------- %.best: %.data @@ -269,8 +269,10 @@ CTslicetimeestimate=15 @printf "WARNING: this step takes about %s minutes\n" $($*timeestimate) @./hypersearch $($*nonormbestmin) $($*nonormbestmax) '$(MAKE)' '$*.%.nonormlearn' > $@ -%.resultsprint: +%.resultsprint: @printf "%20.20s\t%9.3g\t%9.3g\t%9.3g\t%9.3g\n" "$*" $$(cut -f1 $*.best) $$(cut -f2 $*.best) $$(cut -f1 $*.nonormbest) $$(cut -f2 $*.nonormbest) only.%: %.best %.nonormbest all.results.pre %.resultsprint @true + +.PHONY: all all.results all.results.pre diff --git a/explore/clr/explore_clr_wrapper.cpp b/explore/clr/explore_clr_wrapper.cpp index 0564482b..be022af1 100644 --- a/explore/clr/explore_clr_wrapper.cpp +++ b/explore/clr/explore_clr_wrapper.cpp @@ -1,18 +1,18 @@ -// vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application.
-//
-
-#define WIN32_LEAN_AND_MEAN
-#include <Windows.h>
-
-#include "explore_clr_wrapper.h"
-
-using namespace System;
-using namespace System::Collections;
-using namespace System::Collections::Generic;
-using namespace System::Runtime::InteropServices;
-using namespace msclr::interop;
-using namespace NativeMultiWorldTesting;
-
-namespace MultiWorldTesting {
-
-}
+// vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application. +// + +#define WIN32_LEAN_AND_MEAN +#include <Windows.h> + +#include "explore_clr_wrapper.h" + +using namespace System; +using namespace System::Collections; +using namespace System::Collections::Generic; +using namespace System::Runtime::InteropServices; +using namespace msclr::interop; +using namespace NativeMultiWorldTesting; + +namespace MultiWorldTesting { + +} diff --git a/explore/clr/explore_clr_wrapper.h b/explore/clr/explore_clr_wrapper.h index a6cede4b..c73c95d4 100644 --- a/explore/clr/explore_clr_wrapper.h +++ b/explore/clr/explore_clr_wrapper.h @@ -1,438 +1,438 @@ -#pragma once
-#include "explore_interop.h"
-
-/*!
-* \addtogroup MultiWorldTestingCsharp
-* @{
-*/
-namespace MultiWorldTesting {
-
- /// <summary>
- /// The epsilon greedy exploration class.
- /// </summary>
- /// <remarks>
- /// This is a good choice if you have no idea which actions should be preferred.
- /// Epsilon greedy is also computationally cheap.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicy">A default function which outputs an action given a context.</param>
- /// <param name="epsilon">The probability of a random exploration.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions)
- {
- this->defaultPolicy = defaultPolicy;
- m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions);
- }
-
- ~EpsilonGreedyExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- return defaultPolicy->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IPolicy<Ctx>^ defaultPolicy;
- NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The tau-first exploration class.
- /// </summary>
- /// <remarks>
- /// The tau-first explorer collects precisely tau uniform random
- /// exploration events, and then uses the default policy.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class TauFirstExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicy">A default policy after randomization finishes.</param>
- /// <param name="tau">The number of events to be uniform over.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions)
- {
- this->defaultPolicy = defaultPolicy;
- m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions);
- }
-
- ~TauFirstExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- return defaultPolicy->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IPolicy<Ctx>^ defaultPolicy;
- NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The epsilon greedy exploration class.
- /// </summary>
- /// <remarks>
- /// In some cases, different actions have a different scores, and you
- /// would prefer to choose actions with large scores. Softmax allows
- /// you to do that.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class SoftmaxExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultScorer">A function which outputs a score for each action.</param>
- /// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions)
- {
- this->defaultScorer = defaultScorer;
- m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions);
- }
-
- ~SoftmaxExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) override
- {
- return defaultScorer->ScoreActions(context);
- }
-
- NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IScorer<Ctx>^ defaultScorer;
- NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The generic exploration class.
- /// </summary>
- /// <remarks>
- /// GenericExplorer provides complete flexibility. You can create any
- /// distribution over actions desired, and it will draw from that.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class GenericExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultScorer">A function which outputs the probability of each action.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions)
- {
- this->defaultScorer = defaultScorer;
- m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions);
- }
-
- ~GenericExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) override
- {
- return defaultScorer->ScoreActions(context);
- }
-
- NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IScorer<Ctx>^ defaultScorer;
- NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The bootstrap exploration class.
- /// </summary>
- /// <remarks>
- /// The Bootstrap explorer randomizes over the actions chosen by a set of
- /// default policies. This performs well statistically but can be
- /// computationally expensive.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class BootstrapExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicies">A set of default policies to be uniform random over.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions)
- {
- this->defaultPolicies = defaultPolicies;
- if (this->defaultPolicies == nullptr)
- {
- throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null.");
- }
-
- m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions);
- }
-
- ~BootstrapExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- if (index < 0 || index >= defaultPolicies->Length)
- {
- throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range.");
- }
- return defaultPolicies[index]->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- cli::array<IPolicy<Ctx>^>^ defaultPolicies;
- NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The top level MwtExplorer class. Using this makes sure that the
- /// right bits are recorded and good random actions are chosen.
- /// </summary>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class MwtExplorer : public RecorderCallback<Ctx>
- {
- public:
- /// <summary>
- /// Constructor.
- /// </summary>
- /// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param>
- /// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param>
- MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder)
- {
- this->appId = appId;
- this->recorder = recorder;
- }
-
- /// <summary>
- /// Choose_Action should be drop-in replacement for any existing policy function.
- /// </summary>
- /// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param>
- /// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param>
- /// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param>
- /// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns>
- UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context)
- {
- String^ salt = this->appId;
- NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder());
-
- GCHandle selfHandle = GCHandle::Alloc(this);
- IntPtr selfPtr = (IntPtr)selfHandle;
-
- GCHandle contextHandle = GCHandle::Alloc(context);
- IntPtr contextPtr = (IntPtr)contextHandle;
-
- GCHandle explorerHandle = GCHandle::Alloc(explorer);
- IntPtr explorerPtr = (IntPtr)explorerHandle;
-
- NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer());
- u32 action = 0;
- if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid)
- {
- EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid)
- {
- TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid)
- {
- SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == GenericExplorer<Ctx>::typeid)
- {
- GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid)
- {
- BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
-
- explorerHandle.Free();
- contextHandle.Free();
- selfHandle.Free();
-
- return action;
- }
-
- internal:
- virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override
- {
- recorder->Record(context, action, probability, unique_key);
- }
-
- private:
- IRecorder<Ctx>^ recorder;
- String^ appId;
- };
-
- /// <summary>
- /// Represents a feature in a sparse array.
- /// </summary>
- [StructLayout(LayoutKind::Sequential)]
- public value struct Feature
- {
- float Value;
- UInt32 Id;
- };
-
- /// <summary>
- /// A sample recorder class that converts the exploration tuple into string format.
- /// </summary>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx> where Ctx : IStringContext
- public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx>
- {
- public:
- StringRecorder()
- {
- m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>();
- }
-
- ~StringRecorder()
- {
- delete m_string_recorder;
- }
-
- virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey)
- {
- GCHandle contextHandle = GCHandle::Alloc(context);
- IntPtr contextPtr = (IntPtr)contextHandle;
-
- NativeStringContext native_context(contextPtr.ToPointer(), GetCallback());
- m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey));
- }
-
- /// <summary>
- /// Gets the content of the recording so far as a string and clears internal content.
- /// </summary>
- /// <returns>
- /// A string with recording content.
- /// </returns>
- String^ GetRecording()
- {
- // Workaround for C++-CLI bug which does not allow default value for parameter
- return GetRecording(true);
- }
-
- /// <summary>
- /// Gets the content of the recording so far as a string and optionally clears internal content.
- /// </summary>
- /// <param name="flush">A boolean value indicating whether to clear the internal content.</param>
- /// <returns>
- /// A string with recording content.
- /// </returns>
- String^ GetRecording(bool flush)
- {
- return gcnew String(m_string_recorder->Get_Recording(flush).c_str());
- }
-
- private:
- NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder;
- };
-
- /// <summary>
- /// A sample context class that stores a vector of Features.
- /// </summary>
- public ref class SimpleContext : public IStringContext
- {
- public:
- SimpleContext(cli::array<Feature>^ features)
- {
- Features = features;
-
- // TODO: add another constructor overload for native SimpleContext to avoid copying feature values
- m_features = new vector<NativeMultiWorldTesting::Feature>();
- for (int i = 0; i < features->Length; i++)
- {
- m_features->push_back({ features[i].Value, features[i].Id });
- }
-
- m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features);
- }
-
- String^ ToString() override
- {
- return gcnew String(m_native_context->To_String().c_str());
- }
-
- ~SimpleContext()
- {
- delete m_native_context;
- }
-
- public:
- cli::array<Feature>^ GetFeatures() { return Features; }
-
- internal:
- cli::array<Feature>^ Features;
-
- private:
- vector<NativeMultiWorldTesting::Feature>* m_features;
- NativeMultiWorldTesting::SimpleContext* m_native_context;
- };
-}
-
+#pragma once +#include "explore_interop.h" + +/*! +* \addtogroup MultiWorldTestingCsharp +* @{ +*/ +namespace MultiWorldTesting { + + /// <summary> + /// The epsilon greedy exploration class. + /// </summary> + /// <remarks> + /// This is a good choice if you have no idea which actions should be preferred. + /// Epsilon greedy is also computationally cheap. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultPolicy">A default function which outputs an action given a context.</param> + /// <param name="epsilon">The probability of a random exploration.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions) + { + this->defaultPolicy = defaultPolicy; + m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions); + } + + ~EpsilonGreedyExplorer() + { + delete m_explorer; + } + + internal: + virtual UInt32 InvokePolicyCallback(Ctx context, int index) override + { + return defaultPolicy->ChooseAction(context); + } + + NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + IPolicy<Ctx>^ defaultPolicy; + NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The tau-first exploration class. + /// </summary> + /// <remarks> + /// The tau-first explorer collects precisely tau uniform random + /// exploration events, and then uses the default policy. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class TauFirstExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultPolicy">A default policy after randomization finishes.</param> + /// <param name="tau">The number of events to be uniform over.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions) + { + this->defaultPolicy = defaultPolicy; + m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions); + } + + ~TauFirstExplorer() + { + delete m_explorer; + } + + internal: + virtual UInt32 InvokePolicyCallback(Ctx context, int index) override + { + return defaultPolicy->ChooseAction(context); + } + + NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + IPolicy<Ctx>^ defaultPolicy; + NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The epsilon greedy exploration class. + /// </summary> + /// <remarks> + /// In some cases, different actions have a different scores, and you + /// would prefer to choose actions with large scores. Softmax allows + /// you to do that. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class SoftmaxExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultScorer">A function which outputs a score for each action.</param> + /// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions) + { + this->defaultScorer = defaultScorer; + m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions); + } + + ~SoftmaxExplorer() + { + delete m_explorer; + } + + internal: + virtual List<float>^ InvokeScorerCallback(Ctx context) override + { + return defaultScorer->ScoreActions(context); + } + + NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + IScorer<Ctx>^ defaultScorer; + NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The generic exploration class. + /// </summary> + /// <remarks> + /// GenericExplorer provides complete flexibility. You can create any + /// distribution over actions desired, and it will draw from that. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class GenericExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultScorer">A function which outputs the probability of each action.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions) + { + this->defaultScorer = defaultScorer; + m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions); + } + + ~GenericExplorer() + { + delete m_explorer; + } + + internal: + virtual List<float>^ InvokeScorerCallback(Ctx context) override + { + return defaultScorer->ScoreActions(context); + } + + NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + IScorer<Ctx>^ defaultScorer; + NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The bootstrap exploration class. + /// </summary> + /// <remarks> + /// The Bootstrap explorer randomizes over the actions chosen by a set of + /// default policies. This performs well statistically but can be + /// computationally expensive. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class BootstrapExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultPolicies">A set of default policies to be uniform random over.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions) + { + this->defaultPolicies = defaultPolicies; + if (this->defaultPolicies == nullptr) + { + throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null."); + } + + m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions); + } + + ~BootstrapExplorer() + { + delete m_explorer; + } + + internal: + virtual UInt32 InvokePolicyCallback(Ctx context, int index) override + { + if (index < 0 || index >= defaultPolicies->Length) + { + throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range."); + } + return defaultPolicies[index]->ChooseAction(context); + } + + NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + cli::array<IPolicy<Ctx>^>^ defaultPolicies; + NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The top level MwtExplorer class. Using this makes sure that the + /// right bits are recorded and good random actions are chosen. + /// </summary> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class MwtExplorer : public RecorderCallback<Ctx> + { + public: + /// <summary> + /// Constructor. + /// </summary> + /// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param> + /// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param> + MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder) + { + this->appId = appId; + this->recorder = recorder; + } + + /// <summary> + /// Choose_Action should be drop-in replacement for any existing policy function. + /// </summary> + /// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param> + /// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param> + /// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param> + /// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns> + UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context) + { + String^ salt = this->appId; + NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder()); + + GCHandle selfHandle = GCHandle::Alloc(this); + IntPtr selfPtr = (IntPtr)selfHandle; + + GCHandle contextHandle = GCHandle::Alloc(context); + IntPtr contextPtr = (IntPtr)contextHandle; + + GCHandle explorerHandle = GCHandle::Alloc(explorer); + IntPtr explorerPtr = (IntPtr)explorerHandle; + + NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer()); + u32 action = 0; + if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid) + { + EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid) + { + TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid) + { + SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + else if (explorer->GetType() == GenericExplorer<Ctx>::typeid) + { + GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid) + { + BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + + explorerHandle.Free(); + contextHandle.Free(); + selfHandle.Free(); + + return action; + } + + internal: + virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override + { + recorder->Record(context, action, probability, unique_key); + } + + private: + IRecorder<Ctx>^ recorder; + String^ appId; + }; + + /// <summary> + /// Represents a feature in a sparse array. + /// </summary> + [StructLayout(LayoutKind::Sequential)] + public value struct Feature + { + float Value; + UInt32 Id; + }; + + /// <summary> + /// A sample recorder class that converts the exploration tuple into string format. + /// </summary> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> where Ctx : IStringContext + public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx> + { + public: + StringRecorder() + { + m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>(); + } + + ~StringRecorder() + { + delete m_string_recorder; + } + + virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) + { + GCHandle contextHandle = GCHandle::Alloc(context); + IntPtr contextPtr = (IntPtr)contextHandle; + + NativeStringContext native_context(contextPtr.ToPointer(), GetCallback()); + m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey)); + } + + /// <summary> + /// Gets the content of the recording so far as a string and clears internal content. + /// </summary> + /// <returns> + /// A string with recording content. + /// </returns> + String^ GetRecording() + { + // Workaround for C++-CLI bug which does not allow default value for parameter + return GetRecording(true); + } + + /// <summary> + /// Gets the content of the recording so far as a string and optionally clears internal content. + /// </summary> + /// <param name="flush">A boolean value indicating whether to clear the internal content.</param> + /// <returns> + /// A string with recording content. + /// </returns> + String^ GetRecording(bool flush) + { + return gcnew String(m_string_recorder->Get_Recording(flush).c_str()); + } + + private: + NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder; + }; + + /// <summary> + /// A sample context class that stores a vector of Features. + /// </summary> + public ref class SimpleContext : public IStringContext + { + public: + SimpleContext(cli::array<Feature>^ features) + { + Features = features; + + // TODO: add another constructor overload for native SimpleContext to avoid copying feature values + m_features = new vector<NativeMultiWorldTesting::Feature>(); + for (int i = 0; i < features->Length; i++) + { + m_features->push_back({ features[i].Value, features[i].Id }); + } + + m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features); + } + + String^ ToString() override + { + return gcnew String(m_native_context->To_String().c_str()); + } + + ~SimpleContext() + { + delete m_native_context; + } + + public: + cli::array<Feature>^ GetFeatures() { return Features; } + + internal: + cli::array<Feature>^ Features; + + private: + vector<NativeMultiWorldTesting::Feature>* m_features; + NativeMultiWorldTesting::SimpleContext* m_native_context; + }; +} + /*! @} End of Doxygen Groups*/
\ No newline at end of file diff --git a/explore/clr/explore_interface.h b/explore/clr/explore_interface.h index 3212b4d8..dfa17697 100644 --- a/explore/clr/explore_interface.h +++ b/explore/clr/explore_interface.h @@ -1,88 +1,88 @@ -#pragma once
-
-using namespace System;
-using namespace System::Collections::Generic;
-
-/** \defgroup MultiWorldTestingCsharp
-\brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
-*/
-
-/*!
-* \addtogroup MultiWorldTestingCsharp
-* @{
-*/
-
-//! Interface for C# version of Multiworld Testing library.
-//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
-namespace MultiWorldTesting {
-
-/// <summary>
-/// Represents a recorder that exposes a method to record exploration data based on generic contexts.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-/// <remarks>
-/// Exploration data is specified as a set of tuples <context, action, probability, key> as described below. An
-/// application passes an IRecorder object to the @MwtExplorer constructor. See
-/// @StringRecorder for a sample IRecorder object.
-/// </remarks>
-generic <class Ctx>
-public interface class IRecorder
-{
-public:
- /// <summary>
- /// Records the exploration data associated with a given decision.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <param name="action">Chosen by an exploration algorithm given context.</param>
- /// <param name="probability">The probability of the chosen action given context.</param>
- /// <param name="uniqueKey">A user-defined identifer for the decision.</param>
- virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0;
-};
-
-/// <summary>
-/// Exposes a method for choosing an action given a generic context. IPolicy objects are
-/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-generic <class Ctx>
-public interface class IPolicy
-{
-public:
- /// <summary>
- /// Determines the action to take for a given context.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <returns>Index of the action to take (1-based)</returns>
- virtual UInt32 ChooseAction(Ctx context) = 0;
-};
-
-/// <summary>
-/// Exposes a method for specifying a score (weight) for each action given a generic context.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-generic <class Ctx>
-public interface class IScorer
-{
-public:
- /// <summary>
- /// Determines the score of each action for a given context.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <returns>Vector of scores indexed by action (1-based).</returns>
- virtual List<float>^ ScoreActions(Ctx context) = 0;
-};
-
-generic <class Ctx>
-public interface class IExplorer
-{
-};
-
-public interface class IStringContext
-{
-public:
- virtual String^ ToString() = 0;
-};
-
-}
-
+#pragma once + +using namespace System; +using namespace System::Collections::Generic; + +/** \defgroup MultiWorldTestingCsharp +\brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs +*/ + +/*! +* \addtogroup MultiWorldTestingCsharp +* @{ +*/ + +//! Interface for C# version of Multiworld Testing library. +//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs +namespace MultiWorldTesting { + +/// <summary> +/// Represents a recorder that exposes a method to record exploration data based on generic contexts. +/// </summary> +/// <typeparam name="Ctx">The Context type.</typeparam> +/// <remarks> +/// Exploration data is specified as a set of tuples <context, action, probability, key> as described below. An +/// application passes an IRecorder object to the @MwtExplorer constructor. See +/// @StringRecorder for a sample IRecorder object. +/// </remarks> +generic <class Ctx> +public interface class IRecorder +{ +public: + /// <summary> + /// Records the exploration data associated with a given decision. + /// </summary> + /// <param name="context">A user-defined context for the decision.</param> + /// <param name="action">Chosen by an exploration algorithm given context.</param> + /// <param name="probability">The probability of the chosen action given context.</param> + /// <param name="uniqueKey">A user-defined identifer for the decision.</param> + virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0; +}; + +/// <summary> +/// Exposes a method for choosing an action given a generic context. IPolicy objects are +/// passed to (and invoked by) exploration algorithms to specify the default policy behavior. +/// </summary> +/// <typeparam name="Ctx">The Context type.</typeparam> +generic <class Ctx> +public interface class IPolicy +{ +public: + /// <summary> + /// Determines the action to take for a given context. + /// </summary> + /// <param name="context">A user-defined context for the decision.</param> + /// <returns>Index of the action to take (1-based)</returns> + virtual UInt32 ChooseAction(Ctx context) = 0; +}; + +/// <summary> +/// Exposes a method for specifying a score (weight) for each action given a generic context. +/// </summary> +/// <typeparam name="Ctx">The Context type.</typeparam> +generic <class Ctx> +public interface class IScorer +{ +public: + /// <summary> + /// Determines the score of each action for a given context. + /// </summary> + /// <param name="context">A user-defined context for the decision.</param> + /// <returns>Vector of scores indexed by action (1-based).</returns> + virtual List<float>^ ScoreActions(Ctx context) = 0; +}; + +generic <class Ctx> +public interface class IExplorer +{ +}; + +public interface class IStringContext +{ +public: + virtual String^ ToString() = 0; +}; + +} + /*! @} End of Doxygen Groups*/
\ No newline at end of file diff --git a/explore/clr/explore_interop.h b/explore/clr/explore_interop.h index 99466945..a6cb9327 100644 --- a/explore/clr/explore_interop.h +++ b/explore/clr/explore_interop.h @@ -1,358 +1,358 @@ -#pragma once
-
-#define MANAGED_CODE
-
-#include "explore_interface.h"
-#include "MWTExplorer.h"
-
-#include <msclr\marshal_cppstd.h>
-
-using namespace System;
-using namespace System::Collections::Generic;
-using namespace System::IO;
-using namespace System::Runtime::InteropServices;
-using namespace System::Xml::Serialization;
-using namespace msclr::interop;
-
-namespace MultiWorldTesting {
-
-// Policy callback
-private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index);
-typedef u32 Native_Policy_Callback(void* explorer, void* context, int index);
-
-// Scorer callback
-private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size);
-typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size);
-
-// Recorder callback
-private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey);
-typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key);
-
-// ToString callback
-private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue);
-typedef void Native_To_String_Callback(void* explorer, void* string_value);
-
-// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context
-// used for triggering callback for Policy, Scorer, Recorder
-class NativeContext
-{
-public:
- NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context)
- {
- m_clr_mwt = clr_mwt;
- m_clr_explorer = clr_explorer;
- m_clr_context = clr_context;
- }
-
- void* Get_Clr_Mwt()
- {
- return m_clr_mwt;
- }
-
- void* Get_Clr_Context()
- {
- return m_clr_context;
- }
-
- void* Get_Clr_Explorer()
- {
- return m_clr_explorer;
- }
-
-private:
- void* m_clr_mwt;
- void* m_clr_context;
- void* m_clr_explorer;
-};
-
-class NativeStringContext
-{
-public:
- NativeStringContext(void* clr_context, Native_To_String_Callback* func)
- {
- m_clr_context = clr_context;
- m_func = func;
- }
-
- string To_String()
- {
- string value;
- m_func(m_clr_context, &value);
- return value;
- }
-private:
- void* m_clr_context;
- Native_To_String_Callback* m_func;
-};
-
-// NativeRecorder listens to callback event and reroute it to the managed Recorder instance
-class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext>
-{
-public:
- NativeRecorder(Native_Recorder_Callback* native_func)
- {
- m_func = native_func;
- }
-
- void Record(NativeContext& context, u32 action, float probability, string unique_key)
- {
- GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str()));
- IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle;
-
- m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer());
-
- uniqueKeyHandle.Free();
- }
-private:
- Native_Recorder_Callback* m_func;
-};
-
-// NativePolicy listens to callback event and reroute it to the managed Policy instance
-class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext>
-{
-public:
- NativePolicy(Native_Policy_Callback* func, int index = -1)
- {
- m_func = func;
- m_index = index;
- }
-
- u32 Choose_Action(NativeContext& context)
- {
- return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index);
- }
-
-private:
- Native_Policy_Callback* m_func;
- int m_index;
-};
-
-class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext>
-{
-public:
- NativeScorer(Native_Scorer_Callback* func)
- {
- m_func = func;
- }
-
- vector<float> Score_Actions(NativeContext& context)
- {
- float* scores = nullptr;
- u32 num_scores = 0;
- m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores);
-
- // It's ok if scores is null, vector will be empty
- vector<float> scores_vector(scores, scores + num_scores);
- delete[] scores;
-
- return scores_vector;
- }
-private:
- Native_Scorer_Callback* m_func;
-};
-
-// Triggers callback to the Policy instance to choose an action
-generic <class Ctx>
-public ref class PolicyCallback abstract
-{
-internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0;
-
- PolicyCallback()
- {
- policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke);
- IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback);
- m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer());
- m_native_policy = nullptr;
- m_native_policies = nullptr;
- }
-
- ~PolicyCallback()
- {
- delete m_native_policy;
- delete m_native_policies;
- }
-
- NativePolicy* GetNativePolicy()
- {
- if (m_native_policy == nullptr)
- {
- m_native_policy = new NativePolicy(m_callback);
- }
- return m_native_policy;
- }
-
- vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* GetNativePolicies(int count)
- {
- if (m_native_policies == nullptr)
- {
- m_native_policies = new vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>();
- for (int i = 0; i < count; i++)
- {
- m_native_policies->push_back(unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>(new NativePolicy(m_callback, i)));
- }
- }
-
- return m_native_policies;
- }
-
- static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index)
- {
- GCHandle callbackHandle = (GCHandle)callbackPtr;
- PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- return callback->InvokePolicyCallback(context, index);
- }
-
-private:
- ClrPolicyCallback^ policyCallback;
-
-private:
- NativePolicy* m_native_policy;
- vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies;
- Native_Policy_Callback* m_callback;
-};
-
-// Triggers callback to the Recorder instance to record interaction data
-generic <class Ctx>
-public ref class RecorderCallback abstract
-{
-internal:
- virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0;
-
- RecorderCallback()
- {
- recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke);
- IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback);
- Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer());
- m_native_recorder = new NativeRecorder(callback);
- }
-
- ~RecorderCallback()
- {
- delete m_native_recorder;
- }
-
- NativeRecorder* GetNativeRecorder()
- {
- return m_native_recorder;
- }
-
- static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr)
- {
- GCHandle mwtHandle = (GCHandle)mwtPtr;
- RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr;
- String^ uniqueKey = (String^)uniqueKeyHandle.Target;
-
- callback->InvokeRecorderCallback(context, action, probability, uniqueKey);
- }
-
-private:
- ClrRecorderCallback^ recorderCallback;
-
-private:
- NativeRecorder* m_native_recorder;
-};
-
-// Triggers callback to the Recorder instance to record interaction data
-generic <class Ctx>
-public ref class ScorerCallback abstract
-{
-internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) = 0;
-
- ScorerCallback()
- {
- scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke);
- IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback);
- Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer());
- m_native_scorer = new NativeScorer(callback);
- }
-
- ~ScorerCallback()
- {
- delete m_native_scorer;
- }
-
- NativeScorer* GetNativeScorer()
- {
- return m_native_scorer;
- }
-
- static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr)
- {
- GCHandle callbackHandle = (GCHandle)callbackPtr;
- ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- List<float>^ scoreList = callback->InvokeScorerCallback(context);
-
- if (scoreList == nullptr || scoreList->Count == 0)
- {
- return;
- }
-
- u32* num_scores = (u32*)sizePtr.ToPointer();
- *num_scores = (u32)scoreList->Count;
-
- float* scores = new float[*num_scores];
- for (u32 i = 0; i < *num_scores; i++)
- {
- scores[i] = scoreList[i];
- }
-
- float** native_scores = (float**)scoresPtr.ToPointer();
- *native_scores = scores;
- }
-
-private:
- ClrScorerCallback^ scorerCallback;
-
-private:
- NativeScorer* m_native_scorer;
-};
-
-// Triggers callback to the Context instance to perform ToString() operation
-generic <class Ctx> where Ctx : IStringContext
-public ref class ToStringCallback
-{
-internal:
- ToStringCallback()
- {
- toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke);
- IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback);
- m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer());
- }
-
- Native_To_String_Callback* GetCallback()
- {
- return m_callback;
- }
-
- static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr)
- {
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- string* out_string = (string*)stringPtr.ToPointer();
- *out_string = marshal_as<string>(context->ToString());
- }
-
-private:
- ClrToStringCallback^ toStringCallback;
-
-private:
- Native_To_String_Callback* m_callback;
-};
-
+#pragma once + +#define MANAGED_CODE + +#include "explore_interface.h" +#include "MWTExplorer.h" + +#include <msclr\marshal_cppstd.h> + +using namespace System; +using namespace System::Collections::Generic; +using namespace System::IO; +using namespace System::Runtime::InteropServices; +using namespace System::Xml::Serialization; +using namespace msclr::interop; + +namespace MultiWorldTesting { + +// Policy callback +private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index); +typedef u32 Native_Policy_Callback(void* explorer, void* context, int index); + +// Scorer callback +private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size); +typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size); + +// Recorder callback +private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey); +typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key); + +// ToString callback +private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue); +typedef void Native_To_String_Callback(void* explorer, void* string_value); + +// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context +// used for triggering callback for Policy, Scorer, Recorder +class NativeContext +{ +public: + NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context) + { + m_clr_mwt = clr_mwt; + m_clr_explorer = clr_explorer; + m_clr_context = clr_context; + } + + void* Get_Clr_Mwt() + { + return m_clr_mwt; + } + + void* Get_Clr_Context() + { + return m_clr_context; + } + + void* Get_Clr_Explorer() + { + return m_clr_explorer; + } + +private: + void* m_clr_mwt; + void* m_clr_context; + void* m_clr_explorer; +}; + +class NativeStringContext +{ +public: + NativeStringContext(void* clr_context, Native_To_String_Callback* func) + { + m_clr_context = clr_context; + m_func = func; + } + + string To_String() + { + string value; + m_func(m_clr_context, &value); + return value; + } +private: + void* m_clr_context; + Native_To_String_Callback* m_func; +}; + +// NativeRecorder listens to callback event and reroute it to the managed Recorder instance +class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext> +{ +public: + NativeRecorder(Native_Recorder_Callback* native_func) + { + m_func = native_func; + } + + void Record(NativeContext& context, u32 action, float probability, string unique_key) + { + GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str())); + IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle; + + m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer()); + + uniqueKeyHandle.Free(); + } +private: + Native_Recorder_Callback* m_func; +}; + +// NativePolicy listens to callback event and reroute it to the managed Policy instance +class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext> +{ +public: + NativePolicy(Native_Policy_Callback* func, int index = -1) + { + m_func = func; + m_index = index; + } + + u32 Choose_Action(NativeContext& context) + { + return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index); + } + +private: + Native_Policy_Callback* m_func; + int m_index; +}; + +class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext> +{ +public: + NativeScorer(Native_Scorer_Callback* func) + { + m_func = func; + } + + vector<float> Score_Actions(NativeContext& context) + { + float* scores = nullptr; + u32 num_scores = 0; + m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores); + + // It's ok if scores is null, vector will be empty + vector<float> scores_vector(scores, scores + num_scores); + delete[] scores; + + return scores_vector; + } +private: + Native_Scorer_Callback* m_func; +}; + +// Triggers callback to the Policy instance to choose an action +generic <class Ctx> +public ref class PolicyCallback abstract +{ +internal: + virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0; + + PolicyCallback() + { + policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke); + IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback); + m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer()); + m_native_policy = nullptr; + m_native_policies = nullptr; + } + + ~PolicyCallback() + { + delete m_native_policy; + delete m_native_policies; + } + + NativePolicy* GetNativePolicy() + { + if (m_native_policy == nullptr) + { + m_native_policy = new NativePolicy(m_callback); + } + return m_native_policy; + } + + vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* GetNativePolicies(int count) + { + if (m_native_policies == nullptr) + { + m_native_policies = new vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>(); + for (int i = 0; i < count; i++) + { + m_native_policies->push_back(unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>(new NativePolicy(m_callback, i))); + } + } + + return m_native_policies; + } + + static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index) + { + GCHandle callbackHandle = (GCHandle)callbackPtr; + PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target; + + GCHandle contextHandle = (GCHandle)contextPtr; + Ctx context = (Ctx)contextHandle.Target; + + return callback->InvokePolicyCallback(context, index); + } + +private: + ClrPolicyCallback^ policyCallback; + +private: + NativePolicy* m_native_policy; + vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies; + Native_Policy_Callback* m_callback; +}; + +// Triggers callback to the Recorder instance to record interaction data +generic <class Ctx> +public ref class RecorderCallback abstract +{ +internal: + virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0; + + RecorderCallback() + { + recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke); + IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback); + Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer()); + m_native_recorder = new NativeRecorder(callback); + } + + ~RecorderCallback() + { + delete m_native_recorder; + } + + NativeRecorder* GetNativeRecorder() + { + return m_native_recorder; + } + + static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr) + { + GCHandle mwtHandle = (GCHandle)mwtPtr; + RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target; + + GCHandle contextHandle = (GCHandle)contextPtr; + Ctx context = (Ctx)contextHandle.Target; + + GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr; + String^ uniqueKey = (String^)uniqueKeyHandle.Target; + + callback->InvokeRecorderCallback(context, action, probability, uniqueKey); + } + +private: + ClrRecorderCallback^ recorderCallback; + +private: + NativeRecorder* m_native_recorder; +}; + +// Triggers callback to the Recorder instance to record interaction data +generic <class Ctx> +public ref class ScorerCallback abstract +{ +internal: + virtual List<float>^ InvokeScorerCallback(Ctx context) = 0; + + ScorerCallback() + { + scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke); + IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback); + Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer()); + m_native_scorer = new NativeScorer(callback); + } + + ~ScorerCallback() + { + delete m_native_scorer; + } + + NativeScorer* GetNativeScorer() + { + return m_native_scorer; + } + + static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr) + { + GCHandle callbackHandle = (GCHandle)callbackPtr; + ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target; + + GCHandle contextHandle = (GCHandle)contextPtr; + Ctx context = (Ctx)contextHandle.Target; + + List<float>^ scoreList = callback->InvokeScorerCallback(context); + + if (scoreList == nullptr || scoreList->Count == 0) + { + return; + } + + u32* num_scores = (u32*)sizePtr.ToPointer(); + *num_scores = (u32)scoreList->Count; + + float* scores = new float[*num_scores]; + for (u32 i = 0; i < *num_scores; i++) + { + scores[i] = scoreList[i]; + } + + float** native_scores = (float**)scoresPtr.ToPointer(); + *native_scores = scores; + } + +private: + ClrScorerCallback^ scorerCallback; + +private: + NativeScorer* m_native_scorer; +}; + +// Triggers callback to the Context instance to perform ToString() operation +generic <class Ctx> where Ctx : IStringContext +public ref class ToStringCallback +{ +internal: + ToStringCallback() + { + toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke); + IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback); + m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer()); + } + + Native_To_String_Callback* GetCallback() + { + return m_callback; + } + + static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr) + { + GCHandle contextHandle = (GCHandle)contextPtr; + Ctx context = (Ctx)contextHandle.Target; + + string* out_string = (string*)stringPtr.ToPointer(); + *out_string = marshal_as<string>(context->ToString()); + } + +private: + ClrToStringCallback^ toStringCallback; + +private: + Native_To_String_Callback* m_callback; +}; + }
\ No newline at end of file diff --git a/explore/explore.cpp b/explore/explore.cpp index ebc1c14e..f44ac353 100644 --- a/explore/explore.cpp +++ b/explore/explore.cpp @@ -1,73 +1,73 @@ -// explore.cpp : Timing code to measure performance of MWT Explorer library
-
-#include "MWTExplorer.h"
-#include <chrono>
-#include <tuple>
-#include <iostream>
-
-using namespace std;
-using namespace std::chrono;
-
-using namespace MultiWorldTesting;
-
-class MySimplePolicy : public IPolicy<SimpleContext>
-{
-public:
- u32 Choose_Action(SimpleContext& context)
- {
- return (u32)1;
- }
-};
-
-const u32 num_actions = 10;
-
-void Clock_Explore()
-{
- float epsilon = .2f;
- string unique_key = "key";
- int num_features = 1000;
- int num_iter = 10000;
- int num_warmup = 100;
- int num_interactions = 1;
-
- // pre-create features
- vector<Feature> features;
- for (int i = 0; i < num_features; i++)
- {
- Feature f = {0.5, i+1};
- features.push_back(f);
- }
-
- long long time_init = 0, time_choose = 0;
- for (int iter = 0; iter < num_iter + num_warmup; iter++)
- {
- high_resolution_clock::time_point t1 = high_resolution_clock::now();
- StringRecorder<SimpleContext> recorder;
- MwtExplorer<SimpleContext> mwt("test", recorder);
- MySimplePolicy default_policy;
- EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions);
- high_resolution_clock::time_point t2 = high_resolution_clock::now();
- time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
-
- t1 = high_resolution_clock::now();
- SimpleContext appContext(features);
- for (int i = 0; i < num_interactions; i++)
- {
- mwt.Choose_Action(explorer, unique_key, appContext);
- }
- t2 = high_resolution_clock::now();
- time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
- }
-
- cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl;
- cout << "--- PER ITERATION ---" << endl;
- cout << "Init: " << (double)time_init / num_iter << " micro" << endl;
- cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl;
- cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl;
-}
-
-int main(int argc, char* argv[])
-{
- Clock_Explore();
- return 0;
-}
+// explore.cpp : Timing code to measure performance of MWT Explorer library + +#include "MWTExplorer.h" +#include <chrono> +#include <tuple> +#include <iostream> + +using namespace std; +using namespace std::chrono; + +using namespace MultiWorldTesting; + +class MySimplePolicy : public IPolicy<SimpleContext> +{ +public: + u32 Choose_Action(SimpleContext& context) + { + return (u32)1; + } +}; + +const u32 num_actions = 10; + +void Clock_Explore() +{ + float epsilon = .2f; + string unique_key = "key"; + int num_features = 1000; + int num_iter = 10000; + int num_warmup = 100; + int num_interactions = 1; + + // pre-create features + vector<Feature> features; + for (int i = 0; i < num_features; i++) + { + Feature f = {0.5, i+1}; + features.push_back(f); + } + + long long time_init = 0, time_choose = 0; + for (int iter = 0; iter < num_iter + num_warmup; iter++) + { + high_resolution_clock::time_point t1 = high_resolution_clock::now(); + StringRecorder<SimpleContext> recorder; + MwtExplorer<SimpleContext> mwt("test", recorder); + MySimplePolicy default_policy; + EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions); + high_resolution_clock::time_point t2 = high_resolution_clock::now(); + time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count(); + + t1 = high_resolution_clock::now(); + SimpleContext appContext(features); + for (int i = 0; i < num_interactions; i++) + { + mwt.Choose_Action(explorer, unique_key, appContext); + } + t2 = high_resolution_clock::now(); + time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count(); + } + + cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl; + cout << "--- PER ITERATION ---" << endl; + cout << "Init: " << (double)time_init / num_iter << " micro" << endl; + cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl; + cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl; +} + +int main(int argc, char* argv[]) +{ + Clock_Explore(); + return 0; +} diff --git a/explore/static/MWTExplorer.h b/explore/static/MWTExplorer.h index dd5e3044..4c7009b2 100644 --- a/explore/static/MWTExplorer.h +++ b/explore/static/MWTExplorer.h @@ -1,669 +1,669 @@ -//
-// Main interface for clients of the Multiworld testing (MWT) service.
-//
-
-#pragma once
-
-#include <stdexcept>
-#include <float.h>
-#include <math.h>
-#include <stdio.h>
-#include <string.h>
-#include <vector>
-#include <utility>
-#include <memory>
-#include <limits.h>
-#include <tuple>
-
-#ifdef MANAGED_CODE
-#define PORTING_INTERFACE public
-#define MWT_NAMESPACE namespace NativeMultiWorldTesting
-#else
-#define PORTING_INTERFACE private
-#define MWT_NAMESPACE namespace MultiWorldTesting
-#endif
-
-using namespace std;
-
-#include "utility.h"
-
-/** \defgroup MultiWorldTestingCpp
-\brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
-*/
-
-/*!
-* \addtogroup MultiWorldTestingCpp
-* @{
-*/
-
-//! Interface for C++ version of Multiworld Testing library.
-//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
-MWT_NAMESPACE {
-
-// Forward declarations
-template <class Ctx>
-class IRecorder;
-template <class Ctx>
-class IExplorer;
-
-///
-/// The top-level MwtExplorer class. Using this enables principled and efficient exploration
-/// over a set of possible actions, and ensures that the right bits are recorded.
-///
-template <class Ctx>
-class MwtExplorer
-{
-public:
- ///
- /// Constructor
- ///
- /// @param appid This should be unique to your experiment or you risk nasty correlation bugs.
- /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning.
- ///
- MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder)
- {
- m_app_id = HashUtils::Compute_Id_Hash(app_id);
- }
-
- ///
- /// Chooses an action by invoking an underlying exploration algorithm. This should be a
- /// drop-in replacement for any existing policy function.
- ///
- /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback.
- /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc..
- /// @param context The context upon which a decision is made. See SimpleContext below for an example.
- ///
- u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context)
- {
- u64 seed = HashUtils::Compute_Id_Hash(unique_key);
-
- std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context);
-
- u32 action = std::get<0>(action_probability_log_tuple);
- float prob = std::get<1>(action_probability_log_tuple);
-
- if (std::get<2>(action_probability_log_tuple))
- {
- m_recorder.Record(context, action, prob, unique_key);
- }
-
- return action;
- }
-
-private:
- u64 m_app_id;
- IRecorder<Ctx>& m_recorder;
-};
-
-///
-/// Exposes a method to record exploration data based on generic contexts. Exploration data
-/// is specified as a set of tuples <context, action, probability, key> as described below. An
-/// application passes an IRecorder object to the @MwtExplorer constructor. See
-/// @StringRecorder for a sample IRecorder object.
-///
-template <class Ctx>
-class IRecorder
-{
-public:
- ///
- /// Records the exploration data associated with a given decision.
- ///
- /// @param context A user-defined context for the decision
- /// @param action The action chosen by an exploration algorithm given context
- /// @param probability The probability the exploration algorithm chose said action
- /// @param unique_key A user-defined unique identifer for the decision
- ///
- virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0;
-};
-
-///
-/// Exposes a method to choose an action given a generic context, and obtain the relevant
-/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this
-/// interface yourself: instead, use the various exploration algorithms below, which
-/// implement it for you.
-///
-template <class Ctx>
-class IExplorer
-{
-public:
- ///
- /// Determines the action to take and the probability with which it was chosen, for a
- /// given context.
- ///
- /// @param salted_seed A PRG seed based on a unique id information provided by the user
- /// @param context A user-defined context for the decision
- /// @returns The action to take, the probability it was chosen, and a flag indicating
- /// whether to record this decision
- ///
- virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0;
-};
-
-///
-/// Exposes a method to choose an action given a generic context. IPolicy objects are
-/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
-///
-template <class Ctx>
-class IPolicy
-{
-public:
- ///
- /// Determines the action to take for a given context.
- ///
- /// @param context A user-defined context for the decision
- /// @returns The action to take (1-based index)
- ///
- virtual u32 Choose_Action(Ctx& context) = 0;
-};
-
-///
-/// Exposes a method for specifying a score (weight) for each action given a generic context.
-///
-template <class Ctx>
-class IScorer
-{
-public:
- ///
- /// Determines the score of each action for a given context.
- ///
- /// @param context A user-defined context for the decision
- /// @returns A vector of scores indexed by action (1-based)
- ///
- virtual vector<float> Score_Actions(Ctx& context) = 0;
-};
-
-///
-/// A sample recorder class that converts the exploration tuple into string format.
-///
-template <class Ctx>
-struct StringRecorder : public IRecorder<Ctx>
-{
- void Record(Ctx& context, u32 action, float probability, string unique_key)
- {
- // Implicitly enforce To_String() API on the context
- m_recording.append(to_string((unsigned long)action));
- m_recording.append(" ", 1);
- m_recording.append(unique_key);
- m_recording.append(" ", 1);
-
- char prob_str[10] = { 0 };
- NumberUtils::Float_To_String(probability, prob_str);
- m_recording.append(prob_str);
-
- m_recording.append(" | ", 3);
- m_recording.append(context.To_String());
- m_recording.append("\n");
- }
-
- // Gets the content of the recording so far as a string and optionally clears internal content.
- string Get_Recording(bool flush = true)
- {
- if (!flush)
- {
- return m_recording;
- }
- string recording = m_recording;
- m_recording.clear();
- return recording;
- }
-
-private:
- string m_recording;
-};
-
-///
-/// Represents a feature in a sparse array.
-///
-struct Feature
-{
- float Value;
- u32 Id;
-
- bool operator==(Feature other_feature)
- {
- return Id == other_feature.Id;
- }
-};
-
-///
-/// A sample context class that stores a vector of Features.
-///
-class SimpleContext
-{
-public:
- SimpleContext(vector<Feature>& features) :
- m_features(features)
- { }
-
- vector<Feature>& Get_Features()
- {
- return m_features;
- }
-
- string To_String()
- {
- string out_string;
- char feature_str[35] = { 0 };
- for (size_t i = 0; i < m_features.size(); i++)
- {
- int chars;
- if (i == 0)
- {
- chars = sprintf(feature_str, "%d:", m_features[i].Id);
- }
- else
- {
- chars = sprintf(feature_str, " %d:", m_features[i].Id);
- }
- NumberUtils::print_float(feature_str + chars, m_features[i].Value);
- out_string.append(feature_str);
- }
- return out_string;
- }
-
-private:
- vector<Feature>& m_features;
-};
-
-///
-/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea
-/// which actions should be preferred. Epsilon greedy is also computationally cheap.
-///
-template <class Ctx>
-class EpsilonGreedyExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy A default function which outputs an action given a context.
- /// @param epsilon The probability of a random exploration.
- /// @param num_actions The number of actions to randomize over.
- ///
- EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) :
- m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
-
- if (m_epsilon < 0 || m_epsilon > 1)
- {
- throw std::invalid_argument("Epsilon must be between 0 and 1.");
- }
- }
-
- ~EpsilonGreedyExplorer()
- {
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default policy function to get the action
- u32 chosen_action = m_default_policy.Choose_Action(context);
-
- if (chosen_action == 0 || chosen_action > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- float action_probability = 0.f;
- float base_probability = m_epsilon / m_num_actions; // uniform probability
-
- // TODO: check this random generation
- if (random_generator.Uniform_Unit_Interval() < 1.f - m_epsilon)
- {
- action_probability = 1.f - m_epsilon + base_probability;
- }
- else
- {
- // Get uniform random action ID
- u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
-
- if (actionId == chosen_action)
- {
- // IF it matches the one chosen by the default policy
- // then increase the probability
- action_probability = 1.f - m_epsilon + base_probability;
- }
- else
- {
- // Otherwise it's just the uniform probability
- action_probability = base_probability;
- }
- chosen_action = actionId;
- }
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
- }
-
-private:
- IPolicy<Ctx>& m_default_policy;
- float m_epsilon;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// In some cases, different actions have a different scores, and you would prefer to
-/// choose actions with large scores. Softmax allows you to do that.
-///
-template <class Ctx>
-class SoftmaxExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_scorer A function which outputs a score for each action.
- /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.
- /// @param num_actions The number of actions to randomize over.
- ///
- SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) :
- m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default scorer function
- vector<float> scores = m_default_scorer.Score_Actions(context);
- u32 num_scores = (u32)scores.size();
- if (num_scores != m_num_actions)
- {
- throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions");
- }
-
- u32 i = 0;
-
- float max_score = -FLT_MAX;
- for (i = 0; i < num_scores; i++)
- {
- if (max_score < scores[i])
- {
- max_score = scores[i];
- }
- }
-
- // Create a normalized exponential distribution based on the returned scores
- for (i = 0; i < num_scores; i++)
- {
- scores[i] = exp(m_lambda * (scores[i] - max_score));
- }
-
- // Create a discrete_distribution based on the returned weights. This class handles the
- // case where the sum of the weights is < or > 1, by normalizing agains the sum.
- float total = 0.f;
- for (size_t i = 0; i < num_scores; i++)
- total += scores[i];
-
- float draw = random_generator.Uniform_Unit_Interval();
-
- float sum = 0.f;
- float action_probability = 0.f;
- u32 action_index = num_scores - 1;
- for (u32 i = 0; i < num_scores; i++)
- {
- scores[i] = scores[i] / total;
- sum += scores[i];
- if (sum > draw)
- {
- action_index = i;
- action_probability = scores[i];
- break;
- }
- }
-
- // action id is one-based
- return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
- }
-
-private:
- IScorer<Ctx>& m_default_scorer;
- float m_lambda;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// GenericExplorer provides complete flexibility. You can create any
-/// distribution over actions desired, and it will draw from that.
-///
-template <class Ctx>
-class GenericExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_scorer A function which outputs the probability of each action.
- /// @param num_actions The number of actions to randomize over.
- ///
- GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) :
- m_default_scorer(default_scorer), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default scorer function
- vector<float> weights = m_default_scorer.Score_Actions(context);
- u32 num_weights = (u32)weights.size();
- if (num_weights != m_num_actions)
- {
- throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions");
- }
-
- // Create a discrete_distribution based on the returned weights. This class handles the
- // case where the sum of the weights is < or > 1, by normalizing agains the sum.
- float total = 0.f;
- for (size_t i = 0; i < num_weights; i++)
- {
- if (weights[i] < 0)
- {
- throw std::invalid_argument("Scores must be non-negative.");
- }
- total += weights[i];
- }
- if (total == 0)
- {
- throw std::invalid_argument("At least one score must be positive.");
- }
-
- float draw = random_generator.Uniform_Unit_Interval();
-
- float sum = 0.f;
- float action_probability = 0.f;
- u32 action_index = num_weights - 1;
- for (u32 i = 0; i < num_weights; i++)
- {
- weights[i] = weights[i] / total;
- sum += weights[i];
- if (sum > draw)
- {
- action_index = i;
- action_probability = weights[i];
- break;
- }
- }
-
- // action id is one-based
- return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
- }
-
-private:
- IScorer<Ctx>& m_default_scorer;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// The tau-first explorer collects exactly tau uniform random exploration events, and then
-/// uses the default policy thereafter.
-///
-template <class Ctx>
-class TauFirstExplorer : public IExplorer<Ctx>
-{
-public:
-
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy A default policy after randomization finishes.
- /// @param tau The number of events to be uniform over.
- /// @param num_actions The number of actions to randomize over.
- ///
- TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) :
- m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- u32 chosen_action = 0;
- float action_probability = 0.f;
- bool log_action;
- if (m_tau)
- {
- m_tau--;
- u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
- action_probability = 1.f / m_num_actions;
- chosen_action = actionId;
- log_action = true;
- }
- else
- {
- // Invoke the default policy function to get the action
- chosen_action = m_default_policy.Choose_Action(context);
-
- if (chosen_action == 0 || chosen_action > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- action_probability = 1.f;
- log_action = false;
- }
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action);
- }
-
-private:
- IPolicy<Ctx>& m_default_policy;
- u32 m_tau;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies.
-/// This performs well statistically but can be computationally expensive.
-///
-template <class Ctx>
-class BootstrapExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy_functions A set of default policies to be uniform random over.
- /// The policy pointers must be valid throughout the lifetime of this explorer.
- /// @param num_actions The number of actions to randomize over.
- ///
- BootstrapExplorer(vector<unique_ptr<IPolicy<Ctx>>>& default_policy_functions, u32 num_actions) :
- m_default_policy_functions(default_policy_functions),
- m_num_actions(num_actions)
- {
- m_bags = (u32)default_policy_functions.size();
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
-
- if (m_bags < 1)
- {
- throw std::invalid_argument("Number of bags must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Select bag
- u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1);
-
- // Invoke the default policy function to get the action
- u32 chosen_action = 0;
- u32 action_from_bag = 0;
- vector<u32> actions_selected;
- for (size_t i = 0; i < m_num_actions; i++)
- {
- actions_selected.push_back(0);
- }
-
- // Invoke the default policy function to get the action
- for (u32 current_bag = 0; current_bag < m_bags; current_bag++)
- {
- action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context);
-
- if (action_from_bag == 0 || action_from_bag > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- if (current_bag == chosen_bag)
- {
- chosen_action = action_from_bag;
- }
- //this won't work if actions aren't 0 to Count
- actions_selected[action_from_bag - 1]++; // action id is one-based
- }
- float action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
- }
-
-private:
- vector<unique_ptr<IPolicy<Ctx>>>& m_default_policy_functions;
- u32 m_bags;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-} // End namespace MultiWorldTestingCpp
-/*! @} End of Doxygen Groups*/
+// +// Main interface for clients of the Multiworld testing (MWT) service. +// + +#pragma once + +#include <stdexcept> +#include <float.h> +#include <math.h> +#include <stdio.h> +#include <string.h> +#include <vector> +#include <utility> +#include <memory> +#include <limits.h> +#include <tuple> + +#ifdef MANAGED_CODE +#define PORTING_INTERFACE public +#define MWT_NAMESPACE namespace NativeMultiWorldTesting +#else +#define PORTING_INTERFACE private +#define MWT_NAMESPACE namespace MultiWorldTesting +#endif + +using namespace std; + +#include "utility.h" + +/** \defgroup MultiWorldTestingCpp +\brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp +*/ + +/*! +* \addtogroup MultiWorldTestingCpp +* @{ +*/ + +//! Interface for C++ version of Multiworld Testing library. +//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp +MWT_NAMESPACE { + +// Forward declarations +template <class Ctx> +class IRecorder; +template <class Ctx> +class IExplorer; + +/// +/// The top-level MwtExplorer class. Using this enables principled and efficient exploration +/// over a set of possible actions, and ensures that the right bits are recorded. +/// +template <class Ctx> +class MwtExplorer +{ +public: + /// + /// Constructor + /// + /// @param appid This should be unique to your experiment or you risk nasty correlation bugs. + /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning. + /// + MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder) + { + m_app_id = HashUtils::Compute_Id_Hash(app_id); + } + + /// + /// Chooses an action by invoking an underlying exploration algorithm. This should be a + /// drop-in replacement for any existing policy function. + /// + /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback. + /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc.. + /// @param context The context upon which a decision is made. See SimpleContext below for an example. + /// + u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context) + { + u64 seed = HashUtils::Compute_Id_Hash(unique_key); + + std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context); + + u32 action = std::get<0>(action_probability_log_tuple); + float prob = std::get<1>(action_probability_log_tuple); + + if (std::get<2>(action_probability_log_tuple)) + { + m_recorder.Record(context, action, prob, unique_key); + } + + return action; + } + +private: + u64 m_app_id; + IRecorder<Ctx>& m_recorder; +}; + +/// +/// Exposes a method to record exploration data based on generic contexts. Exploration data +/// is specified as a set of tuples <context, action, probability, key> as described below. An +/// application passes an IRecorder object to the @MwtExplorer constructor. See +/// @StringRecorder for a sample IRecorder object. +/// +template <class Ctx> +class IRecorder +{ +public: + /// + /// Records the exploration data associated with a given decision. + /// + /// @param context A user-defined context for the decision + /// @param action The action chosen by an exploration algorithm given context + /// @param probability The probability the exploration algorithm chose said action + /// @param unique_key A user-defined unique identifer for the decision + /// + virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0; +}; + +/// +/// Exposes a method to choose an action given a generic context, and obtain the relevant +/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this +/// interface yourself: instead, use the various exploration algorithms below, which +/// implement it for you. +/// +template <class Ctx> +class IExplorer +{ +public: + /// + /// Determines the action to take and the probability with which it was chosen, for a + /// given context. + /// + /// @param salted_seed A PRG seed based on a unique id information provided by the user + /// @param context A user-defined context for the decision + /// @returns The action to take, the probability it was chosen, and a flag indicating + /// whether to record this decision + /// + virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0; +}; + +/// +/// Exposes a method to choose an action given a generic context. IPolicy objects are +/// passed to (and invoked by) exploration algorithms to specify the default policy behavior. +/// +template <class Ctx> +class IPolicy +{ +public: + /// + /// Determines the action to take for a given context. + /// + /// @param context A user-defined context for the decision + /// @returns The action to take (1-based index) + /// + virtual u32 Choose_Action(Ctx& context) = 0; +}; + +/// +/// Exposes a method for specifying a score (weight) for each action given a generic context. +/// +template <class Ctx> +class IScorer +{ +public: + /// + /// Determines the score of each action for a given context. + /// + /// @param context A user-defined context for the decision + /// @returns A vector of scores indexed by action (1-based) + /// + virtual vector<float> Score_Actions(Ctx& context) = 0; +}; + +/// +/// A sample recorder class that converts the exploration tuple into string format. +/// +template <class Ctx> +struct StringRecorder : public IRecorder<Ctx> +{ + void Record(Ctx& context, u32 action, float probability, string unique_key) + { + // Implicitly enforce To_String() API on the context + m_recording.append(to_string((unsigned long)action)); + m_recording.append(" ", 1); + m_recording.append(unique_key); + m_recording.append(" ", 1); + + char prob_str[10] = { 0 }; + NumberUtils::Float_To_String(probability, prob_str); + m_recording.append(prob_str); + + m_recording.append(" | ", 3); + m_recording.append(context.To_String()); + m_recording.append("\n"); + } + + // Gets the content of the recording so far as a string and optionally clears internal content. + string Get_Recording(bool flush = true) + { + if (!flush) + { + return m_recording; + } + string recording = m_recording; + m_recording.clear(); + return recording; + } + +private: + string m_recording; +}; + +/// +/// Represents a feature in a sparse array. +/// +struct Feature +{ + float Value; + u32 Id; + + bool operator==(Feature other_feature) + { + return Id == other_feature.Id; + } +}; + +/// +/// A sample context class that stores a vector of Features. +/// +class SimpleContext +{ +public: + SimpleContext(vector<Feature>& features) : + m_features(features) + { } + + vector<Feature>& Get_Features() + { + return m_features; + } + + string To_String() + { + string out_string; + char feature_str[35] = { 0 }; + for (size_t i = 0; i < m_features.size(); i++) + { + int chars; + if (i == 0) + { + chars = sprintf(feature_str, "%d:", m_features[i].Id); + } + else + { + chars = sprintf(feature_str, " %d:", m_features[i].Id); + } + NumberUtils::print_float(feature_str + chars, m_features[i].Value); + out_string.append(feature_str); + } + return out_string; + } + +private: + vector<Feature>& m_features; +}; + +/// +/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea +/// which actions should be preferred. Epsilon greedy is also computationally cheap. +/// +template <class Ctx> +class EpsilonGreedyExplorer : public IExplorer<Ctx> +{ +public: + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_policy A default function which outputs an action given a context. + /// @param epsilon The probability of a random exploration. + /// @param num_actions The number of actions to randomize over. + /// + EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) : + m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions) + { + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + + if (m_epsilon < 0 || m_epsilon > 1) + { + throw std::invalid_argument("Epsilon must be between 0 and 1."); + } + } + + ~EpsilonGreedyExplorer() + { + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + // Invoke the default policy function to get the action + u32 chosen_action = m_default_policy.Choose_Action(context); + + if (chosen_action == 0 || chosen_action > m_num_actions) + { + throw std::invalid_argument("Action chosen by default policy is not within valid range."); + } + + float action_probability = 0.f; + float base_probability = m_epsilon / m_num_actions; // uniform probability + + // TODO: check this random generation + if (random_generator.Uniform_Unit_Interval() < 1.f - m_epsilon) + { + action_probability = 1.f - m_epsilon + base_probability; + } + else + { + // Get uniform random action ID + u32 actionId = random_generator.Uniform_Int(1, m_num_actions); + + if (actionId == chosen_action) + { + // IF it matches the one chosen by the default policy + // then increase the probability + action_probability = 1.f - m_epsilon + base_probability; + } + else + { + // Otherwise it's just the uniform probability + action_probability = base_probability; + } + chosen_action = actionId; + } + + return std::tuple<u32, float, bool>(chosen_action, action_probability, true); + } + +private: + IPolicy<Ctx>& m_default_policy; + float m_epsilon; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; + +/// +/// In some cases, different actions have a different scores, and you would prefer to +/// choose actions with large scores. Softmax allows you to do that. +/// +template <class Ctx> +class SoftmaxExplorer : public IExplorer<Ctx> +{ +public: + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_scorer A function which outputs a score for each action. + /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max. + /// @param num_actions The number of actions to randomize over. + /// + SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) : + m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions) + { + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + // Invoke the default scorer function + vector<float> scores = m_default_scorer.Score_Actions(context); + u32 num_scores = (u32)scores.size(); + if (num_scores != m_num_actions) + { + throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions"); + } + + u32 i = 0; + + float max_score = -FLT_MAX; + for (i = 0; i < num_scores; i++) + { + if (max_score < scores[i]) + { + max_score = scores[i]; + } + } + + // Create a normalized exponential distribution based on the returned scores + for (i = 0; i < num_scores; i++) + { + scores[i] = exp(m_lambda * (scores[i] - max_score)); + } + + // Create a discrete_distribution based on the returned weights. This class handles the + // case where the sum of the weights is < or > 1, by normalizing agains the sum. + float total = 0.f; + for (size_t i = 0; i < num_scores; i++) + total += scores[i]; + + float draw = random_generator.Uniform_Unit_Interval(); + + float sum = 0.f; + float action_probability = 0.f; + u32 action_index = num_scores - 1; + for (u32 i = 0; i < num_scores; i++) + { + scores[i] = scores[i] / total; + sum += scores[i]; + if (sum > draw) + { + action_index = i; + action_probability = scores[i]; + break; + } + } + + // action id is one-based + return std::tuple<u32, float, bool>(action_index + 1, action_probability, true); + } + +private: + IScorer<Ctx>& m_default_scorer; + float m_lambda; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; + +/// +/// GenericExplorer provides complete flexibility. You can create any +/// distribution over actions desired, and it will draw from that. +/// +template <class Ctx> +class GenericExplorer : public IExplorer<Ctx> +{ +public: + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_scorer A function which outputs the probability of each action. + /// @param num_actions The number of actions to randomize over. + /// + GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) : + m_default_scorer(default_scorer), m_num_actions(num_actions) + { + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + // Invoke the default scorer function + vector<float> weights = m_default_scorer.Score_Actions(context); + u32 num_weights = (u32)weights.size(); + if (num_weights != m_num_actions) + { + throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions"); + } + + // Create a discrete_distribution based on the returned weights. This class handles the + // case where the sum of the weights is < or > 1, by normalizing agains the sum. + float total = 0.f; + for (size_t i = 0; i < num_weights; i++) + { + if (weights[i] < 0) + { + throw std::invalid_argument("Scores must be non-negative."); + } + total += weights[i]; + } + if (total == 0) + { + throw std::invalid_argument("At least one score must be positive."); + } + + float draw = random_generator.Uniform_Unit_Interval(); + + float sum = 0.f; + float action_probability = 0.f; + u32 action_index = num_weights - 1; + for (u32 i = 0; i < num_weights; i++) + { + weights[i] = weights[i] / total; + sum += weights[i]; + if (sum > draw) + { + action_index = i; + action_probability = weights[i]; + break; + } + } + + // action id is one-based + return std::tuple<u32, float, bool>(action_index + 1, action_probability, true); + } + +private: + IScorer<Ctx>& m_default_scorer; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; + +/// +/// The tau-first explorer collects exactly tau uniform random exploration events, and then +/// uses the default policy thereafter. +/// +template <class Ctx> +class TauFirstExplorer : public IExplorer<Ctx> +{ +public: + + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_policy A default policy after randomization finishes. + /// @param tau The number of events to be uniform over. + /// @param num_actions The number of actions to randomize over. + /// + TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) : + m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions) + { + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + u32 chosen_action = 0; + float action_probability = 0.f; + bool log_action; + if (m_tau) + { + m_tau--; + u32 actionId = random_generator.Uniform_Int(1, m_num_actions); + action_probability = 1.f / m_num_actions; + chosen_action = actionId; + log_action = true; + } + else + { + // Invoke the default policy function to get the action + chosen_action = m_default_policy.Choose_Action(context); + + if (chosen_action == 0 || chosen_action > m_num_actions) + { + throw std::invalid_argument("Action chosen by default policy is not within valid range."); + } + + action_probability = 1.f; + log_action = false; + } + + return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action); + } + +private: + IPolicy<Ctx>& m_default_policy; + u32 m_tau; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; + +/// +/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies. +/// This performs well statistically but can be computationally expensive. +/// +template <class Ctx> +class BootstrapExplorer : public IExplorer<Ctx> +{ +public: + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_policy_functions A set of default policies to be uniform random over. + /// The policy pointers must be valid throughout the lifetime of this explorer. + /// @param num_actions The number of actions to randomize over. + /// + BootstrapExplorer(vector<unique_ptr<IPolicy<Ctx>>>& default_policy_functions, u32 num_actions) : + m_default_policy_functions(default_policy_functions), + m_num_actions(num_actions) + { + m_bags = (u32)default_policy_functions.size(); + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + + if (m_bags < 1) + { + throw std::invalid_argument("Number of bags must be at least 1."); + } + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + // Select bag + u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1); + + // Invoke the default policy function to get the action + u32 chosen_action = 0; + u32 action_from_bag = 0; + vector<u32> actions_selected; + for (size_t i = 0; i < m_num_actions; i++) + { + actions_selected.push_back(0); + } + + // Invoke the default policy function to get the action + for (u32 current_bag = 0; current_bag < m_bags; current_bag++) + { + action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context); + + if (action_from_bag == 0 || action_from_bag > m_num_actions) + { + throw std::invalid_argument("Action chosen by default policy is not within valid range."); + } + + if (current_bag == chosen_bag) + { + chosen_action = action_from_bag; + } + //this won't work if actions aren't 0 to Count + actions_selected[action_from_bag - 1]++; // action id is one-based + } + float action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based + + return std::tuple<u32, float, bool>(chosen_action, action_probability, true); + } + +private: + vector<unique_ptr<IPolicy<Ctx>>>& m_default_policy_functions; + u32 m_bags; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; +} // End namespace MultiWorldTestingCpp +/*! @} End of Doxygen Groups*/ diff --git a/explore/static/utility.h b/explore/static/utility.h index a6284809..6a2a60d6 100644 --- a/explore/static/utility.h +++ b/explore/static/utility.h @@ -1,286 +1,286 @@ /*******************************************************************/ // Classes declared in this file are intended for internal use only. -/*******************************************************************/
-
-#pragma once
-#include <stdint.h>
-#include <sys/types.h> /* defines size_t */
-
-#ifdef WIN32
-typedef unsigned __int64 u64;
-typedef unsigned __int32 u32;
-typedef unsigned __int16 u16;
-typedef unsigned __int8 u8;
-typedef signed __int64 i64;
-typedef signed __int32 i32;
-typedef signed __int16 i16;
-typedef signed __int8 i8;
-#else
-typedef uint64_t u64;
-typedef uint32_t u32;
-typedef uint16_t u16;
-typedef uint8_t u8;
-typedef int64_t i64;
-typedef int32_t i32;
-typedef int16_t i16;
-typedef int8_t i8;
-#endif
-
-typedef unsigned char byte;
-
-#include <string>
-#include <stdint.h>
-#include <math.h>
-
-/*!
-* \addtogroup MultiWorldTestingCpp
-* @{
-*/
-
-MWT_NAMESPACE {
-
-//
-// MurmurHash3, by Austin Appleby
-//
-// Originals at:
-// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp
-// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.h
-//
-// Notes:
-// 1) this code assumes we can read a 4-byte value from any address
-// without crashing (i.e non aligned access is supported). This is
-// not a problem on Intel/x86/AMD64 machines (including new Macs)
-// 2) It produces different results on little-endian and big-endian machines.
-//
-//-----------------------------------------------------------------------------
-// MurmurHash3 was written by Austin Appleby, and is placed in the public
-// domain. The author hereby disclaims copyright to this source code.
-
-// Note - The x86 and x64 versions do _not_ produce the same results, as the
-// algorithms are optimized for their respective platforms. You can still
-// compile and run any of them on any platform, but your performance with the
-// non-native version will be less than optimal.
-//-----------------------------------------------------------------------------
-
-// Platform-specific functions and macros
-#if defined(_MSC_VER) // Microsoft Visual Studio
-# include <stdint.h>
-
-# include <stdlib.h>
-# define ROTL32(x,y) _rotl(x,y)
-# define BIG_CONSTANT(x) (x)
-
-#else // Other compilers
-# include <stdint.h> /* defines uint32_t etc */
-
- inline uint32_t rotl32(uint32_t x, int8_t r)
- {
- return (x << r) | (x >> (32 - r));
- }
-
-# define ROTL32(x,y) rotl32(x,y)
-# define BIG_CONSTANT(x) (x##LLU)
-
-#endif // !defined(_MSC_VER)
-
-struct murmur_hash {
-
- //-----------------------------------------------------------------------------
- // Block read - if your platform needs to do endian-swapping or can only
- // handle aligned reads, do the conversion here
-private:
- static inline uint32_t getblock(const uint32_t * p, int i)
- {
- return p[i];
- }
-
- //-----------------------------------------------------------------------------
- // Finalization mix - force all bits of a hash block to avalanche
-
- static inline uint32_t fmix(uint32_t h)
- {
- h ^= h >> 16;
- h *= 0x85ebca6b;
- h ^= h >> 13;
- h *= 0xc2b2ae35;
- h ^= h >> 16;
-
- return h;
- }
-
- //-----------------------------------------------------------------------------
-public:
- uint32_t uniform_hash(const void * key, size_t len, uint32_t seed)
- {
- const uint8_t * data = (const uint8_t*)key;
- const int nblocks = (int)len / 4;
-
- uint32_t h1 = seed;
-
- const uint32_t c1 = 0xcc9e2d51;
- const uint32_t c2 = 0x1b873593;
-
- // --- body
- const uint32_t * blocks = (const uint32_t *)(data + nblocks * 4);
-
- for (int i = -nblocks; i; i++) {
- uint32_t k1 = getblock(blocks, i);
-
- k1 *= c1;
- k1 = ROTL32(k1, 15);
- k1 *= c2;
-
- h1 ^= k1;
- h1 = ROTL32(h1, 13);
- h1 = h1 * 5 + 0xe6546b64;
- }
-
- // --- tail
- const uint8_t * tail = (const uint8_t*)(data + nblocks * 4);
-
- uint32_t k1 = 0;
-
- switch (len & 3) {
- case 3: k1 ^= tail[2] << 16;
- case 2: k1 ^= tail[1] << 8;
- case 1: k1 ^= tail[0];
- k1 *= c1; k1 = ROTL32(k1, 15); k1 *= c2; h1 ^= k1;
- }
-
- // --- finalization
- h1 ^= len;
-
- return fmix(h1);
- }
-};
-
-class HashUtils
-{
-public:
- static u64 Compute_Id_Hash(const std::string& unique_id)
- {
- size_t ret = 0;
- const char *p = unique_id.c_str();
- while (*p != '\0')
- if (*p >= '0' && *p <= '9')
- ret = 10 * ret + *(p++) - '0';
- else
- {
- murmur_hash foo;
- return foo.uniform_hash(unique_id.c_str(), unique_id.size(), 0);
- }
- return ret;
- }
-};
-
-const size_t max_int = 100000;
-const float max_float = max_int;
-const float min_float = 0.00001f;
-const size_t max_digits = (size_t) roundf((float) (-log(min_float) / log(10.)));
-
-class NumberUtils
-{
-public:
- static void Float_To_String(float f, char* str)
- {
- int x = (int)f;
- int d = (int)(fabs(f - x) * 100000);
- sprintf(str, "%d.%05d", x, d);
- }
-
- template<bool trailing_zeros>
- static void print_mantissa(char*& begin, float f)
- { // helper for print_float
- char values[10];
- size_t v = (size_t)f;
- size_t digit = 0;
- size_t first_nonzero = 0;
- for (size_t max = 1; max <= v; ++digit)
- {
- size_t max_next = max * 10;
- char v_mod = (char) (v % max_next / max);
- if (!trailing_zeros && v_mod != '\0' && first_nonzero == 0)
- first_nonzero = digit;
- values[digit] = '0' + v_mod;
- max = max_next;
- }
- if (!trailing_zeros)
- for (size_t i = max_digits; i > digit; i--)
- *begin++ = '0';
- while (digit > first_nonzero)
- *begin++ = values[--digit];
- }
-
- static void print_float(char* begin, float f)
- {
- bool sign = false;
- if (f < 0.f)
- sign = true;
- float unsigned_f = fabsf(f);
- if (unsigned_f < max_float && unsigned_f > min_float)
- {
- if (sign)
- *begin++ = '-';
- print_mantissa<true>(begin, unsigned_f);
- unsigned_f -= (size_t)unsigned_f;
- unsigned_f *= max_int;
- if (unsigned_f >= 1.f)
- {
- *begin++ = '.';
- print_mantissa<false>(begin, unsigned_f);
- }
- }
- else if (unsigned_f == 0.)
- *begin++ = '0';
- else
- {
- sprintf(begin, "%g", f);
- return;
- }
- *begin = '\0';
- return;
- }
-};
-
-//A quick implementation similar to drand48 for cross-platform compatibility
-namespace PRG {
- const uint64_t a = 0xeece66d5deece66dULL;
- const uint64_t c = 2147483647;
-
- const int bias = 127 << 23;
-
- union int_float {
- int32_t i;
- float f;
- };
-
- struct prg {
- private:
- uint64_t v;
- public:
- prg() { v = c; }
- prg(uint64_t initial) { v = initial; }
-
- float merand48(uint64_t& initial)
- {
- initial = a * initial + c;
- int_float temp;
- temp.i = ((initial >> 25) & 0x7FFFFF) | bias;
- return temp.f - 1;
- }
-
- float Uniform_Unit_Interval()
- {
- return merand48(v);
- }
-
- uint32_t Uniform_Int(uint32_t low, uint32_t high)
- {
- merand48(v);
- uint32_t ret = low + ((v >> 25) % (high - low + 1));
- return ret;
- }
- };
-}
-}
+/*******************************************************************/ + +#pragma once +#include <stdint.h> +#include <sys/types.h> /* defines size_t */ + +#ifdef WIN32 +typedef unsigned __int64 u64; +typedef unsigned __int32 u32; +typedef unsigned __int16 u16; +typedef unsigned __int8 u8; +typedef signed __int64 i64; +typedef signed __int32 i32; +typedef signed __int16 i16; +typedef signed __int8 i8; +#else +typedef uint64_t u64; +typedef uint32_t u32; +typedef uint16_t u16; +typedef uint8_t u8; +typedef int64_t i64; +typedef int32_t i32; +typedef int16_t i16; +typedef int8_t i8; +#endif + +typedef unsigned char byte; + +#include <string> +#include <stdint.h> +#include <math.h> + +/*! +* \addtogroup MultiWorldTestingCpp +* @{ +*/ + +MWT_NAMESPACE { + +// +// MurmurHash3, by Austin Appleby +// +// Originals at: +// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp +// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.h +// +// Notes: +// 1) this code assumes we can read a 4-byte value from any address +// without crashing (i.e non aligned access is supported). This is +// not a problem on Intel/x86/AMD64 machines (including new Macs) +// 2) It produces different results on little-endian and big-endian machines. +// +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. + +// Note - The x86 and x64 versions do _not_ produce the same results, as the +// algorithms are optimized for their respective platforms. You can still +// compile and run any of them on any platform, but your performance with the +// non-native version will be less than optimal. +//----------------------------------------------------------------------------- + +// Platform-specific functions and macros +#if defined(_MSC_VER) // Microsoft Visual Studio +# include <stdint.h> + +# include <stdlib.h> +# define ROTL32(x,y) _rotl(x,y) +# define BIG_CONSTANT(x) (x) + +#else // Other compilers +# include <stdint.h> /* defines uint32_t etc */ + + inline uint32_t rotl32(uint32_t x, int8_t r) + { + return (x << r) | (x >> (32 - r)); + } + +# define ROTL32(x,y) rotl32(x,y) +# define BIG_CONSTANT(x) (x##LLU) + +#endif // !defined(_MSC_VER) + +struct murmur_hash { + + //----------------------------------------------------------------------------- + // Block read - if your platform needs to do endian-swapping or can only + // handle aligned reads, do the conversion here +private: + static inline uint32_t getblock(const uint32_t * p, int i) + { + return p[i]; + } + + //----------------------------------------------------------------------------- + // Finalization mix - force all bits of a hash block to avalanche + + static inline uint32_t fmix(uint32_t h) + { + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + + return h; + } + + //----------------------------------------------------------------------------- +public: + uint32_t uniform_hash(const void * key, size_t len, uint32_t seed) + { + const uint8_t * data = (const uint8_t*)key; + const int nblocks = (int)len / 4; + + uint32_t h1 = seed; + + const uint32_t c1 = 0xcc9e2d51; + const uint32_t c2 = 0x1b873593; + + // --- body + const uint32_t * blocks = (const uint32_t *)(data + nblocks * 4); + + for (int i = -nblocks; i; i++) { + uint32_t k1 = getblock(blocks, i); + + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + + h1 ^= k1; + h1 = ROTL32(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + } + + // --- tail + const uint8_t * tail = (const uint8_t*)(data + nblocks * 4); + + uint32_t k1 = 0; + + switch (len & 3) { + case 3: k1 ^= tail[2] << 16; + case 2: k1 ^= tail[1] << 8; + case 1: k1 ^= tail[0]; + k1 *= c1; k1 = ROTL32(k1, 15); k1 *= c2; h1 ^= k1; + } + + // --- finalization + h1 ^= len; + + return fmix(h1); + } +}; + +class HashUtils +{ +public: + static u64 Compute_Id_Hash(const std::string& unique_id) + { + size_t ret = 0; + const char *p = unique_id.c_str(); + while (*p != '\0') + if (*p >= '0' && *p <= '9') + ret = 10 * ret + *(p++) - '0'; + else + { + murmur_hash foo; + return foo.uniform_hash(unique_id.c_str(), unique_id.size(), 0); + } + return ret; + } +}; + +const size_t max_int = 100000; +const float max_float = max_int; +const float min_float = 0.00001f; +const size_t max_digits = (size_t) roundf((float) (-log(min_float) / log(10.))); + +class NumberUtils +{ +public: + static void Float_To_String(float f, char* str) + { + int x = (int)f; + int d = (int)(fabs(f - x) * 100000); + sprintf(str, "%d.%05d", x, d); + } + + template<bool trailing_zeros> + static void print_mantissa(char*& begin, float f) + { // helper for print_float + char values[10]; + size_t v = (size_t)f; + size_t digit = 0; + size_t first_nonzero = 0; + for (size_t max = 1; max <= v; ++digit) + { + size_t max_next = max * 10; + char v_mod = (char) (v % max_next / max); + if (!trailing_zeros && v_mod != '\0' && first_nonzero == 0) + first_nonzero = digit; + values[digit] = '0' + v_mod; + max = max_next; + } + if (!trailing_zeros) + for (size_t i = max_digits; i > digit; i--) + *begin++ = '0'; + while (digit > first_nonzero) + *begin++ = values[--digit]; + } + + static void print_float(char* begin, float f) + { + bool sign = false; + if (f < 0.f) + sign = true; + float unsigned_f = fabsf(f); + if (unsigned_f < max_float && unsigned_f > min_float) + { + if (sign) + *begin++ = '-'; + print_mantissa<true>(begin, unsigned_f); + unsigned_f -= (size_t)unsigned_f; + unsigned_f *= max_int; + if (unsigned_f >= 1.f) + { + *begin++ = '.'; + print_mantissa<false>(begin, unsigned_f); + } + } + else if (unsigned_f == 0.) + *begin++ = '0'; + else + { + sprintf(begin, "%g", f); + return; + } + *begin = '\0'; + return; + } +}; + +//A quick implementation similar to drand48 for cross-platform compatibility +namespace PRG { + const uint64_t a = 0xeece66d5deece66dULL; + const uint64_t c = 2147483647; + + const int bias = 127 << 23; + + union int_float { + int32_t i; + float f; + }; + + struct prg { + private: + uint64_t v; + public: + prg() { v = c; } + prg(uint64_t initial) { v = initial; } + + float merand48(uint64_t& initial) + { + initial = a * initial + c; + int_float temp; + temp.i = ((initial >> 25) & 0x7FFFFF) | bias; + return temp.f - 1; + } + + float Uniform_Unit_Interval() + { + return merand48(v); + } + + uint32_t Uniform_Int(uint32_t low, uint32_t high) + { + merand48(v); + uint32_t ret = low + ((v >> 25) % (high - low + 1)); + return ret; + } + }; +} +} /*! @} End of Doxygen Groups*/ diff --git a/explore/tests/MWTExploreTests.h b/explore/tests/MWTExploreTests.h index 447a368f..1c451fb4 100644 --- a/explore/tests/MWTExploreTests.h +++ b/explore/tests/MWTExploreTests.h @@ -1,164 +1,164 @@ -#pragma once
-
-#include "MWTExplorer.h"
-#include "utility.h"
-#include <iomanip>
-#include <iostream>
-#include <sstream>
-
-using namespace MultiWorldTesting;
-
-class TestContext
-{
-
-};
-
-template <class Ctx>
-struct TestInteraction
-{
- Ctx& Context;
- u32 Action;
- float Probability;
- string Unique_Key;
-};
-
-class TestPolicy : public IPolicy<TestContext>
-{
-public:
- TestPolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- u32 Choose_Action(TestContext& context)
- {
- return m_params % m_num_actions + 1; // action id is one-based
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestScorer : public IScorer<TestContext>
-{
-public:
- TestScorer(int params, int num_actions, bool uniform = true) :
- m_params(params), m_num_actions(num_actions), m_uniform(uniform)
- { }
-
- vector<float> Score_Actions(TestContext& context)
- {
- vector<float> scores;
- if (m_uniform)
- {
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params);
- }
- }
- else
- {
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params + i);
- }
- }
- return scores;
- }
-private:
- int m_params;
- int m_num_actions;
- bool m_uniform;
-};
-
-class FixedScorer : public IScorer<TestContext>
-{
-public:
- FixedScorer(int num_actions, int value) :
- m_num_actions(num_actions), m_value(value)
- { }
-
- vector<float> Score_Actions(TestContext& context)
- {
- vector<float> scores;
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back((float)m_value);
- }
- return scores;
- }
-private:
- int m_num_actions;
- int m_value;
-};
-
-class TestSimpleScorer : public IScorer<SimpleContext>
-{
-public:
- TestSimpleScorer(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- vector<float> Score_Actions(SimpleContext& context)
- {
- vector<float> scores;
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params);
- }
- return scores;
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestSimplePolicy : public IPolicy<SimpleContext>
-{
-public:
- TestSimplePolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- u32 Choose_Action(SimpleContext& context)
- {
- return m_params % m_num_actions + 1; // action id is one-based
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestSimpleRecorder : public IRecorder<SimpleContext>
-{
-public:
- virtual void Record(SimpleContext& context, u32 action, float probability, string unique_key)
- {
- m_interactions.push_back({ context, action, probability, unique_key });
- }
-
- vector<TestInteraction<SimpleContext>> Get_All_Interactions()
- {
- return m_interactions;
- }
-
-private:
- vector<TestInteraction<SimpleContext>> m_interactions;
-};
-
-// Return action outside valid range
-class TestBadPolicy : public IPolicy<TestContext>
-{
-public:
- u32 Choose_Action(TestContext& context)
- {
- return 100;
- }
-};
-
-class TestRecorder : public IRecorder<TestContext>
-{
-public:
- virtual void Record(TestContext& context, u32 action, float probability, string unique_key)
- {
- m_interactions.push_back({ context, action, probability, unique_key });
- }
-
- vector<TestInteraction<TestContext>> Get_All_Interactions()
- {
- return m_interactions;
- }
-
-private:
- vector<TestInteraction<TestContext>> m_interactions;
-};
+#pragma once + +#include "MWTExplorer.h" +#include "utility.h" +#include <iomanip> +#include <iostream> +#include <sstream> + +using namespace MultiWorldTesting; + +class TestContext +{ + +}; + +template <class Ctx> +struct TestInteraction +{ + Ctx& Context; + u32 Action; + float Probability; + string Unique_Key; +}; + +class TestPolicy : public IPolicy<TestContext> +{ +public: + TestPolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { } + u32 Choose_Action(TestContext& context) + { + return m_params % m_num_actions + 1; // action id is one-based + } +private: + int m_params; + int m_num_actions; +}; + +class TestScorer : public IScorer<TestContext> +{ +public: + TestScorer(int params, int num_actions, bool uniform = true) : + m_params(params), m_num_actions(num_actions), m_uniform(uniform) + { } + + vector<float> Score_Actions(TestContext& context) + { + vector<float> scores; + if (m_uniform) + { + for (u32 i = 0; i < m_num_actions; i++) + { + scores.push_back(m_params); + } + } + else + { + for (u32 i = 0; i < m_num_actions; i++) + { + scores.push_back(m_params + i); + } + } + return scores; + } +private: + int m_params; + int m_num_actions; + bool m_uniform; +}; + +class FixedScorer : public IScorer<TestContext> +{ +public: + FixedScorer(int num_actions, int value) : + m_num_actions(num_actions), m_value(value) + { } + + vector<float> Score_Actions(TestContext& context) + { + vector<float> scores; + for (u32 i = 0; i < m_num_actions; i++) + { + scores.push_back((float)m_value); + } + return scores; + } +private: + int m_num_actions; + int m_value; +}; + +class TestSimpleScorer : public IScorer<SimpleContext> +{ +public: + TestSimpleScorer(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { } + vector<float> Score_Actions(SimpleContext& context) + { + vector<float> scores; + for (u32 i = 0; i < m_num_actions; i++) + { + scores.push_back(m_params); + } + return scores; + } +private: + int m_params; + int m_num_actions; +}; + +class TestSimplePolicy : public IPolicy<SimpleContext> +{ +public: + TestSimplePolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { } + u32 Choose_Action(SimpleContext& context) + { + return m_params % m_num_actions + 1; // action id is one-based + } +private: + int m_params; + int m_num_actions; +}; + +class TestSimpleRecorder : public IRecorder<SimpleContext> +{ +public: + virtual void Record(SimpleContext& context, u32 action, float probability, string unique_key) + { + m_interactions.push_back({ context, action, probability, unique_key }); + } + + vector<TestInteraction<SimpleContext>> Get_All_Interactions() + { + return m_interactions; + } + +private: + vector<TestInteraction<SimpleContext>> m_interactions; +}; + +// Return action outside valid range +class TestBadPolicy : public IPolicy<TestContext> +{ +public: + u32 Choose_Action(TestContext& context) + { + return 100; + } +}; + +class TestRecorder : public IRecorder<TestContext> +{ +public: + virtual void Record(TestContext& context, u32 action, float probability, string unique_key) + { + m_interactions.push_back({ context, action, probability, unique_key }); + } + + vector<TestInteraction<TestContext>> Get_All_Interactions() + { + return m_interactions; + } + +private: + vector<TestInteraction<TestContext>> m_interactions; +}; diff --git a/library/Makefile b/library/Makefile index 8f7deb35..5ca18f35 100644 --- a/library/Makefile +++ b/library/Makefile @@ -40,3 +40,4 @@ gd_mf_weights: gd_mf_weights.cc ../vowpalwabbit/libvw.a clean: rm -f *.o ezexample_predict ezexample_train library_example test_search recommend ezexample_predict_threaded +.PHONY: all clean diff --git a/library/ezexample_predict.cc b/library/ezexample_predict.cc index db061f61..22c8a86d 100644 --- a/library/ezexample_predict.cc +++ b/library/ezexample_predict.cc @@ -1,55 +1,55 @@ -#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-using namespace std;
-
-int main(int argc, char *argv[])
-{
- string init_string = "-t -q st --hash all --noconstant --ldf_override s -i ";
- if (argc > 1)
- init_string += argv[1];
- else
- init_string += "train.w";
-
- cerr << "initializing with: '" << init_string << "'" << endl;
-
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
- vw* vw = VW::initialize(init_string); // "-t -q st --hash all --noconstant --ldf_override s -i train.w");
-
- {
- // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
- ezexample ex(vw, false); // don't need multiline
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- cerr << ex.predict_partial() << endl;
-
- // ex.clear_features();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("2");
- cerr << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace, and add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("2");
- cerr << ex.predict_partial() << endl;
- }
-
- // AND FINISH UP
- VW::finish(*vw);
-}
+#include <stdio.h> +#include "../vowpalwabbit/parser.h" +#include "../vowpalwabbit/vw.h" +#include "../vowpalwabbit/ezexample.h" + +using namespace std; + +int main(int argc, char *argv[]) +{ + string init_string = "-t -q st --hash all --noconstant --ldf_override s -i "; + if (argc > 1) + init_string += argv[1]; + else + init_string += "train.w"; + + cerr << "initializing with: '" << init_string << "'" << endl; + + // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w + vw* vw = VW::initialize(init_string); // "-t -q st --hash all --noconstant --ldf_override s -i train.w"); + + { + // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS + ezexample ex(vw, false); // don't need multiline + ex(vw_namespace('s')) + ("p^the_man") + ("w^the") + ("w^man") + (vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + ex.set_label("1"); + cerr << ex.predict_partial() << endl; + + // ex.clear_features(); + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + ex.set_label("2"); + cerr << ex.predict_partial() << endl; + + --ex; // remove the most recent namespace, and add features with explicit ns + ex('t', "p^un_homme") + ('t', "w^un") + ('t', "w^homme"); + ex.set_label("2"); + cerr << ex.predict_partial() << endl; + } + + // AND FINISH UP + VW::finish(*vw); +} diff --git a/library/ezexample_predict_threaded.cc b/library/ezexample_predict_threaded.cc index 0fa5b1e6..c7c39e85 100644 --- a/library/ezexample_predict_threaded.cc +++ b/library/ezexample_predict_threaded.cc @@ -1,149 +1,149 @@ -#include <stdio.h>
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-#include <boost/thread/thread.hpp>
-
-using namespace std;
-
-int runcount = 100;
-
-class Worker
-{
-public:
- Worker(vw & instance, string & vw_init_string, vector<double> & ref)
- : m_vw(instance)
- , m_referenceValues(ref)
- , vw_init_string(vw_init_string)
- { }
-
- void operator()()
- {
- m_vw_parser = VW::initialize(vw_init_string);
- if (m_vw_parser == NULL) {
- cerr << "cannot initialize vw parser" << endl;
- exit(-1);
- }
-
- int errorCount = 0;
- for (int i = 0; i < runcount; ++i)
- {
- vector<double>::iterator it = m_referenceValues.begin();
- ezexample ex(&m_vw, false, m_vw_parser);
-
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- //if (*it != pred) { cerr << "fail!" << endl; ++errorCount; }
- //VW::finish_example(m_vw, vec2);
- ++it;
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- ++it;
-
- --ex; // remove the most recent namespace
- // add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- ++it;
-
- //cout << "."; cout.flush();
- }
- cerr << "error count = " << errorCount << endl;
- VW::finish(*m_vw_parser);
- m_vw_parser = NULL;
- }
-
-private:
- vw & m_vw;
- vw * m_vw_parser;
- vector<double> & m_referenceValues;
- string & vw_init_string;
-};
-
-int main(int argc, char *argv[])
-{
- if (argc != 3)
- {
- cerr << "need two args: threadcount runcount" << endl;
- return 1;
- }
- int threadcount = atoi(argv[1]);
- runcount = atoi(argv[2]);
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
- string vw_init_string_all = "-t --ldf_override s --quiet -q st --noconstant --hash all -i train.w";
- string vw_init_string_parser = "-t --ldf_override s --quiet -q st --noconstant --hash all --noop"; // this needs to have enough arguments to get the parser right
- vw*vw = VW::initialize(vw_init_string_all);
- vector<double> results;
-
- // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
- {
- ezexample ex(vw, false);
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near zero = " << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near one = " << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace
- // add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near one = " << ex.predict_partial() << endl;
- }
-
- if (threadcount == 0)
- {
- Worker w(*vw, vw_init_string_parser, results);
- w();
- }
- else
- {
- boost::thread_group tg;
- for (int t = 0; t < threadcount; ++t)
- {
- cerr << "starting thread " << t << endl;
- boost::thread * pt = tg.create_thread(Worker(*vw, vw_init_string_parser, results));
- }
- tg.join_all();
- cerr << "finished!" << endl;
- }
-
-
- // AND FINISH UP
- VW::finish(*vw);
-}
+#include <stdio.h> +#include "../vowpalwabbit/vw.h" +#include "../vowpalwabbit/ezexample.h" + +#include <boost/thread/thread.hpp> + +using namespace std; + +int runcount = 100; + +class Worker +{ +public: + Worker(vw & instance, string & vw_init_string, vector<double> & ref) + : m_vw(instance) + , m_referenceValues(ref) + , vw_init_string(vw_init_string) + { } + + void operator()() + { + m_vw_parser = VW::initialize(vw_init_string); + if (m_vw_parser == NULL) { + cerr << "cannot initialize vw parser" << endl; + exit(-1); + } + + int errorCount = 0; + for (int i = 0; i < runcount; ++i) + { + vector<double>::iterator it = m_referenceValues.begin(); + ezexample ex(&m_vw, false, m_vw_parser); + + ex(vw_namespace('s')) + ("p^the_man") + ("w^the") + ("w^man") + (vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + ex.set_label("1"); + if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; } + //if (*it != pred) { cerr << "fail!" << endl; ++errorCount; } + //VW::finish_example(m_vw, vec2); + ++it; + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + ex.set_label("1"); + if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; } + ++it; + + --ex; // remove the most recent namespace + // add features with explicit ns + ex('t', "p^un_homme") + ('t', "w^un") + ('t', "w^homme"); + ex.set_label("1"); + if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; } + ++it; + + //cout << "."; cout.flush(); + } + cerr << "error count = " << errorCount << endl; + VW::finish(*m_vw_parser); + m_vw_parser = NULL; + } + +private: + vw & m_vw; + vw * m_vw_parser; + vector<double> & m_referenceValues; + string & vw_init_string; +}; + +int main(int argc, char *argv[]) +{ + if (argc != 3) + { + cerr << "need two args: threadcount runcount" << endl; + return 1; + } + int threadcount = atoi(argv[1]); + runcount = atoi(argv[2]); + // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w + string vw_init_string_all = "-t --ldf_override s --quiet -q st --noconstant --hash all -i train.w"; + string vw_init_string_parser = "-t --ldf_override s --quiet -q st --noconstant --hash all --noop"; // this needs to have enough arguments to get the parser right + vw*vw = VW::initialize(vw_init_string_all); + vector<double> results; + + // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS + { + ezexample ex(vw, false); + ex(vw_namespace('s')) + ("p^the_man") + ("w^the") + ("w^man") + (vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + ex.set_label("1"); + results.push_back(ex.predict_partial()); + cerr << "should be near zero = " << ex.predict_partial() << endl; + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + ex.set_label("1"); + results.push_back(ex.predict_partial()); + cerr << "should be near one = " << ex.predict_partial() << endl; + + --ex; // remove the most recent namespace + // add features with explicit ns + ex('t', "p^un_homme") + ('t', "w^un") + ('t', "w^homme"); + ex.set_label("1"); + results.push_back(ex.predict_partial()); + cerr << "should be near one = " << ex.predict_partial() << endl; + } + + if (threadcount == 0) + { + Worker w(*vw, vw_init_string_parser, results); + w(); + } + else + { + boost::thread_group tg; + for (int t = 0; t < threadcount; ++t) + { + cerr << "starting thread " << t << endl; + boost::thread * pt = tg.create_thread(Worker(*vw, vw_init_string_parser, results)); + } + tg.join_all(); + cerr << "finished!" << endl; + } + + + // AND FINISH UP + VW::finish(*vw); +} diff --git a/library/ezexample_train.cc b/library/ezexample_train.cc index a0f66a99..9a8af8e0 100644 --- a/library/ezexample_train.cc +++ b/library/ezexample_train.cc @@ -1,72 +1,72 @@ -#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-using namespace std;
-
-void run(vw*vw) {
- ezexample ex(vw, true); // we're doing csoaa_ldf so we need multiline examples
-
- /// BEGIN FIRST MULTILINE EXAMPLE
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
-
- ex.set_label("1:1");
- ex.train();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
-
- ex.set_label("2:0");
- ex.train();
-
- // push it through VW for training
- ex.finish();
-
- /// BEGIN SECOND MULTILINE EXAMPLE
- ex(vw_namespace('s'))
- ("p^a_man")
- ("w^a")
- ("w^man")
- (vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
-
- ex.set_label("1:0");
- ex.train();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
-
- ex.set_label("2:1");
- ex.train();
-
- // push it through VW for training
- ex.finish();
-}
-
-int main(int argc, char *argv[])
-{
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS WILL STORE A MODEL TO train.ezw
- vw* vw = VW::initialize("--hash all -q st --noconstant -f train.w --quiet --csoaa_ldf m");
-
- run(vw);
-
- // AND FINISH UP
- cerr << "ezexample_train finish"<<endl;
- VW::finish(*vw);
-}
+#include <stdio.h> +#include "../vowpalwabbit/parser.h" +#include "../vowpalwabbit/vw.h" +#include "../vowpalwabbit/ezexample.h" + +using namespace std; + +void run(vw*vw) { + ezexample ex(vw, true); // we're doing csoaa_ldf so we need multiline examples + + /// BEGIN FIRST MULTILINE EXAMPLE + ex(vw_namespace('s')) + ("p^the_man") + ("w^the") + ("w^man") + (vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + + ex.set_label("1:1"); + ex.train(); + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + + ex.set_label("2:0"); + ex.train(); + + // push it through VW for training + ex.finish(); + + /// BEGIN SECOND MULTILINE EXAMPLE + ex(vw_namespace('s')) + ("p^a_man") + ("w^a") + ("w^man") + (vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + + ex.set_label("1:0"); + ex.train(); + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + + ex.set_label("2:1"); + ex.train(); + + // push it through VW for training + ex.finish(); +} + +int main(int argc, char *argv[]) +{ + // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS WILL STORE A MODEL TO train.ezw + vw* vw = VW::initialize("--hash all -q st --noconstant -f train.w --quiet --csoaa_ldf m"); + + run(vw); + + // AND FINISH UP + cerr << "ezexample_train finish"<<endl; + VW::finish(*vw); +} diff --git a/library/library_example.cc b/library/library_example.cc index 8abfc1f7..d7c186c2 100644 --- a/library/library_example.cc +++ b/library/library_example.cc @@ -1,62 +1,62 @@ -#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-
-using namespace std;
-
-
-inline feature vw_feature_from_string(vw& v, string fstr, unsigned long seed, float val)
-{
- uint32_t foo = VW::hash_feature(v, fstr, seed);
- feature f = { val, foo};
- return f;
-}
-
-int main(int argc, char *argv[])
-{
- vw* model = VW::initialize("--hash all -q st --noconstant -i train.w -f train2.vw");
-
- example *vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme");
- model->learn(vec2);
- cerr << "p2 = " << vec2->pred.scalar << endl;
- VW::finish_example(*model, vec2);
-
- vector< VW::feature_space > ec_info;
- vector<feature> s_features, t_features;
- uint32_t s_hash = VW::hash_space(*model, "s");
- uint32_t t_hash = VW::hash_space(*model, "t");
- s_features.push_back( vw_feature_from_string(*model, "p^the_man", s_hash, 1.0) );
- s_features.push_back( vw_feature_from_string(*model, "w^the", s_hash, 1.0) );
- s_features.push_back( vw_feature_from_string(*model, "w^man", s_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "p^le_homme", t_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "w^le", t_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "w^homme", t_hash, 1.0) );
- ec_info.push_back( VW::feature_space('s', s_features) );
- ec_info.push_back( VW::feature_space('t', t_features) );
- example* vec3 = VW::import_example(*model, ec_info);
-
- model->learn(vec3);
- cerr << "p3 = " << vec3->pred.scalar << endl;
- VW::finish_example(*model, vec3);
-
- VW::finish(*model);
-
- vw* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw");
- vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme");
- model2->learn(vec2);
- cerr << "p4 = " << vec2->pred.scalar << endl;
-
- size_t len=0;
- VW::primitive_feature_space* pfs = VW::export_example(*model2, vec2, len);
- for (size_t i = 0; i < len; i++)
- {
- cout << "namespace = " << pfs[i].name;
- for (size_t j = 0; j < pfs[i].len; j++)
- cout << " " << pfs[i].fs[j].weight_index << ":" << pfs[i].fs[j].x << ":" << VW::get_weight(*model2, pfs[i].fs[j].weight_index, 0);
- cout << endl;
- }
-
- VW::finish_example(*model2, vec2);
- VW::finish(*model2);
-}
-
+#include <stdio.h> +#include "../vowpalwabbit/parser.h" +#include "../vowpalwabbit/vw.h" + +using namespace std; + + +inline feature vw_feature_from_string(vw& v, string fstr, unsigned long seed, float val) +{ + uint32_t foo = VW::hash_feature(v, fstr, seed); + feature f = { val, foo}; + return f; +} + +int main(int argc, char *argv[]) +{ + vw* model = VW::initialize("--hash all -q st --noconstant -i train.w -f train2.vw"); + + example *vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme"); + model->learn(vec2); + cerr << "p2 = " << vec2->pred.scalar << endl; + VW::finish_example(*model, vec2); + + vector< VW::feature_space > ec_info; + vector<feature> s_features, t_features; + uint32_t s_hash = VW::hash_space(*model, "s"); + uint32_t t_hash = VW::hash_space(*model, "t"); + s_features.push_back( vw_feature_from_string(*model, "p^the_man", s_hash, 1.0) ); + s_features.push_back( vw_feature_from_string(*model, "w^the", s_hash, 1.0) ); + s_features.push_back( vw_feature_from_string(*model, "w^man", s_hash, 1.0) ); + t_features.push_back( vw_feature_from_string(*model, "p^le_homme", t_hash, 1.0) ); + t_features.push_back( vw_feature_from_string(*model, "w^le", t_hash, 1.0) ); + t_features.push_back( vw_feature_from_string(*model, "w^homme", t_hash, 1.0) ); + ec_info.push_back( VW::feature_space('s', s_features) ); + ec_info.push_back( VW::feature_space('t', t_features) ); + example* vec3 = VW::import_example(*model, ec_info); + + model->learn(vec3); + cerr << "p3 = " << vec3->pred.scalar << endl; + VW::finish_example(*model, vec3); + + VW::finish(*model); + + vw* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw"); + vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme"); + model2->learn(vec2); + cerr << "p4 = " << vec2->pred.scalar << endl; + + size_t len=0; + VW::primitive_feature_space* pfs = VW::export_example(*model2, vec2, len); + for (size_t i = 0; i < len; i++) + { + cout << "namespace = " << pfs[i].name; + for (size_t j = 0; j < pfs[i].len; j++) + cout << " " << pfs[i].fs[j].weight_index << ":" << pfs[i].fs[j].x << ":" << VW::get_weight(*model2, pfs[i].fs[j].weight_index, 0); + cout << endl; + } + + VW::finish_example(*model2, vec2); + VW::finish(*model2); +} + diff --git a/python/Makefile b/python/Makefile index b268663d..845fb15f 100644 --- a/python/Makefile +++ b/python/Makefile @@ -43,3 +43,5 @@ pylibvw.o: pylibvw.cc clean: rm -f *.o $(PYLIBVW) + +.PHONY: all clean things diff --git a/vowpalwabbit/Makefile b/vowpalwabbit/Makefile index 79893902..9ecd681c 100644 --- a/vowpalwabbit/Makefile +++ b/vowpalwabbit/Makefile @@ -52,3 +52,4 @@ install: $(BINARIES) clean: rm -f *.o *.d $(BINARIES) *~ $(MANPAGES) libvw.a +.PHONY: all clean install test things diff --git a/vowpalwabbit/accumulate.h b/vowpalwabbit/accumulate.h index c01ac5fe..4d507a60 100644 --- a/vowpalwabbit/accumulate.h +++ b/vowpalwabbit/accumulate.h @@ -1,13 +1,13 @@ -/*
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD
-license as described in the file LICENSE.
- */
-//This implements various accumulate functions building on top of allreduce.
-#pragma once
-#include "global_data.h"
-
-void accumulate(vw& all, std::string master_location, regressor& reg, size_t o);
-float accumulate_scalar(vw& all, std::string master_location, float local_sum);
-void accumulate_weighted_avg(vw& all, std::string master_location, regressor& reg);
-void accumulate_avg(vw& all, std::string master_location, regressor& reg, size_t o);
+/* +Copyright (c) by respective owners including Yahoo!, Microsoft, and +individual contributors. All rights reserved. Released under a BSD +license as described in the file LICENSE. + */ +//This implements various accumulate functions building on top of allreduce. +#pragma once +#include "global_data.h" + +void accumulate(vw& all, std::string master_location, regressor& reg, size_t o); +float accumulate_scalar(vw& all, std::string master_location, float local_sum); +void accumulate_weighted_avg(vw& all, std::string master_location, regressor& reg); +void accumulate_avg(vw& all, std::string master_location, regressor& reg, size_t o); |