Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorariel faigon <github.2009@yendor.com>2014-07-19 04:39:05 +0400
committerariel faigon <github.2009@yendor.com>2014-07-19 04:39:05 +0400
commit7e138ac19bb3e4be88201d521249d87f52e378f3 (patch)
tree18e7420642f4c4419f0592e444884df9cc84cec7
parent4d7021eb6b2b307702e6c3f3e93dfd93a762eefe (diff)
parentee9470a5f2c94691fe3075c926b8651c0233db5b (diff)
Merge branch 'master' of git://github.com/JohnLangford/vowpal_wabbit
-rw-r--r--cs_test/Program.cs51
-rw-r--r--cs_test/VowpalWabbitInterface.cs238
-rw-r--r--vowpalwabbit/example.cc2
-rw-r--r--vowpalwabbit/gd.cc4
-rw-r--r--vowpalwabbit/libvw.vcxproj4
-rw-r--r--vowpalwabbit/parser.cc5
-rw-r--r--vowpalwabbit/stagewise_poly.cc173
-rw-r--r--vowpalwabbit/vw.h1
-rw-r--r--vowpalwabbit/vw_static.vcxproj2
-rw-r--r--vowpalwabbit/vwdll.cpp5
-rw-r--r--vowpalwabbit/vwdll.h1
11 files changed, 311 insertions, 175 deletions
diff --git a/cs_test/Program.cs b/cs_test/Program.cs
index 7660e8f2..9f74e416 100644
--- a/cs_test/Program.cs
+++ b/cs_test/Program.cs
@@ -19,6 +19,7 @@ namespace cs_test
RunParserTest();
RunSpeedTest();
RunFlatExampleTestEx();
+ // RunLDAPredict();
//RunVWParse_and_VWLearn();
}
@@ -68,7 +69,7 @@ namespace cs_test
tfeatures[2].x = 1;
featureSpace[1].len = 3;
- IntPtr importedExample = VowpalWabbitInterface.ImportExample(vw, featureSpacePtr, featureSpace.Length);
+ IntPtr importedExample = VowpalWabbitInterface.ImportExample(vw, featureSpacePtr, (IntPtr)featureSpace.Length);
VowpalWabbitInterface.AddLabel(importedExample, 1);
@@ -97,13 +98,13 @@ namespace cs_test
float label = VowpalWabbitInterface.GetLabel(example);
count++;
- int featureSpaceLen = 0;
+ IntPtr featureSpaceLen = (IntPtr)0;
IntPtr featureSpacePtr = VowpalWabbitInterface.ExportExample(vw, example, ref featureSpaceLen);
- VowpalWabbitInterface.FEATURE_SPACE[] featureSpace = new VowpalWabbitInterface.FEATURE_SPACE[featureSpaceLen];
+ VowpalWabbitInterface.FEATURE_SPACE[] featureSpace = new VowpalWabbitInterface.FEATURE_SPACE[(int)featureSpaceLen];
int featureSpace_size = Marshal.SizeOf(typeof(VowpalWabbitInterface.FEATURE_SPACE));
- for (int i = 0; i < featureSpaceLen; i++)
+ for (int i = 0; i < (int)featureSpaceLen; i++)
{
IntPtr curfeatureSpacePos = new IntPtr(featureSpacePtr.ToInt32() + i * featureSpace_size);
featureSpace[i] = (VowpalWabbitInterface.FEATURE_SPACE)Marshal.PtrToStructure(curfeatureSpacePos, typeof(VowpalWabbitInterface.FEATURE_SPACE));
@@ -167,7 +168,7 @@ namespace cs_test
VowpalWabbitInterface.StartParser(vw, false);
- uint stride = VowpalWabbitInterface.Get_Stride(vw);
+ uint stride = (uint)VowpalWabbitInterface.Get_Stride(vw);
int count = 0;
IntPtr example = IntPtr.Zero;
@@ -180,22 +181,22 @@ namespace cs_test
float initial = VowpalWabbitInterface.GetInitial(example);
float label = VowpalWabbitInterface.GetLabel(example);
- UInt32 tag_len = VowpalWabbitInterface.GetTagLength(example);
+ UInt32 tag_len = (UInt32)VowpalWabbitInterface.GetTagLength(example);
byte[] tag = new byte[tag_len];
if (tag_len > 0)
- //Marshal.Copy(VowpalWabbitInterface.GetTag(example), tag, 0, (int)tag_len); //error CS1502: The best overloaded method match for 'System.Runtime.InteropServices.Marshal.Copy(int[], int, System.IntPtr, int)' has some invalid arguments
- ;
- UInt32 num_features = VowpalWabbitInterface.GetFeatureNumber(example);
+ Marshal.Copy(VowpalWabbitInterface.GetTag(example), tag, 0, (int)tag_len);
+
+ UInt32 num_features = (UInt32)VowpalWabbitInterface.GetFeatureNumber(example);
VowpalWabbitInterface.FEATURE[] f;
if (num_features > 0)
{
f = new VowpalWabbitInterface.FEATURE[num_features];
- int feature_count = 0;
+ IntPtr feature_count = (IntPtr)0;
IntPtr ret = VowpalWabbitInterface.GetFeatures(vw, example, ref feature_count);
int feature_size = Marshal.SizeOf(typeof(VowpalWabbitInterface.FEATURE));
- for (int i = 0; i < feature_count; i++)
+ for (int i = 0; i < (int)feature_count; i++)
{
IntPtr curfeaturePos = new IntPtr(ret.ToInt32() + i * feature_size);
f[i] = (VowpalWabbitInterface.FEATURE)Marshal.PtrToStructure(curfeaturePos, typeof(VowpalWabbitInterface.FEATURE));
@@ -219,13 +220,13 @@ namespace cs_test
IntPtr.Zero == ex)
return;
- int featureSpaceLen = 0;
+ IntPtr featureSpaceLen = (IntPtr)0;
IntPtr featureSpacePtr = VowpalWabbitInterface.ExportExample(vw, ex, ref featureSpaceLen);
- this.featureSpace = new VowpalWabbitInterface.FEATURE_SPACE[featureSpaceLen];
+ this.featureSpace = new VowpalWabbitInterface.FEATURE_SPACE[(int)featureSpaceLen];
int featureSpace_size = Marshal.SizeOf(typeof(VowpalWabbitInterface.FEATURE_SPACE));
- for (int i = 0; i < featureSpaceLen; i++)
+ for (int i = 0; i < (int)featureSpaceLen; i++)
{
IntPtr curfeatureSpacePos = new IntPtr(featureSpacePtr.ToInt32() + i * featureSpace_size);
this.featureSpace[i] = (VowpalWabbitInterface.FEATURE_SPACE)Marshal.PtrToStructure(curfeatureSpacePos, typeof(VowpalWabbitInterface.FEATURE_SPACE));
@@ -239,9 +240,27 @@ namespace cs_test
}
}
- VowpalWabbitInterface.ReleaseFeatureSpace(featureSpacePtr, featureSpaceLen);
+ VowpalWabbitInterface.ReleaseFeatureSpace(featureSpacePtr, (IntPtr)featureSpaceLen);
+ }
+ }
+
+ private static void RunLDAPredict()
+ {
+ IntPtr vw = VowpalWabbitInterface.Initialize("-i wiki1k.model -t --quiet");
+
+ IntPtr example = VowpalWabbitInterface.ReadExample(vw, "| 0:1 2049:6 2:3 5592:1 2796:1 6151:1 6154:1 6157:2 6160:2 1027:2 6168:1 4121:1 6170:1 4124:1 29:1 35:1 2088:1 2091:1 2093:2 2095:3 4145:3 5811:1 53:1 58:1 6204:6 66:2 69:2 4167:1 6216:2 75:3 2402:1 86:1 2135:2 3126:1 4185:1 90:4 2144:1 4193:1 99:1 7185:2 2156:1 110:2 2161:1 114:2 1043:1 2165:1 2166:3 119:2 6265:1 4222:3 4224:1 4230:1 705:1 2674:1 6287:1 2192:1 145:7 2198:1 2200:2 4263:1 6312:1 5148:1 4269:3 6320:4 2227:1 4283:1 4285:2 1397:2 197:2 2246:3 2247:12 201:1 4299:1 2253:1 6351:4 6353:1 4306:1 6179:1 212:1 215:3 2264:1 3108:1 2266:1 224:1 4321:1 6372:1 229:1 2281:4 6381:1 4336:1 241:2 6388:1 2294:1 2297:1 1066:1 6402:1 6405:1 6410:7 6412:2 2322:5 2329:2 282:2 6191:1 6428:1 6431:1 6433:1 4386:21 6436:5 4390:3 6439:3 296:3 1415:3 6444:3 2350:2 2354:5 307:1 6457:3 315:1 319:1 4416:4 4419:1 325:1 326:2 6472:1 6474:1 334:2 1421:2 2384:1 1516:1 340:1 4438:1 344:2 6492:5 2401:1 354:1 4452:2 6505:4 402:3 4463:1 2418:1 2451:3 375:1 4472:1 4478:2 4479:2 2437:2 4487:1 4489:2 4493:2 2448:1 5528:1 4498:1 6547:4 6549:1 406:2 2673:1 2456:2 6554:1 4507:1 4513:1 418:3 6563:1 6566:1 5873:1 2472:10 1095:1 6572:1 4525:1 4529:2 2485:2 4535:15 6587:1 444:3 6590:1 449:1 456:1 2509:6 6221:3 6562:1 2467:1 468:1 902:2 2519:1 2607:1 4653:1 6626:1 422:1 2539:6 493:4 494:1 4591:1 6644:2 3156:1 2554:1 509:1 4606:2 2562:1 516:1 2570:2 524:2 6669:1 2576:1 2577:1 4626:1 6678:1 2584:1 6916:2 538:1 7600:1 547:2 549:2 553:9 555:1 2337:1 4655:1 567:1 5679:1 570:2 6722:2 579:2 6727:2 4793:1 586:1 590:4 2643:15 4694:14 4696:6 4698:1 603:3 4700:1 6749:1 6294:1 4704:1 613:1 4710:2 2833:1 6247:1 1469:1 6769:1 6770:1 629:1 4727:1 2682:4 640:1 642:1 6793:1 2703:1 659:6 772:1 664:1 2714:1 1135:4 3525:1 4768:2 674:1 678:1 4783:1 7624:2 690:1 115:1 1481:1 697:4 6843:1 2748:1 2753:2 6262:2 6854:1 4807:1 6856:2 2763:2 6863:1 2770:1 5923:3 6869:1 4824:2 4834:2 1489:1 2793:4 4844:2 4848:2 2801:1 755:1 2807:1 763:1 2815:2 1152:1 2818:2 2820:2 7638:1 778:1 6923:1 2831:4 6929:1 4882:1 4887:2 4888:16 6940:6 798:2 6950:2 4904:2 809:1 4907:1 4909:4 2870:1 4919:3 4922:2 2879:6 4930:1 4932:5 2892:1 842:1 6988:1 846:1 4943:1 6999:3 4952:1 864:1 4966:5 1853:2 2929:1 7026:2 5267:1 4984:1 4987:1 894:1 6440:1 7042:1 7045:1 4998:3 2953:2 7050:1 2955:3 7053:2 5014:1 836:1 5018:1 3443:2 924:4 7071:8 7072:1 930:1 936:3 5033:3 5036:1 942:2 2991:1 5047:1 7096:1 7099:2 3005:1 3006:3 3008:1 962:3 963:1 3013:1 967:2 5065:3 2419:1 5068:1 5070:1 976:1 977:1 7125:1 3031:1 7130:24 3039:3 7137:2 5090:1 5091:2 996:17 997:3 3047:2 7147:1 7149:1 5105:1 3060:1 3062:13 7159:1 5112:1 3066:4 5631:1 1022:1 1023:1 7171:1 5126:4 1032:1 5131:4 3087:1 2904:1 3090:1 7187:3 5147:1 3100:1 7200:2 7201:4 1058:2 7203:5 5156:2 7207:2 1065:6 5162:3 3116:6 5165:1 7214:1 3119:1 7222:1 5180:1 3133:1 1086:2 5183:15 7233:1 5188:2 7239:4 5192:1 1097:1 5194:2 405:1 4621:1 5200:1 3153:1 855:1 7252:2 1112:1 5211:7 7675:2 7264:1 5218:2 2235:1 5220:1 3173:1 1129:2 1130:5 3181:1 1134:1 7279:1 3184:1 3186:1 1139:1 191:1 3197:1 5248:2 5249:1 993:1 2582:1 1160:2 1165:1 7315:2 3223:1 7321:1 3229:2 4293:1 2631:1 7334:7 3239:3 7338:3 3243:1 5293:2 7344:1 7348:1 6345:3 1226:1 1216:2 3041:1 2361:1 3445:1 3273:1 7370:2 3277:1 3280:4 7378:1 7381:1 3287:4 3288:1 3295:2 6520:1 5348:1 5349:5 7398:1 3303:1 5354:1 5357:1 5358:2 7408:1 5365:2 4991:1 5372:2 7421:1 5374:8 5376:1 1921:1 7434:1 3342:1 1295:1 1296:1 3349:3 6361:1 1306:2 1583:1 5409:3 6113:1 2950:1 3975:1 5420:11 7469:1 1928:1 3381:2 1334:1 5001:5 5434:1 7391:2 1341:1 7487:1 1345:2 7491:1 5449:1 1355:1 2957:1 7505:2 5458:6 3114:1 5460:2 3641:2 7512:1 5466:1 5470:1 5350:1 7526:1 7529:1 7531:1 1388:2 5488:1 1395:3 7541:2 7546:1 1258:1 1407:1 3456:2 7555:2 7557:1 7558:1 5511:2 7560:1 7563:1 4674:1 1424:2 7576:4 3483:3 1437:2 5535:3 7584:1 5539:1 1449:1 5231:1 5548:1 5549:5 3503:1 5552:1 1458:1 5556:1 7611:1 3517:2 3317:3 5570:2 1477:6 5576:2 5577:1 3530:1 3531:1 1485:1 5585:1 7210:1 1492:1 5590:2 5591:1 3544:1 118:1 1502:1 3551:1 3558:3 1513:1 5612:1 3565:2 6397:1 5616:1 4691:2 5622:7 7671:1 3577:1 5626:1 6393:1 1532:2 5629:1 3583:2 7683:2 3590:3 7689:1 5644:1 5650:12 7699:1 5654:3 5655:1 3616:1 1569:1 1572:1 4485:3 5678:4 3631:16 5683:1 5686:1 5687:1 5688:2 5689:5 3646:4 3648:3 1608:15 951:1 5718:2 1625:2 3692:2 274:1 1646:4 3695:1 5751:1 5762:2 3727:3 3737:1 1690:3 5787:1 5794:1 3747:3 5799:4 5805:1 5808:5 3763:4 1716:2 287:1 1725:1 5825:1 7559:1 7457:4 3785:2 5834:1 1746:1 3795:1 1751:15 5859:1 1764:6 5863:1 4392:1 1789:1 5896:1 3860:3 1813:5 5912:1 1822:5 1826:1 3875:6 1828:1 3879:3 3880:1 353:2 3885:6 5934:1 3890:1 6451:2 5946:8 5947:1 3901:3 2653:3 3905:2 5955:2 3908:2 1861:1 1862:1 5959:1 1494:1 5431:1 7139:4 3925:4 5974:1 5975:1 3931:1 1884:3 881:1 1888:1 4411:1 3944:2 3948:1 3949:1 3951:2 3956:5 1910:1 3961:1 6010:1 1918:2 6016:1 320:4 5441:1 3976:1 6027:2 3985:1 1947:1 6045:3 4001:1 6811:1 4009:4 1965:1 1966:1 1967:1 328:1 6131:1 4085:2 1985:1 6083:1 4036:1 4039:1 6135:1 1996:3 6093:1 1999:1 1016:1 4054:5 4055:1 4060:1 2016:2 4432:1 4073:1 2028:5 2035:1 6133:1 2039:5 4436:1");
+ float score = VowpalWabbitInterface.Learn(vw, example);
+
+ for (int i = 0; i < 10; i++)
+ {
+ float topicPrediction = VowpalWabbitInterface.GetTopicPrediction(example, (IntPtr)i);
+ Console.Write("{0} ", topicPrediction);
}
+ Console.Write("\n");
+
+ VowpalWabbitInterface.FinishExample(vw, example);
}
+
private static void RunVWParse_and_VWLearn()
{
// parse and cache
@@ -280,7 +299,7 @@ namespace cs_test
GCHandle pinnedFeatureSpace = GCHandle.Alloc(featureSpace, GCHandleType.Pinned);
IntPtr featureSpacePtr = pinnedFeatureSpace.AddrOfPinnedObject();
- IntPtr importedExample = VowpalWabbitInterface.ImportExample(vw, featureSpacePtr, vwInstanceEx.featureSpace.Length);
+ IntPtr importedExample = VowpalWabbitInterface.ImportExample(vw, featureSpacePtr, (IntPtr)vwInstanceEx.featureSpace.Length);
VowpalWabbitInterface.Learn(vw, importedExample);
VowpalWabbitInterface.FinishExample(vw, importedExample);
diff --git a/cs_test/VowpalWabbitInterface.cs b/cs_test/VowpalWabbitInterface.cs
index 2a1b412f..3b63765c 100644
--- a/cs_test/VowpalWabbitInterface.cs
+++ b/cs_test/VowpalWabbitInterface.cs
@@ -1,108 +1,130 @@
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Runtime.InteropServices;
-
-namespace Microsoft.Research.MachineLearning
-{
- public sealed class VowpalWabbitInterface
- {
- [StructLayout(LayoutKind.Sequential)]
- public struct FEATURE_SPACE
- {
- public byte name;
- public IntPtr features; // points to a FEATURE[]
- public int len;
- }
-
- [StructLayout(LayoutKind.Sequential)]
- public struct FEATURE
- {
- public float x;
- public uint weight_index;
- }
-
- [DllImport("libvw.dll", EntryPoint = "VW_Initialize", CallingConvention = CallingConvention.StdCall)]
- public static extern IntPtr Initialize([MarshalAs(UnmanagedType.LPWStr)]string arguments);
-
- [DllImport("libvw.dll", EntryPoint = "VW_Finish", CallingConvention = CallingConvention.StdCall)]
- public static extern void Finish(IntPtr vw);
-
- [DllImport("libvw.dll", EntryPoint = "VW_ImportExample", CallingConvention = CallingConvention.StdCall)]
- // features points to a FEATURE_SPACE[]
- public static extern IntPtr ImportExample(IntPtr vw, IntPtr features, int length);
-
- [DllImport("libvw.dll", EntryPoint = "VW_ExportExample", CallingConvention = CallingConvention.StdCall)]
- public static extern IntPtr ExportExample(IntPtr vw, IntPtr example, ref int length);
-
- [DllImport("libvw.dll", EntryPoint = "VW_ReleaseFeatureSpace", CallingConvention = CallingConvention.StdCall)]
- public static extern IntPtr ReleaseFeatureSpace(IntPtr fs, int length);
-
- [DllImport("libvw.dll", EntryPoint = "VW_ReadExample", CallingConvention = CallingConvention.StdCall)]
- public static extern IntPtr ReadExample(IntPtr vw, [MarshalAs(UnmanagedType.LPWStr)]string exampleString);
-
- [DllImport("libvw.dll", EntryPoint = "VW_StartParser", CallingConvention = CallingConvention.StdCall)]
- public static extern void StartParser(IntPtr vw, bool do_init);
-
- [DllImport("libvw.dll", EntryPoint = "VW_EndParser", CallingConvention = CallingConvention.StdCall)]
- public static extern void EndParser(IntPtr vw);
-
- [DllImport("libvw.dll", EntryPoint = "VW_GetExample", CallingConvention = CallingConvention.StdCall)]
- public static extern IntPtr GetExample(IntPtr parser);
-
- [DllImport("libvw.dll", EntryPoint = "VW_FinishExample", CallingConvention = CallingConvention.StdCall)]
- public static extern void FinishExample(IntPtr vw, IntPtr example);
-
- [DllImport("libvw.dll", EntryPoint = "VW_GetLabel", CallingConvention = CallingConvention.StdCall)]
- public static extern float GetLabel(IntPtr example);
-
- [DllImport("libvw.dll", EntryPoint = "VW_GetImportance", CallingConvention = CallingConvention.StdCall)]
- public static extern float GetImportance(IntPtr example);
-
- [DllImport("libvw.dll", EntryPoint = "VW_GetInitial", CallingConvention = CallingConvention.StdCall)]
- public static extern float GetInitial(IntPtr example);
-
- [DllImport("libvw.dll", EntryPoint = "VW_GetPrediction", CallingConvention = CallingConvention.StdCall)]
- public static extern float GetPrediction(IntPtr example);
-
- [DllImport("libvw.dll", EntryPoint = "VW_GetTagLength", CallingConvention = CallingConvention.StdCall)]
- public static extern UInt32 GetTagLength(IntPtr example);
-
- [DllImport("libvw.dll", EntryPoint = "VW_GetTag", CallingConvention = CallingConvention.StdCall)]
- public static extern byte GetTag(IntPtr example);
-
- [DllImport("libvw.dll", EntryPoint = "VW_GetFeatureNumber", CallingConvention = CallingConvention.StdCall)]
- public static extern UInt32 GetFeatureNumber(IntPtr example);
-
- [DllImport("libvw.dll", EntryPoint = "VW_GetFeatures", CallingConvention = CallingConvention.StdCall)]
- public static extern IntPtr GetFeatures(IntPtr vw, IntPtr example, ref int length);
-
- [DllImport("libvw.dll", EntryPoint = "VW_ReturnFeatures", CallingConvention = CallingConvention.StdCall)]
- public static extern void ReturnFeatures(IntPtr features);
-
- [DllImport("libvw.dll", EntryPoint = "VW_HashSpace", CallingConvention = CallingConvention.StdCall)]
- public static extern uint HashSpace(IntPtr vw, [MarshalAs(UnmanagedType.LPWStr)]string s);
-
- [DllImport("libvw.dll", EntryPoint = "VW_HashFeature", CallingConvention = CallingConvention.StdCall)]
- public static extern uint HashFeature(IntPtr vw, [MarshalAs(UnmanagedType.LPWStr)]string s, ulong u);
-
- [DllImport("libvw.dll", EntryPoint = "VW_Learn", CallingConvention = CallingConvention.StdCall)]
- public static extern float Learn(IntPtr vw, IntPtr example);
-
- [DllImport("libvw.dll", EntryPoint = "VW_AddLabel", CallingConvention = CallingConvention.StdCall)]
- public static extern void AddLabel(IntPtr example, float label=float.MaxValue, float weight=1, float initial=0);
-
- [DllImport("libvw.dll", EntryPoint = "VW_Get_Weight", CallingConvention = CallingConvention.StdCall)]
- public static extern float Get_Weight(IntPtr vw, UInt32 index, UInt32 offset);
-
- [DllImport("libvw.dll", EntryPoint = "VW_Set_Weight", CallingConvention = CallingConvention.StdCall)]
- public static extern void Set_Weight(IntPtr vw, UInt32 index, UInt32 offset, float value);
-
- [DllImport("libvw.dll", EntryPoint = "VW_Num_Weights", CallingConvention = CallingConvention.StdCall)]
- public static extern UInt32 Num_Weights(IntPtr vw);
-
- [DllImport("libvw.dll", EntryPoint = "VW_Get_Stride", CallingConvention = CallingConvention.StdCall)]
- public static extern UInt32 Get_Stride(IntPtr vw);
- }
-}
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Runtime.InteropServices;
+
+namespace Microsoft.Research.MachineLearning
+{
+ using SizeT = IntPtr;
+ using VwHandle = IntPtr;
+ using VwFeatureSpace = IntPtr;
+ using VwExample = IntPtr;
+ using VwFeature = IntPtr;
+ using BytePtr = IntPtr;
+
+ public sealed class VowpalWabbitInterface
+ {
+ private const string LIBVW = "libvw.dll";
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct FEATURE_SPACE
+ {
+ public byte name;
+ public IntPtr features; // points to a FEATURE[]
+ public int len;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct FEATURE
+ {
+ public float x;
+ public uint weight_index;
+ }
+
+ [DllImport(LIBVW, EntryPoint = "VW_Initialize")]
+ public static extern VwHandle Initialize([MarshalAs(UnmanagedType.LPWStr)]string arguments);
+
+ [DllImport(LIBVW, EntryPoint = "VW_Finish")]
+ public static extern void Finish(VwHandle vw);
+
+ [DllImport(LIBVW, EntryPoint = "VW_ImportExample")]
+ // features points to a FEATURE_SPACE[]
+ public static extern VwExample ImportExample(VwHandle vw, VwFeatureSpace features, SizeT length);
+
+ [DllImport(LIBVW, EntryPoint = "VW_ExportExample")]
+ public static extern VwFeatureSpace ExportExample(VwHandle vw, VwExample example, ref SizeT length);
+
+ [DllImport(LIBVW, EntryPoint = "VW_ReleaseFeatureSpace")]
+ public static extern void ReleaseFeatureSpace(VwFeatureSpace fs, SizeT length);
+
+ [DllImport(LIBVW, EntryPoint = "VW_ReadExample")]
+ public static extern VwExample ReadExample(VwHandle vw, [MarshalAs(UnmanagedType.LPWStr)]string exampleString);
+
+ // Have to marshal bools, C++ considers them 4 byte quantities, and C# considers them 1 byte.
+ [DllImport(LIBVW, EntryPoint = "VW_StartParser")]
+ public static extern void StartParser(VwHandle vw, [MarshalAs(UnmanagedType.Bool)]bool do_init);
+
+ [DllImport(LIBVW, EntryPoint = "VW_EndParser")]
+ public static extern void EndParser(VwHandle vw);
+
+ [DllImport(LIBVW, EntryPoint = "VW_GetExample")]
+ public static extern VwExample GetExample(VwHandle parser);
+
+ [DllImport(LIBVW, EntryPoint = "VW_FinishExample")]
+ public static extern void FinishExample(VwHandle vw, VwExample example);
+
+ [DllImport(LIBVW, EntryPoint = "VW_GetTopicPrediction")]
+ public static extern float GetTopicPrediction(VwExample example, SizeT i);
+
+ [DllImport(LIBVW, EntryPoint = "VW_GetLabel")]
+ public static extern float GetLabel(VwExample example);
+
+ [DllImport(LIBVW, EntryPoint = "VW_GetImportance")]
+ public static extern float GetImportance(VwExample example);
+
+ [DllImport(LIBVW, EntryPoint = "VW_GetInitial")]
+ public static extern float GetInitial(VwExample example);
+
+ [DllImport(LIBVW, EntryPoint = "VW_GetPrediction")]
+ public static extern float GetPrediction(VwExample example);
+
+ [DllImport(LIBVW, EntryPoint = "VW_GetTagLength")]
+ public static extern SizeT GetTagLength(VwExample example);
+
+ // Saying this returned a byte was inappropriate, because you were returning
+ // actually a pointer to a seqeunce of bytes. (Not sure what the interpretation
+ // of this should be, utf8 or something?)
+ [DllImport(LIBVW, EntryPoint = "VW_GetTag")]
+ public static extern BytePtr GetTag(VwExample example);
+
+ [DllImport(LIBVW, EntryPoint = "VW_GetFeatureNumber")]
+ public static extern SizeT GetFeatureNumber(VwExample example);
+
+ // Same note regarding ref int vs size_t*
+ [DllImport(LIBVW, EntryPoint = "VW_GetFeatures")]
+ public static extern VwFeature GetFeatures(VwHandle vw, VwExample example, ref SizeT length);
+
+ [DllImport(LIBVW, EntryPoint = "VW_ReturnFeatures")]
+ public static extern void ReturnFeatures(VwExample features);
+
+ [DllImport(LIBVW, EntryPoint = "VW_HashSpace")]
+ public static extern uint HashSpace(VwHandle vw, [MarshalAs(UnmanagedType.LPWStr)]string s);
+
+ // The DLL defines the last argument "u" as being an "unsigned long".
+ // In C++ under current circumstances, both ints and longs are four byte integers.
+ // If you wanted an eight byte integer you should use "long long" (or probably
+ // more appropriately in this circumstance size_t).
+ // In C#, "int" is four bytes, "long" is eight bytes.
+ [DllImport(LIBVW, EntryPoint = "VW_HashFeature")]
+ public static extern uint HashFeature(VwHandle vw, [MarshalAs(UnmanagedType.LPWStr)]string s, uint u);
+
+ [DllImport(LIBVW, EntryPoint = "VW_Learn")]
+ public static extern float Learn(VwHandle vw, VwExample example);
+
+ [DllImport(LIBVW, EntryPoint = "VW_AddLabel")]
+ public static extern void AddLabel(VwExample example, float label = float.MaxValue, float weight = 1, float initial = 0);
+
+ [DllImport(LIBVW, EntryPoint = "VW_Get_Weight")]
+ public static extern float Get_Weight(VwHandle vw, SizeT index, SizeT offset);
+
+ [DllImport(LIBVW, EntryPoint = "VW_Set_Weight")]
+ public static extern void Set_Weight(VwHandle vw, SizeT index, SizeT offset, float value);
+
+ [DllImport(LIBVW, EntryPoint = "VW_Num_Weights")]
+ public static extern SizeT Num_Weights(VwHandle vw);
+
+ [DllImport(LIBVW, EntryPoint = "VW_Get_Stride")]
+ public static extern SizeT Get_Stride(VwHandle vw);
+ }
+}
diff --git a/vowpalwabbit/example.cc b/vowpalwabbit/example.cc
index cdd934a3..40cd2d92 100644
--- a/vowpalwabbit/example.cc
+++ b/vowpalwabbit/example.cc
@@ -119,7 +119,7 @@ feature* get_features(vw& all, example* ec, size_t& feature_map_len)
{
features_and_source fs;
fs.stride_shift = all.reg.stride_shift;
- fs.mask = all.reg.weight_mask >> all.reg.stride_shift;
+ fs.mask = (uint32_t)all.reg.weight_mask >> all.reg.stride_shift;
fs.base = all.reg.weight_vector;
GD::foreach_feature<features_and_source, vec_store>(all, *ec, fs);
feature_map_len = fs.feature_map.size();
diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc
index 6991db8d..36beff8b 100644
--- a/vowpalwabbit/gd.cc
+++ b/vowpalwabbit/gd.cc
@@ -75,14 +75,14 @@ namespace GD
if (normalized) {
if (sqrt_rate)
{
- float avg_norm = g.total_weight / g.normalized_sum_norm_x;
+ float avg_norm = (float) g.total_weight / (float) g.normalized_sum_norm_x;
if (adaptive)
return sqrt(avg_norm);
else
return avg_norm;
}
else
- return powf(g.normalized_sum_norm_x / g.total_weight, g.neg_norm_power);
+ return powf( (float) g.normalized_sum_norm_x / (float) g.total_weight, g.neg_norm_power);
}
return 1.f;
}
diff --git a/vowpalwabbit/libvw.vcxproj b/vowpalwabbit/libvw.vcxproj
index 689d4a62..efaea3c2 100644
--- a/vowpalwabbit/libvw.vcxproj
+++ b/vowpalwabbit/libvw.vcxproj
@@ -142,13 +142,13 @@
<PropertyGroup Condition="'$(Platform)'=='x64'">
<BoostIncludeDir>c:\boost\x64\include\boost-1_55</BoostIncludeDir>
<BoostLibDir>c:\boost\x64\lib</BoostLibDir>
- <ZlibIncludeDir>..\..\zlib-1.2.7</ZlibIncludeDir>
+ <ZlibIncludeDir>..\..\zlib-1.2.8</ZlibIncludeDir>
<ZlibLibDir>$(ZlibIncludeDir)\contrib\vstudio\vc10\x64\ZlibStat$(Configuration)</ZlibLibDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Platform)'=='Win32'">
<BoostIncludeDir>c:\boost\x86\include\boost-1_55</BoostIncludeDir>
<BoostLibDir>c:\boost\x86\lib</BoostLibDir>
- <ZlibIncludeDir>..\..\zlib-1.2.7</ZlibIncludeDir>
+ <ZlibIncludeDir>..\..\zlib-1.2.8</ZlibIncludeDir>
<ZlibLibDir>$(ZlibIncludeDir)\contrib\vstudio\vc10\x86\ZlibStat$(Configuration)</ZlibLibDir>
</PropertyGroup>
<Target Name="DepCheck">
diff --git a/vowpalwabbit/parser.cc b/vowpalwabbit/parser.cc
index 5f1b1654..5a58ef09 100644
--- a/vowpalwabbit/parser.cc
+++ b/vowpalwabbit/parser.cc
@@ -1087,6 +1087,11 @@ example* get_example(parser* p)
}
}
+float get_topic_prediction(example* ec, size_t i)
+{
+ return ec->topic_predictions[i];
+}
+
float get_label(example* ec)
{
return ((label_data*)(ec->ld))->label;
diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc
index 9612a2d5..8c99f4f0 100644
--- a/vowpalwabbit/stagewise_poly.cc
+++ b/vowpalwabbit/stagewise_poly.cc
@@ -3,6 +3,8 @@
#include "simple_label.h"
#include "allreduce.h"
#include "accumulate.h"
+#include "constant.h"
+#include "memory.h"
#include <float.h>
//#undef NDEBUG
@@ -54,6 +56,7 @@ namespace StagewisePoly
example *original_ec;
uint32_t cur_depth;
bool training;
+ int64_t last_example_counter;
size_t numpasses;
uint32_t next_batch_sz;
bool update_support;
@@ -79,6 +82,26 @@ namespace StagewisePoly
return idx >> poly.all->reg.stride_shift;
}
+ inline uint32_t do_ft_offset(const stagewise_poly &poly, uint32_t idx)
+ {
+ //cout << poly.synth_ec.ft_offset << " " << poly.original_ec->ft_offset << endl;
+ assert(!poly.original_ec || poly.synth_ec.ft_offset == poly.original_ec->ft_offset);
+ return idx + poly.synth_ec.ft_offset;
+ }
+
+ inline uint32_t un_ft_offset(const stagewise_poly &poly, uint32_t idx)
+ {
+ assert(!poly.original_ec || poly.synth_ec.ft_offset == poly.original_ec->ft_offset);
+ if (poly.synth_ec.ft_offset == 0)
+ return idx;
+ else {
+ while (idx < poly.synth_ec.ft_offset) {
+ idx += poly.all->length() << poly.all->reg.stride_shift;
+ }
+ return idx - poly.synth_ec.ft_offset;
+ }
+ }
+
inline uint32_t wid_mask(const stagewise_poly &poly, uint32_t wid)
{
return wid & poly.all->reg.weight_mask;
@@ -91,7 +114,7 @@ namespace StagewisePoly
inline uint32_t constant_feat(const stagewise_poly &poly)
{
- return stride_shift(poly, constant);
+ return stride_shift(poly, constant * poly.all->wpp);
}
inline uint32_t constant_feat_masked(const stagewise_poly &poly)
@@ -107,7 +130,7 @@ namespace StagewisePoly
void depthsbits_create(stagewise_poly &poly)
{
- poly.depthsbits = (uint8_t *) malloc(depthsbits_sizeof(poly));
+ poly.depthsbits = (uint8_t *) calloc_or_die(1, depthsbits_sizeof(poly));
for (uint32_t i = 0; i < poly.all->length() * 2; i += 2) {
poly.depthsbits[i] = default_depth;
poly.depthsbits[i+1] = indicator_bit;
@@ -122,23 +145,27 @@ namespace StagewisePoly
inline bool parent_get(const stagewise_poly &poly, uint32_t wid)
{
assert(wid % stride_shift(poly, 1) == 0);
- return poly.depthsbits[wid_mask_un_shifted(poly, wid) * 2 + 1] & parent_bit;
+ assert(do_ft_offset(poly, wid) % stride_shift(poly, 1) == 0);
+ return poly.depthsbits[wid_mask_un_shifted(poly, do_ft_offset(poly, wid)) * 2 + 1] & parent_bit;
}
inline void parent_toggle(stagewise_poly &poly, uint32_t wid)
{
assert(wid % stride_shift(poly, 1) == 0);
- poly.depthsbits[wid_mask_un_shifted(poly, wid) * 2 + 1] ^= parent_bit;
+ assert(do_ft_offset(poly, wid) % stride_shift(poly, 1) == 0);
+ poly.depthsbits[wid_mask_un_shifted(poly, do_ft_offset(poly, wid)) * 2 + 1] ^= parent_bit;
}
inline bool cycle_get(const stagewise_poly &poly, uint32_t wid)
{
+ //note: intentionally leaving out ft_offset.
assert(wid % stride_shift(poly, 1) == 0);
return poly.depthsbits[wid_mask_un_shifted(poly, wid) * 2 + 1] & cycle_bit;
}
inline void cycle_toggle(stagewise_poly &poly, uint32_t wid)
{
+ //note: intentionally leaving out ft_offset.
assert(wid % stride_shift(poly, 1) == 0);
poly.depthsbits[wid_mask_un_shifted(poly, wid) * 2 + 1] ^= cycle_bit;
}
@@ -146,15 +173,36 @@ namespace StagewisePoly
inline uint8_t min_depths_get(const stagewise_poly &poly, uint32_t wid)
{
assert(wid % stride_shift(poly, 1) == 0);
- return poly.depthsbits[stride_un_shift(poly, wid) * 2];
+ assert(do_ft_offset(poly, wid) % stride_shift(poly, 1) == 0);
+ return poly.depthsbits[stride_un_shift(poly, do_ft_offset(poly, wid)) * 2];
}
inline void min_depths_set(stagewise_poly &poly, uint32_t wid, uint8_t depth)
{
assert(wid % stride_shift(poly, 1) == 0);
- poly.depthsbits[stride_un_shift(poly, wid) * 2] = depth;
+ assert(do_ft_offset(poly, wid) % stride_shift(poly, 1) == 0);
+ poly.depthsbits[stride_un_shift(poly, do_ft_offset(poly, wid)) * 2] = depth;
}
+#ifndef NDEBUG
+ void sanity_check_state(stagewise_poly &poly)
+ {
+ for (uint32_t i = 0; i != poly.all->length(); ++i)
+ {
+ uint32_t wid = stride_shift(poly, i);
+
+ assert( ! cycle_get(poly,wid) );
+
+ assert( ! (min_depths_get(poly, wid) == default_depth && parent_get(poly, wid)) );
+
+ assert( ! (min_depths_get(poly, wid) == default_depth && fabsf(poly.all->reg.weight_vector[wid]) > 0) );
+ //assert( min_depths_get(poly, wid) != default_depth && fabsf(poly.all->reg.weight_vector[wid]) < tolerance );
+
+ assert( ! (poly.depthsbits[wid_mask_un_shifted(poly, wid) * 2 + 1] & ~(parent_bit + cycle_bit + indicator_bit)) );
+ }
+ }
+#endif //NDEBUG
+
//Note. OUTPUT & INPUT masked.
//It is very important that this function is invariant to stride.
inline uint32_t child_wid(const stagewise_poly &poly, uint32_t wi_atomic, uint32_t wi_general)
@@ -197,7 +245,7 @@ namespace StagewisePoly
cout << ", new size " << poly.sd_len << endl;
#endif //DEBUG
free(poly.sd); //okay for null.
- poly.sd = (sort_data *) malloc(poly.sd_len * sizeof(sort_data));
+ poly.sd = (sort_data *) calloc_or_die(poly.sd_len, sizeof(sort_data));
}
assert(len <= poly.sd_len);
}
@@ -235,6 +283,13 @@ namespace StagewisePoly
void sort_data_update_support(stagewise_poly &poly)
{
assert(poly.num_examples);
+
+ //ft_offset affects parent_set / parent_get. This state must be reset at end.
+ uint32_t pop_ft_offset = poly.original_ec->ft_offset;
+ poly.synth_ec.ft_offset = 0;
+ assert(poly.original_ec);
+ poly.original_ec->ft_offset = 0;
+
uint32_t num_new_features = (uint32_t)pow(poly.sum_input_sparsity * 1.0f / poly.num_examples, poly.sched_exponent);
num_new_features = (num_new_features > poly.all->length()) ? (uint32_t)poly.all->length() : num_new_features;
sort_data_ensure_sz(poly, num_new_features);
@@ -302,7 +357,15 @@ namespace StagewisePoly
for (uint32_t depth = 0; depth <= poly.max_depth && depth < sizeof(poly.depths) / sizeof(*poly.depths); ++depth)
cout << " [" << depth << "] = " << poly.depths[depth];
cout << endl;
+
+ cout << "Sanity check after sort... " << flush;
+ sanity_check_state(poly);
+ cout << "done" << endl;
#endif //DEBUG
+
+ //it's okay that these may have been initially unequal; synth_ec value irrelevant so far.
+ poly.original_ec->ft_offset = pop_ft_offset;
+ poly.synth_ec.ft_offset = pop_ft_offset;
}
void synthetic_reset(stagewise_poly &poly, example &ec)
@@ -310,6 +373,26 @@ namespace StagewisePoly
poly.synth_ec.ld = ec.ld;
poly.synth_ec.tag = ec.tag;
poly.synth_ec.example_counter = ec.example_counter;
+
+ /**
+ * Some comments on ft_offset.
+ *
+ * The plan is to do the feature mapping dfs with weight indices ignoring
+ * the ft_offset. This is because ft_offset is then added at the end,
+ * guaranteeing local/strided access on synth_ec. This might not matter
+ * too much in this implementation (where, e.g., --oaa runs one after the
+ * other, not interleaved), but who knows.
+ *
+ * (The other choice is to basically ignore adjusting for ft_offset when
+ * doing the traversal, which means synth_ec.ft_offset is 0 here...)
+ *
+ * Anyway, so here is how ft_offset matters:
+ * - synthetic_create_rec must "normalize it out" of the fed weight value
+ * - parent and min_depths set/get are adjusted for it.
+ * - cycle set/get are not adjusted for it, since it doesn't matter for them.
+ * - operations on the whole weight vector (sorting, save_load, all_reduce)
+ * ignore ft_offset, just treat the thing as a flat vector.
+ **/
poly.synth_ec.ft_offset = ec.ft_offset;
poly.synth_ec.test_only = ec.test_only;
@@ -339,7 +422,8 @@ namespace StagewisePoly
void synthetic_create_rec(stagewise_poly &poly, float v, float &w)
{
- uint32_t wid_atomic = (uint32_t)((&w - poly.all->reg.weight_vector));
+ //Note: need to un_ft_shift since gd::foreach_feature bakes in the offset.
+ uint32_t wid_atomic = wid_mask(poly, un_ft_offset(poly, (uint32_t)((&w - poly.all->reg.weight_vector))));
uint32_t wid_cur = child_wid(poly, wid_atomic, poly.synth_rec_f.weight_index);
assert(wid_atomic % stride_shift(poly, 1) == 0);
@@ -349,12 +433,15 @@ namespace StagewisePoly
//below is run at training time).
if (poly.cur_depth < min_depths_get(poly, wid_cur) && poly.training) {
if (parent_get(poly, wid_cur)) {
- //#ifdef DEBUG
- /* cout
+#ifdef DEBUG
+ cout
<< "FOUND A TRANSPLANT!!! moving [" << wid_cur
<< "] from depth " << (uint32_t) min_depths_get(poly, wid_cur)
- << " to depth " << poly.cur_depth << endl;*/
- //#endif //DEBUG
+ << " to depth " << poly.cur_depth << endl;
+#endif //DEBUG
+ //XXX arguably, should also fear transplants that occured with
+ //a different ft_offset ; e.g., need to look out for cross-reduction
+ //collisions. Have not played with this issue yet...
parent_toggle(poly, wid_cur);
}
min_depths_set(poly, wid_cur, poly.cur_depth);
@@ -395,8 +482,7 @@ namespace StagewisePoly
poly.cur_depth = 0;
poly.synth_rec_f.x = 1.0;
- poly.synth_rec_f.weight_index = constant_feat_masked(poly);
- poly.original_ec = &ec;
+ poly.synth_rec_f.weight_index = constant_feat_masked(poly); //note: not ft_offset'd
poly.training = training;
/*
* Another choice is to mark the constant feature as the single initial
@@ -416,16 +502,17 @@ namespace StagewisePoly
void predict(stagewise_poly &poly, learner &base, example &ec)
{
+ poly.original_ec = &ec;
synthetic_create(poly, ec, false);
base.predict(poly.synth_ec);
- label_data *ld = (label_data *) ec.ld;
- if (ld->label != FLT_MAX)
- ec.loss = poly.all->loss->getLoss(poly.all->sd, ld->prediction, ld->label) * ld->weight;
+ ec.partial_prediction = poly.synth_ec.partial_prediction;
+ ec.updated_prediction = poly.synth_ec.updated_prediction;
}
void learn(stagewise_poly &poly, learner &base, example &ec)
{
bool training = poly.all->training && ((label_data *) ec.ld)->label != FLT_MAX;
+ poly.original_ec = &ec;
if (training) {
if(poly.update_support) {
@@ -435,15 +522,21 @@ namespace StagewisePoly
synthetic_create(poly, ec, training);
base.learn(poly.synth_ec);
- ec.loss = poly.synth_ec.loss;
+ ec.partial_prediction = poly.synth_ec.partial_prediction;
+ ec.updated_prediction = poly.synth_ec.updated_prediction;
if (ec.example_counter
+ //following line is to avoid repeats when multiple reductions on same example.
+ //XXX ideally, would get all "copies" of an example before scheduling the support
+ //update, but how do we know?
+ && poly.last_example_counter != ec.example_counter
&& poly.batch_sz
&& ( (poly.batch_sz_double && !(ec.example_counter % poly.next_batch_sz))
|| (!poly.batch_sz_double && !(ec.example_counter % poly.batch_sz)))) {
poly.next_batch_sz *= 2; //no effect when !poly.batch_sz_double
poly.update_support = (poly.all->span_server == "" || poly.numpasses == 1);
}
+ poly.last_example_counter = ec.example_counter;
} else
predict(poly, base, ec);
}
@@ -477,27 +570,6 @@ namespace StagewisePoly
}
}
-
- void sanity_check_state(stagewise_poly &poly)
- {
- for (uint32_t i = 0; i != poly.all->length(); ++i)
- {
-#ifndef NDEBUG
- uint32_t wid = stride_shift(poly, i);
-#endif //NDEBUG
-
- assert( ! cycle_get(poly,wid) );
-
- assert( ! (min_depths_get(poly, wid) == default_depth && parent_get(poly, wid)) );
-
- assert( ! (min_depths_get(poly, wid) == default_depth && fabsf(poly.all->reg.weight_vector[wid]) > 0) );
- //assert( min_depths_get(poly, wid) != default_depth && fabsf(poly.all->reg.weight_vector[wid]) < tolerance );
-
- assert( ! (poly.depthsbits[wid_mask_un_shifted(poly, wid) * 2 + 1] & ~(parent_bit + cycle_bit + indicator_bit)) );
- }
- }
-
-
void end_pass(stagewise_poly &poly)
{
if (!!poly.batch_sz || (poly.all->span_server != "" && poly.numpasses > 1))
@@ -520,7 +592,8 @@ namespace StagewisePoly
* But it's unclear what the right behavior is in general for either
* case...
*/
- all_reduce<uint8_t, reduce_min_max>(poly.depthsbits, 2*poly.all->length(), all.span_server, all.unique_id, all.total, all.node, all.socks);
+ all_reduce<uint8_t, reduce_min_max>(poly.depthsbits, depthsbits_sizeof(poly),
+ all.span_server, all.unique_id, all.total, all.node, all.socks);
sum_input_sparsity_inc = (uint64_t)accumulate_scalar(all, all.span_server, (float)sum_input_sparsity_inc);
sum_sparsity_inc = (uint64_t)accumulate_scalar(all, all.span_server, (float)sum_sparsity_inc);
@@ -543,11 +616,6 @@ namespace StagewisePoly
poly.update_support = true;
poly.numpasses++;
}
-
-#ifdef DEBUG
- cout << "Sanity after sort\n";
- sanity_check_state(poly);
-#endif //DEBUG
}
void finish_example(vw &all, stagewise_poly &poly, example &ec)
@@ -575,12 +643,20 @@ namespace StagewisePoly
{
if (model_file.files.size() > 0)
bin_text_read_write_fixed(model_file, (char *) poly.depthsbits, depthsbits_sizeof(poly), "", read, "", 0, text);
+
+ //unfortunately, following can't go here since save_load called before gd::save_load and thus
+ //weight vector state uninitialiazed.
+ //#ifdef DEBUG
+ // cout << "Sanity check after save_load... " << flush;
+ // sanity_check_state(poly);
+ // cout << "done" << endl;
+ //#endif //DEBUG
}
learner *setup(vw &all, po::variables_map &vm)
{
- stagewise_poly *poly = (stagewise_poly *) calloc(1, sizeof(stagewise_poly));
+ stagewise_poly *poly = (stagewise_poly *) calloc_or_die(1, sizeof(stagewise_poly));
poly->all = &all;
depthsbits_create(*poly);
@@ -610,10 +686,15 @@ namespace StagewisePoly
poly->sum_sparsity_sync = 0;
poly->sum_input_sparsity_sync = 0;
poly->num_examples_sync = 0;
+ poly->last_example_counter = -1;
poly->numpasses = 1;
poly->update_support = false;
+ poly->original_ec = NULL;
poly->next_batch_sz = poly->batch_sz;
+ //following is so that saved models know to load us.
+ all.file_options.append(" --stage_poly");
+
learner *l = new learner(poly, all.l);
l->set_learn<stagewise_poly, learn>();
l->set_predict<stagewise_poly, predict>();
diff --git a/vowpalwabbit/vw.h b/vowpalwabbit/vw.h
index 4240086d..7226f9b6 100644
--- a/vowpalwabbit/vw.h
+++ b/vowpalwabbit/vw.h
@@ -52,6 +52,7 @@ namespace VW {
void parse_example_label(vw&all, example&ec, string label);
example* new_unused_example(vw& all);
example* get_example(parser* pf);
+ float get_topic_prediction(example*ec, size_t i);//i=0 to max topic -1
float get_label(example*ec);
float get_importance(example*ec);
float get_initial(example*ec);
diff --git a/vowpalwabbit/vw_static.vcxproj b/vowpalwabbit/vw_static.vcxproj
index 2fc8e304..eec82548 100644
--- a/vowpalwabbit/vw_static.vcxproj
+++ b/vowpalwabbit/vw_static.vcxproj
@@ -236,6 +236,7 @@
<ItemGroup>
<ClInclude Include="autolink.h" />
<ClInclude Include="accumulate.h" />
+ <ClInclude Include="active.h" />
<ClInclude Include="allreduce.h" />
<ClInclude Include="bfgs.h" />
<ClInclude Include="binary.h" />
@@ -288,6 +289,7 @@
<ItemGroup>
<ClCompile Include="autolink.cc" />
<ClCompile Include="accumulate.cc" />
+ <ClCompile Include="active.cc" />
<ClCompile Include="allreduce.cc" />
<ClCompile Include="binary.cc" />
<ClCompile Include="bfgs.cc" />
diff --git a/vowpalwabbit/vwdll.cpp b/vowpalwabbit/vwdll.cpp
index dae3e519..961e341f 100644
--- a/vowpalwabbit/vwdll.cpp
+++ b/vowpalwabbit/vwdll.cpp
@@ -113,6 +113,11 @@ extern "C"
return VW::get_label(static_cast<example*>(e));
}
+ VW_DLL_MEMBER float VW_CALLING_CONV VW_GetTopicPrediction(VW_EXAMPLE e, size_t i)
+ {
+ return VW::get_topic_prediction(static_cast<example*>(e), i);
+ }
+
VW_DLL_MEMBER float VW_CALLING_CONV VW_GetImportance(VW_EXAMPLE e)
{
return VW::get_importance(static_cast<example*>(e));
diff --git a/vowpalwabbit/vwdll.h b/vowpalwabbit/vwdll.h
index d6b78b18..820640e4 100644
--- a/vowpalwabbit/vwdll.h
+++ b/vowpalwabbit/vwdll.h
@@ -50,6 +50,7 @@ extern "C"
VW_DLL_MEMBER float VW_CALLING_CONV VW_GetImportance(VW_EXAMPLE e);
VW_DLL_MEMBER float VW_CALLING_CONV VW_GetInitial(VW_EXAMPLE e);
VW_DLL_MEMBER float VW_CALLING_CONV VW_GetPrediction(VW_EXAMPLE e);
+ VW_DLL_MEMBER float VW_CALLING_CONV VW_GetTopicPrediction(VW_EXAMPLE e, size_t i);
VW_DLL_MEMBER size_t VW_CALLING_CONV VW_GetTagLength(VW_EXAMPLE e);
VW_DLL_MEMBER const char* VW_CALLING_CONV VW_GetTag(VW_EXAMPLE e);
VW_DLL_MEMBER size_t VW_CALLING_CONV VW_GetFeatureNumber(VW_EXAMPLE e);