Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/bpe_model_test.cc2
-rw-r--r--src/bpe_model_trainer_test.cc22
-rw-r--r--src/builder_test.cc3
-rw-r--r--src/sentencepiece_processor.cc2
-rw-r--r--src/sentencepiece_processor_test.cc14
-rw-r--r--src/sentencepiece_trainer_test.cc35
-rw-r--r--src/testharness.cc20
-rw-r--r--src/testharness.h3
-rw-r--r--src/trainer_interface.cc6
-rw-r--r--src/unigram_model_trainer_test.cc26
-rw-r--r--src/util.cc40
-rw-r--r--src/util.h22
-rw-r--r--src/util_test.cc23
13 files changed, 158 insertions, 60 deletions
diff --git a/src/bpe_model_test.cc b/src/bpe_model_test.cc
index cc60880..4b067f6 100644
--- a/src/bpe_model_test.cc
+++ b/src/bpe_model_test.cc
@@ -103,7 +103,7 @@ TEST(BPEModelTest, EncodeTest) {
EXPECT_EQ("d", result[6].first);
// all unknown.
- result = model.Encode("xyz東京");
+ result = model.Encode(u8"xyz東京");
EXPECT_EQ(5, result.size());
EXPECT_EQ("x", result[0].first);
EXPECT_EQ("y", result[1].first);
diff --git a/src/bpe_model_trainer_test.cc b/src/bpe_model_trainer_test.cc
index 2061952..7a9c17d 100644
--- a/src/bpe_model_trainer_test.cc
+++ b/src/bpe_model_trainer_test.cc
@@ -89,15 +89,13 @@ TEST(BPETrainerTest, BasicTest) {
TEST(BPETrainerTest, EndToEndTest) {
const test::ScopedTempFile sf("tmp_model");
+ const std::string input =
+ util::JoinPath(FLAGS_data_dir, "wagahaiwa_nekodearu.txt");
- EXPECT_OK(SentencePieceTrainer::Train(std::string("--model_prefix=") +
- sf.filename() +
- " --input=" + FLAGS_data_dir +
- "/wagahaiwa_nekodearu.txt"
- " --vocab_size=8000"
- " --normalization_rule_name=identity"
- " --model_type=bpe"
- " --control_symbols=<ctrl>"));
+ EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
+ "--model_prefix=", sf.filename(), " --input=", input,
+ " --vocab_size=8000 --normalization_rule_name=identity"
+ " --model_type=bpe --control_symbols=<ctrl>")));
SentencePieceProcessor sp;
EXPECT_OK(sp.Load(std::string(sf.filename()) + ".model"));
@@ -117,10 +115,10 @@ TEST(BPETrainerTest, EndToEndTest) {
u8"。",
&tok));
EXPECT_EQ(WS
- " 吾輩 《 わが はい 》 は猫 である 。 名前 はまだ 無い 。 "
- "どこで 生 れた か とん と見 当 《 けんとう 》 が つかぬ 。 "
- "何でも 薄 暗 いじ め じ め した 所で ニャー ニャー 泣 いていた "
- "事 だけは 記憶 している 。",
+ u8" 吾輩 《 わが はい 》 は猫 である 。 名前 はまだ 無い 。 "
+ u8"どこで 生 れた か とん と見 当 《 けんとう 》 が つかぬ 。 "
+ u8"何でも 薄 暗 いじ め じ め した 所で ニャー ニャー 泣 いていた "
+ u8"事 だけは 記憶 している 。",
string_util::Join(tok, " "));
}
diff --git a/src/builder_test.cc b/src/builder_test.cc
index fd2d06c..e76fe40 100644
--- a/src/builder_test.cc
+++ b/src/builder_test.cc
@@ -138,7 +138,8 @@ TEST(BuilderTest, CompileCharsMap) {
TEST(BuilderTest, LoadCharsMapTest) {
Builder::CharsMap chars_map;
- EXPECT_OK(Builder::LoadCharsMap(FLAGS_data_dir + "/nfkc.tsv", &chars_map));
+ EXPECT_OK(Builder::LoadCharsMap(util::JoinPath(FLAGS_data_dir, "nfkc.tsv"),
+ &chars_map));
std::string precompiled, expected;
EXPECT_OK(Builder::CompileCharsMap(chars_map, &precompiled));
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 28c467c..872db63 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -43,7 +43,7 @@ SentencePieceProcessor::SentencePieceProcessor() {}
SentencePieceProcessor::~SentencePieceProcessor() {}
util::Status SentencePieceProcessor::Load(util::min_string_view filename) {
- std::ifstream ifs(filename.data(), std::ios::binary | std::ios::in);
+ std::ifstream ifs(WPATH(filename.data()), std::ios::binary | std::ios::in);
if (!ifs) {
return util::StatusBuilder(util::error::NOT_FOUND)
<< "\"" << filename.data() << "\": " << util::StrError(errno);
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc
index 4e6797d..30bdf8f 100644
--- a/src/sentencepiece_processor_test.cc
+++ b/src/sentencepiece_processor_test.cc
@@ -334,8 +334,10 @@ TEST(SentencepieceProcessorTest, NBestEncodeTest) {
auto mock = MakeUnique<MockModel>();
const NBestEncodeResult result = {
- {{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}, 1.0},
- {{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}}, 0.9}};
+ {{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}},
+ static_cast<float>(1.0)},
+ {{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}},
+ static_cast<float>(0.9)}};
mock->SetNBestEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
@@ -382,8 +384,10 @@ TEST(SentencepieceProcessorTest, SampleEncodeTest) {
const EncodeResult result = {
{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}};
const NBestEncodeResult nbest_result = {
- {{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}, 1.0},
- {{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}}, 0.1}};
+ {{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}},
+ static_cast<float>(1.0)},
+ {{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}},
+ static_cast<float>(0.1)}};
mock->SetNBestEncodeResult(kInput, nbest_result);
mock->SetEncodeResult(kInput, result);
@@ -627,7 +631,7 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
test::ScopedTempFile sf("model");
{
- std::ofstream ofs(sf.filename(), OUTPUT_MODE);
+ std::ofstream ofs(WPATH(sf.filename()), OUTPUT_MODE);
CHECK(model_proto.SerializeToOstream(&ofs));
}
diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc
index ead34c1..3fa57b8 100644
--- a/src/sentencepiece_trainer_test.cc
+++ b/src/sentencepiece_trainer_test.cc
@@ -24,25 +24,26 @@ namespace sentencepiece {
namespace {
TEST(SentencePieceTrainerTest, TrainFromArgsTest) {
- std::string input = FLAGS_data_dir + "/botchan.txt";
- SentencePieceTrainer::Train(std::string("--input=") + input +
- " --model_prefix=m --vocab_size=1000");
- SentencePieceTrainer::Train(std::string("--input=") + input +
- " --model_prefix=m --vocab_size=1000 "
- "--model_type=bpe");
- SentencePieceTrainer::Train(std::string("--input=") + input +
- " --model_prefix=m --vocab_size=1000 "
- "--model_type=char");
- SentencePieceTrainer::Train(std::string("--input=") + input +
- " --model_prefix=m --vocab_size=1000 "
- "--model_type=word");
+ std::string input = util::JoinPath(FLAGS_data_dir, "botchan.txt");
+ SentencePieceTrainer::Train(string_util::StrCat(
+ "--input=", input, " --model_prefix=m --vocab_size=1000"));
+ SentencePieceTrainer::Train(string_util::StrCat(
+ "--input=", input, " --model_prefix=m --vocab_size=1000 ",
+ "--model_type=bpe"));
+ SentencePieceTrainer::Train(string_util::StrCat(
+ "--input=", input, " --model_prefix=m --vocab_size=1000 ",
+ "--model_type=char"));
+ SentencePieceTrainer::Train(string_util::StrCat(
+ "--input=", input, " --model_prefix=m --vocab_size=1000 ",
+ "--model_type=word"));
}
TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) {
- SentencePieceTrainer::Train("--input=" + FLAGS_data_dir +
- "/botchan.txt --model_prefix=m --vocab_size=1000 "
- "--normalization_rule_tsv=" +
- FLAGS_data_dir + "/nfkc.tsv");
+ std::string input = util::JoinPath(FLAGS_data_dir, "botchan.txt");
+ std::string rule = util::JoinPath(FLAGS_data_dir, "nfkc.tsv");
+ SentencePieceTrainer::Train(string_util::StrCat(
+ "--input=", input, " --model_prefix=m --vocab_size=1000 ",
+ "--normalization_rule_tsv=", rule));
}
TEST(SentencePieceTrainerTest, TrainErrorTest) {
@@ -55,7 +56,7 @@ TEST(SentencePieceTrainerTest, TrainErrorTest) {
TEST(SentencePieceTrainerTest, TrainTest) {
TrainerSpec trainer_spec;
- trainer_spec.add_input(FLAGS_data_dir + "/botchan.txt");
+ trainer_spec.add_input(util::JoinPath(FLAGS_data_dir, "botchan.txt"));
trainer_spec.set_model_prefix("m");
trainer_spec.set_vocab_size(1000);
NormalizerSpec normalizer_spec;
diff --git a/src/testharness.cc b/src/testharness.cc
index afdfac3..35bf893 100644
--- a/src/testharness.cc
+++ b/src/testharness.cc
@@ -66,12 +66,24 @@ int RunAllTests() {
return 0;
}
-ScopedTempFile::ScopedTempFile(const std::string &filename) {
+ScopedTempFile::ScopedTempFile(absl::string_view filename) {
char pid[64];
- snprintf(pid, sizeof(pid), "%u", getpid());
- filename_ = "/tmp/.XXX.tmp." + filename + "." + pid;
+ snprintf(pid, sizeof(pid), "%u",
+#ifdef OS_WIN
+ static_cast<uint32>(::GetCurrentProcessId())
+#else
+ ::getpid()
+#endif
+ );
+ filename_ = string_util::StrCat(".XXX.tmp.", filename, ".", pid);
}
-ScopedTempFile::~ScopedTempFile() { ::unlink(filename_.c_str()); }
+ScopedTempFile::~ScopedTempFile() {
+#ifdef OS_WIN
+ ::DeleteFile(WPATH(filename_.c_str()));
+#else
+ ::unlink(filename_.c_str());
+#endif
+}
} // namespace test
} // namespace sentencepiece
diff --git a/src/testharness.h b/src/testharness.h
index 63ff510..ceb29ac 100644
--- a/src/testharness.h
+++ b/src/testharness.h
@@ -20,6 +20,7 @@
#include <sstream>
#include <string>
#include "common.h"
+#include "third_party/absl/strings/string_view.h"
namespace sentencepiece {
namespace test {
@@ -33,7 +34,7 @@ int RunAllTests();
class ScopedTempFile {
public:
- explicit ScopedTempFile(const std::string &filename);
+ explicit ScopedTempFile(absl::string_view filename);
~ScopedTempFile();
const char *filename() const { return filename_.c_str(); }
diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc
index 145b135..aa3a0e0 100644
--- a/src/trainer_interface.cc
+++ b/src/trainer_interface.cc
@@ -123,8 +123,8 @@ bool TrainerInterface::IsValidSentencePiece(
}
// Do not allow a piece to include multiple Unicode scripts
// when split_by_unicode_script() is true (default = true).
- if (prev_script != static_cast<unicode_script::ScriptType>(-1) && prev_script != s &&
- trainer_spec_.split_by_unicode_script()) {
+ if (prev_script != static_cast<unicode_script::ScriptType>(-1) &&
+ prev_script != s && trainer_spec_.split_by_unicode_script()) {
return false;
}
prev_script = s;
@@ -344,7 +344,7 @@ util::Status TrainerInterface::SaveModel(absl::string_view filename) const {
LOG(INFO) << "Saving model: " << filename;
ModelProto model_proto;
RETURN_IF_ERROR(Serialize(&model_proto));
- std::ofstream ofs(filename.data(), OUTPUT_MODE);
+ std::ofstream ofs(WPATH(filename.data()), OUTPUT_MODE);
CHECK_OR_RETURN(ofs) << "\"" << filename.data()
<< "\": " << util::StrError(errno);
CHECK_OR_RETURN(model_proto.SerializeToOstream(&ofs));
diff --git a/src/unigram_model_trainer_test.cc b/src/unigram_model_trainer_test.cc
index 153b3aa..508b7cb 100644
--- a/src/unigram_model_trainer_test.cc
+++ b/src/unigram_model_trainer_test.cc
@@ -39,19 +39,17 @@ TEST(UnigramTrainerTest, TrainerModelTest) {
TEST(UnigramTrainerTest, EndToEndTest) {
const test::ScopedTempFile sf("tmp_model");
+ const std::string input =
+ util::JoinPath(FLAGS_data_dir, "wagahaiwa_nekodearu.txt");
- EXPECT_OK(SentencePieceTrainer::Train(
- std::string("--model_prefix=") + sf.filename() +
- " --input=" + FLAGS_data_dir +
- "/wagahaiwa_nekodearu.txt"
- " --vocab_size=8000"
- " --normalization_rule_name=identity"
- " --model_type=unigram"
- " --user_defined_symbols=<user>" // Allows duplicated symbol
- " --control_symbols=<ctrl>"));
+ EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
+ "--model_prefix=", sf.filename(), " --input=", input,
+ " --vocab_size=8000 --normalization_rule_name=identity",
+ " --model_type=unigram --user_defined_symbols=<user>",
+ " --control_symbols=<ctrl>")));
SentencePieceProcessor sp;
- EXPECT_OK(sp.Load(std::string(sf.filename()) + ".model"));
+ EXPECT_OK(sp.Load(string_util::StrCat(sf.filename(), ".model")));
EXPECT_EQ(8000, sp.GetPieceSize());
const int cid = sp.PieceToId("<ctrl>");
@@ -71,10 +69,10 @@ TEST(UnigramTrainerTest, EndToEndTest) {
u8"。",
&tok));
EXPECT_EQ(WS
- " 吾輩 《 わが はい 》 は 猫 である 。 名前 はまだ 無い 。 "
- "どこ で 生 れた か とん と 見当 《 けん とう 》 が つか ぬ 。 "
- "何でも 薄 暗 い じめ じめ した 所で ニャーニャー "
- "泣 い ていた 事 だけは 記憶 している 。",
+ u8" 吾輩 《 わが はい 》 は 猫 である 。 名前 はまだ 無い 。 "
+ u8"どこ で 生 れた か とん と 見当 《 けん とう 》 が つか ぬ 。 "
+ u8"何でも 薄 暗 い じめ じめ した 所で ニャーニャー "
+ u8"泣 い ていた 事 だけは 記憶 している 。",
string_util::Join(tok, " "));
}
diff --git a/src/util.cc b/src/util.cc
index 5dbdd72..4bf3550 100644
--- a/src/util.cc
+++ b/src/util.cc
@@ -283,6 +283,44 @@ std::string StrError(int errnum) {
os << str << " Error #" << errnum;
return os.str();
}
-
} // namespace util
+
+#ifdef OS_WIN
+namespace win32 {
+std::wstring Utf8ToWide(const std::string &input) {
+ int output_length =
+ ::MultiByteToWideChar(CP_UTF8, 0, input.c_str(), -1, nullptr, 0);
+ output_length = output_length <= 0 ? 0 : output_length - 1;
+ if (output_length == 0) {
+ return L"";
+ }
+ std::unique_ptr<wchar_t[]> input_wide(new wchar_t[output_length + 1]);
+ const int result = ::MultiByteToWideChar(CP_UTF8, 0, input.c_str(), -1,
+ input_wide.get(), output_length + 1);
+ std::wstring output;
+ if (result > 0) {
+ output.assign(input_wide.get());
+ }
+ return output;
+}
+
+std::string WideToUtf8(const std::wstring &input) {
+ const int output_length = ::WideCharToMultiByte(CP_UTF8, 0, input.c_str(), -1,
+ nullptr, 0, nullptr, nullptr);
+ if (output_length == 0) {
+ return "";
+ }
+
+ std::unique_ptr<char[]> input_encoded(new char[output_length + 1]);
+ const int result =
+ ::WideCharToMultiByte(CP_UTF8, 0, input.c_str(), -1, input_encoded.get(),
+ output_length + 1, nullptr, nullptr);
+ std::string output;
+ if (result > 0) {
+ output.assign(input_encoded.get());
+ }
+ return output;
+}
+} // namespace win32
+#endif
} // namespace sentencepiece
diff --git a/src/util.h b/src/util.h
index e2af6e9..279338e 100644
--- a/src/util.h
+++ b/src/util.h
@@ -115,6 +115,15 @@ std::string Join(const std::vector<std::string> &tokens,
std::string Join(const std::vector<int> &tokens, absl::string_view delim);
+inline std::string StrCat(absl::string_view str) {
+ return std::string(str.data(), str.size());
+}
+
+template <typename... T>
+inline std::string StrCat(absl::string_view first, const T &... rest) {
+ return std::string(first) + StrCat(rest...);
+}
+
std::string StringReplace(absl::string_view s, absl::string_view oldsub,
absl::string_view newsub, bool replace_all);
@@ -416,6 +425,19 @@ void STLDeleteElements(std::vector<T *> *vec) {
namespace util {
+inline std::string JoinPath(absl::string_view path) {
+ return std::string(path.data(), path.size());
+}
+
+template <typename... T>
+inline std::string JoinPath(absl::string_view first, const T &... rest) {
+#ifdef OS_WIN
+ return JoinPath(first) + "\\" + JoinPath(rest...);
+#else
+ return JoinPath(first) + "/" + JoinPath(rest...);
+#endif
+}
+
std::string StrError(int errnum);
inline Status OkStatus() { return Status(); }
diff --git a/src/util_test.cc b/src/util_test.cc
index 6772b4a..420b798 100644
--- a/src/util_test.cc
+++ b/src/util_test.cc
@@ -195,6 +195,17 @@ TEST(UtilTest, JoinIntTest) {
EXPECT_EQ(string_util::Join(tokens, ""), "102-45");
}
+TEST(UtilTest, StrCatTest) {
+ EXPECT_EQ("", string_util::StrCat(""));
+ EXPECT_EQ("ab", string_util::StrCat("ab"));
+ EXPECT_EQ("ab", string_util::StrCat("ab", ""));
+ EXPECT_EQ("abc", string_util::StrCat("ab", "c"));
+ EXPECT_EQ("abc", string_util::StrCat("ab", "", "", "c"));
+ std::string a = "foo";
+ std::string b = "bar";
+ EXPECT_EQ("foobar", string_util::StrCat(a, b));
+}
+
TEST(UtilTest, StringViewTest) {
absl::string_view s;
EXPECT_EQ(0, s.find("", 0));
@@ -548,4 +559,16 @@ TEST(UtilTest, StatusTest) {
EXPECT_TRUE(s.ToString().find("message") != std::string::npos);
}
}
+
+TEST(UtilTest, JoinPathTest) {
+#ifdef OS_WIN
+ EXPECT_EQ("foo\\bar\\buz", util::JoinPath("foo", "bar", "buz"));
+ EXPECT_EQ("foo\\\\buz", util::JoinPath("foo", "", "buz"));
+#else
+ EXPECT_EQ("foo/bar/buz", util::JoinPath("foo", "bar", "buz"));
+ EXPECT_EQ("foo//buz", util::JoinPath("foo", "", "buz"));
+#endif
+ EXPECT_EQ("foo", util::JoinPath("foo"));
+ EXPECT_EQ("", util::JoinPath(""));
+}
} // namespace sentencepiece