diff options
author | Taku Kudo <taku@google.com> | 2020-10-23 18:20:11 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2020-10-23 18:20:11 +0300 |
commit | 910f804f720632caf4c31a66d09a6ca68bc1f287 (patch) | |
tree | 4046b1ffd5d2af6d2f8c83093badbe00b299849b /src/util.cc | |
parent | d8c4b0405649d788a46744c646be85d83823f01c (diff) |
validate the range of piece in Python module
Diffstat (limited to 'src/util.cc')
-rw-r--r-- | src/util.cc | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/src/util.cc b/src/util.cc index d3946e1..e9ef6e6 100644 --- a/src/util.cc +++ b/src/util.cc @@ -12,11 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include <iostream> - #include "util.h" +#include <iostream> + namespace sentencepiece { +namespace { +constexpr unsigned int kDefaultSeed = static_cast<unsigned int>(-1); +static unsigned int g_seed = kDefaultSeed; +} // namespace + +void SetRandomGeneratorSeed(unsigned int seed) { + if (seed != kDefaultSeed) g_seed = seed; +} + namespace string_util { // mblen sotres the number of bytes consumed after decoding. @@ -144,7 +153,8 @@ class RandomGeneratorStorage { std::mt19937 *Get() { auto *result = static_cast<std::mt19937 *>(pthread_getspecific(key_)); if (result == nullptr) { - result = new std::mt19937(std::random_device{}()); + result = new std::mt19937(g_seed == kDefaultSeed ? std::random_device{}() + : g_seed); pthread_setspecific(key_, result); } return result; @@ -162,7 +172,8 @@ std::mt19937 *GetRandomGenerator() { } #else std::mt19937 *GetRandomGenerator() { - thread_local static std::mt19937 mt(std::random_device{}()); + thread_local static std::mt19937 mt( + g_seed == kDefaultSeed ? std::random_device{}() : g_seed); return &mt; } #endif |