Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/bitextor/bicleaner-ai.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaume Zaragoza <ZJaume@users.noreply.github.com>2022-09-23 13:50:24 +0300
committerGitHub <noreply@github.com>2022-09-23 13:50:24 +0300
commit770e70c794441cf431287872a0818fdf07b14597 (patch)
tree933b7de684d326c36b6033fa1b301b69a975950a /bicleaner_ai
parenta2eb410cdb7c8c39736ec798c276308d7893581f (diff)
Add test suite (#21)
* Basic test for full model training * Extend full train test * Add train lite test * Ensure reproducibility of frequence noise * Unit test for noise generation * Add Tokenizer class test * Remove old test corpus file * Add classifier tests Download files on pytest setup to the test dir to avoid downloading it every time. Test normal, calibrated and raw modes. * Download models only in classifier test * Delete args object to avoid interference between tests
Diffstat (limited to 'bicleaner_ai')
-rwxr-xr-xbicleaner_ai/bicleaner_ai_classifier.py6
-rwxr-xr-xbicleaner_ai/bicleaner_ai_train.py10
-rw-r--r--bicleaner_ai/training.py2
3 files changed, 9 insertions, 9 deletions
diff --git a/bicleaner_ai/bicleaner_ai_classifier.py b/bicleaner_ai/bicleaner_ai_classifier.py
index a9afdab..ea11824 100755
--- a/bicleaner_ai/bicleaner_ai_classifier.py
+++ b/bicleaner_ai/bicleaner_ai_classifier.py
@@ -26,13 +26,13 @@ except (ImportError, SystemError):
logging_level = 0
-# All the scripts should have an initialization according with the usage. Template:
-def initialization():
+# Argument parsing
+def initialization(argv = None):
global logging_level
# Validating & parsing arguments
parser, groupO, _ = argument_parser()
- args = parser.parse_args()
+ args = parser.parse_args(argv)
# Set up logging
logging_setup(args)
diff --git a/bicleaner_ai/bicleaner_ai_train.py b/bicleaner_ai/bicleaner_ai_train.py
index a9acc4a..fbca32c 100755
--- a/bicleaner_ai/bicleaner_ai_train.py
+++ b/bicleaner_ai/bicleaner_ai_train.py
@@ -36,7 +36,7 @@ except (SystemError, ImportError):
logging_level = 0
# Argument parsing
-def initialization():
+def get_arguments(argv = None):
global logging_level
parser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0]), formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=__doc__)
@@ -101,8 +101,9 @@ def initialization():
groupL.add_argument('--debug', action='store_true', help='Debug logging mode')
groupL.add_argument('--logfile', type=argparse.FileType('a'), default=sys.stderr, help="Store log to a file")
- args = parser.parse_args()
+ return parser.parse_args(argv)
+def initialization(args):
if args.freq_ratio > 0 and args.target_word_freqs is None:
raise Exception("Frequence based noise needs target language word frequencies")
if args.mono_train is None and args.classifier_type != 'xlmr':
@@ -152,8 +153,6 @@ def initialization():
else:
tf.get_logger().setLevel('CRITICAL')
- return args
-
# Main loop of the program
def perform_training(args):
time_start = default_timer()
@@ -226,7 +225,8 @@ def perform_training(args):
"model_name": model_name,
"batch_size": args.batch_size,
"epochs": args.epochs,
- "steps_per_epoch": args.steps_per_epoch
+ "steps_per_epoch": args.steps_per_epoch,
+ "vocab_size": args.vocab_size if 'vocab_size' in args else None,
}
# Avoid overriding settings with None
model_settings = {k:v for k,v in model_settings.items() if v is not None }
diff --git a/bicleaner_ai/training.py b/bicleaner_ai/training.py
index eb45840..f11cde3 100644
--- a/bicleaner_ai/training.py
+++ b/bicleaner_ai/training.py
@@ -223,7 +223,7 @@ def replace_freq_words(sentence, double_linked_zipf_freqs):
wfreq = double_linked_zipf_freqs.get_word_freq(w)
alternatives = double_linked_zipf_freqs.get_words_for_freq(wfreq)
if alternatives is not None:
- alternatives = list(alternatives)
+ alternatives = list(sorted(alternatives))
# Avoid replace with the same word
if w.lower() in alternatives: