Welcome to mirror list, hosted at ThFree Co, Russian Federation.

Learning_to_Search.ipynb « python - github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 961b6c119e75e68ed017dd12f3b783ee08aab963 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
{
 "metadata": {
  "name": "",
  "signature": "sha256:3c586c0a6f1851786f8d3784d3af46a36319bb2212ee8682ec13ebff8953ab5f"
 },
 "nbformat": 3,
 "nbformat_minor": 0,
 "worksheets": [
  {
   "cells": [
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "This tutorial walks you through writing learning to search code using the VW python interface. Once you've completed this, you can graduate to the C++ version, which will be faster for the computer but more painful for you.\n",
      "\n",
      "The \"learning to search\" paradigm solves problems that look like the following. You have a sequence of decisions to make. At the end of making these decisions, the world tells you how bad your decisions were. You want to condition later decisions on earlier decisions. But thankfully, at \"training time,\" you have access to an *oracle* that will tell you the right answer."
     ]
    },
    {
     "cell_type": "heading",
     "level": 1,
     "metadata": {},
     "source": [
      "A basic part of speech tagger"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Let's start with a simple example: sequence labeling for Part of Speech tagging. The goal is to take a sequence of words (\"the monster ate a big sandwich\") and label them with their parts of speech (in this case: Det Noun Verb Det Adj Noun).\n",
      "\n",
      "We will choose to solve this problem with left-to-right search. I.e., we'll label the first word, then the second then the third and so on.\n",
      "\n",
      "For any vw project in python, we have to start by importing the pyvw library:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "import pyvw"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 37
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Now, let's define our data first. We'll do this first by defining the labels (one annoying thing is that labels in vw have to be integers):"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "DET  = 1\n",
      "NOUN = 2\n",
      "VERB = 3\n",
      "ADJ  = 4\n",
      "my_dataset = [ [(DET , 'the'),\n",
      "                (NOUN, 'monster'),\n",
      "                (VERB, 'ate'),\n",
      "                (DET , 'a'),\n",
      "                (ADJ , 'big'),\n",
      "                (NOUN, 'sandwich')],\n",
      "               [(DET , 'the'),\n",
      "                (NOUN, 'sandwich'),\n",
      "                (VERB, 'was'),\n",
      "                (ADJ , 'tasty')],\n",
      "               [(NOUN, 'it'),\n",
      "                (VERB, 'ate'),\n",
      "                (NOUN, 'it'),\n",
      "                (ADJ , 'all')] ]\n",
      "print my_dataset[2]"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "[(2, 'it'), (3, 'ate'), (2, 'it'), (4, 'all')]\n"
       ]
      }
     ],
     "prompt_number": 38
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Here we have an example of a (correctly) tagged sentence.\n",
      "\n",
      "Now, we need to write the structured prediction code. To do this, we have to write a new class that derives from the `pyvw.SearchTask` class.\n",
      "\n",
      "This class *must* have two functions: `__init__` and `_run`.\n",
      "\n",
      "The initialization function takes three arguments (plus `self`): a vw object (`vw`), a search object (`sch`), and the number of actions (`num_actions`) that this object has been initialized with. Within the initialization function, we must first initialize the parent class, and then we can set whatever options we want via `sch.set_options(...)`. Of course we can also do whatever additional initialization we like.\n",
      "\n",
      "The `_run` function executes the sequence of decisions on a given input. The input will be of whatever type our data is (so, in the above example, it will be a list of (label,word) pairs).\n",
      "\n",
      "Here is a basic implementation of sequence labeling:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "class SequenceLabeler(pyvw.SearchTask):\n",
      "    def __init__(self, vw, sch, num_actions):\n",
      "        pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
      "        sch.set_options( sch.AUTO_HAMMING_LOSS | sch.AUTO_CONDITION_FEATURES )\n",
      "\n",
      "    def _run(self, sentence):\n",
      "        output = []\n",
      "        for n in range(len(sentence)):\n",
      "            pos,word = sentence[n]\n",
      "            with self.vw.example({'w': [word]}) as ex:\n",
      "                pred = self.sch.predict(examples=ex, my_tag=n+1, oracle=pos, condition=(n,'p'))\n",
      "                output.append(pred)\n",
      "        return output"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 39
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Let's unpack this a bit.\n",
      "\n",
      "The `__init__` function is simple. It first calls the parent initializer and then sets some options. The options it sets are two things designed to make the programmer's life easier. The first is `AUTO_HAMMING_LOSS`. Remember earlier we said that when the sequence of decision is made, you have to say how bad it was? This says that we want this to be computed automatically by comparing the individual decisions to the oracle decisions, and defining the loss to be the sum of incorrect decisions.\n",
      "\n",
      "The second is `AUTO_CONDITION_FEATURES`. This is a bit subtler. Later in the `_run` function, we will say that the label of the `n`th word depends on the label of the `n-1`th word. In order to get the underlying classifier to *pay attention* to that conditioning, we need to add features. We could do that manually (we'll do this later) or we can ask vw to do it automatically for us. For simplicity, we choose the latter.\n",
      "\n",
      "The `_run` function takes a sentence (list of pos/word pairs) as input. We loop over each word position `n` in the sentence and extract the `pos,word` pair. We then construct a VW example that consists of a single feature (the `word`) in the 'w' namespace. Given that example `ex`, we make a search-based prediction by calling `self.sch.predict(...)`. This is a fairly complicated function that takes a number of arguments. Here, we are calling it with the following:\n",
      "\n",
      " - `examples=ex`: This tells the predictor what features to use.\n",
      " - `my_tag=n+1`: In general, we want to condition the prediction of the `n`th label on the `n-1`th label. In order to do this, we have to give each prediction a \"name\" so that we can refer back to it in the future. This name needs to be an integer `>= 1`. So we'll call the first word `1`, the second word `2`, and so on. It has to be `n+1` and not `n` because of the 1-based requirement.\n",
      " - `oracle=pos`: As mentioned before, on training data, we need to tell the system what the \"true\" (or \"best\") decision is at this point in time. Here, it is the given part of speech label.\n",
      " - `condition=(n,'p')`: This says that this prediction depends on the output of whichever-prediction-was-called-`n`, and that the \"nature\" of that condition is called 'p' (for \"predecessor\" in this case, though this is entirely up to you)\n",
      "\n",
      "Now, we're ready to train the model. We do this in three steps. First, we initialize a vw object, telling it that we have a `--search` task with 4 labels, second that the specific type of `--search_task` is `hook` (you will always use the `hook` task) and finally that we want it to be quiet and use a larger `ring_size` (you can ignore the `ring_size` for now)."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "vw = pyvw.vw(\"--search 4 --quiet --search_task hook --ring_size 1024\")"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 40
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Next, we need to initialize the search task. We use the `vw.init_search_task` function to do this:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "sequenceLabeler = vw.init_search_task(SequenceLabeler)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 41
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Finally, we can train on the dataset we defined earlier, using `sequenceLabeler.learn` (the `.learn` function is inherited from the `pyvw.SearchTask` class). The `.learn` function takes any iterator over data. Whatever type of data it iterates over is what it will feed to your `_run` function."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "sequenceLabeler.learn(my_dataset.__iter__)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 42
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Of course, we want to see if it's learned anything. So let's create a single test example:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "test_example = [ (0,w) for w in \"the sandwich ate a monster\".split() ]\n",
      "print test_example"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "[(0, 'the'), (0, 'sandwich'), (0, 'ate'), (0, 'a'), (0, 'monster')]\n"
       ]
      }
     ],
     "prompt_number": 43
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "We've used `0` as the labels so you can be sure that vw isn't cheating and it's actually making predictions:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "out = sequenceLabeler.predict(test_example)\n",
      "print out"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "[1, 2, 3, 1, 2]\n"
       ]
      }
     ],
     "prompt_number": 44
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "If we look back at our POS tag definitions, this is DET NOUN VERB DET NOUN, which is indeed correct!"
     ]
    },
    {
     "cell_type": "heading",
     "level": 1,
     "metadata": {},
     "source": [
      "Removing the AUTO features"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "In the above example we used both AUTO_HAMMING_LOSS and AUTO_CONDITION_FEATURES. To make more explicit what these are doing, let's rewrite our `SequenceLabeler` class without them! Here's a version that gets rid of both simultaneously. It is only modestly more complex:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "class SequenceLabeler2(pyvw.SearchTask):\n",
      "    def __init__(self, vw, sch, num_actions):\n",
      "        pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
      "\n",
      "    def _run(self, sentence):\n",
      "        output = []\n",
      "        loss = 0.\n",
      "        for n in range(len(sentence)):\n",
      "            pos,word = sentence[n]\n",
      "            prevPred = output[n-1] if n > 0 else '<s>'\n",
      "            with self.vw.example({'w': [word], 'p': [prevPred]}) as ex:\n",
      "                pred = self.sch.predict(examples=ex, my_tag=n+1, oracle=pos, condition=(n,'p'))\n",
      "                output.append(pred)\n",
      "                if pred != pos:\n",
      "                    loss += 1.\n",
      "        self.sch.loss(loss)\n",
      "        return output\n",
      "    \n",
      "sequenceLabeler2 = vw.init_search_task(SequenceLabeler2)\n",
      "sequenceLabeler2.learn(my_dataset.__iter__)\n",
      "print sequenceLabeler2.predict( [(0,w) for w in \"the sandwich ate a monster\".split()] )"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "[1, 2, 3, 1, 2]\n"
       ]
      }
     ],
     "prompt_number": 45
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "If executed correctly, this should have printed `[1, 2, 3, 1, 2]`.\n",
      "\n",
      "There are essentially two things that changed here. In order to get rid of AUTO_HAMMING_LOSS, we had to keep track of how many errors the predictor had made. This is done by checking whether `pred != pos` inside the inner loop, and then at the end calling `self.sch.loss(loss)` to tell the search procedure how well we did.\n",
      "\n",
      "In order to get rid of AUTO_CONDITION_FEATURES, we need to explicitly add the previous prediction as features to the example we are predicting with. Here, we've done this by extracting the previous prediction (`prevPred`) and explicitly adding it as a feature (in the 'p' namespace) during the example construction.\n",
      "\n",
      "**Important Note:** even though we're not using AUTO_CONDITION_FEATURES, we *still* must tell the search process that this prediction depends on the previous prediction. We need to do this because the learning algorithm automatically memoizes certain computations, and so it needs to know that, when memoizing, to remember that this prediction *might have been different* if a previous decision were different."
     ]
    },
    {
     "cell_type": "heading",
     "level": 1,
     "metadata": {},
     "source": [
      "Very silly Covington-esque dependency parsing"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Let's also try a variant of dependency parsing to see that this doesn't work just for sequence-labeling list tasks. First we need to define some data:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# the label for each word is its parent, or -1 for root\n",
      "my_dataset = [ [(\"the\",      1),   # 0\n",
      "                (\"monster\",  2),   # 1\n",
      "                (\"ate\",     -1),   # 2\n",
      "                (\"a\",        5),   # 3\n",
      "                (\"big\",      5),   # 4\n",
      "                (\"sandwich\", 2) ]  # 5\n",
      "                ,\n",
      "               [(\"the\",      1),   # 0\n",
      "                (\"sandwich\", 2),   # 1\n",
      "                (\"is\",      -1),   # 2\n",
      "                (\"tasty\",    2)]   # 3\n",
      "                ,\n",
      "               [(\"a\",        1),   # 0\n",
      "                (\"sandwich\", 2),   # 1\n",
      "                (\"ate\",     -1),   # 2\n",
      "                (\"itself\",   2),   # 3\n",
      "                ]\n",
      "                ]"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 46
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "For instance, in the first sentence, the parent of \"the\" is \"monster\"; the parent of \"monster\" is \"ate\"; and \"ate\" is the root.\n",
      "\n",
      "The basic idea of a Covington-style dependency parser is to loop over all O(N^2) word pairs and ask if one is the parent of the other. In a real parser you would want to make sure that you don't have cycles, that you have a unique root and (perhaps) that the resulting graph is projective. I'm not doing that here. Hopefully I'll add a shift-reduce parser example later that *does* do this. Here's an implementation of this idea:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "class CovingtonDepParser(pyvw.SearchTask):\n",
      "    def __init__(self, vw, sch, num_actions):\n",
      "        pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
      "        sch.set_options( sch.AUTO_HAMMING_LOSS | sch.AUTO_CONDITION_FEATURES )\n",
      "\n",
      "    def _run(self, sentence):\n",
      "        N = len(sentence)\n",
      "        # initialize our output so everything is a root\n",
      "        output = [-1 for i in range(N)]\n",
      "        for n in range(N):\n",
      "            wordN,parN = sentence[n]\n",
      "            for m in range(-1,N):\n",
      "                if m == n: continue\n",
      "                wordM = sentence[m][0] if m > 0 else \"*root*\"\n",
      "                # ask the question: is m the parent of n?\n",
      "                isParent = 2 if m == parN else 1\n",
      "\n",
      "                # construct an example\n",
      "                dir = 'l' if m < n else 'r'\n",
      "                with self.vw.example({'a': [wordN, dir + '_' + wordN], 'b': [wordM, dir + '_' + wordN], 'p': [wordN + '_' + wordM, dir + '_' + wordN + '_' + wordM],\n",
      "                                      'd': [ str(m-n <= d) + '<=' + str(d) for d in [-8, -4, -2, -1, 1, 2, 4, 8] ] +\n",
      "                                           [ str(m-n >= d) + '>=' + str(d) for d in [-8, -4, -2, -1, 1, 2, 4, 8] ] }) as ex:\n",
      "                    pred = self.sch.predict(examples  = ex,\n",
      "                                            my_tag    = (m+1)*N + n + 1,\n",
      "                                            oracle    = isParent,\n",
      "                                            condition = [ (max(0, (m  )*N + n + 1), 'p'),\n",
      "                                                          (max(0, (m+1)*N + n    ), 'q') ])\n",
      "                    if pred == 2:\n",
      "                        output[n] = m\n",
      "                        break\n",
      "        return output"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 47
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "In this, `output` stores the predicted tree and is initialized with every word being a root. We then loop over every word (`n`) and every possible parent (`m`, which can be -1, though that's really kind of unnecessary).\n",
      "\n",
      "The features are basically the words under consideration, the words paired with the direction of the edge, the pair of words, and then a bunch of (binned) distance features.\n",
      "\n",
      "We can train and run this parser with:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "vw = pyvw.vw(\"--search 2 --quiet --search_task hook --ring_size 1024\")\n",
      "task = vw.init_search_task(CovingtonDepParser)\n",
      "for p in range(10): # do ten passes over the training data\n",
      "    task.learn(my_dataset.__iter__)\n",
      "print 'testing'\n",
      "print task.predict( [(w,-1) for w in \"the monster ate a sandwich\".split()] )\n",
      "print 'should have printed [ 1 2 -1 4 2 ]'"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "testing\n",
        "[1, 2, -1, 4, 2]\n",
        "should have printed [ 1 2 -1 4 2 ]\n"
       ]
      }
     ],
     "prompt_number": 48
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "One could argue that a more natural way to do this would be with LDF rather than the inner loop over `m`. We'll do that next."
     ]
    },
    {
     "cell_type": "heading",
     "level": 1,
     "metadata": {},
     "source": [
      "LDF-based Covington-style dependency parser"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "One of the weirdnesses in the previous parser implementation is that it makes N-many binary decisions per word (\"is word n my parent?\") rather than a single N-way decision. The latter makes more sense.\n",
      "\n",
      "The challenge is that you cannot set this up as a vanilla multiclass classification problem because (a) the number of \"classes\" is a function of the input (a length N sentence will have N classes) and (b) class \"1\" and \"2\" don't mean anything consistently across examples.\n",
      "\n",
      "The way around this is label-dependent features (LDF). In LDF mode, the class ids are (essentially -- see caveat below) irrelevant. Instead, you simply provide features that depend on the label (hence \"LDF\"). In particular, for each possible label, you provide a *different* feature vector, and the goal of learning is to pick one of those as the \"correct\" one.\n",
      "\n",
      "Here's a re-implementation of Covington using LDF:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "class CovingtonDepParserLDF(pyvw.SearchTask):\n",
      "    def __init__(self, vw, sch, num_actions):\n",
      "        pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
      "        sch.set_options( sch.AUTO_HAMMING_LOSS | sch.IS_LDF | sch.AUTO_CONDITION_FEATURES )\n",
      "\n",
      "    def makeExample(self, sentence, n, m):\n",
      "        wordN = sentence[n][0]\n",
      "        wordM = sentence[m][0] if m >= 0 else '*ROOT*'\n",
      "        dir   = 'l' if m < n else 'r'\n",
      "        ex = self.vw.example( { 'a': [wordN, dir + '_' + wordN],\n",
      "                                'b': [wordM, dir + '_' + wordN],\n",
      "                                'p': [wordN + '_' + wordM, dir + '_' + wordN + '_' + wordM],\n",
      "                                'd': [ str(m-n <= d) + '<=' + str(d) for d in [-8, -4, -2, -1, 1, 2, 4, 8] ] +\n",
      "                                     [ str(m-n >= d) + '>=' + str(d) for d in [-8, -4, -2, -1, 1, 2, 4, 8] ] },\n",
      "                              labelType=self.vw.lCostSensitive)\n",
      "        ex.set_label_string(str(m+2) + \":0\")\n",
      "        return ex\n",
      "            \n",
      "    def _run(self, sentence):\n",
      "        N = len(sentence)\n",
      "        # initialize our output so everything is a root\n",
      "        output = [-1 for i in range(N)]\n",
      "        for n in range(N):\n",
      "            # make LDF examples\n",
      "            examples = [ self.makeExample(sentence,n,m) for m in range(-1,N) if n != m ]\n",
      "\n",
      "            # truth\n",
      "            parN = sentence[n][1]\n",
      "            oracle = parN+1 if parN < n else parN   # have to -1 because we excluded n==m from list\n",
      "\n",
      "            # make a prediction\n",
      "            pred = self.sch.predict(examples  = examples,\n",
      "                                    my_tag    = n+1,\n",
      "                                    oracle    = oracle,\n",
      "                                    condition = [ (n, 'p'), (n-1, 'q') ] )\n",
      "\n",
      "            output[n] = pred-1 if pred < n else pred # have to +1 because n==m excluded\n",
      "\n",
      "            for ex in examples: ex.finish()  # clean up\n",
      "            \n",
      "        return output\n"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 49
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "There are a few things going on here. Let's focus first on the `__init__` function. The only difference here is that when we call `sch.set_options` we provide `sch.IS_LDF` to declare that this is an LDF task.\n",
      "\n",
      "Let's skip the `makeExample` function for a minute and look at the `_run` function. You should recognize most of this from the non-LDF version. We initialize the `output` (parent) of every word to `-1` (meaning that every word is connected to the root).\n",
      "\n",
      "For each word `n`, we construct `N`-many examples: one for every -1..(N-1), except for the current word `n` because you cannot have self-loops. If we were being more clever, we would only do the ones that won't result in the creation of a cycle, but we're not being clever.\n",
      "\n",
      "Now, because the \"labels\" are just examples, it's a bit more complicated to specify the oracle. The oracle is an *index* into the examples list. So if `oracle` is the oracle action, then `examples[oracle]` is the corresponding example. We compute the oracle as follows. `parN` is the *actual* parent, which is going to be something in the range `-1 .. (N-1)`. If `parN < n` (this is a left arrow), then the oracle index is `parN+1` because the root (`-1`) is index `0` and so on. If `parN > n` (note: it cannot be equal to `n`) then, beacuse `n == m` is left out of the examples list, the correct index is just `parN`. Phew.\n",
      "\n",
      "We then ask for a prediction. Now, instead of giving a single example, with give the list of examples. The tag works the same way, as does the conditioning.\n",
      "\n",
      "Once we get a prediction out (called `pred`) we need to figure out what parent it actually corresponds to. This is simply un-doing the computaiton two paragraphs ago!\n",
      "\n",
      "Finally -- and this is skippable if you trust the Python garbage collector -- we tell VW that we're done with all the examples we created. We do this just to be pedantic; it's optional.\n",
      "\n",
      "Okay, now let's go back to the `makeExample` function. This takes two word ids (`n` and `m`) and makes an example that roughly says \"what would it look like if I had an edge from `n` to `m`?\" We construct basically the same feautres as before. There are two major changes, though:\n",
      "\n",
      "1. When we run `self.vw.example(...)` we provide `labelType=self.vw.lCostSensitive` as an argument. This is because under the hood, vw treats LDF examples as cost-sensitive classification examples. This means they need to have cost-sensitive labels, so that's how we need to create them.\n",
      "\n",
      "1. We explicitly set the label of the this example to `str(m+2)+\":0\"`. What is this? Well, this is _optional_ but recommended. Here's the issue. In LDF mode, recall that labels have no intrinsic meaning. This means that when vw does auto-conditioning, it's not really clear what to use as the \"previous prediction.\" By giving explicit label names (in this case, m+2) we're recording what the position of the last parent, which may be useful for predicting the next parent. We could avoid this necessity if we did our own feature engineering on the history, but for now, this seems to capture the right intuition.\n",
      "\n",
      "Given all this, we can now train and test our parser:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "vw = pyvw.vw(\"--search 0 --csoaa_ldf m --search_task hook --ring_size 1024 --quiet\")\n",
      "task = vw.init_search_task(CovingtonDepParserLDF)\n",
      "for p in range(2): # do two passes over the training data\n",
      "    task.learn(my_dataset.__iter__)\n",
      "print task.predict( [(w,-1) for w in \"the monster ate a sandwich\".split()] )"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "[1, 2, -1, 4, 2]\n"
       ]
      }
     ],
     "prompt_number": 50
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "The correct parse of this sentence is `[1, 2, -1, 4, 2]` which is what this should have printed.\n",
      "\n",
      "There are two major things to notice in the initialization of VW. The first is that we say `--search 0`. The zero labels argument to `--search` declares that this is going to be an LDF task. We also have to tell VW that we want an LDF-enabled cost-sensitive learner, which is what `--csoaa_ldf m` does (if you're wondering, `m` means \"multiline\" -- just treat it as something you have to do). The rest should be familiar."
     ]
    },
    {
     "cell_type": "heading",
     "level": 1,
     "metadata": {},
     "source": [
      "A simple word-alignment model"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Okay, as a last example we'll do a simple word alignment model in the spirit of the IBM models. Note that this will be a *supervised* model; doing unsupervised stuff is a bit trickier.\n",
      "\n",
      "Here's some word alignment data. The dataset is triples of `E, A, F` where `A[i]` = list of words `E[i]` aligned to, or `[]` for null-aligned:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "my_dataset = [\n",
      "    ( \"the blue house\".split(),\n",
      "      ([0], [2], [1]),\n",
      "      \"la maison bleue\".split() ),\n",
      "    ( \"the house\".split(),\n",
      "      ([0], [1]),\n",
      "      \"la maison\".split() ),\n",
      "    ( \"the flower\".split(),\n",
      "      ([0], [1]),\n",
      "      \"la fleur\".split() )\n",
      "    ]"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 51
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "It's going to be useful to compute alignment mismatches at the word level between true alignments (like `[1,2]`) and predicted alignments (like `[2,3,4]`). We use intersection-over-union error:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def alignmentError(true, sys):\n",
      "    t = set(true)\n",
      "    s = set(sys)\n",
      "    if len(t | s) == 0: return 0.\n",
      "    return 1. - float(len(t & s)) / float(len(t | s))"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 52
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "And now we can define our structured prediction task. This is also an LDF problem. Basically for each word on the English side, we'll loop over all possible phrases on the Foreign side to which it could align (maximum phrase length of three). For each of these options we'll create an example to be fed into the LDF classifier. We also ensure that the same foreign word cannot be covered by multiple English words, though this might not be a good idea in general."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "class WordAligner(pyvw.SearchTask):\n",
      "    def __init__(self, vw, sch, num_actions):\n",
      "        pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
      "        sch.set_options( sch.AUTO_HAMMING_LOSS | sch.IS_LDF | sch.AUTO_CONDITION_FEATURES )\n",
      "\n",
      "    def makeExample(self, E, F, i, j0, l):\n",
      "        f  = 'Null' if j0 is None else [ F[j0+k] for k in range(l+1) ]\n",
      "        ex = self.vw.example( { 'e': E[i],\n",
      "                                'f': f,\n",
      "                                'p': '_'.join(f),\n",
      "                                'l': str(l),\n",
      "                                'o': [str(i-j0), str(i-j0-l)] if j0 is not None else [] },\n",
      "                              labelType = self.vw.lCostSensitive )\n",
      "        lab = 'Null' if j0 is None else str(j0+l)\n",
      "        ex.set_label_string(lab + ':0')\n",
      "        return ex\n",
      "        \n",
      "    def _run(self, alignedSentence):\n",
      "        E,A,F = alignedSentence\n",
      "\n",
      "        # for each E word, we pick a F span\n",
      "        covered = {}  # which F words have been covered so far?\n",
      "        output  = []\n",
      "        \n",
      "        for i in range(len(E)):\n",
      "            examples = []  # contains vw examples\n",
      "            spans    = []  # contains triples (alignment error, index in examples, [range])\n",
      "            \n",
      "            # empty span:\n",
      "            examples.append( self.makeExample(E, F, i, None, None) )\n",
      "            spans.append( (alignmentError(A[i], []), 0, []) )\n",
      "\n",
      "            # non-empty spans\n",
      "            for j0 in range(len(F)):\n",
      "                for l in range(3):   # max phrase length of 3\n",
      "                    if j0+l >= len(F): break\n",
      "                    if covered.has_key(j0+l): break\n",
      "\n",
      "                    id = len(examples)\n",
      "                    examples.append( self.makeExample(E, F, i, j0, l) )\n",
      "                    spans.append( (alignmentError(A[i], range(j0,j0+l+1)), id, range(j0,j0+l+1)) )\n",
      "\n",
      "            sortedSpans = []\n",
      "            for s in spans: sortedSpans.append(s)\n",
      "            sortedSpans.sort()\n",
      "            oracle = []\n",
      "            for id in range(len(sortedSpans)):\n",
      "                if sortedSpans[id][0] > sortedSpans[0][0]: break\n",
      "                oracle.append( sortedSpans[id][1] )\n",
      "                \n",
      "            pred = self.sch.predict(examples  = examples,\n",
      "                                    my_tag    = i+1,\n",
      "                                    oracle    = oracle,\n",
      "                                    condition = [ (i, 'p'), (i-1, 'q') ] )\n",
      "\n",
      "            for ex in examples: ex.finish()\n",
      "\n",
      "            output.append( spans[pred][2] )\n",
      "            for j in spans[pred][2]:\n",
      "                covered[j] = True\n",
      "\n",
      "        return output"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 53
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "The only really complicated thing here is computing the oracle. What we do is, for each possible alignment, compute an intersection-over-union error rate. The oracle is then that alignment that achieves the smallest (local) error rate. This is not perfect, but is good enough. One interesting thing here is that now the `oracle` could be a *list*; this is completely supported by the underlying algorithms.\n",
      "\n",
      "We can train and test this model to make sure it does the right thing:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "vw = pyvw.vw(\"--search 0 --csoaa_ldf m --search_task hook --ring_size 1024 --quiet -q ef -q ep\")\n",
      "task = vw.init_search_task(WordAligner)\n",
      "for p in range(10):\n",
      "    task.learn(my_dataset.__iter__)\n",
      "print task.predict( (\"the blue flower\".split(), ([],[],[]), \"la fleur bleue\".split()) )"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "[[0], [2], [1]]\n"
       ]
      }
     ],
     "prompt_number": 54
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "If this worked correctly, it should have printed `[[0], [2], [1]]`."
     ]
    }
   ],
   "metadata": {}
  }
 ]
}