diff options
author | sidsen <sid.sen1@gmail.com> | 2014-10-18 00:16:19 +0400 |
---|---|---|
committer | sidsen <sid.sen1@gmail.com> | 2014-10-18 00:16:19 +0400 |
commit | 82c67032c6cb99a0c35a715b6dcced37c165f8a1 (patch) | |
tree | 1565dbc4e4ff012549d639a27d81d7f9a7467192 /cs_test | |
parent | 37edecccbfedb13cec7c88cfcee36b001fa39c68 (diff) |
Fix build break related to salt addition
Diffstat (limited to 'cs_test')
-rw-r--r-- | cs_test/ExploreSample.cs | 458 | ||||
-rwxr-xr-x | cs_test/LabDemo.cs | 420 |
2 files changed, 439 insertions, 439 deletions
diff --git a/cs_test/ExploreSample.cs b/cs_test/ExploreSample.cs index dee7d54a..0ae4f243 100644 --- a/cs_test/ExploreSample.cs +++ b/cs_test/ExploreSample.cs @@ -1,229 +1,229 @@ -using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using MultiWorldTesting;
-
-namespace cs_test
-{
- class ExploreSample
- {
- private static UInt32 SampleStatefulPolicyFunc(int parameters, CONTEXT context)
- {
- return (uint)((parameters + context.Features.Length) % 10 + 1);
- }
-
- private static UInt32 SampleStatefulPolicyFunc2(int parameters, CONTEXT context)
- {
- return (uint)((parameters + context.Features.Length) % 10 + 2);
- }
-
- private static UInt32 SampleStatefulPolicyFunc(CustomParams parameters, CONTEXT context)
- {
- return (uint)((parameters.Value1 + parameters.Value2 + context.Features.Length) % 10 + 1);
- }
-
- private static UInt32 SampleStatelessPolicyFunc(CONTEXT applicationContext)
- {
- return (UInt32)applicationContext.Features.Length;
- }
-
- private static UInt32 SampleStatelessPolicyFunc2(CONTEXT applicationContext)
- {
- return (UInt32)applicationContext.Features.Length + 1;
- }
-
- private static void SampleStatefulScorerFunc(int parameters, CONTEXT applicationContext, float[] scores)
- {
- for (uint i = 0; i < scores.Length; i++)
- {
- scores[i] = (int)parameters + i;
- }
- }
-
- private static void SampleStatelessScorerFunc(CONTEXT applicationContext, float[] scores)
- {
- for (uint i = 0; i < scores.Length; i++)
- {
- scores[i] = applicationContext.Features.Length + i;
- }
- }
-
- class CustomParams
- {
- public int Value1;
- public int Value2;
- }
-
- public static void Run()
- {
- string interactionFile = "serialized.txt";
- MwtLogger logger = new MwtLogger(interactionFile);
-
- MwtExplorer mwt = new MwtExplorer(logger);
-
- uint numActions = 10;
-
- float epsilon = 0.2f;
- uint tau = 0;
- uint bags = 2;
- float lambda = 0.5f;
-
- int policyParams = 1003;
- CustomParams customParams = new CustomParams() { Value1 = policyParams, Value2 = policyParams + 1 };
-
- /*** Initialize Epsilon-Greedy explore algorithm using a default policy function that accepts parameters ***/
- mwt.InitializeEpsilonGreedy<int>(epsilon, new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), policyParams, numActions);
-
- /*** Initialize Epsilon-Greedy explore algorithm using a stateless default policy function ***/
- //mwt.InitializeEpsilonGreedy(epsilon, new StatelessPolicyDelegate(SampleStatelessPolicyFunc), numActions);
-
- /*** Initialize Tau-First explore algorithm using a default policy function that accepts parameters ***/
- //mwt.InitializeTauFirst<CustomParams>(tau, new StatefulPolicyDelegate<CustomParams>(SampleStatefulPolicyFunc), customParams, numActions);
-
- /*** Initialize Tau-First explore algorithm using a stateless default policy function ***/
- //mwt.InitializeTauFirst(tau, new StatelessPolicyDelegate(SampleStatelessPolicyFunc), numActions);
-
- /*** Initialize Bagging explore algorithm using a default policy function that accepts parameters ***/
- //StatefulPolicyDelegate<int>[] funcs =
- //{
- // new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc),
- // new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc2)
- //};
- //int[] parameters = { policyParams, policyParams };
- //mwt.InitializeBagging<int>(bags, funcs, parameters, numActions);
-
- /*** Initialize Bagging explore algorithm using a stateless default policy function ***/
- //StatelessPolicyDelegate[] funcs =
- //{
- // new StatelessPolicyDelegate(SampleStatelessPolicyFunc),
- // new StatelessPolicyDelegate(SampleStatelessPolicyFunc2)
- //};
- //mwt.InitializeBagging(bags, funcs, numActions);
-
- /*** Initialize Softmax explore algorithm using a default policy function that accepts parameters ***/
- //mwt.InitializeSoftmax<int>(lambda, new StatefulScorerDelegate<int>(SampleStatefulScorerFunc), policyParams, numActions);
-
- /*** Initialize Softmax explore algorithm using a stateless default policy function ***/
- //mwt.InitializeSoftmax(lambda, new StatelessScorerDelegate(SampleStatelessScorerFunc), numActions);
-
- FEATURE[] f = new FEATURE[2];
- f[0].X = 0.5f;
- f[0].Index = 1;
- f[1].X = 0.9f;
- f[1].Index = 2;
-
- string otherContext = "Some other context data that might be helpful to log";
- CONTEXT context = new CONTEXT(f, otherContext);
-
- UInt32 chosenAction = mwt.ChooseAction(context, "myId");
-
- INTERACTION[] interactions = mwt.GetAllInteractions();
-
- mwt.Unintialize();
-
- MwtRewardReporter mrr = new MwtRewardReporter(interactions);
-
- string joinKey = "myId";
- float reward = 0.5f;
- if (!mrr.ReportReward(joinKey, reward))
- {
- throw new Exception();
- }
-
- MwtOptimizer mot = new MwtOptimizer(interactions, numActions);
-
- float eval1 = mot.EvaluatePolicy(new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), policyParams);
-
- mot.OptimizePolicyVWCSOAA("model_file");
- float eval2 = mot.EvaluatePolicyVWCSOAA("model_file");
-
- Console.WriteLine(chosenAction);
- Console.WriteLine(interactions);
-
- logger.Flush();
-
- // Create a new logger to read back interaction data
- logger = new MwtLogger(interactionFile);
- INTERACTION[] inters = logger.GetAllInteractions();
-
- // Load and save reward data to file
- string rewardFile = "rewards.txt";
- RewardStore rewardStore = new RewardStore(rewardFile);
- rewardStore.Add(new float[2] { 1.0f, 0.4f });
- rewardStore.Flush();
-
- // Read back reward data
- rewardStore = new RewardStore(rewardFile);
- float[] rewards = rewardStore.GetAllRewards();
- }
-
- public static void Clock()
- {
- float epsilon = .2f;
- int policyParams = 1003;
- string uniqueKey = "clock";
- int numFeatures = 1000;
- int numIter = 1000;
- int numWarmup = 100;
- int numInteractions = 1;
- uint numActions = 10;
- string otherContext = null;
-
- double timeInit = 0, timeChoose = 0, timeSerializedLog = 0, timeTypedLog = 0;
-
- System.Diagnostics.Stopwatch watch = new System.Diagnostics.Stopwatch();
- for (int iter = 0; iter < numIter + numWarmup; iter++)
- {
- watch.Restart();
-
- MwtExplorer mwt = new MwtExplorer();
- mwt.InitializeEpsilonGreedy<int>(epsilon, new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), policyParams, numActions);
-
- timeInit += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;
-
- FEATURE[] f = new FEATURE[numFeatures];
- for (int i = 0; i < numFeatures; i++)
- {
- f[i].Index = (uint)i + 1;
- f[i].X = 0.5f;
- }
-
- watch.Restart();
-
- CONTEXT context = new CONTEXT(f, otherContext);
-
- for (int i = 0; i < numInteractions; i++)
- {
- mwt.ChooseAction(context, uniqueKey);
- }
-
- timeChoose += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;
-
- watch.Restart();
-
- string interactions = mwt.GetAllInteractionsAsString();
-
- timeSerializedLog += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;
-
- for (int i = 0; i < numInteractions; i++)
- {
- mwt.ChooseAction(context, uniqueKey);
- }
-
- watch.Restart();
-
- mwt.GetAllInteractions();
-
- timeTypedLog += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;
- }
- Console.WriteLine("--- PER ITERATION ---");
- Console.WriteLine("# iterations: {0}, # interactions: {1}, # context features {2}", numIter, numInteractions, numFeatures);
- Console.WriteLine("Init: {0} micro", timeInit * 1000 / numIter);
- Console.WriteLine("Choose Action: {0} micro", timeChoose * 1000 / (numIter * numInteractions));
- Console.WriteLine("Get Serialized Log: {0} micro", timeSerializedLog * 1000 / numIter);
- Console.WriteLine("Get Typed Log: {0} micro", timeTypedLog * 1000 / numIter);
- Console.WriteLine("--- TOTAL TIME: {0} micro", (timeInit + timeChoose + timeSerializedLog + timeTypedLog) * 1000);
- }
- }
-}
+using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using MultiWorldTesting; + +namespace cs_test +{ + class ExploreSample + { + private static UInt32 SampleStatefulPolicyFunc(int parameters, CONTEXT context) + { + return (uint)((parameters + context.Features.Length) % 10 + 1); + } + + private static UInt32 SampleStatefulPolicyFunc2(int parameters, CONTEXT context) + { + return (uint)((parameters + context.Features.Length) % 10 + 2); + } + + private static UInt32 SampleStatefulPolicyFunc(CustomParams parameters, CONTEXT context) + { + return (uint)((parameters.Value1 + parameters.Value2 + context.Features.Length) % 10 + 1); + } + + private static UInt32 SampleStatelessPolicyFunc(CONTEXT applicationContext) + { + return (UInt32)applicationContext.Features.Length; + } + + private static UInt32 SampleStatelessPolicyFunc2(CONTEXT applicationContext) + { + return (UInt32)applicationContext.Features.Length + 1; + } + + private static void SampleStatefulScorerFunc(int parameters, CONTEXT applicationContext, float[] scores) + { + for (uint i = 0; i < scores.Length; i++) + { + scores[i] = (int)parameters + i; + } + } + + private static void SampleStatelessScorerFunc(CONTEXT applicationContext, float[] scores) + { + for (uint i = 0; i < scores.Length; i++) + { + scores[i] = applicationContext.Features.Length + i; + } + } + + class CustomParams + { + public int Value1; + public int Value2; + } + + public static void Run() + { + string interactionFile = "serialized.txt"; + MwtLogger logger = new MwtLogger(interactionFile); + + MwtExplorer mwt = new MwtExplorer("test", logger); + + uint numActions = 10; + + float epsilon = 0.2f; + uint tau = 0; + uint bags = 2; + float lambda = 0.5f; + + int policyParams = 1003; + CustomParams customParams = new CustomParams() { Value1 = policyParams, Value2 = policyParams + 1 }; + + /*** Initialize Epsilon-Greedy explore algorithm using a default policy function that accepts parameters ***/ + mwt.InitializeEpsilonGreedy<int>(epsilon, new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), policyParams, numActions); + + /*** Initialize Epsilon-Greedy explore algorithm using a stateless default policy function ***/ + //mwt.InitializeEpsilonGreedy(epsilon, new StatelessPolicyDelegate(SampleStatelessPolicyFunc), numActions); + + /*** Initialize Tau-First explore algorithm using a default policy function that accepts parameters ***/ + //mwt.InitializeTauFirst<CustomParams>(tau, new StatefulPolicyDelegate<CustomParams>(SampleStatefulPolicyFunc), customParams, numActions); + + /*** Initialize Tau-First explore algorithm using a stateless default policy function ***/ + //mwt.InitializeTauFirst(tau, new StatelessPolicyDelegate(SampleStatelessPolicyFunc), numActions); + + /*** Initialize Bagging explore algorithm using a default policy function that accepts parameters ***/ + //StatefulPolicyDelegate<int>[] funcs = + //{ + // new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), + // new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc2) + //}; + //int[] parameters = { policyParams, policyParams }; + //mwt.InitializeBagging<int>(bags, funcs, parameters, numActions); + + /*** Initialize Bagging explore algorithm using a stateless default policy function ***/ + //StatelessPolicyDelegate[] funcs = + //{ + // new StatelessPolicyDelegate(SampleStatelessPolicyFunc), + // new StatelessPolicyDelegate(SampleStatelessPolicyFunc2) + //}; + //mwt.InitializeBagging(bags, funcs, numActions); + + /*** Initialize Softmax explore algorithm using a default policy function that accepts parameters ***/ + //mwt.InitializeSoftmax<int>(lambda, new StatefulScorerDelegate<int>(SampleStatefulScorerFunc), policyParams, numActions); + + /*** Initialize Softmax explore algorithm using a stateless default policy function ***/ + //mwt.InitializeSoftmax(lambda, new StatelessScorerDelegate(SampleStatelessScorerFunc), numActions); + + FEATURE[] f = new FEATURE[2]; + f[0].X = 0.5f; + f[0].Index = 1; + f[1].X = 0.9f; + f[1].Index = 2; + + string otherContext = "Some other context data that might be helpful to log"; + CONTEXT context = new CONTEXT(f, otherContext); + + UInt32 chosenAction = mwt.ChooseAction(context, "myId"); + + INTERACTION[] interactions = mwt.GetAllInteractions(); + + mwt.Unintialize(); + + MwtRewardReporter mrr = new MwtRewardReporter(interactions); + + string joinKey = "myId"; + float reward = 0.5f; + if (!mrr.ReportReward(joinKey, reward)) + { + throw new Exception(); + } + + MwtOptimizer mot = new MwtOptimizer(interactions, numActions); + + float eval1 = mot.EvaluatePolicy(new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), policyParams); + + mot.OptimizePolicyVWCSOAA("model_file"); + float eval2 = mot.EvaluatePolicyVWCSOAA("model_file"); + + Console.WriteLine(chosenAction); + Console.WriteLine(interactions); + + logger.Flush(); + + // Create a new logger to read back interaction data + logger = new MwtLogger(interactionFile); + INTERACTION[] inters = logger.GetAllInteractions(); + + // Load and save reward data to file + string rewardFile = "rewards.txt"; + RewardStore rewardStore = new RewardStore(rewardFile); + rewardStore.Add(new float[2] { 1.0f, 0.4f }); + rewardStore.Flush(); + + // Read back reward data + rewardStore = new RewardStore(rewardFile); + float[] rewards = rewardStore.GetAllRewards(); + } + + public static void Clock() + { + float epsilon = .2f; + int policyParams = 1003; + string uniqueKey = "clock"; + int numFeatures = 1000; + int numIter = 1000; + int numWarmup = 100; + int numInteractions = 1; + uint numActions = 10; + string otherContext = null; + + double timeInit = 0, timeChoose = 0, timeSerializedLog = 0, timeTypedLog = 0; + + System.Diagnostics.Stopwatch watch = new System.Diagnostics.Stopwatch(); + for (int iter = 0; iter < numIter + numWarmup; iter++) + { + watch.Restart(); + + MwtExplorer mwt = new MwtExplorer("test"); + mwt.InitializeEpsilonGreedy<int>(epsilon, new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), policyParams, numActions); + + timeInit += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds; + + FEATURE[] f = new FEATURE[numFeatures]; + for (int i = 0; i < numFeatures; i++) + { + f[i].Index = (uint)i + 1; + f[i].X = 0.5f; + } + + watch.Restart(); + + CONTEXT context = new CONTEXT(f, otherContext); + + for (int i = 0; i < numInteractions; i++) + { + mwt.ChooseAction(context, uniqueKey); + } + + timeChoose += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds; + + watch.Restart(); + + string interactions = mwt.GetAllInteractionsAsString(); + + timeSerializedLog += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds; + + for (int i = 0; i < numInteractions; i++) + { + mwt.ChooseAction(context, uniqueKey); + } + + watch.Restart(); + + mwt.GetAllInteractions(); + + timeTypedLog += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds; + } + Console.WriteLine("--- PER ITERATION ---"); + Console.WriteLine("# iterations: {0}, # interactions: {1}, # context features {2}", numIter, numInteractions, numFeatures); + Console.WriteLine("Init: {0} micro", timeInit * 1000 / numIter); + Console.WriteLine("Choose Action: {0} micro", timeChoose * 1000 / (numIter * numInteractions)); + Console.WriteLine("Get Serialized Log: {0} micro", timeSerializedLog * 1000 / numIter); + Console.WriteLine("Get Typed Log: {0} micro", timeTypedLog * 1000 / numIter); + Console.WriteLine("--- TOTAL TIME: {0} micro", (timeInit + timeChoose + timeSerializedLog + timeTypedLog) * 1000); + } + } +} diff --git a/cs_test/LabDemo.cs b/cs_test/LabDemo.cs index 75f0e110..06d80ee6 100755 --- a/cs_test/LabDemo.cs +++ b/cs_test/LabDemo.cs @@ -1,210 +1,210 @@ -using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.IO;
-using MultiWorldTesting;
-
-
-public class LabDemo
-{
-
- class IOUtils
- {
- public string contextfile, rewardfile;
- public List<CONTEXT> contexts;
- public List<float[]> rewards;
- private const int numActions = 8;
- private static int cur_id = 0;
- public IOUtils(string cfile, string rfile)
- {
- contextfile = cfile;
- rewardfile = rfile;
- contexts = new List<CONTEXT>();
- rewards = new List<float[]>();
- }
-
- public void ParseContexts()
- {
- Console.WriteLine("Parsing contexts");
- CONTEXT c;
- using (StreamReader sr = new StreamReader(contextfile))
- {
- String line;
- int ex_num = 0;
- while ((line = sr.ReadLine()) != null)
- {
- //Console.WriteLine(line);
- char[] delims = { ' ', '\t' };
- List<FEATURE> featureList = new List<FEATURE>();
- string[] features = line.Split(delims);
- foreach (string s in features)
- {
- char[] feat_delim = { ':' };
- string[] words = s.Split(feat_delim);
- //Console.Write("{0} ", words.Length);
- if (words.Length >= 1 && words[0] != "")
- {
- FEATURE f = new FEATURE();
- //Console.WriteLine("{0}", words[0]);
- f.Index = UInt32.Parse(words[0]);
- if (words.Length == 2)
- f.X = float.Parse(words[1]);
- else
- f.X = (float)1.0;
- featureList.Add(f);
- }
- }
- c = new CONTEXT(featureList.ToArray(), null);
- contexts.Add(c);
-
- }
- }
-
- Console.WriteLine("Read {0} contexts", contexts.Count);
- }
-
- public void ParseRewards()
- {
- Console.WriteLine("Parsing rewards");
- using (StreamReader sr = new StreamReader(rewardfile))
- {
- String line;
-
- while ((line = sr.ReadLine()) != null)
- {
- //Console.WriteLine(line);
- char[] delims = { '\t' };
- float[] reward_arr = new float[numActions];
- string[] reward_strings = line.Split(delims);
- int i = 0;
- foreach (string s in reward_strings)
- {
- reward_arr[i++] = float.Parse(s);
- if (i == numActions) break;
- }
- rewards.Add(reward_arr);
- }
- }
- }
-
- public CONTEXT getContext()
- {
- if (contexts.Count == 0)
- ParseContexts();
-
- if (cur_id < contexts.Count)
- return contexts[cur_id++];
- else
- return null;
- }
-
- public float getReward(uint action, uint uid)
- {
- if (rewards.Count == 0)
- ParseRewards();
-
- //Console.WriteLine("Read {0} rewards, uid = {1}, action = {2}", rewards.Count, uid, action);
-
- if (uid >= rewards.Count)
- Console.WriteLine("Found illegal uid {0}", uid);
-
- return rewards[(int)(uid)][action-1];
- }
-
- }
-
- private static UInt32 ScoreBasedPolicy(float threshold, CONTEXT context)
- {
-
- int score_begin = context.Features.Length - 4;
- float base_val = context.Features[score_begin].X;
- uint num_action = 1;
- for(int i = 1;i < 4;i++)
- {
- if (context.Features[score_begin+i].X >= base_val * threshold)
- {
- num_action += (uint)Math.Pow(2, i - 1);
- }
- }
- return num_action;
- }
-
- public static void Run()
- {
- MwtExplorer mwt = new MwtExplorer();
-
- uint numActions = 8;
- float epsilon = 0.1f;
- float policyParams = 0.1f;
-
- mwt.InitializeEpsilonGreedy<float>(epsilon, new StatefulPolicyDelegate<float>(ScoreBasedPolicy), policyParams, numActions);
- IOUtils iou = new IOUtils(@"..\Release\speller-contexts", @"..\Release\speller-rewards");
-
- CONTEXT c;
- uint uniqueID = 1;
- while ((c = iou.getContext()) != null)
- {
- uint action = mwt.ChooseAction(c, uniqueID.ToString());
- //Console.WriteLine("Taking action {0} on id {1}", action,uniqueID-1);
- uniqueID++;
- }
-
- INTERACTION[] interactions = mwt.GetAllInteractions();
-
- MwtRewardReporter rewardReporter = new MwtRewardReporter(interactions);
- for (uint iInter = 0; iInter < interactions.Length; iInter++)
- {
- float r = iou.getReward(interactions[iInter].ChosenAction,iInter);
- //Console.WriteLine("Got reward on interaction {0} with Action {1} as {2}", iInter, interactions[iInter].ChosenAction,r);
- rewardReporter.ReportReward(interactions[iInter].Id, r);
- }
-
- INTERACTION[] full_interactions = rewardReporter.GetAllInteractions();
-
- //for (uint iInter = 0; iInter < full_interactions.Length; iInter++)
- //{
- // Console.WriteLine("Stored reward on interaction {0} with Action {1} as {2}", iInter, full_interactions[iInter].ChosenAction, full_interactions[iInter].Reward);
- // Console.WriteLine("Action of default policy on this context = {0}", ScoreBasedPolicy(policyParams, full_interactions[iInter].ApplicationContext));
- //}
-
- MwtOptimizer mwtopt = new MwtOptimizer(full_interactions, numActions);
- float val = mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 0.1f);
- if (val == 0)
- Console.WriteLine("ZERO!!");
- Console.WriteLine("Value of default policy = {0}", val);
- Console.WriteLine("Value of default policy and threshold 0.2 = {0}", mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 0.2f));
- Console.WriteLine("Value of default policy and threshold 0.05 = {0}", mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 0.05f));
- Console.WriteLine("Value of default policy and threshold 1 = {0}", mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 1.0f));
-
- Console.WriteLine("Now we will optimize");
- mwtopt.OptimizePolicyVWCSOAA("model");
- Console.WriteLine("Done with optimization, now we will evaluate the optimized model");
- Console.WriteLine("Value of optimized policy using VW = {0}", mwtopt.EvaluatePolicyVWCSOAA("model"));
- Console.ReadKey();
- }
-
- private static CONTEXT GetContext()
- {
- return null;
- }
-
- private static List<float[]> rewardList;
- private float GetReward(uint action, uint uniqueID)
- {
- if (rewardList == null)
- {
- rewardList = new List<float[]>();
- // Read reward from file and populate the list
- }
- //
- return rewardList[(int)uniqueID][action];
- }
-
- private static void DoAction(uint action, uint uniqueID)
- {
- // Performs the action
- }
-
-
-}
+using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.IO; +using MultiWorldTesting; + + +public class LabDemo +{ + + class IOUtils + { + public string contextfile, rewardfile; + public List<CONTEXT> contexts; + public List<float[]> rewards; + private const int numActions = 8; + private static int cur_id = 0; + public IOUtils(string cfile, string rfile) + { + contextfile = cfile; + rewardfile = rfile; + contexts = new List<CONTEXT>(); + rewards = new List<float[]>(); + } + + public void ParseContexts() + { + Console.WriteLine("Parsing contexts"); + CONTEXT c; + using (StreamReader sr = new StreamReader(contextfile)) + { + String line; + int ex_num = 0; + while ((line = sr.ReadLine()) != null) + { + //Console.WriteLine(line); + char[] delims = { ' ', '\t' }; + List<FEATURE> featureList = new List<FEATURE>(); + string[] features = line.Split(delims); + foreach (string s in features) + { + char[] feat_delim = { ':' }; + string[] words = s.Split(feat_delim); + //Console.Write("{0} ", words.Length); + if (words.Length >= 1 && words[0] != "") + { + FEATURE f = new FEATURE(); + //Console.WriteLine("{0}", words[0]); + f.Index = UInt32.Parse(words[0]); + if (words.Length == 2) + f.X = float.Parse(words[1]); + else + f.X = (float)1.0; + featureList.Add(f); + } + } + c = new CONTEXT(featureList.ToArray(), null); + contexts.Add(c); + + } + } + + Console.WriteLine("Read {0} contexts", contexts.Count); + } + + public void ParseRewards() + { + Console.WriteLine("Parsing rewards"); + using (StreamReader sr = new StreamReader(rewardfile)) + { + String line; + + while ((line = sr.ReadLine()) != null) + { + //Console.WriteLine(line); + char[] delims = { '\t' }; + float[] reward_arr = new float[numActions]; + string[] reward_strings = line.Split(delims); + int i = 0; + foreach (string s in reward_strings) + { + reward_arr[i++] = float.Parse(s); + if (i == numActions) break; + } + rewards.Add(reward_arr); + } + } + } + + public CONTEXT getContext() + { + if (contexts.Count == 0) + ParseContexts(); + + if (cur_id < contexts.Count) + return contexts[cur_id++]; + else + return null; + } + + public float getReward(uint action, uint uid) + { + if (rewards.Count == 0) + ParseRewards(); + + //Console.WriteLine("Read {0} rewards, uid = {1}, action = {2}", rewards.Count, uid, action); + + if (uid >= rewards.Count) + Console.WriteLine("Found illegal uid {0}", uid); + + return rewards[(int)(uid)][action-1]; + } + + } + + private static UInt32 ScoreBasedPolicy(float threshold, CONTEXT context) + { + + int score_begin = context.Features.Length - 4; + float base_val = context.Features[score_begin].X; + uint num_action = 1; + for(int i = 1;i < 4;i++) + { + if (context.Features[score_begin+i].X >= base_val * threshold) + { + num_action += (uint)Math.Pow(2, i - 1); + } + } + return num_action; + } + + public static void Run() + { + MwtExplorer mwt = new MwtExplorer("test"); + + uint numActions = 8; + float epsilon = 0.1f; + float policyParams = 0.1f; + + mwt.InitializeEpsilonGreedy<float>(epsilon, new StatefulPolicyDelegate<float>(ScoreBasedPolicy), policyParams, numActions); + IOUtils iou = new IOUtils(@"..\Release\speller-contexts", @"..\Release\speller-rewards"); + + CONTEXT c; + uint uniqueID = 1; + while ((c = iou.getContext()) != null) + { + uint action = mwt.ChooseAction(c, uniqueID.ToString()); + //Console.WriteLine("Taking action {0} on id {1}", action,uniqueID-1); + uniqueID++; + } + + INTERACTION[] interactions = mwt.GetAllInteractions(); + + MwtRewardReporter rewardReporter = new MwtRewardReporter(interactions); + for (uint iInter = 0; iInter < interactions.Length; iInter++) + { + float r = iou.getReward(interactions[iInter].ChosenAction,iInter); + //Console.WriteLine("Got reward on interaction {0} with Action {1} as {2}", iInter, interactions[iInter].ChosenAction,r); + rewardReporter.ReportReward(interactions[iInter].Id, r); + } + + INTERACTION[] full_interactions = rewardReporter.GetAllInteractions(); + + //for (uint iInter = 0; iInter < full_interactions.Length; iInter++) + //{ + // Console.WriteLine("Stored reward on interaction {0} with Action {1} as {2}", iInter, full_interactions[iInter].ChosenAction, full_interactions[iInter].Reward); + // Console.WriteLine("Action of default policy on this context = {0}", ScoreBasedPolicy(policyParams, full_interactions[iInter].ApplicationContext)); + //} + + MwtOptimizer mwtopt = new MwtOptimizer(full_interactions, numActions); + float val = mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 0.1f); + if (val == 0) + Console.WriteLine("ZERO!!"); + Console.WriteLine("Value of default policy = {0}", val); + Console.WriteLine("Value of default policy and threshold 0.2 = {0}", mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 0.2f)); + Console.WriteLine("Value of default policy and threshold 0.05 = {0}", mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 0.05f)); + Console.WriteLine("Value of default policy and threshold 1 = {0}", mwtopt.EvaluatePolicy<float>(new StatefulPolicyDelegate<float>(ScoreBasedPolicy), 1.0f)); + + Console.WriteLine("Now we will optimize"); + mwtopt.OptimizePolicyVWCSOAA("model"); + Console.WriteLine("Done with optimization, now we will evaluate the optimized model"); + Console.WriteLine("Value of optimized policy using VW = {0}", mwtopt.EvaluatePolicyVWCSOAA("model")); + Console.ReadKey(); + } + + private static CONTEXT GetContext() + { + return null; + } + + private static List<float[]> rewardList; + private float GetReward(uint action, uint uniqueID) + { + if (rewardList == null) + { + rewardList = new List<float[]>(); + // Read reward from file and populate the list + } + // + return rewardList[(int)uniqueID][action]; + } + + private static void DoAction(uint action, uint uniqueID) + { + // Performs the action + } + + +} |