#!/usr/bin/perl -w # vim: ts=4 sw=4 nosmarttab # # Goal: improve learning (smaller generalization error) # # How: Inspect the learning process in real-time. Single-out # the elements leading to its biggest jumps in loss # (deviations from convergence) and drill-down to uncover # the features that underlie these biggest deviations. # # The process has two steps: # 1) Train with --progress 1 to capture the examples that cause # the biggest divergences (increase in progressive loss) in # the learning process. # 2) Audit the examples from step 1) and drill down to their # top weighted features (the features that presumably are # throwing the learning process off course: i.e. cause 'un-learning'). # # First step: # input: VW training data and args (to generate stderr with -P 1) # output: N biggest error pairs (example_number delta_error_since_last) # # Second step: # input: The same training set and args as first step # plus the output from the 1st step (N biggest errors) # output: A dump of the biggest-error examples each with its # top-N highest weight features. # # Caveats: # Running vw-top-errors only points-out the top suspects. # top suspects are not necessarily "guilty". # For example: if the constant feature is large and has a large # relative weight, it may be singled out as a strong-feature in # examples that are deviant. It is entirely possible that it is # also a strong feature in other/most examples. # # So use and enjoy this utility, but be aware of its limitations. # # -- Copyright (c) ariel faigon, 2014 # This software is released under the same licence as vowpal-wabbit. # itself. See the file LICENSE. # use Getopt::Std; use vars qw($opt_v $opt_V $opt_s $opt_a); my %ExampleDelta = (); my $TopN; my $TopWeights; my %ExampleNos; my @ExampleNos; my $DefaultTopN = 10; my $DefaultTopWeights = 5; my $VWCmd; sub v { return unless $opt_v; if (@_ == 1) { print STDERR @_; } else { printf STDERR @_; } } sub V { return unless $opt_V; if (@_ == 1) { print STDERR @_; } else { printf STDERR @_; } } sub usage(@) { print STDERR @_, "\n" if @_; die "Usage: $0 [options] [args] vw ... Will run 2 passes: 1) Learning with --progress 1: catch the biggest errors 2) Learning with --audit: intercept the biggest errors and print their highest weights. $0 options (must precede the 'vw' and its args): -s Use sum-of-weights instead of mean-of-weights in 'top adverse features' summary -a Use absolute rather than (the default) relative error (since-last / mean loss) for top N deviant examples -v verbose $0 args (must precede the 'vw' and its args): Integer: how many (top error) examples to print (Default: $DefaultTopN) Integer: how many top weights to print for each (top error) example (Default: $DefaultTopWeights) "; } sub get_args { $0 =~ s{.*/}{}; # Separate our args, from vw and its args... my (@our_args, @vw_args); my $saw_vw = 0; foreach my $arg (@ARGV) { if ($saw_vw) { push(@vw_args, $arg); next; } elsif ($arg =~ /^\S*vw$/) { $saw_vw = 1; push(@vw_args, $arg); next; } if ($arg =~ /^\d+$/) { if (! defined $TopN) { $TopN = $arg; } elsif (! defined $TopWeights) { $TopWeights = $arg; } else { printf STDERR "Too many integer args: ignoring $arg\n" } } push(@our_args, $arg); } $TopN = $DefaultTopN unless (defined $TopN); $TopWeights = $DefaultTopWeights unless (defined $TopWeights); unless (@vw_args >= 2) { usage("Expecting a 'vw' argument followed by training-args"); } $VWCmd = "@vw_args"; @ARGV = @our_args; getopts('vsa'); # Make wide-char output safe from warnings binmode STDOUT, ":utf8"; } # # collect_errors # Reads vw progress output and collects delta-loss magnitudes # in the hash %ExampleDelta. # # We're looking for positive deltas, i.e. examples where the # loss went up instead of down. # sub collect_errors($) { my $vw_stderr = shift; my $good_lines = 0; my ($prev_avgloss, $prev_sincelast); my $stderr = ''; while (<$vw_stderr>) { unless (/^([0-9.]+)\s+([0-9.]+)\s+([0-9.]+)\s/) { # Not a progress line, may be needed if 'vw' crashes etc. $stderr .= "\t$_"; next; } $good_lines++; my ($avgloss, $sincelast, $example_no) = ($1, $2, $3); if (defined $prev_sincelast) { my $delta_since_last = $sincelast - $prev_sincelast; if ($opt_a) { # absolute error (rare need) $ExampleDelta{$example_no} = $delta_since_last; } else { # relative error (default) my $relative_error = $delta_since_last / $avgloss; $ExampleDelta{$example_no} = $relative_error; } } ($prev_avgloss, $prev_sincelast) = ($avgloss, $sincelast); } wait; if ($?) { printf STDERR "%s: vw exited w/ status %d\nvw stderr:\n%s\n", $0, $? >> 8, $stderr; exit 1; } unless ($good_lines > 0) { printf STDERR "%s: no progress lines from vw training\n%s\n", $0, $stderr; exit 1; } } # # by_abs_weight_desc($a, $b) # compare function to sort features by abs(weight) largest first # sub by_abs_weight_desc { my $weight1 = (split(':', $a))[-1]; my $weight2 = (split(':', $b))[-1]; $weight1 =~ s/\@.*$//; $weight2 =~ s/\@.*$//; abs($weight2) <=> abs($weight1); } # # audit_top_weights($vwh, $top_weights, @example_nos); # $vwh file handle of vw audit stream # $top_weights number of weights to print per example # @example_nos example numbers to cover # sub audit_top_weights($$@) { my ($vw_stderr, $top_weights, @example_nos) = @_; unless (@example_nos) { die "$0: audit_top_weights: no wanted examples given\n"; } # Make sure we have example numbers in order from low to high @example_nos = sort { $a <=> $b } @example_nos; my %feature_weights = (); my %feature_examples = (); printf "=== Top-weighting underlying %d features" . " @ diverging examples:\n", $top_weights; # -- Scan the vw audit output my $example_count = 0; while (<$vw_stderr>) { last unless (@example_nos); # No need to skip regular progress lines because # we rn with --quiet # next if (/^[0-9.]+\s+[0-9.]+\s+\d+\s+[0-9]+\.[0-9]\s+/); # Look for an audit-preceding line next unless (/^-?[0-9.]+\s/); # Found an example: $example_count++; my $wanted_example = $example_nos[0]; if ($example_count == $wanted_example) { # Found the next wanted one shift @example_nos; # -- print the matching label/tag etc. header printf "Example #%d: %s", $wanted_example, $_; # Next line has the features and weights my $audit_line = <$vw_stderr>; chomp $audit_line; $audit_line =~ s/^\s+//; my @features = split(/\s+/, $audit_line); my @sorted_features = sort by_abs_weight_desc @features; printf "\t%s\t%-8s\t%s\n", 'Feature', 'Weight', 'Full-audit-data'; for (my $i = 0; $i < $TopWeights; $i++) { my $feature = $sorted_features[$i]; next unless (defined $feature); my ($name, $hash, $value, $weight) = split(':', $feature); unless (defined $weight) { # multiple passes (using cache) don't have the # name but only the hashed value of it so # fields get shifted by one... printf STDERR "Undefined weight in audit data: feature=%s\n", $feature; $weight = $value; $value = $hash; $hash = $name; $name = "[$hash]"; } $weight =~ s/\@.*$//; printf "\t%s\t%7.6f\t%s\n", $name, $weight, $feature; $feature_weights{$name} += abs($weight); unless (exists $feature_examples{$name}) { $feature_examples{$name} = {}; } $feature_examples{$name}->{$wanted_example} = $weight; } } } print "\n"; printf "=== Top adverse features (in descending %s weight)\n", ($opt_s ? 'cummulative' : 'mean'); printf "%-20s\t%-10s %s\n", 'Name', ($opt_s ? 'Sum-Weight' : 'Mean-Weight'), 'Examples'; unless ($opt_s) { # Switch from sum of-weights to mean foreach my $name (keys %feature_weights) { my @examples = keys %{$feature_examples{$name}}; my $example_count = scalar @examples; $feature_weights{$name} /= $example_count; } } foreach my $name (sort {$feature_weights{$b} <=> $feature_weights{$a}} keys %feature_weights) { my @example_list = sort { $feature_examples{$name}->{$b} <=> $feature_examples{$name}->{$a} } keys %{$feature_examples{$name}}; printf "%-20s\t%-10.6f @example_list\n", $name, $feature_weights{$name}; } } # # sort delta loss numerically, descending # sub by_delta() { $ExampleDelta{$b} <=> $ExampleDelta{$a}; } # # biggest_error_examples(N) # list of top N loss example_numbers # sub biggest_error_examples($) { my $howmany = shift; my @sorted_examples = sort by_delta keys %ExampleDelta; @sorted_examples[0 .. $howmany-1]; } # # print_errors(@examples) # Print the top N loss example numbers and their absolute or # relative loss. # sub print_errors(@) { printf "=== Top-%d (highest delta loss) diverging examples:\n", scalar(@_); printf "Example\t%s-Loss\n", ($opt_a ? 'Absolute' : 'Relative'); foreach my $example (@_) { printf "%d\t%g\n", $example, $ExampleDelta{$example}; } print "\n"; } # # first_pass(N) # First training pass to capture progressive loss numbers. # sub first_pass($) { my $top_n = shift; my $vw_cmd = $VWCmd; $vw_cmd =~ s/(?:--progress|-P)\s*[0-9.]+\b//; $vw_cmd =~ s/(?:--quiet\b)//; $vw_cmd .= " --progress 1"; v("+----- 1st pass: find top %d errors\n", $top_n); v("| 1st pass: \$vw_cmd: %s\n", $vw_cmd); open(my $vwh, "$vw_cmd 2>&1 |"); collect_errors($vwh); close $vwh; my @top_error_examples = biggest_error_examples($top_n); print_errors(@top_error_examples); v("+--- 1st pass: done!\n\n"); @top_error_examples; } # # second_pass($top_weights, @example_nos) # 2nd training pass with --audit to capture individual features # causing the biggest loss jumps. # sub second_pass($@) { my ($top_weights, @example_nos) = @_; v("+----- 2nd pass: --audit to find top bad features\n"); v("| 2nd pass: \$top_weights: %d\n", $top_weights); v("| 2nd pass: \@example_nos: @example_nos\n"); my $vw_cmd = $VWCmd; $vw_cmd =~ s/(?:--progress|-P)\s*[0-9.]+\b//; $vw_cmd =~ s/(?:--quiet\b)//; $vw_cmd =~ s/(?:--audit|-a\b)//; $vw_cmd .= " --quiet --audit"; v("| 2nd pass: \$vw_cmd: %s\n", $vw_cmd); open(my $vwh, "$vw_cmd 2>&1 |"); audit_top_weights($vwh, $top_weights, @example_nos); close $vwh; v("+--- 2nd pass: done!\n"); } # -- main get_args; my @top_errors = first_pass($TopN); second_pass($TopWeights, @top_errors);