diff options
author | Sergey Yershov <yershov@corp.mail.ru> | 2016-05-31 20:27:25 +0300 |
---|---|---|
committer | Sergey Yershov <yershov@corp.mail.ru> | 2016-05-31 20:27:25 +0300 |
commit | 669fcc90b3041141ccd8ab9f3fe76f960c33d6c8 (patch) | |
tree | 8d8304c7c685e4d3376ff2a574abcc5c909bf753 /tools | |
parent | 3e212a01d1ffe635670ea746b3b4bd725a63fd9c (diff) |
Review fixes
Diffstat (limited to 'tools')
-rwxr-xr-x | tools/python/booking_hotels_quality.py | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/tools/python/booking_hotels_quality.py b/tools/python/booking_hotels_quality.py index 00d9b29138..d1a58f8a7d 100755 --- a/tools/python/booking_hotels_quality.py +++ b/tools/python/booking_hotels_quality.py @@ -9,6 +9,7 @@ import argparse import base64 import json import logging +import matplotlib.pyplot as plt import os import pickle import time @@ -17,7 +18,11 @@ import urllib2 # init logging logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] %(levelname)s: %(message)s') + def load_binary_list(path): + """ + Loads binary classifier output + """ bits = [] with open(path, 'r') as fd: for line in fd: @@ -26,7 +31,11 @@ def load_binary_list(path): bits.append(1 if line[0] == 'y' else 0) return bits + def load_score_list(path): + """ + Loads list of scores + """ scores = [] with open(path, 'r') as fd: for line in fd: @@ -35,15 +44,17 @@ def load_score_list(path): scores.append(float(line[line.rfind(':')+2:])) return scores + def process_options(): - parser = argparse.ArgumentParser(description='Download and process booking hotels.') + parser = argparse.ArgumentParser(description="Download and process booking hotels.") parser.add_argument("-v", "--verbose", action="store_true", dest="verbose") parser.add_argument("-q", "--quiet", action="store_false", dest="verbose") parser.add_argument("--reference_list", dest="reference_list", help="Path to data files") parser.add_argument("--sample_list", dest="sample_list", help="Name and destination for output file") - parser.add_argument("--show", dest="show", default=False, action="store_true", help="Show graph for precision and recall") + parser.add_argument("--show", dest="show", default=False, action="store_true", + help="Show graph for precision and recall") options = parser.parse_args() @@ -53,6 +64,7 @@ def process_options(): return options + def main(): options = process_options() reference = load_binary_list(options.reference_list) @@ -60,12 +72,14 @@ def main(): precision, recall, threshold = metrics.precision_recall_curve(reference, sample) aa = zip(precision, recall, threshold) - print("Optimal thrashold: {2} for precision: {0} and recall: {1}".format(*max(aa, key=lambda (p, r, t): p*r/(p+r)))) + print("Optimal threshold: {2} for precision: {0} and recall: {1}".format(*max(aa, key=lambda (p, r, t): p*r/(p+r)))) print("AUC: {0}".format(metrics.roc_auc_score(reference, sample))) if options.show: - import matplotlib.pyplot as plt plt.plot(recall, precision) + plt.title("Precision/Recall") + plt.ylabel("Precision") + plt.xlabel("Recall") plt.show() |