Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/bitextor/bicleaner-ai.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZJaume <jzaragoza@prompsit.com>2022-09-15 17:16:39 +0300
committerZJaume <jzaragoza@prompsit.com>2022-09-15 17:16:39 +0300
commita2eb410cdb7c8c39736ec798c276308d7893581f (patch)
treee5be9e7518aaf74538ef87b41e6463e2ac82a0d4
parent3dd57097af9b0b1e6076d6baf03c94a7a5704ebd (diff)
Set SentencePiece seed
-rwxr-xr-xbicleaner_ai/bicleaner_ai_train.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/bicleaner_ai/bicleaner_ai_train.py b/bicleaner_ai/bicleaner_ai_train.py
index 32e5018..a9acc4a 100755
--- a/bicleaner_ai/bicleaner_ai_train.py
+++ b/bicleaner_ai/bicleaner_ai_train.py
@@ -12,6 +12,7 @@ if 'BICLEANER_AI_THREADS' in os.environ:
from tempfile import TemporaryFile, NamedTemporaryFile, gettempdir
from multiprocessing import cpu_count
from timeit import default_timer
+import sentencepiece as spm
import tensorflow as tf
import numpy as np
import argparse
@@ -112,6 +113,7 @@ def initialization():
random.seed(args.seed)
os.environ["PYTHONHASHSEED"] = str(args.seed)
tf.random.seed = args.seed
+ spm.set_random_generator_seed(args.seed)
if args.gpu is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)