Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2020-05-20 07:45:49 +0300
committerTaku Kudo <taku@google.com>2020-05-20 07:45:49 +0300
commitd48247191a6d50e469ed1a4a36e877befffd1851 (patch)
tree83b5aba87746aaa7bd5dbafe26ed2628e0bbab74
parentb254e84528acdd1a2802d29922ae3496e8989be1 (diff)
0.1.91 pre-release
-rw-r--r--VERSION2
-rw-r--r--python/VERSION2
-rw-r--r--python/sentencepiece.i16
-rw-r--r--python/sentencepiece_wrap.cxx52
-rw-r--r--src/bpe_model.h4
-rw-r--r--src/builder.cc9
-rw-r--r--src/compile_charsmap_main.cc6
-rw-r--r--src/model_interface.h6
-rw-r--r--src/sentencepiece_processor.cc20
-rw-r--r--src/sentencepiece_processor.h12
-rw-r--r--src/sentencepiece_processor_test.cc4
-rw-r--r--src/sentencepiece_trainer.cc33
-rw-r--r--src/sentencepiece_trainer.h15
-rw-r--r--src/sentencepiece_trainer_test.cc3
-rw-r--r--src/trainer_interface.cc18
-rw-r--r--src/trainer_interface.h14
-rw-r--r--src/unigram_model.cc4
-rw-r--r--src/unigram_model.h4
-rw-r--r--third_party/absl/strings/str_replace.h1
19 files changed, 118 insertions, 107 deletions
diff --git a/VERSION b/VERSION
index 591c92d..496a825 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.1.90
+0.1.91
diff --git a/python/VERSION b/python/VERSION
index 591c92d..496a825 100644
--- a/python/VERSION
+++ b/python/VERSION
@@ -1 +1 @@
-0.1.90
+0.1.91
diff --git a/python/sentencepiece.i b/python/sentencepiece.i
index 17c255e..ee79311 100644
--- a/python/sentencepiece.i
+++ b/python/sentencepiece.i
@@ -396,27 +396,27 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
return;
}
- static void _TrainFromMap(const std::map<std::string, std::string> &args) {
+ static void _TrainFromMap(const std::unordered_map<std::string, std::string> &args) {
const auto _status = sentencepiece::SentencePieceTrainer::Train(args);
if (!_status.ok()) throw _status;
return;
}
- static void _TrainFromMap2(const std::map<std::string, std::string> &args,
+ static void _TrainFromMap2(const std::unordered_map<std::string, std::string> &args,
SentenceIterator *iter) {
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter);
if (!_status.ok()) throw _status;
return;
}
- static sentencepiece::util::bytes _TrainFromMap3(const std::map<std::string, std::string> &args) {
+ static sentencepiece::util::bytes _TrainFromMap3(const std::unordered_map<std::string, std::string> &args) {
sentencepiece::util::bytes model_proto;
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, nullptr, &model_proto);
if (!_status.ok()) throw _status;
return model_proto;
}
- static sentencepiece::util::bytes _TrainFromMap4(const std::map<std::string, std::string> &args,
+ static sentencepiece::util::bytes _TrainFromMap4(const std::unordered_map<std::string, std::string> &args,
SentenceIterator *iter) {
sentencepiece::util::bytes model_proto;
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter, &model_proto);
@@ -596,12 +596,12 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
$1 = out;
}
-%typemap(in) const std::map<std::string, std::string> & {
- std::map<std::string, std::string> *out = nullptr;
+%typemap(in) const std::unordered_map<std::string, std::string> & {
+ std::unordered_map<std::string, std::string> *out = nullptr;
if (PyDict_Check($input)) {
PyObject *key, *value;
Py_ssize_t pos = 0;
- out = new std::map<std::string, std::string>;
+ out = new std::unordered_map<std::string, std::string>;
while (PyDict_Next($input, &pos, &key, &value)) {
const PyInputString key_ustring(key);
const PyInputString value_ustring(value);
@@ -652,7 +652,7 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
delete $1;
}
-%typemap(freearg) const std::map<std::string, std::string> & {
+%typemap(freearg) const std::unordered_map<std::string, std::string> & {
delete $1;
}
diff --git a/python/sentencepiece_wrap.cxx b/python/sentencepiece_wrap.cxx
index 4ca73d4..bd7f6a1 100644
--- a/python/sentencepiece_wrap.cxx
+++ b/python/sentencepiece_wrap.cxx
@@ -2664,8 +2664,8 @@ SWIGINTERN PyObject *SWIG_PyStaticMethod_New(PyObject *SWIGUNUSEDPARM(self), PyO
#define SWIGTYPE_p_sentencepiece__SentenceIterator swig_types[1]
#define SWIGTYPE_p_sentencepiece__SentencePieceProcessor swig_types[2]
#define SWIGTYPE_p_sentencepiece__SentencePieceTrainer swig_types[3]
-#define SWIGTYPE_p_std__mapT_std__string_std__string_t swig_types[4]
-#define SWIGTYPE_p_std__string swig_types[5]
+#define SWIGTYPE_p_std__string swig_types[4]
+#define SWIGTYPE_p_std__unordered_mapT_std__string_std__string_t swig_types[5]
#define SWIGTYPE_p_std__vectorT_int_t swig_types[6]
#define SWIGTYPE_p_std__vectorT_std__string_t swig_types[7]
static swig_type_info *swig_types[9];
@@ -3290,23 +3290,23 @@ SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromString(absl::string
if (!_status.ok()) throw _status;
return;
}
-SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromMap(std::map< std::string,std::string > const &args){
+SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromMap(std::unordered_map< std::string,std::string > const &args){
const auto _status = sentencepiece::SentencePieceTrainer::Train(args);
if (!_status.ok()) throw _status;
return;
}
-SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromMap2(std::map< std::string,std::string > const &args,sentencepiece::SentenceIterator *iter){
+SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromMap2(std::unordered_map< std::string,std::string > const &args,sentencepiece::SentenceIterator *iter){
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter);
if (!_status.ok()) throw _status;
return;
}
-SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer__TrainFromMap3(std::map< std::string,std::string > const &args){
+SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer__TrainFromMap3(std::unordered_map< std::string,std::string > const &args){
sentencepiece::util::bytes model_proto;
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, nullptr, &model_proto);
if (!_status.ok()) throw _status;
return model_proto;
}
-SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer__TrainFromMap4(std::map< std::string,std::string > const &args,sentencepiece::SentenceIterator *iter){
+SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer__TrainFromMap4(std::unordered_map< std::string,std::string > const &args,sentencepiece::SentenceIterator *iter){
sentencepiece::util::bytes model_proto;
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter, &model_proto);
if (!_status.ok()) throw _status;
@@ -4995,17 +4995,17 @@ fail:
SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
- std::map< std::string,std::string > *arg1 = 0 ;
+ std::unordered_map< std::string,std::string > *arg1 = 0 ;
PyObject *swig_obj[1] ;
if (!args) SWIG_fail;
swig_obj[0] = args;
{
- std::map<std::string, std::string> *out = nullptr;
+ std::unordered_map<std::string, std::string> *out = nullptr;
if (PyDict_Check(swig_obj[0])) {
PyObject *key, *value;
Py_ssize_t pos = 0;
- out = new std::map<std::string, std::string>;
+ out = new std::unordered_map<std::string, std::string>;
while (PyDict_Next(swig_obj[0], &pos, &key, &value)) {
const PyInputString key_ustring(key);
const PyInputString value_ustring(value);
@@ -5026,7 +5026,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap(PyObject *SWIGUNUS
}
{
try {
- sentencepiece_SentencePieceTrainer__TrainFromMap((std::map< std::string,std::string > const &)*arg1);
+ sentencepiece_SentencePieceTrainer__TrainFromMap((std::unordered_map< std::string,std::string > const &)*arg1);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
@@ -5048,17 +5048,17 @@ fail:
SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap2(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
- std::map< std::string,std::string > *arg1 = 0 ;
+ std::unordered_map< std::string,std::string > *arg1 = 0 ;
sentencepiece::SentenceIterator *arg2 = (sentencepiece::SentenceIterator *) 0 ;
PyObject *swig_obj[2] ;
if (!SWIG_Python_UnpackTuple(args, "SentencePieceTrainer__TrainFromMap2", 2, 2, swig_obj)) SWIG_fail;
{
- std::map<std::string, std::string> *out = nullptr;
+ std::unordered_map<std::string, std::string> *out = nullptr;
if (PyDict_Check(swig_obj[0])) {
PyObject *key, *value;
Py_ssize_t pos = 0;
- out = new std::map<std::string, std::string>;
+ out = new std::unordered_map<std::string, std::string>;
while (PyDict_Next(swig_obj[0], &pos, &key, &value)) {
const PyInputString key_ustring(key);
const PyInputString value_ustring(value);
@@ -5089,7 +5089,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap2(PyObject *SWIGUNU
}
{
try {
- sentencepiece_SentencePieceTrainer__TrainFromMap2((std::map< std::string,std::string > const &)*arg1,arg2);
+ sentencepiece_SentencePieceTrainer__TrainFromMap2((std::unordered_map< std::string,std::string > const &)*arg1,arg2);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
@@ -5117,18 +5117,18 @@ fail:
SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap3(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
- std::map< std::string,std::string > *arg1 = 0 ;
+ std::unordered_map< std::string,std::string > *arg1 = 0 ;
PyObject *swig_obj[1] ;
sentencepiece::util::bytes result;
if (!args) SWIG_fail;
swig_obj[0] = args;
{
- std::map<std::string, std::string> *out = nullptr;
+ std::unordered_map<std::string, std::string> *out = nullptr;
if (PyDict_Check(swig_obj[0])) {
PyObject *key, *value;
Py_ssize_t pos = 0;
- out = new std::map<std::string, std::string>;
+ out = new std::unordered_map<std::string, std::string>;
while (PyDict_Next(swig_obj[0], &pos, &key, &value)) {
const PyInputString key_ustring(key);
const PyInputString value_ustring(value);
@@ -5149,7 +5149,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap3(PyObject *SWIGUNU
}
{
try {
- result = sentencepiece_SentencePieceTrainer__TrainFromMap3((std::map< std::string,std::string > const &)*arg1);
+ result = sentencepiece_SentencePieceTrainer__TrainFromMap3((std::unordered_map< std::string,std::string > const &)*arg1);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
@@ -5173,18 +5173,18 @@ fail:
SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap4(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
- std::map< std::string,std::string > *arg1 = 0 ;
+ std::unordered_map< std::string,std::string > *arg1 = 0 ;
sentencepiece::SentenceIterator *arg2 = (sentencepiece::SentenceIterator *) 0 ;
PyObject *swig_obj[2] ;
sentencepiece::util::bytes result;
if (!SWIG_Python_UnpackTuple(args, "SentencePieceTrainer__TrainFromMap4", 2, 2, swig_obj)) SWIG_fail;
{
- std::map<std::string, std::string> *out = nullptr;
+ std::unordered_map<std::string, std::string> *out = nullptr;
if (PyDict_Check(swig_obj[0])) {
PyObject *key, *value;
Py_ssize_t pos = 0;
- out = new std::map<std::string, std::string>;
+ out = new std::unordered_map<std::string, std::string>;
while (PyDict_Next(swig_obj[0], &pos, &key, &value)) {
const PyInputString key_ustring(key);
const PyInputString value_ustring(value);
@@ -5215,7 +5215,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap4(PyObject *SWIGUNU
}
{
try {
- result = sentencepiece_SentencePieceTrainer__TrainFromMap4((std::map< std::string,std::string > const &)*arg1,arg2);
+ result = sentencepiece_SentencePieceTrainer__TrainFromMap4((std::unordered_map< std::string,std::string > const &)*arg1,arg2);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
@@ -5311,8 +5311,8 @@ static swig_type_info _swigt__p_char = {"_p_char", "char *", 0, 0, (void*)0, 0};
static swig_type_info _swigt__p_sentencepiece__SentenceIterator = {"_p_sentencepiece__SentenceIterator", "sentencepiece::SentenceIterator *", 0, 0, (void*)0, 0};
static swig_type_info _swigt__p_sentencepiece__SentencePieceProcessor = {"_p_sentencepiece__SentencePieceProcessor", "sentencepiece::SentencePieceProcessor *", 0, 0, (void*)0, 0};
static swig_type_info _swigt__p_sentencepiece__SentencePieceTrainer = {"_p_sentencepiece__SentencePieceTrainer", "sentencepiece::SentencePieceTrainer *", 0, 0, (void*)0, 0};
-static swig_type_info _swigt__p_std__mapT_std__string_std__string_t = {"_p_std__mapT_std__string_std__string_t", "std::map< std::string,std::string > *", 0, 0, (void*)0, 0};
static swig_type_info _swigt__p_std__string = {"_p_std__string", "sentencepiece::util::bytes *|std::string *", 0, 0, (void*)0, 0};
+static swig_type_info _swigt__p_std__unordered_mapT_std__string_std__string_t = {"_p_std__unordered_mapT_std__string_std__string_t", "std::unordered_map< std::string,std::string > *", 0, 0, (void*)0, 0};
static swig_type_info _swigt__p_std__vectorT_int_t = {"_p_std__vectorT_int_t", "std::vector< int > *", 0, 0, (void*)0, 0};
static swig_type_info _swigt__p_std__vectorT_std__string_t = {"_p_std__vectorT_std__string_t", "std::vector< std::string > *", 0, 0, (void*)0, 0};
@@ -5321,8 +5321,8 @@ static swig_type_info *swig_type_initial[] = {
&_swigt__p_sentencepiece__SentenceIterator,
&_swigt__p_sentencepiece__SentencePieceProcessor,
&_swigt__p_sentencepiece__SentencePieceTrainer,
- &_swigt__p_std__mapT_std__string_std__string_t,
&_swigt__p_std__string,
+ &_swigt__p_std__unordered_mapT_std__string_std__string_t,
&_swigt__p_std__vectorT_int_t,
&_swigt__p_std__vectorT_std__string_t,
};
@@ -5331,8 +5331,8 @@ static swig_cast_info _swigc__p_char[] = { {&_swigt__p_char, 0, 0, 0},{0, 0, 0,
static swig_cast_info _swigc__p_sentencepiece__SentenceIterator[] = { {&_swigt__p_sentencepiece__SentenceIterator, 0, 0, 0},{0, 0, 0, 0}};
static swig_cast_info _swigc__p_sentencepiece__SentencePieceProcessor[] = { {&_swigt__p_sentencepiece__SentencePieceProcessor, 0, 0, 0},{0, 0, 0, 0}};
static swig_cast_info _swigc__p_sentencepiece__SentencePieceTrainer[] = { {&_swigt__p_sentencepiece__SentencePieceTrainer, 0, 0, 0},{0, 0, 0, 0}};
-static swig_cast_info _swigc__p_std__mapT_std__string_std__string_t[] = { {&_swigt__p_std__mapT_std__string_std__string_t, 0, 0, 0},{0, 0, 0, 0}};
static swig_cast_info _swigc__p_std__string[] = { {&_swigt__p_std__string, 0, 0, 0},{0, 0, 0, 0}};
+static swig_cast_info _swigc__p_std__unordered_mapT_std__string_std__string_t[] = { {&_swigt__p_std__unordered_mapT_std__string_std__string_t, 0, 0, 0},{0, 0, 0, 0}};
static swig_cast_info _swigc__p_std__vectorT_int_t[] = { {&_swigt__p_std__vectorT_int_t, 0, 0, 0},{0, 0, 0, 0}};
static swig_cast_info _swigc__p_std__vectorT_std__string_t[] = { {&_swigt__p_std__vectorT_std__string_t, 0, 0, 0},{0, 0, 0, 0}};
@@ -5341,8 +5341,8 @@ static swig_cast_info *swig_cast_initial[] = {
_swigc__p_sentencepiece__SentenceIterator,
_swigc__p_sentencepiece__SentencePieceProcessor,
_swigc__p_sentencepiece__SentencePieceTrainer,
- _swigc__p_std__mapT_std__string_std__string_t,
_swigc__p_std__string,
+ _swigc__p_std__unordered_mapT_std__string_std__string_t,
_swigc__p_std__vectorT_int_t,
_swigc__p_std__vectorT_std__string_t,
};
diff --git a/src/bpe_model.h b/src/bpe_model.h
index 243664f..c6e1abe 100644
--- a/src/bpe_model.h
+++ b/src/bpe_model.h
@@ -42,6 +42,10 @@ class Model : public ModelInterface {
// When alpha <= 0.0, no sampling is performed.
EncodeResult SampleEncode(absl::string_view normalized,
float alpha) const override;
+
+ bool IsSampleEncodeAvailable() const override { return true; }
+
+ bool IsNBestEncodeAvailable() const override { return false; }
};
} // namespace bpe
} // namespace sentencepiece
diff --git a/src/builder.cc b/src/builder.cc
index 7e8ca98..d9442d3 100644
--- a/src/builder.cc
+++ b/src/builder.cc
@@ -54,14 +54,7 @@ Builder::Chars UnicodeNormalize(UNormalizationMode mode,
const std::string utf8 = string_util::UnicodeTextToUTF8(input);
CHECK(!utf8.empty());
- icu::UnicodeString ustr;
- const size_t utf8_length = utf8.size();
- UChar *utf16 = ustr.getBuffer(utf8.size() + 1);
- int32 utf16_length = 0;
- icu::ErrorCode icuerrorcode;
- u_strFromUTF8Lenient(utf16, ustr.getCapacity(), &utf16_length, utf8.data(),
- utf8_length, icuerrorcode);
- ustr.releaseBuffer(utf16_length);
+ icu::UnicodeString ustr = icu::UnicodeString::fromUTF8(utf8.c_str());
UErrorCode status = U_ZERO_ERROR;
icu::UnicodeString dst;
diff --git a/src/compile_charsmap_main.cc b/src/compile_charsmap_main.cc
index 21f1ee8..e8fc072 100644
--- a/src/compile_charsmap_main.cc
+++ b/src/compile_charsmap_main.cc
@@ -25,7 +25,6 @@
#include "third_party/absl/strings/string_view.h"
using sentencepiece::normalizer::Builder;
-using util::Status;
DEFINE_bool(output_precompiled_header, false, "make normalization_rule.h file");
@@ -157,8 +156,9 @@ struct BinaryBlob {
int main(int argc, char **argv) {
sentencepiece::flags::ParseCommandLineFlags(argv[0], &argc, &argv, true);
- const std::vector<
- std::pair<std::string, std::function<Status(Builder::CharsMap *)>>>
+ const std::vector<std::pair<
+ std::string,
+ std::function<sentencepiece::util::Status(Builder::CharsMap *)>>>
kRuleList = {{"nfkc", Builder::BuildNFKCMap},
{"nmt_nfkc", Builder::BuildNmtNFKCMap},
{"nfkc_cf", Builder::BuildNFKC_CFMap},
diff --git a/src/model_interface.h b/src/model_interface.h
index 98a4798..27dad99 100644
--- a/src/model_interface.h
+++ b/src/model_interface.h
@@ -106,6 +106,12 @@ class ModelInterface {
return EncodeResult();
}
+ // Return true if SampleEncode returns a valid result.
+ virtual bool IsSampleEncodeAvailable() const { return false; }
+
+ // Return true if NBestEncode returns a valid result.
+ virtual bool IsNBestEncodeAvailable() const { return false; }
+
// Returns the vocab id of `piece`.
// Returns UNK(0) if `piece` is unknown
virtual int PieceToId(absl::string_view piece) const;
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 4263a2f..a4dd575 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
-#include "sentencepiece_processor.h"
-
#include <map>
#include <set>
#include <utility>
@@ -24,6 +22,7 @@
#include "model_factory.h"
#include "model_interface.h"
#include "normalizer.h"
+#include "sentencepiece_processor.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/numbers.h"
#include "third_party/absl/strings/str_cat.h"
@@ -446,6 +445,9 @@ util::Status SentencePieceProcessor::NBestEncode(
std::vector<size_t> norm_to_orig;
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
+ CHECK_OR_RETURN(model_->IsNBestEncodeAvailable())
+ << "NBestEncode is not available for the current model.";
+
const auto nbests = model_->NBestEncode(normalized, nbest_size);
CHECK_OR_RETURN(!nbests.empty()) << "NBestEncode returns empty result.";
@@ -470,7 +472,13 @@ util::Status SentencePieceProcessor::SampleEncode(
std::vector<size_t> norm_to_orig;
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
- if (nbest_size == 1 || nbest_size == 0) {
+ if (!model_->IsNBestEncodeAvailable() || nbest_size < 0) {
+ CHECK_OR_RETURN(model_->IsSampleEncodeAvailable())
+ << "SampleEncode is not available for the current model.";
+ const auto result = model_->SampleEncode(normalized, alpha);
+ RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
+ result, spt));
+ } else if (nbest_size == 1 || nbest_size == 0) {
const auto result = model_->Encode(normalized);
RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
result, spt));
@@ -487,11 +495,6 @@ util::Status SentencePieceProcessor::SampleEncode(
std::discrete_distribution<int> dist(probs.begin(), probs.end());
RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
nbests[dist(*mt)].first, spt));
-
- } else if (nbest_size < 0) {
- const auto result = model_->SampleEncode(normalized, alpha);
- RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
- result, spt));
}
return util::OkStatus();
@@ -828,6 +831,5 @@ util::Status SaveModelProto(absl::string_view filename,
return util::OkStatus();
}
-
} // namespace io
} // namespace sentencepiece
diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h
index 2b31cb1..019eddf 100644
--- a/src/sentencepiece_processor.h
+++ b/src/sentencepiece_processor.h
@@ -286,9 +286,8 @@ class SentencePieceProcessor {
//
// - BPE (--model_type=bpe):
// `alpha` is the merge probability `p` in https://arxiv.org/abs/1910.13267
- // when alpha<=0, no sampling is performed but the best segmentation is
- // returned. Nbest-based sampling is not supported so you need to specify
- // nbest_size = 0 in BPE.
+ // Nbest-based sampling is not supported so nbest_size parameter is ignored in
+ // BPE.
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
float alpha,
std::vector<std::string> *pieces) const;
@@ -503,13 +502,10 @@ namespace io {
// io::LoadModelProto("//path/spm.model", model_proto.get());
// SentencePieceProcessor sp;
// CHECK_OK(sp.Load(std::move(model_proto)));
-util::Status LoadModelProto(absl::string_view filename,
- ModelProto *model_proto);
+util::Status LoadModelProto(absl::string_view, ModelProto *model_proto);
// Saves `model_proto` as `filename`.
-util::Status SaveModelProto(absl::string_view filename,
- const ModelProto &model_proto);
-
+util::Status SaveModelProto(absl::string_view, const ModelProto &model_proto);
} // namespace io
#endif // SWIG
} // namespace sentencepiece
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc
index 3e00404..bceba2c 100644
--- a/src/sentencepiece_processor_test.cc
+++ b/src/sentencepiece_processor_test.cc
@@ -63,6 +63,10 @@ class MockModel : public ModelInterface {
return nbest_output_;
}
+ bool IsSampleEncodeAvailable() const override { return true; }
+
+ bool IsNBestEncodeAvailable() const override { return true; }
+
bool IsControl(int id) const { return id == 1 || id == 2; }
bool IsUnknown(int id) const { return id == 0; }
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc
index 10c5b6f..e36aa9c 100644
--- a/src/sentencepiece_trainer.cc
+++ b/src/sentencepiece_trainer.cc
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
-#include "sentencepiece_trainer.h"
-
#include <string>
#include <vector>
@@ -22,6 +20,7 @@
#include "builtin_pb/sentencepiece_model.pb.h"
#include "common.h"
#include "normalizer.h"
+#include "sentencepiece_trainer.h"
#include "spec_parser.h"
#include "third_party/absl/strings/str_cat.h"
#include "third_party/absl/strings/str_split.h"
@@ -75,10 +74,15 @@ util::Status SentencePieceTrainer::Train(
LOG(INFO) << "Starts training with : \n" << info;
- trainer->SetSentenceIterator(sentence_iterator);
- trainer->SetOutputSerializedModelProto(serialized_model_proto);
+ if (serialized_model_proto) {
+ ModelProto model_proto;
+ RETURN_IF_ERROR(trainer->Train(sentence_iterator, &model_proto));
+ *serialized_model_proto = model_proto.SerializeAsString();
+ } else {
+ RETURN_IF_ERROR(trainer->Train(sentence_iterator, nullptr));
+ }
- return trainer->Train();
+ return util::OkStatus();
}
// static
@@ -100,7 +104,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
if (args.empty()) return util::OkStatus();
- std::map<std::string, std::string> kwargs;
+ std::unordered_map<std::string, std::string> kwargs;
for (auto arg : absl::StrSplit(args, " ")) {
absl::ConsumePrefix(&arg, "--");
std::string key, value;
@@ -120,8 +124,9 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
// static
util::Status SentencePieceTrainer::MergeSpecsFromArgs(
- const std::map<std::string, std::string> &kwargs, TrainerSpec *trainer_spec,
- NormalizerSpec *normalizer_spec, NormalizerSpec *denormalizer_spec) {
+ const std::unordered_map<std::string, std::string> &kwargs,
+ TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec,
+ NormalizerSpec *denormalizer_spec) {
CHECK_OR_RETURN(trainer_spec) << "`trainer_spec` must not be null.";
CHECK_OR_RETURN(normalizer_spec) << "`normalizer_spec` must not be null.";
CHECK_OR_RETURN(denormalizer_spec) << "`denormalizer_spec` must not be null.";
@@ -174,7 +179,7 @@ util::Status SentencePieceTrainer::Train(absl::string_view args,
// static
util::Status SentencePieceTrainer::Train(
- const std::map<std::string, std::string> &kwargs,
+ const std::unordered_map<std::string, std::string> &kwargs,
SentenceIterator *sentence_iterator, std::string *serialized_model_proto) {
TrainerSpec trainer_spec;
NormalizerSpec normalizer_spec;
@@ -216,11 +221,11 @@ util::Status SentencePieceTrainer::PopulateNormalizerSpec(
// static
util::Status SentencePieceTrainer::PopulateModelTypeFromString(
absl::string_view type, TrainerSpec *spec) {
- static const std::map<std::string, TrainerSpec::ModelType> kModelTypeMap = {
- {"unigram", TrainerSpec::UNIGRAM},
- {"bpe", TrainerSpec::BPE},
- {"word", TrainerSpec::WORD},
- {"char", TrainerSpec::CHAR}};
+ static const std::unordered_map<std::string, TrainerSpec::ModelType>
+ kModelTypeMap = {{"unigram", TrainerSpec::UNIGRAM},
+ {"bpe", TrainerSpec::BPE},
+ {"word", TrainerSpec::WORD},
+ {"char", TrainerSpec::CHAR}};
const auto it = kModelTypeMap.find(absl::AsciiStrToLower(type));
if (it != kModelTypeMap.end()) {
spec->set_model_type(it->second);
diff --git a/src/sentencepiece_trainer.h b/src/sentencepiece_trainer.h
index 5782741..bb74ab9 100644
--- a/src/sentencepiece_trainer.h
+++ b/src/sentencepiece_trainer.h
@@ -15,8 +15,8 @@
#ifndef SENTENCEPIECE_TRAINER_H_
#define SENTENCEPIECE_TRAINER_H_
-#include <map>
#include <string>
+#include <unordered_map>
#include "sentencepiece_processor.h"
@@ -84,9 +84,10 @@ class SentencePieceTrainer {
// Trains SentencePiece model with mapin `kwargs`.
// e.g., {{"input", "data"}, {"model_prefix, "m"}, {"vocab_size", "8192"}...}
- static util::Status Train(const std::map<std::string, std::string> &kwargs,
- SentenceIterator *sentence_iterator = nullptr,
- std::string *serialized_model_proto = nullptr);
+ static util::Status Train(
+ const std::unordered_map<std::string, std::string> &kwargs,
+ SentenceIterator *sentence_iterator = nullptr,
+ std::string *serialized_model_proto = nullptr);
// Handy function to make a normalizer spec from the pre-compiled
// normalization name. Do not use this method in production as it crashes
@@ -96,12 +97,12 @@ class SentencePieceTrainer {
// Populates necessary fields (precompiled_charmap) from
// `NormalizerSpec::name` or `NormalizerSpec::normalization_rule_tsv`.
static util::Status PopulateNormalizerSpec(NormalizerSpec *normalizer_spec,
- bool is_denomalizer = false);
+ bool is_denormalizer = false);
// Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the
- // std::map in `kargs`.
+ // std::unordered_map in `kargs`.
static util::Status MergeSpecsFromArgs(
- const std::map<std::string, std::string> &kwargs,
+ const std::unordered_map<std::string, std::string> &kwargs,
TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec,
NormalizerSpec *denormalizer_spec);
diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc
index c95f686..b78b1d2 100644
--- a/src/sentencepiece_trainer_test.cc
+++ b/src/sentencepiece_trainer_test.cc
@@ -12,10 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
-#include "sentencepiece_trainer.h"
-
#include "builtin_pb/sentencepiece_model.pb.h"
#include "filesystem.h"
+#include "sentencepiece_trainer.h"
#include "testharness.h"
#include "third_party/absl/strings/str_cat.h"
#include "util.h"
diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc
index 37f7003..5cdb300 100644
--- a/src/trainer_interface.cc
+++ b/src/trainer_interface.cc
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
-#include "trainer_interface.h"
-
#include <cstdlib>
#include <memory>
#include <set>
@@ -34,6 +32,7 @@
#include "third_party/absl/strings/str_format.h"
#include "third_party/absl/strings/str_join.h"
#include "third_party/absl/strings/str_split.h"
+#include "trainer_interface.h"
#include "unicode_script.h"
#include "util.h"
@@ -50,7 +49,6 @@ const char TrainerInterface::kUPPBoundaryStr[] = "\t";
namespace {
util::Status VerifySpec(const TrainerSpec &trainer_spec) {
- // CHECK_OR_RETURN(!trainer_spec.model_prefix().empty());
CHECK_GT_OR_RETURN(trainer_spec.vocab_size(), 0);
if (trainer_spec.model_type() == TrainerSpec::UNIGRAM ||
@@ -313,10 +311,10 @@ util::Status TrainerInterface::LoadSentences() {
(sentence_iterator_ == nullptr && !trainer_spec_.input().empty()))
<< "SentenceIterator and trainer_spec.input() must be exclusive.";
- CHECK_OR_RETURN((serialized_model_proto_ != nullptr &&
- trainer_spec_.model_prefix().empty()) ||
- (serialized_model_proto_ == nullptr &&
- !trainer_spec_.model_prefix().empty()))
+ CHECK_OR_RETURN(
+ (output_model_proto_ != nullptr &&
+ trainer_spec_.model_prefix().empty()) ||
+ (output_model_proto_ == nullptr && !trainer_spec_.model_prefix().empty()))
<< "ModelProto and trainer_spec.model_prefix() must be exclusive.";
const bool is_tsv = trainer_spec_.input_format() == "tsv";
@@ -647,10 +645,8 @@ util::Status TrainerInterface::SaveVocab(absl::string_view filename) const {
}
util::Status TrainerInterface::Save() const {
- if (serialized_model_proto_) {
- ModelProto model_proto;
- RETURN_IF_ERROR(Serialize(&model_proto));
- *serialized_model_proto_ = model_proto.SerializeAsString();
+ if (output_model_proto_) {
+ RETURN_IF_ERROR(Serialize(output_model_proto_));
} else {
RETURN_IF_ERROR(SaveModel(trainer_spec_.model_prefix() + ".model"));
RETURN_IF_ERROR(SaveVocab(trainer_spec_.model_prefix() + ".vocab"));
diff --git a/src/trainer_interface.h b/src/trainer_interface.h
index 6cd2469..552b206 100644
--- a/src/trainer_interface.h
+++ b/src/trainer_interface.h
@@ -88,13 +88,13 @@ class TrainerInterface {
virtual ~TrainerInterface();
- virtual void SetSentenceIterator(SentenceIterator *sentence_iterator) {
+ // Loads sentence from `sentence_iterator` and stores the model
+ // to `output_model_proto`.
+ virtual util::Status Train(SentenceIterator *sentence_iterator,
+ ModelProto *output_model_proto) {
sentence_iterator_ = sentence_iterator;
- }
-
- virtual void SetOutputSerializedModelProto(
- std::string *serialized_model_proto) {
- serialized_model_proto_ = serialized_model_proto;
+ output_model_proto_ = output_model_proto;
+ return Train();
}
virtual util::Status Train() { return status(); }
@@ -158,7 +158,7 @@ class TrainerInterface {
SentenceIterator *sentence_iterator_ = nullptr;
// Emits model to this proto instead of file.
- std::string *serialized_model_proto_ = nullptr;
+ ModelProto *output_model_proto_ = nullptr;
private:
// Serialize final_pieces_ to |model_proto|.
diff --git a/src/unigram_model.cc b/src/unigram_model.cc
index 8f6cd4b..bd2d99b 100644
--- a/src/unigram_model.cc
+++ b/src/unigram_model.cc
@@ -578,7 +578,7 @@ bool Model::VerifyOutputsEquivalent(absl::string_view expected,
} else {
const int length = p.size();
total_score += IsUserDefinedInlined(id)
- ? (length * max_score_ + 1.0)
+ ? (length * max_score_ - 0.1)
: GetScoreInlined(id);
}
}
@@ -688,7 +688,7 @@ EncodeResult Model::EncodeOptimized(absl::string_view normalized) const {
const auto length = (key_pos - starts_at);
// User defined symbol receives extra bonus to always be selected.
const auto score = IsUserDefinedInlined(ret)
- ? (length * max_score_ + 1.0)
+ ? (length * max_score_ - 0.1)
: GetScoreInlined(ret);
const auto candidate_best_path_score =
score + best_path_score_till_here;
diff --git a/src/unigram_model.h b/src/unigram_model.h
index d67c7c7..df84260 100644
--- a/src/unigram_model.h
+++ b/src/unigram_model.h
@@ -127,6 +127,10 @@ class Model : public ModelInterface {
EncodeResult SampleEncode(absl::string_view normalized,
float theta) const override;
+ bool IsSampleEncodeAvailable() const override { return true; }
+
+ bool IsNBestEncodeAvailable() const override { return true; }
+
// Returns the minimum score in sentence pieces.
// min_score() - 10 is used for the cost of unknown sentence.
float min_score() const { return min_score_; }
diff --git a/third_party/absl/strings/str_replace.h b/third_party/absl/strings/str_replace.h
index f8ea9a0..5cda342 100644
--- a/third_party/absl/strings/str_replace.h
+++ b/third_party/absl/strings/str_replace.h
@@ -57,6 +57,7 @@ inline std::string StrReplaceAll(
std::string prev(s.data(), s.size());
std::string result;
for (const auto &it : patterns) {
+ result.clear();
StringReplace(prev, it.first, it.second, true, &result);
prev = result;
}