Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-07-12 19:24:46 +0300
committerTaku Kudo <taku@google.com>2018-07-12 19:24:46 +0300
commit256c6f5bb731c567c897999e4dca35e171f3b212 (patch)
tree7684b61ae6a175a3e5fc6d9f3423261bae7fb6b1 /src
parent983c0f5aeb26d6963c3adef94b12e2ea1595dac9 (diff)
Added new API to get bos/eos/unk/pad ids
Diffstat (limited to 'src')
-rw-r--r--src/model_interface.cc5
-rw-r--r--src/model_interface.h5
-rw-r--r--src/sentencepiece_processor.cc25
-rw-r--r--src/sentencepiece_processor.h15
-rw-r--r--src/sentencepiece_processor_test.cc5
-rw-r--r--src/trainer_interface.cc34
-rw-r--r--src/trainer_interface.h5
7 files changed, 70 insertions, 24 deletions
diff --git a/src/model_interface.cc b/src/model_interface.cc
index 5cbb1a5..2e9d685 100644
--- a/src/model_interface.cc
+++ b/src/model_interface.cc
@@ -70,6 +70,11 @@ std::string PrefixMatcher::GlobalReplace(absl::string_view w,
return result;
}
+const char ModelInterface::kUNK[] = "<unk>";
+const char ModelInterface::kBOS[] = "<s>";
+const char ModelInterface::kEOS[] = "</s>";
+const char ModelInterface::kPAD[] = "<pad>";
+
ModelInterface::ModelInterface(const ModelProto &model_proto)
: model_proto_(&model_proto), status_(util::OkStatus()) {}
ModelInterface::~ModelInterface() {}
diff --git a/src/model_interface.h b/src/model_interface.h
index 04b733c..e1bdada 100644
--- a/src/model_interface.h
+++ b/src/model_interface.h
@@ -65,6 +65,11 @@ class ModelInterface {
using PieceToIdMap =
std::unordered_map<absl::string_view, int, string_util::string_view_hash>;
+ static const char kUNK[];
+ static const char kBOS[];
+ static const char kEOS[];
+ static const char kPAD[];
+
// `model_proto` should not be deleted until ModelInterface is destroyed.
explicit ModelInterface(const ModelProto &model_proto);
ModelInterface() {}
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 40ffcb6..c6b9ed5 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -21,6 +21,7 @@
#include "common.h"
#include "model_factory.h"
+#include "model_interface.h"
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "unigram_model.h"
@@ -521,6 +522,30 @@ bool SentencePieceProcessor::IsUnused(int id) const {
return model_->IsUnused(id);
}
+int SentencePieceProcessor::unk_id() const {
+ const int id = PieceToId(ModelInterface::kUNK);
+ if (IsUnknown(id)) return id;
+ return -1;
+}
+
+int SentencePieceProcessor::bos_id() const {
+ const int id = PieceToId(ModelInterface::kBOS);
+ if (IsControl(id)) return id;
+ return -1;
+}
+
+int SentencePieceProcessor::eos_id() const {
+ const int id = PieceToId(ModelInterface::kEOS);
+ if (IsControl(id)) return id;
+ return -1;
+}
+
+int SentencePieceProcessor::pad_id() const {
+ const int id = PieceToId(ModelInterface::kPAD);
+ if (IsControl(id)) return id;
+ return -1;
+}
+
// static
util::Status SentencePieceProcessor::ApplyExtraOptions(
const std::vector<ExtraOption> &extra_options,
diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h
index e795921..1cd6c54 100644
--- a/src/sentencepiece_processor.h
+++ b/src/sentencepiece_processor.h
@@ -384,6 +384,21 @@ class SentencePieceProcessor {
// Returns true if `id` is unused symbol.
virtual bool IsUnused(int id) const;
+ // Returns the reserved id.
+ // Returns -1 if not defined.
+
+ // Returns unknown (<unk>) id.
+ virtual int unk_id() const;
+
+ // Returns BOS (<s>) id.
+ virtual int bos_id() const;
+
+ // Returns EOS (</s>) id.
+ virtual int eos_id() const;
+
+ // Returns PAD (<pad>) id.
+ virtual int pad_id() const;
+
#ifndef SWIG
//////////////////////////////////////////////////////////////
// Model management.
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc
index 0a960b0..b0f1eb9 100644
--- a/src/sentencepiece_processor_test.cc
+++ b/src/sentencepiece_processor_test.cc
@@ -624,6 +624,11 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
EXPECT_FALSE(sp.IsControl(6));
EXPECT_FALSE(sp.IsControl(7));
+ EXPECT_EQ(0, sp.unk_id());
+ EXPECT_EQ(1, sp.bos_id());
+ EXPECT_EQ(2, sp.eos_id());
+ EXPECT_EQ(-1, sp.pad_id());
+
{
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {WS, "ab", "c"};
diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc
index 33625b4..646bc1d 100644
--- a/src/trainer_interface.cc
+++ b/src/trainer_interface.cc
@@ -41,11 +41,6 @@ const char TrainerInterface::kUNKStr[] = "\xe2\x96\x85";
const char32 TrainerInterface::kUPPBoundaryChar = L'\u0009';
const char TrainerInterface::kUPPBoundaryStr[] = "\t";
-const char TrainerInterface::kUNK[] = "<unk>";
-const char TrainerInterface::kBOS[] = "<s>";
-const char TrainerInterface::kEOS[] = "</s>";
-const char TrainerInterface::kPAD[] = "<pad>";
-
namespace {
util::Status VerifySpec(const TrainerSpec &trainer_spec) {
CHECK_OR_RETURN(!trainer_spec.model_prefix().empty());
@@ -385,21 +380,22 @@ util::Status TrainerInterface::InitMetaPieces() {
auto insert_id = [&has_unk, this](int id, const std::string &w) -> bool {
if (id < 0) return true;
if (id >= trainer_spec_.vocab_size() ||
- meta_pieces_.find(id) != meta_pieces_.end() || (has_unk && w == kUNK))
+ meta_pieces_.find(id) != meta_pieces_.end() ||
+ (has_unk && w == ModelInterface::kUNK))
return false;
- if (w == kUNK) has_unk = true;
- meta_pieces_[id] =
- std::make_pair(w, w == kUNK ? ModelProto::SentencePiece::UNKNOWN
- : ModelProto::SentencePiece::CONTROL);
+ if (w == ModelInterface::kUNK) has_unk = true;
+ meta_pieces_[id] = std::make_pair(
+ w, w == ModelInterface::kUNK ? ModelProto::SentencePiece::UNKNOWN
+ : ModelProto::SentencePiece::CONTROL);
return true;
};
- CHECK_OR_RETURN(insert_id(trainer_spec_.unk_id(), kUNK));
- CHECK_OR_RETURN(insert_id(trainer_spec_.bos_id(), kBOS));
- CHECK_OR_RETURN(insert_id(trainer_spec_.eos_id(), kEOS));
- CHECK_OR_RETURN(insert_id(trainer_spec_.pad_id(), kPAD));
+ CHECK_OR_RETURN(insert_id(trainer_spec_.unk_id(), ModelInterface::kUNK));
+ CHECK_OR_RETURN(insert_id(trainer_spec_.bos_id(), ModelInterface::kBOS));
+ CHECK_OR_RETURN(insert_id(trainer_spec_.eos_id(), ModelInterface::kEOS));
+ CHECK_OR_RETURN(insert_id(trainer_spec_.pad_id(), ModelInterface::kPAD));
- CHECK_OR_RETURN(has_unk) << kUNK << " must be defined.";
+ CHECK_OR_RETURN(has_unk) << ModelInterface::kUNK << " must be defined.";
std::set<std::string> dup;
@@ -412,17 +408,17 @@ util::Status TrainerInterface::InitMetaPieces() {
return false;
}
- if (w == kUNK) {
+ if (w == ModelInterface::kUNK) {
LOG(ERROR) << "<unk> must not be defined with --control_symbols and "
"--user_defined_symbols.";
return false;
}
- if (w == kBOS && trainer_spec_.bos_id() >= 0) {
+ if (w == ModelInterface::kBOS && trainer_spec_.bos_id() >= 0) {
meta_pieces_[trainer_spec_.bos_id()].second = type;
- } else if (w == kEOS && trainer_spec_.eos_id() >= 0) {
+ } else if (w == ModelInterface::kEOS && trainer_spec_.eos_id() >= 0) {
meta_pieces_[trainer_spec_.eos_id()].second = type;
- } else if (w == kPAD && trainer_spec_.pad_id() >= 0) {
+ } else if (w == ModelInterface::kPAD && trainer_spec_.pad_id() >= 0) {
meta_pieces_[trainer_spec_.pad_id()].second = type;
} else {
while (meta_pieces_.find(id) != meta_pieces_.end()) ++id;
diff --git a/src/trainer_interface.h b/src/trainer_interface.h
index b4abb12..5f3b1ae 100644
--- a/src/trainer_interface.h
+++ b/src/trainer_interface.h
@@ -59,11 +59,6 @@ class TrainerInterface {
static const char kUNKStr[];
static const char kUPPBoundaryStr[];
- static const char kUNK[];
- static const char kBOS[];
- static const char kEOS[];
- static const char kPAD[];
-
TrainerInterface(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec);