Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/bitextor/bicleaner-ai.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZJaume <jzaragoza@prompsit.com>2022-07-22 14:30:50 +0300
committerJaume Zaragoza <ZJaume@users.noreply.github.com>2022-07-27 15:20:55 +0300
commit55fe50d9223a0ab6975b8a38b6018a1f8d6bafb1 (patch)
tree1d6aed8f8e49a7c4b5b5efc331f720d1e14acc63
parent6383cddf2d4662e7e1c6c15de67daa002665a95a (diff)
Download models from HFHub if possible
-rwxr-xr-xbicleaner_ai/bicleaner_ai_classifier.py22
-rw-r--r--bicleaner_ai/classify.py14
2 files changed, 33 insertions, 3 deletions
diff --git a/bicleaner_ai/bicleaner_ai_classifier.py b/bicleaner_ai/bicleaner_ai_classifier.py
index 0d5f6c8..e13e9c9 100755
--- a/bicleaner_ai/bicleaner_ai_classifier.py
+++ b/bicleaner_ai/bicleaner_ai_classifier.py
@@ -49,6 +49,28 @@ def initialization():
else:
args.processes = max(1, cpu_count()-1)
+ # Try to download the model if not a valid path
+ if not args.offline or is_dir:
+ from huggingface_hub import snapshot_download, model_info
+ from huggingface_hub.utils import RepositoryNotFoundError
+ from requests.exceptions import HTTPError
+ try:
+ # Check if it exists at the HF Hub
+ model_info(args.model, token=args.auth_token)
+ except RepositoryNotFoundError:
+ logging.debug(
+ f"Model {args.model} not found at HF Hub, trying local storage")
+ args.metadata = args.model + '/metadata.yaml'
+ else:
+ logging.info(f"Downloading the model {args.model}")
+ # Download all the model files from the hub
+ cache_path = snapshot_download(args.model,
+ use_auth_token=args.auth_token)
+ # Set metadata path to the cache location of the model
+ args.metadata = cache_path + '/metadata.yaml'
+ else:
+ args.metadata = args.model + '/metadata.yaml'
+
# Load metadata YAML
args = load_metadata(args, parser)
diff --git a/bicleaner_ai/classify.py b/bicleaner_ai/classify.py
index 901c64a..1306692 100644
--- a/bicleaner_ai/classify.py
+++ b/bicleaner_ai/classify.py
@@ -32,7 +32,7 @@ def argument_parser():
## Input file. Try to open it to check if it exists
parser.add_argument('input', type=argparse.FileType('rt'), default=None, help="Tab-separated files to be classified")
parser.add_argument('output', nargs='?', type=argparse.FileType('w'), default=sys.stdout, help="Output of the classification")
- parser.add_argument('metadata', type=argparse.FileType('r'), default=None, help="Training metadata (YAML file)")
+ parser.add_argument('model', type=str, default=None, help="Path to model directory or HuggingFace Hub model identifier (such as 'bitextor/bicleaner-ai-full-en-fr')")
# Options group
groupO = parser.add_argument_group('Optional')
@@ -60,6 +60,10 @@ def argument_parser():
groupO.add_argument('--run_all_rules', default=False, action='store_true', help="Run all rules of Hardrules instead of stopping at first discard")
groupO.add_argument('--rules_config', type=argparse.FileType('r'), default=None, help="Hardrules configuration file")
+ # HuggingFace Hub options
+ groupO.add_argument('--offline', default=False, action='store_true', help="Don't try to download the model, instead try directly to load from local storage")
+ groupO.add_argument('--auth_token', default=None, type=str, help="Auth token for the Hugging Face Hub")
+
# Logging group
groupL = parser.add_argument_group('Logging')
groupL.add_argument('-q', '--quiet', action='store_true', help='Silent logging mode')
@@ -72,10 +76,11 @@ def argument_parser():
# Load metadata, classifier, lm_filter and porn_removal
def load_metadata(args, parser):
+ metadata_file = open(args.metadata)
try:
# Load YAML
- metadata_yaml = yaml.safe_load(args.metadata)
- yamlpath = os.path.dirname(os.path.abspath(args.metadata.name))
+ metadata_yaml = yaml.safe_load(metadata_file)
+ yamlpath = os.path.dirname(os.path.abspath(args.metadata))
metadata_yaml["yamlpath"] = yamlpath
# Read language pair and tokenizers
@@ -134,6 +139,9 @@ def load_metadata(args, parser):
logging.error("Error loading metadata")
traceback.print_exc()
sys.exit(1)
+ finally:
+ if not metadata_file.closed:
+ metadata_file.close()
# Ensure that directory exists; if not, create it
if not os.path.exists(args.tmp_dir):