Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Steingold <sds@gnu.org>2014-12-23 20:09:43 +0300
committerSam Steingold <sds@gnu.org>2014-12-23 20:09:43 +0300
commited8c4d38aba3b49159f1b2574028b5cbae96a7f2 (patch)
tree4ecd16430fb82c8d1d67fcccfd887ce79c02eac6 /vowpalwabbit
parent1452485ae7248f5a12e3ac909aa8a9dedaf26241 (diff)
convert to unix line endings, like all the other sources
Diffstat (limited to 'vowpalwabbit')
-rw-r--r--vowpalwabbit/accumulate.cc230
-rw-r--r--vowpalwabbit/accumulate.h26
-rw-r--r--vowpalwabbit/lda_core.cc1604
-rw-r--r--vowpalwabbit/log_multi.cc1098
4 files changed, 1479 insertions, 1479 deletions
diff --git a/vowpalwabbit/accumulate.cc b/vowpalwabbit/accumulate.cc
index d6c5e71f..8d15dd59 100644
--- a/vowpalwabbit/accumulate.cc
+++ b/vowpalwabbit/accumulate.cc
@@ -1,115 +1,115 @@
-/*
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD (revised)
-license as described in the file LICENSE.
- */
-/*
-This implements the allreduce function of MPI. Code primarily by
-Alekh Agarwal and John Langford, with help Olivier Chapelle.
-*/
-
-#include <iostream>
-#include <sys/timeb.h>
-#include <cmath>
-#include <stdint.h>
-#include "accumulate.h"
-#include "global_data.h"
-
-using namespace std;
-
-void add_float(float& c1, const float& c2) {
- c1 += c2;
-}
-
-void accumulate(vw& all, string master_location, regressor& reg, size_t o) {
- uint32_t length = 1 << all.num_bits; //This is size of gradient
- size_t stride = 1 << all.reg.stride_shift;
- float* local_grad = new float[length];
- weight* weights = reg.weight_vector;
- for(uint32_t i = 0;i < length;i++)
- {
- local_grad[i] = weights[stride*i+o];
- }
-
- all_reduce<float, add_float>(local_grad, length, master_location, all.unique_id, all.total, all.node, all.socks);
- for(uint32_t i = 0;i < length;i++)
- {
- weights[stride*i+o] = local_grad[i];
- }
- delete[] local_grad;
-}
-
-float accumulate_scalar(vw& all, string master_location, float local_sum) {
- float temp = local_sum;
- all_reduce<float, add_float>(&temp, 1, master_location, all.unique_id, all.total, all.node, all.socks);
- return temp;
-}
-
-void accumulate_avg(vw& all, string master_location, regressor& reg, size_t o) {
- uint32_t length = 1 << all.num_bits; //This is size of gradient
- size_t stride = 1 << all.reg.stride_shift;
- float* local_grad = new float[length];
- weight* weights = reg.weight_vector;
- float numnodes = (float)all.total;
-
- for(uint32_t i = 0;i < length;i++)
- local_grad[i] = weights[stride*i+o];
-
- all_reduce<float, add_float>(local_grad, length, master_location, all.unique_id, all.total, all.node, all.socks);
- for(uint32_t i = 0;i < length;i++)
- weights[stride*i+o] = local_grad[i]/numnodes;
- delete[] local_grad;
-}
-
-float max_elem(float* arr, int length) {
- float max = arr[0];
- for(int i = 1;i < length;i++)
- if(arr[i] > max) max = arr[i];
- return max;
-}
-
-float min_elem(float* arr, int length) {
- float min = arr[0];
- for(int i = 1;i < length;i++)
- if(arr[i] < min && arr[i] > 0.001) min = arr[i];
- return min;
-}
-
-void accumulate_weighted_avg(vw& all, string master_location, regressor& reg) {
- if(!all.adaptive) {
- cerr<<"Weighted averaging is implemented only for adaptive gradient, use accumulate_avg instead\n";
- return;
- }
- uint32_t length = 1 << all.num_bits; //This is the number of parameters
- size_t stride = 1 << all.reg.stride_shift;
- weight* weights = reg.weight_vector;
-
-
- float* local_weights = new float[length];
-
- for(uint32_t i = 0;i < length;i++)
- local_weights[i] = weights[stride*i+1];
-
-
- //First compute weights for averaging
- all_reduce<float, add_float>(local_weights, length, master_location, all.unique_id, all.total, all.node, all.socks);
-
- for(uint32_t i = 0;i < length;i++) //Compute weighted versions
- if(local_weights[i] > 0) {
- float ratio = weights[stride*i+1]/local_weights[i];
- local_weights[i] = weights[stride*i] * ratio;
- weights[stride*i] *= ratio;
- weights[stride*i+1] *= ratio; //A crude max
- if (all.normalized_updates)
- weights[stride*i+all.normalized_idx] *= ratio; //A crude max
- }
- else {
- local_weights[i] = 0;
- weights[stride*i] = 0;
- }
-
- all_reduce<float, add_float>(weights, length*stride, master_location, all.unique_id, all.total, all.node, all.socks);
-
- delete[] local_weights;
-}
-
+/*
+Copyright (c) by respective owners including Yahoo!, Microsoft, and
+individual contributors. All rights reserved. Released under a BSD (revised)
+license as described in the file LICENSE.
+ */
+/*
+This implements the allreduce function of MPI. Code primarily by
+Alekh Agarwal and John Langford, with help Olivier Chapelle.
+*/
+
+#include <iostream>
+#include <sys/timeb.h>
+#include <cmath>
+#include <stdint.h>
+#include "accumulate.h"
+#include "global_data.h"
+
+using namespace std;
+
+void add_float(float& c1, const float& c2) {
+ c1 += c2;
+}
+
+void accumulate(vw& all, string master_location, regressor& reg, size_t o) {
+ uint32_t length = 1 << all.num_bits; //This is size of gradient
+ size_t stride = 1 << all.reg.stride_shift;
+ float* local_grad = new float[length];
+ weight* weights = reg.weight_vector;
+ for(uint32_t i = 0;i < length;i++)
+ {
+ local_grad[i] = weights[stride*i+o];
+ }
+
+ all_reduce<float, add_float>(local_grad, length, master_location, all.unique_id, all.total, all.node, all.socks);
+ for(uint32_t i = 0;i < length;i++)
+ {
+ weights[stride*i+o] = local_grad[i];
+ }
+ delete[] local_grad;
+}
+
+float accumulate_scalar(vw& all, string master_location, float local_sum) {
+ float temp = local_sum;
+ all_reduce<float, add_float>(&temp, 1, master_location, all.unique_id, all.total, all.node, all.socks);
+ return temp;
+}
+
+void accumulate_avg(vw& all, string master_location, regressor& reg, size_t o) {
+ uint32_t length = 1 << all.num_bits; //This is size of gradient
+ size_t stride = 1 << all.reg.stride_shift;
+ float* local_grad = new float[length];
+ weight* weights = reg.weight_vector;
+ float numnodes = (float)all.total;
+
+ for(uint32_t i = 0;i < length;i++)
+ local_grad[i] = weights[stride*i+o];
+
+ all_reduce<float, add_float>(local_grad, length, master_location, all.unique_id, all.total, all.node, all.socks);
+ for(uint32_t i = 0;i < length;i++)
+ weights[stride*i+o] = local_grad[i]/numnodes;
+ delete[] local_grad;
+}
+
+float max_elem(float* arr, int length) {
+ float max = arr[0];
+ for(int i = 1;i < length;i++)
+ if(arr[i] > max) max = arr[i];
+ return max;
+}
+
+float min_elem(float* arr, int length) {
+ float min = arr[0];
+ for(int i = 1;i < length;i++)
+ if(arr[i] < min && arr[i] > 0.001) min = arr[i];
+ return min;
+}
+
+void accumulate_weighted_avg(vw& all, string master_location, regressor& reg) {
+ if(!all.adaptive) {
+ cerr<<"Weighted averaging is implemented only for adaptive gradient, use accumulate_avg instead\n";
+ return;
+ }
+ uint32_t length = 1 << all.num_bits; //This is the number of parameters
+ size_t stride = 1 << all.reg.stride_shift;
+ weight* weights = reg.weight_vector;
+
+
+ float* local_weights = new float[length];
+
+ for(uint32_t i = 0;i < length;i++)
+ local_weights[i] = weights[stride*i+1];
+
+
+ //First compute weights for averaging
+ all_reduce<float, add_float>(local_weights, length, master_location, all.unique_id, all.total, all.node, all.socks);
+
+ for(uint32_t i = 0;i < length;i++) //Compute weighted versions
+ if(local_weights[i] > 0) {
+ float ratio = weights[stride*i+1]/local_weights[i];
+ local_weights[i] = weights[stride*i] * ratio;
+ weights[stride*i] *= ratio;
+ weights[stride*i+1] *= ratio; //A crude max
+ if (all.normalized_updates)
+ weights[stride*i+all.normalized_idx] *= ratio; //A crude max
+ }
+ else {
+ local_weights[i] = 0;
+ weights[stride*i] = 0;
+ }
+
+ all_reduce<float, add_float>(weights, length*stride, master_location, all.unique_id, all.total, all.node, all.socks);
+
+ delete[] local_weights;
+}
+
diff --git a/vowpalwabbit/accumulate.h b/vowpalwabbit/accumulate.h
index c01ac5fe..4d507a60 100644
--- a/vowpalwabbit/accumulate.h
+++ b/vowpalwabbit/accumulate.h
@@ -1,13 +1,13 @@
-/*
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD
-license as described in the file LICENSE.
- */
-//This implements various accumulate functions building on top of allreduce.
-#pragma once
-#include "global_data.h"
-
-void accumulate(vw& all, std::string master_location, regressor& reg, size_t o);
-float accumulate_scalar(vw& all, std::string master_location, float local_sum);
-void accumulate_weighted_avg(vw& all, std::string master_location, regressor& reg);
-void accumulate_avg(vw& all, std::string master_location, regressor& reg, size_t o);
+/*
+Copyright (c) by respective owners including Yahoo!, Microsoft, and
+individual contributors. All rights reserved. Released under a BSD
+license as described in the file LICENSE.
+ */
+//This implements various accumulate functions building on top of allreduce.
+#pragma once
+#include "global_data.h"
+
+void accumulate(vw& all, std::string master_location, regressor& reg, size_t o);
+float accumulate_scalar(vw& all, std::string master_location, float local_sum);
+void accumulate_weighted_avg(vw& all, std::string master_location, regressor& reg);
+void accumulate_avg(vw& all, std::string master_location, regressor& reg, size_t o);
diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc
index 90f9e171..23ca9bfe 100644
--- a/vowpalwabbit/lda_core.cc
+++ b/vowpalwabbit/lda_core.cc
@@ -1,802 +1,802 @@
-/*
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD (revised)
-license as described in the file LICENSE.
- */
-#include <fstream>
-#include <vector>
-#include <float.h>
-#ifdef _WIN32
-#include <winsock2.h>
-#else
-#include <netdb.h>
-#endif
-#include <string.h>
-#include <stdio.h>
-#include <assert.h>
-#include "constant.h"
-#include "gd.h"
-#include "simple_label.h"
-#include "rand48.h"
-#include "reductions.h"
-
-using namespace LEARNER;
-using namespace std;
-
-namespace LDA {
-
-class index_feature {
-public:
- uint32_t document;
- feature f;
- bool operator<(const index_feature b) const { return f.weight_index < b.f.weight_index; }
-};
-
- struct lda {
- v_array<float> Elogtheta;
- v_array<float> decay_levels;
- v_array<float> total_new;
- v_array<example* > examples;
- v_array<float> total_lambda;
- v_array<int> doc_lengths;
- v_array<float> digammas;
- v_array<float> v;
- vector<index_feature> sorted_features;
-
- bool total_lambda_init;
-
- double example_t;
- vw* all;
- };
-
-#ifdef _WIN32
-inline float fmax(float f1, float f2) { return (f1 < f2 ? f2 : f1); }
-inline float fmin(float f1, float f2) { return (f1 > f2 ? f2 : f1); }
-#endif
-
-#define MINEIRO_SPECIAL
-#ifdef MINEIRO_SPECIAL
-
-namespace {
-
-inline float
-fastlog2 (float x)
-{
- union { float f; uint32_t i; } vx = { x };
- union { uint32_t i; float f; } mx = { (vx.i & 0x007FFFFF) | (0x7e << 23) };
- float y = (float)vx.i;
- y *= 1.0f / (float)(1 << 23);
-
- return
- y - 124.22544637f - 1.498030302f * mx.f - 1.72587999f / (0.3520887068f + mx.f);
-}
-
-inline float
-fastlog (float x)
-{
- return 0.69314718f * fastlog2 (x);
-}
-
-inline float
-fastpow2 (float p)
-{
- float offset = (p < 0) ? 1.0f : 0.0f;
- float clipp = (p < -126) ? -126.0f : p;
- int w = (int)clipp;
- float z = clipp - w + offset;
- union { uint32_t i; float f; } v = { (uint32_t)((1 << 23) * (clipp + 121.2740838f + 27.7280233f / (4.84252568f - z) - 1.49012907f * z)) };
-
- return v.f;
-}
-
-inline float
-fastexp (float p)
-{
- return fastpow2 (1.442695040f * p);
-}
-
-inline float
-fastpow (float x,
- float p)
-{
- return fastpow2 (p * fastlog2 (x));
-}
-
-inline float
-fastlgamma (float x)
-{
- float logterm = fastlog (x * (1.0f + x) * (2.0f + x));
- float xp3 = 3.0f + x;
-
- return
- -2.081061466f - x + 0.0833333f / xp3 - logterm + (2.5f + x) * fastlog (xp3);
-}
-
-inline float
-fastdigamma (float x)
-{
- float twopx = 2.0f + x;
- float logterm = fastlog (twopx);
-
- return - (1.0f + 2.0f * x) / (x * (1.0f + x))
- - (13.0f + 6.0f * x) / (12.0f * twopx * twopx)
- + logterm;
-}
-
-#define log fastlog
-#define exp fastexp
-#define powf fastpow
-#define mydigamma fastdigamma
-#define mylgamma fastlgamma
-
-#if defined(__SSE2__) && !defined(VW_LDA_NO_SSE)
-
-#include <emmintrin.h>
-
-typedef __m128 v4sf;
-typedef __m128i v4si;
-
-#define v4si_to_v4sf _mm_cvtepi32_ps
-#define v4sf_to_v4si _mm_cvttps_epi32
-
-static inline float
-v4sf_index (const v4sf x,
- unsigned int i)
-{
- union { v4sf f; float array[4]; } tmp = { x };
-
- return tmp.array[i];
-}
-
-static inline const v4sf
-v4sfl (float x)
-{
- union { float array[4]; v4sf f; } tmp = { { x, x, x, x } };
-
- return tmp.f;
-}
-
-static inline const v4si
-v4sil (uint32_t x)
-{
- uint64_t wide = (((uint64_t) x) << 32) | x;
- union { uint64_t array[2]; v4si f; } tmp = { { wide, wide } };
-
- return tmp.f;
-}
-
-static inline v4sf
-vfastpow2 (const v4sf p)
-{
- v4sf ltzero = _mm_cmplt_ps (p, v4sfl (0.0f));
- v4sf offset = _mm_and_ps (ltzero, v4sfl (1.0f));
- v4sf lt126 = _mm_cmplt_ps (p, v4sfl (-126.0f));
- v4sf clipp = _mm_andnot_ps (lt126, p) + _mm_and_ps (lt126, v4sfl (-126.0f));
- v4si w = v4sf_to_v4si (clipp);
- v4sf z = clipp - v4si_to_v4sf (w) + offset;
-
- const v4sf c_121_2740838 = v4sfl (121.2740838f);
- const v4sf c_27_7280233 = v4sfl (27.7280233f);
- const v4sf c_4_84252568 = v4sfl (4.84252568f);
- const v4sf c_1_49012907 = v4sfl (1.49012907f);
- union { v4si i; v4sf f; } v = {
- v4sf_to_v4si (
- v4sfl (1 << 23) *
- (clipp + c_121_2740838 + c_27_7280233 / (c_4_84252568 - z) - c_1_49012907 * z)
- )
- };
-
- return v.f;
-}
-
-inline v4sf
-vfastexp (const v4sf p)
-{
- const v4sf c_invlog_2 = v4sfl (1.442695040f);
-
- return vfastpow2 (c_invlog_2 * p);
-}
-
-inline v4sf
-vfastlog2 (v4sf x)
-{
- union { v4sf f; v4si i; } vx = { x };
- union { v4si i; v4sf f; } mx = { (vx.i & v4sil (0x007FFFFF)) | v4sil (0x3f000000) };
- v4sf y = v4si_to_v4sf (vx.i);
- y *= v4sfl (1.1920928955078125e-7f);
-
- const v4sf c_124_22551499 = v4sfl (124.22551499f);
- const v4sf c_1_498030302 = v4sfl (1.498030302f);
- const v4sf c_1_725877999 = v4sfl (1.72587999f);
- const v4sf c_0_3520087068 = v4sfl (0.3520887068f);
-
- return y - c_124_22551499
- - c_1_498030302 * mx.f
- - c_1_725877999 / (c_0_3520087068 + mx.f);
-}
-
-inline v4sf
-vfastlog (v4sf x)
-{
- const v4sf c_0_69314718 = v4sfl (0.69314718f);
-
- return c_0_69314718 * vfastlog2 (x);
-}
-
-inline v4sf
-vfastdigamma (v4sf x)
-{
- v4sf twopx = v4sfl (2.0f) + x;
- v4sf logterm = vfastlog (twopx);
-
- return (v4sfl (-48.0f) + x * (v4sfl (-157.0f) + x * (v4sfl (-127.0f) - v4sfl (30.0f) * x))) /
- (v4sfl (12.0f) * x * (v4sfl (1.0f) + x) * twopx * twopx)
- + logterm;
-}
-
-void
-vexpdigammify (vw& all, float* gamma)
-{
- unsigned int n = all.lda;
- float extra_sum = 0.0f;
- v4sf sum = v4sfl (0.0f);
- size_t i;
-
- for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
- {
- extra_sum += gamma[i];
- gamma[i] = fastdigamma (gamma[i]);
- }
-
- for (; i + 4 < n; i += 4)
- {
- v4sf arg = _mm_load_ps (gamma + i);
- sum += arg;
- arg = vfastdigamma (arg);
- _mm_store_ps (gamma + i, arg);
- }
-
- for (; i < n; ++i)
- {
- extra_sum += gamma[i];
- gamma[i] = fastdigamma (gamma[i]);
- }
-
- extra_sum += v4sf_index (sum, 0) + v4sf_index (sum, 1) +
- v4sf_index (sum, 2) + v4sf_index (sum, 3);
- extra_sum = fastdigamma (extra_sum);
- sum = v4sfl (extra_sum);
-
- for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0)));
- }
-
- for (; i + 4 < n; i += 4)
- {
- v4sf arg = _mm_load_ps (gamma + i);
- arg -= sum;
- arg = vfastexp (arg);
- arg = _mm_max_ps (v4sfl (1e-10f), arg);
- _mm_store_ps (gamma + i, arg);
- }
-
- for (; i < n; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0)));
- }
-}
-
-void vexpdigammify_2(vw& all, float* gamma, const float* norm)
-{
- size_t n = all.lda;
- size_t i;
-
- for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i]));
- }
-
- for (; i + 4 < n; i += 4)
- {
- v4sf arg = _mm_load_ps (gamma + i);
- arg = vfastdigamma (arg);
- v4sf vnorm = _mm_loadu_ps (norm + i);
- arg -= vnorm;
- arg = vfastexp (arg);
- arg = _mm_max_ps (v4sfl (1e-10f), arg);
- _mm_store_ps (gamma + i, arg);
- }
-
- for (; i < n; ++i)
- {
- gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i]));
- }
-}
-
-#define myexpdigammify vexpdigammify
-#define myexpdigammify_2 vexpdigammify_2
-
-#else
-#ifndef _WIN32
-#warning "lda IS NOT using sse instructions"
-#endif
-#define myexpdigammify expdigammify
-#define myexpdigammify_2 expdigammify_2
-
-#endif // __SSE2__
-
-} // end anonymous namespace
-
-#else
-
-#include <boost/math/special_functions/digamma.hpp>
-#include <boost/math/special_functions/gamma.hpp>
-
-using namespace boost::math::policies;
-
-#define mydigamma boost::math::digamma
-#define mylgamma boost::math::lgamma
-#define myexpdigammify expdigammify
-#define myexpdigammify_2 expdigammify_2
-
-#endif // MINEIRO_SPECIAL
-
-float decayfunc(float t, float old_t, float power_t) {
- float result = 1;
- for (float i = old_t+1; i <= t; i += 1)
- result *= (1-powf(i, -power_t));
- return result;
-}
-
-float decayfunc2(float t, float old_t, float power_t)
-{
- float power_t_plus_one = 1.f - power_t;
- float arg = - ( powf(t, power_t_plus_one) -
- powf(old_t, power_t_plus_one));
- return exp ( arg
- / power_t_plus_one);
-}
-
-float decayfunc3(double t, double old_t, double power_t)
-{
- double power_t_plus_one = 1. - power_t;
- double logt = log((float)t);
- double logoldt = log((float)old_t);
- return (float)((old_t / t) * exp((float)(0.5*power_t_plus_one*(-logt*logt + logoldt*logoldt))));
-}
-
-float decayfunc4(double t, double old_t, double power_t)
-{
- if (power_t > 0.99)
- return decayfunc3(t, old_t, power_t);
- else
- return (float)decayfunc2((float)t, (float)old_t, (float)power_t);
-}
-
-void expdigammify(vw& all, float* gamma)
-{
- float sum=0;
- for (size_t i = 0; i<all.lda; i++)
- {
- sum += gamma[i];
- gamma[i] = mydigamma(gamma[i]);
- }
- sum = mydigamma(sum);
- for (size_t i = 0; i<all.lda; i++)
- gamma[i] = fmax(1e-6f, exp(gamma[i] - sum));
-}
-
-void expdigammify_2(vw& all, float* gamma, float* norm)
-{
- for (size_t i = 0; i<all.lda; i++)
- {
- gamma[i] = fmax(1e-6f, exp(mydigamma(gamma[i]) - norm[i]));
- }
-}
-
-float average_diff(vw& all, float* oldgamma, float* newgamma)
-{
- float sum = 0.;
- float normalizer = 0.;
- for (size_t i = 0; i<all.lda; i++) {
- sum += fabsf(oldgamma[i] - newgamma[i]);
- normalizer += newgamma[i];
- }
- return sum / normalizer;
-}
-
-// Returns E_q[log p(\theta)] - E_q[log q(\theta)].
- float theta_kl(vw& all, v_array<float>& Elogtheta, float* gamma)
-{
- float gammasum = 0;
- Elogtheta.erase();
- for (size_t k = 0; k < all.lda; k++) {
- Elogtheta.push_back(mydigamma(gamma[k]));
- gammasum += gamma[k];
- }
- float digammasum = mydigamma(gammasum);
- gammasum = mylgamma(gammasum);
- float kl = -(all.lda*mylgamma(all.lda_alpha));
- kl += mylgamma(all.lda_alpha*all.lda) - gammasum;
- for (size_t k = 0; k < all.lda; k++) {
- Elogtheta[k] -= digammasum;
- kl += (all.lda_alpha - gamma[k]) * Elogtheta[k];
- kl += mylgamma(gamma[k]);
- }
-
- return kl;
-}
-
-float find_cw(vw& all, float* u_for_w, float* v)
-{
- float c_w = 0;
- for (size_t k =0; k<all.lda; k++)
- c_w += u_for_w[k]*v[k];
-
- return 1.f / c_w;
-}
-
- v_array<float> new_gamma = v_init<float>();
- v_array<float> old_gamma = v_init<float>();
-// Returns an estimate of the part of the variational bound that
-// doesn't have to do with beta for the entire corpus for the current
-// setting of lambda based on the document passed in. The value is
-// divided by the total number of words in the document This can be
-// used as a (possibly very noisy) estimate of held-out likelihood.
- float lda_loop(vw& all, v_array<float>& Elogtheta, float* v,weight* weights,example* ec, float power_t)
-{
- new_gamma.erase();
- old_gamma.erase();
-
- for (size_t i = 0; i < all.lda; i++)
- {
- new_gamma.push_back(1.f);
- old_gamma.push_back(0.f);
- }
- size_t num_words =0;
- for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
- num_words += ec->atomics[*i].end - ec->atomics[*i].begin;
-
- float xc_w = 0;
- float score = 0;
- float doc_length = 0;
- do
- {
- memcpy(v,new_gamma.begin,sizeof(float)*all.lda);
- myexpdigammify(all, v);
-
- memcpy(old_gamma.begin,new_gamma.begin,sizeof(float)*all.lda);
- memset(new_gamma.begin,0,sizeof(float)*all.lda);
-
- score = 0;
- size_t word_count = 0;
- doc_length = 0;
- for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
- {
- feature *f = ec->atomics[*i].begin;
- for (; f != ec->atomics[*i].end; f++)
- {
- float* u_for_w = &weights[(f->weight_index&all.reg.weight_mask)+all.lda+1];
- float c_w = find_cw(all, u_for_w,v);
- xc_w = c_w * f->x;
- score += -f->x*log(c_w);
- size_t max_k = all.lda;
- for (size_t k =0; k<max_k; k++) {
- new_gamma[k] += xc_w*u_for_w[k];
- }
- word_count++;
- doc_length += f->x;
- }
- }
- for (size_t k =0; k<all.lda; k++)
- new_gamma[k] = new_gamma[k]*v[k]+all.lda_alpha;
- }
- while (average_diff(all, old_gamma.begin, new_gamma.begin) > all.lda_epsilon);
-
- ec->topic_predictions.erase();
- ec->topic_predictions.resize(all.lda);
- memcpy(ec->topic_predictions.begin,new_gamma.begin,all.lda*sizeof(float));
-
- score += theta_kl(all, Elogtheta, new_gamma.begin);
-
- return score / doc_length;
-}
-
-size_t next_pow2(size_t x) {
- int i = 0;
- x = x > 0 ? x - 1 : 0;
- while (x > 0) {
- x >>= 1;
- i++;
- }
- return ((size_t)1) << i;
-}
-
-void save_load(lda& l, io_buf& model_file, bool read, bool text)
-{
- vw* all = l.all;
- uint32_t length = 1 << all->num_bits;
- uint32_t stride = 1 << all->reg.stride_shift;
-
- if (read)
- {
- initialize_regressor(*all);
- for (size_t j = 0; j < stride*length; j+=stride)
- {
- for (size_t k = 0; k < all->lda; k++) {
- if (all->random_weights) {
- all->reg.weight_vector[j+k] = (float)(-log(frand48()) + 1.0f);
- all->reg.weight_vector[j+k] *= (float)(all->lda_D / all->lda / all->length() * 200);
- }
- }
- all->reg.weight_vector[j+all->lda] = all->initial_t;
- }
- }
-
- if (model_file.files.size() > 0)
- {
- uint32_t i = 0;
- uint32_t text_len;
- char buff[512];
- size_t brw = 1;
- do
- {
- brw = 0;
- size_t K = all->lda;
-
- text_len = sprintf(buff, "%d ", i);
- brw += bin_text_read_write_fixed(model_file,(char *)&i, sizeof (i),
- "", read,
- buff, text_len, text);
- if (brw != 0)
- for (uint32_t k = 0; k < K; k++)
- {
- uint32_t ndx = stride*i+k;
-
- weight* v = &(all->reg.weight_vector[ndx]);
- text_len = sprintf(buff, "%f ", *v + all->lda_rho);
-
- brw += bin_text_read_write_fixed(model_file,(char *)v, sizeof (*v),
- "", read,
- buff, text_len, text);
-
- }
- if (text)
- brw += bin_text_read_write_fixed(model_file,buff,0,
- "", read,
- "\n",1,text);
-
- if (!read)
- i++;
- }
- while ((!read && i < length) || (read && brw >0));
- }
-}
-
- void learn_batch(lda& l)
- {
- if (l.sorted_features.empty()) {
- // This can happen when the socket connection is dropped by the client.
- // If l.sorted_features is empty, then l.sorted_features[0] does not
- // exist, so we should not try to take its address in the beginning of
- // the for loops down there. Since it seems that there's not much to
- // do in this case, we just return.
- for (size_t d = 0; d < l.examples.size(); d++)
- return_simple_example(*l.all, NULL, *l.examples[d]);
- l.examples.erase();
- return;
- }
-
- float eta = -1;
- float minuseta = -1;
-
- if (l.total_lambda.size() == 0)
- {
- for (size_t k = 0; k < l.all->lda; k++)
- l.total_lambda.push_back(0.f);
-
- size_t stride = 1 << l.all->reg.stride_shift;
- for (size_t i =0; i <= l.all->reg.weight_mask;i+=stride)
- for (size_t k = 0; k < l.all->lda; k++)
- l.total_lambda[k] += l.all->reg.weight_vector[i+k];
- }
-
- l.example_t++;
- l.total_new.erase();
- for (size_t k = 0; k < l.all->lda; k++)
- l.total_new.push_back(0.f);
-
- size_t batch_size = l.examples.size();
-
- sort(l.sorted_features.begin(), l.sorted_features.end());
-
- eta = l.all->eta * powf((float)l.example_t, - l.all->power_t);
- minuseta = 1.0f - eta;
- eta *= l.all->lda_D / batch_size;
- l.decay_levels.push_back(l.decay_levels.last() + log(minuseta));
-
- l.digammas.erase();
- float additional = (float)(l.all->length()) * l.all->lda_rho;
- for (size_t i = 0; i<l.all->lda; i++) {
- l.digammas.push_back(mydigamma(l.total_lambda[i] + additional));
- }
-
-
- weight* weights = l.all->reg.weight_vector;
-
- size_t last_weight_index = -1;
- for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back(); s++)
- {
- if (last_weight_index == s->f.weight_index)
- continue;
- last_weight_index = s->f.weight_index;
- float* weights_for_w = &(weights[s->f.weight_index & l.all->reg.weight_mask]);
- float decay = fmin(1.0, exp(l.decay_levels.end[-2] - l.decay_levels.end[(int)(-1 - l.example_t+weights_for_w[l.all->lda])]));
- float* u_for_w = weights_for_w + l.all->lda+1;
-
- weights_for_w[l.all->lda] = (float)l.example_t;
- for (size_t k = 0; k < l.all->lda; k++)
- {
- weights_for_w[k] *= decay;
- u_for_w[k] = weights_for_w[k] + l.all->lda_rho;
- }
- myexpdigammify_2(*l.all, u_for_w, l.digammas.begin);
- }
-
- for (size_t d = 0; d < batch_size; d++)
- {
- float score = lda_loop(*l.all, l.Elogtheta, &(l.v[d*l.all->lda]), weights, l.examples[d],l.all->power_t);
- if (l.all->audit)
- GD::print_audit_features(*l.all, *l.examples[d]);
- // If the doc is empty, give it loss of 0.
- if (l.doc_lengths[d] > 0) {
- l.all->sd->sum_loss -= score;
- l.all->sd->sum_loss_since_last_dump -= score;
- }
- return_simple_example(*l.all, NULL, *l.examples[d]);
- }
-
- for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back();)
- {
- index_feature* next = s+1;
- while(next <= &l.sorted_features.back() && next->f.weight_index == s->f.weight_index)
- next++;
-
- float* word_weights = &(weights[s->f.weight_index & l.all->reg.weight_mask]);
- for (size_t k = 0; k < l.all->lda; k++) {
- float new_value = minuseta*word_weights[k];
- word_weights[k] = new_value;
- }
-
- for (; s != next; s++) {
- float* v_s = &(l.v[s->document*l.all->lda]);
- float* u_for_w = &weights[(s->f.weight_index & l.all->reg.weight_mask) + l.all->lda + 1];
- float c_w = eta*find_cw(*l.all, u_for_w, v_s)*s->f.x;
- for (size_t k = 0; k < l.all->lda; k++) {
- float new_value = u_for_w[k]*v_s[k]*c_w;
- l.total_new[k] += new_value;
- word_weights[k] += new_value;
- }
- }
- }
- for (size_t k = 0; k < l.all->lda; k++) {
- l.total_lambda[k] *= minuseta;
- l.total_lambda[k] += l.total_new[k];
- }
-
- l.sorted_features.resize(0);
-
- l.examples.erase();
- l.doc_lengths.erase();
- }
-
- void learn(lda& l, learner& base, example& ec)
- {
- size_t num_ex = l.examples.size();
- l.examples.push_back(&ec);
- l.doc_lengths.push_back(0);
- for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) {
- feature* f = ec.atomics[*i].begin;
- for (; f != ec.atomics[*i].end; f++) {
- index_feature temp = {(uint32_t)num_ex, *f};
- l.sorted_features.push_back(temp);
- l.doc_lengths[num_ex] += (int)f->x;
- }
- }
- if (++num_ex == l.all->minibatch)
- learn_batch(l);
- }
-
- // placeholder
- void predict(lda& l, learner& base, example& ec)
- {
- learn(l, base, ec);
- }
-
- void end_pass(lda& l)
- {
- if (l.examples.size())
- learn_batch(l);
- }
-
-void end_examples(lda& l)
-{
- for (size_t i = 0; i < l.all->length(); i++) {
- weight* weights_for_w = & (l.all->reg.weight_vector[i << l.all->reg.stride_shift]);
- float decay = fmin(1.0, exp(l.decay_levels.last() - l.decay_levels.end[(int)(-1- l.example_t +weights_for_w[l.all->lda])]));
- for (size_t k = 0; k < l.all->lda; k++)
- weights_for_w[k] *= decay;
- }
-}
-
- void finish_example(vw& all, lda&, example& ec)
-{}
-
- void finish(lda& ld)
- {
- ld.sorted_features.~vector<index_feature>();
- ld.Elogtheta.delete_v();
- ld.decay_levels.delete_v();
- ld.total_new.delete_v();
- ld.examples.delete_v();
- ld.total_lambda.delete_v();
- ld.doc_lengths.delete_v();
- ld.digammas.delete_v();
- ld.v.delete_v();
- }
-
-learner* setup(vw&all, po::variables_map& vm)
-{
- lda* ld = (lda*)calloc_or_die(1,sizeof(lda));
- ld->sorted_features = vector<index_feature>();
- ld->total_lambda_init = 0;
- ld->all = &all;
- ld->example_t = all.initial_t;
-
- po::options_description lda_opts("LDA options");
- lda_opts.add_options()
- ("lda_alpha", po::value<float>(&all.lda_alpha), "Prior on sparsity of per-document topic weights")
- ("lda_rho", po::value<float>(&all.lda_rho), "Prior on sparsity of topic distributions")
- ("lda_D", po::value<float>(&all.lda_D), "Number of documents")
- ("lda_epsilon", po::value<float>(&all.lda_epsilon), "Loop convergence threshold")
- ("minibatch", po::value<size_t>(&all.minibatch), "Minibatch size, for LDA");
-
- vm = add_options(all, lda_opts);
-
- float temp = ceilf(logf((float)(all.lda*2+1)) / logf (2.f));
- all.reg.stride_shift = (size_t)temp;
- all.random_weights = true;
- all.add_constant = false;
-
- std::stringstream ss;
- ss << " --lda " << all.lda;
- all.file_options.append(ss.str());
-
- if (all.eta > 1.)
- {
- cerr << "your learning rate is too high, setting it to 1" << endl;
- all.eta = min(all.eta,1.f);
- }
-
- if (vm.count("minibatch")) {
- size_t minibatch2 = next_pow2(all.minibatch);
- all.p->ring_size = all.p->ring_size > minibatch2 ? all.p->ring_size : minibatch2;
- }
-
- ld->v.resize(all.lda*all.minibatch);
-
- ld->decay_levels.push_back(0.f);
-
- learner* l = new learner(ld, 1 << all.reg.stride_shift);
- l->set_learn<lda,learn>();
- l->set_predict<lda,predict>();
- l->set_save_load<lda,save_load>();
- l->set_finish_example<lda,finish_example>();
- l->set_end_examples<lda,end_examples>();
- l->set_end_pass<lda,end_pass>();
- l->set_finish<lda,finish>();
-
- return l;
-}
-}
+/*
+Copyright (c) by respective owners including Yahoo!, Microsoft, and
+individual contributors. All rights reserved. Released under a BSD (revised)
+license as described in the file LICENSE.
+ */
+#include <fstream>
+#include <vector>
+#include <float.h>
+#ifdef _WIN32
+#include <winsock2.h>
+#else
+#include <netdb.h>
+#endif
+#include <string.h>
+#include <stdio.h>
+#include <assert.h>
+#include "constant.h"
+#include "gd.h"
+#include "simple_label.h"
+#include "rand48.h"
+#include "reductions.h"
+
+using namespace LEARNER;
+using namespace std;
+
+namespace LDA {
+
+class index_feature {
+public:
+ uint32_t document;
+ feature f;
+ bool operator<(const index_feature b) const { return f.weight_index < b.f.weight_index; }
+};
+
+ struct lda {
+ v_array<float> Elogtheta;
+ v_array<float> decay_levels;
+ v_array<float> total_new;
+ v_array<example* > examples;
+ v_array<float> total_lambda;
+ v_array<int> doc_lengths;
+ v_array<float> digammas;
+ v_array<float> v;
+ vector<index_feature> sorted_features;
+
+ bool total_lambda_init;
+
+ double example_t;
+ vw* all;
+ };
+
+#ifdef _WIN32
+inline float fmax(float f1, float f2) { return (f1 < f2 ? f2 : f1); }
+inline float fmin(float f1, float f2) { return (f1 > f2 ? f2 : f1); }
+#endif
+
+#define MINEIRO_SPECIAL
+#ifdef MINEIRO_SPECIAL
+
+namespace {
+
+inline float
+fastlog2 (float x)
+{
+ union { float f; uint32_t i; } vx = { x };
+ union { uint32_t i; float f; } mx = { (vx.i & 0x007FFFFF) | (0x7e << 23) };
+ float y = (float)vx.i;
+ y *= 1.0f / (float)(1 << 23);
+
+ return
+ y - 124.22544637f - 1.498030302f * mx.f - 1.72587999f / (0.3520887068f + mx.f);
+}
+
+inline float
+fastlog (float x)
+{
+ return 0.69314718f * fastlog2 (x);
+}
+
+inline float
+fastpow2 (float p)
+{
+ float offset = (p < 0) ? 1.0f : 0.0f;
+ float clipp = (p < -126) ? -126.0f : p;
+ int w = (int)clipp;
+ float z = clipp - w + offset;
+ union { uint32_t i; float f; } v = { (uint32_t)((1 << 23) * (clipp + 121.2740838f + 27.7280233f / (4.84252568f - z) - 1.49012907f * z)) };
+
+ return v.f;
+}
+
+inline float
+fastexp (float p)
+{
+ return fastpow2 (1.442695040f * p);
+}
+
+inline float
+fastpow (float x,
+ float p)
+{
+ return fastpow2 (p * fastlog2 (x));
+}
+
+inline float
+fastlgamma (float x)
+{
+ float logterm = fastlog (x * (1.0f + x) * (2.0f + x));
+ float xp3 = 3.0f + x;
+
+ return
+ -2.081061466f - x + 0.0833333f / xp3 - logterm + (2.5f + x) * fastlog (xp3);
+}
+
+inline float
+fastdigamma (float x)
+{
+ float twopx = 2.0f + x;
+ float logterm = fastlog (twopx);
+
+ return - (1.0f + 2.0f * x) / (x * (1.0f + x))
+ - (13.0f + 6.0f * x) / (12.0f * twopx * twopx)
+ + logterm;
+}
+
+#define log fastlog
+#define exp fastexp
+#define powf fastpow
+#define mydigamma fastdigamma
+#define mylgamma fastlgamma
+
+#if defined(__SSE2__) && !defined(VW_LDA_NO_SSE)
+
+#include <emmintrin.h>
+
+typedef __m128 v4sf;
+typedef __m128i v4si;
+
+#define v4si_to_v4sf _mm_cvtepi32_ps
+#define v4sf_to_v4si _mm_cvttps_epi32
+
+static inline float
+v4sf_index (const v4sf x,
+ unsigned int i)
+{
+ union { v4sf f; float array[4]; } tmp = { x };
+
+ return tmp.array[i];
+}
+
+static inline const v4sf
+v4sfl (float x)
+{
+ union { float array[4]; v4sf f; } tmp = { { x, x, x, x } };
+
+ return tmp.f;
+}
+
+static inline const v4si
+v4sil (uint32_t x)
+{
+ uint64_t wide = (((uint64_t) x) << 32) | x;
+ union { uint64_t array[2]; v4si f; } tmp = { { wide, wide } };
+
+ return tmp.f;
+}
+
+static inline v4sf
+vfastpow2 (const v4sf p)
+{
+ v4sf ltzero = _mm_cmplt_ps (p, v4sfl (0.0f));
+ v4sf offset = _mm_and_ps (ltzero, v4sfl (1.0f));
+ v4sf lt126 = _mm_cmplt_ps (p, v4sfl (-126.0f));
+ v4sf clipp = _mm_andnot_ps (lt126, p) + _mm_and_ps (lt126, v4sfl (-126.0f));
+ v4si w = v4sf_to_v4si (clipp);
+ v4sf z = clipp - v4si_to_v4sf (w) + offset;
+
+ const v4sf c_121_2740838 = v4sfl (121.2740838f);
+ const v4sf c_27_7280233 = v4sfl (27.7280233f);
+ const v4sf c_4_84252568 = v4sfl (4.84252568f);
+ const v4sf c_1_49012907 = v4sfl (1.49012907f);
+ union { v4si i; v4sf f; } v = {
+ v4sf_to_v4si (
+ v4sfl (1 << 23) *
+ (clipp + c_121_2740838 + c_27_7280233 / (c_4_84252568 - z) - c_1_49012907 * z)
+ )
+ };
+
+ return v.f;
+}
+
+inline v4sf
+vfastexp (const v4sf p)
+{
+ const v4sf c_invlog_2 = v4sfl (1.442695040f);
+
+ return vfastpow2 (c_invlog_2 * p);
+}
+
+inline v4sf
+vfastlog2 (v4sf x)
+{
+ union { v4sf f; v4si i; } vx = { x };
+ union { v4si i; v4sf f; } mx = { (vx.i & v4sil (0x007FFFFF)) | v4sil (0x3f000000) };
+ v4sf y = v4si_to_v4sf (vx.i);
+ y *= v4sfl (1.1920928955078125e-7f);
+
+ const v4sf c_124_22551499 = v4sfl (124.22551499f);
+ const v4sf c_1_498030302 = v4sfl (1.498030302f);
+ const v4sf c_1_725877999 = v4sfl (1.72587999f);
+ const v4sf c_0_3520087068 = v4sfl (0.3520887068f);
+
+ return y - c_124_22551499
+ - c_1_498030302 * mx.f
+ - c_1_725877999 / (c_0_3520087068 + mx.f);
+}
+
+inline v4sf
+vfastlog (v4sf x)
+{
+ const v4sf c_0_69314718 = v4sfl (0.69314718f);
+
+ return c_0_69314718 * vfastlog2 (x);
+}
+
+inline v4sf
+vfastdigamma (v4sf x)
+{
+ v4sf twopx = v4sfl (2.0f) + x;
+ v4sf logterm = vfastlog (twopx);
+
+ return (v4sfl (-48.0f) + x * (v4sfl (-157.0f) + x * (v4sfl (-127.0f) - v4sfl (30.0f) * x))) /
+ (v4sfl (12.0f) * x * (v4sfl (1.0f) + x) * twopx * twopx)
+ + logterm;
+}
+
+void
+vexpdigammify (vw& all, float* gamma)
+{
+ unsigned int n = all.lda;
+ float extra_sum = 0.0f;
+ v4sf sum = v4sfl (0.0f);
+ size_t i;
+
+ for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
+ {
+ extra_sum += gamma[i];
+ gamma[i] = fastdigamma (gamma[i]);
+ }
+
+ for (; i + 4 < n; i += 4)
+ {
+ v4sf arg = _mm_load_ps (gamma + i);
+ sum += arg;
+ arg = vfastdigamma (arg);
+ _mm_store_ps (gamma + i, arg);
+ }
+
+ for (; i < n; ++i)
+ {
+ extra_sum += gamma[i];
+ gamma[i] = fastdigamma (gamma[i]);
+ }
+
+ extra_sum += v4sf_index (sum, 0) + v4sf_index (sum, 1) +
+ v4sf_index (sum, 2) + v4sf_index (sum, 3);
+ extra_sum = fastdigamma (extra_sum);
+ sum = v4sfl (extra_sum);
+
+ for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
+ {
+ gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0)));
+ }
+
+ for (; i + 4 < n; i += 4)
+ {
+ v4sf arg = _mm_load_ps (gamma + i);
+ arg -= sum;
+ arg = vfastexp (arg);
+ arg = _mm_max_ps (v4sfl (1e-10f), arg);
+ _mm_store_ps (gamma + i, arg);
+ }
+
+ for (; i < n; ++i)
+ {
+ gamma[i] = fmaxf (1e-10f, fastexp (gamma[i] - v4sf_index (sum, 0)));
+ }
+}
+
+void vexpdigammify_2(vw& all, float* gamma, const float* norm)
+{
+ size_t n = all.lda;
+ size_t i;
+
+ for (i = 0; i < n && ((uintptr_t) (gamma + i)) % 16 > 0; ++i)
+ {
+ gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i]));
+ }
+
+ for (; i + 4 < n; i += 4)
+ {
+ v4sf arg = _mm_load_ps (gamma + i);
+ arg = vfastdigamma (arg);
+ v4sf vnorm = _mm_loadu_ps (norm + i);
+ arg -= vnorm;
+ arg = vfastexp (arg);
+ arg = _mm_max_ps (v4sfl (1e-10f), arg);
+ _mm_store_ps (gamma + i, arg);
+ }
+
+ for (; i < n; ++i)
+ {
+ gamma[i] = fmaxf (1e-10f, fastexp (fastdigamma (gamma[i]) - norm[i]));
+ }
+}
+
+#define myexpdigammify vexpdigammify
+#define myexpdigammify_2 vexpdigammify_2
+
+#else
+#ifndef _WIN32
+#warning "lda IS NOT using sse instructions"
+#endif
+#define myexpdigammify expdigammify
+#define myexpdigammify_2 expdigammify_2
+
+#endif // __SSE2__
+
+} // end anonymous namespace
+
+#else
+
+#include <boost/math/special_functions/digamma.hpp>
+#include <boost/math/special_functions/gamma.hpp>
+
+using namespace boost::math::policies;
+
+#define mydigamma boost::math::digamma
+#define mylgamma boost::math::lgamma
+#define myexpdigammify expdigammify
+#define myexpdigammify_2 expdigammify_2
+
+#endif // MINEIRO_SPECIAL
+
+float decayfunc(float t, float old_t, float power_t) {
+ float result = 1;
+ for (float i = old_t+1; i <= t; i += 1)
+ result *= (1-powf(i, -power_t));
+ return result;
+}
+
+float decayfunc2(float t, float old_t, float power_t)
+{
+ float power_t_plus_one = 1.f - power_t;
+ float arg = - ( powf(t, power_t_plus_one) -
+ powf(old_t, power_t_plus_one));
+ return exp ( arg
+ / power_t_plus_one);
+}
+
+float decayfunc3(double t, double old_t, double power_t)
+{
+ double power_t_plus_one = 1. - power_t;
+ double logt = log((float)t);
+ double logoldt = log((float)old_t);
+ return (float)((old_t / t) * exp((float)(0.5*power_t_plus_one*(-logt*logt + logoldt*logoldt))));
+}
+
+float decayfunc4(double t, double old_t, double power_t)
+{
+ if (power_t > 0.99)
+ return decayfunc3(t, old_t, power_t);
+ else
+ return (float)decayfunc2((float)t, (float)old_t, (float)power_t);
+}
+
+void expdigammify(vw& all, float* gamma)
+{
+ float sum=0;
+ for (size_t i = 0; i<all.lda; i++)
+ {
+ sum += gamma[i];
+ gamma[i] = mydigamma(gamma[i]);
+ }
+ sum = mydigamma(sum);
+ for (size_t i = 0; i<all.lda; i++)
+ gamma[i] = fmax(1e-6f, exp(gamma[i] - sum));
+}
+
+void expdigammify_2(vw& all, float* gamma, float* norm)
+{
+ for (size_t i = 0; i<all.lda; i++)
+ {
+ gamma[i] = fmax(1e-6f, exp(mydigamma(gamma[i]) - norm[i]));
+ }
+}
+
+float average_diff(vw& all, float* oldgamma, float* newgamma)
+{
+ float sum = 0.;
+ float normalizer = 0.;
+ for (size_t i = 0; i<all.lda; i++) {
+ sum += fabsf(oldgamma[i] - newgamma[i]);
+ normalizer += newgamma[i];
+ }
+ return sum / normalizer;
+}
+
+// Returns E_q[log p(\theta)] - E_q[log q(\theta)].
+ float theta_kl(vw& all, v_array<float>& Elogtheta, float* gamma)
+{
+ float gammasum = 0;
+ Elogtheta.erase();
+ for (size_t k = 0; k < all.lda; k++) {
+ Elogtheta.push_back(mydigamma(gamma[k]));
+ gammasum += gamma[k];
+ }
+ float digammasum = mydigamma(gammasum);
+ gammasum = mylgamma(gammasum);
+ float kl = -(all.lda*mylgamma(all.lda_alpha));
+ kl += mylgamma(all.lda_alpha*all.lda) - gammasum;
+ for (size_t k = 0; k < all.lda; k++) {
+ Elogtheta[k] -= digammasum;
+ kl += (all.lda_alpha - gamma[k]) * Elogtheta[k];
+ kl += mylgamma(gamma[k]);
+ }
+
+ return kl;
+}
+
+float find_cw(vw& all, float* u_for_w, float* v)
+{
+ float c_w = 0;
+ for (size_t k =0; k<all.lda; k++)
+ c_w += u_for_w[k]*v[k];
+
+ return 1.f / c_w;
+}
+
+ v_array<float> new_gamma = v_init<float>();
+ v_array<float> old_gamma = v_init<float>();
+// Returns an estimate of the part of the variational bound that
+// doesn't have to do with beta for the entire corpus for the current
+// setting of lambda based on the document passed in. The value is
+// divided by the total number of words in the document This can be
+// used as a (possibly very noisy) estimate of held-out likelihood.
+ float lda_loop(vw& all, v_array<float>& Elogtheta, float* v,weight* weights,example* ec, float power_t)
+{
+ new_gamma.erase();
+ old_gamma.erase();
+
+ for (size_t i = 0; i < all.lda; i++)
+ {
+ new_gamma.push_back(1.f);
+ old_gamma.push_back(0.f);
+ }
+ size_t num_words =0;
+ for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
+ num_words += ec->atomics[*i].end - ec->atomics[*i].begin;
+
+ float xc_w = 0;
+ float score = 0;
+ float doc_length = 0;
+ do
+ {
+ memcpy(v,new_gamma.begin,sizeof(float)*all.lda);
+ myexpdigammify(all, v);
+
+ memcpy(old_gamma.begin,new_gamma.begin,sizeof(float)*all.lda);
+ memset(new_gamma.begin,0,sizeof(float)*all.lda);
+
+ score = 0;
+ size_t word_count = 0;
+ doc_length = 0;
+ for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
+ {
+ feature *f = ec->atomics[*i].begin;
+ for (; f != ec->atomics[*i].end; f++)
+ {
+ float* u_for_w = &weights[(f->weight_index&all.reg.weight_mask)+all.lda+1];
+ float c_w = find_cw(all, u_for_w,v);
+ xc_w = c_w * f->x;
+ score += -f->x*log(c_w);
+ size_t max_k = all.lda;
+ for (size_t k =0; k<max_k; k++) {
+ new_gamma[k] += xc_w*u_for_w[k];
+ }
+ word_count++;
+ doc_length += f->x;
+ }
+ }
+ for (size_t k =0; k<all.lda; k++)
+ new_gamma[k] = new_gamma[k]*v[k]+all.lda_alpha;
+ }
+ while (average_diff(all, old_gamma.begin, new_gamma.begin) > all.lda_epsilon);
+
+ ec->topic_predictions.erase();
+ ec->topic_predictions.resize(all.lda);
+ memcpy(ec->topic_predictions.begin,new_gamma.begin,all.lda*sizeof(float));
+
+ score += theta_kl(all, Elogtheta, new_gamma.begin);
+
+ return score / doc_length;
+}
+
+size_t next_pow2(size_t x) {
+ int i = 0;
+ x = x > 0 ? x - 1 : 0;
+ while (x > 0) {
+ x >>= 1;
+ i++;
+ }
+ return ((size_t)1) << i;
+}
+
+void save_load(lda& l, io_buf& model_file, bool read, bool text)
+{
+ vw* all = l.all;
+ uint32_t length = 1 << all->num_bits;
+ uint32_t stride = 1 << all->reg.stride_shift;
+
+ if (read)
+ {
+ initialize_regressor(*all);
+ for (size_t j = 0; j < stride*length; j+=stride)
+ {
+ for (size_t k = 0; k < all->lda; k++) {
+ if (all->random_weights) {
+ all->reg.weight_vector[j+k] = (float)(-log(frand48()) + 1.0f);
+ all->reg.weight_vector[j+k] *= (float)(all->lda_D / all->lda / all->length() * 200);
+ }
+ }
+ all->reg.weight_vector[j+all->lda] = all->initial_t;
+ }
+ }
+
+ if (model_file.files.size() > 0)
+ {
+ uint32_t i = 0;
+ uint32_t text_len;
+ char buff[512];
+ size_t brw = 1;
+ do
+ {
+ brw = 0;
+ size_t K = all->lda;
+
+ text_len = sprintf(buff, "%d ", i);
+ brw += bin_text_read_write_fixed(model_file,(char *)&i, sizeof (i),
+ "", read,
+ buff, text_len, text);
+ if (brw != 0)
+ for (uint32_t k = 0; k < K; k++)
+ {
+ uint32_t ndx = stride*i+k;
+
+ weight* v = &(all->reg.weight_vector[ndx]);
+ text_len = sprintf(buff, "%f ", *v + all->lda_rho);
+
+ brw += bin_text_read_write_fixed(model_file,(char *)v, sizeof (*v),
+ "", read,
+ buff, text_len, text);
+
+ }
+ if (text)
+ brw += bin_text_read_write_fixed(model_file,buff,0,
+ "", read,
+ "\n",1,text);
+
+ if (!read)
+ i++;
+ }
+ while ((!read && i < length) || (read && brw >0));
+ }
+}
+
+ void learn_batch(lda& l)
+ {
+ if (l.sorted_features.empty()) {
+ // This can happen when the socket connection is dropped by the client.
+ // If l.sorted_features is empty, then l.sorted_features[0] does not
+ // exist, so we should not try to take its address in the beginning of
+ // the for loops down there. Since it seems that there's not much to
+ // do in this case, we just return.
+ for (size_t d = 0; d < l.examples.size(); d++)
+ return_simple_example(*l.all, NULL, *l.examples[d]);
+ l.examples.erase();
+ return;
+ }
+
+ float eta = -1;
+ float minuseta = -1;
+
+ if (l.total_lambda.size() == 0)
+ {
+ for (size_t k = 0; k < l.all->lda; k++)
+ l.total_lambda.push_back(0.f);
+
+ size_t stride = 1 << l.all->reg.stride_shift;
+ for (size_t i =0; i <= l.all->reg.weight_mask;i+=stride)
+ for (size_t k = 0; k < l.all->lda; k++)
+ l.total_lambda[k] += l.all->reg.weight_vector[i+k];
+ }
+
+ l.example_t++;
+ l.total_new.erase();
+ for (size_t k = 0; k < l.all->lda; k++)
+ l.total_new.push_back(0.f);
+
+ size_t batch_size = l.examples.size();
+
+ sort(l.sorted_features.begin(), l.sorted_features.end());
+
+ eta = l.all->eta * powf((float)l.example_t, - l.all->power_t);
+ minuseta = 1.0f - eta;
+ eta *= l.all->lda_D / batch_size;
+ l.decay_levels.push_back(l.decay_levels.last() + log(minuseta));
+
+ l.digammas.erase();
+ float additional = (float)(l.all->length()) * l.all->lda_rho;
+ for (size_t i = 0; i<l.all->lda; i++) {
+ l.digammas.push_back(mydigamma(l.total_lambda[i] + additional));
+ }
+
+
+ weight* weights = l.all->reg.weight_vector;
+
+ size_t last_weight_index = -1;
+ for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back(); s++)
+ {
+ if (last_weight_index == s->f.weight_index)
+ continue;
+ last_weight_index = s->f.weight_index;
+ float* weights_for_w = &(weights[s->f.weight_index & l.all->reg.weight_mask]);
+ float decay = fmin(1.0, exp(l.decay_levels.end[-2] - l.decay_levels.end[(int)(-1 - l.example_t+weights_for_w[l.all->lda])]));
+ float* u_for_w = weights_for_w + l.all->lda+1;
+
+ weights_for_w[l.all->lda] = (float)l.example_t;
+ for (size_t k = 0; k < l.all->lda; k++)
+ {
+ weights_for_w[k] *= decay;
+ u_for_w[k] = weights_for_w[k] + l.all->lda_rho;
+ }
+ myexpdigammify_2(*l.all, u_for_w, l.digammas.begin);
+ }
+
+ for (size_t d = 0; d < batch_size; d++)
+ {
+ float score = lda_loop(*l.all, l.Elogtheta, &(l.v[d*l.all->lda]), weights, l.examples[d],l.all->power_t);
+ if (l.all->audit)
+ GD::print_audit_features(*l.all, *l.examples[d]);
+ // If the doc is empty, give it loss of 0.
+ if (l.doc_lengths[d] > 0) {
+ l.all->sd->sum_loss -= score;
+ l.all->sd->sum_loss_since_last_dump -= score;
+ }
+ return_simple_example(*l.all, NULL, *l.examples[d]);
+ }
+
+ for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back();)
+ {
+ index_feature* next = s+1;
+ while(next <= &l.sorted_features.back() && next->f.weight_index == s->f.weight_index)
+ next++;
+
+ float* word_weights = &(weights[s->f.weight_index & l.all->reg.weight_mask]);
+ for (size_t k = 0; k < l.all->lda; k++) {
+ float new_value = minuseta*word_weights[k];
+ word_weights[k] = new_value;
+ }
+
+ for (; s != next; s++) {
+ float* v_s = &(l.v[s->document*l.all->lda]);
+ float* u_for_w = &weights[(s->f.weight_index & l.all->reg.weight_mask) + l.all->lda + 1];
+ float c_w = eta*find_cw(*l.all, u_for_w, v_s)*s->f.x;
+ for (size_t k = 0; k < l.all->lda; k++) {
+ float new_value = u_for_w[k]*v_s[k]*c_w;
+ l.total_new[k] += new_value;
+ word_weights[k] += new_value;
+ }
+ }
+ }
+ for (size_t k = 0; k < l.all->lda; k++) {
+ l.total_lambda[k] *= minuseta;
+ l.total_lambda[k] += l.total_new[k];
+ }
+
+ l.sorted_features.resize(0);
+
+ l.examples.erase();
+ l.doc_lengths.erase();
+ }
+
+ void learn(lda& l, learner& base, example& ec)
+ {
+ size_t num_ex = l.examples.size();
+ l.examples.push_back(&ec);
+ l.doc_lengths.push_back(0);
+ for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) {
+ feature* f = ec.atomics[*i].begin;
+ for (; f != ec.atomics[*i].end; f++) {
+ index_feature temp = {(uint32_t)num_ex, *f};
+ l.sorted_features.push_back(temp);
+ l.doc_lengths[num_ex] += (int)f->x;
+ }
+ }
+ if (++num_ex == l.all->minibatch)
+ learn_batch(l);
+ }
+
+ // placeholder
+ void predict(lda& l, learner& base, example& ec)
+ {
+ learn(l, base, ec);
+ }
+
+ void end_pass(lda& l)
+ {
+ if (l.examples.size())
+ learn_batch(l);
+ }
+
+void end_examples(lda& l)
+{
+ for (size_t i = 0; i < l.all->length(); i++) {
+ weight* weights_for_w = & (l.all->reg.weight_vector[i << l.all->reg.stride_shift]);
+ float decay = fmin(1.0, exp(l.decay_levels.last() - l.decay_levels.end[(int)(-1- l.example_t +weights_for_w[l.all->lda])]));
+ for (size_t k = 0; k < l.all->lda; k++)
+ weights_for_w[k] *= decay;
+ }
+}
+
+ void finish_example(vw& all, lda&, example& ec)
+{}
+
+ void finish(lda& ld)
+ {
+ ld.sorted_features.~vector<index_feature>();
+ ld.Elogtheta.delete_v();
+ ld.decay_levels.delete_v();
+ ld.total_new.delete_v();
+ ld.examples.delete_v();
+ ld.total_lambda.delete_v();
+ ld.doc_lengths.delete_v();
+ ld.digammas.delete_v();
+ ld.v.delete_v();
+ }
+
+learner* setup(vw&all, po::variables_map& vm)
+{
+ lda* ld = (lda*)calloc_or_die(1,sizeof(lda));
+ ld->sorted_features = vector<index_feature>();
+ ld->total_lambda_init = 0;
+ ld->all = &all;
+ ld->example_t = all.initial_t;
+
+ po::options_description lda_opts("LDA options");
+ lda_opts.add_options()
+ ("lda_alpha", po::value<float>(&all.lda_alpha), "Prior on sparsity of per-document topic weights")
+ ("lda_rho", po::value<float>(&all.lda_rho), "Prior on sparsity of topic distributions")
+ ("lda_D", po::value<float>(&all.lda_D), "Number of documents")
+ ("lda_epsilon", po::value<float>(&all.lda_epsilon), "Loop convergence threshold")
+ ("minibatch", po::value<size_t>(&all.minibatch), "Minibatch size, for LDA");
+
+ vm = add_options(all, lda_opts);
+
+ float temp = ceilf(logf((float)(all.lda*2+1)) / logf (2.f));
+ all.reg.stride_shift = (size_t)temp;
+ all.random_weights = true;
+ all.add_constant = false;
+
+ std::stringstream ss;
+ ss << " --lda " << all.lda;
+ all.file_options.append(ss.str());
+
+ if (all.eta > 1.)
+ {
+ cerr << "your learning rate is too high, setting it to 1" << endl;
+ all.eta = min(all.eta,1.f);
+ }
+
+ if (vm.count("minibatch")) {
+ size_t minibatch2 = next_pow2(all.minibatch);
+ all.p->ring_size = all.p->ring_size > minibatch2 ? all.p->ring_size : minibatch2;
+ }
+
+ ld->v.resize(all.lda*all.minibatch);
+
+ ld->decay_levels.push_back(0.f);
+
+ learner* l = new learner(ld, 1 << all.reg.stride_shift);
+ l->set_learn<lda,learn>();
+ l->set_predict<lda,predict>();
+ l->set_save_load<lda,save_load>();
+ l->set_finish_example<lda,finish_example>();
+ l->set_end_examples<lda,end_examples>();
+ l->set_end_pass<lda,end_pass>();
+ l->set_finish<lda,finish>();
+
+ return l;
+}
+}
diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc
index 68f52f06..bfc25288 100644
--- a/vowpalwabbit/log_multi.cc
+++ b/vowpalwabbit/log_multi.cc
@@ -1,549 +1,549 @@
-/*\t
-
-Copyright (c) by respective owners including Yahoo!, Microsoft, and
-individual contributors. All rights reserved. Released under a BSD (revised)
-license as described in the file LICENSE.node
-*/
-#include <float.h>
-#include <math.h>
-#include <stdio.h>
-#include <sstream>
-
-#include "reductions.h"
-#include "simple_label.h"
-#include "multiclass.h"
-#include "vw.h"
-
-using namespace std;
-using namespace LEARNER;
-
-namespace LOG_MULTI
-{
- class node_pred
- {
- public:
-
- double Ehk;
- float norm_Ehk;
- uint32_t nk;
- uint32_t label;
- uint32_t label_count;
-
- bool operator==(node_pred v){
- return (label == v.label);
- }
-
- bool operator>(node_pred v){
- if(label > v.label) return true;
- return false;
- }
-
- bool operator<(node_pred v){
- if(label < v.label) return true;
- return false;
- }
-
- node_pred(uint32_t l)
- {
- label = l;
- Ehk = 0.f;
- norm_Ehk = 0;
- nk = 0;
- label_count = 0;
- }
- };
-
- typedef struct
- {//everyone has
- uint32_t parent;//the parent node
- v_array<node_pred> preds;//per-class state
- uint32_t min_count;//the number of examples reaching this node (if it's a leaf) or the minimum reaching any grandchild.
-
- bool internal;//internal or leaf
-
- //internal nodes have
- uint32_t base_predictor;//id of the base predictor
- uint32_t left;//left child
- uint32_t right;//right child
- float norm_Eh;//the average margin at the node
- double Eh;//total margin at the node
- uint32_t n;//total events at the node
-
- //leaf has
- uint32_t max_count;//the number of samples of the most common label
- uint32_t max_count_label;//the most common label
- } node;
-
- struct log_multi
- {
- uint32_t k;
- vw* all;
-
- v_array<node> nodes;
-
- uint32_t max_predictors;
- uint32_t predictors_used;
-
- bool progress;
- uint32_t swap_resist;
-
- uint32_t nbofswaps;
- };
-
- inline void init_leaf(node& n)
- {
- n.internal = false;
- n.preds.erase();
- n.base_predictor = 0;
- n.norm_Eh = 0;
- n.Eh = 0;
- n.n = 0;
- n.max_count = 0;
- n.max_count_label = 1;
- n.left = 0;
- n.right = 0;
- }
-
- inline node init_node()
- {
- node node;
-
- node.parent = 0;
- node.min_count = 0;
- node.preds = v_init<node_pred>();
- init_leaf(node);
-
- return node;
- }
-
- void init_tree(log_multi& d)
- {
- d.nodes.push_back(init_node());
- d.nbofswaps = 0;
- }
-
- inline uint32_t min_left_right(log_multi& b, node& n)
- {
- return min(b.nodes[n.left].min_count, b.nodes[n.right].min_count);
- }
-
- inline uint32_t find_switch_node(log_multi& b)
- {
- uint32_t node = 0;
- while(b.nodes[node].internal)
- if(b.nodes[b.nodes[node].left].min_count
- < b.nodes[b.nodes[node].right].min_count)
- node = b.nodes[node].left;
- else
- node = b.nodes[node].right;
- return node;
- }
-
- inline void update_min_count(log_multi& b, uint32_t node)
- {//Constant time min count update.
- while(node != 0)
- {
- uint32_t prev = node;
- node = b.nodes[node].parent;
-
- if (b.nodes[node].min_count == b.nodes[prev].min_count)
- break;
- else
- b.nodes[node].min_count = min_left_right(b,b.nodes[node]);
- }
- }
-
- void display_tree_dfs(log_multi& b, node node, uint32_t depth)
- {
- for (uint32_t i = 0; i < depth; i++)
- cout << "\t";
- cout << node.min_count << " " << node.left
- << " " << node.right;
- cout << " label = " << node.max_count_label << " labels = ";
- for (size_t i = 0; i < node.preds.size(); i++)
- cout << node.preds[i].label << ":" << node.preds[i].label_count << "\t";
- cout << endl;
-
- if (node.internal)
- {
- cout << "Left";
- display_tree_dfs(b, b.nodes[node.left], depth+1);
-
- cout << "Right";
- display_tree_dfs(b, b.nodes[node.right], depth+1);
- }
- }
-
- bool children(log_multi& b, uint32_t& current, uint32_t& class_index, uint32_t label)
- {
- class_index = (uint32_t)b.nodes[current].preds.unique_add_sorted(node_pred(label));
- b.nodes[current].preds[class_index].label_count++;
-
- if(b.nodes[current].preds[class_index].label_count > b.nodes[current].max_count)
- {
- b.nodes[current].max_count = b.nodes[current].preds[class_index].label_count;
- b.nodes[current].max_count_label = b.nodes[current].preds[class_index].label;
- }
-
- if (b.nodes[current].internal)
- return true;
- else if( b.nodes[current].preds.size() > 1
- && (b.predictors_used < b.max_predictors
- || b.nodes[current].min_count - b.nodes[current].max_count > b.swap_resist*(b.nodes[0].min_count + 1)))
- { //need children and we can make them.
- uint32_t left_child;
- uint32_t right_child;
- if (b.predictors_used < b.max_predictors)
- {
- left_child = (uint32_t)b.nodes.size();
- b.nodes.push_back(init_node());
- right_child = (uint32_t)b.nodes.size();
- b.nodes.push_back(init_node());
- b.nodes[current].base_predictor = b.predictors_used++;
- }
- else
- {
- uint32_t swap_child = find_switch_node(b);
- uint32_t swap_parent = b.nodes[swap_child].parent;
- uint32_t swap_grandparent = b.nodes[swap_parent].parent;
- if (b.nodes[swap_child].min_count != b.nodes[0].min_count)
- cout << "glargh " << b.nodes[swap_child].min_count << " != " << b.nodes[0].min_count << endl;
- b.nbofswaps++;
-
- uint32_t nonswap_child;
- if(swap_child == b.nodes[swap_parent].right)
- nonswap_child = b.nodes[swap_parent].left;
- else
- nonswap_child = b.nodes[swap_parent].right;
-
- if(swap_parent == b.nodes[swap_grandparent].left)
- b.nodes[swap_grandparent].left = nonswap_child;
- else
- b.nodes[swap_grandparent].right = nonswap_child;
- b.nodes[nonswap_child].parent = swap_grandparent;
- update_min_count(b, nonswap_child);
-
- init_leaf(b.nodes[swap_child]);
- left_child = swap_child;
- b.nodes[current].base_predictor = b.nodes[swap_parent].base_predictor;
- init_leaf(b.nodes[swap_parent]);
- right_child = swap_parent;
- }
- b.nodes[current].left = left_child;
- b.nodes[left_child].parent = current;
- b.nodes[current].right = right_child;
- b.nodes[right_child].parent = current;
-
- b.nodes[left_child].min_count = b.nodes[current].min_count/2;
- b.nodes[right_child].min_count = b.nodes[current].min_count - b.nodes[left_child].min_count;
- update_min_count(b, left_child);
-
- b.nodes[left_child].max_count_label = b.nodes[current].max_count_label;
- b.nodes[right_child].max_count_label = b.nodes[current].max_count_label;
-
- b.nodes[current].internal = true;
- }
- return b.nodes[current].internal;
- }
-
- void train_node(log_multi& b, learner& base, example& ec, uint32_t& current, uint32_t& class_index)
- {
- if(b.nodes[current].norm_Eh > b.nodes[current].preds[class_index].norm_Ehk)
- ec.l.simple.label = -1.f;
- else
- ec.l.simple.label = 1.f;
-
- base.learn(ec, b.nodes[current].base_predictor);
-
- ec.l.simple.label = FLT_MAX;
- base.predict(ec, b.nodes[current].base_predictor);
-
- b.nodes[current].Eh += (double)ec.partial_prediction;
- b.nodes[current].preds[class_index].Ehk += (double)ec.partial_prediction;
- b.nodes[current].n++;
- b.nodes[current].preds[class_index].nk++;
-
- b.nodes[current].norm_Eh = (float)b.nodes[current].Eh / b.nodes[current].n;
- b.nodes[current].preds[class_index].norm_Ehk = (float)b.nodes[current].preds[class_index].Ehk / b.nodes[current].preds[class_index].nk;
- }
-
- void verify_min_dfs(log_multi& b, node node)
- {
- if (node.internal)
- {
- if (node.min_count != min_left_right(b, node))
- {
- cout << "badness! " << endl;
- display_tree_dfs(b, b.nodes[0], 0);
- }
- verify_min_dfs(b, b.nodes[node.left]);
- verify_min_dfs(b, b.nodes[node.right]);
- }
- }
-
- size_t sum_count_dfs(log_multi& b, node node)
- {
- if (node.internal)
- return sum_count_dfs(b, b.nodes[node.left]) + sum_count_dfs(b, b.nodes[node.right]);
- else
- return node.min_count;
- }
-
- inline uint32_t descend(node& n, float prediction)
- {
- if (prediction < 0)
- return n.left;
- else
- return n.right;
- }
-
- void predict(log_multi& b, learner& base, example& ec)
- {
- MULTICLASS::multiclass mc = ec.l.multi;
-
- label_data simple_temp;
- simple_temp.initial = 0.0;
- simple_temp.weight = 0.0;
- simple_temp.label = FLT_MAX;
- ec.l.simple = simple_temp;
- uint32_t cn = 0;
- while(b.nodes[cn].internal)
- {
- base.predict(ec, b.nodes[cn].base_predictor);
- cn = descend(b.nodes[cn], ec.pred.scalar);
- }
- ec.pred.multiclass = b.nodes[cn].max_count_label;
- ec.l.multi = mc;
- }
-
- void learn(log_multi& b, learner& base, example& ec)
- {
- // verify_min_dfs(b, b.nodes[0]);
-
- if (ec.l.multi.label == (uint32_t)-1 || !b.all->training || b.progress)
- predict(b,base,ec);
-
- if(b.all->training && (ec.l.multi.label != (uint32_t)-1) && !ec.test_only) //if training the tree
- {
- MULTICLASS::multiclass mc = ec.l.multi;
-
- uint32_t class_index = 0;
- label_data simple_temp;
- simple_temp.initial = 0.0;
- simple_temp.weight = mc.weight;
- ec.l.simple = simple_temp;
-
- uint32_t cn = 0;
-
- while(children(b, cn, class_index, mc.label))
- {
- train_node(b, base, ec, cn, class_index);
- cn = descend(b.nodes[cn], ec.pred.scalar);
- }
-
- b.nodes[cn].min_count++;
- update_min_count(b, cn);
-
- ec.l.multi = mc;
- }
- }
-
- void save_node_stats(log_multi& d)
- {
- FILE *fp;
- uint32_t i, j;
- uint32_t total;
- log_multi* b = &d;
-
- fp = fopen("atxm_debug.csv", "wt");
-
- for(i = 0; i < b->nodes.size(); i++)
- {
- fprintf(fp, "Node: %4d, Internal: %1d, Eh: %7.4f, n: %6d, \n", (int) i, (int) b->nodes[i].internal, b->nodes[i].Eh / b->nodes[i].n, b->nodes[i].n);
-
- fprintf(fp, "Label:, ");
- for(j = 0; j < b->nodes[i].preds.size(); j++)
- {
- fprintf(fp, "%6d,", (int) b->nodes[i].preds[j].label);
- }
- fprintf(fp, "\n");
-
- fprintf(fp, "Ehk:, ");
- for(j = 0; j < b->nodes[i].preds.size(); j++)
- {
- fprintf(fp, "%7.4f,", b->nodes[i].preds[j].Ehk / b->nodes[i].preds[j].nk);
- }
- fprintf(fp, "\n");
-
- total = 0;
-
- fprintf(fp, "nk:, ");
- for(j = 0; j < b->nodes[i].preds.size(); j++)
- {
- fprintf(fp, "%6d,", (int) b->nodes[i].preds[j].nk);
- total += b->nodes[i].preds[j].nk;
- }
- fprintf(fp, "\n");
-
- fprintf(fp, "max(lab:cnt:tot):, %3d,%6d,%7d,\n", (int) b->nodes[i].max_count_label, (int) b->nodes[i].max_count, (int) total);
- fprintf(fp, "left: %4d, right: %4d", (int) b->nodes[i].left, (int) b->nodes[i].right);
- fprintf(fp, "\n\n");
- }
-
- fclose(fp);
- }
-
- void finish(log_multi& b)
- {
- save_node_stats(b);
- cout << "used " << b.nbofswaps << " swaps" << endl;
- }
-
- void save_load_tree(log_multi& b, io_buf& model_file, bool read, bool text)
- {
- if (model_file.files.size() > 0)
- {
- char buff[512];
-
- uint32_t text_len = sprintf(buff, "k = %d ",b.k);
- bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.k), "", read, buff, text_len, text);
- uint32_t temp = (uint32_t)b.nodes.size();
- text_len = sprintf(buff, "nodes = %d ",temp);
- bin_text_read_write_fixed(model_file,(char*)&temp, sizeof(temp), "", read, buff, text_len, text);
- if (read)
- for (uint32_t j = 1; j < temp; j++)
- b.nodes.push_back(init_node());
- text_len = sprintf(buff, "max_predictors = %d ",b.max_predictors);
- bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.max_predictors), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, "predictors_used = %d ",b.predictors_used);
- bin_text_read_write_fixed(model_file,(char*)&b.predictors_used, sizeof(b.predictors_used), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, "progress = %d ",b.progress);
- bin_text_read_write_fixed(model_file,(char*)&b.progress, sizeof(b.progress), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, "swap_resist = %d\n",b.swap_resist);
- bin_text_read_write_fixed(model_file,(char*)&b.swap_resist, sizeof(b.swap_resist), "", read, buff, text_len, text);
-
- for (size_t j = 0; j < b.nodes.size(); j++)
- {//Need to read or write nodes.
- node& n = b.nodes[j];
- text_len = sprintf(buff, " parent = %d",n.parent);
- bin_text_read_write_fixed(model_file,(char*)&n.parent, sizeof(n.parent), "", read, buff, text_len, text);
-
- uint32_t temp = (uint32_t)n.preds.size();
- text_len = sprintf(buff, " preds = %d",temp);
- bin_text_read_write_fixed(model_file,(char*)&temp, sizeof(temp), "", read, buff, text_len, text);
- if (read)
- for (uint32_t k = 0; k < temp; k++)
- n.preds.push_back(node_pred(1));
-
- text_len = sprintf(buff, " min_count = %d",n.min_count);
- bin_text_read_write_fixed(model_file,(char*)&n.min_count, sizeof(n.min_count), "", read, buff, text_len, text);
-
- uint32_t text_len = sprintf(buff, " internal = %d",n.internal);
- bin_text_read_write_fixed(model_file,(char*)&n.internal, sizeof(n.internal), "", read, buff, text_len, text)
-;
-
- if (n.internal)
- {
- text_len = sprintf(buff, " base_predictor = %d",n.base_predictor);
- bin_text_read_write_fixed(model_file,(char*)&n.base_predictor, sizeof(n.base_predictor), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " left = %d",n.left);
- bin_text_read_write_fixed(model_file,(char*)&n.left, sizeof(n.left), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " right = %d",n.right);
- bin_text_read_write_fixed(model_file,(char*)&n.right, sizeof(n.right), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " norm_Eh = %f",n.norm_Eh);
- bin_text_read_write_fixed(model_file,(char*)&n.norm_Eh, sizeof(n.norm_Eh), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " Eh = %f",n.Eh);
- bin_text_read_write_fixed(model_file,(char*)&n.Eh, sizeof(n.Eh), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " n = %d\n",n.n);
- bin_text_read_write_fixed(model_file,(char*)&n.n, sizeof(n.n), "", read, buff, text_len, text);
- }
- else
- {
- text_len = sprintf(buff, " max_count = %d",n.max_count);
- bin_text_read_write_fixed(model_file,(char*)&n.max_count, sizeof(n.max_count), "", read, buff, text_len, text);
- text_len = sprintf(buff, " max_count_label = %d\n",n.max_count_label);
- bin_text_read_write_fixed(model_file,(char*)&n.max_count_label, sizeof(n.max_count_label), "", read, buff, text_len, text);
- }
-
- for (size_t k = 0; k < n.preds.size(); k++)
- {
- node_pred& p = n.preds[k];
-
- text_len = sprintf(buff, " Ehk = %f",p.Ehk);
- bin_text_read_write_fixed(model_file,(char*)&p.Ehk, sizeof(p.Ehk), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " norm_Ehk = %f",p.norm_Ehk);
- bin_text_read_write_fixed(model_file,(char*)&p.norm_Ehk, sizeof(p.norm_Ehk), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " nk = %d",p.nk);
- bin_text_read_write_fixed(model_file,(char*)&p.nk, sizeof(p.nk), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " label = %d",p.label);
- bin_text_read_write_fixed(model_file,(char*)&p.label, sizeof(p.label), "", read, buff, text_len, text);
-
- text_len = sprintf(buff, " label_count = %d\n",p.label_count);
- bin_text_read_write_fixed(model_file,(char*)&p.label_count, sizeof(p.label_count), "", read, buff, text_len, text);
- }
- }
- }
- }
-
- void finish_example(vw& all, log_multi&, example& ec)
- {
- MULTICLASS::output_example(all, ec);
- VW::finish_example(all, &ec);
- }
-
- learner* setup(vw& all, po::variables_map& vm) //learner setup
- {
- log_multi* data = (log_multi*)calloc(1, sizeof(log_multi));
-
- po::options_description opts("TXM Online options");
- opts.add_options()
- ("no_progress", "disable progressive validation")
- ("swap_resistance", po::value<uint32_t>(&(data->swap_resist))->default_value(4), "higher = more resistance to swap, default=4");
-
- vm = add_options(all, opts);
-
- data->k = (uint32_t)vm["log_multi"].as<size_t>();
-
- //append log_multi with nb_actions to options_from_file so it is saved to regressor later
- std::stringstream ss;
- ss << " --log_multi " << data->k;
- all.file_options.append(ss.str());
-
- if (vm.count("no_progress"))
- data->progress = false;
- else
- data->progress = true;
-
- data->all = &all;
- (all.p->lp) = MULTICLASS::mc_label;
-
- string loss_function = "quantile";
- float loss_parameter = 0.5;
- delete(all.loss);
- all.loss = getLossFunction(&all, loss_function, loss_parameter);
-
- data->max_predictors = data->k - 1;
-
- learner* l = new learner(data, all.l, data->max_predictors);
- l->set_save_load<log_multi,save_load_tree>();
- l->set_learn<log_multi,learn>();
- l->set_predict<log_multi,predict>();
- l->set_finish_example<log_multi,finish_example>();
- l->set_finish<log_multi,finish>();
-
- init_tree(*data);
-
- return l;
- }
-}
+/*\t
+
+Copyright (c) by respective owners including Yahoo!, Microsoft, and
+individual contributors. All rights reserved. Released under a BSD (revised)
+license as described in the file LICENSE.node
+*/
+#include <float.h>
+#include <math.h>
+#include <stdio.h>
+#include <sstream>
+
+#include "reductions.h"
+#include "simple_label.h"
+#include "multiclass.h"
+#include "vw.h"
+
+using namespace std;
+using namespace LEARNER;
+
+namespace LOG_MULTI
+{
+ class node_pred
+ {
+ public:
+
+ double Ehk;
+ float norm_Ehk;
+ uint32_t nk;
+ uint32_t label;
+ uint32_t label_count;
+
+ bool operator==(node_pred v){
+ return (label == v.label);
+ }
+
+ bool operator>(node_pred v){
+ if(label > v.label) return true;
+ return false;
+ }
+
+ bool operator<(node_pred v){
+ if(label < v.label) return true;
+ return false;
+ }
+
+ node_pred(uint32_t l)
+ {
+ label = l;
+ Ehk = 0.f;
+ norm_Ehk = 0;
+ nk = 0;
+ label_count = 0;
+ }
+ };
+
+ typedef struct
+ {//everyone has
+ uint32_t parent;//the parent node
+ v_array<node_pred> preds;//per-class state
+ uint32_t min_count;//the number of examples reaching this node (if it's a leaf) or the minimum reaching any grandchild.
+
+ bool internal;//internal or leaf
+
+ //internal nodes have
+ uint32_t base_predictor;//id of the base predictor
+ uint32_t left;//left child
+ uint32_t right;//right child
+ float norm_Eh;//the average margin at the node
+ double Eh;//total margin at the node
+ uint32_t n;//total events at the node
+
+ //leaf has
+ uint32_t max_count;//the number of samples of the most common label
+ uint32_t max_count_label;//the most common label
+ } node;
+
+ struct log_multi
+ {
+ uint32_t k;
+ vw* all;
+
+ v_array<node> nodes;
+
+ uint32_t max_predictors;
+ uint32_t predictors_used;
+
+ bool progress;
+ uint32_t swap_resist;
+
+ uint32_t nbofswaps;
+ };
+
+ inline void init_leaf(node& n)
+ {
+ n.internal = false;
+ n.preds.erase();
+ n.base_predictor = 0;
+ n.norm_Eh = 0;
+ n.Eh = 0;
+ n.n = 0;
+ n.max_count = 0;
+ n.max_count_label = 1;
+ n.left = 0;
+ n.right = 0;
+ }
+
+ inline node init_node()
+ {
+ node node;
+
+ node.parent = 0;
+ node.min_count = 0;
+ node.preds = v_init<node_pred>();
+ init_leaf(node);
+
+ return node;
+ }
+
+ void init_tree(log_multi& d)
+ {
+ d.nodes.push_back(init_node());
+ d.nbofswaps = 0;
+ }
+
+ inline uint32_t min_left_right(log_multi& b, node& n)
+ {
+ return min(b.nodes[n.left].min_count, b.nodes[n.right].min_count);
+ }
+
+ inline uint32_t find_switch_node(log_multi& b)
+ {
+ uint32_t node = 0;
+ while(b.nodes[node].internal)
+ if(b.nodes[b.nodes[node].left].min_count
+ < b.nodes[b.nodes[node].right].min_count)
+ node = b.nodes[node].left;
+ else
+ node = b.nodes[node].right;
+ return node;
+ }
+
+ inline void update_min_count(log_multi& b, uint32_t node)
+ {//Constant time min count update.
+ while(node != 0)
+ {
+ uint32_t prev = node;
+ node = b.nodes[node].parent;
+
+ if (b.nodes[node].min_count == b.nodes[prev].min_count)
+ break;
+ else
+ b.nodes[node].min_count = min_left_right(b,b.nodes[node]);
+ }
+ }
+
+ void display_tree_dfs(log_multi& b, node node, uint32_t depth)
+ {
+ for (uint32_t i = 0; i < depth; i++)
+ cout << "\t";
+ cout << node.min_count << " " << node.left
+ << " " << node.right;
+ cout << " label = " << node.max_count_label << " labels = ";
+ for (size_t i = 0; i < node.preds.size(); i++)
+ cout << node.preds[i].label << ":" << node.preds[i].label_count << "\t";
+ cout << endl;
+
+ if (node.internal)
+ {
+ cout << "Left";
+ display_tree_dfs(b, b.nodes[node.left], depth+1);
+
+ cout << "Right";
+ display_tree_dfs(b, b.nodes[node.right], depth+1);
+ }
+ }
+
+ bool children(log_multi& b, uint32_t& current, uint32_t& class_index, uint32_t label)
+ {
+ class_index = (uint32_t)b.nodes[current].preds.unique_add_sorted(node_pred(label));
+ b.nodes[current].preds[class_index].label_count++;
+
+ if(b.nodes[current].preds[class_index].label_count > b.nodes[current].max_count)
+ {
+ b.nodes[current].max_count = b.nodes[current].preds[class_index].label_count;
+ b.nodes[current].max_count_label = b.nodes[current].preds[class_index].label;
+ }
+
+ if (b.nodes[current].internal)
+ return true;
+ else if( b.nodes[current].preds.size() > 1
+ && (b.predictors_used < b.max_predictors
+ || b.nodes[current].min_count - b.nodes[current].max_count > b.swap_resist*(b.nodes[0].min_count + 1)))
+ { //need children and we can make them.
+ uint32_t left_child;
+ uint32_t right_child;
+ if (b.predictors_used < b.max_predictors)
+ {
+ left_child = (uint32_t)b.nodes.size();
+ b.nodes.push_back(init_node());
+ right_child = (uint32_t)b.nodes.size();
+ b.nodes.push_back(init_node());
+ b.nodes[current].base_predictor = b.predictors_used++;
+ }
+ else
+ {
+ uint32_t swap_child = find_switch_node(b);
+ uint32_t swap_parent = b.nodes[swap_child].parent;
+ uint32_t swap_grandparent = b.nodes[swap_parent].parent;
+ if (b.nodes[swap_child].min_count != b.nodes[0].min_count)
+ cout << "glargh " << b.nodes[swap_child].min_count << " != " << b.nodes[0].min_count << endl;
+ b.nbofswaps++;
+
+ uint32_t nonswap_child;
+ if(swap_child == b.nodes[swap_parent].right)
+ nonswap_child = b.nodes[swap_parent].left;
+ else
+ nonswap_child = b.nodes[swap_parent].right;
+
+ if(swap_parent == b.nodes[swap_grandparent].left)
+ b.nodes[swap_grandparent].left = nonswap_child;
+ else
+ b.nodes[swap_grandparent].right = nonswap_child;
+ b.nodes[nonswap_child].parent = swap_grandparent;
+ update_min_count(b, nonswap_child);
+
+ init_leaf(b.nodes[swap_child]);
+ left_child = swap_child;
+ b.nodes[current].base_predictor = b.nodes[swap_parent].base_predictor;
+ init_leaf(b.nodes[swap_parent]);
+ right_child = swap_parent;
+ }
+ b.nodes[current].left = left_child;
+ b.nodes[left_child].parent = current;
+ b.nodes[current].right = right_child;
+ b.nodes[right_child].parent = current;
+
+ b.nodes[left_child].min_count = b.nodes[current].min_count/2;
+ b.nodes[right_child].min_count = b.nodes[current].min_count - b.nodes[left_child].min_count;
+ update_min_count(b, left_child);
+
+ b.nodes[left_child].max_count_label = b.nodes[current].max_count_label;
+ b.nodes[right_child].max_count_label = b.nodes[current].max_count_label;
+
+ b.nodes[current].internal = true;
+ }
+ return b.nodes[current].internal;
+ }
+
+ void train_node(log_multi& b, learner& base, example& ec, uint32_t& current, uint32_t& class_index)
+ {
+ if(b.nodes[current].norm_Eh > b.nodes[current].preds[class_index].norm_Ehk)
+ ec.l.simple.label = -1.f;
+ else
+ ec.l.simple.label = 1.f;
+
+ base.learn(ec, b.nodes[current].base_predictor);
+
+ ec.l.simple.label = FLT_MAX;
+ base.predict(ec, b.nodes[current].base_predictor);
+
+ b.nodes[current].Eh += (double)ec.partial_prediction;
+ b.nodes[current].preds[class_index].Ehk += (double)ec.partial_prediction;
+ b.nodes[current].n++;
+ b.nodes[current].preds[class_index].nk++;
+
+ b.nodes[current].norm_Eh = (float)b.nodes[current].Eh / b.nodes[current].n;
+ b.nodes[current].preds[class_index].norm_Ehk = (float)b.nodes[current].preds[class_index].Ehk / b.nodes[current].preds[class_index].nk;
+ }
+
+ void verify_min_dfs(log_multi& b, node node)
+ {
+ if (node.internal)
+ {
+ if (node.min_count != min_left_right(b, node))
+ {
+ cout << "badness! " << endl;
+ display_tree_dfs(b, b.nodes[0], 0);
+ }
+ verify_min_dfs(b, b.nodes[node.left]);
+ verify_min_dfs(b, b.nodes[node.right]);
+ }
+ }
+
+ size_t sum_count_dfs(log_multi& b, node node)
+ {
+ if (node.internal)
+ return sum_count_dfs(b, b.nodes[node.left]) + sum_count_dfs(b, b.nodes[node.right]);
+ else
+ return node.min_count;
+ }
+
+ inline uint32_t descend(node& n, float prediction)
+ {
+ if (prediction < 0)
+ return n.left;
+ else
+ return n.right;
+ }
+
+ void predict(log_multi& b, learner& base, example& ec)
+ {
+ MULTICLASS::multiclass mc = ec.l.multi;
+
+ label_data simple_temp;
+ simple_temp.initial = 0.0;
+ simple_temp.weight = 0.0;
+ simple_temp.label = FLT_MAX;
+ ec.l.simple = simple_temp;
+ uint32_t cn = 0;
+ while(b.nodes[cn].internal)
+ {
+ base.predict(ec, b.nodes[cn].base_predictor);
+ cn = descend(b.nodes[cn], ec.pred.scalar);
+ }
+ ec.pred.multiclass = b.nodes[cn].max_count_label;
+ ec.l.multi = mc;
+ }
+
+ void learn(log_multi& b, learner& base, example& ec)
+ {
+ // verify_min_dfs(b, b.nodes[0]);
+
+ if (ec.l.multi.label == (uint32_t)-1 || !b.all->training || b.progress)
+ predict(b,base,ec);
+
+ if(b.all->training && (ec.l.multi.label != (uint32_t)-1) && !ec.test_only) //if training the tree
+ {
+ MULTICLASS::multiclass mc = ec.l.multi;
+
+ uint32_t class_index = 0;
+ label_data simple_temp;
+ simple_temp.initial = 0.0;
+ simple_temp.weight = mc.weight;
+ ec.l.simple = simple_temp;
+
+ uint32_t cn = 0;
+
+ while(children(b, cn, class_index, mc.label))
+ {
+ train_node(b, base, ec, cn, class_index);
+ cn = descend(b.nodes[cn], ec.pred.scalar);
+ }
+
+ b.nodes[cn].min_count++;
+ update_min_count(b, cn);
+
+ ec.l.multi = mc;
+ }
+ }
+
+ void save_node_stats(log_multi& d)
+ {
+ FILE *fp;
+ uint32_t i, j;
+ uint32_t total;
+ log_multi* b = &d;
+
+ fp = fopen("atxm_debug.csv", "wt");
+
+ for(i = 0; i < b->nodes.size(); i++)
+ {
+ fprintf(fp, "Node: %4d, Internal: %1d, Eh: %7.4f, n: %6d, \n", (int) i, (int) b->nodes[i].internal, b->nodes[i].Eh / b->nodes[i].n, b->nodes[i].n);
+
+ fprintf(fp, "Label:, ");
+ for(j = 0; j < b->nodes[i].preds.size(); j++)
+ {
+ fprintf(fp, "%6d,", (int) b->nodes[i].preds[j].label);
+ }
+ fprintf(fp, "\n");
+
+ fprintf(fp, "Ehk:, ");
+ for(j = 0; j < b->nodes[i].preds.size(); j++)
+ {
+ fprintf(fp, "%7.4f,", b->nodes[i].preds[j].Ehk / b->nodes[i].preds[j].nk);
+ }
+ fprintf(fp, "\n");
+
+ total = 0;
+
+ fprintf(fp, "nk:, ");
+ for(j = 0; j < b->nodes[i].preds.size(); j++)
+ {
+ fprintf(fp, "%6d,", (int) b->nodes[i].preds[j].nk);
+ total += b->nodes[i].preds[j].nk;
+ }
+ fprintf(fp, "\n");
+
+ fprintf(fp, "max(lab:cnt:tot):, %3d,%6d,%7d,\n", (int) b->nodes[i].max_count_label, (int) b->nodes[i].max_count, (int) total);
+ fprintf(fp, "left: %4d, right: %4d", (int) b->nodes[i].left, (int) b->nodes[i].right);
+ fprintf(fp, "\n\n");
+ }
+
+ fclose(fp);
+ }
+
+ void finish(log_multi& b)
+ {
+ save_node_stats(b);
+ cout << "used " << b.nbofswaps << " swaps" << endl;
+ }
+
+ void save_load_tree(log_multi& b, io_buf& model_file, bool read, bool text)
+ {
+ if (model_file.files.size() > 0)
+ {
+ char buff[512];
+
+ uint32_t text_len = sprintf(buff, "k = %d ",b.k);
+ bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.k), "", read, buff, text_len, text);
+ uint32_t temp = (uint32_t)b.nodes.size();
+ text_len = sprintf(buff, "nodes = %d ",temp);
+ bin_text_read_write_fixed(model_file,(char*)&temp, sizeof(temp), "", read, buff, text_len, text);
+ if (read)
+ for (uint32_t j = 1; j < temp; j++)
+ b.nodes.push_back(init_node());
+ text_len = sprintf(buff, "max_predictors = %d ",b.max_predictors);
+ bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.max_predictors), "", read, buff, text_len, text);
+
+ text_len = sprintf(buff, "predictors_used = %d ",b.predictors_used);
+ bin_text_read_write_fixed(model_file,(char*)&b.predictors_used, sizeof(b.predictors_used), "", read, buff, text_len, text);
+
+ text_len = sprintf(buff, "progress = %d ",b.progress);
+ bin_text_read_write_fixed(model_file,(char*)&b.progress, sizeof(b.progress), "", read, buff, text_len, text);
+
+ text_len = sprintf(buff, "swap_resist = %d\n",b.swap_resist);
+ bin_text_read_write_fixed(model_file,(char*)&b.swap_resist, sizeof(b.swap_resist), "", read, buff, text_len, text);
+
+ for (size_t j = 0; j < b.nodes.size(); j++)
+ {//Need to read or write nodes.
+ node& n = b.nodes[j];
+ text_len = sprintf(buff, " parent = %d",n.parent);
+ bin_text_read_write_fixed(model_file,(char*)&n.parent, sizeof(n.parent), "", read, buff, text_len, text);
+
+ uint32_t temp = (uint32_t)n.preds.size();
+ text_len = sprintf(buff, " preds = %d",temp);
+ bin_text_read_write_fixed(model_file,(char*)&temp, sizeof(temp), "", read, buff, text_len, text);
+ if (read)
+ for (uint32_t k = 0; k < temp; k++)
+ n.preds.push_back(node_pred(1));
+
+ text_len = sprintf(buff, " min_count = %d",n.min_count);
+ bin_text_read_write_fixed(model_file,(char*)&n.min_count, sizeof(n.min_count), "", read, buff, text_len, text);
+
+ uint32_t text_len = sprintf(buff, " internal = %d",n.internal);
+ bin_text_read_write_fixed(model_file,(char*)&n.internal, sizeof(n.internal), "", read, buff, text_len, text)
+;
+
+ if (n.internal)
+ {
+ text_len = sprintf(buff, " base_predictor = %d",n.base_predictor);
+ bin_text_read_write_fixed(model_file,(char*)&n.base_predictor, sizeof(n.base_predictor), "", read, buff, text_len, text);
+
+ text_len = sprintf(buff, " left = %d",n.left);
+ bin_text_read_write_fixed(model_file,(char*)&n.left, sizeof(n.left), "", read, buff, text_len, text);
+
+ text_len = sprintf(buff, " right = %d",n.right);
+ bin_text_read_write_fixed(model_file,(char*)&n.right, sizeof(n.right), "", read, buff, text_len, text);
+
+ text_len = sprintf(buff, " norm_Eh = %f",n.norm_Eh);
+ bin_text_read_write_fixed(model_file,(char*)&n.norm_Eh, sizeof(n.norm_Eh), "", read, buff, text_len, text);
+
+ text_len = sprintf(buff, " Eh = %f",n.Eh);
+ bin_text_read_write_fixed(model_file,(char*)&n.Eh, sizeof(n.Eh), "", read, buff, text_len, text);
+
+ text_len = sprintf(buff, " n = %d\n",n.n);
+ bin_text_read_write_fixed(model_file,(char*)&n.n, sizeof(n.n), "", read, buff, text_len, text);
+ }
+ else
+ {
+ text_len = sprintf(buff, " max_count = %d",n.max_count);
+ bin_text_read_write_fixed(model_file,(char*)&n.max_count, sizeof(n.max_count), "", read, buff, text_len, text);
+ text_len = sprintf(buff, " max_count_label = %d\n",n.max_count_label);
+ bin_text_read_write_fixed(model_file,(char*)&n.max_count_label, sizeof(n.max_count_label), "", read, buff, text_len, text);
+ }
+
+ for (size_t k = 0; k < n.preds.size(); k++)
+ {
+ node_pred& p = n.preds[k];
+
+ text_len = sprintf(buff, " Ehk = %f",p.Ehk);
+ bin_text_read_write_fixed(model_file,(char*)&p.Ehk, sizeof(p.Ehk), "", read, buff, text_len, text);
+
+ text_len = sprintf(buff, " norm_Ehk = %f",p.norm_Ehk);
+ bin_text_read_write_fixed(model_file,(char*)&p.norm_Ehk, sizeof(p.norm_Ehk), "", read, buff, text_len, text);
+
+ text_len = sprintf(buff, " nk = %d",p.nk);
+ bin_text_read_write_fixed(model_file,(char*)&p.nk, sizeof(p.nk), "", read, buff, text_len, text);
+
+ text_len = sprintf(buff, " label = %d",p.label);
+ bin_text_read_write_fixed(model_file,(char*)&p.label, sizeof(p.label), "", read, buff, text_len, text);
+
+ text_len = sprintf(buff, " label_count = %d\n",p.label_count);
+ bin_text_read_write_fixed(model_file,(char*)&p.label_count, sizeof(p.label_count), "", read, buff, text_len, text);
+ }
+ }
+ }
+ }
+
+ void finish_example(vw& all, log_multi&, example& ec)
+ {
+ MULTICLASS::output_example(all, ec);
+ VW::finish_example(all, &ec);
+ }
+
+ learner* setup(vw& all, po::variables_map& vm) //learner setup
+ {
+ log_multi* data = (log_multi*)calloc(1, sizeof(log_multi));
+
+ po::options_description opts("TXM Online options");
+ opts.add_options()
+ ("no_progress", "disable progressive validation")
+ ("swap_resistance", po::value<uint32_t>(&(data->swap_resist))->default_value(4), "higher = more resistance to swap, default=4");
+
+ vm = add_options(all, opts);
+
+ data->k = (uint32_t)vm["log_multi"].as<size_t>();
+
+ //append log_multi with nb_actions to options_from_file so it is saved to regressor later
+ std::stringstream ss;
+ ss << " --log_multi " << data->k;
+ all.file_options.append(ss.str());
+
+ if (vm.count("no_progress"))
+ data->progress = false;
+ else
+ data->progress = true;
+
+ data->all = &all;
+ (all.p->lp) = MULTICLASS::mc_label;
+
+ string loss_function = "quantile";
+ float loss_parameter = 0.5;
+ delete(all.loss);
+ all.loss = getLossFunction(&all, loss_function, loss_parameter);
+
+ data->max_predictors = data->k - 1;
+
+ learner* l = new learner(data, all.l, data->max_predictors);
+ l->set_save_load<log_multi,save_load_tree>();
+ l->set_learn<log_multi,learn>();
+ l->set_predict<log_multi,predict>();
+ l->set_finish_example<log_multi,finish_example>();
+ l->set_finish<log_multi,finish>();
+
+ init_tree(*data);
+
+ return l;
+ }
+}