diff options
-rw-r--r-- | src/bpe_model_test.cc | 2 | ||||
-rw-r--r-- | src/bpe_model_trainer_test.cc | 22 | ||||
-rw-r--r-- | src/builder_test.cc | 3 | ||||
-rw-r--r-- | src/sentencepiece_processor.cc | 2 | ||||
-rw-r--r-- | src/sentencepiece_processor_test.cc | 14 | ||||
-rw-r--r-- | src/sentencepiece_trainer_test.cc | 35 | ||||
-rw-r--r-- | src/testharness.cc | 20 | ||||
-rw-r--r-- | src/testharness.h | 3 | ||||
-rw-r--r-- | src/trainer_interface.cc | 6 | ||||
-rw-r--r-- | src/unigram_model_trainer_test.cc | 26 | ||||
-rw-r--r-- | src/util.cc | 40 | ||||
-rw-r--r-- | src/util.h | 22 | ||||
-rw-r--r-- | src/util_test.cc | 23 |
13 files changed, 158 insertions, 60 deletions
diff --git a/src/bpe_model_test.cc b/src/bpe_model_test.cc index cc60880..4b067f6 100644 --- a/src/bpe_model_test.cc +++ b/src/bpe_model_test.cc @@ -103,7 +103,7 @@ TEST(BPEModelTest, EncodeTest) { EXPECT_EQ("d", result[6].first); // all unknown. - result = model.Encode("xyz東京"); + result = model.Encode(u8"xyz東京"); EXPECT_EQ(5, result.size()); EXPECT_EQ("x", result[0].first); EXPECT_EQ("y", result[1].first); diff --git a/src/bpe_model_trainer_test.cc b/src/bpe_model_trainer_test.cc index 2061952..7a9c17d 100644 --- a/src/bpe_model_trainer_test.cc +++ b/src/bpe_model_trainer_test.cc @@ -89,15 +89,13 @@ TEST(BPETrainerTest, BasicTest) { TEST(BPETrainerTest, EndToEndTest) { const test::ScopedTempFile sf("tmp_model"); + const std::string input = + util::JoinPath(FLAGS_data_dir, "wagahaiwa_nekodearu.txt"); - EXPECT_OK(SentencePieceTrainer::Train(std::string("--model_prefix=") + - sf.filename() + - " --input=" + FLAGS_data_dir + - "/wagahaiwa_nekodearu.txt" - " --vocab_size=8000" - " --normalization_rule_name=identity" - " --model_type=bpe" - " --control_symbols=<ctrl>")); + EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat( + "--model_prefix=", sf.filename(), " --input=", input, + " --vocab_size=8000 --normalization_rule_name=identity" + " --model_type=bpe --control_symbols=<ctrl>"))); SentencePieceProcessor sp; EXPECT_OK(sp.Load(std::string(sf.filename()) + ".model")); @@ -117,10 +115,10 @@ TEST(BPETrainerTest, EndToEndTest) { u8"。", &tok)); EXPECT_EQ(WS - " 吾輩 《 わが はい 》 は猫 である 。 名前 はまだ 無い 。 " - "どこで 生 れた か とん と見 当 《 けんとう 》 が つかぬ 。 " - "何でも 薄 暗 いじ め じ め した 所で ニャー ニャー 泣 いていた " - "事 だけは 記憶 している 。", + u8" 吾輩 《 わが はい 》 は猫 である 。 名前 はまだ 無い 。 " + u8"どこで 生 れた か とん と見 当 《 けんとう 》 が つかぬ 。 " + u8"何でも 薄 暗 いじ め じ め した 所で ニャー ニャー 泣 いていた " + u8"事 だけは 記憶 している 。", string_util::Join(tok, " ")); } diff --git a/src/builder_test.cc b/src/builder_test.cc index fd2d06c..e76fe40 100644 --- a/src/builder_test.cc +++ b/src/builder_test.cc @@ -138,7 +138,8 @@ TEST(BuilderTest, CompileCharsMap) { TEST(BuilderTest, LoadCharsMapTest) { Builder::CharsMap chars_map; - EXPECT_OK(Builder::LoadCharsMap(FLAGS_data_dir + "/nfkc.tsv", &chars_map)); + EXPECT_OK(Builder::LoadCharsMap(util::JoinPath(FLAGS_data_dir, "nfkc.tsv"), + &chars_map)); std::string precompiled, expected; EXPECT_OK(Builder::CompileCharsMap(chars_map, &precompiled)); diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 28c467c..872db63 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -43,7 +43,7 @@ SentencePieceProcessor::SentencePieceProcessor() {} SentencePieceProcessor::~SentencePieceProcessor() {} util::Status SentencePieceProcessor::Load(util::min_string_view filename) { - std::ifstream ifs(filename.data(), std::ios::binary | std::ios::in); + std::ifstream ifs(WPATH(filename.data()), std::ios::binary | std::ios::in); if (!ifs) { return util::StatusBuilder(util::error::NOT_FOUND) << "\"" << filename.data() << "\": " << util::StrError(errno); diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index 4e6797d..30bdf8f 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -334,8 +334,10 @@ TEST(SentencepieceProcessorTest, NBestEncodeTest) { auto mock = MakeUnique<MockModel>(); const NBestEncodeResult result = { - {{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}, 1.0}, - {{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}}, 0.9}}; + {{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}, + static_cast<float>(1.0)}, + {{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}}, + static_cast<float>(0.9)}}; mock->SetNBestEncodeResult(kInput, result); sp.SetModel(std::move(mock)); @@ -382,8 +384,10 @@ TEST(SentencepieceProcessorTest, SampleEncodeTest) { const EncodeResult result = { {WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}; const NBestEncodeResult nbest_result = { - {{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}, 1.0}, - {{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}}, 0.1}}; + {{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}, + static_cast<float>(1.0)}, + {{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}}, + static_cast<float>(0.1)}}; mock->SetNBestEncodeResult(kInput, nbest_result); mock->SetEncodeResult(kInput, result); @@ -627,7 +631,7 @@ TEST(SentencePieceProcessorTest, EndToEndTest) { test::ScopedTempFile sf("model"); { - std::ofstream ofs(sf.filename(), OUTPUT_MODE); + std::ofstream ofs(WPATH(sf.filename()), OUTPUT_MODE); CHECK(model_proto.SerializeToOstream(&ofs)); } diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc index ead34c1..3fa57b8 100644 --- a/src/sentencepiece_trainer_test.cc +++ b/src/sentencepiece_trainer_test.cc @@ -24,25 +24,26 @@ namespace sentencepiece { namespace { TEST(SentencePieceTrainerTest, TrainFromArgsTest) { - std::string input = FLAGS_data_dir + "/botchan.txt"; - SentencePieceTrainer::Train(std::string("--input=") + input + - " --model_prefix=m --vocab_size=1000"); - SentencePieceTrainer::Train(std::string("--input=") + input + - " --model_prefix=m --vocab_size=1000 " - "--model_type=bpe"); - SentencePieceTrainer::Train(std::string("--input=") + input + - " --model_prefix=m --vocab_size=1000 " - "--model_type=char"); - SentencePieceTrainer::Train(std::string("--input=") + input + - " --model_prefix=m --vocab_size=1000 " - "--model_type=word"); + std::string input = util::JoinPath(FLAGS_data_dir, "botchan.txt"); + SentencePieceTrainer::Train(string_util::StrCat( + "--input=", input, " --model_prefix=m --vocab_size=1000")); + SentencePieceTrainer::Train(string_util::StrCat( + "--input=", input, " --model_prefix=m --vocab_size=1000 ", + "--model_type=bpe")); + SentencePieceTrainer::Train(string_util::StrCat( + "--input=", input, " --model_prefix=m --vocab_size=1000 ", + "--model_type=char")); + SentencePieceTrainer::Train(string_util::StrCat( + "--input=", input, " --model_prefix=m --vocab_size=1000 ", + "--model_type=word")); } TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) { - SentencePieceTrainer::Train("--input=" + FLAGS_data_dir + - "/botchan.txt --model_prefix=m --vocab_size=1000 " - "--normalization_rule_tsv=" + - FLAGS_data_dir + "/nfkc.tsv"); + std::string input = util::JoinPath(FLAGS_data_dir, "botchan.txt"); + std::string rule = util::JoinPath(FLAGS_data_dir, "nfkc.tsv"); + SentencePieceTrainer::Train(string_util::StrCat( + "--input=", input, " --model_prefix=m --vocab_size=1000 ", + "--normalization_rule_tsv=", rule)); } TEST(SentencePieceTrainerTest, TrainErrorTest) { @@ -55,7 +56,7 @@ TEST(SentencePieceTrainerTest, TrainErrorTest) { TEST(SentencePieceTrainerTest, TrainTest) { TrainerSpec trainer_spec; - trainer_spec.add_input(FLAGS_data_dir + "/botchan.txt"); + trainer_spec.add_input(util::JoinPath(FLAGS_data_dir, "botchan.txt")); trainer_spec.set_model_prefix("m"); trainer_spec.set_vocab_size(1000); NormalizerSpec normalizer_spec; diff --git a/src/testharness.cc b/src/testharness.cc index afdfac3..35bf893 100644 --- a/src/testharness.cc +++ b/src/testharness.cc @@ -66,12 +66,24 @@ int RunAllTests() { return 0; } -ScopedTempFile::ScopedTempFile(const std::string &filename) { +ScopedTempFile::ScopedTempFile(absl::string_view filename) { char pid[64]; - snprintf(pid, sizeof(pid), "%u", getpid()); - filename_ = "/tmp/.XXX.tmp." + filename + "." + pid; + snprintf(pid, sizeof(pid), "%u", +#ifdef OS_WIN + static_cast<uint32>(::GetCurrentProcessId()) +#else + ::getpid() +#endif + ); + filename_ = string_util::StrCat(".XXX.tmp.", filename, ".", pid); } -ScopedTempFile::~ScopedTempFile() { ::unlink(filename_.c_str()); } +ScopedTempFile::~ScopedTempFile() { +#ifdef OS_WIN + ::DeleteFile(WPATH(filename_.c_str())); +#else + ::unlink(filename_.c_str()); +#endif +} } // namespace test } // namespace sentencepiece diff --git a/src/testharness.h b/src/testharness.h index 63ff510..ceb29ac 100644 --- a/src/testharness.h +++ b/src/testharness.h @@ -20,6 +20,7 @@ #include <sstream> #include <string> #include "common.h" +#include "third_party/absl/strings/string_view.h" namespace sentencepiece { namespace test { @@ -33,7 +34,7 @@ int RunAllTests(); class ScopedTempFile { public: - explicit ScopedTempFile(const std::string &filename); + explicit ScopedTempFile(absl::string_view filename); ~ScopedTempFile(); const char *filename() const { return filename_.c_str(); } diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index 145b135..aa3a0e0 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -123,8 +123,8 @@ bool TrainerInterface::IsValidSentencePiece( } // Do not allow a piece to include multiple Unicode scripts // when split_by_unicode_script() is true (default = true). - if (prev_script != static_cast<unicode_script::ScriptType>(-1) && prev_script != s && - trainer_spec_.split_by_unicode_script()) { + if (prev_script != static_cast<unicode_script::ScriptType>(-1) && + prev_script != s && trainer_spec_.split_by_unicode_script()) { return false; } prev_script = s; @@ -344,7 +344,7 @@ util::Status TrainerInterface::SaveModel(absl::string_view filename) const { LOG(INFO) << "Saving model: " << filename; ModelProto model_proto; RETURN_IF_ERROR(Serialize(&model_proto)); - std::ofstream ofs(filename.data(), OUTPUT_MODE); + std::ofstream ofs(WPATH(filename.data()), OUTPUT_MODE); CHECK_OR_RETURN(ofs) << "\"" << filename.data() << "\": " << util::StrError(errno); CHECK_OR_RETURN(model_proto.SerializeToOstream(&ofs)); diff --git a/src/unigram_model_trainer_test.cc b/src/unigram_model_trainer_test.cc index 153b3aa..508b7cb 100644 --- a/src/unigram_model_trainer_test.cc +++ b/src/unigram_model_trainer_test.cc @@ -39,19 +39,17 @@ TEST(UnigramTrainerTest, TrainerModelTest) { TEST(UnigramTrainerTest, EndToEndTest) { const test::ScopedTempFile sf("tmp_model"); + const std::string input = + util::JoinPath(FLAGS_data_dir, "wagahaiwa_nekodearu.txt"); - EXPECT_OK(SentencePieceTrainer::Train( - std::string("--model_prefix=") + sf.filename() + - " --input=" + FLAGS_data_dir + - "/wagahaiwa_nekodearu.txt" - " --vocab_size=8000" - " --normalization_rule_name=identity" - " --model_type=unigram" - " --user_defined_symbols=<user>" // Allows duplicated symbol - " --control_symbols=<ctrl>")); + EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat( + "--model_prefix=", sf.filename(), " --input=", input, + " --vocab_size=8000 --normalization_rule_name=identity", + " --model_type=unigram --user_defined_symbols=<user>", + " --control_symbols=<ctrl>"))); SentencePieceProcessor sp; - EXPECT_OK(sp.Load(std::string(sf.filename()) + ".model")); + EXPECT_OK(sp.Load(string_util::StrCat(sf.filename(), ".model"))); EXPECT_EQ(8000, sp.GetPieceSize()); const int cid = sp.PieceToId("<ctrl>"); @@ -71,10 +69,10 @@ TEST(UnigramTrainerTest, EndToEndTest) { u8"。", &tok)); EXPECT_EQ(WS - " 吾輩 《 わが はい 》 は 猫 である 。 名前 はまだ 無い 。 " - "どこ で 生 れた か とん と 見当 《 けん とう 》 が つか ぬ 。 " - "何でも 薄 暗 い じめ じめ した 所で ニャーニャー " - "泣 い ていた 事 だけは 記憶 している 。", + u8" 吾輩 《 わが はい 》 は 猫 である 。 名前 はまだ 無い 。 " + u8"どこ で 生 れた か とん と 見当 《 けん とう 》 が つか ぬ 。 " + u8"何でも 薄 暗 い じめ じめ した 所で ニャーニャー " + u8"泣 い ていた 事 だけは 記憶 している 。", string_util::Join(tok, " ")); } diff --git a/src/util.cc b/src/util.cc index 5dbdd72..4bf3550 100644 --- a/src/util.cc +++ b/src/util.cc @@ -283,6 +283,44 @@ std::string StrError(int errnum) { os << str << " Error #" << errnum; return os.str(); } - } // namespace util + +#ifdef OS_WIN +namespace win32 { +std::wstring Utf8ToWide(const std::string &input) { + int output_length = + ::MultiByteToWideChar(CP_UTF8, 0, input.c_str(), -1, nullptr, 0); + output_length = output_length <= 0 ? 0 : output_length - 1; + if (output_length == 0) { + return L""; + } + std::unique_ptr<wchar_t[]> input_wide(new wchar_t[output_length + 1]); + const int result = ::MultiByteToWideChar(CP_UTF8, 0, input.c_str(), -1, + input_wide.get(), output_length + 1); + std::wstring output; + if (result > 0) { + output.assign(input_wide.get()); + } + return output; +} + +std::string WideToUtf8(const std::wstring &input) { + const int output_length = ::WideCharToMultiByte(CP_UTF8, 0, input.c_str(), -1, + nullptr, 0, nullptr, nullptr); + if (output_length == 0) { + return ""; + } + + std::unique_ptr<char[]> input_encoded(new char[output_length + 1]); + const int result = + ::WideCharToMultiByte(CP_UTF8, 0, input.c_str(), -1, input_encoded.get(), + output_length + 1, nullptr, nullptr); + std::string output; + if (result > 0) { + output.assign(input_encoded.get()); + } + return output; +} +} // namespace win32 +#endif } // namespace sentencepiece @@ -115,6 +115,15 @@ std::string Join(const std::vector<std::string> &tokens, std::string Join(const std::vector<int> &tokens, absl::string_view delim); +inline std::string StrCat(absl::string_view str) { + return std::string(str.data(), str.size()); +} + +template <typename... T> +inline std::string StrCat(absl::string_view first, const T &... rest) { + return std::string(first) + StrCat(rest...); +} + std::string StringReplace(absl::string_view s, absl::string_view oldsub, absl::string_view newsub, bool replace_all); @@ -416,6 +425,19 @@ void STLDeleteElements(std::vector<T *> *vec) { namespace util { +inline std::string JoinPath(absl::string_view path) { + return std::string(path.data(), path.size()); +} + +template <typename... T> +inline std::string JoinPath(absl::string_view first, const T &... rest) { +#ifdef OS_WIN + return JoinPath(first) + "\\" + JoinPath(rest...); +#else + return JoinPath(first) + "/" + JoinPath(rest...); +#endif +} + std::string StrError(int errnum); inline Status OkStatus() { return Status(); } diff --git a/src/util_test.cc b/src/util_test.cc index 6772b4a..420b798 100644 --- a/src/util_test.cc +++ b/src/util_test.cc @@ -195,6 +195,17 @@ TEST(UtilTest, JoinIntTest) { EXPECT_EQ(string_util::Join(tokens, ""), "102-45"); } +TEST(UtilTest, StrCatTest) { + EXPECT_EQ("", string_util::StrCat("")); + EXPECT_EQ("ab", string_util::StrCat("ab")); + EXPECT_EQ("ab", string_util::StrCat("ab", "")); + EXPECT_EQ("abc", string_util::StrCat("ab", "c")); + EXPECT_EQ("abc", string_util::StrCat("ab", "", "", "c")); + std::string a = "foo"; + std::string b = "bar"; + EXPECT_EQ("foobar", string_util::StrCat(a, b)); +} + TEST(UtilTest, StringViewTest) { absl::string_view s; EXPECT_EQ(0, s.find("", 0)); @@ -548,4 +559,16 @@ TEST(UtilTest, StatusTest) { EXPECT_TRUE(s.ToString().find("message") != std::string::npos); } } + +TEST(UtilTest, JoinPathTest) { +#ifdef OS_WIN + EXPECT_EQ("foo\\bar\\buz", util::JoinPath("foo", "bar", "buz")); + EXPECT_EQ("foo\\\\buz", util::JoinPath("foo", "", "buz")); +#else + EXPECT_EQ("foo/bar/buz", util::JoinPath("foo", "bar", "buz")); + EXPECT_EQ("foo//buz", util::JoinPath("foo", "", "buz")); +#endif + EXPECT_EQ("foo", util::JoinPath("foo")); + EXPECT_EQ("", util::JoinPath("")); +} } // namespace sentencepiece |