Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-05-06 18:51:20 +0300
committerTaku Kudo <taku@google.com>2018-05-06 18:51:20 +0300
commitcf0eb82d65850172a4661e47668938ab82cb2c76 (patch)
treefa6414db5714c6fd15f725d5a9e28cbcb14e20c2 /src
parent31153b117294830ab41ff3e9ee4f0a7323f16d8d (diff)
CHECK to Status migration for Trainer.
Diffstat (limited to 'src')
-rw-r--r--src/bpe_model_trainer.cc20
-rw-r--r--src/bpe_model_trainer.h2
-rw-r--r--src/bpe_model_trainer_test.cc2
-rw-r--r--src/char_model_trainer.cc16
-rw-r--r--src/char_model_trainer.h2
-rw-r--r--src/char_model_trainer_test.cc2
-rw-r--r--src/common.h19
-rw-r--r--src/error.cc5
-rw-r--r--src/model_interface.cc7
-rw-r--r--src/sentencepiece_processor.cc107
-rw-r--r--src/sentencepiece_processor.h5
-rw-r--r--src/sentencepiece_trainer.cc60
-rw-r--r--src/spm_decode_main.cc2
-rw-r--r--src/spm_encode_main.cc2
-rw-r--r--src/spm_export_vocab_main.cc3
-rw-r--r--src/spm_normalize_main.cc4
-rw-r--r--src/trainer_interface.cc146
-rw-r--r--src/trainer_interface.h26
-rw-r--r--src/trainer_interface_test.cc13
-rw-r--r--src/unigram_model.cc6
-rw-r--r--src/unigram_model_trainer.cc12
-rw-r--r--src/unigram_model_trainer.h2
-rw-r--r--src/unigram_model_trainer_test.cc2
-rw-r--r--src/util.cc20
-rw-r--r--src/util.h41
-rw-r--r--src/util_test.cc3
-rw-r--r--src/word_model_trainer.cc16
-rw-r--r--src/word_model_trainer.h2
-rw-r--r--src/word_model_trainer_test.cc4
29 files changed, 322 insertions, 229 deletions
diff --git a/src/bpe_model_trainer.cc b/src/bpe_model_trainer.cc
index 078ca56..e2ffb42 100644
--- a/src/bpe_model_trainer.cc
+++ b/src/bpe_model_trainer.cc
@@ -167,11 +167,13 @@ void Trainer::UpdateActiveSymbols() {
active_symbols_.insert(symbols.begin(), symbols.begin() + size);
}
-void Trainer::Train() {
+util::Status Trainer::Train() {
+ RETURN_IF_ERROR(status());
+
LOG(INFO) << "Starts training with : \n" << trainer_spec_.Utf8DebugString();
- CHECK(normalizer_spec_.escape_whitespaces());
- CHECK_EQ(TrainerSpec::BPE, trainer_spec_.model_type());
+ CHECK_OR_RETURN(normalizer_spec_.escape_whitespaces());
+ CHECK_EQ_OR_RETURN(TrainerSpec::BPE, trainer_spec_.model_type());
symbols_.clear();
allocated_.clear();
@@ -179,7 +181,7 @@ void Trainer::Train() {
active_symbols_.clear();
// Load all sentences
- LoadSentences();
+ RETURN_IF_ERROR(LoadSentences());
if (trainer_spec_.split_by_whitespace()) {
SplitSentencesByWhitespace();
@@ -202,7 +204,7 @@ void Trainer::Train() {
const int vocab_size =
trainer_spec_.vocab_size() - meta_pieces_.size() - required_chars_.size();
- CHECK_GE(vocab_size, 0);
+ CHECK_GE_OR_RETURN(vocab_size, 0);
// We may see duplicated pieces that are extracted with different path.
// In real segmentation phase, we can consider them as one symbol.
@@ -210,7 +212,7 @@ void Trainer::Train() {
std::unordered_set<std::string> dup;
// Main loop.
- CHECK(final_pieces_.empty());
+ CHECK_OR_RETURN(final_pieces_.empty());
while (final_pieces_.size() < static_cast<size_t>(vocab_size)) {
constexpr int kUpdateActiveSymbolsInteval = 100;
if (final_pieces_.size() % kUpdateActiveSymbolsInteval == 0) {
@@ -269,7 +271,7 @@ void Trainer::Train() {
// when left_symbol == right_symbol.
continue;
}
- CHECK_NOTNULL(symbols_[pos.sid][pos.right]);
+ CHECK_OR_RETURN(symbols_[pos.sid][pos.right]);
// We have three bigrams [prev, left], [left, right], [right, next],
// which are affected with this symbol replacement.
@@ -301,9 +303,9 @@ void Trainer::Train() {
-static_cast<float>(final_pieces_.size()));
}
- Save();
-
port::STLDeleteElements(&allocated_);
+
+ return Save();
}
} // namespace bpe
} // namespace sentencepiece
diff --git a/src/bpe_model_trainer.h b/src/bpe_model_trainer.h
index 41e01a8..30056ae 100644
--- a/src/bpe_model_trainer.h
+++ b/src/bpe_model_trainer.h
@@ -29,7 +29,7 @@ class Trainer : public TrainerInterface {
const NormalizerSpec &normalizer_spec)
: TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec) {}
- void Train() override;
+ util::Status Train() override;
private:
// Symbol represents a character or symbol bigram.
diff --git a/src/bpe_model_trainer_test.cc b/src/bpe_model_trainer_test.cc
index 2336dd5..0610ae5 100644
--- a/src/bpe_model_trainer_test.cc
+++ b/src/bpe_model_trainer_test.cc
@@ -51,7 +51,7 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
trainer.Train();
SentencePieceProcessor processor;
- processor.Load(model_prefix + ".model");
+ EXPECT_OK(processor.Load(model_prefix + ".model"));
const auto &model = processor.model_proto();
std::vector<std::string> pieces;
diff --git a/src/char_model_trainer.cc b/src/char_model_trainer.cc
index 6984fb5..dcff1b0 100644
--- a/src/char_model_trainer.cc
+++ b/src/char_model_trainer.cc
@@ -22,16 +22,18 @@
namespace sentencepiece {
namespace character {
-void Trainer::Train() {
+util::Status Trainer::Train() {
+ RETURN_IF_ERROR(status());
+
LOG(INFO) << "Starts training with : \n" << trainer_spec_.Utf8DebugString();
- CHECK(normalizer_spec_.escape_whitespaces());
- CHECK_EQ(TrainerSpec::CHAR, trainer_spec_.model_type());
+ CHECK_OR_RETURN(normalizer_spec_.escape_whitespaces());
+ CHECK_EQ_OR_RETURN(TrainerSpec::CHAR, trainer_spec_.model_type());
- LoadSentences();
+ RETURN_IF_ERROR(LoadSentences());
const int vocab_size = trainer_spec_.vocab_size() - meta_pieces_.size();
- CHECK_GE(vocab_size, 0);
+ CHECK_GE_OR_RETURN(vocab_size, 0);
uint64 sum = 0;
for (const auto &it : required_chars_) {
@@ -40,7 +42,7 @@ void Trainer::Train() {
const float logsum = log(sum);
- CHECK(final_pieces_.empty());
+ CHECK_OR_RETURN(final_pieces_.empty());
for (const auto &it : Sorted(required_chars_)) {
if (final_pieces_.size() == static_cast<size_t>(vocab_size)) {
break;
@@ -49,7 +51,7 @@ void Trainer::Train() {
log(it.second) - logsum);
}
- Save();
+ return Save();
}
} // namespace character
} // namespace sentencepiece
diff --git a/src/char_model_trainer.h b/src/char_model_trainer.h
index 366f145..b57e3b8 100644
--- a/src/char_model_trainer.h
+++ b/src/char_model_trainer.h
@@ -28,7 +28,7 @@ class Trainer : public TrainerInterface {
const NormalizerSpec &normalizer_spec)
: TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec) {}
- void Train() override;
+ util::Status Train() override;
};
} // namespace character
} // namespace sentencepiece
diff --git a/src/char_model_trainer_test.cc b/src/char_model_trainer_test.cc
index d60e383..f577748 100644
--- a/src/char_model_trainer_test.cc
+++ b/src/char_model_trainer_test.cc
@@ -50,7 +50,7 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
trainer.Train();
SentencePieceProcessor processor;
- processor.Load(model_prefix + ".model");
+ EXPECT_OK(processor.Load(model_prefix + ".model"));
const auto &model = processor.model_proto();
std::vector<std::string> pieces;
diff --git a/src/common.h b/src/common.h
index 83b183a..803d9bc 100644
--- a/src/common.h
+++ b/src/common.h
@@ -161,8 +161,6 @@ enum LogSeverity {
std::cerr << __FILE__ << "(" << __LINE__ << ") [" \
<< #condition << "] "
-#define CHECK_IFS(a, b) CHECK((a)) << "No such file or directory: [" << b << "]"
-#define CHECK_OFS(a, b) CHECK((a)) << "Permission denied: [" << b << "]"
#define CHECK_STREQ(a, b) CHECK_EQ(std::string(a), std::string(b))
#define CHECK_EQ(a, b) CHECK((a) == (b))
#define CHECK_NE(a, b) CHECK((a) != (b))
@@ -176,14 +174,21 @@ enum LogSeverity {
#define FRIEND_TEST(a, b) friend class a##_Test_##b;
-#define CHECK_OK(status) \
- CHECK_EQ(status.code(), ::sentencepiece::util::error::OK) << status.ToString()
-#define CHECK_NOT_OK(status) \
- CHECK_NE(status.code(), ::sentencepiece::util::error::OK) << status.ToString()
+#define CHECK_OK(expr) \
+ do { \
+ const auto _status = expr; \
+ CHECK(_status.ok()) << _status.ToString(); \
+ } while (0)
+
+#define CHECK_NOT_OK(expr) \
+ do { \
+ const auto _status = expr; \
+ CHECK(!_status.ok()) << _status.ToString(); \
+ } while (0)
#define RETURN_IF_ERROR(expr) \
do { \
- const util::Status _status = expr; \
+ const auto _status = expr; \
if (!_status.ok()) return _status; \
} while (0)
diff --git a/src/error.cc b/src/error.cc
index b9787b3..93baced 100644
--- a/src/error.cc
+++ b/src/error.cc
@@ -59,6 +59,11 @@ const char* Status::error_message() const {
return ok() ? "" : rep_->error_message.c_str();
}
+void Status::set_error_message(const char* str) {
+ if (rep_ == nullptr) rep_.reset(new Rep);
+ rep_->error_message = str;
+}
+
error::Code Status::code() const { return ok() ? error::OK : rep_->code; }
std::string Status::ToString() const {
diff --git a/src/model_interface.cc b/src/model_interface.cc
index 059e8bf..62ecf17 100644
--- a/src/model_interface.cc
+++ b/src/model_interface.cc
@@ -63,7 +63,8 @@ void ModelInterface::InitializePieces(bool enable_user_defined) {
const auto &sp = model_proto_->pieces(i);
if (!enable_user_defined &&
sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
- status_ = util::InternalError("User defined symbol is not supported.");
+ status_ = util::StatusBuilder(util::error::INTERNAL)
+ << "user defined symbol is not supported.";
return;
}
@@ -72,7 +73,8 @@ void ModelInterface::InitializePieces(bool enable_user_defined) {
sp.type() == ModelProto::SentencePiece::USER_DEFINED);
if (!port::InsertIfNotPresent(
is_normal_piece ? &pieces_ : &reserved_id_map_, sp.piece(), i)) {
- status_ = util::InternalError(sp.piece() + " is already defined.");
+ status_ = util::StatusBuilder(util::error::INTERNAL)
+ << "\"" << sp.piece() << "\" is already defined.";
return;
}
@@ -94,7 +96,6 @@ std::vector<StringPiece> SplitIntoWords(StringPiece text) {
if (begin == text.data() || StringPiece(begin, mblen) == kSpaceSymbol) {
result.emplace_back(begin, 0); // add empty string piece.
}
- CHECK(!result.empty());
result.back() =
StringPiece(result.back().data(), result.back().size() + mblen);
begin += mblen;
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 6b29a56..119165e 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -40,20 +40,22 @@ SentencePieceProcessor::~SentencePieceProcessor() {}
util::Status SentencePieceProcessor::Load(const std::string &filename) {
std::ifstream ifs(filename.c_str(), std::ios::binary | std::ios::in);
if (!ifs) {
- return util::NotFoundError(std::string("Cannot open ") + filename);
+ return util::StatusBuilder(util::error::NOT_FOUND)
+ << "\"" << filename << "\": " << std::strerror(errno);
}
return Load(&ifs);
}
+void SentencePieceProcessor::LoadOrDie(const std::string &filename) {
+ CHECK_OK(Load(filename));
+}
+
util::Status SentencePieceProcessor::Load(std::istream *is) {
- if (is == nullptr)
- return util::InternalError("input ifstream is null");
+ CHECK_OR_RETURN(is) << "input ifstream is null";
model_proto_ = port::MakeUnique<ModelProto>();
- if (!model_proto_->ParseFromIstream(is)) {
- return util::InternalError("Model file is broken");
- }
+ CHECK_OR_RETURN(model_proto_->ParseFromIstream(is)) << "Model file is broken";
model_ = ModelFactory::Create(*model_proto_);
normalizer_ =
@@ -73,32 +75,27 @@ util::Status SentencePieceProcessor::SetDecodeExtraOptions(
}
util::Status SentencePieceProcessor::status() const {
- if (model_ == nullptr)
- return util::InternalError("Model is not initialized.");
- if (normalizer_ == nullptr)
- return util::InternalError("Normalizer is not initialized.");
- if (!model_->status().ok()) return model_->status();
- if (!normalizer_->status().ok()) return normalizer_->status();
-
+ CHECK_OR_RETURN(model_) << "Model is not initialized.";
+ CHECK_OR_RETURN(normalizer_) << "Normalizer is not initialized.";
+ RETURN_IF_ERROR(model_->status());
+ RETURN_IF_ERROR(normalizer_->status());
return util::OkStatus();
}
-#define CHECK_OR_RETURN_STATUS_STL(container) \
- RETURN_IF_ERROR(status()); \
- if (container == nullptr) \
- return util::InternalError("output container is null"); \
+#define CHECK_OR_RETURN_STATUS_STL(container) \
+ RETURN_IF_ERROR(status()); \
+ CHECK_OR_RETURN(container) << "output container is null"; \
container->clear();
-#define CHECK_OR_RETURN_STATUS_PROTO(proto) \
- RETURN_IF_ERROR(status()); \
- if (proto == nullptr) \
- return util::InternalError("output proto is null"); \
+#define CHECK_OR_RETURN_STATUS_PROTO(proto) \
+ RETURN_IF_ERROR(status()); \
+ CHECK_OR_RETURN(proto) << "output proto is null"; \
proto->Clear();
//////////////////////////////////////////////////////////////
// Simple API.
-util::Status SentencePieceProcessor::Encode(const std::string &input,
- std::vector<std::string> *pieces) const {
+util::Status SentencePieceProcessor::Encode(
+ const std::string &input, std::vector<std::string> *pieces) const {
CHECK_OR_RETURN_STATUS_STL(pieces);
SentencePieceText spt;
@@ -123,11 +120,11 @@ util::Status SentencePieceProcessor::Encode(const std::string &input,
return util::OkStatus();
}
-util::Status SentencePieceProcessor::Decode(const std::vector<std::string> &pieces,
- std::string *detokenized) const {
+util::Status SentencePieceProcessor::Decode(
+ const std::vector<std::string> &pieces, std::string *detokenized) const {
CHECK_OR_RETURN_STATUS_STL(detokenized);
- SentencePieceText spt;
+ SentencePieceText spt;
RETURN_IF_ERROR(Decode(pieces, &spt));
*detokenized = std::move(spt.text());
@@ -137,8 +134,8 @@ util::Status SentencePieceProcessor::Decode(const std::vector<std::string> &piec
util::Status SentencePieceProcessor::Decode(const std::vector<int> &ids,
std::string *detokenized) const {
CHECK_OR_RETURN_STATUS_STL(detokenized);
- SentencePieceText spt;
+ SentencePieceText spt;
RETURN_IF_ERROR(Decode(ids, &spt));
*detokenized = std::move(spt.text());
@@ -219,9 +216,7 @@ util::Status SentencePieceProcessor::PopulateSentencePieceText(
const StringPiece w = p.first; // piece
const int id = p.second; // id
- if (w.empty()) {
- return util::InternalError("Empty piece is not allowed.");
- }
+ CHECK_OR_RETURN(!w.empty()) << "Empty piece is not allowed.";
const bool is_unk = IsUnknown(id);
@@ -287,7 +282,8 @@ util::Status SentencePieceProcessor::Encode(const std::string &input,
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
const auto result = model_->Encode(normalized);
- RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, result, spt));
+ RETURN_IF_ERROR(
+ PopulateSentencePieceText(input, normalized, norm_to_orig, result, spt));
return util::OkStatus();
}
@@ -301,25 +297,22 @@ util::Status SentencePieceProcessor::NBestEncode(
std::vector<size_t> norm_to_orig;
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
- const auto nbests =
- model_->NBestEncode(normalized, nbest_size);
- if (nbests.empty()) {
- return util::InternalError("NBestEncode returns empty result");
- }
+ const auto nbests = model_->NBestEncode(normalized, nbest_size);
+ CHECK_OR_RETURN(!nbests.empty()) << "NBestEncode returns empty result.";
for (const auto &result : nbests) {
auto *spt = nbest_spt->add_nbests();
spt->set_score(result.second);
- RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, result.first,
- spt));
+ RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
+ result.first, spt));
}
return util::OkStatus();
}
-util::Status SentencePieceProcessor::SampleEncode(const std::string &input,
- int nbest_size, float alpha,
- SentencePieceText *spt) const {
+util::Status SentencePieceProcessor::SampleEncode(
+ const std::string &input, int nbest_size, float alpha,
+ SentencePieceText *spt) const {
CHECK_OR_RETURN_STATUS_PROTO(spt);
if (nbest_size > 512 || nbest_size == 0) {
@@ -333,13 +326,11 @@ util::Status SentencePieceProcessor::SampleEncode(const std::string &input,
if (nbest_size == 1) {
const auto result = model_->Encode(normalized);
- RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, result, spt));
+ RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
+ result, spt));
} else if (nbest_size > 1) {
- const auto nbests =
- model_->NBestEncode(normalized, nbest_size);
- if (nbests.empty()) {
- return util::InternalError("NBestEncode returns empty result");
- }
+ const auto nbests = model_->NBestEncode(normalized, nbest_size);
+ CHECK_OR_RETURN(!nbests.empty()) << "NBestEncode returns empty result.";
std::vector<float> probs(nbests.size(), 0.0);
for (size_t i = 0; i < nbests.size(); ++i) {
@@ -353,14 +344,15 @@ util::Status SentencePieceProcessor::SampleEncode(const std::string &input,
} else if (nbest_size < 0) {
const auto result = model_->SampleEncode(normalized, alpha);
- RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, result, spt));
+ RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
+ result, spt));
}
return util::OkStatus();
}
-util::Status SentencePieceProcessor::Decode(const std::vector<std::string> &pieces,
- SentencePieceText *spt) const {
+util::Status SentencePieceProcessor::Decode(
+ const std::vector<std::string> &pieces, SentencePieceText *spt) const {
CHECK_OR_RETURN_STATUS_PROTO(spt);
auto DecodeSentencePiece = [&](StringPiece piece, int id,
@@ -412,10 +404,11 @@ util::Status SentencePieceProcessor::Decode(const std::vector<int> &ids,
return Decode(pieces, spt);
}
-#define CHECK_STATUS_OR_RETURN_DEFAULT(value) \
- if (!status().ok()) { \
- LOG(ERROR) << status().error_message() << "\nReturns default value " << value; \
- return value; \
+#define CHECK_STATUS_OR_RETURN_DEFAULT(value) \
+ if (!status().ok()) { \
+ LOG(ERROR) << status().error_message() << "\nReturns default value " \
+ << value; \
+ return value; \
}
int SentencePieceProcessor::GetPieceSize() const {
@@ -474,7 +467,7 @@ util::Status SentencePieceProcessor::ApplyExtraOptions(
piece->set_piece("<s>");
} break;
default:
- return util::InternalError("Unknown extra_option type");
+ return util::InternalError("unknown extra_option type.");
}
}
@@ -492,8 +485,8 @@ util::Status SentencePieceProcessor::ParseExtraOptions(
{"reverse", SentencePieceProcessor::REVERSE}};
for (const auto &s : string_util::Split(extra_option, ":")) {
const auto it = extra_option_map.find(s);
- if (it == extra_option_map.end())
- return util::InternalError(std::string("option ") + s + " is not available.");
+ CHECK_OR_RETURN(it != extra_option_map.end())
+ << "option \"" << s << "\" is not available.";
extra_options->push_back(it->second);
}
return util::OkStatus();
diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h
index 30854eb..251ea0b 100644
--- a/src/sentencepiece_processor.h
+++ b/src/sentencepiece_processor.h
@@ -124,6 +124,7 @@ class Status {
bool operator!=(const Status &s) const;
inline bool ok() const { return rep_ == nullptr; }
+ void set_error_message(const char *str);
const char *error_message() const;
error::Code code() const;
std::string ToString() const;
@@ -145,6 +146,10 @@ class SentencePieceProcessor {
// Returns false if `filename` cannot be loaded.
virtual util::Status Load(const std::string &filename);
+ // Loads model from `filename`.
+ // Crash if `filename` cannot be loaded.
+ virtual void LoadOrDie(const std::string &filename);
+
// Loads model from `is`.
// Returns false if `is` cannot be loaded.
virtual util::Status Load(std::istream *is);
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc
index be74035..c9ac133 100644
--- a/src/sentencepiece_trainer.cc
+++ b/src/sentencepiece_trainer.cc
@@ -32,8 +32,7 @@ static constexpr char kDefaultNormalizerName[] = "nfkc";
// static
util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
NormalizerSpec normalizer_spec;
- Train(trainer_spec, normalizer_spec);
- return util::OkStatus();
+ return Train(trainer_spec, normalizer_spec);
}
// static
@@ -42,9 +41,8 @@ util::Status SentencePieceTrainer::Train(
auto copied_normalizer_spec = normalizer_spec;
if (!copied_normalizer_spec.normalization_rule_tsv().empty()) {
- if (!copied_normalizer_spec.precompiled_charsmap().empty()) {
- return util::InternalError("precompiled_charsmap is already defined.");
- }
+ CHECK_OR_RETURN(copied_normalizer_spec.precompiled_charsmap().empty())
+ << "precompiled_charsmap is already defined.";
const auto chars_map = normalizer::Builder::BuildMapFromFile(
copied_normalizer_spec.normalization_rule_tsv());
@@ -64,7 +62,7 @@ util::Status SentencePieceTrainer::Train(
}
auto trainer = TrainerFactory::Create(trainer_spec, copied_normalizer_spec);
- trainer->Train();
+ RETURN_IF_ERROR(trainer->Train());
return util::OkStatus();
}
@@ -76,16 +74,15 @@ util::Status SentencePieceTrainer::SetProtoField(
const auto *descriptor = message->GetDescriptor();
const auto *reflection = message->GetReflection();
- if (descriptor == nullptr || reflection == nullptr) {
- return util::InternalError("Reflection is not supported.");
- }
+ CHECK_OR_RETURN(descriptor != nullptr && reflection != nullptr)
+ << "reflection is not supported.";
const auto *field = descriptor->FindFieldByName(std::string(field_name));
if (field == nullptr) {
- return util::NotFoundError(std::string("Unknown field name \"") +
- field_name + "\" in " +
- descriptor->DebugString());
+ return util::StatusBuilder(util::error::NOT_FOUND)
+ << "unknown field name \"" << field_name << "\" in\n"
+ << descriptor->DebugString();
}
std::vector<std::string> values = {value};
@@ -97,16 +94,16 @@ util::Status SentencePieceTrainer::SetProtoField(
else \
reflection->Set##METHOD_TYPE(message, field, v);
-#define DEFINE_SET_FIELD(PROTO_TYPE, CPP_TYPE, FUNC_PREFIX, METHOD_TYPE, \
- EMPTY) \
- case google::protobuf::FieldDescriptor::CPPTYPE_##PROTO_TYPE: { \
- CPP_TYPE v; \
- if (!string_util::lexical_cast(value.empty() ? EMPTY : value, &v)) \
- return util::InvalidArgumentError(std::string("Cannot parse \"") + \
- value + "\" as \"" + \
- field->type_name() + "\"."); \
- SET_FIELD(METHOD_TYPE, v); \
- break; \
+#define DEFINE_SET_FIELD(PROTO_TYPE, CPP_TYPE, FUNC_PREFIX, METHOD_TYPE, \
+ EMPTY) \
+ case google::protobuf::FieldDescriptor::CPPTYPE_##PROTO_TYPE: { \
+ CPP_TYPE v; \
+ if (!string_util::lexical_cast(value.empty() ? EMPTY : value, &v)) \
+ return util::StatusBuilder(util::error::INVALID_ARGUMENT) \
+ << "cannot parse \"" << value << "\" as \"" << field->type_name() \
+ << "\"."; \
+ SET_FIELD(METHOD_TYPE, v); \
+ break; \
}
for (const auto &value : values) {
@@ -125,17 +122,16 @@ util::Status SentencePieceTrainer::SetProtoField(
const auto *enum_value =
field->enum_type()->FindValueByName(string_util::ToUpper(value));
if (enum_value == nullptr)
- return util::InvalidArgumentError(
- std::string("Unknown enumeration value of \"") + value +
- "\" for field \"" + field->name() + "\".");
+ return util::StatusBuilder(util::error::INVALID_ARGUMENT)
+ << "unknown enumeration value of \"" << value
+ << "\" for field \"" << field->name() << "\".";
SET_FIELD(Enum, enum_value);
break;
}
default:
- return util::UnimplementedError(std::string("Proto type \"") +
- field->cpp_type_name() +
- "\" is not supported.");
- break;
+ return util::StatusBuilder(util::error::UNIMPLEMENTED)
+ << "proto type \"" << field->cpp_type_name()
+ << "\" is not supported.";
}
}
@@ -146,10 +142,8 @@ util::Status SentencePieceTrainer::SetProtoField(
util::Status SentencePieceTrainer::MergeSpecsFromArgs(
const std::string &args, TrainerSpec *trainer_spec,
NormalizerSpec *normalizer_spec) {
- if (trainer_spec == nullptr || normalizer_spec == nullptr) {
- return util::InternalError(
- "`trainer_spec` and `normalizer_spec` must not be null.");
- }
+ CHECK_OR_RETURN(trainer_spec) << "`trainer_spec` must not be null.";
+ CHECK_OR_RETURN(normalizer_spec) << "`normalizer_spec` must not be null.";
if (args.empty()) return util::OkStatus();
diff --git a/src/spm_decode_main.cc b/src/spm_decode_main.cc
index baf91a6..03f118a 100644
--- a/src/spm_decode_main.cc
+++ b/src/spm_decode_main.cc
@@ -37,6 +37,7 @@ int main(int argc, char *argv[]) {
CHECK_OK(sp.SetDecodeExtraOptions(FLAGS_extra_options));
sentencepiece::io::OutputBuffer output(FLAGS_output);
+ CHECK_OK(output.status());
if (rest_args.empty()) {
rest_args.push_back(""); // empty means that reading from stdin.
@@ -88,6 +89,7 @@ int main(int argc, char *argv[]) {
for (const auto &filename : rest_args) {
sentencepiece::io::InputBuffer input(filename);
+ CHECK_OK(input.status());
while (input.ReadLine(&line)) {
const auto pieces = sentencepiece::string_util::Split(line, " ");
process(pieces);
diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc
index f8f0ed6..5e50005 100644
--- a/src/spm_encode_main.cc
+++ b/src/spm_encode_main.cc
@@ -40,6 +40,7 @@ int main(int argc, char *argv[]) {
CHECK_OK(sp.SetEncodeExtraOptions(FLAGS_extra_options));
sentencepiece::io::OutputBuffer output(FLAGS_output);
+ CHECK_OK(output.status());
if (rest_args.empty()) {
rest_args.push_back(""); // empty means that reading from stdin.
@@ -109,6 +110,7 @@ int main(int argc, char *argv[]) {
for (const auto &filename : rest_args) {
sentencepiece::io::InputBuffer input(filename);
+ CHECK_OK(input.status());
while (input.ReadLine(&line)) {
if (line.empty()) {
output.WriteLine("");
diff --git a/src/spm_export_vocab_main.cc b/src/spm_export_vocab_main.cc
index 955ca4c..cb20ebf 100644
--- a/src/spm_export_vocab_main.cc
+++ b/src/spm_export_vocab_main.cc
@@ -26,9 +26,10 @@ DEFINE_string(output_format, "txt", "output format. choose from txt or proto");
int main(int argc, char *argv[]) {
sentencepiece::flags::ParseCommandLineFlags(argc, argv);
sentencepiece::SentencePieceProcessor sp;
- sp.Load(FLAGS_model);
+ CHECK_OK(sp.Load(FLAGS_model));
sentencepiece::io::OutputBuffer output(FLAGS_output);
+ CHECK_OK(output.status());
if (FLAGS_output_format == "txt") {
for (const auto &piece : sp.model_proto().pieces()) {
diff --git a/src/spm_normalize_main.cc b/src/spm_normalize_main.cc
index ee68f13..c64b6a2 100644
--- a/src/spm_normalize_main.cc
+++ b/src/spm_normalize_main.cc
@@ -37,7 +37,7 @@ int main(int argc, char *argv[]) {
if (FLAGS_normalization_rule_tsv.empty() && !FLAGS_model.empty()) {
sentencepiece::SentencePieceProcessor sp;
- sp.Load(FLAGS_model);
+ CHECK_OK(sp.Load(FLAGS_model));
spec = sp.model_proto().normalizer_spec();
} else if (!FLAGS_normalization_rule_tsv.empty() && FLAGS_model.empty()) {
const auto cmap = sentencepiece::normalizer::Builder::BuildMapFromFile(
@@ -57,6 +57,7 @@ int main(int argc, char *argv[]) {
sentencepiece::normalizer::Normalizer normalizer(spec);
sentencepiece::io::OutputBuffer output(FLAGS_output);
+ CHECK_OK(output.status());
if (rest_args.empty()) {
rest_args.push_back(""); // empty means that read from stdin.
@@ -65,6 +66,7 @@ int main(int argc, char *argv[]) {
std::string line;
for (const auto &filename : rest_args) {
sentencepiece::io::InputBuffer input(filename);
+ CHECK_OK(input.status());
while (input.ReadLine(&line)) {
output.WriteLine(normalizer.Normalize(line));
}
diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc
index ce9b499..01b5b7d 100644
--- a/src/trainer_interface.cc
+++ b/src/trainer_interface.cc
@@ -43,29 +43,35 @@ const char TrainerInterface::kBOS[] = "<s>";
const char TrainerInterface::kEOS[] = "</s>";
const char TrainerInterface::kPAD[] = "<pad>";
-TrainerInterface::TrainerInterface(const TrainerSpec &trainer_spec,
- const NormalizerSpec &normalizer_spec)
- : trainer_spec_(trainer_spec), normalizer_spec_(normalizer_spec) {
- InitMetaPieces();
-
- CHECK(!trainer_spec_.model_prefix().empty());
- CHECK_GT(trainer_spec_.input().size(), 0);
- CHECK_GT(trainer_spec_.vocab_size(), 0);
+namespace {
+util::Status VerifySpec(const TrainerSpec &trainer_spec) {
+ CHECK_OR_RETURN(!trainer_spec.model_prefix().empty());
+ CHECK_GT_OR_RETURN(trainer_spec.input().size(), 0);
+ CHECK_GT_OR_RETURN(trainer_spec.vocab_size(), 0);
#define CHECK_RANGE(variable, minval, maxval) \
- CHECK(variable >= minval && variable <= maxval)
-
- CHECK_RANGE(trainer_spec_.character_coverage(), 0.98, 1.0);
- CHECK_RANGE(trainer_spec_.input_sentence_size(), 100, 100000000);
- CHECK_RANGE(trainer_spec_.max_sentencepiece_length(), 1, 512);
- CHECK_RANGE(trainer_spec_.mining_sentence_size(), 100, 5000000);
- CHECK_RANGE(trainer_spec_.num_sub_iterations(), 1, 10);
- CHECK_RANGE(trainer_spec_.num_threads(), 1, 128);
- CHECK_RANGE(trainer_spec_.seed_sentencepiece_size(), 1000, 5000000);
- CHECK_RANGE(trainer_spec_.shrinking_factor(), 0.5, 0.95);
- CHECK_RANGE(trainer_spec_.training_sentence_size(), 100, 100000000);
-
+ CHECK_OR_RETURN(variable >= minval && variable <= maxval)
+
+ CHECK_RANGE(trainer_spec.character_coverage(), 0.98, 1.0);
+ CHECK_RANGE(trainer_spec.input_sentence_size(), 100, 100000000);
+ CHECK_RANGE(trainer_spec.max_sentencepiece_length(), 1, 512);
+ CHECK_RANGE(trainer_spec.mining_sentence_size(), 100, 5000000);
+ CHECK_RANGE(trainer_spec.num_sub_iterations(), 1, 10);
+ CHECK_RANGE(trainer_spec.num_threads(), 1, 128);
+ CHECK_RANGE(trainer_spec.seed_sentencepiece_size(), 1000, 5000000);
+ CHECK_RANGE(trainer_spec.shrinking_factor(), 0.5, 0.95);
+ CHECK_RANGE(trainer_spec.training_sentence_size(), 100, 100000000);
#undef CHECK_RANGE
+
+ return util::OkStatus();
+}
+} // namespace
+
+TrainerInterface::TrainerInterface(const TrainerSpec &trainer_spec,
+ const NormalizerSpec &normalizer_spec)
+ : trainer_spec_(trainer_spec), normalizer_spec_(normalizer_spec) {
+ status_ = VerifySpec(trainer_spec_);
+ if (status_.ok()) status_ = InitMetaPieces();
}
TrainerInterface::~TrainerInterface() {}
@@ -127,15 +133,16 @@ bool TrainerInterface::IsValidSentencePiece(
return true;
}
-void TrainerInterface::LoadSentences() {
- CHECK(sentences_.empty());
- CHECK(required_chars_.empty());
+util::Status TrainerInterface::LoadSentences() {
+ RETURN_IF_ERROR(status());
+ CHECK_OR_RETURN(sentences_.empty());
+ CHECK_OR_RETURN(required_chars_.empty());
const normalizer::Normalizer normalizer(normalizer_spec_);
- CHECK(trainer_spec_.input_format().empty() ||
- trainer_spec_.input_format() == "text" ||
- trainer_spec_.input_format() == "tsv")
+ CHECK_OR_RETURN(trainer_spec_.input_format().empty() ||
+ trainer_spec_.input_format() == "text" ||
+ trainer_spec_.input_format() == "tsv")
<< "Supported formats are 'text' and 'tsv'.";
const bool is_tsv = trainer_spec_.input_format() == "tsv";
@@ -144,15 +151,16 @@ void TrainerInterface::LoadSentences() {
LOG(INFO) << "Loading corpus: " << filename;
std::string sentence;
io::InputBuffer input(filename);
+ RETURN_IF_ERROR(input.status());
while (input.ReadLine(&sentence)) {
int64 freq = 1;
if (is_tsv) {
const std::vector<std::string> v = string_util::Split(sentence, "\t");
- CHECK_EQ(v.size(), 2)
+ CHECK_EQ_OR_RETURN(v.size(), 2)
<< "Input format must be: word <tab> freq. " << sentence;
sentence = v[0];
freq = std::atoll(v[1].c_str());
- CHECK_GE(freq, 1);
+ CHECK_GE_OR_RETURN(freq, 1);
}
constexpr int kMaxLines = 2048;
@@ -171,7 +179,7 @@ void TrainerInterface::LoadSentences() {
LOG(INFO) << "Loading: " << normalized
<< "\tsize=" << sentences_.size();
}
- CHECK(normalized.find(" ") == std::string::npos)
+ CHECK_OR_RETURN(normalized.find(" ") == std::string::npos)
<< "Normalized string must not include spaces";
if (normalized.empty()) {
LOG(WARNING) << "Empty string found. removed";
@@ -198,7 +206,7 @@ END:
if (c == 0x0020) {
// UTF8ToUnicodeText returns a white space if the text
// contains an interchange-invalid character.
- CHECK(w.first.find(" ") == std::string::npos)
+ CHECK_OR_RETURN(w.first.find(" ") == std::string::npos)
<< "space must not be included in normalized string.";
continue;
}
@@ -217,13 +225,13 @@ END:
break;
}
accumulated_chars_count += w.second;
- CHECK_NE(w.first, 0x0020)
+ CHECK_NE_OR_RETURN(w.first, 0x0020)
<< "space must not be included in normalized string.";
required_chars_.insert(w);
}
LOG(INFO) << "alphabet size=" << required_chars_.size();
- CHECK(!port::ContainsKey(required_chars_, kUNKChar));
+ CHECK_OR_RETURN(!port::ContainsKey(required_chars_, kUNKChar));
// Replaces rare characters (characters not included in required_chars_)
// with kUNKChar.
@@ -242,17 +250,19 @@ END:
// +3 for meta pieces.
if (trainer_spec_.model_type() != TrainerSpec::WORD &&
trainer_spec_.model_type() != TrainerSpec::CHAR) {
- CHECK_LT(static_cast<int>(required_chars_.size() + meta_pieces_.size()),
- trainer_spec_.vocab_size())
+ CHECK_LT_OR_RETURN(
+ static_cast<int>(required_chars_.size() + meta_pieces_.size()),
+ trainer_spec_.vocab_size())
<< "Vocabulary size is smaller than required_chars. "
<< trainer_spec_.vocab_size() << " vs "
- << required_chars_.size() + meta_pieces_.size()
- << ". "
+ << required_chars_.size() + meta_pieces_.size() << ". "
<< "Increase vocab_size or decrease character_coverage with "
<< "--character_coverage option.";
}
LOG(INFO) << "Done! " << sentences_.size() << " sentences are loaded";
+
+ return util::OkStatus();
}
void TrainerInterface::SplitSentencesByWhitespace() {
@@ -267,31 +277,29 @@ void TrainerInterface::SplitSentencesByWhitespace() {
sentences_ = Sorted(tokens);
LOG(INFO) << "Done! " << sentences_.size();
}
-// #endif
-void TrainerInterface::Serialize(ModelProto *model_proto) const {
+util::Status TrainerInterface::Serialize(ModelProto *model_proto) const {
// Duplicated sentencepiece is not allowed.
std::unordered_set<std::string> dup;
- auto CheckPiece = [&dup](const std::string &piece) {
- CHECK(!piece.empty());
- CHECK(dup.insert(piece).second) << piece << " is already defined";
- };
+#define CHECK_PIECE(piece) \
+ CHECK_OR_RETURN(!piece.empty()); \
+ CHECK_OR_RETURN(dup.insert(piece).second) << piece << " is already defined";
for (const auto &w : meta_pieces_) {
auto *sp = model_proto->add_pieces();
sp->set_piece(w.first);
sp->set_type(w.second);
sp->set_score(0.0);
- CHECK_NE(ModelProto::SentencePiece::NORMAL, sp->type());
- CheckPiece(sp->piece());
+ CHECK_NE_OR_RETURN(ModelProto::SentencePiece::NORMAL, sp->type());
+ CHECK_PIECE(sp->piece());
}
for (const auto &w : final_pieces_) {
auto *sp = model_proto->add_pieces();
sp->set_piece(w.first);
sp->set_score(w.second);
- CheckPiece(sp->piece());
+ CHECK_PIECE(sp->piece());
}
*(model_proto->mutable_trainer_spec()) = trainer_spec_;
@@ -299,45 +307,55 @@ void TrainerInterface::Serialize(ModelProto *model_proto) const {
if (!trainer_spec_.hard_vocab_limit() ||
trainer_spec_.model_type() == TrainerSpec::CHAR) {
- CHECK_GE(trainer_spec_.vocab_size(), model_proto->pieces_size());
- CHECK_GE(trainer_spec_.vocab_size(), static_cast<int>(dup.size()));
+ CHECK_GE_OR_RETURN(trainer_spec_.vocab_size(), model_proto->pieces_size());
+ CHECK_GE_OR_RETURN(trainer_spec_.vocab_size(),
+ static_cast<int>(dup.size()));
model_proto->mutable_trainer_spec()->set_vocab_size(
model_proto->pieces_size());
} else {
- CHECK(trainer_spec_.vocab_size() == model_proto->pieces_size() &&
- trainer_spec_.vocab_size() == static_cast<int>(dup.size()))
+ CHECK_OR_RETURN(trainer_spec_.vocab_size() == model_proto->pieces_size() &&
+ trainer_spec_.vocab_size() == static_cast<int>(dup.size()))
<< "Use --hard_vocab_limit=false to make the vocab size `soft limit`.";
}
+
+ return util::OkStatus();
}
-void TrainerInterface::SaveModel(StringPiece filename) const {
+util::Status TrainerInterface::SaveModel(StringPiece filename) const {
LOG(INFO) << "Saving model: " << filename;
ModelProto model_proto;
- Serialize(&model_proto);
+ RETURN_IF_ERROR(Serialize(&model_proto));
std::ofstream ofs(filename.data(), OUTPUT_MODE);
- CHECK_OFS(ofs, filename.to_string());
- CHECK(model_proto.SerializeToOstream(&ofs));
+ CHECK_OR_RETURN(ofs) << "\"" << filename.data()
+ << "\": " << std::strerror(errno);
+ CHECK_OR_RETURN(model_proto.SerializeToOstream(&ofs));
+ return util::OkStatus();
}
-void TrainerInterface::SaveVocab(StringPiece filename) const {
+util::Status TrainerInterface::SaveVocab(StringPiece filename) const {
LOG(INFO) << "Saving vocabs: " << filename;
ModelProto model_proto;
Serialize(&model_proto);
io::OutputBuffer output(filename);
+ RETURN_IF_ERROR(output.status());
+
for (const auto &piece : model_proto.pieces()) {
std::ostringstream os;
os << piece.piece() << "\t" << piece.score();
- output.WriteLine(os.str());
+ CHECK_OR_RETURN(output.WriteLine(os.str()));
}
+
+ return util::OkStatus();
}
-void TrainerInterface::Save() const {
- SaveModel(trainer_spec_.model_prefix() + ".model");
- SaveVocab(trainer_spec_.model_prefix() + ".vocab");
+util::Status TrainerInterface::Save() const {
+ RETURN_IF_ERROR(SaveModel(trainer_spec_.model_prefix() + ".model"));
+ RETURN_IF_ERROR(SaveVocab(trainer_spec_.model_prefix() + ".vocab"));
+ return util::OkStatus();
}
-void TrainerInterface::InitMetaPieces() {
- CHECK(meta_pieces_.empty());
+util::Status TrainerInterface::InitMetaPieces() {
+ CHECK_OR_RETURN(meta_pieces_.empty());
std::vector<std::pair<int, std::string>> ids;
if (trainer_spec_.unk_id() >= 0)
@@ -354,17 +372,17 @@ void TrainerInterface::InitMetaPieces() {
int prev_id = -1;
bool has_unk = false;
for (const auto &p : ids) {
- CHECK_EQ(prev_id + 1, p.first)
+ CHECK_EQ_OR_RETURN(prev_id + 1, p.first)
<< "ID for `" << p.second << "` must be " << prev_id + 1;
prev_id = p.first;
- CHECK_EQ(static_cast<int>(meta_pieces_.size()), p.first);
+ CHECK_EQ_OR_RETURN(static_cast<int>(meta_pieces_.size()), p.first);
if (p.second == kUNK) has_unk = true;
meta_pieces_.emplace_back(
p.second, (p.second == kUNK ? ModelProto::SentencePiece::UNKNOWN
: ModelProto::SentencePiece::CONTROL));
}
- CHECK(has_unk) << kUNK << " must be defined.";
+ CHECK_OR_RETURN(has_unk) << kUNK << " must be defined.";
for (const auto &w : trainer_spec_.control_symbols()) {
meta_pieces_.emplace_back(w, ModelProto::SentencePiece::CONTROL);
@@ -373,6 +391,8 @@ void TrainerInterface::InitMetaPieces() {
for (const auto &w : trainer_spec_.user_defined_symbols()) {
meta_pieces_.emplace_back(w, ModelProto::SentencePiece::USER_DEFINED);
}
+
+ return util::OkStatus();
}
} // namespace sentencepiece
diff --git a/src/trainer_interface.h b/src/trainer_interface.h
index 8ec4df7..cefb1cd 100644
--- a/src/trainer_interface.h
+++ b/src/trainer_interface.h
@@ -22,6 +22,7 @@
#include "common.h"
#include "sentencepiece_model.pb.h"
+#include "sentencepiece_processor.h"
#include "util.h"
namespace sentencepiece {
@@ -65,7 +66,9 @@ class TrainerInterface {
virtual ~TrainerInterface();
- virtual void Train() {}
+ virtual util::Status Train() { return status(); }
+
+ virtual util::Status status() const { return status_; }
FRIEND_TEST(TrainerInterfaceTest, IsValidSentencePieceTest);
FRIEND_TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest);
@@ -79,7 +82,7 @@ class TrainerInterface {
// Loads all sentences from spec.input().
// It loads at most input_sentence_size sentences.
- void LoadSentences();
+ util::Status LoadSentences();
// Splits all sentencecs by whitespaces and
// replace the |sentences_| with tokenized string.
@@ -89,7 +92,7 @@ class TrainerInterface {
void SplitSentencesByWhitespace();
// Save model files into spec.model_prefix().
- void Save() const;
+ util::Status Save() const;
// Set of characters which must be included in the final vocab.
// The value of this map stores the frequency.
@@ -109,24 +112,27 @@ class TrainerInterface {
// Reserved control pieces. e.g., <unk>, <s>, </s>.
// The index corresponds to vocab id.
- std::vector<std::pair<std::string,
- ModelProto::SentencePiece::Type>> meta_pieces_;
+ std::vector<std::pair<std::string, ModelProto::SentencePiece::Type>>
+ meta_pieces_;
+
+ // Detect errors on initialization.
+ util::Status status_;
private:
// Serialize final_pieces_ to |model_proto|.
- void Serialize(ModelProto *model_proto) const;
+ util::Status Serialize(ModelProto *model_proto) const;
// Saves the best sentence split with the current model for debugging.
- void SaveSplits(StringPiece filename) const;
+ util::Status SaveSplits(StringPiece filename) const;
// Saves model file.
- void SaveModel(StringPiece filename) const;
+ util::Status SaveModel(StringPiece filename) const;
// Saves vocabulary file for NMT.
- void SaveVocab(StringPiece filename) const;
+ util::Status SaveVocab(StringPiece filename) const;
// Initializes `meta_pieces_` from TrainerSpec.
- void InitMetaPieces();
+ util::Status InitMetaPieces();
};
} // namespace sentencepiece
#endif // TRAINER_INTERFACE_H_
diff --git a/src/trainer_interface_test.cc b/src/trainer_interface_test.cc
index 055bb09..8fb138e 100644
--- a/src/trainer_interface_test.cc
+++ b/src/trainer_interface_test.cc
@@ -171,15 +171,18 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
trainer_spec.set_unk_id(0);
trainer_spec.set_bos_id(-1);
trainer_spec.set_eos_id(2);
- EXPECT_DEATH(TrainerInterface trainer(trainer_spec, normalizer_spec));
+ TrainerInterface trainer(trainer_spec, normalizer_spec);
+ EXPECT_NOT_OK(trainer.status());
+ }
+ {
// UNK is not defined.
trainer_spec.set_unk_id(-1);
trainer_spec.set_bos_id(0);
trainer_spec.set_eos_id(1);
- EXPECT_DEATH(TrainerInterface trainer(trainer_spec, normalizer_spec));
+ TrainerInterface trainer(trainer_spec, normalizer_spec);
+ EXPECT_NOT_OK(trainer.status());
}
-
}
TEST(TrainerInterfaceTest, SerializeTest) {
@@ -198,7 +201,7 @@ TEST(TrainerInterfaceTest, SerializeTest) {
TrainerInterface trainer(trainer_spec, normalizer_spec);
trainer.final_pieces_ = final_pieces;
ModelProto model_proto;
- EXPECT_DEATH(trainer.Serialize(&model_proto))
+ EXPECT_NOT_OK(trainer.Serialize(&model_proto));
}
{
@@ -207,7 +210,7 @@ TEST(TrainerInterfaceTest, SerializeTest) {
TrainerInterface trainer(trainer_spec, normalizer_spec);
trainer.final_pieces_ = final_pieces;
ModelProto model_proto;
- trainer.Serialize(&model_proto);
+ EXPECT_OK(trainer.Serialize(&model_proto));
EXPECT_EQ(6, model_proto.trainer_spec().vocab_size());
for (int i = 3; i < 6; ++i) {
EXPECT_EQ(final_pieces[i - 3].first, model_proto.pieces(i).piece());
diff --git a/src/unigram_model.cc b/src/unigram_model.cc
index dce7c59..778681b 100644
--- a/src/unigram_model.cc
+++ b/src/unigram_model.cc
@@ -434,7 +434,7 @@ void ModelBase::BuildTrie(std::vector<std::pair<StringPiece, int>> *pieces) {
if (!status().ok()) return;
if (pieces->empty()) {
- status_ = util::InternalError("No pieces are loaded.");
+ status_ = util::InternalError("no pieces are loaded.");
return;
}
@@ -453,7 +453,7 @@ void ModelBase::BuildTrie(std::vector<std::pair<StringPiece, int>> *pieces) {
trie_ = port::MakeUnique<Darts::DoubleArray>();
if (trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr,
&value[0]) != 0) {
- status_ = util::InternalError("Cannot build double-array.");
+ status_ = util::InternalError("cannot build double-array.");
return;
}
@@ -471,7 +471,7 @@ void ModelBase::BuildTrie(std::vector<std::pair<StringPiece, int>> *pieces) {
pieces_.clear();
if (trie_results_size_ == 0)
- status_ = util::InternalError("No entry is found in the trie.");
+ status_ = util::InternalError("no entry is found in the trie.");
}
Model::Model(const ModelProto &model_proto) {
diff --git a/src/unigram_model_trainer.cc b/src/unigram_model_trainer.cc
index ef19170..40e58d7 100644
--- a/src/unigram_model_trainer.cc
+++ b/src/unigram_model_trainer.cc
@@ -466,14 +466,18 @@ TrainerModel::SentencePieces Trainer::FinalizeSentencePieces(
return Sorted(final_sentencepieces);
}
-void Trainer::Train() {
+util::Status Trainer::Train() {
+ RETURN_IF_ERROR(status());
+
LOG(INFO) << "Starts training with : \n" << trainer_spec_.Utf8DebugString();
- CHECK(normalizer_spec_.escape_whitespaces());
+ CHECK_EQ_OR_RETURN(TrainerSpec::UNIGRAM, trainer_spec_.model_type());
+ CHECK_OR_RETURN(normalizer_spec_.escape_whitespaces());
TrainerModel model(trainer_spec_, normalizer_spec_);
- LoadSentences();
+ RETURN_IF_ERROR(model.status());
+ RETURN_IF_ERROR(LoadSentences());
auto seed_sentencepieces = MakeSeedSentencePieces();
model.SetSentencePieces(std::move(seed_sentencepieces));
@@ -522,7 +526,7 @@ void Trainer::Train() {
// Finally, adjusts the size of sentencepices to be |vocab_size|.
final_pieces_ = FinalizeSentencePieces(model);
- Save();
+ return Save();
}
} // namespace unigram
} // namespace sentencepiece
diff --git a/src/unigram_model_trainer.h b/src/unigram_model_trainer.h
index 10138dd..358e67f 100644
--- a/src/unigram_model_trainer.h
+++ b/src/unigram_model_trainer.h
@@ -68,7 +68,7 @@ class Trainer : public TrainerInterface {
const NormalizerSpec &normalizer_spec)
: TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec) {}
- void Train() override;
+ util::Status Train() override;
private:
FRIEND_TEST(TrainerTest, IsValidSentencePieceTest);
diff --git a/src/unigram_model_trainer_test.cc b/src/unigram_model_trainer_test.cc
index 096f70e..a88dc71 100644
--- a/src/unigram_model_trainer_test.cc
+++ b/src/unigram_model_trainer_test.cc
@@ -49,7 +49,7 @@ TEST(UnigramTrainerTest, EndToEndTest) {
test::ScopedTempFile sf("tmp_model");
trainer_spec.set_model_prefix(sf.filename());
unigram::Trainer trainer(trainer_spec, normalizer_spec);
- trainer.Train();
+ EXPECT_OK(trainer.Train());
SentencePieceProcessor sp;
EXPECT_OK(sp.Load(std::string(sf.filename()) + ".model"));
diff --git a/src/util.cc b/src/util.cc
index d06d953..915f01f 100644
--- a/src/util.cc
+++ b/src/util.cc
@@ -220,7 +220,9 @@ namespace io {
InputBuffer::InputBuffer(StringPiece filename)
: is_(filename.empty() ? &std::cin
: new std::ifstream(WPATH(filename.data()))) {
- CHECK_IFS(*is_, filename.data());
+ if (!*is_)
+ status_ = util::StatusBuilder(util::error::NOT_FOUND)
+ << "\"" << filename.data() << "\": " << std::strerror(errno);
}
InputBuffer::~InputBuffer() {
@@ -229,6 +231,8 @@ InputBuffer::~InputBuffer() {
}
}
+util::Status InputBuffer::status() const { return status_; }
+
bool InputBuffer::ReadLine(std::string *line) {
return static_cast<bool>(std::getline(*is_, *line));
}
@@ -237,7 +241,9 @@ OutputBuffer::OutputBuffer(StringPiece filename)
: os_(filename.empty()
? &std::cout
: new std::ofstream(WPATH(filename.data()), OUTPUT_MODE)) {
- CHECK_OFS(*os_, filename.data());
+ if (!*os_)
+ status_ = util::StatusBuilder(util::error::PERMISSION_DENIED)
+ << "\"" << filename.data() << "\": " << std::strerror(errno);
}
OutputBuffer::~OutputBuffer() {
@@ -246,13 +252,15 @@ OutputBuffer::~OutputBuffer() {
}
}
-void OutputBuffer::Write(StringPiece text) {
+util::Status OutputBuffer::status() const { return status_; }
+
+bool OutputBuffer::Write(StringPiece text) {
os_->write(text.data(), text.size());
+ return os_->good();
}
-void OutputBuffer::WriteLine(StringPiece text) {
- Write(text);
- Write("\n");
+bool OutputBuffer::WriteLine(StringPiece text) {
+ return Write(text) && Write("\n");
}
} // namespace io
} // namespace sentencepiece
diff --git a/src/util.h b/src/util.h
index cf30882..f3b8158 100644
--- a/src/util.h
+++ b/src/util.h
@@ -17,6 +17,7 @@
#include <algorithm>
#include <fstream>
+#include <memory>
#include <sstream>
#include <string>
#include <vector>
@@ -201,21 +202,25 @@ namespace io {
class InputBuffer {
public:
explicit InputBuffer(StringPiece filename);
+ util::Status status() const;
~InputBuffer();
bool ReadLine(std::string *line);
private:
+ util::Status status_;
std::istream *is_;
};
class OutputBuffer {
public:
explicit OutputBuffer(StringPiece filename);
+ util::Status status() const;
~OutputBuffer();
- void Write(StringPiece text);
- void WriteLine(StringPiece text);
+ bool Write(StringPiece text);
+ bool WriteLine(StringPiece text);
private:
+ util::Status status_;
std::ostream *os_;
};
} // namespace io
@@ -390,7 +395,37 @@ DECLARE_ERROR(DataLoss, DATA_LOSS)
DECLARE_ERROR(Unknown, UNKNOWN)
DECLARE_ERROR(PermissionDenied, PERMISSION_DENIED)
DECLARE_ERROR(Unauthenticated, UNAUTHENTICATED)
-} // namespace util
+class StatusBuilder {
+ public:
+ explicit StatusBuilder(error::Code code) : code_(code) {}
+
+ template <typename T>
+ StatusBuilder &operator<<(const T &value) {
+ os_ << value;
+ return *this;
+ }
+
+ operator Status() const { return Status(code_, os_.str()); }
+
+ private:
+ error::Code code_;
+ std::ostringstream os_;
+};
+
+#define CHECK_OR_RETURN(condition) \
+ if (condition) { \
+ } else /* NOLINT */ \
+ return ::sentencepiece::util::StatusBuilder(util::error::INTERNAL) \
+ << __FILE__ << "(" << __LINE__ << ") [" << #condition << "] "
+
+#define CHECK_EQ_OR_RETURN(a, b) CHECK_OR_RETURN((a) == (b))
+#define CHECK_NE_OR_RETURN(a, b) CHECK_OR_RETURN((a) != (b))
+#define CHECK_GE_OR_RETURN(a, b) CHECK_OR_RETURN((a) >= (b))
+#define CHECK_LE_OR_RETURN(a, b) CHECK_OR_RETURN((a) <= (b))
+#define CHECK_GT_OR_RETURN(a, b) CHECK_OR_RETURN((a) > (b))
+#define CHECK_LT_OR_RETURN(a, b) CHECK_OR_RETURN((a) < (b))
+
+} // namespace util
} // namespace sentencepiece
#endif // UTIL_H_
diff --git a/src/util_test.cc b/src/util_test.cc
index 224ea50..e9d3829 100644
--- a/src/util_test.cc
+++ b/src/util_test.cc
@@ -449,7 +449,8 @@ TEST(UtilTest, InputOutputBufferTest) {
}
TEST(UtilTest, InputOutputBufferInvalidFileTest) {
- EXPECT_DEATH(io::InputBuffer input("__UNKNOWN__FILE__"));
+ io::InputBuffer input("__UNKNOWN__FILE__");
+ EXPECT_NOT_OK(input.status());
}
TEST(UtilTest, STLDeleteELementsTest) {
diff --git a/src/word_model_trainer.cc b/src/word_model_trainer.cc
index 9ada68b..a16898e 100644
--- a/src/word_model_trainer.cc
+++ b/src/word_model_trainer.cc
@@ -23,13 +23,15 @@
namespace sentencepiece {
namespace word {
-void Trainer::Train() {
+util::Status Trainer::Train() {
+ RETURN_IF_ERROR(status());
+
LOG(INFO) << "Starts training with : \n" << trainer_spec_.Utf8DebugString();
- CHECK(normalizer_spec_.escape_whitespaces());
- CHECK_EQ(TrainerSpec::WORD, trainer_spec_.model_type());
+ CHECK_OR_RETURN(normalizer_spec_.escape_whitespaces());
+ CHECK_EQ_OR_RETURN(TrainerSpec::WORD, trainer_spec_.model_type());
- LoadSentences();
+ RETURN_IF_ERROR(LoadSentences());
std::unordered_map<std::string, uint64> freq;
for (const auto &it : sentences_) {
@@ -39,7 +41,7 @@ void Trainer::Train() {
}
const int vocab_size = trainer_spec_.vocab_size() - meta_pieces_.size();
- CHECK_GE(vocab_size, 0);
+ CHECK_GE_OR_RETURN(vocab_size, 0);
uint64 sum = 0;
for (const auto &it : freq) {
@@ -48,7 +50,7 @@ void Trainer::Train() {
const float logsum = log(sum);
- CHECK(final_pieces_.empty());
+ CHECK_OR_RETURN(final_pieces_.empty());
for (const auto &it : Sorted(freq)) {
if (it.first.find(kUNKStr) != std::string::npos) {
continue;
@@ -59,7 +61,7 @@ void Trainer::Train() {
final_pieces_.emplace_back(it.first, log(it.second) - logsum);
}
- Save();
+ return Save();
}
} // namespace word
} // namespace sentencepiece
diff --git a/src/word_model_trainer.h b/src/word_model_trainer.h
index 672c5ac..1a1aecc 100644
--- a/src/word_model_trainer.h
+++ b/src/word_model_trainer.h
@@ -32,7 +32,7 @@ class Trainer : public TrainerInterface {
const NormalizerSpec &normalizer_spec)
: TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec) {}
- void Train() override;
+ util::Status Train() override;
};
} // namespace word
} // namespace sentencepiece
diff --git a/src/word_model_trainer_test.cc b/src/word_model_trainer_test.cc
index 87d1bfa..35a6eb5 100644
--- a/src/word_model_trainer_test.cc
+++ b/src/word_model_trainer_test.cc
@@ -48,10 +48,10 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
normalizer_spec.set_add_dummy_prefix(true);
Trainer trainer(trainer_spec, normalizer_spec);
- trainer.Train();
+ EXPECT_OK(trainer.Train());
SentencePieceProcessor processor;
- processor.Load(model_prefix + ".model");
+ EXPECT_OK(processor.Load(model_prefix + ".model"));
const auto &model = processor.model_proto();
std::vector<std::string> pieces;