diff options
author | Sam Steingold <sds@gnu.org> | 2014-10-01 00:19:08 +0400 |
---|---|---|
committer | Sam Steingold <sds@gnu.org> | 2014-10-01 00:19:51 +0400 |
commit | f1b7fc29700ce04aa21e3e1a65b279a2303efe80 (patch) | |
tree | 02b83ac61cb77003105f7a1df05e046e77c6ed70 /demo | |
parent | a812761767db06710d5fd51b3489714ced4103ca (diff) |
add demo/ocr (http://ai.stanford.edu/~btaskar/ocr/)
Diffstat (limited to 'demo')
-rw-r--r-- | demo/ocr/Makefile | 58 | ||||
-rw-r--r-- | demo/ocr/ocr2vw.py | 62 |
2 files changed, 120 insertions, 0 deletions
diff --git a/demo/ocr/Makefile b/demo/ocr/Makefile new file mode 100644 index 00000000..87a07e6f --- /dev/null +++ b/demo/ocr/Makefile @@ -0,0 +1,58 @@ +# Handwritten words dataset collected by +# Rob Kassel at MIT Spoken Language Systems Group +# http://ai.stanford.edu/~btaskar/ocr/ + +VW = ../../vowpalwabbit/vw +VW_OPTS = -b 24 -l 0.02 --nn 40 + +RM = rm -f + +help: + @echo handwritten words dataset collected by + @echo Rob Kassel at MIT Spoken Language Systems Group + @echo http://ai.stanford.edu/~btaskar/ocr/ + @echo $$ make letter.confusion + +letter.data.gz: + wget http://ai.stanford.edu/~btaskar/ocr/letter.data.gz + +letter.names: + wget http://ai.stanford.edu/~btaskar/ocr/letter.names + +letter.vw: ocr2vw.py letter.data.gz letter.names + python $^ $@ $@.test + cut -d' ' -f1 $@ | sort | uniq -c | sort -n + +# category count +CATN = 26 + +letter.model: letter.vw + time $(VW) --oaa $(CATN) --final_regressor $@ \ + --adaptive --invariant --holdout_off \ + --loss_function logistic --passes 100 \ + $(VW_OPTS) --data $< -k --cache_file $<.cache + $(RM) $<.cache + +letter.predictions: letter.model + time $(VW) --testonly --initial_regressor $< --predictions $@ \ + --data letter.vw.test + +# taken almost verbatim from ../mnist/Makefile +CONFUSION='++$$n; $$p=int($$F[0]); $$l=ord($$F[1])-ord("a")+1; \ + ++$$c if $$p != $$l; \ + ++$$m{"$$l:$$p"}; } { \ + print "$* test errors: $$c out of $$n = " . \ + sprintf("%.2f%%",100*$$c/$$n) . \ + "\nconfusion matrix (rows = truth, columns = prediction):"; \ + foreach $$true (1 .. $(CATN)) { \ + print join "\t", map { $$m{"$$true:$$_"} || 0 } (1 .. $(CATN)); \ + }' + +%.confusion: %.predictions + @perl -lane $(CONFUSION) $< > $@ + @cat $@ + +clean: + $(RM) letter.* + +.PHONY: clean diff --git a/demo/ocr/ocr2vw.py b/demo/ocr/ocr2vw.py new file mode 100644 index 00000000..ce70ad75 --- /dev/null +++ b/demo/ocr/ocr2vw.py @@ -0,0 +1,62 @@ +# convert letter.data to letter.vw + +def read_letter_names (fn): + ret = list() + with open(fn) as ins: + for line in ins: + ret.append(line.rstrip()) + print "Read %d names from %s" % (len(ret),fn) + return ret + +def find_pixel_start (names): + for i in range(len(names)): + if names[i].startswith("p_"): + return i + raise ValueError("No pixel data",names) + +def data2vw (ifn, train, test, names): + lineno = 0 + trainN = 0 + testN = 0 + if ifn.endswith(".gz"): + import gzip + iopener = gzip.open + else: + iopener = open + id_pos = names.index("id") + letter_pos = names.index("letter") + pixel_start = find_pixel_start(names) + with iopener(ifn) as ins, open(train,"wb") as trainS, open(test,"wb") as testS: + for line in ins: + lineno += 1 + vals = line.rstrip().split('\t') + if len(vals) != len(names): + raise ValueError("Bad field count", + len(vals),len(names),vals,names) + char = vals[letter_pos] + if len(char) != 1: + raise ValueError("Bad letter",char) + if lineno % 10 == 0: + testN += 1 + outs = testS + else: + trainN += 1 + outs = trainS + outs.write("%d 1 %s-%s|Pixel" % (ord(char)-ord('a')+1,char,vals[id_pos])) + for i in range(pixel_start,len(names)): + if vals[i] != '0': + outs.write(' %s:%s' % (names[i],vals[i])) + outs.write('\n') + print "Read %d lines from %s; wrote %d lines into %s and %d lines into %s" % ( + lineno,ifn,trainN,train,testN,test) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Convert letters.data to VW format') + parser.add_argument('input',help='path to letter.data[.gz]') + parser.add_argument('names',help='path to letter.names') + parser.add_argument('train',help='VW train file location (90%)') + parser.add_argument('test',help='VW test file location (10%)') + args = parser.parse_args() + data2vw(args.input,args.train,args.test,read_letter_names(args.names)) |