Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/bitextor/bicleaner-ai.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZJaume <jzaragoza@prompsit.com>2022-08-09 16:46:00 +0300
committerZJaume <jzaragoza@prompsit.com>2022-08-09 16:46:00 +0300
commitaf09cc4c92edb5f34294b716076a4e439bd35e8a (patch)
tree574aa2846004e0ccfe3db87e1176b1dee685501e
parent341eca49932a089c09c2264d3570ef0f3d44554f (diff)
parent87f17a6079f9cabc14feb9a2cb594061dc99aa0f (diff)
Merge branch 'master' into noise
-rw-r--r--CHANGELOG.md119
-rw-r--r--README.md19
-rwxr-xr-xbicleaner_ai/bicleaner_ai_classifier.py28
-rwxr-xr-xbicleaner_ai/bicleaner_ai_train.py12
-rw-r--r--bicleaner_ai/classify.py15
-rw-r--r--bicleaner_ai/datagen.py2
-rw-r--r--bicleaner_ai/layers.py16
-rw-r--r--bicleaner_ai/models.py66
-rw-r--r--requirements.txt2
-rwxr-xr-xscripts/bicleaner-ai-download64
-rw-r--r--scripts/bicleaner-ai-download-hf11
-rwxr-xr-xsetup.py4
-rwxr-xr-xutils/download-pack.sh58
13 files changed, 217 insertions, 199 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5702013..df978af 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,21 +1,41 @@
-Unreleased:
-* Update to Hardrules 2.0
- * Rules can be parametrized with `--rules_config config.yaml`
- * Some rules have been refactored with better names.
- * `--run_all_rules` mode to run each rule instead of stoppping at first discard
- * Language identification with [FastSpell](https://github.com/mbanon/fastspell)
-* Huge memory improvements during training.
+# Changelog
+All notable changes to this project will be documented in this file.
+
+The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
+and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+
+## [Unreleased]:
+### Added
+* Upload full models to Hugging Face Hub.
+* Automatic download of full models.
* Hide Tensorflow and Transformers logging messages in executable scripts.
-* Update HF Transformers, no longer needed single GPU for prediction.
+* Redirect Keras prediction progress bar to stderr.
+* Huge memory improvements during training.
+* Speed improvements using pading `longest` instead of `max_length`
+### Changed
+* Update to Hardrules 2.3
+ * Rules can be parametrized with `--rules_config config.yaml`
+ * Some rules have been refactored with better names.
+ * `--run_all_rules` mode to run each rule instead of stoppping at first discard
+ * Language identification with [FastSpell](https://github.com/mbanon/fastspell)
+ * Easier installation! Now KenLM comes pre-compiled.
+* Now BICLEANER\_AI\_THREADS environment variable controls the number of threads.
+* Update HF Transformers.
+* Update TensorFlow minimum version.
+* Rename `download-packs.sh` to `bicleaner-ai-download`.
+* Set inter/intra\_op parallelism to 0 by default.
+* Add citation info to README.
+### Fixed
* Avoid generating empty sentences in omit noise.
* Restore capital letters at the beggining of the sentennce in frequency noise.
* Fix loading lite models in other other Python versions than 3.8.
+* Fix unbound variable `lm_stats`.
* Other minor fixes.
-Bicleaner AI 1.0.1:
+## Bicleaner AI 1.0.1:
* Update hardrules to 1.2: adds score only mode.
-Bicleaner AI 1.0:
+## Bicleaner AI 1.0:
* Bicleaner train changes:
* Separate most of the training logic in the BaseModel class.
* Re-factor synthetic noise build function.
@@ -27,82 +47,3 @@ Bicleaner AI 1.0:
* Bicleaner classify changes:
* Change old classifier by new neural models.
* Move hardrules into a separate package.
-
-Bicleaner 0.15:
-* Bicleaner train changes:
- * Qmax bug fixing.
- * Classifier training uses the number of processes given by argument.
-* Bicleaner classify changes:
- * Refactored classifier scripts: code cleaning and remove lot of duplicated code.
- * Buffered tokenization: improve speed of external tokenization tokenizing blocks of lines instead of line by line.
-
-Bicleaner 0.14:
-* Bicleaner hardrules changes:
- * New rule: filter out sentences containing gluedWordsLikeThis.
- * Rule change: Relaxed c_different_language rule for similar languages.
- * New rule: filter out porn sentences using FastText classifier.
- * Parameters changed: `-s/--source_lang` and `-t/--target_lang` are no longer mandatory (if a metadata .yaml file is provided)
-* Bicleaner train changes:
- * Default classifier is now `extra_trees`
- * New parameters: `-f` and `-F`, source and target word frequency dictionaries.
- * New qmax features:
- * `qmax_nosmooth_nolimit_freq`: removes OOV smoothing, word limits and weights each target word with its monolingual probability using the word frequency dictionary.
- * `qmax_nosmooth_nolimit_cummulated_prob_zipf_freq`: uses accumulated probability instead of maximum and splits the score into quartiles based on word frequencies.
- * Added more bilingual dictionary coverage features, splitting them into quartiles based on monolingual word frequencies.
- * Added new noise function that synthesizes negative samples cutting sentences and replacing words (this is not used by default, needs more testing).
- * Changed classifier training behavior and use grid search.
- * Removed `bicleaner_train_lite.py`
- * Removed parameters: `-g` (`--good_examples`) and `-w` (`--wrong_examples`):
- * Now, training automatically uses one half of the input file for good examples and the other half to synthesize wrong examples.
- * Of this partitions, 90% will be used for training and the remaining 10% for testing.
- * New parameter: `--relative_paths` allows to save model files paths relative instead of absolute (useful for training distributable models)
- * Changed logging info messages, now more informative.
-* Other
- * Now using [sacremoses](https://github.com/alvations/sacremoses) instead of [mosestokenizer](https://github.com/luismsgomes/mosestokenizer)
- * New script: `./utils/download-pack.sh` allows to download language packs for a given language pair.
-
-
-Bicleaner 0.13:
-* Bicleaner hardrules changes:
- * Rule change: Relaxed c_minimal_length to accept 3-word sentences
- * New feature: LM filtering (moved from Bicleaner Classify)
- * New parameter: `--disable_lm_filter`, `--metadata` and `--lm_threshold`, to support LM filtering
-* Bicleaner training changes:
- * New parameter: Features relying on language identification can be disabled with flag `--disable_lang_ident` (this will be outputed in the .yaml file and used by Bicleaner clasifier)
- * New feature: Debug mode now gives information on random forest feature importances
- * Parameter change: --noisy_examples_file_sl and --noisy_examples_file_tl are now optional
- * Parameter change: input now must be more than 10K sentences long
- * Removed INFO messages when processes starting/ending (except when debugging)
-* Bicleaner classifier changes:
- * `--disable_lang_ident` flag is now read from the .yaml file
- * Removed feature: LM filtering (moved to Bicleaner Hardrules)
- * New parameter: `--disable_lm_filter`
- * Removed parameters: `--keep_lm_result`, `--threshold`
-* Other:
- * Updated requirements
-
-
-
-Bicleaner 0.12:
-* Bicleaner hardrules changes:
- * New rule: c_identical_wo_punct to reject sentences only different in punctuation (and it's case insensitive)
- * New rule: Sentences containing "Re:" are rejected
- * Rule change: c_minimal_length now rejects sentences with both sides <= 3 words (instead of only one)
- * Rule change: c_identical and c_identical_wo_digits now is case insensitive
- * Rule change: Breadcrumbs rule now split into c_no_breadcrumbs1 and c_no_breadcrumbs2
- * Rule change: Breadcrumbs2 now includes character "ยท" in the rejected characters
- * Rule change: c_length now compares byte length ratio (will avoid rejecting valid sentences due to length ratio when comparing languages with different alphabets)
- * Changed behaviour for `--annotated_output` argument in hardrules. See README.md for more information.
- * New parameter: `--disable_lang_ident` flag to avoid applying rules that need to identify the language
-* Bicleaner classify changes:
- * Now using only 3 decimal places for Bicleaner score and LM score
- * Removed INFO messages when processes starting/ending (except when debugging)
- * New parameter: '--disable_hardrules' flag to avoid applying hardrules
- * New parameter: '--disable_lang_ident' flag to avoid applying rules that need to identify the language
- * New parameter: '--score_only' flag to output only Bicleaner scores (proposed by [@kirefu](https://github.com/kirefu))
-* Bicleaner features changes:
- * Fixed bug when probability in prob_dict is 0 (issue [#19](https://github.com/bitextor/bicleaner/issues/19))
-* Other:
- * Fixed sklearn version to 0.19.1
- * Added utilities for training: `shuffle.py` and `dict_pruner.py`
- * Updated instalation guides in readme
diff --git a/README.md b/README.md
index be46422..0c98599 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@ indicates the likelihood of a pair of sentences being mutual translations (with
Sentence pairs considered very noisy are scored with 0.
Although a training tool (`bicleaner-ai-train`) is provided, you may want to use the available ready-to-use language packages.
-Please, visit https://github.com/bitextor/bicleaner-ai-data/releases/latest or use `./utils/download-pack.sh` to download the latest language packages.
+Please, use `bicleaner-ai-download` to download the latest language packages or visit the [Github releases](https://github.com/bitextor/bicleaner-ai-data/releases/latest) for lite models and [Hugging Face Hub](https://huggingface.co/bitextor) for full models since v2.0.
Visit our [Wiki](https://github.com/bitextor/bicleaner-ai/wiki/How-to-train-your-Bicleaner-AI) for a detailed example on Bicleaner training.
## Citation
@@ -47,21 +47,12 @@ The use of XLMRoberta and 1:10 positive to negative ratio were inspired in the w
- TensorFlow >= 2.6.5
- CUDA 11.2 (for training and inference with full models)
-Bicleaner AI is written in Python and can be installed using `pip`:
+Bicleaner AI is written in Python and can be installed using `pip`.
+It also requires the [KenLM](https://github.com/kpu/kenlm) Python bindings with support for 7-gram language models.
+You can easily install it by running the following command:
```bash
-pip install bicleaner-ai
-```
-
-Bicleaner AI requires the [KenLM](https://github.com/kpu/kenlm) Python bindings with support for 7-gram language models. You can easily install it by running the following commands:
-
-```bash
-git clone https://github.com/kpu/kenlm
-cd kenlm
-pip install . --install-option="--max_order 7"
-mkdir -p build && cd build
-cmake .. -DKENLM_MAX_ORDER=7 -DCMAKE_INSTALL_PREFIX:PATH=/your/prefix/path
-make -j all install
+pip install bicleaner-ai https://github.com/kpu/kenlm/archive/master.zip --install-option="--max_order 7"
```
Hardrules uses [FastSpell](https://github.com/mbanon/fastspell) that requires `python-dev` and `libhunspell-dev`:
diff --git a/bicleaner_ai/bicleaner_ai_classifier.py b/bicleaner_ai/bicleaner_ai_classifier.py
index 0d5f6c8..ba57d5f 100755
--- a/bicleaner_ai/bicleaner_ai_classifier.py
+++ b/bicleaner_ai/bicleaner_ai_classifier.py
@@ -49,6 +49,34 @@ def initialization():
else:
args.processes = max(1, cpu_count()-1)
+ # Try to download the model if not a valid path
+ hub_not_found = False
+ if not args.offline:
+ from huggingface_hub import snapshot_download, model_info
+ from huggingface_hub.utils import RepositoryNotFoundError
+ from requests.exceptions import HTTPError
+ try:
+ # Check if it exists at the HF Hub
+ model_info(args.model, token=args.auth_token)
+ except RepositoryNotFoundError:
+ hub_not_found = True
+ args.metadata = args.model + '/metadata.yaml'
+ else:
+ logging.info(f"Downloading the model {args.model}")
+ # Download all the model files from the hub
+ cache_path = snapshot_download(args.model,
+ use_auth_token=args.auth_token)
+ # Set metadata path to the cache location of the model
+ args.metadata = cache_path + '/metadata.yaml'
+ else:
+ args.metadata = args.model + '/metadata.yaml'
+
+ if not os.path.isfile(args.metadata):
+ if hub_not_found:
+ logging.error(
+ f"Model {args.model} not found at HF Hub")
+ raise FileNotFoundError(f"model {args.model} no such file")
+
# Load metadata YAML
args = load_metadata(args, parser)
diff --git a/bicleaner_ai/bicleaner_ai_train.py b/bicleaner_ai/bicleaner_ai_train.py
index a8a650d..ebda149 100755
--- a/bicleaner_ai/bicleaner_ai_train.py
+++ b/bicleaner_ai/bicleaner_ai_train.py
@@ -49,6 +49,7 @@ def initialization():
groupM.add_argument('--parallel_valid', type=argparse.FileType('r'), default=None, required=True, help="TSV file containing parallel sentences for validation")
groupO = parser.add_argument_group('Options')
+ groupO.add_argument('--model_name', type=str, default=None, help='The name of the model. For the XLMR models it will be used as the name in Hugging Face Hub.')
groupO.add_argument('-S', '--source_tokenizer_command', help="Source language tokenizer full command")
groupO.add_argument('-T', '--target_tokenizer_command', help="Target language tokenizer full command")
groupO.add_argument('-f', '--source_word_freqs', type=argparse.FileType('r'), default=None, required=False, help="L language gzipped list of word frequencies")
@@ -211,7 +212,18 @@ def perform_training(args):
args.parallel_train.close()
args.parallel_valid.close()
+ # Define the model name
+ if args.model_name is None:
+ model_name = 'bitextor/bicleaner-ai'
+ if args.classifier_type in ['dec_attention', 'transformer']:
+ model_name += f'-lite-{args.source_lang}-{args.target_lang}'
+ else:
+ model_name += f'-full-{args.source_lang}-{args.target_lang}'
+ else:
+ model_name = args.model_name
+
model_settings = {
+ "model_name": model_name,
"batch_size": args.batch_size,
"epochs": args.epochs,
"steps_per_epoch": args.steps_per_epoch
diff --git a/bicleaner_ai/classify.py b/bicleaner_ai/classify.py
index 0061702..1306692 100644
--- a/bicleaner_ai/classify.py
+++ b/bicleaner_ai/classify.py
@@ -21,6 +21,7 @@ except (ImportError, SystemError):
__author__ = "Jaume Zaragoza"
__version__ = "Version 1.0 # 14/06/2021 #"
__version__ = "Version 1.0.1 # 16/06/2021 #"
+__version__ = "Version 2.0"
# Create an argument parser and add all the arguments
@@ -31,7 +32,7 @@ def argument_parser():
## Input file. Try to open it to check if it exists
parser.add_argument('input', type=argparse.FileType('rt'), default=None, help="Tab-separated files to be classified")
parser.add_argument('output', nargs='?', type=argparse.FileType('w'), default=sys.stdout, help="Output of the classification")
- parser.add_argument('metadata', type=argparse.FileType('r'), default=None, help="Training metadata (YAML file)")
+ parser.add_argument('model', type=str, default=None, help="Path to model directory or HuggingFace Hub model identifier (such as 'bitextor/bicleaner-ai-full-en-fr')")
# Options group
groupO = parser.add_argument_group('Optional')
@@ -59,6 +60,10 @@ def argument_parser():
groupO.add_argument('--run_all_rules', default=False, action='store_true', help="Run all rules of Hardrules instead of stopping at first discard")
groupO.add_argument('--rules_config', type=argparse.FileType('r'), default=None, help="Hardrules configuration file")
+ # HuggingFace Hub options
+ groupO.add_argument('--offline', default=False, action='store_true', help="Don't try to download the model, instead try directly to load from local storage")
+ groupO.add_argument('--auth_token', default=None, type=str, help="Auth token for the Hugging Face Hub")
+
# Logging group
groupL = parser.add_argument_group('Logging')
groupL.add_argument('-q', '--quiet', action='store_true', help='Silent logging mode')
@@ -71,10 +76,11 @@ def argument_parser():
# Load metadata, classifier, lm_filter and porn_removal
def load_metadata(args, parser):
+ metadata_file = open(args.metadata)
try:
# Load YAML
- metadata_yaml = yaml.safe_load(args.metadata)
- yamlpath = os.path.dirname(os.path.abspath(args.metadata.name))
+ metadata_yaml = yaml.safe_load(metadata_file)
+ yamlpath = os.path.dirname(os.path.abspath(args.metadata))
metadata_yaml["yamlpath"] = yamlpath
# Read language pair and tokenizers
@@ -133,6 +139,9 @@ def load_metadata(args, parser):
logging.error("Error loading metadata")
traceback.print_exc()
sys.exit(1)
+ finally:
+ if not metadata_file.closed:
+ metadata_file.close()
# Ensure that directory exists; if not, create it
if not os.path.exists(args.tmp_dir):
diff --git a/bicleaner_ai/datagen.py b/bicleaner_ai/datagen.py
index be8a6ba..b624ffc 100644
--- a/bicleaner_ai/datagen.py
+++ b/bicleaner_ai/datagen.py
@@ -182,7 +182,7 @@ class ConcatSentenceGenerator(SentenceGenerator):
else:
# Tokenize with Transformers tokenizer that concatenates internally
dataset = self.encoder(text1, text2,
- padding='max_length',
+ padding='longest',
truncation=True,
max_length=self.maxlen,
return_tensors='np',
diff --git a/bicleaner_ai/layers.py b/bicleaner_ai/layers.py
index 7d19aa0..95264f0 100644
--- a/bicleaner_ai/layers.py
+++ b/bicleaner_ai/layers.py
@@ -67,18 +67,22 @@ class TransformerBlock(layers.Layer):
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(out1 + ffn_output)
-class BCClassificationHead(layers.Layer):
- """Head for sentence-level classification tasks."""
+class BicleanerAIClassificationHead(layers.Layer):
+ """
+ Head for Bicleaner sentence classification tasks.
+ It reads BicleanerAIConfig to be able to change
+ classifier layer parameters (size, activation and dropout)
+ """
- def __init__(self, config, hidden_size, dropout, activation, **kwargs):
+ def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = layers.Dense(
- hidden_size,
+ config.head_hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
- activation=activation,
+ activation=config.head_activation,
name="dense",
)
- self.dropout = layers.Dropout(dropout)
+ self.dropout = layers.Dropout(config.head_dropout)
self.out_proj = layers.Dense(
config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py
index e81826e..4062ee0 100644
--- a/bicleaner_ai/models.py
+++ b/bicleaner_ai/models.py
@@ -1,4 +1,8 @@
-from transformers import TFXLMRobertaForSequenceClassification, XLMRobertaTokenizerFast
+from transformers import (
+ TFXLMRobertaForSequenceClassification,
+ XLMRobertaConfig,
+ XLMRobertaTokenizerFast
+)
from transformers.modeling_tf_outputs import TFSequenceClassifierOutput
from transformers.optimization_tf import create_optimizer
from tensorflow.keras.optimizers.schedules import InverseTimeDecay
@@ -29,7 +33,7 @@ try:
from .layers import (
TransformerBlock,
TokenAndPositionEmbedding,
- BCClassificationHead)
+ BicleanerAIClassificationHead)
except (SystemError, ImportError):
import decomposable_attention
from metrics import FScore, MatthewsCorrCoef
@@ -40,7 +44,7 @@ except (SystemError, ImportError):
from layers import (
TransformerBlock,
TokenAndPositionEmbedding,
- BCClassificationHead)
+ BicleanerAIClassificationHead)
def calibrate_output(y_true, y_pred):
''' Platt calibration
@@ -219,7 +223,8 @@ class BaseModel(ModelInterface):
generator = self.get_generator(batch_size, shuffle=False)
generator.load((x1, x2, None))
- y_pred = self.model.predict(generator, verbose=verbose)
+ with redirect_stdout(sys.stderr):
+ y_pred = self.model.predict(generator, verbose=verbose)
# Obtain logits if model returns HF output
if isinstance(y_pred, TFSequenceClassifierOutput):
y_pred = y_pred.logits
@@ -481,9 +486,7 @@ class BCXLMRoberta(BaseModel):
self.tokenizer = None
self.settings = {
- "model_file": "model.tf",
- "vocab_file": "vocab",
- "model": 'jplu/tf-xlm-roberta-base',
+ "base_model": 'jplu/tf-xlm-roberta-base',
"batch_size": 16,
"maxlen": 150,
"n_classes": 2,
@@ -521,7 +524,7 @@ class BCXLMRoberta(BaseModel):
def load_model(self, model_file):
settings = self.settings
- tf_model = TFXLMRBicleaner.from_pretrained(
+ tf_model = TFXLMRBicleanerAI.from_pretrained(
model_file,
num_labels=settings["n_classes"],
head_hidden_size=settings["n_hidden"],
@@ -532,9 +535,8 @@ class BCXLMRoberta(BaseModel):
def load(self):
''' Load fine-tuned model '''
- vocab_file = self.dir + '/' + self.settings["vocab_file"]
- self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(vocab_file)
- self.model = self.load_model(self.dir+'/'+self.settings["model_file"])
+ self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(self.dir)
+ self.model = self.load_model(self.dir)
def softmax_pos_prob(self, x):
# Compute softmax probability of the second (positive) class
@@ -571,7 +573,7 @@ class BCXLMRoberta(BaseModel):
logging.info("Loading training set")
self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(
- self.settings["model"])
+ self.settings["base_model"])
train_generator = self.get_generator(self.settings["batch_size"],
shuffle=True)
train_generator.load(train_set)
@@ -582,8 +584,6 @@ class BCXLMRoberta(BaseModel):
shuffle=False)
dev_generator.load(dev_set, ignore_tags=False)
- model_filename = self.dir + '/' + self.settings["model_file"]
- vocab_filename = self.dir + '/' + self.settings["vocab_file"]
earlystop = EarlyStopping(monitor='val_f1',
mode='max',
patience=self.settings["patience"],
@@ -594,12 +594,13 @@ class BCXLMRoberta(BaseModel):
strategy = tf.distribute.MirroredStrategy()
num_devices = strategy.num_replicas_in_sync
with strategy.scope():
- self.model = self.load_model(self.settings["model"])
+ self.model = self.load_model(self.settings["base_model"])
self.model.compile(optimizer=self.settings["optimizer"],
loss=SparseCategoricalCrossentropy(
from_logits=True),
metrics=[FScore(argmax=True),
MatthewsCorrCoef(argmax=True)])
+ self.model.config._name_or_path = self.settings["model_name"]
if logging.getLogger().level == logging.DEBUG:
self.model.summary()
@@ -616,8 +617,8 @@ class BCXLMRoberta(BaseModel):
batch_size=self.settings["batch_size"],
callbacks=[earlystop],
verbose=verbose)
- self.model.save_pretrained(model_filename)
- self.tokenizer.save_pretrained(vocab_filename)
+ self.model.save_pretrained(self.dir)
+ self.tokenizer.save_pretrained(self.dir)
y_true = dev_generator.y
with redirect_stdout(sys.stderr):
@@ -634,13 +635,26 @@ class BCXLMRoberta(BaseModel):
return y_true, y_pred, dev_generator.tags
-class TFXLMRBicleaner(TFXLMRobertaForSequenceClassification):
- """Model for sentence-level classification tasks."""
+class XLMRBicleanerAIConfig(XLMRobertaConfig):
+ '''
+ Bicleaner AI XLMR configuration class
+ adds the config parameters for the classification layer
+ '''
+ def __init__(self, head_hidden_size=2048, head_dropout=0.1,
+ head_activation='relu', **kwargs):
+ super().__init__(**kwargs)
+ self.head_hidden_size = head_hidden_size
+ self.head_dropout = head_dropout
+ self.head_activation = head_activation
+
+class TFXLMRBicleanerAI(TFXLMRobertaForSequenceClassification):
+ '''
+ Model for Bicleaner sentence-level classification tasks.
+ Overwrites XLMRoberta classifiacion layer.
+ '''
+ config_class = XLMRBicleanerAIConfig
- def __init__(self, config, head_hidden_size, head_dropout, head_activation):
- super().__init__(config)
- self.classifier = BCClassificationHead(config,
- head_hidden_size,
- head_dropout,
- head_activation,
- name='bc_classification_head')
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.classifier = BicleanerAIClassificationHead(config,
+ name='bicleaner_ai_classification_head')
diff --git a/requirements.txt b/requirements.txt
index 1d07d4b..8373732 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,7 +5,7 @@ pytest
toolwrapper
joblib
sacremoses
-bicleaner-hardrules>=2.1
+bicleaner-hardrules>=2.3,<3.0
sentencepiece
tensorflow>=2.6.5
glove-python-binary==0.2.0
diff --git a/scripts/bicleaner-ai-download b/scripts/bicleaner-ai-download
new file mode 100755
index 0000000..df74755
--- /dev/null
+++ b/scripts/bicleaner-ai-download
@@ -0,0 +1,64 @@
+#!/bin/bash
+
+usage() {
+ echo "Script to download Bicleaner AI language packs."
+ echo "It will try to download {lite,full}-lang1-lang2 and if it does not exist it will try {lite,full}-lang2-lang1 ."
+ echo
+ echo "Usage: `basename $0` <lang1> <lang2> <download_path>"
+ echo " <lang1> Language 1."
+ echo " <lang2> Language 2."
+ echo " {lite,full} Download lite or full model."
+ echo " <download_path> Path where downloaded language pack should be placed. Will be ignored for full models."
+}
+
+invalid_url(){
+ wget -S --spider -o - $1 | grep -q '404 Not Found'
+}
+
+if [[ $# -lt 3 ]]
+then
+ echo "Wrong number of arguments: $@" >&2
+ usage >&2
+ exit 1
+fi
+
+URL="https://github.com/bitextor/bicleaner-ai-data/releases/latest/download"
+L1=$1
+L2=$2
+if [ "$3" != "lite" ] && [ "$3" != "full" ]; then
+ echo "Model type must be 'lite' or 'full' not '$3'" 1>&2
+ usage >&2
+ exit 1
+fi
+TYPE=$3
+if [ "$4" != "" ]; then
+ DOWNLOAD_PATH=$4
+else
+ DOWNLOAD_PATH="."
+fi
+
+if [ "$TYPE" == "full" ]; then
+ # Download from HF Hub
+ bicleaner-download-hf bitextor/bicleaner-ai-full-$L1-$L2
+else
+ # Download from github bitextor/bicleaner-ai-data
+ # and decompress tgz in the desired directory
+ if invalid_url $URL/$TYPE-$L1-$L2.tgz
+ then
+ >&2 echo $L1-$L2 language pack does not exist, trying $L2-$L1...
+ if invalid_url $URL/$TYPE-$L2-$L1.tgz
+ then
+ >&2 echo $L2-$L1 language pack does not exist
+ else
+ wget -P $DOWNLOAD_PATH $URL/$TYPE-$L2-$L1.tgz
+ tar xvf $DOWNLOAD_PATH/$TYPE-$L2-$L1.tgz -C $DOWNLOAD_PATH
+ rm $DOWNLOAD_PATH/$TYPE-$L2-$L1.tgz
+ fi
+ else
+ wget -P $DOWNLOAD_PATH $URL/$TYPE-$L1-$L2.tgz
+ tar xvf $DOWNLOAD_PATH/$TYPE-$L1-$L2.tgz -C $DOWNLOAD_PATH
+ rm $DOWNLOAD_PATH/$TYPE-$L1-$L2.tgz
+ fi
+fi
+
+echo Finished
diff --git a/scripts/bicleaner-ai-download-hf b/scripts/bicleaner-ai-download-hf
new file mode 100644
index 0000000..ccd38e9
--- /dev/null
+++ b/scripts/bicleaner-ai-download-hf
@@ -0,0 +1,11 @@
+#!/usr/bin/env python
+from huggingface_hub import snapshot_download
+from argparse import ArgumentParser
+
+parser = ArgumentParser(description='Download Bicleaner AI full models from the Hugging Face Hub')
+parser.add_argument('model', type=str, help='Hugging Face Bicleaner AI model identifier (e.g. "bitextor/bicleaner-ai-full-en-fr")')
+parser.add_argument('-t', '--auth_token', default=None, type=str, help='Authentication token for private models downloading')
+
+args = parser.parse_args()
+
+snapshot_download(args.model, use_auth_token=args.auth_token)
diff --git a/setup.py b/setup.py
index 26bb1e0..21508fd 100755
--- a/setup.py
+++ b/setup.py
@@ -9,7 +9,7 @@ with open("requirements.txt") as rf:
setuptools.setup(
name="bicleaner-ai",
- version="1.0.2",
+ version="2.0",
install_requires=requirements,
license="GNU General Public License v3.0",
author="Prompsit Language Engineering",
@@ -41,5 +41,7 @@ setuptools.setup(
scripts=[
"scripts/bicleaner-ai-classify",
"scripts/bicleaner-ai-train",
+ "scripts/bicleaner-ai-download",
+ "scripts/bicleaner-ai-download-hf",
]
)
diff --git a/utils/download-pack.sh b/utils/download-pack.sh
deleted file mode 100755
index 357fad2..0000000
--- a/utils/download-pack.sh
+++ /dev/null
@@ -1,58 +0,0 @@
-#!/bin/bash
-
-usage() {
- echo "Script to download Bicleaner AI language packs."
- echo "It will try to download {lite,full}-lang1-lang2.tgz and if it does not exist it will try {lite,full}-lang2-lang1.tgz ."
- echo
- echo "Usage: `basename $0` <lang1> <lang2> <download_path>"
- echo " <lang1> Language 1."
- echo " <lang2> Language 2."
- echo " {lite,full} Download lite or full model."
- echo " <download_path> Path where downloaded language pack should be placed."
-}
-
-invalid_url(){
- wget -S --spider -o - $1 | grep -q '404 Not Found'
-}
-
-if [[ $# -lt 3 ]]
-then
- echo "Wrong number of arguments: $@" >&2
- usage >&2
- exit 1
-fi
-
-URL="https://github.com/bitextor/bicleaner-ai-data/releases/latest/download"
-L1=$1
-L2=$2
-if [ "$3" != "lite" ] && [ "$3" != "full" ]; then
- echo "Model type must be 'lite' or 'full' not '$3'" 1>&2
- usage >&2
- exit 1
-fi
-TYPE=$3
-if [ "$4" != "" ]; then
- DOWNLOAD_PATH=$4
-else
- DOWNLOAD_PATH="."
-fi
-
-
-if invalid_url $URL/$TYPE-$L1-$L2.tgz
-then
- >&2 echo $L1-$L2 language pack does not exist, trying $L2-$L1...
- if invalid_url $URL/$TYPE-$L2-$L1.tgz
- then
- >&2 echo $L2-$L1 language pack does not exist
- else
- wget -P $DOWNLOAD_PATH $URL/$TYPE-$L2-$L1.tgz
- tar xvf $DOWNLOAD_PATH/$TYPE-$L2-$L1.tgz -C $DOWNLOAD_PATH
- rm $DOWNLOAD_PATH/$TYPE-$L2-$L1.tgz
- fi
-else
- wget -P $DOWNLOAD_PATH $URL/$TYPE-$L1-$L2.tgz
- tar xvf $DOWNLOAD_PATH/$TYPE-$L1-$L2.tgz -C $DOWNLOAD_PATH
- rm $DOWNLOAD_PATH/$TYPE-$L1-$L2.tgz
-fi
-
-echo Finished