Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@nyclamp.(none)>2013-10-30 21:58:27 +0400
committerJohn Langford <jl@nyclamp.(none)>2013-10-30 21:58:27 +0400
commitd40a24a72b69dc00cd02b153caf3179939030904 (patch)
treec469cb7014387784b81109738221bf37e628a48b /vowpalwabbit/gd_mf.cc
parent53a0a7e5a757761ae9b0780c1a0c2cb648a08faf (diff)
finish is autorecursive
Diffstat (limited to 'vowpalwabbit/gd_mf.cc')
-rw-r--r--vowpalwabbit/gd_mf.cc19
1 files changed, 11 insertions, 8 deletions
diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc
index 79f874d2..2114d558 100644
--- a/vowpalwabbit/gd_mf.cc
+++ b/vowpalwabbit/gd_mf.cc
@@ -24,6 +24,10 @@ license as described in the file LICENSE.
using namespace std;
namespace GDMF {
+ struct gdmf {
+ vw* all;
+ };
+
void mf_local_predict(example* ec, regressor& reg);
float mf_inline_predict(vw& all, example* &ec)
@@ -214,7 +218,7 @@ float mf_predict(vw& all, example* ex)
void save_load(void* d, io_buf& model_file, bool read, bool text)
{
- vw* all = (vw*)d;
+ vw* all = ((gdmf*)d)->all;
uint32_t length = 1 << all->num_bits;
uint32_t stride = all->reg.stride;
@@ -268,7 +272,7 @@ float mf_predict(vw& all, example* ex)
void end_pass(void* d)
{
- vw* all = (vw*)d;
+ vw* all = ((gdmf*)d)->all;
all->eta *= all->eta_decay_rate;
if (all->save_per_pass)
@@ -279,20 +283,19 @@ void end_pass(void* d)
void learn(void* d, example* ec)
{
- vw* all = (vw*)d;
+ vw* all = ((gdmf*)d)->all;
mf_predict(*all,ec);
if (all->training && ((label_data*)(ec->ld))->label != FLT_MAX)
mf_inline_train(*all, ec, ec->eta_round);
}
- void finish(void* d)
- { }
-
learner setup(vw& all)
{
- sl_t sl = {&all, save_load};
- learner l(&all,LEARNER::generic_driver,learn,finish,sl);
+ gdmf* data = (gdmf*)calloc(1,sizeof(gdmf));
+ data->all = &all;
+ sl_t sl = {data, save_load};
+ learner l(data,LEARNER::generic_driver,learn,sl);
l.set_end_pass(end_pass);
return l;
}