/* Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #ifndef SEARN_H #define SEARN_H #include #include "parse_args.h" #include "parse_primitives.h" #include "v_hashmap.h" #include "cost_sensitive.h" #include #define clog_print_audit_features(ec,reg) { print_audit_features(reg, ec); } #define MAX_BRANCHING_FACTOR 128 #define cdbg clog #undef cdbg #define cdbg if (1) {} else clog namespace Searn { struct searn_private; struct searn_task; // options: extern uint32_t AUTO_HISTORY, AUTO_HAMMING_LOSS, EXAMPLES_DONT_CHANGE, IS_LDF; struct searn { // INTERFACE // for managing task-specific data that you want on the heap: template void set_task_data(T*data) { task_data = data; } template T* get_task_data() { return (T*)task_data; } // for setting programmatic options during initialization void set_options(uint32_t opts); // for snapshotting your algorithm's state void snapshot(size_t index, size_t tag, void* data_ptr, size_t sizeof_data, bool used_for_prediction); // for explicitly declaring your loss void loss(float incr_loss, size_t predictions_since_last=1); // for making predictions in regular (non-LDF) mode: uint32_t predict(example* ec, uint32_t one_ystar, v_array* yallowed=NULL); // if there is a single oracle action uint32_t predict(example* ec, v_array* ystar, v_array* yallowed=NULL); // if there are multiple oracle actions // for making predictions in LDF mode: uint32_t predict(example* ecs, size_t ec_len, v_array* ystar, v_array* yallowed=NULL); // if there is a single oracle action uint32_t predict(example* ecs, size_t ec_len, uint32_t one_ystar, v_array* yallowed=NULL); // if there is are multiple oracle action // for generating output (check to see if output().good() before attempting to write!) stringstream& output(); // internal data searn_task* task; searn_private* priv; void* task_data; }; template void check_option(T& ret, vw&all, po::variables_map& vm, po::variables_map& vm_file, const char* opt_name, bool default_to_cmdline, bool(*equal)(T,T), const char* mismatch_error_string, const char* required_error_string); void check_option(bool& ret, vw&all, po::variables_map& vm, po::variables_map& vm_file, const char* opt_name, bool default_to_cmdline, const char* mismatch_error_string); bool string_equal(string a, string b); bool float_equal(float a, float b); bool uint32_equal(uint32_t a, uint32_t b); bool size_equal(size_t a, size_t b); struct searn_task { const char* task_name; void (*initialize)(searn&,size_t&,std::vector&, po::variables_map&, po::variables_map&); void (*finish)(searn&); void (*structured_predict)(searn&, std::vector); }; LEARNER::learner* setup(vw&, std::vector&, po::variables_map&, po::variables_map&); void searn_finish(void*); void searn_drive(void*); void searn_learn(void*,example*); } #endif