Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2020-10-23 18:20:11 +0300
committerTaku Kudo <taku@google.com>2020-10-23 18:20:11 +0300
commit910f804f720632caf4c31a66d09a6ca68bc1f287 (patch)
tree4046b1ffd5d2af6d2f8c83093badbe00b299849b /src/util.cc
parentd8c4b0405649d788a46744c646be85d83823f01c (diff)
validate the range of piece in Python module
Diffstat (limited to 'src/util.cc')
-rw-r--r--src/util.cc19
1 files changed, 15 insertions, 4 deletions
diff --git a/src/util.cc b/src/util.cc
index d3946e1..e9ef6e6 100644
--- a/src/util.cc
+++ b/src/util.cc
@@ -12,11 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
-#include <iostream>
-
#include "util.h"
+#include <iostream>
+
namespace sentencepiece {
+namespace {
+constexpr unsigned int kDefaultSeed = static_cast<unsigned int>(-1);
+static unsigned int g_seed = kDefaultSeed;
+} // namespace
+
+void SetRandomGeneratorSeed(unsigned int seed) {
+ if (seed != kDefaultSeed) g_seed = seed;
+}
+
namespace string_util {
// mblen sotres the number of bytes consumed after decoding.
@@ -144,7 +153,8 @@ class RandomGeneratorStorage {
std::mt19937 *Get() {
auto *result = static_cast<std::mt19937 *>(pthread_getspecific(key_));
if (result == nullptr) {
- result = new std::mt19937(std::random_device{}());
+ result = new std::mt19937(g_seed == kDefaultSeed ? std::random_device{}()
+ : g_seed);
pthread_setspecific(key_, result);
}
return result;
@@ -162,7 +172,8 @@ std::mt19937 *GetRandomGenerator() {
}
#else
std::mt19937 *GetRandomGenerator() {
- thread_local static std::mt19937 mt(std::random_device{}());
+ thread_local static std::mt19937 mt(
+ g_seed == kDefaultSeed ? std::random_device{}() : g_seed);
return &mt;
}
#endif