From 03cd62e8f94aa93c9f2a2b7d09849c31674dae1f Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Sat, 1 Sep 2018 23:45:08 +0900 Subject: Added is_binary flag to filesystem API. --- src/filesystem.cc | 31 ++++++++++++++++--------------- src/filesystem.h | 10 ++++++---- src/sentencepiece_processor.cc | 2 +- src/trainer_interface.cc | 2 +- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/filesystem.cc b/src/filesystem.cc index 2bf2bbb..4f077a4 100644 --- a/src/filesystem.cc +++ b/src/filesystem.cc @@ -17,20 +17,17 @@ #include "util.h" -#ifdef OS_WIN -#define OUTPUT_MODE std::ios::binary | std::ios::out -#else -#define OUTPUT_MODE std::ios::out -#endif - namespace sentencepiece { namespace filesystem { class PosixReadableFile : public ReadableFile { public: - PosixReadableFile(absl::string_view filename) - : is_(filename.empty() ? &std::cin - : new std::ifstream(WPATH(filename.data()))) { + PosixReadableFile(absl::string_view filename, bool is_binary = false) + : is_(filename.empty() + ? &std::cin + : new std::ifstream(WPATH(filename.data()), + is_binary ? std::ios::binary | std::ios::in + : std::ios::out)) { if (!*is_) status_ = util::StatusBuilder(util::error::NOT_FOUND) << "\"" << filename.data() << "\": " << util::StrError(errno); @@ -63,10 +60,12 @@ class PosixReadableFile : public ReadableFile { class PosixWritableFile : public WritableFile { public: - PosixWritableFile(absl::string_view filename) + PosixWritableFile(absl::string_view filename, bool is_binary = false) : os_(filename.empty() ? &std::cout - : new std::ofstream(WPATH(filename.data()), OUTPUT_MODE)) { + : new std::ofstream(WPATH(filename.data()), + is_binary ? std::ios::binary | std::ios::out + : std::ios::out)) { if (!*os_) status_ = util::StatusBuilder(util::error::PERMISSION_DENIED) << "\"" << filename.data() << "\": " << util::StrError(errno); @@ -90,12 +89,14 @@ class PosixWritableFile : public WritableFile { std::ostream *os_; }; -std::unique_ptr NewReadableFile(absl::string_view filename) { - return port::MakeUnique(filename); +std::unique_ptr NewReadableFile(absl::string_view filename, + bool is_binary) { + return port::MakeUnique(filename, is_binary); } -std::unique_ptr NewWritableFile(absl::string_view filename) { - return port::MakeUnique(filename); +std::unique_ptr NewWritableFile(absl::string_view filename, + bool is_binary) { + return port::MakeUnique(filename, is_binary); } } // namespace filesystem diff --git a/src/filesystem.h b/src/filesystem.h index c2de7c9..488321d 100644 --- a/src/filesystem.h +++ b/src/filesystem.h @@ -29,7 +29,7 @@ namespace filesystem { class ReadableFile { public: ReadableFile() {} - explicit ReadableFile(absl::string_view filename) {} + explicit ReadableFile(absl::string_view filename, bool is_binary = false) {} virtual ~ReadableFile() {} virtual util::Status status() const = 0; @@ -40,7 +40,7 @@ class ReadableFile { class WritableFile { public: WritableFile() {} - explicit WritableFile(absl::string_view filename) {} + explicit WritableFile(absl::string_view filename, bool is_binary = false) {} virtual ~WritableFile() {} virtual util::Status status() const = 0; @@ -48,8 +48,10 @@ class WritableFile { virtual bool WriteLine(absl::string_view text) = 0; }; -std::unique_ptr NewReadableFile(absl::string_view filename); -std::unique_ptr NewWritableFile(absl::string_view filename); +std::unique_ptr NewReadableFile(absl::string_view filename, + bool is_binary = false); +std::unique_ptr NewWritableFile(absl::string_view filename, + bool is_binary = false); } // namespace filesystem } // namespace sentencepiece diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 1e2ddcb..8d7139d 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -43,7 +43,7 @@ SentencePieceProcessor::SentencePieceProcessor() {} SentencePieceProcessor::~SentencePieceProcessor() {} util::Status SentencePieceProcessor::Load(util::min_string_view filename) { - auto input = filesystem::NewReadableFile(string_util::ToSV(filename)); + auto input = filesystem::NewReadableFile(string_util::ToSV(filename), true); RETURN_IF_ERROR(input->status()); std::string proto; CHECK_OR_RETURN(input->ReadAll(&proto)); diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index fc8dd28..36313b7 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -353,7 +353,7 @@ util::Status TrainerInterface::SaveModel(absl::string_view filename) const { LOG(INFO) << "Saving model: " << filename; ModelProto model_proto; RETURN_IF_ERROR(Serialize(&model_proto)); - auto output = filesystem::NewWritableFile(filename.data()); + auto output = filesystem::NewWritableFile(filename.data(), true); RETURN_IF_ERROR(output->status()); output->Write(model_proto.SerializeAsString()); return util::OkStatus(); -- cgit v1.2.3