diff options
author | ariel faigon <github.2009@yendor.com> | 2015-01-04 10:11:31 +0300 |
---|---|---|
committer | ariel faigon <github.2009@yendor.com> | 2015-01-04 10:11:31 +0300 |
commit | afe136b52b503239166fe5fee22d3b2c9fd144cd (patch) | |
tree | 02346b7ad036ffbdfe5cad3674636b8ab72b70b1 | |
parent | f1f859a19c8a3c70ad7b4706d78ae5cac6bdca36 (diff) | |
parent | 143522150dedcd953c8ac4f3a0114c9195f4efb6 (diff) |
Merge branch 'master' of git://github.com/JohnLangford/vowpal_wabbit
102 files changed, 4150 insertions, 4284 deletions
@@ -20,17 +20,21 @@ UNAME := $(shell uname) LIBS = -l boost_program_options -l pthread -l z BOOST_INCLUDE = -I /usr/include BOOST_LIBRARY = -L /usr/lib +NPROCS := 1 ifeq ($(UNAME), Linux) BOOST_LIBRARY += -L /usr/lib/x86_64-linux-gnu + NPROCS:=$(shell grep -c ^processor /proc/cpuinfo) endif ifeq ($(UNAME), FreeBSD) LIBS = -l boost_program_options -l pthread -l z -l compat BOOST_INCLUDE = -I /usr/local/include + NPROCS:=$(shell grep -c ^processor /proc/cpuinfo) endif ifeq "CYGWIN" "$(findstring CYGWIN,$(UNAME))" LIBS = -l boost_program_options-mt -l pthread -l z BOOST_INCLUDE = -I /usr/include + NPROCS:=$(shell grep -c ^processor /proc/cpuinfo) endif ifeq ($(UNAME), Darwin) LIBS = -lboost_program_options-mt -lboost_serialization-mt -l pthread -l z @@ -46,6 +50,7 @@ ifeq ($(UNAME), Darwin) BOOST_INCLUDE += -I /opt/local/include BOOST_LIBRARY += -L /opt/local/lib endif + NPROCS:=$(shell sysctl -n hw.ncpu) endif #LIBS = -l boost_program_options-gcc34 -l pthread -l z @@ -84,7 +89,7 @@ spanning_tree: cd cluster; $(MAKE) vw: - cd vowpalwabbit; $(MAKE) -j 8 things + cd vowpalwabbit; $(MAKE) -j $(NPROCS) things active_interactor: cd vowpalwabbit; $(MAKE) @@ -117,3 +122,5 @@ clean: ifneq ($(JAVA_HOME),) cd java && $(MAKE) clean endif + +.PHONY: all clean install diff --git a/Makefile.am b/Makefile.am index 8bc91364..abb79bb2 100644 --- a/Makefile.am +++ b/Makefile.am @@ -36,7 +36,6 @@ noinst_HEADERS = vowpalwabbit/accumulate.h \ vowpalwabbit/lda_core.h \ vowpalwabbit/log_multi.h \ vowpalwabbit/lrq.h \ - vowpalwabbit/memory.h \ vowpalwabbit/mf.h \ vowpalwabbit/multiclass.h \ vowpalwabbit/network.h \ diff --git a/demo/dna/Makefile b/demo/dna/Makefile index 93efac3b..477cfb98 100644 --- a/demo/dna/Makefile +++ b/demo/dna/Makefile @@ -27,3 +27,5 @@ quaddna2vw: quaddna2vw.cpp %.perf: %.test.predictions perf.check perl.check zsh.check ./do-perf $< + +.PHONY: all clean diff --git a/demo/entityrelation/Makefile b/demo/entityrelation/Makefile index c287d97c..c3db3a69 100644 --- a/demo/entityrelation/Makefile +++ b/demo/entityrelation/Makefile @@ -8,7 +8,7 @@ eval_script=./evaluationER.py all: @cat README.md clean: - rm -f *.model *.predictions ER_*.vw *.cache evaluationER.py er.zip *~ + rm -f *.model *.predictions ER_*.vw *.cache evaluationER.py er.zip *~ %.check: @test -x "$$(which $*)" || { \ @@ -33,5 +33,4 @@ er.test.predictions: ER_test.vw er.model er.perf: ER_test.vw er.test.predictions python.check @$(eval_script) $< er.test.predictions - - +.PHONY: all clean diff --git a/demo/movielens/Makefile b/demo/movielens/Makefile index 68be5310..46c92e24 100644 --- a/demo/movielens/Makefile +++ b/demo/movielens/Makefile @@ -39,7 +39,7 @@ ml-%.ratings.train.vw: ml-%/ratings.dat @printf "%s test MAE is %3.3f\n" $* $$(cat $*.results) #--------------------------------------------------------------------- -# linear model (no interaction terms) +# linear model (no interaction terms) #--------------------------------------------------------------------- linear.results: ml-1m.ratings.test.vw ml-1m.ratings.train.vw @@ -154,3 +154,5 @@ lrqdropouthogwild.results: ml-1m.ratings.test.vw ml-1m.ratings.train.vw do-lrq-h -d $(word 1,$+) -p \ >(perl -lane '$$s+=abs(($$F[0]-$$F[1])); } { \ 1; print $$s/$$.;' > $@) + +.PHONY: all clean shootout diff --git a/demo/normalized/Makefile b/demo/normalized/Makefile index 8268e82f..f1378b80 100644 --- a/demo/normalized/Makefile +++ b/demo/normalized/Makefile @@ -19,8 +19,8 @@ SHUFFLE='BEGIN { srand 69; }; \ $$b[$$i] = $$_; } { print grep { defined $$_ } @b;' #--------------------------------------------------------------------- -# bank marketing -# +# bank marketing +# # normalization really helps. The columns have units of euros, seconds, # days, and years; in addition there are categorical variables. #--------------------------------------------------------------------- @@ -57,10 +57,10 @@ bankbestmin=1e-2 bankbestmax=10 banknonormbestmin=1e-8 banknonormbestmax=1e-3 -banktimeestimate=1 +banktimeestimate=1 #--------------------------------------------------------------------- -# covertype +# covertype #--------------------------------------------------------------------- covtype.data.gz: @@ -88,11 +88,11 @@ covertypebestmin=1e-2 covertypebestmax=10 covertypenonormbestmin=1e-8 covertypenonormbestmax=1e-5 -covertypetimeestimate=5 +covertypetimeestimate=5 #--------------------------------------------------------------------- -# million song database -# +# million song database +# # normalization is helpful. #--------------------------------------------------------------------- @@ -142,7 +142,7 @@ MSDnonormbestmax=1e-5 MSDtimeestimate=15 #--------------------------------------------------------------------- -# census-income (KDD) +# census-income (KDD) #--------------------------------------------------------------------- census-income.data.gz: @@ -180,7 +180,7 @@ censusnonormbestmax=1e-4 censustimeestimate=5 #--------------------------------------------------------------------- -# Statlog (Shuttle) +# Statlog (Shuttle) #--------------------------------------------------------------------- shuttle.trn.Z: @@ -211,7 +211,7 @@ shuttlenonormbestmax=1e-3 shuttletimeestimate=1 #--------------------------------------------------------------------- -# CT slices +# CT slices # # normalization doesn't help much #--------------------------------------------------------------------- @@ -256,7 +256,7 @@ CTslicenonormbestmax=1 CTslicetimeestimate=15 #--------------------------------------------------------------------- -# common routines +# common routines #--------------------------------------------------------------------- %.best: %.data @@ -269,8 +269,10 @@ CTslicetimeestimate=15 @printf "WARNING: this step takes about %s minutes\n" $($*timeestimate) @./hypersearch $($*nonormbestmin) $($*nonormbestmax) '$(MAKE)' '$*.%.nonormlearn' > $@ -%.resultsprint: +%.resultsprint: @printf "%20.20s\t%9.3g\t%9.3g\t%9.3g\t%9.3g\n" "$*" $$(cut -f1 $*.best) $$(cut -f2 $*.best) $$(cut -f1 $*.nonormbest) $$(cut -f2 $*.nonormbest) only.%: %.best %.nonormbest all.results.pre %.resultsprint @true + +.PHONY: all all.results all.results.pre diff --git a/explore/clr/explore_clr_wrapper.cpp b/explore/clr/explore_clr_wrapper.cpp index 0564482b..be022af1 100644 --- a/explore/clr/explore_clr_wrapper.cpp +++ b/explore/clr/explore_clr_wrapper.cpp @@ -1,18 +1,18 @@ -// vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application.
-//
-
-#define WIN32_LEAN_AND_MEAN
-#include <Windows.h>
-
-#include "explore_clr_wrapper.h"
-
-using namespace System;
-using namespace System::Collections;
-using namespace System::Collections::Generic;
-using namespace System::Runtime::InteropServices;
-using namespace msclr::interop;
-using namespace NativeMultiWorldTesting;
-
-namespace MultiWorldTesting {
-
-}
+// vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application. +// + +#define WIN32_LEAN_AND_MEAN +#include <Windows.h> + +#include "explore_clr_wrapper.h" + +using namespace System; +using namespace System::Collections; +using namespace System::Collections::Generic; +using namespace System::Runtime::InteropServices; +using namespace msclr::interop; +using namespace NativeMultiWorldTesting; + +namespace MultiWorldTesting { + +} diff --git a/explore/clr/explore_clr_wrapper.h b/explore/clr/explore_clr_wrapper.h index a6cede4b..c73c95d4 100644 --- a/explore/clr/explore_clr_wrapper.h +++ b/explore/clr/explore_clr_wrapper.h @@ -1,438 +1,438 @@ -#pragma once
-#include "explore_interop.h"
-
-/*!
-* \addtogroup MultiWorldTestingCsharp
-* @{
-*/
-namespace MultiWorldTesting {
-
- /// <summary>
- /// The epsilon greedy exploration class.
- /// </summary>
- /// <remarks>
- /// This is a good choice if you have no idea which actions should be preferred.
- /// Epsilon greedy is also computationally cheap.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicy">A default function which outputs an action given a context.</param>
- /// <param name="epsilon">The probability of a random exploration.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions)
- {
- this->defaultPolicy = defaultPolicy;
- m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions);
- }
-
- ~EpsilonGreedyExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- return defaultPolicy->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IPolicy<Ctx>^ defaultPolicy;
- NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The tau-first exploration class.
- /// </summary>
- /// <remarks>
- /// The tau-first explorer collects precisely tau uniform random
- /// exploration events, and then uses the default policy.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class TauFirstExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicy">A default policy after randomization finishes.</param>
- /// <param name="tau">The number of events to be uniform over.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions)
- {
- this->defaultPolicy = defaultPolicy;
- m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions);
- }
-
- ~TauFirstExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- return defaultPolicy->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IPolicy<Ctx>^ defaultPolicy;
- NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The epsilon greedy exploration class.
- /// </summary>
- /// <remarks>
- /// In some cases, different actions have a different scores, and you
- /// would prefer to choose actions with large scores. Softmax allows
- /// you to do that.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class SoftmaxExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultScorer">A function which outputs a score for each action.</param>
- /// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions)
- {
- this->defaultScorer = defaultScorer;
- m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions);
- }
-
- ~SoftmaxExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) override
- {
- return defaultScorer->ScoreActions(context);
- }
-
- NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IScorer<Ctx>^ defaultScorer;
- NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The generic exploration class.
- /// </summary>
- /// <remarks>
- /// GenericExplorer provides complete flexibility. You can create any
- /// distribution over actions desired, and it will draw from that.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class GenericExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultScorer">A function which outputs the probability of each action.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions)
- {
- this->defaultScorer = defaultScorer;
- m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions);
- }
-
- ~GenericExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) override
- {
- return defaultScorer->ScoreActions(context);
- }
-
- NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IScorer<Ctx>^ defaultScorer;
- NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The bootstrap exploration class.
- /// </summary>
- /// <remarks>
- /// The Bootstrap explorer randomizes over the actions chosen by a set of
- /// default policies. This performs well statistically but can be
- /// computationally expensive.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class BootstrapExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicies">A set of default policies to be uniform random over.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions)
- {
- this->defaultPolicies = defaultPolicies;
- if (this->defaultPolicies == nullptr)
- {
- throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null.");
- }
-
- m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions);
- }
-
- ~BootstrapExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- if (index < 0 || index >= defaultPolicies->Length)
- {
- throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range.");
- }
- return defaultPolicies[index]->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- cli::array<IPolicy<Ctx>^>^ defaultPolicies;
- NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The top level MwtExplorer class. Using this makes sure that the
- /// right bits are recorded and good random actions are chosen.
- /// </summary>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class MwtExplorer : public RecorderCallback<Ctx>
- {
- public:
- /// <summary>
- /// Constructor.
- /// </summary>
- /// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param>
- /// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param>
- MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder)
- {
- this->appId = appId;
- this->recorder = recorder;
- }
-
- /// <summary>
- /// Choose_Action should be drop-in replacement for any existing policy function.
- /// </summary>
- /// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param>
- /// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param>
- /// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param>
- /// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns>
- UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context)
- {
- String^ salt = this->appId;
- NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder());
-
- GCHandle selfHandle = GCHandle::Alloc(this);
- IntPtr selfPtr = (IntPtr)selfHandle;
-
- GCHandle contextHandle = GCHandle::Alloc(context);
- IntPtr contextPtr = (IntPtr)contextHandle;
-
- GCHandle explorerHandle = GCHandle::Alloc(explorer);
- IntPtr explorerPtr = (IntPtr)explorerHandle;
-
- NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer());
- u32 action = 0;
- if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid)
- {
- EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid)
- {
- TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid)
- {
- SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == GenericExplorer<Ctx>::typeid)
- {
- GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid)
- {
- BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
-
- explorerHandle.Free();
- contextHandle.Free();
- selfHandle.Free();
-
- return action;
- }
-
- internal:
- virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override
- {
- recorder->Record(context, action, probability, unique_key);
- }
-
- private:
- IRecorder<Ctx>^ recorder;
- String^ appId;
- };
-
- /// <summary>
- /// Represents a feature in a sparse array.
- /// </summary>
- [StructLayout(LayoutKind::Sequential)]
- public value struct Feature
- {
- float Value;
- UInt32 Id;
- };
-
- /// <summary>
- /// A sample recorder class that converts the exploration tuple into string format.
- /// </summary>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx> where Ctx : IStringContext
- public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx>
- {
- public:
- StringRecorder()
- {
- m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>();
- }
-
- ~StringRecorder()
- {
- delete m_string_recorder;
- }
-
- virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey)
- {
- GCHandle contextHandle = GCHandle::Alloc(context);
- IntPtr contextPtr = (IntPtr)contextHandle;
-
- NativeStringContext native_context(contextPtr.ToPointer(), GetCallback());
- m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey));
- }
-
- /// <summary>
- /// Gets the content of the recording so far as a string and clears internal content.
- /// </summary>
- /// <returns>
- /// A string with recording content.
- /// </returns>
- String^ GetRecording()
- {
- // Workaround for C++-CLI bug which does not allow default value for parameter
- return GetRecording(true);
- }
-
- /// <summary>
- /// Gets the content of the recording so far as a string and optionally clears internal content.
- /// </summary>
- /// <param name="flush">A boolean value indicating whether to clear the internal content.</param>
- /// <returns>
- /// A string with recording content.
- /// </returns>
- String^ GetRecording(bool flush)
- {
- return gcnew String(m_string_recorder->Get_Recording(flush).c_str());
- }
-
- private:
- NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder;
- };
-
- /// <summary>
- /// A sample context class that stores a vector of Features.
- /// </summary>
- public ref class SimpleContext : public IStringContext
- {
- public:
- SimpleContext(cli::array<Feature>^ features)
- {
- Features = features;
-
- // TODO: add another constructor overload for native SimpleContext to avoid copying feature values
- m_features = new vector<NativeMultiWorldTesting::Feature>();
- for (int i = 0; i < features->Length; i++)
- {
- m_features->push_back({ features[i].Value, features[i].Id });
- }
-
- m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features);
- }
-
- String^ ToString() override
- {
- return gcnew String(m_native_context->To_String().c_str());
- }
-
- ~SimpleContext()
- {
- delete m_native_context;
- }
-
- public:
- cli::array<Feature>^ GetFeatures() { return Features; }
-
- internal:
- cli::array<Feature>^ Features;
-
- private:
- vector<NativeMultiWorldTesting::Feature>* m_features;
- NativeMultiWorldTesting::SimpleContext* m_native_context;
- };
-}
-
+#pragma once +#include "explore_interop.h" + +/*! +* \addtogroup MultiWorldTestingCsharp +* @{ +*/ +namespace MultiWorldTesting { + + /// <summary> + /// The epsilon greedy exploration class. + /// </summary> + /// <remarks> + /// This is a good choice if you have no idea which actions should be preferred. + /// Epsilon greedy is also computationally cheap. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultPolicy">A default function which outputs an action given a context.</param> + /// <param name="epsilon">The probability of a random exploration.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions) + { + this->defaultPolicy = defaultPolicy; + m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions); + } + + ~EpsilonGreedyExplorer() + { + delete m_explorer; + } + + internal: + virtual UInt32 InvokePolicyCallback(Ctx context, int index) override + { + return defaultPolicy->ChooseAction(context); + } + + NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + IPolicy<Ctx>^ defaultPolicy; + NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The tau-first exploration class. + /// </summary> + /// <remarks> + /// The tau-first explorer collects precisely tau uniform random + /// exploration events, and then uses the default policy. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class TauFirstExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultPolicy">A default policy after randomization finishes.</param> + /// <param name="tau">The number of events to be uniform over.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions) + { + this->defaultPolicy = defaultPolicy; + m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions); + } + + ~TauFirstExplorer() + { + delete m_explorer; + } + + internal: + virtual UInt32 InvokePolicyCallback(Ctx context, int index) override + { + return defaultPolicy->ChooseAction(context); + } + + NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + IPolicy<Ctx>^ defaultPolicy; + NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The epsilon greedy exploration class. + /// </summary> + /// <remarks> + /// In some cases, different actions have a different scores, and you + /// would prefer to choose actions with large scores. Softmax allows + /// you to do that. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class SoftmaxExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultScorer">A function which outputs a score for each action.</param> + /// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions) + { + this->defaultScorer = defaultScorer; + m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions); + } + + ~SoftmaxExplorer() + { + delete m_explorer; + } + + internal: + virtual List<float>^ InvokeScorerCallback(Ctx context) override + { + return defaultScorer->ScoreActions(context); + } + + NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + IScorer<Ctx>^ defaultScorer; + NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The generic exploration class. + /// </summary> + /// <remarks> + /// GenericExplorer provides complete flexibility. You can create any + /// distribution over actions desired, and it will draw from that. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class GenericExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultScorer">A function which outputs the probability of each action.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions) + { + this->defaultScorer = defaultScorer; + m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions); + } + + ~GenericExplorer() + { + delete m_explorer; + } + + internal: + virtual List<float>^ InvokeScorerCallback(Ctx context) override + { + return defaultScorer->ScoreActions(context); + } + + NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + IScorer<Ctx>^ defaultScorer; + NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The bootstrap exploration class. + /// </summary> + /// <remarks> + /// The Bootstrap explorer randomizes over the actions chosen by a set of + /// default policies. This performs well statistically but can be + /// computationally expensive. + /// </remarks> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class BootstrapExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx> + { + public: + /// <summary> + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// </summary> + /// <param name="defaultPolicies">A set of default policies to be uniform random over.</param> + /// <param name="numActions">The number of actions to randomize over.</param> + BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions) + { + this->defaultPolicies = defaultPolicies; + if (this->defaultPolicies == nullptr) + { + throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null."); + } + + m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions); + } + + ~BootstrapExplorer() + { + delete m_explorer; + } + + internal: + virtual UInt32 InvokePolicyCallback(Ctx context, int index) override + { + if (index < 0 || index >= defaultPolicies->Length) + { + throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range."); + } + return defaultPolicies[index]->ChooseAction(context); + } + + NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get() + { + return m_explorer; + } + + private: + cli::array<IPolicy<Ctx>^>^ defaultPolicies; + NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer; + }; + + /// <summary> + /// The top level MwtExplorer class. Using this makes sure that the + /// right bits are recorded and good random actions are chosen. + /// </summary> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> + public ref class MwtExplorer : public RecorderCallback<Ctx> + { + public: + /// <summary> + /// Constructor. + /// </summary> + /// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param> + /// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param> + MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder) + { + this->appId = appId; + this->recorder = recorder; + } + + /// <summary> + /// Choose_Action should be drop-in replacement for any existing policy function. + /// </summary> + /// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param> + /// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param> + /// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param> + /// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns> + UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context) + { + String^ salt = this->appId; + NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder()); + + GCHandle selfHandle = GCHandle::Alloc(this); + IntPtr selfPtr = (IntPtr)selfHandle; + + GCHandle contextHandle = GCHandle::Alloc(context); + IntPtr contextPtr = (IntPtr)contextHandle; + + GCHandle explorerHandle = GCHandle::Alloc(explorer); + IntPtr explorerPtr = (IntPtr)explorerHandle; + + NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer()); + u32 action = 0; + if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid) + { + EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid) + { + TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid) + { + SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + else if (explorer->GetType() == GenericExplorer<Ctx>::typeid) + { + GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid) + { + BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer; + action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context); + } + + explorerHandle.Free(); + contextHandle.Free(); + selfHandle.Free(); + + return action; + } + + internal: + virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override + { + recorder->Record(context, action, probability, unique_key); + } + + private: + IRecorder<Ctx>^ recorder; + String^ appId; + }; + + /// <summary> + /// Represents a feature in a sparse array. + /// </summary> + [StructLayout(LayoutKind::Sequential)] + public value struct Feature + { + float Value; + UInt32 Id; + }; + + /// <summary> + /// A sample recorder class that converts the exploration tuple into string format. + /// </summary> + /// <typeparam name="Ctx">The Context type.</typeparam> + generic <class Ctx> where Ctx : IStringContext + public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx> + { + public: + StringRecorder() + { + m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>(); + } + + ~StringRecorder() + { + delete m_string_recorder; + } + + virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) + { + GCHandle contextHandle = GCHandle::Alloc(context); + IntPtr contextPtr = (IntPtr)contextHandle; + + NativeStringContext native_context(contextPtr.ToPointer(), GetCallback()); + m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey)); + } + + /// <summary> + /// Gets the content of the recording so far as a string and clears internal content. + /// </summary> + /// <returns> + /// A string with recording content. + /// </returns> + String^ GetRecording() + { + // Workaround for C++-CLI bug which does not allow default value for parameter + return GetRecording(true); + } + + /// <summary> + /// Gets the content of the recording so far as a string and optionally clears internal content. + /// </summary> + /// <param name="flush">A boolean value indicating whether to clear the internal content.</param> + /// <returns> + /// A string with recording content. + /// </returns> + String^ GetRecording(bool flush) + { + return gcnew String(m_string_recorder->Get_Recording(flush).c_str()); + } + + private: + NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder; + }; + + /// <summary> + /// A sample context class that stores a vector of Features. + /// </summary> + public ref class SimpleContext : public IStringContext + { + public: + SimpleContext(cli::array<Feature>^ features) + { + Features = features; + + // TODO: add another constructor overload for native SimpleContext to avoid copying feature values + m_features = new vector<NativeMultiWorldTesting::Feature>(); + for (int i = 0; i < features->Length; i++) + { + m_features->push_back({ features[i].Value, features[i].Id }); + } + + m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features); + } + + String^ ToString() override + { + return gcnew String(m_native_context->To_String().c_str()); + } + + ~SimpleContext() + { + delete m_native_context; + } + + public: + cli::array<Feature>^ GetFeatures() { return Features; } + + internal: + cli::array<Feature>^ Features; + + private: + vector<NativeMultiWorldTesting::Feature>* m_features; + NativeMultiWorldTesting::SimpleContext* m_native_context; + }; +} + /*! @} End of Doxygen Groups*/
\ No newline at end of file diff --git a/explore/clr/explore_interface.h b/explore/clr/explore_interface.h index 3212b4d8..dfa17697 100644 --- a/explore/clr/explore_interface.h +++ b/explore/clr/explore_interface.h @@ -1,88 +1,88 @@ -#pragma once
-
-using namespace System;
-using namespace System::Collections::Generic;
-
-/** \defgroup MultiWorldTestingCsharp
-\brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
-*/
-
-/*!
-* \addtogroup MultiWorldTestingCsharp
-* @{
-*/
-
-//! Interface for C# version of Multiworld Testing library.
-//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
-namespace MultiWorldTesting {
-
-/// <summary>
-/// Represents a recorder that exposes a method to record exploration data based on generic contexts.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-/// <remarks>
-/// Exploration data is specified as a set of tuples <context, action, probability, key> as described below. An
-/// application passes an IRecorder object to the @MwtExplorer constructor. See
-/// @StringRecorder for a sample IRecorder object.
-/// </remarks>
-generic <class Ctx>
-public interface class IRecorder
-{
-public:
- /// <summary>
- /// Records the exploration data associated with a given decision.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <param name="action">Chosen by an exploration algorithm given context.</param>
- /// <param name="probability">The probability of the chosen action given context.</param>
- /// <param name="uniqueKey">A user-defined identifer for the decision.</param>
- virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0;
-};
-
-/// <summary>
-/// Exposes a method for choosing an action given a generic context. IPolicy objects are
-/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-generic <class Ctx>
-public interface class IPolicy
-{
-public:
- /// <summary>
- /// Determines the action to take for a given context.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <returns>Index of the action to take (1-based)</returns>
- virtual UInt32 ChooseAction(Ctx context) = 0;
-};
-
-/// <summary>
-/// Exposes a method for specifying a score (weight) for each action given a generic context.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-generic <class Ctx>
-public interface class IScorer
-{
-public:
- /// <summary>
- /// Determines the score of each action for a given context.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <returns>Vector of scores indexed by action (1-based).</returns>
- virtual List<float>^ ScoreActions(Ctx context) = 0;
-};
-
-generic <class Ctx>
-public interface class IExplorer
-{
-};
-
-public interface class IStringContext
-{
-public:
- virtual String^ ToString() = 0;
-};
-
-}
-
+#pragma once + +using namespace System; +using namespace System::Collections::Generic; + +/** \defgroup MultiWorldTestingCsharp +\brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs +*/ + +/*! +* \addtogroup MultiWorldTestingCsharp +* @{ +*/ + +//! Interface for C# version of Multiworld Testing library. +//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs +namespace MultiWorldTesting { + +/// <summary> +/// Represents a recorder that exposes a method to record exploration data based on generic contexts. +/// </summary> +/// <typeparam name="Ctx">The Context type.</typeparam> +/// <remarks> +/// Exploration data is specified as a set of tuples <context, action, probability, key> as described below. An +/// application passes an IRecorder object to the @MwtExplorer constructor. See +/// @StringRecorder for a sample IRecorder object. +/// </remarks> +generic <class Ctx> +public interface class IRecorder +{ +public: + /// <summary> + /// Records the exploration data associated with a given decision. + /// </summary> + /// <param name="context">A user-defined context for the decision.</param> + /// <param name="action">Chosen by an exploration algorithm given context.</param> + /// <param name="probability">The probability of the chosen action given context.</param> + /// <param name="uniqueKey">A user-defined identifer for the decision.</param> + virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0; +}; + +/// <summary> +/// Exposes a method for choosing an action given a generic context. IPolicy objects are +/// passed to (and invoked by) exploration algorithms to specify the default policy behavior. +/// </summary> +/// <typeparam name="Ctx">The Context type.</typeparam> +generic <class Ctx> +public interface class IPolicy +{ +public: + /// <summary> + /// Determines the action to take for a given context. + /// </summary> + /// <param name="context">A user-defined context for the decision.</param> + /// <returns>Index of the action to take (1-based)</returns> + virtual UInt32 ChooseAction(Ctx context) = 0; +}; + +/// <summary> +/// Exposes a method for specifying a score (weight) for each action given a generic context. +/// </summary> +/// <typeparam name="Ctx">The Context type.</typeparam> +generic <class Ctx> +public interface class IScorer +{ +public: + /// <summary> + /// Determines the score of each action for a given context. + /// </summary> + /// <param name="context">A user-defined context for the decision.</param> + /// <returns>Vector of scores indexed by action (1-based).</returns> + virtual List<float>^ ScoreActions(Ctx context) = 0; +}; + +generic <class Ctx> +public interface class IExplorer +{ +}; + +public interface class IStringContext +{ +public: + virtual String^ ToString() = 0; +}; + +} + /*! @} End of Doxygen Groups*/
\ No newline at end of file diff --git a/explore/clr/explore_interop.h b/explore/clr/explore_interop.h index 99466945..a6cb9327 100644 --- a/explore/clr/explore_interop.h +++ b/explore/clr/explore_interop.h @@ -1,358 +1,358 @@ -#pragma once
-
-#define MANAGED_CODE
-
-#include "explore_interface.h"
-#include "MWTExplorer.h"
-
-#include <msclr\marshal_cppstd.h>
-
-using namespace System;
-using namespace System::Collections::Generic;
-using namespace System::IO;
-using namespace System::Runtime::InteropServices;
-using namespace System::Xml::Serialization;
-using namespace msclr::interop;
-
-namespace MultiWorldTesting {
-
-// Policy callback
-private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index);
-typedef u32 Native_Policy_Callback(void* explorer, void* context, int index);
-
-// Scorer callback
-private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size);
-typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size);
-
-// Recorder callback
-private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey);
-typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key);
-
-// ToString callback
-private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue);
-typedef void Native_To_String_Callback(void* explorer, void* string_value);
-
-// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context
-// used for triggering callback for Policy, Scorer, Recorder
-class NativeContext
-{
-public:
- NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context)
- {
- m_clr_mwt = clr_mwt;
- m_clr_explorer = clr_explorer;
- m_clr_context = clr_context;
- }
-
- void* Get_Clr_Mwt()
- {
- return m_clr_mwt;
- }
-
- void* Get_Clr_Context()
- {
- return m_clr_context;
- }
-
- void* Get_Clr_Explorer()
- {
- return m_clr_explorer;
- }
-
-private:
- void* m_clr_mwt;
- void* m_clr_context;
- void* m_clr_explorer;
-};
-
-class NativeStringContext
-{
-public:
- NativeStringContext(void* clr_context, Native_To_String_Callback* func)
- {
- m_clr_context = clr_context;
- m_func = func;
- }
-
- string To_String()
- {
- string value;
- m_func(m_clr_context, &value);
- return value;
- }
-private:
- void* m_clr_context;
- Native_To_String_Callback* m_func;
-};
-
-// NativeRecorder listens to callback event and reroute it to the managed Recorder instance
-class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext>
-{
-public:
- NativeRecorder(Native_Recorder_Callback* native_func)
- {
- m_func = native_func;
- }
-
- void Record(NativeContext& context, u32 action, float probability, string unique_key)
- {
- GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str()));
- IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle;
-
- m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer());
-
- uniqueKeyHandle.Free();
- }
-private:
- Native_Recorder_Callback* m_func;
-};
-
-// NativePolicy listens to callback event and reroute it to the managed Policy instance
-class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext>
-{
-public:
- NativePolicy(Native_Policy_Callback* func, int index = -1)
- {
- m_func = func;
- m_index = index;
- }
-
- u32 Choose_Action(NativeContext& context)
- {
- return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index);
- }
-
-private:
- Native_Policy_Callback* m_func;
- int m_index;
-};
-
-class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext>
-{
-public:
- NativeScorer(Native_Scorer_Callback* func)
- {
- m_func = func;
- }
-
- vector<float> Score_Actions(NativeContext& context)
- {
- float* scores = nullptr;
- u32 num_scores = 0;
- m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores);
-
- // It's ok if scores is null, vector will be empty
- vector<float> scores_vector(scores, scores + num_scores);
- delete[] scores;
-
- return scores_vector;
- }
-private:
- Native_Scorer_Callback* m_func;
-};
-
-// Triggers callback to the Policy instance to choose an action
-generic <class Ctx>
-public ref class PolicyCallback abstract
-{
-internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0;
-
- PolicyCallback()
- {
- policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke);
- IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback);
- m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer());
- m_native_policy = nullptr;
- m_native_policies = nullptr;
- }
-
- ~PolicyCallback()
- {
- delete m_native_policy;
- delete m_native_policies;
- }
-
- NativePolicy* GetNativePolicy()
- {
- if (m_native_policy == nullptr)
- {
- m_native_policy = new NativePolicy(m_callback);
- }
- return m_native_policy;
- }
-
- vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* GetNativePolicies(int count)
- {
- if (m_native_policies == nullptr)
- {
- m_native_policies = new vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>();
- for (int i = 0; i < count; i++)
- {
- m_native_policies->push_back(unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>(new NativePolicy(m_callback, i)));
- }
- }
-
- return m_native_policies;
- }
-
- static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index)
- {
- GCHandle callbackHandle = (GCHandle)callbackPtr;
- PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- return callback->InvokePolicyCallback(context, index);
- }
-
-private:
- ClrPolicyCallback^ policyCallback;
-
-private:
- NativePolicy* m_native_policy;
- vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies;
- Native_Policy_Callback* m_callback;
-};
-
-// Triggers callback to the Recorder instance to record interaction data
-generic <class Ctx>
-public ref class RecorderCallback abstract
-{
-internal:
- virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0;
-
- RecorderCallback()
- {
- recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke);
- IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback);
- Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer());
- m_native_recorder = new NativeRecorder(callback);
- }
-
- ~RecorderCallback()
- {
- delete m_native_recorder;
- }
-
- NativeRecorder* GetNativeRecorder()
- {
- return m_native_recorder;
- }
-
- static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr)
- {
- GCHandle mwtHandle = (GCHandle)mwtPtr;
- RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr;
- String^ uniqueKey = (String^)uniqueKeyHandle.Target;
-
- callback->InvokeRecorderCallback(context, action, probability, uniqueKey);
- }
-
-private:
- ClrRecorderCallback^ recorderCallback;
-
-private:
- NativeRecorder* m_native_recorder;
-};
-
-// Triggers callback to the Recorder instance to record interaction data
-generic <class Ctx>
-public ref class ScorerCallback abstract
-{
-internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) = 0;
-
- ScorerCallback()
- {
- scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke);
- IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback);
- Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer());
- m_native_scorer = new NativeScorer(callback);
- }
-
- ~ScorerCallback()
- {
- delete m_native_scorer;
- }
-
- NativeScorer* GetNativeScorer()
- {
- return m_native_scorer;
- }
-
- static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr)
- {
- GCHandle callbackHandle = (GCHandle)callbackPtr;
- ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- List<float>^ scoreList = callback->InvokeScorerCallback(context);
-
- if (scoreList == nullptr || scoreList->Count == 0)
- {
- return;
- }
-
- u32* num_scores = (u32*)sizePtr.ToPointer();
- *num_scores = (u32)scoreList->Count;
-
- float* scores = new float[*num_scores];
- for (u32 i = 0; i < *num_scores; i++)
- {
- scores[i] = scoreList[i];
- }
-
- float** native_scores = (float**)scoresPtr.ToPointer();
- *native_scores = scores;
- }
-
-private:
- ClrScorerCallback^ scorerCallback;
-
-private:
- NativeScorer* m_native_scorer;
-};
-
-// Triggers callback to the Context instance to perform ToString() operation
-generic <class Ctx> where Ctx : IStringContext
-public ref class ToStringCallback
-{
-internal:
- ToStringCallback()
- {
- toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke);
- IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback);
- m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer());
- }
-
- Native_To_String_Callback* GetCallback()
- {
- return m_callback;
- }
-
- static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr)
- {
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- string* out_string = (string*)stringPtr.ToPointer();
- *out_string = marshal_as<string>(context->ToString());
- }
-
-private:
- ClrToStringCallback^ toStringCallback;
-
-private:
- Native_To_String_Callback* m_callback;
-};
-
+#pragma once + +#define MANAGED_CODE + +#include "explore_interface.h" +#include "MWTExplorer.h" + +#include <msclr\marshal_cppstd.h> + +using namespace System; +using namespace System::Collections::Generic; +using namespace System::IO; +using namespace System::Runtime::InteropServices; +using namespace System::Xml::Serialization; +using namespace msclr::interop; + +namespace MultiWorldTesting { + +// Policy callback +private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index); +typedef u32 Native_Policy_Callback(void* explorer, void* context, int index); + +// Scorer callback +private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size); +typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size); + +// Recorder callback +private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey); +typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key); + +// ToString callback +private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue); +typedef void Native_To_String_Callback(void* explorer, void* string_value); + +// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context +// used for triggering callback for Policy, Scorer, Recorder +class NativeContext +{ +public: + NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context) + { + m_clr_mwt = clr_mwt; + m_clr_explorer = clr_explorer; + m_clr_context = clr_context; + } + + void* Get_Clr_Mwt() + { + return m_clr_mwt; + } + + void* Get_Clr_Context() + { + return m_clr_context; + } + + void* Get_Clr_Explorer() + { + return m_clr_explorer; + } + +private: + void* m_clr_mwt; + void* m_clr_context; + void* m_clr_explorer; +}; + +class NativeStringContext +{ +public: + NativeStringContext(void* clr_context, Native_To_String_Callback* func) + { + m_clr_context = clr_context; + m_func = func; + } + + string To_String() + { + string value; + m_func(m_clr_context, &value); + return value; + } +private: + void* m_clr_context; + Native_To_String_Callback* m_func; +}; + +// NativeRecorder listens to callback event and reroute it to the managed Recorder instance +class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext> +{ +public: + NativeRecorder(Native_Recorder_Callback* native_func) + { + m_func = native_func; + } + + void Record(NativeContext& context, u32 action, float probability, string unique_key) + { + GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str())); + IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle; + + m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer()); + + uniqueKeyHandle.Free(); + } +private: + Native_Recorder_Callback* m_func; +}; + +// NativePolicy listens to callback event and reroute it to the managed Policy instance +class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext> +{ +public: + NativePolicy(Native_Policy_Callback* func, int index = -1) + { + m_func = func; + m_index = index; + } + + u32 Choose_Action(NativeContext& context) + { + return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index); + } + +private: + Native_Policy_Callback* m_func; + int m_index; +}; + +class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext> +{ +public: + NativeScorer(Native_Scorer_Callback* func) + { + m_func = func; + } + + vector<float> Score_Actions(NativeContext& context) + { + float* scores = nullptr; + u32 num_scores = 0; + m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores); + + // It's ok if scores is null, vector will be empty + vector<float> scores_vector(scores, scores + num_scores); + delete[] scores; + + return scores_vector; + } +private: + Native_Scorer_Callback* m_func; +}; + +// Triggers callback to the Policy instance to choose an action +generic <class Ctx> +public ref class PolicyCallback abstract +{ +internal: + virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0; + + PolicyCallback() + { + policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke); + IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback); + m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer()); + m_native_policy = nullptr; + m_native_policies = nullptr; + } + + ~PolicyCallback() + { + delete m_native_policy; + delete m_native_policies; + } + + NativePolicy* GetNativePolicy() + { + if (m_native_policy == nullptr) + { + m_native_policy = new NativePolicy(m_callback); + } + return m_native_policy; + } + + vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* GetNativePolicies(int count) + { + if (m_native_policies == nullptr) + { + m_native_policies = new vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>(); + for (int i = 0; i < count; i++) + { + m_native_policies->push_back(unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>(new NativePolicy(m_callback, i))); + } + } + + return m_native_policies; + } + + static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index) + { + GCHandle callbackHandle = (GCHandle)callbackPtr; + PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target; + + GCHandle contextHandle = (GCHandle)contextPtr; + Ctx context = (Ctx)contextHandle.Target; + + return callback->InvokePolicyCallback(context, index); + } + +private: + ClrPolicyCallback^ policyCallback; + +private: + NativePolicy* m_native_policy; + vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies; + Native_Policy_Callback* m_callback; +}; + +// Triggers callback to the Recorder instance to record interaction data +generic <class Ctx> +public ref class RecorderCallback abstract +{ +internal: + virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0; + + RecorderCallback() + { + recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke); + IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback); + Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer()); + m_native_recorder = new NativeRecorder(callback); + } + + ~RecorderCallback() + { + delete m_native_recorder; + } + + NativeRecorder* GetNativeRecorder() + { + return m_native_recorder; + } + + static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr) + { + GCHandle mwtHandle = (GCHandle)mwtPtr; + RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target; + + GCHandle contextHandle = (GCHandle)contextPtr; + Ctx context = (Ctx)contextHandle.Target; + + GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr; + String^ uniqueKey = (String^)uniqueKeyHandle.Target; + + callback->InvokeRecorderCallback(context, action, probability, uniqueKey); + } + +private: + ClrRecorderCallback^ recorderCallback; + +private: + NativeRecorder* m_native_recorder; +}; + +// Triggers callback to the Recorder instance to record interaction data +generic <class Ctx> +public ref class ScorerCallback abstract +{ +internal: + virtual List<float>^ InvokeScorerCallback(Ctx context) = 0; + + ScorerCallback() + { + scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke); + IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback); + Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer()); + m_native_scorer = new NativeScorer(callback); + } + + ~ScorerCallback() + { + delete m_native_scorer; + } + + NativeScorer* GetNativeScorer() + { + return m_native_scorer; + } + + static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr) + { + GCHandle callbackHandle = (GCHandle)callbackPtr; + ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target; + + GCHandle contextHandle = (GCHandle)contextPtr; + Ctx context = (Ctx)contextHandle.Target; + + List<float>^ scoreList = callback->InvokeScorerCallback(context); + + if (scoreList == nullptr || scoreList->Count == 0) + { + return; + } + + u32* num_scores = (u32*)sizePtr.ToPointer(); + *num_scores = (u32)scoreList->Count; + + float* scores = new float[*num_scores]; + for (u32 i = 0; i < *num_scores; i++) + { + scores[i] = scoreList[i]; + } + + float** native_scores = (float**)scoresPtr.ToPointer(); + *native_scores = scores; + } + +private: + ClrScorerCallback^ scorerCallback; + +private: + NativeScorer* m_native_scorer; +}; + +// Triggers callback to the Context instance to perform ToString() operation +generic <class Ctx> where Ctx : IStringContext +public ref class ToStringCallback +{ +internal: + ToStringCallback() + { + toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke); + IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback); + m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer()); + } + + Native_To_String_Callback* GetCallback() + { + return m_callback; + } + + static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr) + { + GCHandle contextHandle = (GCHandle)contextPtr; + Ctx context = (Ctx)contextHandle.Target; + + string* out_string = (string*)stringPtr.ToPointer(); + *out_string = marshal_as<string>(context->ToString()); + } + +private: + ClrToStringCallback^ toStringCallback; + +private: + Native_To_String_Callback* m_callback; +}; + }
\ No newline at end of file diff --git a/explore/explore.cpp b/explore/explore.cpp index ebc1c14e..f44ac353 100644 --- a/explore/explore.cpp +++ b/explore/explore.cpp @@ -1,73 +1,73 @@ -// explore.cpp : Timing code to measure performance of MWT Explorer library
-
-#include "MWTExplorer.h"
-#include <chrono>
-#include <tuple>
-#include <iostream>
-
-using namespace std;
-using namespace std::chrono;
-
-using namespace MultiWorldTesting;
-
-class MySimplePolicy : public IPolicy<SimpleContext>
-{
-public:
- u32 Choose_Action(SimpleContext& context)
- {
- return (u32)1;
- }
-};
-
-const u32 num_actions = 10;
-
-void Clock_Explore()
-{
- float epsilon = .2f;
- string unique_key = "key";
- int num_features = 1000;
- int num_iter = 10000;
- int num_warmup = 100;
- int num_interactions = 1;
-
- // pre-create features
- vector<Feature> features;
- for (int i = 0; i < num_features; i++)
- {
- Feature f = {0.5, i+1};
- features.push_back(f);
- }
-
- long long time_init = 0, time_choose = 0;
- for (int iter = 0; iter < num_iter + num_warmup; iter++)
- {
- high_resolution_clock::time_point t1 = high_resolution_clock::now();
- StringRecorder<SimpleContext> recorder;
- MwtExplorer<SimpleContext> mwt("test", recorder);
- MySimplePolicy default_policy;
- EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions);
- high_resolution_clock::time_point t2 = high_resolution_clock::now();
- time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
-
- t1 = high_resolution_clock::now();
- SimpleContext appContext(features);
- for (int i = 0; i < num_interactions; i++)
- {
- mwt.Choose_Action(explorer, unique_key, appContext);
- }
- t2 = high_resolution_clock::now();
- time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
- }
-
- cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl;
- cout << "--- PER ITERATION ---" << endl;
- cout << "Init: " << (double)time_init / num_iter << " micro" << endl;
- cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl;
- cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl;
-}
-
-int main(int argc, char* argv[])
-{
- Clock_Explore();
- return 0;
-}
+// explore.cpp : Timing code to measure performance of MWT Explorer library + +#include "MWTExplorer.h" +#include <chrono> +#include <tuple> +#include <iostream> + +using namespace std; +using namespace std::chrono; + +using namespace MultiWorldTesting; + +class MySimplePolicy : public IPolicy<SimpleContext> +{ +public: + u32 Choose_Action(SimpleContext& context) + { + return (u32)1; + } +}; + +const u32 num_actions = 10; + +void Clock_Explore() +{ + float epsilon = .2f; + string unique_key = "key"; + int num_features = 1000; + int num_iter = 10000; + int num_warmup = 100; + int num_interactions = 1; + + // pre-create features + vector<Feature> features; + for (int i = 0; i < num_features; i++) + { + Feature f = {0.5, i+1}; + features.push_back(f); + } + + long long time_init = 0, time_choose = 0; + for (int iter = 0; iter < num_iter + num_warmup; iter++) + { + high_resolution_clock::time_point t1 = high_resolution_clock::now(); + StringRecorder<SimpleContext> recorder; + MwtExplorer<SimpleContext> mwt("test", recorder); + MySimplePolicy default_policy; + EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions); + high_resolution_clock::time_point t2 = high_resolution_clock::now(); + time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count(); + + t1 = high_resolution_clock::now(); + SimpleContext appContext(features); + for (int i = 0; i < num_interactions; i++) + { + mwt.Choose_Action(explorer, unique_key, appContext); + } + t2 = high_resolution_clock::now(); + time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count(); + } + + cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl; + cout << "--- PER ITERATION ---" << endl; + cout << "Init: " << (double)time_init / num_iter << " micro" << endl; + cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl; + cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl; +} + +int main(int argc, char* argv[]) +{ + Clock_Explore(); + return 0; +} diff --git a/explore/static/MWTExplorer.h b/explore/static/MWTExplorer.h index dd5e3044..4c7009b2 100644 --- a/explore/static/MWTExplorer.h +++ b/explore/static/MWTExplorer.h @@ -1,669 +1,669 @@ -//
-// Main interface for clients of the Multiworld testing (MWT) service.
-//
-
-#pragma once
-
-#include <stdexcept>
-#include <float.h>
-#include <math.h>
-#include <stdio.h>
-#include <string.h>
-#include <vector>
-#include <utility>
-#include <memory>
-#include <limits.h>
-#include <tuple>
-
-#ifdef MANAGED_CODE
-#define PORTING_INTERFACE public
-#define MWT_NAMESPACE namespace NativeMultiWorldTesting
-#else
-#define PORTING_INTERFACE private
-#define MWT_NAMESPACE namespace MultiWorldTesting
-#endif
-
-using namespace std;
-
-#include "utility.h"
-
-/** \defgroup MultiWorldTestingCpp
-\brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
-*/
-
-/*!
-* \addtogroup MultiWorldTestingCpp
-* @{
-*/
-
-//! Interface for C++ version of Multiworld Testing library.
-//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
-MWT_NAMESPACE {
-
-// Forward declarations
-template <class Ctx>
-class IRecorder;
-template <class Ctx>
-class IExplorer;
-
-///
-/// The top-level MwtExplorer class. Using this enables principled and efficient exploration
-/// over a set of possible actions, and ensures that the right bits are recorded.
-///
-template <class Ctx>
-class MwtExplorer
-{
-public:
- ///
- /// Constructor
- ///
- /// @param appid This should be unique to your experiment or you risk nasty correlation bugs.
- /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning.
- ///
- MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder)
- {
- m_app_id = HashUtils::Compute_Id_Hash(app_id);
- }
-
- ///
- /// Chooses an action by invoking an underlying exploration algorithm. This should be a
- /// drop-in replacement for any existing policy function.
- ///
- /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback.
- /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc..
- /// @param context The context upon which a decision is made. See SimpleContext below for an example.
- ///
- u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context)
- {
- u64 seed = HashUtils::Compute_Id_Hash(unique_key);
-
- std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context);
-
- u32 action = std::get<0>(action_probability_log_tuple);
- float prob = std::get<1>(action_probability_log_tuple);
-
- if (std::get<2>(action_probability_log_tuple))
- {
- m_recorder.Record(context, action, prob, unique_key);
- }
-
- return action;
- }
-
-private:
- u64 m_app_id;
- IRecorder<Ctx>& m_recorder;
-};
-
-///
-/// Exposes a method to record exploration data based on generic contexts. Exploration data
-/// is specified as a set of tuples <context, action, probability, key> as described below. An
-/// application passes an IRecorder object to the @MwtExplorer constructor. See
-/// @StringRecorder for a sample IRecorder object.
-///
-template <class Ctx>
-class IRecorder
-{
-public:
- ///
- /// Records the exploration data associated with a given decision.
- ///
- /// @param context A user-defined context for the decision
- /// @param action The action chosen by an exploration algorithm given context
- /// @param probability The probability the exploration algorithm chose said action
- /// @param unique_key A user-defined unique identifer for the decision
- ///
- virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0;
-};
-
-///
-/// Exposes a method to choose an action given a generic context, and obtain the relevant
-/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this
-/// interface yourself: instead, use the various exploration algorithms below, which
-/// implement it for you.
-///
-template <class Ctx>
-class IExplorer
-{
-public:
- ///
- /// Determines the action to take and the probability with which it was chosen, for a
- /// given context.
- ///
- /// @param salted_seed A PRG seed based on a unique id information provided by the user
- /// @param context A user-defined context for the decision
- /// @returns The action to take, the probability it was chosen, and a flag indicating
- /// whether to record this decision
- ///
- virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0;
-};
-
-///
-/// Exposes a method to choose an action given a generic context. IPolicy objects are
-/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
-///
-template <class Ctx>
-class IPolicy
-{
-public:
- ///
- /// Determines the action to take for a given context.
- ///
- /// @param context A user-defined context for the decision
- /// @returns The action to take (1-based index)
- ///
- virtual u32 Choose_Action(Ctx& context) = 0;
-};
-
-///
-/// Exposes a method for specifying a score (weight) for each action given a generic context.
-///
-template <class Ctx>
-class IScorer
-{
-public:
- ///
- /// Determines the score of each action for a given context.
- ///
- /// @param context A user-defined context for the decision
- /// @returns A vector of scores indexed by action (1-based)
- ///
- virtual vector<float> Score_Actions(Ctx& context) = 0;
-};
-
-///
-/// A sample recorder class that converts the exploration tuple into string format.
-///
-template <class Ctx>
-struct StringRecorder : public IRecorder<Ctx>
-{
- void Record(Ctx& context, u32 action, float probability, string unique_key)
- {
- // Implicitly enforce To_String() API on the context
- m_recording.append(to_string((unsigned long)action));
- m_recording.append(" ", 1);
- m_recording.append(unique_key);
- m_recording.append(" ", 1);
-
- char prob_str[10] = { 0 };
- NumberUtils::Float_To_String(probability, prob_str);
- m_recording.append(prob_str);
-
- m_recording.append(" | ", 3);
- m_recording.append(context.To_String());
- m_recording.append("\n");
- }
-
- // Gets the content of the recording so far as a string and optionally clears internal content.
- string Get_Recording(bool flush = true)
- {
- if (!flush)
- {
- return m_recording;
- }
- string recording = m_recording;
- m_recording.clear();
- return recording;
- }
-
-private:
- string m_recording;
-};
-
-///
-/// Represents a feature in a sparse array.
-///
-struct Feature
-{
- float Value;
- u32 Id;
-
- bool operator==(Feature other_feature)
- {
- return Id == other_feature.Id;
- }
-};
-
-///
-/// A sample context class that stores a vector of Features.
-///
-class SimpleContext
-{
-public:
- SimpleContext(vector<Feature>& features) :
- m_features(features)
- { }
-
- vector<Feature>& Get_Features()
- {
- return m_features;
- }
-
- string To_String()
- {
- string out_string;
- char feature_str[35] = { 0 };
- for (size_t i = 0; i < m_features.size(); i++)
- {
- int chars;
- if (i == 0)
- {
- chars = sprintf(feature_str, "%d:", m_features[i].Id);
- }
- else
- {
- chars = sprintf(feature_str, " %d:", m_features[i].Id);
- }
- NumberUtils::print_float(feature_str + chars, m_features[i].Value);
- out_string.append(feature_str);
- }
- return out_string;
- }
-
-private:
- vector<Feature>& m_features;
-};
-
-///
-/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea
-/// which actions should be preferred. Epsilon greedy is also computationally cheap.
-///
-template <class Ctx>
-class EpsilonGreedyExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy A default function which outputs an action given a context.
- /// @param epsilon The probability of a random exploration.
- /// @param num_actions The number of actions to randomize over.
- ///
- EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) :
- m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
-
- if (m_epsilon < 0 || m_epsilon > 1)
- {
- throw std::invalid_argument("Epsilon must be between 0 and 1.");
- }
- }
-
- ~EpsilonGreedyExplorer()
- {
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default policy function to get the action
- u32 chosen_action = m_default_policy.Choose_Action(context);
-
- if (chosen_action == 0 || chosen_action > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- float action_probability = 0.f;
- float base_probability = m_epsilon / m_num_actions; // uniform probability
-
- // TODO: check this random generation
- if (random_generator.Uniform_Unit_Interval() < 1.f - m_epsilon)
- {
- action_probability = 1.f - m_epsilon + base_probability;
- }
- else
- {
- // Get uniform random action ID
- u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
-
- if (actionId == chosen_action)
- {
- // IF it matches the one chosen by the default policy
- // then increase the probability
- action_probability = 1.f - m_epsilon + base_probability;
- }
- else
- {
- // Otherwise it's just the uniform probability
- action_probability = base_probability;
- }
- chosen_action = actionId;
- }
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
- }
-
-private:
- IPolicy<Ctx>& m_default_policy;
- float m_epsilon;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// In some cases, different actions have a different scores, and you would prefer to
-/// choose actions with large scores. Softmax allows you to do that.
-///
-template <class Ctx>
-class SoftmaxExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_scorer A function which outputs a score for each action.
- /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.
- /// @param num_actions The number of actions to randomize over.
- ///
- SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) :
- m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default scorer function
- vector<float> scores = m_default_scorer.Score_Actions(context);
- u32 num_scores = (u32)scores.size();
- if (num_scores != m_num_actions)
- {
- throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions");
- }
-
- u32 i = 0;
-
- float max_score = -FLT_MAX;
- for (i = 0; i < num_scores; i++)
- {
- if (max_score < scores[i])
- {
- max_score = scores[i];
- }
- }
-
- // Create a normalized exponential distribution based on the returned scores
- for (i = 0; i < num_scores; i++)
- {
- scores[i] = exp(m_lambda * (scores[i] - max_score));
- }
-
- // Create a discrete_distribution based on the returned weights. This class handles the
- // case where the sum of the weights is < or > 1, by normalizing agains the sum.
- float total = 0.f;
- for (size_t i = 0; i < num_scores; i++)
- total += scores[i];
-
- float draw = random_generator.Uniform_Unit_Interval();
-
- float sum = 0.f;
- float action_probability = 0.f;
- u32 action_index = num_scores - 1;
- for (u32 i = 0; i < num_scores; i++)
- {
- scores[i] = scores[i] / total;
- sum += scores[i];
- if (sum > draw)
- {
- action_index = i;
- action_probability = scores[i];
- break;
- }
- }
-
- // action id is one-based
- return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
- }
-
-private:
- IScorer<Ctx>& m_default_scorer;
- float m_lambda;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// GenericExplorer provides complete flexibility. You can create any
-/// distribution over actions desired, and it will draw from that.
-///
-template <class Ctx>
-class GenericExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_scorer A function which outputs the probability of each action.
- /// @param num_actions The number of actions to randomize over.
- ///
- GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) :
- m_default_scorer(default_scorer), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default scorer function
- vector<float> weights = m_default_scorer.Score_Actions(context);
- u32 num_weights = (u32)weights.size();
- if (num_weights != m_num_actions)
- {
- throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions");
- }
-
- // Create a discrete_distribution based on the returned weights. This class handles the
- // case where the sum of the weights is < or > 1, by normalizing agains the sum.
- float total = 0.f;
- for (size_t i = 0; i < num_weights; i++)
- {
- if (weights[i] < 0)
- {
- throw std::invalid_argument("Scores must be non-negative.");
- }
- total += weights[i];
- }
- if (total == 0)
- {
- throw std::invalid_argument("At least one score must be positive.");
- }
-
- float draw = random_generator.Uniform_Unit_Interval();
-
- float sum = 0.f;
- float action_probability = 0.f;
- u32 action_index = num_weights - 1;
- for (u32 i = 0; i < num_weights; i++)
- {
- weights[i] = weights[i] / total;
- sum += weights[i];
- if (sum > draw)
- {
- action_index = i;
- action_probability = weights[i];
- break;
- }
- }
-
- // action id is one-based
- return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
- }
-
-private:
- IScorer<Ctx>& m_default_scorer;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// The tau-first explorer collects exactly tau uniform random exploration events, and then
-/// uses the default policy thereafter.
-///
-template <class Ctx>
-class TauFirstExplorer : public IExplorer<Ctx>
-{
-public:
-
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy A default policy after randomization finishes.
- /// @param tau The number of events to be uniform over.
- /// @param num_actions The number of actions to randomize over.
- ///
- TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) :
- m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- u32 chosen_action = 0;
- float action_probability = 0.f;
- bool log_action;
- if (m_tau)
- {
- m_tau--;
- u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
- action_probability = 1.f / m_num_actions;
- chosen_action = actionId;
- log_action = true;
- }
- else
- {
- // Invoke the default policy function to get the action
- chosen_action = m_default_policy.Choose_Action(context);
-
- if (chosen_action == 0 || chosen_action > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- action_probability = 1.f;
- log_action = false;
- }
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action);
- }
-
-private:
- IPolicy<Ctx>& m_default_policy;
- u32 m_tau;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies.
-/// This performs well statistically but can be computationally expensive.
-///
-template <class Ctx>
-class BootstrapExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy_functions A set of default policies to be uniform random over.
- /// The policy pointers must be valid throughout the lifetime of this explorer.
- /// @param num_actions The number of actions to randomize over.
- ///
- BootstrapExplorer(vector<unique_ptr<IPolicy<Ctx>>>& default_policy_functions, u32 num_actions) :
- m_default_policy_functions(default_policy_functions),
- m_num_actions(num_actions)
- {
- m_bags = (u32)default_policy_functions.size();
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
-
- if (m_bags < 1)
- {
- throw std::invalid_argument("Number of bags must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Select bag
- u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1);
-
- // Invoke the default policy function to get the action
- u32 chosen_action = 0;
- u32 action_from_bag = 0;
- vector<u32> actions_selected;
- for (size_t i = 0; i < m_num_actions; i++)
- {
- actions_selected.push_back(0);
- }
-
- // Invoke the default policy function to get the action
- for (u32 current_bag = 0; current_bag < m_bags; current_bag++)
- {
- action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context);
-
- if (action_from_bag == 0 || action_from_bag > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- if (current_bag == chosen_bag)
- {
- chosen_action = action_from_bag;
- }
- //this won't work if actions aren't 0 to Count
- actions_selected[action_from_bag - 1]++; // action id is one-based
- }
- float action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
- }
-
-private:
- vector<unique_ptr<IPolicy<Ctx>>>& m_default_policy_functions;
- u32 m_bags;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-} // End namespace MultiWorldTestingCpp
-/*! @} End of Doxygen Groups*/
+// +// Main interface for clients of the Multiworld testing (MWT) service. +// + +#pragma once + +#include <stdexcept> +#include <float.h> +#include <math.h> +#include <stdio.h> +#include <string.h> +#include <vector> +#include <utility> +#include <memory> +#include <limits.h> +#include <tuple> + +#ifdef MANAGED_CODE +#define PORTING_INTERFACE public +#define MWT_NAMESPACE namespace NativeMultiWorldTesting +#else +#define PORTING_INTERFACE private +#define MWT_NAMESPACE namespace MultiWorldTesting +#endif + +using namespace std; + +#include "utility.h" + +/** \defgroup MultiWorldTestingCpp +\brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp +*/ + +/*! +* \addtogroup MultiWorldTestingCpp +* @{ +*/ + +//! Interface for C++ version of Multiworld Testing library. +//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp +MWT_NAMESPACE { + +// Forward declarations +template <class Ctx> +class IRecorder; +template <class Ctx> +class IExplorer; + +/// +/// The top-level MwtExplorer class. Using this enables principled and efficient exploration +/// over a set of possible actions, and ensures that the right bits are recorded. +/// +template <class Ctx> +class MwtExplorer +{ +public: + /// + /// Constructor + /// + /// @param appid This should be unique to your experiment or you risk nasty correlation bugs. + /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning. + /// + MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder) + { + m_app_id = HashUtils::Compute_Id_Hash(app_id); + } + + /// + /// Chooses an action by invoking an underlying exploration algorithm. This should be a + /// drop-in replacement for any existing policy function. + /// + /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback. + /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc.. + /// @param context The context upon which a decision is made. See SimpleContext below for an example. + /// + u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context) + { + u64 seed = HashUtils::Compute_Id_Hash(unique_key); + + std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context); + + u32 action = std::get<0>(action_probability_log_tuple); + float prob = std::get<1>(action_probability_log_tuple); + + if (std::get<2>(action_probability_log_tuple)) + { + m_recorder.Record(context, action, prob, unique_key); + } + + return action; + } + +private: + u64 m_app_id; + IRecorder<Ctx>& m_recorder; +}; + +/// +/// Exposes a method to record exploration data based on generic contexts. Exploration data +/// is specified as a set of tuples <context, action, probability, key> as described below. An +/// application passes an IRecorder object to the @MwtExplorer constructor. See +/// @StringRecorder for a sample IRecorder object. +/// +template <class Ctx> +class IRecorder +{ +public: + /// + /// Records the exploration data associated with a given decision. + /// + /// @param context A user-defined context for the decision + /// @param action The action chosen by an exploration algorithm given context + /// @param probability The probability the exploration algorithm chose said action + /// @param unique_key A user-defined unique identifer for the decision + /// + virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0; +}; + +/// +/// Exposes a method to choose an action given a generic context, and obtain the relevant +/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this +/// interface yourself: instead, use the various exploration algorithms below, which +/// implement it for you. +/// +template <class Ctx> +class IExplorer +{ +public: + /// + /// Determines the action to take and the probability with which it was chosen, for a + /// given context. + /// + /// @param salted_seed A PRG seed based on a unique id information provided by the user + /// @param context A user-defined context for the decision + /// @returns The action to take, the probability it was chosen, and a flag indicating + /// whether to record this decision + /// + virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0; +}; + +/// +/// Exposes a method to choose an action given a generic context. IPolicy objects are +/// passed to (and invoked by) exploration algorithms to specify the default policy behavior. +/// +template <class Ctx> +class IPolicy +{ +public: + /// + /// Determines the action to take for a given context. + /// + /// @param context A user-defined context for the decision + /// @returns The action to take (1-based index) + /// + virtual u32 Choose_Action(Ctx& context) = 0; +}; + +/// +/// Exposes a method for specifying a score (weight) for each action given a generic context. +/// +template <class Ctx> +class IScorer +{ +public: + /// + /// Determines the score of each action for a given context. + /// + /// @param context A user-defined context for the decision + /// @returns A vector of scores indexed by action (1-based) + /// + virtual vector<float> Score_Actions(Ctx& context) = 0; +}; + +/// +/// A sample recorder class that converts the exploration tuple into string format. +/// +template <class Ctx> +struct StringRecorder : public IRecorder<Ctx> +{ + void Record(Ctx& context, u32 action, float probability, string unique_key) + { + // Implicitly enforce To_String() API on the context + m_recording.append(to_string((unsigned long)action)); + m_recording.append(" ", 1); + m_recording.append(unique_key); + m_recording.append(" ", 1); + + char prob_str[10] = { 0 }; + NumberUtils::Float_To_String(probability, prob_str); + m_recording.append(prob_str); + + m_recording.append(" | ", 3); + m_recording.append(context.To_String()); + m_recording.append("\n"); + } + + // Gets the content of the recording so far as a string and optionally clears internal content. + string Get_Recording(bool flush = true) + { + if (!flush) + { + return m_recording; + } + string recording = m_recording; + m_recording.clear(); + return recording; + } + +private: + string m_recording; +}; + +/// +/// Represents a feature in a sparse array. +/// +struct Feature +{ + float Value; + u32 Id; + + bool operator==(Feature other_feature) + { + return Id == other_feature.Id; + } +}; + +/// +/// A sample context class that stores a vector of Features. +/// +class SimpleContext +{ +public: + SimpleContext(vector<Feature>& features) : + m_features(features) + { } + + vector<Feature>& Get_Features() + { + return m_features; + } + + string To_String() + { + string out_string; + char feature_str[35] = { 0 }; + for (size_t i = 0; i < m_features.size(); i++) + { + int chars; + if (i == 0) + { + chars = sprintf(feature_str, "%d:", m_features[i].Id); + } + else + { + chars = sprintf(feature_str, " %d:", m_features[i].Id); + } + NumberUtils::print_float(feature_str + chars, m_features[i].Value); + out_string.append(feature_str); + } + return out_string; + } + +private: + vector<Feature>& m_features; +}; + +/// +/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea +/// which actions should be preferred. Epsilon greedy is also computationally cheap. +/// +template <class Ctx> +class EpsilonGreedyExplorer : public IExplorer<Ctx> +{ +public: + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_policy A default function which outputs an action given a context. + /// @param epsilon The probability of a random exploration. + /// @param num_actions The number of actions to randomize over. + /// + EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) : + m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions) + { + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + + if (m_epsilon < 0 || m_epsilon > 1) + { + throw std::invalid_argument("Epsilon must be between 0 and 1."); + } + } + + ~EpsilonGreedyExplorer() + { + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + // Invoke the default policy function to get the action + u32 chosen_action = m_default_policy.Choose_Action(context); + + if (chosen_action == 0 || chosen_action > m_num_actions) + { + throw std::invalid_argument("Action chosen by default policy is not within valid range."); + } + + float action_probability = 0.f; + float base_probability = m_epsilon / m_num_actions; // uniform probability + + // TODO: check this random generation + if (random_generator.Uniform_Unit_Interval() < 1.f - m_epsilon) + { + action_probability = 1.f - m_epsilon + base_probability; + } + else + { + // Get uniform random action ID + u32 actionId = random_generator.Uniform_Int(1, m_num_actions); + + if (actionId == chosen_action) + { + // IF it matches the one chosen by the default policy + // then increase the probability + action_probability = 1.f - m_epsilon + base_probability; + } + else + { + // Otherwise it's just the uniform probability + action_probability = base_probability; + } + chosen_action = actionId; + } + + return std::tuple<u32, float, bool>(chosen_action, action_probability, true); + } + +private: + IPolicy<Ctx>& m_default_policy; + float m_epsilon; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; + +/// +/// In some cases, different actions have a different scores, and you would prefer to +/// choose actions with large scores. Softmax allows you to do that. +/// +template <class Ctx> +class SoftmaxExplorer : public IExplorer<Ctx> +{ +public: + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_scorer A function which outputs a score for each action. + /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max. + /// @param num_actions The number of actions to randomize over. + /// + SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) : + m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions) + { + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + // Invoke the default scorer function + vector<float> scores = m_default_scorer.Score_Actions(context); + u32 num_scores = (u32)scores.size(); + if (num_scores != m_num_actions) + { + throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions"); + } + + u32 i = 0; + + float max_score = -FLT_MAX; + for (i = 0; i < num_scores; i++) + { + if (max_score < scores[i]) + { + max_score = scores[i]; + } + } + + // Create a normalized exponential distribution based on the returned scores + for (i = 0; i < num_scores; i++) + { + scores[i] = exp(m_lambda * (scores[i] - max_score)); + } + + // Create a discrete_distribution based on the returned weights. This class handles the + // case where the sum of the weights is < or > 1, by normalizing agains the sum. + float total = 0.f; + for (size_t i = 0; i < num_scores; i++) + total += scores[i]; + + float draw = random_generator.Uniform_Unit_Interval(); + + float sum = 0.f; + float action_probability = 0.f; + u32 action_index = num_scores - 1; + for (u32 i = 0; i < num_scores; i++) + { + scores[i] = scores[i] / total; + sum += scores[i]; + if (sum > draw) + { + action_index = i; + action_probability = scores[i]; + break; + } + } + + // action id is one-based + return std::tuple<u32, float, bool>(action_index + 1, action_probability, true); + } + +private: + IScorer<Ctx>& m_default_scorer; + float m_lambda; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; + +/// +/// GenericExplorer provides complete flexibility. You can create any +/// distribution over actions desired, and it will draw from that. +/// +template <class Ctx> +class GenericExplorer : public IExplorer<Ctx> +{ +public: + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_scorer A function which outputs the probability of each action. + /// @param num_actions The number of actions to randomize over. + /// + GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) : + m_default_scorer(default_scorer), m_num_actions(num_actions) + { + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + // Invoke the default scorer function + vector<float> weights = m_default_scorer.Score_Actions(context); + u32 num_weights = (u32)weights.size(); + if (num_weights != m_num_actions) + { + throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions"); + } + + // Create a discrete_distribution based on the returned weights. This class handles the + // case where the sum of the weights is < or > 1, by normalizing agains the sum. + float total = 0.f; + for (size_t i = 0; i < num_weights; i++) + { + if (weights[i] < 0) + { + throw std::invalid_argument("Scores must be non-negative."); + } + total += weights[i]; + } + if (total == 0) + { + throw std::invalid_argument("At least one score must be positive."); + } + + float draw = random_generator.Uniform_Unit_Interval(); + + float sum = 0.f; + float action_probability = 0.f; + u32 action_index = num_weights - 1; + for (u32 i = 0; i < num_weights; i++) + { + weights[i] = weights[i] / total; + sum += weights[i]; + if (sum > draw) + { + action_index = i; + action_probability = weights[i]; + break; + } + } + + // action id is one-based + return std::tuple<u32, float, bool>(action_index + 1, action_probability, true); + } + +private: + IScorer<Ctx>& m_default_scorer; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; + +/// +/// The tau-first explorer collects exactly tau uniform random exploration events, and then +/// uses the default policy thereafter. +/// +template <class Ctx> +class TauFirstExplorer : public IExplorer<Ctx> +{ +public: + + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_policy A default policy after randomization finishes. + /// @param tau The number of events to be uniform over. + /// @param num_actions The number of actions to randomize over. + /// + TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) : + m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions) + { + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + u32 chosen_action = 0; + float action_probability = 0.f; + bool log_action; + if (m_tau) + { + m_tau--; + u32 actionId = random_generator.Uniform_Int(1, m_num_actions); + action_probability = 1.f / m_num_actions; + chosen_action = actionId; + log_action = true; + } + else + { + // Invoke the default policy function to get the action + chosen_action = m_default_policy.Choose_Action(context); + + if (chosen_action == 0 || chosen_action > m_num_actions) + { + throw std::invalid_argument("Action chosen by default policy is not within valid range."); + } + + action_probability = 1.f; + log_action = false; + } + + return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action); + } + +private: + IPolicy<Ctx>& m_default_policy; + u32 m_tau; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; + +/// +/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies. +/// This performs well statistically but can be computationally expensive. +/// +template <class Ctx> +class BootstrapExplorer : public IExplorer<Ctx> +{ +public: + /// + /// The constructor is the only public member, because this should be used with the MwtExplorer. + /// + /// @param default_policy_functions A set of default policies to be uniform random over. + /// The policy pointers must be valid throughout the lifetime of this explorer. + /// @param num_actions The number of actions to randomize over. + /// + BootstrapExplorer(vector<unique_ptr<IPolicy<Ctx>>>& default_policy_functions, u32 num_actions) : + m_default_policy_functions(default_policy_functions), + m_num_actions(num_actions) + { + m_bags = (u32)default_policy_functions.size(); + if (m_num_actions < 1) + { + throw std::invalid_argument("Number of actions must be at least 1."); + } + + if (m_bags < 1) + { + throw std::invalid_argument("Number of bags must be at least 1."); + } + } + +private: + std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) + { + PRG::prg random_generator(salted_seed); + + // Select bag + u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1); + + // Invoke the default policy function to get the action + u32 chosen_action = 0; + u32 action_from_bag = 0; + vector<u32> actions_selected; + for (size_t i = 0; i < m_num_actions; i++) + { + actions_selected.push_back(0); + } + + // Invoke the default policy function to get the action + for (u32 current_bag = 0; current_bag < m_bags; current_bag++) + { + action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context); + + if (action_from_bag == 0 || action_from_bag > m_num_actions) + { + throw std::invalid_argument("Action chosen by default policy is not within valid range."); + } + + if (current_bag == chosen_bag) + { + chosen_action = action_from_bag; + } + //this won't work if actions aren't 0 to Count + actions_selected[action_from_bag - 1]++; // action id is one-based + } + float action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based + + return std::tuple<u32, float, bool>(chosen_action, action_probability, true); + } + +private: + vector<unique_ptr<IPolicy<Ctx>>>& m_default_policy_functions; + u32 m_bags; + u32 m_num_actions; + +private: + friend class MwtExplorer<Ctx>; +}; +} // End namespace MultiWorldTestingCpp +/*! @} End of Doxygen Groups*/ diff --git a/explore/static/utility.h b/explore/static/utility.h index a6284809..6a2a60d6 100644 --- a/explore/static/utility.h +++ b/explore/static/utility.h @@ -1,286 +1,286 @@ /*******************************************************************/ // Classes declared in this file are intended for internal use only. -/*******************************************************************/
-
-#pragma once
-#include <stdint.h>
-#include <sys/types.h> /* defines size_t */
-
-#ifdef WIN32
-typedef unsigned __int64 u64;
-typedef unsigned __int32 u32;
-typedef unsigned __int16 u16;
-typedef unsigned __int8 u8;
-typedef signed __int64 i64;
-typedef signed __int32 i32;
-typedef signed __int16 i16;
-typedef signed __int8 i8;
-#else
-typedef uint64_t u64;
-typedef uint32_t u32;
-typedef uint16_t u16;
-typedef uint8_t u8;
-typedef int64_t i64;
-typedef int32_t i32;
-typedef int16_t i16;
-typedef int8_t i8;
-#endif
-
-typedef unsigned char byte;
-
-#include <string>
-#include <stdint.h>
-#include <math.h>
-
-/*!
-* \addtogroup MultiWorldTestingCpp
-* @{
-*/
-
-MWT_NAMESPACE {
-
-//
-// MurmurHash3, by Austin Appleby
-//
-// Originals at:
-// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp
-// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.h
-//
-// Notes:
-// 1) this code assumes we can read a 4-byte value from any address
-// without crashing (i.e non aligned access is supported). This is
-// not a problem on Intel/x86/AMD64 machines (including new Macs)
-// 2) It produces different results on little-endian and big-endian machines.
-//
-//-----------------------------------------------------------------------------
-// MurmurHash3 was written by Austin Appleby, and is placed in the public
-// domain. The author hereby disclaims copyright to this source code.
-
-// Note - The x86 and x64 versions do _not_ produce the same results, as the
-// algorithms are optimized for their respective platforms. You can still
-// compile and run any of them on any platform, but your performance with the
-// non-native version will be less than optimal.
-//-----------------------------------------------------------------------------
-
-// Platform-specific functions and macros
-#if defined(_MSC_VER) // Microsoft Visual Studio
-# include <stdint.h>
-
-# include <stdlib.h>
-# define ROTL32(x,y) _rotl(x,y)
-# define BIG_CONSTANT(x) (x)
-
-#else // Other compilers
-# include <stdint.h> /* defines uint32_t etc */
-
- inline uint32_t rotl32(uint32_t x, int8_t r)
- {
- return (x << r) | (x >> (32 - r));
- }
-
-# define ROTL32(x,y) rotl32(x,y)
-# define BIG_CONSTANT(x) (x##LLU)
-
-#endif // !defined(_MSC_VER)
-
-struct murmur_hash {
-
- //-----------------------------------------------------------------------------
- // Block read - if your platform needs to do endian-swapping or can only
- // handle aligned reads, do the conversion here
-private:
- static inline uint32_t getblock(const uint32_t * p, int i)
- {
- return p[i];
- }
-
- //-----------------------------------------------------------------------------
- // Finalization mix - force all bits of a hash block to avalanche
-
- static inline uint32_t fmix(uint32_t h)
- {
- h ^= h >> 16;
- h *= 0x85ebca6b;
- h ^= h >> 13;
- h *= 0xc2b2ae35;
- h ^= h >> 16;
-
- return h;
- }
-
- //-----------------------------------------------------------------------------
-public:
- uint32_t uniform_hash(const void * key, size_t len, uint32_t seed)
- {
- const uint8_t * data = (const uint8_t*)key;
- const int nblocks = (int)len / 4;
-
- uint32_t h1 = seed;
-
- const uint32_t c1 = 0xcc9e2d51;
- const uint32_t c2 = 0x1b873593;
-
- // --- body
- const uint32_t * blocks = (const uint32_t *)(data + nblocks * 4);
-
- for (int i = -nblocks; i; i++) {
- uint32_t k1 = getblock(blocks, i);
-
- k1 *= c1;
- k1 = ROTL32(k1, 15);
- k1 *= c2;
-
- h1 ^= k1;
- h1 = ROTL32(h1, 13);
- h1 = h1 * 5 + 0xe6546b64;
- }
-
- // --- tail
- const uint8_t * tail = (const uint8_t*)(data + nblocks * 4);
-
- uint32_t k1 = 0;
-
- switch (len & 3) {
- case 3: k1 ^= tail[2] << 16;
- case 2: k1 ^= tail[1] << 8;
- case 1: k1 ^= tail[0];
- k1 *= c1; k1 = ROTL32(k1, 15); k1 *= c2; h1 ^= k1;
- }
-
- // --- finalization
- h1 ^= len;
-
- return fmix(h1);
- }
-};
-
-class HashUtils
-{
-public:
- static u64 Compute_Id_Hash(const std::string& unique_id)
- {
- size_t ret = 0;
- const char *p = unique_id.c_str();
- while (*p != '\0')
- if (*p >= '0' && *p <= '9')
- ret = 10 * ret + *(p++) - '0';
- else
- {
- murmur_hash foo;
- return foo.uniform_hash(unique_id.c_str(), unique_id.size(), 0);
- }
- return ret;
- }
-};
-
-const size_t max_int = 100000;
-const float max_float = max_int;
-const float min_float = 0.00001f;
-const size_t max_digits = (size_t) roundf((float) (-log(min_float) / log(10.)));
-
-class NumberUtils
-{
-public:
- static void Float_To_String(float f, char* str)
- {
- int x = (int)f;
- int d = (int)(fabs(f - x) * 100000);
- sprintf(str, "%d.%05d", x, d);
- }
-
- template<bool trailing_zeros>
- static void print_mantissa(char*& begin, float f)
- { // helper for print_float
- char values[10];
- size_t v = (size_t)f;
- size_t digit = 0;
- size_t first_nonzero = 0;
- for (size_t max = 1; max <= v; ++digit)
- {
- size_t max_next = max * 10;
- char v_mod = (char) (v % max_next / max);
- if (!trailing_zeros && v_mod != '\0' && first_nonzero == 0)
- first_nonzero = digit;
- values[digit] = '0' + v_mod;
- max = max_next;
- }
- if (!trailing_zeros)
- for (size_t i = max_digits; i > digit; i--)
- *begin++ = '0';
- while (digit > first_nonzero)
- *begin++ = values[--digit];
- }
-
- static void print_float(char* begin, float f)
- {
- bool sign = false;
- if (f < 0.f)
- sign = true;
- float unsigned_f = fabsf(f);
- if (unsigned_f < max_float && unsigned_f > min_float)
- {
- if (sign)
- *begin++ = '-';
- print_mantissa<true>(begin, unsigned_f);
- unsigned_f -= (size_t)unsigned_f;
- unsigned_f *= max_int;
- if (unsigned_f >= 1.f)
- {
- *begin++ = '.';
- print_mantissa<false>(begin, unsigned_f);
- }
- }
- else if (unsigned_f == 0.)
- *begin++ = '0';
- else
- {
- sprintf(begin, "%g", f);
- return;
- }
- *begin = '\0';
- return;
- }
-};
-
-//A quick implementation similar to drand48 for cross-platform compatibility
-namespace PRG {
- const uint64_t a = 0xeece66d5deece66dULL;
- const uint64_t c = 2147483647;
-
- const int bias = 127 << 23;
-
- union int_float {
- int32_t i;
- float f;
- };
-
- struct prg {
- private:
- uint64_t v;
- public:
- prg() { v = c; }
- prg(uint64_t initial) { v = initial; }
-
- float merand48(uint64_t& initial)
- {
- initial = a * initial + c;
- int_float temp;
- temp.i = ((initial >> 25) & 0x7FFFFF) | bias;
- return temp.f - 1;
- }
-
- float Uniform_Unit_Interval()
- {
- return merand48(v);
- }
-
- uint32_t Uniform_Int(uint32_t low, uint32_t high)
- {
- merand48(v);
- uint32_t ret = low + ((v >> 25) % (high - low + 1));
- return ret;
- }
- };
-}
-}
+/*******************************************************************/ + +#pragma once +#include <stdint.h> +#include <sys/types.h> /* defines size_t */ + +#ifdef WIN32 +typedef unsigned __int64 u64; +typedef unsigned __int32 u32; +typedef unsigned __int16 u16; +typedef unsigned __int8 u8; +typedef signed __int64 i64; +typedef signed __int32 i32; +typedef signed __int16 i16; +typedef signed __int8 i8; +#else +typedef uint64_t u64; +typedef uint32_t u32; +typedef uint16_t u16; +typedef uint8_t u8; +typedef int64_t i64; +typedef int32_t i32; +typedef int16_t i16; +typedef int8_t i8; +#endif + +typedef unsigned char byte; + +#include <string> +#include <stdint.h> +#include <math.h> + +/*! +* \addtogroup MultiWorldTestingCpp +* @{ +*/ + +MWT_NAMESPACE { + +// +// MurmurHash3, by Austin Appleby +// +// Originals at: +// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp +// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.h +// +// Notes: +// 1) this code assumes we can read a 4-byte value from any address +// without crashing (i.e non aligned access is supported). This is +// not a problem on Intel/x86/AMD64 machines (including new Macs) +// 2) It produces different results on little-endian and big-endian machines. +// +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. + +// Note - The x86 and x64 versions do _not_ produce the same results, as the +// algorithms are optimized for their respective platforms. You can still +// compile and run any of them on any platform, but your performance with the +// non-native version will be less than optimal. +//----------------------------------------------------------------------------- + +// Platform-specific functions and macros +#if defined(_MSC_VER) // Microsoft Visual Studio +# include <stdint.h> + +# include <stdlib.h> +# define ROTL32(x,y) _rotl(x,y) +# define BIG_CONSTANT(x) (x) + +#else // Other compilers +# include <stdint.h> /* defines uint32_t etc */ + + inline uint32_t rotl32(uint32_t x, int8_t r) + { + return (x << r) | (x >> (32 - r)); + } + +# define ROTL32(x,y) rotl32(x,y) +# define BIG_CONSTANT(x) (x##LLU) + +#endif // !defined(_MSC_VER) + +struct murmur_hash { + + //----------------------------------------------------------------------------- + // Block read - if your platform needs to do endian-swapping or can only + // handle aligned reads, do the conversion here +private: + static inline uint32_t getblock(const uint32_t * p, int i) + { + return p[i]; + } + + //----------------------------------------------------------------------------- + // Finalization mix - force all bits of a hash block to avalanche + + static inline uint32_t fmix(uint32_t h) + { + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + + return h; + } + + //----------------------------------------------------------------------------- +public: + uint32_t uniform_hash(const void * key, size_t len, uint32_t seed) + { + const uint8_t * data = (const uint8_t*)key; + const int nblocks = (int)len / 4; + + uint32_t h1 = seed; + + const uint32_t c1 = 0xcc9e2d51; + const uint32_t c2 = 0x1b873593; + + // --- body + const uint32_t * blocks = (const uint32_t *)(data + nblocks * 4); + + for (int i = -nblocks; i; i++) { + uint32_t k1 = getblock(blocks, i); + + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + + h1 ^= k1; + h1 = ROTL32(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + } + + // --- tail + const uint8_t * tail = (const uint8_t*)(data + nblocks * 4); + + uint32_t k1 = 0; + + switch (len & 3) { + case 3: k1 ^= tail[2] << 16; + case 2: k1 ^= tail[1] << 8; + case 1: k1 ^= tail[0]; + k1 *= c1; k1 = ROTL32(k1, 15); k1 *= c2; h1 ^= k1; + } + + // --- finalization + h1 ^= len; + + return fmix(h1); + } +}; + +class HashUtils +{ +public: + static u64 Compute_Id_Hash(const std::string& unique_id) + { + size_t ret = 0; + const char *p = unique_id.c_str(); + while (*p != '\0') + if (*p >= '0' && *p <= '9') + ret = 10 * ret + *(p++) - '0'; + else + { + murmur_hash foo; + return foo.uniform_hash(unique_id.c_str(), unique_id.size(), 0); + } + return ret; + } +}; + +const size_t max_int = 100000; +const float max_float = max_int; +const float min_float = 0.00001f; +const size_t max_digits = (size_t) roundf((float) (-log(min_float) / log(10.))); + +class NumberUtils +{ +public: + static void Float_To_String(float f, char* str) + { + int x = (int)f; + int d = (int)(fabs(f - x) * 100000); + sprintf(str, "%d.%05d", x, d); + } + + template<bool trailing_zeros> + static void print_mantissa(char*& begin, float f) + { // helper for print_float + char values[10]; + size_t v = (size_t)f; + size_t digit = 0; + size_t first_nonzero = 0; + for (size_t max = 1; max <= v; ++digit) + { + size_t max_next = max * 10; + char v_mod = (char) (v % max_next / max); + if (!trailing_zeros && v_mod != '\0' && first_nonzero == 0) + first_nonzero = digit; + values[digit] = '0' + v_mod; + max = max_next; + } + if (!trailing_zeros) + for (size_t i = max_digits; i > digit; i--) + *begin++ = '0'; + while (digit > first_nonzero) + *begin++ = values[--digit]; + } + + static void print_float(char* begin, float f) + { + bool sign = false; + if (f < 0.f) + sign = true; + float unsigned_f = fabsf(f); + if (unsigned_f < max_float && unsigned_f > min_float) + { + if (sign) + *begin++ = '-'; + print_mantissa<true>(begin, unsigned_f); + unsigned_f -= (size_t)unsigned_f; + unsigned_f *= max_int; + if (unsigned_f >= 1.f) + { + *begin++ = '.'; + print_mantissa<false>(begin, unsigned_f); + } + } + else if (unsigned_f == 0.) + *begin++ = '0'; + else + { + sprintf(begin, "%g", f); + return; + } + *begin = '\0'; + return; + } +}; + +//A quick implementation similar to drand48 for cross-platform compatibility +namespace PRG { + const uint64_t a = 0xeece66d5deece66dULL; + const uint64_t c = 2147483647; + + const int bias = 127 << 23; + + union int_float { + int32_t i; + float f; + }; + + struct prg { + private: + uint64_t v; + public: + prg() { v = c; } + prg(uint64_t initial) { v = initial; } + + float merand48(uint64_t& initial) + { + initial = a * initial + c; + int_float temp; + temp.i = ((initial >> 25) & 0x7FFFFF) | bias; + return temp.f - 1; + } + + float Uniform_Unit_Interval() + { + return merand48(v); + } + + uint32_t Uniform_Int(uint32_t low, uint32_t high) + { + merand48(v); + uint32_t ret = low + ((v >> 25) % (high - low + 1)); + return ret; + } + }; +} +} /*! @} End of Doxygen Groups*/ diff --git a/explore/tests/MWTExploreTests.h b/explore/tests/MWTExploreTests.h index 447a368f..1c451fb4 100644 --- a/explore/tests/MWTExploreTests.h +++ b/explore/tests/MWTExploreTests.h @@ -1,164 +1,164 @@ -#pragma once
-
-#include "MWTExplorer.h"
-#include "utility.h"
-#include <iomanip>
-#include <iostream>
-#include <sstream>
-
-using namespace MultiWorldTesting;
-
-class TestContext
-{
-
-};
-
-template <class Ctx>
-struct TestInteraction
-{
- Ctx& Context;
- u32 Action;
- float Probability;
- string Unique_Key;
-};
-
-class TestPolicy : public IPolicy<TestContext>
-{
-public:
- TestPolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- u32 Choose_Action(TestContext& context)
- {
- return m_params % m_num_actions + 1; // action id is one-based
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestScorer : public IScorer<TestContext>
-{
-public:
- TestScorer(int params, int num_actions, bool uniform = true) :
- m_params(params), m_num_actions(num_actions), m_uniform(uniform)
- { }
-
- vector<float> Score_Actions(TestContext& context)
- {
- vector<float> scores;
- if (m_uniform)
- {
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params);
- }
- }
- else
- {
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params + i);
- }
- }
- return scores;
- }
-private:
- int m_params;
- int m_num_actions;
- bool m_uniform;
-};
-
-class FixedScorer : public IScorer<TestContext>
-{
-public:
- FixedScorer(int num_actions, int value) :
- m_num_actions(num_actions), m_value(value)
- { }
-
- vector<float> Score_Actions(TestContext& context)
- {
- vector<float> scores;
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back((float)m_value);
- }
- return scores;
- }
-private:
- int m_num_actions;
- int m_value;
-};
-
-class TestSimpleScorer : public IScorer<SimpleContext>
-{
-public:
- TestSimpleScorer(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- vector<float> Score_Actions(SimpleContext& context)
- {
- vector<float> scores;
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params);
- }
- return scores;
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestSimplePolicy : public IPolicy<SimpleContext>
-{
-public:
- TestSimplePolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- u32 Choose_Action(SimpleContext& context)
- {
- return m_params % m_num_actions + 1; // action id is one-based
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestSimpleRecorder : public IRecorder<SimpleContext>
-{
-public:
- virtual void Record(SimpleContext& context, u32 action, float probability, string unique_key)
- {
- m_interactions.push_back({ context, action, probability, unique_key });
- }
-
- vector<TestInteraction<SimpleContext>> Get_All_Interactions()
- {
- return m_interactions;
- }
-
-private:
- vector<TestInteraction<SimpleContext>> m_interactions;
-};
-
-// Return action outside valid range
-class TestBadPolicy : public IPolicy<TestContext>
-{
-public:
- u32 Choose_Action(TestContext& context)
- {
- return 100;
- }
-};
-
-class TestRecorder : public IRecorder<TestContext>
-{
-public:
- virtual void Record(TestContext& context, u32 action, float probability, string unique_key)
- {
- m_interactions.push_back({ context, action, probability, unique_key });
- }
-
- vector<TestInteraction<TestContext>> Get_All_Interactions()
- {
- return m_interactions;
- }
-
-private:
- vector<TestInteraction<TestContext>> m_interactions;
-};
+#pragma once + +#include "MWTExplorer.h" +#include "utility.h" +#include <iomanip> +#include <iostream> +#include <sstream> + +using namespace MultiWorldTesting; + +class TestContext +{ + +}; + +template <class Ctx> +struct TestInteraction +{ + Ctx& Context; + u32 Action; + float Probability; + string Unique_Key; +}; + +class TestPolicy : public IPolicy<TestContext> +{ +public: + TestPolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { } + u32 Choose_Action(TestContext& context) + { + return m_params % m_num_actions + 1; // action id is one-based + } +private: + int m_params; + int m_num_actions; +}; + +class TestScorer : public IScorer<TestContext> +{ +public: + TestScorer(int params, int num_actions, bool uniform = true) : + m_params(params), m_num_actions(num_actions), m_uniform(uniform) + { } + + vector<float> Score_Actions(TestContext& context) + { + vector<float> scores; + if (m_uniform) + { + for (u32 i = 0; i < m_num_actions; i++) + { + scores.push_back(m_params); + } + } + else + { + for (u32 i = 0; i < m_num_actions; i++) + { + scores.push_back(m_params + i); + } + } + return scores; + } +private: + int m_params; + int m_num_actions; + bool m_uniform; +}; + +class FixedScorer : public IScorer<TestContext> +{ +public: + FixedScorer(int num_actions, int value) : + m_num_actions(num_actions), m_value(value) + { } + + vector<float> Score_Actions(TestContext& context) + { + vector<float> scores; + for (u32 i = 0; i < m_num_actions; i++) + { + scores.push_back((float)m_value); + } + return scores; + } +private: + int m_num_actions; + int m_value; +}; + +class TestSimpleScorer : public IScorer<SimpleContext> +{ +public: + TestSimpleScorer(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { } + vector<float> Score_Actions(SimpleContext& context) + { + vector<float> scores; + for (u32 i = 0; i < m_num_actions; i++) + { + scores.push_back(m_params); + } + return scores; + } +private: + int m_params; + int m_num_actions; +}; + +class TestSimplePolicy : public IPolicy<SimpleContext> +{ +public: + TestSimplePolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { } + u32 Choose_Action(SimpleContext& context) + { + return m_params % m_num_actions + 1; // action id is one-based + } +private: + int m_params; + int m_num_actions; +}; + +class TestSimpleRecorder : public IRecorder<SimpleContext> +{ +public: + virtual void Record(SimpleContext& context, u32 action, float probability, string unique_key) + { + m_interactions.push_back({ context, action, probability, unique_key }); + } + + vector<TestInteraction<SimpleContext>> Get_All_Interactions() + { + return m_interactions; + } + +private: + vector<TestInteraction<SimpleContext>> m_interactions; +}; + +// Return action outside valid range +class TestBadPolicy : public IPolicy<TestContext> +{ +public: + u32 Choose_Action(TestContext& context) + { + return 100; + } +}; + +class TestRecorder : public IRecorder<TestContext> +{ +public: + virtual void Record(TestContext& context, u32 action, float probability, string unique_key) + { + m_interactions.push_back({ context, action, probability, unique_key }); + } + + vector<TestInteraction<TestContext>> Get_All_Interactions() + { + return m_interactions; + } + +private: + vector<TestInteraction<TestContext>> m_interactions; +}; diff --git a/java/pom.xml b/java/pom.xml index 5ef652bb..db48a576 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -1,12 +1,11 @@ <?xml version="1.0" encoding="UTF-8"?> -<project xmlns="http://maven.apache.org/POM/4.0.0" - xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" - xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>vw</groupId> <artifactId>vw-jni-${os.version}-${os.arch}</artifactId> - <version>1.0-SNAPSHOT</version> + <version>1.0.0-SNAPSHOT</version> + <name>Vowpal Wabbit JNI Layer</name> <properties> <slf4j.version>1.7.6</slf4j.version> @@ -28,14 +27,6 @@ </dependencies> <build> - <resources> - <resource> - <directory>${project.build.directory}</directory> - <includes> - <include>vw_jni.lib</include> - </includes> - </resource> - </resources> <testResources> <testResource> <directory>${project.build.directory}</directory> @@ -68,7 +59,7 @@ <configuration> <exportAntProperties>true</exportAntProperties> <target> - <exec executable="make" failonerror="true"/> + <exec executable="make" failonerror="true" /> </target> </configuration> </execution> @@ -82,7 +73,7 @@ <exportAntProperties>true</exportAntProperties> <target> <exec executable="make" failonerror="true"> - <arg value="clean"/> + <arg value="clean" /> </exec> </target> </configuration> @@ -145,6 +136,53 @@ <redirectTestOutputToFile>true</redirectTestOutputToFile> </configuration> </plugin> + <plugin> + <artifactId>maven-resources-plugin</artifactId> + <version>2.7</version> + <executions> + <execution> + <id>copy-resources</id> + <phase>process-classes</phase> + <goals> + <goal>copy-resources</goal> + </goals> + <configuration> + <outputDirectory>${project.build.outputDirectory}</outputDirectory> + <resources> + <resource> + <directory>${project.build.directory}</directory> + <includes> + <include>vw_jni.lib</include> + </includes> + </resource> + </resources> + </configuration> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-release-plugin</artifactId> + <version>2.4.2</version> + <dependencies> + <dependency> + <groupId>org.apache.maven.scm</groupId> + <artifactId>maven-scm-provider-gitexe</artifactId> + <version>1.3</version> + </dependency> + </dependencies> + <executions> + <execution> + <id>default</id> + <goals> + <goal>perform</goal> + </goals> + <configuration> + <pomFileName>java/pom.xml</pomFileName> + </configuration> + </execution> + </executions> + </plugin> </plugins> </build> -</project>
\ No newline at end of file +</project> diff --git a/java/src/main/c++/vw_VWScorer.cc b/java/src/main/c++/vw_VWScorer.cc index 652d722b..dce1869e 100644 --- a/java/src/main/c++/vw_VWScorer.cc +++ b/java/src/main/c++/vw_VWScorer.cc @@ -4,13 +4,11 @@ vw* vw; -JNIEXPORT void JNICALL Java_vw_VWScorer_initialize - (JNIEnv *env, jobject obj, jstring command) { +JNIEXPORT void JNICALL Java_vw_VWScorer_initialize (JNIEnv *env, jobject obj, jstring command) { vw = VW::initialize(env->GetStringUTFChars(command, NULL)); } -JNIEXPORT jfloat JNICALL Java_vw_VWScorer_getPrediction - (JNIEnv *env, jobject obj, jstring example_string) { +JNIEXPORT jfloat JNICALL Java_vw_VWScorer_getPrediction (JNIEnv *env, jobject obj, jstring example_string) { example *vec2 = VW::read_example(*vw, env->GetStringUTFChars(example_string, NULL)); vw->l->predict(*vec2); float prediction; @@ -18,7 +16,7 @@ JNIEXPORT jfloat JNICALL Java_vw_VWScorer_getPrediction prediction = vec2->pred.scalar; else prediction = vec2->pred.multiclass; - VW::finish_example(*vw, vec2); + VW::finish_example(*vw, vec2); return prediction; } diff --git a/java/src/test/java/vw/VWScorerTest.java b/java/src/test/java/vw/VWScorerTest.java index 740fefce..05b6ae61 100644 --- a/java/src/test/java/vw/VWScorerTest.java +++ b/java/src/test/java/vw/VWScorerTest.java @@ -3,6 +3,8 @@ package vw; import org.junit.BeforeClass; import org.junit.Test; +import java.io.IOException; + import static org.junit.Assert.assertEquals; /** @@ -12,19 +14,23 @@ public class VWScorerTest { private static VWScorer scorer; @BeforeClass - public static void setup() { - scorer = new VWScorer("-i src/test/resources/house.model --quiet -t"); + public static void setup() throws IOException, InterruptedException { + // Since we want this test to continue to work between VW changes, we can't store the model + // Instead, we'll make a new model for each test + String vwModel = "target/house.model"; + Runtime.getRuntime().exec("../vowpalwabbit/vw -d src/test/resources/house.vw -f " + vwModel).waitFor(); + scorer = new VWScorer("--quiet -t -i " + vwModel); } @Test public void testBlank() { float prediction = scorer.getPrediction("| "); - assertEquals(0.07496, prediction, 0.00001); + assertEquals(0.075, prediction, 0.001); } @Test public void testLine() { float prediction = scorer.getPrediction("| price:0.23 sqft:0.25 age:0.05 2006"); - assertEquals(0.118467, prediction, 0.00001); + assertEquals(0.118, prediction, 0.001); } } diff --git a/java/src/test/resources/house.model b/java/src/test/resources/house.model Binary files differdeleted file mode 100644 index 20a3f310..00000000 --- a/java/src/test/resources/house.model +++ /dev/null diff --git a/java/src/test/resources/house.vw b/java/src/test/resources/house.vw new file mode 100644 index 00000000..87da95e5 --- /dev/null +++ b/java/src/test/resources/house.vw @@ -0,0 +1,3 @@ +0 | price:.23 sqft:.25 age:.05 2006 +1 2 'second_house | price:.18 sqft:.15 age:.35 1976 +0 1 0.5 'third_house | price:.53 sqft:.32 age:.87 1924 diff --git a/library/Makefile b/library/Makefile index 8f7deb35..5ca18f35 100644 --- a/library/Makefile +++ b/library/Makefile @@ -40,3 +40,4 @@ gd_mf_weights: gd_mf_weights.cc ../vowpalwabbit/libvw.a clean: rm -f *.o ezexample_predict ezexample_train library_example test_search recommend ezexample_predict_threaded +.PHONY: all clean diff --git a/library/ezexample_predict.cc b/library/ezexample_predict.cc index db061f61..22c8a86d 100644 --- a/library/ezexample_predict.cc +++ b/library/ezexample_predict.cc @@ -1,55 +1,55 @@ -#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-using namespace std;
-
-int main(int argc, char *argv[])
-{
- string init_string = "-t -q st --hash all --noconstant --ldf_override s -i ";
- if (argc > 1)
- init_string += argv[1];
- else
- init_string += "train.w";
-
- cerr << "initializing with: '" << init_string << "'" << endl;
-
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
- vw* vw = VW::initialize(init_string); // "-t -q st --hash all --noconstant --ldf_override s -i train.w");
-
- {
- // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
- ezexample ex(vw, false); // don't need multiline
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- cerr << ex.predict_partial() << endl;
-
- // ex.clear_features();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("2");
- cerr << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace, and add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("2");
- cerr << ex.predict_partial() << endl;
- }
-
- // AND FINISH UP
- VW::finish(*vw);
-}
+#include <stdio.h> +#include "../vowpalwabbit/parser.h" +#include "../vowpalwabbit/vw.h" +#include "../vowpalwabbit/ezexample.h" + +using namespace std; + +int main(int argc, char *argv[]) +{ + string init_string = "-t -q st --hash all --noconstant --ldf_override s -i "; + if (argc > 1) + init_string += argv[1]; + else + init_string += "train.w"; + + cerr << "initializing with: '" << init_string << "'" << endl; + + // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w + vw* vw = VW::initialize(init_string); // "-t -q st --hash all --noconstant --ldf_override s -i train.w"); + + { + // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS + ezexample ex(vw, false); // don't need multiline + ex(vw_namespace('s')) + ("p^the_man") + ("w^the") + ("w^man") + (vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + ex.set_label("1"); + cerr << ex.predict_partial() << endl; + + // ex.clear_features(); + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + ex.set_label("2"); + cerr << ex.predict_partial() << endl; + + --ex; // remove the most recent namespace, and add features with explicit ns + ex('t', "p^un_homme") + ('t', "w^un") + ('t', "w^homme"); + ex.set_label("2"); + cerr << ex.predict_partial() << endl; + } + + // AND FINISH UP + VW::finish(*vw); +} diff --git a/library/ezexample_predict_threaded.cc b/library/ezexample_predict_threaded.cc index 0fa5b1e6..c7c39e85 100644 --- a/library/ezexample_predict_threaded.cc +++ b/library/ezexample_predict_threaded.cc @@ -1,149 +1,149 @@ -#include <stdio.h>
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-#include <boost/thread/thread.hpp>
-
-using namespace std;
-
-int runcount = 100;
-
-class Worker
-{
-public:
- Worker(vw & instance, string & vw_init_string, vector<double> & ref)
- : m_vw(instance)
- , m_referenceValues(ref)
- , vw_init_string(vw_init_string)
- { }
-
- void operator()()
- {
- m_vw_parser = VW::initialize(vw_init_string);
- if (m_vw_parser == NULL) {
- cerr << "cannot initialize vw parser" << endl;
- exit(-1);
- }
-
- int errorCount = 0;
- for (int i = 0; i < runcount; ++i)
- {
- vector<double>::iterator it = m_referenceValues.begin();
- ezexample ex(&m_vw, false, m_vw_parser);
-
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- //if (*it != pred) { cerr << "fail!" << endl; ++errorCount; }
- //VW::finish_example(m_vw, vec2);
- ++it;
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- ++it;
-
- --ex; // remove the most recent namespace
- // add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- ++it;
-
- //cout << "."; cout.flush();
- }
- cerr << "error count = " << errorCount << endl;
- VW::finish(*m_vw_parser);
- m_vw_parser = NULL;
- }
-
-private:
- vw & m_vw;
- vw * m_vw_parser;
- vector<double> & m_referenceValues;
- string & vw_init_string;
-};
-
-int main(int argc, char *argv[])
-{
- if (argc != 3)
- {
- cerr << "need two args: threadcount runcount" << endl;
- return 1;
- }
- int threadcount = atoi(argv[1]);
- runcount = atoi(argv[2]);
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
- string vw_init_string_all = "-t --ldf_override s --quiet -q st --noconstant --hash all -i train.w";
- string vw_init_string_parser = "-t --ldf_override s --quiet -q st --noconstant --hash all --noop"; // this needs to have enough arguments to get the parser right
- vw*vw = VW::initialize(vw_init_string_all);
- vector<double> results;
-
- // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
- {
- ezexample ex(vw, false);
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near zero = " << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near one = " << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace
- // add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near one = " << ex.predict_partial() << endl;
- }
-
- if (threadcount == 0)
- {
- Worker w(*vw, vw_init_string_parser, results);
- w();
- }
- else
- {
- boost::thread_group tg;
- for (int t = 0; t < threadcount; ++t)
- {
- cerr << "starting thread " << t << endl;
- boost::thread * pt = tg.create_thread(Worker(*vw, vw_init_string_parser, results));
- }
- tg.join_all();
- cerr << "finished!" << endl;
- }
-
-
- // AND FINISH UP
- VW::finish(*vw);
-}
+#include <stdio.h> +#include "../vowpalwabbit/vw.h" +#include "../vowpalwabbit/ezexample.h" + +#include <boost/thread/thread.hpp> + +using namespace std; + +int runcount = 100; + +class Worker +{ +public: + Worker(vw & instance, string & vw_init_string, vector<double> & ref) + : m_vw(instance) + , m_referenceValues(ref) + , vw_init_string(vw_init_string) + { } + + void operator()() + { + m_vw_parser = VW::initialize(vw_init_string); + if (m_vw_parser == NULL) { + cerr << "cannot initialize vw parser" << endl; + exit(-1); + } + + int errorCount = 0; + for (int i = 0; i < runcount; ++i) + { + vector<double>::iterator it = m_referenceValues.begin(); + ezexample ex(&m_vw, false, m_vw_parser); + + ex(vw_namespace('s')) + ("p^the_man") + ("w^the") + ("w^man") + (vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + ex.set_label("1"); + if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; } + //if (*it != pred) { cerr << "fail!" << endl; ++errorCount; } + //VW::finish_example(m_vw, vec2); + ++it; + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + ex.set_label("1"); + if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; } + ++it; + + --ex; // remove the most recent namespace + // add features with explicit ns + ex('t', "p^un_homme") + ('t', "w^un") + ('t', "w^homme"); + ex.set_label("1"); + if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; } + ++it; + + //cout << "."; cout.flush(); + } + cerr << "error count = " << errorCount << endl; + VW::finish(*m_vw_parser); + m_vw_parser = NULL; + } + +private: + vw & m_vw; + vw * m_vw_parser; + vector<double> & m_referenceValues; + string & vw_init_string; +}; + +int main(int argc, char *argv[]) +{ + if (argc != 3) + { + cerr << "need two args: threadcount runcount" << endl; + return 1; + } + int threadcount = atoi(argv[1]); + runcount = atoi(argv[2]); + // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w + string vw_init_string_all = "-t --ldf_override s --quiet -q st --noconstant --hash all -i train.w"; + string vw_init_string_parser = "-t --ldf_override s --quiet -q st --noconstant --hash all --noop"; // this needs to have enough arguments to get the parser right + vw*vw = VW::initialize(vw_init_string_all); + vector<double> results; + + // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS + { + ezexample ex(vw, false); + ex(vw_namespace('s')) + ("p^the_man") + ("w^the") + ("w^man") + (vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + ex.set_label("1"); + results.push_back(ex.predict_partial()); + cerr << "should be near zero = " << ex.predict_partial() << endl; + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + ex.set_label("1"); + results.push_back(ex.predict_partial()); + cerr << "should be near one = " << ex.predict_partial() << endl; + + --ex; // remove the most recent namespace + // add features with explicit ns + ex('t', "p^un_homme") + ('t', "w^un") + ('t', "w^homme"); + ex.set_label("1"); + results.push_back(ex.predict_partial()); + cerr << "should be near one = " << ex.predict_partial() << endl; + } + + if (threadcount == 0) + { + Worker w(*vw, vw_init_string_parser, results); + w(); + } + else + { + boost::thread_group tg; + for (int t = 0; t < threadcount; ++t) + { + cerr << "starting thread " << t << endl; + boost::thread * pt = tg.create_thread(Worker(*vw, vw_init_string_parser, results)); + } + tg.join_all(); + cerr << "finished!" << endl; + } + + + // AND FINISH UP + VW::finish(*vw); +} diff --git a/library/ezexample_train.cc b/library/ezexample_train.cc index a0f66a99..9a8af8e0 100644 --- a/library/ezexample_train.cc +++ b/library/ezexample_train.cc @@ -1,72 +1,72 @@ -#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-using namespace std;
-
-void run(vw*vw) {
- ezexample ex(vw, true); // we're doing csoaa_ldf so we need multiline examples
-
- /// BEGIN FIRST MULTILINE EXAMPLE
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
-
- ex.set_label("1:1");
- ex.train();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
-
- ex.set_label("2:0");
- ex.train();
-
- // push it through VW for training
- ex.finish();
-
- /// BEGIN SECOND MULTILINE EXAMPLE
- ex(vw_namespace('s'))
- ("p^a_man")
- ("w^a")
- ("w^man")
- (vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
-
- ex.set_label("1:0");
- ex.train();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
-
- ex.set_label("2:1");
- ex.train();
-
- // push it through VW for training
- ex.finish();
-}
-
-int main(int argc, char *argv[])
-{
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS WILL STORE A MODEL TO train.ezw
- vw* vw = VW::initialize("--hash all -q st --noconstant -f train.w --quiet --csoaa_ldf m");
-
- run(vw);
-
- // AND FINISH UP
- cerr << "ezexample_train finish"<<endl;
- VW::finish(*vw);
-}
+#include <stdio.h> +#include "../vowpalwabbit/parser.h" +#include "../vowpalwabbit/vw.h" +#include "../vowpalwabbit/ezexample.h" + +using namespace std; + +void run(vw*vw) { + ezexample ex(vw, true); // we're doing csoaa_ldf so we need multiline examples + + /// BEGIN FIRST MULTILINE EXAMPLE + ex(vw_namespace('s')) + ("p^the_man") + ("w^the") + ("w^man") + (vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + + ex.set_label("1:1"); + ex.train(); + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + + ex.set_label("2:0"); + ex.train(); + + // push it through VW for training + ex.finish(); + + /// BEGIN SECOND MULTILINE EXAMPLE + ex(vw_namespace('s')) + ("p^a_man") + ("w^a") + ("w^man") + (vw_namespace('t')) + ("p^un_homme") + ("w^un") + ("w^homme"); + + ex.set_label("1:0"); + ex.train(); + + --ex; // remove the most recent namespace + ex(vw_namespace('t')) + ("p^le_homme") + ("w^le") + ("w^homme"); + + ex.set_label("2:1"); + ex.train(); + + // push it through VW for training + ex.finish(); +} + +int main(int argc, char *argv[]) +{ + // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS WILL STORE A MODEL TO train.ezw + vw* vw = VW::initialize("--hash all -q st --noconstant -f train.w --quiet --csoaa_ldf m"); + + run(vw); + + // AND FINISH UP + cerr << "ezexample_train finish"<<endl; + VW::finish(*vw); +} diff --git a/library/gd_mf_weights.cc b/library/gd_mf_weights.cc index 4b394775..268739a4 100644 --- a/library/gd_mf_weights.cc +++ b/library/gd_mf_weights.cc @@ -52,12 +52,17 @@ int main(int argc, char *argv[]) vw* model = VW::initialize(vwparams);
model->audit = true;
+ string target("--rank ");
+ size_t loc = vwparams.find(target);
+ const char* location = vwparams.c_str()+loc+target.size();
+ size_t rank = atoi(location);
+
// global model params
unsigned char left_ns = model->pairs[0][0];
unsigned char right_ns = model->pairs[0][1];
weight* weights = model->reg.weight_vector;
size_t mask = model->reg.weight_mask;
-
+
// const char *filename = argv[0];
FILE* file = fopen(infile.c_str(), "r");
char* line = NULL;
@@ -86,7 +91,7 @@ int main(int argc, char *argv[]) left_linear << f->feature << '\t' << weights[f->weight_index & mask];
left_quadratic << f->feature;
- for (size_t k = 1; k <= model->rank; k++)
+ for (size_t k = 1; k <= rank; k++)
left_quadratic << '\t' << weights[(f->weight_index + k) & mask];
}
left_linear << endl;
@@ -101,8 +106,8 @@ int main(int argc, char *argv[]) right_linear << f->feature << '\t' << weights[f->weight_index & mask];
right_quadratic << f->feature;
- for (size_t k = 1; k <= model->rank; k++)
- right_quadratic << '\t' << weights[(f->weight_index + k + model->rank) & mask];
+ for (size_t k = 1; k <= rank; k++)
+ right_quadratic << '\t' << weights[(f->weight_index + k + rank) & mask];
}
right_linear << endl;
right_quadratic << endl;
diff --git a/library/library_example.cc b/library/library_example.cc index 8abfc1f7..d7c186c2 100644 --- a/library/library_example.cc +++ b/library/library_example.cc @@ -1,62 +1,62 @@ -#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-
-using namespace std;
-
-
-inline feature vw_feature_from_string(vw& v, string fstr, unsigned long seed, float val)
-{
- uint32_t foo = VW::hash_feature(v, fstr, seed);
- feature f = { val, foo};
- return f;
-}
-
-int main(int argc, char *argv[])
-{
- vw* model = VW::initialize("--hash all -q st --noconstant -i train.w -f train2.vw");
-
- example *vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme");
- model->learn(vec2);
- cerr << "p2 = " << vec2->pred.scalar << endl;
- VW::finish_example(*model, vec2);
-
- vector< VW::feature_space > ec_info;
- vector<feature> s_features, t_features;
- uint32_t s_hash = VW::hash_space(*model, "s");
- uint32_t t_hash = VW::hash_space(*model, "t");
- s_features.push_back( vw_feature_from_string(*model, "p^the_man", s_hash, 1.0) );
- s_features.push_back( vw_feature_from_string(*model, "w^the", s_hash, 1.0) );
- s_features.push_back( vw_feature_from_string(*model, "w^man", s_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "p^le_homme", t_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "w^le", t_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "w^homme", t_hash, 1.0) );
- ec_info.push_back( VW::feature_space('s', s_features) );
- ec_info.push_back( VW::feature_space('t', t_features) );
- example* vec3 = VW::import_example(*model, ec_info);
-
- model->learn(vec3);
- cerr << "p3 = " << vec3->pred.scalar << endl;
- VW::finish_example(*model, vec3);
-
- VW::finish(*model);
-
- vw* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw");
- vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme");
- model2->learn(vec2);
- cerr << "p4 = " << vec2->pred.scalar << endl;
-
- size_t len=0;
- VW::primitive_feature_space* pfs = VW::export_example(*model2, vec2, len);
- for (size_t i = 0; i < len; i++)
- {
- cout << "namespace = " << pfs[i].name;
- for (size_t j = 0; j < pfs[i].len; j++)
- cout << " " << pfs[i].fs[j].weight_index << ":" << pfs[i].fs[j].x << ":" << VW::get_weight(*model2, pfs[i].fs[j].weight_index, 0);
- cout << endl;
- }
-
- VW::finish_example(*model2, vec2);
- VW::finish(*model2);
-}
-
+#include <stdio.h> +#include "../vowpalwabbit/parser.h" +#include "../vowpalwabbit/vw.h" + +using namespace std; + + +inline feature vw_feature_from_string(vw& v, string fstr, unsigned long seed, float val) +{ + uint32_t foo = VW::hash_feature(v, fstr, seed); + feature f = { val, foo}; + return f; +} + +int main(int argc, char *argv[]) +{ + vw* model = VW::initialize("--hash all -q st --noconstant -i train.w -f train2.vw"); + + example *vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme"); + model->learn(vec2); + cerr << "p2 = " << vec2->pred.scalar << endl; + VW::finish_example(*model, vec2); + + vector< VW::feature_space > ec_info; + vector<feature> s_features, t_features; + uint32_t s_hash = VW::hash_space(*model, "s"); + uint32_t t_hash = VW::hash_space(*model, "t"); + s_features.push_back( vw_feature_from_string(*model, "p^the_man", s_hash, 1.0) ); + s_features.push_back( vw_feature_from_string(*model, "w^the", s_hash, 1.0) ); + s_features.push_back( vw_feature_from_string(*model, "w^man", s_hash, 1.0) ); + t_features.push_back( vw_feature_from_string(*model, "p^le_homme", t_hash, 1.0) ); + t_features.push_back( vw_feature_from_string(*model, "w^le", t_hash, 1.0) ); + t_features.push_back( vw_feature_from_string(*model, "w^homme", t_hash, 1.0) ); + ec_info.push_back( VW::feature_space('s', s_features) ); + ec_info.push_back( VW::feature_space('t', t_features) ); + example* vec3 = VW::import_example(*model, ec_info); + + model->learn(vec3); + cerr << "p3 = " << vec3->pred.scalar << endl; + VW::finish_example(*model, vec3); + + VW::finish(*model); + + vw* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw"); + vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme"); + model2->learn(vec2); + cerr << "p4 = " << vec2->pred.scalar << endl; + + size_t len=0; + VW::primitive_feature_space* pfs = VW::export_example(*model2, vec2, len); + for (size_t i = 0; i < len; i++) + { + cout << "namespace = " << pfs[i].name; + for (size_t j = 0; j < pfs[i].len; j++) + cout << " " << pfs[i].fs[j].weight_index << ":" << pfs[i].fs[j].x << ":" << VW::get_weight(*model2, pfs[i].fs[j].weight_index, 0); + cout << endl; + } + + VW::finish_example(*model2, vec2); + VW::finish(*model2); +} + diff --git a/python/Makefile b/python/Makefile index b268663d..845fb15f 100644 --- a/python/Makefile +++ b/python/Makefile @@ -43,3 +43,5 @@ pylibvw.o: pylibvw.cc clean: rm -f *.o $(PYLIBVW) + +.PHONY: all clean things diff --git a/test/test-sets/ref/ml100k_small.stderr b/test/test-sets/ref/ml100k_small.stderr index c8bcbb51..92b91e64 100644 --- a/test/test-sets/ref/ml100k_small.stderr +++ b/test/test-sets/ref/ml100k_small.stderr @@ -3,7 +3,6 @@ Num weight bits = 16 learning rate = 10 initial_t = 1 power_t = 0.5 -rank = 10 using no cache Reading datafile = test-sets/ml100k_small_test num sources = 1 diff --git a/test/train-sets/ref/ml100k_small.stderr b/test/train-sets/ref/ml100k_small.stderr index 3f3d6e25..462dca37 100644 --- a/test/train-sets/ref/ml100k_small.stderr +++ b/test/train-sets/ref/ml100k_small.stderr @@ -6,7 +6,6 @@ learning rate = 0.05 initial_t = 1 power_t = 0 decay_learning_rate = 0.97 -rank = 10 creating cache_file = train-sets/ml100k_small_train.cache Reading datafile = train-sets/ml100k_small_train num sources = 1 diff --git a/test/train-sets/ref/sequencespan_data.nonldf-bilou.test.stderr b/test/train-sets/ref/sequencespan_data.nonldf-bilou.test.stderr index aad69ced..5b3afa22 100644 --- a/test/train-sets/ref/sequencespan_data.nonldf-bilou.test.stderr +++ b/test/train-sets/ref/sequencespan_data.nonldf-bilou.test.stderr @@ -1,10 +1,10 @@ only testing +switching to BILOU encoding for sequence span labeling Num weight bits = 18 learning rate = 10 initial_t = 1 power_t = 0.5 predictions = sequencespan_data.predict -switching to BILOU encoding for sequence span labeling using no cache Reading datafile = train-sets/sequencespan_data num sources = 1 diff --git a/test/train-sets/ref/sequencespan_data.nonldf-bilou.train.stderr b/test/train-sets/ref/sequencespan_data.nonldf-bilou.train.stderr index f0c914e0..a4fc72a6 100644 --- a/test/train-sets/ref/sequencespan_data.nonldf-bilou.train.stderr +++ b/test/train-sets/ref/sequencespan_data.nonldf-bilou.train.stderr @@ -1,10 +1,10 @@ final_regressor = models/sequencespan_data.model +switching to BILOU encoding for sequence span labeling Num weight bits = 18 learning rate = 10 initial_t = 1 power_t = 0.5 decay_learning_rate = 1 -switching to BILOU encoding for sequence span labeling creating cache_file = train-sets/sequencespan_data.cache Reading datafile = train-sets/sequencespan_data num sources = 1 diff --git a/test/train-sets/ref/sequencespan_data.nonldf.train.stderr b/test/train-sets/ref/sequencespan_data.nonldf.train.stderr index a75bbfd5..2ea5e7e1 100644 --- a/test/train-sets/ref/sequencespan_data.nonldf.train.stderr +++ b/test/train-sets/ref/sequencespan_data.nonldf.train.stderr @@ -1,10 +1,10 @@ final_regressor = models/sequencespan_data.model +no rollout! Num weight bits = 18 learning rate = 10 initial_t = 1 power_t = 0.5 decay_learning_rate = 1 -no rollout! creating cache_file = train-sets/sequencespan_data.cache Reading datafile = train-sets/sequencespan_data num sources = 1 diff --git a/vowpalwabbit/Makefile b/vowpalwabbit/Makefile index 79893902..9ecd681c 100644 --- a/vowpalwabbit/Makefile +++ b/vowpalwabbit/Makefile @@ -52,3 +52,4 @@ install: $(BINARIES) clean: rm -f *.o *.d $(BINARIES) *~ $(MANPAGES) libvw.a +.PHONY: all clean install test things diff --git a/vowpalwabbit/Makefile.am b/vowpalwabbit/Makefile.am index 676d6f4e..291a4702 100644 --- a/vowpalwabbit/Makefile.am +++ b/vowpalwabbit/Makefile.am @@ -4,7 +4,7 @@ liballreduce_la_SOURCES = allreduce.cc bin_PROGRAMS = vw active_interactor -libvw_la_SOURCES = hash.cc memory.cc global_data.cc io_buf.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc rand48.cc simple_label.cc multiclass.cc oaa.cc ect.cc autolink.cc binary.cc lrq.cc cost_sensitive.cc csoaa.cc cb.cc cb_algs.cc search.cc search_sequencetask.cc search_dep_parser.cc search_hooktask.cc search_multiclasstask.cc search_entityrelationtask.cc parse_example.cc scorer.cc network.cc parse_args.cc accumulate.cc gd.cc learner.cc lda_core.cc gd_mf.cc mf.cc bfgs.cc noop.cc print.cc example.cc parser.cc loss_functions.cc sender.cc nn.cc bs.cc cbify.cc topk.cc stagewise_poly.cc log_multi.cc active.cc kernel_svm.cc best_constant.cc ftrl_proximal.cc +libvw_la_SOURCES = hash.cc global_data.cc io_buf.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc rand48.cc simple_label.cc multiclass.cc oaa.cc ect.cc autolink.cc binary.cc lrq.cc cost_sensitive.cc csoaa.cc cb.cc cb_algs.cc search.cc search_sequencetask.cc search_dep_parser.cc search_hooktask.cc search_multiclasstask.cc search_entityrelationtask.cc parse_example.cc scorer.cc network.cc parse_args.cc accumulate.cc gd.cc learner.cc lda_core.cc gd_mf.cc mf.cc bfgs.cc noop.cc print.cc example.cc parser.cc loss_functions.cc sender.cc nn.cc bs.cc cbify.cc topk.cc stagewise_poly.cc log_multi.cc active.cc kernel_svm.cc best_constant.cc ftrl_proximal.cc libvw_c_wrapper_la_SOURCES = vwdll.cpp diff --git a/vowpalwabbit/accumulate.h b/vowpalwabbit/accumulate.h index c01ac5fe..4d507a60 100644 --- a/vowpalwabbit/accumulate.h +++ b/vowpalwabbit/accumulate.h @@ -1,13 +1,13 @@ -/*
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD
-license as described in the file LICENSE.
- */
-//This implements various accumulate functions building on top of allreduce.
-#pragma once
-#include "global_data.h"
-
-void accumulate(vw& all, std::string master_location, regressor& reg, size_t o);
-float accumulate_scalar(vw& all, std::string master_location, float local_sum);
-void accumulate_weighted_avg(vw& all, std::string master_location, regressor& reg);
-void accumulate_avg(vw& all, std::string master_location, regressor& reg, size_t o);
+/* +Copyright (c) by respective owners including Yahoo!, Microsoft, and +individual contributors. All rights reserved. Released under a BSD +license as described in the file LICENSE. + */ +//This implements various accumulate functions building on top of allreduce. +#pragma once +#include "global_data.h" + +void accumulate(vw& all, std::string master_location, regressor& reg, size_t o); +float accumulate_scalar(vw& all, std::string master_location, float local_sum); +void accumulate_weighted_avg(vw& all, std::string master_location, regressor& reg); +void accumulate_avg(vw& all, std::string master_location, regressor& reg, size_t o); diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc index a1070be3..54094331 100644 --- a/vowpalwabbit/active.cc +++ b/vowpalwabbit/active.cc @@ -151,27 +151,34 @@ namespace ACTIVE { VW::finish_example(all,&ec); } - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) {//parse and set arguments - active& data = calloc_or_die<active>(); - - po::options_description active_opts("Active Learning options"); - active_opts.add_options() + new_options(all, "Active Learning options") + ("active", "enable active learning"); + if (missing_required(all)) return NULL; + new_options(all) ("simulation", "active learning simulation mode") - ("mellowness", po::value<float>(&(data.active_c0)), "active learning mellowness parameter c_0. Default 8") - ; - vm = add_options(all, active_opts); + ("mellowness", po::value<float>(), "active learning mellowness parameter c_0. Default 8"); + add_options(all); + + active& data = calloc_or_die<active>(); + data.active_c0 = 8; data.all=&all; + if (all.vm.count("mellowness")) + data.active_c0 = all.vm["mellowness"].as<float>(); + + base_learner* base = setup_base(all); + //Create new learner learner<active>* ret; - if (vm.count("simulation")) - ret = &init_learner(&data, all.l, predict_or_learn_simulation<true>, + if (all.vm.count("simulation")) + ret = &init_learner(&data, base, predict_or_learn_simulation<true>, predict_or_learn_simulation<false>); else { all.active = true; - ret = &init_learner(&data, all.l, predict_or_learn_active<true>, + ret = &init_learner(&data, base, predict_or_learn_active<true>, predict_or_learn_active<false>); ret->set_finish_example(return_active_example); } diff --git a/vowpalwabbit/active.h b/vowpalwabbit/active.h index d71950ab..b1145e8c 100644 --- a/vowpalwabbit/active.h +++ b/vowpalwabbit/active.h @@ -1,4 +1,2 @@ #pragma once -namespace ACTIVE { - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace ACTIVE { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc index 7cdaecef..707b38a6 100644 --- a/vowpalwabbit/autolink.cc +++ b/vowpalwabbit/autolink.cc @@ -37,15 +37,21 @@ namespace ALINK { ec.total_sum_feat_sq -= sum_sq; } - LEARNER::base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all) { + new_options(all,"Autolink options") + ("autolink", po::value<size_t>(), "create link function with polynomial d"); + if(missing_required(all)) return NULL; + autolink& data = calloc_or_die<autolink>(); - data.d = (uint32_t)vm["autolink"].as<size_t>(); + data.d = (uint32_t)all.vm["autolink"].as<size_t>(); data.stride_shift = all.reg.stride_shift; *all.file_options << " --autolink " << data.d; - LEARNER::learner<autolink>& ret = init_learner(&data, all.l, predict_or_learn<true>, + LEARNER::base_learner* base = setup_base(all); + + LEARNER::learner<autolink>& ret = init_learner(&data, base, predict_or_learn<true>, predict_or_learn<false>); return make_base(ret); } diff --git a/vowpalwabbit/autolink.h b/vowpalwabbit/autolink.h index 3bb70dc1..e7bfcf51 100644 --- a/vowpalwabbit/autolink.h +++ b/vowpalwabbit/autolink.h @@ -1,4 +1,2 @@ #pragma once -namespace ALINK { - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace ALINK { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc index fc4ee851..a14d085c 100644 --- a/vowpalwabbit/bfgs.cc +++ b/vowpalwabbit/bfgs.cc @@ -63,6 +63,9 @@ namespace BFGS struct bfgs { vw* all; + int m; + float rel_threshold; // termination threshold + double wolfe1_bound; size_t final_pass; @@ -247,7 +250,7 @@ void bfgs_iter_start(vw& all, bfgs& b, float* mem, int& lastj, double importance origin = 0; for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) { - if (all.m>0) + if (b.m>0) mem[(MEM_XT+origin)%b.mem_stride] = w[W_XT]; mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT]; g1_Hg1 += w[W_GT] * w[W_GT] * w[W_COND]; @@ -272,7 +275,7 @@ void bfgs_iter_middle(vw& all, bfgs& b, float* mem, double* rho, double* alpha, float* w0 = w; // implement conjugate gradient - if (all.m==0) { + if (b.m==0) { double g_Hy = 0.; double g_Hg = 0.; double y = 0.; @@ -374,7 +377,7 @@ void bfgs_iter_middle(vw& all, bfgs& b, float* mem, double* rho, double* alpha, mem = mem0; w = w0; - lastj = (lastj<all.m-1) ? lastj+1 : all.m-1; + lastj = (lastj<b.m-1) ? lastj+1 : b.m-1; origin = (origin+b.mem_stride-2)%b.mem_stride; for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) { mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT]; @@ -633,9 +636,9 @@ int process_pass(vw& all, bfgs& b) { /********************************************************************/ else { double rel_decrease = (b.previous_loss_sum-b.loss_sum)/b.previous_loss_sum; - if (!nanpattern((float)rel_decrease) && b.backstep_on && fabs(rel_decrease)<all.rel_threshold) { + if (!nanpattern((float)rel_decrease) && b.backstep_on && fabs(rel_decrease)<b.rel_threshold) { fprintf(stdout, "\nTermination condition reached in pass %ld: decrease in loss less than %.3f%%.\n" - "If you want to optimize further, decrease termination threshold.\n", (long int)b.current_pass+1, all.rel_threshold*100.0); + "If you want to optimize further, decrease termination threshold.\n", (long int)b.current_pass+1, b.rel_threshold*100.0); status = LEARN_CONV; } b.previous_loss_sum = b.loss_sum; @@ -913,7 +916,7 @@ void save_load(bfgs& b, io_buf& model_file, bool read, bool text) throw exception(); } } - int m = all->m; + int m = b.m; b.mem_stride = (m==0) ? CG_EXTRA : 2*m; b.mem = (float*) malloc(sizeof(float)*all->length()*(b.mem_stride)); @@ -965,10 +968,23 @@ void save_load(bfgs& b, io_buf& model_file, bool read, bool text) b.backstep_on = true; } -base_learner* setup(vw& all, po::variables_map& vm) +base_learner* setup(vw& all) { + new_options(all, "LBFGS options") + ("bfgs", "use bfgs optimization") + ("conjugate_gradient", "use conjugate gradient based optimization"); + if (missing_required(all)) return NULL; + new_options(all) + ("hessian_on", "use second derivative in line search") + ("mem", po::value<uint32_t>()->default_value(15), "memory in bfgs") + ("termination", po::value<float>()->default_value(0.001f),"Termination threshold"); + add_options(all); + + po::variables_map& vm = all.vm; bfgs& b = calloc_or_die<bfgs>(); b.all = &all; + b.m = vm["mem"].as<uint32_t>(); + b.rel_threshold = vm["termination"].as<float>(); b.wolfe1_bound = 0.01; b.first_hessian_on=true; b.first_pass = true; @@ -979,16 +995,6 @@ base_learner* setup(vw& all, po::variables_map& vm) b.no_win_counter = 0; b.early_stop_thres = 3; - po::options_description bfgs_opts("LBFGS options"); - - bfgs_opts.add_options() - ("hessian_on", "use second derivative in line search") - ("mem", po::value<int>(&(all.m)), "memory in bfgs") - ("conjugate_gradient", "use conjugate gradient based optimization") - ("termination", po::value<float>(&(all.rel_threshold)),"Termination threshold"); - - vm = add_options(all, bfgs_opts); - if(!all.holdout_set_off) { all.sd->holdout_best_loss = FLT_MAX; @@ -996,11 +1002,11 @@ base_learner* setup(vw& all, po::variables_map& vm) b.early_stop_thres = vm["early_terminate"].as< size_t>(); } - if (vm.count("hessian_on") || all.m==0) { + if (vm.count("hessian_on") || b.m==0) { all.hessian_on = true; } if (!all.quiet) { - if (all.m>0) + if (b.m>0) cerr << "enabling BFGS based optimization "; else cerr << "enabling conjugate gradient optimization via BFGS "; diff --git a/vowpalwabbit/bfgs.h b/vowpalwabbit/bfgs.h index 1960662b..e66a0df0 100644 --- a/vowpalwabbit/bfgs.h +++ b/vowpalwabbit/bfgs.h @@ -4,6 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace BFGS { - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace BFGS { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc index 04810e26..d9afc63b 100644 --- a/vowpalwabbit/binary.cc +++ b/vowpalwabbit/binary.cc @@ -28,10 +28,16 @@ namespace BINARY { } } - LEARNER::base_learner* setup(vw& all, po::variables_map& vm) - { +LEARNER::base_learner* setup(vw& all) + {//parse and set arguments + new_options(all,"Binary options") + ("binary", "report loss as binary classification on -1,1"); + if(missing_required(all)) return NULL; + + //Create new learner LEARNER::learner<char>& ret = - LEARNER::init_learner<char>(NULL, all.l, predict_or_learn<true>, predict_or_learn<false>); + LEARNER::init_learner<char>(NULL, setup_base(all), + predict_or_learn<true>, predict_or_learn<false>); return make_base(ret); } } diff --git a/vowpalwabbit/binary.h b/vowpalwabbit/binary.h index 609de90b..7b79946d 100644 --- a/vowpalwabbit/binary.h +++ b/vowpalwabbit/binary.h @@ -1,4 +1,2 @@ #pragma once -namespace BINARY { - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace BINARY { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc index 78828632..f66403dd 100644 --- a/vowpalwabbit/bs.cc +++ b/vowpalwabbit/bs.cc @@ -239,28 +239,27 @@ namespace BS { d.pred_vec.~vector(); } - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) { + new_options(all, "Bootstrap options") + ("bootstrap,B", po::value<size_t>(), "bootstrap mode with k rounds by online importance resampling"); + if (missing_required(all)) return NULL; + new_options(all)("bs_type", po::value<string>(), "prediction type {mean,vote}"); + add_options(all); + bs& data = calloc_or_die<bs>(); data.ub = FLT_MAX; data.lb = -FLT_MAX; - - po::options_description bs_options("Bootstrap options"); - bs_options.add_options() - ("bs_type", po::value<string>(), "prediction type {mean,vote}"); - - vm = add_options(all, bs_options); - - data.B = (uint32_t)vm["bootstrap"].as<size_t>(); + data.B = (uint32_t)all.vm["bootstrap"].as<size_t>(); //append bs with number of samples to options_from_file so it is saved to regressor later *all.file_options << " --bootstrap " << data.B; std::string type_string("mean"); - if (vm.count("bs_type")) + if (all.vm.count("bs_type")) { - type_string = vm["bs_type"].as<std::string>(); + type_string = all.vm["bs_type"].as<std::string>(); if (type_string.compare("mean") == 0) { data.bs_type = BS_TYPE_MEAN; @@ -280,7 +279,7 @@ namespace BS { data.pred_vec.reserve(data.B); data.all = &all; - learner<bs>& l = init_learner(&data, all.l, predict_or_learn<true>, + learner<bs>& l = init_learner(&data, setup_base(all), predict_or_learn<true>, predict_or_learn<false>, data.B); l.set_finish_example(finish_example); l.set_finish(finish); diff --git a/vowpalwabbit/bs.h b/vowpalwabbit/bs.h index e05bcd05..062c195a 100644 --- a/vowpalwabbit/bs.h +++ b/vowpalwabbit/bs.h @@ -11,7 +11,7 @@ license as described in the file LICENSE. namespace BS { - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all); void print_result(int f, float res, float weight, v_array<char> tag, float lb, float ub); void output_example(vw& all, example* ec, float lb, float ub); diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc index bf9a6e5b..b14328d3 100644 --- a/vowpalwabbit/cb_algs.cc +++ b/vowpalwabbit/cb_algs.cc @@ -436,36 +436,35 @@ namespace CB_ALGS VW::finish_example(all, &ec); } - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) { + new_options(all, "CB options") + ("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs"); + if (missing_required(all)) return NULL; + new_options(all) + ("cb_type", po::value<string>(), "contextual bandit method to use in {ips,dm,dr}") + ("eval", "Evaluate a policy rather than optimizing."); + add_options(all); + cb& c = calloc_or_die<cb>(); c.all = &all; - uint32_t nb_actions = (uint32_t)vm["cb"].as<size_t>(); - //append cb with nb_actions to file_options so it is saved to regressor later - - po::options_description cb_opts("CB options"); - cb_opts.add_options() - ("cb_type", po::value<string>(), "contextual bandit method to use in {ips,dm,dr}") - ("eval", "Evaluate a policy rather than optimizing.") - ; - - vm = add_options(all, cb_opts); + uint32_t nb_actions = (uint32_t)all.vm["cb"].as<size_t>(); *all.file_options << " --cb " << nb_actions; all.sd->k = nb_actions; bool eval = false; - if (vm.count("eval")) + if (all.vm.count("eval")) eval = true; size_t problem_multiplier = 2;//default for DR - if (vm.count("cb_type")) + if (all.vm.count("cb_type")) { std::string type_string; - type_string = vm["cb_type"].as<std::string>(); + type_string = all.vm["cb_type"].as<std::string>(); *all.file_options << " --cb_type " << type_string; if (type_string.compare("dr") == 0) @@ -496,6 +495,15 @@ namespace CB_ALGS *all.file_options << " --cb_type dr"; } + if (count(all.args.begin(), all.args.end(),"--csoaa") == 0) + { + all.args.push_back("--csoaa"); + stringstream ss; + ss << all.vm["cb"].as<size_t>(); + all.args.push_back(ss.str()); + } + + base_learner* base = setup_base(all); if (eval) all.p->lp = CB_EVAL::cb_eval; else @@ -504,18 +512,18 @@ namespace CB_ALGS learner<cb>* l; if (eval) { - l = &init_learner(&c, all.l, learn_eval, predict_eval, problem_multiplier); + l = &init_learner(&c, base, learn_eval, predict_eval, problem_multiplier); l->set_finish_example(eval_finish_example); } else { - l = &init_learner(&c, all.l, predict_or_learn<true>, predict_or_learn<false>, + l = &init_learner(&c, base, predict_or_learn<true>, predict_or_learn<false>, problem_multiplier); l->set_finish_example(finish_example); } // preserve the increment of the base learner since we are // _adding_ to the number of problems rather than multiplying. - l->increment = all.l->increment; + l->increment = base->increment; l->set_init_driver(init_driver); l->set_finish(finish); diff --git a/vowpalwabbit/cb_algs.h b/vowpalwabbit/cb_algs.h index e989b5a1..0593fb6c 100644 --- a/vowpalwabbit/cb_algs.h +++ b/vowpalwabbit/cb_algs.h @@ -6,8 +6,7 @@ license as described in the file LICENSE. #pragma once //TODO: extend to handle CSOAA_LDF and WAP_LDF namespace CB_ALGS { - - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all); template <bool is_learn> float get_cost_pred(vw& all, CB::cb_class* known_cost, example& ec, uint32_t index, uint32_t base) diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index d8176228..d72011a2 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -371,24 +371,35 @@ namespace CBIFY { void finish(cbify& data) { CB::cb_label.delete_label(&data.cb_label); } - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) {//parse and set arguments - cbify& data = calloc_or_die<cbify>(); - - data.all = &all; - po::options_description cb_opts("CBIFY options"); - cb_opts.add_options() + new_options(all, "CBIFY options") + ("cbify", po::value<size_t>(), "Convert multiclass on <k> classes into a contextual bandit problem and solve"); + if (missing_required(all)) return NULL; + new_options(all) ("first", po::value<size_t>(), "tau-first exploration") ("epsilon",po::value<float>() ,"epsilon-greedy exploration") ("bag",po::value<size_t>() ,"bagging-based exploration") ("cover",po::value<size_t>() ,"bagging-based exploration"); - - vm = add_options(all, cb_opts); - + add_options(all); + + po::variables_map& vm = all.vm; + cbify& data = calloc_or_die<cbify>(); + data.all = &all; data.k = (uint32_t)vm["cbify"].as<size_t>(); *all.file_options << " --cbify " << data.k; + if (count(all.args.begin(), all.args.end(),"--cb") == 0) + { + all.args.push_back("--cb"); + stringstream ss; + ss << vm["cbify"].as<size_t>(); + all.args.push_back(ss.str()); + } + base_learner* base = setup_base(all); + all.p->lp = MULTICLASS::mc_label; + learner<cbify>* l; data.recorder.reset(new vw_recorder()); data.mwt_explorer.reset(new MwtExplorer<vw_context>("vw", *data.recorder.get())); @@ -403,7 +414,7 @@ namespace CBIFY { epsilon = vm["epsilon"].as<float>(); data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k)); data.generic_explorer.reset(new GenericExplorer<vw_context>(*data.scorer.get(), (u32)data.k)); - l = &init_learner(&data, all.l, predict_or_learn_cover<true>, + l = &init_learner(&data, base, predict_or_learn_cover<true>, predict_or_learn_cover<false>, cover + 1); } else if (vm.count("bag")) @@ -414,7 +425,7 @@ namespace CBIFY { data.policies.push_back(unique_ptr<IPolicy<vw_context>>(new vw_policy(i))); } data.bootstrap_explorer.reset(new BootstrapExplorer<vw_context>(data.policies, (u32)data.k)); - l = &init_learner(&data, all.l, predict_or_learn_bag<true>, + l = &init_learner(&data, base, predict_or_learn_bag<true>, predict_or_learn_bag<false>, bags); } else if (vm.count("first") ) @@ -422,7 +433,7 @@ namespace CBIFY { uint32_t tau = (uint32_t)vm["first"].as<size_t>(); data.policy.reset(new vw_policy()); data.tau_explorer.reset(new TauFirstExplorer<vw_context>(*data.policy.get(), (u32)tau, (u32)data.k)); - l = &init_learner(&data, all.l, predict_or_learn_first<true>, + l = &init_learner(&data, base, predict_or_learn_first<true>, predict_or_learn_first<false>, 1); } else @@ -432,7 +443,7 @@ namespace CBIFY { epsilon = vm["epsilon"].as<float>(); data.policy.reset(new vw_policy()); data.greedy_explorer.reset(new EpsilonGreedyExplorer<vw_context>(*data.policy.get(), epsilon, (u32)data.k)); - l = &init_learner(&data, all.l, predict_or_learn_greedy<true>, + l = &init_learner(&data, base, predict_or_learn_greedy<true>, predict_or_learn_greedy<false>, 1); } diff --git a/vowpalwabbit/cbify.h b/vowpalwabbit/cbify.h index aed26b75..2ea0a627 100644 --- a/vowpalwabbit/cbify.h +++ b/vowpalwabbit/cbify.h @@ -4,6 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace CBIFY { - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace CBIFY { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc index 350d40a3..908369b7 100644 --- a/vowpalwabbit/csoaa.cc +++ b/vowpalwabbit/csoaa.cc @@ -13,9 +13,7 @@ license as described in the file LICENSE. #include "gd.h" // GD::foreach_feature() needed in subtract_example() using namespace std; - using namespace LEARNER; - using namespace COST_SENSITIVE; namespace CSOAA { @@ -68,21 +66,25 @@ namespace CSOAA { VW::finish_example(all, &ec); } - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) { + new_options(all, "CSOAA options") + ("csoaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> costs"); + if(missing_required(all)) return NULL; + csoaa& c = calloc_or_die<csoaa>(); c.all = &all; //first parse for number of actions uint32_t nb_actions = 0; - nb_actions = (uint32_t)vm["csoaa"].as<size_t>(); + nb_actions = (uint32_t)all.vm["csoaa"].as<size_t>(); //append csoaa with nb_actions to file_options so it is saved to regressor later *all.file_options << " --csoaa " << nb_actions; all.p->lp = cs_label; all.sd->k = nb_actions; - learner<csoaa>& l = init_learner(&c, all.l, predict_or_learn<true>, + learner<csoaa>& l = init_learner(&c, setup_base(all), predict_or_learn<true>, predict_or_learn<false>, nb_actions); l.set_finish_example(finish_example); return make_base(l); @@ -645,15 +647,17 @@ namespace LabelDict { } } - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) { - po::options_description ldf_opts("LDF Options"); - ldf_opts.add_options() - ("ldf_override", po::value<string>(), "Override singleline or multiline from csoaa_ldf or wap_ldf, eg if stored in file") - ; - - vm = add_options(all, ldf_opts); - + new_options(all, "LDF Options") + ("csoaa_ldf", po::value<string>(), "Use one-against-all multiclass learning with label dependent features. Specify singleline or multiline.") + ("wap_ldf", po::value<string>(), "Use weighted all-pairs multiclass learning with label dependent features. Specify singleline or multiline."); + if (missing_required(all)) return NULL; + new_options(all) + ("ldf_override", po::value<string>(), "Override singleline or multiline from csoaa_ldf or wap_ldf, eg if stored in file"); + add_options(all); + + po::variables_map& vm = all.vm; ldf& ld = calloc_or_die<ldf>(); ld.all = &all; @@ -708,7 +712,7 @@ namespace LabelDict { ld.read_example_this_loop = 0; ld.need_to_clear = false; - learner<ldf>& l = init_learner(&ld, all.l, predict_or_learn<true>, predict_or_learn<false>); + learner<ldf>& l = init_learner(&ld, setup_base(all), predict_or_learn<true>, predict_or_learn<false>); if (ld.is_singleline) l.set_finish_example(finish_singleline_example); else diff --git a/vowpalwabbit/csoaa.h b/vowpalwabbit/csoaa.h index 79f5e4d2..45e4edf2 100644 --- a/vowpalwabbit/csoaa.h +++ b/vowpalwabbit/csoaa.h @@ -4,17 +4,14 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace CSOAA { - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace CSOAA { LEARNER::base_learner* setup(vw& all); } namespace CSOAA_AND_WAP_LDF { - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); - -namespace LabelDict { - bool ec_is_example_header(example& ec); // example headers look like "0:-1" or just "shared" - void add_example_namespaces_from_example(example& target, example& source); - void del_example_namespaces_from_example(example& target, example& source); -} + LEARNER::base_learner* setup(vw& all); + namespace LabelDict { + bool ec_is_example_header(example& ec);// example headers look like "0:-1" or just "shared" + void add_example_namespaces_from_example(example& target, example& source); + void del_example_namespaces_from_example(example& target, example& source); + } } diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index dea87040..1a071eb2 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -17,7 +17,6 @@ license as described in the file LICENSE. #include "reductions.h" #include "multiclass.h" #include "simple_label.h" -#include "vw.h" using namespace std; using namespace LEARNER; @@ -51,8 +50,6 @@ namespace ECT uint32_t last_pair; v_array<bool> tournaments_won; - - vw* all; }; bool exists(v_array<size_t> db) @@ -184,7 +181,7 @@ namespace ECT return e.last_pair + (eliminations-1); } - uint32_t ect_predict(vw& all, ect& e, base_learner& base, example& ec) + uint32_t ect_predict(ect& e, base_learner& base, example& ec) { if (e.k == (size_t)1) return 1; @@ -228,7 +225,7 @@ namespace ECT return false; } - void ect_train(vw& all, ect& e, base_learner& base, example& ec) + void ect_train(ect& e, base_learner& base, example& ec) { if (e.k == 1)//nothing to do return; @@ -318,25 +315,21 @@ namespace ECT } void predict(ect& e, base_learner& base, example& ec) { - vw* all = e.all; - MULTICLASS::multiclass mc = ec.l.multi; if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1)) cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl; - ec.pred.multiclass = ect_predict(*all, e, base, ec); + ec.pred.multiclass = ect_predict(e, base, ec); ec.l.multi = mc; } void learn(ect& e, base_learner& base, example& ec) { - vw* all = e.all; - MULTICLASS::multiclass mc = ec.l.multi; predict(e, base, ec); uint32_t pred = ec.pred.multiclass; - if (mc.label != (uint32_t)-1 && all->training) - ect_train(*all, e, base, ec); + if (mc.label != (uint32_t)-1) + ect_train(e, base, ec); ec.l.multi = mc; ec.pred.multiclass = pred; } @@ -360,36 +353,26 @@ namespace ECT e.tournaments_won.delete_v(); } - void finish_example(vw& all, ect&, example& ec) { MULTICLASS::finish_example(all, ec); } - - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) { - ect& data = calloc_or_die<ect>(); - po::options_description ect_opts("ECT options"); - ect_opts.add_options() - ("error", po::value<size_t>(), "error in ECT"); + new_options(all, "Error Correcting Tournament options") + ("ect", po::value<size_t>(), "Use error correcting tournament with <k> labels"); + if (missing_required(all)) return NULL; + new_options(all)("error", po::value<size_t>()->default_value(0), "error in ECT"); + add_options(all); - vm = add_options(all, ect_opts); - - //first parse for number of actions - data.k = (int)vm["ect"].as<size_t>(); - - //append ect with nb_actions to options_from_file so it is saved to regressor later - if (vm.count("error")) { - data.errors = (uint32_t)vm["error"].as<size_t>(); - } else - data.errors = 0; + ect& data = calloc_or_die<ect>(); + data.k = (int)all.vm["ect"].as<size_t>(); + data.errors = (uint32_t)all.vm["error"].as<size_t>(); //append error flag to options_from_file so it is saved in regressor file later *all.file_options << " --ect " << data.k << " --error " << data.errors; - all.p->lp = MULTICLASS::mc_label; size_t wpp = create_circuit(all, data, data.k, data.errors+1); - data.all = &all; - learner<ect>& l = init_learner(&data, all.l, learn, predict, wpp); - l.set_finish_example(finish_example); + learner<ect>& l = init_learner(&data, setup_base(all), learn, predict, wpp); + l.set_finish_example(MULTICLASS::finish_example<ect>); + all.p->lp = MULTICLASS::mc_label; l.set_finish(finish); - return make_base(l); } } diff --git a/vowpalwabbit/ect.h b/vowpalwabbit/ect.h index 81129791..c02d3848 100644 --- a/vowpalwabbit/ect.h +++ b/vowpalwabbit/ect.h @@ -4,7 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace ECT -{ - LEARNER::base_learner* setup(vw&, po::variables_map&); -} +namespace ECT { LEARNER::base_learner* setup(vw&); } diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc index 6216058b..5eecfc26 100644 --- a/vowpalwabbit/ftrl_proximal.cc +++ b/vowpalwabbit/ftrl_proximal.cc @@ -55,8 +55,8 @@ namespace FTRL { }; void update_accumulated_state(weight* w, float ftrl_alpha) { - double ng2 = w[W_G2] + w[W_GT]*w[W_GT]; - double sigma = (sqrt(ng2) - sqrt(w[W_G2]))/ ftrl_alpha; + float ng2 = w[W_G2] + w[W_GT]*w[W_GT]; + float sigma = (sqrtf(ng2) - sqrtf(w[W_G2]))/ ftrl_alpha; w[W_ZT] += w[W_GT] - sigma * w[W_XT]; w[W_G2] = ng2; } @@ -107,7 +107,7 @@ namespace FTRL { if (fabs_zt <= d.l1_lambda) { w[W_XT] = 0.; } else { - double step = 1/(d.l2_lambda + (d.ftrl_beta + sqrt(w[W_G2]))/d.ftrl_alpha); + float step = 1/(d.l2_lambda + (d.ftrl_beta + sqrtf(w[W_G2]))/d.ftrl_alpha); w[W_XT] = step * flag * (d.l1_lambda - fabs_zt); } } @@ -129,7 +129,7 @@ namespace FTRL { label_data& ld = ec.l.simple; ec.loss = all.loss->getLoss(all.sd, ec.updated_prediction, ld.label) * ld.weight; if (b.progressive_validation) { - float v = 1./(1 + exp(-ec.updated_prediction)); + float v = 1.f/(1 + exp(-ec.updated_prediction)); fprintf(b.fo, "%.6f\t%d\n", v, (int)(ld.label * ld.weight)); } } @@ -177,35 +177,27 @@ namespace FTRL { ec.pred.scalar = ftrl_predict(*all,ec); } - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) { + new_options(all, "FTRL options") + ("ftrl", "Follow the Regularized Leader"); + if (missing_required(all)) return NULL; + new_options(all) + ("ftrl_alpha", po::value<float>()->default_value(0.0), "Learning rate for FTRL-proximal optimization") + ("ftrl_beta", po::value<float>()->default_value(0.1f), "FTRL beta") + ("progressive_validation", po::value<string>()->default_value("ftrl.evl"), "File to record progressive validation for ftrl-proximal"); + add_options(all); + ftrl& b = calloc_or_die<ftrl>(); b.all = &all; - b.ftrl_beta = 0.0; - b.ftrl_alpha = 0.1; - - po::options_description ftrl_opts("FTRL options"); - - ftrl_opts.add_options() - ("ftrl_alpha", po::value<float>(&(b.ftrl_alpha)), "Learning rate for FTRL-proximal optimization") - ("ftrl_beta", po::value<float>(&(b.ftrl_beta)), "FTRL beta") - ("progressive_validation", po::value<string>()->default_value("ftrl.evl"), "File to record progressive validation for ftrl-proximal"); - - vm = add_options(all, ftrl_opts); - - if (vm.count("ftrl_alpha")) { - b.ftrl_alpha = vm["ftrl_alpha"].as<float>(); - } - - if (vm.count("ftrl_beta")) { - b.ftrl_beta = vm["ftrl_beta"].as<float>(); - } + b.ftrl_beta = all.vm["ftrl_beta"].as<float>(); + b.ftrl_alpha = all.vm["ftrl_alpha"].as<float>(); all.reg.stride_shift = 2; // NOTE: for more parameter storage b.progressive_validation = false; - if (vm.count("progressive_validation")) { - std::string filename = vm["progressive_validation"].as<string>(); + if (all.vm.count("progressive_validation")) { + std::string filename = all.vm["progressive_validation"].as<string>(); b.fo = fopen(filename.c_str(), "w"); assert(b.fo != NULL); b.progressive_validation = true; diff --git a/vowpalwabbit/ftrl_proximal.h b/vowpalwabbit/ftrl_proximal.h index 59bf4653..dd495d3e 100644 --- a/vowpalwabbit/ftrl_proximal.h +++ b/vowpalwabbit/ftrl_proximal.h @@ -3,10 +3,5 @@ Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ -#ifndef FTRL_PROXIMAL_H -#define FTRL_PROXIMAL_H - -namespace FTRL { - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} -#endif +#pragma once +namespace FTRL { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc index 851fbea7..8460bab8 100644 --- a/vowpalwabbit/gd.cc +++ b/vowpalwabbit/gd.cc @@ -844,8 +844,16 @@ uint32_t ceil_log_2(uint32_t v) return 1 + ceil_log_2(v >> 1); } -base_learner* setup(vw& all, po::variables_map& vm) +base_learner* setup(vw& all) { + new_options(all, "Gradient Descent options") + ("sgd", "use regular stochastic gradient descent update.") + ("adaptive", "use adaptive, individual learning rates.") + ("invariant", "use safe/importance aware updates.") + ("normalized", "use per feature normalized updates") + ("exact_adaptive_norm", "use current default invariant normalized adaptive update rule"); + add_options(all); + po::variables_map& vm = all.vm; gd& g = calloc_or_die<gd>(); g.all = &all; g.all->normalized_sum_norm_x = 0; diff --git a/vowpalwabbit/gd.h b/vowpalwabbit/gd.h index de3964eb..a5413eae 100644 --- a/vowpalwabbit/gd.h +++ b/vowpalwabbit/gd.h @@ -24,7 +24,7 @@ namespace GD{ void compute_update(example* ec); void offset_train(regressor ®, example* &ec, float update, size_t offset); void train_one_example_single_thread(regressor& r, example* ex); - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all); void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text); void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text); void output_and_account_example(example* ec); diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc index b57328f4..8c51b0ce 100644 --- a/vowpalwabbit/gd_mf.cc +++ b/vowpalwabbit/gd_mf.cc @@ -26,10 +26,12 @@ using namespace LEARNER; namespace GDMF { struct gdmf { vw* all; + uint32_t rank; }; -void mf_print_offset_features(vw& all, example& ec, size_t offset) +void mf_print_offset_features(gdmf& d, example& ec, size_t offset) { + vw& all = *d.all; weight* weights = all.reg.weight_vector; size_t mask = all.reg.weight_mask; for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) @@ -53,7 +55,7 @@ void mf_print_offset_features(vw& all, example& ec, size_t offset) if (ec.atomics[(int)(*i)[0]].size() > 0 && ec.atomics[(int)(*i)[1]].size() > 0) { /* print out nsk^feature:hash:value:weight:nsk^feature^:hash:value:weight:prod_weights */ - for (size_t k = 1; k <= all.rank; k++) + for (size_t k = 1; k <= d.rank; k++) { for (audit_data* f = ec.audit_features[(int)(*i)[0]].begin; f!= ec.audit_features[(int)(*i)[0]].end; f++) for (audit_data* f2 = ec.audit_features[(int)(*i)[1]].begin; f2!= ec.audit_features[(int)(*i)[1]].end; f2++) @@ -62,11 +64,11 @@ void mf_print_offset_features(vw& all, example& ec, size_t offset) <<"(" << ((f->weight_index + offset +k) & mask) << ")" << ':' << f->x; cout << ':' << weights[(f->weight_index + offset + k) & mask]; - cout << ':' << f2->space << k << '^' << f2->feature << ':' << ((f2->weight_index+k+all.rank)&mask) - <<"(" << ((f2->weight_index + offset +k+all.rank) & mask) << ")" << ':' << f2->x; - cout << ':' << weights[(f2->weight_index + offset + k+all.rank) & mask]; + cout << ':' << f2->space << k << '^' << f2->feature << ':' << ((f2->weight_index+k+d.rank)&mask) + <<"(" << ((f2->weight_index + offset +k+d.rank) & mask) << ")" << ':' << f2->x; + cout << ':' << weights[(f2->weight_index + offset + k+d.rank) & mask]; - cout << ':' << weights[(f->weight_index + offset + k) & mask] * weights[(f2->weight_index + offset + k + all.rank) & mask]; + cout << ':' << weights[(f->weight_index + offset + k) & mask] * weights[(f2->weight_index + offset + k + d.rank) & mask]; } } } @@ -77,17 +79,25 @@ void mf_print_offset_features(vw& all, example& ec, size_t offset) cout << endl; } -void mf_print_audit_features(vw& all, example& ec, size_t offset) +void mf_print_audit_features(gdmf& d, example& ec, size_t offset) { - print_result(all.stdout_fileno,ec.pred.scalar,-1,ec.tag); - mf_print_offset_features(all, ec, offset); + print_result(d.all->stdout_fileno,ec.pred.scalar,-1,ec.tag); + mf_print_offset_features(d, ec, offset); } -float mf_predict(vw& all, example& ec) +float mf_predict(gdmf& d, example& ec) { + vw& all = *d.all; label_data& ld = ec.l.simple; float prediction = ld.initial; + for (vector<string>::iterator i = d.all->pairs.begin(); i != d.all->pairs.end();i++) + { + ec.num_features -= ec.atomics[(int)(*i)[0]].size() * ec.atomics[(int)(*i)[1]].size(); + ec.num_features += ec.atomics[(int)(*i)[0]].size() * d.rank; + ec.num_features += ec.atomics[(int)(*i)[1]].size() * d.rank; + } + // clear stored predictions ec.topic_predictions.erase(); @@ -107,18 +117,18 @@ float mf_predict(vw& all, example& ec) { if (ec.atomics[(int)(*i)[0]].size() > 0 && ec.atomics[(int)(*i)[1]].size() > 0) { - for (uint32_t k = 1; k <= all.rank; k++) + for (uint32_t k = 1; k <= d.rank; k++) { // x_l * l^k - // l^k is from index+1 to index+all.rank + // l^k is from index+1 to index+d.rank //float x_dot_l = sd_offset_add(weights, mask, ec.atomics[(int)(*i)[0]].begin, ec.atomics[(int)(*i)[0]].end, k); float x_dot_l = 0.; GD::foreach_feature<float, GD::vec_add>(all.reg.weight_vector, all.reg.weight_mask, ec.atomics[(int)(*i)[0]].begin, ec.atomics[(int)(*i)[0]].end, x_dot_l, k); // x_r * r^k - // r^k is from index+all.rank+1 to index+2*all.rank - //float x_dot_r = sd_offset_add(weights, mask, ec.atomics[(int)(*i)[1]].begin, ec.atomics[(int)(*i)[1]].end, k+all.rank); + // r^k is from index+d.rank+1 to index+2*d.rank + //float x_dot_r = sd_offset_add(weights, mask, ec.atomics[(int)(*i)[1]].begin, ec.atomics[(int)(*i)[1]].end, k+d.rank); float x_dot_r = 0.; - GD::foreach_feature<float,GD::vec_add>(all.reg.weight_vector, all.reg.weight_mask, ec.atomics[(int)(*i)[1]].begin, ec.atomics[(int)(*i)[1]].end, x_dot_r, k+all.rank); + GD::foreach_feature<float,GD::vec_add>(all.reg.weight_vector, all.reg.weight_mask, ec.atomics[(int)(*i)[1]].begin, ec.atomics[(int)(*i)[1]].end, x_dot_r, k+d.rank); prediction += x_dot_l * x_dot_r; @@ -146,7 +156,7 @@ float mf_predict(vw& all, example& ec) ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ld.weight; if (all.audit) - mf_print_audit_features(all, ec, 0); + mf_print_audit_features(d, ec, 0); return ec.pred.scalar; } @@ -158,55 +168,55 @@ void sd_offset_update(weight* weights, size_t mask, feature* begin, feature* end weights[(f->weight_index + offset) & mask] += update * f->x - regularization * weights[(f->weight_index + offset) & mask]; } -void mf_train(vw& all, example& ec) -{ - weight* weights = all.reg.weight_vector; - size_t mask = all.reg.weight_mask; - label_data& ld = ec.l.simple; - - // use final prediction to get update size + void mf_train(gdmf& d, example& ec) + { + vw& all = *d.all; + weight* weights = all.reg.weight_vector; + size_t mask = all.reg.weight_mask; + label_data& ld = ec.l.simple; + + // use final prediction to get update size // update = eta_t*(y-y_hat) where eta_t = eta/(3*t^p) * importance weight - float eta_t = all.eta/pow(ec.example_t,all.power_t) / 3.f * ld.weight; - float update = all.loss->getUpdate(ec.pred.scalar, ld.label, eta_t, 1.); //ec.total_sum_feat_sq); - - float regularization = eta_t * all.l2_lambda; - - // linear update - for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) - sd_offset_update(weights, mask, ec.atomics[*i].begin, ec.atomics[*i].end, 0, update, regularization); - - // quadratic update - for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++) - { - if (ec.atomics[(int)(*i)[0]].size() > 0 && ec.atomics[(int)(*i)[1]].size() > 0) - { - - // update l^k weights - for (size_t k = 1; k <= all.rank; k++) - { - // r^k \cdot x_r - float r_dot_x = ec.topic_predictions[2*k]; - // l^k <- l^k + update * (r^k \cdot x_r) * x_l - sd_offset_update(weights, mask, ec.atomics[(int)(*i)[0]].begin, ec.atomics[(int)(*i)[0]].end, k, update*r_dot_x, regularization); - } - - // update r^k weights - for (size_t k = 1; k <= all.rank; k++) - { - // l^k \cdot x_l - float l_dot_x = ec.topic_predictions[2*k-1]; - // r^k <- r^k + update * (l^k \cdot x_l) * x_r - sd_offset_update(weights, mask, ec.atomics[(int)(*i)[1]].begin, ec.atomics[(int)(*i)[1]].end, k+all.rank, update*l_dot_x, regularization); - } - - } - } - if (all.triples.begin() != all.triples.end()) { - cerr << "cannot use triples in matrix factorization" << endl; - throw exception(); - } -} - + float eta_t = all.eta/pow(ec.example_t,all.power_t) / 3.f * ld.weight; + float update = all.loss->getUpdate(ec.pred.scalar, ld.label, eta_t, 1.); //ec.total_sum_feat_sq); + + float regularization = eta_t * all.l2_lambda; + + // linear update + for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) + sd_offset_update(weights, mask, ec.atomics[*i].begin, ec.atomics[*i].end, 0, update, regularization); + + // quadratic update + for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++) + { + if (ec.atomics[(int)(*i)[0]].size() > 0 && ec.atomics[(int)(*i)[1]].size() > 0) + { + + // update l^k weights + for (size_t k = 1; k <= d.rank; k++) + { + // r^k \cdot x_r + float r_dot_x = ec.topic_predictions[2*k]; + // l^k <- l^k + update * (r^k \cdot x_r) * x_l + sd_offset_update(weights, mask, ec.atomics[(int)(*i)[0]].begin, ec.atomics[(int)(*i)[0]].end, k, update*r_dot_x, regularization); + } + // update r^k weights + for (size_t k = 1; k <= d.rank; k++) + { + // l^k \cdot x_l + float l_dot_x = ec.topic_predictions[2*k-1]; + // r^k <- r^k + update * (l^k \cdot x_l) * x_r + sd_offset_update(weights, mask, ec.atomics[(int)(*i)[1]].begin, ec.atomics[(int)(*i)[1]].end, k+d.rank, update*l_dot_x, regularization); + } + + } + } + if (all.triples.begin() != all.triples.end()) { + cerr << "cannot use triples in matrix factorization" << endl; + throw exception(); + } + } + void save_load(gdmf& d, io_buf& model_file, bool read, bool text) { vw* all = d.all; @@ -231,7 +241,7 @@ void mf_train(vw& all, example& ec) do { brw = 0; - size_t K = all->rank*2+1; + size_t K = d.rank*2+1; text_len = sprintf(buff, "%d ", i); brw += bin_text_read_write_fixed(model_file,(char *)&i, sizeof (i), @@ -272,58 +282,59 @@ void mf_train(vw& all, example& ec) all->current_pass++; } - void predict(gdmf& d, base_learner& base, example& ec) - { - vw* all = d.all; - - mf_predict(*all,ec); - } + void predict(gdmf& d, base_learner&, example& ec) { mf_predict(d,ec); } void learn(gdmf& d, base_learner& base, example& ec) { - vw* all = d.all; + vw& all = *d.all; - predict(d, base, ec); - if (all->training && ec.l.simple.label != FLT_MAX) - mf_train(*all, ec); + mf_predict(d, ec); + if (all.training && ec.l.simple.label != FLT_MAX) + mf_train(d, ec); } - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) { + new_options(all, "Gdmf options") + ("rank", po::value<uint32_t>(), "rank for matrix factorization."); + if(missing_required(all)) return NULL; + gdmf& data = calloc_or_die<gdmf>(); data.all = &all; + data.rank = all.vm["rank"].as<uint32_t>(); + *all.file_options << " --rank " << data.rank; // store linear + 2*rank weights per index, round up to power of two - float temp = ceilf(logf((float)(all.rank*2+1)) / logf (2.f)); + float temp = ceilf(logf((float)(data.rank*2+1)) / logf (2.f)); all.reg.stride_shift = (size_t) temp; all.random_weights = true; - if ( vm.count("adaptive") ) + if ( all.vm.count("adaptive") ) { cerr << "adaptive is not implemented for matrix factorization" << endl; throw exception(); } - if ( vm.count("normalized") ) + if ( all.vm.count("normalized") ) { cerr << "normalized is not implemented for matrix factorization" << endl; throw exception(); } - if ( vm.count("exact_adaptive_norm") ) + if ( all.vm.count("exact_adaptive_norm") ) { cerr << "normalized adaptive updates is not implemented for matrix factorization" << endl; throw exception(); } - if (vm.count("bfgs") || vm.count("conjugate_gradient")) + if (all.vm.count("bfgs") || all.vm.count("conjugate_gradient")) { cerr << "bfgs is not implemented for matrix factorization" << endl; throw exception(); } - if(!vm.count("learning_rate") && !vm.count("l")) + if(!all.vm.count("learning_rate") && !all.vm.count("l")) all.eta = 10; //default learning rate to 10 for non default update rule //default initial_t to 1 instead of 0 - if(!vm.count("initial_t")) { + if(!all.vm.count("initial_t")) { all.sd->t = 1.f; all.sd->weighted_unlabeled_examples = 1.f; all.initial_t = 1.f; diff --git a/vowpalwabbit/gd_mf.h b/vowpalwabbit/gd_mf.h index a0ce2f87..09623e61 100644 --- a/vowpalwabbit/gd_mf.h +++ b/vowpalwabbit/gd_mf.h @@ -4,12 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -#include <math.h> -#include "example.h" -#include "parse_regressor.h" -#include "parser.h" -#include "gd.h" - -namespace GDMF{ - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace GDMF{ LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/global_data.cc b/vowpalwabbit/global_data.cc index 825977d5..cd288b52 100644 --- a/vowpalwabbit/global_data.cc +++ b/vowpalwabbit/global_data.cc @@ -214,23 +214,44 @@ void compile_limits(vector<string> limits, uint32_t* dest, bool quiet) } } -po::variables_map add_options(vw& all, po::options_description& opts) +void add_options(vw& all, po::options_description& opts) { all.opts.add(opts); po::variables_map new_vm; - //parse local opts once for notifications. po::parsed_options parsed = po::command_line_parser(all.args). style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing). options(opts).allow_unregistered().run(); po::store(parsed, new_vm); po::notify(new_vm); - //parse all opts for a complete variable map. - parsed = po::command_line_parser(all.args). + + for (po::variables_map::iterator it=new_vm.begin(); it!=new_vm.end(); ++it) + all.vm.insert(*it); +} + +void add_options(vw& all) +{ + add_options(all, *all.new_opts); + delete all.new_opts; +} + +bool missing_required(vw& all) +{ + //parse local opts once for notifications. + po::parsed_options parsed = po::command_line_parser(all.args). style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing). - options(all.opts).allow_unregistered().run(); + options(*all.new_opts).allow_unregistered().run(); + po::variables_map new_vm; po::store(parsed, new_vm); - return new_vm; + all.opts.add(*all.new_opts); + delete all.new_opts; + for (po::variables_map::iterator it=new_vm.begin(); it!=new_vm.end(); ++it) + all.vm.insert(*it); + + if (new_vm.size() == 0) // required are missing; + return true; + else + return false; } vw::vw() @@ -247,6 +268,7 @@ vw::vw() reg_mode = 0; current_pass = 0; + reduction_stack=v_init<LEARNER::base_learner* (*)(vw&)>(); data_filename = ""; @@ -260,13 +282,7 @@ vw::vw() default_bits = true; daemon = false; num_children = 10; - lda_alpha = 0.1f; - lda_rho = 0.1f; - lda_D = 10000.; - lda_epsilon = 0.001f; - minibatch = 1; span_server = ""; - m = 15; save_resume = false; random_positive_weights = false; @@ -276,8 +292,6 @@ vw::vw() power_t = 0.5; eta = 0.5; //default learning rate for normalized adaptive updates, this is switched to 10 by default for the other updates (see parse_args.cc) numpasses = 1; - rel_threshold = 0.001f; - rank = 0; final_prediction_sink.begin = final_prediction_sink.end=final_prediction_sink.end_array = NULL; raw_prediction = -1; diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h index bdf76b46..7e6ec74b 100644 --- a/vowpalwabbit/global_data.h +++ b/vowpalwabbit/global_data.h @@ -193,12 +193,13 @@ struct vw { bool bfgs; bool hessian_on; - int m; bool save_resume; double normalized_sum_norm_x; po::options_description opts; + po::options_description* new_opts; + po::variables_map vm; std::stringstream* file_options; vector<std::string> args; @@ -217,10 +218,6 @@ struct vw { float power_t;//the power on learning rate decay. int reg_mode; - size_t minibatch; - - float rel_threshold; // termination threshold - size_t pass_length; size_t numpasses; size_t passes_complete; @@ -262,19 +259,14 @@ struct vw { size_t normalized_idx; //offset idx where the norm is stored (1 or 2 depending on whether adaptive is true) uint32_t lda; - float lda_alpha; - float lda_rho; - float lda_D; - float lda_epsilon; std::string text_regressor_name; std::string inv_hash_regressor_name; - std::string span_server; size_t length () { return ((size_t)1) << num_bits; }; - uint32_t rank; + v_array<LEARNER::base_learner* (*)(vw&)> reduction_stack; //Prediction output v_array<int> final_prediction_sink; // set to send global predictions to. @@ -321,4 +313,11 @@ void get_prediction(int sock, float& res, float& weight); void compile_gram(vector<string> grams, uint32_t* dest, char* descriptor, bool quiet); void compile_limits(vector<string> limits, uint32_t* dest, bool quiet); int print_tag(std::stringstream& ss, v_array<char> tag); -po::variables_map add_options(vw& all, po::options_description& opts); +void add_options(vw& all, po::options_description& opts); +inline po::options_description_easy_init new_options(vw& all, std::string name = "\0") +{ + all.new_opts = new po::options_description(name); + return all.new_opts->add_options(); +} +bool missing_required(vw& all); +void add_options(vw& all); diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc index 9a88f17e..ed917cda 100644 --- a/vowpalwabbit/kernel_svm.cc +++ b/vowpalwabbit/kernel_svm.cc @@ -648,7 +648,7 @@ namespace KSVM else { for(size_t i = 0;i < params.pool_pos;i++) { - float queryp = 2.0f/(1.0f + expf((float)(params.active_c*fabs(scores[i]))*pow(params.pool[i]->ex.example_counter,0.5f))); + float queryp = 2.0f/(1.0f + expf((float)(params.active_c*fabs(scores[i]))*(float)pow(params.pool[i]->ex.example_counter,0.5f))); if(rand() < queryp) { svm_example* fec = params.pool[i]; fec->ex.l.simple.weight *= 1/queryp; @@ -790,13 +790,14 @@ namespace KSVM cerr<<"Done with finish \n"; } - - LEARNER::base_learner* setup(vw &all, po::variables_map& vm) { - po::options_description desc("KSVM options"); - desc.add_options() + LEARNER::base_learner* setup(vw &all) { + new_options(all, "KSVM options") + ("ksvm", "kernel svm"); + if (missing_required(all)) return NULL; + new_options(all) ("reprocess", po::value<size_t>(), "number of reprocess steps for LASVM") - ("active", "do active learning") - ("active_c", po::value<double>(), "parameter for query prob") + // ("active", "do active learning") + //("active_c", po::value<double>(), "parameter for query prob") ("pool_greedy", "use greedy selection on mini pools") ("para_active", "do parallel active learning") ("pool_size", po::value<size_t>(), "size of pools for active learning") @@ -805,8 +806,9 @@ namespace KSVM ("bandwidth", po::value<float>(), "bandwidth of rbf kernel") ("degree", po::value<int>(), "degree of poly kernel") ("lambda", po::value<double>(), "saving regularization for test time"); - vm = add_options(all, desc); + add_options(all); + po::variables_map& vm = all.vm; string loss_function = "hinge"; float loss_parameter = 0.0; delete all.loss; diff --git a/vowpalwabbit/kernel_svm.h b/vowpalwabbit/kernel_svm.h index 7a65a051..288e6729 100644 --- a/vowpalwabbit/kernel_svm.h +++ b/vowpalwabbit/kernel_svm.h @@ -4,7 +4,4 @@ individual contributors. All rights reserved. Released under a BSD (revised) license as described in the file LICENSE. */ #pragma once -namespace KSVM -{ -LEARNER::base_learner* setup(vw &all, po::variables_map& vm); -} +namespace KSVM { LEARNER::base_learner* setup(vw &all); } diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc index 8da81a4d..e616ab25 100644 --- a/vowpalwabbit/lda_core.cc +++ b/vowpalwabbit/lda_core.cc @@ -1,799 +1,818 @@ -/*
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD (revised)
-license as described in the file LICENSE.
- */
-#include <fstream>
-#include <vector>
-#include <float.h>
-#ifdef _WIN32
-#include <winsock2.h>
-#else
-#include <netdb.h>
-#endif
-#include <string.h>
-#include <stdio.h>
-#include <assert.h>
-#include "constant.h"
-#include "gd.h"
-#include "simple_label.h"
-#include "rand48.h"
-#include "reductions.h"
-
-using namespace LEARNER;
-using namespace std;
-
-namespace LDA {
-
-class index_feature {
-public:
- uint32_t document;
- feature f;
- bool operator<(const index_feature b) const { return f.weight_index < b.f.weight_index; }
-};
-
- struct lda {
- v_array<float> Elogtheta;
- v_array<float> decay_levels;
- v_array<float> total_new;
- v_array<example* > examples;
- v_array<float> total_lambda;
- v_array<int> doc_lengths;
- v_array<float> digammas;
- v_array<float> v;
- vector<index_feature> sorted_features;
-
- bool total_lambda_init;
-
- double example_t;
- vw* all;
- };
-
-#ifdef _WIN32
-inline float fmax(float f1, float f2) { return (f1 < f2 ? f2 : f1); }
-inline float fmin(float f1, float f2) { return (f1 > f2 ? f2 : f1); }
-#endif
-
-#define MINEIRO_SPECIAL
-#ifdef MINEIRO_SPECIAL
-
-namespace {
-
-inline float
-fastlog2 (float x)
-{
- union { float f; uint32_t i; } vx = { x };
- union { uint32_t i; float f; } mx = { (vx.i & 0x007FFFFF) | (0x7e << 23) };
- float y = (float)vx.i;
- y *= 1.0f / (float)(1 << 23);
-
- return
- y - 124.22544637f - 1.498030302f * mx.f - 1.72587999f / (0.3520887068f + mx.f);
-}
-
-inline float
-fastlog (float x)
-{
- return 0.69314718f * fastlog2 (x);
-}
-
-inline float
-fastpow2 (float p)
-{
- float offset = (p < 0) ? 1.0f : 0.0f;
- float clipp = (p < -126) ? -126.0f : p;
- int w = (int)clipp;
- float z = clipp - w + offset;
- union { uint32_t i; float f; } v = { (uint32_t)((1 << 23) * (clipp + 121.2740838f + 27.7280233f / (4.84252568f - z) - 1.49012907f * z)) };
-
- return v.f;
-}
-
-inline float
-fastexp (float p)
-{
- return fastpow2 (1.442695040f * p);
-}
-
-inline float
-fastpow (float x,
- float p)
-{
- return fastpow2 (p * fastlog2 (x));
-}
-
-inline float
-fastlgamma (float x)
-{
- float logterm = fastlog (x * (1.0f + x) * (2.0f + x));
- float xp3 = 3.0f + x;
-
- return
- -2.081061466f - x + 0.0833333f / xp3 - logterm + (2.5f + x) * fastlog (xp3);
-}
-
-inline float
-fastdigamma (float x)
-{
- float twopx = 2.0f + x;
- float logterm = fastlog (twopx);
-
- return - (1.0f + 2.0f * x) / (x * (1.0f + x))
- - (13.0f + 6.0f * x) / (12.0f * twopx * twopx)
- + logterm;
-}
-
-#define log fastlog
-#define exp fastexp
-#define powf fastpow
-#define mydigamma fastdigamma
-#define mylgamma fastlgamma
-
-#if defined(__SSE2__) && !defined(VW_LDA_NO_SSE)
-
-#include <emmintrin.h>
-
-typedef __m128 v4sf;
-typedef __m128i v4si;
-
-#define v4si_to_v4sf _mm_cvtepi32_ps
-#define v4sf_to_v4si _mm_cvttps_epi32
-
-static inline float
-v4sf_index (const v4sf x,
- unsigned int i)
-{
- union { v4sf f; float array[4]; } tmp = { x };
-
- return tmp.array[i];
-}
-
-static inline const v4sf
-v4sfl (float x)
-{
- union { float array[4]; v4sf f; } tmp = { { x, x, x, x } };
-
- return tmp.f;
-}
-
-static inline const v4si
-v4sil (uint32_t x)
-{
- uint64_t wide = (((uint64_t) x) << 32) | x;
- union { uint64_t array[2]; v4si f; } tmp = { { wide, wide } };
-
- return tmp.f;
-}
-
-static inline v4sf
-vfastpow2 (const v4sf p)
-{
- v4sf ltzero = _mm_cmplt_ps (p, v4sfl (0.0f));
- v4sf offset = _mm_and_ps (ltzero, v4sfl (1.0f));
- v4sf lt126 = _mm_cmplt_ps (p, v4sfl (-126.0f));
- v4sf clipp = _mm_andnot_ps (lt126, p) + _mm_and_ps (lt126, v4sfl (-126.0f));
- v4si w = v4sf_to_v4si (clipp);
- v4sf z = clipp - v4si_to_v4sf (w) + offset;
-
- const v4sf c_121_2740838 = v4sfl (121.2740838f);
- const v4sf c_27_7280233 = v4sfl (27.7280233f);
- const v4sf c_4_84252568 = v4sfl (4.84252568f);
- const v4sf c_1_49012907 = v4sfl (1.49012907f);
- union { v4si i; v4sf f; } v = {
- v4sf_to_v4si (
- v4sfl (1 << 23) *
- (clipp + c_121_2740838 + c_27_7280233 / (c_4_84252568 - z) - c_1_49012907 * z)
- )
- };
-
- return v.f;
-}
-
-inline v4sf
-vfastexp (const v4sf p)
-{
- const v4sf c_invlog_2 = v4sfl (1.442695040f);
-
- return vfastpow2 (c_invlog_2 * p);
-}
-
-inline v4sf
-vfastlog2 (v4sf x)
-{
- union { v4sf f; v4si i; } vx = { x };
- union { v4si i; v4sf f; } mx = { (vx.i & v4sil (0x007FFFFF)) | v4sil (0x3f000000) };
- v4sf y = v4si_to_v4sf (vx.i);
- y *= v4sfl (1.1920928955078125e-7f);
-
- const v4sf c_124_22551499 = v4sfl (124.22551499f);
- const v4sf c_1_498030302 = v4sfl (1.498030302f);
- const v4sf c_1_725877999 = v4sfl (1.72587999f);
- const v4sf c_0_3520087068 = v4sfl (0.3520887068f);
-
- return y - c_124_22551499
- - c_1_498030302 * mx.f
- - c_1_725877999 / (c_0_3520087068 + mx.f);
-}
-
-inline v4sf
-vfastlog (v4sf x)
-{
- const v4sf c_0_69314718 = v4sfl (0.69314718f);
-
- return c_0_69314718 * vfastlog2 (x);
-}
-
-inline v4sf
-vfastdigamma (v4sf x)
-{
- v4sf twopx = v4sfl (2.0f) + x;
- v4sf logterm = vfastlog (twopx);
-
- return (v4sfl (-48.0f) + x * (v4sfl (-157.0f) + x * (v4sfl (-127.0f) - v4sfl (30.0f) * x))) /
- (v4sfl (12.0f) * x * (v4sfl (1.0f) + x) * twopx * twopx)
- + logterm;
-}
-
-void
-vexpdigammify (vw& all, float* gamma)
-{
- unsigned int n = all.lda;
- float extra_sum = 0.0f;
- v4sf sum = v4sfl (0.0f);
- size_t i;
-
- for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
- {
- extra_sum += gamma[i];
- gamma[i] = fastdigamma (gamma[i]);
- }
-
- for (; i + 4 < n; i += 4)
- {
- v4sf arg = _mm_load_ps (gamma + i);
- sum += arg;
- arg = vfastdigamma (arg);
- _mm_store_ps (gamma + i, arg);
- }
-
- for (; i < n; ++i)
- {
- extra_sum += gamma[i];
- gamma[i] = fastdigamma (gamma[i]);
- }
-
- extra_sum += v4sf_index (sum, 0) + v4sf_index (sum, 1) +
- v4sf_index (sum, 2) + v4sf_index (sum, 3);
- extra_sum = fastdigamma (extra_sum);
- sum = v4sfl (extra_sum);
-
- for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0)));
- }
-
- for (; i + 4 < n; i += 4)
- {
- v4sf arg = _mm_load_ps (gamma + i);
- arg -= sum;
- arg = vfastexp (arg);
- arg = _mm_max_ps (v4sfl (1e-10f), arg);
- _mm_store_ps (gamma + i, arg);
- }
-
- for (; i < n; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0)));
- }
-}
-
-void vexpdigammify_2(vw& all, float* gamma, const float* norm)
-{
- size_t n = all.lda;
- size_t i;
-
- for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i]));
- }
-
- for (; i + 4 < n; i += 4)
- {
- v4sf arg = _mm_load_ps (gamma + i);
- arg = vfastdigamma (arg);
- v4sf vnorm = _mm_loadu_ps (norm + i);
- arg -= vnorm;
- arg = vfastexp (arg);
- arg = _mm_max_ps (v4sfl (1e-10f), arg);
- _mm_store_ps (gamma + i, arg);
- }
-
- for (; i < n; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i]));
- }
-}
-
-#define myexpdigammify vexpdigammify
-#define myexpdigammify_2 vexpdigammify_2
-
-#else
-#ifndef _WIN32
-#warning "lda IS NOT using sse instructions"
-#endif
-#define myexpdigammify expdigammify
-#define myexpdigammify_2 expdigammify_2
-
-#endif // __SSE2__
-
-} // end anonymous namespace
-
-#else
-
-#include <boost/math/special_functions/digamma.hpp>
-#include <boost/math/special_functions/gamma.hpp>
-
-using namespace boost::math::policies;
-
-#define mydigamma boost::math::digamma
-#define mylgamma boost::math::lgamma
-#define myexpdigammify expdigammify
-#define myexpdigammify_2 expdigammify_2
-
-#endif // MINEIRO_SPECIAL
-
-float decayfunc(float t, float old_t, float power_t) {
- float result = 1;
- for (float i = old_t+1; i <= t; i += 1)
- result *= (1-powf(i, -power_t));
- return result;
-}
-
-float decayfunc2(float t, float old_t, float power_t)
-{
- float power_t_plus_one = 1.f - power_t;
- float arg = - ( powf(t, power_t_plus_one) -
- powf(old_t, power_t_plus_one));
- return exp ( arg
- / power_t_plus_one);
-}
-
-float decayfunc3(double t, double old_t, double power_t)
-{
- double power_t_plus_one = 1. - power_t;
- double logt = log((float)t);
- double logoldt = log((float)old_t);
- return (float)((old_t / t) * exp((float)(0.5*power_t_plus_one*(-logt*logt + logoldt*logoldt))));
-}
-
-float decayfunc4(double t, double old_t, double power_t)
-{
- if (power_t > 0.99)
- return decayfunc3(t, old_t, power_t);
- else
- return (float)decayfunc2((float)t, (float)old_t, (float)power_t);
-}
-
-void expdigammify(vw& all, float* gamma)
-{
- float sum=0;
- for (size_t i = 0; i<all.lda; i++)
- {
- sum += gamma[i];
- gamma[i] = mydigamma(gamma[i]);
- }
- sum = mydigamma(sum);
- for (size_t i = 0; i<all.lda; i++)
- gamma[i] = fmax(1e-6f, exp(gamma[i] - sum));
-}
-
-void expdigammify_2(vw& all, float* gamma, float* norm)
-{
- for (size_t i = 0; i<all.lda; i++)
- {
- gamma[i] = fmax(1e-6f, exp(mydigamma(gamma[i]) - norm[i]));
- }
-}
-
-float average_diff(vw& all, float* oldgamma, float* newgamma)
-{
- float sum = 0.;
- float normalizer = 0.;
- for (size_t i = 0; i<all.lda; i++) {
- sum += fabsf(oldgamma[i] - newgamma[i]);
- normalizer += newgamma[i];
- }
- return sum / normalizer;
-}
-
-// Returns E_q[log p(\theta)] - E_q[log q(\theta)].
- float theta_kl(vw& all, v_array<float>& Elogtheta, float* gamma)
-{
- float gammasum = 0;
- Elogtheta.erase();
- for (size_t k = 0; k < all.lda; k++) {
- Elogtheta.push_back(mydigamma(gamma[k]));
- gammasum += gamma[k];
- }
- float digammasum = mydigamma(gammasum);
- gammasum = mylgamma(gammasum);
- float kl = -(all.lda*mylgamma(all.lda_alpha));
- kl += mylgamma(all.lda_alpha*all.lda) - gammasum;
- for (size_t k = 0; k < all.lda; k++) {
- Elogtheta[k] -= digammasum;
- kl += (all.lda_alpha - gamma[k]) * Elogtheta[k];
- kl += mylgamma(gamma[k]);
- }
-
- return kl;
-}
-
-float find_cw(vw& all, float* u_for_w, float* v)
-{
- float c_w = 0;
- for (size_t k =0; k<all.lda; k++)
- c_w += u_for_w[k]*v[k];
-
- return 1.f / c_w;
-}
-
- v_array<float> new_gamma = v_init<float>();
- v_array<float> old_gamma = v_init<float>();
-// Returns an estimate of the part of the variational bound that
-// doesn't have to do with beta for the entire corpus for the current
-// setting of lambda based on the document passed in. The value is
-// divided by the total number of words in the document This can be
-// used as a (possibly very noisy) estimate of held-out likelihood.
- float lda_loop(vw& all, v_array<float>& Elogtheta, float* v,weight* weights,example* ec, float power_t)
-{
- new_gamma.erase();
- old_gamma.erase();
-
- for (size_t i = 0; i < all.lda; i++)
- {
- new_gamma.push_back(1.f);
- old_gamma.push_back(0.f);
- }
- size_t num_words =0;
- for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
- num_words += ec->atomics[*i].end - ec->atomics[*i].begin;
-
- float xc_w = 0;
- float score = 0;
- float doc_length = 0;
- do
- {
- memcpy(v,new_gamma.begin,sizeof(float)*all.lda);
- myexpdigammify(all, v);
-
- memcpy(old_gamma.begin,new_gamma.begin,sizeof(float)*all.lda);
- memset(new_gamma.begin,0,sizeof(float)*all.lda);
-
- score = 0;
- size_t word_count = 0;
- doc_length = 0;
- for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
- {
- feature *f = ec->atomics[*i].begin;
- for (; f != ec->atomics[*i].end; f++)
- {
- float* u_for_w = &weights[(f->weight_index&all.reg.weight_mask)+all.lda+1];
- float c_w = find_cw(all, u_for_w,v);
- xc_w = c_w * f->x;
- score += -f->x*log(c_w);
- size_t max_k = all.lda;
- for (size_t k =0; k<max_k; k++) {
- new_gamma[k] += xc_w*u_for_w[k];
- }
- word_count++;
- doc_length += f->x;
- }
- }
- for (size_t k =0; k<all.lda; k++)
- new_gamma[k] = new_gamma[k]*v[k]+all.lda_alpha;
- }
- while (average_diff(all, old_gamma.begin, new_gamma.begin) > all.lda_epsilon);
-
- ec->topic_predictions.erase();
- ec->topic_predictions.resize(all.lda);
- memcpy(ec->topic_predictions.begin,new_gamma.begin,all.lda*sizeof(float));
-
- score += theta_kl(all, Elogtheta, new_gamma.begin);
-
- return score / doc_length;
-}
-
-size_t next_pow2(size_t x) {
- int i = 0;
- x = x > 0 ? x - 1 : 0;
- while (x > 0) {
- x >>= 1;
- i++;
- }
- return ((size_t)1) << i;
-}
-
-void save_load(lda& l, io_buf& model_file, bool read, bool text)
-{
- vw* all = l.all;
- uint32_t length = 1 << all->num_bits;
- uint32_t stride = 1 << all->reg.stride_shift;
-
- if (read)
- {
- initialize_regressor(*all);
- for (size_t j = 0; j < stride*length; j+=stride)
- {
- for (size_t k = 0; k < all->lda; k++) {
- if (all->random_weights) {
- all->reg.weight_vector[j+k] = (float)(-log(frand48()) + 1.0f);
- all->reg.weight_vector[j+k] *= (float)(all->lda_D / all->lda / all->length() * 200);
- }
- }
- all->reg.weight_vector[j+all->lda] = all->initial_t;
- }
- }
-
- if (model_file.files.size() > 0)
- {
- uint32_t i = 0;
- uint32_t text_len;
- char buff[512];
- size_t brw = 1;
- do
- {
- brw = 0;
- size_t K = all->lda;
-
- text_len = sprintf(buff, "%d ", i);
- brw += bin_text_read_write_fixed(model_file,(char *)&i, sizeof (i),
- "", read,
- buff, text_len, text);
- if (brw != 0)
- for (uint32_t k = 0; k < K; k++)
- {
- uint32_t ndx = stride*i+k;
-
- weight* v = &(all->reg.weight_vector[ndx]);
- text_len = sprintf(buff, "%f ", *v + all->lda_rho);
-
- brw += bin_text_read_write_fixed(model_file,(char *)v, sizeof (*v),
- "", read,
- buff, text_len, text);
-
- }
- if (text)
- brw += bin_text_read_write_fixed(model_file,buff,0,
- "", read,
- "\n",1,text);
-
- if (!read)
- i++;
- }
- while ((!read && i < length) || (read && brw >0));
- }
-}
-
- void learn_batch(lda& l)
- {
- if (l.sorted_features.empty()) {
- // This can happen when the socket connection is dropped by the client.
- // If l.sorted_features is empty, then l.sorted_features[0] does not
- // exist, so we should not try to take its address in the beginning of
- // the for loops down there. Since it seems that there's not much to
- // do in this case, we just return.
- for (size_t d = 0; d < l.examples.size(); d++)
- return_simple_example(*l.all, NULL, *l.examples[d]);
- l.examples.erase();
- return;
- }
-
- float eta = -1;
- float minuseta = -1;
-
- if (l.total_lambda.size() == 0)
- {
- for (size_t k = 0; k < l.all->lda; k++)
- l.total_lambda.push_back(0.f);
-
- size_t stride = 1 << l.all->reg.stride_shift;
- for (size_t i =0; i <= l.all->reg.weight_mask;i+=stride)
- for (size_t k = 0; k < l.all->lda; k++)
- l.total_lambda[k] += l.all->reg.weight_vector[i+k];
- }
-
- l.example_t++;
- l.total_new.erase();
- for (size_t k = 0; k < l.all->lda; k++)
- l.total_new.push_back(0.f);
-
- size_t batch_size = l.examples.size();
-
- sort(l.sorted_features.begin(), l.sorted_features.end());
-
- eta = l.all->eta * powf((float)l.example_t, - l.all->power_t);
- minuseta = 1.0f - eta;
- eta *= l.all->lda_D / batch_size;
- l.decay_levels.push_back(l.decay_levels.last() + log(minuseta));
-
- l.digammas.erase();
- float additional = (float)(l.all->length()) * l.all->lda_rho;
- for (size_t i = 0; i<l.all->lda; i++) {
- l.digammas.push_back(mydigamma(l.total_lambda[i] + additional));
- }
-
-
- weight* weights = l.all->reg.weight_vector;
-
- size_t last_weight_index = -1;
- for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back(); s++)
- {
- if (last_weight_index == s->f.weight_index)
- continue;
- last_weight_index = s->f.weight_index;
- float* weights_for_w = &(weights[s->f.weight_index & l.all->reg.weight_mask]);
- float decay = fmin(1.0, exp(l.decay_levels.end[-2] - l.decay_levels.end[(int)(-1 - l.example_t+weights_for_w[l.all->lda])]));
- float* u_for_w = weights_for_w + l.all->lda+1;
-
- weights_for_w[l.all->lda] = (float)l.example_t;
- for (size_t k = 0; k < l.all->lda; k++)
- {
- weights_for_w[k] *= decay;
- u_for_w[k] = weights_for_w[k] + l.all->lda_rho;
- }
- myexpdigammify_2(*l.all, u_for_w, l.digammas.begin);
- }
-
- for (size_t d = 0; d < batch_size; d++)
- {
- float score = lda_loop(*l.all, l.Elogtheta, &(l.v[d*l.all->lda]), weights, l.examples[d],l.all->power_t);
- if (l.all->audit)
- GD::print_audit_features(*l.all, *l.examples[d]);
- // If the doc is empty, give it loss of 0.
- if (l.doc_lengths[d] > 0) {
- l.all->sd->sum_loss -= score;
- l.all->sd->sum_loss_since_last_dump -= score;
- }
- return_simple_example(*l.all, NULL, *l.examples[d]);
- }
-
- for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back();)
- {
- index_feature* next = s+1;
- while(next <= &l.sorted_features.back() && next->f.weight_index == s->f.weight_index)
- next++;
-
- float* word_weights = &(weights[s->f.weight_index & l.all->reg.weight_mask]);
- for (size_t k = 0; k < l.all->lda; k++) {
- float new_value = minuseta*word_weights[k];
- word_weights[k] = new_value;
- }
-
- for (; s != next; s++) {
- float* v_s = &(l.v[s->document*l.all->lda]);
- float* u_for_w = &weights[(s->f.weight_index & l.all->reg.weight_mask) + l.all->lda + 1];
- float c_w = eta*find_cw(*l.all, u_for_w, v_s)*s->f.x;
- for (size_t k = 0; k < l.all->lda; k++) {
- float new_value = u_for_w[k]*v_s[k]*c_w;
- l.total_new[k] += new_value;
- word_weights[k] += new_value;
- }
- }
- }
- for (size_t k = 0; k < l.all->lda; k++) {
- l.total_lambda[k] *= minuseta;
- l.total_lambda[k] += l.total_new[k];
- }
-
- l.sorted_features.resize(0);
-
- l.examples.erase();
- l.doc_lengths.erase();
- }
-
- void learn(lda& l, base_learner& base, example& ec)
- {
- size_t num_ex = l.examples.size();
- l.examples.push_back(&ec);
- l.doc_lengths.push_back(0);
- for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) {
- feature* f = ec.atomics[*i].begin;
- for (; f != ec.atomics[*i].end; f++) {
- index_feature temp = {(uint32_t)num_ex, *f};
- l.sorted_features.push_back(temp);
- l.doc_lengths[num_ex] += (int)f->x;
- }
- }
- if (++num_ex == l.all->minibatch)
- learn_batch(l);
- }
-
- // placeholder
- void predict(lda& l, base_learner& base, example& ec)
- {
- learn(l, base, ec);
- }
-
- void end_pass(lda& l)
- {
- if (l.examples.size())
- learn_batch(l);
- }
-
-void end_examples(lda& l)
-{
- for (size_t i = 0; i < l.all->length(); i++) {
- weight* weights_for_w = & (l.all->reg.weight_vector[i << l.all->reg.stride_shift]);
- float decay = fmin(1.0, exp(l.decay_levels.last() - l.decay_levels.end[(int)(-1- l.example_t +weights_for_w[l.all->lda])]));
- for (size_t k = 0; k < l.all->lda; k++)
- weights_for_w[k] *= decay;
- }
-}
-
- void finish_example(vw& all, lda&, example& ec)
-{}
-
- void finish(lda& ld)
- {
- ld.sorted_features.~vector<index_feature>();
- ld.Elogtheta.delete_v();
- ld.decay_levels.delete_v();
- ld.total_new.delete_v();
- ld.examples.delete_v();
- ld.total_lambda.delete_v();
- ld.doc_lengths.delete_v();
- ld.digammas.delete_v();
- ld.v.delete_v();
- }
-
-base_learner* setup(vw&all, po::variables_map& vm)
-{
- lda& ld = calloc_or_die<lda>();
- ld.sorted_features = vector<index_feature>();
- ld.total_lambda_init = 0;
- ld.all = &all;
- ld.example_t = all.initial_t;
-
- po::options_description lda_opts("LDA options");
- lda_opts.add_options()
- ("lda_alpha", po::value<float>(&all.lda_alpha), "Prior on sparsity of per-document topic weights")
- ("lda_rho", po::value<float>(&all.lda_rho), "Prior on sparsity of topic distributions")
- ("lda_D", po::value<float>(&all.lda_D), "Number of documents")
- ("lda_epsilon", po::value<float>(&all.lda_epsilon), "Loop convergence threshold")
- ("minibatch", po::value<size_t>(&all.minibatch), "Minibatch size, for LDA");
-
- vm = add_options(all, lda_opts);
-
- float temp = ceilf(logf((float)(all.lda*2+1)) / logf (2.f));
- all.reg.stride_shift = (size_t)temp;
- all.random_weights = true;
- all.add_constant = false;
-
- *all.file_options << " --lda " << all.lda;
-
- if (all.eta > 1.)
- {
- cerr << "your learning rate is too high, setting it to 1" << endl;
- all.eta = min(all.eta,1.f);
- }
-
- if (vm.count("minibatch")) {
- size_t minibatch2 = next_pow2(all.minibatch);
- all.p->ring_size = all.p->ring_size > minibatch2 ? all.p->ring_size : minibatch2;
- }
-
- ld.v.resize(all.lda*all.minibatch);
-
- ld.decay_levels.push_back(0.f);
-
- learner<lda>& l = init_learner(&ld, learn, 1 << all.reg.stride_shift);
- l.set_predict(predict);
- l.set_save_load(save_load);
- l.set_finish_example(finish_example);
- l.set_end_examples(end_examples);
- l.set_end_pass(end_pass);
- l.set_finish(finish);
-
- return make_base(l);
-}
-}
+/* +Copyright (c) by respective owners including Yahoo!, Microsoft, and +individual contributors. All rights reserved. Released under a BSD (revised) +license as described in the file LICENSE. + */ +#include <fstream> +#include <vector> +#include <float.h> +#ifdef _WIN32 +#include <winsock2.h> +#else +#include <netdb.h> +#endif +#include <string.h> +#include <stdio.h> +#include <assert.h> +#include "constant.h" +#include "gd.h" +#include "simple_label.h" +#include "rand48.h" +#include "reductions.h" + +using namespace LEARNER; +using namespace std; + +namespace LDA { + +class index_feature { +public: + uint32_t document; + feature f; + bool operator<(const index_feature b) const { return f.weight_index < b.f.weight_index; } +}; + + struct lda { + uint32_t topics; + float lda_alpha; + float lda_rho; + float lda_D; + float lda_epsilon; + size_t minibatch; + + v_array<float> Elogtheta; + v_array<float> decay_levels; + v_array<float> total_new; + v_array<example* > examples; + v_array<float> total_lambda; + v_array<int> doc_lengths; + v_array<float> digammas; + v_array<float> v; + vector<index_feature> sorted_features; + + bool total_lambda_init; + + double example_t; + vw* all; + }; + +#ifdef _WIN32 +inline float fmax(float f1, float f2) { return (f1 < f2 ? f2 : f1); } +inline float fmin(float f1, float f2) { return (f1 > f2 ? f2 : f1); } +#endif + +#define MINEIRO_SPECIAL +#ifdef MINEIRO_SPECIAL + +namespace { + +inline float +fastlog2 (float x) +{ + union { float f; uint32_t i; } vx = { x }; + union { uint32_t i; float f; } mx = { (vx.i & 0x007FFFFF) | (0x7e << 23) }; + float y = (float)vx.i; + y *= 1.0f / (float)(1 << 23); + + return + y - 124.22544637f - 1.498030302f * mx.f - 1.72587999f / (0.3520887068f + mx.f); +} + +inline float +fastlog (float x) +{ + return 0.69314718f * fastlog2 (x); +} + +inline float +fastpow2 (float p) +{ + float offset = (p < 0) ? 1.0f : 0.0f; + float clipp = (p < -126) ? -126.0f : p; + int w = (int)clipp; + float z = clipp - w + offset; + union { uint32_t i; float f; } v = { (uint32_t)((1 << 23) * (clipp + 121.2740838f + 27.7280233f / (4.84252568f - z) - 1.49012907f * z)) }; + + return v.f; +} + +inline float +fastexp (float p) +{ + return fastpow2 (1.442695040f * p); +} + +inline float +fastpow (float x, + float p) +{ + return fastpow2 (p * fastlog2 (x)); +} + +inline float +fastlgamma (float x) +{ + float logterm = fastlog (x * (1.0f + x) * (2.0f + x)); + float xp3 = 3.0f + x; + + return + -2.081061466f - x + 0.0833333f / xp3 - logterm + (2.5f + x) * fastlog (xp3); +} + +inline float +fastdigamma (float x) +{ + float twopx = 2.0f + x; + float logterm = fastlog (twopx); + + return - (1.0f + 2.0f * x) / (x * (1.0f + x)) + - (13.0f + 6.0f * x) / (12.0f * twopx * twopx) + + logterm; +} + +#define log fastlog +#define exp fastexp +#define powf fastpow +#define mydigamma fastdigamma +#define mylgamma fastlgamma + +#if defined(__SSE2__) && !defined(VW_LDA_NO_SSE) + +#include <emmintrin.h> + +typedef __m128 v4sf; +typedef __m128i v4si; + +#define v4si_to_v4sf _mm_cvtepi32_ps +#define v4sf_to_v4si _mm_cvttps_epi32 + +static inline float +v4sf_index (const v4sf x, + unsigned int i) +{ + union { v4sf f; float array[4]; } tmp = { x }; + + return tmp.array[i]; +} + +static inline const v4sf +v4sfl (float x) +{ + union { float array[4]; v4sf f; } tmp = { { x, x, x, x } }; + + return tmp.f; +} + +static inline const v4si +v4sil (uint32_t x) +{ + uint64_t wide = (((uint64_t) x) << 32) | x; + union { uint64_t array[2]; v4si f; } tmp = { { wide, wide } }; + + return tmp.f; +} + +static inline v4sf +vfastpow2 (const v4sf p) +{ + v4sf ltzero = _mm_cmplt_ps (p, v4sfl (0.0f)); + v4sf offset = _mm_and_ps (ltzero, v4sfl (1.0f)); + v4sf lt126 = _mm_cmplt_ps (p, v4sfl (-126.0f)); + v4sf clipp = _mm_andnot_ps (lt126, p) + _mm_and_ps (lt126, v4sfl (-126.0f)); + v4si w = v4sf_to_v4si (clipp); + v4sf z = clipp - v4si_to_v4sf (w) + offset; + + const v4sf c_121_2740838 = v4sfl (121.2740838f); + const v4sf c_27_7280233 = v4sfl (27.7280233f); + const v4sf c_4_84252568 = v4sfl (4.84252568f); + const v4sf c_1_49012907 = v4sfl (1.49012907f); + union { v4si i; v4sf f; } v = { + v4sf_to_v4si ( + v4sfl (1 << 23) * + (clipp + c_121_2740838 + c_27_7280233 / (c_4_84252568 - z) - c_1_49012907 * z) + ) + }; + + return v.f; +} + +inline v4sf +vfastexp (const v4sf p) +{ + const v4sf c_invlog_2 = v4sfl (1.442695040f); + + return vfastpow2 (c_invlog_2 * p); +} + +inline v4sf +vfastlog2 (v4sf x) +{ + union { v4sf f; v4si i; } vx = { x }; + union { v4si i; v4sf f; } mx = { (vx.i & v4sil (0x007FFFFF)) | v4sil (0x3f000000) }; + v4sf y = v4si_to_v4sf (vx.i); + y *= v4sfl (1.1920928955078125e-7f); + + const v4sf c_124_22551499 = v4sfl (124.22551499f); + const v4sf c_1_498030302 = v4sfl (1.498030302f); + const v4sf c_1_725877999 = v4sfl (1.72587999f); + const v4sf c_0_3520087068 = v4sfl (0.3520887068f); + + return y - c_124_22551499 + - c_1_498030302 * mx.f + - c_1_725877999 / (c_0_3520087068 + mx.f); +} + +inline v4sf +vfastlog (v4sf x) +{ + const v4sf c_0_69314718 = v4sfl (0.69314718f); + + return c_0_69314718 * vfastlog2 (x); +} + +inline v4sf +vfastdigamma (v4sf x) +{ + v4sf twopx = v4sfl (2.0f) + x; + v4sf logterm = vfastlog (twopx); + + return (v4sfl (-48.0f) + x * (v4sfl (-157.0f) + x * (v4sfl (-127.0f) - v4sfl (30.0f) * x))) / + (v4sfl (12.0f) * x * (v4sfl (1.0f) + x) * twopx * twopx) + + logterm; +} + +void +vexpdigammify (vw& all, float* gamma) +{ + unsigned int n = all.lda; + float extra_sum = 0.0f; + v4sf sum = v4sfl (0.0f); + size_t i; + + for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i) + { + extra_sum += gamma[i]; + gamma[i] = fastdigamma (gamma[i]); + } + + for (; i + 4 < n; i += 4) + { + v4sf arg = _mm_load_ps (gamma + i); + sum += arg; + arg = vfastdigamma (arg); + _mm_store_ps (gamma + i, arg); + } + + for (; i < n; ++i) + { + extra_sum += gamma[i]; + gamma[i] = fastdigamma (gamma[i]); + } + + extra_sum += v4sf_index (sum, 0) + v4sf_index (sum, 1) + + v4sf_index (sum, 2) + v4sf_index (sum, 3); + extra_sum = fastdigamma (extra_sum); + sum = v4sfl (extra_sum); + + for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i) + { + gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0))); + } + + for (; i + 4 < n; i += 4) + { + v4sf arg = _mm_load_ps (gamma + i); + arg -= sum; + arg = vfastexp (arg); + arg = _mm_max_ps (v4sfl (1e-10f), arg); + _mm_store_ps (gamma + i, arg); + } + + for (; i < n; ++i) + { + gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0))); + } +} + +void vexpdigammify_2(vw& all, float* gamma, const float* norm) +{ + size_t n = all.lda; + size_t i; + + for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i) + { + gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i])); + } + + for (; i + 4 < n; i += 4) + { + v4sf arg = _mm_load_ps (gamma + i); + arg = vfastdigamma (arg); + v4sf vnorm = _mm_loadu_ps (norm + i); + arg -= vnorm; + arg = vfastexp (arg); + arg = _mm_max_ps (v4sfl (1e-10f), arg); + _mm_store_ps (gamma + i, arg); + } + + for (; i < n; ++i) + { + gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i])); + } +} + +#define myexpdigammify vexpdigammify +#define myexpdigammify_2 vexpdigammify_2 + +#else +#ifndef _WIN32 +#warning "lda IS NOT using sse instructions" +#endif +#define myexpdigammify expdigammify +#define myexpdigammify_2 expdigammify_2 + +#endif // __SSE2__ + +} // end anonymous namespace + +#else + +#include <boost/math/special_functions/digamma.hpp> +#include <boost/math/special_functions/gamma.hpp> + +using namespace boost::math::policies; + +#define mydigamma boost::math::digamma +#define mylgamma boost::math::lgamma +#define myexpdigammify expdigammify +#define myexpdigammify_2 expdigammify_2 + +#endif // MINEIRO_SPECIAL + +float decayfunc(float t, float old_t, float power_t) { + float result = 1; + for (float i = old_t+1; i <= t; i += 1) + result *= (1-powf(i, -power_t)); + return result; +} + +float decayfunc2(float t, float old_t, float power_t) +{ + float power_t_plus_one = 1.f - power_t; + float arg = - ( powf(t, power_t_plus_one) - + powf(old_t, power_t_plus_one)); + return exp ( arg + / power_t_plus_one); +} + +float decayfunc3(double t, double old_t, double power_t) +{ + double power_t_plus_one = 1. - power_t; + double logt = log((float)t); + double logoldt = log((float)old_t); + return (float)((old_t / t) * exp((float)(0.5*power_t_plus_one*(-logt*logt + logoldt*logoldt)))); +} + +float decayfunc4(double t, double old_t, double power_t) +{ + if (power_t > 0.99) + return decayfunc3(t, old_t, power_t); + else + return (float)decayfunc2((float)t, (float)old_t, (float)power_t); +} + +void expdigammify(vw& all, float* gamma) +{ + float sum=0; + for (size_t i = 0; i<all.lda; i++) + { + sum += gamma[i]; + gamma[i] = mydigamma(gamma[i]); + } + sum = mydigamma(sum); + for (size_t i = 0; i<all.lda; i++) + gamma[i] = fmax(1e-6f, exp(gamma[i] - sum)); +} + +void expdigammify_2(vw& all, float* gamma, float* norm) +{ + for (size_t i = 0; i<all.lda; i++) + { + gamma[i] = fmax(1e-6f, exp(mydigamma(gamma[i]) - norm[i])); + } +} + +float average_diff(vw& all, float* oldgamma, float* newgamma) +{ + float sum = 0.; + float normalizer = 0.; + for (size_t i = 0; i<all.lda; i++) { + sum += fabsf(oldgamma[i] - newgamma[i]); + normalizer += newgamma[i]; + } + return sum / normalizer; +} + +// Returns E_q[log p(\theta)] - E_q[log q(\theta)]. + float theta_kl(lda& l, v_array<float>& Elogtheta, float* gamma) +{ + float gammasum = 0; + Elogtheta.erase(); + for (size_t k = 0; k < l.topics; k++) { + Elogtheta.push_back(mydigamma(gamma[k])); + gammasum += gamma[k]; + } + float digammasum = mydigamma(gammasum); + gammasum = mylgamma(gammasum); + float kl = -(l.topics*mylgamma(l.lda_alpha)); + kl += mylgamma(l.lda_alpha*l.topics) - gammasum; + for (size_t k = 0; k < l.topics; k++) { + Elogtheta[k] -= digammasum; + kl += (l.lda_alpha - gamma[k]) * Elogtheta[k]; + kl += mylgamma(gamma[k]); + } + + return kl; +} + +float find_cw(lda& l, float* u_for_w, float* v) +{ + float c_w = 0; + for (size_t k =0; k<l.topics; k++) + c_w += u_for_w[k]*v[k]; + + return 1.f / c_w; +} + + v_array<float> new_gamma = v_init<float>(); + v_array<float> old_gamma = v_init<float>(); +// Returns an estimate of the part of the variational bound that +// doesn't have to do with beta for the entire corpus for the current +// setting of lambda based on the document passed in. The value is +// divided by the total number of words in the document This can be +// used as a (possibly very noisy) estimate of held-out likelihood. + float lda_loop(lda& l, v_array<float>& Elogtheta, float* v,weight* weights,example* ec, float power_t) +{ + new_gamma.erase(); + old_gamma.erase(); + + for (size_t i = 0; i < l.topics; i++) + { + new_gamma.push_back(1.f); + old_gamma.push_back(0.f); + } + size_t num_words =0; + for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++) + num_words += ec->atomics[*i].end - ec->atomics[*i].begin; + + float xc_w = 0; + float score = 0; + float doc_length = 0; + do + { + memcpy(v,new_gamma.begin,sizeof(float)*l.topics); + myexpdigammify(*l.all, v); + + memcpy(old_gamma.begin,new_gamma.begin,sizeof(float)*l.topics); + memset(new_gamma.begin,0,sizeof(float)*l.topics); + + score = 0; + size_t word_count = 0; + doc_length = 0; + for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++) + { + feature *f = ec->atomics[*i].begin; + for (; f != ec->atomics[*i].end; f++) + { + float* u_for_w = &weights[(f->weight_index & l.all->reg.weight_mask)+l.topics+1]; + float c_w = find_cw(l, u_for_w,v); + xc_w = c_w * f->x; + score += -f->x*log(c_w); + size_t max_k = l.topics; + for (size_t k =0; k<max_k; k++) { + new_gamma[k] += xc_w*u_for_w[k]; + } + word_count++; + doc_length += f->x; + } + } + for (size_t k =0; k<l.topics; k++) + new_gamma[k] = new_gamma[k]*v[k]+l.lda_alpha; + } + while (average_diff(*l.all, old_gamma.begin, new_gamma.begin) > l.lda_epsilon); + + ec->topic_predictions.erase(); + ec->topic_predictions.resize(l.topics); + memcpy(ec->topic_predictions.begin,new_gamma.begin,l.topics*sizeof(float)); + + score += theta_kl(l, Elogtheta, new_gamma.begin); + + return score / doc_length; +} + +size_t next_pow2(size_t x) { + int i = 0; + x = x > 0 ? x - 1 : 0; + while (x > 0) { + x >>= 1; + i++; + } + return ((size_t)1) << i; +} + +void save_load(lda& l, io_buf& model_file, bool read, bool text) +{ + vw* all = l.all; + uint32_t length = 1 << all->num_bits; + uint32_t stride = 1 << all->reg.stride_shift; + + if (read) + { + initialize_regressor(*all); + for (size_t j = 0; j < stride*length; j+=stride) + { + for (size_t k = 0; k < all->lda; k++) { + if (all->random_weights) { + all->reg.weight_vector[j+k] = (float)(-log(frand48()) + 1.0f); + all->reg.weight_vector[j+k] *= (float)(l.lda_D / all->lda / all->length() * 200); + } + } + all->reg.weight_vector[j+all->lda] = all->initial_t; + } + } + + if (model_file.files.size() > 0) + { + uint32_t i = 0; + uint32_t text_len; + char buff[512]; + size_t brw = 1; + do + { + brw = 0; + size_t K = all->lda; + + text_len = sprintf(buff, "%d ", i); + brw += bin_text_read_write_fixed(model_file,(char *)&i, sizeof (i), + "", read, + buff, text_len, text); + if (brw != 0) + for (uint32_t k = 0; k < K; k++) + { + uint32_t ndx = stride*i+k; + + weight* v = &(all->reg.weight_vector[ndx]); + text_len = sprintf(buff, "%f ", *v + l.lda_rho); + + brw += bin_text_read_write_fixed(model_file,(char *)v, sizeof (*v), + "", read, + buff, text_len, text); + + } + if (text) + brw += bin_text_read_write_fixed(model_file,buff,0, + "", read, + "\n",1,text); + + if (!read) + i++; + } + while ((!read && i < length) || (read && brw >0)); + } +} + + void learn_batch(lda& l) + { + if (l.sorted_features.empty()) { + // This can happen when the socket connection is dropped by the client. + // If l.sorted_features is empty, then l.sorted_features[0] does not + // exist, so we should not try to take its address in the beginning of + // the for loops down there. Since it seems that there's not much to + // do in this case, we just return. + for (size_t d = 0; d < l.examples.size(); d++) + return_simple_example(*l.all, NULL, *l.examples[d]); + l.examples.erase(); + return; + } + + float eta = -1; + float minuseta = -1; + + if (l.total_lambda.size() == 0) + { + for (size_t k = 0; k < l.all->lda; k++) + l.total_lambda.push_back(0.f); + + size_t stride = 1 << l.all->reg.stride_shift; + for (size_t i =0; i <= l.all->reg.weight_mask;i+=stride) + for (size_t k = 0; k < l.all->lda; k++) + l.total_lambda[k] += l.all->reg.weight_vector[i+k]; + } + + l.example_t++; + l.total_new.erase(); + for (size_t k = 0; k < l.all->lda; k++) + l.total_new.push_back(0.f); + + size_t batch_size = l.examples.size(); + + sort(l.sorted_features.begin(), l.sorted_features.end()); + + eta = l.all->eta * powf((float)l.example_t, - l.all->power_t); + minuseta = 1.0f - eta; + eta *= l.lda_D / batch_size; + l.decay_levels.push_back(l.decay_levels.last() + log(minuseta)); + + l.digammas.erase(); + float additional = (float)(l.all->length()) * l.lda_rho; + for (size_t i = 0; i<l.all->lda; i++) { + l.digammas.push_back(mydigamma(l.total_lambda[i] + additional)); + } + + + weight* weights = l.all->reg.weight_vector; + + size_t last_weight_index = -1; + for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back(); s++) + { + if (last_weight_index == s->f.weight_index) + continue; + last_weight_index = s->f.weight_index; + float* weights_for_w = &(weights[s->f.weight_index & l.all->reg.weight_mask]); + float decay = fmin(1.0, exp(l.decay_levels.end[-2] - l.decay_levels.end[(int)(-1 - l.example_t+weights_for_w[l.all->lda])])); + float* u_for_w = weights_for_w + l.all->lda+1; + + weights_for_w[l.all->lda] = (float)l.example_t; + for (size_t k = 0; k < l.all->lda; k++) + { + weights_for_w[k] *= decay; + u_for_w[k] = weights_for_w[k] + l.lda_rho; + } + myexpdigammify_2(*l.all, u_for_w, l.digammas.begin); + } + + for (size_t d = 0; d < batch_size; d++) + { + float score = lda_loop(l, l.Elogtheta, &(l.v[d*l.all->lda]), weights, l.examples[d],l.all->power_t); + if (l.all->audit) + GD::print_audit_features(*l.all, *l.examples[d]); + // If the doc is empty, give it loss of 0. + if (l.doc_lengths[d] > 0) { + l.all->sd->sum_loss -= score; + l.all->sd->sum_loss_since_last_dump -= score; + } + return_simple_example(*l.all, NULL, *l.examples[d]); + } + + for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back();) + { + index_feature* next = s+1; + while(next <= &l.sorted_features.back() && next->f.weight_index == s->f.weight_index) + next++; + + float* word_weights = &(weights[s->f.weight_index & l.all->reg.weight_mask]); + for (size_t k = 0; k < l.all->lda; k++) { + float new_value = minuseta*word_weights[k]; + word_weights[k] = new_value; + } + + for (; s != next; s++) { + float* v_s = &(l.v[s->document*l.all->lda]); + float* u_for_w = &weights[(s->f.weight_index & l.all->reg.weight_mask) + l.all->lda + 1]; + float c_w = eta*find_cw(l, u_for_w, v_s)*s->f.x; + for (size_t k = 0; k < l.all->lda; k++) { + float new_value = u_for_w[k]*v_s[k]*c_w; + l.total_new[k] += new_value; + word_weights[k] += new_value; + } + } + } + for (size_t k = 0; k < l.all->lda; k++) { + l.total_lambda[k] *= minuseta; + l.total_lambda[k] += l.total_new[k]; + } + + l.sorted_features.resize(0); + + l.examples.erase(); + l.doc_lengths.erase(); + } + + void learn(lda& l, base_learner& base, example& ec) + { + size_t num_ex = l.examples.size(); + l.examples.push_back(&ec); + l.doc_lengths.push_back(0); + for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) { + feature* f = ec.atomics[*i].begin; + for (; f != ec.atomics[*i].end; f++) { + index_feature temp = {(uint32_t)num_ex, *f}; + l.sorted_features.push_back(temp); + l.doc_lengths[num_ex] += (int)f->x; + } + } + if (++num_ex == l.minibatch) + learn_batch(l); + } + + // placeholder + void predict(lda& l, base_learner& base, example& ec) + { + learn(l, base, ec); + } + + void end_pass(lda& l) + { + if (l.examples.size()) + learn_batch(l); + } + +void end_examples(lda& l) +{ + for (size_t i = 0; i < l.all->length(); i++) { + weight* weights_for_w = & (l.all->reg.weight_vector[i << l.all->reg.stride_shift]); + float decay = fmin(1.0, exp(l.decay_levels.last() - l.decay_levels.end[(int)(-1- l.example_t +weights_for_w[l.all->lda])])); + for (size_t k = 0; k < l.all->lda; k++) + weights_for_w[k] *= decay; + } +} + + void finish_example(vw& all, lda&, example& ec) +{} + + void finish(lda& ld) + { + ld.sorted_features.~vector<index_feature>(); + ld.Elogtheta.delete_v(); + ld.decay_levels.delete_v(); + ld.total_new.delete_v(); + ld.examples.delete_v(); + ld.total_lambda.delete_v(); + ld.doc_lengths.delete_v(); + ld.digammas.delete_v(); + ld.v.delete_v(); + } + + +base_learner* setup(vw&all) +{ + new_options(all, "Lda options") + ("lda", po::value<uint32_t>(), "Run lda with <int> topics") + ("lda_alpha", po::value<float>()->default_value(0.1f), "Prior on sparsity of per-document topic weights") + ("lda_rho", po::value<float>()->default_value(0.1f), "Prior on sparsity of topic distributions") + ("lda_D", po::value<float>()->default_value(10000.), "Number of documents") + ("lda_epsilon", po::value<float>()->default_value(0.001f), "Loop convergence threshold") + ("minibatch", po::value<size_t>()->default_value(1), "Minibatch size, for LDA"); + add_options(all); + po::variables_map& vm= all.vm; + if(!vm.count("lda")) + return NULL; + else + all.lda = vm["lda"].as<uint32_t>(); + + lda& ld = calloc_or_die<lda>(); + + ld.topics = all.lda; + ld.lda_alpha = vm["lda_alpha"].as<float>(); + ld.lda_rho = vm["lda_rho"].as<float>(); + ld.lda_D = vm["lda_D"].as<float>(); + ld.lda_epsilon = vm["lda_epsilon"].as<float>(); + ld.minibatch = vm["minibatch"].as<size_t>(); + ld.sorted_features = vector<index_feature>(); + ld.total_lambda_init = 0; + ld.all = &all; + ld.example_t = all.initial_t; + + float temp = ceilf(logf((float)(all.lda*2+1)) / logf (2.f)); + all.reg.stride_shift = (size_t)temp; + all.random_weights = true; + all.add_constant = false; + + *all.file_options << " --lda " << all.lda; + + if (all.eta > 1.) + { + cerr << "your learning rate is too high, setting it to 1" << endl; + all.eta = min(all.eta,1.f); + } + + if (vm.count("minibatch")) { + size_t minibatch2 = next_pow2(ld.minibatch); + all.p->ring_size = all.p->ring_size > minibatch2 ? all.p->ring_size : minibatch2; + } + + ld.v.resize(all.lda*ld.minibatch); + + ld.decay_levels.push_back(0.f); + + learner<lda>& l = init_learner(&ld, learn, 1 << all.reg.stride_shift); + l.set_predict(predict); + l.set_save_load(save_load); + l.set_finish_example(finish_example); + l.set_end_examples(end_examples); + l.set_end_pass(end_pass); + l.set_finish(finish); + + return make_base(l); +} +} diff --git a/vowpalwabbit/lda_core.h b/vowpalwabbit/lda_core.h index 2a065783..e734fcad 100644 --- a/vowpalwabbit/lda_core.h +++ b/vowpalwabbit/lda_core.h @@ -4,6 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace LDA{ - LEARNER::base_learner* setup(vw&, po::variables_map&); -} +namespace LDA{ LEARNER::base_learner* setup(vw&); } diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h index 2649429c..b832306c 100644 --- a/vowpalwabbit/learner.h +++ b/vowpalwabbit/learner.h @@ -186,7 +186,7 @@ namespace LEARNER template<class T> learner<T>& init_learner(T* dat, base_learner* base, void (*learn)(T&, base_learner&, example&), - void (*predict)(T&, base_learner&, example&), size_t ws = 1) + void (*predict)(T&, base_learner&, example&), size_t ws) { //the reduction constructor, with separate learn and predict functions learner<T>& ret = calloc_or_die<learner<T> >(); ret = *(learner<T>*)base; diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index 226376bd..e673b0d8 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -12,7 +12,6 @@ license as described in the file LICENSE.node #include "reductions.h"
#include "simple_label.h"
#include "multiclass.h"
-#include "vw.h"
using namespace std;
using namespace LEARNER;
@@ -77,12 +76,11 @@ namespace LOG_MULTI struct log_multi
{
uint32_t k;
- vw* all;
v_array<node> nodes;
- uint32_t max_predictors;
- uint32_t predictors_used;
+ size_t max_predictors;
+ size_t predictors_used;
bool progress;
uint32_t swap_resist;
@@ -199,7 +197,7 @@ namespace LOG_MULTI b.nodes.push_back(init_node());
right_child = (uint32_t)b.nodes.size();
b.nodes.push_back(init_node());
- b.nodes[current].base_predictor = b.predictors_used++;
+ b.nodes[current].base_predictor = (uint32_t)b.predictors_used++;
}
else
{
@@ -319,11 +317,10 @@ namespace LOG_MULTI void learn(log_multi& b, base_learner& base, example& ec)
{
// verify_min_dfs(b, b.nodes[0]);
-
- if (ec.l.multi.label == (uint32_t)-1 || !b.all->training || b.progress)
+ if (ec.l.multi.label == (uint32_t)-1 || b.progress)
predict(b,base,ec);
- if(b.all->training && (ec.l.multi.label != (uint32_t)-1) && !ec.test_only) //if training the tree
+ if((ec.l.multi.label != (uint32_t)-1) && !ec.test_only) //if training the tree
{
MULTICLASS::multiclass mc = ec.l.multi;
@@ -413,10 +410,10 @@ namespace LOG_MULTI if (read)
for (uint32_t j = 1; j < temp; j++)
b.nodes.push_back(init_node());
- text_len = sprintf(buff, "max_predictors = %d ",b.max_predictors);
+ text_len = sprintf(buff, "max_predictors = %ld ",b.max_predictors);
bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.max_predictors), "", read, buff, text_len, text);
- text_len = sprintf(buff, "predictors_used = %d ",b.predictors_used);
+ text_len = sprintf(buff, "predictors_used = %ld ",b.predictors_used);
bin_text_read_write_fixed(model_file,(char*)&b.predictors_used, sizeof(b.predictors_used), "", read, buff, text_len, text);
text_len = sprintf(buff, "progress = %d ",b.progress);
@@ -496,20 +493,25 @@ namespace LOG_MULTI }
}
- void finish_example(vw& all, log_multi&, example& ec) { MULTICLASS::finish_example(all, ec); }
-
- base_learner* setup(vw& all, po::variables_map& vm) //learner setup
+ base_learner* setup(vw& all) //learner setup
{
- log_multi& data = calloc_or_die<log_multi>();
-
- po::options_description opts("TXM Online options");
- opts.add_options()
+ new_options(all, "Logarithmic Time Multiclass options") + ("log_multi", po::value<size_t>(), "Use online tree for multiclass"); + if (missing_required(all)) return NULL; + new_options(all) ("no_progress", "disable progressive validation")
- ("swap_resistance", po::value<uint32_t>(&(data.swap_resist))->default_value(4), "higher = more resistance to swap, default=4");
-
- vm = add_options(all, opts);
-
+ ("swap_resistance", po::value<uint32_t>(), "higher = more resistance to swap, default=4"); + add_options(all);
+ + po::variables_map& vm = all.vm; +
+ log_multi& data = calloc_or_die<log_multi>(); data.k = (uint32_t)vm["log_multi"].as<size_t>();
+ data.swap_resist = 4;
+
+ if (vm.count("swap_resistance"))
+ data.swap_resist = vm["swap_resistance"].as<uint32_t>();
+
*all.file_options << " --log_multi " << data.k;
if (vm.count("no_progress"))
@@ -517,9 +519,6 @@ namespace LOG_MULTI else
data.progress = true;
- data.all = &all;
- (all.p->lp) = MULTICLASS::mc_label;
-
string loss_function = "quantile";
float loss_parameter = 0.5;
delete(all.loss);
@@ -527,10 +526,11 @@ namespace LOG_MULTI data.max_predictors = data.k - 1;
- learner<log_multi>& l = init_learner(&data, all.l, learn, predict, data.max_predictors);
+ learner<log_multi>& l = init_learner(&data, setup_base(all), learn, predict, data.max_predictors);
l.set_save_load(save_load_tree);
- l.set_finish_example(finish_example);
l.set_finish(finish);
+ l.set_finish_example(MULTICLASS::finish_example<log_multi>); + all.p->lp = MULTICLASS::mc_label; init_tree(data);
diff --git a/vowpalwabbit/log_multi.h b/vowpalwabbit/log_multi.h index 5e1ee3bf..0660334a 100644 --- a/vowpalwabbit/log_multi.h +++ b/vowpalwabbit/log_multi.h @@ -4,7 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace LOG_MULTI -{ - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace LOG_MULTI { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/lrq.cc b/vowpalwabbit/lrq.cc index 268815d5..13534375 100644 --- a/vowpalwabbit/lrq.cc +++ b/vowpalwabbit/lrq.cc @@ -187,24 +187,34 @@ namespace LRQ { } } - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) {//parse and set arguments + new_options(all, "Lrq options") + ("lrq", po::value<vector<string> > (), "use low rank quadratic features"); + if (missing_required(all)) return NULL; + new_options(all) + ("lrqdropout", "use dropout training for low rank quadratic features"); + add_options(all); + + if(!all.vm.count("lrq")) + return NULL; + LRQstate& lrq = calloc_or_die<LRQstate>(); size_t maxk = 0; lrq.all = &all; size_t random_seed = 0; - if (vm.count("random_seed")) random_seed = vm["random_seed"].as<size_t> (); + if (all.vm.count("random_seed")) random_seed = all.vm["random_seed"].as<size_t> (); lrq.initial_seed = lrq.seed = random_seed | 8675309; - if (vm.count("lrqdropout")) + if (all.vm.count("lrqdropout")) lrq.dropout = true; else lrq.dropout = false; *all.file_options << " --lrqdropout "; - lrq.lrpairs = vm["lrq"].as<vector<string> > (); + lrq.lrpairs = all.vm["lrq"].as<vector<string> > (); for (vector<string>::iterator i = lrq.lrpairs.begin (); i != lrq.lrpairs.end (); @@ -242,8 +252,8 @@ namespace LRQ { if(!all.quiet) cerr<<endl; - all.wpp = all.wpp * (1 + maxk); - learner<LRQstate>& l = init_learner(&lrq, all.l, predict_or_learn<true>, + all.wpp = all.wpp * (uint32_t)(1 + maxk); + learner<LRQstate>& l = init_learner(&lrq, setup_base(all), predict_or_learn<true>, predict_or_learn<false>, 1 + maxk); l.set_end_pass(reset_seed); diff --git a/vowpalwabbit/lrq.h b/vowpalwabbit/lrq.h index 376bd6e5..b08e24f4 100644 --- a/vowpalwabbit/lrq.h +++ b/vowpalwabbit/lrq.h @@ -1,4 +1,2 @@ #pragma once -namespace LRQ { - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace LRQ { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/main.cc b/vowpalwabbit/main.cc index c7f40326..95e9d2f9 100644 --- a/vowpalwabbit/main.cc +++ b/vowpalwabbit/main.cc @@ -21,11 +21,11 @@ using namespace std; int main(int argc, char *argv[]) { try { - vw *all = parse_args(argc, argv); + vw& all = parse_args(argc, argv); struct timeb t_start, t_end; ftime(&t_start); - if (!all->quiet && !all->bfgs && !all->searchstr) + if (!all.quiet && !all.bfgs && !all.searchstr) { const char * header_fmt = "%-10s %-10s %10s %11s %8s %8s %8s\n"; fprintf(stderr, header_fmt, @@ -36,67 +36,63 @@ int main(int argc, char *argv[]) cerr.precision(5); } - VW::start_parser(*all); - LEARNER::generic_driver(*all); - VW::end_parser(*all); + VW::start_parser(all); + LEARNER::generic_driver(all); + VW::end_parser(all); ftime(&t_end); double net_time = (int) (1000.0 * (t_end.time - t_start.time) + (t_end.millitm - t_start.millitm)); - if(!all->quiet && all->span_server != "") + if(!all.quiet && all.span_server != "") cerr<<"Net time taken by process = "<<net_time/(double)(1000)<<" seconds\n"; - if(all->span_server != "") { - float loss = (float)all->sd->sum_loss; - all->sd->sum_loss = (double)accumulate_scalar(*all, all->span_server, loss); - float weighted_examples = (float)all->sd->weighted_examples; - all->sd->weighted_examples = (double)accumulate_scalar(*all, all->span_server, weighted_examples); - float weighted_labels = (float)all->sd->weighted_labels; - all->sd->weighted_labels = (double)accumulate_scalar(*all, all->span_server, weighted_labels); - float weighted_unlabeled_examples = (float)all->sd->weighted_unlabeled_examples; - all->sd->weighted_unlabeled_examples = (double)accumulate_scalar(*all, all->span_server, weighted_unlabeled_examples); - float example_number = (float)all->sd->example_number; - all->sd->example_number = (uint64_t)accumulate_scalar(*all, all->span_server, example_number); - float total_features = (float)all->sd->total_features; - all->sd->total_features = (uint64_t)accumulate_scalar(*all, all->span_server, total_features); + if(all.span_server != "") { + float loss = (float)all.sd->sum_loss; + all.sd->sum_loss = (double)accumulate_scalar(all, all.span_server, loss); + float weighted_examples = (float)all.sd->weighted_examples; + all.sd->weighted_examples = (double)accumulate_scalar(all, all.span_server, weighted_examples); + float weighted_labels = (float)all.sd->weighted_labels; + all.sd->weighted_labels = (double)accumulate_scalar(all, all.span_server, weighted_labels); + float weighted_unlabeled_examples = (float)all.sd->weighted_unlabeled_examples; + all.sd->weighted_unlabeled_examples = (double)accumulate_scalar(all, all.span_server, weighted_unlabeled_examples); + float example_number = (float)all.sd->example_number; + all.sd->example_number = (uint64_t)accumulate_scalar(all, all.span_server, example_number); + float total_features = (float)all.sd->total_features; + all.sd->total_features = (uint64_t)accumulate_scalar(all, all.span_server, total_features); } - if (!all->quiet) + if (!all.quiet) { cerr.precision(6); cerr << endl << "finished run"; - if(all->current_pass == 0) - cerr << endl << "number of examples = " << all->sd->example_number; + if(all.current_pass == 0) + cerr << endl << "number of examples = " << all.sd->example_number; else{ - cerr << endl << "number of examples per pass = " << all->sd->example_number / all->current_pass; - cerr << endl << "passes used = " << all->current_pass; - } - cerr << endl << "weighted example sum = " << all->sd->weighted_examples; - cerr << endl << "weighted label sum = " << all->sd->weighted_labels; - if(all->holdout_set_off || (all->sd->holdout_best_loss == FLT_MAX)) - { - cerr << endl << "average loss = " << all->sd->sum_loss / all->sd->weighted_examples; - } - else - { - cerr << endl << "average loss = " << all->sd->holdout_best_loss << " h"; + cerr << endl << "number of examples per pass = " << all.sd->example_number / all.current_pass; + cerr << endl << "passes used = " << all.current_pass; } + cerr << endl << "weighted example sum = " << all.sd->weighted_examples; + cerr << endl << "weighted label sum = " << all.sd->weighted_labels; + if(all.holdout_set_off || (all.sd->holdout_best_loss == FLT_MAX)) + cerr << endl << "average loss = " << all.sd->sum_loss / all.sd->weighted_examples; + else + cerr << endl << "average loss = " << all.sd->holdout_best_loss << " h"; float best_constant; float best_constant_loss; - if (get_best_constant(*all, best_constant, best_constant_loss)) - { + if (get_best_constant(all, best_constant, best_constant_loss)) + { cerr << endl << "best constant = " << best_constant; if (best_constant_loss != FLT_MIN) - cerr << endl << "best constant's loss = " << best_constant_loss; - } - - cerr << endl << "total feature number = " << all->sd->total_features; - if (all->sd->queries > 0) - cerr << endl << "total queries = " << all->sd->queries << endl; + cerr << endl << "best constant's loss = " << best_constant_loss; + } + + cerr << endl << "total feature number = " << all.sd->total_features; + if (all.sd->queries > 0) + cerr << endl << "total queries = " << all.sd->queries << endl; cerr << endl; } - VW::finish(*all); + VW::finish(all); } catch (exception& e) { // vw is implemented as a library, so we use 'throw exception()' // error 'handling' everywhere. To reduce stderr pollution diff --git a/vowpalwabbit/memory.cc b/vowpalwabbit/memory.cc deleted file mode 100644 index 7e40bf71..00000000 --- a/vowpalwabbit/memory.cc +++ /dev/null @@ -1,8 +0,0 @@ -#include <stdlib.h> - -void free_it(void* ptr) -{ - if (ptr != NULL) - free(ptr); -} - diff --git a/vowpalwabbit/memory.h b/vowpalwabbit/memory.h index 6d67d51e..af4c0c4c 100644 --- a/vowpalwabbit/memory.h +++ b/vowpalwabbit/memory.h @@ -1,5 +1,4 @@ #pragma once - #include <stdlib.h> #include <iostream> @@ -18,9 +17,6 @@ T* calloc_or_die(size_t nmemb) } template<class T> T& calloc_or_die() -{ - return *calloc_or_die<T>(1); -} - +{ return *calloc_or_die<T>(1); } -void free_it(void* ptr); +inline void free_it(void* ptr) { if (ptr != NULL) free(ptr); } diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc index 4e00be8d..c2e7575a 100644 --- a/vowpalwabbit/mf.cc +++ b/vowpalwabbit/mf.cc @@ -22,7 +22,7 @@ namespace MF { struct mf { vector<string> pairs; - uint32_t rank; + size_t rank; uint32_t increment; @@ -188,22 +188,23 @@ void finish(mf& o) { o.sub_predictions.delete_v(); } - -base_learner* setup(vw& all, po::variables_map& vm) { - mf* data = new mf; - - // copy global data locally - data->all = &all; - data->rank = (uint32_t)vm["new_mf"].as<size_t>(); + base_learner* setup(vw& all) { + new_options(all, "MF options") + ("new_mf", po::value<size_t>(), "rank for reduction-based matrix factorization"); + if(missing_required(all)) return NULL; + + mf& data = calloc_or_die<mf>(); + data.all = &all; + data.rank = (uint32_t)all.vm["new_mf"].as<size_t>(); // store global pairs in local data structure and clear global pairs // for eventual calls to base learner - data->pairs = all.pairs; + data.pairs = all.pairs; all.pairs.clear(); all.random_positive_weights = true; - learner<mf>& l = init_learner(data, all.l, learn, predict<false>, 2*data->rank+1); + learner<mf>& l = init_learner(&data, setup_base(all), learn, predict<false>, 2*data.rank+1); l.set_finish(finish); return make_base(l); } diff --git a/vowpalwabbit/mf.h b/vowpalwabbit/mf.h index 99643601..6d4be4f3 100644 --- a/vowpalwabbit/mf.h +++ b/vowpalwabbit/mf.h @@ -4,12 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -#include <math.h> -#include "example.h" -#include "parse_regressor.h" -#include "parser.h" -#include "gd.h" - -namespace MF{ - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace MF{ LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/multiclass.h b/vowpalwabbit/multiclass.h index f48efbfc..14f5d443 100644 --- a/vowpalwabbit/multiclass.h +++ b/vowpalwabbit/multiclass.h @@ -20,6 +20,8 @@ namespace MULTICLASS void finish_example(vw& all, example& ec); + template <class T> void finish_example(vw& all, T&, example& ec) { finish_example(all, ec); } + inline int label_is_test(multiclass* ld) { return ld->label == (uint32_t)-1; } } diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc index bfab6009..bcd42092 100644 --- a/vowpalwabbit/nn.cc +++ b/vowpalwabbit/nn.cc @@ -308,19 +308,20 @@ CONVERSE: // That's right, I'm using goto. So sue me. free (n.output_layer.atomics[nn_output_namespace].begin); } - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) { - nn& n = calloc_or_die<nn>(); - n.all = &all; - - po::options_description nn_opts("NN options"); - nn_opts.add_options() + new_options(all, "Neural Network options") + ("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units"); + if(missing_required(all)) return NULL; + new_options(all) ("inpass", "Train or test sigmoidal feedforward network with input passthrough.") ("dropout", "Train or test sigmoidal feedforward network using dropout.") ("meanfield", "Train or test sigmoidal feedforward network using mean field."); + add_options(all); - vm = add_options(all, nn_opts); - + po::variables_map& vm = all.vm; + nn& n = calloc_or_die<nn>(); + n.all = &all; //first parse for number of hidden units n.k = (uint32_t)vm["nn"].as<size_t>(); *all.file_options << " --nn " << n.k; @@ -364,8 +365,10 @@ CONVERSE: // That's right, I'm using goto. So sue me. n.xsubi = vm["random_seed"].as<size_t>(); n.save_xsubi = n.xsubi; - n.increment = all.l->increment;//Indexing of output layer is odd. - learner<nn>& l = init_learner(&n, all.l, predict_or_learn<true>, + + base_learner* base = setup_base(all); + n.increment = base->increment;//Indexing of output layer is odd. + learner<nn>& l = init_learner(&n, base, predict_or_learn<true>, predict_or_learn<false>, n.k+1); l.set_finish(finish); l.set_finish_example(finish_example); diff --git a/vowpalwabbit/nn.h b/vowpalwabbit/nn.h index 52e08f46..3b000433 100644 --- a/vowpalwabbit/nn.h +++ b/vowpalwabbit/nn.h @@ -4,10 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -#include "global_data.h" -#include "parse_args.h" - -namespace NN -{ - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace NN { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/noop.cc b/vowpalwabbit/noop.cc index 0c883a8c..892480ff 100644 --- a/vowpalwabbit/noop.cc +++ b/vowpalwabbit/noop.cc @@ -9,7 +9,11 @@ license as described in the file LICENSE. namespace NOOP { void learn(char&, LEARNER::base_learner&, example&) {} - + LEARNER::base_learner* setup(vw& all) - { return &LEARNER::init_learner<char>(NULL, learn, 1); } + { + new_options(all, "Noop options") ("noop","do no learning"); + if(missing_required(all)) return NULL; + + return &LEARNER::init_learner<char>(NULL, learn, 1); } } diff --git a/vowpalwabbit/noop.h b/vowpalwabbit/noop.h index 5220e1ee..ed660870 100644 --- a/vowpalwabbit/noop.h +++ b/vowpalwabbit/noop.h @@ -4,6 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace NOOP { - LEARNER::base_learner* setup(vw&); -} +namespace NOOP { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index 2328b00d..185264f5 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -7,7 +7,6 @@ license as described in the file LICENSE. #include "multiclass.h" #include "simple_label.h" #include "reductions.h" -#include "vw.h" namespace OAA { struct oaa{ @@ -60,21 +59,23 @@ namespace OAA { o.all->print_text(o.all->raw_prediction, outputStringStream.str(), ec.tag); } - void finish_example(vw& all, oaa&, example& ec) { MULTICLASS::finish_example(all, ec); } - - LEARNER::base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all) { + new_options(all, "One-against-all options") + ("oaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> labels"); + if(missing_required(all)) return NULL; + oaa& data = calloc_or_die<oaa>(); - data.k = vm["oaa"].as<size_t>(); + data.k = all.vm["oaa"].as<size_t>(); data.shouldOutput = all.raw_prediction > 0; data.all = &all; *all.file_options << " --oaa " << data.k; - all.p->lp = MULTICLASS::mc_label; - LEARNER::learner<oaa>& l = init_learner(&data, all.l, predict_or_learn<true>, - predict_or_learn<false>, data.k); - l.set_finish_example(finish_example); + LEARNER::learner<oaa>& l = init_learner(&data, setup_base(all), predict_or_learn<true>, + predict_or_learn<false>, data.k); + l.set_finish_example(MULTICLASS::finish_example<oaa>); + all.p->lp = MULTICLASS::mc_label; return make_base(l); } } diff --git a/vowpalwabbit/oaa.h b/vowpalwabbit/oaa.h index de1b08ab..2bc46649 100644 --- a/vowpalwabbit/oaa.h +++ b/vowpalwabbit/oaa.h @@ -4,5 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace OAA -{ LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } +namespace OAA { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index a760261f..c8ec126d 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -17,6 +17,7 @@ license as described in the file LICENSE. #include "network.h" #include "global_data.h" #include "nn.h" +#include "gd.h" #include "cbify.h" #include "oaa.h" #include "rand48.h" @@ -177,18 +178,17 @@ void parse_affix_argument(vw&all, string str) { free(cstr); } -void parse_diagnostics(vw& all, po::variables_map& vm, int argc) +void parse_diagnostics(vw& all, int argc) { - po::options_description diag_opt("Diagnostic options"); - - diag_opt.add_options() + new_options(all, "Diagnostic options") ("version","Version information") ("audit,a", "print weights of features") ("progress,P", po::value< string >(), "Progress update frequency. int: additive, float: multiplicative") ("quiet", "Don't output disgnostics and progress updates") ("help,h","Look here: http://hunch.net/~vw/ and click on Tutorial."); - - vm = add_options(all, diag_opt); + add_options(all); + + po::variables_map& vm = all.vm; if (vm.count("version")) { /* upon direct query for version -- spit it out to stdout */ @@ -245,11 +245,9 @@ void parse_diagnostics(vw& all, po::variables_map& vm, int argc) } } -void parse_source(vw& all, po::variables_map& vm) +void parse_source(vw& all) { - po::options_description in_opt("Input options"); - - in_opt.add_options() + new_options(all, "Input options") ("data,d", po::value< string >(), "Example Set") ("daemon", "persistent daemon mode on port 26542") ("port", po::value<size_t>(),"port to listen on; use 0 to pick unused port") @@ -261,19 +259,18 @@ void parse_source(vw& all, po::variables_map& vm) ("kill_cache,k", "do not reuse existing cache: create a new one always") ("compressed", "use gzip format whenever possible. If a cache file is being created, this option creates a compressed cache file. A mixture of raw-text & compressed inputs are supported with autodetection.") ("no_stdin", "do not default to reading from stdin"); - - vm = add_options(all, in_opt); + add_options(all); // Be friendly: if -d was left out, treat positional param as data file po::positional_options_description p; p.add("data", -1); - vm = po::variables_map(); po::parsed_options pos = po::command_line_parser(all.args). style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing). options(all.opts).positional(p).run(); - vm = po::variables_map(); - po::store(pos, vm); + all.vm = po::variables_map(); + po::store(pos, all.vm); + po::variables_map& vm = all.vm; //begin input source if (vm.count("no_stdin")) @@ -315,10 +312,9 @@ void parse_source(vw& all, po::variables_map& vm) } } -void parse_feature_tweaks(vw& all, po::variables_map& vm) +void parse_feature_tweaks(vw& all) { - po::options_description feature_opt("Feature options"); - feature_opt.add_options() + new_options(all, "Feature options") ("hash", po::value< string > (), "how to hash the features. Available options: strings, all") ("ignore", po::value< vector<unsigned char> >(), "ignore namespaces beginning with character <arg>") ("keep", po::value< vector<unsigned char> >(), "keep namespaces beginning with character <arg>") @@ -335,8 +331,9 @@ void parse_feature_tweaks(vw& all, po::variables_map& vm) ("q:", po::value< string >(), ": corresponds to a wildcard for all printable characters") ("cubic", po::value< vector<string> > (), "Create and use cubic features"); + add_options(all); - vm = add_options(all, feature_opt); + po::variables_map& vm = all.vm; //feature manipulation string hash_function("strings"); @@ -545,11 +542,9 @@ void parse_feature_tweaks(vw& all, po::variables_map& vm) all.add_constant = false; } -void parse_example_tweaks(vw& all, po::variables_map& vm) +void parse_example_tweaks(vw& all) { - po::options_description example_opts("Example options"); - - example_opts.add_options() + new_options(all, "Example options") ("testonly,t", "Ignore label information and just test") ("holdout_off", "no holdout data in multiple passes") ("holdout_period", po::value<uint32_t>(&(all.holdout_period)), "holdout period for test only, default 10") @@ -565,9 +560,9 @@ void parse_example_tweaks(vw& all, po::variables_map& vm) ("quantile_tau", po::value<float>()->default_value(0.5), "Parameter \\tau associated with Quantile loss. Defaults to 0.5") ("l1", po::value<float>(&(all.l1_lambda)), "l_1 lambda") ("l2", po::value<float>(&(all.l2_lambda)), "l_2 lambda"); + add_options(all); - vm = add_options(all, example_opts); - + po::variables_map& vm = all.vm; if (vm.count("testonly") || all.eta == 0.) { if (!all.quiet) @@ -621,17 +616,14 @@ void parse_example_tweaks(vw& all, po::variables_map& vm) } } -void parse_output_preds(vw& all, po::variables_map& vm) +void parse_output_preds(vw& all) { - po::options_description out_opt("Output options"); - - out_opt.add_options() + new_options(all, "Output options") ("predictions,p", po::value< string >(), "File to output predictions to") - ("raw_predictions,r", po::value< string >(), "File to output unnormalized predictions to") - ; - - vm = add_options(all, out_opt); + ("raw_predictions,r", po::value< string >(), "File to output unnormalized predictions to"); + add_options(all); + po::variables_map& vm = all.vm; if (vm.count("predictions")) { if (!all.quiet) cerr << "predictions = " << vm["predictions"].as< string >() << endl; @@ -676,21 +668,19 @@ void parse_output_preds(vw& all, po::variables_map& vm) } } -void parse_output_model(vw& all, po::variables_map& vm) +void parse_output_model(vw& all) { - po::options_description output_model("Output model"); - - output_model.add_options() + new_options(all, "Output model") ("final_regressor,f", po::value< string >(), "Final regressor") ("readable_model", po::value< string >(), "Output human-readable final regressor with numeric features") ("invert_hash", po::value< string >(), "Output human-readable final regressor with feature names. Computationally expensive.") ("save_resume", "save extra state so learning can be resumed later with new data") ("save_per_pass", "Save the model after every pass over data") ("output_feature_regularizer_binary", po::value< string >(&(all.per_feature_regularizer_output)), "Per feature regularization output file") - ("output_feature_regularizer_text", po::value< string >(&(all.per_feature_regularizer_text)), "Per feature regularization output file, in text"); - - vm = add_options(all, output_model); + ("output_feature_regularizer_text", po::value< string >(&(all.per_feature_regularizer_text)), "Per feature regularization output file, in text"); + add_options(all); + po::variables_map& vm = all.vm; if (vm.count("final_regressor")) { all.final_regressor_name = vm["final_regressor"].as<string>(); if (!all.quiet) @@ -714,68 +704,22 @@ void parse_output_model(vw& all, po::variables_map& vm) all.save_resume = true; } -void parse_base_algorithm(vw& all, po::variables_map& vm) -{ - //base learning algorithm. - po::options_description base_opt("base algorithms (these are exclusive)"); - - base_opt.add_options() - ("sgd", "use regular stochastic gradient descent update.") - ("ftrl", "use ftrl-proximal optimization") - ("adaptive", "use adaptive, individual learning rates.") - ("invariant", "use safe/importance aware updates.") - ("normalized", "use per feature normalized updates") - ("exact_adaptive_norm", "use current default invariant normalized adaptive update rule") - ("bfgs", "use bfgs optimization") - ("lda", po::value<uint32_t>(&(all.lda)), "Run lda with <int> topics") - ("rank", po::value<uint32_t>(&(all.rank)), "rank for matrix factorization.") - ("noop","do no learning") - ("print","print examples") - ("ksvm", "kernel svm") - ("sendto", po::value< vector<string> >(), "send examples to <host>"); - - vm = add_options(all, base_opt); - - if (vm.count("bfgs") || vm.count("conjugate_gradient")) - all.l = BFGS::setup(all, vm); - else if (vm.count("lda")) - all.l = LDA::setup(all, vm); - else if (vm.count("ftrl")) - all.l = FTRL::setup(all, vm); - else if (vm.count("noop")) - all.l = NOOP::setup(all); - else if (vm.count("print")) - all.l = PRINT::setup(all); - else if (all.rank > 0) - all.l = GDMF::setup(all, vm); - else if (vm.count("sendto")) - all.l = SENDER::setup(all, vm, all.pairs); - else if (vm.count("ksvm")) { - all.l = KSVM::setup(all, vm); - } - else - { - all.l = GD::setup(all, vm); - all.scorer = all.l; - } -} - -void load_input_model(vw& all, po::variables_map& vm, io_buf& io_temp) +void load_input_model(vw& all, io_buf& io_temp) { // Need to see if we have to load feature mask first or second. // -i and -mask are from same file, load -i file first so mask can use it - if (vm.count("feature_mask") && vm.count("initial_regressor") - && vm["feature_mask"].as<string>() == vm["initial_regressor"].as< vector<string> >()[0]) { + if (all.vm.count("feature_mask") && all.vm.count("initial_regressor") + && all.vm["feature_mask"].as<string>() == all.vm["initial_regressor"].as< vector<string> >()[0]) { // load rest of regressor all.l->save_load(io_temp, true, false); io_temp.close_file(); // set the mask, which will reuse -i file we just loaded - parse_mask_regressor_args(all, vm); + parse_mask_regressor_args(all); } else { // load mask first - parse_mask_regressor_args(all, vm); + parse_mask_regressor_args(all); // load rest of regressor all.l->save_load(io_temp, true, false); @@ -783,166 +727,51 @@ void load_input_model(vw& all, po::variables_map& vm, io_buf& io_temp) } } -void parse_scorer_reductions(vw& all, po::variables_map& vm) +LEARNER::base_learner* setup_base(vw& all) { - po::options_description score_mod_opt("Score modifying options (can be combined)"); - - score_mod_opt.add_options() - ("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units") - ("new_mf", po::value<size_t>(), "rank for reduction-based matrix factorization") - ("autolink", po::value<size_t>(), "create link function with polynomial d") - ("lrq", po::value<vector<string> > (), "use low rank quadratic features") - ("lrqdropout", "use dropout training for low rank quadratic features") - ("stage_poly", "use stagewise polynomial feature learning") - ("active", "enable active learning"); - - vm = add_options(all, score_mod_opt); - - if (vm.count("active")) - all.l = ACTIVE::setup(all,vm); - - if(vm.count("nn")) - all.l = NN::setup(all, vm); - - if (vm.count("new_mf")) - all.l = MF::setup(all, vm); - - if(vm.count("autolink")) - all.l = ALINK::setup(all, vm); - - if (vm.count("lrq")) - all.l = LRQ::setup(all, vm); - - if (vm.count("stage_poly")) - all.l = StagewisePoly::setup(all, vm); - - all.l = Scorer::setup(all, vm); + LEARNER::base_learner* ret = all.reduction_stack.pop()(all); + if (ret == NULL) + return setup_base(all); + else + return ret; } -LEARNER::base_learner* exclusive_setup(vw& all, po::variables_map& vm, bool& score_consumer, LEARNER::base_learner* (*setup)(vw&, po::variables_map&)) +void parse_reductions(vw& all) { - if (score_consumer) { cerr << "error: cannot specify multiple direct score consumers" << endl; throw exception(); } - score_consumer = true; - return setup(all, vm); -} - -void parse_score_users(vw& all, po::variables_map& vm, bool& got_cs) -{ - po::options_description multiclass_opt("Score user options (these are exclusive)"); - multiclass_opt.add_options() - ("top", po::value<size_t>(), "top k recommendation") - ("binary", "report loss as binary classification on -1,1") - ("oaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> labels") - ("ect", po::value<size_t>(), "Use error correcting tournament with <k> labels") - ("log_multi", po::value<size_t>(), "Use online tree for multiclass") - ("csoaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> costs") - ("csoaa_ldf", po::value<string>(), "Use one-against-all multiclass learning with label dependent features. Specify singleline or multiline.") - ("wap_ldf", po::value<string>(), "Use weighted all-pairs multiclass learning with label dependent features. Specify singleline or multiline.") - ; - - vm = add_options(all, multiclass_opt); - bool score_consumer = false; - - if(vm.count("top")) - all.l = exclusive_setup(all, vm, score_consumer, TOPK::setup); - - if (vm.count("binary")) - all.l = exclusive_setup(all, vm, score_consumer, BINARY::setup); - - if (vm.count("oaa")) - all.l = exclusive_setup(all, vm, score_consumer, OAA::setup); - - if (vm.count("ect")) - all.l = exclusive_setup(all, vm, score_consumer, ECT::setup); - - if(vm.count("csoaa")) { - all.l = exclusive_setup(all, vm, score_consumer, CSOAA::setup); - all.cost_sensitive = all.l; - got_cs = true; - } - - if(vm.count("log_multi")){ - all.l = exclusive_setup(all, vm, score_consumer, LOG_MULTI::setup); - } - - if(vm.count("csoaa_ldf") || vm.count("csoaa_ldf")) { - all.l = exclusive_setup(all, vm, score_consumer, CSOAA_AND_WAP_LDF::setup); - all.cost_sensitive = all.l; - got_cs = true; - } - - if(vm.count("wap_ldf") || vm.count("wap_ldf") ) { - all.l = exclusive_setup(all, vm, score_consumer, CSOAA_AND_WAP_LDF::setup); - all.cost_sensitive = all.l; - got_cs = true; - } -} - -void parse_cb(vw& all, po::variables_map& vm, bool& got_cs, bool& got_cb) -{ - po::options_description cb_opts("Contextual Bandit options"); - - cb_opts.add_options() - ("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs") - ("cbify", po::value<size_t>(), "Convert multiclass on <k> classes into a contextual bandit problem and solve"); - - vm = add_options(all,cb_opts); - - if( vm.count("cb")) - { - if(!got_cs) { - if( vm.count("cb") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cb"])); - else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cb"])); - - all.l = CSOAA::setup(all, vm); // default to CSOAA unless wap is specified - all.cost_sensitive = all.l; - got_cs = true; - } - - all.l = CB_ALGS::setup(all, vm); - got_cb = true; - } - - if (vm.count("cbify")) - { - if(!got_cs) { - vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cbify"])); - - all.l = CSOAA::setup(all, vm); // default to CSOAA unless wap is specified - all.cost_sensitive = all.l; - got_cs = true; - } - - if (!got_cb) { - vm.insert(pair<string,po::variable_value>(string("cb"),vm["cbify"])); - all.l = CB_ALGS::setup(all, vm); - got_cb = true; - } - - all.l = CBIFY::setup(all, vm); - } -} - -void parse_search(vw& all, po::variables_map& vm, bool& got_cs, bool& got_cb) -{ - po::options_description search_opts("Search"); - - search_opts.add_options() - ("search", po::value<size_t>(), "use search-based structured prediction, argument=maximum action id or 0 for LDF"); - - vm = add_options(all,search_opts); - - if (vm.count("search")) { - if (!got_cs && !got_cb) { - if( vm.count("search") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["search"])); - else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["search"])); - - all.l = CSOAA::setup(all, vm); // default to CSOAA unless others have been specified - all.cost_sensitive = all.l; - got_cs = true; - } - all.l = Search::setup(all, vm); - } + //Base algorithms + all.reduction_stack.push_back(GD::setup); + all.reduction_stack.push_back(KSVM::setup); + all.reduction_stack.push_back(FTRL::setup); + all.reduction_stack.push_back(SENDER::setup); + all.reduction_stack.push_back(GDMF::setup); + all.reduction_stack.push_back(PRINT::setup); + all.reduction_stack.push_back(NOOP::setup); + all.reduction_stack.push_back(LDA::setup); + all.reduction_stack.push_back(BFGS::setup); + + //Score Users + all.reduction_stack.push_back(ACTIVE::setup); + all.reduction_stack.push_back(NN::setup); + all.reduction_stack.push_back(MF::setup); + all.reduction_stack.push_back(ALINK::setup); + all.reduction_stack.push_back(LRQ::setup); + all.reduction_stack.push_back(StagewisePoly::setup); + all.reduction_stack.push_back(Scorer::setup); + + //Reductions + all.reduction_stack.push_back(BINARY::setup); + all.reduction_stack.push_back(TOPK::setup); + all.reduction_stack.push_back(OAA::setup); + all.reduction_stack.push_back(ECT::setup); + all.reduction_stack.push_back(LOG_MULTI::setup); + all.reduction_stack.push_back(CSOAA::setup); + all.reduction_stack.push_back(CSOAA_AND_WAP_LDF::setup); + all.reduction_stack.push_back(CB_ALGS::setup); + all.reduction_stack.push_back(CBIFY::setup); + all.reduction_stack.push_back(Search::setup); + all.reduction_stack.push_back(BS::setup); + + all.l = setup_base(all); } void add_to_args(vw& all, int argc, char* argv[]) @@ -951,143 +780,107 @@ void add_to_args(vw& all, int argc, char* argv[]) all.args.push_back(string(argv[i])); } -vw* parse_args(int argc, char *argv[]) +vw& parse_args(int argc, char *argv[]) { - vw* all = new vw(); + vw& all = *(new vw()); - add_to_args(*all, argc, argv); + add_to_args(all, argc, argv); size_t random_seed = 0; - all->program_name = argv[0]; + all.program_name = argv[0]; - po::options_description desc("VW options"); - - desc.add_options() + new_options(all, "VW options") ("random_seed", po::value<size_t>(&random_seed), "seed random number generator") - ("ring_size", po::value<size_t>(&(all->p->ring_size)), "size of example ring"); - - po::options_description update_opt("Update options"); + ("ring_size", po::value<size_t>(&(all.p->ring_size)), "size of example ring"); + add_options(all); - update_opt.add_options() - ("learning_rate,l", po::value<float>(&(all->eta)), "Set learning rate") - ("power_t", po::value<float>(&(all->power_t)), "t power value") - ("decay_learning_rate", po::value<float>(&(all->eta_decay_rate)), + new_options(all, "Update options") + ("learning_rate,l", po::value<float>(&(all.eta)), "Set learning rate") + ("power_t", po::value<float>(&(all.power_t)), "t power value") + ("decay_learning_rate", po::value<float>(&(all.eta_decay_rate)), "Set Decay factor for learning_rate between passes") - ("initial_t", po::value<double>(&((all->sd->t))), "initial t value") - ("feature_mask", po::value< string >(), "Use existing regressor to determine which parameters may be updated. If no initial_regressor given, also used for initial weights.") - ; - - po::options_description weight_opt("Weight options"); + ("initial_t", po::value<double>(&((all.sd->t))), "initial t value") + ("feature_mask", po::value< string >(), "Use existing regressor to determine which parameters may be updated. If no initial_regressor given, also used for initial weights."); + add_options(all); - weight_opt.add_options() + new_options(all, "Weight options") ("initial_regressor,i", po::value< vector<string> >(), "Initial regressor(s)") - ("initial_weight", po::value<float>(&(all->initial_weight)), "Set all weights to an initial value of 1.") - ("random_weights", po::value<bool>(&(all->random_weights)), "make initial weights random") - ("input_feature_regularizer", po::value< string >(&(all->per_feature_regularizer_input)), "Per feature regularization input file") - ; - - po::options_description cluster_opt("Parallelization options"); - cluster_opt.add_options() - ("span_server", po::value<string>(&(all->span_server)), "Location of server for setting up spanning tree") - ("unique_id", po::value<size_t>(&(all->unique_id)),"unique id used for cluster parallel jobs") - ("total", po::value<size_t>(&(all->total)),"total number of nodes used in cluster parallel job") - ("node", po::value<size_t>(&(all->node)),"node number in cluster parallel job") - ; - - po::options_description other_opt("Other options"); - other_opt.add_options() - ("bootstrap,B", po::value<size_t>(), "bootstrap mode with k rounds by online importance resampling") - ; - - desc.add(update_opt) - .add(weight_opt) - .add(cluster_opt) - .add(other_opt); - - po::variables_map vm = add_options(*all, desc); - + ("initial_weight", po::value<float>(&(all.initial_weight)), "Set all weights to an initial value of 1.") + ("random_weights", po::value<bool>(&(all.random_weights)), "make initial weights random") + ("input_feature_regularizer", po::value< string >(&(all.per_feature_regularizer_input)), "Per feature regularization input file"); + add_options(all); + + new_options(all, "Parallelization options") + ("span_server", po::value<string>(&(all.span_server)), "Location of server for setting up spanning tree") + ("unique_id", po::value<size_t>(&(all.unique_id)),"unique id used for cluster parallel jobs") + ("total", po::value<size_t>(&(all.total)),"total number of nodes used in cluster parallel job") + ("node", po::value<size_t>(&(all.node)),"node number in cluster parallel job"); + add_options(all); + + po::variables_map& vm = all.vm; msrand48(random_seed); + parse_diagnostics(all, argc); - parse_diagnostics(*all, vm, argc); - - all->sd->weighted_unlabeled_examples = all->sd->t; - all->initial_t = (float)all->sd->t; + all.sd->weighted_unlabeled_examples = all.sd->t; + all.initial_t = (float)all.sd->t; //Input regressor header io_buf io_temp; - parse_regressor_args(*all, vm, io_temp); + parse_regressor_args(all, vm, io_temp); int temp_argc = 0; - char** temp_argv = VW::get_argv_from_string(all->file_options->str(), temp_argc); - add_to_args(*all, temp_argc, temp_argv); + char** temp_argv = VW::get_argv_from_string(all.file_options->str(), temp_argc); + add_to_args(all, temp_argc, temp_argv); for (int i = 0; i < temp_argc; i++) free(temp_argv[i]); free(temp_argv); - po::parsed_options pos = po::command_line_parser(all->args). + po::parsed_options pos = po::command_line_parser(all.args). style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing). - options(all->opts).allow_unregistered().run(); + options(all.opts).allow_unregistered().run(); vm = po::variables_map(); po::store(pos, vm); po::notify(vm); - all->file_options->str(""); + all.file_options->str(""); - parse_feature_tweaks(*all, vm); //feature tweaks + parse_feature_tweaks(all); //feature tweaks - parse_example_tweaks(*all, vm); //example manipulation + parse_example_tweaks(all); //example manipulation - parse_output_model(*all, vm); + parse_output_model(all); - parse_base_algorithm(*all, vm); + parse_reductions(all); - if (!all->quiet) + if (!all.quiet) { - cerr << "Num weight bits = " << all->num_bits << endl; - cerr << "learning rate = " << all->eta << endl; - cerr << "initial_t = " << all->sd->t << endl; - cerr << "power_t = " << all->power_t << endl; - if (all->numpasses > 1) - cerr << "decay_learning_rate = " << all->eta_decay_rate << endl; - if (all->rank > 0) - cerr << "rank = " << all->rank << endl; + cerr << "Num weight bits = " << all.num_bits << endl; + cerr << "learning rate = " << all.eta << endl; + cerr << "initial_t = " << all.sd->t << endl; + cerr << "power_t = " << all.power_t << endl; + if (all.numpasses > 1) + cerr << "decay_learning_rate = " << all.eta_decay_rate << endl; } - parse_output_preds(*all, vm); - - parse_scorer_reductions(*all, vm); - - bool got_cs = false; - - parse_score_users(*all, vm, got_cs); - - bool got_cb = false; - - parse_cb(*all, vm, got_cs, got_cb); - - parse_search(*all, vm, got_cs, got_cb); - - - if(vm.count("bootstrap")) - all->l = BS::setup(*all, vm); + parse_output_preds(all); - load_input_model(*all, vm, io_temp); + load_input_model(all, io_temp); - parse_source(*all, vm); + parse_source(all); - enable_sources(*all, vm, all->quiet,all->numpasses); + enable_sources(all, all.quiet, all.numpasses); // force wpp to be a power of 2 to avoid 32-bit overflow uint32_t i = 0; - size_t params_per_problem = all->l->increment; + size_t params_per_problem = all.l->increment; while (params_per_problem > (uint32_t)(1 << i)) i++; - all->wpp = (1 << i) >> all->reg.stride_shift; + all.wpp = (1 << i) >> all.reg.stride_shift; if (vm.count("help")) { /* upon direct query for help -- spit it out to stdout */ - cout << "\n" << all->opts << "\n"; + cout << "\n" << all.opts << "\n"; exit(0); } @@ -1152,15 +945,15 @@ namespace VW { s += " --no_stdin"; char** argv = get_argv_from_string(s,argc); - vw* all = parse_args(argc, argv); + vw& all = parse_args(argc, argv); - initialize_parser_datastructures(*all); + initialize_parser_datastructures(all); for(int i = 0; i < argc; i++) free(argv[i]); - free (argv); + free(argv); - return all; + return &all; } void delete_dictionary_entry(substring ss, v_array<feature>*A) { @@ -1182,6 +975,7 @@ namespace VW { all.p->parse_name.delete_v(); free(all.p); free(all.sd); + all.reduction_stack.delete_v(); delete all.file_options; for (size_t i = 0; i < all.final_prediction_sink.size(); i++) if (all.final_prediction_sink[i] != 1) diff --git a/vowpalwabbit/parse_args.h b/vowpalwabbit/parse_args.h index 9e16d5bc..6d150382 100644 --- a/vowpalwabbit/parse_args.h +++ b/vowpalwabbit/parse_args.h @@ -6,4 +6,5 @@ license as described in the file LICENSE. #pragma once #include "global_data.h" -vw* parse_args(int argc, char *argv[]); +vw& parse_args(int argc, char *argv[]); +LEARNER::base_learner* setup_base(vw& all); diff --git a/vowpalwabbit/parse_regressor.cc b/vowpalwabbit/parse_regressor.cc index def3a8de..7cb6a21b 100644 --- a/vowpalwabbit/parse_regressor.cc +++ b/vowpalwabbit/parse_regressor.cc @@ -164,11 +164,6 @@ void save_load_header(vw& all, io_buf& model_file, bool read, bool text) "", read, "\n",1, text); - text_len = sprintf(buff, "rank:%d\n", (int)all.rank); - bin_text_read_write_fixed(model_file,(char*)&all.rank, sizeof(all.rank), - "", read, - buff,text_len, text); - text_len = sprintf(buff, "lda:%d\n", (int)all.lda); bin_text_read_write_fixed(model_file,(char*)&all.lda, sizeof(all.lda), "", read, @@ -312,19 +307,19 @@ void parse_regressor_args(vw& all, po::variables_map& vm, io_buf& io_temp) save_load_header(all, io_temp, true, false); } -void parse_mask_regressor_args(vw& all, po::variables_map& vm){ - +void parse_mask_regressor_args(vw& all) +{ + po::variables_map& vm = all.vm; if (vm.count("feature_mask")) { size_t length = ((size_t)1) << all.num_bits; string mask_filename = vm["feature_mask"].as<string>(); if (vm.count("initial_regressor")){ vector<string> init_filename = vm["initial_regressor"].as< vector<string> >(); if(mask_filename == init_filename[0]){//-i and -mask are from same file, just generate mask - return; } } - + //all other cases, including from different file, or -i does not exist, need to read in the mask file io_buf io_temp_mask; io_temp_mask.open_file(mask_filename.c_str(), false, io_buf::READ); diff --git a/vowpalwabbit/parse_regressor.h b/vowpalwabbit/parse_regressor.h index b76b5cdc..069dd6a2 100644 --- a/vowpalwabbit/parse_regressor.h +++ b/vowpalwabbit/parse_regressor.h @@ -20,4 +20,4 @@ void initialize_regressor(vw& all); void save_predictor(vw& all, std::string reg_name, size_t current_pass); void save_load_header(vw& all, io_buf& model_file, bool read, bool text); -void parse_mask_regressor_args(vw& all, po::variables_map& vm); +void parse_mask_regressor_args(vw& all); diff --git a/vowpalwabbit/parser.cc b/vowpalwabbit/parser.cc index 44055bab..c71f2956 100644 --- a/vowpalwabbit/parser.cc +++ b/vowpalwabbit/parser.cc @@ -405,10 +405,10 @@ void parse_cache(vw& all, po::variables_map &vm, string source, # define MAP_ANONYMOUS MAP_ANON #endif -void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes) +void enable_sources(vw& all, bool quiet, size_t passes) { all.p->input->current = 0; - parse_cache(all, vm, all.data_filename, quiet); + parse_cache(all, all.vm, all.data_filename, quiet); if (all.daemon || all.active) { @@ -431,8 +431,8 @@ void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes) address.sin_family = AF_INET; address.sin_addr.s_addr = htonl(INADDR_ANY); short unsigned int port = 26542; - if (vm.count("port")) - port = (uint16_t)vm["port"].as<size_t>(); + if (all.vm.count("port")) + port = (uint16_t)all.vm["port"].as<size_t>(); address.sin_port = htons(port); // attempt to bind to socket @@ -449,7 +449,7 @@ void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes) } // write port file - if (vm.count("port_file")) + if (all.vm.count("port_file")) { socklen_t address_size = sizeof(address); if (getsockname(all.p->bound_sock, (sockaddr*)&address, &address_size) < 0) @@ -457,7 +457,7 @@ void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes) cerr << "getsockname: " << strerror(errno) << endl; } ofstream port_file; - port_file.open(vm["port_file"].as<string>().c_str()); + port_file.open(all.vm["port_file"].as<string>().c_str()); if (!port_file.is_open()) { cerr << "error writing port file" << endl; @@ -474,10 +474,10 @@ void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes) throw exception(); } // write pid file - if (vm.count("pid_file")) + if (all.vm.count("pid_file")) { ofstream pid_file; - pid_file.open(vm["pid_file"].as<string>().c_str()); + pid_file.open(all.vm["pid_file"].as<string>().c_str()); if (!pid_file.is_open()) { cerr << "error writing pid file" << endl; @@ -597,7 +597,7 @@ void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes) } all.p->resettable = all.p->write_cache || all.daemon; } - else // was: else if (vm.count("data")) + else { if (all.p->input->files.size() > 0) { @@ -840,37 +840,22 @@ void setup_example(vw& all, example* ae) ae->total_sum_feat_sq += ae->sum_feat_sq[*i]; } - if (all.rank == 0) { - for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++) - { - ae->num_features - += ae->atomics[(int)(*i)[0]].size() - *ae->atomics[(int)(*i)[1]].size(); - ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]]*ae->sum_feat_sq[(int)(*i)[1]]; - } - - for (vector<string>::iterator i = all.triples.begin(); i != all.triples.end();i++) - { - ae->num_features - += ae->atomics[(int)(*i)[0]].size() - *ae->atomics[(int)(*i)[1]].size() - *ae->atomics[(int)(*i)[2]].size(); - ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]] * ae->sum_feat_sq[(int)(*i)[1]] * ae->sum_feat_sq[(int)(*i)[2]]; - } - - } else { - for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++) - { - ae->num_features += ae->atomics[(int)(*i)[0]].size() * all.rank; - ae->num_features += ae->atomics[(int)(*i)[1]].size() * all.rank; - } - for (vector<string>::iterator i = all.triples.begin(); i != all.triples.end();i++) - { - ae->num_features += ae->atomics[(int)(*i)[0]].size() * all.rank; - ae->num_features += ae->atomics[(int)(*i)[1]].size() * all.rank; - ae->num_features += ae->atomics[(int)(*i)[2]].size() * all.rank; - } - } + for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++) + { + ae->num_features + += ae->atomics[(int)(*i)[0]].size() + *ae->atomics[(int)(*i)[1]].size(); + ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]]*ae->sum_feat_sq[(int)(*i)[1]]; + } + + for (vector<string>::iterator i = all.triples.begin(); i != all.triples.end();i++) + { + ae->num_features + += ae->atomics[(int)(*i)[0]].size() + *ae->atomics[(int)(*i)[1]].size() + *ae->atomics[(int)(*i)[2]].size(); + ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]] * ae->sum_feat_sq[(int)(*i)[1]] * ae->sum_feat_sq[(int)(*i)[2]]; + } } } diff --git a/vowpalwabbit/parser.h b/vowpalwabbit/parser.h index c2d271d1..1c5b8f50 100644 --- a/vowpalwabbit/parser.h +++ b/vowpalwabbit/parser.h @@ -58,7 +58,7 @@ struct parser { parser* new_parser(); -void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes); +void enable_sources(vw& all, bool quiet, size_t passes); bool examples_to_finish(); diff --git a/vowpalwabbit/print.cc b/vowpalwabbit/print.cc index d0dc2765..321a8710 100644 --- a/vowpalwabbit/print.cc +++ b/vowpalwabbit/print.cc @@ -40,11 +40,15 @@ namespace PRINT GD::foreach_feature<vw, print_feature>(*(p.all), ec, *p.all); cout << endl; } - + LEARNER::base_learner* setup(vw& all) { + new_options(all, "Print options") ("print","print examples"); + if(missing_required(all)) return NULL; + print& p = calloc_or_die<print>(); p.all = &all; + size_t length = ((size_t)1) << all.num_bits; all.reg.weight_mask = (length << all.reg.stride_shift) - 1; all.reg.stride_shift = 0; diff --git a/vowpalwabbit/print.h b/vowpalwabbit/print.h index b6a771ed..affd09e8 100644 --- a/vowpalwabbit/print.h +++ b/vowpalwabbit/print.h @@ -4,6 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace PRINT { - LEARNER::base_learner* setup(vw& all); -} +namespace PRINT { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/reductions.h b/vowpalwabbit/reductions.h index a1c18ecf..35571cb2 100644 --- a/vowpalwabbit/reductions.h +++ b/vowpalwabbit/reductions.h @@ -13,3 +13,4 @@ namespace po = boost::program_options; #include "learner.h" // for core reduction definition #include "global_data.h" // for vw datastructure #include "memory.h" +#include "parse_args.h" diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc index 50645ed8..826f8e3b 100644 --- a/vowpalwabbit/scorer.cc +++ b/vowpalwabbit/scorer.cc @@ -31,33 +31,31 @@ namespace Scorer { float id(float in) { return in; } - LEARNER::base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all) { + new_options(all, "Link options") + ("link", po::value<string>()->default_value("identity"), "Specify the link function: identity, logistic or glf1"); + add_options(all); + po::variables_map& vm = all.vm; scorer& s = calloc_or_die<scorer>(); s.all = &all; - po::options_description link_opts("Link options"); - - link_opts.add_options() - ("link", po::value<string>()->default_value("identity"), "Specify the link function: identity, logistic or glf1"); - - vm = add_options(all, link_opts); - + LEARNER::base_learner* base = setup_base(all); LEARNER::learner<scorer>* l; string link = vm["link"].as<string>(); if (!vm.count("link") || link.compare("identity") == 0) - l = &init_learner(&s, all.l, predict_or_learn<true, id>, predict_or_learn<false, id>); + l = &init_learner(&s, base, predict_or_learn<true, id>, predict_or_learn<false, id>); else if (link.compare("logistic") == 0) { *all.file_options << " --link=logistic "; - l = &init_learner(&s, all.l, predict_or_learn<true, logistic>, + l = &init_learner(&s, base, predict_or_learn<true, logistic>, predict_or_learn<false, logistic>); } else if (link.compare("glf1") == 0) { *all.file_options << " --link=glf1 "; - l = &init_learner(&s, all.l, predict_or_learn<true, glf1>, + l = &init_learner(&s, base, predict_or_learn<true, glf1>, predict_or_learn<false, glf1>); } else @@ -65,6 +63,8 @@ namespace Scorer { cerr << "Unknown link function: " << link << endl; throw exception(); } - return make_base(*l); + all.scorer = make_base(*l); + + return all.scorer; } } diff --git a/vowpalwabbit/scorer.h b/vowpalwabbit/scorer.h index 2d0ec294..efd95e9a 100644 --- a/vowpalwabbit/scorer.h +++ b/vowpalwabbit/scorer.h @@ -1,4 +1,2 @@ #pragma once -namespace Scorer { - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace Scorer { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc index fb14e65c..74444206 100644 --- a/vowpalwabbit/search.cc +++ b/vowpalwabbit/search.cc @@ -12,9 +12,8 @@ license as described in the file LICENSE. #include "rand48.h" #include "cost_sensitive.h" #include "multiclass.h" -#include "memory.h" #include "constant.h" -#include "example.h" +#include "reductions.h" #include "cb.h" #include "gd.h" // for GD::foreach_feature #include <math.h> @@ -1669,14 +1668,14 @@ namespace Search { ret = false; } - void handle_condition_options(vw& vw, auto_condition_settings& acset, po::variables_map& vm) { - po::options_description condition_options("Search Auto-conditioning Options"); - condition_options.add_options() - ("search_max_bias_ngram_length", po::value<size_t>(), "add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 (default), then you get a single feature for each conditional") - ("search_max_quad_ngram_length", po::value<size_t>(), "add bias *times* input features for each ngram up to and including this length (def: 0)") - ("search_condition_feature_value", po::value<float> (), "how much weight should the conditional features get? (def: 1.)"); + void handle_condition_options(vw& vw, auto_condition_settings& acset) { + new_options(vw, "Search Auto-conditioning Options") + ("search_max_bias_ngram_length", po::value<size_t>(), "add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 (default), then you get a single feature for each conditional") + ("search_max_quad_ngram_length", po::value<size_t>(), "add bias *times* input features for each ngram up to and including this length (def: 0)") + ("search_condition_feature_value", po::value<float> (), "how much weight should the conditional features get? (def: 1.)"); + add_options(vw); - vm = add_options(vw, condition_options); + po::variables_map& vm = vw.vm; check_option<size_t>(acset.max_bias_ngram_length, vw, vm, "search_max_bias_ngram_length", false, size_equal, "warning: you specified a different value for --search_max_bias_ngram_length than the one loaded from regressor. proceeding with loaded value: ", ""); @@ -1764,38 +1763,39 @@ namespace Search { delete[] cstr; } - base_learner* setup(vw&all, po::variables_map& vm) { - search& sch = calloc_or_die<search>(); - sch.priv = new search_private(); - search_initialize(&all, sch); - search_private& priv = *sch.priv; - - po::options_description search_opts("Search Options"); - search_opts.add_options() - ("search_task", po::value<string>(), "the search task (use \"--search_task list\" to get a list of available tasks)") - ("search_interpolation", po::value<string>(), "at what level should interpolation happen? [*data|policy]") - ("search_rollout", po::value<string>(), "how should rollouts be executed? [policy|oracle|*mix_per_state|mix_per_roll|none]") - ("search_rollin", po::value<string>(), "how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]") - - ("search_passes_per_policy", po::value<size_t>(), "number of passes per policy (only valid for search_interpolation=policy) [def=1]") - ("search_beta", po::value<float>(), "interpolation rate for policies (only valid for search_interpolation=policy) [def=0.5]") - - ("search_alpha", po::value<float>(), "annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data) [def=1e-10]") - - ("search_total_nb_policies", po::value<size_t>(), "if we are going to train the policies through multiple separate calls to vw, we need to specify this parameter and tell vw how many policies are eventually going to be trained") - - ("search_trained_nb_policies", po::value<size_t>(), "the number of trained policies in a file") - - ("search_allowed_transitions",po::value<string>(),"read file of allowed transitions [def: all transitions are allowed]") - ("search_subsample_time", po::value<float>(), "instead of training at all timesteps, use a subset. if value in (0,1), train on a random v%. if v>=1, train on precisely v steps per example") - ("search_neighbor_features", po::value<string>(), "copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line namespace a and next next line from namespace _unnamed_, where ',' separates them") - ("search_rollout_num_steps", po::value<size_t>(), "how many calls of \"loss\" before we stop really predicting on rollouts and switch to oracle (def: 0 means \"infinite\")") - ("search_history_length", po::value<size_t>(), "some tasks allow you to specify how much history their depend on; specify that here [def: 1]") - - ("search_no_caching", "turn off the built-in caching ability (makes things slower, but technically more safe)") - ("search_beam", po::value<size_t>(), "use beam search (arg = beam size, default 0 = no beam)") - ("search_kbest", po::value<size_t>(), "size of k-best list to produce (must be <= beam size)") - ; + base_learner* setup(vw&all) { + new_options(all,"Search Options") + ("search", po::value<size_t>(), "use search-based structured prediction, argument=maximum action id or 0 for LDF"); + if (missing_required(all)) return NULL; + new_options(all) + ("search_task", po::value<string>(), "the search task (use \"--search_task list\" to get a list of available tasks)") + ("search_interpolation", po::value<string>(), "at what level should interpolation happen? [*data|policy]") + ("search_rollout", po::value<string>(), "how should rollouts be executed? [policy|oracle|*mix_per_state|mix_per_roll|none]") + ("search_rollin", po::value<string>(), "how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]") + + ("search_passes_per_policy", po::value<size_t>(), "number of passes per policy (only valid for search_interpolation=policy) [def=1]") + ("search_beta", po::value<float>(), "interpolation rate for policies (only valid for search_interpolation=policy) [def=0.5]") + + ("search_alpha", po::value<float>(), "annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data) [def=1e-10]") + + ("search_total_nb_policies", po::value<size_t>(), "if we are going to train the policies through multiple separate calls to vw, we need to specify this parameter and tell vw how many policies are eventually going to be trained") + + ("search_trained_nb_policies", po::value<size_t>(), "the number of trained policies in a file") + + ("search_allowed_transitions",po::value<string>(),"read file of allowed transitions [def: all transitions are allowed]") + ("search_subsample_time", po::value<float>(), "instead of training at all timesteps, use a subset. if value in (0,1), train on a random v%. if v>=1, train on precisely v steps per example") + ("search_neighbor_features", po::value<string>(), "copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line namespace a and next next line from namespace _unnamed_, where ',' separates them") + ("search_rollout_num_steps", po::value<size_t>(), "how many calls of \"loss\" before we stop really predicting on rollouts and switch to oracle (def: 0 means \"infinite\")") + ("search_history_length", po::value<size_t>(), "some tasks allow you to specify how much history their depend on; specify that here [def: 1]") + + ("search_no_caching", "turn off the built-in caching ability (makes things slower, but technically more safe)") + ("search_beam", po::value<size_t>(), "use beam search (arg = beam size, default 0 = no beam)") + ("search_kbest", po::value<size_t>(), "size of k-best list to produce (must be <= beam size)") + ; + add_options(all); + po::variables_map& vm = all.vm; + if (!vm.count("search")) + return NULL; bool has_hook_task = false; for (size_t i=0; i<all.args.size()-1; i++) @@ -1805,9 +1805,12 @@ namespace Search { for (int i = (int)all.args.size()-2; i >= 0; i--) if (all.args[i] == "--search_task" && all.args[i+1] != "hook") all.args.erase(all.args.begin() + i, all.args.begin() + i + 2); + + search& sch = calloc_or_die<search>(); + sch.priv = new search_private(); + search_initialize(&all, sch); + search_private& priv = *sch.priv; - vm = add_options(all, search_opts); - std::string task_string; std::string interpolation_string = "data"; std::string rollout_string = "mix_per_state"; @@ -1955,6 +1958,18 @@ namespace Search { } all.p->emptylines_separate_examples = true; + if (count(all.args.begin(), all.args.end(),"--csoaa") == 0 + && count(all.args.begin(), all.args.end(),"--csoaa_ldf") == 0 + && count(all.args.begin(), all.args.end(),"--wap_ldf") == 0 + && count(all.args.begin(), all.args.end(),"--cb") == 0) + { + all.args.push_back("--csoaa"); + stringstream ss; + ss << vm["search"].as<size_t>(); + all.args.push_back(ss.str()); + } + base_learner* base = setup_base(all); + // default to OAA labels unless the task wants to override this (which they can do in initialize) all.p->lp = MC::mc_label; if (priv.task) @@ -1964,7 +1979,7 @@ namespace Search { // set up auto-history if they want it if (priv.auto_condition_features) { - handle_condition_options(all, priv.acset, vm); + handle_condition_options(all, priv.acset); // turn off auto-condition if it's irrelevant if (((priv.acset.max_bias_ngram_length == 0) && (priv.acset.max_quad_ngram_length == 0)) || @@ -1981,7 +1996,8 @@ namespace Search { priv.start_clock_time = clock(); - learner<search>& l = init_learner(&sch, all.l, search_predict_or_learn<true>, + learner<search>& l = init_learner(&sch, base, + search_predict_or_learn<true>, search_predict_or_learn<false>, priv.total_number_of_policies); l.set_finish_example(finish_example); @@ -2063,7 +2079,7 @@ namespace Search { void search::set_num_learners(size_t num_learners) { this->priv->num_learners = num_learners; } - void search::add_program_options(po::variables_map& vm, po::options_description& opts) { vm = add_options( *this->priv->all, opts ); } + void search::add_program_options(po::variables_map& vw, po::options_description& opts) { add_options( *this->priv->all, opts ); } size_t search::get_mask() { return this->priv->all->reg.weight_mask;} size_t search::get_stride_shift() { return this->priv->all->reg.stride_shift;} diff --git a/vowpalwabbit/search.h b/vowpalwabbit/search.h index e129de25..54c35259 100644 --- a/vowpalwabbit/search.h +++ b/vowpalwabbit/search.h @@ -241,8 +241,5 @@ namespace Search { bool size_equal(size_t a, size_t b); // our interface within VW - LEARNER::base_learner* setup(vw&, po::variables_map&); - void search_finish(void*); - void search_drive(void*); - void search_learn(void*,example*); + LEARNER::base_learner* setup(vw&); } diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc index a9ded7e4..ae73cee9 100644 --- a/vowpalwabbit/sender.cc +++ b/vowpalwabbit/sender.cc @@ -96,13 +96,17 @@ void end_examples(sender& s) delete s.buf; } - LEARNER::base_learner* setup(vw& all, po::variables_map& vm, vector<string> pairs) -{ + LEARNER::base_learner* setup(vw& all) + { + new_options(all, "Sender options") + ("sendto", po::value< vector<string> >(), "send examples to <host>"); + if(missing_required(all)) return NULL; + sender& s = calloc_or_die<sender>(); s.sd = -1; - if (vm.count("sendto")) + if (all.vm.count("sendto")) { - vector<string> hosts = vm["sendto"].as< vector<string> >(); + vector<string> hosts = all.vm["sendto"].as< vector<string> >(); open_sockets(s, hosts[0]); } diff --git a/vowpalwabbit/sender.h b/vowpalwabbit/sender.h index 9740f159..b8199bf6 100644 --- a/vowpalwabbit/sender.h +++ b/vowpalwabbit/sender.h @@ -4,6 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace SENDER{ - LEARNER::base_learner* setup(vw& all, po::variables_map& vm, vector<string> pairs); -} +namespace SENDER { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc index b2e7e150..68065e0c 100644 --- a/vowpalwabbit/stagewise_poly.cc +++ b/vowpalwabbit/stagewise_poly.cc @@ -7,7 +7,7 @@ #include "allreduce.h" #include "accumulate.h" #include "constant.h" -#include "memory.h" +#include "reductions.h" #include "vw.h" //#define MAGIC_ARGUMENT //MAY IT NEVER DIE @@ -656,17 +656,13 @@ namespace StagewisePoly //#endif //DEBUG } - - base_learner *setup(vw &all, po::variables_map &vm) + base_learner *setup(vw &all) { - stagewise_poly& poly = calloc_or_die<stagewise_poly>(); - poly.all = &all; - - depthsbits_create(poly); - sort_data_create(poly); + new_options(all, "Stagewise poly options") + ("stage_poly", "use stagewise polynomial feature learning"); + if (missing_required(all)) return NULL; - po::options_description sp_opt("Stagewise poly options"); - sp_opt.add_options() + new_options(all) ("sched_exponent", po::value<float>(), "exponent controlling quantity of included features") ("batch_sz", po::value<uint32_t>(), "multiplier on batch size before including more features") ("batch_sz_no_doubling", "batch_sz does not double") @@ -674,7 +670,13 @@ namespace StagewisePoly ("magic_argument", po::value<float>(), "magical feature flag") #endif //MAGIC_ARGUMENT ; - vm = add_options(all, sp_opt); + add_options(all); + + po::variables_map &vm = all.vm; + stagewise_poly& poly = calloc_or_die<stagewise_poly>(); + poly.all = &all; + depthsbits_create(poly); + sort_data_create(poly); poly.sched_exponent = vm.count("sched_exponent") ? vm["sched_exponent"].as<float>() : 1.f; poly.batch_sz = vm.count("batch_sz") ? vm["batch_sz"].as<uint32_t>() : 1000; @@ -698,7 +700,7 @@ namespace StagewisePoly //following is so that saved models know to load us. *all.file_options << " --stage_poly"; - learner<stagewise_poly>& l = init_learner(&poly, all.l, learn, predict); + learner<stagewise_poly>& l = init_learner(&poly, setup_base(all), learn, predict); l.set_finish(finish); l.set_save_load(save_load); l.set_finish_example(finish_example); diff --git a/vowpalwabbit/stagewise_poly.h b/vowpalwabbit/stagewise_poly.h index 983b4382..60478e81 100644 --- a/vowpalwabbit/stagewise_poly.h +++ b/vowpalwabbit/stagewise_poly.h @@ -4,7 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace StagewisePoly -{ - LEARNER::base_learner *setup(vw &all, po::variables_map &vm); -} +namespace StagewisePoly { LEARNER::base_learner *setup(vw &all); } diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc index 445bdb23..52f03ad9 100644 --- a/vowpalwabbit/topk.cc +++ b/vowpalwabbit/topk.cc @@ -16,15 +16,12 @@ namespace TOPK { struct compare_scored_examples { bool operator()(scored_example const& a, scored_example const& b) const - { - return a.first > b.first; - } + { return a.first > b.first; } }; struct topk{ uint32_t B; //rec number priority_queue<scored_example, vector<scored_example>, compare_scored_examples > pr_queue; - vw* all; }; void print_result(int f, priority_queue<scored_example, vector<scored_example>, compare_scored_examples > &pr_queue) @@ -72,43 +69,46 @@ namespace TOPK { if (example_is_newline(ec)) for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++) TOPK::print_result(*sink, d.pr_queue); - + print_update(all, ec); } - + template <bool is_learn> void predict_or_learn(topk& d, LEARNER::base_learner& base, example& ec) { if (example_is_newline(ec)) return;//do not predict newline - + if (is_learn) base.learn(ec); else base.predict(ec); - + if(d.pr_queue.size() < d.B) d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); - + else if(d.pr_queue.top().first < ec.pred.scalar) - { - d.pr_queue.pop(); - d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); - } + { + d.pr_queue.pop(); + d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); + } } - + void finish_example(vw& all, topk& d, example& ec) { TOPK::output_example(all, d, ec); VW::finish_example(all, &ec); } - LEARNER::base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all) { + new_options(all, "TOP K options") + ("top", po::value<size_t>(), "top k recommendation"); + if(missing_required(all)) return NULL; + topk& data = calloc_or_die<topk>(); - data.B = (uint32_t)vm["top"].as<size_t>(); - data.all = &all; + data.B = (uint32_t)all.vm["top"].as<size_t>(); - LEARNER::learner<topk>& l = init_learner(&data, all.l, predict_or_learn<true>, + LEARNER::learner<topk>& l = init_learner(&data, setup_base(all), predict_or_learn<true>, predict_or_learn<false>); l.set_finish_example(finish_example); diff --git a/vowpalwabbit/topk.h b/vowpalwabbit/topk.h index 866d94c5..6e9973ad 100644 --- a/vowpalwabbit/topk.h +++ b/vowpalwabbit/topk.h @@ -4,15 +4,4 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -#include "io_buf.h" -#include "parse_primitives.h" -#include "global_data.h" -#include "example.h" -#include "parse_args.h" -#include "v_hashmap.h" -#include "simple_label.h" - -namespace TOPK -{ - LEARNER::base_learner* setup(vw& all, po::variables_map& vm); -} +namespace TOPK { LEARNER::base_learner* setup(vw& all); } diff --git a/vowpalwabbit/vw_static.vcxproj b/vowpalwabbit/vw_static.vcxproj index b555aef2..3df09387 100644 --- a/vowpalwabbit/vw_static.vcxproj +++ b/vowpalwabbit/vw_static.vcxproj @@ -1,4 +1,4 @@ -<?xml version="1.0" encoding="utf-8"?> +<?xml version="1.0" encoding="utf-8"?> <Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> <ItemGroup Label="ProjectConfigurations"> <ProjectConfiguration Include="Debug|Win32"> @@ -249,6 +249,7 @@ <ClInclude Include="ect.h" /> <ClInclude Include="example.h" /> <ClInclude Include="gd.h" /> + <ClInclude Include="ftrl_proximal.h" /> <ClInclude Include="memory.h" /> <ClInclude Include="multiclass.h" /> <ClInclude Include="cost_sensitive.h" /> @@ -305,8 +306,8 @@ <ClCompile Include="ect.cc" /> <ClCompile Include="example.cc" /> <ClCompile Include="gd.cc" /> + <ClCompile Include="ftrl_proximal.cc" /> <ClCompile Include="kernel_svm.cc" /> - <ClCompile Include="memory.cc" /> <ClCompile Include="multiclass.cc" /> <ClCompile Include="cost_sensitive.cc" /> <ClCompile Include="cb_algs.cc" /> diff --git a/vowpalwabbit/vwdll.cpp b/vowpalwabbit/vwdll.cpp index 6235705b..501730b5 100644 --- a/vowpalwabbit/vwdll.cpp +++ b/vowpalwabbit/vwdll.cpp @@ -52,7 +52,7 @@ extern "C" adjust_used_index(*pointer); pointer->do_reset_source = true; VW::start_parser(*pointer,false); - pointer->l->driver(pointer); + LEARNER::generic_driver(*pointer); VW::end_parser(*pointer); } else |