From 5d79c12900a166b78fb4ce856940a1dc5df8d29f Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Sat, 24 Oct 2020 09:33:02 +0900 Subject: add SetRandomGeneratorSeed --- python/src/sentencepiece/__init__.py | 4 ++ python/src/sentencepiece/sentencepiece.i | 1 + python/src/sentencepiece/sentencepiece_wrap.cxx | 95 +++++++++++++++++++++++++ src/spm_encode_main.cc | 3 + src/spm_train_main.cc | 4 ++ src/util.cc | 10 +-- src/util.h | 4 +- 7 files changed, 116 insertions(+), 5 deletions(-) diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py index 001ffc7..566f810 100644 --- a/python/src/sentencepiece/__init__.py +++ b/python/src/sentencepiece/__init__.py @@ -370,6 +370,9 @@ class SentencePieceProcessor(object): # Register SentencePieceProcessor in _sentencepiece: _sentencepiece.SentencePieceProcessor_swigregister(SentencePieceProcessor) + +def SetRandomGeneratorSeed(seed): + return _sentencepiece.SetRandomGeneratorSeed(seed) class SentencePieceTrainer(object): thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag") @@ -516,6 +519,7 @@ for m in [ _add_snake_case(SentencePieceProcessor) _add_snake_case(SentencePieceTrainer) +set_random_generator_seed = SetRandomGeneratorSeed diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i index 6522d1f..40938e4 100644 --- a/python/src/sentencepiece/sentencepiece.i +++ b/python/src/sentencepiece/sentencepiece.i @@ -740,4 +740,5 @@ for m in [ _add_snake_case(SentencePieceProcessor) _add_snake_case(SentencePieceTrainer) +set_random_generator_seed = SetRandomGeneratorSeed %} diff --git a/python/src/sentencepiece/sentencepiece_wrap.cxx b/python/src/sentencepiece/sentencepiece_wrap.cxx index 7e2e85d..a358b39 100644 --- a/python/src/sentencepiece/sentencepiece_wrap.cxx +++ b/python/src/sentencepiece/sentencepiece_wrap.cxx @@ -3301,6 +3301,70 @@ SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor_Decod "piece id is out of range."); return self->DecodeIdsAsSerializedProto(ids); } + +SWIGINTERN int +SWIG_AsVal_unsigned_SS_long (PyObject *obj, unsigned long *val) +{ +#if PY_VERSION_HEX < 0x03000000 + if (PyInt_Check(obj)) { + long v = PyInt_AsLong(obj); + if (v >= 0) { + if (val) *val = v; + return SWIG_OK; + } else { + return SWIG_OverflowError; + } + } else +#endif + if (PyLong_Check(obj)) { + unsigned long v = PyLong_AsUnsignedLong(obj); + if (!PyErr_Occurred()) { + if (val) *val = v; + return SWIG_OK; + } else { + PyErr_Clear(); + return SWIG_OverflowError; + } + } +#ifdef SWIG_PYTHON_CAST_MODE + { + int dispatch = 0; + unsigned long v = PyLong_AsUnsignedLong(obj); + if (!PyErr_Occurred()) { + if (val) *val = v; + return SWIG_AddCast(SWIG_OK); + } else { + PyErr_Clear(); + } + if (!dispatch) { + double d; + int res = SWIG_AddCast(SWIG_AsVal_double (obj,&d)); + if (SWIG_IsOK(res) && SWIG_CanCastAsInteger(&d, 0, ULONG_MAX)) { + if (val) *val = (unsigned long)(d); + return res; + } + } + } +#endif + return SWIG_TypeError; +} + + +SWIGINTERN int +SWIG_AsVal_unsigned_SS_int (PyObject * obj, unsigned int *val) +{ + unsigned long v; + int res = SWIG_AsVal_unsigned_SS_long (obj, &v); + if (SWIG_IsOK(res)) { + if ((v > UINT_MAX)) { + return SWIG_OverflowError; + } else { + if (val) *val = static_cast< unsigned int >(v); + } + } + return res; +} + SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromString(absl::string_view arg){ const auto _status = sentencepiece::SentencePieceTrainer::Train(arg); if (!_status.ok()) throw _status; @@ -4977,6 +5041,36 @@ SWIGINTERN PyObject *SentencePieceProcessor_swiginit(PyObject *SWIGUNUSEDPARM(se return SWIG_Python_InitShadowInstance(args); } +SWIGINTERN PyObject *_wrap_SetRandomGeneratorSeed(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + unsigned int arg1 ; + unsigned int val1 ; + int ecode1 = 0 ; + PyObject *swig_obj[1] ; + + if (!args) SWIG_fail; + swig_obj[0] = args; + ecode1 = SWIG_AsVal_unsigned_SS_int(swig_obj[0], &val1); + if (!SWIG_IsOK(ecode1)) { + SWIG_exception_fail(SWIG_ArgError(ecode1), "in method '" "SetRandomGeneratorSeed" "', argument " "1"" of type '" "unsigned int""'"); + } + arg1 = static_cast< unsigned int >(val1); + { + try { + sentencepiece::SetRandomGeneratorSeed(arg1); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + resultobj = SWIG_Py_Void(); + return resultobj; +fail: + return NULL; +} + + SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromString(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; absl::string_view arg1 ; @@ -5307,6 +5401,7 @@ static PyMethodDef SwigMethods[] = { { "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck", _wrap_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck, METH_VARARGS, NULL}, { "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL}, { "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL}, + { "SetRandomGeneratorSeed", _wrap_SetRandomGeneratorSeed, METH_O, NULL}, { "SentencePieceTrainer__TrainFromString", _wrap_SentencePieceTrainer__TrainFromString, METH_O, NULL}, { "SentencePieceTrainer__TrainFromMap", _wrap_SentencePieceTrainer__TrainFromMap, METH_O, NULL}, { "SentencePieceTrainer__TrainFromMap2", _wrap_SentencePieceTrainer__TrainFromMap2, METH_VARARGS, NULL}, diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc index c0c94db..f151ecf 100644 --- a/src/spm_encode_main.cc +++ b/src/spm_encode_main.cc @@ -61,6 +61,9 @@ int main(int argc, char *argv[]) { rest_args.push_back(absl::GetFlag(FLAGS_input)); } + if (absl::GetFlag(FLAGS_random_seed) != -1) + sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed)); + if (rest_args.empty()) rest_args.push_back(""); // empty means that reading from stdin. diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc index 847b7e7..b72a138 100644 --- a/src/spm_train_main.cc +++ b/src/spm_train_main.cc @@ -137,6 +137,7 @@ ABSL_FLAG(std::string, unk_surface, kDefaultTrainerSpec.unk_surface(), ABSL_FLAG(bool, train_extremely_large_corpus, kDefaultTrainerSpec.train_extremely_large_corpus(), "Increase bit depth for unigram tokenization."); +ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator."); int main(int argc, char *argv[]) { sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true); @@ -148,6 +149,9 @@ int main(int argc, char *argv[]) { CHECK(!absl::GetFlag(FLAGS_input).empty()); CHECK(!absl::GetFlag(FLAGS_model_prefix).empty()); + if (absl::GetFlag(FLAGS_random_seed) != -1) + sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed)); + auto load_lines = [](absl::string_view filename) { std::vector lines; auto input = sentencepiece::filesystem::NewReadableFile(filename); diff --git a/src/util.cc b/src/util.cc index e9ef6e6..58225ae 100644 --- a/src/util.cc +++ b/src/util.cc @@ -26,6 +26,10 @@ void SetRandomGeneratorSeed(unsigned int seed) { if (seed != kDefaultSeed) g_seed = seed; } +uint32 GetRandomGeneratorSeed() { + return g_seed == kDefaultSeed ? std::random_device{}() : g_seed; +} + namespace string_util { // mblen sotres the number of bytes consumed after decoding. @@ -153,8 +157,7 @@ class RandomGeneratorStorage { std::mt19937 *Get() { auto *result = static_cast(pthread_getspecific(key_)); if (result == nullptr) { - result = new std::mt19937(g_seed == kDefaultSeed ? std::random_device{}() - : g_seed); + result = new std::mt19937(GetRandomGeneratorSeed()); pthread_setspecific(key_, result); } return result; @@ -172,8 +175,7 @@ std::mt19937 *GetRandomGenerator() { } #else std::mt19937 *GetRandomGenerator() { - thread_local static std::mt19937 mt( - g_seed == kDefaultSeed ? std::random_device{}() : g_seed); + thread_local static std::mt19937 mt(GetRandomGeneratorSeed()); return &mt; } #endif diff --git a/src/util.h b/src/util.h index 176a363..673e8f6 100644 --- a/src/util.h +++ b/src/util.h @@ -49,6 +49,8 @@ std::ostream &operator<<(std::ostream &out, const std::vector &v) { return out; } +uint32 GetRandomGeneratorSeed(); + // String utilities namespace string_util { @@ -306,7 +308,7 @@ template class ReservoirSampler { public: explicit ReservoirSampler(std::vector *sampled, size_t size) - : sampled_(sampled), size_(size), engine_(std::random_device{}()) {} + : sampled_(sampled), size_(size), engine_(GetRandomGeneratorSeed()) {} explicit ReservoirSampler(std::vector *sampled, size_t size, size_t seed) : sampled_(sampled), size_(size), engine_(seed) {} virtual ~ReservoirSampler() {} -- cgit v1.2.3