diff options
author | Sam Steingold <sds@gnu.org> | 2014-12-23 20:09:43 +0300 |
---|---|---|
committer | Sam Steingold <sds@gnu.org> | 2014-12-23 20:09:43 +0300 |
commit | ed8c4d38aba3b49159f1b2574028b5cbae96a7f2 (patch) | |
tree | 4ecd16430fb82c8d1d67fcccfd887ce79c02eac6 | |
parent | 1452485ae7248f5a12e3ac909aa8a9dedaf26241 (diff) |
convert to unix line endings, like all the other sources
-rw-r--r-- | explore/clr/explore_clr_wrapper.cpp | 36 | ||||
-rw-r--r-- | explore/clr/explore_clr_wrapper.h | 874 | ||||
-rw-r--r-- | explore/clr/explore_interface.h | 174 | ||||
-rw-r--r-- | explore/clr/explore_interop.h | 714 | ||||
-rw-r--r-- | explore/explore.cpp | 146 | ||||
-rw-r--r-- | explore/static/MWTExplorer.h | 1338 | ||||
-rw-r--r-- | explore/static/utility.h | 566 | ||||
-rw-r--r-- | explore/tests/MWTExploreTests.h | 328 | ||||
-rw-r--r-- | library/ezexample_predict.cc | 110 | ||||
-rw-r--r-- | library/ezexample_predict_threaded.cc | 298 | ||||
-rw-r--r-- | library/ezexample_train.cc | 144 | ||||
-rw-r--r-- | library/gd_mf_weights.cc | 242 | ||||
-rw-r--r-- | library/library_example.cc | 124 | ||||
-rw-r--r-- | vowpalwabbit/accumulate.cc | 230 | ||||
-rw-r--r-- | vowpalwabbit/accumulate.h | 26 | ||||
-rw-r--r-- | vowpalwabbit/lda_core.cc | 1604 | ||||
-rw-r--r-- | vowpalwabbit/log_multi.cc | 1098 |
17 files changed, 4026 insertions, 4026 deletions
diff --git a/explore/clr/explore_clr_wrapper.cpp b/explore/clr/explore_clr_wrapper.cpp index 0564482b..be022af1 100644 --- a/explore/clr/explore_clr_wrapper.cpp +++ b/explore/clr/explore_clr_wrapper.cpp @@ -1,18 +1,18 @@ -// vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application.
-//
-
-#define WIN32_LEAN_AND_MEAN
-#include <Windows.h>
-
-#include "explore_clr_wrapper.h"
-
-using namespace System;
-using namespace System::Collections;
-using namespace System::Collections::Generic;
-using namespace System::Runtime::InteropServices;
-using namespace msclr::interop;
-using namespace NativeMultiWorldTesting;
-
-namespace MultiWorldTesting {
-
-}
+// vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application. +// + +#define WIN32_LEAN_AND_MEAN +#include <Windows.h> + +#include "explore_clr_wrapper.h" + +using namespace System; +using namespace System::Collections; +using namespace System::Collections::Generic; +using namespace System::Runtime::InteropServices; +using namespace msclr::interop; +using namespace NativeMultiWorldTesting; + +namespace MultiWorldTesting { + +} diff --git a/explore/clr/explore_clr_wrapper.h b/explore/clr/explore_clr_wrapper.h index a6cede4b..c73c95d4 100644 --- a/explore/clr/explore_clr_wrapper.h +++ b/explore/clr/explore_clr_wrapper.h @@ -1,438 +1,438 @@ -#pragma once
-#include "explore_interop.h"
-
-/*!
-* \addtogroup MultiWorldTestingCsharp
-* @{
-*/
-namespace MultiWorldTesting {
-
- /// <summary>
- /// The epsilon greedy exploration class.
- /// </summary>
- /// <remarks>
- /// This is a good choice if you have no idea which actions should be preferred.
- /// Epsilon greedy is also computationally cheap.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicy">A default function which outputs an action given a context.</param>
- /// <param name="epsilon">The probability of a random exploration.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions)
- {
- this->defaultPolicy = defaultPolicy;
- m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions);
- }
-
- ~EpsilonGreedyExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- return defaultPolicy->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IPolicy<Ctx>^ defaultPolicy;
- NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The tau-first exploration class.
- /// </summary>
- /// <remarks>
- /// The tau-first explorer collects precisely tau uniform random
- /// exploration events, and then uses the default policy.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class TauFirstExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicy">A default policy after randomization finishes.</param>
- /// <param name="tau">The number of events to be uniform over.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions)
- {
- this->defaultPolicy = defaultPolicy;
- m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions);
- }
-
- ~TauFirstExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- return defaultPolicy->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IPolicy<Ctx>^ defaultPolicy;
- NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The epsilon greedy exploration class.
- /// </summary>
- /// <remarks>
- /// In some cases, different actions have a different scores, and you
- /// would prefer to choose actions with large scores. Softmax allows
- /// you to do that.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class SoftmaxExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultScorer">A function which outputs a score for each action.</param>
- /// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions)
- {
- this->defaultScorer = defaultScorer;
- m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions);
- }
-
- ~SoftmaxExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) override
- {
- return defaultScorer->ScoreActions(context);
- }
-
- NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IScorer<Ctx>^ defaultScorer;
- NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The generic exploration class.
- /// </summary>
- /// <remarks>
- /// GenericExplorer provides complete flexibility. You can create any
- /// distribution over actions desired, and it will draw from that.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class GenericExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultScorer">A function which outputs the probability of each action.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions)
- {
- this->defaultScorer = defaultScorer;
- m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions);
- }
-
- ~GenericExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) override
- {
- return defaultScorer->ScoreActions(context);
- }
-
- NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IScorer<Ctx>^ defaultScorer;
- NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The bootstrap exploration class.
- /// </summary>
- /// <remarks>
- /// The Bootstrap explorer randomizes over the actions chosen by a set of
- /// default policies. This performs well statistically but can be
- /// computationally expensive.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class BootstrapExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicies">A set of default policies to be uniform random over.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions)
- {
- this->defaultPolicies = defaultPolicies;
- if (this->defaultPolicies == nullptr)
- {
- throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null.");
- }
-
- m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions);
- }
-
- ~BootstrapExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- if (index < 0 || index >= defaultPolicies->Length)
- {
- throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range.");
- }
- return defaultPolicies[index]->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- cli::array<IPolicy<Ctx>^>^ defaultPolicies;
- NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The top level MwtExplorer class. Using this makes sure that the
- /// right bits are recorded and good random actions are chosen.
- /// </summary>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class MwtExplorer : public RecorderCallback<Ctx>
- {
- public:
- /// <summary>
- /// Constructor.
- /// </summary>
- /// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param>
- /// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param>
- MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder)
- {
- this->appId = appId;
- this->recorder = recorder;
- }
-
- /// <summary>
- /// Choose_Action should be drop-in replacement for any existing policy function.
- /// </summary>
- /// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param>
- /// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param>
- /// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param>
- /// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns>
- UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context)
- {
- String^ salt = this->appId;
- NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder());
-
- GCHandle selfHandle = GCHandle::Alloc(this);
- IntPtr selfPtr = (IntPtr)selfHandle;
-
- GCHandle contextHandle = GCHandle::Alloc(context);
- IntPtr contextPtr = (IntPtr)contextHandle;
-
- GCHandle explorerHandle = GCHandle::Alloc(explorer);
- IntPtr explorerPtr = (IntPtr)explorerHandle;
-
- NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer());
- u32 action = 0;
- if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid)
- {
- EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid)
- {
- TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid)
- {
- SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == GenericExplorer<Ctx>::typeid)
- {
- GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid)
- {
- BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
-
- explorerHandle.Free();
- contextHandle.Free();
- selfHandle.Free();
-
- return action;
- }
-
- internal:
- virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override
- {
- recorder->Record(context, action, probability, unique_key);
- }
-
- private:
- IRecorder<Ctx>^ recorder;
- String^ appId;
- };
-
- /// <summary>
- /// Represents a feature in a sparse array.
- /// </summary>
- [StructLayout(LayoutKind::Sequential)]
- public value struct Feature
- {
- float Value;
- UInt32 Id;
- };
-
- /// <summary>
- /// A sample recorder class that converts the exploration tuple into string format.
- /// </summary>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx> where Ctx : IStringContext
- public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx>
- {
- public:
- StringRecorder()
- {
- m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>();
- }
-
- ~StringRecorder()
- {
- delete m_string_recorder;
- }
-
- virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey)
- {
- GCHandle contextHandle = GCHandle::Alloc(context);
- IntPtr contextPtr = (IntPtr)contextHandle;
-
- NativeStringContext native_context(contextPtr.ToPointer(), GetCallback());
- m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey));
- }
-
- /// <summary>
- /// Gets the content of the recording so far as a string and clears internal content.
- /// </summary>
- /// <returns>
- /// A string with recording content.
- /// </returns>
- String^ GetRecording()
- {
- // Workaround for C++-CLI bug which does not allow default value for parameter
- return GetRecording(true);
- }
-
- /// <summary>
- /// Gets the content of the recording so far as a string and optionally clears internal content.
- /// </summary>
- /// <param name="flush">A boolean value indicating whether to clear the internal content.</param>
- /// <returns>
- /// A string with recording content.
- /// </returns>
- String^ GetRecording(bool flush)
- {
- return gcnew String(m_string_recorder->Get_Recording(flush).c_str());
- }
-
- private:
- NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder;
- };
-
- /// <summary>
- /// A sample context class that stores a vector of Features.
- /// </summary>
- public ref class SimpleContext : public IStringContext
- {
- public:
- SimpleContext(cli::array<Feature>^ features)
- {
- Features = features;
-
- // TODO: add another constructor overload for native SimpleContext to avoid copying feature values
- m_features = new vector<NativeMultiWorldTesting::Feature>();
- for (int i = 0; i < features->Length; i++)
- {
- m_features->push_back({ features[i].Value, features[i].Id });
- }
-
- m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features);
- }
-
- String^ ToString() override
- {
- return gcnew String(m_native_context->To_String().c_str());
- }
-
- ~SimpleContext()
- {
- delete m_native_context;
- }
-
- public:
- cli::array<Feature>^ GetFeatures() { return Features; }
-
- internal:
- cli::array<Feature>^ Features;
-
- private:
- vector<NativeMultiWorldTesting::Feature>* m_features;
- NativeMultiWorldTesting::SimpleContext* m_native_context;
- };
-}
-
+#pragma once +#include "explore_interop.h" + +/*! +* \addtogroup MultiWorldTestingCsharp +* @{ +*/ +namespace MultiWorldTesting { + + /// <summary> + /// The epsilon greedy exploration class. + /// </summary> + /// <remarks> + /// This is a good choice if you have no idea which actions should be preferred. + /// Epsilon greedy is also computationally cheap. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultPolicy">A default function which outputs an action given a context.</param> + /// <param name="epsilon">The probability of a random exploration.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions) + { + this->defaultPolicy = defaultPolicy; + m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions); + } + + ~EpsilonGreedyExplorer() + { + delete m_explorer; + } + + internal: + virtual UInt32 InvokePolicyCallback(Ctx context, int index) override + { + return defaultPolicy->ChooseAction(context); + } + + NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + IPolicy<Ctx>^ defaultPolicy; + NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The tau-first exploration class. + /// </summary> + /// <remarks> + /// The tau-first explorer collects precisely tau uniform random + /// exploration events, and then uses the default policy. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class TauFirstExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultPolicy">A default policy after randomization finishes.</param> + /// <param name="tau">The number of events to be uniform over.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions) + { + this->defaultPolicy = defaultPolicy; + m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions); + } + + ~TauFirstExplorer() + { + delete m_explorer; + } + + internal: + virtual UInt32 InvokePolicyCallback(Ctx context, int index) override + { + return defaultPolicy->ChooseAction(context); + } + + NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + IPolicy<Ctx>^ defaultPolicy; + NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The epsilon greedy exploration class. + /// </summary> + /// <remarks> + /// In some cases, different actions have a different scores, and you + /// would prefer to choose actions with large scores. Softmax allows + /// you to do that. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class SoftmaxExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultScorer">A function which outputs a score for each action.</param> + /// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions) + { + this->defaultScorer = defaultScorer; + m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions); + } + + ~SoftmaxExplorer() + { + delete m_explorer; + } + + internal: + virtual List<float>^ InvokeScorerCallback(Ctx context) override + { + return defaultScorer->ScoreActions(context); + } + + NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + IScorer<Ctx>^ defaultScorer; + NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The generic exploration class. + /// </summary> + /// <remarks> + /// GenericExplorer provides complete flexibility. You can create any + /// distribution over actions desired, and it will draw from that. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class GenericExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultScorer">A function which outputs the probability of each action.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions) + { + this->defaultScorer = defaultScorer; + m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions); + } + + ~GenericExplorer() + { + delete m_explorer; + } + + internal: + virtual List<float>^ InvokeScorerCallback(Ctx context) override + { + return defaultScorer->ScoreActions(context); + } + + NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + IScorer<Ctx>^ defaultScorer; + NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The bootstrap exploration class. + /// </summary> + /// <remarks> + /// The Bootstrap explorer randomizes over the actions chosen by a set of + /// default policies. This performs well statistically but can be + /// computationally expensive. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class BootstrapExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultPolicies">A set of default policies to be uniform random over.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions) + { + this->defaultPolicies = defaultPolicies; + if (this->defaultPolicies == nullptr) + { + throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null."); + } + + m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions); + } + + ~BootstrapExplorer() + { + delete m_explorer; + } + + internal: + virtual UInt32 InvokePolicyCallback(Ctx context, int index) override + { + if (index < 0 || index >= defaultPolicies->Length) + { + throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range."); + } + return defaultPolicies[index]->ChooseAction(context); + } + + NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + cli::array<IPolicy<Ctx>^>^ defaultPolicies; + NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The top level MwtExplorer class. Using this makes sure that the + /// right bits are recorded and good random actions are chosen. + /// </summary> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class MwtExplorer : public RecorderCallback<Ctx> + { + public: + /// <summary> + /// Constructor. + /// </summary> + /// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param> + /// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param> + MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder) + { + this->appId = appId; + this->recorder = recorder; + } + + /// <summary> + /// Choose_Action should be drop-in replacement for any existing policy function. + /// </summary> + /// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param> + /// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param> + /// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param> + /// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns> + UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context) + { + String^ salt = this->appId; + NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder()); + + GCHandle selfHandle = GCHandle::Alloc(this); + IntPtr selfPtr = (IntPtr)selfHandle; + + GCHandle contextHandle = GCHandle::Alloc(context); + IntPtr contextPtr = (IntPtr)contextHandle; + + GCHandle explorerHandle = GCHandle::Alloc(explorer); + IntPtr explorerPtr = (IntPtr)explorerHandle; + + NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer()); + u32 action = 0; + if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid) + { + EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid) + { + TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid) + { + SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + else if (explorer->GetType() == GenericExplorer<Ctx>::typeid) + { + GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid) + { + BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + + explorerHandle.Free(); + contextHandle.Free(); + selfHandle.Free(); + + return action; + } + + internal: + virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override + { + recorder->Record(context, action, probability, unique_key); + } + + private: + IRecorder<Ctx>^ recorder; + String^ appId; + }; + + /// <summary> + /// Represents a feature in a sparse array. + /// </summary> + [StructLayout(LayoutKind::Sequential)] + public value struct Feature + { + float Value; + UInt32 Id; + }; + + /// <summary> + /// A sample recorder class that converts the exploration tuple into string format. + /// </summary> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> where Ctx : IStringContext + public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx> + { + public: + StringRecorder() + { + m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>(); + } + + ~StringRecorder() + { + delete m_string_recorder; + } + + virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) + { + GCHandle contextHandle = GCHandle::Alloc(context); + IntPtr contextPtr = (IntPtr)contextHandle; + + NativeStringContext native_context(contextPtr.ToPointer(), GetCallback()); + m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey)); + } + + /// <summary> + /// Gets the content of the recording so far as a string and clears internal content. + /// </summary> + /// <returns> + /// A string with recording content. + /// </returns> + String^ GetRecording() + { + // Workaround for C++-CLI bug which does not allow default value for parameter + return GetRecording(true); + } + + /// <summary> + /// Gets the content of the recording so far as a string and optionally clears internal content. + /// </summary> + /// <param name="flush">A boolean value indicating whether to clear the internal content.</param> + /// <returns> + /// A string with recording content. + /// </returns> + String^ GetRecording(bool flush) + { + return gcnew String(m_string_recorder->Get_Recording(flush).c_str()); + } + + private: + NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder; + }; + + /// <summary> + /// A sample context class that stores a vector of Features. + /// </summary> + public ref class SimpleContext : public IStringContext + { + public: + SimpleContext(cli::array<Feature>^ features) + { + Features = features; + + // TODO: add another constructor overload for native SimpleContext to avoid copying feature values + m_features = new vector<NativeMultiWorldTesting::Feature>(); + for (int i = 0; i < features->Length; i++) + { + m_features->push_back({ features[i].Value, features[i].Id }); + } + + m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features); + } + + String^ ToString() override + { + return gcnew String(m_native_context->To_String().c_str()); + } + + ~SimpleContext() + { + delete m_native_context; + } + + public: + cli::array<Feature>^ GetFeatures() { return Features; } + + internal: + cli::array<Feature>^ Features; + + private: + vector<NativeMultiWorldTesting::Feature>* m_features; + NativeMultiWorldTesting::SimpleContext* m_native_context; + }; +} + /*! @} End of Doxygen Groups*/
\ No newline at end of file diff --git a/explore/clr/explore_interface.h b/explore/clr/explore_interface.h index 3212b4d8..dfa17697 100644 --- a/explore/clr/explore_interface.h +++ b/explore/clr/explore_interface.h @@ -1,88 +1,88 @@ -#pragma once
-
-using namespace System;
-using namespace System::Collections::Generic;
-
-/** \defgroup MultiWorldTestingCsharp
-\brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
-*/
-
-/*!
-* \addtogroup MultiWorldTestingCsharp
-* @{
-*/
-
-//! Interface for C# version of Multiworld Testing library.
-//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
-namespace MultiWorldTesting {
-
-/// <summary>
-/// Represents a recorder that exposes a method to record exploration data based on generic contexts.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-/// <remarks>
-/// Exploration data is specified as a set of tuples <context, action, probability, key> as described below. An
-/// application passes an IRecorder object to the @MwtExplorer constructor. See
-/// @StringRecorder for a sample IRecorder object.
-/// </remarks>
-generic <class Ctx>
-public interface class IRecorder
-{
-public:
- /// <summary>
- /// Records the exploration data associated with a given decision.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <param name="action">Chosen by an exploration algorithm given context.</param>
- /// <param name="probability">The probability of the chosen action given context.</param>
- /// <param name="uniqueKey">A user-defined identifer for the decision.</param>
- virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0;
-};
-
-/// <summary>
-/// Exposes a method for choosing an action given a generic context. IPolicy objects are
-/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-generic <class Ctx>
-public interface class IPolicy
-{
-public:
- /// <summary>
- /// Determines the action to take for a given context.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <returns>Index of the action to take (1-based)</returns>
- virtual UInt32 ChooseAction(Ctx context) = 0;
-};
-
-/// <summary>
-/// Exposes a method for specifying a score (weight) for each action given a generic context.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-generic <class Ctx>
-public interface class IScorer
-{
-public:
- /// <summary>
- /// Determines the score of each action for a given context.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <returns>Vector of scores indexed by action (1-based).</returns>
- virtual List<float>^ ScoreActions(Ctx context) = 0;
-};
-
-generic <class Ctx>
-public interface class IExplorer
-{
-};
-
-public interface class IStringContext
-{
-public:
- virtual String^ ToString() = 0;
-};
-
-}
-
+#pragma once + +using namespace System; +using namespace System::Collections::Generic; + +/** \defgroup MultiWorldTestingCsharp +\brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs +*/ + +/*! +* \addtogroup MultiWorldTestingCsharp +* @{ +*/ + +//! Interface for C# version of Multiworld Testing library. +//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs +namespace MultiWorldTesting { + +/// <summary> +/// Represents a recorder that exposes a method to record exploration data based on generic contexts. +/// </summary> +/// <typeparam name="Ctx">The Context type.</typeparam> +/// <remarks> +/// Exploration data is specified as a set of tuples <context, action, probability, key> as described below. An +/// application passes an IRecorder object to the @MwtExplorer constructor. See +/// @StringRecorder for a sample IRecorder object. +/// </remarks> +generic <class Ctx> +public interface class IRecorder +{ +public: + /// <summary> + /// Records the exploration data associated with a given decision. + /// </summary> + /// <param name="context">A user-defined context for the decision.</param> + /// <param name="action">Chosen by an exploration algorithm given context.</param> + /// <param name="probability">The probability of the chosen action given context.</param> + /// <param name="uniqueKey">A user-defined identifer for the decision.</param> + virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0; +}; + +/// <summary> +/// Exposes a method for choosing an action given a generic context. IPolicy objects are +/// passed to (and invoked by) exploration algorithms to specify the default policy behavior. +/// </summary> +/// <typeparam name="Ctx">The Context type.</typeparam> +generic <class Ctx> +public interface class IPolicy +{ +public: + /// <summary> + /// Determines the action to take for a given context. + /// </summary> + /// <param name="context">A user-defined context for the decision.</param> + /// <returns>Index of the action to take (1-based)</returns> + virtual UInt32 ChooseAction(Ctx context) = 0; +}; + +/// <summary> +/// Exposes a method for specifying a score (weight) for each action given a generic context. +/// </summary> +/// <typeparam name="Ctx">The Context type.</typeparam> +generic <class Ctx> +public interface class IScorer +{ +public: + /// <summary> + /// Determines the score of each action for a given context. + /// </summary> + /// <param name="context">A user-defined context for the decision.</param> + /// <returns>Vector of scores indexed by action (1-based).</returns> + virtual List<float>^ ScoreActions(Ctx context) = 0; +}; + +generic <class Ctx> +public interface class IExplorer +{ +}; + +public interface class IStringContext +{ +public: + virtual String^ ToString() = 0; +}; + +} + /*! @} End of Doxygen Groups*/
\ No newline at end of file diff --git a/explore/clr/explore_interop.h b/explore/clr/explore_interop.h index 99466945..a6cb9327 100644 --- a/explore/clr/explore_interop.h +++ b/explore/clr/explore_interop.h @@ -1,358 +1,358 @@ -#pragma once
-
-#define MANAGED_CODE
-
-#include "explore_interface.h"
-#include "MWTExplorer.h"
-
-#include <msclr\marshal_cppstd.h>
-
-using namespace System;
-using namespace System::Collections::Generic;
-using namespace System::IO;
-using namespace System::Runtime::InteropServices;
-using namespace System::Xml::Serialization;
-using namespace msclr::interop;
-
-namespace MultiWorldTesting {
-
-// Policy callback
-private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index);
-typedef u32 Native_Policy_Callback(void* explorer, void* context, int index);
-
-// Scorer callback
-private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size);
-typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size);
-
-// Recorder callback
-private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey);
-typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key);
-
-// ToString callback
-private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue);
-typedef void Native_To_String_Callback(void* explorer, void* string_value);
-
-// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context
-// used for triggering callback for Policy, Scorer, Recorder
-class NativeContext
-{
-public:
- NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context)
- {
- m_clr_mwt = clr_mwt;
- m_clr_explorer = clr_explorer;
- m_clr_context = clr_context;
- }
-
- void* Get_Clr_Mwt()
- {
- return m_clr_mwt;
- }
-
- void* Get_Clr_Context()
- {
- return m_clr_context;
- }
-
- void* Get_Clr_Explorer()
- {
- return m_clr_explorer;
- }
-
-private:
- void* m_clr_mwt;
- void* m_clr_context;
- void* m_clr_explorer;
-};
-
-class NativeStringContext
-{
-public:
- NativeStringContext(void* clr_context, Native_To_String_Callback* func)
- {
- m_clr_context = clr_context;
- m_func = func;
- }
-
- string To_String()
- {
- string value;
- m_func(m_clr_context, &value);
- return value;
- }
-private:
- void* m_clr_context;
- Native_To_String_Callback* m_func;
-};
-
-// NativeRecorder listens to callback event and reroute it to the managed Recorder instance
-class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext>
-{
-public:
- NativeRecorder(Native_Recorder_Callback* native_func)
- {
- m_func = native_func;
- }
-
- void Record(NativeContext& context, u32 action, float probability, string unique_key)
- {
- GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str()));
- IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle;
-
- m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer());
-
- uniqueKeyHandle.Free();
- }
-private:
- Native_Recorder_Callback* m_func;
-};
-
-// NativePolicy listens to callback event and reroute it to the managed Policy instance
-class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext>
-{
-public:
- NativePolicy(Native_Policy_Callback* func, int index = -1)
- {
- m_func = func;
- m_index = index;
- }
-
- u32 Choose_Action(NativeContext& context)
- {
- return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index);
- }
-
-private:
- Native_Policy_Callback* m_func;
- int m_index;
-};
-
-class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext>
-{
-public:
- NativeScorer(Native_Scorer_Callback* func)
- {
- m_func = func;
- }
-
- vector<float> Score_Actions(NativeContext& context)
- {
- float* scores = nullptr;
- u32 num_scores = 0;
- m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores);
-
- // It's ok if scores is null, vector will be empty
- vector<float> scores_vector(scores, scores + num_scores);
- delete[] scores;
-
- return scores_vector;
- }
-private:
- Native_Scorer_Callback* m_func;
-};
-
-// Triggers callback to the Policy instance to choose an action
-generic <class Ctx>
-public ref class PolicyCallback abstract
-{
-internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0;
-
- PolicyCallback()
- {
- policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke);
- IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback);
- m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer());
- m_native_policy = nullptr;
- m_native_policies = nullptr;
- }
-
- ~PolicyCallback()
- {
- delete m_native_policy;
- delete m_native_policies;
- }
-
- NativePolicy* GetNativePolicy()
- {
- if (m_native_policy == nullptr)
- {
- m_native_policy = new NativePolicy(m_callback);
- }
- return m_native_policy;
- }
-
- vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* GetNativePolicies(int count)
- {
- if (m_native_policies == nullptr)
- {
- m_native_policies = new vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>();
- for (int i = 0; i < count; i++)
- {
- m_native_policies->push_back(unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>(new NativePolicy(m_callback, i)));
- }
- }
-
- return m_native_policies;
- }
-
- static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index)
- {
- GCHandle callbackHandle = (GCHandle)callbackPtr;
- PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- return callback->InvokePolicyCallback(context, index);
- }
-
-private:
- ClrPolicyCallback^ policyCallback;
-
-private:
- NativePolicy* m_native_policy;
- vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies;
- Native_Policy_Callback* m_callback;
-};
-
-// Triggers callback to the Recorder instance to record interaction data
-generic <class Ctx>
-public ref class RecorderCallback abstract
-{
-internal:
- virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0;
-
- RecorderCallback()
- {
- recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke);
- IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback);
- Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer());
- m_native_recorder = new NativeRecorder(callback);
- }
-
- ~RecorderCallback()
- {
- delete m_native_recorder;
- }
-
- NativeRecorder* GetNativeRecorder()
- {
- return m_native_recorder;
- }
-
- static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr)
- {
- GCHandle mwtHandle = (GCHandle)mwtPtr;
- RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr;
- String^ uniqueKey = (String^)uniqueKeyHandle.Target;
-
- callback->InvokeRecorderCallback(context, action, probability, uniqueKey);
- }
-
-private:
- ClrRecorderCallback^ recorderCallback;
-
-private:
- NativeRecorder* m_native_recorder;
-};
-
-// Triggers callback to the Recorder instance to record interaction data
-generic <class Ctx>
-public ref class ScorerCallback abstract
-{
-internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) = 0;
-
- ScorerCallback()
- {
- scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke);
- IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback);
- Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer());
- m_native_scorer = new NativeScorer(callback);
- }
-
- ~ScorerCallback()
- {
- delete m_native_scorer;
- }
-
- NativeScorer* GetNativeScorer()
- {
- return m_native_scorer;
- }
-
- static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr)
- {
- GCHandle callbackHandle = (GCHandle)callbackPtr;
- ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- List<float>^ scoreList = callback->InvokeScorerCallback(context);
-
- if (scoreList == nullptr || scoreList->Count == 0)
- {
- return;
- }
-
- u32* num_scores = (u32*)sizePtr.ToPointer();
- *num_scores = (u32)scoreList->Count;
-
- float* scores = new float[*num_scores];
- for (u32 i = 0; i < *num_scores; i++)
- {
- scores[i] = scoreList[i];
- }
-
- float** native_scores = (float**)scoresPtr.ToPointer();
- *native_scores = scores;
- }
-
-private:
- ClrScorerCallback^ scorerCallback;
-
-private:
- NativeScorer* m_native_scorer;
-};
-
-// Triggers callback to the Context instance to perform ToString() operation
-generic <class Ctx> where Ctx : IStringContext
-public ref class ToStringCallback
-{
-internal:
- ToStringCallback()
- {
- toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke);
- IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback);
- m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer());
- }
-
- Native_To_String_Callback* GetCallback()
- {
- return m_callback;
- }
-
- static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr)
- {
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- string* out_string = (string*)stringPtr.ToPointer();
- *out_string = marshal_as<string>(context->ToString());
- }
-
-private:
- ClrToStringCallback^ toStringCallback;
-
-private:
- Native_To_String_Callback* m_callback;
-};
-
+#pragma once + +#define MANAGED_CODE + +#include "explore_interface.h" +#include "MWTExplorer.h" + +#include <msclr\marshal_cppstd.h> + +using namespace System; +using namespace System::Collections::Generic; +using namespace System::IO; +using namespace System::Runtime::InteropServices; +using namespace System::Xml::Serialization; +using namespace msclr::interop; + +namespace MultiWorldTesting { + +// Policy callback +private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index); +typedef u32 Native_Policy_Callback(void* explorer, void* context, int index); + +// Scorer callback +private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size); +typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size); + +// Recorder callback +private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey); +typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key); + +// ToString callback +private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue); +typedef void Native_To_String_Callback(void* explorer, void* string_value); + +// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context +// used for triggering callback for Policy, Scorer, Recorder +class NativeContext +{ +public: + NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context) + { + m_clr_mwt = clr_mwt; + m_clr_explorer = clr_explorer; + m_clr_context = clr_context; + } + + void* Get_Clr_Mwt() + { + return m_clr_mwt; + } + + void* Get_Clr_Context() + { + return m_clr_context; + } + + void* Get_Clr_Explorer() + { + return m_clr_explorer; + } + +private: + void* m_clr_mwt; + void* m_clr_context; + void* m_clr_explorer; +}; + +class NativeStringContext +{ +public: + NativeStringContext(void* clr_context, Native_To_String_Callback* func) + { + m_clr_context = clr_context; + m_func = func; + } + + string To_String() + { + string value; + m_func(m_clr_context, &value); + return value; + } +private: + void* m_clr_context; + Native_To_String_Callback* m_func; +}; + +// NativeRecorder listens to callback event and reroute it to the managed Recorder instance +class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext> +{ +public: + NativeRecorder(Native_Recorder_Callback* native_func) + { + m_func = native_func; + } + + void Record(NativeContext& context, u32 action, float probability, string unique_key) + { + GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str())); + IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle; + + m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer()); + + uniqueKeyHandle.Free(); + } +private: + Native_Recorder_Callback* m_func; +}; + +// NativePolicy listens to callback event and reroute it to the managed Policy instance +class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext> +{ +public: + NativePolicy(Native_Policy_Callback* func, int index = -1) + { + m_func = func; + m_index = index; + } + + u32 Choose_Action(NativeContext& context) + { + return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index); + } + +private: + Native_Policy_Callback* m_func; + int m_index; +}; + +class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext> +{ +public: + NativeScorer(Native_Scorer_Callback* func) + { + m_func = func; + } + + vector<float> Score_Actions(NativeContext& context) + { + float* scores = nullptr; + u32 num_scores = 0; + m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores); + + // It's ok if scores is null, vector will be empty + vector<float> scores_vector(scores, scores + num_scores); + delete[] scores; + + return scores_vector; + } +private: + Native_Scorer_Callback* m_func; +}; + +// Triggers callback to the Policy instance to choose an action +generic <class Ctx> +public ref class PolicyCallback abstract +{ +internal: + virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0; + + PolicyCallback() + { + policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke); + IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback); + m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer()); + m_native_policy = nullptr; + m_native_policies = nullptr; + } + + ~PolicyCallback() + { + delete m_native_policy; + delete m_native_policies; + } + + NativePolicy* GetNativePolicy() + { + if (m_native_policy == nullptr) + { + m_native_policy = new NativePolicy(m_callback); + } + return m_native_policy; + } + + vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* GetNativePolicies(int count) + { + if (m_native_policies == nullptr) + { + m_native_policies = new vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>(); + for (int i = 0; i < count; i++) + { + m_native_policies->push_back(unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>(new NativePolicy(m_callback, i))); + } + } + + return m_native_policies; + } + + static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index) + { + GCHandle callbackHandle = (GCHandle)callbackPtr; + PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target; + + GCHandle contextHandle = (GCHandle)contextPtr; + Ctx context = (Ctx)contextHandle.Target; + + return callback->InvokePolicyCallback(context, index); + } + +private: + ClrPolicyCallback^ policyCallback; + +private: + NativePolicy* m_native_policy; + vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies; + Native_Policy_Callback* m_callback; +}; + +// Triggers callback to the Recorder instance to record interaction data +generic <class Ctx> +public ref class RecorderCallback abstract +{ +internal: + virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0; + + RecorderCallback() + { + recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke); + IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback); + Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer()); + m_native_recorder = new NativeRecorder(callback); + } + + ~RecorderCallback() + { + delete m_native_recorder; + } + + NativeRecorder* GetNativeRecorder() + { + return m_native_recorder; + } + + static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr) + { + GCHandle mwtHandle = (GCHandle)mwtPtr; + RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target; + + GCHandle contextHandle = (GCHandle)contextPtr; + Ctx context = (Ctx)contextHandle.Target; + + GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr; + String^ uniqueKey = (String^)uniqueKeyHandle.Target; + + callback->InvokeRecorderCallback(context, action, probability, uniqueKey); + } + +private: + ClrRecorderCallback^ recorderCallback; + +private: + NativeRecorder* m_native_recorder; +}; + +// Triggers callback to the Recorder instance to record interaction data +generic <class Ctx> +public ref class ScorerCallback abstract +{ +internal: + virtual List<float>^ InvokeScorerCallback(Ctx context) = 0; + + ScorerCallback() + { + scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke); + IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback); + Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer()); + m_native_scorer = new NativeScorer(callback); + } + + ~ScorerCallback() + { + delete m_native_scorer; + } + + NativeScorer* GetNativeScorer() + { + return m_native_scorer; + } + + static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr) + { + GCHandle callbackHandle = (GCHandle)callbackPtr; + ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target; + + GCHandle contextHandle = (GCHandle)contextPtr; + Ctx context = (Ctx)contextHandle.Target; + + List<float>^ scoreList = callback->InvokeScorerCallback(context); + + if (scoreList == nullptr || scoreList->Count == 0) + { + return; + } + + u32* num_scores = (u32*)sizePtr.ToPointer(); + *num_scores = (u32)scoreList->Count; + + float* scores = new float[*num_scores]; + for (u32 i = 0; i < *num_scores; i++) + { + scores[i] = scoreList[i]; + } + + float** native_scores = (float**)scoresPtr.ToPointer(); + *native_scores = scores; + } + +private: + ClrScorerCallback^ scorerCallback; + +private: + NativeScorer* m_native_scorer; +}; + +// Triggers callback to the Context instance to perform ToString() operation +generic <class Ctx> where Ctx : IStringContext +public ref class ToStringCallback +{ +internal: + ToStringCallback() + { + toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke); + IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback); + m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer()); + } + + Native_To_String_Callback* GetCallback() + { + return m_callback; + } + + static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr) + { + GCHandle contextHandle = (GCHandle)contextPtr; + Ctx context = (Ctx)contextHandle.Target; + + string* out_string = (string*)stringPtr.ToPointer(); + *out_string = marshal_as<string>(context->ToString()); + } + +private: + ClrToStringCallback^ toStringCallback; + +private: + Native_To_String_Callback* m_callback; +}; + }
\ No newline at end of file diff --git a/explore/explore.cpp b/explore/explore.cpp index ebc1c14e..f44ac353 100644 --- a/explore/explore.cpp +++ b/explore/explore.cpp @@ -1,73 +1,73 @@ -// explore.cpp : Timing code to measure performance of MWT Explorer library
-
-#include "MWTExplorer.h"
-#include <chrono>
-#include <tuple>
-#include <iostream>
-
-using namespace std;
-using namespace std::chrono;
-
-using namespace MultiWorldTesting;
-
-class MySimplePolicy : public IPolicy<SimpleContext>
-{
-public:
- u32 Choose_Action(SimpleContext& context)
- {
- return (u32)1;
- }
-};
-
-const u32 num_actions = 10;
-
-void Clock_Explore()
-{
- float epsilon = .2f;
- string unique_key = "key";
- int num_features = 1000;
- int num_iter = 10000;
- int num_warmup = 100;
- int num_interactions = 1;
-
- // pre-create features
- vector<Feature> features;
- for (int i = 0; i < num_features; i++)
- {
- Feature f = {0.5, i+1};
- features.push_back(f);
- }
-
- long long time_init = 0, time_choose = 0;
- for (int iter = 0; iter < num_iter + num_warmup; iter++)
- {
- high_resolution_clock::time_point t1 = high_resolution_clock::now();
- StringRecorder<SimpleContext> recorder;
- MwtExplorer<SimpleContext> mwt("test", recorder);
- MySimplePolicy default_policy;
- EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions);
- high_resolution_clock::time_point t2 = high_resolution_clock::now();
- time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
-
- t1 = high_resolution_clock::now();
- SimpleContext appContext(features);
- for (int i = 0; i < num_interactions; i++)
- {
- mwt.Choose_Action(explorer, unique_key, appContext);
- }
- t2 = high_resolution_clock::now();
- time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
- }
-
- cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl;
- cout << "--- PER ITERATION ---" << endl;
- cout << "Init: " << (double)time_init / num_iter << " micro" << endl;
- cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl;
- cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl;
-}
-
-int main(int argc, char* argv[])
-{
- Clock_Explore();
- return 0;
-}
+// explore.cpp : Timing code to measure performance of MWT Explorer library + +#include "MWTExplorer.h" +#include <chrono> +#include <tuple> +#include <iostream> + +using namespace std; +using namespace std::chrono; + +using namespace MultiWorldTesting; + +class MySimplePolicy : public IPolicy<SimpleContext> +{ +public: + u32 Choose_Action(SimpleContext& context) + { + return (u32)1; + } +}; + +const u32 num_actions = 10; + +void Clock_Explore() +{ + float epsilon = .2f; + string unique_key = "key"; + int num_features = 1000; + int num_iter = 10000; + int num_warmup = 100; + int num_interactions = 1; + + // pre-create features + vector<Feature> features; + for (int i = 0; i < num_features; i++) + { + Feature f = {0.5, i+1}; + features.push_back(f); + } + + long long time_init = 0, time_choose = 0; + for (int iter = 0; iter < num_iter + num_warmup; iter++) + { + high_resolution_clock::time_point t1 = high_resolution_clock::now(); + StringRecorder<SimpleContext> recorder; + MwtExplorer<SimpleContext> mwt("test", recorder); + MySimplePolicy default_policy; + EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions); + high_resolution_clock::time_point t2 = high_resolution_clock::now(); + time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count(); + + t1 = high_resolution_clock::now(); + SimpleContext appContext(features); + for (int i = 0; i < num_interactions; i++) + { + mwt.Choose_Action(explorer, unique_key, appContext); + } + t2 = high_resolution_clock::now(); + time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count(); + } + + cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl; + cout << "--- PER ITERATION ---" << endl; + cout << "Init: " << (double)time_init / num_iter << " micro" << endl; + cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl; + cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl; +} + +int main(int argc, char* argv[]) +{ + Clock_Explore(); + return 0; +} diff --git a/explore/static/MWTExplorer.h b/explore/static/MWTExplorer.h index dd5e3044..4c7009b2 100644 --- a/explore/static/MWTExplorer.h +++ b/explore/static/MWTExplorer.h @@ -1,669 +1,669 @@ -//
-// Main interface for clients of the Multiworld testing (MWT) service.
-//
-
-#pragma once
-
-#include <stdexcept>
-#include <float.h>
-#include <math.h>
-#include <stdio.h>
-#include <string.h>
-#include <vector>
-#include <utility>
-#include <memory>
-#include <limits.h>
-#include <tuple>
-
-#ifdef MANAGED_CODE
-#define PORTING_INTERFACE public
-#define MWT_NAMESPACE namespace NativeMultiWorldTesting
-#else
-#define PORTING_INTERFACE private
-#define MWT_NAMESPACE namespace MultiWorldTesting
-#endif
-
-using namespace std;
-
-#include "utility.h"
-
-/** \defgroup MultiWorldTestingCpp
-\brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
-*/
-
-/*!
-* \addtogroup MultiWorldTestingCpp
-* @{
-*/
-
-//! Interface for C++ version of Multiworld Testing library.
-//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
-MWT_NAMESPACE {
-
-// Forward declarations
-template <class Ctx>
-class IRecorder;
-template <class Ctx>
-class IExplorer;
-
-///
-/// The top-level MwtExplorer class. Using this enables principled and efficient exploration
-/// over a set of possible actions, and ensures that the right bits are recorded.
-///
-template <class Ctx>
-class MwtExplorer
-{
-public:
- ///
- /// Constructor
- ///
- /// @param appid This should be unique to your experiment or you risk nasty correlation bugs.
- /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning.
- ///
- MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder)
- {
- m_app_id = HashUtils::Compute_Id_Hash(app_id);
- }
-
- ///
- /// Chooses an action by invoking an underlying exploration algorithm. This should be a
- /// drop-in replacement for any existing policy function.
- ///
- /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback.
- /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc..
- /// @param context The context upon which a decision is made. See SimpleContext below for an example.
- ///
- u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context)
- {
- u64 seed = HashUtils::Compute_Id_Hash(unique_key);
-
- std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context);
-
- u32 action = std::get<0>(action_probability_log_tuple);
- float prob = std::get<1>(action_probability_log_tuple);
-
- if (std::get<2>(action_probability_log_tuple))
- {
- m_recorder.Record(context, action, prob, unique_key);
- }
-
- return action;
- }
-
-private:
- u64 m_app_id;
- IRecorder<Ctx>& m_recorder;
-};
-
-///
-/// Exposes a method to record exploration data based on generic contexts. Exploration data
-/// is specified as a set of tuples <context, action, probability, key> as described below. An
-/// application passes an IRecorder object to the @MwtExplorer constructor. See
-/// @StringRecorder for a sample IRecorder object.
-///
-template <class Ctx>
-class IRecorder
-{
-public:
- ///
- /// Records the exploration data associated with a given decision.
- ///
- /// @param context A user-defined context for the decision
- /// @param action The action chosen by an exploration algorithm given context
- /// @param probability The probability the exploration algorithm chose said action
- /// @param unique_key A user-defined unique identifer for the decision
- ///
- virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0;
-};
-
-///
-/// Exposes a method to choose an action given a generic context, and obtain the relevant
-/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this
-/// interface yourself: instead, use the various exploration algorithms below, which
-/// implement it for you.
-///
-template <class Ctx>
-class IExplorer
-{
-public:
- ///
- /// Determines the action to take and the probability with which it was chosen, for a
- /// given context.
- ///
- /// @param salted_seed A PRG seed based on a unique id information provided by the user
- /// @param context A user-defined context for the decision
- /// @returns The action to take, the probability it was chosen, and a flag indicating
- /// whether to record this decision
- ///
- virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0;
-};
-
-///
-/// Exposes a method to choose an action given a generic context. IPolicy objects are
-/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
-///
-template <class Ctx>
-class IPolicy
-{
-public:
- ///
- /// Determines the action to take for a given context.
- ///
- /// @param context A user-defined context for the decision
- /// @returns The action to take (1-based index)
- ///
- virtual u32 Choose_Action(Ctx& context) = 0;
-};
-
-///
-/// Exposes a method for specifying a score (weight) for each action given a generic context.
-///
-template <class Ctx>
-class IScorer
-{
-public:
- ///
- /// Determines the score of each action for a given context.
- ///
- /// @param context A user-defined context for the decision
- /// @returns A vector of scores indexed by action (1-based)
- ///
- virtual vector<float> Score_Actions(Ctx& context) = 0;
-};
-
-///
-/// A sample recorder class that converts the exploration tuple into string format.
-///
-template <class Ctx>
-struct StringRecorder : public IRecorder<Ctx>
-{
- void Record(Ctx& context, u32 action, float probability, string unique_key)
- {
- // Implicitly enforce To_String() API on the context
- m_recording.append(to_string((unsigned long)action));
- m_recording.append(" ", 1);
- m_recording.append(unique_key);
- m_recording.append(" ", 1);
-
- char prob_str[10] = { 0 };
- NumberUtils::Float_To_String(probability, prob_str);
- m_recording.append(prob_str);
-
- m_recording.append(" | ", 3);
- m_recording.append(context.To_String());
- m_recording.append("\n");
- }
-
- // Gets the content of the recording so far as a string and optionally clears internal content.
- string Get_Recording(bool flush = true)
- {
- if (!flush)
- {
- return m_recording;
- }
- string recording = m_recording;
- m_recording.clear();
- return recording;
- }
-
-private:
- string m_recording;
-};
-
-///
-/// Represents a feature in a sparse array.
-///
-struct Feature
-{
- float Value;
- u32 Id;
-
- bool operator==(Feature other_feature)
- {
- return Id == other_feature.Id;
- }
-};
-
-///
-/// A sample context class that stores a vector of Features.
-///
-class SimpleContext
-{
-public:
- SimpleContext(vector<Feature>& features) :
- m_features(features)
- { }
-
- vector<Feature>& Get_Features()
- {
- return m_features;
- }
-
- string To_String()
- {
- string out_string;
- char feature_str[35] = { 0 };
- for (size_t i = 0; i < m_features.size(); i++)
- {
- int chars;
- if (i == 0)
- {
- chars = sprintf(feature_str, "%d:", m_features[i].Id);
- }
- else
- {
- chars = sprintf(feature_str, " %d:", m_features[i].Id);
- }
- NumberUtils::print_float(feature_str + chars, m_features[i].Value);
- out_string.append(feature_str);
- }
- return out_string;
- }
-
-private:
- vector<Feature>& m_features;
-};
-
-///
-/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea
-/// which actions should be preferred. Epsilon greedy is also computationally cheap.
-///
-template <class Ctx>
-class EpsilonGreedyExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy A default function which outputs an action given a context.
- /// @param epsilon The probability of a random exploration.
- /// @param num_actions The number of actions to randomize over.
- ///
- EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) :
- m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
-
- if (m_epsilon < 0 || m_epsilon > 1)
- {
- throw std::invalid_argument("Epsilon must be between 0 and 1.");
- }
- }
-
- ~EpsilonGreedyExplorer()
- {
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default policy function to get the action
- u32 chosen_action = m_default_policy.Choose_Action(context);
-
- if (chosen_action == 0 || chosen_action > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- float action_probability = 0.f;
- float base_probability = m_epsilon / m_num_actions; // uniform probability
-
- // TODO: check this random generation
- if (random_generator.Uniform_Unit_Interval() < 1.f - m_epsilon)
- {
- action_probability = 1.f - m_epsilon + base_probability;
- }
- else
- {
- // Get uniform random action ID
- u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
-
- if (actionId == chosen_action)
- {
- // IF it matches the one chosen by the default policy
- // then increase the probability
- action_probability = 1.f - m_epsilon + base_probability;
- }
- else
- {
- // Otherwise it's just the uniform probability
- action_probability = base_probability;
- }
- chosen_action = actionId;
- }
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
- }
-
-private:
- IPolicy<Ctx>& m_default_policy;
- float m_epsilon;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// In some cases, different actions have a different scores, and you would prefer to
-/// choose actions with large scores. Softmax allows you to do that.
-///
-template <class Ctx>
-class SoftmaxExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_scorer A function which outputs a score for each action.
- /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.
- /// @param num_actions The number of actions to randomize over.
- ///
- SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) :
- m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default scorer function
- vector<float> scores = m_default_scorer.Score_Actions(context);
- u32 num_scores = (u32)scores.size();
- if (num_scores != m_num_actions)
- {
- throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions");
- }
-
- u32 i = 0;
-
- float max_score = -FLT_MAX;
- for (i = 0; i < num_scores; i++)
- {
- if (max_score < scores[i])
- {
- max_score = scores[i];
- }
- }
-
- // Create a normalized exponential distribution based on the returned scores
- for (i = 0; i < num_scores; i++)
- {
- scores[i] = exp(m_lambda * (scores[i] - max_score));
- }
-
- // Create a discrete_distribution based on the returned weights. This class handles the
- // case where the sum of the weights is < or > 1, by normalizing agains the sum.
- float total = 0.f;
- for (size_t i = 0; i < num_scores; i++)
- total += scores[i];
-
- float draw = random_generator.Uniform_Unit_Interval();
-
- float sum = 0.f;
- float action_probability = 0.f;
- u32 action_index = num_scores - 1;
- for (u32 i = 0; i < num_scores; i++)
- {
- scores[i] = scores[i] / total;
- sum += scores[i];
- if (sum > draw)
- {
- action_index = i;
- action_probability = scores[i];
- break;
- }
- }
-
- // action id is one-based
- return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
- }
-
-private:
- IScorer<Ctx>& m_default_scorer;
- float m_lambda;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// GenericExplorer provides complete flexibility. You can create any
-/// distribution over actions desired, and it will draw from that.
-///
-template <class Ctx>
-class GenericExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_scorer A function which outputs the probability of each action.
- /// @param num_actions The number of actions to randomize over.
- ///
- GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) :
- m_default_scorer(default_scorer), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default scorer function
- vector<float> weights = m_default_scorer.Score_Actions(context);
- u32 num_weights = (u32)weights.size();
- if (num_weights != m_num_actions)
- {
- throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions");
- }
-
- // Create a discrete_distribution based on the returned weights. This class handles the
- // case where the sum of the weights is < or > 1, by normalizing agains the sum.
- float total = 0.f;
- for (size_t i = 0; i < num_weights; i++)
- {
- if (weights[i] < 0)
- {
- throw std::invalid_argument("Scores must be non-negative.");
- }
- total += weights[i];
- }
- if (total == 0)
- {
- throw std::invalid_argument("At least one score must be positive.");
- }
-
- float draw = random_generator.Uniform_Unit_Interval();
-
- float sum = 0.f;
- float action_probability = 0.f;
- u32 action_index = num_weights - 1;
- for (u32 i = 0; i < num_weights; i++)
- {
- weights[i] = weights[i] / total;
- sum += weights[i];
- if (sum > draw)
- {
- action_index = i;
- action_probability = weights[i];
- break;
- }
- }
-
- // action id is one-based
- return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
- }
-
-private:
- IScorer<Ctx>& m_default_scorer;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// The tau-first explorer collects exactly tau uniform random exploration events, and then
-/// uses the default policy thereafter.
-///
-template <class Ctx>
-class TauFirstExplorer : public IExplorer<Ctx>
-{
-public:
-
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy A default policy after randomization finishes.
- /// @param tau The number of events to be uniform over.
- /// @param num_actions The number of actions to randomize over.
- ///
- TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) :
- m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- u32 chosen_action = 0;
- float action_probability = 0.f;
- bool log_action;
- if (m_tau)
- {
- m_tau--;
- u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
- action_probability = 1.f / m_num_actions;
- chosen_action = actionId;
- log_action = true;
- }
- else
- {
- // Invoke the default policy function to get the action
- chosen_action = m_default_policy.Choose_Action(context);
-
- if (chosen_action == 0 || chosen_action > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- action_probability = 1.f;
- log_action = false;
- }
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action);
- }
-
-private:
- IPolicy<Ctx>& m_default_policy;
- u32 m_tau;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies.
-/// This performs well statistically but can be computationally expensive.
-///
-template <class Ctx>
-class BootstrapExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy_functions A set of default policies to be uniform random over.
- /// The policy pointers must be valid throughout the lifetime of this explorer.
- /// @param num_actions The number of actions to randomize over.
- ///
- BootstrapExplorer(vector<unique_ptr<IPolicy<Ctx>>>& default_policy_functions, u32 num_actions) :
- m_default_policy_functions(default_policy_functions),
- m_num_actions(num_actions)
- {
- m_bags = (u32)default_policy_functions.size();
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
-
- if (m_bags < 1)
- {
- throw std::invalid_argument("Number of bags must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Select bag
- u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1);
-
- // Invoke the default policy function to get the action
- u32 chosen_action = 0;
- u32 action_from_bag = 0;
- vector<u32> actions_selected;
- for (size_t i = 0; i < m_num_actions; i++)
- {
- actions_selected.push_back(0);
- }
-
- // Invoke the default policy function to get the action
- for (u32 current_bag = 0; current_bag < m_bags; current_bag++)
- {
- action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context);
-
- if (action_from_bag == 0 || action_from_bag > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- if (current_bag == chosen_bag)
- {
- chosen_action = action_from_bag;
- }
- //this won't work if actions aren't 0 to Count
- actions_selected[action_from_bag - 1]++; // action id is one-based
- }
- float action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
- }
-
-private:
- vector<unique_ptr<IPolicy<Ctx>>>& m_default_policy_functions;
- u32 m_bags;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-} // End namespace MultiWorldTestingCpp
-/*! @} End of Doxygen Groups*/
+// +// Main interface for clients of the Multiworld testing (MWT) service. +// + +#pragma once + +#include <stdexcept> +#include <float.h> +#include <math.h> +#include <stdio.h> +#include <string.h> +#include <vector> +#include <utility> +#include <memory> +#include <limits.h> +#include <tuple> + +#ifdef MANAGED_CODE +#define PORTING_INTERFACE public +#define MWT_NAMESPACE namespace NativeMultiWorldTesting +#else +#define PORTING_INTERFACE private +#define MWT_NAMESPACE namespace MultiWorldTesting +#endif + +using namespace std; + +#include "utility.h" + +/** \defgroup MultiWorldTestingCpp +\brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp +*/ + +/*! +* \addtogroup MultiWorldTestingCpp +* @{ +*/ + +//! Interface for C++ version of Multiworld Testing library. +//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp +MWT_NAMESPACE { + +// Forward declarations +template <class Ctx> +class IRecorder; +template <class Ctx> +class IExplorer; + +/// +/// The top-level MwtExplorer class. Using this enables principled and efficient exploration +/// over a set of possible actions, and ensures that the right bits are recorded. +/// +template <class Ctx> +class MwtExplorer +{ +public: + /// + /// Constructor + /// + /// @param appid This should be unique to your experiment or you risk nasty correlation bugs. + /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning. + /// + MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder) + { + m_app_id = HashUtils::Compute_Id_Hash(app_id); + } + + /// + /// Chooses an action by invoking an underlying exploration algorithm. This should be a + /// drop-in replacement for any existing policy function. + /// + /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback. + /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc.. + /// @param context The context upon which a decision is made. See SimpleContext below for an example. + /// + u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context) + { + u64 seed = HashUtils::Compute_Id_Hash(unique_key); + + std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context); + + u32 action = std::get<0>(action_probability_log_tuple); + float prob = std::get<1>(action_probability_log_tuple); + + if (std::get<2>(action_probability_log_tuple)) + { + m_recorder.Record(context, action, prob, unique_key); + } + + return action; + } + +private: + u64 m_app_id; + IRecorder<Ctx>& m_recorder; +}; + +/// +/// Exposes a method to record exploration data based on generic contexts. Exploration data +/// is specified as a set of tuples <context, action, probability, key> as described below. An +/// application passes an IRecorder object to the @MwtExplorer constructor. See +/// @StringRecorder for a sample IRecorder object. +/// +template <class Ctx> +class IRecorder +{ +public: + /// + /// Records the exploration data associated with a given decision. + /// + /// @param context A user-defined context for the decision + /// @param action The action chosen by an exploration algorithm given context + /// @param probability The probability the exploration algorithm chose said action + /// @param unique_key A user-defined unique identifer for the decision + /// + virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0; +}; + +/// +/// Exposes a method to choose an action given a generic context, and obtain the relevant +/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this +/// interface yourself: instead, use the various exploration algorithms below, which +/// implement it for you. +/// +template <class Ctx> +class IExplorer +{ +public: + /// + /// Determines the action to take and the probability with which it was chosen, for a + /// given context. + /// + /// @param salted_seed A PRG seed based on a unique id information provided by the user + /// @param context A user-defined context for the decision + /// @returns The action to take, the probability it was chosen, and a flag indicating + /// whether to record this decision + /// + virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0; +}; + +/// +/// Exposes a method to choose an action given a generic context. IPolicy objects are +/// passed to (and invoked by) exploration algorithms to specify the default policy behavior. +/// +template <class Ctx> +class IPolicy +{ +public: + /// + /// Determines the action to take for a given context. + /// + /// @param context A user-defined context for the decision + /// @returns The action to take (1-based index) + /// + virtual u32 Choose_Action(Ctx& context) = 0; +}; + +/// +/// Exposes a method for specifying a score (weight) for each action given a generic context. +/// +template <class Ctx> +class IScorer +{ +public: + /// + /// Determines the score of each action for a given context. + /// + /// @param context A user-defined context for the decision + /// @returns A vector of scores indexed by action (1-based) + /// + virtual vector<float> Score_Actions(Ctx& context) = 0; +}; + +/// +/// A sample recorder class that converts the exploration tuple into string format. +/// +template <class Ctx> +struct StringRecorder : public IRecorder<Ctx> +{ + void Record(Ctx& context, u32 action, float probability, string unique_key) + { + // Implicitly enforce To_String() API on the context + m_recording.append(to_string((unsigned long)action)); + m_recording.append(" ", 1); + m_recording.append(unique_key); + m_recording.append(" ", 1); + + char prob_str[10] = { 0 }; + NumberUtils::Float_To_String(probability, prob_str); + m_recording.append(prob_str); + + m_recording.append(" | ", 3); + m_recording.append(context.To_String()); + m_recording.append("\n"); + } + + // Gets the content of the recording so far as a string and optionally clears internal content. + string Get_Recording(bool flush = true) + { + if (!flush) + { + return m_recording; + } + string recording = m_recording; + m_recording.clear(); + return recording; + } + +private: + string m_recording; +}; + +/// +/// Represents a feature in a sparse array. +/// +struct Feature +{ + float Value; + u32 Id; + + bool operator==(Feature other_feature) + { + return Id == other_feature.Id; + } +}; + +/// +/// A sample context class that stores a vector of Features. +/// +class SimpleContext +{ +public: + SimpleContext(vector<Feature>& features) : + m_features(features) + { } + + vector<Feature>& Get_Features() + { + return m_features; + } + + string To_String() + { + string out_string; + char feature_str[35] = { 0 }; + for (size_t i = 0; i < m_features.size(); i++) + { + int chars; + if (i == 0) + { + chars = sprintf(feature_str, "%d:", m_features[i].Id); + } + else + { + chars = sprintf(feature_str, " %d:", m_features[i].Id); + } + NumberUtils::print_float(feature_str + chars, m_features[i].Value); + out_string.append(feature_str); + } + return out_string; + } + +private: + vector<Feature>& m_features; +}; + +/// +/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea +/// which actions should be preferred. Epsilon greedy is also computationally cheap. +/// +template <class Ctx> +class EpsilonGreedyExplorer : public IExplorer<Ctx> +{ +public: + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_policy A default function which outputs an action given a context. + /// @param epsilon The probability of a random exploration. + /// @param num_actions The number of actions to randomize over. + /// + EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) : + m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions) + { + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + + if (m_epsilon < 0 || m_epsilon > 1) + { + throw std::invalid_argument("Epsilon must be between 0 and 1."); + } + } + + ~EpsilonGreedyExplorer() + { + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + // Invoke the default policy function to get the action + u32 chosen_action = m_default_policy.Choose_Action(context); + + if (chosen_action == 0 || chosen_action > m_num_actions) + { + throw std::invalid_argument("Action chosen by default policy is not within valid range."); + } + + float action_probability = 0.f; + float base_probability = m_epsilon / m_num_actions; // uniform probability + + // TODO: check this random generation + if (random_generator.Uniform_Unit_Interval() < 1.f - m_epsilon) + { + action_probability = 1.f - m_epsilon + base_probability; + } + else + { + // Get uniform random action ID + u32 actionId = random_generator.Uniform_Int(1, m_num_actions); + + if (actionId == chosen_action) + { + // IF it matches the one chosen by the default policy + // then increase the probability + action_probability = 1.f - m_epsilon + base_probability; + } + else + { + // Otherwise it's just the uniform probability + action_probability = base_probability; + } + chosen_action = actionId; + } + + return std::tuple<u32, float, bool>(chosen_action, action_probability, true); + } + +private: + IPolicy<Ctx>& m_default_policy; + float m_epsilon; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; + +/// +/// In some cases, different actions have a different scores, and you would prefer to +/// choose actions with large scores. Softmax allows you to do that. +/// +template <class Ctx> +class SoftmaxExplorer : public IExplorer<Ctx> +{ +public: + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_scorer A function which outputs a score for each action. + /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max. + /// @param num_actions The number of actions to randomize over. + /// + SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) : + m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions) + { + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + // Invoke the default scorer function + vector<float> scores = m_default_scorer.Score_Actions(context); + u32 num_scores = (u32)scores.size(); + if (num_scores != m_num_actions) + { + throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions"); + } + + u32 i = 0; + + float max_score = -FLT_MAX; + for (i = 0; i < num_scores; i++) + { + if (max_score < scores[i]) + { + max_score = scores[i]; + } + } + + // Create a normalized exponential distribution based on the returned scores + for (i = 0; i < num_scores; i++) + { + scores[i] = exp(m_lambda * (scores[i] - max_score)); + } + + // Create a discrete_distribution based on the returned weights. This class handles the + // case where the sum of the weights is < or > 1, by normalizing agains the sum. + float total = 0.f; + for (size_t i = 0; i < num_scores; i++) + total += scores[i]; + + float draw = random_generator.Uniform_Unit_Interval(); + + float sum = 0.f; + float action_probability = 0.f; + u32 action_index = num_scores - 1; + for (u32 i = 0; i < num_scores; i++) + { + scores[i] = scores[i] / total; + sum += scores[i]; + if (sum > draw) + { + action_index = i; + action_probability = scores[i]; + break; + } + } + + // action id is one-based + return std::tuple<u32, float, bool>(action_index + 1, action_probability, true); + } + +private: + IScorer<Ctx>& m_default_scorer; + float m_lambda; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; + +/// +/// GenericExplorer provides complete flexibility. You can create any +/// distribution over actions desired, and it will draw from that. +/// +template <class Ctx> +class GenericExplorer : public IExplorer<Ctx> +{ +public: + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_scorer A function which outputs the probability of each action. + /// @param num_actions The number of actions to randomize over. + /// + GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) : + m_default_scorer(default_scorer), m_num_actions(num_actions) + { + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + // Invoke the default scorer function + vector<float> weights = m_default_scorer.Score_Actions(context); + u32 num_weights = (u32)weights.size(); + if (num_weights != m_num_actions) + { + throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions"); + } + + // Create a discrete_distribution based on the returned weights. This class handles the + // case where the sum of the weights is < or > 1, by normalizing agains the sum. + float total = 0.f; + for (size_t i = 0; i < num_weights; i++) + { + if (weights[i] < 0) + { + throw std::invalid_argument("Scores must be non-negative."); + } + total += weights[i]; + } + if (total == 0) + { + throw std::invalid_argument("At least one score must be positive."); + } + + float draw = random_generator.Uniform_Unit_Interval(); + + float sum = 0.f; + float action_probability = 0.f; + u32 action_index = num_weights - 1; + for (u32 i = 0; i < num_weights; i++) + { + weights[i] = weights[i] / total; + sum += weights[i]; + if (sum > draw) + { + action_index = i; + action_probability = weights[i]; + break; + } + } + + // action id is one-based + return std::tuple<u32, float, bool>(action_index + 1, action_probability, true); + } + +private: + IScorer<Ctx>& m_default_scorer; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; + +/// +/// The tau-first explorer collects exactly tau uniform random exploration events, and then +/// uses the default policy thereafter. +/// +template <class Ctx> +class TauFirstExplorer : public IExplorer<Ctx> +{ +public: + + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_policy A default policy after randomization finishes. + /// @param tau The number of events to be uniform over. + /// @param num_actions The number of actions to randomize over. + /// + TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) : + m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions) + { + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + u32 chosen_action = 0; + float action_probability = 0.f; + bool log_action; + if (m_tau) + { + m_tau--; + u32 actionId = random_generator.Uniform_Int(1, m_num_actions); + action_probability = 1.f / m_num_actions; + chosen_action = actionId; + log_action = true; + } + else + { + // Invoke the default policy function to get the action + chosen_action = m_default_policy.Choose_Action(context); + + if (chosen_action == 0 || chosen_action > m_num_actions) + { + throw std::invalid_argument("Action chosen by default policy is not within valid range."); + } + + action_probability = 1.f; + log_action = false; + } + + return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action); + } + +private: + IPolicy<Ctx>& m_default_policy; + u32 m_tau; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; + +/// +/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies. +/// This performs well statistically but can be computationally expensive. +/// +template <class Ctx> +class BootstrapExplorer : public IExplorer<Ctx> +{ +public: + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_policy_functions A set of default policies to be uniform random over. + /// The policy pointers must be valid throughout the lifetime of this explorer. + /// @param num_actions The number of actions to randomize over. + /// + BootstrapExplorer(vector<unique_ptr<IPolicy<Ctx>>>& default_policy_functions, u32 num_actions) : + m_default_policy_functions(default_policy_functions), + m_num_actions(num_actions) + { + m_bags = (u32)default_policy_functions.size(); + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + + if (m_bags < 1) + { + throw std::invalid_argument("Number of bags must be at least 1."); + } + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + // Select bag + u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1); + + // Invoke the default policy function to get the action + u32 chosen_action = 0; + u32 action_from_bag = 0; + vector<u32> actions_selected; + for (size_t i = 0; i < m_num_actions; i++) + { + actions_selected.push_back(0); + } + + // Invoke the default policy function to get the action + for (u32 current_bag = 0; current_bag < m_bags; current_bag++) + { + action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context); + + if (action_from_bag == 0 || action_from_bag > m_num_actions) + { + throw std::invalid_argument("Action chosen by default policy is not within valid range."); + } + + if (current_bag == chosen_bag) + { + chosen_action = action_from_bag; + } + //this won't work if actions aren't 0 to Count + actions_selected[action_from_bag - 1]++; // action id is one-based + } + float action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based + + return std::tuple<u32, float, bool>(chosen_action, action_probability, true); + } + +private: + vector<unique_ptr<IPolicy<Ctx>>>& m_default_policy_functions; + u32 m_bags; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; +} // End namespace MultiWorldTestingCpp +/*! @} End of Doxygen Groups*/ diff --git a/explore/static/utility.h b/explore/static/utility.h index a6284809..6a2a60d6 100644 --- a/explore/static/utility.h +++ b/explore/static/utility.h @@ -1,286 +1,286 @@ /*******************************************************************/ // Classes declared in this file are intended for internal use only. -/*******************************************************************/
-
-#pragma once
-#include <stdint.h>
-#include <sys/types.h> /* defines size_t */
-
-#ifdef WIN32
-typedef unsigned __int64 u64;
-typedef unsigned __int32 u32;
-typedef unsigned __int16 u16;
-typedef unsigned __int8 u8;
-typedef signed __int64 i64;
-typedef signed __int32 i32;
-typedef signed __int16 i16;
-typedef signed __int8 i8;
-#else
-typedef uint64_t u64;
-typedef uint32_t u32;
-typedef uint16_t u16;
-typedef uint8_t u8;
-typedef int64_t i64;
-typedef int32_t i32;
-typedef int16_t i16;
-typedef int8_t i8;
-#endif
-
-typedef unsigned char byte;
-
-#include <string>
-#include <stdint.h>
-#include <math.h>
-
-/*!
-* \addtogroup MultiWorldTestingCpp
-* @{
-*/
-
-MWT_NAMESPACE {
-
-//
-// MurmurHash3, by Austin Appleby
-//
-// Originals at:
-// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp
-// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.h
-//
-// Notes:
-// 1) this code assumes we can read a 4-byte value from any address
-// without crashing (i.e non aligned access is supported). This is
-// not a problem on Intel/x86/AMD64 machines (including new Macs)
-// 2) It produces different results on little-endian and big-endian machines.
-//
-//-----------------------------------------------------------------------------
-// MurmurHash3 was written by Austin Appleby, and is placed in the public
-// domain. The author hereby disclaims copyright to this source code.
-
-// Note - The x86 and x64 versions do _not_ produce the same results, as the
-// algorithms are optimized for their respective platforms. You can still
-// compile and run any of them on any platform, but your performance with the
-// non-native version will be less than optimal.
-//-----------------------------------------------------------------------------
-
-// Platform-specific functions and macros
-#if defined(_MSC_VER) // Microsoft Visual Studio
-# include <stdint.h>
-
-# include <stdlib.h>
-# define ROTL32(x,y) _rotl(x,y)
-# define BIG_CONSTANT(x) (x)
-
-#else // Other compilers
-# include <stdint.h> /* defines uint32_t etc */
-
- inline uint32_t rotl32(uint32_t x, int8_t r)
- {
- return (x << r) | (x >> (32 - r));
- }
-
-# define ROTL32(x,y) rotl32(x,y)
-# define BIG_CONSTANT(x) (x##LLU)
-
-#endif // !defined(_MSC_VER)
-
-struct murmur_hash {
-
- //-----------------------------------------------------------------------------
- // Block read - if your platform needs to do endian-swapping or can only
- // handle aligned reads, do the conversion here
-private:
- static inline uint32_t getblock(const uint32_t * p, int i)
- {
- return p[i];
- }
-
- //-----------------------------------------------------------------------------
- // Finalization mix - force all bits of a hash block to avalanche
-
- static inline uint32_t fmix(uint32_t h)
- {
- h ^= h >> 16;
- h *= 0x85ebca6b;
- h ^= h >> 13;
- h *= 0xc2b2ae35;
- h ^= h >> 16;
-
- return h;
- }
-
- //-----------------------------------------------------------------------------
-public:
- uint32_t uniform_hash(const void * key, size_t len, uint32_t seed)
- {
- const uint8_t * data = (const uint8_t*)key;
- const int nblocks = (int)len / 4;
-
- uint32_t h1 = seed;
-
- const uint32_t c1 = 0xcc9e2d51;
- const uint32_t c2 = 0x1b873593;
-
- // --- body
- const uint32_t * blocks = (const uint32_t *)(data + nblocks * 4);
-
- for (int i = -nblocks; i; i++) {
- uint32_t k1 = getblock(blocks, i);
-
- k1 *= c1;
- k1 = ROTL32(k1, 15);
- k1 *= c2;
-
- h1 ^= k1;
- h1 = ROTL32(h1, 13);
- h1 = h1 * 5 + 0xe6546b64;
- }
-
- // --- tail
- const uint8_t * tail = (const uint8_t*)(data + nblocks * 4);
-
- uint32_t k1 = 0;
-
- switch (len & 3) {
- case 3: k1 ^= tail[2] << 16;
- case 2: k1 ^= tail[1] << 8;
- case 1: k1 ^= tail[0];
- k1 *= c1; k1 = ROTL32(k1, 15); k1 *= c2; h1 ^= k1;
- }
-
- // --- finalization
- h1 ^= len;
-
- return fmix(h1);
- }
-};
-
-class HashUtils
-{
-public:
- static u64 Compute_Id_Hash(const std::string& unique_id)
- {
- size_t ret = 0;
- const char *p = unique_id.c_str();
- while (*p != '\0')
- if (*p >= '0' && *p <= '9')
- ret = 10 * ret + *(p++) - '0';
- else
- {
- murmur_hash foo;
- return foo.uniform_hash(unique_id.c_str(), unique_id.size(), 0);
- }
- return ret;
- }
-};
-
-const size_t max_int = 100000;
-const float max_float = max_int;
-const float min_float = 0.00001f;
-const size_t max_digits = (size_t) roundf((float) (-log(min_float) / log(10.)));
-
-class NumberUtils
-{
-public:
- static void Float_To_String(float f, char* str)
- {
- int x = (int)f;
- int d = (int)(fabs(f - x) * 100000);
- sprintf(str, "%d.%05d", x, d);
- }
-
- template<bool trailing_zeros>
- static void print_mantissa(char*& begin, float f)
- { // helper for print_float
- char values[10];
- size_t v = (size_t)f;
- size_t digit = 0;
- size_t first_nonzero = 0;
- for (size_t max = 1; max <= v; ++digit)
- {
- size_t max_next = max * 10;
- char v_mod = (char) (v % max_next / max);
- if (!trailing_zeros && v_mod != '\0' && first_nonzero == 0)
- first_nonzero = digit;
- values[digit] = '0' + v_mod;
- max = max_next;
- }
- if (!trailing_zeros)
- for (size_t i = max_digits; i > digit; i--)
- *begin++ = '0';
- while (digit > first_nonzero)
- *begin++ = values[--digit];
- }
-
- static void print_float(char* begin, float f)
- {
- bool sign = false;
- if (f < 0.f)
- sign = true;
- float unsigned_f = fabsf(f);
- if (unsigned_f < max_float && unsigned_f > min_float)
- {
- if (sign)
- *begin++ = '-';
- print_mantissa<true>(begin, unsigned_f);
- unsigned_f -= (size_t)unsigned_f;
- unsigned_f *= max_int;
- if (unsigned_f >= 1.f)
- {
- *begin++ = '.';
- print_mantissa<false>(begin, unsigned_f);
- }
- }
- else if (unsigned_f == 0.)
- *begin++ = '0';
- else
- {
- sprintf(begin, "%g", f);
- return;
- }
- *begin = '\0';
- return;
- }
-};
-
-//A quick implementation similar to drand48 for cross-platform compatibility
-namespace PRG {
- const uint64_t a = 0xeece66d5deece66dULL;
- const uint64_t c = 2147483647;
-
- const int bias = 127 << 23;
-
- union int_float {
- int32_t i;
- float f;
- };
-
- struct prg {
- private:
- uint64_t v;
- public:
- prg() { v = c; }
- prg(uint64_t initial) { v = initial; }
-
- float merand48(uint64_t& initial)
- {
- initial = a * initial + c;
- int_float temp;
- temp.i = ((initial >> 25) & 0x7FFFFF) | bias;
- return temp.f - 1;
- }
-
- float Uniform_Unit_Interval()
- {
- return merand48(v);
- }
-
- uint32_t Uniform_Int(uint32_t low, uint32_t high)
- {
- merand48(v);
- uint32_t ret = low + ((v >> 25) % (high - low + 1));
- return ret;
- }
- };
-}
-}
+/*******************************************************************/ + +#pragma once +#include <stdint.h> +#include <sys/types.h> /* defines size_t */ + +#ifdef WIN32 +typedef unsigned __int64 u64; +typedef unsigned __int32 u32; +typedef unsigned __int16 u16; +typedef unsigned __int8 u8; +typedef signed __int64 i64; +typedef signed __int32 i32; +typedef signed __int16 i16; +typedef signed __int8 i8; +#else +typedef uint64_t u64; +typedef uint32_t u32; +typedef uint16_t u16; +typedef uint8_t u8; +typedef int64_t i64; +typedef int32_t i32; +typedef int16_t i16; +typedef int8_t i8; +#endif + +typedef unsigned char byte; + +#include <string> +#include <stdint.h> +#include <math.h> + +/*! +* \addtogroup MultiWorldTestingCpp +* @{ +*/ + +MWT_NAMESPACE { + +// +// MurmurHash3, by Austin Appleby +// +// Originals at: +// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp +// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.h +// +// Notes: +// 1) this code assumes we can read a 4-byte value from any address +// without crashing (i.e non aligned access is supported). This is +// not a problem on Intel/x86/AMD64 machines (including new Macs) +// 2) It produces different results on little-endian and big-endian machines. +// +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. + +// Note - The x86 and x64 versions do _not_ produce the same results, as the +// algorithms are optimized for their respective platforms. You can still +// compile and run any of them on any platform, but your performance with the +// non-native version will be less than optimal. +//----------------------------------------------------------------------------- + +// Platform-specific functions and macros +#if defined(_MSC_VER) // Microsoft Visual Studio +# include <stdint.h> + +# include <stdlib.h> +# define ROTL32(x,y) _rotl(x,y) +# define BIG_CONSTANT(x) (x) + +#else // Other compilers +# include <stdint.h> /* defines uint32_t etc */ + + inline uint32_t rotl32(uint32_t x, int8_t r) + { + return (x << r) | (x >> (32 - r)); + } + +# define ROTL32(x,y) rotl32(x,y) +# define BIG_CONSTANT(x) (x##LLU) + +#endif // !defined(_MSC_VER) + +struct murmur_hash { + + //----------------------------------------------------------------------------- + // Block read - if your platform needs to do endian-swapping or can only + // handle aligned reads, do the conversion here +private: + static inline uint32_t getblock(const uint32_t * p, int i) + { + return p[i]; + } + + //----------------------------------------------------------------------------- + // Finalization mix - force all bits of a hash block to avalanche + + static inline uint32_t fmix(uint32_t h) + { + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + + return h; + } + + //----------------------------------------------------------------------------- +public: + uint32_t uniform_hash(const void * key, size_t len, uint32_t seed) + { + const uint8_t * data = (const uint8_t*)key; + const int nblocks = (int)len / 4; + + uint32_t h1 = seed; + + const uint32_t c1 = 0xcc9e2d51; + const uint32_t c2 = 0x1b873593; + + // --- body + const uint32_t * blocks = (const uint32_t *)(data + nblocks * 4); + + for (int i = -nblocks; i; i++) { + uint32_t k1 = getblock(blocks, i); + + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + + h1 ^= k1; + h1 = ROTL32(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + } + + // --- tail + const uint8_t * tail = (const uint8_t*)(data + nblocks * 4); + + uint32_t k1 = 0; + + switch (len & 3) { + case 3: k1 ^= tail[2] << 16; + case 2: k1 ^= tail[1] << 8; + case 1: k1 ^= tail[0]; + k1 *= c1; k1 = ROTL32(k1, 15); k1 *= c2; h1 ^= k1; + } + + // --- finalization + h1 ^= len; + + return fmix(h1); + } +}; + +class HashUtils +{ +public: + static u64 Compute_Id_Hash(const std::string& unique_id) + { + size_t ret = 0; + const char *p = unique_id.c_str(); + while (*p != '\0') + if (*p >= '0' && *p <= '9') + ret = 10 * ret + *(p++) - '0'; + else + { + murmur_hash foo; + return foo.uniform_hash(unique_id.c_str(), unique_id.size(), 0); + } + return ret; + } +}; + +const size_t max_int = 100000; +const float max_float = max_int; +const float min_float = 0.00001f; +const size_t max_digits = (size_t) roundf((float) (-log(min_float) / log(10.))); + +class NumberUtils +{ +public: + static void Float_To_String(float f, char* str) + { + int x = (int)f; + int d = (int)(fabs(f - x) * 100000); + sprintf(str, "%d.%05d", x, d); + } + + template<bool trailing_zeros> + static void print_mantissa(char*& begin, float f) + { // helper for print_float + char values[10]; + size_t v = (size_t)f; + size_t digit = 0; + size_t first_nonzero = 0; + for (size_t max = 1; max <= v; ++digit) + { + size_t max_next = max * 10; + char v_mod = (char) (v % max_next / max); + if (!trailing_zeros && v_mod != '\0' && first_nonzero == 0) + first_nonzero = digit; + values[digit] = '0' + v_mod; + max = max_next; + } + if (!trailing_zeros) + for (size_t i = max_digits; i > digit; i--) + *begin++ = '0'; + while (digit > first_nonzero) + *begin++ = values[--digit]; + } + + static void print_float(char* begin, float f) + { + bool sign = false; + if (f < 0.f) + sign = true; + float unsigned_f = fabsf(f); + if (unsigned_f < max_float && unsigned_f > min_float) + { + if (sign) + *begin++ = '-'; + print_mantissa<true>(begin, unsigned_f); + unsigned_f -= (size_t)unsigned_f; + unsigned_f *= max_int; + if (unsigned_f >= 1.f) + { + *begin++ = '.'; + print_mantissa<false>(begin, unsigned_f); + } + } + else if (unsigned_f == 0.) + *begin++ = '0'; + else + { + sprintf(begin, "%g", f); + return; + } + *begin = '\0'; + return; + } +}; + +//A quick implementation similar to drand48 for cross-platform compatibility +namespace PRG { + const uint64_t a = 0xeece66d5deece66dULL; + const uint64_t c = 2147483647; + + const int bias = 127 << 23; + + union int_float { + int32_t i; + float f; + }; + + struct prg { + private: + uint64_t v; + public: + prg() { v = c; } + prg(uint64_t initial) { v = initial; } + + float merand48(uint64_t& initial) + { + initial = a * initial + c; + int_float temp; + temp.i = ((initial >> 25) & 0x7FFFFF) | bias; + return temp.f - 1; + } + + float Uniform_Unit_Interval() + { + return merand48(v); + } + + uint32_t Uniform_Int(uint32_t low, uint32_t high) + { + merand48(v); + uint32_t ret = low + ((v >> 25) % (high - low + 1)); + return ret; + } + }; +} +} /*! @} End of Doxygen Groups*/ diff --git a/explore/tests/MWTExploreTests.h b/explore/tests/MWTExploreTests.h index 447a368f..1c451fb4 100644 --- a/explore/tests/MWTExploreTests.h +++ b/explore/tests/MWTExploreTests.h @@ -1,164 +1,164 @@ -#pragma once
-
-#include "MWTExplorer.h"
-#include "utility.h"
-#include <iomanip>
-#include <iostream>
-#include <sstream>
-
-using namespace MultiWorldTesting;
-
-class TestContext
-{
-
-};
-
-template <class Ctx>
-struct TestInteraction
-{
- Ctx& Context;
- u32 Action;
- float Probability;
- string Unique_Key;
-};
-
-class TestPolicy : public IPolicy<TestContext>
-{
-public:
- TestPolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- u32 Choose_Action(TestContext& context)
- {
- return m_params % m_num_actions + 1; // action id is one-based
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestScorer : public IScorer<TestContext>
-{
-public:
- TestScorer(int params, int num_actions, bool uniform = true) :
- m_params(params), m_num_actions(num_actions), m_uniform(uniform)
- { }
-
- vector<float> Score_Actions(TestContext& context)
- {
- vector<float> scores;
- if (m_uniform)
- {
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params);
- }
- }
- else
- {
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params + i);
- }
- }
- return scores;
- }
-private:
- int m_params;
- int m_num_actions;
- bool m_uniform;
-};
-
-class FixedScorer : public IScorer<TestContext>
-{
-public:
- FixedScorer(int num_actions, int value) :
- m_num_actions(num_actions), m_value(value)
- { }
-
- vector<float> Score_Actions(TestContext& context)
- {
- vector<float> scores;
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back((float)m_value);
- }
- return scores;
- }
-private:
- int m_num_actions;
- int m_value;
-};
-
-class TestSimpleScorer : public IScorer<SimpleContext>
-{
-public:
- TestSimpleScorer(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- vector<float> Score_Actions(SimpleContext& context)
- {
- vector<float> scores;
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params);
- }
- return scores;
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestSimplePolicy : public IPolicy<SimpleContext>
-{
-public:
- TestSimplePolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- u32 Choose_Action(SimpleContext& context)
- {
- return m_params % m_num_actions + 1; // action id is one-based
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestSimpleRecorder : public IRecorder<SimpleContext>
-{
-public:
- virtual void Record(SimpleContext& context, u32 action, float probability, string unique_key)
- {
- m_interactions.push_back({ context, action, probability, unique_key });
- }
-
- vector<TestInteraction<SimpleContext>> Get_All_Interactions()
- {
- return m_interactions;
- }
-
-private:
- vector<TestInteraction<SimpleContext>> m_interactions;
-};
-
-// Return action outside valid range
-class TestBadPolicy : public IPolicy<TestContext>
-{
-public:
- u32 Choose_Action(TestContext& context)
- {
- return 100;
- }
-};
-
-class TestRecorder : public IRecorder<TestContext>
-{
-public:
- virtual void Record(TestContext& context, u32 action, float probability, string unique_key)
- {
- m_interactions.push_back({ context, action, probability, unique_key });
- }
-
- vector<TestInteraction<TestContext>> Get_All_Interactions()
- {
- return m_interactions;
- }
-
-private:
- vector<TestInteraction<TestContext>> m_interactions;
-};
+#pragma once + +#include "MWTExplorer.h" +#include "utility.h" +#include <iomanip> +#include <iostream> +#include <sstream> + +using namespace MultiWorldTesting; + +class TestContext +{ + +}; + +template <class Ctx> +struct TestInteraction +{ + Ctx& Context; + u32 Action; + float Probability; + string Unique_Key; +}; + +class TestPolicy : public IPolicy<TestContext> +{ +public: + TestPolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { } + u32 Choose_Action(TestContext& context) + { + return m_params % m_num_actions + 1; // action id is one-based + } +private: + int m_params; + int m_num_actions; +}; + +class TestScorer : public IScorer<TestContext> +{ +public: + TestScorer(int params, int num_actions, bool uniform = true) : + m_params(params), m_num_actions(num_actions), m_uniform(uniform) + { } + + vector<float> Score_Actions(TestContext& context) + { + vector<float> scores; + if (m_uniform) + { + for (u32 i = 0; i < m_num_actions; i++) + { + scores.push_back(m_params); + } + } + else + { + for (u32 i = 0; i < m_num_actions; i++) + { + scores.push_back(m_params + i); + } + } + return scores; + } +private: + int m_params; + int m_num_actions; + bool m_uniform; +}; + +class FixedScorer : public IScorer<TestContext> +{ +public: + FixedScorer(int num_actions, int value) : + m_num_actions(num_actions), m_value(value) + { } + + vector<float> Score_Actions(TestContext& context) + { + vector<float> scores; + for (u32 i = 0; i < m_num_actions; i++) + { + scores.push_back((float)m_value); + } + return scores; + } +private: + int m_num_actions; + int m_value; +}; + +class TestSimpleScorer : public IScorer<SimpleContext> +{ +public: + TestSimpleScorer(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { } + vector<float> Score_Actions(SimpleContext& context) + { + vector<float> scores; + for (u32 i = 0; i < m_num_actions; i++) + { + scores.push_back(m_params); + } + return scores; + } +private: + int m_params; + int m_num_actions; +}; + +class TestSimplePolicy : public IPolicy<SimpleContext> +{ +public: + TestSimplePolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { } + u32 Choose_Action(SimpleContext& context) + { + return m_params % m_num_actions + 1; // action id is one-based + } +private: + int m_params; + int m_num_actions; +}; + +class TestSimpleRecorder : public IRecorder<SimpleContext> +{ +public: + virtual void Record(SimpleContext& context, u32 action, float probability, string unique_key) + { + m_interactions.push_back({ context, action, probability, unique_key }); + } + + vector<TestInteraction<SimpleContext>> Get_All_Interactions() + { + return m_interactions; + } + +private: + vector<TestInteraction<SimpleContext>> m_interactions; +}; + +// Return action outside valid range +class TestBadPolicy : public IPolicy<TestContext> +{ +public: + u32 Choose_Action(TestContext& context) + { + return 100; + } +}; + +class TestRecorder : public IRecorder<TestContext> +{ +public: + virtual void Record(TestContext& context, u32 action, float probability, string unique_key) + { + m_interactions.push_back({ context, action, probability, unique_key }); + } + + vector<TestInteraction<TestContext>> Get_All_Interactions() + { + return m_interactions; + } + +private: + vector<TestInteraction<TestContext>> m_interactions; +}; diff --git a/library/ezexample_predict.cc b/library/ezexample_predict.cc index db061f61..22c8a86d 100644 --- a/library/ezexample_predict.cc +++ b/library/ezexample_predict.cc @@ -1,55 +1,55 @@ -#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-using namespace std;
-
-int main(int argc, char *argv[])
-{
- string init_string = "-t -q st --hash all --noconstant --ldf_override s -i ";
- if (argc > 1)
- init_string += argv[1];
- else
- init_string += "train.w";
-
- cerr << "initializing with: '" << init_string << "'" << endl;
-
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
- vw* vw = VW::initialize(init_string); // "-t -q st --hash all --noconstant --ldf_override s -i train.w");
-
- {
- // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
- ezexample ex(vw, false); // don't need multiline
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- cerr << ex.predict_partial() << endl;
-
- // ex.clear_features();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("2");
- cerr << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace, and add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("2");
- cerr << ex.predict_partial() << endl;
- }
-
- // AND FINISH UP
- VW::finish(*vw);
-}
+#include <stdio.h> +#include "../vowpalwabbit/parser.h" +#include "../vowpalwabbit/vw.h" +#include "../vowpalwabbit/ezexample.h" + +using namespace std; + +int main(int argc, char *argv[]) +{ + string init_string = "-t -q st --hash all --noconstant --ldf_override s -i "; + if (argc > 1) + init_string += argv[1]; + else + init_string += "train.w"; + + cerr << "initializing with: '" << init_string << "'" << endl; + + // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w + vw* vw = VW::initialize(init_string); // "-t -q st --hash all --noconstant --ldf_override s -i train.w"); + + { + // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS + ezexample ex(vw, false); // don't need multiline + ex(vw_namespace('s')) + ("p^the_man") + ("w^the") + ("w^man") + (vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + ex.set_label("1"); + cerr << ex.predict_partial() << endl; + + // ex.clear_features(); + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + ex.set_label("2"); + cerr << ex.predict_partial() << endl; + + --ex; // remove the most recent namespace, and add features with explicit ns + ex('t', "p^un_homme") + ('t', "w^un") + ('t', "w^homme"); + ex.set_label("2"); + cerr << ex.predict_partial() << endl; + } + + // AND FINISH UP + VW::finish(*vw); +} diff --git a/library/ezexample_predict_threaded.cc b/library/ezexample_predict_threaded.cc index 0fa5b1e6..c7c39e85 100644 --- a/library/ezexample_predict_threaded.cc +++ b/library/ezexample_predict_threaded.cc @@ -1,149 +1,149 @@ -#include <stdio.h>
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-#include <boost/thread/thread.hpp>
-
-using namespace std;
-
-int runcount = 100;
-
-class Worker
-{
-public:
- Worker(vw & instance, string & vw_init_string, vector<double> & ref)
- : m_vw(instance)
- , m_referenceValues(ref)
- , vw_init_string(vw_init_string)
- { }
-
- void operator()()
- {
- m_vw_parser = VW::initialize(vw_init_string);
- if (m_vw_parser == NULL) {
- cerr << "cannot initialize vw parser" << endl;
- exit(-1);
- }
-
- int errorCount = 0;
- for (int i = 0; i < runcount; ++i)
- {
- vector<double>::iterator it = m_referenceValues.begin();
- ezexample ex(&m_vw, false, m_vw_parser);
-
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- //if (*it != pred) { cerr << "fail!" << endl; ++errorCount; }
- //VW::finish_example(m_vw, vec2);
- ++it;
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- ++it;
-
- --ex; // remove the most recent namespace
- // add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- ++it;
-
- //cout << "."; cout.flush();
- }
- cerr << "error count = " << errorCount << endl;
- VW::finish(*m_vw_parser);
- m_vw_parser = NULL;
- }
-
-private:
- vw & m_vw;
- vw * m_vw_parser;
- vector<double> & m_referenceValues;
- string & vw_init_string;
-};
-
-int main(int argc, char *argv[])
-{
- if (argc != 3)
- {
- cerr << "need two args: threadcount runcount" << endl;
- return 1;
- }
- int threadcount = atoi(argv[1]);
- runcount = atoi(argv[2]);
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
- string vw_init_string_all = "-t --ldf_override s --quiet -q st --noconstant --hash all -i train.w";
- string vw_init_string_parser = "-t --ldf_override s --quiet -q st --noconstant --hash all --noop"; // this needs to have enough arguments to get the parser right
- vw*vw = VW::initialize(vw_init_string_all);
- vector<double> results;
-
- // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
- {
- ezexample ex(vw, false);
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near zero = " << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near one = " << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace
- // add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near one = " << ex.predict_partial() << endl;
- }
-
- if (threadcount == 0)
- {
- Worker w(*vw, vw_init_string_parser, results);
- w();
- }
- else
- {
- boost::thread_group tg;
- for (int t = 0; t < threadcount; ++t)
- {
- cerr << "starting thread " << t << endl;
- boost::thread * pt = tg.create_thread(Worker(*vw, vw_init_string_parser, results));
- }
- tg.join_all();
- cerr << "finished!" << endl;
- }
-
-
- // AND FINISH UP
- VW::finish(*vw);
-}
+#include <stdio.h> +#include "../vowpalwabbit/vw.h" +#include "../vowpalwabbit/ezexample.h" + +#include <boost/thread/thread.hpp> + +using namespace std; + +int runcount = 100; + +class Worker +{ +public: + Worker(vw & instance, string & vw_init_string, vector<double> & ref) + : m_vw(instance) + , m_referenceValues(ref) + , vw_init_string(vw_init_string) + { } + + void operator()() + { + m_vw_parser = VW::initialize(vw_init_string); + if (m_vw_parser == NULL) { + cerr << "cannot initialize vw parser" << endl; + exit(-1); + } + + int errorCount = 0; + for (int i = 0; i < runcount; ++i) + { + vector<double>::iterator it = m_referenceValues.begin(); + ezexample ex(&m_vw, false, m_vw_parser); + + ex(vw_namespace('s')) + ("p^the_man") + ("w^the") + ("w^man") + (vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + ex.set_label("1"); + if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; } + //if (*it != pred) { cerr << "fail!" << endl; ++errorCount; } + //VW::finish_example(m_vw, vec2); + ++it; + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + ex.set_label("1"); + if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; } + ++it; + + --ex; // remove the most recent namespace + // add features with explicit ns + ex('t', "p^un_homme") + ('t', "w^un") + ('t', "w^homme"); + ex.set_label("1"); + if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; } + ++it; + + //cout << "."; cout.flush(); + } + cerr << "error count = " << errorCount << endl; + VW::finish(*m_vw_parser); + m_vw_parser = NULL; + } + +private: + vw & m_vw; + vw * m_vw_parser; + vector<double> & m_referenceValues; + string & vw_init_string; +}; + +int main(int argc, char *argv[]) +{ + if (argc != 3) + { + cerr << "need two args: threadcount runcount" << endl; + return 1; + } + int threadcount = atoi(argv[1]); + runcount = atoi(argv[2]); + // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w + string vw_init_string_all = "-t --ldf_override s --quiet -q st --noconstant --hash all -i train.w"; + string vw_init_string_parser = "-t --ldf_override s --quiet -q st --noconstant --hash all --noop"; // this needs to have enough arguments to get the parser right + vw*vw = VW::initialize(vw_init_string_all); + vector<double> results; + + // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS + { + ezexample ex(vw, false); + ex(vw_namespace('s')) + ("p^the_man") + ("w^the") + ("w^man") + (vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + ex.set_label("1"); + results.push_back(ex.predict_partial()); + cerr << "should be near zero = " << ex.predict_partial() << endl; + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + ex.set_label("1"); + results.push_back(ex.predict_partial()); + cerr << "should be near one = " << ex.predict_partial() << endl; + + --ex; // remove the most recent namespace + // add features with explicit ns + ex('t', "p^un_homme") + ('t', "w^un") + ('t', "w^homme"); + ex.set_label("1"); + results.push_back(ex.predict_partial()); + cerr << "should be near one = " << ex.predict_partial() << endl; + } + + if (threadcount == 0) + { + Worker w(*vw, vw_init_string_parser, results); + w(); + } + else + { + boost::thread_group tg; + for (int t = 0; t < threadcount; ++t) + { + cerr << "starting thread " << t << endl; + boost::thread * pt = tg.create_thread(Worker(*vw, vw_init_string_parser, results)); + } + tg.join_all(); + cerr << "finished!" << endl; + } + + + // AND FINISH UP + VW::finish(*vw); +} diff --git a/library/ezexample_train.cc b/library/ezexample_train.cc index a0f66a99..9a8af8e0 100644 --- a/library/ezexample_train.cc +++ b/library/ezexample_train.cc @@ -1,72 +1,72 @@ -#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-using namespace std;
-
-void run(vw*vw) {
- ezexample ex(vw, true); // we're doing csoaa_ldf so we need multiline examples
-
- /// BEGIN FIRST MULTILINE EXAMPLE
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
-
- ex.set_label("1:1");
- ex.train();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
-
- ex.set_label("2:0");
- ex.train();
-
- // push it through VW for training
- ex.finish();
-
- /// BEGIN SECOND MULTILINE EXAMPLE
- ex(vw_namespace('s'))
- ("p^a_man")
- ("w^a")
- ("w^man")
- (vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
-
- ex.set_label("1:0");
- ex.train();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
-
- ex.set_label("2:1");
- ex.train();
-
- // push it through VW for training
- ex.finish();
-}
-
-int main(int argc, char *argv[])
-{
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS WILL STORE A MODEL TO train.ezw
- vw* vw = VW::initialize("--hash all -q st --noconstant -f train.w --quiet --csoaa_ldf m");
-
- run(vw);
-
- // AND FINISH UP
- cerr << "ezexample_train finish"<<endl;
- VW::finish(*vw);
-}
+#include <stdio.h> +#include "../vowpalwabbit/parser.h" +#include "../vowpalwabbit/vw.h" +#include "../vowpalwabbit/ezexample.h" + +using namespace std; + +void run(vw*vw) { + ezexample ex(vw, true); // we're doing csoaa_ldf so we need multiline examples + + /// BEGIN FIRST MULTILINE EXAMPLE + ex(vw_namespace('s')) + ("p^the_man") + ("w^the") + ("w^man") + (vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + + ex.set_label("1:1"); + ex.train(); + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + + ex.set_label("2:0"); + ex.train(); + + // push it through VW for training + ex.finish(); + + /// BEGIN SECOND MULTILINE EXAMPLE + ex(vw_namespace('s')) + ("p^a_man") + ("w^a") + ("w^man") + (vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + + ex.set_label("1:0"); + ex.train(); + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + + ex.set_label("2:1"); + ex.train(); + + // push it through VW for training + ex.finish(); +} + +int main(int argc, char *argv[]) +{ + // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS WILL STORE A MODEL TO train.ezw + vw* vw = VW::initialize("--hash all -q st --noconstant -f train.w --quiet --csoaa_ldf m"); + + run(vw); + + // AND FINISH UP + cerr << "ezexample_train finish"<<endl; + VW::finish(*vw); +} diff --git a/library/gd_mf_weights.cc b/library/gd_mf_weights.cc index 4b394775..2c50349f 100644 --- a/library/gd_mf_weights.cc +++ b/library/gd_mf_weights.cc @@ -1,121 +1,121 @@ -#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-#include <fstream>
-#include <iostream>
-#include <string.h>
-#include <boost/program_options.hpp>
-
-using namespace std;
-namespace po = boost::program_options;
-
-
-int main(int argc, char *argv[])
-{
- string infile;
- string outdir(".");
- string vwparams;
-
- po::variables_map vm;
- po::options_description desc("Allowed options");
- desc.add_options()
- ("help,h", "produce help message")
- ("infile,I", po::value<string>(&infile), "input (in vw format) of weights to extract")
- ("outdir,O", po::value<string>(&outdir), "directory to write model files to (default: .)")
- ("vwparams", po::value<string>(&vwparams), "vw parameters for model instantiation (-i model.reg -t ...")
- ;
-
- try {
- po::store(po::parse_command_line(argc, argv, desc), vm);
- po::notify(vm);
- }
- catch(exception & e)
- {
- cout << endl << argv[0] << ": " << e.what() << endl << endl << desc << endl;
- exit(2);
- }
-
- if (vm.count("help") || infile.empty() || vwparams.empty()) {
- cout << "Dumps weights for matrix factorization model (gd_mf)." << endl;
- cout << "The constant will be written to <outdir>/constant." << endl;
- cout << "Linear and quadratic weights corresponding to the input features will be " << endl;
- cout << "written to <outdir>/<ns>.linear and <outdir>/<ns>.quadratic,respectively." << endl;
- cout << endl;
- cout << desc << "\n";
- cout << "Example usage:" << endl;
- cout << " Extract weights for user 42 and item 7 under randomly initialized rank 10 model:" << endl;
- cout << " echo '|u 42 |i 7' | ./gd_mf_weights -I /dev/stdin --vwparams '-q ui --rank 10'" << endl;
- return 1;
- }
-
- // initialize model
- vw* model = VW::initialize(vwparams);
- model->audit = true;
-
- // global model params
- unsigned char left_ns = model->pairs[0][0];
- unsigned char right_ns = model->pairs[0][1];
- weight* weights = model->reg.weight_vector;
- size_t mask = model->reg.weight_mask;
-
- // const char *filename = argv[0];
- FILE* file = fopen(infile.c_str(), "r");
- char* line = NULL;
- size_t len = 0;
- ssize_t read;
-
- // output files
- ofstream constant((outdir + string("/") + string("constant")).c_str()),
- left_linear((outdir + string("/") + string(1, left_ns) + string(".linear")).c_str()),
- left_quadratic((outdir + string("/") + string(1, left_ns) + string(".quadratic")).c_str()),
- right_linear((outdir + string("/") + string(1, right_ns) + string(".linear")).c_str()),
- right_quadratic((outdir + string("/") + string(1, right_ns) + string(".quadratic")).c_str());
-
- example *ec = NULL;
- while ((read = getline(&line, &len, file)) != -1)
- {
- line[strlen(line)-1] = 0; // chop
-
- ec = VW::read_example(*model, line);
-
- // write out features for left namespace
- if (ec->audit_features[left_ns].begin != ec->audit_features[left_ns].end)
- {
- for (audit_data *f = ec->audit_features[left_ns].begin; f != ec->audit_features[left_ns].end; f++)
- {
- left_linear << f->feature << '\t' << weights[f->weight_index & mask];
-
- left_quadratic << f->feature;
- for (size_t k = 1; k <= model->rank; k++)
- left_quadratic << '\t' << weights[(f->weight_index + k) & mask];
- }
- left_linear << endl;
- left_quadratic << endl;
- }
-
- // write out features for right namespace
- if (ec->audit_features[right_ns].begin != ec->audit_features[right_ns].end)
- {
- for (audit_data *f = ec->audit_features[right_ns].begin; f != ec->audit_features[right_ns].end; f++)
- {
- right_linear << f->feature << '\t' << weights[f->weight_index & mask];
-
- right_quadratic << f->feature;
- for (size_t k = 1; k <= model->rank; k++)
- right_quadratic << '\t' << weights[(f->weight_index + k + model->rank) & mask];
- }
- right_linear << endl;
- right_quadratic << endl;
- }
-
- VW::finish_example(*model, ec);
- }
-
- // write constant
- feature* f = ec->atomics[constant_namespace].begin;
- constant << weights[f->weight_index & mask] << endl;
-
- // clean up
- VW::finish(*model);
- fclose(file);
-}
+#include <stdio.h> +#include "../vowpalwabbit/parser.h" +#include "../vowpalwabbit/vw.h" +#include <fstream> +#include <iostream> +#include <string.h> +#include <boost/program_options.hpp> + +using namespace std; +namespace po = boost::program_options; + + +int main(int argc, char *argv[]) +{ + string infile; + string outdir("."); + string vwparams; + + po::variables_map vm; + po::options_description desc("Allowed options"); + desc.add_options() + ("help,h", "produce help message") + ("infile,I", po::value<string>(&infile), "input (in vw format) of weights to extract") + ("outdir,O", po::value<string>(&outdir), "directory to write model files to (default: .)") + ("vwparams", po::value<string>(&vwparams), "vw parameters for model instantiation (-i model.reg -t ...") + ; + + try { + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + } + catch(exception & e) + { + cout << endl << argv[0] << ": " << e.what() << endl << endl << desc << endl; + exit(2); + } + + if (vm.count("help") || infile.empty() || vwparams.empty()) { + cout << "Dumps weights for matrix factorization model (gd_mf)." << endl; + cout << "The constant will be written to <outdir>/constant." << endl; + cout << "Linear and quadratic weights corresponding to the input features will be " << endl; + cout << "written to <outdir>/<ns>.linear and <outdir>/<ns>.quadratic,respectively." << endl; + cout << endl; + cout << desc << "\n"; + cout << "Example usage:" << endl; + cout << " Extract weights for user 42 and item 7 under randomly initialized rank 10 model:" << endl; + cout << " echo '|u 42 |i 7' | ./gd_mf_weights -I /dev/stdin --vwparams '-q ui --rank 10'" << endl; + return 1; + } + + // initialize model + vw* model = VW::initialize(vwparams); + model->audit = true; + + // global model params + unsigned char left_ns = model->pairs[0][0]; + unsigned char right_ns = model->pairs[0][1]; + weight* weights = model->reg.weight_vector; + size_t mask = model->reg.weight_mask; + + // const char *filename = argv[0]; + FILE* file = fopen(infile.c_str(), "r"); + char* line = NULL; + size_t len = 0; + ssize_t read; + + // output files + ofstream constant((outdir + string("/") + string("constant")).c_str()), + left_linear((outdir + string("/") + string(1, left_ns) + string(".linear")).c_str()), + left_quadratic((outdir + string("/") + string(1, left_ns) + string(".quadratic")).c_str()), + right_linear((outdir + string("/") + string(1, right_ns) + string(".linear")).c_str()), + right_quadratic((outdir + string("/") + string(1, right_ns) + string(".quadratic")).c_str()); + + example *ec = NULL; + while ((read = getline(&line, &len, file)) != -1) + { + line[strlen(line)-1] = 0; // chop + + ec = VW::read_example(*model, line); + + // write out features for left namespace + if (ec->audit_features[left_ns].begin != ec->audit_features[left_ns].end) + { + for (audit_data *f = ec->audit_features[left_ns].begin; f != ec->audit_features[left_ns].end; f++) + { + left_linear << f->feature << '\t' << weights[f->weight_index & mask]; + + left_quadratic << f->feature; + for (size_t k = 1; k <= model->rank; k++) + left_quadratic << '\t' << weights[(f->weight_index + k) & mask]; + } + left_linear << endl; + left_quadratic << endl; + } + + // write out features for right namespace + if (ec->audit_features[right_ns].begin != ec->audit_features[right_ns].end) + { + for (audit_data *f = ec->audit_features[right_ns].begin; f != ec->audit_features[right_ns].end; f++) + { + right_linear << f->feature << '\t' << weights[f->weight_index & mask]; + + right_quadratic << f->feature; + for (size_t k = 1; k <= model->rank; k++) + right_quadratic << '\t' << weights[(f->weight_index + k + model->rank) & mask]; + } + right_linear << endl; + right_quadratic << endl; + } + + VW::finish_example(*model, ec); + } + + // write constant + feature* f = ec->atomics[constant_namespace].begin; + constant << weights[f->weight_index & mask] << endl; + + // clean up + VW::finish(*model); + fclose(file); +} diff --git a/library/library_example.cc b/library/library_example.cc index 8abfc1f7..d7c186c2 100644 --- a/library/library_example.cc +++ b/library/library_example.cc @@ -1,62 +1,62 @@ -#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-
-using namespace std;
-
-
-inline feature vw_feature_from_string(vw& v, string fstr, unsigned long seed, float val)
-{
- uint32_t foo = VW::hash_feature(v, fstr, seed);
- feature f = { val, foo};
- return f;
-}
-
-int main(int argc, char *argv[])
-{
- vw* model = VW::initialize("--hash all -q st --noconstant -i train.w -f train2.vw");
-
- example *vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme");
- model->learn(vec2);
- cerr << "p2 = " << vec2->pred.scalar << endl;
- VW::finish_example(*model, vec2);
-
- vector< VW::feature_space > ec_info;
- vector<feature> s_features, t_features;
- uint32_t s_hash = VW::hash_space(*model, "s");
- uint32_t t_hash = VW::hash_space(*model, "t");
- s_features.push_back( vw_feature_from_string(*model, "p^the_man", s_hash, 1.0) );
- s_features.push_back( vw_feature_from_string(*model, "w^the", s_hash, 1.0) );
- s_features.push_back( vw_feature_from_string(*model, "w^man", s_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "p^le_homme", t_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "w^le", t_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "w^homme", t_hash, 1.0) );
- ec_info.push_back( VW::feature_space('s', s_features) );
- ec_info.push_back( VW::feature_space('t', t_features) );
- example* vec3 = VW::import_example(*model, ec_info);
-
- model->learn(vec3);
- cerr << "p3 = " << vec3->pred.scalar << endl;
- VW::finish_example(*model, vec3);
-
- VW::finish(*model);
-
- vw* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw");
- vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme");
- model2->learn(vec2);
- cerr << "p4 = " << vec2->pred.scalar << endl;
-
- size_t len=0;
- VW::primitive_feature_space* pfs = VW::export_example(*model2, vec2, len);
- for (size_t i = 0; i < len; i++)
- {
- cout << "namespace = " << pfs[i].name;
- for (size_t j = 0; j < pfs[i].len; j++)
- cout << " " << pfs[i].fs[j].weight_index << ":" << pfs[i].fs[j].x << ":" << VW::get_weight(*model2, pfs[i].fs[j].weight_index, 0);
- cout << endl;
- }
-
- VW::finish_example(*model2, vec2);
- VW::finish(*model2);
-}
-
+#include <stdio.h> +#include "../vowpalwabbit/parser.h" +#include "../vowpalwabbit/vw.h" + +using namespace std; + + +inline feature vw_feature_from_string(vw& v, string fstr, unsigned long seed, float val) +{ + uint32_t foo = VW::hash_feature(v, fstr, seed); + feature f = { val, foo}; + return f; +} + +int main(int argc, char *argv[]) +{ + vw* model = VW::initialize("--hash all -q st --noconstant -i train.w -f train2.vw"); + + example *vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme"); + model->learn(vec2); + cerr << "p2 = " << vec2->pred.scalar << endl; + VW::finish_example(*model, vec2); + + vector< VW::feature_space > ec_info; + vector<feature> s_features, t_features; + uint32_t s_hash = VW::hash_space(*model, "s"); + uint32_t t_hash = VW::hash_space(*model, "t"); + s_features.push_back( vw_feature_from_string(*model, "p^the_man", s_hash, 1.0) ); + s_features.push_back( vw_feature_from_string(*model, "w^the", s_hash, 1.0) ); + s_features.push_back( vw_feature_from_string(*model, "w^man", s_hash, 1.0) ); + t_features.push_back( vw_feature_from_string(*model, "p^le_homme", t_hash, 1.0) ); + t_features.push_back( vw_feature_from_string(*model, "w^le", t_hash, 1.0) ); + t_features.push_back( vw_feature_from_string(*model, "w^homme", t_hash, 1.0) ); + ec_info.push_back( VW::feature_space('s', s_features) ); + ec_info.push_back( VW::feature_space('t', t_features) ); + example* vec3 = VW::import_example(*model, ec_info); + + model->learn(vec3); + cerr << "p3 = " << vec3->pred.scalar << endl; + VW::finish_example(*model, vec3); + + VW::finish(*model); + + vw* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw"); + vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme"); + model2->learn(vec2); + cerr << "p4 = " << vec2->pred.scalar << endl; + + size_t len=0; + VW::primitive_feature_space* pfs = VW::export_example(*model2, vec2, len); + for (size_t i = 0; i < len; i++) + { + cout << "namespace = " << pfs[i].name; + for (size_t j = 0; j < pfs[i].len; j++) + cout << " " << pfs[i].fs[j].weight_index << ":" << pfs[i].fs[j].x << ":" << VW::get_weight(*model2, pfs[i].fs[j].weight_index, 0); + cout << endl; + } + + VW::finish_example(*model2, vec2); + VW::finish(*model2); +} + diff --git a/vowpalwabbit/accumulate.cc b/vowpalwabbit/accumulate.cc index d6c5e71f..8d15dd59 100644 --- a/vowpalwabbit/accumulate.cc +++ b/vowpalwabbit/accumulate.cc @@ -1,115 +1,115 @@ -/*
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD (revised)
-license as described in the file LICENSE.
- */
-/*
-This implements the allreduce function of MPI. Code primarily by
-Alekh Agarwal and John Langford, with help Olivier Chapelle.
-*/
-
-#include <iostream>
-#include <sys/timeb.h>
-#include <cmath>
-#include <stdint.h>
-#include "accumulate.h"
-#include "global_data.h"
-
-using namespace std;
-
-void add_float(float& c1, const float& c2) {
- c1 += c2;
-}
-
-void accumulate(vw& all, string master_location, regressor& reg, size_t o) {
- uint32_t length = 1 << all.num_bits; //This is size of gradient
- size_t stride = 1 << all.reg.stride_shift;
- float* local_grad = new float[length];
- weight* weights = reg.weight_vector;
- for(uint32_t i = 0;i < length;i++)
- {
- local_grad[i] = weights[stride*i+o];
- }
-
- all_reduce<float, add_float>(local_grad, length, master_location, all.unique_id, all.total, all.node, all.socks);
- for(uint32_t i = 0;i < length;i++)
- {
- weights[stride*i+o] = local_grad[i];
- }
- delete[] local_grad;
-}
-
-float accumulate_scalar(vw& all, string master_location, float local_sum) {
- float temp = local_sum;
- all_reduce<float, add_float>(&temp, 1, master_location, all.unique_id, all.total, all.node, all.socks);
- return temp;
-}
-
-void accumulate_avg(vw& all, string master_location, regressor& reg, size_t o) {
- uint32_t length = 1 << all.num_bits; //This is size of gradient
- size_t stride = 1 << all.reg.stride_shift;
- float* local_grad = new float[length];
- weight* weights = reg.weight_vector;
- float numnodes = (float)all.total;
-
- for(uint32_t i = 0;i < length;i++)
- local_grad[i] = weights[stride*i+o];
-
- all_reduce<float, add_float>(local_grad, length, master_location, all.unique_id, all.total, all.node, all.socks);
- for(uint32_t i = 0;i < length;i++)
- weights[stride*i+o] = local_grad[i]/numnodes;
- delete[] local_grad;
-}
-
-float max_elem(float* arr, int length) {
- float max = arr[0];
- for(int i = 1;i < length;i++)
- if(arr[i] > max) max = arr[i];
- return max;
-}
-
-float min_elem(float* arr, int length) {
- float min = arr[0];
- for(int i = 1;i < length;i++)
- if(arr[i] < min && arr[i] > 0.001) min = arr[i];
- return min;
-}
-
-void accumulate_weighted_avg(vw& all, string master_location, regressor& reg) {
- if(!all.adaptive) {
- cerr<<"Weighted averaging is implemented only for adaptive gradient, use accumulate_avg instead\n";
- return;
- }
- uint32_t length = 1 << all.num_bits; //This is the number of parameters
- size_t stride = 1 << all.reg.stride_shift;
- weight* weights = reg.weight_vector;
-
-
- float* local_weights = new float[length];
-
- for(uint32_t i = 0;i < length;i++)
- local_weights[i] = weights[stride*i+1];
-
-
- //First compute weights for averaging
- all_reduce<float, add_float>(local_weights, length, master_location, all.unique_id, all.total, all.node, all.socks);
-
- for(uint32_t i = 0;i < length;i++) //Compute weighted versions
- if(local_weights[i] > 0) {
- float ratio = weights[stride*i+1]/local_weights[i];
- local_weights[i] = weights[stride*i] * ratio;
- weights[stride*i] *= ratio;
- weights[stride*i+1] *= ratio; //A crude max
- if (all.normalized_updates)
- weights[stride*i+all.normalized_idx] *= ratio; //A crude max
- }
- else {
- local_weights[i] = 0;
- weights[stride*i] = 0;
- }
-
- all_reduce<float, add_float>(weights, length*stride, master_location, all.unique_id, all.total, all.node, all.socks);
-
- delete[] local_weights;
-}
-
+/* +Copyright (c) by respective owners including Yahoo!, Microsoft, and +individual contributors. All rights reserved. Released under a BSD (revised) +license as described in the file LICENSE. + */ +/* +This implements the allreduce function of MPI. Code primarily by +Alekh Agarwal and John Langford, with help Olivier Chapelle. +*/ + +#include <iostream> +#include <sys/timeb.h> +#include <cmath> +#include <stdint.h> +#include "accumulate.h" +#include "global_data.h" + +using namespace std; + +void add_float(float& c1, const float& c2) { + c1 += c2; +} + +void accumulate(vw& all, string master_location, regressor& reg, size_t o) { + uint32_t length = 1 << all.num_bits; //This is size of gradient + size_t stride = 1 << all.reg.stride_shift; + float* local_grad = new float[length]; + weight* weights = reg.weight_vector; + for(uint32_t i = 0;i < length;i++) + { + local_grad[i] = weights[stride*i+o]; + } + + all_reduce<float, add_float>(local_grad, length, master_location, all.unique_id, all.total, all.node, all.socks); + for(uint32_t i = 0;i < length;i++) + { + weights[stride*i+o] = local_grad[i]; + } + delete[] local_grad; +} + +float accumulate_scalar(vw& all, string master_location, float local_sum) { + float temp = local_sum; + all_reduce<float, add_float>(&temp, 1, master_location, all.unique_id, all.total, all.node, all.socks); + return temp; +} + +void accumulate_avg(vw& all, string master_location, regressor& reg, size_t o) { + uint32_t length = 1 << all.num_bits; //This is size of gradient + size_t stride = 1 << all.reg.stride_shift; + float* local_grad = new float[length]; + weight* weights = reg.weight_vector; + float numnodes = (float)all.total; + + for(uint32_t i = 0;i < length;i++) + local_grad[i] = weights[stride*i+o]; + + all_reduce<float, add_float>(local_grad, length, master_location, all.unique_id, all.total, all.node, all.socks); + for(uint32_t i = 0;i < length;i++) + weights[stride*i+o] = local_grad[i]/numnodes; + delete[] local_grad; +} + +float max_elem(float* arr, int length) { + float max = arr[0]; + for(int i = 1;i < length;i++) + if(arr[i] > max) max = arr[i]; + return max; +} + +float min_elem(float* arr, int length) { + float min = arr[0]; + for(int i = 1;i < length;i++) + if(arr[i] < min && arr[i] > 0.001) min = arr[i]; + return min; +} + +void accumulate_weighted_avg(vw& all, string master_location, regressor& reg) { + if(!all.adaptive) { + cerr<<"Weighted averaging is implemented only for adaptive gradient, use accumulate_avg instead\n"; + return; + } + uint32_t length = 1 << all.num_bits; //This is the number of parameters + size_t stride = 1 << all.reg.stride_shift; + weight* weights = reg.weight_vector; + + + float* local_weights = new float[length]; + + for(uint32_t i = 0;i < length;i++) + local_weights[i] = weights[stride*i+1]; + + + //First compute weights for averaging + all_reduce<float, add_float>(local_weights, length, master_location, all.unique_id, all.total, all.node, all.socks); + + for(uint32_t i = 0;i < length;i++) //Compute weighted versions + if(local_weights[i] > 0) { + float ratio = weights[stride*i+1]/local_weights[i]; + local_weights[i] = weights[stride*i] * ratio; + weights[stride*i] *= ratio; + weights[stride*i+1] *= ratio; //A crude max + if (all.normalized_updates) + weights[stride*i+all.normalized_idx] *= ratio; //A crude max + } + else { + local_weights[i] = 0; + weights[stride*i] = 0; + } + + all_reduce<float, add_float>(weights, length*stride, master_location, all.unique_id, all.total, all.node, all.socks); + + delete[] local_weights; +} + diff --git a/vowpalwabbit/accumulate.h b/vowpalwabbit/accumulate.h index c01ac5fe..4d507a60 100644 --- a/vowpalwabbit/accumulate.h +++ b/vowpalwabbit/accumulate.h @@ -1,13 +1,13 @@ -/*
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD
-license as described in the file LICENSE.
- */
-//This implements various accumulate functions building on top of allreduce.
-#pragma once
-#include "global_data.h"
-
-void accumulate(vw& all, std::string master_location, regressor& reg, size_t o);
-float accumulate_scalar(vw& all, std::string master_location, float local_sum);
-void accumulate_weighted_avg(vw& all, std::string master_location, regressor& reg);
-void accumulate_avg(vw& all, std::string master_location, regressor& reg, size_t o);
+/* +Copyright (c) by respective owners including Yahoo!, Microsoft, and +individual contributors. All rights reserved. Released under a BSD +license as described in the file LICENSE. + */ +//This implements various accumulate functions building on top of allreduce. +#pragma once +#include "global_data.h" + +void accumulate(vw& all, std::string master_location, regressor& reg, size_t o); +float accumulate_scalar(vw& all, std::string master_location, float local_sum); +void accumulate_weighted_avg(vw& all, std::string master_location, regressor& reg); +void accumulate_avg(vw& all, std::string master_location, regressor& reg, size_t o); diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc index 90f9e171..23ca9bfe 100644 --- a/vowpalwabbit/lda_core.cc +++ b/vowpalwabbit/lda_core.cc @@ -1,802 +1,802 @@ -/*
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD (revised)
-license as described in the file LICENSE.
- */
-#include <fstream>
-#include <vector>
-#include <float.h>
-#ifdef _WIN32
-#include <winsock2.h>
-#else
-#include <netdb.h>
-#endif
-#include <string.h>
-#include <stdio.h>
-#include <assert.h>
-#include "constant.h"
-#include "gd.h"
-#include "simple_label.h"
-#include "rand48.h"
-#include "reductions.h"
-
-using namespace LEARNER;
-using namespace std;
-
-namespace LDA {
-
-class index_feature {
-public:
- uint32_t document;
- feature f;
- bool operator<(const index_feature b) const { return f.weight_index < b.f.weight_index; }
-};
-
- struct lda {
- v_array<float> Elogtheta;
- v_array<float> decay_levels;
- v_array<float> total_new;
- v_array<example* > examples;
- v_array<float> total_lambda;
- v_array<int> doc_lengths;
- v_array<float> digammas;
- v_array<float> v;
- vector<index_feature> sorted_features;
-
- bool total_lambda_init;
-
- double example_t;
- vw* all;
- };
-
-#ifdef _WIN32
-inline float fmax(float f1, float f2) { return (f1 < f2 ? f2 : f1); }
-inline float fmin(float f1, float f2) { return (f1 > f2 ? f2 : f1); }
-#endif
-
-#define MINEIRO_SPECIAL
-#ifdef MINEIRO_SPECIAL
-
-namespace {
-
-inline float
-fastlog2 (float x)
-{
- union { float f; uint32_t i; } vx = { x };
- union { uint32_t i; float f; } mx = { (vx.i & 0x007FFFFF) | (0x7e << 23) };
- float y = (float)vx.i;
- y *= 1.0f / (float)(1 << 23);
-
- return
- y - 124.22544637f - 1.498030302f * mx.f - 1.72587999f / (0.3520887068f + mx.f);
-}
-
-inline float
-fastlog (float x)
-{
- return 0.69314718f * fastlog2 (x);
-}
-
-inline float
-fastpow2 (float p)
-{
- float offset = (p < 0) ? 1.0f : 0.0f;
- float clipp = (p < -126) ? -126.0f : p;
- int w = (int)clipp;
- float z = clipp - w + offset;
- union { uint32_t i; float f; } v = { (uint32_t)((1 << 23) * (clipp + 121.2740838f + 27.7280233f / (4.84252568f - z) - 1.49012907f * z)) };
-
- return v.f;
-}
-
-inline float
-fastexp (float p)
-{
- return fastpow2 (1.442695040f * p);
-}
-
-inline float
-fastpow (float x,
- float p)
-{
- return fastpow2 (p * fastlog2 (x));
-}
-
-inline float
-fastlgamma (float x)
-{
- float logterm = fastlog (x * (1.0f + x) * (2.0f + x));
- float xp3 = 3.0f + x;
-
- return
- -2.081061466f - x + 0.0833333f / xp3 - logterm + (2.5f + x) * fastlog (xp3);
-}
-
-inline float
-fastdigamma (float x)
-{
- float twopx = 2.0f + x;
- float logterm = fastlog (twopx);
-
- return - (1.0f + 2.0f * x) / (x * (1.0f + x))
- - (13.0f + 6.0f * x) / (12.0f * twopx * twopx)
- + logterm;
-}
-
-#define log fastlog
-#define exp fastexp
-#define powf fastpow
-#define mydigamma fastdigamma
-#define mylgamma fastlgamma
-
-#if defined(__SSE2__) && !defined(VW_LDA_NO_SSE)
-
-#include <emmintrin.h>
-
-typedef __m128 v4sf;
-typedef __m128i v4si;
-
-#define v4si_to_v4sf _mm_cvtepi32_ps
-#define v4sf_to_v4si _mm_cvttps_epi32
-
-static inline float
-v4sf_index (const v4sf x,
- unsigned int i)
-{
- union { v4sf f; float array[4]; } tmp = { x };
-
- return tmp.array[i];
-}
-
-static inline const v4sf
-v4sfl (float x)
-{
- union { float array[4]; v4sf f; } tmp = { { x, x, x, x } };
-
- return tmp.f;
-}
-
-static inline const v4si
-v4sil (uint32_t x)
-{
- uint64_t wide = (((uint64_t) x) << 32) | x;
- union { uint64_t array[2]; v4si f; } tmp = { { wide, wide } };
-
- return tmp.f;
-}
-
-static inline v4sf
-vfastpow2 (const v4sf p)
-{
- v4sf ltzero = _mm_cmplt_ps (p, v4sfl (0.0f));
- v4sf offset = _mm_and_ps (ltzero, v4sfl (1.0f));
- v4sf lt126 = _mm_cmplt_ps (p, v4sfl (-126.0f));
- v4sf clipp = _mm_andnot_ps (lt126, p) + _mm_and_ps (lt126, v4sfl (-126.0f));
- v4si w = v4sf_to_v4si (clipp);
- v4sf z = clipp - v4si_to_v4sf (w) + offset;
-
- const v4sf c_121_2740838 = v4sfl (121.2740838f);
- const v4sf c_27_7280233 = v4sfl (27.7280233f);
- const v4sf c_4_84252568 = v4sfl (4.84252568f);
- const v4sf c_1_49012907 = v4sfl (1.49012907f);
- union { v4si i; v4sf f; } v = {
- v4sf_to_v4si (
- v4sfl (1 << 23) *
- (clipp + c_121_2740838 + c_27_7280233 / (c_4_84252568 - z) - c_1_49012907 * z)
- )
- };
-
- return v.f;
-}
-
-inline v4sf
-vfastexp (const v4sf p)
-{
- const v4sf c_invlog_2 = v4sfl (1.442695040f);
-
- return vfastpow2 (c_invlog_2 * p);
-}
-
-inline v4sf
-vfastlog2 (v4sf x)
-{
- union { v4sf f; v4si i; } vx = { x };
- union { v4si i; v4sf f; } mx = { (vx.i & v4sil (0x007FFFFF)) | v4sil (0x3f000000) };
- v4sf y = v4si_to_v4sf (vx.i);
- y *= v4sfl (1.1920928955078125e-7f);
-
- const v4sf c_124_22551499 = v4sfl (124.22551499f);
- const v4sf c_1_498030302 = v4sfl (1.498030302f);
- const v4sf c_1_725877999 = v4sfl (1.72587999f);
- const v4sf c_0_3520087068 = v4sfl (0.3520887068f);
-
- return y - c_124_22551499
- - c_1_498030302 * mx.f
- - c_1_725877999 / (c_0_3520087068 + mx.f);
-}
-
-inline v4sf
-vfastlog (v4sf x)
-{
- const v4sf c_0_69314718 = v4sfl (0.69314718f);
-
- return c_0_69314718 * vfastlog2 (x);
-}
-
-inline v4sf
-vfastdigamma (v4sf x)
-{
- v4sf twopx = v4sfl (2.0f) + x;
- v4sf logterm = vfastlog (twopx);
-
- return (v4sfl (-48.0f) + x * (v4sfl (-157.0f) + x * (v4sfl (-127.0f) - v4sfl (30.0f) * x))) /
- (v4sfl (12.0f) * x * (v4sfl (1.0f) + x) * twopx * twopx)
- + logterm;
-}
-
-void
-vexpdigammify (vw& all, float* gamma)
-{
- unsigned int n = all.lda;
- float extra_sum = 0.0f;
- v4sf sum = v4sfl (0.0f);
- size_t i;
-
- for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
- {
- extra_sum += gamma[i];
- gamma[i] = fastdigamma (gamma[i]);
- }
-
- for (; i + 4 < n; i += 4)
- {
- v4sf arg = _mm_load_ps (gamma + i);
- sum += arg;
- arg = vfastdigamma (arg);
- _mm_store_ps (gamma + i, arg);
- }
-
- for (; i < n; ++i)
- {
- extra_sum += gamma[i];
- gamma[i] = fastdigamma (gamma[i]);
- }
-
- extra_sum += v4sf_index (sum, 0) + v4sf_index (sum, 1) +
- v4sf_index (sum, 2) + v4sf_index (sum, 3);
- extra_sum = fastdigamma (extra_sum);
- sum = v4sfl (extra_sum);
-
- for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0)));
- }
-
- for (; i + 4 < n; i += 4)
- {
- v4sf arg = _mm_load_ps (gamma + i);
- arg -= sum;
- arg = vfastexp (arg);
- arg = _mm_max_ps (v4sfl (1e-10f), arg);
- _mm_store_ps (gamma + i, arg);
- }
-
- for (; i < n; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0)));
- }
-}
-
-void vexpdigammify_2(vw& all, float* gamma, const float* norm)
-{
- size_t n = all.lda;
- size_t i;
-
- for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i]));
- }
-
- for (; i + 4 < n; i += 4)
- {
- v4sf arg = _mm_load_ps (gamma + i);
- arg = vfastdigamma (arg);
- v4sf vnorm = _mm_loadu_ps (norm + i);
- arg -= vnorm;
- arg = vfastexp (arg);
- arg = _mm_max_ps (v4sfl (1e-10f), arg);
- _mm_store_ps (gamma + i, arg);
- }
-
- for (; i < n; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i]));
- }
-}
-
-#define myexpdigammify vexpdigammify
-#define myexpdigammify_2 vexpdigammify_2
-
-#else
-#ifndef _WIN32
-#warning "lda IS NOT using sse instructions"
-#endif
-#define myexpdigammify expdigammify
-#define myexpdigammify_2 expdigammify_2
-
-#endif // __SSE2__
-
-} // end anonymous namespace
-
-#else
-
-#include <boost/math/special_functions/digamma.hpp>
-#include <boost/math/special_functions/gamma.hpp>
-
-using namespace boost::math::policies;
-
-#define mydigamma boost::math::digamma
-#define mylgamma boost::math::lgamma
-#define myexpdigammify expdigammify
-#define myexpdigammify_2 expdigammify_2
-
-#endif // MINEIRO_SPECIAL
-
-float decayfunc(float t, float old_t, float power_t) {
- float result = 1;
- for (float i = old_t+1; i <= t; i += 1)
- result *= (1-powf(i, -power_t));
- return result;
-}
-
-float decayfunc2(float t, float old_t, float power_t)
-{
- float power_t_plus_one = 1.f - power_t;
- float arg = - ( powf(t, power_t_plus_one) -
- powf(old_t, power_t_plus_one));
- return exp ( arg
- / power_t_plus_one);
-}
-
-float decayfunc3(double t, double old_t, double power_t)
-{
- double power_t_plus_one = 1. - power_t;
- double logt = log((float)t);
- double logoldt = log((float)old_t);
- return (float)((old_t / t) * exp((float)(0.5*power_t_plus_one*(-logt*logt + logoldt*logoldt))));
-}
-
-float decayfunc4(double t, double old_t, double power_t)
-{
- if (power_t > 0.99)
- return decayfunc3(t, old_t, power_t);
- else
- return (float)decayfunc2((float)t, (float)old_t, (float)power_t);
-}
-
-void expdigammify(vw& all, float* gamma)
-{
- float sum=0;
- for (size_t i = 0; i<all.lda; i++)
- {
- sum += gamma[i];
- gamma[i] = mydigamma(gamma[i]);
- }
- sum = mydigamma(sum);
- for (size_t i = 0; i<all.lda; i++)
- gamma[i] = fmax(1e-6f, exp(gamma[i] - sum));
-}
-
-void expdigammify_2(vw& all, float* gamma, float* norm)
-{
- for (size_t i = 0; i<all.lda; i++)
- {
- gamma[i] = fmax(1e-6f, exp(mydigamma(gamma[i]) - norm[i]));
- }
-}
-
-float average_diff(vw& all, float* oldgamma, float* newgamma)
-{
- float sum = 0.;
- float normalizer = 0.;
- for (size_t i = 0; i<all.lda; i++) {
- sum += fabsf(oldgamma[i] - newgamma[i]);
- normalizer += newgamma[i];
- }
- return sum / normalizer;
-}
-
-// Returns E_q[log p(\theta)] - E_q[log q(\theta)].
- float theta_kl(vw& all, v_array<float>& Elogtheta, float* gamma)
-{
- float gammasum = 0;
- Elogtheta.erase();
- for (size_t k = 0; k < all.lda; k++) {
- Elogtheta.push_back(mydigamma(gamma[k]));
- gammasum += gamma[k];
- }
- float digammasum = mydigamma(gammasum);
- gammasum = mylgamma(gammasum);
- float kl = -(all.lda*mylgamma(all.lda_alpha));
- kl += mylgamma(all.lda_alpha*all.lda) - gammasum;
- for (size_t k = 0; k < all.lda; k++) {
- Elogtheta[k] -= digammasum;
- kl += (all.lda_alpha - gamma[k]) * Elogtheta[k];
- kl += mylgamma(gamma[k]);
- }
-
- return kl;
-}
-
-float find_cw(vw& all, float* u_for_w, float* v)
-{
- float c_w = 0;
- for (size_t k =0; k<all.lda; k++)
- c_w += u_for_w[k]*v[k];
-
- return 1.f / c_w;
-}
-
- v_array<float> new_gamma = v_init<float>();
- v_array<float> old_gamma = v_init<float>();
-// Returns an estimate of the part of the variational bound that
-// doesn't have to do with beta for the entire corpus for the current
-// setting of lambda based on the document passed in. The value is
-// divided by the total number of words in the document This can be
-// used as a (possibly very noisy) estimate of held-out likelihood.
- float lda_loop(vw& all, v_array<float>& Elogtheta, float* v,weight* weights,example* ec, float power_t)
-{
- new_gamma.erase();
- old_gamma.erase();
-
- for (size_t i = 0; i < all.lda; i++)
- {
- new_gamma.push_back(1.f);
- old_gamma.push_back(0.f);
- }
- size_t num_words =0;
- for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
- num_words += ec->atomics[*i].end - ec->atomics[*i].begin;
-
- float xc_w = 0;
- float score = 0;
- float doc_length = 0;
- do
- {
- memcpy(v,new_gamma.begin,sizeof(float)*all.lda);
- myexpdigammify(all, v);
-
- memcpy(old_gamma.begin,new_gamma.begin,sizeof(float)*all.lda);
- memset(new_gamma.begin,0,sizeof(float)*all.lda);
-
- score = 0;
- size_t word_count = 0;
- doc_length = 0;
- for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
- {
- feature *f = ec->atomics[*i].begin;
- for (; f != ec->atomics[*i].end; f++)
- {
- float* u_for_w = &weights[(f->weight_index&all.reg.weight_mask)+all.lda+1];
- float c_w = find_cw(all, u_for_w,v);
- xc_w = c_w * f->x;
- score += -f->x*log(c_w);
- size_t max_k = all.lda;
- for (size_t k =0; k<max_k; k++) {
- new_gamma[k] += xc_w*u_for_w[k];
- }
- word_count++;
- doc_length += f->x;
- }
- }
- for (size_t k =0; k<all.lda; k++)
- new_gamma[k] = new_gamma[k]*v[k]+all.lda_alpha;
- }
- while (average_diff(all, old_gamma.begin, new_gamma.begin) > all.lda_epsilon);
-
- ec->topic_predictions.erase();
- ec->topic_predictions.resize(all.lda);
- memcpy(ec->topic_predictions.begin,new_gamma.begin,all.lda*sizeof(float));
-
- score += theta_kl(all, Elogtheta, new_gamma.begin);
-
- return score / doc_length;
-}
-
-size_t next_pow2(size_t x) {
- int i = 0;
- x = x > 0 ? x - 1 : 0;
- while (x > 0) {
- x >>= 1;
- i++;
- }
- return ((size_t)1) << i;
-}
-
-void save_load(lda& l, io_buf& model_file, bool read, bool text)
-{
- vw* all = l.all;
- uint32_t length = 1 << all->num_bits;
- uint32_t stride = 1 << all->reg.stride_shift;
-
- if (read)
- {
- initialize_regressor(*all);
- for (size_t j = 0; j < stride*length; j+=stride)
- {
- for (size_t k = 0; k < all->lda; k++) {
- if (all->random_weights) {
- all->reg.weight_vector[j+k] = (float)(-log(frand48()) + 1.0f);
- all->reg.weight_vector[j+k] *= (float)(all->lda_D / all->lda / all->length() * 200);
- }
- }
- all->reg.weight_vector[j+all->lda] = all->initial_t;
- }
- }
-
- if (model_file.files.size() > 0)
- {
- uint32_t i = 0;
- uint32_t text_len;
- char buff[512];
- size_t brw = 1;
- do
- {
- brw = 0;
- size_t K = all->lda;
-
- text_len = sprintf(buff, "%d ", i);
- brw += bin_text_read_write_fixed(model_file,(char *)&i, sizeof (i),
- "", read,
- buff, text_len, text);
- if (brw != 0)
- for (uint32_t k = 0; k < K; k++)
- {
- uint32_t ndx = stride*i+k;
-
- weight* v = &(all->reg.weight_vector[ndx]);
- text_len = sprintf(buff, "%f ", *v + all->lda_rho);
-
- brw += bin_text_read_write_fixed(model_file,(char *)v, sizeof (*v),
- "", read,
- buff, text_len, text);
-
- }
- if (text)
- brw += bin_text_read_write_fixed(model_file,buff,0,
- "", read,
- "\n",1,text);
-
- if (!read)
- i++;
- }
- while ((!read && i < length) || (read && brw >0));
- }
-}
-
- void learn_batch(lda& l)
- {
- if (l.sorted_features.empty()) {
- // This can happen when the socket connection is dropped by the client.
- // If l.sorted_features is empty, then l.sorted_features[0] does not
- // exist, so we should not try to take its address in the beginning of
- // the for loops down there. Since it seems that there's not much to
- // do in this case, we just return.
- for (size_t d = 0; d < l.examples.size(); d++)
- return_simple_example(*l.all, NULL, *l.examples[d]);
- l.examples.erase();
- return;
- }
-
- float eta = -1;
- float minuseta = -1;
-
- if (l.total_lambda.size() == 0)
- {
- for (size_t k = 0; k < l.all->lda; k++)
- l.total_lambda.push_back(0.f);
-
- size_t stride = 1 << l.all->reg.stride_shift;
- for (size_t i =0; i <= l.all->reg.weight_mask;i+=stride)
- for (size_t k = 0; k < l.all->lda; k++)
- l.total_lambda[k] += l.all->reg.weight_vector[i+k];
- }
-
- l.example_t++;
- l.total_new.erase();
- for (size_t k = 0; k < l.all->lda; k++)
- l.total_new.push_back(0.f);
-
- size_t batch_size = l.examples.size();
-
- sort(l.sorted_features.begin(), l.sorted_features.end());
-
- eta = l.all->eta * powf((float)l.example_t, - l.all->power_t);
- minuseta = 1.0f - eta;
- eta *= l.all->lda_D / batch_size;
- l.decay_levels.push_back(l.decay_levels.last() + log(minuseta));
-
- l.digammas.erase();
- float additional = (float)(l.all->length()) * l.all->lda_rho;
- for (size_t i = 0; i<l.all->lda; i++) {
- l.digammas.push_back(mydigamma(l.total_lambda[i] + additional));
- }
-
-
- weight* weights = l.all->reg.weight_vector;
-
- size_t last_weight_index = -1;
- for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back(); s++)
- {
- if (last_weight_index == s->f.weight_index)
- continue;
- last_weight_index = s->f.weight_index;
- float* weights_for_w = &(weights[s->f.weight_index & l.all->reg.weight_mask]);
- float decay = fmin(1.0, exp(l.decay_levels.end[-2] - l.decay_levels.end[(int)(-1 - l.example_t+weights_for_w[l.all->lda])]));
- float* u_for_w = weights_for_w + l.all->lda+1;
-
- weights_for_w[l.all->lda] = (float)l.example_t;
- for (size_t k = 0; k < l.all->lda; k++)
- {
- weights_for_w[k] *= decay;
- u_for_w[k] = weights_for_w[k] + l.all->lda_rho;
- }
- myexpdigammify_2(*l.all, u_for_w, l.digammas.begin);
- }
-
- for (size_t d = 0; d < batch_size; d++)
- {
- float score = lda_loop(*l.all, l.Elogtheta, &(l.v[d*l.all->lda]), weights, l.examples[d],l.all->power_t);
- if (l.all->audit)
- GD::print_audit_features(*l.all, *l.examples[d]);
- // If the doc is empty, give it loss of 0.
- if (l.doc_lengths[d] > 0) {
- l.all->sd->sum_loss -= score;
- l.all->sd->sum_loss_since_last_dump -= score;
- }
- return_simple_example(*l.all, NULL, *l.examples[d]);
- }
-
- for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back();)
- {
- index_feature* next = s+1;
- while(next <= &l.sorted_features.back() && next->f.weight_index == s->f.weight_index)
- next++;
-
- float* word_weights = &(weights[s->f.weight_index & l.all->reg.weight_mask]);
- for (size_t k = 0; k < l.all->lda; k++) {
- float new_value = minuseta*word_weights[k];
- word_weights[k] = new_value;
- }
-
- for (; s != next; s++) {
- float* v_s = &(l.v[s->document*l.all->lda]);
- float* u_for_w = &weights[(s->f.weight_index & l.all->reg.weight_mask) + l.all->lda + 1];
- float c_w = eta*find_cw(*l.all, u_for_w, v_s)*s->f.x;
- for (size_t k = 0; k < l.all->lda; k++) {
- float new_value = u_for_w[k]*v_s[k]*c_w;
- l.total_new[k] += new_value;
- word_weights[k] += new_value;
- }
- }
- }
- for (size_t k = 0; k < l.all->lda; k++) {
- l.total_lambda[k] *= minuseta;
- l.total_lambda[k] += l.total_new[k];
- }
-
- l.sorted_features.resize(0);
-
- l.examples.erase();
- l.doc_lengths.erase();
- }
-
- void learn(lda& l, learner& base, example& ec)
- {
- size_t num_ex = l.examples.size();
- l.examples.push_back(&ec);
- l.doc_lengths.push_back(0);
- for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) {
- feature* f = ec.atomics[*i].begin;
- for (; f != ec.atomics[*i].end; f++) {
- index_feature temp = {(uint32_t)num_ex, *f};
- l.sorted_features.push_back(temp);
- l.doc_lengths[num_ex] += (int)f->x;
- }
- }
- if (++num_ex == l.all->minibatch)
- learn_batch(l);
- }
-
- // placeholder
- void predict(lda& l, learner& base, example& ec)
- {
- learn(l, base, ec);
- }
-
- void end_pass(lda& l)
- {
- if (l.examples.size())
- learn_batch(l);
- }
-
-void end_examples(lda& l)
-{
- for (size_t i = 0; i < l.all->length(); i++) {
- weight* weights_for_w = & (l.all->reg.weight_vector[i << l.all->reg.stride_shift]);
- float decay = fmin(1.0, exp(l.decay_levels.last() - l.decay_levels.end[(int)(-1- l.example_t +weights_for_w[l.all->lda])]));
- for (size_t k = 0; k < l.all->lda; k++)
- weights_for_w[k] *= decay;
- }
-}
-
- void finish_example(vw& all, lda&, example& ec)
-{}
-
- void finish(lda& ld)
- {
- ld.sorted_features.~vector<index_feature>();
- ld.Elogtheta.delete_v();
- ld.decay_levels.delete_v();
- ld.total_new.delete_v();
- ld.examples.delete_v();
- ld.total_lambda.delete_v();
- ld.doc_lengths.delete_v();
- ld.digammas.delete_v();
- ld.v.delete_v();
- }
-
-learner* setup(vw&all, po::variables_map& vm)
-{
- lda* ld = (lda*)calloc_or_die(1,sizeof(lda));
- ld->sorted_features = vector<index_feature>();
- ld->total_lambda_init = 0;
- ld->all = &all;
- ld->example_t = all.initial_t;
-
- po::options_description lda_opts("LDA options");
- lda_opts.add_options()
- ("lda_alpha", po::value<float>(&all.lda_alpha), "Prior on sparsity of per-document topic weights")
- ("lda_rho", po::value<float>(&all.lda_rho), "Prior on sparsity of topic distributions")
- ("lda_D", po::value<float>(&all.lda_D), "Number of documents")
- ("lda_epsilon", po::value<float>(&all.lda_epsilon), "Loop convergence threshold")
- ("minibatch", po::value<size_t>(&all.minibatch), "Minibatch size, for LDA");
-
- vm = add_options(all, lda_opts);
-
- float temp = ceilf(logf((float)(all.lda*2+1)) / logf (2.f));
- all.reg.stride_shift = (size_t)temp;
- all.random_weights = true;
- all.add_constant = false;
-
- std::stringstream ss;
- ss << " --lda " << all.lda;
- all.file_options.append(ss.str());
-
- if (all.eta > 1.)
- {
- cerr << "your learning rate is too high, setting it to 1" << endl;
- all.eta = min(all.eta,1.f);
- }
-
- if (vm.count("minibatch")) {
- size_t minibatch2 = next_pow2(all.minibatch);
- all.p->ring_size = all.p->ring_size > minibatch2 ? all.p->ring_size : minibatch2;
- }
-
- ld->v.resize(all.lda*all.minibatch);
-
- ld->decay_levels.push_back(0.f);
-
- learner* l = new learner(ld, 1 << all.reg.stride_shift);
- l->set_learn<lda,learn>();
- l->set_predict<lda,predict>();
- l->set_save_load<lda,save_load>();
- l->set_finish_example<lda,finish_example>();
- l->set_end_examples<lda,end_examples>();
- l->set_end_pass<lda,end_pass>();
- l->set_finish<lda,finish>();
-
- return l;
-}
-}
+/* +Copyright (c) by respective owners including Yahoo!, Microsoft, and +individual contributors. All rights reserved. Released under a BSD (revised) +license as described in the file LICENSE. + */ +#include <fstream> +#include <vector> +#include <float.h> +#ifdef _WIN32 +#include <winsock2.h> +#else +#include <netdb.h> +#endif +#include <string.h> +#include <stdio.h> +#include <assert.h> +#include "constant.h" +#include "gd.h" +#include "simple_label.h" +#include "rand48.h" +#include "reductions.h" + +using namespace LEARNER; +using namespace std; + +namespace LDA { + +class index_feature { +public: + uint32_t document; + feature f; + bool operator<(const index_feature b) const { return f.weight_index < b.f.weight_index; } +}; + + struct lda { + v_array<float> Elogtheta; + v_array<float> decay_levels; + v_array<float> total_new; + v_array<example* > examples; + v_array<float> total_lambda; + v_array<int> doc_lengths; + v_array<float> digammas; + v_array<float> v; + vector<index_feature> sorted_features; + + bool total_lambda_init; + + double example_t; + vw* all; + }; + +#ifdef _WIN32 +inline float fmax(float f1, float f2) { return (f1 < f2 ? f2 : f1); } +inline float fmin(float f1, float f2) { return (f1 > f2 ? f2 : f1); } +#endif + +#define MINEIRO_SPECIAL +#ifdef MINEIRO_SPECIAL + +namespace { + +inline float +fastlog2 (float x) +{ + union { float f; uint32_t i; } vx = { x }; + union { uint32_t i; float f; } mx = { (vx.i & 0x007FFFFF) | (0x7e << 23) }; + float y = (float)vx.i; + y *= 1.0f / (float)(1 << 23); + + return + y - 124.22544637f - 1.498030302f * mx.f - 1.72587999f / (0.3520887068f + mx.f); +} + +inline float +fastlog (float x) +{ + return 0.69314718f * fastlog2 (x); +} + +inline float +fastpow2 (float p) +{ + float offset = (p < 0) ? 1.0f : 0.0f; + float clipp = (p < -126) ? -126.0f : p; + int w = (int)clipp; + float z = clipp - w + offset; + union { uint32_t i; float f; } v = { (uint32_t)((1 << 23) * (clipp + 121.2740838f + 27.7280233f / (4.84252568f - z) - 1.49012907f * z)) }; + + return v.f; +} + +inline float +fastexp (float p) +{ + return fastpow2 (1.442695040f * p); +} + +inline float +fastpow (float x, + float p) +{ + return fastpow2 (p * fastlog2 (x)); +} + +inline float +fastlgamma (float x) +{ + float logterm = fastlog (x * (1.0f + x) * (2.0f + x)); + float xp3 = 3.0f + x; + + return + -2.081061466f - x + 0.0833333f / xp3 - logterm + (2.5f + x) * fastlog (xp3); +} + +inline float +fastdigamma (float x) +{ + float twopx = 2.0f + x; + float logterm = fastlog (twopx); + + return - (1.0f + 2.0f * x) / (x * (1.0f + x)) + - (13.0f + 6.0f * x) / (12.0f * twopx * twopx) + + logterm; +} + +#define log fastlog +#define exp fastexp +#define powf fastpow +#define mydigamma fastdigamma +#define mylgamma fastlgamma + +#if defined(__SSE2__) && !defined(VW_LDA_NO_SSE) + +#include <emmintrin.h> + +typedef __m128 v4sf; +typedef __m128i v4si; + +#define v4si_to_v4sf _mm_cvtepi32_ps +#define v4sf_to_v4si _mm_cvttps_epi32 + +static inline float +v4sf_index (const v4sf x, + unsigned int i) +{ + union { v4sf f; float array[4]; } tmp = { x }; + + return tmp.array[i]; +} + +static inline const v4sf +v4sfl (float x) +{ + union { float array[4]; v4sf f; } tmp = { { x, x, x, x } }; + + return tmp.f; +} + +static inline const v4si +v4sil (uint32_t x) +{ + uint64_t wide = (((uint64_t) x) << 32) | x; + union { uint64_t array[2]; v4si f; } tmp = { { wide, wide } }; + + return tmp.f; +} + +static inline v4sf +vfastpow2 (const v4sf p) +{ + v4sf ltzero = _mm_cmplt_ps (p, v4sfl (0.0f)); + v4sf offset = _mm_and_ps (ltzero, v4sfl (1.0f)); + v4sf lt126 = _mm_cmplt_ps (p, v4sfl (-126.0f)); + v4sf clipp = _mm_andnot_ps (lt126, p) + _mm_and_ps (lt126, v4sfl (-126.0f)); + v4si w = v4sf_to_v4si (clipp); + v4sf z = clipp - v4si_to_v4sf (w) + offset; + + const v4sf c_121_2740838 = v4sfl (121.2740838f); + const v4sf c_27_7280233 = v4sfl (27.7280233f); + const v4sf c_4_84252568 = v4sfl (4.84252568f); + const v4sf c_1_49012907 = v4sfl (1.49012907f); + union { v4si i; v4sf f; } v = { + v4sf_to_v4si ( + v4sfl (1 << 23) * + (clipp + c_121_2740838 + c_27_7280233 / (c_4_84252568 - z) - c_1_49012907 * z) + ) + }; + + return v.f; +} + +inline v4sf +vfastexp (const v4sf p) +{ + const v4sf c_invlog_2 = v4sfl (1.442695040f); + + return vfastpow2 (c_invlog_2 * p); +} + +inline v4sf +vfastlog2 (v4sf x) +{ + union { v4sf f; v4si i; } vx = { x }; + union { v4si i; v4sf f; } mx = { (vx.i & v4sil (0x007FFFFF)) | v4sil (0x3f000000) }; + v4sf y = v4si_to_v4sf (vx.i); + y *= v4sfl (1.1920928955078125e-7f); + + const v4sf c_124_22551499 = v4sfl (124.22551499f); + const v4sf c_1_498030302 = v4sfl (1.498030302f); + const v4sf c_1_725877999 = v4sfl (1.72587999f); + const v4sf c_0_3520087068 = v4sfl (0.3520887068f); + + return y - c_124_22551499 + - c_1_498030302 * mx.f + - c_1_725877999 / (c_0_3520087068 + mx.f); +} + +inline v4sf +vfastlog (v4sf x) +{ + const v4sf c_0_69314718 = v4sfl (0.69314718f); + + return c_0_69314718 * vfastlog2 (x); +} + +inline v4sf +vfastdigamma (v4sf x) +{ + v4sf twopx = v4sfl (2.0f) + x; + v4sf logterm = vfastlog (twopx); + + return (v4sfl (-48.0f) + x * (v4sfl (-157.0f) + x * (v4sfl (-127.0f) - v4sfl (30.0f) * x))) / + (v4sfl (12.0f) * x * (v4sfl (1.0f) + x) * twopx * twopx) + + logterm; +} + +void +vexpdigammify (vw& all, float* gamma) +{ + unsigned int n = all.lda; + float extra_sum = 0.0f; + v4sf sum = v4sfl (0.0f); + size_t i; + + for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i) + { + extra_sum += gamma[i]; + gamma[i] = fastdigamma (gamma[i]); + } + + for (; i + 4 < n; i += 4) + { + v4sf arg = _mm_load_ps (gamma + i); + sum += arg; + arg = vfastdigamma (arg); + _mm_store_ps (gamma + i, arg); + } + + for (; i < n; ++i) + { + extra_sum += gamma[i]; + gamma[i] = fastdigamma (gamma[i]); + } + + extra_sum += v4sf_index (sum, 0) + v4sf_index (sum, 1) + + v4sf_index (sum, 2) + v4sf_index (sum, 3); + extra_sum = fastdigamma (extra_sum); + sum = v4sfl (extra_sum); + + for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i) + { + gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0))); + } + + for (; i + 4 < n; i += 4) + { + v4sf arg = _mm_load_ps (gamma + i); + arg -= sum; + arg = vfastexp (arg); + arg = _mm_max_ps (v4sfl (1e-10f), arg); + _mm_store_ps (gamma + i, arg); + } + + for (; i < n; ++i) + { + gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0))); + } +} + +void vexpdigammify_2(vw& all, float* gamma, const float* norm) +{ + size_t n = all.lda; + size_t i; + + for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i) + { + gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i])); + } + + for (; i + 4 < n; i += 4) + { + v4sf arg = _mm_load_ps (gamma + i); + arg = vfastdigamma (arg); + v4sf vnorm = _mm_loadu_ps (norm + i); + arg -= vnorm; + arg = vfastexp (arg); + arg = _mm_max_ps (v4sfl (1e-10f), arg); + _mm_store_ps (gamma + i, arg); + } + + for (; i < n; ++i) + { + gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i])); + } +} + +#define myexpdigammify vexpdigammify +#define myexpdigammify_2 vexpdigammify_2 + +#else +#ifndef _WIN32 +#warning "lda IS NOT using sse instructions" +#endif +#define myexpdigammify expdigammify +#define myexpdigammify_2 expdigammify_2 + +#endif // __SSE2__ + +} // end anonymous namespace + +#else + +#include <boost/math/special_functions/digamma.hpp> +#include <boost/math/special_functions/gamma.hpp> + +using namespace boost::math::policies; + +#define mydigamma boost::math::digamma +#define mylgamma boost::math::lgamma +#define myexpdigammify expdigammify +#define myexpdigammify_2 expdigammify_2 + +#endif // MINEIRO_SPECIAL + +float decayfunc(float t, float old_t, float power_t) { + float result = 1; + for (float i = old_t+1; i <= t; i += 1) + result *= (1-powf(i, -power_t)); + return result; +} + +float decayfunc2(float t, float old_t, float power_t) +{ + float power_t_plus_one = 1.f - power_t; + float arg = - ( powf(t, power_t_plus_one) - + powf(old_t, power_t_plus_one)); + return exp ( arg + / power_t_plus_one); +} + +float decayfunc3(double t, double old_t, double power_t) +{ + double power_t_plus_one = 1. - power_t; + double logt = log((float)t); + double logoldt = log((float)old_t); + return (float)((old_t / t) * exp((float)(0.5*power_t_plus_one*(-logt*logt + logoldt*logoldt)))); +} + +float decayfunc4(double t, double old_t, double power_t) +{ + if (power_t > 0.99) + return decayfunc3(t, old_t, power_t); + else + return (float)decayfunc2((float)t, (float)old_t, (float)power_t); +} + +void expdigammify(vw& all, float* gamma) +{ + float sum=0; + for (size_t i = 0; i<all.lda; i++) + { + sum += gamma[i]; + gamma[i] = mydigamma(gamma[i]); + } + sum = mydigamma(sum); + for (size_t i = 0; i<all.lda; i++) + gamma[i] = fmax(1e-6f, exp(gamma[i] - sum)); +} + +void expdigammify_2(vw& all, float* gamma, float* norm) +{ + for (size_t i = 0; i<all.lda; i++) + { + gamma[i] = fmax(1e-6f, exp(mydigamma(gamma[i]) - norm[i])); + } +} + +float average_diff(vw& all, float* oldgamma, float* newgamma) +{ + float sum = 0.; + float normalizer = 0.; + for (size_t i = 0; i<all.lda; i++) { + sum += fabsf(oldgamma[i] - newgamma[i]); + normalizer += newgamma[i]; + } + return sum / normalizer; +} + +// Returns E_q[log p(\theta)] - E_q[log q(\theta)]. + float theta_kl(vw& all, v_array<float>& Elogtheta, float* gamma) +{ + float gammasum = 0; + Elogtheta.erase(); + for (size_t k = 0; k < all.lda; k++) { + Elogtheta.push_back(mydigamma(gamma[k])); + gammasum += gamma[k]; + } + float digammasum = mydigamma(gammasum); + gammasum = mylgamma(gammasum); + float kl = -(all.lda*mylgamma(all.lda_alpha)); + kl += mylgamma(all.lda_alpha*all.lda) - gammasum; + for (size_t k = 0; k < all.lda; k++) { + Elogtheta[k] -= digammasum; + kl += (all.lda_alpha - gamma[k]) * Elogtheta[k]; + kl += mylgamma(gamma[k]); + } + + return kl; +} + +float find_cw(vw& all, float* u_for_w, float* v) +{ + float c_w = 0; + for (size_t k =0; k<all.lda; k++) + c_w += u_for_w[k]*v[k]; + + return 1.f / c_w; +} + + v_array<float> new_gamma = v_init<float>(); + v_array<float> old_gamma = v_init<float>(); +// Returns an estimate of the part of the variational bound that +// doesn't have to do with beta for the entire corpus for the current +// setting of lambda based on the document passed in. The value is +// divided by the total number of words in the document This can be +// used as a (possibly very noisy) estimate of held-out likelihood. + float lda_loop(vw& all, v_array<float>& Elogtheta, float* v,weight* weights,example* ec, float power_t) +{ + new_gamma.erase(); + old_gamma.erase(); + + for (size_t i = 0; i < all.lda; i++) + { + new_gamma.push_back(1.f); + old_gamma.push_back(0.f); + } + size_t num_words =0; + for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++) + num_words += ec->atomics[*i].end - ec->atomics[*i].begin; + + float xc_w = 0; + float score = 0; + float doc_length = 0; + do + { + memcpy(v,new_gamma.begin,sizeof(float)*all.lda); + myexpdigammify(all, v); + + memcpy(old_gamma.begin,new_gamma.begin,sizeof(float)*all.lda); + memset(new_gamma.begin,0,sizeof(float)*all.lda); + + score = 0; + size_t word_count = 0; + doc_length = 0; + for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++) + { + feature *f = ec->atomics[*i].begin; + for (; f != ec->atomics[*i].end; f++) + { + float* u_for_w = &weights[(f->weight_index&all.reg.weight_mask)+all.lda+1]; + float c_w = find_cw(all, u_for_w,v); + xc_w = c_w * f->x; + score += -f->x*log(c_w); + size_t max_k = all.lda; + for (size_t k =0; k<max_k; k++) { + new_gamma[k] += xc_w*u_for_w[k]; + } + word_count++; + doc_length += f->x; + } + } + for (size_t k =0; k<all.lda; k++) + new_gamma[k] = new_gamma[k]*v[k]+all.lda_alpha; + } + while (average_diff(all, old_gamma.begin, new_gamma.begin) > all.lda_epsilon); + + ec->topic_predictions.erase(); + ec->topic_predictions.resize(all.lda); + memcpy(ec->topic_predictions.begin,new_gamma.begin,all.lda*sizeof(float)); + + score += theta_kl(all, Elogtheta, new_gamma.begin); + + return score / doc_length; +} + +size_t next_pow2(size_t x) { + int i = 0; + x = x > 0 ? x - 1 : 0; + while (x > 0) { + x >>= 1; + i++; + } + return ((size_t)1) << i; +} + +void save_load(lda& l, io_buf& model_file, bool read, bool text) +{ + vw* all = l.all; + uint32_t length = 1 << all->num_bits; + uint32_t stride = 1 << all->reg.stride_shift; + + if (read) + { + initialize_regressor(*all); + for (size_t j = 0; j < stride*length; j+=stride) + { + for (size_t k = 0; k < all->lda; k++) { + if (all->random_weights) { + all->reg.weight_vector[j+k] = (float)(-log(frand48()) + 1.0f); + all->reg.weight_vector[j+k] *= (float)(all->lda_D / all->lda / all->length() * 200); + } + } + all->reg.weight_vector[j+all->lda] = all->initial_t; + } + } + + if (model_file.files.size() > 0) + { + uint32_t i = 0; + uint32_t text_len; + char buff[512]; + size_t brw = 1; + do + { + brw = 0; + size_t K = all->lda; + + text_len = sprintf(buff, "%d ", i); + brw += bin_text_read_write_fixed(model_file,(char *)&i, sizeof (i), + "", read, + buff, text_len, text); + if (brw != 0) + for (uint32_t k = 0; k < K; k++) + { + uint32_t ndx = stride*i+k; + + weight* v = &(all->reg.weight_vector[ndx]); + text_len = sprintf(buff, "%f ", *v + all->lda_rho); + + brw += bin_text_read_write_fixed(model_file,(char *)v, sizeof (*v), + "", read, + buff, text_len, text); + + } + if (text) + brw += bin_text_read_write_fixed(model_file,buff,0, + "", read, + "\n",1,text); + + if (!read) + i++; + } + while ((!read && i < length) || (read && brw >0)); + } +} + + void learn_batch(lda& l) + { + if (l.sorted_features.empty()) { + // This can happen when the socket connection is dropped by the client. + // If l.sorted_features is empty, then l.sorted_features[0] does not + // exist, so we should not try to take its address in the beginning of + // the for loops down there. Since it seems that there's not much to + // do in this case, we just return. + for (size_t d = 0; d < l.examples.size(); d++) + return_simple_example(*l.all, NULL, *l.examples[d]); + l.examples.erase(); + return; + } + + float eta = -1; + float minuseta = -1; + + if (l.total_lambda.size() == 0) + { + for (size_t k = 0; k < l.all->lda; k++) + l.total_lambda.push_back(0.f); + + size_t stride = 1 << l.all->reg.stride_shift; + for (size_t i =0; i <= l.all->reg.weight_mask;i+=stride) + for (size_t k = 0; k < l.all->lda; k++) + l.total_lambda[k] += l.all->reg.weight_vector[i+k]; + } + + l.example_t++; + l.total_new.erase(); + for (size_t k = 0; k < l.all->lda; k++) + l.total_new.push_back(0.f); + + size_t batch_size = l.examples.size(); + + sort(l.sorted_features.begin(), l.sorted_features.end()); + + eta = l.all->eta * powf((float)l.example_t, - l.all->power_t); + minuseta = 1.0f - eta; + eta *= l.all->lda_D / batch_size; + l.decay_levels.push_back(l.decay_levels.last() + log(minuseta)); + + l.digammas.erase(); + float additional = (float)(l.all->length()) * l.all->lda_rho; + for (size_t i = 0; i<l.all->lda; i++) { + l.digammas.push_back(mydigamma(l.total_lambda[i] + additional)); + } + + + weight* weights = l.all->reg.weight_vector; + + size_t last_weight_index = -1; + for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back(); s++) + { + if (last_weight_index == s->f.weight_index) + continue; + last_weight_index = s->f.weight_index; + float* weights_for_w = &(weights[s->f.weight_index & l.all->reg.weight_mask]); + float decay = fmin(1.0, exp(l.decay_levels.end[-2] - l.decay_levels.end[(int)(-1 - l.example_t+weights_for_w[l.all->lda])])); + float* u_for_w = weights_for_w + l.all->lda+1; + + weights_for_w[l.all->lda] = (float)l.example_t; + for (size_t k = 0; k < l.all->lda; k++) + { + weights_for_w[k] *= decay; + u_for_w[k] = weights_for_w[k] + l.all->lda_rho; + } + myexpdigammify_2(*l.all, u_for_w, l.digammas.begin); + } + + for (size_t d = 0; d < batch_size; d++) + { + float score = lda_loop(*l.all, l.Elogtheta, &(l.v[d*l.all->lda]), weights, l.examples[d],l.all->power_t); + if (l.all->audit) + GD::print_audit_features(*l.all, *l.examples[d]); + // If the doc is empty, give it loss of 0. + if (l.doc_lengths[d] > 0) { + l.all->sd->sum_loss -= score; + l.all->sd->sum_loss_since_last_dump -= score; + } + return_simple_example(*l.all, NULL, *l.examples[d]); + } + + for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back();) + { + index_feature* next = s+1; + while(next <= &l.sorted_features.back() && next->f.weight_index == s->f.weight_index) + next++; + + float* word_weights = &(weights[s->f.weight_index & l.all->reg.weight_mask]); + for (size_t k = 0; k < l.all->lda; k++) { + float new_value = minuseta*word_weights[k]; + word_weights[k] = new_value; + } + + for (; s != next; s++) { + float* v_s = &(l.v[s->document*l.all->lda]); + float* u_for_w = &weights[(s->f.weight_index & l.all->reg.weight_mask) + l.all->lda + 1]; + float c_w = eta*find_cw(*l.all, u_for_w, v_s)*s->f.x; + for (size_t k = 0; k < l.all->lda; k++) { + float new_value = u_for_w[k]*v_s[k]*c_w; + l.total_new[k] += new_value; + word_weights[k] += new_value; + } + } + } + for (size_t k = 0; k < l.all->lda; k++) { + l.total_lambda[k] *= minuseta; + l.total_lambda[k] += l.total_new[k]; + } + + l.sorted_features.resize(0); + + l.examples.erase(); + l.doc_lengths.erase(); + } + + void learn(lda& l, learner& base, example& ec) + { + size_t num_ex = l.examples.size(); + l.examples.push_back(&ec); + l.doc_lengths.push_back(0); + for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) { + feature* f = ec.atomics[*i].begin; + for (; f != ec.atomics[*i].end; f++) { + index_feature temp = {(uint32_t)num_ex, *f}; + l.sorted_features.push_back(temp); + l.doc_lengths[num_ex] += (int)f->x; + } + } + if (++num_ex == l.all->minibatch) + learn_batch(l); + } + + // placeholder + void predict(lda& l, learner& base, example& ec) + { + learn(l, base, ec); + } + + void end_pass(lda& l) + { + if (l.examples.size()) + learn_batch(l); + } + +void end_examples(lda& l) +{ + for (size_t i = 0; i < l.all->length(); i++) { + weight* weights_for_w = & (l.all->reg.weight_vector[i << l.all->reg.stride_shift]); + float decay = fmin(1.0, exp(l.decay_levels.last() - l.decay_levels.end[(int)(-1- l.example_t +weights_for_w[l.all->lda])])); + for (size_t k = 0; k < l.all->lda; k++) + weights_for_w[k] *= decay; + } +} + + void finish_example(vw& all, lda&, example& ec) +{} + + void finish(lda& ld) + { + ld.sorted_features.~vector<index_feature>(); + ld.Elogtheta.delete_v(); + ld.decay_levels.delete_v(); + ld.total_new.delete_v(); + ld.examples.delete_v(); + ld.total_lambda.delete_v(); + ld.doc_lengths.delete_v(); + ld.digammas.delete_v(); + ld.v.delete_v(); + } + +learner* setup(vw&all, po::variables_map& vm) +{ + lda* ld = (lda*)calloc_or_die(1,sizeof(lda)); + ld->sorted_features = vector<index_feature>(); + ld->total_lambda_init = 0; + ld->all = &all; + ld->example_t = all.initial_t; + + po::options_description lda_opts("LDA options"); + lda_opts.add_options() + ("lda_alpha", po::value<float>(&all.lda_alpha), "Prior on sparsity of per-document topic weights") + ("lda_rho", po::value<float>(&all.lda_rho), "Prior on sparsity of topic distributions") + ("lda_D", po::value<float>(&all.lda_D), "Number of documents") + ("lda_epsilon", po::value<float>(&all.lda_epsilon), "Loop convergence threshold") + ("minibatch", po::value<size_t>(&all.minibatch), "Minibatch size, for LDA"); + + vm = add_options(all, lda_opts); + + float temp = ceilf(logf((float)(all.lda*2+1)) / logf (2.f)); + all.reg.stride_shift = (size_t)temp; + all.random_weights = true; + all.add_constant = false; + + std::stringstream ss; + ss << " --lda " << all.lda; + all.file_options.append(ss.str()); + + if (all.eta > 1.) + { + cerr << "your learning rate is too high, setting it to 1" << endl; + all.eta = min(all.eta,1.f); + } + + if (vm.count("minibatch")) { + size_t minibatch2 = next_pow2(all.minibatch); + all.p->ring_size = all.p->ring_size > minibatch2 ? all.p->ring_size : minibatch2; + } + + ld->v.resize(all.lda*all.minibatch); + + ld->decay_levels.push_back(0.f); + + learner* l = new learner(ld, 1 << all.reg.stride_shift); + l->set_learn<lda,learn>(); + l->set_predict<lda,predict>(); + l->set_save_load<lda,save_load>(); + l->set_finish_example<lda,finish_example>(); + l->set_end_examples<lda,end_examples>(); + l->set_end_pass<lda,end_pass>(); + l->set_finish<lda,finish>(); + + return l; +} +} diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index 68f52f06..bfc25288 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -1,549 +1,549 @@ -/*\t
-
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD (revised)
-license as described in the file LICENSE.node
-*/
-#include <float.h>
-#include <math.h>
-#include <stdio.h>
-#include <sstream>
-
-#include "reductions.h"
-#include "simple_label.h"
-#include "multiclass.h"
-#include "vw.h"
-
-using namespace std;
-using namespace LEARNER;
-
-namespace LOG_MULTI
-{
- class node_pred
- {
- public:
-
- double Ehk;
- float norm_Ehk;
- uint32_t nk;
- uint32_t label;
- uint32_t label_count;
-
- bool operator==(node_pred v){
- return (label == v.label);
- }
-
- bool operator>(node_pred v){
- if(label > v.label) return true;
- return false;
- }
-
- bool operator<(node_pred v){
- if(label < v.label) return true;
- return false;
- }
-
- node_pred(uint32_t l)
- {
- label = l;
- Ehk = 0.f;
- norm_Ehk = 0;
- nk = 0;
- label_count = 0;
- }
- };
-
- typedef struct
- {//everyone has
- uint32_t parent;//the parent node
- v_array<node_pred> preds;//per-class state
- uint32_t min_count;//the number of examples reaching this node (if it's a leaf) or the minimum reaching any grandchild.
-
- bool internal;//internal or leaf
-
- //internal nodes have
- uint32_t base_predictor;//id of the base predictor
- uint32_t left;//left child
- uint32_t right;//right child
- float norm_Eh;//the average margin at the node
- double Eh;//total margin at the node
- uint32_t n;//total events at the node
-
- //leaf has
- uint32_t max_count;//the number of samples of the most common label
- uint32_t max_count_label;//the most common label
- } node;
-
- struct log_multi
- {
- uint32_t k;
- vw* all;
-
- v_array<node> nodes;
-
- uint32_t max_predictors;
- uint32_t predictors_used;
-
- bool progress;
- uint32_t swap_resist;
-
- uint32_t nbofswaps;
- };
-
- inline void init_leaf(node& n)
- {
- n.internal = false;
- n.preds.erase();
- n.base_predictor = 0;
- n.norm_Eh = 0;
- n.Eh = 0;
- n.n = 0;
- n.max_count = 0;
- n.max_count_label = 1;
- n.left = 0;
- n.right = 0;
- }
-
- inline node init_node()
- {
- node node;
-
- node.parent = 0;
- node.min_count = 0;
- node.preds = v_init<node_pred>();
- init_leaf(node);
-
- return node;
- }
-
- void init_tree(log_multi& d)
- {
- d.nodes.push_back(init_node());
- d.nbofswaps = 0;
- }
-
- inline uint32_t min_left_right(log_multi& b, node& n)
- {
- return min(b.nodes[n.left].min_count, b.nodes[n.right].min_count);
- }
-
- inline uint32_t find_switch_node(log_multi& b)
- {
- uint32_t node = 0;
- while(b.nodes[node].internal)
- if(b.nodes[b.nodes[node].left].min_count
- < b.nodes[b.nodes[node].right].min_count)
- node = b.nodes[node].left;
- else
- node = b.nodes[node].right;
- return node;
- }
-
- inline void update_min_count(log_multi& b, uint32_t node)
- {//Constant time min count update.
- while(node != 0)
- {
- uint32_t prev = node;
- node = b.nodes[node].parent;
-
- if (b.nodes[node].min_count == b.nodes[prev].min_count)
- break;
- else
- b.nodes[node].min_count = min_left_right(b,b.nodes[node]);
- }
- }
-
- void display_tree_dfs(log_multi& b, node node, uint32_t depth)
- {
- for (uint32_t i = 0; i < depth; i++)
- cout << "\t";
- cout << node.min_count << " " << node.left
- << " " << node.right;
- cout << " label = " << node.max_count_label << " labels = ";
- for (size_t i = 0; i < node.preds.size(); i++)
- cout << node.preds[i].label << ":" << node.preds[i].label_count << "\t";
- cout << endl;
-
- if (node.internal)
- {
- cout << "Left";
- display_tree_dfs(b, b.nodes[node.left], depth+1);
-
- cout << "Right";
- display_tree_dfs(b, b.nodes[node.right], depth+1);
- }
- }
-
- bool children(log_multi& b, uint32_t& current, uint32_t& class_index, uint32_t label)
- {
- class_index = (uint32_t)b.nodes[current].preds.unique_add_sorted(node_pred(label));
- b.nodes[current].preds[class_index].label_count++;
-
- if(b.nodes[current].preds[class_index].label_count > b.nodes[current].max_count)
- {
- b.nodes[current].max_count = b.nodes[current].preds[class_index].label_count;
- b.nodes[current].max_count_label = b.nodes[current].preds[class_index].label;
- }
-
- if (b.nodes[current].internal)
- return true;
- else if( b.nodes[current].preds.size() > 1
- && (b.predictors_used < b.max_predictors
- || b.nodes[current].min_count - b.nodes[current].max_count > b.swap_resist*(b.nodes[0].min_count + 1)))
- { //need children and we can make them.
- uint32_t left_child;
- uint32_t right_child;
- if (b.predictors_used < b.max_predictors)
- {
- left_child = (uint32_t)b.nodes.size();
- b.nodes.push_back(init_node());
- right_child = (uint32_t)b.nodes.size();
- b.nodes.push_back(init_node());
- b.nodes[current].base_predictor = b.predictors_used++;
- }
- else
- {
- uint32_t swap_child = find_switch_node(b);
- uint32_t swap_parent = b.nodes[swap_child].parent;
- uint32_t swap_grandparent = b.nodes[swap_parent].parent;
- if (b.nodes[swap_child].min_count != b.nodes[0].min_count)
- cout << "glargh " << b.nodes[swap_child].min_count << " != " << b.nodes[0].min_count << endl;
- b.nbofswaps++;
-
- uint32_t nonswap_child;
- if(swap_child == b.nodes[swap_parent].right)
- nonswap_child = b.nodes[swap_parent].left;
- else
- nonswap_child = b.nodes[swap_parent].right;
-
- if(swap_parent == b.nodes[swap_grandparent].left)
- b.nodes[swap_grandparent].left = nonswap_child;
- else
- b.nodes[swap_grandparent].right = nonswap_child;
- b.nodes[nonswap_child].parent = swap_grandparent;
- update_min_count(b, nonswap_child);
-
- init_leaf(b.nodes[swap_child]);
- left_child = swap_child;
- b.nodes[current].base_predictor = b.nodes[swap_parent].base_predictor;
- init_leaf(b.nodes[swap_parent]);
- right_child = swap_parent;
- }
- b.nodes[current].left = left_child;
- b.nodes[left_child].parent = current;
- b.nodes[current].right = right_child;
- b.nodes[right_child].parent = current;
-
- b.nodes[left_child].min_count = b.nodes[current].min_count/2;
- b.nodes[right_child].min_count = b.nodes[current].min_count - b.nodes[left_child].min_count;
- update_min_count(b, left_child);
-
- b.nodes[left_child].max_count_label = b.nodes[current].max_count_label;
- b.nodes[right_child].max_count_label = b.nodes[current].max_count_label;
-
- b.nodes[current].internal = true;
- }
- return b.nodes[current].internal;
- }
-
- void train_node(log_multi& b, learner& base, example& ec, uint32_t& current, uint32_t& class_index)
- {
- if(b.nodes[current].norm_Eh > b.nodes[current].preds[class_index].norm_Ehk)
- ec.l.simple.label = -1.f;
- else
- ec.l.simple.label = 1.f;
-
- base.learn(ec, b.nodes[current].base_predictor);
-
- ec.l.simple.label = FLT_MAX;
- base.predict(ec, b.nodes[current].base_predictor);
-
- b.nodes[current].Eh += (double)ec.partial_prediction;
- b.nodes[current].preds[class_index].Ehk += (double)ec.partial_prediction;
- b.nodes[current].n++;
- b.nodes[current].preds[class_index].nk++;
-
- b.nodes[current].norm_Eh = (float)b.nodes[current].Eh / b.nodes[current].n;
- b.nodes[current].preds[class_index].norm_Ehk = (float)b.nodes[current].preds[class_index].Ehk / b.nodes[current].preds[class_index].nk;
- }
-
- void verify_min_dfs(log_multi& b, node node)
- {
- if (node.internal)
- {
- if (node.min_count != min_left_right(b, node))
- {
- cout << "badness! " << endl;
- display_tree_dfs(b, b.nodes[0], 0);
- }
- verify_min_dfs(b, b.nodes[node.left]);
- verify_min_dfs(b, b.nodes[node.right]);
- }
- }
-
- size_t sum_count_dfs(log_multi& b, node node)
- {
- if (node.internal)
- return sum_count_dfs(b, b.nodes[node.left]) + sum_count_dfs(b, b.nodes[node.right]);
- else
- return node.min_count;
- }
-
- inline uint32_t descend(node& n, float prediction)
- {
- if (prediction < 0)
- return n.left;
- else
- return n.right;
- }
-
- void predict(log_multi& b, learner& base, example& ec)
- {
- MULTICLASS::multiclass mc = ec.l.multi;
-
- label_data simple_temp;
- simple_temp.initial = 0.0;
- simple_temp.weight = 0.0;
- simple_temp.label = FLT_MAX;
- ec.l.simple = simple_temp;
- uint32_t cn = 0;
- while(b.nodes[cn].internal)
- {
- base.predict(ec, b.nodes[cn].base_predictor);
- cn = descend(b.nodes[cn], ec.pred.scalar);
- }
- ec.pred.multiclass = b.nodes[cn].max_count_label;
- ec.l.multi = mc;
- }
-
- void learn(log_multi& b, learner& base, example& ec)
- {
- // verify_min_dfs(b, b.nodes[0]);
-
- if (ec.l.multi.label == (uint32_t)-1 || !b.all->training || b.progress)
- predict(b,base,ec);
-
- if(b.all->training && (ec.l.multi.label != (uint32_t)-1) && !ec.test_only) //if training the tree
- {
- MULTICLASS::multiclass mc = ec.l.multi;
-
- uint32_t class_index = 0;
- label_data simple_temp;
- simple_temp.initial = 0.0;
- simple_temp.weight = mc.weight;
- ec.l.simple = simple_temp;
-
- uint32_t cn = 0;
-
- while(children(b, cn, class_index, mc.label))
- {
- train_node(b, base, ec, cn, class_index);
- cn = descend(b.nodes[cn], ec.pred.scalar);
- }
-
- b.nodes[cn].min_count++;
- update_min_count(b, cn);
-
- ec.l.multi = mc;
- }
- }
-
- void save_node_stats(log_multi& d)
- {
- FILE *fp;
- uint32_t i, j;
- uint32_t total;
- log_multi* b = &d;
-
- fp = fopen("atxm_debug.csv", "wt");
-
- for(i = 0; i < b->nodes.size(); i++)
- {
- fprintf(fp, "Node: %4d, Internal: %1d, Eh: %7.4f, n: %6d, \n", (int) i, (int) b->nodes[i].internal, b->nodes[i].Eh / b->nodes[i].n, b->nodes[i].n);
-
- fprintf(fp, "Label:, ");
- for(j = 0; j < b->nodes[i].preds.size(); j++)
- {
- fprintf(fp, "%6d,", (int) b->nodes[i].preds[j].label);
- }
- fprintf(fp, "\n");
-
- fprintf(fp, "Ehk:, ");
- for(j = 0; j < b->nodes[i].preds.size(); j++)
- {
- fprintf(fp, "%7.4f,", b->nodes[i].preds[j].Ehk / b->nodes[i].preds[j].nk);
- }
- fprintf(fp, "\n");
-
- total = 0;
-
- fprintf(fp, "nk:, ");
- for(j = 0; j < b->nodes[i].preds.size(); j++)
- {
- fprintf(fp, "%6d,", (int) b->nodes[i].preds[j].nk);
- total += b->nodes[i].preds[j].nk;
- }
- fprintf(fp, "\n");
-
- fprintf(fp, "max(lab:cnt:tot):, %3d,%6d,%7d,\n", (int) b->nodes[i].max_count_label, (int) b->nodes[i].max_count, (int) total);
- fprintf(fp, "left: %4d, right: %4d", (int) b->nodes[i].left, (int) b->nodes[i].right);
- fprintf(fp, "\n\n");
- }
-
- fclose(fp);
- }
-
- void finish(log_multi& b)
- {
- save_node_stats(b);
- cout << "used " << b.nbofswaps << " swaps" << endl;
- }
-
- void save_load_tree(log_multi& b, io_buf& model_file, bool read, bool text)
- {
- if (model_file.files.size() > 0)
- {
- char buff[512];
-
- uint32_t text_len = sprintf(buff, "k = %d ",b.k);
- bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.k), "", read, buff, text_len, text);
- uint32_t temp = (uint32_t)b.nodes.size();
- text_len = sprintf(buff, "nodes = %d ",temp);
- bin_text_read_write_fixed(model_file,(char*)&temp, sizeof(temp), "", read, buff, text_len, text);
- if (read)
- for (uint32_t j = 1; j < temp; j++)
- b.nodes.push_back(init_node());
- text_len = sprintf(buff, "max_predictors = %d ",b.max_predictors);
- bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.max_predictors), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, "predictors_used = %d ",b.predictors_used);
- bin_text_read_write_fixed(model_file,(char*)&b.predictors_used, sizeof(b.predictors_used), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, "progress = %d ",b.progress);
- bin_text_read_write_fixed(model_file,(char*)&b.progress, sizeof(b.progress), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, "swap_resist = %d\n",b.swap_resist);
- bin_text_read_write_fixed(model_file,(char*)&b.swap_resist, sizeof(b.swap_resist), "", read, buff, text_len, text);
-
- for (size_t j = 0; j < b.nodes.size(); j++)
- {//Need to read or write nodes.
- node& n = b.nodes[j];
- text_len = sprintf(buff, " parent = %d",n.parent);
- bin_text_read_write_fixed(model_file,(char*)&n.parent, sizeof(n.parent), "", read, buff, text_len, text);
-
- uint32_t temp = (uint32_t)n.preds.size();
- text_len = sprintf(buff, " preds = %d",temp);
- bin_text_read_write_fixed(model_file,(char*)&temp, sizeof(temp), "", read, buff, text_len, text);
- if (read)
- for (uint32_t k = 0; k < temp; k++)
- n.preds.push_back(node_pred(1));
-
- text_len = sprintf(buff, " min_count = %d",n.min_count);
- bin_text_read_write_fixed(model_file,(char*)&n.min_count, sizeof(n.min_count), "", read, buff, text_len, text);
-
- uint32_t text_len = sprintf(buff, " internal = %d",n.internal);
- bin_text_read_write_fixed(model_file,(char*)&n.internal, sizeof(n.internal), "", read, buff, text_len, text)
-;
-
- if (n.internal)
- {
- text_len = sprintf(buff, " base_predictor = %d",n.base_predictor);
- bin_text_read_write_fixed(model_file,(char*)&n.base_predictor, sizeof(n.base_predictor), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " left = %d",n.left);
- bin_text_read_write_fixed(model_file,(char*)&n.left, sizeof(n.left), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " right = %d",n.right);
- bin_text_read_write_fixed(model_file,(char*)&n.right, sizeof(n.right), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " norm_Eh = %f",n.norm_Eh);
- bin_text_read_write_fixed(model_file,(char*)&n.norm_Eh, sizeof(n.norm_Eh), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " Eh = %f",n.Eh);
- bin_text_read_write_fixed(model_file,(char*)&n.Eh, sizeof(n.Eh), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " n = %d\n",n.n);
- bin_text_read_write_fixed(model_file,(char*)&n.n, sizeof(n.n), "", read, buff, text_len, text);
- }
- else
- {
- text_len = sprintf(buff, " max_count = %d",n.max_count);
- bin_text_read_write_fixed(model_file,(char*)&n.max_count, sizeof(n.max_count), "", read, buff, text_len, text);
- text_len = sprintf(buff, " max_count_label = %d\n",n.max_count_label);
- bin_text_read_write_fixed(model_file,(char*)&n.max_count_label, sizeof(n.max_count_label), "", read, buff, text_len, text);
- }
-
- for (size_t k = 0; k < n.preds.size(); k++)
- {
- node_pred& p = n.preds[k];
-
- text_len = sprintf(buff, " Ehk = %f",p.Ehk);
- bin_text_read_write_fixed(model_file,(char*)&p.Ehk, sizeof(p.Ehk), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " norm_Ehk = %f",p.norm_Ehk);
- bin_text_read_write_fixed(model_file,(char*)&p.norm_Ehk, sizeof(p.norm_Ehk), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " nk = %d",p.nk);
- bin_text_read_write_fixed(model_file,(char*)&p.nk, sizeof(p.nk), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " label = %d",p.label);
- bin_text_read_write_fixed(model_file,(char*)&p.label, sizeof(p.label), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " label_count = %d\n",p.label_count);
- bin_text_read_write_fixed(model_file,(char*)&p.label_count, sizeof(p.label_count), "", read, buff, text_len, text);
- }
- }
- }
- }
-
- void finish_example(vw& all, log_multi&, example& ec)
- {
- MULTICLASS::output_example(all, ec);
- VW::finish_example(all, &ec);
- }
-
- learner* setup(vw& all, po::variables_map& vm) //learner setup
- {
- log_multi* data = (log_multi*)calloc(1, sizeof(log_multi));
-
- po::options_description opts("TXM Online options");
- opts.add_options()
- ("no_progress", "disable progressive validation")
- ("swap_resistance", po::value<uint32_t>(&(data->swap_resist))->default_value(4), "higher = more resistance to swap, default=4");
-
- vm = add_options(all, opts);
-
- data->k = (uint32_t)vm["log_multi"].as<size_t>();
-
- //append log_multi with nb_actions to options_from_file so it is saved to regressor later
- std::stringstream ss;
- ss << " --log_multi " << data->k;
- all.file_options.append(ss.str());
-
- if (vm.count("no_progress"))
- data->progress = false;
- else
- data->progress = true;
-
- data->all = &all;
- (all.p->lp) = MULTICLASS::mc_label;
-
- string loss_function = "quantile";
- float loss_parameter = 0.5;
- delete(all.loss);
- all.loss = getLossFunction(&all, loss_function, loss_parameter);
-
- data->max_predictors = data->k - 1;
-
- learner* l = new learner(data, all.l, data->max_predictors);
- l->set_save_load<log_multi,save_load_tree>();
- l->set_learn<log_multi,learn>();
- l->set_predict<log_multi,predict>();
- l->set_finish_example<log_multi,finish_example>();
- l->set_finish<log_multi,finish>();
-
- init_tree(*data);
-
- return l;
- }
-}
+/*\t + +Copyright (c) by respective owners including Yahoo!, Microsoft, and +individual contributors. All rights reserved. Released under a BSD (revised) +license as described in the file LICENSE.node +*/ +#include <float.h> +#include <math.h> +#include <stdio.h> +#include <sstream> + +#include "reductions.h" +#include "simple_label.h" +#include "multiclass.h" +#include "vw.h" + +using namespace std; +using namespace LEARNER; + +namespace LOG_MULTI +{ + class node_pred + { + public: + + double Ehk; + float norm_Ehk; + uint32_t nk; + uint32_t label; + uint32_t label_count; + + bool operator==(node_pred v){ + return (label == v.label); + } + + bool operator>(node_pred v){ + if(label > v.label) return true; + return false; + } + + bool operator<(node_pred v){ + if(label < v.label) return true; + return false; + } + + node_pred(uint32_t l) + { + label = l; + Ehk = 0.f; + norm_Ehk = 0; + nk = 0; + label_count = 0; + } + }; + + typedef struct + {//everyone has + uint32_t parent;//the parent node + v_array<node_pred> preds;//per-class state + uint32_t min_count;//the number of examples reaching this node (if it's a leaf) or the minimum reaching any grandchild. + + bool internal;//internal or leaf + + //internal nodes have + uint32_t base_predictor;//id of the base predictor + uint32_t left;//left child + uint32_t right;//right child + float norm_Eh;//the average margin at the node + double Eh;//total margin at the node + uint32_t n;//total events at the node + + //leaf has + uint32_t max_count;//the number of samples of the most common label + uint32_t max_count_label;//the most common label + } node; + + struct log_multi + { + uint32_t k; + vw* all; + + v_array<node> nodes; + + uint32_t max_predictors; + uint32_t predictors_used; + + bool progress; + uint32_t swap_resist; + + uint32_t nbofswaps; + }; + + inline void init_leaf(node& n) + { + n.internal = false; + n.preds.erase(); + n.base_predictor = 0; + n.norm_Eh = 0; + n.Eh = 0; + n.n = 0; + n.max_count = 0; + n.max_count_label = 1; + n.left = 0; + n.right = 0; + } + + inline node init_node() + { + node node; + + node.parent = 0; + node.min_count = 0; + node.preds = v_init<node_pred>(); + init_leaf(node); + + return node; + } + + void init_tree(log_multi& d) + { + d.nodes.push_back(init_node()); + d.nbofswaps = 0; + } + + inline uint32_t min_left_right(log_multi& b, node& n) + { + return min(b.nodes[n.left].min_count, b.nodes[n.right].min_count); + } + + inline uint32_t find_switch_node(log_multi& b) + { + uint32_t node = 0; + while(b.nodes[node].internal) + if(b.nodes[b.nodes[node].left].min_count + < b.nodes[b.nodes[node].right].min_count) + node = b.nodes[node].left; + else + node = b.nodes[node].right; + return node; + } + + inline void update_min_count(log_multi& b, uint32_t node) + {//Constant time min count update. + while(node != 0) + { + uint32_t prev = node; + node = b.nodes[node].parent; + + if (b.nodes[node].min_count == b.nodes[prev].min_count) + break; + else + b.nodes[node].min_count = min_left_right(b,b.nodes[node]); + } + } + + void display_tree_dfs(log_multi& b, node node, uint32_t depth) + { + for (uint32_t i = 0; i < depth; i++) + cout << "\t"; + cout << node.min_count << " " << node.left + << " " << node.right; + cout << " label = " << node.max_count_label << " labels = "; + for (size_t i = 0; i < node.preds.size(); i++) + cout << node.preds[i].label << ":" << node.preds[i].label_count << "\t"; + cout << endl; + + if (node.internal) + { + cout << "Left"; + display_tree_dfs(b, b.nodes[node.left], depth+1); + + cout << "Right"; + display_tree_dfs(b, b.nodes[node.right], depth+1); + } + } + + bool children(log_multi& b, uint32_t& current, uint32_t& class_index, uint32_t label) + { + class_index = (uint32_t)b.nodes[current].preds.unique_add_sorted(node_pred(label)); + b.nodes[current].preds[class_index].label_count++; + + if(b.nodes[current].preds[class_index].label_count > b.nodes[current].max_count) + { + b.nodes[current].max_count = b.nodes[current].preds[class_index].label_count; + b.nodes[current].max_count_label = b.nodes[current].preds[class_index].label; + } + + if (b.nodes[current].internal) + return true; + else if( b.nodes[current].preds.size() > 1 + && (b.predictors_used < b.max_predictors + || b.nodes[current].min_count - b.nodes[current].max_count > b.swap_resist*(b.nodes[0].min_count + 1))) + { //need children and we can make them. + uint32_t left_child; + uint32_t right_child; + if (b.predictors_used < b.max_predictors) + { + left_child = (uint32_t)b.nodes.size(); + b.nodes.push_back(init_node()); + right_child = (uint32_t)b.nodes.size(); + b.nodes.push_back(init_node()); + b.nodes[current].base_predictor = b.predictors_used++; + } + else + { + uint32_t swap_child = find_switch_node(b); + uint32_t swap_parent = b.nodes[swap_child].parent; + uint32_t swap_grandparent = b.nodes[swap_parent].parent; + if (b.nodes[swap_child].min_count != b.nodes[0].min_count) + cout << "glargh " << b.nodes[swap_child].min_count << " != " << b.nodes[0].min_count << endl; + b.nbofswaps++; + + uint32_t nonswap_child; + if(swap_child == b.nodes[swap_parent].right) + nonswap_child = b.nodes[swap_parent].left; + else + nonswap_child = b.nodes[swap_parent].right; + + if(swap_parent == b.nodes[swap_grandparent].left) + b.nodes[swap_grandparent].left = nonswap_child; + else + b.nodes[swap_grandparent].right = nonswap_child; + b.nodes[nonswap_child].parent = swap_grandparent; + update_min_count(b, nonswap_child); + + init_leaf(b.nodes[swap_child]); + left_child = swap_child; + b.nodes[current].base_predictor = b.nodes[swap_parent].base_predictor; + init_leaf(b.nodes[swap_parent]); + right_child = swap_parent; + } + b.nodes[current].left = left_child; + b.nodes[left_child].parent = current; + b.nodes[current].right = right_child; + b.nodes[right_child].parent = current; + + b.nodes[left_child].min_count = b.nodes[current].min_count/2; + b.nodes[right_child].min_count = b.nodes[current].min_count - b.nodes[left_child].min_count; + update_min_count(b, left_child); + + b.nodes[left_child].max_count_label = b.nodes[current].max_count_label; + b.nodes[right_child].max_count_label = b.nodes[current].max_count_label; + + b.nodes[current].internal = true; + } + return b.nodes[current].internal; + } + + void train_node(log_multi& b, learner& base, example& ec, uint32_t& current, uint32_t& class_index) + { + if(b.nodes[current].norm_Eh > b.nodes[current].preds[class_index].norm_Ehk) + ec.l.simple.label = -1.f; + else + ec.l.simple.label = 1.f; + + base.learn(ec, b.nodes[current].base_predictor); + + ec.l.simple.label = FLT_MAX; + base.predict(ec, b.nodes[current].base_predictor); + + b.nodes[current].Eh += (double)ec.partial_prediction; + b.nodes[current].preds[class_index].Ehk += (double)ec.partial_prediction; + b.nodes[current].n++; + b.nodes[current].preds[class_index].nk++; + + b.nodes[current].norm_Eh = (float)b.nodes[current].Eh / b.nodes[current].n; + b.nodes[current].preds[class_index].norm_Ehk = (float)b.nodes[current].preds[class_index].Ehk / b.nodes[current].preds[class_index].nk; + } + + void verify_min_dfs(log_multi& b, node node) + { + if (node.internal) + { + if (node.min_count != min_left_right(b, node)) + { + cout << "badness! " << endl; + display_tree_dfs(b, b.nodes[0], 0); + } + verify_min_dfs(b, b.nodes[node.left]); + verify_min_dfs(b, b.nodes[node.right]); + } + } + + size_t sum_count_dfs(log_multi& b, node node) + { + if (node.internal) + return sum_count_dfs(b, b.nodes[node.left]) + sum_count_dfs(b, b.nodes[node.right]); + else + return node.min_count; + } + + inline uint32_t descend(node& n, float prediction) + { + if (prediction < 0) + return n.left; + else + return n.right; + } + + void predict(log_multi& b, learner& base, example& ec) + { + MULTICLASS::multiclass mc = ec.l.multi; + + label_data simple_temp; + simple_temp.initial = 0.0; + simple_temp.weight = 0.0; + simple_temp.label = FLT_MAX; + ec.l.simple = simple_temp; + uint32_t cn = 0; + while(b.nodes[cn].internal) + { + base.predict(ec, b.nodes[cn].base_predictor); + cn = descend(b.nodes[cn], ec.pred.scalar); + } + ec.pred.multiclass = b.nodes[cn].max_count_label; + ec.l.multi = mc; + } + + void learn(log_multi& b, learner& base, example& ec) + { + // verify_min_dfs(b, b.nodes[0]); + + if (ec.l.multi.label == (uint32_t)-1 || !b.all->training || b.progress) + predict(b,base,ec); + + if(b.all->training && (ec.l.multi.label != (uint32_t)-1) && !ec.test_only) //if training the tree + { + MULTICLASS::multiclass mc = ec.l.multi; + + uint32_t class_index = 0; + label_data simple_temp; + simple_temp.initial = 0.0; + simple_temp.weight = mc.weight; + ec.l.simple = simple_temp; + + uint32_t cn = 0; + + while(children(b, cn, class_index, mc.label)) + { + train_node(b, base, ec, cn, class_index); + cn = descend(b.nodes[cn], ec.pred.scalar); + } + + b.nodes[cn].min_count++; + update_min_count(b, cn); + + ec.l.multi = mc; + } + } + + void save_node_stats(log_multi& d) + { + FILE *fp; + uint32_t i, j; + uint32_t total; + log_multi* b = &d; + + fp = fopen("atxm_debug.csv", "wt"); + + for(i = 0; i < b->nodes.size(); i++) + { + fprintf(fp, "Node: %4d, Internal: %1d, Eh: %7.4f, n: %6d, \n", (int) i, (int) b->nodes[i].internal, b->nodes[i].Eh / b->nodes[i].n, b->nodes[i].n); + + fprintf(fp, "Label:, "); + for(j = 0; j < b->nodes[i].preds.size(); j++) + { + fprintf(fp, "%6d,", (int) b->nodes[i].preds[j].label); + } + fprintf(fp, "\n"); + + fprintf(fp, "Ehk:, "); + for(j = 0; j < b->nodes[i].preds.size(); j++) + { + fprintf(fp, "%7.4f,", b->nodes[i].preds[j].Ehk / b->nodes[i].preds[j].nk); + } + fprintf(fp, "\n"); + + total = 0; + + fprintf(fp, "nk:, "); + for(j = 0; j < b->nodes[i].preds.size(); j++) + { + fprintf(fp, "%6d,", (int) b->nodes[i].preds[j].nk); + total += b->nodes[i].preds[j].nk; + } + fprintf(fp, "\n"); + + fprintf(fp, "max(lab:cnt:tot):, %3d,%6d,%7d,\n", (int) b->nodes[i].max_count_label, (int) b->nodes[i].max_count, (int) total); + fprintf(fp, "left: %4d, right: %4d", (int) b->nodes[i].left, (int) b->nodes[i].right); + fprintf(fp, "\n\n"); + } + + fclose(fp); + } + + void finish(log_multi& b) + { + save_node_stats(b); + cout << "used " << b.nbofswaps << " swaps" << endl; + } + + void save_load_tree(log_multi& b, io_buf& model_file, bool read, bool text) + { + if (model_file.files.size() > 0) + { + char buff[512]; + + uint32_t text_len = sprintf(buff, "k = %d ",b.k); + bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.k), "", read, buff, text_len, text); + uint32_t temp = (uint32_t)b.nodes.size(); + text_len = sprintf(buff, "nodes = %d ",temp); + bin_text_read_write_fixed(model_file,(char*)&temp, sizeof(temp), "", read, buff, text_len, text); + if (read) + for (uint32_t j = 1; j < temp; j++) + b.nodes.push_back(init_node()); + text_len = sprintf(buff, "max_predictors = %d ",b.max_predictors); + bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.max_predictors), "", read, buff, text_len, text); + + text_len = sprintf(buff, "predictors_used = %d ",b.predictors_used); + bin_text_read_write_fixed(model_file,(char*)&b.predictors_used, sizeof(b.predictors_used), "", read, buff, text_len, text); + + text_len = sprintf(buff, "progress = %d ",b.progress); + bin_text_read_write_fixed(model_file,(char*)&b.progress, sizeof(b.progress), "", read, buff, text_len, text); + + text_len = sprintf(buff, "swap_resist = %d\n",b.swap_resist); + bin_text_read_write_fixed(model_file,(char*)&b.swap_resist, sizeof(b.swap_resist), "", read, buff, text_len, text); + + for (size_t j = 0; j < b.nodes.size(); j++) + {//Need to read or write nodes. + node& n = b.nodes[j]; + text_len = sprintf(buff, " parent = %d",n.parent); + bin_text_read_write_fixed(model_file,(char*)&n.parent, sizeof(n.parent), "", read, buff, text_len, text); + + uint32_t temp = (uint32_t)n.preds.size(); + text_len = sprintf(buff, " preds = %d",temp); + bin_text_read_write_fixed(model_file,(char*)&temp, sizeof(temp), "", read, buff, text_len, text); + if (read) + for (uint32_t k = 0; k < temp; k++) + n.preds.push_back(node_pred(1)); + + text_len = sprintf(buff, " min_count = %d",n.min_count); + bin_text_read_write_fixed(model_file,(char*)&n.min_count, sizeof(n.min_count), "", read, buff, text_len, text); + + uint32_t text_len = sprintf(buff, " internal = %d",n.internal); + bin_text_read_write_fixed(model_file,(char*)&n.internal, sizeof(n.internal), "", read, buff, text_len, text) +; + + if (n.internal) + { + text_len = sprintf(buff, " base_predictor = %d",n.base_predictor); + bin_text_read_write_fixed(model_file,(char*)&n.base_predictor, sizeof(n.base_predictor), "", read, buff, text_len, text); + + text_len = sprintf(buff, " left = %d",n.left); + bin_text_read_write_fixed(model_file,(char*)&n.left, sizeof(n.left), "", read, buff, text_len, text); + + text_len = sprintf(buff, " right = %d",n.right); + bin_text_read_write_fixed(model_file,(char*)&n.right, sizeof(n.right), "", read, buff, text_len, text); + + text_len = sprintf(buff, " norm_Eh = %f",n.norm_Eh); + bin_text_read_write_fixed(model_file,(char*)&n.norm_Eh, sizeof(n.norm_Eh), "", read, buff, text_len, text); + + text_len = sprintf(buff, " Eh = %f",n.Eh); + bin_text_read_write_fixed(model_file,(char*)&n.Eh, sizeof(n.Eh), "", read, buff, text_len, text); + + text_len = sprintf(buff, " n = %d\n",n.n); + bin_text_read_write_fixed(model_file,(char*)&n.n, sizeof(n.n), "", read, buff, text_len, text); + } + else + { + text_len = sprintf(buff, " max_count = %d",n.max_count); + bin_text_read_write_fixed(model_file,(char*)&n.max_count, sizeof(n.max_count), "", read, buff, text_len, text); + text_len = sprintf(buff, " max_count_label = %d\n",n.max_count_label); + bin_text_read_write_fixed(model_file,(char*)&n.max_count_label, sizeof(n.max_count_label), "", read, buff, text_len, text); + } + + for (size_t k = 0; k < n.preds.size(); k++) + { + node_pred& p = n.preds[k]; + + text_len = sprintf(buff, " Ehk = %f",p.Ehk); + bin_text_read_write_fixed(model_file,(char*)&p.Ehk, sizeof(p.Ehk), "", read, buff, text_len, text); + + text_len = sprintf(buff, " norm_Ehk = %f",p.norm_Ehk); + bin_text_read_write_fixed(model_file,(char*)&p.norm_Ehk, sizeof(p.norm_Ehk), "", read, buff, text_len, text); + + text_len = sprintf(buff, " nk = %d",p.nk); + bin_text_read_write_fixed(model_file,(char*)&p.nk, sizeof(p.nk), "", read, buff, text_len, text); + + text_len = sprintf(buff, " label = %d",p.label); + bin_text_read_write_fixed(model_file,(char*)&p.label, sizeof(p.label), "", read, buff, text_len, text); + + text_len = sprintf(buff, " label_count = %d\n",p.label_count); + bin_text_read_write_fixed(model_file,(char*)&p.label_count, sizeof(p.label_count), "", read, buff, text_len, text); + } + } + } + } + + void finish_example(vw& all, log_multi&, example& ec) + { + MULTICLASS::output_example(all, ec); + VW::finish_example(all, &ec); + } + + learner* setup(vw& all, po::variables_map& vm) //learner setup + { + log_multi* data = (log_multi*)calloc(1, sizeof(log_multi)); + + po::options_description opts("TXM Online options"); + opts.add_options() + ("no_progress", "disable progressive validation") + ("swap_resistance", po::value<uint32_t>(&(data->swap_resist))->default_value(4), "higher = more resistance to swap, default=4"); + + vm = add_options(all, opts); + + data->k = (uint32_t)vm["log_multi"].as<size_t>(); + + //append log_multi with nb_actions to options_from_file so it is saved to regressor later + std::stringstream ss; + ss << " --log_multi " << data->k; + all.file_options.append(ss.str()); + + if (vm.count("no_progress")) + data->progress = false; + else + data->progress = true; + + data->all = &all; + (all.p->lp) = MULTICLASS::mc_label; + + string loss_function = "quantile"; + float loss_parameter = 0.5; + delete(all.loss); + all.loss = getLossFunction(&all, loss_function, loss_parameter); + + data->max_predictors = data->k - 1; + + learner* l = new learner(data, all.l, data->max_predictors); + l->set_save_load<log_multi,save_load_tree>(); + l->set_learn<log_multi,learn>(); + l->set_predict<log_multi,predict>(); + l->set_finish_example<log_multi,finish_example>(); + l->set_finish<log_multi,finish>(); + + init_tree(*data); + + return l; + } +} |