Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/model_interface.cc')
-rw-r--r--src/model_interface.cc27
1 files changed, 22 insertions, 5 deletions
diff --git a/src/model_interface.cc b/src/model_interface.cc
index c69cecc..e46a632 100644
--- a/src/model_interface.cc
+++ b/src/model_interface.cc
@@ -20,15 +20,32 @@
namespace sentencepiece {
-const char *ModelInterface::kUNK() { return "<unk>"; }
-const char *ModelInterface::kBOS() { return "<s>"; }
-const char *ModelInterface::kEOS() { return "</s>"; }
-const char *ModelInterface::kPAD() { return "<pad>"; };
-
ModelInterface::ModelInterface(const ModelProto &model_proto)
: model_proto_(&model_proto), status_(util::OkStatus()) {}
ModelInterface::~ModelInterface() {}
+#define RETURN_PIECE(name, default_value) \
+ if (model_proto_->trainer_spec().name().empty()) return default_value; \
+ return model_proto_->trainer_spec().name();
+
+absl::string_view ModelInterface::unk_piece() const {
+ RETURN_PIECE(unk_piece, "<unk>");
+}
+
+absl::string_view ModelInterface::bos_piece() const {
+ RETURN_PIECE(bos_piece, "<s>");
+}
+
+absl::string_view ModelInterface::eos_piece() const {
+ RETURN_PIECE(eos_piece, "</s>");
+}
+
+absl::string_view ModelInterface::pad_piece() const {
+ RETURN_PIECE(pad_piece, "<pad>");
+}
+
+#undef RETURN_PIECE
+
int ModelInterface::PieceToId(absl::string_view piece) const {
auto it = reserved_id_map_.find(piece);
if (it != reserved_id_map_.end()) {