Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2015-01-04 04:45:55 +0300
committerJohn Langford <jl@hunch.net>2015-01-04 04:45:55 +0300
commitb913d178acb0052f4bb9622dd253a4ddc684a4ce (patch)
tree2a7c6348009dea6ee77caf67af897473fb2b77aa
parent00ce5188f90cc426493872e8d630b6c61f99df79 (diff)
parented8c4d38aba3b49159f1b2574028b5cbae96a7f2 (diff)
resolve conflicts
-rw-r--r--Makefile2
-rw-r--r--demo/dna/Makefile2
-rw-r--r--demo/entityrelation/Makefile5
-rw-r--r--demo/movielens/Makefile4
-rw-r--r--demo/normalized/Makefile26
-rw-r--r--explore/clr/explore_clr_wrapper.cpp36
-rw-r--r--explore/clr/explore_clr_wrapper.h874
-rw-r--r--explore/clr/explore_interface.h174
-rw-r--r--explore/clr/explore_interop.h714
-rw-r--r--explore/explore.cpp146
-rw-r--r--explore/static/MWTExplorer.h1338
-rw-r--r--explore/static/utility.h566
-rw-r--r--explore/tests/MWTExploreTests.h328
-rw-r--r--library/Makefile1
-rw-r--r--library/ezexample_predict.cc110
-rw-r--r--library/ezexample_predict_threaded.cc298
-rw-r--r--library/ezexample_train.cc144
-rw-r--r--library/library_example.cc124
-rw-r--r--python/Makefile2
-rw-r--r--vowpalwabbit/Makefile1
-rw-r--r--vowpalwabbit/accumulate.h26
21 files changed, 2466 insertions, 2455 deletions
diff --git a/Makefile b/Makefile
index 0d53bb54..1f1de74b 100644
--- a/Makefile
+++ b/Makefile
@@ -117,3 +117,5 @@ clean:
ifneq ($(JAVA_HOME),)
cd java && $(MAKE) clean
endif
+
+.PHONY: all clean install
diff --git a/demo/dna/Makefile b/demo/dna/Makefile
index 93efac3b..477cfb98 100644
--- a/demo/dna/Makefile
+++ b/demo/dna/Makefile
@@ -27,3 +27,5 @@ quaddna2vw: quaddna2vw.cpp
%.perf: %.test.predictions perf.check perl.check zsh.check
./do-perf $<
+
+.PHONY: all clean
diff --git a/demo/entityrelation/Makefile b/demo/entityrelation/Makefile
index c287d97c..c3db3a69 100644
--- a/demo/entityrelation/Makefile
+++ b/demo/entityrelation/Makefile
@@ -8,7 +8,7 @@ eval_script=./evaluationER.py
all:
@cat README.md
clean:
- rm -f *.model *.predictions ER_*.vw *.cache evaluationER.py er.zip *~
+ rm -f *.model *.predictions ER_*.vw *.cache evaluationER.py er.zip *~
%.check:
@test -x "$$(which $*)" || { \
@@ -33,5 +33,4 @@ er.test.predictions: ER_test.vw er.model
er.perf: ER_test.vw er.test.predictions python.check
@$(eval_script) $< er.test.predictions
-
-
+.PHONY: all clean
diff --git a/demo/movielens/Makefile b/demo/movielens/Makefile
index 68be5310..46c92e24 100644
--- a/demo/movielens/Makefile
+++ b/demo/movielens/Makefile
@@ -39,7 +39,7 @@ ml-%.ratings.train.vw: ml-%/ratings.dat
@printf "%s test MAE is %3.3f\n" $* $$(cat $*.results)
#---------------------------------------------------------------------
-# linear model (no interaction terms)
+# linear model (no interaction terms)
#---------------------------------------------------------------------
linear.results: ml-1m.ratings.test.vw ml-1m.ratings.train.vw
@@ -154,3 +154,5 @@ lrqdropouthogwild.results: ml-1m.ratings.test.vw ml-1m.ratings.train.vw do-lrq-h
-d $(word 1,$+) -p \
>(perl -lane '$$s+=abs(($$F[0]-$$F[1])); } { \
1; print $$s/$$.;' > $@)
+
+.PHONY: all clean shootout
diff --git a/demo/normalized/Makefile b/demo/normalized/Makefile
index 8268e82f..f1378b80 100644
--- a/demo/normalized/Makefile
+++ b/demo/normalized/Makefile
@@ -19,8 +19,8 @@ SHUFFLE='BEGIN { srand 69; }; \
$$b[$$i] = $$_; } { print grep { defined $$_ } @b;'
#---------------------------------------------------------------------
-# bank marketing
-#
+# bank marketing
+#
# normalization really helps. The columns have units of euros, seconds,
# days, and years; in addition there are categorical variables.
#---------------------------------------------------------------------
@@ -57,10 +57,10 @@ bankbestmin=1e-2
bankbestmax=10
banknonormbestmin=1e-8
banknonormbestmax=1e-3
-banktimeestimate=1
+banktimeestimate=1
#---------------------------------------------------------------------
-# covertype
+# covertype
#---------------------------------------------------------------------
covtype.data.gz:
@@ -88,11 +88,11 @@ covertypebestmin=1e-2
covertypebestmax=10
covertypenonormbestmin=1e-8
covertypenonormbestmax=1e-5
-covertypetimeestimate=5
+covertypetimeestimate=5
#---------------------------------------------------------------------
-# million song database
-#
+# million song database
+#
# normalization is helpful.
#---------------------------------------------------------------------
@@ -142,7 +142,7 @@ MSDnonormbestmax=1e-5
MSDtimeestimate=15
#---------------------------------------------------------------------
-# census-income (KDD)
+# census-income (KDD)
#---------------------------------------------------------------------
census-income.data.gz:
@@ -180,7 +180,7 @@ censusnonormbestmax=1e-4
censustimeestimate=5
#---------------------------------------------------------------------
-# Statlog (Shuttle)
+# Statlog (Shuttle)
#---------------------------------------------------------------------
shuttle.trn.Z:
@@ -211,7 +211,7 @@ shuttlenonormbestmax=1e-3
shuttletimeestimate=1
#---------------------------------------------------------------------
-# CT slices
+# CT slices
#
# normalization doesn't help much
#---------------------------------------------------------------------
@@ -256,7 +256,7 @@ CTslicenonormbestmax=1
CTslicetimeestimate=15
#---------------------------------------------------------------------
-# common routines
+# common routines
#---------------------------------------------------------------------
%.best: %.data
@@ -269,8 +269,10 @@ CTslicetimeestimate=15
@printf "WARNING: this step takes about %s minutes\n" $($*timeestimate)
@./hypersearch $($*nonormbestmin) $($*nonormbestmax) '$(MAKE)' '$*.%.nonormlearn' > $@
-%.resultsprint:
+%.resultsprint:
@printf "%20.20s\t%9.3g\t%9.3g\t%9.3g\t%9.3g\n" "$*" $$(cut -f1 $*.best) $$(cut -f2 $*.best) $$(cut -f1 $*.nonormbest) $$(cut -f2 $*.nonormbest)
only.%: %.best %.nonormbest all.results.pre %.resultsprint
@true
+
+.PHONY: all all.results all.results.pre
diff --git a/explore/clr/explore_clr_wrapper.cpp b/explore/clr/explore_clr_wrapper.cpp
index 0564482b..be022af1 100644
--- a/explore/clr/explore_clr_wrapper.cpp
+++ b/explore/clr/explore_clr_wrapper.cpp
@@ -1,18 +1,18 @@
-// vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application.
-//
-
-#define WIN32_LEAN_AND_MEAN
-#include <Windows.h>
-
-#include "explore_clr_wrapper.h"
-
-using namespace System;
-using namespace System::Collections;
-using namespace System::Collections::Generic;
-using namespace System::Runtime::InteropServices;
-using namespace msclr::interop;
-using namespace NativeMultiWorldTesting;
-
-namespace MultiWorldTesting {
-
-}
+// vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application.
+//
+
+#define WIN32_LEAN_AND_MEAN
+#include <Windows.h>
+
+#include "explore_clr_wrapper.h"
+
+using namespace System;
+using namespace System::Collections;
+using namespace System::Collections::Generic;
+using namespace System::Runtime::InteropServices;
+using namespace msclr::interop;
+using namespace NativeMultiWorldTesting;
+
+namespace MultiWorldTesting {
+
+}
diff --git a/explore/clr/explore_clr_wrapper.h b/explore/clr/explore_clr_wrapper.h
index a6cede4b..c73c95d4 100644
--- a/explore/clr/explore_clr_wrapper.h
+++ b/explore/clr/explore_clr_wrapper.h
@@ -1,438 +1,438 @@
-#pragma once
-#include "explore_interop.h"
-
-/*!
-* \addtogroup MultiWorldTestingCsharp
-* @{
-*/
-namespace MultiWorldTesting {
-
- /// <summary>
- /// The epsilon greedy exploration class.
- /// </summary>
- /// <remarks>
- /// This is a good choice if you have no idea which actions should be preferred.
- /// Epsilon greedy is also computationally cheap.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicy">A default function which outputs an action given a context.</param>
- /// <param name="epsilon">The probability of a random exploration.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions)
- {
- this->defaultPolicy = defaultPolicy;
- m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions);
- }
-
- ~EpsilonGreedyExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- return defaultPolicy->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IPolicy<Ctx>^ defaultPolicy;
- NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The tau-first exploration class.
- /// </summary>
- /// <remarks>
- /// The tau-first explorer collects precisely tau uniform random
- /// exploration events, and then uses the default policy.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class TauFirstExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicy">A default policy after randomization finishes.</param>
- /// <param name="tau">The number of events to be uniform over.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions)
- {
- this->defaultPolicy = defaultPolicy;
- m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions);
- }
-
- ~TauFirstExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- return defaultPolicy->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IPolicy<Ctx>^ defaultPolicy;
- NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The epsilon greedy exploration class.
- /// </summary>
- /// <remarks>
- /// In some cases, different actions have a different scores, and you
- /// would prefer to choose actions with large scores. Softmax allows
- /// you to do that.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class SoftmaxExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultScorer">A function which outputs a score for each action.</param>
- /// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions)
- {
- this->defaultScorer = defaultScorer;
- m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions);
- }
-
- ~SoftmaxExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) override
- {
- return defaultScorer->ScoreActions(context);
- }
-
- NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IScorer<Ctx>^ defaultScorer;
- NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The generic exploration class.
- /// </summary>
- /// <remarks>
- /// GenericExplorer provides complete flexibility. You can create any
- /// distribution over actions desired, and it will draw from that.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class GenericExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultScorer">A function which outputs the probability of each action.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions)
- {
- this->defaultScorer = defaultScorer;
- m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions);
- }
-
- ~GenericExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) override
- {
- return defaultScorer->ScoreActions(context);
- }
-
- NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IScorer<Ctx>^ defaultScorer;
- NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The bootstrap exploration class.
- /// </summary>
- /// <remarks>
- /// The Bootstrap explorer randomizes over the actions chosen by a set of
- /// default policies. This performs well statistically but can be
- /// computationally expensive.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class BootstrapExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicies">A set of default policies to be uniform random over.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions)
- {
- this->defaultPolicies = defaultPolicies;
- if (this->defaultPolicies == nullptr)
- {
- throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null.");
- }
-
- m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions);
- }
-
- ~BootstrapExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- if (index < 0 || index >= defaultPolicies->Length)
- {
- throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range.");
- }
- return defaultPolicies[index]->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- cli::array<IPolicy<Ctx>^>^ defaultPolicies;
- NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The top level MwtExplorer class. Using this makes sure that the
- /// right bits are recorded and good random actions are chosen.
- /// </summary>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class MwtExplorer : public RecorderCallback<Ctx>
- {
- public:
- /// <summary>
- /// Constructor.
- /// </summary>
- /// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param>
- /// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param>
- MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder)
- {
- this->appId = appId;
- this->recorder = recorder;
- }
-
- /// <summary>
- /// Choose_Action should be drop-in replacement for any existing policy function.
- /// </summary>
- /// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param>
- /// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param>
- /// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param>
- /// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns>
- UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context)
- {
- String^ salt = this->appId;
- NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder());
-
- GCHandle selfHandle = GCHandle::Alloc(this);
- IntPtr selfPtr = (IntPtr)selfHandle;
-
- GCHandle contextHandle = GCHandle::Alloc(context);
- IntPtr contextPtr = (IntPtr)contextHandle;
-
- GCHandle explorerHandle = GCHandle::Alloc(explorer);
- IntPtr explorerPtr = (IntPtr)explorerHandle;
-
- NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer());
- u32 action = 0;
- if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid)
- {
- EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid)
- {
- TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid)
- {
- SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == GenericExplorer<Ctx>::typeid)
- {
- GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid)
- {
- BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
-
- explorerHandle.Free();
- contextHandle.Free();
- selfHandle.Free();
-
- return action;
- }
-
- internal:
- virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override
- {
- recorder->Record(context, action, probability, unique_key);
- }
-
- private:
- IRecorder<Ctx>^ recorder;
- String^ appId;
- };
-
- /// <summary>
- /// Represents a feature in a sparse array.
- /// </summary>
- [StructLayout(LayoutKind::Sequential)]
- public value struct Feature
- {
- float Value;
- UInt32 Id;
- };
-
- /// <summary>
- /// A sample recorder class that converts the exploration tuple into string format.
- /// </summary>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx> where Ctx : IStringContext
- public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx>
- {
- public:
- StringRecorder()
- {
- m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>();
- }
-
- ~StringRecorder()
- {
- delete m_string_recorder;
- }
-
- virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey)
- {
- GCHandle contextHandle = GCHandle::Alloc(context);
- IntPtr contextPtr = (IntPtr)contextHandle;
-
- NativeStringContext native_context(contextPtr.ToPointer(), GetCallback());
- m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey));
- }
-
- /// <summary>
- /// Gets the content of the recording so far as a string and clears internal content.
- /// </summary>
- /// <returns>
- /// A string with recording content.
- /// </returns>
- String^ GetRecording()
- {
- // Workaround for C++-CLI bug which does not allow default value for parameter
- return GetRecording(true);
- }
-
- /// <summary>
- /// Gets the content of the recording so far as a string and optionally clears internal content.
- /// </summary>
- /// <param name="flush">A boolean value indicating whether to clear the internal content.</param>
- /// <returns>
- /// A string with recording content.
- /// </returns>
- String^ GetRecording(bool flush)
- {
- return gcnew String(m_string_recorder->Get_Recording(flush).c_str());
- }
-
- private:
- NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder;
- };
-
- /// <summary>
- /// A sample context class that stores a vector of Features.
- /// </summary>
- public ref class SimpleContext : public IStringContext
- {
- public:
- SimpleContext(cli::array<Feature>^ features)
- {
- Features = features;
-
- // TODO: add another constructor overload for native SimpleContext to avoid copying feature values
- m_features = new vector<NativeMultiWorldTesting::Feature>();
- for (int i = 0; i < features->Length; i++)
- {
- m_features->push_back({ features[i].Value, features[i].Id });
- }
-
- m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features);
- }
-
- String^ ToString() override
- {
- return gcnew String(m_native_context->To_String().c_str());
- }
-
- ~SimpleContext()
- {
- delete m_native_context;
- }
-
- public:
- cli::array<Feature>^ GetFeatures() { return Features; }
-
- internal:
- cli::array<Feature>^ Features;
-
- private:
- vector<NativeMultiWorldTesting::Feature>* m_features;
- NativeMultiWorldTesting::SimpleContext* m_native_context;
- };
-}
-
+#pragma once
+#include "explore_interop.h"
+
+/*!
+* \addtogroup MultiWorldTestingCsharp
+* @{
+*/
+namespace MultiWorldTesting {
+
+ /// <summary>
+ /// The epsilon greedy exploration class.
+ /// </summary>
+ /// <remarks>
+ /// This is a good choice if you have no idea which actions should be preferred.
+ /// Epsilon greedy is also computationally cheap.
+ /// </remarks>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
+ {
+ public:
+ /// <summary>
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ /// </summary>
+ /// <param name="defaultPolicy">A default function which outputs an action given a context.</param>
+ /// <param name="epsilon">The probability of a random exploration.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions)
+ {
+ this->defaultPolicy = defaultPolicy;
+ m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions);
+ }
+
+ ~EpsilonGreedyExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
+ {
+ return defaultPolicy->ChooseAction(context);
+ }
+
+ NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ IPolicy<Ctx>^ defaultPolicy;
+ NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer;
+ };
+
+ /// <summary>
+ /// The tau-first exploration class.
+ /// </summary>
+ /// <remarks>
+ /// The tau-first explorer collects precisely tau uniform random
+ /// exploration events, and then uses the default policy.
+ /// </remarks>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class TauFirstExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
+ {
+ public:
+ /// <summary>
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ /// </summary>
+ /// <param name="defaultPolicy">A default policy after randomization finishes.</param>
+ /// <param name="tau">The number of events to be uniform over.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions)
+ {
+ this->defaultPolicy = defaultPolicy;
+ m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions);
+ }
+
+ ~TauFirstExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
+ {
+ return defaultPolicy->ChooseAction(context);
+ }
+
+ NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ IPolicy<Ctx>^ defaultPolicy;
+ NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer;
+ };
+
+ /// <summary>
+ /// The epsilon greedy exploration class.
+ /// </summary>
+ /// <remarks>
+ /// In some cases, different actions have a different scores, and you
+ /// would prefer to choose actions with large scores. Softmax allows
+ /// you to do that.
+ /// </remarks>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class SoftmaxExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
+ {
+ public:
+ /// <summary>
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ /// </summary>
+ /// <param name="defaultScorer">A function which outputs a score for each action.</param>
+ /// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions)
+ {
+ this->defaultScorer = defaultScorer;
+ m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions);
+ }
+
+ ~SoftmaxExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual List<float>^ InvokeScorerCallback(Ctx context) override
+ {
+ return defaultScorer->ScoreActions(context);
+ }
+
+ NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ IScorer<Ctx>^ defaultScorer;
+ NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer;
+ };
+
+ /// <summary>
+ /// The generic exploration class.
+ /// </summary>
+ /// <remarks>
+ /// GenericExplorer provides complete flexibility. You can create any
+ /// distribution over actions desired, and it will draw from that.
+ /// </remarks>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class GenericExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
+ {
+ public:
+ /// <summary>
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ /// </summary>
+ /// <param name="defaultScorer">A function which outputs the probability of each action.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions)
+ {
+ this->defaultScorer = defaultScorer;
+ m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions);
+ }
+
+ ~GenericExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual List<float>^ InvokeScorerCallback(Ctx context) override
+ {
+ return defaultScorer->ScoreActions(context);
+ }
+
+ NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ IScorer<Ctx>^ defaultScorer;
+ NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer;
+ };
+
+ /// <summary>
+ /// The bootstrap exploration class.
+ /// </summary>
+ /// <remarks>
+ /// The Bootstrap explorer randomizes over the actions chosen by a set of
+ /// default policies. This performs well statistically but can be
+ /// computationally expensive.
+ /// </remarks>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class BootstrapExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
+ {
+ public:
+ /// <summary>
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ /// </summary>
+ /// <param name="defaultPolicies">A set of default policies to be uniform random over.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions)
+ {
+ this->defaultPolicies = defaultPolicies;
+ if (this->defaultPolicies == nullptr)
+ {
+ throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null.");
+ }
+
+ m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions);
+ }
+
+ ~BootstrapExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
+ {
+ if (index < 0 || index >= defaultPolicies->Length)
+ {
+ throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range.");
+ }
+ return defaultPolicies[index]->ChooseAction(context);
+ }
+
+ NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ cli::array<IPolicy<Ctx>^>^ defaultPolicies;
+ NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer;
+ };
+
+ /// <summary>
+ /// The top level MwtExplorer class. Using this makes sure that the
+ /// right bits are recorded and good random actions are chosen.
+ /// </summary>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class MwtExplorer : public RecorderCallback<Ctx>
+ {
+ public:
+ /// <summary>
+ /// Constructor.
+ /// </summary>
+ /// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param>
+ /// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param>
+ MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder)
+ {
+ this->appId = appId;
+ this->recorder = recorder;
+ }
+
+ /// <summary>
+ /// Choose_Action should be drop-in replacement for any existing policy function.
+ /// </summary>
+ /// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param>
+ /// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param>
+ /// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param>
+ /// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns>
+ UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context)
+ {
+ String^ salt = this->appId;
+ NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder());
+
+ GCHandle selfHandle = GCHandle::Alloc(this);
+ IntPtr selfPtr = (IntPtr)selfHandle;
+
+ GCHandle contextHandle = GCHandle::Alloc(context);
+ IntPtr contextPtr = (IntPtr)contextHandle;
+
+ GCHandle explorerHandle = GCHandle::Alloc(explorer);
+ IntPtr explorerPtr = (IntPtr)explorerHandle;
+
+ NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer());
+ u32 action = 0;
+ if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid)
+ {
+ EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+ else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid)
+ {
+ TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+ else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid)
+ {
+ SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+ else if (explorer->GetType() == GenericExplorer<Ctx>::typeid)
+ {
+ GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+ else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid)
+ {
+ BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+
+ explorerHandle.Free();
+ contextHandle.Free();
+ selfHandle.Free();
+
+ return action;
+ }
+
+ internal:
+ virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override
+ {
+ recorder->Record(context, action, probability, unique_key);
+ }
+
+ private:
+ IRecorder<Ctx>^ recorder;
+ String^ appId;
+ };
+
+ /// <summary>
+ /// Represents a feature in a sparse array.
+ /// </summary>
+ [StructLayout(LayoutKind::Sequential)]
+ public value struct Feature
+ {
+ float Value;
+ UInt32 Id;
+ };
+
+ /// <summary>
+ /// A sample recorder class that converts the exploration tuple into string format.
+ /// </summary>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx> where Ctx : IStringContext
+ public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx>
+ {
+ public:
+ StringRecorder()
+ {
+ m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>();
+ }
+
+ ~StringRecorder()
+ {
+ delete m_string_recorder;
+ }
+
+ virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey)
+ {
+ GCHandle contextHandle = GCHandle::Alloc(context);
+ IntPtr contextPtr = (IntPtr)contextHandle;
+
+ NativeStringContext native_context(contextPtr.ToPointer(), GetCallback());
+ m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey));
+ }
+
+ /// <summary>
+ /// Gets the content of the recording so far as a string and clears internal content.
+ /// </summary>
+ /// <returns>
+ /// A string with recording content.
+ /// </returns>
+ String^ GetRecording()
+ {
+ // Workaround for C++-CLI bug which does not allow default value for parameter
+ return GetRecording(true);
+ }
+
+ /// <summary>
+ /// Gets the content of the recording so far as a string and optionally clears internal content.
+ /// </summary>
+ /// <param name="flush">A boolean value indicating whether to clear the internal content.</param>
+ /// <returns>
+ /// A string with recording content.
+ /// </returns>
+ String^ GetRecording(bool flush)
+ {
+ return gcnew String(m_string_recorder->Get_Recording(flush).c_str());
+ }
+
+ private:
+ NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder;
+ };
+
+ /// <summary>
+ /// A sample context class that stores a vector of Features.
+ /// </summary>
+ public ref class SimpleContext : public IStringContext
+ {
+ public:
+ SimpleContext(cli::array<Feature>^ features)
+ {
+ Features = features;
+
+ // TODO: add another constructor overload for native SimpleContext to avoid copying feature values
+ m_features = new vector<NativeMultiWorldTesting::Feature>();
+ for (int i = 0; i < features->Length; i++)
+ {
+ m_features->push_back({ features[i].Value, features[i].Id });
+ }
+
+ m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features);
+ }
+
+ String^ ToString() override
+ {
+ return gcnew String(m_native_context->To_String().c_str());
+ }
+
+ ~SimpleContext()
+ {
+ delete m_native_context;
+ }
+
+ public:
+ cli::array<Feature>^ GetFeatures() { return Features; }
+
+ internal:
+ cli::array<Feature>^ Features;
+
+ private:
+ vector<NativeMultiWorldTesting::Feature>* m_features;
+ NativeMultiWorldTesting::SimpleContext* m_native_context;
+ };
+}
+
/*! @} End of Doxygen Groups*/ \ No newline at end of file
diff --git a/explore/clr/explore_interface.h b/explore/clr/explore_interface.h
index 3212b4d8..dfa17697 100644
--- a/explore/clr/explore_interface.h
+++ b/explore/clr/explore_interface.h
@@ -1,88 +1,88 @@
-#pragma once
-
-using namespace System;
-using namespace System::Collections::Generic;
-
-/** \defgroup MultiWorldTestingCsharp
-\brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
-*/
-
-/*!
-* \addtogroup MultiWorldTestingCsharp
-* @{
-*/
-
-//! Interface for C# version of Multiworld Testing library.
-//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
-namespace MultiWorldTesting {
-
-/// <summary>
-/// Represents a recorder that exposes a method to record exploration data based on generic contexts.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-/// <remarks>
-/// Exploration data is specified as a set of tuples <context, action, probability, key> as described below. An
-/// application passes an IRecorder object to the @MwtExplorer constructor. See
-/// @StringRecorder for a sample IRecorder object.
-/// </remarks>
-generic <class Ctx>
-public interface class IRecorder
-{
-public:
- /// <summary>
- /// Records the exploration data associated with a given decision.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <param name="action">Chosen by an exploration algorithm given context.</param>
- /// <param name="probability">The probability of the chosen action given context.</param>
- /// <param name="uniqueKey">A user-defined identifer for the decision.</param>
- virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0;
-};
-
-/// <summary>
-/// Exposes a method for choosing an action given a generic context. IPolicy objects are
-/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-generic <class Ctx>
-public interface class IPolicy
-{
-public:
- /// <summary>
- /// Determines the action to take for a given context.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <returns>Index of the action to take (1-based)</returns>
- virtual UInt32 ChooseAction(Ctx context) = 0;
-};
-
-/// <summary>
-/// Exposes a method for specifying a score (weight) for each action given a generic context.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-generic <class Ctx>
-public interface class IScorer
-{
-public:
- /// <summary>
- /// Determines the score of each action for a given context.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <returns>Vector of scores indexed by action (1-based).</returns>
- virtual List<float>^ ScoreActions(Ctx context) = 0;
-};
-
-generic <class Ctx>
-public interface class IExplorer
-{
-};
-
-public interface class IStringContext
-{
-public:
- virtual String^ ToString() = 0;
-};
-
-}
-
+#pragma once
+
+using namespace System;
+using namespace System::Collections::Generic;
+
+/** \defgroup MultiWorldTestingCsharp
+\brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
+*/
+
+/*!
+* \addtogroup MultiWorldTestingCsharp
+* @{
+*/
+
+//! Interface for C# version of Multiworld Testing library.
+//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
+namespace MultiWorldTesting {
+
+/// <summary>
+/// Represents a recorder that exposes a method to record exploration data based on generic contexts.
+/// </summary>
+/// <typeparam name="Ctx">The Context type.</typeparam>
+/// <remarks>
+/// Exploration data is specified as a set of tuples <context, action, probability, key> as described below. An
+/// application passes an IRecorder object to the @MwtExplorer constructor. See
+/// @StringRecorder for a sample IRecorder object.
+/// </remarks>
+generic <class Ctx>
+public interface class IRecorder
+{
+public:
+ /// <summary>
+ /// Records the exploration data associated with a given decision.
+ /// </summary>
+ /// <param name="context">A user-defined context for the decision.</param>
+ /// <param name="action">Chosen by an exploration algorithm given context.</param>
+ /// <param name="probability">The probability of the chosen action given context.</param>
+ /// <param name="uniqueKey">A user-defined identifer for the decision.</param>
+ virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0;
+};
+
+/// <summary>
+/// Exposes a method for choosing an action given a generic context. IPolicy objects are
+/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
+/// </summary>
+/// <typeparam name="Ctx">The Context type.</typeparam>
+generic <class Ctx>
+public interface class IPolicy
+{
+public:
+ /// <summary>
+ /// Determines the action to take for a given context.
+ /// </summary>
+ /// <param name="context">A user-defined context for the decision.</param>
+ /// <returns>Index of the action to take (1-based)</returns>
+ virtual UInt32 ChooseAction(Ctx context) = 0;
+};
+
+/// <summary>
+/// Exposes a method for specifying a score (weight) for each action given a generic context.
+/// </summary>
+/// <typeparam name="Ctx">The Context type.</typeparam>
+generic <class Ctx>
+public interface class IScorer
+{
+public:
+ /// <summary>
+ /// Determines the score of each action for a given context.
+ /// </summary>
+ /// <param name="context">A user-defined context for the decision.</param>
+ /// <returns>Vector of scores indexed by action (1-based).</returns>
+ virtual List<float>^ ScoreActions(Ctx context) = 0;
+};
+
+generic <class Ctx>
+public interface class IExplorer
+{
+};
+
+public interface class IStringContext
+{
+public:
+ virtual String^ ToString() = 0;
+};
+
+}
+
/*! @} End of Doxygen Groups*/ \ No newline at end of file
diff --git a/explore/clr/explore_interop.h b/explore/clr/explore_interop.h
index 99466945..a6cb9327 100644
--- a/explore/clr/explore_interop.h
+++ b/explore/clr/explore_interop.h
@@ -1,358 +1,358 @@
-#pragma once
-
-#define MANAGED_CODE
-
-#include "explore_interface.h"
-#include "MWTExplorer.h"
-
-#include <msclr\marshal_cppstd.h>
-
-using namespace System;
-using namespace System::Collections::Generic;
-using namespace System::IO;
-using namespace System::Runtime::InteropServices;
-using namespace System::Xml::Serialization;
-using namespace msclr::interop;
-
-namespace MultiWorldTesting {
-
-// Policy callback
-private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index);
-typedef u32 Native_Policy_Callback(void* explorer, void* context, int index);
-
-// Scorer callback
-private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size);
-typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size);
-
-// Recorder callback
-private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey);
-typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key);
-
-// ToString callback
-private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue);
-typedef void Native_To_String_Callback(void* explorer, void* string_value);
-
-// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context
-// used for triggering callback for Policy, Scorer, Recorder
-class NativeContext
-{
-public:
- NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context)
- {
- m_clr_mwt = clr_mwt;
- m_clr_explorer = clr_explorer;
- m_clr_context = clr_context;
- }
-
- void* Get_Clr_Mwt()
- {
- return m_clr_mwt;
- }
-
- void* Get_Clr_Context()
- {
- return m_clr_context;
- }
-
- void* Get_Clr_Explorer()
- {
- return m_clr_explorer;
- }
-
-private:
- void* m_clr_mwt;
- void* m_clr_context;
- void* m_clr_explorer;
-};
-
-class NativeStringContext
-{
-public:
- NativeStringContext(void* clr_context, Native_To_String_Callback* func)
- {
- m_clr_context = clr_context;
- m_func = func;
- }
-
- string To_String()
- {
- string value;
- m_func(m_clr_context, &value);
- return value;
- }
-private:
- void* m_clr_context;
- Native_To_String_Callback* m_func;
-};
-
-// NativeRecorder listens to callback event and reroute it to the managed Recorder instance
-class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext>
-{
-public:
- NativeRecorder(Native_Recorder_Callback* native_func)
- {
- m_func = native_func;
- }
-
- void Record(NativeContext& context, u32 action, float probability, string unique_key)
- {
- GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str()));
- IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle;
-
- m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer());
-
- uniqueKeyHandle.Free();
- }
-private:
- Native_Recorder_Callback* m_func;
-};
-
-// NativePolicy listens to callback event and reroute it to the managed Policy instance
-class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext>
-{
-public:
- NativePolicy(Native_Policy_Callback* func, int index = -1)
- {
- m_func = func;
- m_index = index;
- }
-
- u32 Choose_Action(NativeContext& context)
- {
- return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index);
- }
-
-private:
- Native_Policy_Callback* m_func;
- int m_index;
-};
-
-class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext>
-{
-public:
- NativeScorer(Native_Scorer_Callback* func)
- {
- m_func = func;
- }
-
- vector<float> Score_Actions(NativeContext& context)
- {
- float* scores = nullptr;
- u32 num_scores = 0;
- m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores);
-
- // It's ok if scores is null, vector will be empty
- vector<float> scores_vector(scores, scores + num_scores);
- delete[] scores;
-
- return scores_vector;
- }
-private:
- Native_Scorer_Callback* m_func;
-};
-
-// Triggers callback to the Policy instance to choose an action
-generic <class Ctx>
-public ref class PolicyCallback abstract
-{
-internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0;
-
- PolicyCallback()
- {
- policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke);
- IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback);
- m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer());
- m_native_policy = nullptr;
- m_native_policies = nullptr;
- }
-
- ~PolicyCallback()
- {
- delete m_native_policy;
- delete m_native_policies;
- }
-
- NativePolicy* GetNativePolicy()
- {
- if (m_native_policy == nullptr)
- {
- m_native_policy = new NativePolicy(m_callback);
- }
- return m_native_policy;
- }
-
- vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* GetNativePolicies(int count)
- {
- if (m_native_policies == nullptr)
- {
- m_native_policies = new vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>();
- for (int i = 0; i < count; i++)
- {
- m_native_policies->push_back(unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>(new NativePolicy(m_callback, i)));
- }
- }
-
- return m_native_policies;
- }
-
- static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index)
- {
- GCHandle callbackHandle = (GCHandle)callbackPtr;
- PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- return callback->InvokePolicyCallback(context, index);
- }
-
-private:
- ClrPolicyCallback^ policyCallback;
-
-private:
- NativePolicy* m_native_policy;
- vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies;
- Native_Policy_Callback* m_callback;
-};
-
-// Triggers callback to the Recorder instance to record interaction data
-generic <class Ctx>
-public ref class RecorderCallback abstract
-{
-internal:
- virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0;
-
- RecorderCallback()
- {
- recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke);
- IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback);
- Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer());
- m_native_recorder = new NativeRecorder(callback);
- }
-
- ~RecorderCallback()
- {
- delete m_native_recorder;
- }
-
- NativeRecorder* GetNativeRecorder()
- {
- return m_native_recorder;
- }
-
- static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr)
- {
- GCHandle mwtHandle = (GCHandle)mwtPtr;
- RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr;
- String^ uniqueKey = (String^)uniqueKeyHandle.Target;
-
- callback->InvokeRecorderCallback(context, action, probability, uniqueKey);
- }
-
-private:
- ClrRecorderCallback^ recorderCallback;
-
-private:
- NativeRecorder* m_native_recorder;
-};
-
-// Triggers callback to the Recorder instance to record interaction data
-generic <class Ctx>
-public ref class ScorerCallback abstract
-{
-internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) = 0;
-
- ScorerCallback()
- {
- scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke);
- IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback);
- Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer());
- m_native_scorer = new NativeScorer(callback);
- }
-
- ~ScorerCallback()
- {
- delete m_native_scorer;
- }
-
- NativeScorer* GetNativeScorer()
- {
- return m_native_scorer;
- }
-
- static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr)
- {
- GCHandle callbackHandle = (GCHandle)callbackPtr;
- ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- List<float>^ scoreList = callback->InvokeScorerCallback(context);
-
- if (scoreList == nullptr || scoreList->Count == 0)
- {
- return;
- }
-
- u32* num_scores = (u32*)sizePtr.ToPointer();
- *num_scores = (u32)scoreList->Count;
-
- float* scores = new float[*num_scores];
- for (u32 i = 0; i < *num_scores; i++)
- {
- scores[i] = scoreList[i];
- }
-
- float** native_scores = (float**)scoresPtr.ToPointer();
- *native_scores = scores;
- }
-
-private:
- ClrScorerCallback^ scorerCallback;
-
-private:
- NativeScorer* m_native_scorer;
-};
-
-// Triggers callback to the Context instance to perform ToString() operation
-generic <class Ctx> where Ctx : IStringContext
-public ref class ToStringCallback
-{
-internal:
- ToStringCallback()
- {
- toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke);
- IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback);
- m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer());
- }
-
- Native_To_String_Callback* GetCallback()
- {
- return m_callback;
- }
-
- static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr)
- {
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- string* out_string = (string*)stringPtr.ToPointer();
- *out_string = marshal_as<string>(context->ToString());
- }
-
-private:
- ClrToStringCallback^ toStringCallback;
-
-private:
- Native_To_String_Callback* m_callback;
-};
-
+#pragma once
+
+#define MANAGED_CODE
+
+#include "explore_interface.h"
+#include "MWTExplorer.h"
+
+#include <msclr\marshal_cppstd.h>
+
+using namespace System;
+using namespace System::Collections::Generic;
+using namespace System::IO;
+using namespace System::Runtime::InteropServices;
+using namespace System::Xml::Serialization;
+using namespace msclr::interop;
+
+namespace MultiWorldTesting {
+
+// Policy callback
+private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index);
+typedef u32 Native_Policy_Callback(void* explorer, void* context, int index);
+
+// Scorer callback
+private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size);
+typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size);
+
+// Recorder callback
+private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey);
+typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key);
+
+// ToString callback
+private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue);
+typedef void Native_To_String_Callback(void* explorer, void* string_value);
+
+// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context
+// used for triggering callback for Policy, Scorer, Recorder
+class NativeContext
+{
+public:
+ NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context)
+ {
+ m_clr_mwt = clr_mwt;
+ m_clr_explorer = clr_explorer;
+ m_clr_context = clr_context;
+ }
+
+ void* Get_Clr_Mwt()
+ {
+ return m_clr_mwt;
+ }
+
+ void* Get_Clr_Context()
+ {
+ return m_clr_context;
+ }
+
+ void* Get_Clr_Explorer()
+ {
+ return m_clr_explorer;
+ }
+
+private:
+ void* m_clr_mwt;
+ void* m_clr_context;
+ void* m_clr_explorer;
+};
+
+class NativeStringContext
+{
+public:
+ NativeStringContext(void* clr_context, Native_To_String_Callback* func)
+ {
+ m_clr_context = clr_context;
+ m_func = func;
+ }
+
+ string To_String()
+ {
+ string value;
+ m_func(m_clr_context, &value);
+ return value;
+ }
+private:
+ void* m_clr_context;
+ Native_To_String_Callback* m_func;
+};
+
+// NativeRecorder listens to callback event and reroute it to the managed Recorder instance
+class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext>
+{
+public:
+ NativeRecorder(Native_Recorder_Callback* native_func)
+ {
+ m_func = native_func;
+ }
+
+ void Record(NativeContext& context, u32 action, float probability, string unique_key)
+ {
+ GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str()));
+ IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle;
+
+ m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer());
+
+ uniqueKeyHandle.Free();
+ }
+private:
+ Native_Recorder_Callback* m_func;
+};
+
+// NativePolicy listens to callback event and reroute it to the managed Policy instance
+class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext>
+{
+public:
+ NativePolicy(Native_Policy_Callback* func, int index = -1)
+ {
+ m_func = func;
+ m_index = index;
+ }
+
+ u32 Choose_Action(NativeContext& context)
+ {
+ return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index);
+ }
+
+private:
+ Native_Policy_Callback* m_func;
+ int m_index;
+};
+
+class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext>
+{
+public:
+ NativeScorer(Native_Scorer_Callback* func)
+ {
+ m_func = func;
+ }
+
+ vector<float> Score_Actions(NativeContext& context)
+ {
+ float* scores = nullptr;
+ u32 num_scores = 0;
+ m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores);
+
+ // It's ok if scores is null, vector will be empty
+ vector<float> scores_vector(scores, scores + num_scores);
+ delete[] scores;
+
+ return scores_vector;
+ }
+private:
+ Native_Scorer_Callback* m_func;
+};
+
+// Triggers callback to the Policy instance to choose an action
+generic <class Ctx>
+public ref class PolicyCallback abstract
+{
+internal:
+ virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0;
+
+ PolicyCallback()
+ {
+ policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke);
+ IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback);
+ m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer());
+ m_native_policy = nullptr;
+ m_native_policies = nullptr;
+ }
+
+ ~PolicyCallback()
+ {
+ delete m_native_policy;
+ delete m_native_policies;
+ }
+
+ NativePolicy* GetNativePolicy()
+ {
+ if (m_native_policy == nullptr)
+ {
+ m_native_policy = new NativePolicy(m_callback);
+ }
+ return m_native_policy;
+ }
+
+ vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* GetNativePolicies(int count)
+ {
+ if (m_native_policies == nullptr)
+ {
+ m_native_policies = new vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>();
+ for (int i = 0; i < count; i++)
+ {
+ m_native_policies->push_back(unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>(new NativePolicy(m_callback, i)));
+ }
+ }
+
+ return m_native_policies;
+ }
+
+ static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index)
+ {
+ GCHandle callbackHandle = (GCHandle)callbackPtr;
+ PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target;
+
+ GCHandle contextHandle = (GCHandle)contextPtr;
+ Ctx context = (Ctx)contextHandle.Target;
+
+ return callback->InvokePolicyCallback(context, index);
+ }
+
+private:
+ ClrPolicyCallback^ policyCallback;
+
+private:
+ NativePolicy* m_native_policy;
+ vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies;
+ Native_Policy_Callback* m_callback;
+};
+
+// Triggers callback to the Recorder instance to record interaction data
+generic <class Ctx>
+public ref class RecorderCallback abstract
+{
+internal:
+ virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0;
+
+ RecorderCallback()
+ {
+ recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke);
+ IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback);
+ Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer());
+ m_native_recorder = new NativeRecorder(callback);
+ }
+
+ ~RecorderCallback()
+ {
+ delete m_native_recorder;
+ }
+
+ NativeRecorder* GetNativeRecorder()
+ {
+ return m_native_recorder;
+ }
+
+ static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr)
+ {
+ GCHandle mwtHandle = (GCHandle)mwtPtr;
+ RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target;
+
+ GCHandle contextHandle = (GCHandle)contextPtr;
+ Ctx context = (Ctx)contextHandle.Target;
+
+ GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr;
+ String^ uniqueKey = (String^)uniqueKeyHandle.Target;
+
+ callback->InvokeRecorderCallback(context, action, probability, uniqueKey);
+ }
+
+private:
+ ClrRecorderCallback^ recorderCallback;
+
+private:
+ NativeRecorder* m_native_recorder;
+};
+
+// Triggers callback to the Recorder instance to record interaction data
+generic <class Ctx>
+public ref class ScorerCallback abstract
+{
+internal:
+ virtual List<float>^ InvokeScorerCallback(Ctx context) = 0;
+
+ ScorerCallback()
+ {
+ scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke);
+ IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback);
+ Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer());
+ m_native_scorer = new NativeScorer(callback);
+ }
+
+ ~ScorerCallback()
+ {
+ delete m_native_scorer;
+ }
+
+ NativeScorer* GetNativeScorer()
+ {
+ return m_native_scorer;
+ }
+
+ static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr)
+ {
+ GCHandle callbackHandle = (GCHandle)callbackPtr;
+ ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target;
+
+ GCHandle contextHandle = (GCHandle)contextPtr;
+ Ctx context = (Ctx)contextHandle.Target;
+
+ List<float>^ scoreList = callback->InvokeScorerCallback(context);
+
+ if (scoreList == nullptr || scoreList->Count == 0)
+ {
+ return;
+ }
+
+ u32* num_scores = (u32*)sizePtr.ToPointer();
+ *num_scores = (u32)scoreList->Count;
+
+ float* scores = new float[*num_scores];
+ for (u32 i = 0; i < *num_scores; i++)
+ {
+ scores[i] = scoreList[i];
+ }
+
+ float** native_scores = (float**)scoresPtr.ToPointer();
+ *native_scores = scores;
+ }
+
+private:
+ ClrScorerCallback^ scorerCallback;
+
+private:
+ NativeScorer* m_native_scorer;
+};
+
+// Triggers callback to the Context instance to perform ToString() operation
+generic <class Ctx> where Ctx : IStringContext
+public ref class ToStringCallback
+{
+internal:
+ ToStringCallback()
+ {
+ toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke);
+ IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback);
+ m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer());
+ }
+
+ Native_To_String_Callback* GetCallback()
+ {
+ return m_callback;
+ }
+
+ static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr)
+ {
+ GCHandle contextHandle = (GCHandle)contextPtr;
+ Ctx context = (Ctx)contextHandle.Target;
+
+ string* out_string = (string*)stringPtr.ToPointer();
+ *out_string = marshal_as<string>(context->ToString());
+ }
+
+private:
+ ClrToStringCallback^ toStringCallback;
+
+private:
+ Native_To_String_Callback* m_callback;
+};
+
} \ No newline at end of file
diff --git a/explore/explore.cpp b/explore/explore.cpp
index ebc1c14e..f44ac353 100644
--- a/explore/explore.cpp
+++ b/explore/explore.cpp
@@ -1,73 +1,73 @@
-// explore.cpp : Timing code to measure performance of MWT Explorer library
-
-#include "MWTExplorer.h"
-#include <chrono>
-#include <tuple>
-#include <iostream>
-
-using namespace std;
-using namespace std::chrono;
-
-using namespace MultiWorldTesting;
-
-class MySimplePolicy : public IPolicy<SimpleContext>
-{
-public:
- u32 Choose_Action(SimpleContext& context)
- {
- return (u32)1;
- }
-};
-
-const u32 num_actions = 10;
-
-void Clock_Explore()
-{
- float epsilon = .2f;
- string unique_key = "key";
- int num_features = 1000;
- int num_iter = 10000;
- int num_warmup = 100;
- int num_interactions = 1;
-
- // pre-create features
- vector<Feature> features;
- for (int i = 0; i < num_features; i++)
- {
- Feature f = {0.5, i+1};
- features.push_back(f);
- }
-
- long long time_init = 0, time_choose = 0;
- for (int iter = 0; iter < num_iter + num_warmup; iter++)
- {
- high_resolution_clock::time_point t1 = high_resolution_clock::now();
- StringRecorder<SimpleContext> recorder;
- MwtExplorer<SimpleContext> mwt("test", recorder);
- MySimplePolicy default_policy;
- EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions);
- high_resolution_clock::time_point t2 = high_resolution_clock::now();
- time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
-
- t1 = high_resolution_clock::now();
- SimpleContext appContext(features);
- for (int i = 0; i < num_interactions; i++)
- {
- mwt.Choose_Action(explorer, unique_key, appContext);
- }
- t2 = high_resolution_clock::now();
- time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
- }
-
- cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl;
- cout << "--- PER ITERATION ---" << endl;
- cout << "Init: " << (double)time_init / num_iter << " micro" << endl;
- cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl;
- cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl;
-}
-
-int main(int argc, char* argv[])
-{
- Clock_Explore();
- return 0;
-}
+// explore.cpp : Timing code to measure performance of MWT Explorer library
+
+#include "MWTExplorer.h"
+#include <chrono>
+#include <tuple>
+#include <iostream>
+
+using namespace std;
+using namespace std::chrono;
+
+using namespace MultiWorldTesting;
+
+class MySimplePolicy : public IPolicy<SimpleContext>
+{
+public:
+ u32 Choose_Action(SimpleContext& context)
+ {
+ return (u32)1;
+ }
+};
+
+const u32 num_actions = 10;
+
+void Clock_Explore()
+{
+ float epsilon = .2f;
+ string unique_key = "key";
+ int num_features = 1000;
+ int num_iter = 10000;
+ int num_warmup = 100;
+ int num_interactions = 1;
+
+ // pre-create features
+ vector<Feature> features;
+ for (int i = 0; i < num_features; i++)
+ {
+ Feature f = {0.5, i+1};
+ features.push_back(f);
+ }
+
+ long long time_init = 0, time_choose = 0;
+ for (int iter = 0; iter < num_iter + num_warmup; iter++)
+ {
+ high_resolution_clock::time_point t1 = high_resolution_clock::now();
+ StringRecorder<SimpleContext> recorder;
+ MwtExplorer<SimpleContext> mwt("test", recorder);
+ MySimplePolicy default_policy;
+ EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions);
+ high_resolution_clock::time_point t2 = high_resolution_clock::now();
+ time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
+
+ t1 = high_resolution_clock::now();
+ SimpleContext appContext(features);
+ for (int i = 0; i < num_interactions; i++)
+ {
+ mwt.Choose_Action(explorer, unique_key, appContext);
+ }
+ t2 = high_resolution_clock::now();
+ time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
+ }
+
+ cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl;
+ cout << "--- PER ITERATION ---" << endl;
+ cout << "Init: " << (double)time_init / num_iter << " micro" << endl;
+ cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl;
+ cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl;
+}
+
+int main(int argc, char* argv[])
+{
+ Clock_Explore();
+ return 0;
+}
diff --git a/explore/static/MWTExplorer.h b/explore/static/MWTExplorer.h
index dd5e3044..4c7009b2 100644
--- a/explore/static/MWTExplorer.h
+++ b/explore/static/MWTExplorer.h
@@ -1,669 +1,669 @@
-//
-// Main interface for clients of the Multiworld testing (MWT) service.
-//
-
-#pragma once
-
-#include <stdexcept>
-#include <float.h>
-#include <math.h>
-#include <stdio.h>
-#include <string.h>
-#include <vector>
-#include <utility>
-#include <memory>
-#include <limits.h>
-#include <tuple>
-
-#ifdef MANAGED_CODE
-#define PORTING_INTERFACE public
-#define MWT_NAMESPACE namespace NativeMultiWorldTesting
-#else
-#define PORTING_INTERFACE private
-#define MWT_NAMESPACE namespace MultiWorldTesting
-#endif
-
-using namespace std;
-
-#include "utility.h"
-
-/** \defgroup MultiWorldTestingCpp
-\brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
-*/
-
-/*!
-* \addtogroup MultiWorldTestingCpp
-* @{
-*/
-
-//! Interface for C++ version of Multiworld Testing library.
-//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
-MWT_NAMESPACE {
-
-// Forward declarations
-template <class Ctx>
-class IRecorder;
-template <class Ctx>
-class IExplorer;
-
-///
-/// The top-level MwtExplorer class. Using this enables principled and efficient exploration
-/// over a set of possible actions, and ensures that the right bits are recorded.
-///
-template <class Ctx>
-class MwtExplorer
-{
-public:
- ///
- /// Constructor
- ///
- /// @param appid This should be unique to your experiment or you risk nasty correlation bugs.
- /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning.
- ///
- MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder)
- {
- m_app_id = HashUtils::Compute_Id_Hash(app_id);
- }
-
- ///
- /// Chooses an action by invoking an underlying exploration algorithm. This should be a
- /// drop-in replacement for any existing policy function.
- ///
- /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback.
- /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc..
- /// @param context The context upon which a decision is made. See SimpleContext below for an example.
- ///
- u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context)
- {
- u64 seed = HashUtils::Compute_Id_Hash(unique_key);
-
- std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context);
-
- u32 action = std::get<0>(action_probability_log_tuple);
- float prob = std::get<1>(action_probability_log_tuple);
-
- if (std::get<2>(action_probability_log_tuple))
- {
- m_recorder.Record(context, action, prob, unique_key);
- }
-
- return action;
- }
-
-private:
- u64 m_app_id;
- IRecorder<Ctx>& m_recorder;
-};
-
-///
-/// Exposes a method to record exploration data based on generic contexts. Exploration data
-/// is specified as a set of tuples <context, action, probability, key> as described below. An
-/// application passes an IRecorder object to the @MwtExplorer constructor. See
-/// @StringRecorder for a sample IRecorder object.
-///
-template <class Ctx>
-class IRecorder
-{
-public:
- ///
- /// Records the exploration data associated with a given decision.
- ///
- /// @param context A user-defined context for the decision
- /// @param action The action chosen by an exploration algorithm given context
- /// @param probability The probability the exploration algorithm chose said action
- /// @param unique_key A user-defined unique identifer for the decision
- ///
- virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0;
-};
-
-///
-/// Exposes a method to choose an action given a generic context, and obtain the relevant
-/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this
-/// interface yourself: instead, use the various exploration algorithms below, which
-/// implement it for you.
-///
-template <class Ctx>
-class IExplorer
-{
-public:
- ///
- /// Determines the action to take and the probability with which it was chosen, for a
- /// given context.
- ///
- /// @param salted_seed A PRG seed based on a unique id information provided by the user
- /// @param context A user-defined context for the decision
- /// @returns The action to take, the probability it was chosen, and a flag indicating
- /// whether to record this decision
- ///
- virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0;
-};
-
-///
-/// Exposes a method to choose an action given a generic context. IPolicy objects are
-/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
-///
-template <class Ctx>
-class IPolicy
-{
-public:
- ///
- /// Determines the action to take for a given context.
- ///
- /// @param context A user-defined context for the decision
- /// @returns The action to take (1-based index)
- ///
- virtual u32 Choose_Action(Ctx& context) = 0;
-};
-
-///
-/// Exposes a method for specifying a score (weight) for each action given a generic context.
-///
-template <class Ctx>
-class IScorer
-{
-public:
- ///
- /// Determines the score of each action for a given context.
- ///
- /// @param context A user-defined context for the decision
- /// @returns A vector of scores indexed by action (1-based)
- ///
- virtual vector<float> Score_Actions(Ctx& context) = 0;
-};
-
-///
-/// A sample recorder class that converts the exploration tuple into string format.
-///
-template <class Ctx>
-struct StringRecorder : public IRecorder<Ctx>
-{
- void Record(Ctx& context, u32 action, float probability, string unique_key)
- {
- // Implicitly enforce To_String() API on the context
- m_recording.append(to_string((unsigned long)action));
- m_recording.append(" ", 1);
- m_recording.append(unique_key);
- m_recording.append(" ", 1);
-
- char prob_str[10] = { 0 };
- NumberUtils::Float_To_String(probability, prob_str);
- m_recording.append(prob_str);
-
- m_recording.append(" | ", 3);
- m_recording.append(context.To_String());
- m_recording.append("\n");
- }
-
- // Gets the content of the recording so far as a string and optionally clears internal content.
- string Get_Recording(bool flush = true)
- {
- if (!flush)
- {
- return m_recording;
- }
- string recording = m_recording;
- m_recording.clear();
- return recording;
- }
-
-private:
- string m_recording;
-};
-
-///
-/// Represents a feature in a sparse array.
-///
-struct Feature
-{
- float Value;
- u32 Id;
-
- bool operator==(Feature other_feature)
- {
- return Id == other_feature.Id;
- }
-};
-
-///
-/// A sample context class that stores a vector of Features.
-///
-class SimpleContext
-{
-public:
- SimpleContext(vector<Feature>& features) :
- m_features(features)
- { }
-
- vector<Feature>& Get_Features()
- {
- return m_features;
- }
-
- string To_String()
- {
- string out_string;
- char feature_str[35] = { 0 };
- for (size_t i = 0; i < m_features.size(); i++)
- {
- int chars;
- if (i == 0)
- {
- chars = sprintf(feature_str, "%d:", m_features[i].Id);
- }
- else
- {
- chars = sprintf(feature_str, " %d:", m_features[i].Id);
- }
- NumberUtils::print_float(feature_str + chars, m_features[i].Value);
- out_string.append(feature_str);
- }
- return out_string;
- }
-
-private:
- vector<Feature>& m_features;
-};
-
-///
-/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea
-/// which actions should be preferred. Epsilon greedy is also computationally cheap.
-///
-template <class Ctx>
-class EpsilonGreedyExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy A default function which outputs an action given a context.
- /// @param epsilon The probability of a random exploration.
- /// @param num_actions The number of actions to randomize over.
- ///
- EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) :
- m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
-
- if (m_epsilon < 0 || m_epsilon > 1)
- {
- throw std::invalid_argument("Epsilon must be between 0 and 1.");
- }
- }
-
- ~EpsilonGreedyExplorer()
- {
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default policy function to get the action
- u32 chosen_action = m_default_policy.Choose_Action(context);
-
- if (chosen_action == 0 || chosen_action > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- float action_probability = 0.f;
- float base_probability = m_epsilon / m_num_actions; // uniform probability
-
- // TODO: check this random generation
- if (random_generator.Uniform_Unit_Interval() < 1.f - m_epsilon)
- {
- action_probability = 1.f - m_epsilon + base_probability;
- }
- else
- {
- // Get uniform random action ID
- u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
-
- if (actionId == chosen_action)
- {
- // IF it matches the one chosen by the default policy
- // then increase the probability
- action_probability = 1.f - m_epsilon + base_probability;
- }
- else
- {
- // Otherwise it's just the uniform probability
- action_probability = base_probability;
- }
- chosen_action = actionId;
- }
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
- }
-
-private:
- IPolicy<Ctx>& m_default_policy;
- float m_epsilon;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// In some cases, different actions have a different scores, and you would prefer to
-/// choose actions with large scores. Softmax allows you to do that.
-///
-template <class Ctx>
-class SoftmaxExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_scorer A function which outputs a score for each action.
- /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.
- /// @param num_actions The number of actions to randomize over.
- ///
- SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) :
- m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default scorer function
- vector<float> scores = m_default_scorer.Score_Actions(context);
- u32 num_scores = (u32)scores.size();
- if (num_scores != m_num_actions)
- {
- throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions");
- }
-
- u32 i = 0;
-
- float max_score = -FLT_MAX;
- for (i = 0; i < num_scores; i++)
- {
- if (max_score < scores[i])
- {
- max_score = scores[i];
- }
- }
-
- // Create a normalized exponential distribution based on the returned scores
- for (i = 0; i < num_scores; i++)
- {
- scores[i] = exp(m_lambda * (scores[i] - max_score));
- }
-
- // Create a discrete_distribution based on the returned weights. This class handles the
- // case where the sum of the weights is < or > 1, by normalizing agains the sum.
- float total = 0.f;
- for (size_t i = 0; i < num_scores; i++)
- total += scores[i];
-
- float draw = random_generator.Uniform_Unit_Interval();
-
- float sum = 0.f;
- float action_probability = 0.f;
- u32 action_index = num_scores - 1;
- for (u32 i = 0; i < num_scores; i++)
- {
- scores[i] = scores[i] / total;
- sum += scores[i];
- if (sum > draw)
- {
- action_index = i;
- action_probability = scores[i];
- break;
- }
- }
-
- // action id is one-based
- return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
- }
-
-private:
- IScorer<Ctx>& m_default_scorer;
- float m_lambda;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// GenericExplorer provides complete flexibility. You can create any
-/// distribution over actions desired, and it will draw from that.
-///
-template <class Ctx>
-class GenericExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_scorer A function which outputs the probability of each action.
- /// @param num_actions The number of actions to randomize over.
- ///
- GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) :
- m_default_scorer(default_scorer), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default scorer function
- vector<float> weights = m_default_scorer.Score_Actions(context);
- u32 num_weights = (u32)weights.size();
- if (num_weights != m_num_actions)
- {
- throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions");
- }
-
- // Create a discrete_distribution based on the returned weights. This class handles the
- // case where the sum of the weights is < or > 1, by normalizing agains the sum.
- float total = 0.f;
- for (size_t i = 0; i < num_weights; i++)
- {
- if (weights[i] < 0)
- {
- throw std::invalid_argument("Scores must be non-negative.");
- }
- total += weights[i];
- }
- if (total == 0)
- {
- throw std::invalid_argument("At least one score must be positive.");
- }
-
- float draw = random_generator.Uniform_Unit_Interval();
-
- float sum = 0.f;
- float action_probability = 0.f;
- u32 action_index = num_weights - 1;
- for (u32 i = 0; i < num_weights; i++)
- {
- weights[i] = weights[i] / total;
- sum += weights[i];
- if (sum > draw)
- {
- action_index = i;
- action_probability = weights[i];
- break;
- }
- }
-
- // action id is one-based
- return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
- }
-
-private:
- IScorer<Ctx>& m_default_scorer;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// The tau-first explorer collects exactly tau uniform random exploration events, and then
-/// uses the default policy thereafter.
-///
-template <class Ctx>
-class TauFirstExplorer : public IExplorer<Ctx>
-{
-public:
-
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy A default policy after randomization finishes.
- /// @param tau The number of events to be uniform over.
- /// @param num_actions The number of actions to randomize over.
- ///
- TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) :
- m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- u32 chosen_action = 0;
- float action_probability = 0.f;
- bool log_action;
- if (m_tau)
- {
- m_tau--;
- u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
- action_probability = 1.f / m_num_actions;
- chosen_action = actionId;
- log_action = true;
- }
- else
- {
- // Invoke the default policy function to get the action
- chosen_action = m_default_policy.Choose_Action(context);
-
- if (chosen_action == 0 || chosen_action > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- action_probability = 1.f;
- log_action = false;
- }
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action);
- }
-
-private:
- IPolicy<Ctx>& m_default_policy;
- u32 m_tau;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies.
-/// This performs well statistically but can be computationally expensive.
-///
-template <class Ctx>
-class BootstrapExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy_functions A set of default policies to be uniform random over.
- /// The policy pointers must be valid throughout the lifetime of this explorer.
- /// @param num_actions The number of actions to randomize over.
- ///
- BootstrapExplorer(vector<unique_ptr<IPolicy<Ctx>>>& default_policy_functions, u32 num_actions) :
- m_default_policy_functions(default_policy_functions),
- m_num_actions(num_actions)
- {
- m_bags = (u32)default_policy_functions.size();
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
-
- if (m_bags < 1)
- {
- throw std::invalid_argument("Number of bags must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Select bag
- u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1);
-
- // Invoke the default policy function to get the action
- u32 chosen_action = 0;
- u32 action_from_bag = 0;
- vector<u32> actions_selected;
- for (size_t i = 0; i < m_num_actions; i++)
- {
- actions_selected.push_back(0);
- }
-
- // Invoke the default policy function to get the action
- for (u32 current_bag = 0; current_bag < m_bags; current_bag++)
- {
- action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context);
-
- if (action_from_bag == 0 || action_from_bag > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- if (current_bag == chosen_bag)
- {
- chosen_action = action_from_bag;
- }
- //this won't work if actions aren't 0 to Count
- actions_selected[action_from_bag - 1]++; // action id is one-based
- }
- float action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
- }
-
-private:
- vector<unique_ptr<IPolicy<Ctx>>>& m_default_policy_functions;
- u32 m_bags;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-} // End namespace MultiWorldTestingCpp
-/*! @} End of Doxygen Groups*/
+//
+// Main interface for clients of the Multiworld testing (MWT) service.
+//
+
+#pragma once
+
+#include <stdexcept>
+#include <float.h>
+#include <math.h>
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include <utility>
+#include <memory>
+#include <limits.h>
+#include <tuple>
+
+#ifdef MANAGED_CODE
+#define PORTING_INTERFACE public
+#define MWT_NAMESPACE namespace NativeMultiWorldTesting
+#else
+#define PORTING_INTERFACE private
+#define MWT_NAMESPACE namespace MultiWorldTesting
+#endif
+
+using namespace std;
+
+#include "utility.h"
+
+/** \defgroup MultiWorldTestingCpp
+\brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
+*/
+
+/*!
+* \addtogroup MultiWorldTestingCpp
+* @{
+*/
+
+//! Interface for C++ version of Multiworld Testing library.
+//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
+MWT_NAMESPACE {
+
+// Forward declarations
+template <class Ctx>
+class IRecorder;
+template <class Ctx>
+class IExplorer;
+
+///
+/// The top-level MwtExplorer class. Using this enables principled and efficient exploration
+/// over a set of possible actions, and ensures that the right bits are recorded.
+///
+template <class Ctx>
+class MwtExplorer
+{
+public:
+ ///
+ /// Constructor
+ ///
+ /// @param appid This should be unique to your experiment or you risk nasty correlation bugs.
+ /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning.
+ ///
+ MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder)
+ {
+ m_app_id = HashUtils::Compute_Id_Hash(app_id);
+ }
+
+ ///
+ /// Chooses an action by invoking an underlying exploration algorithm. This should be a
+ /// drop-in replacement for any existing policy function.
+ ///
+ /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback.
+ /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc..
+ /// @param context The context upon which a decision is made. See SimpleContext below for an example.
+ ///
+ u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context)
+ {
+ u64 seed = HashUtils::Compute_Id_Hash(unique_key);
+
+ std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context);
+
+ u32 action = std::get<0>(action_probability_log_tuple);
+ float prob = std::get<1>(action_probability_log_tuple);
+
+ if (std::get<2>(action_probability_log_tuple))
+ {
+ m_recorder.Record(context, action, prob, unique_key);
+ }
+
+ return action;
+ }
+
+private:
+ u64 m_app_id;
+ IRecorder<Ctx>& m_recorder;
+};
+
+///
+/// Exposes a method to record exploration data based on generic contexts. Exploration data
+/// is specified as a set of tuples <context, action, probability, key> as described below. An
+/// application passes an IRecorder object to the @MwtExplorer constructor. See
+/// @StringRecorder for a sample IRecorder object.
+///
+template <class Ctx>
+class IRecorder
+{
+public:
+ ///
+ /// Records the exploration data associated with a given decision.
+ ///
+ /// @param context A user-defined context for the decision
+ /// @param action The action chosen by an exploration algorithm given context
+ /// @param probability The probability the exploration algorithm chose said action
+ /// @param unique_key A user-defined unique identifer for the decision
+ ///
+ virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0;
+};
+
+///
+/// Exposes a method to choose an action given a generic context, and obtain the relevant
+/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this
+/// interface yourself: instead, use the various exploration algorithms below, which
+/// implement it for you.
+///
+template <class Ctx>
+class IExplorer
+{
+public:
+ ///
+ /// Determines the action to take and the probability with which it was chosen, for a
+ /// given context.
+ ///
+ /// @param salted_seed A PRG seed based on a unique id information provided by the user
+ /// @param context A user-defined context for the decision
+ /// @returns The action to take, the probability it was chosen, and a flag indicating
+ /// whether to record this decision
+ ///
+ virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0;
+};
+
+///
+/// Exposes a method to choose an action given a generic context. IPolicy objects are
+/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
+///
+template <class Ctx>
+class IPolicy
+{
+public:
+ ///
+ /// Determines the action to take for a given context.
+ ///
+ /// @param context A user-defined context for the decision
+ /// @returns The action to take (1-based index)
+ ///
+ virtual u32 Choose_Action(Ctx& context) = 0;
+};
+
+///
+/// Exposes a method for specifying a score (weight) for each action given a generic context.
+///
+template <class Ctx>
+class IScorer
+{
+public:
+ ///
+ /// Determines the score of each action for a given context.
+ ///
+ /// @param context A user-defined context for the decision
+ /// @returns A vector of scores indexed by action (1-based)
+ ///
+ virtual vector<float> Score_Actions(Ctx& context) = 0;
+};
+
+///
+/// A sample recorder class that converts the exploration tuple into string format.
+///
+template <class Ctx>
+struct StringRecorder : public IRecorder<Ctx>
+{
+ void Record(Ctx& context, u32 action, float probability, string unique_key)
+ {
+ // Implicitly enforce To_String() API on the context
+ m_recording.append(to_string((unsigned long)action));
+ m_recording.append(" ", 1);
+ m_recording.append(unique_key);
+ m_recording.append(" ", 1);
+
+ char prob_str[10] = { 0 };
+ NumberUtils::Float_To_String(probability, prob_str);
+ m_recording.append(prob_str);
+
+ m_recording.append(" | ", 3);
+ m_recording.append(context.To_String());
+ m_recording.append("\n");
+ }
+
+ // Gets the content of the recording so far as a string and optionally clears internal content.
+ string Get_Recording(bool flush = true)
+ {
+ if (!flush)
+ {
+ return m_recording;
+ }
+ string recording = m_recording;
+ m_recording.clear();
+ return recording;
+ }
+
+private:
+ string m_recording;
+};
+
+///
+/// Represents a feature in a sparse array.
+///
+struct Feature
+{
+ float Value;
+ u32 Id;
+
+ bool operator==(Feature other_feature)
+ {
+ return Id == other_feature.Id;
+ }
+};
+
+///
+/// A sample context class that stores a vector of Features.
+///
+class SimpleContext
+{
+public:
+ SimpleContext(vector<Feature>& features) :
+ m_features(features)
+ { }
+
+ vector<Feature>& Get_Features()
+ {
+ return m_features;
+ }
+
+ string To_String()
+ {
+ string out_string;
+ char feature_str[35] = { 0 };
+ for (size_t i = 0; i < m_features.size(); i++)
+ {
+ int chars;
+ if (i == 0)
+ {
+ chars = sprintf(feature_str, "%d:", m_features[i].Id);
+ }
+ else
+ {
+ chars = sprintf(feature_str, " %d:", m_features[i].Id);
+ }
+ NumberUtils::print_float(feature_str + chars, m_features[i].Value);
+ out_string.append(feature_str);
+ }
+ return out_string;
+ }
+
+private:
+ vector<Feature>& m_features;
+};
+
+///
+/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea
+/// which actions should be preferred. Epsilon greedy is also computationally cheap.
+///
+template <class Ctx>
+class EpsilonGreedyExplorer : public IExplorer<Ctx>
+{
+public:
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_policy A default function which outputs an action given a context.
+ /// @param epsilon The probability of a random exploration.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) :
+ m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions)
+ {
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+
+ if (m_epsilon < 0 || m_epsilon > 1)
+ {
+ throw std::invalid_argument("Epsilon must be between 0 and 1.");
+ }
+ }
+
+ ~EpsilonGreedyExplorer()
+ {
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ // Invoke the default policy function to get the action
+ u32 chosen_action = m_default_policy.Choose_Action(context);
+
+ if (chosen_action == 0 || chosen_action > m_num_actions)
+ {
+ throw std::invalid_argument("Action chosen by default policy is not within valid range.");
+ }
+
+ float action_probability = 0.f;
+ float base_probability = m_epsilon / m_num_actions; // uniform probability
+
+ // TODO: check this random generation
+ if (random_generator.Uniform_Unit_Interval() < 1.f - m_epsilon)
+ {
+ action_probability = 1.f - m_epsilon + base_probability;
+ }
+ else
+ {
+ // Get uniform random action ID
+ u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
+
+ if (actionId == chosen_action)
+ {
+ // IF it matches the one chosen by the default policy
+ // then increase the probability
+ action_probability = 1.f - m_epsilon + base_probability;
+ }
+ else
+ {
+ // Otherwise it's just the uniform probability
+ action_probability = base_probability;
+ }
+ chosen_action = actionId;
+ }
+
+ return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
+ }
+
+private:
+ IPolicy<Ctx>& m_default_policy;
+ float m_epsilon;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+
+///
+/// In some cases, different actions have a different scores, and you would prefer to
+/// choose actions with large scores. Softmax allows you to do that.
+///
+template <class Ctx>
+class SoftmaxExplorer : public IExplorer<Ctx>
+{
+public:
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_scorer A function which outputs a score for each action.
+ /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) :
+ m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions)
+ {
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ // Invoke the default scorer function
+ vector<float> scores = m_default_scorer.Score_Actions(context);
+ u32 num_scores = (u32)scores.size();
+ if (num_scores != m_num_actions)
+ {
+ throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions");
+ }
+
+ u32 i = 0;
+
+ float max_score = -FLT_MAX;
+ for (i = 0; i < num_scores; i++)
+ {
+ if (max_score < scores[i])
+ {
+ max_score = scores[i];
+ }
+ }
+
+ // Create a normalized exponential distribution based on the returned scores
+ for (i = 0; i < num_scores; i++)
+ {
+ scores[i] = exp(m_lambda * (scores[i] - max_score));
+ }
+
+ // Create a discrete_distribution based on the returned weights. This class handles the
+ // case where the sum of the weights is < or > 1, by normalizing agains the sum.
+ float total = 0.f;
+ for (size_t i = 0; i < num_scores; i++)
+ total += scores[i];
+
+ float draw = random_generator.Uniform_Unit_Interval();
+
+ float sum = 0.f;
+ float action_probability = 0.f;
+ u32 action_index = num_scores - 1;
+ for (u32 i = 0; i < num_scores; i++)
+ {
+ scores[i] = scores[i] / total;
+ sum += scores[i];
+ if (sum > draw)
+ {
+ action_index = i;
+ action_probability = scores[i];
+ break;
+ }
+ }
+
+ // action id is one-based
+ return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
+ }
+
+private:
+ IScorer<Ctx>& m_default_scorer;
+ float m_lambda;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+
+///
+/// GenericExplorer provides complete flexibility. You can create any
+/// distribution over actions desired, and it will draw from that.
+///
+template <class Ctx>
+class GenericExplorer : public IExplorer<Ctx>
+{
+public:
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_scorer A function which outputs the probability of each action.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) :
+ m_default_scorer(default_scorer), m_num_actions(num_actions)
+ {
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ // Invoke the default scorer function
+ vector<float> weights = m_default_scorer.Score_Actions(context);
+ u32 num_weights = (u32)weights.size();
+ if (num_weights != m_num_actions)
+ {
+ throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions");
+ }
+
+ // Create a discrete_distribution based on the returned weights. This class handles the
+ // case where the sum of the weights is < or > 1, by normalizing agains the sum.
+ float total = 0.f;
+ for (size_t i = 0; i < num_weights; i++)
+ {
+ if (weights[i] < 0)
+ {
+ throw std::invalid_argument("Scores must be non-negative.");
+ }
+ total += weights[i];
+ }
+ if (total == 0)
+ {
+ throw std::invalid_argument("At least one score must be positive.");
+ }
+
+ float draw = random_generator.Uniform_Unit_Interval();
+
+ float sum = 0.f;
+ float action_probability = 0.f;
+ u32 action_index = num_weights - 1;
+ for (u32 i = 0; i < num_weights; i++)
+ {
+ weights[i] = weights[i] / total;
+ sum += weights[i];
+ if (sum > draw)
+ {
+ action_index = i;
+ action_probability = weights[i];
+ break;
+ }
+ }
+
+ // action id is one-based
+ return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
+ }
+
+private:
+ IScorer<Ctx>& m_default_scorer;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+
+///
+/// The tau-first explorer collects exactly tau uniform random exploration events, and then
+/// uses the default policy thereafter.
+///
+template <class Ctx>
+class TauFirstExplorer : public IExplorer<Ctx>
+{
+public:
+
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_policy A default policy after randomization finishes.
+ /// @param tau The number of events to be uniform over.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) :
+ m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions)
+ {
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ u32 chosen_action = 0;
+ float action_probability = 0.f;
+ bool log_action;
+ if (m_tau)
+ {
+ m_tau--;
+ u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
+ action_probability = 1.f / m_num_actions;
+ chosen_action = actionId;
+ log_action = true;
+ }
+ else
+ {
+ // Invoke the default policy function to get the action
+ chosen_action = m_default_policy.Choose_Action(context);
+
+ if (chosen_action == 0 || chosen_action > m_num_actions)
+ {
+ throw std::invalid_argument("Action chosen by default policy is not within valid range.");
+ }
+
+ action_probability = 1.f;
+ log_action = false;
+ }
+
+ return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action);
+ }
+
+private:
+ IPolicy<Ctx>& m_default_policy;
+ u32 m_tau;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+
+///
+/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies.
+/// This performs well statistically but can be computationally expensive.
+///
+template <class Ctx>
+class BootstrapExplorer : public IExplorer<Ctx>
+{
+public:
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_policy_functions A set of default policies to be uniform random over.
+ /// The policy pointers must be valid throughout the lifetime of this explorer.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ BootstrapExplorer(vector<unique_ptr<IPolicy<Ctx>>>& default_policy_functions, u32 num_actions) :
+ m_default_policy_functions(default_policy_functions),
+ m_num_actions(num_actions)
+ {
+ m_bags = (u32)default_policy_functions.size();
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+
+ if (m_bags < 1)
+ {
+ throw std::invalid_argument("Number of bags must be at least 1.");
+ }
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ // Select bag
+ u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1);
+
+ // Invoke the default policy function to get the action
+ u32 chosen_action = 0;
+ u32 action_from_bag = 0;
+ vector<u32> actions_selected;
+ for (size_t i = 0; i < m_num_actions; i++)
+ {
+ actions_selected.push_back(0);
+ }
+
+ // Invoke the default policy function to get the action
+ for (u32 current_bag = 0; current_bag < m_bags; current_bag++)
+ {
+ action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context);
+
+ if (action_from_bag == 0 || action_from_bag > m_num_actions)
+ {
+ throw std::invalid_argument("Action chosen by default policy is not within valid range.");
+ }
+
+ if (current_bag == chosen_bag)
+ {
+ chosen_action = action_from_bag;
+ }
+ //this won't work if actions aren't 0 to Count
+ actions_selected[action_from_bag - 1]++; // action id is one-based
+ }
+ float action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based
+
+ return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
+ }
+
+private:
+ vector<unique_ptr<IPolicy<Ctx>>>& m_default_policy_functions;
+ u32 m_bags;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+} // End namespace MultiWorldTestingCpp
+/*! @} End of Doxygen Groups*/
diff --git a/explore/static/utility.h b/explore/static/utility.h
index a6284809..6a2a60d6 100644
--- a/explore/static/utility.h
+++ b/explore/static/utility.h
@@ -1,286 +1,286 @@
/*******************************************************************/
// Classes declared in this file are intended for internal use only.
-/*******************************************************************/
-
-#pragma once
-#include <stdint.h>
-#include <sys/types.h> /* defines size_t */
-
-#ifdef WIN32
-typedef unsigned __int64 u64;
-typedef unsigned __int32 u32;
-typedef unsigned __int16 u16;
-typedef unsigned __int8 u8;
-typedef signed __int64 i64;
-typedef signed __int32 i32;
-typedef signed __int16 i16;
-typedef signed __int8 i8;
-#else
-typedef uint64_t u64;
-typedef uint32_t u32;
-typedef uint16_t u16;
-typedef uint8_t u8;
-typedef int64_t i64;
-typedef int32_t i32;
-typedef int16_t i16;
-typedef int8_t i8;
-#endif
-
-typedef unsigned char byte;
-
-#include <string>
-#include <stdint.h>
-#include <math.h>
-
-/*!
-* \addtogroup MultiWorldTestingCpp
-* @{
-*/
-
-MWT_NAMESPACE {
-
-//
-// MurmurHash3, by Austin Appleby
-//
-// Originals at:
-// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp
-// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.h
-//
-// Notes:
-// 1) this code assumes we can read a 4-byte value from any address
-// without crashing (i.e non aligned access is supported). This is
-// not a problem on Intel/x86/AMD64 machines (including new Macs)
-// 2) It produces different results on little-endian and big-endian machines.
-//
-//-----------------------------------------------------------------------------
-// MurmurHash3 was written by Austin Appleby, and is placed in the public
-// domain. The author hereby disclaims copyright to this source code.
-
-// Note - The x86 and x64 versions do _not_ produce the same results, as the
-// algorithms are optimized for their respective platforms. You can still
-// compile and run any of them on any platform, but your performance with the
-// non-native version will be less than optimal.
-//-----------------------------------------------------------------------------
-
-// Platform-specific functions and macros
-#if defined(_MSC_VER) // Microsoft Visual Studio
-# include <stdint.h>
-
-# include <stdlib.h>
-# define ROTL32(x,y) _rotl(x,y)
-# define BIG_CONSTANT(x) (x)
-
-#else // Other compilers
-# include <stdint.h> /* defines uint32_t etc */
-
- inline uint32_t rotl32(uint32_t x, int8_t r)
- {
- return (x << r) | (x >> (32 - r));
- }
-
-# define ROTL32(x,y) rotl32(x,y)
-# define BIG_CONSTANT(x) (x##LLU)
-
-#endif // !defined(_MSC_VER)
-
-struct murmur_hash {
-
- //-----------------------------------------------------------------------------
- // Block read - if your platform needs to do endian-swapping or can only
- // handle aligned reads, do the conversion here
-private:
- static inline uint32_t getblock(const uint32_t * p, int i)
- {
- return p[i];
- }
-
- //-----------------------------------------------------------------------------
- // Finalization mix - force all bits of a hash block to avalanche
-
- static inline uint32_t fmix(uint32_t h)
- {
- h ^= h >> 16;
- h *= 0x85ebca6b;
- h ^= h >> 13;
- h *= 0xc2b2ae35;
- h ^= h >> 16;
-
- return h;
- }
-
- //-----------------------------------------------------------------------------
-public:
- uint32_t uniform_hash(const void * key, size_t len, uint32_t seed)
- {
- const uint8_t * data = (const uint8_t*)key;
- const int nblocks = (int)len / 4;
-
- uint32_t h1 = seed;
-
- const uint32_t c1 = 0xcc9e2d51;
- const uint32_t c2 = 0x1b873593;
-
- // --- body
- const uint32_t * blocks = (const uint32_t *)(data + nblocks * 4);
-
- for (int i = -nblocks; i; i++) {
- uint32_t k1 = getblock(blocks, i);
-
- k1 *= c1;
- k1 = ROTL32(k1, 15);
- k1 *= c2;
-
- h1 ^= k1;
- h1 = ROTL32(h1, 13);
- h1 = h1 * 5 + 0xe6546b64;
- }
-
- // --- tail
- const uint8_t * tail = (const uint8_t*)(data + nblocks * 4);
-
- uint32_t k1 = 0;
-
- switch (len & 3) {
- case 3: k1 ^= tail[2] << 16;
- case 2: k1 ^= tail[1] << 8;
- case 1: k1 ^= tail[0];
- k1 *= c1; k1 = ROTL32(k1, 15); k1 *= c2; h1 ^= k1;
- }
-
- // --- finalization
- h1 ^= len;
-
- return fmix(h1);
- }
-};
-
-class HashUtils
-{
-public:
- static u64 Compute_Id_Hash(const std::string& unique_id)
- {
- size_t ret = 0;
- const char *p = unique_id.c_str();
- while (*p != '\0')
- if (*p >= '0' && *p <= '9')
- ret = 10 * ret + *(p++) - '0';
- else
- {
- murmur_hash foo;
- return foo.uniform_hash(unique_id.c_str(), unique_id.size(), 0);
- }
- return ret;
- }
-};
-
-const size_t max_int = 100000;
-const float max_float = max_int;
-const float min_float = 0.00001f;
-const size_t max_digits = (size_t) roundf((float) (-log(min_float) / log(10.)));
-
-class NumberUtils
-{
-public:
- static void Float_To_String(float f, char* str)
- {
- int x = (int)f;
- int d = (int)(fabs(f - x) * 100000);
- sprintf(str, "%d.%05d", x, d);
- }
-
- template<bool trailing_zeros>
- static void print_mantissa(char*& begin, float f)
- { // helper for print_float
- char values[10];
- size_t v = (size_t)f;
- size_t digit = 0;
- size_t first_nonzero = 0;
- for (size_t max = 1; max <= v; ++digit)
- {
- size_t max_next = max * 10;
- char v_mod = (char) (v % max_next / max);
- if (!trailing_zeros && v_mod != '\0' && first_nonzero == 0)
- first_nonzero = digit;
- values[digit] = '0' + v_mod;
- max = max_next;
- }
- if (!trailing_zeros)
- for (size_t i = max_digits; i > digit; i--)
- *begin++ = '0';
- while (digit > first_nonzero)
- *begin++ = values[--digit];
- }
-
- static void print_float(char* begin, float f)
- {
- bool sign = false;
- if (f < 0.f)
- sign = true;
- float unsigned_f = fabsf(f);
- if (unsigned_f < max_float && unsigned_f > min_float)
- {
- if (sign)
- *begin++ = '-';
- print_mantissa<true>(begin, unsigned_f);
- unsigned_f -= (size_t)unsigned_f;
- unsigned_f *= max_int;
- if (unsigned_f >= 1.f)
- {
- *begin++ = '.';
- print_mantissa<false>(begin, unsigned_f);
- }
- }
- else if (unsigned_f == 0.)
- *begin++ = '0';
- else
- {
- sprintf(begin, "%g", f);
- return;
- }
- *begin = '\0';
- return;
- }
-};
-
-//A quick implementation similar to drand48 for cross-platform compatibility
-namespace PRG {
- const uint64_t a = 0xeece66d5deece66dULL;
- const uint64_t c = 2147483647;
-
- const int bias = 127 << 23;
-
- union int_float {
- int32_t i;
- float f;
- };
-
- struct prg {
- private:
- uint64_t v;
- public:
- prg() { v = c; }
- prg(uint64_t initial) { v = initial; }
-
- float merand48(uint64_t& initial)
- {
- initial = a * initial + c;
- int_float temp;
- temp.i = ((initial >> 25) & 0x7FFFFF) | bias;
- return temp.f - 1;
- }
-
- float Uniform_Unit_Interval()
- {
- return merand48(v);
- }
-
- uint32_t Uniform_Int(uint32_t low, uint32_t high)
- {
- merand48(v);
- uint32_t ret = low + ((v >> 25) % (high - low + 1));
- return ret;
- }
- };
-}
-}
+/*******************************************************************/
+
+#pragma once
+#include <stdint.h>
+#include <sys/types.h> /* defines size_t */
+
+#ifdef WIN32
+typedef unsigned __int64 u64;
+typedef unsigned __int32 u32;
+typedef unsigned __int16 u16;
+typedef unsigned __int8 u8;
+typedef signed __int64 i64;
+typedef signed __int32 i32;
+typedef signed __int16 i16;
+typedef signed __int8 i8;
+#else
+typedef uint64_t u64;
+typedef uint32_t u32;
+typedef uint16_t u16;
+typedef uint8_t u8;
+typedef int64_t i64;
+typedef int32_t i32;
+typedef int16_t i16;
+typedef int8_t i8;
+#endif
+
+typedef unsigned char byte;
+
+#include <string>
+#include <stdint.h>
+#include <math.h>
+
+/*!
+* \addtogroup MultiWorldTestingCpp
+* @{
+*/
+
+MWT_NAMESPACE {
+
+//
+// MurmurHash3, by Austin Appleby
+//
+// Originals at:
+// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp
+// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.h
+//
+// Notes:
+// 1) this code assumes we can read a 4-byte value from any address
+// without crashing (i.e non aligned access is supported). This is
+// not a problem on Intel/x86/AMD64 machines (including new Macs)
+// 2) It produces different results on little-endian and big-endian machines.
+//
+//-----------------------------------------------------------------------------
+// MurmurHash3 was written by Austin Appleby, and is placed in the public
+// domain. The author hereby disclaims copyright to this source code.
+
+// Note - The x86 and x64 versions do _not_ produce the same results, as the
+// algorithms are optimized for their respective platforms. You can still
+// compile and run any of them on any platform, but your performance with the
+// non-native version will be less than optimal.
+//-----------------------------------------------------------------------------
+
+// Platform-specific functions and macros
+#if defined(_MSC_VER) // Microsoft Visual Studio
+# include <stdint.h>
+
+# include <stdlib.h>
+# define ROTL32(x,y) _rotl(x,y)
+# define BIG_CONSTANT(x) (x)
+
+#else // Other compilers
+# include <stdint.h> /* defines uint32_t etc */
+
+ inline uint32_t rotl32(uint32_t x, int8_t r)
+ {
+ return (x << r) | (x >> (32 - r));
+ }
+
+# define ROTL32(x,y) rotl32(x,y)
+# define BIG_CONSTANT(x) (x##LLU)
+
+#endif // !defined(_MSC_VER)
+
+struct murmur_hash {
+
+ //-----------------------------------------------------------------------------
+ // Block read - if your platform needs to do endian-swapping or can only
+ // handle aligned reads, do the conversion here
+private:
+ static inline uint32_t getblock(const uint32_t * p, int i)
+ {
+ return p[i];
+ }
+
+ //-----------------------------------------------------------------------------
+ // Finalization mix - force all bits of a hash block to avalanche
+
+ static inline uint32_t fmix(uint32_t h)
+ {
+ h ^= h >> 16;
+ h *= 0x85ebca6b;
+ h ^= h >> 13;
+ h *= 0xc2b2ae35;
+ h ^= h >> 16;
+
+ return h;
+ }
+
+ //-----------------------------------------------------------------------------
+public:
+ uint32_t uniform_hash(const void * key, size_t len, uint32_t seed)
+ {
+ const uint8_t * data = (const uint8_t*)key;
+ const int nblocks = (int)len / 4;
+
+ uint32_t h1 = seed;
+
+ const uint32_t c1 = 0xcc9e2d51;
+ const uint32_t c2 = 0x1b873593;
+
+ // --- body
+ const uint32_t * blocks = (const uint32_t *)(data + nblocks * 4);
+
+ for (int i = -nblocks; i; i++) {
+ uint32_t k1 = getblock(blocks, i);
+
+ k1 *= c1;
+ k1 = ROTL32(k1, 15);
+ k1 *= c2;
+
+ h1 ^= k1;
+ h1 = ROTL32(h1, 13);
+ h1 = h1 * 5 + 0xe6546b64;
+ }
+
+ // --- tail
+ const uint8_t * tail = (const uint8_t*)(data + nblocks * 4);
+
+ uint32_t k1 = 0;
+
+ switch (len & 3) {
+ case 3: k1 ^= tail[2] << 16;
+ case 2: k1 ^= tail[1] << 8;
+ case 1: k1 ^= tail[0];
+ k1 *= c1; k1 = ROTL32(k1, 15); k1 *= c2; h1 ^= k1;
+ }
+
+ // --- finalization
+ h1 ^= len;
+
+ return fmix(h1);
+ }
+};
+
+class HashUtils
+{
+public:
+ static u64 Compute_Id_Hash(const std::string& unique_id)
+ {
+ size_t ret = 0;
+ const char *p = unique_id.c_str();
+ while (*p != '\0')
+ if (*p >= '0' && *p <= '9')
+ ret = 10 * ret + *(p++) - '0';
+ else
+ {
+ murmur_hash foo;
+ return foo.uniform_hash(unique_id.c_str(), unique_id.size(), 0);
+ }
+ return ret;
+ }
+};
+
+const size_t max_int = 100000;
+const float max_float = max_int;
+const float min_float = 0.00001f;
+const size_t max_digits = (size_t) roundf((float) (-log(min_float) / log(10.)));
+
+class NumberUtils
+{
+public:
+ static void Float_To_String(float f, char* str)
+ {
+ int x = (int)f;
+ int d = (int)(fabs(f - x) * 100000);
+ sprintf(str, "%d.%05d", x, d);
+ }
+
+ template<bool trailing_zeros>
+ static void print_mantissa(char*& begin, float f)
+ { // helper for print_float
+ char values[10];
+ size_t v = (size_t)f;
+ size_t digit = 0;
+ size_t first_nonzero = 0;
+ for (size_t max = 1; max <= v; ++digit)
+ {
+ size_t max_next = max * 10;
+ char v_mod = (char) (v % max_next / max);
+ if (!trailing_zeros && v_mod != '\0' && first_nonzero == 0)
+ first_nonzero = digit;
+ values[digit] = '0' + v_mod;
+ max = max_next;
+ }
+ if (!trailing_zeros)
+ for (size_t i = max_digits; i > digit; i--)
+ *begin++ = '0';
+ while (digit > first_nonzero)
+ *begin++ = values[--digit];
+ }
+
+ static void print_float(char* begin, float f)
+ {
+ bool sign = false;
+ if (f < 0.f)
+ sign = true;
+ float unsigned_f = fabsf(f);
+ if (unsigned_f < max_float && unsigned_f > min_float)
+ {
+ if (sign)
+ *begin++ = '-';
+ print_mantissa<true>(begin, unsigned_f);
+ unsigned_f -= (size_t)unsigned_f;
+ unsigned_f *= max_int;
+ if (unsigned_f >= 1.f)
+ {
+ *begin++ = '.';
+ print_mantissa<false>(begin, unsigned_f);
+ }
+ }
+ else if (unsigned_f == 0.)
+ *begin++ = '0';
+ else
+ {
+ sprintf(begin, "%g", f);
+ return;
+ }
+ *begin = '\0';
+ return;
+ }
+};
+
+//A quick implementation similar to drand48 for cross-platform compatibility
+namespace PRG {
+ const uint64_t a = 0xeece66d5deece66dULL;
+ const uint64_t c = 2147483647;
+
+ const int bias = 127 << 23;
+
+ union int_float {
+ int32_t i;
+ float f;
+ };
+
+ struct prg {
+ private:
+ uint64_t v;
+ public:
+ prg() { v = c; }
+ prg(uint64_t initial) { v = initial; }
+
+ float merand48(uint64_t& initial)
+ {
+ initial = a * initial + c;
+ int_float temp;
+ temp.i = ((initial >> 25) & 0x7FFFFF) | bias;
+ return temp.f - 1;
+ }
+
+ float Uniform_Unit_Interval()
+ {
+ return merand48(v);
+ }
+
+ uint32_t Uniform_Int(uint32_t low, uint32_t high)
+ {
+ merand48(v);
+ uint32_t ret = low + ((v >> 25) % (high - low + 1));
+ return ret;
+ }
+ };
+}
+}
/*! @} End of Doxygen Groups*/
diff --git a/explore/tests/MWTExploreTests.h b/explore/tests/MWTExploreTests.h
index 447a368f..1c451fb4 100644
--- a/explore/tests/MWTExploreTests.h
+++ b/explore/tests/MWTExploreTests.h
@@ -1,164 +1,164 @@
-#pragma once
-
-#include "MWTExplorer.h"
-#include "utility.h"
-#include <iomanip>
-#include <iostream>
-#include <sstream>
-
-using namespace MultiWorldTesting;
-
-class TestContext
-{
-
-};
-
-template <class Ctx>
-struct TestInteraction
-{
- Ctx& Context;
- u32 Action;
- float Probability;
- string Unique_Key;
-};
-
-class TestPolicy : public IPolicy<TestContext>
-{
-public:
- TestPolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- u32 Choose_Action(TestContext& context)
- {
- return m_params % m_num_actions + 1; // action id is one-based
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestScorer : public IScorer<TestContext>
-{
-public:
- TestScorer(int params, int num_actions, bool uniform = true) :
- m_params(params), m_num_actions(num_actions), m_uniform(uniform)
- { }
-
- vector<float> Score_Actions(TestContext& context)
- {
- vector<float> scores;
- if (m_uniform)
- {
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params);
- }
- }
- else
- {
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params + i);
- }
- }
- return scores;
- }
-private:
- int m_params;
- int m_num_actions;
- bool m_uniform;
-};
-
-class FixedScorer : public IScorer<TestContext>
-{
-public:
- FixedScorer(int num_actions, int value) :
- m_num_actions(num_actions), m_value(value)
- { }
-
- vector<float> Score_Actions(TestContext& context)
- {
- vector<float> scores;
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back((float)m_value);
- }
- return scores;
- }
-private:
- int m_num_actions;
- int m_value;
-};
-
-class TestSimpleScorer : public IScorer<SimpleContext>
-{
-public:
- TestSimpleScorer(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- vector<float> Score_Actions(SimpleContext& context)
- {
- vector<float> scores;
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params);
- }
- return scores;
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestSimplePolicy : public IPolicy<SimpleContext>
-{
-public:
- TestSimplePolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- u32 Choose_Action(SimpleContext& context)
- {
- return m_params % m_num_actions + 1; // action id is one-based
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestSimpleRecorder : public IRecorder<SimpleContext>
-{
-public:
- virtual void Record(SimpleContext& context, u32 action, float probability, string unique_key)
- {
- m_interactions.push_back({ context, action, probability, unique_key });
- }
-
- vector<TestInteraction<SimpleContext>> Get_All_Interactions()
- {
- return m_interactions;
- }
-
-private:
- vector<TestInteraction<SimpleContext>> m_interactions;
-};
-
-// Return action outside valid range
-class TestBadPolicy : public IPolicy<TestContext>
-{
-public:
- u32 Choose_Action(TestContext& context)
- {
- return 100;
- }
-};
-
-class TestRecorder : public IRecorder<TestContext>
-{
-public:
- virtual void Record(TestContext& context, u32 action, float probability, string unique_key)
- {
- m_interactions.push_back({ context, action, probability, unique_key });
- }
-
- vector<TestInteraction<TestContext>> Get_All_Interactions()
- {
- return m_interactions;
- }
-
-private:
- vector<TestInteraction<TestContext>> m_interactions;
-};
+#pragma once
+
+#include "MWTExplorer.h"
+#include "utility.h"
+#include <iomanip>
+#include <iostream>
+#include <sstream>
+
+using namespace MultiWorldTesting;
+
+class TestContext
+{
+
+};
+
+template <class Ctx>
+struct TestInteraction
+{
+ Ctx& Context;
+ u32 Action;
+ float Probability;
+ string Unique_Key;
+};
+
+class TestPolicy : public IPolicy<TestContext>
+{
+public:
+ TestPolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
+ u32 Choose_Action(TestContext& context)
+ {
+ return m_params % m_num_actions + 1; // action id is one-based
+ }
+private:
+ int m_params;
+ int m_num_actions;
+};
+
+class TestScorer : public IScorer<TestContext>
+{
+public:
+ TestScorer(int params, int num_actions, bool uniform = true) :
+ m_params(params), m_num_actions(num_actions), m_uniform(uniform)
+ { }
+
+ vector<float> Score_Actions(TestContext& context)
+ {
+ vector<float> scores;
+ if (m_uniform)
+ {
+ for (u32 i = 0; i < m_num_actions; i++)
+ {
+ scores.push_back(m_params);
+ }
+ }
+ else
+ {
+ for (u32 i = 0; i < m_num_actions; i++)
+ {
+ scores.push_back(m_params + i);
+ }
+ }
+ return scores;
+ }
+private:
+ int m_params;
+ int m_num_actions;
+ bool m_uniform;
+};
+
+class FixedScorer : public IScorer<TestContext>
+{
+public:
+ FixedScorer(int num_actions, int value) :
+ m_num_actions(num_actions), m_value(value)
+ { }
+
+ vector<float> Score_Actions(TestContext& context)
+ {
+ vector<float> scores;
+ for (u32 i = 0; i < m_num_actions; i++)
+ {
+ scores.push_back((float)m_value);
+ }
+ return scores;
+ }
+private:
+ int m_num_actions;
+ int m_value;
+};
+
+class TestSimpleScorer : public IScorer<SimpleContext>
+{
+public:
+ TestSimpleScorer(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
+ vector<float> Score_Actions(SimpleContext& context)
+ {
+ vector<float> scores;
+ for (u32 i = 0; i < m_num_actions; i++)
+ {
+ scores.push_back(m_params);
+ }
+ return scores;
+ }
+private:
+ int m_params;
+ int m_num_actions;
+};
+
+class TestSimplePolicy : public IPolicy<SimpleContext>
+{
+public:
+ TestSimplePolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
+ u32 Choose_Action(SimpleContext& context)
+ {
+ return m_params % m_num_actions + 1; // action id is one-based
+ }
+private:
+ int m_params;
+ int m_num_actions;
+};
+
+class TestSimpleRecorder : public IRecorder<SimpleContext>
+{
+public:
+ virtual void Record(SimpleContext& context, u32 action, float probability, string unique_key)
+ {
+ m_interactions.push_back({ context, action, probability, unique_key });
+ }
+
+ vector<TestInteraction<SimpleContext>> Get_All_Interactions()
+ {
+ return m_interactions;
+ }
+
+private:
+ vector<TestInteraction<SimpleContext>> m_interactions;
+};
+
+// Return action outside valid range
+class TestBadPolicy : public IPolicy<TestContext>
+{
+public:
+ u32 Choose_Action(TestContext& context)
+ {
+ return 100;
+ }
+};
+
+class TestRecorder : public IRecorder<TestContext>
+{
+public:
+ virtual void Record(TestContext& context, u32 action, float probability, string unique_key)
+ {
+ m_interactions.push_back({ context, action, probability, unique_key });
+ }
+
+ vector<TestInteraction<TestContext>> Get_All_Interactions()
+ {
+ return m_interactions;
+ }
+
+private:
+ vector<TestInteraction<TestContext>> m_interactions;
+};
diff --git a/library/Makefile b/library/Makefile
index 8f7deb35..5ca18f35 100644
--- a/library/Makefile
+++ b/library/Makefile
@@ -40,3 +40,4 @@ gd_mf_weights: gd_mf_weights.cc ../vowpalwabbit/libvw.a
clean:
rm -f *.o ezexample_predict ezexample_train library_example test_search recommend ezexample_predict_threaded
+.PHONY: all clean
diff --git a/library/ezexample_predict.cc b/library/ezexample_predict.cc
index db061f61..22c8a86d 100644
--- a/library/ezexample_predict.cc
+++ b/library/ezexample_predict.cc
@@ -1,55 +1,55 @@
-#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-using namespace std;
-
-int main(int argc, char *argv[])
-{
- string init_string = "-t -q st --hash all --noconstant --ldf_override s -i ";
- if (argc > 1)
- init_string += argv[1];
- else
- init_string += "train.w";
-
- cerr << "initializing with: '" << init_string << "'" << endl;
-
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
- vw* vw = VW::initialize(init_string); // "-t -q st --hash all --noconstant --ldf_override s -i train.w");
-
- {
- // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
- ezexample ex(vw, false); // don't need multiline
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- cerr << ex.predict_partial() << endl;
-
- // ex.clear_features();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("2");
- cerr << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace, and add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("2");
- cerr << ex.predict_partial() << endl;
- }
-
- // AND FINISH UP
- VW::finish(*vw);
-}
+#include <stdio.h>
+#include "../vowpalwabbit/parser.h"
+#include "../vowpalwabbit/vw.h"
+#include "../vowpalwabbit/ezexample.h"
+
+using namespace std;
+
+int main(int argc, char *argv[])
+{
+ string init_string = "-t -q st --hash all --noconstant --ldf_override s -i ";
+ if (argc > 1)
+ init_string += argv[1];
+ else
+ init_string += "train.w";
+
+ cerr << "initializing with: '" << init_string << "'" << endl;
+
+ // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
+ vw* vw = VW::initialize(init_string); // "-t -q st --hash all --noconstant --ldf_override s -i train.w");
+
+ {
+ // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
+ ezexample ex(vw, false); // don't need multiline
+ ex(vw_namespace('s'))
+ ("p^the_man")
+ ("w^the")
+ ("w^man")
+ (vw_namespace('t'))
+ ("p^le_homme")
+ ("w^le")
+ ("w^homme");
+ ex.set_label("1");
+ cerr << ex.predict_partial() << endl;
+
+ // ex.clear_features();
+
+ --ex; // remove the most recent namespace
+ ex(vw_namespace('t'))
+ ("p^un_homme")
+ ("w^un")
+ ("w^homme");
+ ex.set_label("2");
+ cerr << ex.predict_partial() << endl;
+
+ --ex; // remove the most recent namespace, and add features with explicit ns
+ ex('t', "p^un_homme")
+ ('t', "w^un")
+ ('t', "w^homme");
+ ex.set_label("2");
+ cerr << ex.predict_partial() << endl;
+ }
+
+ // AND FINISH UP
+ VW::finish(*vw);
+}
diff --git a/library/ezexample_predict_threaded.cc b/library/ezexample_predict_threaded.cc
index 0fa5b1e6..c7c39e85 100644
--- a/library/ezexample_predict_threaded.cc
+++ b/library/ezexample_predict_threaded.cc
@@ -1,149 +1,149 @@
-#include <stdio.h>
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-#include <boost/thread/thread.hpp>
-
-using namespace std;
-
-int runcount = 100;
-
-class Worker
-{
-public:
- Worker(vw & instance, string & vw_init_string, vector<double> & ref)
- : m_vw(instance)
- , m_referenceValues(ref)
- , vw_init_string(vw_init_string)
- { }
-
- void operator()()
- {
- m_vw_parser = VW::initialize(vw_init_string);
- if (m_vw_parser == NULL) {
- cerr << "cannot initialize vw parser" << endl;
- exit(-1);
- }
-
- int errorCount = 0;
- for (int i = 0; i < runcount; ++i)
- {
- vector<double>::iterator it = m_referenceValues.begin();
- ezexample ex(&m_vw, false, m_vw_parser);
-
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- //if (*it != pred) { cerr << "fail!" << endl; ++errorCount; }
- //VW::finish_example(m_vw, vec2);
- ++it;
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- ++it;
-
- --ex; // remove the most recent namespace
- // add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- ++it;
-
- //cout << "."; cout.flush();
- }
- cerr << "error count = " << errorCount << endl;
- VW::finish(*m_vw_parser);
- m_vw_parser = NULL;
- }
-
-private:
- vw & m_vw;
- vw * m_vw_parser;
- vector<double> & m_referenceValues;
- string & vw_init_string;
-};
-
-int main(int argc, char *argv[])
-{
- if (argc != 3)
- {
- cerr << "need two args: threadcount runcount" << endl;
- return 1;
- }
- int threadcount = atoi(argv[1]);
- runcount = atoi(argv[2]);
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
- string vw_init_string_all = "-t --ldf_override s --quiet -q st --noconstant --hash all -i train.w";
- string vw_init_string_parser = "-t --ldf_override s --quiet -q st --noconstant --hash all --noop"; // this needs to have enough arguments to get the parser right
- vw*vw = VW::initialize(vw_init_string_all);
- vector<double> results;
-
- // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
- {
- ezexample ex(vw, false);
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near zero = " << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near one = " << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace
- // add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near one = " << ex.predict_partial() << endl;
- }
-
- if (threadcount == 0)
- {
- Worker w(*vw, vw_init_string_parser, results);
- w();
- }
- else
- {
- boost::thread_group tg;
- for (int t = 0; t < threadcount; ++t)
- {
- cerr << "starting thread " << t << endl;
- boost::thread * pt = tg.create_thread(Worker(*vw, vw_init_string_parser, results));
- }
- tg.join_all();
- cerr << "finished!" << endl;
- }
-
-
- // AND FINISH UP
- VW::finish(*vw);
-}
+#include <stdio.h>
+#include "../vowpalwabbit/vw.h"
+#include "../vowpalwabbit/ezexample.h"
+
+#include <boost/thread/thread.hpp>
+
+using namespace std;
+
+int runcount = 100;
+
+class Worker
+{
+public:
+ Worker(vw & instance, string & vw_init_string, vector<double> & ref)
+ : m_vw(instance)
+ , m_referenceValues(ref)
+ , vw_init_string(vw_init_string)
+ { }
+
+ void operator()()
+ {
+ m_vw_parser = VW::initialize(vw_init_string);
+ if (m_vw_parser == NULL) {
+ cerr << "cannot initialize vw parser" << endl;
+ exit(-1);
+ }
+
+ int errorCount = 0;
+ for (int i = 0; i < runcount; ++i)
+ {
+ vector<double>::iterator it = m_referenceValues.begin();
+ ezexample ex(&m_vw, false, m_vw_parser);
+
+ ex(vw_namespace('s'))
+ ("p^the_man")
+ ("w^the")
+ ("w^man")
+ (vw_namespace('t'))
+ ("p^le_homme")
+ ("w^le")
+ ("w^homme");
+ ex.set_label("1");
+ if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
+ //if (*it != pred) { cerr << "fail!" << endl; ++errorCount; }
+ //VW::finish_example(m_vw, vec2);
+ ++it;
+
+ --ex; // remove the most recent namespace
+ ex(vw_namespace('t'))
+ ("p^un_homme")
+ ("w^un")
+ ("w^homme");
+ ex.set_label("1");
+ if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
+ ++it;
+
+ --ex; // remove the most recent namespace
+ // add features with explicit ns
+ ex('t', "p^un_homme")
+ ('t', "w^un")
+ ('t', "w^homme");
+ ex.set_label("1");
+ if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
+ ++it;
+
+ //cout << "."; cout.flush();
+ }
+ cerr << "error count = " << errorCount << endl;
+ VW::finish(*m_vw_parser);
+ m_vw_parser = NULL;
+ }
+
+private:
+ vw & m_vw;
+ vw * m_vw_parser;
+ vector<double> & m_referenceValues;
+ string & vw_init_string;
+};
+
+int main(int argc, char *argv[])
+{
+ if (argc != 3)
+ {
+ cerr << "need two args: threadcount runcount" << endl;
+ return 1;
+ }
+ int threadcount = atoi(argv[1]);
+ runcount = atoi(argv[2]);
+ // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
+ string vw_init_string_all = "-t --ldf_override s --quiet -q st --noconstant --hash all -i train.w";
+ string vw_init_string_parser = "-t --ldf_override s --quiet -q st --noconstant --hash all --noop"; // this needs to have enough arguments to get the parser right
+ vw*vw = VW::initialize(vw_init_string_all);
+ vector<double> results;
+
+ // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
+ {
+ ezexample ex(vw, false);
+ ex(vw_namespace('s'))
+ ("p^the_man")
+ ("w^the")
+ ("w^man")
+ (vw_namespace('t'))
+ ("p^le_homme")
+ ("w^le")
+ ("w^homme");
+ ex.set_label("1");
+ results.push_back(ex.predict_partial());
+ cerr << "should be near zero = " << ex.predict_partial() << endl;
+
+ --ex; // remove the most recent namespace
+ ex(vw_namespace('t'))
+ ("p^un_homme")
+ ("w^un")
+ ("w^homme");
+ ex.set_label("1");
+ results.push_back(ex.predict_partial());
+ cerr << "should be near one = " << ex.predict_partial() << endl;
+
+ --ex; // remove the most recent namespace
+ // add features with explicit ns
+ ex('t', "p^un_homme")
+ ('t', "w^un")
+ ('t', "w^homme");
+ ex.set_label("1");
+ results.push_back(ex.predict_partial());
+ cerr << "should be near one = " << ex.predict_partial() << endl;
+ }
+
+ if (threadcount == 0)
+ {
+ Worker w(*vw, vw_init_string_parser, results);
+ w();
+ }
+ else
+ {
+ boost::thread_group tg;
+ for (int t = 0; t < threadcount; ++t)
+ {
+ cerr << "starting thread " << t << endl;
+ boost::thread * pt = tg.create_thread(Worker(*vw, vw_init_string_parser, results));
+ }
+ tg.join_all();
+ cerr << "finished!" << endl;
+ }
+
+
+ // AND FINISH UP
+ VW::finish(*vw);
+}
diff --git a/library/ezexample_train.cc b/library/ezexample_train.cc
index a0f66a99..9a8af8e0 100644
--- a/library/ezexample_train.cc
+++ b/library/ezexample_train.cc
@@ -1,72 +1,72 @@
-#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-using namespace std;
-
-void run(vw*vw) {
- ezexample ex(vw, true); // we're doing csoaa_ldf so we need multiline examples
-
- /// BEGIN FIRST MULTILINE EXAMPLE
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
-
- ex.set_label("1:1");
- ex.train();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
-
- ex.set_label("2:0");
- ex.train();
-
- // push it through VW for training
- ex.finish();
-
- /// BEGIN SECOND MULTILINE EXAMPLE
- ex(vw_namespace('s'))
- ("p^a_man")
- ("w^a")
- ("w^man")
- (vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
-
- ex.set_label("1:0");
- ex.train();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
-
- ex.set_label("2:1");
- ex.train();
-
- // push it through VW for training
- ex.finish();
-}
-
-int main(int argc, char *argv[])
-{
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS WILL STORE A MODEL TO train.ezw
- vw* vw = VW::initialize("--hash all -q st --noconstant -f train.w --quiet --csoaa_ldf m");
-
- run(vw);
-
- // AND FINISH UP
- cerr << "ezexample_train finish"<<endl;
- VW::finish(*vw);
-}
+#include <stdio.h>
+#include "../vowpalwabbit/parser.h"
+#include "../vowpalwabbit/vw.h"
+#include "../vowpalwabbit/ezexample.h"
+
+using namespace std;
+
+void run(vw*vw) {
+ ezexample ex(vw, true); // we're doing csoaa_ldf so we need multiline examples
+
+ /// BEGIN FIRST MULTILINE EXAMPLE
+ ex(vw_namespace('s'))
+ ("p^the_man")
+ ("w^the")
+ ("w^man")
+ (vw_namespace('t'))
+ ("p^un_homme")
+ ("w^un")
+ ("w^homme");
+
+ ex.set_label("1:1");
+ ex.train();
+
+ --ex; // remove the most recent namespace
+ ex(vw_namespace('t'))
+ ("p^le_homme")
+ ("w^le")
+ ("w^homme");
+
+ ex.set_label("2:0");
+ ex.train();
+
+ // push it through VW for training
+ ex.finish();
+
+ /// BEGIN SECOND MULTILINE EXAMPLE
+ ex(vw_namespace('s'))
+ ("p^a_man")
+ ("w^a")
+ ("w^man")
+ (vw_namespace('t'))
+ ("p^un_homme")
+ ("w^un")
+ ("w^homme");
+
+ ex.set_label("1:0");
+ ex.train();
+
+ --ex; // remove the most recent namespace
+ ex(vw_namespace('t'))
+ ("p^le_homme")
+ ("w^le")
+ ("w^homme");
+
+ ex.set_label("2:1");
+ ex.train();
+
+ // push it through VW for training
+ ex.finish();
+}
+
+int main(int argc, char *argv[])
+{
+ // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS WILL STORE A MODEL TO train.ezw
+ vw* vw = VW::initialize("--hash all -q st --noconstant -f train.w --quiet --csoaa_ldf m");
+
+ run(vw);
+
+ // AND FINISH UP
+ cerr << "ezexample_train finish"<<endl;
+ VW::finish(*vw);
+}
diff --git a/library/library_example.cc b/library/library_example.cc
index 8abfc1f7..d7c186c2 100644
--- a/library/library_example.cc
+++ b/library/library_example.cc
@@ -1,62 +1,62 @@
-#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-
-using namespace std;
-
-
-inline feature vw_feature_from_string(vw& v, string fstr, unsigned long seed, float val)
-{
- uint32_t foo = VW::hash_feature(v, fstr, seed);
- feature f = { val, foo};
- return f;
-}
-
-int main(int argc, char *argv[])
-{
- vw* model = VW::initialize("--hash all -q st --noconstant -i train.w -f train2.vw");
-
- example *vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme");
- model->learn(vec2);
- cerr << "p2 = " << vec2->pred.scalar << endl;
- VW::finish_example(*model, vec2);
-
- vector< VW::feature_space > ec_info;
- vector<feature> s_features, t_features;
- uint32_t s_hash = VW::hash_space(*model, "s");
- uint32_t t_hash = VW::hash_space(*model, "t");
- s_features.push_back( vw_feature_from_string(*model, "p^the_man", s_hash, 1.0) );
- s_features.push_back( vw_feature_from_string(*model, "w^the", s_hash, 1.0) );
- s_features.push_back( vw_feature_from_string(*model, "w^man", s_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "p^le_homme", t_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "w^le", t_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "w^homme", t_hash, 1.0) );
- ec_info.push_back( VW::feature_space('s', s_features) );
- ec_info.push_back( VW::feature_space('t', t_features) );
- example* vec3 = VW::import_example(*model, ec_info);
-
- model->learn(vec3);
- cerr << "p3 = " << vec3->pred.scalar << endl;
- VW::finish_example(*model, vec3);
-
- VW::finish(*model);
-
- vw* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw");
- vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme");
- model2->learn(vec2);
- cerr << "p4 = " << vec2->pred.scalar << endl;
-
- size_t len=0;
- VW::primitive_feature_space* pfs = VW::export_example(*model2, vec2, len);
- for (size_t i = 0; i < len; i++)
- {
- cout << "namespace = " << pfs[i].name;
- for (size_t j = 0; j < pfs[i].len; j++)
- cout << " " << pfs[i].fs[j].weight_index << ":" << pfs[i].fs[j].x << ":" << VW::get_weight(*model2, pfs[i].fs[j].weight_index, 0);
- cout << endl;
- }
-
- VW::finish_example(*model2, vec2);
- VW::finish(*model2);
-}
-
+#include <stdio.h>
+#include "../vowpalwabbit/parser.h"
+#include "../vowpalwabbit/vw.h"
+
+using namespace std;
+
+
+inline feature vw_feature_from_string(vw& v, string fstr, unsigned long seed, float val)
+{
+ uint32_t foo = VW::hash_feature(v, fstr, seed);
+ feature f = { val, foo};
+ return f;
+}
+
+int main(int argc, char *argv[])
+{
+ vw* model = VW::initialize("--hash all -q st --noconstant -i train.w -f train2.vw");
+
+ example *vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme");
+ model->learn(vec2);
+ cerr << "p2 = " << vec2->pred.scalar << endl;
+ VW::finish_example(*model, vec2);
+
+ vector< VW::feature_space > ec_info;
+ vector<feature> s_features, t_features;
+ uint32_t s_hash = VW::hash_space(*model, "s");
+ uint32_t t_hash = VW::hash_space(*model, "t");
+ s_features.push_back( vw_feature_from_string(*model, "p^the_man", s_hash, 1.0) );
+ s_features.push_back( vw_feature_from_string(*model, "w^the", s_hash, 1.0) );
+ s_features.push_back( vw_feature_from_string(*model, "w^man", s_hash, 1.0) );
+ t_features.push_back( vw_feature_from_string(*model, "p^le_homme", t_hash, 1.0) );
+ t_features.push_back( vw_feature_from_string(*model, "w^le", t_hash, 1.0) );
+ t_features.push_back( vw_feature_from_string(*model, "w^homme", t_hash, 1.0) );
+ ec_info.push_back( VW::feature_space('s', s_features) );
+ ec_info.push_back( VW::feature_space('t', t_features) );
+ example* vec3 = VW::import_example(*model, ec_info);
+
+ model->learn(vec3);
+ cerr << "p3 = " << vec3->pred.scalar << endl;
+ VW::finish_example(*model, vec3);
+
+ VW::finish(*model);
+
+ vw* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw");
+ vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme");
+ model2->learn(vec2);
+ cerr << "p4 = " << vec2->pred.scalar << endl;
+
+ size_t len=0;
+ VW::primitive_feature_space* pfs = VW::export_example(*model2, vec2, len);
+ for (size_t i = 0; i < len; i++)
+ {
+ cout << "namespace = " << pfs[i].name;
+ for (size_t j = 0; j < pfs[i].len; j++)
+ cout << " " << pfs[i].fs[j].weight_index << ":" << pfs[i].fs[j].x << ":" << VW::get_weight(*model2, pfs[i].fs[j].weight_index, 0);
+ cout << endl;
+ }
+
+ VW::finish_example(*model2, vec2);
+ VW::finish(*model2);
+}
+
diff --git a/python/Makefile b/python/Makefile
index b268663d..845fb15f 100644
--- a/python/Makefile
+++ b/python/Makefile
@@ -43,3 +43,5 @@ pylibvw.o: pylibvw.cc
clean:
rm -f *.o $(PYLIBVW)
+
+.PHONY: all clean things
diff --git a/vowpalwabbit/Makefile b/vowpalwabbit/Makefile
index 79893902..9ecd681c 100644
--- a/vowpalwabbit/Makefile
+++ b/vowpalwabbit/Makefile
@@ -52,3 +52,4 @@ install: $(BINARIES)
clean:
rm -f *.o *.d $(BINARIES) *~ $(MANPAGES) libvw.a
+.PHONY: all clean install test things
diff --git a/vowpalwabbit/accumulate.h b/vowpalwabbit/accumulate.h
index c01ac5fe..4d507a60 100644
--- a/vowpalwabbit/accumulate.h
+++ b/vowpalwabbit/accumulate.h
@@ -1,13 +1,13 @@
-/*
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD
-license as described in the file LICENSE.
- */
-//This implements various accumulate functions building on top of allreduce.
-#pragma once
-#include "global_data.h"
-
-void accumulate(vw& all, std::string master_location, regressor& reg, size_t o);
-float accumulate_scalar(vw& all, std::string master_location, float local_sum);
-void accumulate_weighted_avg(vw& all, std::string master_location, regressor& reg);
-void accumulate_avg(vw& all, std::string master_location, regressor& reg, size_t o);
+/*
+Copyright (c) by respective owners including Yahoo!, Microsoft, and
+individual contributors. All rights reserved. Released under a BSD
+license as described in the file LICENSE.
+ */
+//This implements various accumulate functions building on top of allreduce.
+#pragma once
+#include "global_data.h"
+
+void accumulate(vw& all, std::string master_location, regressor& reg, size_t o);
+float accumulate_scalar(vw& all, std::string master_location, float local_sum);
+void accumulate_weighted_avg(vw& all, std::string master_location, regressor& reg);
+void accumulate_avg(vw& all, std::string master_location, regressor& reg, size_t o);