Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-06-08 12:44:58 +0300
committerTaku Kudo <taku@google.com>2018-06-08 16:10:13 +0300
commitdb1eeac98580e3aa5a9e523ece5cdc1dc839b333 (patch)
tree9b4261d1cadf7a6764d983b83db6ca82a8d323d3 /src
parent54ccef78b800625a58cbdbac1245d77c9b744e84 (diff)
Allows to define duplicated user defined symbols
Diffstat (limited to 'src')
-rw-r--r--src/bpe_model_trainer_test.cc10
-rw-r--r--src/model_interface.cc15
-rw-r--r--src/model_interface.h3
-rw-r--r--src/model_interface_test.cc77
-rw-r--r--src/trainer_interface.cc9
-rw-r--r--src/unigram_model_trainer.cc1
-rw-r--r--src/unigram_model_trainer_test.cc2
7 files changed, 78 insertions, 39 deletions
diff --git a/src/bpe_model_trainer_test.cc b/src/bpe_model_trainer_test.cc
index 222dd11..71d49ba 100644
--- a/src/bpe_model_trainer_test.cc
+++ b/src/bpe_model_trainer_test.cc
@@ -28,7 +28,9 @@ namespace {
// Space symbol
#define WS "\xe2\x96\x81"
-std::string RunTrainer(const std::vector<std::string> &input, int size) {
+std::string RunTrainer(
+ const std::vector<std::string> &input, int size,
+ const std::vector<std::string> &user_defined_symbols = {}) {
test::ScopedTempFile input_scoped_file("input");
test::ScopedTempFile model_scoped_file("model");
const std::string input_file = input_scoped_file.filename();
@@ -50,6 +52,10 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
normalizer_spec.set_name("identity");
normalizer_spec.set_add_dummy_prefix(false);
+ for (const auto &w : user_defined_symbols) {
+ trainer_spec.add_user_defined_symbols(w);
+ }
+
Trainer trainer(trainer_spec, normalizer_spec);
EXPECT_OK(trainer.Train());
@@ -74,6 +80,8 @@ TEST(BPETrainerTest, BasicTest) {
RunTrainer({"pen", "pineapple", "apple"}, 20));
EXPECT_EQ("he ll llo hello hellohe el lo oh hel ohe e h l o",
RunTrainer({"hellohe"}, 20));
+ EXPECT_EQ("app le en in ine pen " WS "le pine e l n p i " WS,
+ RunTrainer({"pen", "pineapple", "apple"}, 20, {"app"}));
}
TEST(BPETrainerTest, EndToEndTest) {
diff --git a/src/model_interface.cc b/src/model_interface.cc
index 656334f..255d1be 100644
--- a/src/model_interface.cc
+++ b/src/model_interface.cc
@@ -54,6 +54,21 @@ int PrefixMatcher::PrefixMatch(StringPiece w, bool *found) const {
return mblen;
}
+std::string PrefixMatcher::GlobalReplace(StringPiece w, StringPiece out) const {
+ std::string result;
+ while (!w.empty()) {
+ bool found = false;
+ const int mblen = PrefixMatch(w, &found);
+ if (found) {
+ result.append(out.data(), out.size());
+ } else {
+ result.append(w.data(), mblen);
+ }
+ w.remove_prefix(mblen);
+ }
+ return result;
+}
+
ModelInterface::ModelInterface(const ModelProto &model_proto)
: model_proto_(&model_proto), status_(util::OkStatus()) {}
ModelInterface::~ModelInterface() {}
diff --git a/src/model_interface.h b/src/model_interface.h
index df96b2e..f70c58a 100644
--- a/src/model_interface.h
+++ b/src/model_interface.h
@@ -50,6 +50,9 @@ class PrefixMatcher {
// If no entry is found, consumes one Unicode character.
int PrefixMatch(StringPiece w, bool *found = nullptr) const;
+ // Replaces entries in `w` with `out`.
+ std::string GlobalReplace(StringPiece w, StringPiece out) const;
+
private:
std::unique_ptr<Darts::DoubleArray> trie_;
};
diff --git a/src/model_interface_test.cc b/src/model_interface_test.cc
index c5c2fb8..51dbdae 100644
--- a/src/model_interface_test.cc
+++ b/src/model_interface_test.cc
@@ -266,43 +266,48 @@ TEST(ModelInterfaceTest, SplitIntoWordsTest) {
}
TEST(ModelInterfaceTest, PrefixMatcherTest) {
- {
- const PrefixMatcher matcher({"abc", "ab", "xy", "京都"});
- bool found;
- EXPECT_EQ(1, matcher.PrefixMatch("test", &found));
- EXPECT_FALSE(found);
- EXPECT_EQ(3, matcher.PrefixMatch("abcd", &found));
- EXPECT_TRUE(found);
- EXPECT_EQ(2, matcher.PrefixMatch("abxy", &found));
- EXPECT_TRUE(found);
- EXPECT_EQ(1, matcher.PrefixMatch("x", &found));
- EXPECT_FALSE(found);
- EXPECT_EQ(2, matcher.PrefixMatch("xyz", &found));
- EXPECT_TRUE(found);
- EXPECT_EQ(6, matcher.PrefixMatch("京都大学", &found));
- EXPECT_TRUE(found);
- EXPECT_EQ(3, matcher.PrefixMatch("東京大学", &found));
- EXPECT_FALSE(found);
- }
+ const PrefixMatcher matcher({"abc", "ab", "xy", "京都"});
+ bool found;
+ EXPECT_EQ(1, matcher.PrefixMatch("test", &found));
+ EXPECT_FALSE(found);
+ EXPECT_EQ(3, matcher.PrefixMatch("abcd", &found));
+ EXPECT_TRUE(found);
+ EXPECT_EQ(2, matcher.PrefixMatch("abxy", &found));
+ EXPECT_TRUE(found);
+ EXPECT_EQ(1, matcher.PrefixMatch("x", &found));
+ EXPECT_FALSE(found);
+ EXPECT_EQ(2, matcher.PrefixMatch("xyz", &found));
+ EXPECT_TRUE(found);
+ EXPECT_EQ(6, matcher.PrefixMatch("京都大学", &found));
+ EXPECT_TRUE(found);
+ EXPECT_EQ(3, matcher.PrefixMatch("東京大学", &found));
+ EXPECT_FALSE(found);
+
+ EXPECT_EQ("", matcher.GlobalReplace("", ""));
+ EXPECT_EQ("", matcher.GlobalReplace("abc", ""));
+ EXPECT_EQ("--de-pqr", matcher.GlobalReplace("xyabcdeabpqr", "-"));
+}
- {
- const PrefixMatcher matcher({});
- bool found;
- EXPECT_EQ(1, matcher.PrefixMatch("test", &found));
- EXPECT_FALSE(found);
- EXPECT_EQ(1, matcher.PrefixMatch("abcd", &found));
- EXPECT_FALSE(found);
- EXPECT_EQ(1, matcher.PrefixMatch("abxy", &found));
- EXPECT_FALSE(found);
- EXPECT_EQ(1, matcher.PrefixMatch("x", &found));
- EXPECT_FALSE(found);
- EXPECT_EQ(1, matcher.PrefixMatch("xyz", &found));
- EXPECT_FALSE(found);
- EXPECT_EQ(3, matcher.PrefixMatch("京都大学", &found));
- EXPECT_FALSE(found);
- EXPECT_EQ(3, matcher.PrefixMatch("東京大学", &found));
- EXPECT_FALSE(found);
- }
+TEST(ModelInterfaceTest, PrefixMatcherWithEmptyTest) {
+ const PrefixMatcher matcher({});
+ bool found;
+ EXPECT_EQ(1, matcher.PrefixMatch("test", &found));
+ EXPECT_FALSE(found);
+ EXPECT_EQ(1, matcher.PrefixMatch("abcd", &found));
+ EXPECT_FALSE(found);
+ EXPECT_EQ(1, matcher.PrefixMatch("abxy", &found));
+ EXPECT_FALSE(found);
+ EXPECT_EQ(1, matcher.PrefixMatch("x", &found));
+ EXPECT_FALSE(found);
+ EXPECT_EQ(1, matcher.PrefixMatch("xyz", &found));
+ EXPECT_FALSE(found);
+ EXPECT_EQ(3, matcher.PrefixMatch("京都大学", &found));
+ EXPECT_FALSE(found);
+ EXPECT_EQ(3, matcher.PrefixMatch("東京大学", &found));
+ EXPECT_FALSE(found);
+
+ EXPECT_EQ("", matcher.GlobalReplace("", ""));
+ EXPECT_EQ("abc", matcher.GlobalReplace("abc", ""));
}
} // namespace
diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc
index 1203f68..ff8aa21 100644
--- a/src/trainer_interface.cc
+++ b/src/trainer_interface.cc
@@ -24,6 +24,7 @@
#include <vector>
#include "model_factory.h"
+#include "model_interface.h"
#include "normalizer.h"
#include "sentencepiece_processor.h"
#include "unicode_script.h"
@@ -152,6 +153,10 @@ util::Status TrainerInterface::LoadSentences() {
const bool is_tsv = trainer_spec_.input_format() == "tsv";
+ std::set<StringPiece> meta_pieces_set;
+ for (const auto &it : meta_pieces_) meta_pieces_set.insert(it.second.first);
+ const PrefixMatcher meta_pieces_matcher(meta_pieces_set);
+
for (const auto &filename : trainer_spec_.input()) {
LOG(INFO) << "Loading corpus: " << filename;
std::string sentence;
@@ -177,6 +182,9 @@ util::Status TrainerInterface::LoadSentences() {
continue;
}
+ // Escapes meta symbols so that they are not extract as normal pieces.
+ sentence = meta_pieces_matcher.GlobalReplace(sentence, " ");
+
// Normalizes sentence with Normalizer.
// whitespaces are replaced with kWSChar.
const std::string normalized = normalizer.Normalize(sentence);
@@ -184,6 +192,7 @@ util::Status TrainerInterface::LoadSentences() {
LOG(INFO) << "Loading: " << normalized
<< "\tsize=" << sentences_.size();
}
+
CHECK_OR_RETURN(normalized.find(" ") == std::string::npos)
<< "Normalized string must not include spaces";
if (normalized.empty()) {
diff --git a/src/unigram_model_trainer.cc b/src/unigram_model_trainer.cc
index 2d9a04b..d3c1326 100644
--- a/src/unigram_model_trainer.cc
+++ b/src/unigram_model_trainer.cc
@@ -109,7 +109,6 @@ void TrainerModel::SetSentencePieces(SentencePieces &&sentencepieces) {
TrainerModel::SentencePieces Trainer::MakeSeedSentencePieces() const {
CHECK(!sentences_.empty());
CHECK(!required_chars_.empty());
- CHECK(port::ContainsKey(required_chars_, kWSChar));
// Merges all sentences into one array with 0x0000 delimiter.
std::vector<char32> array;
diff --git a/src/unigram_model_trainer_test.cc b/src/unigram_model_trainer_test.cc
index 6845da1..aa60427 100644
--- a/src/unigram_model_trainer_test.cc
+++ b/src/unigram_model_trainer_test.cc
@@ -42,7 +42,7 @@ TEST(UnigramTrainerTest, EndToEndTest) {
" --vocab_size=8000"
" --normalization_rule_name=identity"
" --model_type=unigram"
- " --user_defined_symbols=<user>"
+ " --user_defined_symbols=<user>" // Allows duplicated symbol
" --control_symbols=<ctrl>"));
SentencePieceProcessor sp;