diff options
author | ZJaume <jzaragoza@prompsit.com> | 2022-07-22 14:30:50 +0300 |
---|---|---|
committer | Jaume Zaragoza <ZJaume@users.noreply.github.com> | 2022-07-27 15:20:55 +0300 |
commit | 55fe50d9223a0ab6975b8a38b6018a1f8d6bafb1 (patch) | |
tree | 1d6aed8f8e49a7c4b5b5efc331f720d1e14acc63 | |
parent | 6383cddf2d4662e7e1c6c15de67daa002665a95a (diff) |
Download models from HFHub if possible
-rwxr-xr-x | bicleaner_ai/bicleaner_ai_classifier.py | 22 | ||||
-rw-r--r-- | bicleaner_ai/classify.py | 14 |
2 files changed, 33 insertions, 3 deletions
diff --git a/bicleaner_ai/bicleaner_ai_classifier.py b/bicleaner_ai/bicleaner_ai_classifier.py index 0d5f6c8..e13e9c9 100755 --- a/bicleaner_ai/bicleaner_ai_classifier.py +++ b/bicleaner_ai/bicleaner_ai_classifier.py @@ -49,6 +49,28 @@ def initialization(): else: args.processes = max(1, cpu_count()-1) + # Try to download the model if not a valid path + if not args.offline or is_dir: + from huggingface_hub import snapshot_download, model_info + from huggingface_hub.utils import RepositoryNotFoundError + from requests.exceptions import HTTPError + try: + # Check if it exists at the HF Hub + model_info(args.model, token=args.auth_token) + except RepositoryNotFoundError: + logging.debug( + f"Model {args.model} not found at HF Hub, trying local storage") + args.metadata = args.model + '/metadata.yaml' + else: + logging.info(f"Downloading the model {args.model}") + # Download all the model files from the hub + cache_path = snapshot_download(args.model, + use_auth_token=args.auth_token) + # Set metadata path to the cache location of the model + args.metadata = cache_path + '/metadata.yaml' + else: + args.metadata = args.model + '/metadata.yaml' + # Load metadata YAML args = load_metadata(args, parser) diff --git a/bicleaner_ai/classify.py b/bicleaner_ai/classify.py index 901c64a..1306692 100644 --- a/bicleaner_ai/classify.py +++ b/bicleaner_ai/classify.py @@ -32,7 +32,7 @@ def argument_parser(): ## Input file. Try to open it to check if it exists parser.add_argument('input', type=argparse.FileType('rt'), default=None, help="Tab-separated files to be classified") parser.add_argument('output', nargs='?', type=argparse.FileType('w'), default=sys.stdout, help="Output of the classification") - parser.add_argument('metadata', type=argparse.FileType('r'), default=None, help="Training metadata (YAML file)") + parser.add_argument('model', type=str, default=None, help="Path to model directory or HuggingFace Hub model identifier (such as 'bitextor/bicleaner-ai-full-en-fr')") # Options group groupO = parser.add_argument_group('Optional') @@ -60,6 +60,10 @@ def argument_parser(): groupO.add_argument('--run_all_rules', default=False, action='store_true', help="Run all rules of Hardrules instead of stopping at first discard") groupO.add_argument('--rules_config', type=argparse.FileType('r'), default=None, help="Hardrules configuration file") + # HuggingFace Hub options + groupO.add_argument('--offline', default=False, action='store_true', help="Don't try to download the model, instead try directly to load from local storage") + groupO.add_argument('--auth_token', default=None, type=str, help="Auth token for the Hugging Face Hub") + # Logging group groupL = parser.add_argument_group('Logging') groupL.add_argument('-q', '--quiet', action='store_true', help='Silent logging mode') @@ -72,10 +76,11 @@ def argument_parser(): # Load metadata, classifier, lm_filter and porn_removal def load_metadata(args, parser): + metadata_file = open(args.metadata) try: # Load YAML - metadata_yaml = yaml.safe_load(args.metadata) - yamlpath = os.path.dirname(os.path.abspath(args.metadata.name)) + metadata_yaml = yaml.safe_load(metadata_file) + yamlpath = os.path.dirname(os.path.abspath(args.metadata)) metadata_yaml["yamlpath"] = yamlpath # Read language pair and tokenizers @@ -134,6 +139,9 @@ def load_metadata(args, parser): logging.error("Error loading metadata") traceback.print_exc() sys.exit(1) + finally: + if not metadata_file.closed: + metadata_file.close() # Ensure that directory exists; if not, create it if not os.path.exists(args.tmp_dir): |