Welcome to mirror list, hosted at ThFree Co, Russian Federation.

vwdll.cpp « vowpalwabbit - github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 501730b5bb4b885e65445a7c8c8010392445a237 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
#include <memory>

#ifdef WIN32
#define USE_CODECVT
#include <codecvt>
#endif

#include <locale>
#include <string>

#include "vwdll.h"
#include "parser.h"
#include "simple_label.h"
#include "parse_args.h"
#include "vw.h"

// This interface now provides "wide" functions for compatibility with .NET interop
// The default functions assume a wide (16 bit char pointer) that is converted to a utf8-string and passed to
// a function which takes a narrow (8 bit char pointer) function. Both are exposed in the c/c++ API 
// so that programs using 8 bit wide characters can use the direct call without conversion and 
//  programs using 16 bit characters can use the default wide versions of the functions.
// "Ansi versions  (FcnA instead of Fcn) have only been written for functions which handle strings.

// a future optimization would be to write an inner version of hash feature which either hashed the
// wide string directly (and live with the different hash values) or incorporate the UTF-16 to UTF-8 conversion
// in the hashing to avoid allocating an intermediate string.
 
extern "C"
{
#ifdef USE_CODECVT
	VW_DLL_MEMBER VW_HANDLE VW_CALLING_CONV VW_Initialize(const char16_t * pstrArgs)
	{
		std::wstring_convert<std::codecvt_utf8<char16_t>, char16_t> convert;
		std::string sa(convert.to_bytes(pstrArgs));
		return VW_InitializeA(sa.c_str());
	}
#endif


	VW_DLL_MEMBER VW_HANDLE VW_CALLING_CONV VW_InitializeA(const char * pstrArgs)
	{
		string s(pstrArgs);
		vw* all = VW::initialize(s);
		return static_cast<VW_HANDLE>(all);
	}
	
	VW_DLL_MEMBER void      VW_CALLING_CONV VW_Finish(VW_HANDLE handle)
	{
		vw * pointer = static_cast<vw*>(handle);
		if (pointer->numpasses > 1)
			{
			adjust_used_index(*pointer);
			pointer->do_reset_source = true;
			VW::start_parser(*pointer,false);
			LEARNER::generic_driver(*pointer);
			VW::end_parser(*pointer); 
			}
		else
			release_parser_datastructures(*pointer);

		VW::finish(*pointer);
	}

	VW_DLL_MEMBER VW_EXAMPLE VW_CALLING_CONV VW_ImportExample(VW_HANDLE handle, VW_FEATURE_SPACE* features, size_t len)
	{
		vw * pointer = static_cast<vw*>(handle);
		VW::primitive_feature_space * f = reinterpret_cast<VW::primitive_feature_space*>( features );
		return static_cast<VW_EXAMPLE>(VW::import_example(*pointer, f, len));
	}
	
	VW_DLL_MEMBER VW_FEATURE_SPACE VW_CALLING_CONV VW_ExportExample(VW_HANDLE handle, VW_EXAMPLE e, size_t * plen)
	{
		vw* pointer = static_cast<vw*>(handle);
		example* ex = static_cast<example*>(e);
		return static_cast<VW_FEATURE_SPACE>(VW::export_example(*pointer, ex, *plen));
	}

	VW_DLL_MEMBER void VW_CALLING_CONV VW_ReleaseFeatureSpace(VW_FEATURE_SPACE* features, size_t len)
	{
		VW::primitive_feature_space * f = reinterpret_cast<VW::primitive_feature_space*>( features );
		VW::releaseFeatureSpace(f, len);
	}
#ifdef USE_CODECVT
	VW_DLL_MEMBER VW_EXAMPLE VW_CALLING_CONV VW_ReadExample(VW_HANDLE handle, const char16_t * line)
	{
		std::wstring_convert<std::codecvt_utf8<char16_t>, char16_t> convert;
		std::string sa(convert.to_bytes(line));
		return VW_ReadExampleA(handle, sa.c_str());
	}
#endif
	VW_DLL_MEMBER VW_EXAMPLE VW_CALLING_CONV VW_ReadExampleA(VW_HANDLE handle, const char * line)
	{
		vw * pointer = static_cast<vw*>(handle);
		// BUGBUG: I really dislike this const_cast. should VW really change the input string?
		return static_cast<VW_EXAMPLE>(VW::read_example(*pointer, const_cast<char*>(line)));
	}
	
	VW_DLL_MEMBER void VW_CALLING_CONV VW_StartParser(VW_HANDLE handle, bool do_init)
	{
		vw * pointer = static_cast<vw*>(handle);
		VW::start_parser(*pointer, do_init);
	}

	VW_DLL_MEMBER void VW_CALLING_CONV VW_EndParser(VW_HANDLE handle)
	{
		vw * pointer = static_cast<vw*>(handle);
		VW::end_parser(*pointer);
	}

	VW_DLL_MEMBER VW_EXAMPLE VW_CALLING_CONV VW_GetExample(VW_HANDLE handle)
	{
		vw * pointer = static_cast<vw*>(handle);
		parser * parser_pointer = static_cast<parser *>(pointer->p);
		return static_cast<VW_EXAMPLE>(VW::get_example(parser_pointer));
	}

	VW_DLL_MEMBER float VW_CALLING_CONV VW_GetLabel(VW_EXAMPLE e)
	{
		return VW::get_label(static_cast<example*>(e));
	}

	VW_DLL_MEMBER float VW_CALLING_CONV VW_GetTopicPrediction(VW_EXAMPLE e, size_t i)
	{
		return VW::get_topic_prediction(static_cast<example*>(e), i);
	}

	VW_DLL_MEMBER float VW_CALLING_CONV VW_GetImportance(VW_EXAMPLE e)
	{
		return VW::get_importance(static_cast<example*>(e));
	}

	VW_DLL_MEMBER float VW_CALLING_CONV VW_GetInitial(VW_EXAMPLE e)
	{
		return VW::get_initial(static_cast<example*>(e));
	}

	VW_DLL_MEMBER float VW_CALLING_CONV VW_GetPrediction(VW_EXAMPLE e)
	{
		return VW::get_prediction(static_cast<example*>(e));
	}

	VW_DLL_MEMBER float VW_CALLING_CONV VW_GetCostSensitivePrediction(VW_EXAMPLE e)
	{
		return VW::get_cost_sensitive_prediction(static_cast<example*>(e));
	}

	VW_DLL_MEMBER size_t VW_CALLING_CONV VW_GetTagLength(VW_EXAMPLE e)
	{
		return VW::get_tag_length(static_cast<example*>(e));
	}

	VW_DLL_MEMBER const char* VW_CALLING_CONV VW_GetTag(VW_EXAMPLE e)
	{
		return VW::get_tag(static_cast<example*>(e));
	}

	VW_DLL_MEMBER size_t VW_CALLING_CONV VW_GetFeatureNumber(VW_EXAMPLE e)
	{
		return VW::get_feature_number(static_cast<example*>(e));
	}

	VW_DLL_MEMBER VW_FEATURE VW_CALLING_CONV VW_GetFeatures(VW_HANDLE handle, VW_EXAMPLE e, size_t* plen)
	{
		vw* pointer = static_cast<vw*>(handle);
		return VW::get_features(*pointer, static_cast<example*>(e), *plen);
	}

	VW_DLL_MEMBER void VW_CALLING_CONV VW_ReturnFeatures(VW_FEATURE f)
	{
		VW::return_features(static_cast<feature*>(f));
	}
	VW_DLL_MEMBER void VW_CALLING_CONV VW_FinishExample(VW_HANDLE handle, VW_EXAMPLE e)
	{
		vw * pointer = static_cast<vw*>(handle);
		VW::finish_example(*pointer, static_cast<example*>(e));
	}
#ifdef USE_CODECVT
	VW_DLL_MEMBER size_t VW_CALLING_CONV VW_HashSpace(VW_HANDLE handle, const char16_t * s)
	{
		std::wstring_convert<std::codecvt_utf8<char16_t>, char16_t> convert;
		std::string sa(convert.to_bytes(s));
		return VW_HashSpaceA(handle,sa.c_str());
	}
#endif
	VW_DLL_MEMBER size_t VW_CALLING_CONV VW_HashSpaceA(VW_HANDLE handle, const char * s)
	{
		vw * pointer = static_cast<vw*>(handle);
		string str(s);
		return VW::hash_space(*pointer, str);
	}

#ifdef USE_CODECVT
	VW_DLL_MEMBER size_t VW_CALLING_CONV VW_HashFeature(VW_HANDLE handle, const char16_t * s, unsigned long u)
	{
		std::wstring_convert<std::codecvt_utf8<char16_t>, char16_t> convert;
		std::string sa(convert.to_bytes(s));
		return VW_HashFeatureA(handle,sa.c_str(),u);
	}
#endif

	VW_DLL_MEMBER size_t VW_CALLING_CONV VW_HashFeatureA(VW_HANDLE handle, const char * s, unsigned long u)
	{
		vw * pointer = static_cast<vw*>(handle);
		string str(s);
		return VW::hash_feature(*pointer, str, u);
	}
	
	VW_DLL_MEMBER void  VW_CALLING_CONV VW_AddLabel(VW_EXAMPLE e, float label, float weight, float base)
	{
		example* ex = static_cast<example*>(e);
		return VW::add_label(ex, label, weight, base);
	}

	VW_DLL_MEMBER float VW_CALLING_CONV VW_Learn(VW_HANDLE handle, VW_EXAMPLE e)
	{
		vw * pointer = static_cast<vw*>(handle);
		example * ex = static_cast<example*>(e);
		pointer->learn(ex);
		return VW::get_prediction(ex);
	}

	VW_DLL_MEMBER float VW_CALLING_CONV VW_Predict(VW_HANDLE handle, VW_EXAMPLE e) 
	{
		vw * pointer = static_cast<vw*>(handle);
		example * ex = static_cast<example*>(e);
		pointer->l->predict(*ex);
		//BUG: The below method may return garbage as it assumes a certain structure for ex->ld
		//which may not be the actual one used (e.g., for cost-sensitive multi-class learning)
		return VW::get_prediction(ex);
	}

	VW_DLL_MEMBER float VW_CALLING_CONV VW_Get_Weight(VW_HANDLE handle, size_t index, size_t offset)
	{
		vw* pointer = static_cast<vw*>(handle);
		return VW::get_weight(*pointer, (uint32_t) index, (uint32_t) offset);
	}

	VW_DLL_MEMBER void VW_CALLING_CONV VW_Set_Weight(VW_HANDLE handle, size_t index, size_t offset, float value)
	{
		vw* pointer = static_cast<vw*>(handle);
		return VW::set_weight(*pointer, (uint32_t) index, (uint32_t)offset, value);
	}

	VW_DLL_MEMBER size_t VW_CALLING_CONV VW_Num_Weights(VW_HANDLE handle)
	{
		vw* pointer = static_cast<vw*>(handle);
		return VW::num_weights(*pointer);
	}

	VW_DLL_MEMBER size_t VW_CALLING_CONV VW_Get_Stride(VW_HANDLE handle)
	{
		vw* pointer = static_cast<vw*>(handle);
		return VW::get_stride(*pointer);
	}
}