Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/demo
diff options
context:
space:
mode:
authorSam Steingold <sds@gnu.org>2014-10-01 00:19:08 +0400
committerSam Steingold <sds@gnu.org>2014-10-01 00:19:51 +0400
commitf1b7fc29700ce04aa21e3e1a65b279a2303efe80 (patch)
tree02b83ac61cb77003105f7a1df05e046e77c6ed70 /demo
parenta812761767db06710d5fd51b3489714ced4103ca (diff)
add demo/ocr (http://ai.stanford.edu/~btaskar/ocr/)
Diffstat (limited to 'demo')
-rw-r--r--demo/ocr/Makefile58
-rw-r--r--demo/ocr/ocr2vw.py62
2 files changed, 120 insertions, 0 deletions
diff --git a/demo/ocr/Makefile b/demo/ocr/Makefile
new file mode 100644
index 00000000..87a07e6f
--- /dev/null
+++ b/demo/ocr/Makefile
@@ -0,0 +1,58 @@
+# Handwritten words dataset collected by
+# Rob Kassel at MIT Spoken Language Systems Group
+# http://ai.stanford.edu/~btaskar/ocr/
+
+VW = ../../vowpalwabbit/vw
+VW_OPTS = -b 24 -l 0.02 --nn 40
+
+RM = rm -f
+
+help:
+ @echo handwritten words dataset collected by
+ @echo Rob Kassel at MIT Spoken Language Systems Group
+ @echo http://ai.stanford.edu/~btaskar/ocr/
+ @echo $$ make letter.confusion
+
+letter.data.gz:
+ wget http://ai.stanford.edu/~btaskar/ocr/letter.data.gz
+
+letter.names:
+ wget http://ai.stanford.edu/~btaskar/ocr/letter.names
+
+letter.vw: ocr2vw.py letter.data.gz letter.names
+ python $^ $@ $@.test
+ cut -d' ' -f1 $@ | sort | uniq -c | sort -n
+
+# category count
+CATN = 26
+
+letter.model: letter.vw
+ time $(VW) --oaa $(CATN) --final_regressor $@ \
+ --adaptive --invariant --holdout_off \
+ --loss_function logistic --passes 100 \
+ $(VW_OPTS) --data $< -k --cache_file $<.cache
+ $(RM) $<.cache
+
+letter.predictions: letter.model
+ time $(VW) --testonly --initial_regressor $< --predictions $@ \
+ --data letter.vw.test
+
+# taken almost verbatim from ../mnist/Makefile
+CONFUSION='++$$n; $$p=int($$F[0]); $$l=ord($$F[1])-ord("a")+1; \
+ ++$$c if $$p != $$l; \
+ ++$$m{"$$l:$$p"}; } { \
+ print "$* test errors: $$c out of $$n = " . \
+ sprintf("%.2f%%",100*$$c/$$n) . \
+ "\nconfusion matrix (rows = truth, columns = prediction):"; \
+ foreach $$true (1 .. $(CATN)) { \
+ print join "\t", map { $$m{"$$true:$$_"} || 0 } (1 .. $(CATN)); \
+ }'
+
+%.confusion: %.predictions
+ @perl -lane $(CONFUSION) $< > $@
+ @cat $@
+
+clean:
+ $(RM) letter.*
+
+.PHONY: clean
diff --git a/demo/ocr/ocr2vw.py b/demo/ocr/ocr2vw.py
new file mode 100644
index 00000000..ce70ad75
--- /dev/null
+++ b/demo/ocr/ocr2vw.py
@@ -0,0 +1,62 @@
+# convert letter.data to letter.vw
+
+def read_letter_names (fn):
+ ret = list()
+ with open(fn) as ins:
+ for line in ins:
+ ret.append(line.rstrip())
+ print "Read %d names from %s" % (len(ret),fn)
+ return ret
+
+def find_pixel_start (names):
+ for i in range(len(names)):
+ if names[i].startswith("p_"):
+ return i
+ raise ValueError("No pixel data",names)
+
+def data2vw (ifn, train, test, names):
+ lineno = 0
+ trainN = 0
+ testN = 0
+ if ifn.endswith(".gz"):
+ import gzip
+ iopener = gzip.open
+ else:
+ iopener = open
+ id_pos = names.index("id")
+ letter_pos = names.index("letter")
+ pixel_start = find_pixel_start(names)
+ with iopener(ifn) as ins, open(train,"wb") as trainS, open(test,"wb") as testS:
+ for line in ins:
+ lineno += 1
+ vals = line.rstrip().split('\t')
+ if len(vals) != len(names):
+ raise ValueError("Bad field count",
+ len(vals),len(names),vals,names)
+ char = vals[letter_pos]
+ if len(char) != 1:
+ raise ValueError("Bad letter",char)
+ if lineno % 10 == 0:
+ testN += 1
+ outs = testS
+ else:
+ trainN += 1
+ outs = trainS
+ outs.write("%d 1 %s-%s|Pixel" % (ord(char)-ord('a')+1,char,vals[id_pos]))
+ for i in range(pixel_start,len(names)):
+ if vals[i] != '0':
+ outs.write(' %s:%s' % (names[i],vals[i]))
+ outs.write('\n')
+ print "Read %d lines from %s; wrote %d lines into %s and %d lines into %s" % (
+ lineno,ifn,trainN,train,testN,test)
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser(description='Convert letters.data to VW format')
+ parser.add_argument('input',help='path to letter.data[.gz]')
+ parser.add_argument('names',help='path to letter.names')
+ parser.add_argument('train',help='VW train file location (90%)')
+ parser.add_argument('test',help='VW test file location (10%)')
+ args = parser.parse_args()
+ data2vw(args.input,args.train,args.test,read_letter_names(args.names))