Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorariel faigon <github.2009@yendor.com>2015-01-04 10:11:31 +0300
committerariel faigon <github.2009@yendor.com>2015-01-04 10:11:31 +0300
commitafe136b52b503239166fe5fee22d3b2c9fd144cd (patch)
tree02346b7ad036ffbdfe5cad3674636b8ab72b70b1
parentf1f859a19c8a3c70ad7b4706d78ae5cac6bdca36 (diff)
parent143522150dedcd953c8ac4f3a0114c9195f4efb6 (diff)
Merge branch 'master' of git://github.com/JohnLangford/vowpal_wabbit
-rw-r--r--Makefile9
-rw-r--r--Makefile.am1
-rw-r--r--demo/dna/Makefile2
-rw-r--r--demo/entityrelation/Makefile5
-rw-r--r--demo/movielens/Makefile4
-rw-r--r--demo/normalized/Makefile26
-rw-r--r--explore/clr/explore_clr_wrapper.cpp36
-rw-r--r--explore/clr/explore_clr_wrapper.h874
-rw-r--r--explore/clr/explore_interface.h174
-rw-r--r--explore/clr/explore_interop.h714
-rw-r--r--explore/explore.cpp146
-rw-r--r--explore/static/MWTExplorer.h1338
-rw-r--r--explore/static/utility.h566
-rw-r--r--explore/tests/MWTExploreTests.h328
-rw-r--r--java/pom.xml68
-rw-r--r--java/src/main/c++/vw_VWScorer.cc8
-rw-r--r--java/src/test/java/vw/VWScorerTest.java14
-rw-r--r--java/src/test/resources/house.modelbin101 -> 0 bytes
-rw-r--r--java/src/test/resources/house.vw3
-rw-r--r--library/Makefile1
-rw-r--r--library/ezexample_predict.cc110
-rw-r--r--library/ezexample_predict_threaded.cc298
-rw-r--r--library/ezexample_train.cc144
-rw-r--r--library/gd_mf_weights.cc13
-rw-r--r--library/library_example.cc124
-rw-r--r--python/Makefile2
-rw-r--r--test/test-sets/ref/ml100k_small.stderr1
-rw-r--r--test/train-sets/ref/ml100k_small.stderr1
-rw-r--r--test/train-sets/ref/sequencespan_data.nonldf-bilou.test.stderr2
-rw-r--r--test/train-sets/ref/sequencespan_data.nonldf-bilou.train.stderr2
-rw-r--r--test/train-sets/ref/sequencespan_data.nonldf.train.stderr2
-rw-r--r--vowpalwabbit/Makefile1
-rw-r--r--vowpalwabbit/Makefile.am2
-rw-r--r--vowpalwabbit/accumulate.h26
-rw-r--r--vowpalwabbit/active.cc29
-rw-r--r--vowpalwabbit/active.h4
-rw-r--r--vowpalwabbit/autolink.cc12
-rw-r--r--vowpalwabbit/autolink.h4
-rw-r--r--vowpalwabbit/bfgs.cc44
-rw-r--r--vowpalwabbit/bfgs.h4
-rw-r--r--vowpalwabbit/binary.cc12
-rw-r--r--vowpalwabbit/binary.h4
-rw-r--r--vowpalwabbit/bs.cc23
-rw-r--r--vowpalwabbit/bs.h2
-rw-r--r--vowpalwabbit/cb_algs.cc42
-rw-r--r--vowpalwabbit/cb_algs.h3
-rw-r--r--vowpalwabbit/cbify.cc37
-rw-r--r--vowpalwabbit/cbify.h4
-rw-r--r--vowpalwabbit/csoaa.cc32
-rw-r--r--vowpalwabbit/csoaa.h17
-rw-r--r--vowpalwabbit/ect.cc51
-rw-r--r--vowpalwabbit/ect.h5
-rw-r--r--vowpalwabbit/ftrl_proximal.cc44
-rw-r--r--vowpalwabbit/ftrl_proximal.h9
-rw-r--r--vowpalwabbit/gd.cc10
-rw-r--r--vowpalwabbit/gd.h2
-rw-r--r--vowpalwabbit/gd_mf.cc177
-rw-r--r--vowpalwabbit/gd_mf.h10
-rw-r--r--vowpalwabbit/global_data.cc42
-rw-r--r--vowpalwabbit/global_data.h23
-rw-r--r--vowpalwabbit/kernel_svm.cc18
-rw-r--r--vowpalwabbit/kernel_svm.h5
-rw-r--r--vowpalwabbit/lda_core.cc1617
-rw-r--r--vowpalwabbit/lda_core.h4
-rw-r--r--vowpalwabbit/learner.h2
-rw-r--r--vowpalwabbit/log_multi.cc52
-rw-r--r--vowpalwabbit/log_multi.h5
-rw-r--r--vowpalwabbit/lrq.cc22
-rw-r--r--vowpalwabbit/lrq.h4
-rw-r--r--vowpalwabbit/main.cc82
-rw-r--r--vowpalwabbit/memory.cc8
-rw-r--r--vowpalwabbit/memory.h8
-rw-r--r--vowpalwabbit/mf.cc21
-rw-r--r--vowpalwabbit/mf.h10
-rw-r--r--vowpalwabbit/multiclass.h2
-rw-r--r--vowpalwabbit/nn.cc23
-rw-r--r--vowpalwabbit/nn.h8
-rw-r--r--vowpalwabbit/noop.cc8
-rw-r--r--vowpalwabbit/noop.h4
-rw-r--r--vowpalwabbit/oaa.cc19
-rw-r--r--vowpalwabbit/oaa.h3
-rw-r--r--vowpalwabbit/parse_args.cc478
-rw-r--r--vowpalwabbit/parse_args.h3
-rw-r--r--vowpalwabbit/parse_regressor.cc13
-rw-r--r--vowpalwabbit/parse_regressor.h2
-rw-r--r--vowpalwabbit/parser.cc65
-rw-r--r--vowpalwabbit/parser.h2
-rw-r--r--vowpalwabbit/print.cc6
-rw-r--r--vowpalwabbit/print.h4
-rw-r--r--vowpalwabbit/reductions.h1
-rw-r--r--vowpalwabbit/scorer.cc24
-rw-r--r--vowpalwabbit/scorer.h4
-rw-r--r--vowpalwabbit/search.cc108
-rw-r--r--vowpalwabbit/search.h5
-rw-r--r--vowpalwabbit/sender.cc12
-rw-r--r--vowpalwabbit/sender.h4
-rw-r--r--vowpalwabbit/stagewise_poly.cc26
-rw-r--r--vowpalwabbit/stagewise_poly.h5
-rw-r--r--vowpalwabbit/topk.cc36
-rw-r--r--vowpalwabbit/topk.h13
-rw-r--r--vowpalwabbit/vw_static.vcxproj5
-rw-r--r--vowpalwabbit/vwdll.cpp2
102 files changed, 4150 insertions, 4284 deletions
diff --git a/Makefile b/Makefile
index 0d53bb54..052a578d 100644
--- a/Makefile
+++ b/Makefile
@@ -20,17 +20,21 @@ UNAME := $(shell uname)
LIBS = -l boost_program_options -l pthread -l z
BOOST_INCLUDE = -I /usr/include
BOOST_LIBRARY = -L /usr/lib
+NPROCS := 1
ifeq ($(UNAME), Linux)
BOOST_LIBRARY += -L /usr/lib/x86_64-linux-gnu
+ NPROCS:=$(shell grep -c ^processor /proc/cpuinfo)
endif
ifeq ($(UNAME), FreeBSD)
LIBS = -l boost_program_options -l pthread -l z -l compat
BOOST_INCLUDE = -I /usr/local/include
+ NPROCS:=$(shell grep -c ^processor /proc/cpuinfo)
endif
ifeq "CYGWIN" "$(findstring CYGWIN,$(UNAME))"
LIBS = -l boost_program_options-mt -l pthread -l z
BOOST_INCLUDE = -I /usr/include
+ NPROCS:=$(shell grep -c ^processor /proc/cpuinfo)
endif
ifeq ($(UNAME), Darwin)
LIBS = -lboost_program_options-mt -lboost_serialization-mt -l pthread -l z
@@ -46,6 +50,7 @@ ifeq ($(UNAME), Darwin)
BOOST_INCLUDE += -I /opt/local/include
BOOST_LIBRARY += -L /opt/local/lib
endif
+ NPROCS:=$(shell sysctl -n hw.ncpu)
endif
#LIBS = -l boost_program_options-gcc34 -l pthread -l z
@@ -84,7 +89,7 @@ spanning_tree:
cd cluster; $(MAKE)
vw:
- cd vowpalwabbit; $(MAKE) -j 8 things
+ cd vowpalwabbit; $(MAKE) -j $(NPROCS) things
active_interactor:
cd vowpalwabbit; $(MAKE)
@@ -117,3 +122,5 @@ clean:
ifneq ($(JAVA_HOME),)
cd java && $(MAKE) clean
endif
+
+.PHONY: all clean install
diff --git a/Makefile.am b/Makefile.am
index 8bc91364..abb79bb2 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -36,7 +36,6 @@ noinst_HEADERS = vowpalwabbit/accumulate.h \
vowpalwabbit/lda_core.h \
vowpalwabbit/log_multi.h \
vowpalwabbit/lrq.h \
- vowpalwabbit/memory.h \
vowpalwabbit/mf.h \
vowpalwabbit/multiclass.h \
vowpalwabbit/network.h \
diff --git a/demo/dna/Makefile b/demo/dna/Makefile
index 93efac3b..477cfb98 100644
--- a/demo/dna/Makefile
+++ b/demo/dna/Makefile
@@ -27,3 +27,5 @@ quaddna2vw: quaddna2vw.cpp
%.perf: %.test.predictions perf.check perl.check zsh.check
./do-perf $<
+
+.PHONY: all clean
diff --git a/demo/entityrelation/Makefile b/demo/entityrelation/Makefile
index c287d97c..c3db3a69 100644
--- a/demo/entityrelation/Makefile
+++ b/demo/entityrelation/Makefile
@@ -8,7 +8,7 @@ eval_script=./evaluationER.py
all:
@cat README.md
clean:
- rm -f *.model *.predictions ER_*.vw *.cache evaluationER.py er.zip *~
+ rm -f *.model *.predictions ER_*.vw *.cache evaluationER.py er.zip *~
%.check:
@test -x "$$(which $*)" || { \
@@ -33,5 +33,4 @@ er.test.predictions: ER_test.vw er.model
er.perf: ER_test.vw er.test.predictions python.check
@$(eval_script) $< er.test.predictions
-
-
+.PHONY: all clean
diff --git a/demo/movielens/Makefile b/demo/movielens/Makefile
index 68be5310..46c92e24 100644
--- a/demo/movielens/Makefile
+++ b/demo/movielens/Makefile
@@ -39,7 +39,7 @@ ml-%.ratings.train.vw: ml-%/ratings.dat
@printf "%s test MAE is %3.3f\n" $* $$(cat $*.results)
#---------------------------------------------------------------------
-# linear model (no interaction terms)
+# linear model (no interaction terms)
#---------------------------------------------------------------------
linear.results: ml-1m.ratings.test.vw ml-1m.ratings.train.vw
@@ -154,3 +154,5 @@ lrqdropouthogwild.results: ml-1m.ratings.test.vw ml-1m.ratings.train.vw do-lrq-h
-d $(word 1,$+) -p \
>(perl -lane '$$s+=abs(($$F[0]-$$F[1])); } { \
1; print $$s/$$.;' > $@)
+
+.PHONY: all clean shootout
diff --git a/demo/normalized/Makefile b/demo/normalized/Makefile
index 8268e82f..f1378b80 100644
--- a/demo/normalized/Makefile
+++ b/demo/normalized/Makefile
@@ -19,8 +19,8 @@ SHUFFLE='BEGIN { srand 69; }; \
$$b[$$i] = $$_; } { print grep { defined $$_ } @b;'
#---------------------------------------------------------------------
-# bank marketing
-#
+# bank marketing
+#
# normalization really helps. The columns have units of euros, seconds,
# days, and years; in addition there are categorical variables.
#---------------------------------------------------------------------
@@ -57,10 +57,10 @@ bankbestmin=1e-2
bankbestmax=10
banknonormbestmin=1e-8
banknonormbestmax=1e-3
-banktimeestimate=1
+banktimeestimate=1
#---------------------------------------------------------------------
-# covertype
+# covertype
#---------------------------------------------------------------------
covtype.data.gz:
@@ -88,11 +88,11 @@ covertypebestmin=1e-2
covertypebestmax=10
covertypenonormbestmin=1e-8
covertypenonormbestmax=1e-5
-covertypetimeestimate=5
+covertypetimeestimate=5
#---------------------------------------------------------------------
-# million song database
-#
+# million song database
+#
# normalization is helpful.
#---------------------------------------------------------------------
@@ -142,7 +142,7 @@ MSDnonormbestmax=1e-5
MSDtimeestimate=15
#---------------------------------------------------------------------
-# census-income (KDD)
+# census-income (KDD)
#---------------------------------------------------------------------
census-income.data.gz:
@@ -180,7 +180,7 @@ censusnonormbestmax=1e-4
censustimeestimate=5
#---------------------------------------------------------------------
-# Statlog (Shuttle)
+# Statlog (Shuttle)
#---------------------------------------------------------------------
shuttle.trn.Z:
@@ -211,7 +211,7 @@ shuttlenonormbestmax=1e-3
shuttletimeestimate=1
#---------------------------------------------------------------------
-# CT slices
+# CT slices
#
# normalization doesn't help much
#---------------------------------------------------------------------
@@ -256,7 +256,7 @@ CTslicenonormbestmax=1
CTslicetimeestimate=15
#---------------------------------------------------------------------
-# common routines
+# common routines
#---------------------------------------------------------------------
%.best: %.data
@@ -269,8 +269,10 @@ CTslicetimeestimate=15
@printf "WARNING: this step takes about %s minutes\n" $($*timeestimate)
@./hypersearch $($*nonormbestmin) $($*nonormbestmax) '$(MAKE)' '$*.%.nonormlearn' > $@
-%.resultsprint:
+%.resultsprint:
@printf "%20.20s\t%9.3g\t%9.3g\t%9.3g\t%9.3g\n" "$*" $$(cut -f1 $*.best) $$(cut -f2 $*.best) $$(cut -f1 $*.nonormbest) $$(cut -f2 $*.nonormbest)
only.%: %.best %.nonormbest all.results.pre %.resultsprint
@true
+
+.PHONY: all all.results all.results.pre
diff --git a/explore/clr/explore_clr_wrapper.cpp b/explore/clr/explore_clr_wrapper.cpp
index 0564482b..be022af1 100644
--- a/explore/clr/explore_clr_wrapper.cpp
+++ b/explore/clr/explore_clr_wrapper.cpp
@@ -1,18 +1,18 @@
-// vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application.
-//
-
-#define WIN32_LEAN_AND_MEAN
-#include <Windows.h>
-
-#include "explore_clr_wrapper.h"
-
-using namespace System;
-using namespace System::Collections;
-using namespace System::Collections::Generic;
-using namespace System::Runtime::InteropServices;
-using namespace msclr::interop;
-using namespace NativeMultiWorldTesting;
-
-namespace MultiWorldTesting {
-
-}
+// vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application.
+//
+
+#define WIN32_LEAN_AND_MEAN
+#include <Windows.h>
+
+#include "explore_clr_wrapper.h"
+
+using namespace System;
+using namespace System::Collections;
+using namespace System::Collections::Generic;
+using namespace System::Runtime::InteropServices;
+using namespace msclr::interop;
+using namespace NativeMultiWorldTesting;
+
+namespace MultiWorldTesting {
+
+}
diff --git a/explore/clr/explore_clr_wrapper.h b/explore/clr/explore_clr_wrapper.h
index a6cede4b..c73c95d4 100644
--- a/explore/clr/explore_clr_wrapper.h
+++ b/explore/clr/explore_clr_wrapper.h
@@ -1,438 +1,438 @@
-#pragma once
-#include "explore_interop.h"
-
-/*!
-* \addtogroup MultiWorldTestingCsharp
-* @{
-*/
-namespace MultiWorldTesting {
-
- /// <summary>
- /// The epsilon greedy exploration class.
- /// </summary>
- /// <remarks>
- /// This is a good choice if you have no idea which actions should be preferred.
- /// Epsilon greedy is also computationally cheap.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicy">A default function which outputs an action given a context.</param>
- /// <param name="epsilon">The probability of a random exploration.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions)
- {
- this->defaultPolicy = defaultPolicy;
- m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions);
- }
-
- ~EpsilonGreedyExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- return defaultPolicy->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IPolicy<Ctx>^ defaultPolicy;
- NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The tau-first exploration class.
- /// </summary>
- /// <remarks>
- /// The tau-first explorer collects precisely tau uniform random
- /// exploration events, and then uses the default policy.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class TauFirstExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicy">A default policy after randomization finishes.</param>
- /// <param name="tau">The number of events to be uniform over.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions)
- {
- this->defaultPolicy = defaultPolicy;
- m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions);
- }
-
- ~TauFirstExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- return defaultPolicy->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IPolicy<Ctx>^ defaultPolicy;
- NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The epsilon greedy exploration class.
- /// </summary>
- /// <remarks>
- /// In some cases, different actions have a different scores, and you
- /// would prefer to choose actions with large scores. Softmax allows
- /// you to do that.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class SoftmaxExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultScorer">A function which outputs a score for each action.</param>
- /// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions)
- {
- this->defaultScorer = defaultScorer;
- m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions);
- }
-
- ~SoftmaxExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) override
- {
- return defaultScorer->ScoreActions(context);
- }
-
- NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IScorer<Ctx>^ defaultScorer;
- NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The generic exploration class.
- /// </summary>
- /// <remarks>
- /// GenericExplorer provides complete flexibility. You can create any
- /// distribution over actions desired, and it will draw from that.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class GenericExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultScorer">A function which outputs the probability of each action.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions)
- {
- this->defaultScorer = defaultScorer;
- m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions);
- }
-
- ~GenericExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) override
- {
- return defaultScorer->ScoreActions(context);
- }
-
- NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- IScorer<Ctx>^ defaultScorer;
- NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The bootstrap exploration class.
- /// </summary>
- /// <remarks>
- /// The Bootstrap explorer randomizes over the actions chosen by a set of
- /// default policies. This performs well statistically but can be
- /// computationally expensive.
- /// </remarks>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class BootstrapExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
- {
- public:
- /// <summary>
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- /// </summary>
- /// <param name="defaultPolicies">A set of default policies to be uniform random over.</param>
- /// <param name="numActions">The number of actions to randomize over.</param>
- BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions)
- {
- this->defaultPolicies = defaultPolicies;
- if (this->defaultPolicies == nullptr)
- {
- throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null.");
- }
-
- m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions);
- }
-
- ~BootstrapExplorer()
- {
- delete m_explorer;
- }
-
- internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
- {
- if (index < 0 || index >= defaultPolicies->Length)
- {
- throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range.");
- }
- return defaultPolicies[index]->ChooseAction(context);
- }
-
- NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get()
- {
- return m_explorer;
- }
-
- private:
- cli::array<IPolicy<Ctx>^>^ defaultPolicies;
- NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer;
- };
-
- /// <summary>
- /// The top level MwtExplorer class. Using this makes sure that the
- /// right bits are recorded and good random actions are chosen.
- /// </summary>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx>
- public ref class MwtExplorer : public RecorderCallback<Ctx>
- {
- public:
- /// <summary>
- /// Constructor.
- /// </summary>
- /// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param>
- /// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param>
- MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder)
- {
- this->appId = appId;
- this->recorder = recorder;
- }
-
- /// <summary>
- /// Choose_Action should be drop-in replacement for any existing policy function.
- /// </summary>
- /// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param>
- /// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param>
- /// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param>
- /// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns>
- UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context)
- {
- String^ salt = this->appId;
- NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder());
-
- GCHandle selfHandle = GCHandle::Alloc(this);
- IntPtr selfPtr = (IntPtr)selfHandle;
-
- GCHandle contextHandle = GCHandle::Alloc(context);
- IntPtr contextPtr = (IntPtr)contextHandle;
-
- GCHandle explorerHandle = GCHandle::Alloc(explorer);
- IntPtr explorerPtr = (IntPtr)explorerHandle;
-
- NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer());
- u32 action = 0;
- if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid)
- {
- EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid)
- {
- TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid)
- {
- SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == GenericExplorer<Ctx>::typeid)
- {
- GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
- else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid)
- {
- BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer;
- action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
- }
-
- explorerHandle.Free();
- contextHandle.Free();
- selfHandle.Free();
-
- return action;
- }
-
- internal:
- virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override
- {
- recorder->Record(context, action, probability, unique_key);
- }
-
- private:
- IRecorder<Ctx>^ recorder;
- String^ appId;
- };
-
- /// <summary>
- /// Represents a feature in a sparse array.
- /// </summary>
- [StructLayout(LayoutKind::Sequential)]
- public value struct Feature
- {
- float Value;
- UInt32 Id;
- };
-
- /// <summary>
- /// A sample recorder class that converts the exploration tuple into string format.
- /// </summary>
- /// <typeparam name="Ctx">The Context type.</typeparam>
- generic <class Ctx> where Ctx : IStringContext
- public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx>
- {
- public:
- StringRecorder()
- {
- m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>();
- }
-
- ~StringRecorder()
- {
- delete m_string_recorder;
- }
-
- virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey)
- {
- GCHandle contextHandle = GCHandle::Alloc(context);
- IntPtr contextPtr = (IntPtr)contextHandle;
-
- NativeStringContext native_context(contextPtr.ToPointer(), GetCallback());
- m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey));
- }
-
- /// <summary>
- /// Gets the content of the recording so far as a string and clears internal content.
- /// </summary>
- /// <returns>
- /// A string with recording content.
- /// </returns>
- String^ GetRecording()
- {
- // Workaround for C++-CLI bug which does not allow default value for parameter
- return GetRecording(true);
- }
-
- /// <summary>
- /// Gets the content of the recording so far as a string and optionally clears internal content.
- /// </summary>
- /// <param name="flush">A boolean value indicating whether to clear the internal content.</param>
- /// <returns>
- /// A string with recording content.
- /// </returns>
- String^ GetRecording(bool flush)
- {
- return gcnew String(m_string_recorder->Get_Recording(flush).c_str());
- }
-
- private:
- NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder;
- };
-
- /// <summary>
- /// A sample context class that stores a vector of Features.
- /// </summary>
- public ref class SimpleContext : public IStringContext
- {
- public:
- SimpleContext(cli::array<Feature>^ features)
- {
- Features = features;
-
- // TODO: add another constructor overload for native SimpleContext to avoid copying feature values
- m_features = new vector<NativeMultiWorldTesting::Feature>();
- for (int i = 0; i < features->Length; i++)
- {
- m_features->push_back({ features[i].Value, features[i].Id });
- }
-
- m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features);
- }
-
- String^ ToString() override
- {
- return gcnew String(m_native_context->To_String().c_str());
- }
-
- ~SimpleContext()
- {
- delete m_native_context;
- }
-
- public:
- cli::array<Feature>^ GetFeatures() { return Features; }
-
- internal:
- cli::array<Feature>^ Features;
-
- private:
- vector<NativeMultiWorldTesting::Feature>* m_features;
- NativeMultiWorldTesting::SimpleContext* m_native_context;
- };
-}
-
+#pragma once
+#include "explore_interop.h"
+
+/*!
+* \addtogroup MultiWorldTestingCsharp
+* @{
+*/
+namespace MultiWorldTesting {
+
+ /// <summary>
+ /// The epsilon greedy exploration class.
+ /// </summary>
+ /// <remarks>
+ /// This is a good choice if you have no idea which actions should be preferred.
+ /// Epsilon greedy is also computationally cheap.
+ /// </remarks>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
+ {
+ public:
+ /// <summary>
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ /// </summary>
+ /// <param name="defaultPolicy">A default function which outputs an action given a context.</param>
+ /// <param name="epsilon">The probability of a random exploration.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions)
+ {
+ this->defaultPolicy = defaultPolicy;
+ m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions);
+ }
+
+ ~EpsilonGreedyExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
+ {
+ return defaultPolicy->ChooseAction(context);
+ }
+
+ NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ IPolicy<Ctx>^ defaultPolicy;
+ NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer;
+ };
+
+ /// <summary>
+ /// The tau-first exploration class.
+ /// </summary>
+ /// <remarks>
+ /// The tau-first explorer collects precisely tau uniform random
+ /// exploration events, and then uses the default policy.
+ /// </remarks>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class TauFirstExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
+ {
+ public:
+ /// <summary>
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ /// </summary>
+ /// <param name="defaultPolicy">A default policy after randomization finishes.</param>
+ /// <param name="tau">The number of events to be uniform over.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions)
+ {
+ this->defaultPolicy = defaultPolicy;
+ m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions);
+ }
+
+ ~TauFirstExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
+ {
+ return defaultPolicy->ChooseAction(context);
+ }
+
+ NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ IPolicy<Ctx>^ defaultPolicy;
+ NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer;
+ };
+
+ /// <summary>
+ /// The epsilon greedy exploration class.
+ /// </summary>
+ /// <remarks>
+ /// In some cases, different actions have a different scores, and you
+ /// would prefer to choose actions with large scores. Softmax allows
+ /// you to do that.
+ /// </remarks>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class SoftmaxExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
+ {
+ public:
+ /// <summary>
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ /// </summary>
+ /// <param name="defaultScorer">A function which outputs a score for each action.</param>
+ /// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions)
+ {
+ this->defaultScorer = defaultScorer;
+ m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions);
+ }
+
+ ~SoftmaxExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual List<float>^ InvokeScorerCallback(Ctx context) override
+ {
+ return defaultScorer->ScoreActions(context);
+ }
+
+ NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ IScorer<Ctx>^ defaultScorer;
+ NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer;
+ };
+
+ /// <summary>
+ /// The generic exploration class.
+ /// </summary>
+ /// <remarks>
+ /// GenericExplorer provides complete flexibility. You can create any
+ /// distribution over actions desired, and it will draw from that.
+ /// </remarks>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class GenericExplorer : public IExplorer<Ctx>, public ScorerCallback<Ctx>
+ {
+ public:
+ /// <summary>
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ /// </summary>
+ /// <param name="defaultScorer">A function which outputs the probability of each action.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions)
+ {
+ this->defaultScorer = defaultScorer;
+ m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions);
+ }
+
+ ~GenericExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual List<float>^ InvokeScorerCallback(Ctx context) override
+ {
+ return defaultScorer->ScoreActions(context);
+ }
+
+ NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ IScorer<Ctx>^ defaultScorer;
+ NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer;
+ };
+
+ /// <summary>
+ /// The bootstrap exploration class.
+ /// </summary>
+ /// <remarks>
+ /// The Bootstrap explorer randomizes over the actions chosen by a set of
+ /// default policies. This performs well statistically but can be
+ /// computationally expensive.
+ /// </remarks>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class BootstrapExplorer : public IExplorer<Ctx>, public PolicyCallback<Ctx>
+ {
+ public:
+ /// <summary>
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ /// </summary>
+ /// <param name="defaultPolicies">A set of default policies to be uniform random over.</param>
+ /// <param name="numActions">The number of actions to randomize over.</param>
+ BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions)
+ {
+ this->defaultPolicies = defaultPolicies;
+ if (this->defaultPolicies == nullptr)
+ {
+ throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null.");
+ }
+
+ m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions);
+ }
+
+ ~BootstrapExplorer()
+ {
+ delete m_explorer;
+ }
+
+ internal:
+ virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
+ {
+ if (index < 0 || index >= defaultPolicies->Length)
+ {
+ throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range.");
+ }
+ return defaultPolicies[index]->ChooseAction(context);
+ }
+
+ NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get()
+ {
+ return m_explorer;
+ }
+
+ private:
+ cli::array<IPolicy<Ctx>^>^ defaultPolicies;
+ NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer;
+ };
+
+ /// <summary>
+ /// The top level MwtExplorer class. Using this makes sure that the
+ /// right bits are recorded and good random actions are chosen.
+ /// </summary>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx>
+ public ref class MwtExplorer : public RecorderCallback<Ctx>
+ {
+ public:
+ /// <summary>
+ /// Constructor.
+ /// </summary>
+ /// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param>
+ /// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param>
+ MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder)
+ {
+ this->appId = appId;
+ this->recorder = recorder;
+ }
+
+ /// <summary>
+ /// Choose_Action should be drop-in replacement for any existing policy function.
+ /// </summary>
+ /// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param>
+ /// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param>
+ /// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param>
+ /// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns>
+ UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context)
+ {
+ String^ salt = this->appId;
+ NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder());
+
+ GCHandle selfHandle = GCHandle::Alloc(this);
+ IntPtr selfPtr = (IntPtr)selfHandle;
+
+ GCHandle contextHandle = GCHandle::Alloc(context);
+ IntPtr contextPtr = (IntPtr)contextHandle;
+
+ GCHandle explorerHandle = GCHandle::Alloc(explorer);
+ IntPtr explorerPtr = (IntPtr)explorerHandle;
+
+ NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer());
+ u32 action = 0;
+ if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid)
+ {
+ EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+ else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid)
+ {
+ TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+ else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid)
+ {
+ SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+ else if (explorer->GetType() == GenericExplorer<Ctx>::typeid)
+ {
+ GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+ else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid)
+ {
+ BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer;
+ action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
+ }
+
+ explorerHandle.Free();
+ contextHandle.Free();
+ selfHandle.Free();
+
+ return action;
+ }
+
+ internal:
+ virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override
+ {
+ recorder->Record(context, action, probability, unique_key);
+ }
+
+ private:
+ IRecorder<Ctx>^ recorder;
+ String^ appId;
+ };
+
+ /// <summary>
+ /// Represents a feature in a sparse array.
+ /// </summary>
+ [StructLayout(LayoutKind::Sequential)]
+ public value struct Feature
+ {
+ float Value;
+ UInt32 Id;
+ };
+
+ /// <summary>
+ /// A sample recorder class that converts the exploration tuple into string format.
+ /// </summary>
+ /// <typeparam name="Ctx">The Context type.</typeparam>
+ generic <class Ctx> where Ctx : IStringContext
+ public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx>
+ {
+ public:
+ StringRecorder()
+ {
+ m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>();
+ }
+
+ ~StringRecorder()
+ {
+ delete m_string_recorder;
+ }
+
+ virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey)
+ {
+ GCHandle contextHandle = GCHandle::Alloc(context);
+ IntPtr contextPtr = (IntPtr)contextHandle;
+
+ NativeStringContext native_context(contextPtr.ToPointer(), GetCallback());
+ m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey));
+ }
+
+ /// <summary>
+ /// Gets the content of the recording so far as a string and clears internal content.
+ /// </summary>
+ /// <returns>
+ /// A string with recording content.
+ /// </returns>
+ String^ GetRecording()
+ {
+ // Workaround for C++-CLI bug which does not allow default value for parameter
+ return GetRecording(true);
+ }
+
+ /// <summary>
+ /// Gets the content of the recording so far as a string and optionally clears internal content.
+ /// </summary>
+ /// <param name="flush">A boolean value indicating whether to clear the internal content.</param>
+ /// <returns>
+ /// A string with recording content.
+ /// </returns>
+ String^ GetRecording(bool flush)
+ {
+ return gcnew String(m_string_recorder->Get_Recording(flush).c_str());
+ }
+
+ private:
+ NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder;
+ };
+
+ /// <summary>
+ /// A sample context class that stores a vector of Features.
+ /// </summary>
+ public ref class SimpleContext : public IStringContext
+ {
+ public:
+ SimpleContext(cli::array<Feature>^ features)
+ {
+ Features = features;
+
+ // TODO: add another constructor overload for native SimpleContext to avoid copying feature values
+ m_features = new vector<NativeMultiWorldTesting::Feature>();
+ for (int i = 0; i < features->Length; i++)
+ {
+ m_features->push_back({ features[i].Value, features[i].Id });
+ }
+
+ m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features);
+ }
+
+ String^ ToString() override
+ {
+ return gcnew String(m_native_context->To_String().c_str());
+ }
+
+ ~SimpleContext()
+ {
+ delete m_native_context;
+ }
+
+ public:
+ cli::array<Feature>^ GetFeatures() { return Features; }
+
+ internal:
+ cli::array<Feature>^ Features;
+
+ private:
+ vector<NativeMultiWorldTesting::Feature>* m_features;
+ NativeMultiWorldTesting::SimpleContext* m_native_context;
+ };
+}
+
/*! @} End of Doxygen Groups*/ \ No newline at end of file
diff --git a/explore/clr/explore_interface.h b/explore/clr/explore_interface.h
index 3212b4d8..dfa17697 100644
--- a/explore/clr/explore_interface.h
+++ b/explore/clr/explore_interface.h
@@ -1,88 +1,88 @@
-#pragma once
-
-using namespace System;
-using namespace System::Collections::Generic;
-
-/** \defgroup MultiWorldTestingCsharp
-\brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
-*/
-
-/*!
-* \addtogroup MultiWorldTestingCsharp
-* @{
-*/
-
-//! Interface for C# version of Multiworld Testing library.
-//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
-namespace MultiWorldTesting {
-
-/// <summary>
-/// Represents a recorder that exposes a method to record exploration data based on generic contexts.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-/// <remarks>
-/// Exploration data is specified as a set of tuples <context, action, probability, key> as described below. An
-/// application passes an IRecorder object to the @MwtExplorer constructor. See
-/// @StringRecorder for a sample IRecorder object.
-/// </remarks>
-generic <class Ctx>
-public interface class IRecorder
-{
-public:
- /// <summary>
- /// Records the exploration data associated with a given decision.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <param name="action">Chosen by an exploration algorithm given context.</param>
- /// <param name="probability">The probability of the chosen action given context.</param>
- /// <param name="uniqueKey">A user-defined identifer for the decision.</param>
- virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0;
-};
-
-/// <summary>
-/// Exposes a method for choosing an action given a generic context. IPolicy objects are
-/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-generic <class Ctx>
-public interface class IPolicy
-{
-public:
- /// <summary>
- /// Determines the action to take for a given context.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <returns>Index of the action to take (1-based)</returns>
- virtual UInt32 ChooseAction(Ctx context) = 0;
-};
-
-/// <summary>
-/// Exposes a method for specifying a score (weight) for each action given a generic context.
-/// </summary>
-/// <typeparam name="Ctx">The Context type.</typeparam>
-generic <class Ctx>
-public interface class IScorer
-{
-public:
- /// <summary>
- /// Determines the score of each action for a given context.
- /// </summary>
- /// <param name="context">A user-defined context for the decision.</param>
- /// <returns>Vector of scores indexed by action (1-based).</returns>
- virtual List<float>^ ScoreActions(Ctx context) = 0;
-};
-
-generic <class Ctx>
-public interface class IExplorer
-{
-};
-
-public interface class IStringContext
-{
-public:
- virtual String^ ToString() = 0;
-};
-
-}
-
+#pragma once
+
+using namespace System;
+using namespace System::Collections::Generic;
+
+/** \defgroup MultiWorldTestingCsharp
+\brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
+*/
+
+/*!
+* \addtogroup MultiWorldTestingCsharp
+* @{
+*/
+
+//! Interface for C# version of Multiworld Testing library.
+//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
+namespace MultiWorldTesting {
+
+/// <summary>
+/// Represents a recorder that exposes a method to record exploration data based on generic contexts.
+/// </summary>
+/// <typeparam name="Ctx">The Context type.</typeparam>
+/// <remarks>
+/// Exploration data is specified as a set of tuples <context, action, probability, key> as described below. An
+/// application passes an IRecorder object to the @MwtExplorer constructor. See
+/// @StringRecorder for a sample IRecorder object.
+/// </remarks>
+generic <class Ctx>
+public interface class IRecorder
+{
+public:
+ /// <summary>
+ /// Records the exploration data associated with a given decision.
+ /// </summary>
+ /// <param name="context">A user-defined context for the decision.</param>
+ /// <param name="action">Chosen by an exploration algorithm given context.</param>
+ /// <param name="probability">The probability of the chosen action given context.</param>
+ /// <param name="uniqueKey">A user-defined identifer for the decision.</param>
+ virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0;
+};
+
+/// <summary>
+/// Exposes a method for choosing an action given a generic context. IPolicy objects are
+/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
+/// </summary>
+/// <typeparam name="Ctx">The Context type.</typeparam>
+generic <class Ctx>
+public interface class IPolicy
+{
+public:
+ /// <summary>
+ /// Determines the action to take for a given context.
+ /// </summary>
+ /// <param name="context">A user-defined context for the decision.</param>
+ /// <returns>Index of the action to take (1-based)</returns>
+ virtual UInt32 ChooseAction(Ctx context) = 0;
+};
+
+/// <summary>
+/// Exposes a method for specifying a score (weight) for each action given a generic context.
+/// </summary>
+/// <typeparam name="Ctx">The Context type.</typeparam>
+generic <class Ctx>
+public interface class IScorer
+{
+public:
+ /// <summary>
+ /// Determines the score of each action for a given context.
+ /// </summary>
+ /// <param name="context">A user-defined context for the decision.</param>
+ /// <returns>Vector of scores indexed by action (1-based).</returns>
+ virtual List<float>^ ScoreActions(Ctx context) = 0;
+};
+
+generic <class Ctx>
+public interface class IExplorer
+{
+};
+
+public interface class IStringContext
+{
+public:
+ virtual String^ ToString() = 0;
+};
+
+}
+
/*! @} End of Doxygen Groups*/ \ No newline at end of file
diff --git a/explore/clr/explore_interop.h b/explore/clr/explore_interop.h
index 99466945..a6cb9327 100644
--- a/explore/clr/explore_interop.h
+++ b/explore/clr/explore_interop.h
@@ -1,358 +1,358 @@
-#pragma once
-
-#define MANAGED_CODE
-
-#include "explore_interface.h"
-#include "MWTExplorer.h"
-
-#include <msclr\marshal_cppstd.h>
-
-using namespace System;
-using namespace System::Collections::Generic;
-using namespace System::IO;
-using namespace System::Runtime::InteropServices;
-using namespace System::Xml::Serialization;
-using namespace msclr::interop;
-
-namespace MultiWorldTesting {
-
-// Policy callback
-private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index);
-typedef u32 Native_Policy_Callback(void* explorer, void* context, int index);
-
-// Scorer callback
-private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size);
-typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size);
-
-// Recorder callback
-private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey);
-typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key);
-
-// ToString callback
-private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue);
-typedef void Native_To_String_Callback(void* explorer, void* string_value);
-
-// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context
-// used for triggering callback for Policy, Scorer, Recorder
-class NativeContext
-{
-public:
- NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context)
- {
- m_clr_mwt = clr_mwt;
- m_clr_explorer = clr_explorer;
- m_clr_context = clr_context;
- }
-
- void* Get_Clr_Mwt()
- {
- return m_clr_mwt;
- }
-
- void* Get_Clr_Context()
- {
- return m_clr_context;
- }
-
- void* Get_Clr_Explorer()
- {
- return m_clr_explorer;
- }
-
-private:
- void* m_clr_mwt;
- void* m_clr_context;
- void* m_clr_explorer;
-};
-
-class NativeStringContext
-{
-public:
- NativeStringContext(void* clr_context, Native_To_String_Callback* func)
- {
- m_clr_context = clr_context;
- m_func = func;
- }
-
- string To_String()
- {
- string value;
- m_func(m_clr_context, &value);
- return value;
- }
-private:
- void* m_clr_context;
- Native_To_String_Callback* m_func;
-};
-
-// NativeRecorder listens to callback event and reroute it to the managed Recorder instance
-class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext>
-{
-public:
- NativeRecorder(Native_Recorder_Callback* native_func)
- {
- m_func = native_func;
- }
-
- void Record(NativeContext& context, u32 action, float probability, string unique_key)
- {
- GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str()));
- IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle;
-
- m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer());
-
- uniqueKeyHandle.Free();
- }
-private:
- Native_Recorder_Callback* m_func;
-};
-
-// NativePolicy listens to callback event and reroute it to the managed Policy instance
-class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext>
-{
-public:
- NativePolicy(Native_Policy_Callback* func, int index = -1)
- {
- m_func = func;
- m_index = index;
- }
-
- u32 Choose_Action(NativeContext& context)
- {
- return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index);
- }
-
-private:
- Native_Policy_Callback* m_func;
- int m_index;
-};
-
-class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext>
-{
-public:
- NativeScorer(Native_Scorer_Callback* func)
- {
- m_func = func;
- }
-
- vector<float> Score_Actions(NativeContext& context)
- {
- float* scores = nullptr;
- u32 num_scores = 0;
- m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores);
-
- // It's ok if scores is null, vector will be empty
- vector<float> scores_vector(scores, scores + num_scores);
- delete[] scores;
-
- return scores_vector;
- }
-private:
- Native_Scorer_Callback* m_func;
-};
-
-// Triggers callback to the Policy instance to choose an action
-generic <class Ctx>
-public ref class PolicyCallback abstract
-{
-internal:
- virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0;
-
- PolicyCallback()
- {
- policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke);
- IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback);
- m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer());
- m_native_policy = nullptr;
- m_native_policies = nullptr;
- }
-
- ~PolicyCallback()
- {
- delete m_native_policy;
- delete m_native_policies;
- }
-
- NativePolicy* GetNativePolicy()
- {
- if (m_native_policy == nullptr)
- {
- m_native_policy = new NativePolicy(m_callback);
- }
- return m_native_policy;
- }
-
- vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* GetNativePolicies(int count)
- {
- if (m_native_policies == nullptr)
- {
- m_native_policies = new vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>();
- for (int i = 0; i < count; i++)
- {
- m_native_policies->push_back(unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>(new NativePolicy(m_callback, i)));
- }
- }
-
- return m_native_policies;
- }
-
- static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index)
- {
- GCHandle callbackHandle = (GCHandle)callbackPtr;
- PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- return callback->InvokePolicyCallback(context, index);
- }
-
-private:
- ClrPolicyCallback^ policyCallback;
-
-private:
- NativePolicy* m_native_policy;
- vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies;
- Native_Policy_Callback* m_callback;
-};
-
-// Triggers callback to the Recorder instance to record interaction data
-generic <class Ctx>
-public ref class RecorderCallback abstract
-{
-internal:
- virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0;
-
- RecorderCallback()
- {
- recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke);
- IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback);
- Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer());
- m_native_recorder = new NativeRecorder(callback);
- }
-
- ~RecorderCallback()
- {
- delete m_native_recorder;
- }
-
- NativeRecorder* GetNativeRecorder()
- {
- return m_native_recorder;
- }
-
- static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr)
- {
- GCHandle mwtHandle = (GCHandle)mwtPtr;
- RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr;
- String^ uniqueKey = (String^)uniqueKeyHandle.Target;
-
- callback->InvokeRecorderCallback(context, action, probability, uniqueKey);
- }
-
-private:
- ClrRecorderCallback^ recorderCallback;
-
-private:
- NativeRecorder* m_native_recorder;
-};
-
-// Triggers callback to the Recorder instance to record interaction data
-generic <class Ctx>
-public ref class ScorerCallback abstract
-{
-internal:
- virtual List<float>^ InvokeScorerCallback(Ctx context) = 0;
-
- ScorerCallback()
- {
- scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke);
- IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback);
- Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer());
- m_native_scorer = new NativeScorer(callback);
- }
-
- ~ScorerCallback()
- {
- delete m_native_scorer;
- }
-
- NativeScorer* GetNativeScorer()
- {
- return m_native_scorer;
- }
-
- static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr)
- {
- GCHandle callbackHandle = (GCHandle)callbackPtr;
- ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target;
-
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- List<float>^ scoreList = callback->InvokeScorerCallback(context);
-
- if (scoreList == nullptr || scoreList->Count == 0)
- {
- return;
- }
-
- u32* num_scores = (u32*)sizePtr.ToPointer();
- *num_scores = (u32)scoreList->Count;
-
- float* scores = new float[*num_scores];
- for (u32 i = 0; i < *num_scores; i++)
- {
- scores[i] = scoreList[i];
- }
-
- float** native_scores = (float**)scoresPtr.ToPointer();
- *native_scores = scores;
- }
-
-private:
- ClrScorerCallback^ scorerCallback;
-
-private:
- NativeScorer* m_native_scorer;
-};
-
-// Triggers callback to the Context instance to perform ToString() operation
-generic <class Ctx> where Ctx : IStringContext
-public ref class ToStringCallback
-{
-internal:
- ToStringCallback()
- {
- toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke);
- IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback);
- m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer());
- }
-
- Native_To_String_Callback* GetCallback()
- {
- return m_callback;
- }
-
- static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr)
- {
- GCHandle contextHandle = (GCHandle)contextPtr;
- Ctx context = (Ctx)contextHandle.Target;
-
- string* out_string = (string*)stringPtr.ToPointer();
- *out_string = marshal_as<string>(context->ToString());
- }
-
-private:
- ClrToStringCallback^ toStringCallback;
-
-private:
- Native_To_String_Callback* m_callback;
-};
-
+#pragma once
+
+#define MANAGED_CODE
+
+#include "explore_interface.h"
+#include "MWTExplorer.h"
+
+#include <msclr\marshal_cppstd.h>
+
+using namespace System;
+using namespace System::Collections::Generic;
+using namespace System::IO;
+using namespace System::Runtime::InteropServices;
+using namespace System::Xml::Serialization;
+using namespace msclr::interop;
+
+namespace MultiWorldTesting {
+
+// Policy callback
+private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index);
+typedef u32 Native_Policy_Callback(void* explorer, void* context, int index);
+
+// Scorer callback
+private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size);
+typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size);
+
+// Recorder callback
+private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey);
+typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key);
+
+// ToString callback
+private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue);
+typedef void Native_To_String_Callback(void* explorer, void* string_value);
+
+// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context
+// used for triggering callback for Policy, Scorer, Recorder
+class NativeContext
+{
+public:
+ NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context)
+ {
+ m_clr_mwt = clr_mwt;
+ m_clr_explorer = clr_explorer;
+ m_clr_context = clr_context;
+ }
+
+ void* Get_Clr_Mwt()
+ {
+ return m_clr_mwt;
+ }
+
+ void* Get_Clr_Context()
+ {
+ return m_clr_context;
+ }
+
+ void* Get_Clr_Explorer()
+ {
+ return m_clr_explorer;
+ }
+
+private:
+ void* m_clr_mwt;
+ void* m_clr_context;
+ void* m_clr_explorer;
+};
+
+class NativeStringContext
+{
+public:
+ NativeStringContext(void* clr_context, Native_To_String_Callback* func)
+ {
+ m_clr_context = clr_context;
+ m_func = func;
+ }
+
+ string To_String()
+ {
+ string value;
+ m_func(m_clr_context, &value);
+ return value;
+ }
+private:
+ void* m_clr_context;
+ Native_To_String_Callback* m_func;
+};
+
+// NativeRecorder listens to callback event and reroute it to the managed Recorder instance
+class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext>
+{
+public:
+ NativeRecorder(Native_Recorder_Callback* native_func)
+ {
+ m_func = native_func;
+ }
+
+ void Record(NativeContext& context, u32 action, float probability, string unique_key)
+ {
+ GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str()));
+ IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle;
+
+ m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer());
+
+ uniqueKeyHandle.Free();
+ }
+private:
+ Native_Recorder_Callback* m_func;
+};
+
+// NativePolicy listens to callback event and reroute it to the managed Policy instance
+class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext>
+{
+public:
+ NativePolicy(Native_Policy_Callback* func, int index = -1)
+ {
+ m_func = func;
+ m_index = index;
+ }
+
+ u32 Choose_Action(NativeContext& context)
+ {
+ return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index);
+ }
+
+private:
+ Native_Policy_Callback* m_func;
+ int m_index;
+};
+
+class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext>
+{
+public:
+ NativeScorer(Native_Scorer_Callback* func)
+ {
+ m_func = func;
+ }
+
+ vector<float> Score_Actions(NativeContext& context)
+ {
+ float* scores = nullptr;
+ u32 num_scores = 0;
+ m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores);
+
+ // It's ok if scores is null, vector will be empty
+ vector<float> scores_vector(scores, scores + num_scores);
+ delete[] scores;
+
+ return scores_vector;
+ }
+private:
+ Native_Scorer_Callback* m_func;
+};
+
+// Triggers callback to the Policy instance to choose an action
+generic <class Ctx>
+public ref class PolicyCallback abstract
+{
+internal:
+ virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0;
+
+ PolicyCallback()
+ {
+ policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke);
+ IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback);
+ m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer());
+ m_native_policy = nullptr;
+ m_native_policies = nullptr;
+ }
+
+ ~PolicyCallback()
+ {
+ delete m_native_policy;
+ delete m_native_policies;
+ }
+
+ NativePolicy* GetNativePolicy()
+ {
+ if (m_native_policy == nullptr)
+ {
+ m_native_policy = new NativePolicy(m_callback);
+ }
+ return m_native_policy;
+ }
+
+ vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* GetNativePolicies(int count)
+ {
+ if (m_native_policies == nullptr)
+ {
+ m_native_policies = new vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>();
+ for (int i = 0; i < count; i++)
+ {
+ m_native_policies->push_back(unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>(new NativePolicy(m_callback, i)));
+ }
+ }
+
+ return m_native_policies;
+ }
+
+ static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index)
+ {
+ GCHandle callbackHandle = (GCHandle)callbackPtr;
+ PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target;
+
+ GCHandle contextHandle = (GCHandle)contextPtr;
+ Ctx context = (Ctx)contextHandle.Target;
+
+ return callback->InvokePolicyCallback(context, index);
+ }
+
+private:
+ ClrPolicyCallback^ policyCallback;
+
+private:
+ NativePolicy* m_native_policy;
+ vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies;
+ Native_Policy_Callback* m_callback;
+};
+
+// Triggers callback to the Recorder instance to record interaction data
+generic <class Ctx>
+public ref class RecorderCallback abstract
+{
+internal:
+ virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0;
+
+ RecorderCallback()
+ {
+ recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke);
+ IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback);
+ Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer());
+ m_native_recorder = new NativeRecorder(callback);
+ }
+
+ ~RecorderCallback()
+ {
+ delete m_native_recorder;
+ }
+
+ NativeRecorder* GetNativeRecorder()
+ {
+ return m_native_recorder;
+ }
+
+ static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr)
+ {
+ GCHandle mwtHandle = (GCHandle)mwtPtr;
+ RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target;
+
+ GCHandle contextHandle = (GCHandle)contextPtr;
+ Ctx context = (Ctx)contextHandle.Target;
+
+ GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr;
+ String^ uniqueKey = (String^)uniqueKeyHandle.Target;
+
+ callback->InvokeRecorderCallback(context, action, probability, uniqueKey);
+ }
+
+private:
+ ClrRecorderCallback^ recorderCallback;
+
+private:
+ NativeRecorder* m_native_recorder;
+};
+
+// Triggers callback to the Recorder instance to record interaction data
+generic <class Ctx>
+public ref class ScorerCallback abstract
+{
+internal:
+ virtual List<float>^ InvokeScorerCallback(Ctx context) = 0;
+
+ ScorerCallback()
+ {
+ scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke);
+ IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback);
+ Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer());
+ m_native_scorer = new NativeScorer(callback);
+ }
+
+ ~ScorerCallback()
+ {
+ delete m_native_scorer;
+ }
+
+ NativeScorer* GetNativeScorer()
+ {
+ return m_native_scorer;
+ }
+
+ static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr)
+ {
+ GCHandle callbackHandle = (GCHandle)callbackPtr;
+ ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target;
+
+ GCHandle contextHandle = (GCHandle)contextPtr;
+ Ctx context = (Ctx)contextHandle.Target;
+
+ List<float>^ scoreList = callback->InvokeScorerCallback(context);
+
+ if (scoreList == nullptr || scoreList->Count == 0)
+ {
+ return;
+ }
+
+ u32* num_scores = (u32*)sizePtr.ToPointer();
+ *num_scores = (u32)scoreList->Count;
+
+ float* scores = new float[*num_scores];
+ for (u32 i = 0; i < *num_scores; i++)
+ {
+ scores[i] = scoreList[i];
+ }
+
+ float** native_scores = (float**)scoresPtr.ToPointer();
+ *native_scores = scores;
+ }
+
+private:
+ ClrScorerCallback^ scorerCallback;
+
+private:
+ NativeScorer* m_native_scorer;
+};
+
+// Triggers callback to the Context instance to perform ToString() operation
+generic <class Ctx> where Ctx : IStringContext
+public ref class ToStringCallback
+{
+internal:
+ ToStringCallback()
+ {
+ toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke);
+ IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback);
+ m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer());
+ }
+
+ Native_To_String_Callback* GetCallback()
+ {
+ return m_callback;
+ }
+
+ static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr)
+ {
+ GCHandle contextHandle = (GCHandle)contextPtr;
+ Ctx context = (Ctx)contextHandle.Target;
+
+ string* out_string = (string*)stringPtr.ToPointer();
+ *out_string = marshal_as<string>(context->ToString());
+ }
+
+private:
+ ClrToStringCallback^ toStringCallback;
+
+private:
+ Native_To_String_Callback* m_callback;
+};
+
} \ No newline at end of file
diff --git a/explore/explore.cpp b/explore/explore.cpp
index ebc1c14e..f44ac353 100644
--- a/explore/explore.cpp
+++ b/explore/explore.cpp
@@ -1,73 +1,73 @@
-// explore.cpp : Timing code to measure performance of MWT Explorer library
-
-#include "MWTExplorer.h"
-#include <chrono>
-#include <tuple>
-#include <iostream>
-
-using namespace std;
-using namespace std::chrono;
-
-using namespace MultiWorldTesting;
-
-class MySimplePolicy : public IPolicy<SimpleContext>
-{
-public:
- u32 Choose_Action(SimpleContext& context)
- {
- return (u32)1;
- }
-};
-
-const u32 num_actions = 10;
-
-void Clock_Explore()
-{
- float epsilon = .2f;
- string unique_key = "key";
- int num_features = 1000;
- int num_iter = 10000;
- int num_warmup = 100;
- int num_interactions = 1;
-
- // pre-create features
- vector<Feature> features;
- for (int i = 0; i < num_features; i++)
- {
- Feature f = {0.5, i+1};
- features.push_back(f);
- }
-
- long long time_init = 0, time_choose = 0;
- for (int iter = 0; iter < num_iter + num_warmup; iter++)
- {
- high_resolution_clock::time_point t1 = high_resolution_clock::now();
- StringRecorder<SimpleContext> recorder;
- MwtExplorer<SimpleContext> mwt("test", recorder);
- MySimplePolicy default_policy;
- EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions);
- high_resolution_clock::time_point t2 = high_resolution_clock::now();
- time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
-
- t1 = high_resolution_clock::now();
- SimpleContext appContext(features);
- for (int i = 0; i < num_interactions; i++)
- {
- mwt.Choose_Action(explorer, unique_key, appContext);
- }
- t2 = high_resolution_clock::now();
- time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
- }
-
- cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl;
- cout << "--- PER ITERATION ---" << endl;
- cout << "Init: " << (double)time_init / num_iter << " micro" << endl;
- cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl;
- cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl;
-}
-
-int main(int argc, char* argv[])
-{
- Clock_Explore();
- return 0;
-}
+// explore.cpp : Timing code to measure performance of MWT Explorer library
+
+#include "MWTExplorer.h"
+#include <chrono>
+#include <tuple>
+#include <iostream>
+
+using namespace std;
+using namespace std::chrono;
+
+using namespace MultiWorldTesting;
+
+class MySimplePolicy : public IPolicy<SimpleContext>
+{
+public:
+ u32 Choose_Action(SimpleContext& context)
+ {
+ return (u32)1;
+ }
+};
+
+const u32 num_actions = 10;
+
+void Clock_Explore()
+{
+ float epsilon = .2f;
+ string unique_key = "key";
+ int num_features = 1000;
+ int num_iter = 10000;
+ int num_warmup = 100;
+ int num_interactions = 1;
+
+ // pre-create features
+ vector<Feature> features;
+ for (int i = 0; i < num_features; i++)
+ {
+ Feature f = {0.5, i+1};
+ features.push_back(f);
+ }
+
+ long long time_init = 0, time_choose = 0;
+ for (int iter = 0; iter < num_iter + num_warmup; iter++)
+ {
+ high_resolution_clock::time_point t1 = high_resolution_clock::now();
+ StringRecorder<SimpleContext> recorder;
+ MwtExplorer<SimpleContext> mwt("test", recorder);
+ MySimplePolicy default_policy;
+ EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions);
+ high_resolution_clock::time_point t2 = high_resolution_clock::now();
+ time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
+
+ t1 = high_resolution_clock::now();
+ SimpleContext appContext(features);
+ for (int i = 0; i < num_interactions; i++)
+ {
+ mwt.Choose_Action(explorer, unique_key, appContext);
+ }
+ t2 = high_resolution_clock::now();
+ time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
+ }
+
+ cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl;
+ cout << "--- PER ITERATION ---" << endl;
+ cout << "Init: " << (double)time_init / num_iter << " micro" << endl;
+ cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl;
+ cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl;
+}
+
+int main(int argc, char* argv[])
+{
+ Clock_Explore();
+ return 0;
+}
diff --git a/explore/static/MWTExplorer.h b/explore/static/MWTExplorer.h
index dd5e3044..4c7009b2 100644
--- a/explore/static/MWTExplorer.h
+++ b/explore/static/MWTExplorer.h
@@ -1,669 +1,669 @@
-//
-// Main interface for clients of the Multiworld testing (MWT) service.
-//
-
-#pragma once
-
-#include <stdexcept>
-#include <float.h>
-#include <math.h>
-#include <stdio.h>
-#include <string.h>
-#include <vector>
-#include <utility>
-#include <memory>
-#include <limits.h>
-#include <tuple>
-
-#ifdef MANAGED_CODE
-#define PORTING_INTERFACE public
-#define MWT_NAMESPACE namespace NativeMultiWorldTesting
-#else
-#define PORTING_INTERFACE private
-#define MWT_NAMESPACE namespace MultiWorldTesting
-#endif
-
-using namespace std;
-
-#include "utility.h"
-
-/** \defgroup MultiWorldTestingCpp
-\brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
-*/
-
-/*!
-* \addtogroup MultiWorldTestingCpp
-* @{
-*/
-
-//! Interface for C++ version of Multiworld Testing library.
-//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
-MWT_NAMESPACE {
-
-// Forward declarations
-template <class Ctx>
-class IRecorder;
-template <class Ctx>
-class IExplorer;
-
-///
-/// The top-level MwtExplorer class. Using this enables principled and efficient exploration
-/// over a set of possible actions, and ensures that the right bits are recorded.
-///
-template <class Ctx>
-class MwtExplorer
-{
-public:
- ///
- /// Constructor
- ///
- /// @param appid This should be unique to your experiment or you risk nasty correlation bugs.
- /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning.
- ///
- MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder)
- {
- m_app_id = HashUtils::Compute_Id_Hash(app_id);
- }
-
- ///
- /// Chooses an action by invoking an underlying exploration algorithm. This should be a
- /// drop-in replacement for any existing policy function.
- ///
- /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback.
- /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc..
- /// @param context The context upon which a decision is made. See SimpleContext below for an example.
- ///
- u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context)
- {
- u64 seed = HashUtils::Compute_Id_Hash(unique_key);
-
- std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context);
-
- u32 action = std::get<0>(action_probability_log_tuple);
- float prob = std::get<1>(action_probability_log_tuple);
-
- if (std::get<2>(action_probability_log_tuple))
- {
- m_recorder.Record(context, action, prob, unique_key);
- }
-
- return action;
- }
-
-private:
- u64 m_app_id;
- IRecorder<Ctx>& m_recorder;
-};
-
-///
-/// Exposes a method to record exploration data based on generic contexts. Exploration data
-/// is specified as a set of tuples <context, action, probability, key> as described below. An
-/// application passes an IRecorder object to the @MwtExplorer constructor. See
-/// @StringRecorder for a sample IRecorder object.
-///
-template <class Ctx>
-class IRecorder
-{
-public:
- ///
- /// Records the exploration data associated with a given decision.
- ///
- /// @param context A user-defined context for the decision
- /// @param action The action chosen by an exploration algorithm given context
- /// @param probability The probability the exploration algorithm chose said action
- /// @param unique_key A user-defined unique identifer for the decision
- ///
- virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0;
-};
-
-///
-/// Exposes a method to choose an action given a generic context, and obtain the relevant
-/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this
-/// interface yourself: instead, use the various exploration algorithms below, which
-/// implement it for you.
-///
-template <class Ctx>
-class IExplorer
-{
-public:
- ///
- /// Determines the action to take and the probability with which it was chosen, for a
- /// given context.
- ///
- /// @param salted_seed A PRG seed based on a unique id information provided by the user
- /// @param context A user-defined context for the decision
- /// @returns The action to take, the probability it was chosen, and a flag indicating
- /// whether to record this decision
- ///
- virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0;
-};
-
-///
-/// Exposes a method to choose an action given a generic context. IPolicy objects are
-/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
-///
-template <class Ctx>
-class IPolicy
-{
-public:
- ///
- /// Determines the action to take for a given context.
- ///
- /// @param context A user-defined context for the decision
- /// @returns The action to take (1-based index)
- ///
- virtual u32 Choose_Action(Ctx& context) = 0;
-};
-
-///
-/// Exposes a method for specifying a score (weight) for each action given a generic context.
-///
-template <class Ctx>
-class IScorer
-{
-public:
- ///
- /// Determines the score of each action for a given context.
- ///
- /// @param context A user-defined context for the decision
- /// @returns A vector of scores indexed by action (1-based)
- ///
- virtual vector<float> Score_Actions(Ctx& context) = 0;
-};
-
-///
-/// A sample recorder class that converts the exploration tuple into string format.
-///
-template <class Ctx>
-struct StringRecorder : public IRecorder<Ctx>
-{
- void Record(Ctx& context, u32 action, float probability, string unique_key)
- {
- // Implicitly enforce To_String() API on the context
- m_recording.append(to_string((unsigned long)action));
- m_recording.append(" ", 1);
- m_recording.append(unique_key);
- m_recording.append(" ", 1);
-
- char prob_str[10] = { 0 };
- NumberUtils::Float_To_String(probability, prob_str);
- m_recording.append(prob_str);
-
- m_recording.append(" | ", 3);
- m_recording.append(context.To_String());
- m_recording.append("\n");
- }
-
- // Gets the content of the recording so far as a string and optionally clears internal content.
- string Get_Recording(bool flush = true)
- {
- if (!flush)
- {
- return m_recording;
- }
- string recording = m_recording;
- m_recording.clear();
- return recording;
- }
-
-private:
- string m_recording;
-};
-
-///
-/// Represents a feature in a sparse array.
-///
-struct Feature
-{
- float Value;
- u32 Id;
-
- bool operator==(Feature other_feature)
- {
- return Id == other_feature.Id;
- }
-};
-
-///
-/// A sample context class that stores a vector of Features.
-///
-class SimpleContext
-{
-public:
- SimpleContext(vector<Feature>& features) :
- m_features(features)
- { }
-
- vector<Feature>& Get_Features()
- {
- return m_features;
- }
-
- string To_String()
- {
- string out_string;
- char feature_str[35] = { 0 };
- for (size_t i = 0; i < m_features.size(); i++)
- {
- int chars;
- if (i == 0)
- {
- chars = sprintf(feature_str, "%d:", m_features[i].Id);
- }
- else
- {
- chars = sprintf(feature_str, " %d:", m_features[i].Id);
- }
- NumberUtils::print_float(feature_str + chars, m_features[i].Value);
- out_string.append(feature_str);
- }
- return out_string;
- }
-
-private:
- vector<Feature>& m_features;
-};
-
-///
-/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea
-/// which actions should be preferred. Epsilon greedy is also computationally cheap.
-///
-template <class Ctx>
-class EpsilonGreedyExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy A default function which outputs an action given a context.
- /// @param epsilon The probability of a random exploration.
- /// @param num_actions The number of actions to randomize over.
- ///
- EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) :
- m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
-
- if (m_epsilon < 0 || m_epsilon > 1)
- {
- throw std::invalid_argument("Epsilon must be between 0 and 1.");
- }
- }
-
- ~EpsilonGreedyExplorer()
- {
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default policy function to get the action
- u32 chosen_action = m_default_policy.Choose_Action(context);
-
- if (chosen_action == 0 || chosen_action > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- float action_probability = 0.f;
- float base_probability = m_epsilon / m_num_actions; // uniform probability
-
- // TODO: check this random generation
- if (random_generator.Uniform_Unit_Interval() < 1.f - m_epsilon)
- {
- action_probability = 1.f - m_epsilon + base_probability;
- }
- else
- {
- // Get uniform random action ID
- u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
-
- if (actionId == chosen_action)
- {
- // IF it matches the one chosen by the default policy
- // then increase the probability
- action_probability = 1.f - m_epsilon + base_probability;
- }
- else
- {
- // Otherwise it's just the uniform probability
- action_probability = base_probability;
- }
- chosen_action = actionId;
- }
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
- }
-
-private:
- IPolicy<Ctx>& m_default_policy;
- float m_epsilon;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// In some cases, different actions have a different scores, and you would prefer to
-/// choose actions with large scores. Softmax allows you to do that.
-///
-template <class Ctx>
-class SoftmaxExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_scorer A function which outputs a score for each action.
- /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.
- /// @param num_actions The number of actions to randomize over.
- ///
- SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) :
- m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default scorer function
- vector<float> scores = m_default_scorer.Score_Actions(context);
- u32 num_scores = (u32)scores.size();
- if (num_scores != m_num_actions)
- {
- throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions");
- }
-
- u32 i = 0;
-
- float max_score = -FLT_MAX;
- for (i = 0; i < num_scores; i++)
- {
- if (max_score < scores[i])
- {
- max_score = scores[i];
- }
- }
-
- // Create a normalized exponential distribution based on the returned scores
- for (i = 0; i < num_scores; i++)
- {
- scores[i] = exp(m_lambda * (scores[i] - max_score));
- }
-
- // Create a discrete_distribution based on the returned weights. This class handles the
- // case where the sum of the weights is < or > 1, by normalizing agains the sum.
- float total = 0.f;
- for (size_t i = 0; i < num_scores; i++)
- total += scores[i];
-
- float draw = random_generator.Uniform_Unit_Interval();
-
- float sum = 0.f;
- float action_probability = 0.f;
- u32 action_index = num_scores - 1;
- for (u32 i = 0; i < num_scores; i++)
- {
- scores[i] = scores[i] / total;
- sum += scores[i];
- if (sum > draw)
- {
- action_index = i;
- action_probability = scores[i];
- break;
- }
- }
-
- // action id is one-based
- return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
- }
-
-private:
- IScorer<Ctx>& m_default_scorer;
- float m_lambda;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// GenericExplorer provides complete flexibility. You can create any
-/// distribution over actions desired, and it will draw from that.
-///
-template <class Ctx>
-class GenericExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_scorer A function which outputs the probability of each action.
- /// @param num_actions The number of actions to randomize over.
- ///
- GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) :
- m_default_scorer(default_scorer), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Invoke the default scorer function
- vector<float> weights = m_default_scorer.Score_Actions(context);
- u32 num_weights = (u32)weights.size();
- if (num_weights != m_num_actions)
- {
- throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions");
- }
-
- // Create a discrete_distribution based on the returned weights. This class handles the
- // case where the sum of the weights is < or > 1, by normalizing agains the sum.
- float total = 0.f;
- for (size_t i = 0; i < num_weights; i++)
- {
- if (weights[i] < 0)
- {
- throw std::invalid_argument("Scores must be non-negative.");
- }
- total += weights[i];
- }
- if (total == 0)
- {
- throw std::invalid_argument("At least one score must be positive.");
- }
-
- float draw = random_generator.Uniform_Unit_Interval();
-
- float sum = 0.f;
- float action_probability = 0.f;
- u32 action_index = num_weights - 1;
- for (u32 i = 0; i < num_weights; i++)
- {
- weights[i] = weights[i] / total;
- sum += weights[i];
- if (sum > draw)
- {
- action_index = i;
- action_probability = weights[i];
- break;
- }
- }
-
- // action id is one-based
- return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
- }
-
-private:
- IScorer<Ctx>& m_default_scorer;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// The tau-first explorer collects exactly tau uniform random exploration events, and then
-/// uses the default policy thereafter.
-///
-template <class Ctx>
-class TauFirstExplorer : public IExplorer<Ctx>
-{
-public:
-
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy A default policy after randomization finishes.
- /// @param tau The number of events to be uniform over.
- /// @param num_actions The number of actions to randomize over.
- ///
- TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) :
- m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions)
- {
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- u32 chosen_action = 0;
- float action_probability = 0.f;
- bool log_action;
- if (m_tau)
- {
- m_tau--;
- u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
- action_probability = 1.f / m_num_actions;
- chosen_action = actionId;
- log_action = true;
- }
- else
- {
- // Invoke the default policy function to get the action
- chosen_action = m_default_policy.Choose_Action(context);
-
- if (chosen_action == 0 || chosen_action > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- action_probability = 1.f;
- log_action = false;
- }
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action);
- }
-
-private:
- IPolicy<Ctx>& m_default_policy;
- u32 m_tau;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-
-///
-/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies.
-/// This performs well statistically but can be computationally expensive.
-///
-template <class Ctx>
-class BootstrapExplorer : public IExplorer<Ctx>
-{
-public:
- ///
- /// The constructor is the only public member, because this should be used with the MwtExplorer.
- ///
- /// @param default_policy_functions A set of default policies to be uniform random over.
- /// The policy pointers must be valid throughout the lifetime of this explorer.
- /// @param num_actions The number of actions to randomize over.
- ///
- BootstrapExplorer(vector<unique_ptr<IPolicy<Ctx>>>& default_policy_functions, u32 num_actions) :
- m_default_policy_functions(default_policy_functions),
- m_num_actions(num_actions)
- {
- m_bags = (u32)default_policy_functions.size();
- if (m_num_actions < 1)
- {
- throw std::invalid_argument("Number of actions must be at least 1.");
- }
-
- if (m_bags < 1)
- {
- throw std::invalid_argument("Number of bags must be at least 1.");
- }
- }
-
-private:
- std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
- {
- PRG::prg random_generator(salted_seed);
-
- // Select bag
- u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1);
-
- // Invoke the default policy function to get the action
- u32 chosen_action = 0;
- u32 action_from_bag = 0;
- vector<u32> actions_selected;
- for (size_t i = 0; i < m_num_actions; i++)
- {
- actions_selected.push_back(0);
- }
-
- // Invoke the default policy function to get the action
- for (u32 current_bag = 0; current_bag < m_bags; current_bag++)
- {
- action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context);
-
- if (action_from_bag == 0 || action_from_bag > m_num_actions)
- {
- throw std::invalid_argument("Action chosen by default policy is not within valid range.");
- }
-
- if (current_bag == chosen_bag)
- {
- chosen_action = action_from_bag;
- }
- //this won't work if actions aren't 0 to Count
- actions_selected[action_from_bag - 1]++; // action id is one-based
- }
- float action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based
-
- return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
- }
-
-private:
- vector<unique_ptr<IPolicy<Ctx>>>& m_default_policy_functions;
- u32 m_bags;
- u32 m_num_actions;
-
-private:
- friend class MwtExplorer<Ctx>;
-};
-} // End namespace MultiWorldTestingCpp
-/*! @} End of Doxygen Groups*/
+//
+// Main interface for clients of the Multiworld testing (MWT) service.
+//
+
+#pragma once
+
+#include <stdexcept>
+#include <float.h>
+#include <math.h>
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include <utility>
+#include <memory>
+#include <limits.h>
+#include <tuple>
+
+#ifdef MANAGED_CODE
+#define PORTING_INTERFACE public
+#define MWT_NAMESPACE namespace NativeMultiWorldTesting
+#else
+#define PORTING_INTERFACE private
+#define MWT_NAMESPACE namespace MultiWorldTesting
+#endif
+
+using namespace std;
+
+#include "utility.h"
+
+/** \defgroup MultiWorldTestingCpp
+\brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
+*/
+
+/*!
+* \addtogroup MultiWorldTestingCpp
+* @{
+*/
+
+//! Interface for C++ version of Multiworld Testing library.
+//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
+MWT_NAMESPACE {
+
+// Forward declarations
+template <class Ctx>
+class IRecorder;
+template <class Ctx>
+class IExplorer;
+
+///
+/// The top-level MwtExplorer class. Using this enables principled and efficient exploration
+/// over a set of possible actions, and ensures that the right bits are recorded.
+///
+template <class Ctx>
+class MwtExplorer
+{
+public:
+ ///
+ /// Constructor
+ ///
+ /// @param appid This should be unique to your experiment or you risk nasty correlation bugs.
+ /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning.
+ ///
+ MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder)
+ {
+ m_app_id = HashUtils::Compute_Id_Hash(app_id);
+ }
+
+ ///
+ /// Chooses an action by invoking an underlying exploration algorithm. This should be a
+ /// drop-in replacement for any existing policy function.
+ ///
+ /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback.
+ /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc..
+ /// @param context The context upon which a decision is made. See SimpleContext below for an example.
+ ///
+ u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context)
+ {
+ u64 seed = HashUtils::Compute_Id_Hash(unique_key);
+
+ std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context);
+
+ u32 action = std::get<0>(action_probability_log_tuple);
+ float prob = std::get<1>(action_probability_log_tuple);
+
+ if (std::get<2>(action_probability_log_tuple))
+ {
+ m_recorder.Record(context, action, prob, unique_key);
+ }
+
+ return action;
+ }
+
+private:
+ u64 m_app_id;
+ IRecorder<Ctx>& m_recorder;
+};
+
+///
+/// Exposes a method to record exploration data based on generic contexts. Exploration data
+/// is specified as a set of tuples <context, action, probability, key> as described below. An
+/// application passes an IRecorder object to the @MwtExplorer constructor. See
+/// @StringRecorder for a sample IRecorder object.
+///
+template <class Ctx>
+class IRecorder
+{
+public:
+ ///
+ /// Records the exploration data associated with a given decision.
+ ///
+ /// @param context A user-defined context for the decision
+ /// @param action The action chosen by an exploration algorithm given context
+ /// @param probability The probability the exploration algorithm chose said action
+ /// @param unique_key A user-defined unique identifer for the decision
+ ///
+ virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0;
+};
+
+///
+/// Exposes a method to choose an action given a generic context, and obtain the relevant
+/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this
+/// interface yourself: instead, use the various exploration algorithms below, which
+/// implement it for you.
+///
+template <class Ctx>
+class IExplorer
+{
+public:
+ ///
+ /// Determines the action to take and the probability with which it was chosen, for a
+ /// given context.
+ ///
+ /// @param salted_seed A PRG seed based on a unique id information provided by the user
+ /// @param context A user-defined context for the decision
+ /// @returns The action to take, the probability it was chosen, and a flag indicating
+ /// whether to record this decision
+ ///
+ virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0;
+};
+
+///
+/// Exposes a method to choose an action given a generic context. IPolicy objects are
+/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
+///
+template <class Ctx>
+class IPolicy
+{
+public:
+ ///
+ /// Determines the action to take for a given context.
+ ///
+ /// @param context A user-defined context for the decision
+ /// @returns The action to take (1-based index)
+ ///
+ virtual u32 Choose_Action(Ctx& context) = 0;
+};
+
+///
+/// Exposes a method for specifying a score (weight) for each action given a generic context.
+///
+template <class Ctx>
+class IScorer
+{
+public:
+ ///
+ /// Determines the score of each action for a given context.
+ ///
+ /// @param context A user-defined context for the decision
+ /// @returns A vector of scores indexed by action (1-based)
+ ///
+ virtual vector<float> Score_Actions(Ctx& context) = 0;
+};
+
+///
+/// A sample recorder class that converts the exploration tuple into string format.
+///
+template <class Ctx>
+struct StringRecorder : public IRecorder<Ctx>
+{
+ void Record(Ctx& context, u32 action, float probability, string unique_key)
+ {
+ // Implicitly enforce To_String() API on the context
+ m_recording.append(to_string((unsigned long)action));
+ m_recording.append(" ", 1);
+ m_recording.append(unique_key);
+ m_recording.append(" ", 1);
+
+ char prob_str[10] = { 0 };
+ NumberUtils::Float_To_String(probability, prob_str);
+ m_recording.append(prob_str);
+
+ m_recording.append(" | ", 3);
+ m_recording.append(context.To_String());
+ m_recording.append("\n");
+ }
+
+ // Gets the content of the recording so far as a string and optionally clears internal content.
+ string Get_Recording(bool flush = true)
+ {
+ if (!flush)
+ {
+ return m_recording;
+ }
+ string recording = m_recording;
+ m_recording.clear();
+ return recording;
+ }
+
+private:
+ string m_recording;
+};
+
+///
+/// Represents a feature in a sparse array.
+///
+struct Feature
+{
+ float Value;
+ u32 Id;
+
+ bool operator==(Feature other_feature)
+ {
+ return Id == other_feature.Id;
+ }
+};
+
+///
+/// A sample context class that stores a vector of Features.
+///
+class SimpleContext
+{
+public:
+ SimpleContext(vector<Feature>& features) :
+ m_features(features)
+ { }
+
+ vector<Feature>& Get_Features()
+ {
+ return m_features;
+ }
+
+ string To_String()
+ {
+ string out_string;
+ char feature_str[35] = { 0 };
+ for (size_t i = 0; i < m_features.size(); i++)
+ {
+ int chars;
+ if (i == 0)
+ {
+ chars = sprintf(feature_str, "%d:", m_features[i].Id);
+ }
+ else
+ {
+ chars = sprintf(feature_str, " %d:", m_features[i].Id);
+ }
+ NumberUtils::print_float(feature_str + chars, m_features[i].Value);
+ out_string.append(feature_str);
+ }
+ return out_string;
+ }
+
+private:
+ vector<Feature>& m_features;
+};
+
+///
+/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea
+/// which actions should be preferred. Epsilon greedy is also computationally cheap.
+///
+template <class Ctx>
+class EpsilonGreedyExplorer : public IExplorer<Ctx>
+{
+public:
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_policy A default function which outputs an action given a context.
+ /// @param epsilon The probability of a random exploration.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) :
+ m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions)
+ {
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+
+ if (m_epsilon < 0 || m_epsilon > 1)
+ {
+ throw std::invalid_argument("Epsilon must be between 0 and 1.");
+ }
+ }
+
+ ~EpsilonGreedyExplorer()
+ {
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ // Invoke the default policy function to get the action
+ u32 chosen_action = m_default_policy.Choose_Action(context);
+
+ if (chosen_action == 0 || chosen_action > m_num_actions)
+ {
+ throw std::invalid_argument("Action chosen by default policy is not within valid range.");
+ }
+
+ float action_probability = 0.f;
+ float base_probability = m_epsilon / m_num_actions; // uniform probability
+
+ // TODO: check this random generation
+ if (random_generator.Uniform_Unit_Interval() < 1.f - m_epsilon)
+ {
+ action_probability = 1.f - m_epsilon + base_probability;
+ }
+ else
+ {
+ // Get uniform random action ID
+ u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
+
+ if (actionId == chosen_action)
+ {
+ // IF it matches the one chosen by the default policy
+ // then increase the probability
+ action_probability = 1.f - m_epsilon + base_probability;
+ }
+ else
+ {
+ // Otherwise it's just the uniform probability
+ action_probability = base_probability;
+ }
+ chosen_action = actionId;
+ }
+
+ return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
+ }
+
+private:
+ IPolicy<Ctx>& m_default_policy;
+ float m_epsilon;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+
+///
+/// In some cases, different actions have a different scores, and you would prefer to
+/// choose actions with large scores. Softmax allows you to do that.
+///
+template <class Ctx>
+class SoftmaxExplorer : public IExplorer<Ctx>
+{
+public:
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_scorer A function which outputs a score for each action.
+ /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) :
+ m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions)
+ {
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ // Invoke the default scorer function
+ vector<float> scores = m_default_scorer.Score_Actions(context);
+ u32 num_scores = (u32)scores.size();
+ if (num_scores != m_num_actions)
+ {
+ throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions");
+ }
+
+ u32 i = 0;
+
+ float max_score = -FLT_MAX;
+ for (i = 0; i < num_scores; i++)
+ {
+ if (max_score < scores[i])
+ {
+ max_score = scores[i];
+ }
+ }
+
+ // Create a normalized exponential distribution based on the returned scores
+ for (i = 0; i < num_scores; i++)
+ {
+ scores[i] = exp(m_lambda * (scores[i] - max_score));
+ }
+
+ // Create a discrete_distribution based on the returned weights. This class handles the
+ // case where the sum of the weights is < or > 1, by normalizing agains the sum.
+ float total = 0.f;
+ for (size_t i = 0; i < num_scores; i++)
+ total += scores[i];
+
+ float draw = random_generator.Uniform_Unit_Interval();
+
+ float sum = 0.f;
+ float action_probability = 0.f;
+ u32 action_index = num_scores - 1;
+ for (u32 i = 0; i < num_scores; i++)
+ {
+ scores[i] = scores[i] / total;
+ sum += scores[i];
+ if (sum > draw)
+ {
+ action_index = i;
+ action_probability = scores[i];
+ break;
+ }
+ }
+
+ // action id is one-based
+ return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
+ }
+
+private:
+ IScorer<Ctx>& m_default_scorer;
+ float m_lambda;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+
+///
+/// GenericExplorer provides complete flexibility. You can create any
+/// distribution over actions desired, and it will draw from that.
+///
+template <class Ctx>
+class GenericExplorer : public IExplorer<Ctx>
+{
+public:
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_scorer A function which outputs the probability of each action.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) :
+ m_default_scorer(default_scorer), m_num_actions(num_actions)
+ {
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ // Invoke the default scorer function
+ vector<float> weights = m_default_scorer.Score_Actions(context);
+ u32 num_weights = (u32)weights.size();
+ if (num_weights != m_num_actions)
+ {
+ throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions");
+ }
+
+ // Create a discrete_distribution based on the returned weights. This class handles the
+ // case where the sum of the weights is < or > 1, by normalizing agains the sum.
+ float total = 0.f;
+ for (size_t i = 0; i < num_weights; i++)
+ {
+ if (weights[i] < 0)
+ {
+ throw std::invalid_argument("Scores must be non-negative.");
+ }
+ total += weights[i];
+ }
+ if (total == 0)
+ {
+ throw std::invalid_argument("At least one score must be positive.");
+ }
+
+ float draw = random_generator.Uniform_Unit_Interval();
+
+ float sum = 0.f;
+ float action_probability = 0.f;
+ u32 action_index = num_weights - 1;
+ for (u32 i = 0; i < num_weights; i++)
+ {
+ weights[i] = weights[i] / total;
+ sum += weights[i];
+ if (sum > draw)
+ {
+ action_index = i;
+ action_probability = weights[i];
+ break;
+ }
+ }
+
+ // action id is one-based
+ return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
+ }
+
+private:
+ IScorer<Ctx>& m_default_scorer;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+
+///
+/// The tau-first explorer collects exactly tau uniform random exploration events, and then
+/// uses the default policy thereafter.
+///
+template <class Ctx>
+class TauFirstExplorer : public IExplorer<Ctx>
+{
+public:
+
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_policy A default policy after randomization finishes.
+ /// @param tau The number of events to be uniform over.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) :
+ m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions)
+ {
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ u32 chosen_action = 0;
+ float action_probability = 0.f;
+ bool log_action;
+ if (m_tau)
+ {
+ m_tau--;
+ u32 actionId = random_generator.Uniform_Int(1, m_num_actions);
+ action_probability = 1.f / m_num_actions;
+ chosen_action = actionId;
+ log_action = true;
+ }
+ else
+ {
+ // Invoke the default policy function to get the action
+ chosen_action = m_default_policy.Choose_Action(context);
+
+ if (chosen_action == 0 || chosen_action > m_num_actions)
+ {
+ throw std::invalid_argument("Action chosen by default policy is not within valid range.");
+ }
+
+ action_probability = 1.f;
+ log_action = false;
+ }
+
+ return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action);
+ }
+
+private:
+ IPolicy<Ctx>& m_default_policy;
+ u32 m_tau;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+
+///
+/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies.
+/// This performs well statistically but can be computationally expensive.
+///
+template <class Ctx>
+class BootstrapExplorer : public IExplorer<Ctx>
+{
+public:
+ ///
+ /// The constructor is the only public member, because this should be used with the MwtExplorer.
+ ///
+ /// @param default_policy_functions A set of default policies to be uniform random over.
+ /// The policy pointers must be valid throughout the lifetime of this explorer.
+ /// @param num_actions The number of actions to randomize over.
+ ///
+ BootstrapExplorer(vector<unique_ptr<IPolicy<Ctx>>>& default_policy_functions, u32 num_actions) :
+ m_default_policy_functions(default_policy_functions),
+ m_num_actions(num_actions)
+ {
+ m_bags = (u32)default_policy_functions.size();
+ if (m_num_actions < 1)
+ {
+ throw std::invalid_argument("Number of actions must be at least 1.");
+ }
+
+ if (m_bags < 1)
+ {
+ throw std::invalid_argument("Number of bags must be at least 1.");
+ }
+ }
+
+private:
+ std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
+ {
+ PRG::prg random_generator(salted_seed);
+
+ // Select bag
+ u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1);
+
+ // Invoke the default policy function to get the action
+ u32 chosen_action = 0;
+ u32 action_from_bag = 0;
+ vector<u32> actions_selected;
+ for (size_t i = 0; i < m_num_actions; i++)
+ {
+ actions_selected.push_back(0);
+ }
+
+ // Invoke the default policy function to get the action
+ for (u32 current_bag = 0; current_bag < m_bags; current_bag++)
+ {
+ action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context);
+
+ if (action_from_bag == 0 || action_from_bag > m_num_actions)
+ {
+ throw std::invalid_argument("Action chosen by default policy is not within valid range.");
+ }
+
+ if (current_bag == chosen_bag)
+ {
+ chosen_action = action_from_bag;
+ }
+ //this won't work if actions aren't 0 to Count
+ actions_selected[action_from_bag - 1]++; // action id is one-based
+ }
+ float action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based
+
+ return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
+ }
+
+private:
+ vector<unique_ptr<IPolicy<Ctx>>>& m_default_policy_functions;
+ u32 m_bags;
+ u32 m_num_actions;
+
+private:
+ friend class MwtExplorer<Ctx>;
+};
+} // End namespace MultiWorldTestingCpp
+/*! @} End of Doxygen Groups*/
diff --git a/explore/static/utility.h b/explore/static/utility.h
index a6284809..6a2a60d6 100644
--- a/explore/static/utility.h
+++ b/explore/static/utility.h
@@ -1,286 +1,286 @@
/*******************************************************************/
// Classes declared in this file are intended for internal use only.
-/*******************************************************************/
-
-#pragma once
-#include <stdint.h>
-#include <sys/types.h> /* defines size_t */
-
-#ifdef WIN32
-typedef unsigned __int64 u64;
-typedef unsigned __int32 u32;
-typedef unsigned __int16 u16;
-typedef unsigned __int8 u8;
-typedef signed __int64 i64;
-typedef signed __int32 i32;
-typedef signed __int16 i16;
-typedef signed __int8 i8;
-#else
-typedef uint64_t u64;
-typedef uint32_t u32;
-typedef uint16_t u16;
-typedef uint8_t u8;
-typedef int64_t i64;
-typedef int32_t i32;
-typedef int16_t i16;
-typedef int8_t i8;
-#endif
-
-typedef unsigned char byte;
-
-#include <string>
-#include <stdint.h>
-#include <math.h>
-
-/*!
-* \addtogroup MultiWorldTestingCpp
-* @{
-*/
-
-MWT_NAMESPACE {
-
-//
-// MurmurHash3, by Austin Appleby
-//
-// Originals at:
-// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp
-// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.h
-//
-// Notes:
-// 1) this code assumes we can read a 4-byte value from any address
-// without crashing (i.e non aligned access is supported). This is
-// not a problem on Intel/x86/AMD64 machines (including new Macs)
-// 2) It produces different results on little-endian and big-endian machines.
-//
-//-----------------------------------------------------------------------------
-// MurmurHash3 was written by Austin Appleby, and is placed in the public
-// domain. The author hereby disclaims copyright to this source code.
-
-// Note - The x86 and x64 versions do _not_ produce the same results, as the
-// algorithms are optimized for their respective platforms. You can still
-// compile and run any of them on any platform, but your performance with the
-// non-native version will be less than optimal.
-//-----------------------------------------------------------------------------
-
-// Platform-specific functions and macros
-#if defined(_MSC_VER) // Microsoft Visual Studio
-# include <stdint.h>
-
-# include <stdlib.h>
-# define ROTL32(x,y) _rotl(x,y)
-# define BIG_CONSTANT(x) (x)
-
-#else // Other compilers
-# include <stdint.h> /* defines uint32_t etc */
-
- inline uint32_t rotl32(uint32_t x, int8_t r)
- {
- return (x << r) | (x >> (32 - r));
- }
-
-# define ROTL32(x,y) rotl32(x,y)
-# define BIG_CONSTANT(x) (x##LLU)
-
-#endif // !defined(_MSC_VER)
-
-struct murmur_hash {
-
- //-----------------------------------------------------------------------------
- // Block read - if your platform needs to do endian-swapping or can only
- // handle aligned reads, do the conversion here
-private:
- static inline uint32_t getblock(const uint32_t * p, int i)
- {
- return p[i];
- }
-
- //-----------------------------------------------------------------------------
- // Finalization mix - force all bits of a hash block to avalanche
-
- static inline uint32_t fmix(uint32_t h)
- {
- h ^= h >> 16;
- h *= 0x85ebca6b;
- h ^= h >> 13;
- h *= 0xc2b2ae35;
- h ^= h >> 16;
-
- return h;
- }
-
- //-----------------------------------------------------------------------------
-public:
- uint32_t uniform_hash(const void * key, size_t len, uint32_t seed)
- {
- const uint8_t * data = (const uint8_t*)key;
- const int nblocks = (int)len / 4;
-
- uint32_t h1 = seed;
-
- const uint32_t c1 = 0xcc9e2d51;
- const uint32_t c2 = 0x1b873593;
-
- // --- body
- const uint32_t * blocks = (const uint32_t *)(data + nblocks * 4);
-
- for (int i = -nblocks; i; i++) {
- uint32_t k1 = getblock(blocks, i);
-
- k1 *= c1;
- k1 = ROTL32(k1, 15);
- k1 *= c2;
-
- h1 ^= k1;
- h1 = ROTL32(h1, 13);
- h1 = h1 * 5 + 0xe6546b64;
- }
-
- // --- tail
- const uint8_t * tail = (const uint8_t*)(data + nblocks * 4);
-
- uint32_t k1 = 0;
-
- switch (len & 3) {
- case 3: k1 ^= tail[2] << 16;
- case 2: k1 ^= tail[1] << 8;
- case 1: k1 ^= tail[0];
- k1 *= c1; k1 = ROTL32(k1, 15); k1 *= c2; h1 ^= k1;
- }
-
- // --- finalization
- h1 ^= len;
-
- return fmix(h1);
- }
-};
-
-class HashUtils
-{
-public:
- static u64 Compute_Id_Hash(const std::string& unique_id)
- {
- size_t ret = 0;
- const char *p = unique_id.c_str();
- while (*p != '\0')
- if (*p >= '0' && *p <= '9')
- ret = 10 * ret + *(p++) - '0';
- else
- {
- murmur_hash foo;
- return foo.uniform_hash(unique_id.c_str(), unique_id.size(), 0);
- }
- return ret;
- }
-};
-
-const size_t max_int = 100000;
-const float max_float = max_int;
-const float min_float = 0.00001f;
-const size_t max_digits = (size_t) roundf((float) (-log(min_float) / log(10.)));
-
-class NumberUtils
-{
-public:
- static void Float_To_String(float f, char* str)
- {
- int x = (int)f;
- int d = (int)(fabs(f - x) * 100000);
- sprintf(str, "%d.%05d", x, d);
- }
-
- template<bool trailing_zeros>
- static void print_mantissa(char*& begin, float f)
- { // helper for print_float
- char values[10];
- size_t v = (size_t)f;
- size_t digit = 0;
- size_t first_nonzero = 0;
- for (size_t max = 1; max <= v; ++digit)
- {
- size_t max_next = max * 10;
- char v_mod = (char) (v % max_next / max);
- if (!trailing_zeros && v_mod != '\0' && first_nonzero == 0)
- first_nonzero = digit;
- values[digit] = '0' + v_mod;
- max = max_next;
- }
- if (!trailing_zeros)
- for (size_t i = max_digits; i > digit; i--)
- *begin++ = '0';
- while (digit > first_nonzero)
- *begin++ = values[--digit];
- }
-
- static void print_float(char* begin, float f)
- {
- bool sign = false;
- if (f < 0.f)
- sign = true;
- float unsigned_f = fabsf(f);
- if (unsigned_f < max_float && unsigned_f > min_float)
- {
- if (sign)
- *begin++ = '-';
- print_mantissa<true>(begin, unsigned_f);
- unsigned_f -= (size_t)unsigned_f;
- unsigned_f *= max_int;
- if (unsigned_f >= 1.f)
- {
- *begin++ = '.';
- print_mantissa<false>(begin, unsigned_f);
- }
- }
- else if (unsigned_f == 0.)
- *begin++ = '0';
- else
- {
- sprintf(begin, "%g", f);
- return;
- }
- *begin = '\0';
- return;
- }
-};
-
-//A quick implementation similar to drand48 for cross-platform compatibility
-namespace PRG {
- const uint64_t a = 0xeece66d5deece66dULL;
- const uint64_t c = 2147483647;
-
- const int bias = 127 << 23;
-
- union int_float {
- int32_t i;
- float f;
- };
-
- struct prg {
- private:
- uint64_t v;
- public:
- prg() { v = c; }
- prg(uint64_t initial) { v = initial; }
-
- float merand48(uint64_t& initial)
- {
- initial = a * initial + c;
- int_float temp;
- temp.i = ((initial >> 25) & 0x7FFFFF) | bias;
- return temp.f - 1;
- }
-
- float Uniform_Unit_Interval()
- {
- return merand48(v);
- }
-
- uint32_t Uniform_Int(uint32_t low, uint32_t high)
- {
- merand48(v);
- uint32_t ret = low + ((v >> 25) % (high - low + 1));
- return ret;
- }
- };
-}
-}
+/*******************************************************************/
+
+#pragma once
+#include <stdint.h>
+#include <sys/types.h> /* defines size_t */
+
+#ifdef WIN32
+typedef unsigned __int64 u64;
+typedef unsigned __int32 u32;
+typedef unsigned __int16 u16;
+typedef unsigned __int8 u8;
+typedef signed __int64 i64;
+typedef signed __int32 i32;
+typedef signed __int16 i16;
+typedef signed __int8 i8;
+#else
+typedef uint64_t u64;
+typedef uint32_t u32;
+typedef uint16_t u16;
+typedef uint8_t u8;
+typedef int64_t i64;
+typedef int32_t i32;
+typedef int16_t i16;
+typedef int8_t i8;
+#endif
+
+typedef unsigned char byte;
+
+#include <string>
+#include <stdint.h>
+#include <math.h>
+
+/*!
+* \addtogroup MultiWorldTestingCpp
+* @{
+*/
+
+MWT_NAMESPACE {
+
+//
+// MurmurHash3, by Austin Appleby
+//
+// Originals at:
+// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp
+// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.h
+//
+// Notes:
+// 1) this code assumes we can read a 4-byte value from any address
+// without crashing (i.e non aligned access is supported). This is
+// not a problem on Intel/x86/AMD64 machines (including new Macs)
+// 2) It produces different results on little-endian and big-endian machines.
+//
+//-----------------------------------------------------------------------------
+// MurmurHash3 was written by Austin Appleby, and is placed in the public
+// domain. The author hereby disclaims copyright to this source code.
+
+// Note - The x86 and x64 versions do _not_ produce the same results, as the
+// algorithms are optimized for their respective platforms. You can still
+// compile and run any of them on any platform, but your performance with the
+// non-native version will be less than optimal.
+//-----------------------------------------------------------------------------
+
+// Platform-specific functions and macros
+#if defined(_MSC_VER) // Microsoft Visual Studio
+# include <stdint.h>
+
+# include <stdlib.h>
+# define ROTL32(x,y) _rotl(x,y)
+# define BIG_CONSTANT(x) (x)
+
+#else // Other compilers
+# include <stdint.h> /* defines uint32_t etc */
+
+ inline uint32_t rotl32(uint32_t x, int8_t r)
+ {
+ return (x << r) | (x >> (32 - r));
+ }
+
+# define ROTL32(x,y) rotl32(x,y)
+# define BIG_CONSTANT(x) (x##LLU)
+
+#endif // !defined(_MSC_VER)
+
+struct murmur_hash {
+
+ //-----------------------------------------------------------------------------
+ // Block read - if your platform needs to do endian-swapping or can only
+ // handle aligned reads, do the conversion here
+private:
+ static inline uint32_t getblock(const uint32_t * p, int i)
+ {
+ return p[i];
+ }
+
+ //-----------------------------------------------------------------------------
+ // Finalization mix - force all bits of a hash block to avalanche
+
+ static inline uint32_t fmix(uint32_t h)
+ {
+ h ^= h >> 16;
+ h *= 0x85ebca6b;
+ h ^= h >> 13;
+ h *= 0xc2b2ae35;
+ h ^= h >> 16;
+
+ return h;
+ }
+
+ //-----------------------------------------------------------------------------
+public:
+ uint32_t uniform_hash(const void * key, size_t len, uint32_t seed)
+ {
+ const uint8_t * data = (const uint8_t*)key;
+ const int nblocks = (int)len / 4;
+
+ uint32_t h1 = seed;
+
+ const uint32_t c1 = 0xcc9e2d51;
+ const uint32_t c2 = 0x1b873593;
+
+ // --- body
+ const uint32_t * blocks = (const uint32_t *)(data + nblocks * 4);
+
+ for (int i = -nblocks; i; i++) {
+ uint32_t k1 = getblock(blocks, i);
+
+ k1 *= c1;
+ k1 = ROTL32(k1, 15);
+ k1 *= c2;
+
+ h1 ^= k1;
+ h1 = ROTL32(h1, 13);
+ h1 = h1 * 5 + 0xe6546b64;
+ }
+
+ // --- tail
+ const uint8_t * tail = (const uint8_t*)(data + nblocks * 4);
+
+ uint32_t k1 = 0;
+
+ switch (len & 3) {
+ case 3: k1 ^= tail[2] << 16;
+ case 2: k1 ^= tail[1] << 8;
+ case 1: k1 ^= tail[0];
+ k1 *= c1; k1 = ROTL32(k1, 15); k1 *= c2; h1 ^= k1;
+ }
+
+ // --- finalization
+ h1 ^= len;
+
+ return fmix(h1);
+ }
+};
+
+class HashUtils
+{
+public:
+ static u64 Compute_Id_Hash(const std::string& unique_id)
+ {
+ size_t ret = 0;
+ const char *p = unique_id.c_str();
+ while (*p != '\0')
+ if (*p >= '0' && *p <= '9')
+ ret = 10 * ret + *(p++) - '0';
+ else
+ {
+ murmur_hash foo;
+ return foo.uniform_hash(unique_id.c_str(), unique_id.size(), 0);
+ }
+ return ret;
+ }
+};
+
+const size_t max_int = 100000;
+const float max_float = max_int;
+const float min_float = 0.00001f;
+const size_t max_digits = (size_t) roundf((float) (-log(min_float) / log(10.)));
+
+class NumberUtils
+{
+public:
+ static void Float_To_String(float f, char* str)
+ {
+ int x = (int)f;
+ int d = (int)(fabs(f - x) * 100000);
+ sprintf(str, "%d.%05d", x, d);
+ }
+
+ template<bool trailing_zeros>
+ static void print_mantissa(char*& begin, float f)
+ { // helper for print_float
+ char values[10];
+ size_t v = (size_t)f;
+ size_t digit = 0;
+ size_t first_nonzero = 0;
+ for (size_t max = 1; max <= v; ++digit)
+ {
+ size_t max_next = max * 10;
+ char v_mod = (char) (v % max_next / max);
+ if (!trailing_zeros && v_mod != '\0' && first_nonzero == 0)
+ first_nonzero = digit;
+ values[digit] = '0' + v_mod;
+ max = max_next;
+ }
+ if (!trailing_zeros)
+ for (size_t i = max_digits; i > digit; i--)
+ *begin++ = '0';
+ while (digit > first_nonzero)
+ *begin++ = values[--digit];
+ }
+
+ static void print_float(char* begin, float f)
+ {
+ bool sign = false;
+ if (f < 0.f)
+ sign = true;
+ float unsigned_f = fabsf(f);
+ if (unsigned_f < max_float && unsigned_f > min_float)
+ {
+ if (sign)
+ *begin++ = '-';
+ print_mantissa<true>(begin, unsigned_f);
+ unsigned_f -= (size_t)unsigned_f;
+ unsigned_f *= max_int;
+ if (unsigned_f >= 1.f)
+ {
+ *begin++ = '.';
+ print_mantissa<false>(begin, unsigned_f);
+ }
+ }
+ else if (unsigned_f == 0.)
+ *begin++ = '0';
+ else
+ {
+ sprintf(begin, "%g", f);
+ return;
+ }
+ *begin = '\0';
+ return;
+ }
+};
+
+//A quick implementation similar to drand48 for cross-platform compatibility
+namespace PRG {
+ const uint64_t a = 0xeece66d5deece66dULL;
+ const uint64_t c = 2147483647;
+
+ const int bias = 127 << 23;
+
+ union int_float {
+ int32_t i;
+ float f;
+ };
+
+ struct prg {
+ private:
+ uint64_t v;
+ public:
+ prg() { v = c; }
+ prg(uint64_t initial) { v = initial; }
+
+ float merand48(uint64_t& initial)
+ {
+ initial = a * initial + c;
+ int_float temp;
+ temp.i = ((initial >> 25) & 0x7FFFFF) | bias;
+ return temp.f - 1;
+ }
+
+ float Uniform_Unit_Interval()
+ {
+ return merand48(v);
+ }
+
+ uint32_t Uniform_Int(uint32_t low, uint32_t high)
+ {
+ merand48(v);
+ uint32_t ret = low + ((v >> 25) % (high - low + 1));
+ return ret;
+ }
+ };
+}
+}
/*! @} End of Doxygen Groups*/
diff --git a/explore/tests/MWTExploreTests.h b/explore/tests/MWTExploreTests.h
index 447a368f..1c451fb4 100644
--- a/explore/tests/MWTExploreTests.h
+++ b/explore/tests/MWTExploreTests.h
@@ -1,164 +1,164 @@
-#pragma once
-
-#include "MWTExplorer.h"
-#include "utility.h"
-#include <iomanip>
-#include <iostream>
-#include <sstream>
-
-using namespace MultiWorldTesting;
-
-class TestContext
-{
-
-};
-
-template <class Ctx>
-struct TestInteraction
-{
- Ctx& Context;
- u32 Action;
- float Probability;
- string Unique_Key;
-};
-
-class TestPolicy : public IPolicy<TestContext>
-{
-public:
- TestPolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- u32 Choose_Action(TestContext& context)
- {
- return m_params % m_num_actions + 1; // action id is one-based
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestScorer : public IScorer<TestContext>
-{
-public:
- TestScorer(int params, int num_actions, bool uniform = true) :
- m_params(params), m_num_actions(num_actions), m_uniform(uniform)
- { }
-
- vector<float> Score_Actions(TestContext& context)
- {
- vector<float> scores;
- if (m_uniform)
- {
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params);
- }
- }
- else
- {
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params + i);
- }
- }
- return scores;
- }
-private:
- int m_params;
- int m_num_actions;
- bool m_uniform;
-};
-
-class FixedScorer : public IScorer<TestContext>
-{
-public:
- FixedScorer(int num_actions, int value) :
- m_num_actions(num_actions), m_value(value)
- { }
-
- vector<float> Score_Actions(TestContext& context)
- {
- vector<float> scores;
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back((float)m_value);
- }
- return scores;
- }
-private:
- int m_num_actions;
- int m_value;
-};
-
-class TestSimpleScorer : public IScorer<SimpleContext>
-{
-public:
- TestSimpleScorer(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- vector<float> Score_Actions(SimpleContext& context)
- {
- vector<float> scores;
- for (u32 i = 0; i < m_num_actions; i++)
- {
- scores.push_back(m_params);
- }
- return scores;
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestSimplePolicy : public IPolicy<SimpleContext>
-{
-public:
- TestSimplePolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
- u32 Choose_Action(SimpleContext& context)
- {
- return m_params % m_num_actions + 1; // action id is one-based
- }
-private:
- int m_params;
- int m_num_actions;
-};
-
-class TestSimpleRecorder : public IRecorder<SimpleContext>
-{
-public:
- virtual void Record(SimpleContext& context, u32 action, float probability, string unique_key)
- {
- m_interactions.push_back({ context, action, probability, unique_key });
- }
-
- vector<TestInteraction<SimpleContext>> Get_All_Interactions()
- {
- return m_interactions;
- }
-
-private:
- vector<TestInteraction<SimpleContext>> m_interactions;
-};
-
-// Return action outside valid range
-class TestBadPolicy : public IPolicy<TestContext>
-{
-public:
- u32 Choose_Action(TestContext& context)
- {
- return 100;
- }
-};
-
-class TestRecorder : public IRecorder<TestContext>
-{
-public:
- virtual void Record(TestContext& context, u32 action, float probability, string unique_key)
- {
- m_interactions.push_back({ context, action, probability, unique_key });
- }
-
- vector<TestInteraction<TestContext>> Get_All_Interactions()
- {
- return m_interactions;
- }
-
-private:
- vector<TestInteraction<TestContext>> m_interactions;
-};
+#pragma once
+
+#include "MWTExplorer.h"
+#include "utility.h"
+#include <iomanip>
+#include <iostream>
+#include <sstream>
+
+using namespace MultiWorldTesting;
+
+class TestContext
+{
+
+};
+
+template <class Ctx>
+struct TestInteraction
+{
+ Ctx& Context;
+ u32 Action;
+ float Probability;
+ string Unique_Key;
+};
+
+class TestPolicy : public IPolicy<TestContext>
+{
+public:
+ TestPolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
+ u32 Choose_Action(TestContext& context)
+ {
+ return m_params % m_num_actions + 1; // action id is one-based
+ }
+private:
+ int m_params;
+ int m_num_actions;
+};
+
+class TestScorer : public IScorer<TestContext>
+{
+public:
+ TestScorer(int params, int num_actions, bool uniform = true) :
+ m_params(params), m_num_actions(num_actions), m_uniform(uniform)
+ { }
+
+ vector<float> Score_Actions(TestContext& context)
+ {
+ vector<float> scores;
+ if (m_uniform)
+ {
+ for (u32 i = 0; i < m_num_actions; i++)
+ {
+ scores.push_back(m_params);
+ }
+ }
+ else
+ {
+ for (u32 i = 0; i < m_num_actions; i++)
+ {
+ scores.push_back(m_params + i);
+ }
+ }
+ return scores;
+ }
+private:
+ int m_params;
+ int m_num_actions;
+ bool m_uniform;
+};
+
+class FixedScorer : public IScorer<TestContext>
+{
+public:
+ FixedScorer(int num_actions, int value) :
+ m_num_actions(num_actions), m_value(value)
+ { }
+
+ vector<float> Score_Actions(TestContext& context)
+ {
+ vector<float> scores;
+ for (u32 i = 0; i < m_num_actions; i++)
+ {
+ scores.push_back((float)m_value);
+ }
+ return scores;
+ }
+private:
+ int m_num_actions;
+ int m_value;
+};
+
+class TestSimpleScorer : public IScorer<SimpleContext>
+{
+public:
+ TestSimpleScorer(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
+ vector<float> Score_Actions(SimpleContext& context)
+ {
+ vector<float> scores;
+ for (u32 i = 0; i < m_num_actions; i++)
+ {
+ scores.push_back(m_params);
+ }
+ return scores;
+ }
+private:
+ int m_params;
+ int m_num_actions;
+};
+
+class TestSimplePolicy : public IPolicy<SimpleContext>
+{
+public:
+ TestSimplePolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
+ u32 Choose_Action(SimpleContext& context)
+ {
+ return m_params % m_num_actions + 1; // action id is one-based
+ }
+private:
+ int m_params;
+ int m_num_actions;
+};
+
+class TestSimpleRecorder : public IRecorder<SimpleContext>
+{
+public:
+ virtual void Record(SimpleContext& context, u32 action, float probability, string unique_key)
+ {
+ m_interactions.push_back({ context, action, probability, unique_key });
+ }
+
+ vector<TestInteraction<SimpleContext>> Get_All_Interactions()
+ {
+ return m_interactions;
+ }
+
+private:
+ vector<TestInteraction<SimpleContext>> m_interactions;
+};
+
+// Return action outside valid range
+class TestBadPolicy : public IPolicy<TestContext>
+{
+public:
+ u32 Choose_Action(TestContext& context)
+ {
+ return 100;
+ }
+};
+
+class TestRecorder : public IRecorder<TestContext>
+{
+public:
+ virtual void Record(TestContext& context, u32 action, float probability, string unique_key)
+ {
+ m_interactions.push_back({ context, action, probability, unique_key });
+ }
+
+ vector<TestInteraction<TestContext>> Get_All_Interactions()
+ {
+ return m_interactions;
+ }
+
+private:
+ vector<TestInteraction<TestContext>> m_interactions;
+};
diff --git a/java/pom.xml b/java/pom.xml
index 5ef652bb..db48a576 100644
--- a/java/pom.xml
+++ b/java/pom.xml
@@ -1,12 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
-<project xmlns="http://maven.apache.org/POM/4.0.0"
- xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
- xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>vw</groupId>
<artifactId>vw-jni-${os.version}-${os.arch}</artifactId>
- <version>1.0-SNAPSHOT</version>
+ <version>1.0.0-SNAPSHOT</version>
+ <name>Vowpal Wabbit JNI Layer</name>
<properties>
<slf4j.version>1.7.6</slf4j.version>
@@ -28,14 +27,6 @@
</dependencies>
<build>
- <resources>
- <resource>
- <directory>${project.build.directory}</directory>
- <includes>
- <include>vw_jni.lib</include>
- </includes>
- </resource>
- </resources>
<testResources>
<testResource>
<directory>${project.build.directory}</directory>
@@ -68,7 +59,7 @@
<configuration>
<exportAntProperties>true</exportAntProperties>
<target>
- <exec executable="make" failonerror="true"/>
+ <exec executable="make" failonerror="true" />
</target>
</configuration>
</execution>
@@ -82,7 +73,7 @@
<exportAntProperties>true</exportAntProperties>
<target>
<exec executable="make" failonerror="true">
- <arg value="clean"/>
+ <arg value="clean" />
</exec>
</target>
</configuration>
@@ -145,6 +136,53 @@
<redirectTestOutputToFile>true</redirectTestOutputToFile>
</configuration>
</plugin>
+ <plugin>
+ <artifactId>maven-resources-plugin</artifactId>
+ <version>2.7</version>
+ <executions>
+ <execution>
+ <id>copy-resources</id>
+ <phase>process-classes</phase>
+ <goals>
+ <goal>copy-resources</goal>
+ </goals>
+ <configuration>
+ <outputDirectory>${project.build.outputDirectory}</outputDirectory>
+ <resources>
+ <resource>
+ <directory>${project.build.directory}</directory>
+ <includes>
+ <include>vw_jni.lib</include>
+ </includes>
+ </resource>
+ </resources>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-release-plugin</artifactId>
+ <version>2.4.2</version>
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.maven.scm</groupId>
+ <artifactId>maven-scm-provider-gitexe</artifactId>
+ <version>1.3</version>
+ </dependency>
+ </dependencies>
+ <executions>
+ <execution>
+ <id>default</id>
+ <goals>
+ <goal>perform</goal>
+ </goals>
+ <configuration>
+ <pomFileName>java/pom.xml</pomFileName>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
</plugins>
</build>
-</project> \ No newline at end of file
+</project>
diff --git a/java/src/main/c++/vw_VWScorer.cc b/java/src/main/c++/vw_VWScorer.cc
index 652d722b..dce1869e 100644
--- a/java/src/main/c++/vw_VWScorer.cc
+++ b/java/src/main/c++/vw_VWScorer.cc
@@ -4,13 +4,11 @@
vw* vw;
-JNIEXPORT void JNICALL Java_vw_VWScorer_initialize
- (JNIEnv *env, jobject obj, jstring command) {
+JNIEXPORT void JNICALL Java_vw_VWScorer_initialize (JNIEnv *env, jobject obj, jstring command) {
vw = VW::initialize(env->GetStringUTFChars(command, NULL));
}
-JNIEXPORT jfloat JNICALL Java_vw_VWScorer_getPrediction
- (JNIEnv *env, jobject obj, jstring example_string) {
+JNIEXPORT jfloat JNICALL Java_vw_VWScorer_getPrediction (JNIEnv *env, jobject obj, jstring example_string) {
example *vec2 = VW::read_example(*vw, env->GetStringUTFChars(example_string, NULL));
vw->l->predict(*vec2);
float prediction;
@@ -18,7 +16,7 @@ JNIEXPORT jfloat JNICALL Java_vw_VWScorer_getPrediction
prediction = vec2->pred.scalar;
else
prediction = vec2->pred.multiclass;
- VW::finish_example(*vw, vec2);
+ VW::finish_example(*vw, vec2);
return prediction;
}
diff --git a/java/src/test/java/vw/VWScorerTest.java b/java/src/test/java/vw/VWScorerTest.java
index 740fefce..05b6ae61 100644
--- a/java/src/test/java/vw/VWScorerTest.java
+++ b/java/src/test/java/vw/VWScorerTest.java
@@ -3,6 +3,8 @@ package vw;
import org.junit.BeforeClass;
import org.junit.Test;
+import java.io.IOException;
+
import static org.junit.Assert.assertEquals;
/**
@@ -12,19 +14,23 @@ public class VWScorerTest {
private static VWScorer scorer;
@BeforeClass
- public static void setup() {
- scorer = new VWScorer("-i src/test/resources/house.model --quiet -t");
+ public static void setup() throws IOException, InterruptedException {
+ // Since we want this test to continue to work between VW changes, we can't store the model
+ // Instead, we'll make a new model for each test
+ String vwModel = "target/house.model";
+ Runtime.getRuntime().exec("../vowpalwabbit/vw -d src/test/resources/house.vw -f " + vwModel).waitFor();
+ scorer = new VWScorer("--quiet -t -i " + vwModel);
}
@Test
public void testBlank() {
float prediction = scorer.getPrediction("| ");
- assertEquals(0.07496, prediction, 0.00001);
+ assertEquals(0.075, prediction, 0.001);
}
@Test
public void testLine() {
float prediction = scorer.getPrediction("| price:0.23 sqft:0.25 age:0.05 2006");
- assertEquals(0.118467, prediction, 0.00001);
+ assertEquals(0.118, prediction, 0.001);
}
}
diff --git a/java/src/test/resources/house.model b/java/src/test/resources/house.model
deleted file mode 100644
index 20a3f310..00000000
--- a/java/src/test/resources/house.model
+++ /dev/null
Binary files differ
diff --git a/java/src/test/resources/house.vw b/java/src/test/resources/house.vw
new file mode 100644
index 00000000..87da95e5
--- /dev/null
+++ b/java/src/test/resources/house.vw
@@ -0,0 +1,3 @@
+0 | price:.23 sqft:.25 age:.05 2006
+1 2 'second_house | price:.18 sqft:.15 age:.35 1976
+0 1 0.5 'third_house | price:.53 sqft:.32 age:.87 1924
diff --git a/library/Makefile b/library/Makefile
index 8f7deb35..5ca18f35 100644
--- a/library/Makefile
+++ b/library/Makefile
@@ -40,3 +40,4 @@ gd_mf_weights: gd_mf_weights.cc ../vowpalwabbit/libvw.a
clean:
rm -f *.o ezexample_predict ezexample_train library_example test_search recommend ezexample_predict_threaded
+.PHONY: all clean
diff --git a/library/ezexample_predict.cc b/library/ezexample_predict.cc
index db061f61..22c8a86d 100644
--- a/library/ezexample_predict.cc
+++ b/library/ezexample_predict.cc
@@ -1,55 +1,55 @@
-#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-using namespace std;
-
-int main(int argc, char *argv[])
-{
- string init_string = "-t -q st --hash all --noconstant --ldf_override s -i ";
- if (argc > 1)
- init_string += argv[1];
- else
- init_string += "train.w";
-
- cerr << "initializing with: '" << init_string << "'" << endl;
-
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
- vw* vw = VW::initialize(init_string); // "-t -q st --hash all --noconstant --ldf_override s -i train.w");
-
- {
- // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
- ezexample ex(vw, false); // don't need multiline
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- cerr << ex.predict_partial() << endl;
-
- // ex.clear_features();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("2");
- cerr << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace, and add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("2");
- cerr << ex.predict_partial() << endl;
- }
-
- // AND FINISH UP
- VW::finish(*vw);
-}
+#include <stdio.h>
+#include "../vowpalwabbit/parser.h"
+#include "../vowpalwabbit/vw.h"
+#include "../vowpalwabbit/ezexample.h"
+
+using namespace std;
+
+int main(int argc, char *argv[])
+{
+ string init_string = "-t -q st --hash all --noconstant --ldf_override s -i ";
+ if (argc > 1)
+ init_string += argv[1];
+ else
+ init_string += "train.w";
+
+ cerr << "initializing with: '" << init_string << "'" << endl;
+
+ // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
+ vw* vw = VW::initialize(init_string); // "-t -q st --hash all --noconstant --ldf_override s -i train.w");
+
+ {
+ // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
+ ezexample ex(vw, false); // don't need multiline
+ ex(vw_namespace('s'))
+ ("p^the_man")
+ ("w^the")
+ ("w^man")
+ (vw_namespace('t'))
+ ("p^le_homme")
+ ("w^le")
+ ("w^homme");
+ ex.set_label("1");
+ cerr << ex.predict_partial() << endl;
+
+ // ex.clear_features();
+
+ --ex; // remove the most recent namespace
+ ex(vw_namespace('t'))
+ ("p^un_homme")
+ ("w^un")
+ ("w^homme");
+ ex.set_label("2");
+ cerr << ex.predict_partial() << endl;
+
+ --ex; // remove the most recent namespace, and add features with explicit ns
+ ex('t', "p^un_homme")
+ ('t', "w^un")
+ ('t', "w^homme");
+ ex.set_label("2");
+ cerr << ex.predict_partial() << endl;
+ }
+
+ // AND FINISH UP
+ VW::finish(*vw);
+}
diff --git a/library/ezexample_predict_threaded.cc b/library/ezexample_predict_threaded.cc
index 0fa5b1e6..c7c39e85 100644
--- a/library/ezexample_predict_threaded.cc
+++ b/library/ezexample_predict_threaded.cc
@@ -1,149 +1,149 @@
-#include <stdio.h>
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-#include <boost/thread/thread.hpp>
-
-using namespace std;
-
-int runcount = 100;
-
-class Worker
-{
-public:
- Worker(vw & instance, string & vw_init_string, vector<double> & ref)
- : m_vw(instance)
- , m_referenceValues(ref)
- , vw_init_string(vw_init_string)
- { }
-
- void operator()()
- {
- m_vw_parser = VW::initialize(vw_init_string);
- if (m_vw_parser == NULL) {
- cerr << "cannot initialize vw parser" << endl;
- exit(-1);
- }
-
- int errorCount = 0;
- for (int i = 0; i < runcount; ++i)
- {
- vector<double>::iterator it = m_referenceValues.begin();
- ezexample ex(&m_vw, false, m_vw_parser);
-
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- //if (*it != pred) { cerr << "fail!" << endl; ++errorCount; }
- //VW::finish_example(m_vw, vec2);
- ++it;
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- ++it;
-
- --ex; // remove the most recent namespace
- // add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("1");
- if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
- ++it;
-
- //cout << "."; cout.flush();
- }
- cerr << "error count = " << errorCount << endl;
- VW::finish(*m_vw_parser);
- m_vw_parser = NULL;
- }
-
-private:
- vw & m_vw;
- vw * m_vw_parser;
- vector<double> & m_referenceValues;
- string & vw_init_string;
-};
-
-int main(int argc, char *argv[])
-{
- if (argc != 3)
- {
- cerr << "need two args: threadcount runcount" << endl;
- return 1;
- }
- int threadcount = atoi(argv[1]);
- runcount = atoi(argv[2]);
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
- string vw_init_string_all = "-t --ldf_override s --quiet -q st --noconstant --hash all -i train.w";
- string vw_init_string_parser = "-t --ldf_override s --quiet -q st --noconstant --hash all --noop"; // this needs to have enough arguments to get the parser right
- vw*vw = VW::initialize(vw_init_string_all);
- vector<double> results;
-
- // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
- {
- ezexample ex(vw, false);
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near zero = " << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near one = " << ex.predict_partial() << endl;
-
- --ex; // remove the most recent namespace
- // add features with explicit ns
- ex('t', "p^un_homme")
- ('t', "w^un")
- ('t', "w^homme");
- ex.set_label("1");
- results.push_back(ex.predict_partial());
- cerr << "should be near one = " << ex.predict_partial() << endl;
- }
-
- if (threadcount == 0)
- {
- Worker w(*vw, vw_init_string_parser, results);
- w();
- }
- else
- {
- boost::thread_group tg;
- for (int t = 0; t < threadcount; ++t)
- {
- cerr << "starting thread " << t << endl;
- boost::thread * pt = tg.create_thread(Worker(*vw, vw_init_string_parser, results));
- }
- tg.join_all();
- cerr << "finished!" << endl;
- }
-
-
- // AND FINISH UP
- VW::finish(*vw);
-}
+#include <stdio.h>
+#include "../vowpalwabbit/vw.h"
+#include "../vowpalwabbit/ezexample.h"
+
+#include <boost/thread/thread.hpp>
+
+using namespace std;
+
+int runcount = 100;
+
+class Worker
+{
+public:
+ Worker(vw & instance, string & vw_init_string, vector<double> & ref)
+ : m_vw(instance)
+ , m_referenceValues(ref)
+ , vw_init_string(vw_init_string)
+ { }
+
+ void operator()()
+ {
+ m_vw_parser = VW::initialize(vw_init_string);
+ if (m_vw_parser == NULL) {
+ cerr << "cannot initialize vw parser" << endl;
+ exit(-1);
+ }
+
+ int errorCount = 0;
+ for (int i = 0; i < runcount; ++i)
+ {
+ vector<double>::iterator it = m_referenceValues.begin();
+ ezexample ex(&m_vw, false, m_vw_parser);
+
+ ex(vw_namespace('s'))
+ ("p^the_man")
+ ("w^the")
+ ("w^man")
+ (vw_namespace('t'))
+ ("p^le_homme")
+ ("w^le")
+ ("w^homme");
+ ex.set_label("1");
+ if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
+ //if (*it != pred) { cerr << "fail!" << endl; ++errorCount; }
+ //VW::finish_example(m_vw, vec2);
+ ++it;
+
+ --ex; // remove the most recent namespace
+ ex(vw_namespace('t'))
+ ("p^un_homme")
+ ("w^un")
+ ("w^homme");
+ ex.set_label("1");
+ if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
+ ++it;
+
+ --ex; // remove the most recent namespace
+ // add features with explicit ns
+ ex('t', "p^un_homme")
+ ('t', "w^un")
+ ('t', "w^homme");
+ ex.set_label("1");
+ if (*it != ex()) { cerr << "fail!" << endl; ++errorCount; }
+ ++it;
+
+ //cout << "."; cout.flush();
+ }
+ cerr << "error count = " << errorCount << endl;
+ VW::finish(*m_vw_parser);
+ m_vw_parser = NULL;
+ }
+
+private:
+ vw & m_vw;
+ vw * m_vw_parser;
+ vector<double> & m_referenceValues;
+ string & vw_init_string;
+};
+
+int main(int argc, char *argv[])
+{
+ if (argc != 3)
+ {
+ cerr << "need two args: threadcount runcount" << endl;
+ return 1;
+ }
+ int threadcount = atoi(argv[1]);
+ runcount = atoi(argv[2]);
+ // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS READS IN A MODEL FROM train.w
+ string vw_init_string_all = "-t --ldf_override s --quiet -q st --noconstant --hash all -i train.w";
+ string vw_init_string_parser = "-t --ldf_override s --quiet -q st --noconstant --hash all --noop"; // this needs to have enough arguments to get the parser right
+ vw*vw = VW::initialize(vw_init_string_all);
+ vector<double> results;
+
+ // HAL'S SPIFFY INTERFACE USING C++ CRAZINESS
+ {
+ ezexample ex(vw, false);
+ ex(vw_namespace('s'))
+ ("p^the_man")
+ ("w^the")
+ ("w^man")
+ (vw_namespace('t'))
+ ("p^le_homme")
+ ("w^le")
+ ("w^homme");
+ ex.set_label("1");
+ results.push_back(ex.predict_partial());
+ cerr << "should be near zero = " << ex.predict_partial() << endl;
+
+ --ex; // remove the most recent namespace
+ ex(vw_namespace('t'))
+ ("p^un_homme")
+ ("w^un")
+ ("w^homme");
+ ex.set_label("1");
+ results.push_back(ex.predict_partial());
+ cerr << "should be near one = " << ex.predict_partial() << endl;
+
+ --ex; // remove the most recent namespace
+ // add features with explicit ns
+ ex('t', "p^un_homme")
+ ('t', "w^un")
+ ('t', "w^homme");
+ ex.set_label("1");
+ results.push_back(ex.predict_partial());
+ cerr << "should be near one = " << ex.predict_partial() << endl;
+ }
+
+ if (threadcount == 0)
+ {
+ Worker w(*vw, vw_init_string_parser, results);
+ w();
+ }
+ else
+ {
+ boost::thread_group tg;
+ for (int t = 0; t < threadcount; ++t)
+ {
+ cerr << "starting thread " << t << endl;
+ boost::thread * pt = tg.create_thread(Worker(*vw, vw_init_string_parser, results));
+ }
+ tg.join_all();
+ cerr << "finished!" << endl;
+ }
+
+
+ // AND FINISH UP
+ VW::finish(*vw);
+}
diff --git a/library/ezexample_train.cc b/library/ezexample_train.cc
index a0f66a99..9a8af8e0 100644
--- a/library/ezexample_train.cc
+++ b/library/ezexample_train.cc
@@ -1,72 +1,72 @@
-#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-#include "../vowpalwabbit/ezexample.h"
-
-using namespace std;
-
-void run(vw*vw) {
- ezexample ex(vw, true); // we're doing csoaa_ldf so we need multiline examples
-
- /// BEGIN FIRST MULTILINE EXAMPLE
- ex(vw_namespace('s'))
- ("p^the_man")
- ("w^the")
- ("w^man")
- (vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
-
- ex.set_label("1:1");
- ex.train();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
-
- ex.set_label("2:0");
- ex.train();
-
- // push it through VW for training
- ex.finish();
-
- /// BEGIN SECOND MULTILINE EXAMPLE
- ex(vw_namespace('s'))
- ("p^a_man")
- ("w^a")
- ("w^man")
- (vw_namespace('t'))
- ("p^un_homme")
- ("w^un")
- ("w^homme");
-
- ex.set_label("1:0");
- ex.train();
-
- --ex; // remove the most recent namespace
- ex(vw_namespace('t'))
- ("p^le_homme")
- ("w^le")
- ("w^homme");
-
- ex.set_label("2:1");
- ex.train();
-
- // push it through VW for training
- ex.finish();
-}
-
-int main(int argc, char *argv[])
-{
- // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS WILL STORE A MODEL TO train.ezw
- vw* vw = VW::initialize("--hash all -q st --noconstant -f train.w --quiet --csoaa_ldf m");
-
- run(vw);
-
- // AND FINISH UP
- cerr << "ezexample_train finish"<<endl;
- VW::finish(*vw);
-}
+#include <stdio.h>
+#include "../vowpalwabbit/parser.h"
+#include "../vowpalwabbit/vw.h"
+#include "../vowpalwabbit/ezexample.h"
+
+using namespace std;
+
+void run(vw*vw) {
+ ezexample ex(vw, true); // we're doing csoaa_ldf so we need multiline examples
+
+ /// BEGIN FIRST MULTILINE EXAMPLE
+ ex(vw_namespace('s'))
+ ("p^the_man")
+ ("w^the")
+ ("w^man")
+ (vw_namespace('t'))
+ ("p^un_homme")
+ ("w^un")
+ ("w^homme");
+
+ ex.set_label("1:1");
+ ex.train();
+
+ --ex; // remove the most recent namespace
+ ex(vw_namespace('t'))
+ ("p^le_homme")
+ ("w^le")
+ ("w^homme");
+
+ ex.set_label("2:0");
+ ex.train();
+
+ // push it through VW for training
+ ex.finish();
+
+ /// BEGIN SECOND MULTILINE EXAMPLE
+ ex(vw_namespace('s'))
+ ("p^a_man")
+ ("w^a")
+ ("w^man")
+ (vw_namespace('t'))
+ ("p^un_homme")
+ ("w^un")
+ ("w^homme");
+
+ ex.set_label("1:0");
+ ex.train();
+
+ --ex; // remove the most recent namespace
+ ex(vw_namespace('t'))
+ ("p^le_homme")
+ ("w^le")
+ ("w^homme");
+
+ ex.set_label("2:1");
+ ex.train();
+
+ // push it through VW for training
+ ex.finish();
+}
+
+int main(int argc, char *argv[])
+{
+ // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE -- THIS WILL STORE A MODEL TO train.ezw
+ vw* vw = VW::initialize("--hash all -q st --noconstant -f train.w --quiet --csoaa_ldf m");
+
+ run(vw);
+
+ // AND FINISH UP
+ cerr << "ezexample_train finish"<<endl;
+ VW::finish(*vw);
+}
diff --git a/library/gd_mf_weights.cc b/library/gd_mf_weights.cc
index 4b394775..268739a4 100644
--- a/library/gd_mf_weights.cc
+++ b/library/gd_mf_weights.cc
@@ -52,12 +52,17 @@ int main(int argc, char *argv[])
vw* model = VW::initialize(vwparams);
model->audit = true;
+ string target("--rank ");
+ size_t loc = vwparams.find(target);
+ const char* location = vwparams.c_str()+loc+target.size();
+ size_t rank = atoi(location);
+
// global model params
unsigned char left_ns = model->pairs[0][0];
unsigned char right_ns = model->pairs[0][1];
weight* weights = model->reg.weight_vector;
size_t mask = model->reg.weight_mask;
-
+
// const char *filename = argv[0];
FILE* file = fopen(infile.c_str(), "r");
char* line = NULL;
@@ -86,7 +91,7 @@ int main(int argc, char *argv[])
left_linear << f->feature << '\t' << weights[f->weight_index & mask];
left_quadratic << f->feature;
- for (size_t k = 1; k <= model->rank; k++)
+ for (size_t k = 1; k <= rank; k++)
left_quadratic << '\t' << weights[(f->weight_index + k) & mask];
}
left_linear << endl;
@@ -101,8 +106,8 @@ int main(int argc, char *argv[])
right_linear << f->feature << '\t' << weights[f->weight_index & mask];
right_quadratic << f->feature;
- for (size_t k = 1; k <= model->rank; k++)
- right_quadratic << '\t' << weights[(f->weight_index + k + model->rank) & mask];
+ for (size_t k = 1; k <= rank; k++)
+ right_quadratic << '\t' << weights[(f->weight_index + k + rank) & mask];
}
right_linear << endl;
right_quadratic << endl;
diff --git a/library/library_example.cc b/library/library_example.cc
index 8abfc1f7..d7c186c2 100644
--- a/library/library_example.cc
+++ b/library/library_example.cc
@@ -1,62 +1,62 @@
-#include <stdio.h>
-#include "../vowpalwabbit/parser.h"
-#include "../vowpalwabbit/vw.h"
-
-using namespace std;
-
-
-inline feature vw_feature_from_string(vw& v, string fstr, unsigned long seed, float val)
-{
- uint32_t foo = VW::hash_feature(v, fstr, seed);
- feature f = { val, foo};
- return f;
-}
-
-int main(int argc, char *argv[])
-{
- vw* model = VW::initialize("--hash all -q st --noconstant -i train.w -f train2.vw");
-
- example *vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme");
- model->learn(vec2);
- cerr << "p2 = " << vec2->pred.scalar << endl;
- VW::finish_example(*model, vec2);
-
- vector< VW::feature_space > ec_info;
- vector<feature> s_features, t_features;
- uint32_t s_hash = VW::hash_space(*model, "s");
- uint32_t t_hash = VW::hash_space(*model, "t");
- s_features.push_back( vw_feature_from_string(*model, "p^the_man", s_hash, 1.0) );
- s_features.push_back( vw_feature_from_string(*model, "w^the", s_hash, 1.0) );
- s_features.push_back( vw_feature_from_string(*model, "w^man", s_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "p^le_homme", t_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "w^le", t_hash, 1.0) );
- t_features.push_back( vw_feature_from_string(*model, "w^homme", t_hash, 1.0) );
- ec_info.push_back( VW::feature_space('s', s_features) );
- ec_info.push_back( VW::feature_space('t', t_features) );
- example* vec3 = VW::import_example(*model, ec_info);
-
- model->learn(vec3);
- cerr << "p3 = " << vec3->pred.scalar << endl;
- VW::finish_example(*model, vec3);
-
- VW::finish(*model);
-
- vw* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw");
- vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme");
- model2->learn(vec2);
- cerr << "p4 = " << vec2->pred.scalar << endl;
-
- size_t len=0;
- VW::primitive_feature_space* pfs = VW::export_example(*model2, vec2, len);
- for (size_t i = 0; i < len; i++)
- {
- cout << "namespace = " << pfs[i].name;
- for (size_t j = 0; j < pfs[i].len; j++)
- cout << " " << pfs[i].fs[j].weight_index << ":" << pfs[i].fs[j].x << ":" << VW::get_weight(*model2, pfs[i].fs[j].weight_index, 0);
- cout << endl;
- }
-
- VW::finish_example(*model2, vec2);
- VW::finish(*model2);
-}
-
+#include <stdio.h>
+#include "../vowpalwabbit/parser.h"
+#include "../vowpalwabbit/vw.h"
+
+using namespace std;
+
+
+inline feature vw_feature_from_string(vw& v, string fstr, unsigned long seed, float val)
+{
+ uint32_t foo = VW::hash_feature(v, fstr, seed);
+ feature f = { val, foo};
+ return f;
+}
+
+int main(int argc, char *argv[])
+{
+ vw* model = VW::initialize("--hash all -q st --noconstant -i train.w -f train2.vw");
+
+ example *vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme");
+ model->learn(vec2);
+ cerr << "p2 = " << vec2->pred.scalar << endl;
+ VW::finish_example(*model, vec2);
+
+ vector< VW::feature_space > ec_info;
+ vector<feature> s_features, t_features;
+ uint32_t s_hash = VW::hash_space(*model, "s");
+ uint32_t t_hash = VW::hash_space(*model, "t");
+ s_features.push_back( vw_feature_from_string(*model, "p^the_man", s_hash, 1.0) );
+ s_features.push_back( vw_feature_from_string(*model, "w^the", s_hash, 1.0) );
+ s_features.push_back( vw_feature_from_string(*model, "w^man", s_hash, 1.0) );
+ t_features.push_back( vw_feature_from_string(*model, "p^le_homme", t_hash, 1.0) );
+ t_features.push_back( vw_feature_from_string(*model, "w^le", t_hash, 1.0) );
+ t_features.push_back( vw_feature_from_string(*model, "w^homme", t_hash, 1.0) );
+ ec_info.push_back( VW::feature_space('s', s_features) );
+ ec_info.push_back( VW::feature_space('t', t_features) );
+ example* vec3 = VW::import_example(*model, ec_info);
+
+ model->learn(vec3);
+ cerr << "p3 = " << vec3->pred.scalar << endl;
+ VW::finish_example(*model, vec3);
+
+ VW::finish(*model);
+
+ vw* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw");
+ vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme");
+ model2->learn(vec2);
+ cerr << "p4 = " << vec2->pred.scalar << endl;
+
+ size_t len=0;
+ VW::primitive_feature_space* pfs = VW::export_example(*model2, vec2, len);
+ for (size_t i = 0; i < len; i++)
+ {
+ cout << "namespace = " << pfs[i].name;
+ for (size_t j = 0; j < pfs[i].len; j++)
+ cout << " " << pfs[i].fs[j].weight_index << ":" << pfs[i].fs[j].x << ":" << VW::get_weight(*model2, pfs[i].fs[j].weight_index, 0);
+ cout << endl;
+ }
+
+ VW::finish_example(*model2, vec2);
+ VW::finish(*model2);
+}
+
diff --git a/python/Makefile b/python/Makefile
index b268663d..845fb15f 100644
--- a/python/Makefile
+++ b/python/Makefile
@@ -43,3 +43,5 @@ pylibvw.o: pylibvw.cc
clean:
rm -f *.o $(PYLIBVW)
+
+.PHONY: all clean things
diff --git a/test/test-sets/ref/ml100k_small.stderr b/test/test-sets/ref/ml100k_small.stderr
index c8bcbb51..92b91e64 100644
--- a/test/test-sets/ref/ml100k_small.stderr
+++ b/test/test-sets/ref/ml100k_small.stderr
@@ -3,7 +3,6 @@ Num weight bits = 16
learning rate = 10
initial_t = 1
power_t = 0.5
-rank = 10
using no cache
Reading datafile = test-sets/ml100k_small_test
num sources = 1
diff --git a/test/train-sets/ref/ml100k_small.stderr b/test/train-sets/ref/ml100k_small.stderr
index 3f3d6e25..462dca37 100644
--- a/test/train-sets/ref/ml100k_small.stderr
+++ b/test/train-sets/ref/ml100k_small.stderr
@@ -6,7 +6,6 @@ learning rate = 0.05
initial_t = 1
power_t = 0
decay_learning_rate = 0.97
-rank = 10
creating cache_file = train-sets/ml100k_small_train.cache
Reading datafile = train-sets/ml100k_small_train
num sources = 1
diff --git a/test/train-sets/ref/sequencespan_data.nonldf-bilou.test.stderr b/test/train-sets/ref/sequencespan_data.nonldf-bilou.test.stderr
index aad69ced..5b3afa22 100644
--- a/test/train-sets/ref/sequencespan_data.nonldf-bilou.test.stderr
+++ b/test/train-sets/ref/sequencespan_data.nonldf-bilou.test.stderr
@@ -1,10 +1,10 @@
only testing
+switching to BILOU encoding for sequence span labeling
Num weight bits = 18
learning rate = 10
initial_t = 1
power_t = 0.5
predictions = sequencespan_data.predict
-switching to BILOU encoding for sequence span labeling
using no cache
Reading datafile = train-sets/sequencespan_data
num sources = 1
diff --git a/test/train-sets/ref/sequencespan_data.nonldf-bilou.train.stderr b/test/train-sets/ref/sequencespan_data.nonldf-bilou.train.stderr
index f0c914e0..a4fc72a6 100644
--- a/test/train-sets/ref/sequencespan_data.nonldf-bilou.train.stderr
+++ b/test/train-sets/ref/sequencespan_data.nonldf-bilou.train.stderr
@@ -1,10 +1,10 @@
final_regressor = models/sequencespan_data.model
+switching to BILOU encoding for sequence span labeling
Num weight bits = 18
learning rate = 10
initial_t = 1
power_t = 0.5
decay_learning_rate = 1
-switching to BILOU encoding for sequence span labeling
creating cache_file = train-sets/sequencespan_data.cache
Reading datafile = train-sets/sequencespan_data
num sources = 1
diff --git a/test/train-sets/ref/sequencespan_data.nonldf.train.stderr b/test/train-sets/ref/sequencespan_data.nonldf.train.stderr
index a75bbfd5..2ea5e7e1 100644
--- a/test/train-sets/ref/sequencespan_data.nonldf.train.stderr
+++ b/test/train-sets/ref/sequencespan_data.nonldf.train.stderr
@@ -1,10 +1,10 @@
final_regressor = models/sequencespan_data.model
+no rollout!
Num weight bits = 18
learning rate = 10
initial_t = 1
power_t = 0.5
decay_learning_rate = 1
-no rollout!
creating cache_file = train-sets/sequencespan_data.cache
Reading datafile = train-sets/sequencespan_data
num sources = 1
diff --git a/vowpalwabbit/Makefile b/vowpalwabbit/Makefile
index 79893902..9ecd681c 100644
--- a/vowpalwabbit/Makefile
+++ b/vowpalwabbit/Makefile
@@ -52,3 +52,4 @@ install: $(BINARIES)
clean:
rm -f *.o *.d $(BINARIES) *~ $(MANPAGES) libvw.a
+.PHONY: all clean install test things
diff --git a/vowpalwabbit/Makefile.am b/vowpalwabbit/Makefile.am
index 676d6f4e..291a4702 100644
--- a/vowpalwabbit/Makefile.am
+++ b/vowpalwabbit/Makefile.am
@@ -4,7 +4,7 @@ liballreduce_la_SOURCES = allreduce.cc
bin_PROGRAMS = vw active_interactor
-libvw_la_SOURCES = hash.cc memory.cc global_data.cc io_buf.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc rand48.cc simple_label.cc multiclass.cc oaa.cc ect.cc autolink.cc binary.cc lrq.cc cost_sensitive.cc csoaa.cc cb.cc cb_algs.cc search.cc search_sequencetask.cc search_dep_parser.cc search_hooktask.cc search_multiclasstask.cc search_entityrelationtask.cc parse_example.cc scorer.cc network.cc parse_args.cc accumulate.cc gd.cc learner.cc lda_core.cc gd_mf.cc mf.cc bfgs.cc noop.cc print.cc example.cc parser.cc loss_functions.cc sender.cc nn.cc bs.cc cbify.cc topk.cc stagewise_poly.cc log_multi.cc active.cc kernel_svm.cc best_constant.cc ftrl_proximal.cc
+libvw_la_SOURCES = hash.cc global_data.cc io_buf.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc rand48.cc simple_label.cc multiclass.cc oaa.cc ect.cc autolink.cc binary.cc lrq.cc cost_sensitive.cc csoaa.cc cb.cc cb_algs.cc search.cc search_sequencetask.cc search_dep_parser.cc search_hooktask.cc search_multiclasstask.cc search_entityrelationtask.cc parse_example.cc scorer.cc network.cc parse_args.cc accumulate.cc gd.cc learner.cc lda_core.cc gd_mf.cc mf.cc bfgs.cc noop.cc print.cc example.cc parser.cc loss_functions.cc sender.cc nn.cc bs.cc cbify.cc topk.cc stagewise_poly.cc log_multi.cc active.cc kernel_svm.cc best_constant.cc ftrl_proximal.cc
libvw_c_wrapper_la_SOURCES = vwdll.cpp
diff --git a/vowpalwabbit/accumulate.h b/vowpalwabbit/accumulate.h
index c01ac5fe..4d507a60 100644
--- a/vowpalwabbit/accumulate.h
+++ b/vowpalwabbit/accumulate.h
@@ -1,13 +1,13 @@
-/*
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD
-license as described in the file LICENSE.
- */
-//This implements various accumulate functions building on top of allreduce.
-#pragma once
-#include "global_data.h"
-
-void accumulate(vw& all, std::string master_location, regressor& reg, size_t o);
-float accumulate_scalar(vw& all, std::string master_location, float local_sum);
-void accumulate_weighted_avg(vw& all, std::string master_location, regressor& reg);
-void accumulate_avg(vw& all, std::string master_location, regressor& reg, size_t o);
+/*
+Copyright (c) by respective owners including Yahoo!, Microsoft, and
+individual contributors. All rights reserved. Released under a BSD
+license as described in the file LICENSE.
+ */
+//This implements various accumulate functions building on top of allreduce.
+#pragma once
+#include "global_data.h"
+
+void accumulate(vw& all, std::string master_location, regressor& reg, size_t o);
+float accumulate_scalar(vw& all, std::string master_location, float local_sum);
+void accumulate_weighted_avg(vw& all, std::string master_location, regressor& reg);
+void accumulate_avg(vw& all, std::string master_location, regressor& reg, size_t o);
diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc
index a1070be3..54094331 100644
--- a/vowpalwabbit/active.cc
+++ b/vowpalwabbit/active.cc
@@ -151,27 +151,34 @@ namespace ACTIVE {
VW::finish_example(all,&ec);
}
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{//parse and set arguments
- active& data = calloc_or_die<active>();
-
- po::options_description active_opts("Active Learning options");
- active_opts.add_options()
+ new_options(all, "Active Learning options")
+ ("active", "enable active learning");
+ if (missing_required(all)) return NULL;
+ new_options(all)
("simulation", "active learning simulation mode")
- ("mellowness", po::value<float>(&(data.active_c0)), "active learning mellowness parameter c_0. Default 8")
- ;
- vm = add_options(all, active_opts);
+ ("mellowness", po::value<float>(), "active learning mellowness parameter c_0. Default 8");
+ add_options(all);
+
+ active& data = calloc_or_die<active>();
+ data.active_c0 = 8;
data.all=&all;
+ if (all.vm.count("mellowness"))
+ data.active_c0 = all.vm["mellowness"].as<float>();
+
+ base_learner* base = setup_base(all);
+
//Create new learner
learner<active>* ret;
- if (vm.count("simulation"))
- ret = &init_learner(&data, all.l, predict_or_learn_simulation<true>,
+ if (all.vm.count("simulation"))
+ ret = &init_learner(&data, base, predict_or_learn_simulation<true>,
predict_or_learn_simulation<false>);
else
{
all.active = true;
- ret = &init_learner(&data, all.l, predict_or_learn_active<true>,
+ ret = &init_learner(&data, base, predict_or_learn_active<true>,
predict_or_learn_active<false>);
ret->set_finish_example(return_active_example);
}
diff --git a/vowpalwabbit/active.h b/vowpalwabbit/active.h
index d71950ab..b1145e8c 100644
--- a/vowpalwabbit/active.h
+++ b/vowpalwabbit/active.h
@@ -1,4 +1,2 @@
#pragma once
-namespace ACTIVE {
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace ACTIVE { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc
index 7cdaecef..707b38a6 100644
--- a/vowpalwabbit/autolink.cc
+++ b/vowpalwabbit/autolink.cc
@@ -37,15 +37,21 @@ namespace ALINK {
ec.total_sum_feat_sq -= sum_sq;
}
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm)
+ LEARNER::base_learner* setup(vw& all)
{
+ new_options(all,"Autolink options")
+ ("autolink", po::value<size_t>(), "create link function with polynomial d");
+ if(missing_required(all)) return NULL;
+
autolink& data = calloc_or_die<autolink>();
- data.d = (uint32_t)vm["autolink"].as<size_t>();
+ data.d = (uint32_t)all.vm["autolink"].as<size_t>();
data.stride_shift = all.reg.stride_shift;
*all.file_options << " --autolink " << data.d;
- LEARNER::learner<autolink>& ret = init_learner(&data, all.l, predict_or_learn<true>,
+ LEARNER::base_learner* base = setup_base(all);
+
+ LEARNER::learner<autolink>& ret = init_learner(&data, base, predict_or_learn<true>,
predict_or_learn<false>);
return make_base(ret);
}
diff --git a/vowpalwabbit/autolink.h b/vowpalwabbit/autolink.h
index 3bb70dc1..e7bfcf51 100644
--- a/vowpalwabbit/autolink.h
+++ b/vowpalwabbit/autolink.h
@@ -1,4 +1,2 @@
#pragma once
-namespace ALINK {
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace ALINK { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc
index fc4ee851..a14d085c 100644
--- a/vowpalwabbit/bfgs.cc
+++ b/vowpalwabbit/bfgs.cc
@@ -63,6 +63,9 @@ namespace BFGS
struct bfgs {
vw* all;
+ int m;
+ float rel_threshold; // termination threshold
+
double wolfe1_bound;
size_t final_pass;
@@ -247,7 +250,7 @@ void bfgs_iter_start(vw& all, bfgs& b, float* mem, int& lastj, double importance
origin = 0;
for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
- if (all.m>0)
+ if (b.m>0)
mem[(MEM_XT+origin)%b.mem_stride] = w[W_XT];
mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT];
g1_Hg1 += w[W_GT] * w[W_GT] * w[W_COND];
@@ -272,7 +275,7 @@ void bfgs_iter_middle(vw& all, bfgs& b, float* mem, double* rho, double* alpha,
float* w0 = w;
// implement conjugate gradient
- if (all.m==0) {
+ if (b.m==0) {
double g_Hy = 0.;
double g_Hg = 0.;
double y = 0.;
@@ -374,7 +377,7 @@ void bfgs_iter_middle(vw& all, bfgs& b, float* mem, double* rho, double* alpha,
mem = mem0;
w = w0;
- lastj = (lastj<all.m-1) ? lastj+1 : all.m-1;
+ lastj = (lastj<b.m-1) ? lastj+1 : b.m-1;
origin = (origin+b.mem_stride-2)%b.mem_stride;
for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT];
@@ -633,9 +636,9 @@ int process_pass(vw& all, bfgs& b) {
/********************************************************************/
else {
double rel_decrease = (b.previous_loss_sum-b.loss_sum)/b.previous_loss_sum;
- if (!nanpattern((float)rel_decrease) && b.backstep_on && fabs(rel_decrease)<all.rel_threshold) {
+ if (!nanpattern((float)rel_decrease) && b.backstep_on && fabs(rel_decrease)<b.rel_threshold) {
fprintf(stdout, "\nTermination condition reached in pass %ld: decrease in loss less than %.3f%%.\n"
- "If you want to optimize further, decrease termination threshold.\n", (long int)b.current_pass+1, all.rel_threshold*100.0);
+ "If you want to optimize further, decrease termination threshold.\n", (long int)b.current_pass+1, b.rel_threshold*100.0);
status = LEARN_CONV;
}
b.previous_loss_sum = b.loss_sum;
@@ -913,7 +916,7 @@ void save_load(bfgs& b, io_buf& model_file, bool read, bool text)
throw exception();
}
}
- int m = all->m;
+ int m = b.m;
b.mem_stride = (m==0) ? CG_EXTRA : 2*m;
b.mem = (float*) malloc(sizeof(float)*all->length()*(b.mem_stride));
@@ -965,10 +968,23 @@ void save_load(bfgs& b, io_buf& model_file, bool read, bool text)
b.backstep_on = true;
}
-base_learner* setup(vw& all, po::variables_map& vm)
+base_learner* setup(vw& all)
{
+ new_options(all, "LBFGS options")
+ ("bfgs", "use bfgs optimization")
+ ("conjugate_gradient", "use conjugate gradient based optimization");
+ if (missing_required(all)) return NULL;
+ new_options(all)
+ ("hessian_on", "use second derivative in line search")
+ ("mem", po::value<uint32_t>()->default_value(15), "memory in bfgs")
+ ("termination", po::value<float>()->default_value(0.001f),"Termination threshold");
+ add_options(all);
+
+ po::variables_map& vm = all.vm;
bfgs& b = calloc_or_die<bfgs>();
b.all = &all;
+ b.m = vm["mem"].as<uint32_t>();
+ b.rel_threshold = vm["termination"].as<float>();
b.wolfe1_bound = 0.01;
b.first_hessian_on=true;
b.first_pass = true;
@@ -979,16 +995,6 @@ base_learner* setup(vw& all, po::variables_map& vm)
b.no_win_counter = 0;
b.early_stop_thres = 3;
- po::options_description bfgs_opts("LBFGS options");
-
- bfgs_opts.add_options()
- ("hessian_on", "use second derivative in line search")
- ("mem", po::value<int>(&(all.m)), "memory in bfgs")
- ("conjugate_gradient", "use conjugate gradient based optimization")
- ("termination", po::value<float>(&(all.rel_threshold)),"Termination threshold");
-
- vm = add_options(all, bfgs_opts);
-
if(!all.holdout_set_off)
{
all.sd->holdout_best_loss = FLT_MAX;
@@ -996,11 +1002,11 @@ base_learner* setup(vw& all, po::variables_map& vm)
b.early_stop_thres = vm["early_terminate"].as< size_t>();
}
- if (vm.count("hessian_on") || all.m==0) {
+ if (vm.count("hessian_on") || b.m==0) {
all.hessian_on = true;
}
if (!all.quiet) {
- if (all.m>0)
+ if (b.m>0)
cerr << "enabling BFGS based optimization ";
else
cerr << "enabling conjugate gradient optimization via BFGS ";
diff --git a/vowpalwabbit/bfgs.h b/vowpalwabbit/bfgs.h
index 1960662b..e66a0df0 100644
--- a/vowpalwabbit/bfgs.h
+++ b/vowpalwabbit/bfgs.h
@@ -4,6 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace BFGS {
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace BFGS { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc
index 04810e26..d9afc63b 100644
--- a/vowpalwabbit/binary.cc
+++ b/vowpalwabbit/binary.cc
@@ -28,10 +28,16 @@ namespace BINARY {
}
}
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm)
- {
+LEARNER::base_learner* setup(vw& all)
+ {//parse and set arguments
+ new_options(all,"Binary options")
+ ("binary", "report loss as binary classification on -1,1");
+ if(missing_required(all)) return NULL;
+
+ //Create new learner
LEARNER::learner<char>& ret =
- LEARNER::init_learner<char>(NULL, all.l, predict_or_learn<true>, predict_or_learn<false>);
+ LEARNER::init_learner<char>(NULL, setup_base(all),
+ predict_or_learn<true>, predict_or_learn<false>);
return make_base(ret);
}
}
diff --git a/vowpalwabbit/binary.h b/vowpalwabbit/binary.h
index 609de90b..7b79946d 100644
--- a/vowpalwabbit/binary.h
+++ b/vowpalwabbit/binary.h
@@ -1,4 +1,2 @@
#pragma once
-namespace BINARY {
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace BINARY { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc
index 78828632..f66403dd 100644
--- a/vowpalwabbit/bs.cc
+++ b/vowpalwabbit/bs.cc
@@ -239,28 +239,27 @@ namespace BS {
d.pred_vec.~vector();
}
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{
+ new_options(all, "Bootstrap options")
+ ("bootstrap,B", po::value<size_t>(), "bootstrap mode with k rounds by online importance resampling");
+ if (missing_required(all)) return NULL;
+ new_options(all)("bs_type", po::value<string>(), "prediction type {mean,vote}");
+ add_options(all);
+
bs& data = calloc_or_die<bs>();
data.ub = FLT_MAX;
data.lb = -FLT_MAX;
-
- po::options_description bs_options("Bootstrap options");
- bs_options.add_options()
- ("bs_type", po::value<string>(), "prediction type {mean,vote}");
-
- vm = add_options(all, bs_options);
-
- data.B = (uint32_t)vm["bootstrap"].as<size_t>();
+ data.B = (uint32_t)all.vm["bootstrap"].as<size_t>();
//append bs with number of samples to options_from_file so it is saved to regressor later
*all.file_options << " --bootstrap " << data.B;
std::string type_string("mean");
- if (vm.count("bs_type"))
+ if (all.vm.count("bs_type"))
{
- type_string = vm["bs_type"].as<std::string>();
+ type_string = all.vm["bs_type"].as<std::string>();
if (type_string.compare("mean") == 0) {
data.bs_type = BS_TYPE_MEAN;
@@ -280,7 +279,7 @@ namespace BS {
data.pred_vec.reserve(data.B);
data.all = &all;
- learner<bs>& l = init_learner(&data, all.l, predict_or_learn<true>,
+ learner<bs>& l = init_learner(&data, setup_base(all), predict_or_learn<true>,
predict_or_learn<false>, data.B);
l.set_finish_example(finish_example);
l.set_finish(finish);
diff --git a/vowpalwabbit/bs.h b/vowpalwabbit/bs.h
index e05bcd05..062c195a 100644
--- a/vowpalwabbit/bs.h
+++ b/vowpalwabbit/bs.h
@@ -11,7 +11,7 @@ license as described in the file LICENSE.
namespace BS
{
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all);
void print_result(int f, float res, float weight, v_array<char> tag, float lb, float ub);
void output_example(vw& all, example* ec, float lb, float ub);
diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc
index bf9a6e5b..b14328d3 100644
--- a/vowpalwabbit/cb_algs.cc
+++ b/vowpalwabbit/cb_algs.cc
@@ -436,36 +436,35 @@ namespace CB_ALGS
VW::finish_example(all, &ec);
}
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{
+ new_options(all, "CB options")
+ ("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs");
+ if (missing_required(all)) return NULL;
+ new_options(all)
+ ("cb_type", po::value<string>(), "contextual bandit method to use in {ips,dm,dr}")
+ ("eval", "Evaluate a policy rather than optimizing.");
+ add_options(all);
+
cb& c = calloc_or_die<cb>();
c.all = &all;
- uint32_t nb_actions = (uint32_t)vm["cb"].as<size_t>();
- //append cb with nb_actions to file_options so it is saved to regressor later
-
- po::options_description cb_opts("CB options");
- cb_opts.add_options()
- ("cb_type", po::value<string>(), "contextual bandit method to use in {ips,dm,dr}")
- ("eval", "Evaluate a policy rather than optimizing.")
- ;
-
- vm = add_options(all, cb_opts);
+ uint32_t nb_actions = (uint32_t)all.vm["cb"].as<size_t>();
*all.file_options << " --cb " << nb_actions;
all.sd->k = nb_actions;
bool eval = false;
- if (vm.count("eval"))
+ if (all.vm.count("eval"))
eval = true;
size_t problem_multiplier = 2;//default for DR
- if (vm.count("cb_type"))
+ if (all.vm.count("cb_type"))
{
std::string type_string;
- type_string = vm["cb_type"].as<std::string>();
+ type_string = all.vm["cb_type"].as<std::string>();
*all.file_options << " --cb_type " << type_string;
if (type_string.compare("dr") == 0)
@@ -496,6 +495,15 @@ namespace CB_ALGS
*all.file_options << " --cb_type dr";
}
+ if (count(all.args.begin(), all.args.end(),"--csoaa") == 0)
+ {
+ all.args.push_back("--csoaa");
+ stringstream ss;
+ ss << all.vm["cb"].as<size_t>();
+ all.args.push_back(ss.str());
+ }
+
+ base_learner* base = setup_base(all);
if (eval)
all.p->lp = CB_EVAL::cb_eval;
else
@@ -504,18 +512,18 @@ namespace CB_ALGS
learner<cb>* l;
if (eval)
{
- l = &init_learner(&c, all.l, learn_eval, predict_eval, problem_multiplier);
+ l = &init_learner(&c, base, learn_eval, predict_eval, problem_multiplier);
l->set_finish_example(eval_finish_example);
}
else
{
- l = &init_learner(&c, all.l, predict_or_learn<true>, predict_or_learn<false>,
+ l = &init_learner(&c, base, predict_or_learn<true>, predict_or_learn<false>,
problem_multiplier);
l->set_finish_example(finish_example);
}
// preserve the increment of the base learner since we are
// _adding_ to the number of problems rather than multiplying.
- l->increment = all.l->increment;
+ l->increment = base->increment;
l->set_init_driver(init_driver);
l->set_finish(finish);
diff --git a/vowpalwabbit/cb_algs.h b/vowpalwabbit/cb_algs.h
index e989b5a1..0593fb6c 100644
--- a/vowpalwabbit/cb_algs.h
+++ b/vowpalwabbit/cb_algs.h
@@ -6,8 +6,7 @@ license as described in the file LICENSE.
#pragma once
//TODO: extend to handle CSOAA_LDF and WAP_LDF
namespace CB_ALGS {
-
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all);
template <bool is_learn>
float get_cost_pred(vw& all, CB::cb_class* known_cost, example& ec, uint32_t index, uint32_t base)
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc
index d8176228..d72011a2 100644
--- a/vowpalwabbit/cbify.cc
+++ b/vowpalwabbit/cbify.cc
@@ -371,24 +371,35 @@ namespace CBIFY {
void finish(cbify& data)
{ CB::cb_label.delete_label(&data.cb_label); }
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{//parse and set arguments
- cbify& data = calloc_or_die<cbify>();
-
- data.all = &all;
- po::options_description cb_opts("CBIFY options");
- cb_opts.add_options()
+ new_options(all, "CBIFY options")
+ ("cbify", po::value<size_t>(), "Convert multiclass on <k> classes into a contextual bandit problem and solve");
+ if (missing_required(all)) return NULL;
+ new_options(all)
("first", po::value<size_t>(), "tau-first exploration")
("epsilon",po::value<float>() ,"epsilon-greedy exploration")
("bag",po::value<size_t>() ,"bagging-based exploration")
("cover",po::value<size_t>() ,"bagging-based exploration");
-
- vm = add_options(all, cb_opts);
-
+ add_options(all);
+
+ po::variables_map& vm = all.vm;
+ cbify& data = calloc_or_die<cbify>();
+ data.all = &all;
data.k = (uint32_t)vm["cbify"].as<size_t>();
*all.file_options << " --cbify " << data.k;
+ if (count(all.args.begin(), all.args.end(),"--cb") == 0)
+ {
+ all.args.push_back("--cb");
+ stringstream ss;
+ ss << vm["cbify"].as<size_t>();
+ all.args.push_back(ss.str());
+ }
+ base_learner* base = setup_base(all);
+
all.p->lp = MULTICLASS::mc_label;
+
learner<cbify>* l;
data.recorder.reset(new vw_recorder());
data.mwt_explorer.reset(new MwtExplorer<vw_context>("vw", *data.recorder.get()));
@@ -403,7 +414,7 @@ namespace CBIFY {
epsilon = vm["epsilon"].as<float>();
data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k));
data.generic_explorer.reset(new GenericExplorer<vw_context>(*data.scorer.get(), (u32)data.k));
- l = &init_learner(&data, all.l, predict_or_learn_cover<true>,
+ l = &init_learner(&data, base, predict_or_learn_cover<true>,
predict_or_learn_cover<false>, cover + 1);
}
else if (vm.count("bag"))
@@ -414,7 +425,7 @@ namespace CBIFY {
data.policies.push_back(unique_ptr<IPolicy<vw_context>>(new vw_policy(i)));
}
data.bootstrap_explorer.reset(new BootstrapExplorer<vw_context>(data.policies, (u32)data.k));
- l = &init_learner(&data, all.l, predict_or_learn_bag<true>,
+ l = &init_learner(&data, base, predict_or_learn_bag<true>,
predict_or_learn_bag<false>, bags);
}
else if (vm.count("first") )
@@ -422,7 +433,7 @@ namespace CBIFY {
uint32_t tau = (uint32_t)vm["first"].as<size_t>();
data.policy.reset(new vw_policy());
data.tau_explorer.reset(new TauFirstExplorer<vw_context>(*data.policy.get(), (u32)tau, (u32)data.k));
- l = &init_learner(&data, all.l, predict_or_learn_first<true>,
+ l = &init_learner(&data, base, predict_or_learn_first<true>,
predict_or_learn_first<false>, 1);
}
else
@@ -432,7 +443,7 @@ namespace CBIFY {
epsilon = vm["epsilon"].as<float>();
data.policy.reset(new vw_policy());
data.greedy_explorer.reset(new EpsilonGreedyExplorer<vw_context>(*data.policy.get(), epsilon, (u32)data.k));
- l = &init_learner(&data, all.l, predict_or_learn_greedy<true>,
+ l = &init_learner(&data, base, predict_or_learn_greedy<true>,
predict_or_learn_greedy<false>, 1);
}
diff --git a/vowpalwabbit/cbify.h b/vowpalwabbit/cbify.h
index aed26b75..2ea0a627 100644
--- a/vowpalwabbit/cbify.h
+++ b/vowpalwabbit/cbify.h
@@ -4,6 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace CBIFY {
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace CBIFY { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index 350d40a3..908369b7 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -13,9 +13,7 @@ license as described in the file LICENSE.
#include "gd.h" // GD::foreach_feature() needed in subtract_example()
using namespace std;
-
using namespace LEARNER;
-
using namespace COST_SENSITIVE;
namespace CSOAA {
@@ -68,21 +66,25 @@ namespace CSOAA {
VW::finish_example(all, &ec);
}
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{
+ new_options(all, "CSOAA options")
+ ("csoaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> costs");
+ if(missing_required(all)) return NULL;
+
csoaa& c = calloc_or_die<csoaa>();
c.all = &all;
//first parse for number of actions
uint32_t nb_actions = 0;
- nb_actions = (uint32_t)vm["csoaa"].as<size_t>();
+ nb_actions = (uint32_t)all.vm["csoaa"].as<size_t>();
//append csoaa with nb_actions to file_options so it is saved to regressor later
*all.file_options << " --csoaa " << nb_actions;
all.p->lp = cs_label;
all.sd->k = nb_actions;
- learner<csoaa>& l = init_learner(&c, all.l, predict_or_learn<true>,
+ learner<csoaa>& l = init_learner(&c, setup_base(all), predict_or_learn<true>,
predict_or_learn<false>, nb_actions);
l.set_finish_example(finish_example);
return make_base(l);
@@ -645,15 +647,17 @@ namespace LabelDict {
}
}
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{
- po::options_description ldf_opts("LDF Options");
- ldf_opts.add_options()
- ("ldf_override", po::value<string>(), "Override singleline or multiline from csoaa_ldf or wap_ldf, eg if stored in file")
- ;
-
- vm = add_options(all, ldf_opts);
-
+ new_options(all, "LDF Options")
+ ("csoaa_ldf", po::value<string>(), "Use one-against-all multiclass learning with label dependent features. Specify singleline or multiline.")
+ ("wap_ldf", po::value<string>(), "Use weighted all-pairs multiclass learning with label dependent features. Specify singleline or multiline.");
+ if (missing_required(all)) return NULL;
+ new_options(all)
+ ("ldf_override", po::value<string>(), "Override singleline or multiline from csoaa_ldf or wap_ldf, eg if stored in file");
+ add_options(all);
+
+ po::variables_map& vm = all.vm;
ldf& ld = calloc_or_die<ldf>();
ld.all = &all;
@@ -708,7 +712,7 @@ namespace LabelDict {
ld.read_example_this_loop = 0;
ld.need_to_clear = false;
- learner<ldf>& l = init_learner(&ld, all.l, predict_or_learn<true>, predict_or_learn<false>);
+ learner<ldf>& l = init_learner(&ld, setup_base(all), predict_or_learn<true>, predict_or_learn<false>);
if (ld.is_singleline)
l.set_finish_example(finish_singleline_example);
else
diff --git a/vowpalwabbit/csoaa.h b/vowpalwabbit/csoaa.h
index 79f5e4d2..45e4edf2 100644
--- a/vowpalwabbit/csoaa.h
+++ b/vowpalwabbit/csoaa.h
@@ -4,17 +4,14 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace CSOAA {
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace CSOAA { LEARNER::base_learner* setup(vw& all); }
namespace CSOAA_AND_WAP_LDF {
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-
-namespace LabelDict {
- bool ec_is_example_header(example& ec); // example headers look like "0:-1" or just "shared"
- void add_example_namespaces_from_example(example& target, example& source);
- void del_example_namespaces_from_example(example& target, example& source);
-}
+ LEARNER::base_learner* setup(vw& all);
+ namespace LabelDict {
+ bool ec_is_example_header(example& ec);// example headers look like "0:-1" or just "shared"
+ void add_example_namespaces_from_example(example& target, example& source);
+ void del_example_namespaces_from_example(example& target, example& source);
+ }
}
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc
index dea87040..1a071eb2 100644
--- a/vowpalwabbit/ect.cc
+++ b/vowpalwabbit/ect.cc
@@ -17,7 +17,6 @@ license as described in the file LICENSE.
#include "reductions.h"
#include "multiclass.h"
#include "simple_label.h"
-#include "vw.h"
using namespace std;
using namespace LEARNER;
@@ -51,8 +50,6 @@ namespace ECT
uint32_t last_pair;
v_array<bool> tournaments_won;
-
- vw* all;
};
bool exists(v_array<size_t> db)
@@ -184,7 +181,7 @@ namespace ECT
return e.last_pair + (eliminations-1);
}
- uint32_t ect_predict(vw& all, ect& e, base_learner& base, example& ec)
+ uint32_t ect_predict(ect& e, base_learner& base, example& ec)
{
if (e.k == (size_t)1)
return 1;
@@ -228,7 +225,7 @@ namespace ECT
return false;
}
- void ect_train(vw& all, ect& e, base_learner& base, example& ec)
+ void ect_train(ect& e, base_learner& base, example& ec)
{
if (e.k == 1)//nothing to do
return;
@@ -318,25 +315,21 @@ namespace ECT
}
void predict(ect& e, base_learner& base, example& ec) {
- vw* all = e.all;
-
MULTICLASS::multiclass mc = ec.l.multi;
if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1))
cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl;
- ec.pred.multiclass = ect_predict(*all, e, base, ec);
+ ec.pred.multiclass = ect_predict(e, base, ec);
ec.l.multi = mc;
}
void learn(ect& e, base_learner& base, example& ec)
{
- vw* all = e.all;
-
MULTICLASS::multiclass mc = ec.l.multi;
predict(e, base, ec);
uint32_t pred = ec.pred.multiclass;
- if (mc.label != (uint32_t)-1 && all->training)
- ect_train(*all, e, base, ec);
+ if (mc.label != (uint32_t)-1)
+ ect_train(e, base, ec);
ec.l.multi = mc;
ec.pred.multiclass = pred;
}
@@ -360,36 +353,26 @@ namespace ECT
e.tournaments_won.delete_v();
}
- void finish_example(vw& all, ect&, example& ec) { MULTICLASS::finish_example(all, ec); }
-
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{
- ect& data = calloc_or_die<ect>();
- po::options_description ect_opts("ECT options");
- ect_opts.add_options()
- ("error", po::value<size_t>(), "error in ECT");
+ new_options(all, "Error Correcting Tournament options")
+ ("ect", po::value<size_t>(), "Use error correcting tournament with <k> labels");
+ if (missing_required(all)) return NULL;
+ new_options(all)("error", po::value<size_t>()->default_value(0), "error in ECT");
+ add_options(all);
- vm = add_options(all, ect_opts);
-
- //first parse for number of actions
- data.k = (int)vm["ect"].as<size_t>();
-
- //append ect with nb_actions to options_from_file so it is saved to regressor later
- if (vm.count("error")) {
- data.errors = (uint32_t)vm["error"].as<size_t>();
- } else
- data.errors = 0;
+ ect& data = calloc_or_die<ect>();
+ data.k = (int)all.vm["ect"].as<size_t>();
+ data.errors = (uint32_t)all.vm["error"].as<size_t>();
//append error flag to options_from_file so it is saved in regressor file later
*all.file_options << " --ect " << data.k << " --error " << data.errors;
- all.p->lp = MULTICLASS::mc_label;
size_t wpp = create_circuit(all, data, data.k, data.errors+1);
- data.all = &all;
- learner<ect>& l = init_learner(&data, all.l, learn, predict, wpp);
- l.set_finish_example(finish_example);
+ learner<ect>& l = init_learner(&data, setup_base(all), learn, predict, wpp);
+ l.set_finish_example(MULTICLASS::finish_example<ect>);
+ all.p->lp = MULTICLASS::mc_label;
l.set_finish(finish);
-
return make_base(l);
}
}
diff --git a/vowpalwabbit/ect.h b/vowpalwabbit/ect.h
index 81129791..c02d3848 100644
--- a/vowpalwabbit/ect.h
+++ b/vowpalwabbit/ect.h
@@ -4,7 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace ECT
-{
- LEARNER::base_learner* setup(vw&, po::variables_map&);
-}
+namespace ECT { LEARNER::base_learner* setup(vw&); }
diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc
index 6216058b..5eecfc26 100644
--- a/vowpalwabbit/ftrl_proximal.cc
+++ b/vowpalwabbit/ftrl_proximal.cc
@@ -55,8 +55,8 @@ namespace FTRL {
};
void update_accumulated_state(weight* w, float ftrl_alpha) {
- double ng2 = w[W_G2] + w[W_GT]*w[W_GT];
- double sigma = (sqrt(ng2) - sqrt(w[W_G2]))/ ftrl_alpha;
+ float ng2 = w[W_G2] + w[W_GT]*w[W_GT];
+ float sigma = (sqrtf(ng2) - sqrtf(w[W_G2]))/ ftrl_alpha;
w[W_ZT] += w[W_GT] - sigma * w[W_XT];
w[W_G2] = ng2;
}
@@ -107,7 +107,7 @@ namespace FTRL {
if (fabs_zt <= d.l1_lambda) {
w[W_XT] = 0.;
} else {
- double step = 1/(d.l2_lambda + (d.ftrl_beta + sqrt(w[W_G2]))/d.ftrl_alpha);
+ float step = 1/(d.l2_lambda + (d.ftrl_beta + sqrtf(w[W_G2]))/d.ftrl_alpha);
w[W_XT] = step * flag * (d.l1_lambda - fabs_zt);
}
}
@@ -129,7 +129,7 @@ namespace FTRL {
label_data& ld = ec.l.simple;
ec.loss = all.loss->getLoss(all.sd, ec.updated_prediction, ld.label) * ld.weight;
if (b.progressive_validation) {
- float v = 1./(1 + exp(-ec.updated_prediction));
+ float v = 1.f/(1 + exp(-ec.updated_prediction));
fprintf(b.fo, "%.6f\t%d\n", v, (int)(ld.label * ld.weight));
}
}
@@ -177,35 +177,27 @@ namespace FTRL {
ec.pred.scalar = ftrl_predict(*all,ec);
}
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{
+ new_options(all, "FTRL options")
+ ("ftrl", "Follow the Regularized Leader");
+ if (missing_required(all)) return NULL;
+ new_options(all)
+ ("ftrl_alpha", po::value<float>()->default_value(0.0), "Learning rate for FTRL-proximal optimization")
+ ("ftrl_beta", po::value<float>()->default_value(0.1f), "FTRL beta")
+ ("progressive_validation", po::value<string>()->default_value("ftrl.evl"), "File to record progressive validation for ftrl-proximal");
+ add_options(all);
+
ftrl& b = calloc_or_die<ftrl>();
b.all = &all;
- b.ftrl_beta = 0.0;
- b.ftrl_alpha = 0.1;
-
- po::options_description ftrl_opts("FTRL options");
-
- ftrl_opts.add_options()
- ("ftrl_alpha", po::value<float>(&(b.ftrl_alpha)), "Learning rate for FTRL-proximal optimization")
- ("ftrl_beta", po::value<float>(&(b.ftrl_beta)), "FTRL beta")
- ("progressive_validation", po::value<string>()->default_value("ftrl.evl"), "File to record progressive validation for ftrl-proximal");
-
- vm = add_options(all, ftrl_opts);
-
- if (vm.count("ftrl_alpha")) {
- b.ftrl_alpha = vm["ftrl_alpha"].as<float>();
- }
-
- if (vm.count("ftrl_beta")) {
- b.ftrl_beta = vm["ftrl_beta"].as<float>();
- }
+ b.ftrl_beta = all.vm["ftrl_beta"].as<float>();
+ b.ftrl_alpha = all.vm["ftrl_alpha"].as<float>();
all.reg.stride_shift = 2; // NOTE: for more parameter storage
b.progressive_validation = false;
- if (vm.count("progressive_validation")) {
- std::string filename = vm["progressive_validation"].as<string>();
+ if (all.vm.count("progressive_validation")) {
+ std::string filename = all.vm["progressive_validation"].as<string>();
b.fo = fopen(filename.c_str(), "w");
assert(b.fo != NULL);
b.progressive_validation = true;
diff --git a/vowpalwabbit/ftrl_proximal.h b/vowpalwabbit/ftrl_proximal.h
index 59bf4653..dd495d3e 100644
--- a/vowpalwabbit/ftrl_proximal.h
+++ b/vowpalwabbit/ftrl_proximal.h
@@ -3,10 +3,5 @@ Copyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
-#ifndef FTRL_PROXIMAL_H
-#define FTRL_PROXIMAL_H
-
-namespace FTRL {
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
-#endif
+#pragma once
+namespace FTRL { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc
index 851fbea7..8460bab8 100644
--- a/vowpalwabbit/gd.cc
+++ b/vowpalwabbit/gd.cc
@@ -844,8 +844,16 @@ uint32_t ceil_log_2(uint32_t v)
return 1 + ceil_log_2(v >> 1);
}
-base_learner* setup(vw& all, po::variables_map& vm)
+base_learner* setup(vw& all)
{
+ new_options(all, "Gradient Descent options")
+ ("sgd", "use regular stochastic gradient descent update.")
+ ("adaptive", "use adaptive, individual learning rates.")
+ ("invariant", "use safe/importance aware updates.")
+ ("normalized", "use per feature normalized updates")
+ ("exact_adaptive_norm", "use current default invariant normalized adaptive update rule");
+ add_options(all);
+ po::variables_map& vm = all.vm;
gd& g = calloc_or_die<gd>();
g.all = &all;
g.all->normalized_sum_norm_x = 0;
diff --git a/vowpalwabbit/gd.h b/vowpalwabbit/gd.h
index de3964eb..a5413eae 100644
--- a/vowpalwabbit/gd.h
+++ b/vowpalwabbit/gd.h
@@ -24,7 +24,7 @@ namespace GD{
void compute_update(example* ec);
void offset_train(regressor &reg, example* &ec, float update, size_t offset);
void train_one_example_single_thread(regressor& r, example* ex);
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all);
void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text);
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text);
void output_and_account_example(example* ec);
diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc
index b57328f4..8c51b0ce 100644
--- a/vowpalwabbit/gd_mf.cc
+++ b/vowpalwabbit/gd_mf.cc
@@ -26,10 +26,12 @@ using namespace LEARNER;
namespace GDMF {
struct gdmf {
vw* all;
+ uint32_t rank;
};
-void mf_print_offset_features(vw& all, example& ec, size_t offset)
+void mf_print_offset_features(gdmf& d, example& ec, size_t offset)
{
+ vw& all = *d.all;
weight* weights = all.reg.weight_vector;
size_t mask = all.reg.weight_mask;
for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++)
@@ -53,7 +55,7 @@ void mf_print_offset_features(vw& all, example& ec, size_t offset)
if (ec.atomics[(int)(*i)[0]].size() > 0 && ec.atomics[(int)(*i)[1]].size() > 0)
{
/* print out nsk^feature:hash:value:weight:nsk^feature^:hash:value:weight:prod_weights */
- for (size_t k = 1; k <= all.rank; k++)
+ for (size_t k = 1; k <= d.rank; k++)
{
for (audit_data* f = ec.audit_features[(int)(*i)[0]].begin; f!= ec.audit_features[(int)(*i)[0]].end; f++)
for (audit_data* f2 = ec.audit_features[(int)(*i)[1]].begin; f2!= ec.audit_features[(int)(*i)[1]].end; f2++)
@@ -62,11 +64,11 @@ void mf_print_offset_features(vw& all, example& ec, size_t offset)
<<"(" << ((f->weight_index + offset +k) & mask) << ")" << ':' << f->x;
cout << ':' << weights[(f->weight_index + offset + k) & mask];
- cout << ':' << f2->space << k << '^' << f2->feature << ':' << ((f2->weight_index+k+all.rank)&mask)
- <<"(" << ((f2->weight_index + offset +k+all.rank) & mask) << ")" << ':' << f2->x;
- cout << ':' << weights[(f2->weight_index + offset + k+all.rank) & mask];
+ cout << ':' << f2->space << k << '^' << f2->feature << ':' << ((f2->weight_index+k+d.rank)&mask)
+ <<"(" << ((f2->weight_index + offset +k+d.rank) & mask) << ")" << ':' << f2->x;
+ cout << ':' << weights[(f2->weight_index + offset + k+d.rank) & mask];
- cout << ':' << weights[(f->weight_index + offset + k) & mask] * weights[(f2->weight_index + offset + k + all.rank) & mask];
+ cout << ':' << weights[(f->weight_index + offset + k) & mask] * weights[(f2->weight_index + offset + k + d.rank) & mask];
}
}
}
@@ -77,17 +79,25 @@ void mf_print_offset_features(vw& all, example& ec, size_t offset)
cout << endl;
}
-void mf_print_audit_features(vw& all, example& ec, size_t offset)
+void mf_print_audit_features(gdmf& d, example& ec, size_t offset)
{
- print_result(all.stdout_fileno,ec.pred.scalar,-1,ec.tag);
- mf_print_offset_features(all, ec, offset);
+ print_result(d.all->stdout_fileno,ec.pred.scalar,-1,ec.tag);
+ mf_print_offset_features(d, ec, offset);
}
-float mf_predict(vw& all, example& ec)
+float mf_predict(gdmf& d, example& ec)
{
+ vw& all = *d.all;
label_data& ld = ec.l.simple;
float prediction = ld.initial;
+ for (vector<string>::iterator i = d.all->pairs.begin(); i != d.all->pairs.end();i++)
+ {
+ ec.num_features -= ec.atomics[(int)(*i)[0]].size() * ec.atomics[(int)(*i)[1]].size();
+ ec.num_features += ec.atomics[(int)(*i)[0]].size() * d.rank;
+ ec.num_features += ec.atomics[(int)(*i)[1]].size() * d.rank;
+ }
+
// clear stored predictions
ec.topic_predictions.erase();
@@ -107,18 +117,18 @@ float mf_predict(vw& all, example& ec)
{
if (ec.atomics[(int)(*i)[0]].size() > 0 && ec.atomics[(int)(*i)[1]].size() > 0)
{
- for (uint32_t k = 1; k <= all.rank; k++)
+ for (uint32_t k = 1; k <= d.rank; k++)
{
// x_l * l^k
- // l^k is from index+1 to index+all.rank
+ // l^k is from index+1 to index+d.rank
//float x_dot_l = sd_offset_add(weights, mask, ec.atomics[(int)(*i)[0]].begin, ec.atomics[(int)(*i)[0]].end, k);
float x_dot_l = 0.;
GD::foreach_feature<float, GD::vec_add>(all.reg.weight_vector, all.reg.weight_mask, ec.atomics[(int)(*i)[0]].begin, ec.atomics[(int)(*i)[0]].end, x_dot_l, k);
// x_r * r^k
- // r^k is from index+all.rank+1 to index+2*all.rank
- //float x_dot_r = sd_offset_add(weights, mask, ec.atomics[(int)(*i)[1]].begin, ec.atomics[(int)(*i)[1]].end, k+all.rank);
+ // r^k is from index+d.rank+1 to index+2*d.rank
+ //float x_dot_r = sd_offset_add(weights, mask, ec.atomics[(int)(*i)[1]].begin, ec.atomics[(int)(*i)[1]].end, k+d.rank);
float x_dot_r = 0.;
- GD::foreach_feature<float,GD::vec_add>(all.reg.weight_vector, all.reg.weight_mask, ec.atomics[(int)(*i)[1]].begin, ec.atomics[(int)(*i)[1]].end, x_dot_r, k+all.rank);
+ GD::foreach_feature<float,GD::vec_add>(all.reg.weight_vector, all.reg.weight_mask, ec.atomics[(int)(*i)[1]].begin, ec.atomics[(int)(*i)[1]].end, x_dot_r, k+d.rank);
prediction += x_dot_l * x_dot_r;
@@ -146,7 +156,7 @@ float mf_predict(vw& all, example& ec)
ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ld.weight;
if (all.audit)
- mf_print_audit_features(all, ec, 0);
+ mf_print_audit_features(d, ec, 0);
return ec.pred.scalar;
}
@@ -158,55 +168,55 @@ void sd_offset_update(weight* weights, size_t mask, feature* begin, feature* end
weights[(f->weight_index + offset) & mask] += update * f->x - regularization * weights[(f->weight_index + offset) & mask];
}
-void mf_train(vw& all, example& ec)
-{
- weight* weights = all.reg.weight_vector;
- size_t mask = all.reg.weight_mask;
- label_data& ld = ec.l.simple;
-
- // use final prediction to get update size
+ void mf_train(gdmf& d, example& ec)
+ {
+ vw& all = *d.all;
+ weight* weights = all.reg.weight_vector;
+ size_t mask = all.reg.weight_mask;
+ label_data& ld = ec.l.simple;
+
+ // use final prediction to get update size
// update = eta_t*(y-y_hat) where eta_t = eta/(3*t^p) * importance weight
- float eta_t = all.eta/pow(ec.example_t,all.power_t) / 3.f * ld.weight;
- float update = all.loss->getUpdate(ec.pred.scalar, ld.label, eta_t, 1.); //ec.total_sum_feat_sq);
-
- float regularization = eta_t * all.l2_lambda;
-
- // linear update
- for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++)
- sd_offset_update(weights, mask, ec.atomics[*i].begin, ec.atomics[*i].end, 0, update, regularization);
-
- // quadratic update
- for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++)
- {
- if (ec.atomics[(int)(*i)[0]].size() > 0 && ec.atomics[(int)(*i)[1]].size() > 0)
- {
-
- // update l^k weights
- for (size_t k = 1; k <= all.rank; k++)
- {
- // r^k \cdot x_r
- float r_dot_x = ec.topic_predictions[2*k];
- // l^k <- l^k + update * (r^k \cdot x_r) * x_l
- sd_offset_update(weights, mask, ec.atomics[(int)(*i)[0]].begin, ec.atomics[(int)(*i)[0]].end, k, update*r_dot_x, regularization);
- }
-
- // update r^k weights
- for (size_t k = 1; k <= all.rank; k++)
- {
- // l^k \cdot x_l
- float l_dot_x = ec.topic_predictions[2*k-1];
- // r^k <- r^k + update * (l^k \cdot x_l) * x_r
- sd_offset_update(weights, mask, ec.atomics[(int)(*i)[1]].begin, ec.atomics[(int)(*i)[1]].end, k+all.rank, update*l_dot_x, regularization);
- }
-
- }
- }
- if (all.triples.begin() != all.triples.end()) {
- cerr << "cannot use triples in matrix factorization" << endl;
- throw exception();
- }
-}
-
+ float eta_t = all.eta/pow(ec.example_t,all.power_t) / 3.f * ld.weight;
+ float update = all.loss->getUpdate(ec.pred.scalar, ld.label, eta_t, 1.); //ec.total_sum_feat_sq);
+
+ float regularization = eta_t * all.l2_lambda;
+
+ // linear update
+ for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++)
+ sd_offset_update(weights, mask, ec.atomics[*i].begin, ec.atomics[*i].end, 0, update, regularization);
+
+ // quadratic update
+ for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++)
+ {
+ if (ec.atomics[(int)(*i)[0]].size() > 0 && ec.atomics[(int)(*i)[1]].size() > 0)
+ {
+
+ // update l^k weights
+ for (size_t k = 1; k <= d.rank; k++)
+ {
+ // r^k \cdot x_r
+ float r_dot_x = ec.topic_predictions[2*k];
+ // l^k <- l^k + update * (r^k \cdot x_r) * x_l
+ sd_offset_update(weights, mask, ec.atomics[(int)(*i)[0]].begin, ec.atomics[(int)(*i)[0]].end, k, update*r_dot_x, regularization);
+ }
+ // update r^k weights
+ for (size_t k = 1; k <= d.rank; k++)
+ {
+ // l^k \cdot x_l
+ float l_dot_x = ec.topic_predictions[2*k-1];
+ // r^k <- r^k + update * (l^k \cdot x_l) * x_r
+ sd_offset_update(weights, mask, ec.atomics[(int)(*i)[1]].begin, ec.atomics[(int)(*i)[1]].end, k+d.rank, update*l_dot_x, regularization);
+ }
+
+ }
+ }
+ if (all.triples.begin() != all.triples.end()) {
+ cerr << "cannot use triples in matrix factorization" << endl;
+ throw exception();
+ }
+ }
+
void save_load(gdmf& d, io_buf& model_file, bool read, bool text)
{
vw* all = d.all;
@@ -231,7 +241,7 @@ void mf_train(vw& all, example& ec)
do
{
brw = 0;
- size_t K = all->rank*2+1;
+ size_t K = d.rank*2+1;
text_len = sprintf(buff, "%d ", i);
brw += bin_text_read_write_fixed(model_file,(char *)&i, sizeof (i),
@@ -272,58 +282,59 @@ void mf_train(vw& all, example& ec)
all->current_pass++;
}
- void predict(gdmf& d, base_learner& base, example& ec)
- {
- vw* all = d.all;
-
- mf_predict(*all,ec);
- }
+ void predict(gdmf& d, base_learner&, example& ec) { mf_predict(d,ec); }
void learn(gdmf& d, base_learner& base, example& ec)
{
- vw* all = d.all;
+ vw& all = *d.all;
- predict(d, base, ec);
- if (all->training && ec.l.simple.label != FLT_MAX)
- mf_train(*all, ec);
+ mf_predict(d, ec);
+ if (all.training && ec.l.simple.label != FLT_MAX)
+ mf_train(d, ec);
}
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{
+ new_options(all, "Gdmf options")
+ ("rank", po::value<uint32_t>(), "rank for matrix factorization.");
+ if(missing_required(all)) return NULL;
+
gdmf& data = calloc_or_die<gdmf>();
data.all = &all;
+ data.rank = all.vm["rank"].as<uint32_t>();
+ *all.file_options << " --rank " << data.rank;
// store linear + 2*rank weights per index, round up to power of two
- float temp = ceilf(logf((float)(all.rank*2+1)) / logf (2.f));
+ float temp = ceilf(logf((float)(data.rank*2+1)) / logf (2.f));
all.reg.stride_shift = (size_t) temp;
all.random_weights = true;
- if ( vm.count("adaptive") )
+ if ( all.vm.count("adaptive") )
{
cerr << "adaptive is not implemented for matrix factorization" << endl;
throw exception();
}
- if ( vm.count("normalized") )
+ if ( all.vm.count("normalized") )
{
cerr << "normalized is not implemented for matrix factorization" << endl;
throw exception();
}
- if ( vm.count("exact_adaptive_norm") )
+ if ( all.vm.count("exact_adaptive_norm") )
{
cerr << "normalized adaptive updates is not implemented for matrix factorization" << endl;
throw exception();
}
- if (vm.count("bfgs") || vm.count("conjugate_gradient"))
+ if (all.vm.count("bfgs") || all.vm.count("conjugate_gradient"))
{
cerr << "bfgs is not implemented for matrix factorization" << endl;
throw exception();
}
- if(!vm.count("learning_rate") && !vm.count("l"))
+ if(!all.vm.count("learning_rate") && !all.vm.count("l"))
all.eta = 10; //default learning rate to 10 for non default update rule
//default initial_t to 1 instead of 0
- if(!vm.count("initial_t")) {
+ if(!all.vm.count("initial_t")) {
all.sd->t = 1.f;
all.sd->weighted_unlabeled_examples = 1.f;
all.initial_t = 1.f;
diff --git a/vowpalwabbit/gd_mf.h b/vowpalwabbit/gd_mf.h
index a0ce2f87..09623e61 100644
--- a/vowpalwabbit/gd_mf.h
+++ b/vowpalwabbit/gd_mf.h
@@ -4,12 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-#include <math.h>
-#include "example.h"
-#include "parse_regressor.h"
-#include "parser.h"
-#include "gd.h"
-
-namespace GDMF{
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace GDMF{ LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/global_data.cc b/vowpalwabbit/global_data.cc
index 825977d5..cd288b52 100644
--- a/vowpalwabbit/global_data.cc
+++ b/vowpalwabbit/global_data.cc
@@ -214,23 +214,44 @@ void compile_limits(vector<string> limits, uint32_t* dest, bool quiet)
}
}
-po::variables_map add_options(vw& all, po::options_description& opts)
+void add_options(vw& all, po::options_description& opts)
{
all.opts.add(opts);
po::variables_map new_vm;
-
//parse local opts once for notifications.
po::parsed_options parsed = po::command_line_parser(all.args).
style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
options(opts).allow_unregistered().run();
po::store(parsed, new_vm);
po::notify(new_vm);
- //parse all opts for a complete variable map.
- parsed = po::command_line_parser(all.args).
+
+ for (po::variables_map::iterator it=new_vm.begin(); it!=new_vm.end(); ++it)
+ all.vm.insert(*it);
+}
+
+void add_options(vw& all)
+{
+ add_options(all, *all.new_opts);
+ delete all.new_opts;
+}
+
+bool missing_required(vw& all)
+{
+ //parse local opts once for notifications.
+ po::parsed_options parsed = po::command_line_parser(all.args).
style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
- options(all.opts).allow_unregistered().run();
+ options(*all.new_opts).allow_unregistered().run();
+ po::variables_map new_vm;
po::store(parsed, new_vm);
- return new_vm;
+ all.opts.add(*all.new_opts);
+ delete all.new_opts;
+ for (po::variables_map::iterator it=new_vm.begin(); it!=new_vm.end(); ++it)
+ all.vm.insert(*it);
+
+ if (new_vm.size() == 0) // required are missing;
+ return true;
+ else
+ return false;
}
vw::vw()
@@ -247,6 +268,7 @@ vw::vw()
reg_mode = 0;
current_pass = 0;
+ reduction_stack=v_init<LEARNER::base_learner* (*)(vw&)>();
data_filename = "";
@@ -260,13 +282,7 @@ vw::vw()
default_bits = true;
daemon = false;
num_children = 10;
- lda_alpha = 0.1f;
- lda_rho = 0.1f;
- lda_D = 10000.;
- lda_epsilon = 0.001f;
- minibatch = 1;
span_server = "";
- m = 15;
save_resume = false;
random_positive_weights = false;
@@ -276,8 +292,6 @@ vw::vw()
power_t = 0.5;
eta = 0.5; //default learning rate for normalized adaptive updates, this is switched to 10 by default for the other updates (see parse_args.cc)
numpasses = 1;
- rel_threshold = 0.001f;
- rank = 0;
final_prediction_sink.begin = final_prediction_sink.end=final_prediction_sink.end_array = NULL;
raw_prediction = -1;
diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h
index bdf76b46..7e6ec74b 100644
--- a/vowpalwabbit/global_data.h
+++ b/vowpalwabbit/global_data.h
@@ -193,12 +193,13 @@ struct vw {
bool bfgs;
bool hessian_on;
- int m;
bool save_resume;
double normalized_sum_norm_x;
po::options_description opts;
+ po::options_description* new_opts;
+ po::variables_map vm;
std::stringstream* file_options;
vector<std::string> args;
@@ -217,10 +218,6 @@ struct vw {
float power_t;//the power on learning rate decay.
int reg_mode;
- size_t minibatch;
-
- float rel_threshold; // termination threshold
-
size_t pass_length;
size_t numpasses;
size_t passes_complete;
@@ -262,19 +259,14 @@ struct vw {
size_t normalized_idx; //offset idx where the norm is stored (1 or 2 depending on whether adaptive is true)
uint32_t lda;
- float lda_alpha;
- float lda_rho;
- float lda_D;
- float lda_epsilon;
std::string text_regressor_name;
std::string inv_hash_regressor_name;
-
std::string span_server;
size_t length () { return ((size_t)1) << num_bits; };
- uint32_t rank;
+ v_array<LEARNER::base_learner* (*)(vw&)> reduction_stack;
//Prediction output
v_array<int> final_prediction_sink; // set to send global predictions to.
@@ -321,4 +313,11 @@ void get_prediction(int sock, float& res, float& weight);
void compile_gram(vector<string> grams, uint32_t* dest, char* descriptor, bool quiet);
void compile_limits(vector<string> limits, uint32_t* dest, bool quiet);
int print_tag(std::stringstream& ss, v_array<char> tag);
-po::variables_map add_options(vw& all, po::options_description& opts);
+void add_options(vw& all, po::options_description& opts);
+inline po::options_description_easy_init new_options(vw& all, std::string name = "\0")
+{
+ all.new_opts = new po::options_description(name);
+ return all.new_opts->add_options();
+}
+bool missing_required(vw& all);
+void add_options(vw& all);
diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc
index 9a88f17e..ed917cda 100644
--- a/vowpalwabbit/kernel_svm.cc
+++ b/vowpalwabbit/kernel_svm.cc
@@ -648,7 +648,7 @@ namespace KSVM
else {
for(size_t i = 0;i < params.pool_pos;i++) {
- float queryp = 2.0f/(1.0f + expf((float)(params.active_c*fabs(scores[i]))*pow(params.pool[i]->ex.example_counter,0.5f)));
+ float queryp = 2.0f/(1.0f + expf((float)(params.active_c*fabs(scores[i]))*(float)pow(params.pool[i]->ex.example_counter,0.5f)));
if(rand() < queryp) {
svm_example* fec = params.pool[i];
fec->ex.l.simple.weight *= 1/queryp;
@@ -790,13 +790,14 @@ namespace KSVM
cerr<<"Done with finish \n";
}
-
- LEARNER::base_learner* setup(vw &all, po::variables_map& vm) {
- po::options_description desc("KSVM options");
- desc.add_options()
+ LEARNER::base_learner* setup(vw &all) {
+ new_options(all, "KSVM options")
+ ("ksvm", "kernel svm");
+ if (missing_required(all)) return NULL;
+ new_options(all)
("reprocess", po::value<size_t>(), "number of reprocess steps for LASVM")
- ("active", "do active learning")
- ("active_c", po::value<double>(), "parameter for query prob")
+ // ("active", "do active learning")
+ //("active_c", po::value<double>(), "parameter for query prob")
("pool_greedy", "use greedy selection on mini pools")
("para_active", "do parallel active learning")
("pool_size", po::value<size_t>(), "size of pools for active learning")
@@ -805,8 +806,9 @@ namespace KSVM
("bandwidth", po::value<float>(), "bandwidth of rbf kernel")
("degree", po::value<int>(), "degree of poly kernel")
("lambda", po::value<double>(), "saving regularization for test time");
- vm = add_options(all, desc);
+ add_options(all);
+ po::variables_map& vm = all.vm;
string loss_function = "hinge";
float loss_parameter = 0.0;
delete all.loss;
diff --git a/vowpalwabbit/kernel_svm.h b/vowpalwabbit/kernel_svm.h
index 7a65a051..288e6729 100644
--- a/vowpalwabbit/kernel_svm.h
+++ b/vowpalwabbit/kernel_svm.h
@@ -4,7 +4,4 @@ individual contributors. All rights reserved. Released under a BSD (revised)
license as described in the file LICENSE.
*/
#pragma once
-namespace KSVM
-{
-LEARNER::base_learner* setup(vw &all, po::variables_map& vm);
-}
+namespace KSVM { LEARNER::base_learner* setup(vw &all); }
diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc
index 8da81a4d..e616ab25 100644
--- a/vowpalwabbit/lda_core.cc
+++ b/vowpalwabbit/lda_core.cc
@@ -1,799 +1,818 @@
-/*
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD (revised)
-license as described in the file LICENSE.
- */
-#include <fstream>
-#include <vector>
-#include <float.h>
-#ifdef _WIN32
-#include <winsock2.h>
-#else
-#include <netdb.h>
-#endif
-#include <string.h>
-#include <stdio.h>
-#include <assert.h>
-#include "constant.h"
-#include "gd.h"
-#include "simple_label.h"
-#include "rand48.h"
-#include "reductions.h"
-
-using namespace LEARNER;
-using namespace std;
-
-namespace LDA {
-
-class index_feature {
-public:
- uint32_t document;
- feature f;
- bool operator<(const index_feature b) const { return f.weight_index < b.f.weight_index; }
-};
-
- struct lda {
- v_array<float> Elogtheta;
- v_array<float> decay_levels;
- v_array<float> total_new;
- v_array<example* > examples;
- v_array<float> total_lambda;
- v_array<int> doc_lengths;
- v_array<float> digammas;
- v_array<float> v;
- vector<index_feature> sorted_features;
-
- bool total_lambda_init;
-
- double example_t;
- vw* all;
- };
-
-#ifdef _WIN32
-inline float fmax(float f1, float f2) { return (f1 < f2 ? f2 : f1); }
-inline float fmin(float f1, float f2) { return (f1 > f2 ? f2 : f1); }
-#endif
-
-#define MINEIRO_SPECIAL
-#ifdef MINEIRO_SPECIAL
-
-namespace {
-
-inline float
-fastlog2 (float x)
-{
- union { float f; uint32_t i; } vx = { x };
- union { uint32_t i; float f; } mx = { (vx.i & 0x007FFFFF) | (0x7e << 23) };
- float y = (float)vx.i;
- y *= 1.0f / (float)(1 << 23);
-
- return
- y - 124.22544637f - 1.498030302f * mx.f - 1.72587999f / (0.3520887068f + mx.f);
-}
-
-inline float
-fastlog (float x)
-{
- return 0.69314718f * fastlog2 (x);
-}
-
-inline float
-fastpow2 (float p)
-{
- float offset = (p < 0) ? 1.0f : 0.0f;
- float clipp = (p < -126) ? -126.0f : p;
- int w = (int)clipp;
- float z = clipp - w + offset;
- union { uint32_t i; float f; } v = { (uint32_t)((1 << 23) * (clipp + 121.2740838f + 27.7280233f / (4.84252568f - z) - 1.49012907f * z)) };
-
- return v.f;
-}
-
-inline float
-fastexp (float p)
-{
- return fastpow2 (1.442695040f * p);
-}
-
-inline float
-fastpow (float x,
- float p)
-{
- return fastpow2 (p * fastlog2 (x));
-}
-
-inline float
-fastlgamma (float x)
-{
- float logterm = fastlog (x * (1.0f + x) * (2.0f + x));
- float xp3 = 3.0f + x;
-
- return
- -2.081061466f - x + 0.0833333f / xp3 - logterm + (2.5f + x) * fastlog (xp3);
-}
-
-inline float
-fastdigamma (float x)
-{
- float twopx = 2.0f + x;
- float logterm = fastlog (twopx);
-
- return - (1.0f + 2.0f * x) / (x * (1.0f + x))
- - (13.0f + 6.0f * x) / (12.0f * twopx * twopx)
- + logterm;
-}
-
-#define log fastlog
-#define exp fastexp
-#define powf fastpow
-#define mydigamma fastdigamma
-#define mylgamma fastlgamma
-
-#if defined(__SSE2__) && !defined(VW_LDA_NO_SSE)
-
-#include <emmintrin.h>
-
-typedef __m128 v4sf;
-typedef __m128i v4si;
-
-#define v4si_to_v4sf _mm_cvtepi32_ps
-#define v4sf_to_v4si _mm_cvttps_epi32
-
-static inline float
-v4sf_index (const v4sf x,
- unsigned int i)
-{
- union { v4sf f; float array[4]; } tmp = { x };
-
- return tmp.array[i];
-}
-
-static inline const v4sf
-v4sfl (float x)
-{
- union { float array[4]; v4sf f; } tmp = { { x, x, x, x } };
-
- return tmp.f;
-}
-
-static inline const v4si
-v4sil (uint32_t x)
-{
- uint64_t wide = (((uint64_t) x) << 32) | x;
- union { uint64_t array[2]; v4si f; } tmp = { { wide, wide } };
-
- return tmp.f;
-}
-
-static inline v4sf
-vfastpow2 (const v4sf p)
-{
- v4sf ltzero = _mm_cmplt_ps (p, v4sfl (0.0f));
- v4sf offset = _mm_and_ps (ltzero, v4sfl (1.0f));
- v4sf lt126 = _mm_cmplt_ps (p, v4sfl (-126.0f));
- v4sf clipp = _mm_andnot_ps (lt126, p) + _mm_and_ps (lt126, v4sfl (-126.0f));
- v4si w = v4sf_to_v4si (clipp);
- v4sf z = clipp - v4si_to_v4sf (w) + offset;
-
- const v4sf c_121_2740838 = v4sfl (121.2740838f);
- const v4sf c_27_7280233 = v4sfl (27.7280233f);
- const v4sf c_4_84252568 = v4sfl (4.84252568f);
- const v4sf c_1_49012907 = v4sfl (1.49012907f);
- union { v4si i; v4sf f; } v = {
- v4sf_to_v4si (
- v4sfl (1 << 23) *
- (clipp + c_121_2740838 + c_27_7280233 / (c_4_84252568 - z) - c_1_49012907 * z)
- )
- };
-
- return v.f;
-}
-
-inline v4sf
-vfastexp (const v4sf p)
-{
- const v4sf c_invlog_2 = v4sfl (1.442695040f);
-
- return vfastpow2 (c_invlog_2 * p);
-}
-
-inline v4sf
-vfastlog2 (v4sf x)
-{
- union { v4sf f; v4si i; } vx = { x };
- union { v4si i; v4sf f; } mx = { (vx.i & v4sil (0x007FFFFF)) | v4sil (0x3f000000) };
- v4sf y = v4si_to_v4sf (vx.i);
- y *= v4sfl (1.1920928955078125e-7f);
-
- const v4sf c_124_22551499 = v4sfl (124.22551499f);
- const v4sf c_1_498030302 = v4sfl (1.498030302f);
- const v4sf c_1_725877999 = v4sfl (1.72587999f);
- const v4sf c_0_3520087068 = v4sfl (0.3520887068f);
-
- return y - c_124_22551499
- - c_1_498030302 * mx.f
- - c_1_725877999 / (c_0_3520087068 + mx.f);
-}
-
-inline v4sf
-vfastlog (v4sf x)
-{
- const v4sf c_0_69314718 = v4sfl (0.69314718f);
-
- return c_0_69314718 * vfastlog2 (x);
-}
-
-inline v4sf
-vfastdigamma (v4sf x)
-{
- v4sf twopx = v4sfl (2.0f) + x;
- v4sf logterm = vfastlog (twopx);
-
- return (v4sfl (-48.0f) + x * (v4sfl (-157.0f) + x * (v4sfl (-127.0f) - v4sfl (30.0f) * x))) /
- (v4sfl (12.0f) * x * (v4sfl (1.0f) + x) * twopx * twopx)
- + logterm;
-}
-
-void
-vexpdigammify (vw& all, float* gamma)
-{
- unsigned int n = all.lda;
- float extra_sum = 0.0f;
- v4sf sum = v4sfl (0.0f);
- size_t i;
-
- for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
- {
- extra_sum += gamma[i];
- gamma[i] = fastdigamma (gamma[i]);
- }
-
- for (; i + 4 < n; i += 4)
- {
- v4sf arg = _mm_load_ps (gamma + i);
- sum += arg;
- arg = vfastdigamma (arg);
- _mm_store_ps (gamma + i, arg);
- }
-
- for (; i < n; ++i)
- {
- extra_sum += gamma[i];
- gamma[i] = fastdigamma (gamma[i]);
- }
-
- extra_sum += v4sf_index (sum, 0) + v4sf_index (sum, 1) +
- v4sf_index (sum, 2) + v4sf_index (sum, 3);
- extra_sum = fastdigamma (extra_sum);
- sum = v4sfl (extra_sum);
-
- for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0)));
- }
-
- for (; i + 4 < n; i += 4)
- {
- v4sf arg = _mm_load_ps (gamma + i);
- arg -= sum;
- arg = vfastexp (arg);
- arg = _mm_max_ps (v4sfl (1e-10f), arg);
- _mm_store_ps (gamma + i, arg);
- }
-
- for (; i < n; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0)));
- }
-}
-
-void vexpdigammify_2(vw& all, float* gamma, const float* norm)
-{
- size_t n = all.lda;
- size_t i;
-
- for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i]));
- }
-
- for (; i + 4 < n; i += 4)
- {
- v4sf arg = _mm_load_ps (gamma + i);
- arg = vfastdigamma (arg);
- v4sf vnorm = _mm_loadu_ps (norm + i);
- arg -= vnorm;
- arg = vfastexp (arg);
- arg = _mm_max_ps (v4sfl (1e-10f), arg);
- _mm_store_ps (gamma + i, arg);
- }
-
- for (; i < n; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i]));
- }
-}
-
-#define myexpdigammify vexpdigammify
-#define myexpdigammify_2 vexpdigammify_2
-
-#else
-#ifndef _WIN32
-#warning "lda IS NOT using sse instructions"
-#endif
-#define myexpdigammify expdigammify
-#define myexpdigammify_2 expdigammify_2
-
-#endif // __SSE2__
-
-} // end anonymous namespace
-
-#else
-
-#include <boost/math/special_functions/digamma.hpp>
-#include <boost/math/special_functions/gamma.hpp>
-
-using namespace boost::math::policies;
-
-#define mydigamma boost::math::digamma
-#define mylgamma boost::math::lgamma
-#define myexpdigammify expdigammify
-#define myexpdigammify_2 expdigammify_2
-
-#endif // MINEIRO_SPECIAL
-
-float decayfunc(float t, float old_t, float power_t) {
- float result = 1;
- for (float i = old_t+1; i <= t; i += 1)
- result *= (1-powf(i, -power_t));
- return result;
-}
-
-float decayfunc2(float t, float old_t, float power_t)
-{
- float power_t_plus_one = 1.f - power_t;
- float arg = - ( powf(t, power_t_plus_one) -
- powf(old_t, power_t_plus_one));
- return exp ( arg
- / power_t_plus_one);
-}
-
-float decayfunc3(double t, double old_t, double power_t)
-{
- double power_t_plus_one = 1. - power_t;
- double logt = log((float)t);
- double logoldt = log((float)old_t);
- return (float)((old_t / t) * exp((float)(0.5*power_t_plus_one*(-logt*logt + logoldt*logoldt))));
-}
-
-float decayfunc4(double t, double old_t, double power_t)
-{
- if (power_t > 0.99)
- return decayfunc3(t, old_t, power_t);
- else
- return (float)decayfunc2((float)t, (float)old_t, (float)power_t);
-}
-
-void expdigammify(vw& all, float* gamma)
-{
- float sum=0;
- for (size_t i = 0; i<all.lda; i++)
- {
- sum += gamma[i];
- gamma[i] = mydigamma(gamma[i]);
- }
- sum = mydigamma(sum);
- for (size_t i = 0; i<all.lda; i++)
- gamma[i] = fmax(1e-6f, exp(gamma[i] - sum));
-}
-
-void expdigammify_2(vw& all, float* gamma, float* norm)
-{
- for (size_t i = 0; i<all.lda; i++)
- {
- gamma[i] = fmax(1e-6f, exp(mydigamma(gamma[i]) - norm[i]));
- }
-}
-
-float average_diff(vw& all, float* oldgamma, float* newgamma)
-{
- float sum = 0.;
- float normalizer = 0.;
- for (size_t i = 0; i<all.lda; i++) {
- sum += fabsf(oldgamma[i] - newgamma[i]);
- normalizer += newgamma[i];
- }
- return sum / normalizer;
-}
-
-// Returns E_q[log p(\theta)] - E_q[log q(\theta)].
- float theta_kl(vw& all, v_array<float>& Elogtheta, float* gamma)
-{
- float gammasum = 0;
- Elogtheta.erase();
- for (size_t k = 0; k < all.lda; k++) {
- Elogtheta.push_back(mydigamma(gamma[k]));
- gammasum += gamma[k];
- }
- float digammasum = mydigamma(gammasum);
- gammasum = mylgamma(gammasum);
- float kl = -(all.lda*mylgamma(all.lda_alpha));
- kl += mylgamma(all.lda_alpha*all.lda) - gammasum;
- for (size_t k = 0; k < all.lda; k++) {
- Elogtheta[k] -= digammasum;
- kl += (all.lda_alpha - gamma[k]) * Elogtheta[k];
- kl += mylgamma(gamma[k]);
- }
-
- return kl;
-}
-
-float find_cw(vw& all, float* u_for_w, float* v)
-{
- float c_w = 0;
- for (size_t k =0; k<all.lda; k++)
- c_w += u_for_w[k]*v[k];
-
- return 1.f / c_w;
-}
-
- v_array<float> new_gamma = v_init<float>();
- v_array<float> old_gamma = v_init<float>();
-// Returns an estimate of the part of the variational bound that
-// doesn't have to do with beta for the entire corpus for the current
-// setting of lambda based on the document passed in. The value is
-// divided by the total number of words in the document This can be
-// used as a (possibly very noisy) estimate of held-out likelihood.
- float lda_loop(vw& all, v_array<float>& Elogtheta, float* v,weight* weights,example* ec, float power_t)
-{
- new_gamma.erase();
- old_gamma.erase();
-
- for (size_t i = 0; i < all.lda; i++)
- {
- new_gamma.push_back(1.f);
- old_gamma.push_back(0.f);
- }
- size_t num_words =0;
- for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
- num_words += ec->atomics[*i].end - ec->atomics[*i].begin;
-
- float xc_w = 0;
- float score = 0;
- float doc_length = 0;
- do
- {
- memcpy(v,new_gamma.begin,sizeof(float)*all.lda);
- myexpdigammify(all, v);
-
- memcpy(old_gamma.begin,new_gamma.begin,sizeof(float)*all.lda);
- memset(new_gamma.begin,0,sizeof(float)*all.lda);
-
- score = 0;
- size_t word_count = 0;
- doc_length = 0;
- for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
- {
- feature *f = ec->atomics[*i].begin;
- for (; f != ec->atomics[*i].end; f++)
- {
- float* u_for_w = &weights[(f->weight_index&all.reg.weight_mask)+all.lda+1];
- float c_w = find_cw(all, u_for_w,v);
- xc_w = c_w * f->x;
- score += -f->x*log(c_w);
- size_t max_k = all.lda;
- for (size_t k =0; k<max_k; k++) {
- new_gamma[k] += xc_w*u_for_w[k];
- }
- word_count++;
- doc_length += f->x;
- }
- }
- for (size_t k =0; k<all.lda; k++)
- new_gamma[k] = new_gamma[k]*v[k]+all.lda_alpha;
- }
- while (average_diff(all, old_gamma.begin, new_gamma.begin) > all.lda_epsilon);
-
- ec->topic_predictions.erase();
- ec->topic_predictions.resize(all.lda);
- memcpy(ec->topic_predictions.begin,new_gamma.begin,all.lda*sizeof(float));
-
- score += theta_kl(all, Elogtheta, new_gamma.begin);
-
- return score / doc_length;
-}
-
-size_t next_pow2(size_t x) {
- int i = 0;
- x = x > 0 ? x - 1 : 0;
- while (x > 0) {
- x >>= 1;
- i++;
- }
- return ((size_t)1) << i;
-}
-
-void save_load(lda& l, io_buf& model_file, bool read, bool text)
-{
- vw* all = l.all;
- uint32_t length = 1 << all->num_bits;
- uint32_t stride = 1 << all->reg.stride_shift;
-
- if (read)
- {
- initialize_regressor(*all);
- for (size_t j = 0; j < stride*length; j+=stride)
- {
- for (size_t k = 0; k < all->lda; k++) {
- if (all->random_weights) {
- all->reg.weight_vector[j+k] = (float)(-log(frand48()) + 1.0f);
- all->reg.weight_vector[j+k] *= (float)(all->lda_D / all->lda / all->length() * 200);
- }
- }
- all->reg.weight_vector[j+all->lda] = all->initial_t;
- }
- }
-
- if (model_file.files.size() > 0)
- {
- uint32_t i = 0;
- uint32_t text_len;
- char buff[512];
- size_t brw = 1;
- do
- {
- brw = 0;
- size_t K = all->lda;
-
- text_len = sprintf(buff, "%d ", i);
- brw += bin_text_read_write_fixed(model_file,(char *)&i, sizeof (i),
- "", read,
- buff, text_len, text);
- if (brw != 0)
- for (uint32_t k = 0; k < K; k++)
- {
- uint32_t ndx = stride*i+k;
-
- weight* v = &(all->reg.weight_vector[ndx]);
- text_len = sprintf(buff, "%f ", *v + all->lda_rho);
-
- brw += bin_text_read_write_fixed(model_file,(char *)v, sizeof (*v),
- "", read,
- buff, text_len, text);
-
- }
- if (text)
- brw += bin_text_read_write_fixed(model_file,buff,0,
- "", read,
- "\n",1,text);
-
- if (!read)
- i++;
- }
- while ((!read && i < length) || (read && brw >0));
- }
-}
-
- void learn_batch(lda& l)
- {
- if (l.sorted_features.empty()) {
- // This can happen when the socket connection is dropped by the client.
- // If l.sorted_features is empty, then l.sorted_features[0] does not
- // exist, so we should not try to take its address in the beginning of
- // the for loops down there. Since it seems that there's not much to
- // do in this case, we just return.
- for (size_t d = 0; d < l.examples.size(); d++)
- return_simple_example(*l.all, NULL, *l.examples[d]);
- l.examples.erase();
- return;
- }
-
- float eta = -1;
- float minuseta = -1;
-
- if (l.total_lambda.size() == 0)
- {
- for (size_t k = 0; k < l.all->lda; k++)
- l.total_lambda.push_back(0.f);
-
- size_t stride = 1 << l.all->reg.stride_shift;
- for (size_t i =0; i <= l.all->reg.weight_mask;i+=stride)
- for (size_t k = 0; k < l.all->lda; k++)
- l.total_lambda[k] += l.all->reg.weight_vector[i+k];
- }
-
- l.example_t++;
- l.total_new.erase();
- for (size_t k = 0; k < l.all->lda; k++)
- l.total_new.push_back(0.f);
-
- size_t batch_size = l.examples.size();
-
- sort(l.sorted_features.begin(), l.sorted_features.end());
-
- eta = l.all->eta * powf((float)l.example_t, - l.all->power_t);
- minuseta = 1.0f - eta;
- eta *= l.all->lda_D / batch_size;
- l.decay_levels.push_back(l.decay_levels.last() + log(minuseta));
-
- l.digammas.erase();
- float additional = (float)(l.all->length()) * l.all->lda_rho;
- for (size_t i = 0; i<l.all->lda; i++) {
- l.digammas.push_back(mydigamma(l.total_lambda[i] + additional));
- }
-
-
- weight* weights = l.all->reg.weight_vector;
-
- size_t last_weight_index = -1;
- for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back(); s++)
- {
- if (last_weight_index == s->f.weight_index)
- continue;
- last_weight_index = s->f.weight_index;
- float* weights_for_w = &(weights[s->f.weight_index & l.all->reg.weight_mask]);
- float decay = fmin(1.0, exp(l.decay_levels.end[-2] - l.decay_levels.end[(int)(-1 - l.example_t+weights_for_w[l.all->lda])]));
- float* u_for_w = weights_for_w + l.all->lda+1;
-
- weights_for_w[l.all->lda] = (float)l.example_t;
- for (size_t k = 0; k < l.all->lda; k++)
- {
- weights_for_w[k] *= decay;
- u_for_w[k] = weights_for_w[k] + l.all->lda_rho;
- }
- myexpdigammify_2(*l.all, u_for_w, l.digammas.begin);
- }
-
- for (size_t d = 0; d < batch_size; d++)
- {
- float score = lda_loop(*l.all, l.Elogtheta, &(l.v[d*l.all->lda]), weights, l.examples[d],l.all->power_t);
- if (l.all->audit)
- GD::print_audit_features(*l.all, *l.examples[d]);
- // If the doc is empty, give it loss of 0.
- if (l.doc_lengths[d] > 0) {
- l.all->sd->sum_loss -= score;
- l.all->sd->sum_loss_since_last_dump -= score;
- }
- return_simple_example(*l.all, NULL, *l.examples[d]);
- }
-
- for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back();)
- {
- index_feature* next = s+1;
- while(next <= &l.sorted_features.back() && next->f.weight_index == s->f.weight_index)
- next++;
-
- float* word_weights = &(weights[s->f.weight_index & l.all->reg.weight_mask]);
- for (size_t k = 0; k < l.all->lda; k++) {
- float new_value = minuseta*word_weights[k];
- word_weights[k] = new_value;
- }
-
- for (; s != next; s++) {
- float* v_s = &(l.v[s->document*l.all->lda]);
- float* u_for_w = &weights[(s->f.weight_index & l.all->reg.weight_mask) + l.all->lda + 1];
- float c_w = eta*find_cw(*l.all, u_for_w, v_s)*s->f.x;
- for (size_t k = 0; k < l.all->lda; k++) {
- float new_value = u_for_w[k]*v_s[k]*c_w;
- l.total_new[k] += new_value;
- word_weights[k] += new_value;
- }
- }
- }
- for (size_t k = 0; k < l.all->lda; k++) {
- l.total_lambda[k] *= minuseta;
- l.total_lambda[k] += l.total_new[k];
- }
-
- l.sorted_features.resize(0);
-
- l.examples.erase();
- l.doc_lengths.erase();
- }
-
- void learn(lda& l, base_learner& base, example& ec)
- {
- size_t num_ex = l.examples.size();
- l.examples.push_back(&ec);
- l.doc_lengths.push_back(0);
- for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) {
- feature* f = ec.atomics[*i].begin;
- for (; f != ec.atomics[*i].end; f++) {
- index_feature temp = {(uint32_t)num_ex, *f};
- l.sorted_features.push_back(temp);
- l.doc_lengths[num_ex] += (int)f->x;
- }
- }
- if (++num_ex == l.all->minibatch)
- learn_batch(l);
- }
-
- // placeholder
- void predict(lda& l, base_learner& base, example& ec)
- {
- learn(l, base, ec);
- }
-
- void end_pass(lda& l)
- {
- if (l.examples.size())
- learn_batch(l);
- }
-
-void end_examples(lda& l)
-{
- for (size_t i = 0; i < l.all->length(); i++) {
- weight* weights_for_w = & (l.all->reg.weight_vector[i << l.all->reg.stride_shift]);
- float decay = fmin(1.0, exp(l.decay_levels.last() - l.decay_levels.end[(int)(-1- l.example_t +weights_for_w[l.all->lda])]));
- for (size_t k = 0; k < l.all->lda; k++)
- weights_for_w[k] *= decay;
- }
-}
-
- void finish_example(vw& all, lda&, example& ec)
-{}
-
- void finish(lda& ld)
- {
- ld.sorted_features.~vector<index_feature>();
- ld.Elogtheta.delete_v();
- ld.decay_levels.delete_v();
- ld.total_new.delete_v();
- ld.examples.delete_v();
- ld.total_lambda.delete_v();
- ld.doc_lengths.delete_v();
- ld.digammas.delete_v();
- ld.v.delete_v();
- }
-
-base_learner* setup(vw&all, po::variables_map& vm)
-{
- lda& ld = calloc_or_die<lda>();
- ld.sorted_features = vector<index_feature>();
- ld.total_lambda_init = 0;
- ld.all = &all;
- ld.example_t = all.initial_t;
-
- po::options_description lda_opts("LDA options");
- lda_opts.add_options()
- ("lda_alpha", po::value<float>(&all.lda_alpha), "Prior on sparsity of per-document topic weights")
- ("lda_rho", po::value<float>(&all.lda_rho), "Prior on sparsity of topic distributions")
- ("lda_D", po::value<float>(&all.lda_D), "Number of documents")
- ("lda_epsilon", po::value<float>(&all.lda_epsilon), "Loop convergence threshold")
- ("minibatch", po::value<size_t>(&all.minibatch), "Minibatch size, for LDA");
-
- vm = add_options(all, lda_opts);
-
- float temp = ceilf(logf((float)(all.lda*2+1)) / logf (2.f));
- all.reg.stride_shift = (size_t)temp;
- all.random_weights = true;
- all.add_constant = false;
-
- *all.file_options << " --lda " << all.lda;
-
- if (all.eta > 1.)
- {
- cerr << "your learning rate is too high, setting it to 1" << endl;
- all.eta = min(all.eta,1.f);
- }
-
- if (vm.count("minibatch")) {
- size_t minibatch2 = next_pow2(all.minibatch);
- all.p->ring_size = all.p->ring_size > minibatch2 ? all.p->ring_size : minibatch2;
- }
-
- ld.v.resize(all.lda*all.minibatch);
-
- ld.decay_levels.push_back(0.f);
-
- learner<lda>& l = init_learner(&ld, learn, 1 << all.reg.stride_shift);
- l.set_predict(predict);
- l.set_save_load(save_load);
- l.set_finish_example(finish_example);
- l.set_end_examples(end_examples);
- l.set_end_pass(end_pass);
- l.set_finish(finish);
-
- return make_base(l);
-}
-}
+/*
+Copyright (c) by respective owners including Yahoo!, Microsoft, and
+individual contributors. All rights reserved. Released under a BSD (revised)
+license as described in the file LICENSE.
+ */
+#include <fstream>
+#include <vector>
+#include <float.h>
+#ifdef _WIN32
+#include <winsock2.h>
+#else
+#include <netdb.h>
+#endif
+#include <string.h>
+#include <stdio.h>
+#include <assert.h>
+#include "constant.h"
+#include "gd.h"
+#include "simple_label.h"
+#include "rand48.h"
+#include "reductions.h"
+
+using namespace LEARNER;
+using namespace std;
+
+namespace LDA {
+
+class index_feature {
+public:
+ uint32_t document;
+ feature f;
+ bool operator<(const index_feature b) const { return f.weight_index < b.f.weight_index; }
+};
+
+ struct lda {
+ uint32_t topics;
+ float lda_alpha;
+ float lda_rho;
+ float lda_D;
+ float lda_epsilon;
+ size_t minibatch;
+
+ v_array<float> Elogtheta;
+ v_array<float> decay_levels;
+ v_array<float> total_new;
+ v_array<example* > examples;
+ v_array<float> total_lambda;
+ v_array<int> doc_lengths;
+ v_array<float> digammas;
+ v_array<float> v;
+ vector<index_feature> sorted_features;
+
+ bool total_lambda_init;
+
+ double example_t;
+ vw* all;
+ };
+
+#ifdef _WIN32
+inline float fmax(float f1, float f2) { return (f1 < f2 ? f2 : f1); }
+inline float fmin(float f1, float f2) { return (f1 > f2 ? f2 : f1); }
+#endif
+
+#define MINEIRO_SPECIAL
+#ifdef MINEIRO_SPECIAL
+
+namespace {
+
+inline float
+fastlog2 (float x)
+{
+ union { float f; uint32_t i; } vx = { x };
+ union { uint32_t i; float f; } mx = { (vx.i & 0x007FFFFF) | (0x7e << 23) };
+ float y = (float)vx.i;
+ y *= 1.0f / (float)(1 << 23);
+
+ return
+ y - 124.22544637f - 1.498030302f * mx.f - 1.72587999f / (0.3520887068f + mx.f);
+}
+
+inline float
+fastlog (float x)
+{
+ return 0.69314718f * fastlog2 (x);
+}
+
+inline float
+fastpow2 (float p)
+{
+ float offset = (p < 0) ? 1.0f : 0.0f;
+ float clipp = (p < -126) ? -126.0f : p;
+ int w = (int)clipp;
+ float z = clipp - w + offset;
+ union { uint32_t i; float f; } v = { (uint32_t)((1 << 23) * (clipp + 121.2740838f + 27.7280233f / (4.84252568f - z) - 1.49012907f * z)) };
+
+ return v.f;
+}
+
+inline float
+fastexp (float p)
+{
+ return fastpow2 (1.442695040f * p);
+}
+
+inline float
+fastpow (float x,
+ float p)
+{
+ return fastpow2 (p * fastlog2 (x));
+}
+
+inline float
+fastlgamma (float x)
+{
+ float logterm = fastlog (x * (1.0f + x) * (2.0f + x));
+ float xp3 = 3.0f + x;
+
+ return
+ -2.081061466f - x + 0.0833333f / xp3 - logterm + (2.5f + x) * fastlog (xp3);
+}
+
+inline float
+fastdigamma (float x)
+{
+ float twopx = 2.0f + x;
+ float logterm = fastlog (twopx);
+
+ return - (1.0f + 2.0f * x) / (x * (1.0f + x))
+ - (13.0f + 6.0f * x) / (12.0f * twopx * twopx)
+ + logterm;
+}
+
+#define log fastlog
+#define exp fastexp
+#define powf fastpow
+#define mydigamma fastdigamma
+#define mylgamma fastlgamma
+
+#if defined(__SSE2__) && !defined(VW_LDA_NO_SSE)
+
+#include <emmintrin.h>
+
+typedef __m128 v4sf;
+typedef __m128i v4si;
+
+#define v4si_to_v4sf _mm_cvtepi32_ps
+#define v4sf_to_v4si _mm_cvttps_epi32
+
+static inline float
+v4sf_index (const v4sf x,
+ unsigned int i)
+{
+ union { v4sf f; float array[4]; } tmp = { x };
+
+ return tmp.array[i];
+}
+
+static inline const v4sf
+v4sfl (float x)
+{
+ union { float array[4]; v4sf f; } tmp = { { x, x, x, x } };
+
+ return tmp.f;
+}
+
+static inline const v4si
+v4sil (uint32_t x)
+{
+ uint64_t wide = (((uint64_t) x) << 32) | x;
+ union { uint64_t array[2]; v4si f; } tmp = { { wide, wide } };
+
+ return tmp.f;
+}
+
+static inline v4sf
+vfastpow2 (const v4sf p)
+{
+ v4sf ltzero = _mm_cmplt_ps (p, v4sfl (0.0f));
+ v4sf offset = _mm_and_ps (ltzero, v4sfl (1.0f));
+ v4sf lt126 = _mm_cmplt_ps (p, v4sfl (-126.0f));
+ v4sf clipp = _mm_andnot_ps (lt126, p) + _mm_and_ps (lt126, v4sfl (-126.0f));
+ v4si w = v4sf_to_v4si (clipp);
+ v4sf z = clipp - v4si_to_v4sf (w) + offset;
+
+ const v4sf c_121_2740838 = v4sfl (121.2740838f);
+ const v4sf c_27_7280233 = v4sfl (27.7280233f);
+ const v4sf c_4_84252568 = v4sfl (4.84252568f);
+ const v4sf c_1_49012907 = v4sfl (1.49012907f);
+ union { v4si i; v4sf f; } v = {
+ v4sf_to_v4si (
+ v4sfl (1 << 23) *
+ (clipp + c_121_2740838 + c_27_7280233 / (c_4_84252568 - z) - c_1_49012907 * z)
+ )
+ };
+
+ return v.f;
+}
+
+inline v4sf
+vfastexp (const v4sf p)
+{
+ const v4sf c_invlog_2 = v4sfl (1.442695040f);
+
+ return vfastpow2 (c_invlog_2 * p);
+}
+
+inline v4sf
+vfastlog2 (v4sf x)
+{
+ union { v4sf f; v4si i; } vx = { x };
+ union { v4si i; v4sf f; } mx = { (vx.i & v4sil (0x007FFFFF)) | v4sil (0x3f000000) };
+ v4sf y = v4si_to_v4sf (vx.i);
+ y *= v4sfl (1.1920928955078125e-7f);
+
+ const v4sf c_124_22551499 = v4sfl (124.22551499f);
+ const v4sf c_1_498030302 = v4sfl (1.498030302f);
+ const v4sf c_1_725877999 = v4sfl (1.72587999f);
+ const v4sf c_0_3520087068 = v4sfl (0.3520887068f);
+
+ return y - c_124_22551499
+ - c_1_498030302 * mx.f
+ - c_1_725877999 / (c_0_3520087068 + mx.f);
+}
+
+inline v4sf
+vfastlog (v4sf x)
+{
+ const v4sf c_0_69314718 = v4sfl (0.69314718f);
+
+ return c_0_69314718 * vfastlog2 (x);
+}
+
+inline v4sf
+vfastdigamma (v4sf x)
+{
+ v4sf twopx = v4sfl (2.0f) + x;
+ v4sf logterm = vfastlog (twopx);
+
+ return (v4sfl (-48.0f) + x * (v4sfl (-157.0f) + x * (v4sfl (-127.0f) - v4sfl (30.0f) * x))) /
+ (v4sfl (12.0f) * x * (v4sfl (1.0f) + x) * twopx * twopx)
+ + logterm;
+}
+
+void
+vexpdigammify (vw& all, float* gamma)
+{
+ unsigned int n = all.lda;
+ float extra_sum = 0.0f;
+ v4sf sum = v4sfl (0.0f);
+ size_t i;
+
+ for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
+ {
+ extra_sum += gamma[i];
+ gamma[i] = fastdigamma (gamma[i]);
+ }
+
+ for (; i + 4 < n; i += 4)
+ {
+ v4sf arg = _mm_load_ps (gamma + i);
+ sum += arg;
+ arg = vfastdigamma (arg);
+ _mm_store_ps (gamma + i, arg);
+ }
+
+ for (; i < n; ++i)
+ {
+ extra_sum += gamma[i];
+ gamma[i] = fastdigamma (gamma[i]);
+ }
+
+ extra_sum += v4sf_index (sum, 0) + v4sf_index (sum, 1) +
+ v4sf_index (sum, 2) + v4sf_index (sum, 3);
+ extra_sum = fastdigamma (extra_sum);
+ sum = v4sfl (extra_sum);
+
+ for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
+ {
+ gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0)));
+ }
+
+ for (; i + 4 < n; i += 4)
+ {
+ v4sf arg = _mm_load_ps (gamma + i);
+ arg -= sum;
+ arg = vfastexp (arg);
+ arg = _mm_max_ps (v4sfl (1e-10f), arg);
+ _mm_store_ps (gamma + i, arg);
+ }
+
+ for (; i < n; ++i)
+ {
+ gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0)));
+ }
+}
+
+void vexpdigammify_2(vw& all, float* gamma, const float* norm)
+{
+ size_t n = all.lda;
+ size_t i;
+
+ for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
+ {
+ gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i]));
+ }
+
+ for (; i + 4 < n; i += 4)
+ {
+ v4sf arg = _mm_load_ps (gamma + i);
+ arg = vfastdigamma (arg);
+ v4sf vnorm = _mm_loadu_ps (norm + i);
+ arg -= vnorm;
+ arg = vfastexp (arg);
+ arg = _mm_max_ps (v4sfl (1e-10f), arg);
+ _mm_store_ps (gamma + i, arg);
+ }
+
+ for (; i < n; ++i)
+ {
+ gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i]));
+ }
+}
+
+#define myexpdigammify vexpdigammify
+#define myexpdigammify_2 vexpdigammify_2
+
+#else
+#ifndef _WIN32
+#warning "lda IS NOT using sse instructions"
+#endif
+#define myexpdigammify expdigammify
+#define myexpdigammify_2 expdigammify_2
+
+#endif // __SSE2__
+
+} // end anonymous namespace
+
+#else
+
+#include <boost/math/special_functions/digamma.hpp>
+#include <boost/math/special_functions/gamma.hpp>
+
+using namespace boost::math::policies;
+
+#define mydigamma boost::math::digamma
+#define mylgamma boost::math::lgamma
+#define myexpdigammify expdigammify
+#define myexpdigammify_2 expdigammify_2
+
+#endif // MINEIRO_SPECIAL
+
+float decayfunc(float t, float old_t, float power_t) {
+ float result = 1;
+ for (float i = old_t+1; i <= t; i += 1)
+ result *= (1-powf(i, -power_t));
+ return result;
+}
+
+float decayfunc2(float t, float old_t, float power_t)
+{
+ float power_t_plus_one = 1.f - power_t;
+ float arg = - ( powf(t, power_t_plus_one) -
+ powf(old_t, power_t_plus_one));
+ return exp ( arg
+ / power_t_plus_one);
+}
+
+float decayfunc3(double t, double old_t, double power_t)
+{
+ double power_t_plus_one = 1. - power_t;
+ double logt = log((float)t);
+ double logoldt = log((float)old_t);
+ return (float)((old_t / t) * exp((float)(0.5*power_t_plus_one*(-logt*logt + logoldt*logoldt))));
+}
+
+float decayfunc4(double t, double old_t, double power_t)
+{
+ if (power_t > 0.99)
+ return decayfunc3(t, old_t, power_t);
+ else
+ return (float)decayfunc2((float)t, (float)old_t, (float)power_t);
+}
+
+void expdigammify(vw& all, float* gamma)
+{
+ float sum=0;
+ for (size_t i = 0; i<all.lda; i++)
+ {
+ sum += gamma[i];
+ gamma[i] = mydigamma(gamma[i]);
+ }
+ sum = mydigamma(sum);
+ for (size_t i = 0; i<all.lda; i++)
+ gamma[i] = fmax(1e-6f, exp(gamma[i] - sum));
+}
+
+void expdigammify_2(vw& all, float* gamma, float* norm)
+{
+ for (size_t i = 0; i<all.lda; i++)
+ {
+ gamma[i] = fmax(1e-6f, exp(mydigamma(gamma[i]) - norm[i]));
+ }
+}
+
+float average_diff(vw& all, float* oldgamma, float* newgamma)
+{
+ float sum = 0.;
+ float normalizer = 0.;
+ for (size_t i = 0; i<all.lda; i++) {
+ sum += fabsf(oldgamma[i] - newgamma[i]);
+ normalizer += newgamma[i];
+ }
+ return sum / normalizer;
+}
+
+// Returns E_q[log p(\theta)] - E_q[log q(\theta)].
+ float theta_kl(lda& l, v_array<float>& Elogtheta, float* gamma)
+{
+ float gammasum = 0;
+ Elogtheta.erase();
+ for (size_t k = 0; k < l.topics; k++) {
+ Elogtheta.push_back(mydigamma(gamma[k]));
+ gammasum += gamma[k];
+ }
+ float digammasum = mydigamma(gammasum);
+ gammasum = mylgamma(gammasum);
+ float kl = -(l.topics*mylgamma(l.lda_alpha));
+ kl += mylgamma(l.lda_alpha*l.topics) - gammasum;
+ for (size_t k = 0; k < l.topics; k++) {
+ Elogtheta[k] -= digammasum;
+ kl += (l.lda_alpha - gamma[k]) * Elogtheta[k];
+ kl += mylgamma(gamma[k]);
+ }
+
+ return kl;
+}
+
+float find_cw(lda& l, float* u_for_w, float* v)
+{
+ float c_w = 0;
+ for (size_t k =0; k<l.topics; k++)
+ c_w += u_for_w[k]*v[k];
+
+ return 1.f / c_w;
+}
+
+ v_array<float> new_gamma = v_init<float>();
+ v_array<float> old_gamma = v_init<float>();
+// Returns an estimate of the part of the variational bound that
+// doesn't have to do with beta for the entire corpus for the current
+// setting of lambda based on the document passed in. The value is
+// divided by the total number of words in the document This can be
+// used as a (possibly very noisy) estimate of held-out likelihood.
+ float lda_loop(lda& l, v_array<float>& Elogtheta, float* v,weight* weights,example* ec, float power_t)
+{
+ new_gamma.erase();
+ old_gamma.erase();
+
+ for (size_t i = 0; i < l.topics; i++)
+ {
+ new_gamma.push_back(1.f);
+ old_gamma.push_back(0.f);
+ }
+ size_t num_words =0;
+ for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
+ num_words += ec->atomics[*i].end - ec->atomics[*i].begin;
+
+ float xc_w = 0;
+ float score = 0;
+ float doc_length = 0;
+ do
+ {
+ memcpy(v,new_gamma.begin,sizeof(float)*l.topics);
+ myexpdigammify(*l.all, v);
+
+ memcpy(old_gamma.begin,new_gamma.begin,sizeof(float)*l.topics);
+ memset(new_gamma.begin,0,sizeof(float)*l.topics);
+
+ score = 0;
+ size_t word_count = 0;
+ doc_length = 0;
+ for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
+ {
+ feature *f = ec->atomics[*i].begin;
+ for (; f != ec->atomics[*i].end; f++)
+ {
+ float* u_for_w = &weights[(f->weight_index & l.all->reg.weight_mask)+l.topics+1];
+ float c_w = find_cw(l, u_for_w,v);
+ xc_w = c_w * f->x;
+ score += -f->x*log(c_w);
+ size_t max_k = l.topics;
+ for (size_t k =0; k<max_k; k++) {
+ new_gamma[k] += xc_w*u_for_w[k];
+ }
+ word_count++;
+ doc_length += f->x;
+ }
+ }
+ for (size_t k =0; k<l.topics; k++)
+ new_gamma[k] = new_gamma[k]*v[k]+l.lda_alpha;
+ }
+ while (average_diff(*l.all, old_gamma.begin, new_gamma.begin) > l.lda_epsilon);
+
+ ec->topic_predictions.erase();
+ ec->topic_predictions.resize(l.topics);
+ memcpy(ec->topic_predictions.begin,new_gamma.begin,l.topics*sizeof(float));
+
+ score += theta_kl(l, Elogtheta, new_gamma.begin);
+
+ return score / doc_length;
+}
+
+size_t next_pow2(size_t x) {
+ int i = 0;
+ x = x > 0 ? x - 1 : 0;
+ while (x > 0) {
+ x >>= 1;
+ i++;
+ }
+ return ((size_t)1) << i;
+}
+
+void save_load(lda& l, io_buf& model_file, bool read, bool text)
+{
+ vw* all = l.all;
+ uint32_t length = 1 << all->num_bits;
+ uint32_t stride = 1 << all->reg.stride_shift;
+
+ if (read)
+ {
+ initialize_regressor(*all);
+ for (size_t j = 0; j < stride*length; j+=stride)
+ {
+ for (size_t k = 0; k < all->lda; k++) {
+ if (all->random_weights) {
+ all->reg.weight_vector[j+k] = (float)(-log(frand48()) + 1.0f);
+ all->reg.weight_vector[j+k] *= (float)(l.lda_D / all->lda / all->length() * 200);
+ }
+ }
+ all->reg.weight_vector[j+all->lda] = all->initial_t;
+ }
+ }
+
+ if (model_file.files.size() > 0)
+ {
+ uint32_t i = 0;
+ uint32_t text_len;
+ char buff[512];
+ size_t brw = 1;
+ do
+ {
+ brw = 0;
+ size_t K = all->lda;
+
+ text_len = sprintf(buff, "%d ", i);
+ brw += bin_text_read_write_fixed(model_file,(char *)&i, sizeof (i),
+ "", read,
+ buff, text_len, text);
+ if (brw != 0)
+ for (uint32_t k = 0; k < K; k++)
+ {
+ uint32_t ndx = stride*i+k;
+
+ weight* v = &(all->reg.weight_vector[ndx]);
+ text_len = sprintf(buff, "%f ", *v + l.lda_rho);
+
+ brw += bin_text_read_write_fixed(model_file,(char *)v, sizeof (*v),
+ "", read,
+ buff, text_len, text);
+
+ }
+ if (text)
+ brw += bin_text_read_write_fixed(model_file,buff,0,
+ "", read,
+ "\n",1,text);
+
+ if (!read)
+ i++;
+ }
+ while ((!read && i < length) || (read && brw >0));
+ }
+}
+
+ void learn_batch(lda& l)
+ {
+ if (l.sorted_features.empty()) {
+ // This can happen when the socket connection is dropped by the client.
+ // If l.sorted_features is empty, then l.sorted_features[0] does not
+ // exist, so we should not try to take its address in the beginning of
+ // the for loops down there. Since it seems that there's not much to
+ // do in this case, we just return.
+ for (size_t d = 0; d < l.examples.size(); d++)
+ return_simple_example(*l.all, NULL, *l.examples[d]);
+ l.examples.erase();
+ return;
+ }
+
+ float eta = -1;
+ float minuseta = -1;
+
+ if (l.total_lambda.size() == 0)
+ {
+ for (size_t k = 0; k < l.all->lda; k++)
+ l.total_lambda.push_back(0.f);
+
+ size_t stride = 1 << l.all->reg.stride_shift;
+ for (size_t i =0; i <= l.all->reg.weight_mask;i+=stride)
+ for (size_t k = 0; k < l.all->lda; k++)
+ l.total_lambda[k] += l.all->reg.weight_vector[i+k];
+ }
+
+ l.example_t++;
+ l.total_new.erase();
+ for (size_t k = 0; k < l.all->lda; k++)
+ l.total_new.push_back(0.f);
+
+ size_t batch_size = l.examples.size();
+
+ sort(l.sorted_features.begin(), l.sorted_features.end());
+
+ eta = l.all->eta * powf((float)l.example_t, - l.all->power_t);
+ minuseta = 1.0f - eta;
+ eta *= l.lda_D / batch_size;
+ l.decay_levels.push_back(l.decay_levels.last() + log(minuseta));
+
+ l.digammas.erase();
+ float additional = (float)(l.all->length()) * l.lda_rho;
+ for (size_t i = 0; i<l.all->lda; i++) {
+ l.digammas.push_back(mydigamma(l.total_lambda[i] + additional));
+ }
+
+
+ weight* weights = l.all->reg.weight_vector;
+
+ size_t last_weight_index = -1;
+ for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back(); s++)
+ {
+ if (last_weight_index == s->f.weight_index)
+ continue;
+ last_weight_index = s->f.weight_index;
+ float* weights_for_w = &(weights[s->f.weight_index & l.all->reg.weight_mask]);
+ float decay = fmin(1.0, exp(l.decay_levels.end[-2] - l.decay_levels.end[(int)(-1 - l.example_t+weights_for_w[l.all->lda])]));
+ float* u_for_w = weights_for_w + l.all->lda+1;
+
+ weights_for_w[l.all->lda] = (float)l.example_t;
+ for (size_t k = 0; k < l.all->lda; k++)
+ {
+ weights_for_w[k] *= decay;
+ u_for_w[k] = weights_for_w[k] + l.lda_rho;
+ }
+ myexpdigammify_2(*l.all, u_for_w, l.digammas.begin);
+ }
+
+ for (size_t d = 0; d < batch_size; d++)
+ {
+ float score = lda_loop(l, l.Elogtheta, &(l.v[d*l.all->lda]), weights, l.examples[d],l.all->power_t);
+ if (l.all->audit)
+ GD::print_audit_features(*l.all, *l.examples[d]);
+ // If the doc is empty, give it loss of 0.
+ if (l.doc_lengths[d] > 0) {
+ l.all->sd->sum_loss -= score;
+ l.all->sd->sum_loss_since_last_dump -= score;
+ }
+ return_simple_example(*l.all, NULL, *l.examples[d]);
+ }
+
+ for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back();)
+ {
+ index_feature* next = s+1;
+ while(next <= &l.sorted_features.back() && next->f.weight_index == s->f.weight_index)
+ next++;
+
+ float* word_weights = &(weights[s->f.weight_index & l.all->reg.weight_mask]);
+ for (size_t k = 0; k < l.all->lda; k++) {
+ float new_value = minuseta*word_weights[k];
+ word_weights[k] = new_value;
+ }
+
+ for (; s != next; s++) {
+ float* v_s = &(l.v[s->document*l.all->lda]);
+ float* u_for_w = &weights[(s->f.weight_index & l.all->reg.weight_mask) + l.all->lda + 1];
+ float c_w = eta*find_cw(l, u_for_w, v_s)*s->f.x;
+ for (size_t k = 0; k < l.all->lda; k++) {
+ float new_value = u_for_w[k]*v_s[k]*c_w;
+ l.total_new[k] += new_value;
+ word_weights[k] += new_value;
+ }
+ }
+ }
+ for (size_t k = 0; k < l.all->lda; k++) {
+ l.total_lambda[k] *= minuseta;
+ l.total_lambda[k] += l.total_new[k];
+ }
+
+ l.sorted_features.resize(0);
+
+ l.examples.erase();
+ l.doc_lengths.erase();
+ }
+
+ void learn(lda& l, base_learner& base, example& ec)
+ {
+ size_t num_ex = l.examples.size();
+ l.examples.push_back(&ec);
+ l.doc_lengths.push_back(0);
+ for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) {
+ feature* f = ec.atomics[*i].begin;
+ for (; f != ec.atomics[*i].end; f++) {
+ index_feature temp = {(uint32_t)num_ex, *f};
+ l.sorted_features.push_back(temp);
+ l.doc_lengths[num_ex] += (int)f->x;
+ }
+ }
+ if (++num_ex == l.minibatch)
+ learn_batch(l);
+ }
+
+ // placeholder
+ void predict(lda& l, base_learner& base, example& ec)
+ {
+ learn(l, base, ec);
+ }
+
+ void end_pass(lda& l)
+ {
+ if (l.examples.size())
+ learn_batch(l);
+ }
+
+void end_examples(lda& l)
+{
+ for (size_t i = 0; i < l.all->length(); i++) {
+ weight* weights_for_w = & (l.all->reg.weight_vector[i << l.all->reg.stride_shift]);
+ float decay = fmin(1.0, exp(l.decay_levels.last() - l.decay_levels.end[(int)(-1- l.example_t +weights_for_w[l.all->lda])]));
+ for (size_t k = 0; k < l.all->lda; k++)
+ weights_for_w[k] *= decay;
+ }
+}
+
+ void finish_example(vw& all, lda&, example& ec)
+{}
+
+ void finish(lda& ld)
+ {
+ ld.sorted_features.~vector<index_feature>();
+ ld.Elogtheta.delete_v();
+ ld.decay_levels.delete_v();
+ ld.total_new.delete_v();
+ ld.examples.delete_v();
+ ld.total_lambda.delete_v();
+ ld.doc_lengths.delete_v();
+ ld.digammas.delete_v();
+ ld.v.delete_v();
+ }
+
+
+base_learner* setup(vw&all)
+{
+ new_options(all, "Lda options")
+ ("lda", po::value<uint32_t>(), "Run lda with <int> topics")
+ ("lda_alpha", po::value<float>()->default_value(0.1f), "Prior on sparsity of per-document topic weights")
+ ("lda_rho", po::value<float>()->default_value(0.1f), "Prior on sparsity of topic distributions")
+ ("lda_D", po::value<float>()->default_value(10000.), "Number of documents")
+ ("lda_epsilon", po::value<float>()->default_value(0.001f), "Loop convergence threshold")
+ ("minibatch", po::value<size_t>()->default_value(1), "Minibatch size, for LDA");
+ add_options(all);
+ po::variables_map& vm= all.vm;
+ if(!vm.count("lda"))
+ return NULL;
+ else
+ all.lda = vm["lda"].as<uint32_t>();
+
+ lda& ld = calloc_or_die<lda>();
+
+ ld.topics = all.lda;
+ ld.lda_alpha = vm["lda_alpha"].as<float>();
+ ld.lda_rho = vm["lda_rho"].as<float>();
+ ld.lda_D = vm["lda_D"].as<float>();
+ ld.lda_epsilon = vm["lda_epsilon"].as<float>();
+ ld.minibatch = vm["minibatch"].as<size_t>();
+ ld.sorted_features = vector<index_feature>();
+ ld.total_lambda_init = 0;
+ ld.all = &all;
+ ld.example_t = all.initial_t;
+
+ float temp = ceilf(logf((float)(all.lda*2+1)) / logf (2.f));
+ all.reg.stride_shift = (size_t)temp;
+ all.random_weights = true;
+ all.add_constant = false;
+
+ *all.file_options << " --lda " << all.lda;
+
+ if (all.eta > 1.)
+ {
+ cerr << "your learning rate is too high, setting it to 1" << endl;
+ all.eta = min(all.eta,1.f);
+ }
+
+ if (vm.count("minibatch")) {
+ size_t minibatch2 = next_pow2(ld.minibatch);
+ all.p->ring_size = all.p->ring_size > minibatch2 ? all.p->ring_size : minibatch2;
+ }
+
+ ld.v.resize(all.lda*ld.minibatch);
+
+ ld.decay_levels.push_back(0.f);
+
+ learner<lda>& l = init_learner(&ld, learn, 1 << all.reg.stride_shift);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
+ l.set_finish_example(finish_example);
+ l.set_end_examples(end_examples);
+ l.set_end_pass(end_pass);
+ l.set_finish(finish);
+
+ return make_base(l);
+}
+}
diff --git a/vowpalwabbit/lda_core.h b/vowpalwabbit/lda_core.h
index 2a065783..e734fcad 100644
--- a/vowpalwabbit/lda_core.h
+++ b/vowpalwabbit/lda_core.h
@@ -4,6 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace LDA{
- LEARNER::base_learner* setup(vw&, po::variables_map&);
-}
+namespace LDA{ LEARNER::base_learner* setup(vw&); }
diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h
index 2649429c..b832306c 100644
--- a/vowpalwabbit/learner.h
+++ b/vowpalwabbit/learner.h
@@ -186,7 +186,7 @@ namespace LEARNER
template<class T>
learner<T>& init_learner(T* dat, base_learner* base,
void (*learn)(T&, base_learner&, example&),
- void (*predict)(T&, base_learner&, example&), size_t ws = 1)
+ void (*predict)(T&, base_learner&, example&), size_t ws)
{ //the reduction constructor, with separate learn and predict functions
learner<T>& ret = calloc_or_die<learner<T> >();
ret = *(learner<T>*)base;
diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc
index 226376bd..e673b0d8 100644
--- a/vowpalwabbit/log_multi.cc
+++ b/vowpalwabbit/log_multi.cc
@@ -12,7 +12,6 @@ license as described in the file LICENSE.node
#include "reductions.h"
#include "simple_label.h"
#include "multiclass.h"
-#include "vw.h"
using namespace std;
using namespace LEARNER;
@@ -77,12 +76,11 @@ namespace LOG_MULTI
struct log_multi
{
uint32_t k;
- vw* all;
v_array<node> nodes;
- uint32_t max_predictors;
- uint32_t predictors_used;
+ size_t max_predictors;
+ size_t predictors_used;
bool progress;
uint32_t swap_resist;
@@ -199,7 +197,7 @@ namespace LOG_MULTI
b.nodes.push_back(init_node());
right_child = (uint32_t)b.nodes.size();
b.nodes.push_back(init_node());
- b.nodes[current].base_predictor = b.predictors_used++;
+ b.nodes[current].base_predictor = (uint32_t)b.predictors_used++;
}
else
{
@@ -319,11 +317,10 @@ namespace LOG_MULTI
void learn(log_multi& b, base_learner& base, example& ec)
{
// verify_min_dfs(b, b.nodes[0]);
-
- if (ec.l.multi.label == (uint32_t)-1 || !b.all->training || b.progress)
+ if (ec.l.multi.label == (uint32_t)-1 || b.progress)
predict(b,base,ec);
- if(b.all->training && (ec.l.multi.label != (uint32_t)-1) && !ec.test_only) //if training the tree
+ if((ec.l.multi.label != (uint32_t)-1) && !ec.test_only) //if training the tree
{
MULTICLASS::multiclass mc = ec.l.multi;
@@ -413,10 +410,10 @@ namespace LOG_MULTI
if (read)
for (uint32_t j = 1; j < temp; j++)
b.nodes.push_back(init_node());
- text_len = sprintf(buff, "max_predictors = %d ",b.max_predictors);
+ text_len = sprintf(buff, "max_predictors = %ld ",b.max_predictors);
bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.max_predictors), "", read, buff, text_len, text);
- text_len = sprintf(buff, "predictors_used = %d ",b.predictors_used);
+ text_len = sprintf(buff, "predictors_used = %ld ",b.predictors_used);
bin_text_read_write_fixed(model_file,(char*)&b.predictors_used, sizeof(b.predictors_used), "", read, buff, text_len, text);
text_len = sprintf(buff, "progress = %d ",b.progress);
@@ -496,20 +493,25 @@ namespace LOG_MULTI
}
}
- void finish_example(vw& all, log_multi&, example& ec) { MULTICLASS::finish_example(all, ec); }
-
- base_learner* setup(vw& all, po::variables_map& vm) //learner setup
+ base_learner* setup(vw& all) //learner setup
{
- log_multi& data = calloc_or_die<log_multi>();
-
- po::options_description opts("TXM Online options");
- opts.add_options()
+ new_options(all, "Logarithmic Time Multiclass options")
+ ("log_multi", po::value<size_t>(), "Use online tree for multiclass");
+ if (missing_required(all)) return NULL;
+ new_options(all)
("no_progress", "disable progressive validation")
- ("swap_resistance", po::value<uint32_t>(&(data.swap_resist))->default_value(4), "higher = more resistance to swap, default=4");
-
- vm = add_options(all, opts);
-
+ ("swap_resistance", po::value<uint32_t>(), "higher = more resistance to swap, default=4");
+ add_options(all);
+
+ po::variables_map& vm = all.vm;
+
+ log_multi& data = calloc_or_die<log_multi>();
data.k = (uint32_t)vm["log_multi"].as<size_t>();
+ data.swap_resist = 4;
+
+ if (vm.count("swap_resistance"))
+ data.swap_resist = vm["swap_resistance"].as<uint32_t>();
+
*all.file_options << " --log_multi " << data.k;
if (vm.count("no_progress"))
@@ -517,9 +519,6 @@ namespace LOG_MULTI
else
data.progress = true;
- data.all = &all;
- (all.p->lp) = MULTICLASS::mc_label;
-
string loss_function = "quantile";
float loss_parameter = 0.5;
delete(all.loss);
@@ -527,10 +526,11 @@ namespace LOG_MULTI
data.max_predictors = data.k - 1;
- learner<log_multi>& l = init_learner(&data, all.l, learn, predict, data.max_predictors);
+ learner<log_multi>& l = init_learner(&data, setup_base(all), learn, predict, data.max_predictors);
l.set_save_load(save_load_tree);
- l.set_finish_example(finish_example);
l.set_finish(finish);
+ l.set_finish_example(MULTICLASS::finish_example<log_multi>);
+ all.p->lp = MULTICLASS::mc_label;
init_tree(data);
diff --git a/vowpalwabbit/log_multi.h b/vowpalwabbit/log_multi.h
index 5e1ee3bf..0660334a 100644
--- a/vowpalwabbit/log_multi.h
+++ b/vowpalwabbit/log_multi.h
@@ -4,7 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace LOG_MULTI
-{
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace LOG_MULTI { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/lrq.cc b/vowpalwabbit/lrq.cc
index 268815d5..13534375 100644
--- a/vowpalwabbit/lrq.cc
+++ b/vowpalwabbit/lrq.cc
@@ -187,24 +187,34 @@ namespace LRQ {
}
}
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{//parse and set arguments
+ new_options(all, "Lrq options")
+ ("lrq", po::value<vector<string> > (), "use low rank quadratic features");
+ if (missing_required(all)) return NULL;
+ new_options(all)
+ ("lrqdropout", "use dropout training for low rank quadratic features");
+ add_options(all);
+
+ if(!all.vm.count("lrq"))
+ return NULL;
+
LRQstate& lrq = calloc_or_die<LRQstate>();
size_t maxk = 0;
lrq.all = &all;
size_t random_seed = 0;
- if (vm.count("random_seed")) random_seed = vm["random_seed"].as<size_t> ();
+ if (all.vm.count("random_seed")) random_seed = all.vm["random_seed"].as<size_t> ();
lrq.initial_seed = lrq.seed = random_seed | 8675309;
- if (vm.count("lrqdropout"))
+ if (all.vm.count("lrqdropout"))
lrq.dropout = true;
else
lrq.dropout = false;
*all.file_options << " --lrqdropout ";
- lrq.lrpairs = vm["lrq"].as<vector<string> > ();
+ lrq.lrpairs = all.vm["lrq"].as<vector<string> > ();
for (vector<string>::iterator i = lrq.lrpairs.begin ();
i != lrq.lrpairs.end ();
@@ -242,8 +252,8 @@ namespace LRQ {
if(!all.quiet)
cerr<<endl;
- all.wpp = all.wpp * (1 + maxk);
- learner<LRQstate>& l = init_learner(&lrq, all.l, predict_or_learn<true>,
+ all.wpp = all.wpp * (uint32_t)(1 + maxk);
+ learner<LRQstate>& l = init_learner(&lrq, setup_base(all), predict_or_learn<true>,
predict_or_learn<false>, 1 + maxk);
l.set_end_pass(reset_seed);
diff --git a/vowpalwabbit/lrq.h b/vowpalwabbit/lrq.h
index 376bd6e5..b08e24f4 100644
--- a/vowpalwabbit/lrq.h
+++ b/vowpalwabbit/lrq.h
@@ -1,4 +1,2 @@
#pragma once
-namespace LRQ {
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace LRQ { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/main.cc b/vowpalwabbit/main.cc
index c7f40326..95e9d2f9 100644
--- a/vowpalwabbit/main.cc
+++ b/vowpalwabbit/main.cc
@@ -21,11 +21,11 @@ using namespace std;
int main(int argc, char *argv[])
{
try {
- vw *all = parse_args(argc, argv);
+ vw& all = parse_args(argc, argv);
struct timeb t_start, t_end;
ftime(&t_start);
- if (!all->quiet && !all->bfgs && !all->searchstr)
+ if (!all.quiet && !all.bfgs && !all.searchstr)
{
const char * header_fmt = "%-10s %-10s %10s %11s %8s %8s %8s\n";
fprintf(stderr, header_fmt,
@@ -36,67 +36,63 @@ int main(int argc, char *argv[])
cerr.precision(5);
}
- VW::start_parser(*all);
- LEARNER::generic_driver(*all);
- VW::end_parser(*all);
+ VW::start_parser(all);
+ LEARNER::generic_driver(all);
+ VW::end_parser(all);
ftime(&t_end);
double net_time = (int) (1000.0 * (t_end.time - t_start.time) + (t_end.millitm - t_start.millitm));
- if(!all->quiet && all->span_server != "")
+ if(!all.quiet && all.span_server != "")
cerr<<"Net time taken by process = "<<net_time/(double)(1000)<<" seconds\n";
- if(all->span_server != "") {
- float loss = (float)all->sd->sum_loss;
- all->sd->sum_loss = (double)accumulate_scalar(*all, all->span_server, loss);
- float weighted_examples = (float)all->sd->weighted_examples;
- all->sd->weighted_examples = (double)accumulate_scalar(*all, all->span_server, weighted_examples);
- float weighted_labels = (float)all->sd->weighted_labels;
- all->sd->weighted_labels = (double)accumulate_scalar(*all, all->span_server, weighted_labels);
- float weighted_unlabeled_examples = (float)all->sd->weighted_unlabeled_examples;
- all->sd->weighted_unlabeled_examples = (double)accumulate_scalar(*all, all->span_server, weighted_unlabeled_examples);
- float example_number = (float)all->sd->example_number;
- all->sd->example_number = (uint64_t)accumulate_scalar(*all, all->span_server, example_number);
- float total_features = (float)all->sd->total_features;
- all->sd->total_features = (uint64_t)accumulate_scalar(*all, all->span_server, total_features);
+ if(all.span_server != "") {
+ float loss = (float)all.sd->sum_loss;
+ all.sd->sum_loss = (double)accumulate_scalar(all, all.span_server, loss);
+ float weighted_examples = (float)all.sd->weighted_examples;
+ all.sd->weighted_examples = (double)accumulate_scalar(all, all.span_server, weighted_examples);
+ float weighted_labels = (float)all.sd->weighted_labels;
+ all.sd->weighted_labels = (double)accumulate_scalar(all, all.span_server, weighted_labels);
+ float weighted_unlabeled_examples = (float)all.sd->weighted_unlabeled_examples;
+ all.sd->weighted_unlabeled_examples = (double)accumulate_scalar(all, all.span_server, weighted_unlabeled_examples);
+ float example_number = (float)all.sd->example_number;
+ all.sd->example_number = (uint64_t)accumulate_scalar(all, all.span_server, example_number);
+ float total_features = (float)all.sd->total_features;
+ all.sd->total_features = (uint64_t)accumulate_scalar(all, all.span_server, total_features);
}
- if (!all->quiet)
+ if (!all.quiet)
{
cerr.precision(6);
cerr << endl << "finished run";
- if(all->current_pass == 0)
- cerr << endl << "number of examples = " << all->sd->example_number;
+ if(all.current_pass == 0)
+ cerr << endl << "number of examples = " << all.sd->example_number;
else{
- cerr << endl << "number of examples per pass = " << all->sd->example_number / all->current_pass;
- cerr << endl << "passes used = " << all->current_pass;
- }
- cerr << endl << "weighted example sum = " << all->sd->weighted_examples;
- cerr << endl << "weighted label sum = " << all->sd->weighted_labels;
- if(all->holdout_set_off || (all->sd->holdout_best_loss == FLT_MAX))
- {
- cerr << endl << "average loss = " << all->sd->sum_loss / all->sd->weighted_examples;
- }
- else
- {
- cerr << endl << "average loss = " << all->sd->holdout_best_loss << " h";
+ cerr << endl << "number of examples per pass = " << all.sd->example_number / all.current_pass;
+ cerr << endl << "passes used = " << all.current_pass;
}
+ cerr << endl << "weighted example sum = " << all.sd->weighted_examples;
+ cerr << endl << "weighted label sum = " << all.sd->weighted_labels;
+ if(all.holdout_set_off || (all.sd->holdout_best_loss == FLT_MAX))
+ cerr << endl << "average loss = " << all.sd->sum_loss / all.sd->weighted_examples;
+ else
+ cerr << endl << "average loss = " << all.sd->holdout_best_loss << " h";
float best_constant; float best_constant_loss;
- if (get_best_constant(*all, best_constant, best_constant_loss))
- {
+ if (get_best_constant(all, best_constant, best_constant_loss))
+ {
cerr << endl << "best constant = " << best_constant;
if (best_constant_loss != FLT_MIN)
- cerr << endl << "best constant's loss = " << best_constant_loss;
- }
-
- cerr << endl << "total feature number = " << all->sd->total_features;
- if (all->sd->queries > 0)
- cerr << endl << "total queries = " << all->sd->queries << endl;
+ cerr << endl << "best constant's loss = " << best_constant_loss;
+ }
+
+ cerr << endl << "total feature number = " << all.sd->total_features;
+ if (all.sd->queries > 0)
+ cerr << endl << "total queries = " << all.sd->queries << endl;
cerr << endl;
}
- VW::finish(*all);
+ VW::finish(all);
} catch (exception& e) {
// vw is implemented as a library, so we use 'throw exception()'
// error 'handling' everywhere. To reduce stderr pollution
diff --git a/vowpalwabbit/memory.cc b/vowpalwabbit/memory.cc
deleted file mode 100644
index 7e40bf71..00000000
--- a/vowpalwabbit/memory.cc
+++ /dev/null
@@ -1,8 +0,0 @@
-#include <stdlib.h>
-
-void free_it(void* ptr)
-{
- if (ptr != NULL)
- free(ptr);
-}
-
diff --git a/vowpalwabbit/memory.h b/vowpalwabbit/memory.h
index 6d67d51e..af4c0c4c 100644
--- a/vowpalwabbit/memory.h
+++ b/vowpalwabbit/memory.h
@@ -1,5 +1,4 @@
#pragma once
-
#include <stdlib.h>
#include <iostream>
@@ -18,9 +17,6 @@ T* calloc_or_die(size_t nmemb)
}
template<class T> T& calloc_or_die()
-{
- return *calloc_or_die<T>(1);
-}
-
+{ return *calloc_or_die<T>(1); }
-void free_it(void* ptr);
+inline void free_it(void* ptr) { if (ptr != NULL) free(ptr); }
diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc
index 4e00be8d..c2e7575a 100644
--- a/vowpalwabbit/mf.cc
+++ b/vowpalwabbit/mf.cc
@@ -22,7 +22,7 @@ namespace MF {
struct mf {
vector<string> pairs;
- uint32_t rank;
+ size_t rank;
uint32_t increment;
@@ -188,22 +188,23 @@ void finish(mf& o) {
o.sub_predictions.delete_v();
}
-
-base_learner* setup(vw& all, po::variables_map& vm) {
- mf* data = new mf;
-
- // copy global data locally
- data->all = &all;
- data->rank = (uint32_t)vm["new_mf"].as<size_t>();
+ base_learner* setup(vw& all) {
+ new_options(all, "MF options")
+ ("new_mf", po::value<size_t>(), "rank for reduction-based matrix factorization");
+ if(missing_required(all)) return NULL;
+
+ mf& data = calloc_or_die<mf>();
+ data.all = &all;
+ data.rank = (uint32_t)all.vm["new_mf"].as<size_t>();
// store global pairs in local data structure and clear global pairs
// for eventual calls to base learner
- data->pairs = all.pairs;
+ data.pairs = all.pairs;
all.pairs.clear();
all.random_positive_weights = true;
- learner<mf>& l = init_learner(data, all.l, learn, predict<false>, 2*data->rank+1);
+ learner<mf>& l = init_learner(&data, setup_base(all), learn, predict<false>, 2*data.rank+1);
l.set_finish(finish);
return make_base(l);
}
diff --git a/vowpalwabbit/mf.h b/vowpalwabbit/mf.h
index 99643601..6d4be4f3 100644
--- a/vowpalwabbit/mf.h
+++ b/vowpalwabbit/mf.h
@@ -4,12 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-#include <math.h>
-#include "example.h"
-#include "parse_regressor.h"
-#include "parser.h"
-#include "gd.h"
-
-namespace MF{
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace MF{ LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/multiclass.h b/vowpalwabbit/multiclass.h
index f48efbfc..14f5d443 100644
--- a/vowpalwabbit/multiclass.h
+++ b/vowpalwabbit/multiclass.h
@@ -20,6 +20,8 @@ namespace MULTICLASS
void finish_example(vw& all, example& ec);
+ template <class T> void finish_example(vw& all, T&, example& ec) { finish_example(all, ec); }
+
inline int label_is_test(multiclass* ld)
{ return ld->label == (uint32_t)-1; }
}
diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc
index bfab6009..bcd42092 100644
--- a/vowpalwabbit/nn.cc
+++ b/vowpalwabbit/nn.cc
@@ -308,19 +308,20 @@ CONVERSE: // That's right, I'm using goto. So sue me.
free (n.output_layer.atomics[nn_output_namespace].begin);
}
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{
- nn& n = calloc_or_die<nn>();
- n.all = &all;
-
- po::options_description nn_opts("NN options");
- nn_opts.add_options()
+ new_options(all, "Neural Network options")
+ ("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units");
+ if(missing_required(all)) return NULL;
+ new_options(all)
("inpass", "Train or test sigmoidal feedforward network with input passthrough.")
("dropout", "Train or test sigmoidal feedforward network using dropout.")
("meanfield", "Train or test sigmoidal feedforward network using mean field.");
+ add_options(all);
- vm = add_options(all, nn_opts);
-
+ po::variables_map& vm = all.vm;
+ nn& n = calloc_or_die<nn>();
+ n.all = &all;
//first parse for number of hidden units
n.k = (uint32_t)vm["nn"].as<size_t>();
*all.file_options << " --nn " << n.k;
@@ -364,8 +365,10 @@ CONVERSE: // That's right, I'm using goto. So sue me.
n.xsubi = vm["random_seed"].as<size_t>();
n.save_xsubi = n.xsubi;
- n.increment = all.l->increment;//Indexing of output layer is odd.
- learner<nn>& l = init_learner(&n, all.l, predict_or_learn<true>,
+
+ base_learner* base = setup_base(all);
+ n.increment = base->increment;//Indexing of output layer is odd.
+ learner<nn>& l = init_learner(&n, base, predict_or_learn<true>,
predict_or_learn<false>, n.k+1);
l.set_finish(finish);
l.set_finish_example(finish_example);
diff --git a/vowpalwabbit/nn.h b/vowpalwabbit/nn.h
index 52e08f46..3b000433 100644
--- a/vowpalwabbit/nn.h
+++ b/vowpalwabbit/nn.h
@@ -4,10 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-#include "global_data.h"
-#include "parse_args.h"
-
-namespace NN
-{
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace NN { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/noop.cc b/vowpalwabbit/noop.cc
index 0c883a8c..892480ff 100644
--- a/vowpalwabbit/noop.cc
+++ b/vowpalwabbit/noop.cc
@@ -9,7 +9,11 @@ license as described in the file LICENSE.
namespace NOOP {
void learn(char&, LEARNER::base_learner&, example&) {}
-
+
LEARNER::base_learner* setup(vw& all)
- { return &LEARNER::init_learner<char>(NULL, learn, 1); }
+ {
+ new_options(all, "Noop options") ("noop","do no learning");
+ if(missing_required(all)) return NULL;
+
+ return &LEARNER::init_learner<char>(NULL, learn, 1); }
}
diff --git a/vowpalwabbit/noop.h b/vowpalwabbit/noop.h
index 5220e1ee..ed660870 100644
--- a/vowpalwabbit/noop.h
+++ b/vowpalwabbit/noop.h
@@ -4,6 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace NOOP {
- LEARNER::base_learner* setup(vw&);
-}
+namespace NOOP { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc
index 2328b00d..185264f5 100644
--- a/vowpalwabbit/oaa.cc
+++ b/vowpalwabbit/oaa.cc
@@ -7,7 +7,6 @@ license as described in the file LICENSE.
#include "multiclass.h"
#include "simple_label.h"
#include "reductions.h"
-#include "vw.h"
namespace OAA {
struct oaa{
@@ -60,21 +59,23 @@ namespace OAA {
o.all->print_text(o.all->raw_prediction, outputStringStream.str(), ec.tag);
}
- void finish_example(vw& all, oaa&, example& ec) { MULTICLASS::finish_example(all, ec); }
-
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm)
+ LEARNER::base_learner* setup(vw& all)
{
+ new_options(all, "One-against-all options")
+ ("oaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> labels");
+ if(missing_required(all)) return NULL;
+
oaa& data = calloc_or_die<oaa>();
- data.k = vm["oaa"].as<size_t>();
+ data.k = all.vm["oaa"].as<size_t>();
data.shouldOutput = all.raw_prediction > 0;
data.all = &all;
*all.file_options << " --oaa " << data.k;
- all.p->lp = MULTICLASS::mc_label;
- LEARNER::learner<oaa>& l = init_learner(&data, all.l, predict_or_learn<true>,
- predict_or_learn<false>, data.k);
- l.set_finish_example(finish_example);
+ LEARNER::learner<oaa>& l = init_learner(&data, setup_base(all), predict_or_learn<true>,
+ predict_or_learn<false>, data.k);
+ l.set_finish_example(MULTICLASS::finish_example<oaa>);
+ all.p->lp = MULTICLASS::mc_label;
return make_base(l);
}
}
diff --git a/vowpalwabbit/oaa.h b/vowpalwabbit/oaa.h
index de1b08ab..2bc46649 100644
--- a/vowpalwabbit/oaa.h
+++ b/vowpalwabbit/oaa.h
@@ -4,5 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace OAA
-{ LEARNER::base_learner* setup(vw& all, po::variables_map& vm); }
+namespace OAA { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index a760261f..c8ec126d 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -17,6 +17,7 @@ license as described in the file LICENSE.
#include "network.h"
#include "global_data.h"
#include "nn.h"
+#include "gd.h"
#include "cbify.h"
#include "oaa.h"
#include "rand48.h"
@@ -177,18 +178,17 @@ void parse_affix_argument(vw&all, string str) {
free(cstr);
}
-void parse_diagnostics(vw& all, po::variables_map& vm, int argc)
+void parse_diagnostics(vw& all, int argc)
{
- po::options_description diag_opt("Diagnostic options");
-
- diag_opt.add_options()
+ new_options(all, "Diagnostic options")
("version","Version information")
("audit,a", "print weights of features")
("progress,P", po::value< string >(), "Progress update frequency. int: additive, float: multiplicative")
("quiet", "Don't output disgnostics and progress updates")
("help,h","Look here: http://hunch.net/~vw/ and click on Tutorial.");
-
- vm = add_options(all, diag_opt);
+ add_options(all);
+
+ po::variables_map& vm = all.vm;
if (vm.count("version")) {
/* upon direct query for version -- spit it out to stdout */
@@ -245,11 +245,9 @@ void parse_diagnostics(vw& all, po::variables_map& vm, int argc)
}
}
-void parse_source(vw& all, po::variables_map& vm)
+void parse_source(vw& all)
{
- po::options_description in_opt("Input options");
-
- in_opt.add_options()
+ new_options(all, "Input options")
("data,d", po::value< string >(), "Example Set")
("daemon", "persistent daemon mode on port 26542")
("port", po::value<size_t>(),"port to listen on; use 0 to pick unused port")
@@ -261,19 +259,18 @@ void parse_source(vw& all, po::variables_map& vm)
("kill_cache,k", "do not reuse existing cache: create a new one always")
("compressed", "use gzip format whenever possible. If a cache file is being created, this option creates a compressed cache file. A mixture of raw-text & compressed inputs are supported with autodetection.")
("no_stdin", "do not default to reading from stdin");
-
- vm = add_options(all, in_opt);
+ add_options(all);
// Be friendly: if -d was left out, treat positional param as data file
po::positional_options_description p;
p.add("data", -1);
- vm = po::variables_map();
po::parsed_options pos = po::command_line_parser(all.args).
style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
options(all.opts).positional(p).run();
- vm = po::variables_map();
- po::store(pos, vm);
+ all.vm = po::variables_map();
+ po::store(pos, all.vm);
+ po::variables_map& vm = all.vm;
//begin input source
if (vm.count("no_stdin"))
@@ -315,10 +312,9 @@ void parse_source(vw& all, po::variables_map& vm)
}
}
-void parse_feature_tweaks(vw& all, po::variables_map& vm)
+void parse_feature_tweaks(vw& all)
{
- po::options_description feature_opt("Feature options");
- feature_opt.add_options()
+ new_options(all, "Feature options")
("hash", po::value< string > (), "how to hash the features. Available options: strings, all")
("ignore", po::value< vector<unsigned char> >(), "ignore namespaces beginning with character <arg>")
("keep", po::value< vector<unsigned char> >(), "keep namespaces beginning with character <arg>")
@@ -335,8 +331,9 @@ void parse_feature_tweaks(vw& all, po::variables_map& vm)
("q:", po::value< string >(), ": corresponds to a wildcard for all printable characters")
("cubic", po::value< vector<string> > (),
"Create and use cubic features");
+ add_options(all);
- vm = add_options(all, feature_opt);
+ po::variables_map& vm = all.vm;
//feature manipulation
string hash_function("strings");
@@ -545,11 +542,9 @@ void parse_feature_tweaks(vw& all, po::variables_map& vm)
all.add_constant = false;
}
-void parse_example_tweaks(vw& all, po::variables_map& vm)
+void parse_example_tweaks(vw& all)
{
- po::options_description example_opts("Example options");
-
- example_opts.add_options()
+ new_options(all, "Example options")
("testonly,t", "Ignore label information and just test")
("holdout_off", "no holdout data in multiple passes")
("holdout_period", po::value<uint32_t>(&(all.holdout_period)), "holdout period for test only, default 10")
@@ -565,9 +560,9 @@ void parse_example_tweaks(vw& all, po::variables_map& vm)
("quantile_tau", po::value<float>()->default_value(0.5), "Parameter \\tau associated with Quantile loss. Defaults to 0.5")
("l1", po::value<float>(&(all.l1_lambda)), "l_1 lambda")
("l2", po::value<float>(&(all.l2_lambda)), "l_2 lambda");
+ add_options(all);
- vm = add_options(all, example_opts);
-
+ po::variables_map& vm = all.vm;
if (vm.count("testonly") || all.eta == 0.)
{
if (!all.quiet)
@@ -621,17 +616,14 @@ void parse_example_tweaks(vw& all, po::variables_map& vm)
}
}
-void parse_output_preds(vw& all, po::variables_map& vm)
+void parse_output_preds(vw& all)
{
- po::options_description out_opt("Output options");
-
- out_opt.add_options()
+ new_options(all, "Output options")
("predictions,p", po::value< string >(), "File to output predictions to")
- ("raw_predictions,r", po::value< string >(), "File to output unnormalized predictions to")
- ;
-
- vm = add_options(all, out_opt);
+ ("raw_predictions,r", po::value< string >(), "File to output unnormalized predictions to");
+ add_options(all);
+ po::variables_map& vm = all.vm;
if (vm.count("predictions")) {
if (!all.quiet)
cerr << "predictions = " << vm["predictions"].as< string >() << endl;
@@ -676,21 +668,19 @@ void parse_output_preds(vw& all, po::variables_map& vm)
}
}
-void parse_output_model(vw& all, po::variables_map& vm)
+void parse_output_model(vw& all)
{
- po::options_description output_model("Output model");
-
- output_model.add_options()
+ new_options(all, "Output model")
("final_regressor,f", po::value< string >(), "Final regressor")
("readable_model", po::value< string >(), "Output human-readable final regressor with numeric features")
("invert_hash", po::value< string >(), "Output human-readable final regressor with feature names. Computationally expensive.")
("save_resume", "save extra state so learning can be resumed later with new data")
("save_per_pass", "Save the model after every pass over data")
("output_feature_regularizer_binary", po::value< string >(&(all.per_feature_regularizer_output)), "Per feature regularization output file")
- ("output_feature_regularizer_text", po::value< string >(&(all.per_feature_regularizer_text)), "Per feature regularization output file, in text");
-
- vm = add_options(all, output_model);
+ ("output_feature_regularizer_text", po::value< string >(&(all.per_feature_regularizer_text)), "Per feature regularization output file, in text");
+ add_options(all);
+ po::variables_map& vm = all.vm;
if (vm.count("final_regressor")) {
all.final_regressor_name = vm["final_regressor"].as<string>();
if (!all.quiet)
@@ -714,68 +704,22 @@ void parse_output_model(vw& all, po::variables_map& vm)
all.save_resume = true;
}
-void parse_base_algorithm(vw& all, po::variables_map& vm)
-{
- //base learning algorithm.
- po::options_description base_opt("base algorithms (these are exclusive)");
-
- base_opt.add_options()
- ("sgd", "use regular stochastic gradient descent update.")
- ("ftrl", "use ftrl-proximal optimization")
- ("adaptive", "use adaptive, individual learning rates.")
- ("invariant", "use safe/importance aware updates.")
- ("normalized", "use per feature normalized updates")
- ("exact_adaptive_norm", "use current default invariant normalized adaptive update rule")
- ("bfgs", "use bfgs optimization")
- ("lda", po::value<uint32_t>(&(all.lda)), "Run lda with <int> topics")
- ("rank", po::value<uint32_t>(&(all.rank)), "rank for matrix factorization.")
- ("noop","do no learning")
- ("print","print examples")
- ("ksvm", "kernel svm")
- ("sendto", po::value< vector<string> >(), "send examples to <host>");
-
- vm = add_options(all, base_opt);
-
- if (vm.count("bfgs") || vm.count("conjugate_gradient"))
- all.l = BFGS::setup(all, vm);
- else if (vm.count("lda"))
- all.l = LDA::setup(all, vm);
- else if (vm.count("ftrl"))
- all.l = FTRL::setup(all, vm);
- else if (vm.count("noop"))
- all.l = NOOP::setup(all);
- else if (vm.count("print"))
- all.l = PRINT::setup(all);
- else if (all.rank > 0)
- all.l = GDMF::setup(all, vm);
- else if (vm.count("sendto"))
- all.l = SENDER::setup(all, vm, all.pairs);
- else if (vm.count("ksvm")) {
- all.l = KSVM::setup(all, vm);
- }
- else
- {
- all.l = GD::setup(all, vm);
- all.scorer = all.l;
- }
-}
-
-void load_input_model(vw& all, po::variables_map& vm, io_buf& io_temp)
+void load_input_model(vw& all, io_buf& io_temp)
{
// Need to see if we have to load feature mask first or second.
// -i and -mask are from same file, load -i file first so mask can use it
- if (vm.count("feature_mask") && vm.count("initial_regressor")
- && vm["feature_mask"].as<string>() == vm["initial_regressor"].as< vector<string> >()[0]) {
+ if (all.vm.count("feature_mask") && all.vm.count("initial_regressor")
+ && all.vm["feature_mask"].as<string>() == all.vm["initial_regressor"].as< vector<string> >()[0]) {
// load rest of regressor
all.l->save_load(io_temp, true, false);
io_temp.close_file();
// set the mask, which will reuse -i file we just loaded
- parse_mask_regressor_args(all, vm);
+ parse_mask_regressor_args(all);
}
else {
// load mask first
- parse_mask_regressor_args(all, vm);
+ parse_mask_regressor_args(all);
// load rest of regressor
all.l->save_load(io_temp, true, false);
@@ -783,166 +727,51 @@ void load_input_model(vw& all, po::variables_map& vm, io_buf& io_temp)
}
}
-void parse_scorer_reductions(vw& all, po::variables_map& vm)
+LEARNER::base_learner* setup_base(vw& all)
{
- po::options_description score_mod_opt("Score modifying options (can be combined)");
-
- score_mod_opt.add_options()
- ("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units")
- ("new_mf", po::value<size_t>(), "rank for reduction-based matrix factorization")
- ("autolink", po::value<size_t>(), "create link function with polynomial d")
- ("lrq", po::value<vector<string> > (), "use low rank quadratic features")
- ("lrqdropout", "use dropout training for low rank quadratic features")
- ("stage_poly", "use stagewise polynomial feature learning")
- ("active", "enable active learning");
-
- vm = add_options(all, score_mod_opt);
-
- if (vm.count("active"))
- all.l = ACTIVE::setup(all,vm);
-
- if(vm.count("nn"))
- all.l = NN::setup(all, vm);
-
- if (vm.count("new_mf"))
- all.l = MF::setup(all, vm);
-
- if(vm.count("autolink"))
- all.l = ALINK::setup(all, vm);
-
- if (vm.count("lrq"))
- all.l = LRQ::setup(all, vm);
-
- if (vm.count("stage_poly"))
- all.l = StagewisePoly::setup(all, vm);
-
- all.l = Scorer::setup(all, vm);
+ LEARNER::base_learner* ret = all.reduction_stack.pop()(all);
+ if (ret == NULL)
+ return setup_base(all);
+ else
+ return ret;
}
-LEARNER::base_learner* exclusive_setup(vw& all, po::variables_map& vm, bool& score_consumer, LEARNER::base_learner* (*setup)(vw&, po::variables_map&))
+void parse_reductions(vw& all)
{
- if (score_consumer) { cerr << "error: cannot specify multiple direct score consumers" << endl; throw exception(); }
- score_consumer = true;
- return setup(all, vm);
-}
-
-void parse_score_users(vw& all, po::variables_map& vm, bool& got_cs)
-{
- po::options_description multiclass_opt("Score user options (these are exclusive)");
- multiclass_opt.add_options()
- ("top", po::value<size_t>(), "top k recommendation")
- ("binary", "report loss as binary classification on -1,1")
- ("oaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> labels")
- ("ect", po::value<size_t>(), "Use error correcting tournament with <k> labels")
- ("log_multi", po::value<size_t>(), "Use online tree for multiclass")
- ("csoaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> costs")
- ("csoaa_ldf", po::value<string>(), "Use one-against-all multiclass learning with label dependent features. Specify singleline or multiline.")
- ("wap_ldf", po::value<string>(), "Use weighted all-pairs multiclass learning with label dependent features. Specify singleline or multiline.")
- ;
-
- vm = add_options(all, multiclass_opt);
- bool score_consumer = false;
-
- if(vm.count("top"))
- all.l = exclusive_setup(all, vm, score_consumer, TOPK::setup);
-
- if (vm.count("binary"))
- all.l = exclusive_setup(all, vm, score_consumer, BINARY::setup);
-
- if (vm.count("oaa"))
- all.l = exclusive_setup(all, vm, score_consumer, OAA::setup);
-
- if (vm.count("ect"))
- all.l = exclusive_setup(all, vm, score_consumer, ECT::setup);
-
- if(vm.count("csoaa")) {
- all.l = exclusive_setup(all, vm, score_consumer, CSOAA::setup);
- all.cost_sensitive = all.l;
- got_cs = true;
- }
-
- if(vm.count("log_multi")){
- all.l = exclusive_setup(all, vm, score_consumer, LOG_MULTI::setup);
- }
-
- if(vm.count("csoaa_ldf") || vm.count("csoaa_ldf")) {
- all.l = exclusive_setup(all, vm, score_consumer, CSOAA_AND_WAP_LDF::setup);
- all.cost_sensitive = all.l;
- got_cs = true;
- }
-
- if(vm.count("wap_ldf") || vm.count("wap_ldf") ) {
- all.l = exclusive_setup(all, vm, score_consumer, CSOAA_AND_WAP_LDF::setup);
- all.cost_sensitive = all.l;
- got_cs = true;
- }
-}
-
-void parse_cb(vw& all, po::variables_map& vm, bool& got_cs, bool& got_cb)
-{
- po::options_description cb_opts("Contextual Bandit options");
-
- cb_opts.add_options()
- ("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs")
- ("cbify", po::value<size_t>(), "Convert multiclass on <k> classes into a contextual bandit problem and solve");
-
- vm = add_options(all,cb_opts);
-
- if( vm.count("cb"))
- {
- if(!got_cs) {
- if( vm.count("cb") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cb"]));
- else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cb"]));
-
- all.l = CSOAA::setup(all, vm); // default to CSOAA unless wap is specified
- all.cost_sensitive = all.l;
- got_cs = true;
- }
-
- all.l = CB_ALGS::setup(all, vm);
- got_cb = true;
- }
-
- if (vm.count("cbify"))
- {
- if(!got_cs) {
- vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cbify"]));
-
- all.l = CSOAA::setup(all, vm); // default to CSOAA unless wap is specified
- all.cost_sensitive = all.l;
- got_cs = true;
- }
-
- if (!got_cb) {
- vm.insert(pair<string,po::variable_value>(string("cb"),vm["cbify"]));
- all.l = CB_ALGS::setup(all, vm);
- got_cb = true;
- }
-
- all.l = CBIFY::setup(all, vm);
- }
-}
-
-void parse_search(vw& all, po::variables_map& vm, bool& got_cs, bool& got_cb)
-{
- po::options_description search_opts("Search");
-
- search_opts.add_options()
- ("search", po::value<size_t>(), "use search-based structured prediction, argument=maximum action id or 0 for LDF");
-
- vm = add_options(all,search_opts);
-
- if (vm.count("search")) {
- if (!got_cs && !got_cb) {
- if( vm.count("search") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["search"]));
- else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["search"]));
-
- all.l = CSOAA::setup(all, vm); // default to CSOAA unless others have been specified
- all.cost_sensitive = all.l;
- got_cs = true;
- }
- all.l = Search::setup(all, vm);
- }
+ //Base algorithms
+ all.reduction_stack.push_back(GD::setup);
+ all.reduction_stack.push_back(KSVM::setup);
+ all.reduction_stack.push_back(FTRL::setup);
+ all.reduction_stack.push_back(SENDER::setup);
+ all.reduction_stack.push_back(GDMF::setup);
+ all.reduction_stack.push_back(PRINT::setup);
+ all.reduction_stack.push_back(NOOP::setup);
+ all.reduction_stack.push_back(LDA::setup);
+ all.reduction_stack.push_back(BFGS::setup);
+
+ //Score Users
+ all.reduction_stack.push_back(ACTIVE::setup);
+ all.reduction_stack.push_back(NN::setup);
+ all.reduction_stack.push_back(MF::setup);
+ all.reduction_stack.push_back(ALINK::setup);
+ all.reduction_stack.push_back(LRQ::setup);
+ all.reduction_stack.push_back(StagewisePoly::setup);
+ all.reduction_stack.push_back(Scorer::setup);
+
+ //Reductions
+ all.reduction_stack.push_back(BINARY::setup);
+ all.reduction_stack.push_back(TOPK::setup);
+ all.reduction_stack.push_back(OAA::setup);
+ all.reduction_stack.push_back(ECT::setup);
+ all.reduction_stack.push_back(LOG_MULTI::setup);
+ all.reduction_stack.push_back(CSOAA::setup);
+ all.reduction_stack.push_back(CSOAA_AND_WAP_LDF::setup);
+ all.reduction_stack.push_back(CB_ALGS::setup);
+ all.reduction_stack.push_back(CBIFY::setup);
+ all.reduction_stack.push_back(Search::setup);
+ all.reduction_stack.push_back(BS::setup);
+
+ all.l = setup_base(all);
}
void add_to_args(vw& all, int argc, char* argv[])
@@ -951,143 +780,107 @@ void add_to_args(vw& all, int argc, char* argv[])
all.args.push_back(string(argv[i]));
}
-vw* parse_args(int argc, char *argv[])
+vw& parse_args(int argc, char *argv[])
{
- vw* all = new vw();
+ vw& all = *(new vw());
- add_to_args(*all, argc, argv);
+ add_to_args(all, argc, argv);
size_t random_seed = 0;
- all->program_name = argv[0];
+ all.program_name = argv[0];
- po::options_description desc("VW options");
-
- desc.add_options()
+ new_options(all, "VW options")
("random_seed", po::value<size_t>(&random_seed), "seed random number generator")
- ("ring_size", po::value<size_t>(&(all->p->ring_size)), "size of example ring");
-
- po::options_description update_opt("Update options");
+ ("ring_size", po::value<size_t>(&(all.p->ring_size)), "size of example ring");
+ add_options(all);
- update_opt.add_options()
- ("learning_rate,l", po::value<float>(&(all->eta)), "Set learning rate")
- ("power_t", po::value<float>(&(all->power_t)), "t power value")
- ("decay_learning_rate", po::value<float>(&(all->eta_decay_rate)),
+ new_options(all, "Update options")
+ ("learning_rate,l", po::value<float>(&(all.eta)), "Set learning rate")
+ ("power_t", po::value<float>(&(all.power_t)), "t power value")
+ ("decay_learning_rate", po::value<float>(&(all.eta_decay_rate)),
"Set Decay factor for learning_rate between passes")
- ("initial_t", po::value<double>(&((all->sd->t))), "initial t value")
- ("feature_mask", po::value< string >(), "Use existing regressor to determine which parameters may be updated. If no initial_regressor given, also used for initial weights.")
- ;
-
- po::options_description weight_opt("Weight options");
+ ("initial_t", po::value<double>(&((all.sd->t))), "initial t value")
+ ("feature_mask", po::value< string >(), "Use existing regressor to determine which parameters may be updated. If no initial_regressor given, also used for initial weights.");
+ add_options(all);
- weight_opt.add_options()
+ new_options(all, "Weight options")
("initial_regressor,i", po::value< vector<string> >(), "Initial regressor(s)")
- ("initial_weight", po::value<float>(&(all->initial_weight)), "Set all weights to an initial value of 1.")
- ("random_weights", po::value<bool>(&(all->random_weights)), "make initial weights random")
- ("input_feature_regularizer", po::value< string >(&(all->per_feature_regularizer_input)), "Per feature regularization input file")
- ;
-
- po::options_description cluster_opt("Parallelization options");
- cluster_opt.add_options()
- ("span_server", po::value<string>(&(all->span_server)), "Location of server for setting up spanning tree")
- ("unique_id", po::value<size_t>(&(all->unique_id)),"unique id used for cluster parallel jobs")
- ("total", po::value<size_t>(&(all->total)),"total number of nodes used in cluster parallel job")
- ("node", po::value<size_t>(&(all->node)),"node number in cluster parallel job")
- ;
-
- po::options_description other_opt("Other options");
- other_opt.add_options()
- ("bootstrap,B", po::value<size_t>(), "bootstrap mode with k rounds by online importance resampling")
- ;
-
- desc.add(update_opt)
- .add(weight_opt)
- .add(cluster_opt)
- .add(other_opt);
-
- po::variables_map vm = add_options(*all, desc);
-
+ ("initial_weight", po::value<float>(&(all.initial_weight)), "Set all weights to an initial value of 1.")
+ ("random_weights", po::value<bool>(&(all.random_weights)), "make initial weights random")
+ ("input_feature_regularizer", po::value< string >(&(all.per_feature_regularizer_input)), "Per feature regularization input file");
+ add_options(all);
+
+ new_options(all, "Parallelization options")
+ ("span_server", po::value<string>(&(all.span_server)), "Location of server for setting up spanning tree")
+ ("unique_id", po::value<size_t>(&(all.unique_id)),"unique id used for cluster parallel jobs")
+ ("total", po::value<size_t>(&(all.total)),"total number of nodes used in cluster parallel job")
+ ("node", po::value<size_t>(&(all.node)),"node number in cluster parallel job");
+ add_options(all);
+
+ po::variables_map& vm = all.vm;
msrand48(random_seed);
+ parse_diagnostics(all, argc);
- parse_diagnostics(*all, vm, argc);
-
- all->sd->weighted_unlabeled_examples = all->sd->t;
- all->initial_t = (float)all->sd->t;
+ all.sd->weighted_unlabeled_examples = all.sd->t;
+ all.initial_t = (float)all.sd->t;
//Input regressor header
io_buf io_temp;
- parse_regressor_args(*all, vm, io_temp);
+ parse_regressor_args(all, vm, io_temp);
int temp_argc = 0;
- char** temp_argv = VW::get_argv_from_string(all->file_options->str(), temp_argc);
- add_to_args(*all, temp_argc, temp_argv);
+ char** temp_argv = VW::get_argv_from_string(all.file_options->str(), temp_argc);
+ add_to_args(all, temp_argc, temp_argv);
for (int i = 0; i < temp_argc; i++)
free(temp_argv[i]);
free(temp_argv);
- po::parsed_options pos = po::command_line_parser(all->args).
+ po::parsed_options pos = po::command_line_parser(all.args).
style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
- options(all->opts).allow_unregistered().run();
+ options(all.opts).allow_unregistered().run();
vm = po::variables_map();
po::store(pos, vm);
po::notify(vm);
- all->file_options->str("");
+ all.file_options->str("");
- parse_feature_tweaks(*all, vm); //feature tweaks
+ parse_feature_tweaks(all); //feature tweaks
- parse_example_tweaks(*all, vm); //example manipulation
+ parse_example_tweaks(all); //example manipulation
- parse_output_model(*all, vm);
+ parse_output_model(all);
- parse_base_algorithm(*all, vm);
+ parse_reductions(all);
- if (!all->quiet)
+ if (!all.quiet)
{
- cerr << "Num weight bits = " << all->num_bits << endl;
- cerr << "learning rate = " << all->eta << endl;
- cerr << "initial_t = " << all->sd->t << endl;
- cerr << "power_t = " << all->power_t << endl;
- if (all->numpasses > 1)
- cerr << "decay_learning_rate = " << all->eta_decay_rate << endl;
- if (all->rank > 0)
- cerr << "rank = " << all->rank << endl;
+ cerr << "Num weight bits = " << all.num_bits << endl;
+ cerr << "learning rate = " << all.eta << endl;
+ cerr << "initial_t = " << all.sd->t << endl;
+ cerr << "power_t = " << all.power_t << endl;
+ if (all.numpasses > 1)
+ cerr << "decay_learning_rate = " << all.eta_decay_rate << endl;
}
- parse_output_preds(*all, vm);
-
- parse_scorer_reductions(*all, vm);
-
- bool got_cs = false;
-
- parse_score_users(*all, vm, got_cs);
-
- bool got_cb = false;
-
- parse_cb(*all, vm, got_cs, got_cb);
-
- parse_search(*all, vm, got_cs, got_cb);
-
-
- if(vm.count("bootstrap"))
- all->l = BS::setup(*all, vm);
+ parse_output_preds(all);
- load_input_model(*all, vm, io_temp);
+ load_input_model(all, io_temp);
- parse_source(*all, vm);
+ parse_source(all);
- enable_sources(*all, vm, all->quiet,all->numpasses);
+ enable_sources(all, all.quiet, all.numpasses);
// force wpp to be a power of 2 to avoid 32-bit overflow
uint32_t i = 0;
- size_t params_per_problem = all->l->increment;
+ size_t params_per_problem = all.l->increment;
while (params_per_problem > (uint32_t)(1 << i))
i++;
- all->wpp = (1 << i) >> all->reg.stride_shift;
+ all.wpp = (1 << i) >> all.reg.stride_shift;
if (vm.count("help")) {
/* upon direct query for help -- spit it out to stdout */
- cout << "\n" << all->opts << "\n";
+ cout << "\n" << all.opts << "\n";
exit(0);
}
@@ -1152,15 +945,15 @@ namespace VW {
s += " --no_stdin";
char** argv = get_argv_from_string(s,argc);
- vw* all = parse_args(argc, argv);
+ vw& all = parse_args(argc, argv);
- initialize_parser_datastructures(*all);
+ initialize_parser_datastructures(all);
for(int i = 0; i < argc; i++)
free(argv[i]);
- free (argv);
+ free(argv);
- return all;
+ return &all;
}
void delete_dictionary_entry(substring ss, v_array<feature>*A) {
@@ -1182,6 +975,7 @@ namespace VW {
all.p->parse_name.delete_v();
free(all.p);
free(all.sd);
+ all.reduction_stack.delete_v();
delete all.file_options;
for (size_t i = 0; i < all.final_prediction_sink.size(); i++)
if (all.final_prediction_sink[i] != 1)
diff --git a/vowpalwabbit/parse_args.h b/vowpalwabbit/parse_args.h
index 9e16d5bc..6d150382 100644
--- a/vowpalwabbit/parse_args.h
+++ b/vowpalwabbit/parse_args.h
@@ -6,4 +6,5 @@ license as described in the file LICENSE.
#pragma once
#include "global_data.h"
-vw* parse_args(int argc, char *argv[]);
+vw& parse_args(int argc, char *argv[]);
+LEARNER::base_learner* setup_base(vw& all);
diff --git a/vowpalwabbit/parse_regressor.cc b/vowpalwabbit/parse_regressor.cc
index def3a8de..7cb6a21b 100644
--- a/vowpalwabbit/parse_regressor.cc
+++ b/vowpalwabbit/parse_regressor.cc
@@ -164,11 +164,6 @@ void save_load_header(vw& all, io_buf& model_file, bool read, bool text)
"", read,
"\n",1, text);
- text_len = sprintf(buff, "rank:%d\n", (int)all.rank);
- bin_text_read_write_fixed(model_file,(char*)&all.rank, sizeof(all.rank),
- "", read,
- buff,text_len, text);
-
text_len = sprintf(buff, "lda:%d\n", (int)all.lda);
bin_text_read_write_fixed(model_file,(char*)&all.lda, sizeof(all.lda),
"", read,
@@ -312,19 +307,19 @@ void parse_regressor_args(vw& all, po::variables_map& vm, io_buf& io_temp)
save_load_header(all, io_temp, true, false);
}
-void parse_mask_regressor_args(vw& all, po::variables_map& vm){
-
+void parse_mask_regressor_args(vw& all)
+{
+ po::variables_map& vm = all.vm;
if (vm.count("feature_mask")) {
size_t length = ((size_t)1) << all.num_bits;
string mask_filename = vm["feature_mask"].as<string>();
if (vm.count("initial_regressor")){
vector<string> init_filename = vm["initial_regressor"].as< vector<string> >();
if(mask_filename == init_filename[0]){//-i and -mask are from same file, just generate mask
-
return;
}
}
-
+
//all other cases, including from different file, or -i does not exist, need to read in the mask file
io_buf io_temp_mask;
io_temp_mask.open_file(mask_filename.c_str(), false, io_buf::READ);
diff --git a/vowpalwabbit/parse_regressor.h b/vowpalwabbit/parse_regressor.h
index b76b5cdc..069dd6a2 100644
--- a/vowpalwabbit/parse_regressor.h
+++ b/vowpalwabbit/parse_regressor.h
@@ -20,4 +20,4 @@ void initialize_regressor(vw& all);
void save_predictor(vw& all, std::string reg_name, size_t current_pass);
void save_load_header(vw& all, io_buf& model_file, bool read, bool text);
-void parse_mask_regressor_args(vw& all, po::variables_map& vm);
+void parse_mask_regressor_args(vw& all);
diff --git a/vowpalwabbit/parser.cc b/vowpalwabbit/parser.cc
index 44055bab..c71f2956 100644
--- a/vowpalwabbit/parser.cc
+++ b/vowpalwabbit/parser.cc
@@ -405,10 +405,10 @@ void parse_cache(vw& all, po::variables_map &vm, string source,
# define MAP_ANONYMOUS MAP_ANON
#endif
-void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes)
+void enable_sources(vw& all, bool quiet, size_t passes)
{
all.p->input->current = 0;
- parse_cache(all, vm, all.data_filename, quiet);
+ parse_cache(all, all.vm, all.data_filename, quiet);
if (all.daemon || all.active)
{
@@ -431,8 +431,8 @@ void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes)
address.sin_family = AF_INET;
address.sin_addr.s_addr = htonl(INADDR_ANY);
short unsigned int port = 26542;
- if (vm.count("port"))
- port = (uint16_t)vm["port"].as<size_t>();
+ if (all.vm.count("port"))
+ port = (uint16_t)all.vm["port"].as<size_t>();
address.sin_port = htons(port);
// attempt to bind to socket
@@ -449,7 +449,7 @@ void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes)
}
// write port file
- if (vm.count("port_file"))
+ if (all.vm.count("port_file"))
{
socklen_t address_size = sizeof(address);
if (getsockname(all.p->bound_sock, (sockaddr*)&address, &address_size) < 0)
@@ -457,7 +457,7 @@ void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes)
cerr << "getsockname: " << strerror(errno) << endl;
}
ofstream port_file;
- port_file.open(vm["port_file"].as<string>().c_str());
+ port_file.open(all.vm["port_file"].as<string>().c_str());
if (!port_file.is_open())
{
cerr << "error writing port file" << endl;
@@ -474,10 +474,10 @@ void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes)
throw exception();
}
// write pid file
- if (vm.count("pid_file"))
+ if (all.vm.count("pid_file"))
{
ofstream pid_file;
- pid_file.open(vm["pid_file"].as<string>().c_str());
+ pid_file.open(all.vm["pid_file"].as<string>().c_str());
if (!pid_file.is_open())
{
cerr << "error writing pid file" << endl;
@@ -597,7 +597,7 @@ void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes)
}
all.p->resettable = all.p->write_cache || all.daemon;
}
- else // was: else if (vm.count("data"))
+ else
{
if (all.p->input->files.size() > 0)
{
@@ -840,37 +840,22 @@ void setup_example(vw& all, example* ae)
ae->total_sum_feat_sq += ae->sum_feat_sq[*i];
}
- if (all.rank == 0) {
- for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++)
- {
- ae->num_features
- += ae->atomics[(int)(*i)[0]].size()
- *ae->atomics[(int)(*i)[1]].size();
- ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]]*ae->sum_feat_sq[(int)(*i)[1]];
- }
-
- for (vector<string>::iterator i = all.triples.begin(); i != all.triples.end();i++)
- {
- ae->num_features
- += ae->atomics[(int)(*i)[0]].size()
- *ae->atomics[(int)(*i)[1]].size()
- *ae->atomics[(int)(*i)[2]].size();
- ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]] * ae->sum_feat_sq[(int)(*i)[1]] * ae->sum_feat_sq[(int)(*i)[2]];
- }
-
- } else {
- for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++)
- {
- ae->num_features += ae->atomics[(int)(*i)[0]].size() * all.rank;
- ae->num_features += ae->atomics[(int)(*i)[1]].size() * all.rank;
- }
- for (vector<string>::iterator i = all.triples.begin(); i != all.triples.end();i++)
- {
- ae->num_features += ae->atomics[(int)(*i)[0]].size() * all.rank;
- ae->num_features += ae->atomics[(int)(*i)[1]].size() * all.rank;
- ae->num_features += ae->atomics[(int)(*i)[2]].size() * all.rank;
- }
- }
+ for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++)
+ {
+ ae->num_features
+ += ae->atomics[(int)(*i)[0]].size()
+ *ae->atomics[(int)(*i)[1]].size();
+ ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]]*ae->sum_feat_sq[(int)(*i)[1]];
+ }
+
+ for (vector<string>::iterator i = all.triples.begin(); i != all.triples.end();i++)
+ {
+ ae->num_features
+ += ae->atomics[(int)(*i)[0]].size()
+ *ae->atomics[(int)(*i)[1]].size()
+ *ae->atomics[(int)(*i)[2]].size();
+ ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]] * ae->sum_feat_sq[(int)(*i)[1]] * ae->sum_feat_sq[(int)(*i)[2]];
+ }
}
}
diff --git a/vowpalwabbit/parser.h b/vowpalwabbit/parser.h
index c2d271d1..1c5b8f50 100644
--- a/vowpalwabbit/parser.h
+++ b/vowpalwabbit/parser.h
@@ -58,7 +58,7 @@ struct parser {
parser* new_parser();
-void enable_sources(vw& all, po::variables_map& vm, bool quiet, size_t passes);
+void enable_sources(vw& all, bool quiet, size_t passes);
bool examples_to_finish();
diff --git a/vowpalwabbit/print.cc b/vowpalwabbit/print.cc
index d0dc2765..321a8710 100644
--- a/vowpalwabbit/print.cc
+++ b/vowpalwabbit/print.cc
@@ -40,11 +40,15 @@ namespace PRINT
GD::foreach_feature<vw, print_feature>(*(p.all), ec, *p.all);
cout << endl;
}
-
+
LEARNER::base_learner* setup(vw& all)
{
+ new_options(all, "Print options") ("print","print examples");
+ if(missing_required(all)) return NULL;
+
print& p = calloc_or_die<print>();
p.all = &all;
+
size_t length = ((size_t)1) << all.num_bits;
all.reg.weight_mask = (length << all.reg.stride_shift) - 1;
all.reg.stride_shift = 0;
diff --git a/vowpalwabbit/print.h b/vowpalwabbit/print.h
index b6a771ed..affd09e8 100644
--- a/vowpalwabbit/print.h
+++ b/vowpalwabbit/print.h
@@ -4,6 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace PRINT {
- LEARNER::base_learner* setup(vw& all);
-}
+namespace PRINT { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/reductions.h b/vowpalwabbit/reductions.h
index a1c18ecf..35571cb2 100644
--- a/vowpalwabbit/reductions.h
+++ b/vowpalwabbit/reductions.h
@@ -13,3 +13,4 @@ namespace po = boost::program_options;
#include "learner.h" // for core reduction definition
#include "global_data.h" // for vw datastructure
#include "memory.h"
+#include "parse_args.h"
diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc
index 50645ed8..826f8e3b 100644
--- a/vowpalwabbit/scorer.cc
+++ b/vowpalwabbit/scorer.cc
@@ -31,33 +31,31 @@ namespace Scorer {
float id(float in) { return in; }
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm)
+ LEARNER::base_learner* setup(vw& all)
{
+ new_options(all, "Link options")
+ ("link", po::value<string>()->default_value("identity"), "Specify the link function: identity, logistic or glf1");
+ add_options(all);
+ po::variables_map& vm = all.vm;
scorer& s = calloc_or_die<scorer>();
s.all = &all;
- po::options_description link_opts("Link options");
-
- link_opts.add_options()
- ("link", po::value<string>()->default_value("identity"), "Specify the link function: identity, logistic or glf1");
-
- vm = add_options(all, link_opts);
-
+ LEARNER::base_learner* base = setup_base(all);
LEARNER::learner<scorer>* l;
string link = vm["link"].as<string>();
if (!vm.count("link") || link.compare("identity") == 0)
- l = &init_learner(&s, all.l, predict_or_learn<true, id>, predict_or_learn<false, id>);
+ l = &init_learner(&s, base, predict_or_learn<true, id>, predict_or_learn<false, id>);
else if (link.compare("logistic") == 0)
{
*all.file_options << " --link=logistic ";
- l = &init_learner(&s, all.l, predict_or_learn<true, logistic>,
+ l = &init_learner(&s, base, predict_or_learn<true, logistic>,
predict_or_learn<false, logistic>);
}
else if (link.compare("glf1") == 0)
{
*all.file_options << " --link=glf1 ";
- l = &init_learner(&s, all.l, predict_or_learn<true, glf1>,
+ l = &init_learner(&s, base, predict_or_learn<true, glf1>,
predict_or_learn<false, glf1>);
}
else
@@ -65,6 +63,8 @@ namespace Scorer {
cerr << "Unknown link function: " << link << endl;
throw exception();
}
- return make_base(*l);
+ all.scorer = make_base(*l);
+
+ return all.scorer;
}
}
diff --git a/vowpalwabbit/scorer.h b/vowpalwabbit/scorer.h
index 2d0ec294..efd95e9a 100644
--- a/vowpalwabbit/scorer.h
+++ b/vowpalwabbit/scorer.h
@@ -1,4 +1,2 @@
#pragma once
-namespace Scorer {
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace Scorer { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc
index fb14e65c..74444206 100644
--- a/vowpalwabbit/search.cc
+++ b/vowpalwabbit/search.cc
@@ -12,9 +12,8 @@ license as described in the file LICENSE.
#include "rand48.h"
#include "cost_sensitive.h"
#include "multiclass.h"
-#include "memory.h"
#include "constant.h"
-#include "example.h"
+#include "reductions.h"
#include "cb.h"
#include "gd.h" // for GD::foreach_feature
#include <math.h>
@@ -1669,14 +1668,14 @@ namespace Search {
ret = false;
}
- void handle_condition_options(vw& vw, auto_condition_settings& acset, po::variables_map& vm) {
- po::options_description condition_options("Search Auto-conditioning Options");
- condition_options.add_options()
- ("search_max_bias_ngram_length", po::value<size_t>(), "add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 (default), then you get a single feature for each conditional")
- ("search_max_quad_ngram_length", po::value<size_t>(), "add bias *times* input features for each ngram up to and including this length (def: 0)")
- ("search_condition_feature_value", po::value<float> (), "how much weight should the conditional features get? (def: 1.)");
+ void handle_condition_options(vw& vw, auto_condition_settings& acset) {
+ new_options(vw, "Search Auto-conditioning Options")
+ ("search_max_bias_ngram_length", po::value<size_t>(), "add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 (default), then you get a single feature for each conditional")
+ ("search_max_quad_ngram_length", po::value<size_t>(), "add bias *times* input features for each ngram up to and including this length (def: 0)")
+ ("search_condition_feature_value", po::value<float> (), "how much weight should the conditional features get? (def: 1.)");
+ add_options(vw);
- vm = add_options(vw, condition_options);
+ po::variables_map& vm = vw.vm;
check_option<size_t>(acset.max_bias_ngram_length, vw, vm, "search_max_bias_ngram_length", false, size_equal,
"warning: you specified a different value for --search_max_bias_ngram_length than the one loaded from regressor. proceeding with loaded value: ", "");
@@ -1764,38 +1763,39 @@ namespace Search {
delete[] cstr;
}
- base_learner* setup(vw&all, po::variables_map& vm) {
- search& sch = calloc_or_die<search>();
- sch.priv = new search_private();
- search_initialize(&all, sch);
- search_private& priv = *sch.priv;
-
- po::options_description search_opts("Search Options");
- search_opts.add_options()
- ("search_task", po::value<string>(), "the search task (use \"--search_task list\" to get a list of available tasks)")
- ("search_interpolation", po::value<string>(), "at what level should interpolation happen? [*data|policy]")
- ("search_rollout", po::value<string>(), "how should rollouts be executed? [policy|oracle|*mix_per_state|mix_per_roll|none]")
- ("search_rollin", po::value<string>(), "how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]")
-
- ("search_passes_per_policy", po::value<size_t>(), "number of passes per policy (only valid for search_interpolation=policy) [def=1]")
- ("search_beta", po::value<float>(), "interpolation rate for policies (only valid for search_interpolation=policy) [def=0.5]")
-
- ("search_alpha", po::value<float>(), "annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data) [def=1e-10]")
-
- ("search_total_nb_policies", po::value<size_t>(), "if we are going to train the policies through multiple separate calls to vw, we need to specify this parameter and tell vw how many policies are eventually going to be trained")
-
- ("search_trained_nb_policies", po::value<size_t>(), "the number of trained policies in a file")
-
- ("search_allowed_transitions",po::value<string>(),"read file of allowed transitions [def: all transitions are allowed]")
- ("search_subsample_time", po::value<float>(), "instead of training at all timesteps, use a subset. if value in (0,1), train on a random v%. if v>=1, train on precisely v steps per example")
- ("search_neighbor_features", po::value<string>(), "copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line namespace a and next next line from namespace _unnamed_, where ',' separates them")
- ("search_rollout_num_steps", po::value<size_t>(), "how many calls of \"loss\" before we stop really predicting on rollouts and switch to oracle (def: 0 means \"infinite\")")
- ("search_history_length", po::value<size_t>(), "some tasks allow you to specify how much history their depend on; specify that here [def: 1]")
-
- ("search_no_caching", "turn off the built-in caching ability (makes things slower, but technically more safe)")
- ("search_beam", po::value<size_t>(), "use beam search (arg = beam size, default 0 = no beam)")
- ("search_kbest", po::value<size_t>(), "size of k-best list to produce (must be <= beam size)")
- ;
+ base_learner* setup(vw&all) {
+ new_options(all,"Search Options")
+ ("search", po::value<size_t>(), "use search-based structured prediction, argument=maximum action id or 0 for LDF");
+ if (missing_required(all)) return NULL;
+ new_options(all)
+ ("search_task", po::value<string>(), "the search task (use \"--search_task list\" to get a list of available tasks)")
+ ("search_interpolation", po::value<string>(), "at what level should interpolation happen? [*data|policy]")
+ ("search_rollout", po::value<string>(), "how should rollouts be executed? [policy|oracle|*mix_per_state|mix_per_roll|none]")
+ ("search_rollin", po::value<string>(), "how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]")
+
+ ("search_passes_per_policy", po::value<size_t>(), "number of passes per policy (only valid for search_interpolation=policy) [def=1]")
+ ("search_beta", po::value<float>(), "interpolation rate for policies (only valid for search_interpolation=policy) [def=0.5]")
+
+ ("search_alpha", po::value<float>(), "annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data) [def=1e-10]")
+
+ ("search_total_nb_policies", po::value<size_t>(), "if we are going to train the policies through multiple separate calls to vw, we need to specify this parameter and tell vw how many policies are eventually going to be trained")
+
+ ("search_trained_nb_policies", po::value<size_t>(), "the number of trained policies in a file")
+
+ ("search_allowed_transitions",po::value<string>(),"read file of allowed transitions [def: all transitions are allowed]")
+ ("search_subsample_time", po::value<float>(), "instead of training at all timesteps, use a subset. if value in (0,1), train on a random v%. if v>=1, train on precisely v steps per example")
+ ("search_neighbor_features", po::value<string>(), "copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line namespace a and next next line from namespace _unnamed_, where ',' separates them")
+ ("search_rollout_num_steps", po::value<size_t>(), "how many calls of \"loss\" before we stop really predicting on rollouts and switch to oracle (def: 0 means \"infinite\")")
+ ("search_history_length", po::value<size_t>(), "some tasks allow you to specify how much history their depend on; specify that here [def: 1]")
+
+ ("search_no_caching", "turn off the built-in caching ability (makes things slower, but technically more safe)")
+ ("search_beam", po::value<size_t>(), "use beam search (arg = beam size, default 0 = no beam)")
+ ("search_kbest", po::value<size_t>(), "size of k-best list to produce (must be <= beam size)")
+ ;
+ add_options(all);
+ po::variables_map& vm = all.vm;
+ if (!vm.count("search"))
+ return NULL;
bool has_hook_task = false;
for (size_t i=0; i<all.args.size()-1; i++)
@@ -1805,9 +1805,12 @@ namespace Search {
for (int i = (int)all.args.size()-2; i >= 0; i--)
if (all.args[i] == "--search_task" && all.args[i+1] != "hook")
all.args.erase(all.args.begin() + i, all.args.begin() + i + 2);
+
+ search& sch = calloc_or_die<search>();
+ sch.priv = new search_private();
+ search_initialize(&all, sch);
+ search_private& priv = *sch.priv;
- vm = add_options(all, search_opts);
-
std::string task_string;
std::string interpolation_string = "data";
std::string rollout_string = "mix_per_state";
@@ -1955,6 +1958,18 @@ namespace Search {
}
all.p->emptylines_separate_examples = true;
+ if (count(all.args.begin(), all.args.end(),"--csoaa") == 0
+ && count(all.args.begin(), all.args.end(),"--csoaa_ldf") == 0
+ && count(all.args.begin(), all.args.end(),"--wap_ldf") == 0
+ && count(all.args.begin(), all.args.end(),"--cb") == 0)
+ {
+ all.args.push_back("--csoaa");
+ stringstream ss;
+ ss << vm["search"].as<size_t>();
+ all.args.push_back(ss.str());
+ }
+ base_learner* base = setup_base(all);
+
// default to OAA labels unless the task wants to override this (which they can do in initialize)
all.p->lp = MC::mc_label;
if (priv.task)
@@ -1964,7 +1979,7 @@ namespace Search {
// set up auto-history if they want it
if (priv.auto_condition_features) {
- handle_condition_options(all, priv.acset, vm);
+ handle_condition_options(all, priv.acset);
// turn off auto-condition if it's irrelevant
if (((priv.acset.max_bias_ngram_length == 0) && (priv.acset.max_quad_ngram_length == 0)) ||
@@ -1981,7 +1996,8 @@ namespace Search {
priv.start_clock_time = clock();
- learner<search>& l = init_learner(&sch, all.l, search_predict_or_learn<true>,
+ learner<search>& l = init_learner(&sch, base,
+ search_predict_or_learn<true>,
search_predict_or_learn<false>,
priv.total_number_of_policies);
l.set_finish_example(finish_example);
@@ -2063,7 +2079,7 @@ namespace Search {
void search::set_num_learners(size_t num_learners) { this->priv->num_learners = num_learners; }
- void search::add_program_options(po::variables_map& vm, po::options_description& opts) { vm = add_options( *this->priv->all, opts ); }
+ void search::add_program_options(po::variables_map& vw, po::options_description& opts) { add_options( *this->priv->all, opts ); }
size_t search::get_mask() { return this->priv->all->reg.weight_mask;}
size_t search::get_stride_shift() { return this->priv->all->reg.stride_shift;}
diff --git a/vowpalwabbit/search.h b/vowpalwabbit/search.h
index e129de25..54c35259 100644
--- a/vowpalwabbit/search.h
+++ b/vowpalwabbit/search.h
@@ -241,8 +241,5 @@ namespace Search {
bool size_equal(size_t a, size_t b);
// our interface within VW
- LEARNER::base_learner* setup(vw&, po::variables_map&);
- void search_finish(void*);
- void search_drive(void*);
- void search_learn(void*,example*);
+ LEARNER::base_learner* setup(vw&);
}
diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc
index a9ded7e4..ae73cee9 100644
--- a/vowpalwabbit/sender.cc
+++ b/vowpalwabbit/sender.cc
@@ -96,13 +96,17 @@ void end_examples(sender& s)
delete s.buf;
}
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm, vector<string> pairs)
-{
+ LEARNER::base_learner* setup(vw& all)
+ {
+ new_options(all, "Sender options")
+ ("sendto", po::value< vector<string> >(), "send examples to <host>");
+ if(missing_required(all)) return NULL;
+
sender& s = calloc_or_die<sender>();
s.sd = -1;
- if (vm.count("sendto"))
+ if (all.vm.count("sendto"))
{
- vector<string> hosts = vm["sendto"].as< vector<string> >();
+ vector<string> hosts = all.vm["sendto"].as< vector<string> >();
open_sockets(s, hosts[0]);
}
diff --git a/vowpalwabbit/sender.h b/vowpalwabbit/sender.h
index 9740f159..b8199bf6 100644
--- a/vowpalwabbit/sender.h
+++ b/vowpalwabbit/sender.h
@@ -4,6 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace SENDER{
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm, vector<string> pairs);
-}
+namespace SENDER { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc
index b2e7e150..68065e0c 100644
--- a/vowpalwabbit/stagewise_poly.cc
+++ b/vowpalwabbit/stagewise_poly.cc
@@ -7,7 +7,7 @@
#include "allreduce.h"
#include "accumulate.h"
#include "constant.h"
-#include "memory.h"
+#include "reductions.h"
#include "vw.h"
//#define MAGIC_ARGUMENT //MAY IT NEVER DIE
@@ -656,17 +656,13 @@ namespace StagewisePoly
//#endif //DEBUG
}
-
- base_learner *setup(vw &all, po::variables_map &vm)
+ base_learner *setup(vw &all)
{
- stagewise_poly& poly = calloc_or_die<stagewise_poly>();
- poly.all = &all;
-
- depthsbits_create(poly);
- sort_data_create(poly);
+ new_options(all, "Stagewise poly options")
+ ("stage_poly", "use stagewise polynomial feature learning");
+ if (missing_required(all)) return NULL;
- po::options_description sp_opt("Stagewise poly options");
- sp_opt.add_options()
+ new_options(all)
("sched_exponent", po::value<float>(), "exponent controlling quantity of included features")
("batch_sz", po::value<uint32_t>(), "multiplier on batch size before including more features")
("batch_sz_no_doubling", "batch_sz does not double")
@@ -674,7 +670,13 @@ namespace StagewisePoly
("magic_argument", po::value<float>(), "magical feature flag")
#endif //MAGIC_ARGUMENT
;
- vm = add_options(all, sp_opt);
+ add_options(all);
+
+ po::variables_map &vm = all.vm;
+ stagewise_poly& poly = calloc_or_die<stagewise_poly>();
+ poly.all = &all;
+ depthsbits_create(poly);
+ sort_data_create(poly);
poly.sched_exponent = vm.count("sched_exponent") ? vm["sched_exponent"].as<float>() : 1.f;
poly.batch_sz = vm.count("batch_sz") ? vm["batch_sz"].as<uint32_t>() : 1000;
@@ -698,7 +700,7 @@ namespace StagewisePoly
//following is so that saved models know to load us.
*all.file_options << " --stage_poly";
- learner<stagewise_poly>& l = init_learner(&poly, all.l, learn, predict);
+ learner<stagewise_poly>& l = init_learner(&poly, setup_base(all), learn, predict);
l.set_finish(finish);
l.set_save_load(save_load);
l.set_finish_example(finish_example);
diff --git a/vowpalwabbit/stagewise_poly.h b/vowpalwabbit/stagewise_poly.h
index 983b4382..60478e81 100644
--- a/vowpalwabbit/stagewise_poly.h
+++ b/vowpalwabbit/stagewise_poly.h
@@ -4,7 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace StagewisePoly
-{
- LEARNER::base_learner *setup(vw &all, po::variables_map &vm);
-}
+namespace StagewisePoly { LEARNER::base_learner *setup(vw &all); }
diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc
index 445bdb23..52f03ad9 100644
--- a/vowpalwabbit/topk.cc
+++ b/vowpalwabbit/topk.cc
@@ -16,15 +16,12 @@ namespace TOPK {
struct compare_scored_examples
{
bool operator()(scored_example const& a, scored_example const& b) const
- {
- return a.first > b.first;
- }
+ { return a.first > b.first; }
};
struct topk{
uint32_t B; //rec number
priority_queue<scored_example, vector<scored_example>, compare_scored_examples > pr_queue;
- vw* all;
};
void print_result(int f, priority_queue<scored_example, vector<scored_example>, compare_scored_examples > &pr_queue)
@@ -72,43 +69,46 @@ namespace TOPK {
if (example_is_newline(ec))
for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++)
TOPK::print_result(*sink, d.pr_queue);
-
+
print_update(all, ec);
}
-
+
template <bool is_learn>
void predict_or_learn(topk& d, LEARNER::base_learner& base, example& ec)
{
if (example_is_newline(ec)) return;//do not predict newline
-
+
if (is_learn)
base.learn(ec);
else
base.predict(ec);
-
+
if(d.pr_queue.size() < d.B)
d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
-
+
else if(d.pr_queue.top().first < ec.pred.scalar)
- {
- d.pr_queue.pop();
- d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
- }
+ {
+ d.pr_queue.pop();
+ d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
+ }
}
-
+
void finish_example(vw& all, topk& d, example& ec)
{
TOPK::output_example(all, d, ec);
VW::finish_example(all, &ec);
}
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm)
+ LEARNER::base_learner* setup(vw& all)
{
+ new_options(all, "TOP K options")
+ ("top", po::value<size_t>(), "top k recommendation");
+ if(missing_required(all)) return NULL;
+
topk& data = calloc_or_die<topk>();
- data.B = (uint32_t)vm["top"].as<size_t>();
- data.all = &all;
+ data.B = (uint32_t)all.vm["top"].as<size_t>();
- LEARNER::learner<topk>& l = init_learner(&data, all.l, predict_or_learn<true>,
+ LEARNER::learner<topk>& l = init_learner(&data, setup_base(all), predict_or_learn<true>,
predict_or_learn<false>);
l.set_finish_example(finish_example);
diff --git a/vowpalwabbit/topk.h b/vowpalwabbit/topk.h
index 866d94c5..6e9973ad 100644
--- a/vowpalwabbit/topk.h
+++ b/vowpalwabbit/topk.h
@@ -4,15 +4,4 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-#include "io_buf.h"
-#include "parse_primitives.h"
-#include "global_data.h"
-#include "example.h"
-#include "parse_args.h"
-#include "v_hashmap.h"
-#include "simple_label.h"
-
-namespace TOPK
-{
- LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
-}
+namespace TOPK { LEARNER::base_learner* setup(vw& all); }
diff --git a/vowpalwabbit/vw_static.vcxproj b/vowpalwabbit/vw_static.vcxproj
index b555aef2..3df09387 100644
--- a/vowpalwabbit/vw_static.vcxproj
+++ b/vowpalwabbit/vw_static.vcxproj
@@ -1,4 +1,4 @@
-<?xml version="1.0" encoding="utf-8"?>
+<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
@@ -249,6 +249,7 @@
<ClInclude Include="ect.h" />
<ClInclude Include="example.h" />
<ClInclude Include="gd.h" />
+ <ClInclude Include="ftrl_proximal.h" />
<ClInclude Include="memory.h" />
<ClInclude Include="multiclass.h" />
<ClInclude Include="cost_sensitive.h" />
@@ -305,8 +306,8 @@
<ClCompile Include="ect.cc" />
<ClCompile Include="example.cc" />
<ClCompile Include="gd.cc" />
+ <ClCompile Include="ftrl_proximal.cc" />
<ClCompile Include="kernel_svm.cc" />
- <ClCompile Include="memory.cc" />
<ClCompile Include="multiclass.cc" />
<ClCompile Include="cost_sensitive.cc" />
<ClCompile Include="cb_algs.cc" />
diff --git a/vowpalwabbit/vwdll.cpp b/vowpalwabbit/vwdll.cpp
index 6235705b..501730b5 100644
--- a/vowpalwabbit/vwdll.cpp
+++ b/vowpalwabbit/vwdll.cpp
@@ -52,7 +52,7 @@ extern "C"
adjust_used_index(*pointer);
pointer->do_reset_source = true;
VW::start_parser(*pointer,false);
- pointer->l->driver(pointer);
+ LEARNER::generic_driver(*pointer);
VW::end_parser(*pointer);
}
else