diff options
author | Taku Kudo <taku@google.com> | 2020-10-24 03:33:02 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2020-10-24 03:33:02 +0300 |
commit | 5d79c12900a166b78fb4ce856940a1dc5df8d29f (patch) | |
tree | c20e1bd544812296c8be309a94cd8a5ddc8d5f5d | |
parent | 63211c130e320477ebf20e0895f73253a97d340d (diff) |
add SetRandomGeneratorSeed
-rw-r--r-- | python/src/sentencepiece/__init__.py | 4 | ||||
-rw-r--r-- | python/src/sentencepiece/sentencepiece.i | 1 | ||||
-rw-r--r-- | python/src/sentencepiece/sentencepiece_wrap.cxx | 95 | ||||
-rw-r--r-- | src/spm_encode_main.cc | 3 | ||||
-rw-r--r-- | src/spm_train_main.cc | 4 | ||||
-rw-r--r-- | src/util.cc | 10 | ||||
-rw-r--r-- | src/util.h | 4 |
7 files changed, 116 insertions, 5 deletions
diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py index 001ffc7..566f810 100644 --- a/python/src/sentencepiece/__init__.py +++ b/python/src/sentencepiece/__init__.py @@ -370,6 +370,9 @@ class SentencePieceProcessor(object): # Register SentencePieceProcessor in _sentencepiece: _sentencepiece.SentencePieceProcessor_swigregister(SentencePieceProcessor) + +def SetRandomGeneratorSeed(seed): + return _sentencepiece.SetRandomGeneratorSeed(seed) class SentencePieceTrainer(object): thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag") @@ -516,6 +519,7 @@ for m in [ _add_snake_case(SentencePieceProcessor) _add_snake_case(SentencePieceTrainer) +set_random_generator_seed = SetRandomGeneratorSeed diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i index 6522d1f..40938e4 100644 --- a/python/src/sentencepiece/sentencepiece.i +++ b/python/src/sentencepiece/sentencepiece.i @@ -740,4 +740,5 @@ for m in [ _add_snake_case(SentencePieceProcessor) _add_snake_case(SentencePieceTrainer) +set_random_generator_seed = SetRandomGeneratorSeed %} diff --git a/python/src/sentencepiece/sentencepiece_wrap.cxx b/python/src/sentencepiece/sentencepiece_wrap.cxx index 7e2e85d..a358b39 100644 --- a/python/src/sentencepiece/sentencepiece_wrap.cxx +++ b/python/src/sentencepiece/sentencepiece_wrap.cxx @@ -3301,6 +3301,70 @@ SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor_Decod "piece id is out of range."); return self->DecodeIdsAsSerializedProto(ids); } + +SWIGINTERN int +SWIG_AsVal_unsigned_SS_long (PyObject *obj, unsigned long *val) +{ +#if PY_VERSION_HEX < 0x03000000 + if (PyInt_Check(obj)) { + long v = PyInt_AsLong(obj); + if (v >= 0) { + if (val) *val = v; + return SWIG_OK; + } else { + return SWIG_OverflowError; + } + } else +#endif + if (PyLong_Check(obj)) { + unsigned long v = PyLong_AsUnsignedLong(obj); + if (!PyErr_Occurred()) { + if (val) *val = v; + return SWIG_OK; + } else { + PyErr_Clear(); + return SWIG_OverflowError; + } + } +#ifdef SWIG_PYTHON_CAST_MODE + { + int dispatch = 0; + unsigned long v = PyLong_AsUnsignedLong(obj); + if (!PyErr_Occurred()) { + if (val) *val = v; + return SWIG_AddCast(SWIG_OK); + } else { + PyErr_Clear(); + } + if (!dispatch) { + double d; + int res = SWIG_AddCast(SWIG_AsVal_double (obj,&d)); + if (SWIG_IsOK(res) && SWIG_CanCastAsInteger(&d, 0, ULONG_MAX)) { + if (val) *val = (unsigned long)(d); + return res; + } + } + } +#endif + return SWIG_TypeError; +} + + +SWIGINTERN int +SWIG_AsVal_unsigned_SS_int (PyObject * obj, unsigned int *val) +{ + unsigned long v; + int res = SWIG_AsVal_unsigned_SS_long (obj, &v); + if (SWIG_IsOK(res)) { + if ((v > UINT_MAX)) { + return SWIG_OverflowError; + } else { + if (val) *val = static_cast< unsigned int >(v); + } + } + return res; +} + SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromString(absl::string_view arg){ const auto _status = sentencepiece::SentencePieceTrainer::Train(arg); if (!_status.ok()) throw _status; @@ -4977,6 +5041,36 @@ SWIGINTERN PyObject *SentencePieceProcessor_swiginit(PyObject *SWIGUNUSEDPARM(se return SWIG_Python_InitShadowInstance(args); } +SWIGINTERN PyObject *_wrap_SetRandomGeneratorSeed(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + unsigned int arg1 ; + unsigned int val1 ; + int ecode1 = 0 ; + PyObject *swig_obj[1] ; + + if (!args) SWIG_fail; + swig_obj[0] = args; + ecode1 = SWIG_AsVal_unsigned_SS_int(swig_obj[0], &val1); + if (!SWIG_IsOK(ecode1)) { + SWIG_exception_fail(SWIG_ArgError(ecode1), "in method '" "SetRandomGeneratorSeed" "', argument " "1"" of type '" "unsigned int""'"); + } + arg1 = static_cast< unsigned int >(val1); + { + try { + sentencepiece::SetRandomGeneratorSeed(arg1); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + resultobj = SWIG_Py_Void(); + return resultobj; +fail: + return NULL; +} + + SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromString(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; absl::string_view arg1 ; @@ -5307,6 +5401,7 @@ static PyMethodDef SwigMethods[] = { { "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck", _wrap_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck, METH_VARARGS, NULL}, { "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL}, { "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL}, + { "SetRandomGeneratorSeed", _wrap_SetRandomGeneratorSeed, METH_O, NULL}, { "SentencePieceTrainer__TrainFromString", _wrap_SentencePieceTrainer__TrainFromString, METH_O, NULL}, { "SentencePieceTrainer__TrainFromMap", _wrap_SentencePieceTrainer__TrainFromMap, METH_O, NULL}, { "SentencePieceTrainer__TrainFromMap2", _wrap_SentencePieceTrainer__TrainFromMap2, METH_VARARGS, NULL}, diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc index c0c94db..f151ecf 100644 --- a/src/spm_encode_main.cc +++ b/src/spm_encode_main.cc @@ -61,6 +61,9 @@ int main(int argc, char *argv[]) { rest_args.push_back(absl::GetFlag(FLAGS_input)); } + if (absl::GetFlag(FLAGS_random_seed) != -1) + sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed)); + if (rest_args.empty()) rest_args.push_back(""); // empty means that reading from stdin. diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc index 847b7e7..b72a138 100644 --- a/src/spm_train_main.cc +++ b/src/spm_train_main.cc @@ -137,6 +137,7 @@ ABSL_FLAG(std::string, unk_surface, kDefaultTrainerSpec.unk_surface(), ABSL_FLAG(bool, train_extremely_large_corpus, kDefaultTrainerSpec.train_extremely_large_corpus(), "Increase bit depth for unigram tokenization."); +ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator."); int main(int argc, char *argv[]) { sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true); @@ -148,6 +149,9 @@ int main(int argc, char *argv[]) { CHECK(!absl::GetFlag(FLAGS_input).empty()); CHECK(!absl::GetFlag(FLAGS_model_prefix).empty()); + if (absl::GetFlag(FLAGS_random_seed) != -1) + sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed)); + auto load_lines = [](absl::string_view filename) { std::vector<std::string> lines; auto input = sentencepiece::filesystem::NewReadableFile(filename); diff --git a/src/util.cc b/src/util.cc index e9ef6e6..58225ae 100644 --- a/src/util.cc +++ b/src/util.cc @@ -26,6 +26,10 @@ void SetRandomGeneratorSeed(unsigned int seed) { if (seed != kDefaultSeed) g_seed = seed; } +uint32 GetRandomGeneratorSeed() { + return g_seed == kDefaultSeed ? std::random_device{}() : g_seed; +} + namespace string_util { // mblen sotres the number of bytes consumed after decoding. @@ -153,8 +157,7 @@ class RandomGeneratorStorage { std::mt19937 *Get() { auto *result = static_cast<std::mt19937 *>(pthread_getspecific(key_)); if (result == nullptr) { - result = new std::mt19937(g_seed == kDefaultSeed ? std::random_device{}() - : g_seed); + result = new std::mt19937(GetRandomGeneratorSeed()); pthread_setspecific(key_, result); } return result; @@ -172,8 +175,7 @@ std::mt19937 *GetRandomGenerator() { } #else std::mt19937 *GetRandomGenerator() { - thread_local static std::mt19937 mt( - g_seed == kDefaultSeed ? std::random_device{}() : g_seed); + thread_local static std::mt19937 mt(GetRandomGeneratorSeed()); return &mt; } #endif @@ -49,6 +49,8 @@ std::ostream &operator<<(std::ostream &out, const std::vector<T> &v) { return out; } +uint32 GetRandomGeneratorSeed(); + // String utilities namespace string_util { @@ -306,7 +308,7 @@ template <typename T> class ReservoirSampler { public: explicit ReservoirSampler(std::vector<T> *sampled, size_t size) - : sampled_(sampled), size_(size), engine_(std::random_device{}()) {} + : sampled_(sampled), size_(size), engine_(GetRandomGeneratorSeed()) {} explicit ReservoirSampler(std::vector<T> *sampled, size_t size, size_t seed) : sampled_(sampled), size_(size), engine_(seed) {} virtual ~ReservoirSampler() {} |