diff options
author | Jaume Zaragoza <ZJaume@users.noreply.github.com> | 2022-09-23 13:50:24 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-23 13:50:24 +0300 |
commit | 770e70c794441cf431287872a0818fdf07b14597 (patch) | |
tree | 933b7de684d326c36b6033fa1b301b69a975950a /bicleaner_ai | |
parent | a2eb410cdb7c8c39736ec798c276308d7893581f (diff) |
Add test suite (#21)
* Basic test for full model training
* Extend full train test
* Add train lite test
* Ensure reproducibility of frequence noise
* Unit test for noise generation
* Add Tokenizer class test
* Remove old test corpus file
* Add classifier tests
Download files on pytest setup to the test dir to avoid downloading it
every time.
Test normal, calibrated and raw modes.
* Download models only in classifier test
* Delete args object to avoid interference between tests
Diffstat (limited to 'bicleaner_ai')
-rwxr-xr-x | bicleaner_ai/bicleaner_ai_classifier.py | 6 | ||||
-rwxr-xr-x | bicleaner_ai/bicleaner_ai_train.py | 10 | ||||
-rw-r--r-- | bicleaner_ai/training.py | 2 |
3 files changed, 9 insertions, 9 deletions
diff --git a/bicleaner_ai/bicleaner_ai_classifier.py b/bicleaner_ai/bicleaner_ai_classifier.py index a9afdab..ea11824 100755 --- a/bicleaner_ai/bicleaner_ai_classifier.py +++ b/bicleaner_ai/bicleaner_ai_classifier.py @@ -26,13 +26,13 @@ except (ImportError, SystemError): logging_level = 0 -# All the scripts should have an initialization according with the usage. Template: -def initialization(): +# Argument parsing +def initialization(argv = None): global logging_level # Validating & parsing arguments parser, groupO, _ = argument_parser() - args = parser.parse_args() + args = parser.parse_args(argv) # Set up logging logging_setup(args) diff --git a/bicleaner_ai/bicleaner_ai_train.py b/bicleaner_ai/bicleaner_ai_train.py index a9acc4a..fbca32c 100755 --- a/bicleaner_ai/bicleaner_ai_train.py +++ b/bicleaner_ai/bicleaner_ai_train.py @@ -36,7 +36,7 @@ except (SystemError, ImportError): logging_level = 0 # Argument parsing -def initialization(): +def get_arguments(argv = None): global logging_level parser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0]), formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=__doc__) @@ -101,8 +101,9 @@ def initialization(): groupL.add_argument('--debug', action='store_true', help='Debug logging mode') groupL.add_argument('--logfile', type=argparse.FileType('a'), default=sys.stderr, help="Store log to a file") - args = parser.parse_args() + return parser.parse_args(argv) +def initialization(args): if args.freq_ratio > 0 and args.target_word_freqs is None: raise Exception("Frequence based noise needs target language word frequencies") if args.mono_train is None and args.classifier_type != 'xlmr': @@ -152,8 +153,6 @@ def initialization(): else: tf.get_logger().setLevel('CRITICAL') - return args - # Main loop of the program def perform_training(args): time_start = default_timer() @@ -226,7 +225,8 @@ def perform_training(args): "model_name": model_name, "batch_size": args.batch_size, "epochs": args.epochs, - "steps_per_epoch": args.steps_per_epoch + "steps_per_epoch": args.steps_per_epoch, + "vocab_size": args.vocab_size if 'vocab_size' in args else None, } # Avoid overriding settings with None model_settings = {k:v for k,v in model_settings.items() if v is not None } diff --git a/bicleaner_ai/training.py b/bicleaner_ai/training.py index eb45840..f11cde3 100644 --- a/bicleaner_ai/training.py +++ b/bicleaner_ai/training.py @@ -223,7 +223,7 @@ def replace_freq_words(sentence, double_linked_zipf_freqs): wfreq = double_linked_zipf_freqs.get_word_freq(w) alternatives = double_linked_zipf_freqs.get_words_for_freq(wfreq) if alternatives is not None: - alternatives = list(alternatives) + alternatives = list(sorted(alternatives)) # Avoid replace with the same word if w.lower() in alternatives: |