Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSiddhartha Sen <sid.sen1@gmail.com>2014-11-11 08:31:38 +0300
committerSiddhartha Sen <sid.sen1@gmail.com>2014-11-11 08:31:38 +0300
commit1f3ee451c5ff4f23f527105a3d65d4173ff741f8 (patch)
tree70e77c3a071fb15d604c82225c5c053f77e33d5b /explore
parent31fb8d02301cbd069b21bf2ff208edcdc14c773d (diff)
parent401b9dd64820d1bfafa46a8107e055c3bb4c1604 (diff)
Merge branch 'v0' of https://github.com/sidsen/vowpal_wabbit into v0
Diffstat (limited to 'explore')
-rw-r--r--explore/ExploreSample/ExploreSample.csproj3
-rw-r--r--explore/ExploreSample/Program.cs140
-rw-r--r--explore/explore_sample.cpp15
3 files changed, 17 insertions, 141 deletions
diff --git a/explore/ExploreSample/ExploreSample.csproj b/explore/ExploreSample/ExploreSample.csproj
index 0f68feef..9809d62e 100644
--- a/explore/ExploreSample/ExploreSample.csproj
+++ b/explore/ExploreSample/ExploreSample.csproj
@@ -62,6 +62,9 @@
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
+ <Compile Include="..\..\cs_test\ExploreOnlySample.cs">
+ <Link>ExploreOnlySample.cs</Link>
+ </Compile>
<Compile Include="Program.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>
diff --git a/explore/ExploreSample/Program.cs b/explore/ExploreSample/Program.cs
index 40e303a6..21aacfbd 100644
--- a/explore/ExploreSample/Program.cs
+++ b/explore/ExploreSample/Program.cs
@@ -8,147 +8,9 @@ namespace ExploreSample
{
class Program
{
- class MyContext { }
-
- class MyRecorder : IRecorder<MyContext>
- {
- public void Record(MyContext context, UInt32 action, float probability, string uniqueKey)
- {
- actions.Add(action);
- }
-
- public List<uint> GetData()
- {
- return actions;
- }
-
- private List<uint> actions = new List<uint>();
- }
-
- class MyPolicy : IPolicy<MyContext>
- {
- public MyPolicy() : this(-1) { }
-
- public MyPolicy(int index)
- {
- this.index = index;
- }
-
- public uint ChooseAction(MyContext context)
- {
- return 5;
- }
-
- private int index;
- }
-
- class StringPolicy : IPolicy<SimpleContext>
- {
- public uint ChooseAction(SimpleContext context)
- {
- return 1;
- }
- }
-
- class MyScorer : IScorer<MyContext>
- {
- public MyScorer(uint numActions)
- {
- this.numActions = numActions;
- }
- public List<float> ScoreActions(MyContext context)
- {
- return Enumerable.Repeat<float>(1.0f / numActions, (int)numActions).ToList();
- }
- private uint numActions;
- }
-
public static void Main(string[] args)
{
- string exploration_type = "greedy";
-
- if (exploration_type == "greedy")
- {
- // Initialize Epsilon-Greedy explore algorithm using built-in StringRecorder and SimpleContext types
- StringRecorder<SimpleContext> recorder = new StringRecorder<SimpleContext>();
- MwtExplorer<SimpleContext> mwtt = new MwtExplorer<SimpleContext>("mwt", recorder);
-
- uint numActions = 10;
- float epsilon = 0.2f;
- StringPolicy policy = new StringPolicy();
- SimpleContext context = new SimpleContext(new Feature[] {
- new Feature() { Id = 1, Value = 0.5f },
- new Feature() { Id = 4, Value = 1.3f },
- new Feature() { Id = 9, Value = -0.5f },
- });
- uint action = mwtt.ChooseAction(new EpsilonGreedyExplorer<SimpleContext>(policy, epsilon, numActions), "key", context);
-
- Console.WriteLine(recorder.GetRecording());
-
- return;
- }
- else if (exploration_type == "tau-first")
- {
- // Initialize Tau-First explore algorithm using custom Recorder, Policy & Context types
- MyRecorder recorder = new MyRecorder();
- MwtExplorer<MyContext> mwtt = new MwtExplorer<MyContext>("mwt", recorder);
-
- uint numActions = 10;
- uint tau = 0;
- MyPolicy policy = new MyPolicy();
- uint action = mwtt.ChooseAction(new TauFirstExplorer<MyContext>(policy, tau, numActions), "key", new MyContext());
- Console.WriteLine(String.Join(",", recorder.GetData()));
- return;
- }
- else if (exploration_type == "bagging")
- {
- // Initialize Bagging explore algorithm using custom Recorder, Policy & Context types
- MyRecorder recorder = new MyRecorder();
- MwtExplorer<MyContext> mwtt = new MwtExplorer<MyContext>("mwt", recorder);
-
- uint numActions = 10;
- uint numbags = 2;
- MyPolicy[] policies = new MyPolicy[numbags];
- for (int i = 0; i < numbags; i++)
- {
- policies[i] = new MyPolicy(i * 2);
- }
- uint action = mwtt.ChooseAction(new BaggingExplorer<MyContext>(policies, numbags, numActions), "key", new MyContext());
- Console.WriteLine(String.Join(",", recorder.GetData()));
- return;
- }
- else if (exploration_type == "softmax")
- {
- // Initialize Softmax explore algorithm using custom Recorder, Scorer & Context types
- MyRecorder recorder = new MyRecorder();
- MwtExplorer<MyContext> mwtt = new MwtExplorer<MyContext>("mwt", recorder);
-
- uint numActions = 10;
- float lambda = 0.5f;
- MyScorer scorer = new MyScorer(numActions);
- uint action = mwtt.ChooseAction(new SoftmaxExplorer<MyContext>(scorer, lambda, numActions), "key", new MyContext());
-
- Console.WriteLine(String.Join(",", recorder.GetData()));
- return;
- }
- else if (exploration_type == "generic")
- {
- // Initialize Generic explore algorithm using custom Recorder, Scorer & Context types
- MyRecorder recorder = new MyRecorder();
- MwtExplorer<MyContext> mwtt = new MwtExplorer<MyContext>("mwt", recorder);
-
- uint numActions = 10;
- MyScorer scorer = new MyScorer(numActions);
- uint action = mwtt.ChooseAction(new GenericExplorer<MyContext>(scorer, numActions), "key", new MyContext());
-
- Console.WriteLine(String.Join(",", recorder.GetData()));
- return;
- }
- else
- { //add error here
-
-
- }
+ cs_test.ExploreOnlySample.Run();
}
}
}
diff --git a/explore/explore_sample.cpp b/explore/explore_sample.cpp
index 8418792c..e6e04be2 100644
--- a/explore/explore_sample.cpp
+++ b/explore/explore_sample.cpp
@@ -53,20 +53,31 @@ private:
u32 m_num_actions;
};
+template <class Ctx>
+struct MyInteraction
+{
+ Ctx& Context;
+ u32 Action;
+ float Probability;
+ string Unique_Key;
+};
+
class MyRecorder : public IRecorder<MyContext>
{
public:
virtual void Record(MyContext& context, u32 action, float probability, string unique_key)
{
-
+ m_interactions.push_back({ context, action, probability, unique_key });
}
+private:
+ vector<MyInteraction<MyContext>> m_interactions;
};
int main(int argc, char* argv[])
{
if (argc < 2)
{
- cerr << "arguments: {greedy,tau-first,bagging,softmax,generic} [stateful]" << endl;
+ cerr << "arguments: {greedy,tau-first,bagging,softmax,generic}" << endl;
exit(1);
}