From 69e34a3203fce4d7bd147f57cb1d75a46dc70682 Mon Sep 17 00:00:00 2001 From: Tomasz Dwojak Date: Thu, 29 Sep 2016 11:53:00 +0100 Subject: Add missed during merge python script: download_models.py --- scripts/download_models.py | 123 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100755 scripts/download_models.py (limited to 'scripts') diff --git a/scripts/download_models.py b/scripts/download_models.py new file mode 100755 index 00000000..39b3a991 --- /dev/null +++ b/scripts/download_models.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python2 +# -*- coding: utf-8 -*- + +import argparse +import sys +import os +import requests +from clint.textui import progress + +BASE_URL = "http://data.statmt.org/rsennrich/wmt16_systems/{}-{}/{}" + +CONFIG_TEMPLATE = """ +# Paths are relative to config file location +relative-paths: yes + +# performance settings +beam-size: 12 +devices: [0] +normalize: yes +threads-per-device: 1 + +# scorer configuration +scorers: + F0: + path: ./model.npz + type: Nematus + +# scorer weights +weights: + F0: 1.0 + +bpe: ./{}{}.bpe + +# vocabularies +source-vocab: ./vocab.{}.json +target-vocab: ./vocab.{}.json +""" + +def download_with_progress(path, url): + r = requests.get(url, stream=True) + with open(path, 'wb') as f: + total_length = int(r.headers.get('content-length')) + for chunk in progress.bar(r.iter_content(chunk_size=(1024 ** 2)), + expected_size=(total_length/(1024 ** 2)) + 1): + if chunk: + f.write(chunk) + f.flush() + + +def parse_args(): + """ parse command arguments """ + parser = argparse.ArgumentParser() + parser.add_argument("-w", dest="workdir", default='.') + parser.add_argument('-m', dest="model", default='en-de') + parser.add_argument('-f', dest="force", default=False) + return parser.parse_args() + + +def make_workdir(path): + """ Create a directory. """ + workdir = os.path.abspath(path) + + try: + os.makedirs(workdir) + except OSError: + pass + + +def download_model(model, workdir, force=False): + """ download Rico Sennrich's WMT16 model: to . """ + make_workdir(workdir) + download_model_parts(model, workdir, force) + create_base_config(model, workdir) + + +def download_model_parts(model, workdir, force=False): + src = model.split('-')[0] + trg = model.split('-')[1] + + model_parts = ["model.npz", + "vocab.{}.json".format(src), + "vocab.{}.json".format(trg), + "{}{}.bpe".format(src, trg), + "truecase-model.{}".format(src)] + + for part in model_parts: + download_file(src, trg, part, workdir, force) + + +def download_file(src, trg, name, workdir, force=False): + path = os.path.join(workdir, name) + if not os.path.exists(path): + full_url = BASE_URL.format(src, trg, name) + print >> sys.stderr, "Downloading: {} to {}".format(full_url, path) + download_with_progress(path, full_url) + elif force: + full_url = BASE_URL.format(src, trg, name) + print >> sys.stderr, "Force downloading: {}".format(full_url) + download_with_progress(path, full_url) + else: + print >> sys.stderr, "File {} exists. Skipped".format(path) + + +def create_base_config(model, model_dir): + src = model.split('-')[0] + trg = model.split('-')[1] + config = CONFIG_TEMPLATE.format(src, trg, src, trg) + + with open("{}/config.yml".format(model_dir), 'w') as config_file: + config_file.write(config) + + +def main(): + """ main """ + args = parse_args() + + print >> sys.stderr, "Downloading {} to {}".format(args.model, + args.workdir) + download_model(args.model, args.workdir, args.force) + + +if __name__ == "__main__": + main() -- cgit v1.2.3