Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsidsen <sid.sen1@gmail.com>2014-10-18 00:16:19 +0400
committersidsen <sid.sen1@gmail.com>2014-10-18 00:16:19 +0400
commit82c67032c6cb99a0c35a715b6dcced37c165f8a1 (patch)
tree1565dbc4e4ff012549d639a27d81d7f9a7467192 /cs_test
parent37edecccbfedb13cec7c88cfcee36b001fa39c68 (diff)
Fix build break related to salt addition
Diffstat (limited to 'cs_test')
-rw-r--r--cs_test/ExploreSample.cs458
-rwxr-xr-xcs_test/LabDemo.cs420
2 files changed, 439 insertions, 439 deletions
diff --git a/cs_test/ExploreSample.cs b/cs_test/ExploreSample.cs
index dee7d54a..0ae4f243 100644
--- a/cs_test/ExploreSample.cs
+++ b/cs_test/ExploreSample.cs
@@ -1,229 +1,229 @@
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using MultiWorldTesting;
-
-namespace cs_test
-{
- class ExploreSample
- {
- private static UInt32 SampleStatefulPolicyFunc(int parameters, CONTEXT context)
- {
- return (uint)((parameters + context.Features.Length) % 10 + 1);
- }
-
- private static UInt32 SampleStatefulPolicyFunc2(int parameters, CONTEXT context)
- {
- return (uint)((parameters + context.Features.Length) % 10 + 2);
- }
-
- private static UInt32 SampleStatefulPolicyFunc(CustomParams parameters, CONTEXT context)
- {
- return (uint)((parameters.Value1 + parameters.Value2 + context.Features.Length) % 10 + 1);
- }
-
- private static UInt32 SampleStatelessPolicyFunc(CONTEXT applicationContext)
- {
- return (UInt32)applicationContext.Features.Length;
- }
-
- private static UInt32 SampleStatelessPolicyFunc2(CONTEXT applicationContext)
- {
- return (UInt32)applicationContext.Features.Length + 1;
- }
-
- private static void SampleStatefulScorerFunc(int parameters, CONTEXT applicationContext, float[] scores)
- {
- for (uint i = 0; i < scores.Length; i++)
- {
- scores[i] = (int)parameters + i;
- }
- }
-
- private static void SampleStatelessScorerFunc(CONTEXT applicationContext, float[] scores)
- {
- for (uint i = 0; i < scores.Length; i++)
- {
- scores[i] = applicationContext.Features.Length + i;
- }
- }
-
- class CustomParams
- {
- public int Value1;
- public int Value2;
- }
-
- public static void Run()
- {
- string interactionFile = "serialized.txt";
- MwtLogger logger = new MwtLogger(interactionFile);
-
- MwtExplorer mwt = new MwtExplorer(logger);
-
- uint numActions = 10;
-
- float epsilon = 0.2f;
- uint tau = 0;
- uint bags = 2;
- float lambda = 0.5f;
-
- int policyParams = 1003;
- CustomParams customParams = new CustomParams() { Value1 = policyParams, Value2 = policyParams + 1 };
-
- /*** Initialize Epsilon-Greedy explore algorithm using a default policy function that accepts parameters ***/
- mwt.InitializeEpsilonGreedy<int>(epsilon, new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), policyParams, numActions);
-
- /*** Initialize Epsilon-Greedy explore algorithm using a stateless default policy function ***/
- //mwt.InitializeEpsilonGreedy(epsilon, new StatelessPolicyDelegate(SampleStatelessPolicyFunc), numActions);
-
- /*** Initialize Tau-First explore algorithm using a default policy function that accepts parameters ***/
- //mwt.InitializeTauFirst<CustomParams>(tau, new StatefulPolicyDelegate<CustomParams>(SampleStatefulPolicyFunc), customParams, numActions);
-
- /*** Initialize Tau-First explore algorithm using a stateless default policy function ***/
- //mwt.InitializeTauFirst(tau, new StatelessPolicyDelegate(SampleStatelessPolicyFunc), numActions);
-
- /*** Initialize Bagging explore algorithm using a default policy function that accepts parameters ***/
- //StatefulPolicyDelegate<int>[] funcs =
- //{
- // new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc),
- // new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc2)
- //};
- //int[] parameters = { policyParams, policyParams };
- //mwt.InitializeBagging<int>(bags, funcs, parameters, numActions);
-
- /*** Initialize Bagging explore algorithm using a stateless default policy function ***/
- //StatelessPolicyDelegate[] funcs =
- //{
- // new StatelessPolicyDelegate(SampleStatelessPolicyFunc),
- // new StatelessPolicyDelegate(SampleStatelessPolicyFunc2)
- //};
- //mwt.InitializeBagging(bags, funcs, numActions);
-
- /*** Initialize Softmax explore algorithm using a default policy function that accepts parameters ***/
- //mwt.InitializeSoftmax<int>(lambda, new StatefulScorerDelegate<int>(SampleStatefulScorerFunc), policyParams, numActions);
-
- /*** Initialize Softmax explore algorithm using a stateless default policy function ***/
- //mwt.InitializeSoftmax(lambda, new StatelessScorerDelegate(SampleStatelessScorerFunc), numActions);
-
- FEATURE[] f = new FEATURE[2];
- f[0].X = 0.5f;
- f[0].Index = 1;
- f[1].X = 0.9f;
- f[1].Index = 2;
-
- string otherContext = "Some other context data that might be helpful to log";
- CONTEXT context = new CONTEXT(f, otherContext);
-
- UInt32 chosenAction = mwt.ChooseAction(context, "myId");
-
- INTERACTION[] interactions = mwt.GetAllInteractions();
-
- mwt.Unintialize();
-
- MwtRewardReporter mrr = new MwtRewardReporter(interactions);
-
- string joinKey = "myId";
- float reward = 0.5f;
- if (!mrr.ReportReward(joinKey, reward))
- {
- throw new Exception();
- }
-
- MwtOptimizer mot = new MwtOptimizer(interactions, numActions);
-
- float eval1 = mot.EvaluatePolicy(new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), policyParams);
-
- mot.OptimizePolicyVWCSOAA("model_file");
- float eval2 = mot.EvaluatePolicyVWCSOAA("model_file");
-
- Console.WriteLine(chosenAction);
- Console.WriteLine(interactions);
-
- logger.Flush();
-
- // Create a new logger to read back interaction data
- logger = new MwtLogger(interactionFile);
- INTERACTION[] inters = logger.GetAllInteractions();
-
- // Load and save reward data to file
- string rewardFile = "rewards.txt";
- RewardStore rewardStore = new RewardStore(rewardFile);
- rewardStore.Add(new float[2] { 1.0f, 0.4f });
- rewardStore.Flush();
-
- // Read back reward data
- rewardStore = new RewardStore(rewardFile);
- float[] rewards = rewardStore.GetAllRewards();
- }
-
- public static void Clock()
- {
- float epsilon = .2f;
- int policyParams = 1003;
- string uniqueKey = "clock";
- int numFeatures = 1000;
- int numIter = 1000;
- int numWarmup = 100;
- int numInteractions = 1;
- uint numActions = 10;
- string otherContext = null;
-
- double timeInit = 0, timeChoose = 0, timeSerializedLog = 0, timeTypedLog = 0;
-
- System.Diagnostics.Stopwatch watch = new System.Diagnostics.Stopwatch();
- for (int iter = 0; iter < numIter + numWarmup; iter++)
- {
- watch.Restart();
-
- MwtExplorer mwt = new MwtExplorer();
- mwt.InitializeEpsilonGreedy<int>(epsilon, new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), policyParams, numActions);
-
- timeInit += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;
-
- FEATURE[] f = new FEATURE[numFeatures];
- for (int i = 0; i < numFeatures; i++)
- {
- f[i].Index = (uint)i + 1;
- f[i].X = 0.5f;
- }
-
- watch.Restart();
-
- CONTEXT context = new CONTEXT(f, otherContext);
-
- for (int i = 0; i < numInteractions; i++)
- {
- mwt.ChooseAction(context, uniqueKey);
- }
-
- timeChoose += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;
-
- watch.Restart();
-
- string interactions = mwt.GetAllInteractionsAsString();
-
- timeSerializedLog += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;
-
- for (int i = 0; i < numInteractions; i++)
- {
- mwt.ChooseAction(context, uniqueKey);
- }
-
- watch.Restart();
-
- mwt.GetAllInteractions();
-
- timeTypedLog += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;
- }
- Console.WriteLine("--- PER ITERATION ---");
- Console.WriteLine("# iterations: {0}, # interactions: {1}, # context features {2}", numIter, numInteractions, numFeatures);
- Console.WriteLine("Init: {0} micro", timeInit * 1000 / numIter);
- Console.WriteLine("Choose Action: {0} micro", timeChoose * 1000 / (numIter * numInteractions));
- Console.WriteLine("Get Serialized Log: {0} micro", timeSerializedLog * 1000 / numIter);
- Console.WriteLine("Get Typed Log: {0} micro", timeTypedLog * 1000 / numIter);
- Console.WriteLine("--- TOTAL TIME: {0} micro", (timeInit + timeChoose + timeSerializedLog + timeTypedLog) * 1000);
- }
- }
-}
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using MultiWorldTesting;
+
+namespace cs_test
+{
+ class ExploreSample
+ {
+ private static UInt32 SampleStatefulPolicyFunc(int parameters, CONTEXT context)
+ {
+ return (uint)((parameters + context.Features.Length) % 10 + 1);
+ }
+
+ private static UInt32 SampleStatefulPolicyFunc2(int parameters, CONTEXT context)
+ {
+ return (uint)((parameters + context.Features.Length) % 10 + 2);
+ }
+
+ private static UInt32 SampleStatefulPolicyFunc(CustomParams parameters, CONTEXT context)
+ {
+ return (uint)((parameters.Value1 + parameters.Value2 + context.Features.Length) % 10 + 1);
+ }
+
+ private static UInt32 SampleStatelessPolicyFunc(CONTEXT applicationContext)
+ {
+ return (UInt32)applicationContext.Features.Length;
+ }
+
+ private static UInt32 SampleStatelessPolicyFunc2(CONTEXT applicationContext)
+ {
+ return (UInt32)applicationContext.Features.Length + 1;
+ }
+
+ private static void SampleStatefulScorerFunc(int parameters, CONTEXT applicationContext, float[] scores)
+ {
+ for (uint i = 0; i < scores.Length; i++)
+ {
+ scores[i] = (int)parameters + i;
+ }
+ }
+
+ private static void SampleStatelessScorerFunc(CONTEXT applicationContext, float[] scores)
+ {
+ for (uint i = 0; i < scores.Length; i++)
+ {
+ scores[i] = applicationContext.Features.Length + i;
+ }
+ }
+
+ class CustomParams
+ {
+ public int Value1;
+ public int Value2;
+ }
+
+ public static void Run()
+ {
+ string interactionFile = "serialized.txt";
+ MwtLogger logger = new MwtLogger(interactionFile);
+
+ MwtExplorer mwt = new MwtExplorer("test", logger);
+
+ uint numActions = 10;
+
+ float epsilon = 0.2f;
+ uint tau = 0;
+ uint bags = 2;
+ float lambda = 0.5f;
+
+ int policyParams = 1003;
+ CustomParams customParams = new CustomParams() { Value1 = policyParams, Value2 = policyParams + 1 };
+
+ /*** Initialize Epsilon-Greedy explore algorithm using a default policy function that accepts parameters ***/
+ mwt.InitializeEpsilonGreedy<int>(epsilon, new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), policyParams, numActions);
+
+ /*** Initialize Epsilon-Greedy explore algorithm using a stateless default policy function ***/
+ //mwt.InitializeEpsilonGreedy(epsilon, new StatelessPolicyDelegate(SampleStatelessPolicyFunc), numActions);
+
+ /*** Initialize Tau-First explore algorithm using a default policy function that accepts parameters ***/
+ //mwt.InitializeTauFirst<CustomParams>(tau, new StatefulPolicyDelegate<CustomParams>(SampleStatefulPolicyFunc), customParams, numActions);
+
+ /*** Initialize Tau-First explore algorithm using a stateless default policy function ***/
+ //mwt.InitializeTauFirst(tau, new StatelessPolicyDelegate(SampleStatelessPolicyFunc), numActions);
+
+ /*** Initialize Bagging explore algorithm using a default policy function that accepts parameters ***/
+ //StatefulPolicyDelegate<int>[] funcs =
+ //{
+ // new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc),
+ // new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc2)
+ //};
+ //int[] parameters = { policyParams, policyParams };
+ //mwt.InitializeBagging<int>(bags, funcs, parameters, numActions);
+
+ /*** Initialize Bagging explore algorithm using a stateless default policy function ***/
+ //StatelessPolicyDelegate[] funcs =
+ //{
+ // new StatelessPolicyDelegate(SampleStatelessPolicyFunc),
+ // new StatelessPolicyDelegate(SampleStatelessPolicyFunc2)
+ //};
+ //mwt.InitializeBagging(bags, funcs, numActions);
+
+ /*** Initialize Softmax explore algorithm using a default policy function that accepts parameters ***/
+ //mwt.InitializeSoftmax<int>(lambda, new StatefulScorerDelegate<int>(SampleStatefulScorerFunc), policyParams, numActions);
+
+ /*** Initialize Softmax explore algorithm using a stateless default policy function ***/
+ //mwt.InitializeSoftmax(lambda, new StatelessScorerDelegate(SampleStatelessScorerFunc), numActions);
+
+ FEATURE[] f = new FEATURE[2];
+ f[0].X = 0.5f;
+ f[0].Index = 1;
+ f[1].X = 0.9f;
+ f[1].Index = 2;
+
+ string otherContext = "Some other context data that might be helpful to log";
+ CONTEXT context = new CONTEXT(f, otherContext);
+
+ UInt32 chosenAction = mwt.ChooseAction(context, "myId");
+
+ INTERACTION[] interactions = mwt.GetAllInteractions();
+
+ mwt.Unintialize();
+
+ MwtRewardReporter mrr = new MwtRewardReporter(interactions);
+
+ string joinKey = "myId";
+ float reward = 0.5f;
+ if (!mrr.ReportReward(joinKey, reward))
+ {
+ throw new Exception();
+ }
+
+ MwtOptimizer mot = new MwtOptimizer(interactions, numActions);
+
+ float eval1 = mot.EvaluatePolicy(new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), policyParams);
+
+ mot.OptimizePolicyVWCSOAA("model_file");
+ float eval2 = mot.EvaluatePolicyVWCSOAA("model_file");
+
+ Console.WriteLine(chosenAction);
+ Console.WriteLine(interactions);
+
+ logger.Flush();
+
+ // Create a new logger to read back interaction data
+ logger = new MwtLogger(interactionFile);
+ INTERACTION[] inters = logger.GetAllInteractions();
+
+ // Load and save reward data to file
+ string rewardFile = "rewards.txt";
+ RewardStore rewardStore = new RewardStore(rewardFile);
+ rewardStore.Add(new float[2] { 1.0f, 0.4f });
+ rewardStore.Flush();
+
+ // Read back reward data
+ rewardStore = new RewardStore(rewardFile);
+ float[] rewards = rewardStore.GetAllRewards();
+ }
+
+ public static void Clock()
+ {
+ float epsilon = .2f;
+ int policyParams = 1003;
+ string uniqueKey = "clock";
+ int numFeatures = 1000;
+ int numIter = 1000;
+ int numWarmup = 100;
+ int numInteractions = 1;
+ uint numActions = 10;
+ string otherContext = null;
+
+ double timeInit = 0, timeChoose = 0, timeSerializedLog = 0, timeTypedLog = 0;
+
+ System.Diagnostics.Stopwatch watch = new System.Diagnostics.Stopwatch();
+ for (int iter = 0; iter < numIter + numWarmup; iter++)
+ {
+ watch.Restart();
+
+ MwtExplorer mwt = new MwtExplorer("test");
+ mwt.InitializeEpsilonGreedy<int>(epsilon, new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), policyParams, numActions);
+
+ timeInit += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;
+
+ FEATURE[] f = new FEATURE[numFeatures];
+ for (int i = 0; i < numFeatures; i++)
+ {
+ f[i].Index = (uint)i + 1;
+ f[i].X = 0.5f;
+ }
+
+ watch.Restart();
+
+ CONTEXT context = new CONTEXT(f, otherContext);
+
+ for (int i = 0; i < numInteractions; i++)
+ {
+ mwt.ChooseAction(context, uniqueKey);
+ }
+
+ timeChoose += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;
+
+ watch.Restart();
+
+ string interactions = mwt.GetAllInteractionsAsString();
+
+ timeSerializedLog += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;
+
+ for (int i = 0; i < numInteractions; i++)
+ {
+ mwt.ChooseAction(context, uniqueKey);
+ }
+
+ watch.Restart();
+
+ mwt.GetAllInteractions();
+
+ timeTypedLog += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;
+ }
+ Console.WriteLine("--- PER ITERATION ---");
+ Console.WriteLine("# iterations: {0}, # interactions: {1}, # context features {2}", numIter, numInteractions, numFeatures);
+ Console.WriteLine("Init: {0} micro", timeInit * 1000 / numIter);
+ Console.WriteLine("Choose Action: {0} micro", timeChoose * 1000 / (numIter * numInteractions));
+ Console.WriteLine("Get Serialized Log: {0} micro", timeSerializedLog * 1000 / numIter);
+ Console.WriteLine("Get Typed Log: {0} micro", timeTypedLog * 1000 / numIter);
+ Console.WriteLine("--- TOTAL TIME: {0} micro", (timeInit + timeChoose + timeSerializedLog + timeTypedLog) * 1000);
+ }
+ }
+}
diff --git a/cs_test/LabDemo.cs b/cs_test/LabDemo.cs
index 75f0e110..06d80ee6 100755
--- a/cs_test/LabDemo.cs
+++ b/cs_test/LabDemo.cs
@@ -1,210 +1,210 @@
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.IO;
-using MultiWorldTesting;
-
-
-public class LabDemo
-{
-
- class IOUtils
- {
- public string contextfile, rewardfile;
- public List<CONTEXT> contexts;
- public List<float[]> rewards;
- private const int numActions = 8;
- private static int cur_id = 0;
- public IOUtils(string cfile, string rfile)
- {
- contextfile = cfile;
- rewardfile = rfile;
- contexts = new List<CONTEXT>();
- rewards = new List<float[]>();
- }
-
- public void ParseContexts()
- {
- Console.WriteLine("Parsing contexts");
- CONTEXT c;
- using (StreamReader sr = new StreamReader(contextfile))
- {
- String line;
- int ex_num = 0;
- while ((line = sr.ReadLine()) != null)
- {
- //Console.WriteLine(line);
- char[] delims = { ' ', '\t' };
- List<FEATURE> featureList = new List<FEATURE>();
- string[] features = line.Split(delims);
- foreach (string s in features)
- {
- char[] feat_delim = { ':' };
- string[] words = s.Split(feat_delim);
- //Console.Write("{0} ", words.Length);
- if (words.Length >= 1 && words[0] != "")
- {
- FEATURE f = new FEATURE();
- //Console.WriteLine("{0}", words[0]);
- f.Index = UInt32.Parse(words[0]);
- if (words.Length == 2)
- f.X = float.Parse(words[1]);
- else
- f.X = (float)1.0;
- featureList.Add(f);
- }
- }
- c = new CONTEXT(featureList.ToArray(), null);
- contexts.Add(c);
-
- }
- }
-
- Console.WriteLine("Read {0} contexts", contexts.Count);
- }
-
- public void ParseRewards()
- {
- Console.WriteLine("Parsing rewards");
- using (StreamReader sr = new StreamReader(rewardfile))
- {
- String line;
-
- while ((line = sr.ReadLine()) != null)
- {
- //Console.WriteLine(line);
- char[] delims = { '\t' };
- float[] reward_arr = new float[numActions];
- string[] reward_strings = line.Split(delims);
- int i = 0;
- foreach (string s in reward_strings)
- {
- reward_arr[i++] = float.Parse(s);
- if (i == numActions) break;
- }
- rewards.Add(reward_arr);
- }
- }
- }
-
- public CONTEXT getContext()
- {
- if (contexts.Count == 0)
- ParseContexts();
-
- if (cur_id < contexts.Count)
- return contexts[cur_id++];
- else
- return null;
- }
-
- public float getReward(uint action, uint uid)
- {
- if (rewards.Count == 0)
- ParseRewards();
-
- //Console.WriteLine("Read {0} rewards, uid = {1}, action = {2}", rewards.Count, uid, action);
-
- if (uid >= rewards.Count)
- Console.WriteLine("Found illegal uid {0}", uid);
-
- return rewards[(int)(uid)][action-1];
- }
-
- }
-
- private static UInt32 ScoreBasedPolicy(float threshold, CONTEXT context)
- {
-
- int score_begin = context.Features.Length - 4;
- float base_val = context.Features[score_begin].X;
- uint num_action = 1;
- for(int i = 1;i < 4;i++)
- {
- if (context.Features[score_begin+i].X >= base_val * threshold)
- {
- num_action += (uint)Math.Pow(2, i - 1);
- }
- }
- return num_action;
- }
-
- public static void Run()
- {
- MwtExplorer mwt = new MwtExplorer();
-
- uint numActions = 8;
- float epsilon = 0.1f;
- float policyParams = 0.1f;
-
- mwt.InitializeEpsilonGreedy<float>(epsilon, new StatefulPolicyDelegate<float>(ScoreBasedPolicy), policyParams, numActions);
- IOUtils iou = new IOUtils(@"..\Release\speller-contexts", @"..\Release\speller-rewards");
-
- CONTEXT c;
- uint uniqueID = 1;
- while ((c = iou.getContext()) != null)
- {
- uint action = mwt.ChooseAction(c, uniqueID.ToString());
- //Console.WriteLine("Taking action {0} on id {1}", action,uniqueID-1);
- uniqueID++;
- }
-
- INTERACTION[] interactions = mwt.GetAllInteractions();
-
- MwtRewardReporter rewardReporter = new MwtRewardReporter(interactions);
- for (uint iInter = 0; iInter < interactions.Length; iInter++)
- {
- float r = iou.getReward(interactions[iInter].ChosenAction,iInter);
- //Console.WriteLine("Got reward on interaction {0} with Action {1} as {2}", iInter, interactions[iInter].ChosenAction,r);
- rewardReporter.ReportReward(interactions[iInter].Id, r);
- }
-
- INTERACTION[] full_interactions = rewardReporter.GetAllInteractions();
-
- //for (uint iInter = 0; iInter < full_interactions.Length; iInter++)
- //{
- // Console.WriteLine("Stored reward on interaction {0} with Action {1} as {2}", iInter, full_interactions[iInter].ChosenAction, full_interactions[iInter].Reward);
- // Console.WriteLine("Action of default policy on this context = {0}", ScoreBasedPolicy(policyParams, full_interactions[iInter].ApplicationContext));
- //}
-
- MwtOptimizer mwtopt = new MwtOptimizer(full_interactions, numActions);
- float val = mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 0.1f);
- if (val == 0)
- Console.WriteLine("ZERO!!");
- Console.WriteLine("Value of default policy = {0}", val);
- Console.WriteLine("Value of default policy and threshold 0.2 = {0}", mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 0.2f));
- Console.WriteLine("Value of default policy and threshold 0.05 = {0}", mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 0.05f));
- Console.WriteLine("Value of default policy and threshold 1 = {0}", mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 1.0f));
-
- Console.WriteLine("Now we will optimize");
- mwtopt.OptimizePolicyVWCSOAA("model");
- Console.WriteLine("Done with optimization, now we will evaluate the optimized model");
- Console.WriteLine("Value of optimized policy using VW = {0}", mwtopt.EvaluatePolicyVWCSOAA("model"));
- Console.ReadKey();
- }
-
- private static CONTEXT GetContext()
- {
- return null;
- }
-
- private static List<float[]> rewardList;
- private float GetReward(uint action, uint uniqueID)
- {
- if (rewardList == null)
- {
- rewardList = new List<float[]>();
- // Read reward from file and populate the list
- }
- //
- return rewardList[(int)uniqueID][action];
- }
-
- private static void DoAction(uint action, uint uniqueID)
- {
- // Performs the action
- }
-
-
-}
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.IO;
+using MultiWorldTesting;
+
+
+public class LabDemo
+{
+
+ class IOUtils
+ {
+ public string contextfile, rewardfile;
+ public List<CONTEXT> contexts;
+ public List<float[]> rewards;
+ private const int numActions = 8;
+ private static int cur_id = 0;
+ public IOUtils(string cfile, string rfile)
+ {
+ contextfile = cfile;
+ rewardfile = rfile;
+ contexts = new List<CONTEXT>();
+ rewards = new List<float[]>();
+ }
+
+ public void ParseContexts()
+ {
+ Console.WriteLine("Parsing contexts");
+ CONTEXT c;
+ using (StreamReader sr = new StreamReader(contextfile))
+ {
+ String line;
+ int ex_num = 0;
+ while ((line = sr.ReadLine()) != null)
+ {
+ //Console.WriteLine(line);
+ char[] delims = { ' ', '\t' };
+ List<FEATURE> featureList = new List<FEATURE>();
+ string[] features = line.Split(delims);
+ foreach (string s in features)
+ {
+ char[] feat_delim = { ':' };
+ string[] words = s.Split(feat_delim);
+ //Console.Write("{0} ", words.Length);
+ if (words.Length >= 1 && words[0] != "")
+ {
+ FEATURE f = new FEATURE();
+ //Console.WriteLine("{0}", words[0]);
+ f.Index = UInt32.Parse(words[0]);
+ if (words.Length == 2)
+ f.X = float.Parse(words[1]);
+ else
+ f.X = (float)1.0;
+ featureList.Add(f);
+ }
+ }
+ c = new CONTEXT(featureList.ToArray(), null);
+ contexts.Add(c);
+
+ }
+ }
+
+ Console.WriteLine("Read {0} contexts", contexts.Count);
+ }
+
+ public void ParseRewards()
+ {
+ Console.WriteLine("Parsing rewards");
+ using (StreamReader sr = new StreamReader(rewardfile))
+ {
+ String line;
+
+ while ((line = sr.ReadLine()) != null)
+ {
+ //Console.WriteLine(line);
+ char[] delims = { '\t' };
+ float[] reward_arr = new float[numActions];
+ string[] reward_strings = line.Split(delims);
+ int i = 0;
+ foreach (string s in reward_strings)
+ {
+ reward_arr[i++] = float.Parse(s);
+ if (i == numActions) break;
+ }
+ rewards.Add(reward_arr);
+ }
+ }
+ }
+
+ public CONTEXT getContext()
+ {
+ if (contexts.Count == 0)
+ ParseContexts();
+
+ if (cur_id < contexts.Count)
+ return contexts[cur_id++];
+ else
+ return null;
+ }
+
+ public float getReward(uint action, uint uid)
+ {
+ if (rewards.Count == 0)
+ ParseRewards();
+
+ //Console.WriteLine("Read {0} rewards, uid = {1}, action = {2}", rewards.Count, uid, action);
+
+ if (uid >= rewards.Count)
+ Console.WriteLine("Found illegal uid {0}", uid);
+
+ return rewards[(int)(uid)][action-1];
+ }
+
+ }
+
+ private static UInt32 ScoreBasedPolicy(float threshold, CONTEXT context)
+ {
+
+ int score_begin = context.Features.Length - 4;
+ float base_val = context.Features[score_begin].X;
+ uint num_action = 1;
+ for(int i = 1;i < 4;i++)
+ {
+ if (context.Features[score_begin+i].X >= base_val * threshold)
+ {
+ num_action += (uint)Math.Pow(2, i - 1);
+ }
+ }
+ return num_action;
+ }
+
+ public static void Run()
+ {
+ MwtExplorer mwt = new MwtExplorer("test");
+
+ uint numActions = 8;
+ float epsilon = 0.1f;
+ float policyParams = 0.1f;
+
+ mwt.InitializeEpsilonGreedy<float>(epsilon, new StatefulPolicyDelegate<float>(ScoreBasedPolicy), policyParams, numActions);
+ IOUtils iou = new IOUtils(@"..\Release\speller-contexts", @"..\Release\speller-rewards");
+
+ CONTEXT c;
+ uint uniqueID = 1;
+ while ((c = iou.getContext()) != null)
+ {
+ uint action = mwt.ChooseAction(c, uniqueID.ToString());
+ //Console.WriteLine("Taking action {0} on id {1}", action,uniqueID-1);
+ uniqueID++;
+ }
+
+ INTERACTION[] interactions = mwt.GetAllInteractions();
+
+ MwtRewardReporter rewardReporter = new MwtRewardReporter(interactions);
+ for (uint iInter = 0; iInter < interactions.Length; iInter++)
+ {
+ float r = iou.getReward(interactions[iInter].ChosenAction,iInter);
+ //Console.WriteLine("Got reward on interaction {0} with Action {1} as {2}", iInter, interactions[iInter].ChosenAction,r);
+ rewardReporter.ReportReward(interactions[iInter].Id, r);
+ }
+
+ INTERACTION[] full_interactions = rewardReporter.GetAllInteractions();
+
+ //for (uint iInter = 0; iInter < full_interactions.Length; iInter++)
+ //{
+ // Console.WriteLine("Stored reward on interaction {0} with Action {1} as {2}", iInter, full_interactions[iInter].ChosenAction, full_interactions[iInter].Reward);
+ // Console.WriteLine("Action of default policy on this context = {0}", ScoreBasedPolicy(policyParams, full_interactions[iInter].ApplicationContext));
+ //}
+
+ MwtOptimizer mwtopt = new MwtOptimizer(full_interactions, numActions);
+ float val = mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 0.1f);
+ if (val == 0)
+ Console.WriteLine("ZERO!!");
+ Console.WriteLine("Value of default policy = {0}", val);
+ Console.WriteLine("Value of default policy and threshold 0.2 = {0}", mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 0.2f));
+ Console.WriteLine("Value of default policy and threshold 0.05 = {0}", mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 0.05f));
+ Console.WriteLine("Value of default policy and threshold 1 = {0}", mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 1.0f));
+
+ Console.WriteLine("Now we will optimize");
+ mwtopt.OptimizePolicyVWCSOAA("model");
+ Console.WriteLine("Done with optimization, now we will evaluate the optimized model");
+ Console.WriteLine("Value of optimized policy using VW = {0}", mwtopt.EvaluatePolicyVWCSOAA("model"));
+ Console.ReadKey();
+ }
+
+ private static CONTEXT GetContext()
+ {
+ return null;
+ }
+
+ private static List<float[]> rewardList;
+ private float GetReward(uint action, uint uniqueID)
+ {
+ if (rewardList == null)
+ {
+ rewardList = new List<float[]>();
+ // Read reward from file and populate the list
+ }
+ //
+ return rewardList[(int)uniqueID][action];
+ }
+
+ private static void DoAction(uint action, uint uniqueID)
+ {
+ // Performs the action
+ }
+
+
+}