// Copyright 2016 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.! #ifndef UTIL_H_ #define UTIL_H_ #include #include #include #include #include #include #include #include #include #include #include "common.h" #include "sentencepiece_processor.h" #include "third_party/absl/strings/string_view.h" #ifdef SPM_NO_THREADLOCAL #include #endif namespace sentencepiece { template std::ostream &operator<<(std::ostream &out, const std::vector &v) { for (const auto n : v) { out << " " << n; } return out; } // String utilities namespace string_util { inline absl::string_view ToSV(util::min_string_view data) { return absl::string_view(data.data(), data.size()); } struct string_view_hash { // DJB hash function. inline size_t operator()(const absl::string_view &sp) const { size_t hash = 5381; for (size_t i = 0; i < sp.size(); ++i) { hash = ((hash << 5) + hash) + sp[i]; } return hash; } }; inline std::string ToLower(absl::string_view arg) { std::string lower_value = std::string(arg); std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(), ::tolower); return lower_value; } inline std::string ToUpper(absl::string_view arg) { std::string upper_value = std::string(arg); std::transform(upper_value.begin(), upper_value.end(), upper_value.begin(), ::toupper); return upper_value; } template inline bool lexical_cast(absl::string_view arg, Target *result) { std::stringstream ss; return (ss << arg.data() && ss >> *result); } template <> inline bool lexical_cast(absl::string_view arg, bool *result) { const char *kTrue[] = {"1", "t", "true", "y", "yes"}; const char *kFalse[] = {"0", "f", "false", "n", "no"}; std::string lower_value = std::string(arg); std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(), ::tolower); for (size_t i = 0; i < 5; ++i) { if (lower_value == kTrue[i]) { *result = true; return true; } else if (lower_value == kFalse[i]) { *result = false; return true; } } return false; } template <> inline bool lexical_cast(absl::string_view arg, std::string *result) { *result = std::string(arg); return true; } std::vector Split(const std::string &str, const std::string &delim, bool allow_empty = false); std::vector SplitPiece(absl::string_view str, absl::string_view delim, bool allow_empty = false); std::string Join(const std::vector &tokens, absl::string_view delim); std::string Join(const std::vector &tokens, absl::string_view delim); inline std::string StrCat(absl::string_view str) { return std::string(str.data(), str.size()); } template inline std::string StrCat(absl::string_view first, const T &... rest) { return std::string(first) + StrCat(rest...); } std::string StringReplace(absl::string_view s, absl::string_view oldsub, absl::string_view newsub, bool replace_all); void StringReplace(absl::string_view s, absl::string_view oldsub, absl::string_view newsub, bool replace_all, std::string *res); template inline bool DecodePOD(absl::string_view str, T *result) { CHECK_NOTNULL(result); if (sizeof(*result) != str.size()) { return false; } memcpy(result, str.data(), sizeof(T)); return true; } template inline std::string EncodePOD(const T &value) { std::string s; s.resize(sizeof(T)); memcpy(const_cast(s.data()), &value, sizeof(T)); return s; } inline bool StartsWith(absl::string_view text, absl::string_view prefix) { return prefix.empty() || (text.size() >= prefix.size() && memcmp(text.data(), prefix.data(), prefix.size()) == 0); } inline bool EndsWith(absl::string_view text, absl::string_view suffix) { return suffix.empty() || (text.size() >= suffix.size() && memcmp(text.data() + (text.size() - suffix.size()), suffix.data(), suffix.size()) == 0); } inline bool ConsumePrefix(absl::string_view *str, absl::string_view expected) { if (!StartsWith(*str, expected)) return false; str->remove_prefix(expected.size()); return true; } template inline std::string IntToHex(T value) { std::ostringstream os; os << std::hex << std::uppercase << value; return os.str(); } template inline T HexToInt(absl::string_view value) { T n; std::istringstream is(value.data()); is >> std::hex >> n; return n; } template inline size_t Itoa(T val, char *s) { char *org = s; if (val < 0) { *s++ = '-'; val = -val; } char *t = s; T mod = 0; while (val) { mod = val % 10; *t++ = static_cast(mod) + '0'; val /= 10; } if (s == t) { *t++ = '0'; } *t = '\0'; std::reverse(s, t); return static_cast(t - org); } template std::string SimpleItoa(T val) { char buf[32]; Itoa(val, buf); return std::string(buf); } // Return length of a single UTF-8 source character inline size_t OneCharLen(const char *src) { return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4]; } // Return (x & 0xC0) == 0x80; // Since trail bytes are always in [0x80, 0xBF], we can optimize: inline bool IsTrailByte(char x) { return static_cast(x) < -0x40; } inline bool IsValidCodepoint(char32 c) { return (static_cast(c) < 0xD800) || (c >= 0xE000 && c <= 0x10FFFF); } bool IsStructurallyValid(absl::string_view str); using UnicodeText = std::vector; char32 DecodeUTF8(const char *begin, const char *end, size_t *mblen); inline char32 DecodeUTF8(absl::string_view input, size_t *mblen) { return DecodeUTF8(input.data(), input.data() + input.size(), mblen); } inline bool IsValidDecodeUTF8(absl::string_view input, size_t *mblen) { const char32 c = DecodeUTF8(input, mblen); return c != kUnicodeError || *mblen == 3; } size_t EncodeUTF8(char32 c, char *output); std::string UnicodeCharToUTF8(const char32 c); UnicodeText UTF8ToUnicodeText(absl::string_view utf8); std::string UnicodeTextToUTF8(const UnicodeText &utext); } // namespace string_util // other map/ptr utilties namespace port { template bool ContainsKey(const Collection &collection, const Key &key) { return collection.find(key) != collection.end(); } template const typename Collection::value_type::second_type &FindOrDie( const Collection &collection, const typename Collection::value_type::first_type &key) { typename Collection::const_iterator it = collection.find(key); CHECK(it != collection.end()) << "Map key not found: " << key; return it->second; } template const typename Collection::value_type::second_type &FindWithDefault( const Collection &collection, const typename Collection::value_type::first_type &key, const typename Collection::value_type::second_type &value) { typename Collection::const_iterator it = collection.find(key); if (it == collection.end()) { return value; } return it->second; } template bool InsertIfNotPresent(Collection *const collection, const typename Collection::value_type &vt) { return collection->insert(vt).second; } template bool InsertIfNotPresent( Collection *const collection, const typename Collection::value_type::first_type &key, const typename Collection::value_type::second_type &value) { return InsertIfNotPresent(collection, typename Collection::value_type(key, value)); } template void InsertOrDie(Collection *const collection, const typename Collection::value_type::first_type &key, const typename Collection::value_type::second_type &data) { CHECK(InsertIfNotPresent(collection, key, data)) << "duplicate key"; } // hash inline void mix(uint64 &a, uint64 &b, uint64 &c) { // 64bit version a -= b; a -= c; a ^= (c >> 43); b -= c; b -= a; b ^= (a << 9); c -= a; c -= b; c ^= (b >> 8); a -= b; a -= c; a ^= (c >> 38); b -= c; b -= a; b ^= (a << 23); c -= a; c -= b; c ^= (b >> 5); a -= b; a -= c; a ^= (c >> 35); b -= c; b -= a; b ^= (a << 49); c -= a; c -= b; c ^= (b >> 11); a -= b; a -= c; a ^= (c >> 12); b -= c; b -= a; b ^= (a << 18); c -= a; c -= b; c ^= (b >> 22); } inline uint64 FingerprintCat(uint64 x, uint64 y) { uint64 b = 0xe08c1d668b756f82; // more of the golden ratio mix(x, b, y); return y; } // Trait to select overloads and return types for MakeUnique. template struct MakeUniqueResult { using scalar = std::unique_ptr; }; template struct MakeUniqueResult { using array = std::unique_ptr; }; template struct MakeUniqueResult { using invalid = void; }; // MakeUnique(...) is an early implementation of C++14 std::make_unique. // It is designed to be 100% compatible with std::make_unique so that the // eventual switchover will be a simple renaming operation. template typename MakeUniqueResult::scalar MakeUnique(Args &&... args) { // NOLINT return std::unique_ptr( new T(std::forward(args)...)); // NOLINT(build/c++11) } // Overload for array of unknown bound. // The allocation of arrays needs to use the array form of new, // and cannot take element constructor arguments. template typename MakeUniqueResult::array MakeUnique(size_t n) { return std::unique_ptr(new typename std::remove_extent::type[n]()); } // Reject arrays of known bound. template typename MakeUniqueResult::invalid MakeUnique(Args &&... /* args */) = delete; // NOLINT template void STLDeleteElements(std::vector *vec) { for (auto item : *vec) { delete item; } vec->clear(); } } // namespace port namespace random { std::mt19937 *GetRandomGenerator(); } // namespace random namespace util { inline std::string JoinPath(absl::string_view path) { return std::string(path.data(), path.size()); } template inline std::string JoinPath(absl::string_view first, const T &... rest) { #ifdef OS_WIN return JoinPath(first) + "\\" + JoinPath(rest...); #else return JoinPath(first) + "/" + JoinPath(rest...); #endif } std::string StrError(int errnum); inline Status OkStatus() { return Status(); } #define DECLARE_ERROR(FUNC, CODE) \ inline util::Status FUNC##Error(absl::string_view str) { \ return util::Status(error::CODE, str.data()); \ } \ inline bool Is##FUNC(const util::Status &status) { \ return status.code() == error::CODE; \ } DECLARE_ERROR(Cancelled, CANCELLED) DECLARE_ERROR(InvalidArgument, INVALID_ARGUMENT) DECLARE_ERROR(NotFound, NOT_FOUND) DECLARE_ERROR(AlreadyExists, ALREADY_EXISTS) DECLARE_ERROR(ResourceExhausted, RESOURCE_EXHAUSTED) DECLARE_ERROR(Unavailable, UNAVAILABLE) DECLARE_ERROR(FailedPrecondition, FAILED_PRECONDITION) DECLARE_ERROR(OutOfRange, OUT_OF_RANGE) DECLARE_ERROR(Unimplemented, UNIMPLEMENTED) DECLARE_ERROR(Internal, INTERNAL) DECLARE_ERROR(Aborted, ABORTED) DECLARE_ERROR(DeadlineExceeded, DEADLINE_EXCEEDED) DECLARE_ERROR(DataLoss, DATA_LOSS) DECLARE_ERROR(Unknown, UNKNOWN) DECLARE_ERROR(PermissionDenied, PERMISSION_DENIED) DECLARE_ERROR(Unauthenticated, UNAUTHENTICATED) class StatusBuilder { public: explicit StatusBuilder(error::Code code) : code_(code) {} template StatusBuilder &operator<<(const T &value) { os_ << value; return *this; } operator Status() const { return Status(code_, os_.str()); } private: error::Code code_; std::ostringstream os_; }; #define CHECK_OR_RETURN(condition) \ if (condition) { \ } else /* NOLINT */ \ return ::sentencepiece::util::StatusBuilder(util::error::INTERNAL) \ << __FILE__ << "(" << __LINE__ << ") [" << #condition << "] " #define CHECK_EQ_OR_RETURN(a, b) CHECK_OR_RETURN((a) == (b)) #define CHECK_NE_OR_RETURN(a, b) CHECK_OR_RETURN((a) != (b)) #define CHECK_GE_OR_RETURN(a, b) CHECK_OR_RETURN((a) >= (b)) #define CHECK_LE_OR_RETURN(a, b) CHECK_OR_RETURN((a) <= (b)) #define CHECK_GT_OR_RETURN(a, b) CHECK_OR_RETURN((a) > (b)) #define CHECK_LT_OR_RETURN(a, b) CHECK_OR_RETURN((a) < (b)) } // namespace util } // namespace sentencepiece #endif // UTIL_H_