Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2020-10-24 03:33:02 +0300
committerTaku Kudo <taku@google.com>2020-10-24 03:33:02 +0300
commit5d79c12900a166b78fb4ce856940a1dc5df8d29f (patch)
treec20e1bd544812296c8be309a94cd8a5ddc8d5f5d
parent63211c130e320477ebf20e0895f73253a97d340d (diff)
add SetRandomGeneratorSeed
-rw-r--r--python/src/sentencepiece/__init__.py4
-rw-r--r--python/src/sentencepiece/sentencepiece.i1
-rw-r--r--python/src/sentencepiece/sentencepiece_wrap.cxx95
-rw-r--r--src/spm_encode_main.cc3
-rw-r--r--src/spm_train_main.cc4
-rw-r--r--src/util.cc10
-rw-r--r--src/util.h4
7 files changed, 116 insertions, 5 deletions
diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py
index 001ffc7..566f810 100644
--- a/python/src/sentencepiece/__init__.py
+++ b/python/src/sentencepiece/__init__.py
@@ -370,6 +370,9 @@ class SentencePieceProcessor(object):
# Register SentencePieceProcessor in _sentencepiece:
_sentencepiece.SentencePieceProcessor_swigregister(SentencePieceProcessor)
+
+def SetRandomGeneratorSeed(seed):
+ return _sentencepiece.SetRandomGeneratorSeed(seed)
class SentencePieceTrainer(object):
thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
@@ -516,6 +519,7 @@ for m in [
_add_snake_case(SentencePieceProcessor)
_add_snake_case(SentencePieceTrainer)
+set_random_generator_seed = SetRandomGeneratorSeed
diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i
index 6522d1f..40938e4 100644
--- a/python/src/sentencepiece/sentencepiece.i
+++ b/python/src/sentencepiece/sentencepiece.i
@@ -740,4 +740,5 @@ for m in [
_add_snake_case(SentencePieceProcessor)
_add_snake_case(SentencePieceTrainer)
+set_random_generator_seed = SetRandomGeneratorSeed
%}
diff --git a/python/src/sentencepiece/sentencepiece_wrap.cxx b/python/src/sentencepiece/sentencepiece_wrap.cxx
index 7e2e85d..a358b39 100644
--- a/python/src/sentencepiece/sentencepiece_wrap.cxx
+++ b/python/src/sentencepiece/sentencepiece_wrap.cxx
@@ -3301,6 +3301,70 @@ SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor_Decod
"piece id is out of range.");
return self->DecodeIdsAsSerializedProto(ids);
}
+
+SWIGINTERN int
+SWIG_AsVal_unsigned_SS_long (PyObject *obj, unsigned long *val)
+{
+#if PY_VERSION_HEX < 0x03000000
+ if (PyInt_Check(obj)) {
+ long v = PyInt_AsLong(obj);
+ if (v >= 0) {
+ if (val) *val = v;
+ return SWIG_OK;
+ } else {
+ return SWIG_OverflowError;
+ }
+ } else
+#endif
+ if (PyLong_Check(obj)) {
+ unsigned long v = PyLong_AsUnsignedLong(obj);
+ if (!PyErr_Occurred()) {
+ if (val) *val = v;
+ return SWIG_OK;
+ } else {
+ PyErr_Clear();
+ return SWIG_OverflowError;
+ }
+ }
+#ifdef SWIG_PYTHON_CAST_MODE
+ {
+ int dispatch = 0;
+ unsigned long v = PyLong_AsUnsignedLong(obj);
+ if (!PyErr_Occurred()) {
+ if (val) *val = v;
+ return SWIG_AddCast(SWIG_OK);
+ } else {
+ PyErr_Clear();
+ }
+ if (!dispatch) {
+ double d;
+ int res = SWIG_AddCast(SWIG_AsVal_double (obj,&d));
+ if (SWIG_IsOK(res) && SWIG_CanCastAsInteger(&d, 0, ULONG_MAX)) {
+ if (val) *val = (unsigned long)(d);
+ return res;
+ }
+ }
+ }
+#endif
+ return SWIG_TypeError;
+}
+
+
+SWIGINTERN int
+SWIG_AsVal_unsigned_SS_int (PyObject * obj, unsigned int *val)
+{
+ unsigned long v;
+ int res = SWIG_AsVal_unsigned_SS_long (obj, &v);
+ if (SWIG_IsOK(res)) {
+ if ((v > UINT_MAX)) {
+ return SWIG_OverflowError;
+ } else {
+ if (val) *val = static_cast< unsigned int >(v);
+ }
+ }
+ return res;
+}
+
SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromString(absl::string_view arg){
const auto _status = sentencepiece::SentencePieceTrainer::Train(arg);
if (!_status.ok()) throw _status;
@@ -4977,6 +5041,36 @@ SWIGINTERN PyObject *SentencePieceProcessor_swiginit(PyObject *SWIGUNUSEDPARM(se
return SWIG_Python_InitShadowInstance(args);
}
+SWIGINTERN PyObject *_wrap_SetRandomGeneratorSeed(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
+ PyObject *resultobj = 0;
+ unsigned int arg1 ;
+ unsigned int val1 ;
+ int ecode1 = 0 ;
+ PyObject *swig_obj[1] ;
+
+ if (!args) SWIG_fail;
+ swig_obj[0] = args;
+ ecode1 = SWIG_AsVal_unsigned_SS_int(swig_obj[0], &val1);
+ if (!SWIG_IsOK(ecode1)) {
+ SWIG_exception_fail(SWIG_ArgError(ecode1), "in method '" "SetRandomGeneratorSeed" "', argument " "1"" of type '" "unsigned int""'");
+ }
+ arg1 = static_cast< unsigned int >(val1);
+ {
+ try {
+ sentencepiece::SetRandomGeneratorSeed(arg1);
+ ReleaseResultObject(resultobj);
+ }
+ catch (const sentencepiece::util::Status &status) {
+ SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
+ }
+ }
+ resultobj = SWIG_Py_Void();
+ return resultobj;
+fail:
+ return NULL;
+}
+
+
SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromString(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
absl::string_view arg1 ;
@@ -5307,6 +5401,7 @@ static PyMethodDef SwigMethods[] = {
{ "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck", _wrap_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck, METH_VARARGS, NULL},
{ "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL},
{ "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL},
+ { "SetRandomGeneratorSeed", _wrap_SetRandomGeneratorSeed, METH_O, NULL},
{ "SentencePieceTrainer__TrainFromString", _wrap_SentencePieceTrainer__TrainFromString, METH_O, NULL},
{ "SentencePieceTrainer__TrainFromMap", _wrap_SentencePieceTrainer__TrainFromMap, METH_O, NULL},
{ "SentencePieceTrainer__TrainFromMap2", _wrap_SentencePieceTrainer__TrainFromMap2, METH_VARARGS, NULL},
diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc
index c0c94db..f151ecf 100644
--- a/src/spm_encode_main.cc
+++ b/src/spm_encode_main.cc
@@ -61,6 +61,9 @@ int main(int argc, char *argv[]) {
rest_args.push_back(absl::GetFlag(FLAGS_input));
}
+ if (absl::GetFlag(FLAGS_random_seed) != -1)
+ sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));
+
if (rest_args.empty())
rest_args.push_back(""); // empty means that reading from stdin.
diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc
index 847b7e7..b72a138 100644
--- a/src/spm_train_main.cc
+++ b/src/spm_train_main.cc
@@ -137,6 +137,7 @@ ABSL_FLAG(std::string, unk_surface, kDefaultTrainerSpec.unk_surface(),
ABSL_FLAG(bool, train_extremely_large_corpus,
kDefaultTrainerSpec.train_extremely_large_corpus(),
"Increase bit depth for unigram tokenization.");
+ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator.");
int main(int argc, char *argv[]) {
sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
@@ -148,6 +149,9 @@ int main(int argc, char *argv[]) {
CHECK(!absl::GetFlag(FLAGS_input).empty());
CHECK(!absl::GetFlag(FLAGS_model_prefix).empty());
+ if (absl::GetFlag(FLAGS_random_seed) != -1)
+ sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));
+
auto load_lines = [](absl::string_view filename) {
std::vector<std::string> lines;
auto input = sentencepiece::filesystem::NewReadableFile(filename);
diff --git a/src/util.cc b/src/util.cc
index e9ef6e6..58225ae 100644
--- a/src/util.cc
+++ b/src/util.cc
@@ -26,6 +26,10 @@ void SetRandomGeneratorSeed(unsigned int seed) {
if (seed != kDefaultSeed) g_seed = seed;
}
+uint32 GetRandomGeneratorSeed() {
+ return g_seed == kDefaultSeed ? std::random_device{}() : g_seed;
+}
+
namespace string_util {
// mblen sotres the number of bytes consumed after decoding.
@@ -153,8 +157,7 @@ class RandomGeneratorStorage {
std::mt19937 *Get() {
auto *result = static_cast<std::mt19937 *>(pthread_getspecific(key_));
if (result == nullptr) {
- result = new std::mt19937(g_seed == kDefaultSeed ? std::random_device{}()
- : g_seed);
+ result = new std::mt19937(GetRandomGeneratorSeed());
pthread_setspecific(key_, result);
}
return result;
@@ -172,8 +175,7 @@ std::mt19937 *GetRandomGenerator() {
}
#else
std::mt19937 *GetRandomGenerator() {
- thread_local static std::mt19937 mt(
- g_seed == kDefaultSeed ? std::random_device{}() : g_seed);
+ thread_local static std::mt19937 mt(GetRandomGeneratorSeed());
return &mt;
}
#endif
diff --git a/src/util.h b/src/util.h
index 176a363..673e8f6 100644
--- a/src/util.h
+++ b/src/util.h
@@ -49,6 +49,8 @@ std::ostream &operator<<(std::ostream &out, const std::vector<T> &v) {
return out;
}
+uint32 GetRandomGeneratorSeed();
+
// String utilities
namespace string_util {
@@ -306,7 +308,7 @@ template <typename T>
class ReservoirSampler {
public:
explicit ReservoirSampler(std::vector<T> *sampled, size_t size)
- : sampled_(sampled), size_(size), engine_(std::random_device{}()) {}
+ : sampled_(sampled), size_(size), engine_(GetRandomGeneratorSeed()) {}
explicit ReservoirSampler(std::vector<T> *sampled, size_t size, size_t seed)
: sampled_(sampled), size_(size), engine_(seed) {}
virtual ~ReservoirSampler() {}