diff options
author | Qianqian Zhu <qianqian.zhu@hotmail.com> | 2021-07-11 08:56:58 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-11 08:56:58 +0300 |
commit | 42f0b8b74bba16fed646c8af7b2f75e02af7a85c (patch) | |
tree | 5ce1c621aa65f286ff82c8374f9e8605a7b873b3 | |
parent | 7e6ea51841025d5abdb6fdb1fc33dc4907355dc9 (diff) |
Binary shortlist (#856)
Co-authored-by: Kenneth Heafield <github@kheafield.com>
-rw-r--r-- | CHANGELOG.md | 2 | ||||
-rw-r--r-- | src/command/marian_conv.cpp | 26 | ||||
-rw-r--r-- | src/common/hash.h | 14 | ||||
-rw-r--r-- | src/data/shortlist.cpp | 260 | ||||
-rw-r--r-- | src/data/shortlist.h | 74 | ||||
-rw-r--r-- | src/translator/translator.h | 10 |
6 files changed, 372 insertions, 14 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index b95ffb5e..a9e24f57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add unit tests for binary files. - Fix compilation with OMP - Compute aligned memory sizes using exact sizing +- Support for loading lexical shortlist from a binary blob +- Integrate a shortlist converter (which can convert a text lexical shortlist to a binary shortlist) into marian-conv with --shortlist option ### Fixed - Added support to MPIWrappest::bcast (and similar) for count of type size_t diff --git a/src/command/marian_conv.cpp b/src/command/marian_conv.cpp index e0e89d2b..943f61d4 100644 --- a/src/command/marian_conv.cpp +++ b/src/command/marian_conv.cpp @@ -3,7 +3,7 @@ #include "tensors/cpu/expression_graph_packable.h" #include "onnx/expression_graph_onnx_exporter.h" #include "layers/lsh.h" - +#include "data/shortlist.h" #include <sstream> int main(int argc, char** argv) { @@ -16,7 +16,8 @@ int main(int argc, char** argv) { YAML::Node config; // @TODO: get rid of YAML::Node here entirely to avoid the pattern. Currently not fixing as it requires more changes to the Options object. auto cli = New<cli::CLIWrapper>( config, - "Convert a model in the .npz format and normal memory layout to a mmap-able binary model which could be in normal memory layout or packed memory layout", + "Convert a model in the .npz format and normal memory layout to a mmap-able binary model which could be in normal memory layout or packed memory layout\n" + "or convert a text lexical shortlist to a binary shortlist with {--shortlist,-s} option", "Allowed options", "Examples:\n" " ./marian-conv -f model.npz -t model.bin --gemm-type packed16"); @@ -30,9 +31,30 @@ int main(int argc, char** argv) { "Encode output matrix and optional rotation matrix into model file. " "arg1: number of bits in LSH encoding, arg2: name of output weights matrix")->implicit_val("1024 Wemb"); cli->add<std::vector<std::string>>("--vocabs,-V", "Vocabulary file, required for ONNX export"); + cli->add<std::vector<std::string>>("--shortlist,-s", "Shortlist conversion: filePath firstNum bestNum threshold"); + cli->add<std::string>("--dump-shortlist,-d", "Binary shortlist dump path","lex.bin"); cli->parse(argc, argv); options->merge(config); } + + // shortlist conversion: + // ./marian-conv --shortlist lex.esen.s2t 100 100 0 --dump-shortlist lex.esen.bin --vocabs vocab.esen.spm vocab.esen.spm + if(options->hasAndNotEmpty("shortlist")){ + auto vocabPaths = options->get<std::vector<std::string>>("vocabs"); + auto dumpPath = options->get<std::string>("dump-shortlist"); + + Ptr<Vocab> srcVocab = New<Vocab>(options, 0); + srcVocab->load(vocabPaths[0]); + Ptr<Vocab> trgVocab = New<Vocab>(options, 1); + trgVocab->load(vocabPaths[1]); + + Ptr<const data::ShortlistGenerator> binaryShortlistGenerator + = New<data::BinaryShortlistGenerator>(options, srcVocab, trgVocab, 0, 1, vocabPaths[0] == vocabPaths[1]); + binaryShortlistGenerator->dump(dumpPath); + LOG(info, "Dumping of the shortlist is finished"); + return 0; + } + auto modelFrom = options->get<std::string>("from"); auto modelTo = options->get<std::string>("to"); diff --git a/src/common/hash.h b/src/common/hash.h index 1b24dbe2..7aca30de 100644 --- a/src/common/hash.h +++ b/src/common/hash.h @@ -10,20 +10,20 @@ template <class T> using hash = std::hash<T>; // This combinator is based on boost::hash_combine, but uses // std::hash as the hash implementation. Used as a drop-in // replacement for boost::hash_combine. -template <class T> -inline void hash_combine(std::size_t& seed, T const& v) { +template <class T, class HashType = std::size_t> +inline void hash_combine(HashType& seed, T const& v) { hash<T> hasher; - seed ^= hasher(v) + 0x9e3779b9 + (seed<<6) + (seed>>2); + seed ^= static_cast<HashType>(hasher(v)) + 0x9e3779b9 + (seed<<6) + (seed>>2); } // Hash a whole chunk of memory, mostly used for diagnostics -template <class T> -inline size_t hashMem(const T* beg, size_t len) { - size_t seed = 0; +template <class T, class HashType = std::size_t> +inline HashType hashMem(const T* beg, size_t len) { + HashType seed = 0; for(auto it = beg; it < beg + len; ++it) hash_combine(seed, *it); return seed; } } -}
\ No newline at end of file +} diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index 396c6ba4..79d685e0 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -3,6 +3,8 @@ #include "marian.h" #include "layers/lsh.h" +#include <queue> + namespace marian { namespace data { @@ -279,7 +281,9 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options, std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist"); ABORT_IF(vals.empty(), "No path to shortlist given"); std::string fname = vals[0]; - if(filesystem::Path(fname).extension().string() == ".bin") { + if(isBinaryShortlist(fname)){ + return New<BinaryShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared); + } else if(filesystem::Path(fname).extension().string() == ".bin") { return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared); } else { return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared); @@ -287,5 +291,259 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options, } } +bool isBinaryShortlist(const std::string& fileName){ + uint64_t magic; + io::InputFileStream in(fileName); + in.read((char*)(&magic), sizeof(magic)); + return in && (magic == BINARY_SHORTLIST_MAGIC); +} + +void BinaryShortlistGenerator::contentCheck() { + bool failFlag = 0; + // The offset table has to be within the size of shortlists. + for(int i = 0; i < wordToOffsetSize_-1; i++) + failFlag |= wordToOffset_[i] >= shortListsSize_; + + // The last element of wordToOffset_ must equal shortListsSize_ + failFlag |= wordToOffset_[wordToOffsetSize_-1] != shortListsSize_; + + // The vocabulary indices have to be within the vocabulary size. + size_t vSize = trgVocab_->size(); + for(int j = 0; j < shortListsSize_; j++) + failFlag |= shortLists_[j] >= vSize; + ABORT_IF(failFlag, "Error: shortlist indices are out of bounds"); +} + +// load shortlist from buffer +void BinaryShortlistGenerator::load(const void* ptr_void, size_t blobSize, bool check /*= true*/) { + /* File layout: + * header + * wordToOffset array + * shortLists array + */ + ABORT_IF(blobSize < sizeof(Header), "Shortlist length {} too short to have a header", blobSize); + + const char *ptr = static_cast<const char*>(ptr_void); + const Header &header = *reinterpret_cast<const Header*>(ptr); + ptr += sizeof(Header); + ABORT_IF(header.magic != BINARY_SHORTLIST_MAGIC, "Incorrect magic in binary shortlist"); + + uint64_t expectedSize = sizeof(Header) + header.wordToOffsetSize * sizeof(uint64_t) + header.shortListsSize * sizeof(WordIndex); + ABORT_IF(expectedSize != blobSize, "Shortlist header claims file size should be {} but file is {}", expectedSize, blobSize); + + if (check) { + uint64_t checksumActual = util::hashMem<uint64_t, uint64_t>(&header.firstNum, (blobSize - sizeof(header.magic) - sizeof(header.checksum)) / sizeof(uint64_t)); + ABORT_IF(checksumActual != header.checksum, "checksum check failed: this binary shortlist is corrupted"); + } + + firstNum_ = header.firstNum; + bestNum_ = header.bestNum; + LOG(info, "[data] Lexical short list firstNum {} and bestNum {}", firstNum_, bestNum_); + + wordToOffsetSize_ = header.wordToOffsetSize; + shortListsSize_ = header.shortListsSize; + + // Offsets right after header. + wordToOffset_ = reinterpret_cast<const uint64_t*>(ptr); + ptr += wordToOffsetSize_ * sizeof(uint64_t); + + shortLists_ = reinterpret_cast<const WordIndex*>(ptr); + + // Verify offsets and vocab ids are within bounds if requested by user. + if(check) + contentCheck(); +} + +// load shortlist from file +void BinaryShortlistGenerator::load(const std::string& filename, bool check /*=true*/) { + std::error_code error; + mmapMem_.map(filename, error); + ABORT_IF(error, "Error mapping file: {}", error.message()); + load(mmapMem_.data(), mmapMem_.mapped_length(), check); +} + +BinaryShortlistGenerator::BinaryShortlistGenerator(Ptr<Options> options, + Ptr<const Vocab> srcVocab, + Ptr<const Vocab> trgVocab, + size_t srcIdx /*= 0*/, + size_t /*trgIdx = 1*/, + bool shared /*= false*/) + : options_(options), + srcVocab_(srcVocab), + trgVocab_(trgVocab), + srcIdx_(srcIdx), + shared_(shared) { + + std::vector<std::string> vals = options_->get<std::vector<std::string>>("shortlist"); + ABORT_IF(vals.empty(), "No path to shortlist file given"); + std::string fname = vals[0]; + + if(isBinaryShortlist(fname)){ + bool check = vals.size() > 1 ? std::stoi(vals[1]) : 1; + LOG(info, "[data] Loading binary shortlist as {} {}", fname, check); + load(fname, check); + } + else{ + firstNum_ = vals.size() > 1 ? std::stoi(vals[1]) : 100; + bestNum_ = vals.size() > 2 ? std::stoi(vals[2]) : 100; + float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0; + LOG(info, "[data] Importing text lexical shortlist as {} {} {} {}", + fname, firstNum_, bestNum_, threshold); + import(fname, threshold); + } +} + +BinaryShortlistGenerator::BinaryShortlistGenerator(const void *ptr_void, + const size_t blobSize, + Ptr<const Vocab> srcVocab, + Ptr<const Vocab> trgVocab, + size_t srcIdx /*= 0*/, + size_t /*trgIdx = 1*/, + bool shared /*= false*/, + bool check /*= true*/) + : srcVocab_(srcVocab), + trgVocab_(trgVocab), + srcIdx_(srcIdx), + shared_(shared) { + load(ptr_void, blobSize, check); +} + +Ptr<Shortlist> BinaryShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const { + auto srcBatch = (*batch)[srcIdx_]; + size_t srcVocabSize = srcVocab_->size(); + size_t trgVocabSize = trgVocab_->size(); + + // Since V=trgVocab_->size() is not large, anchor the time and space complexity to O(V). + // Attempt to squeeze the truth tables into CPU cache + std::vector<bool> srcTruthTable(srcVocabSize, 0); // holds selected source words + std::vector<bool> trgTruthTable(trgVocabSize, 0); // holds selected target words + + // add firstNum most frequent words + for(WordIndex i = 0; i < firstNum_ && i < trgVocabSize; ++i) + trgTruthTable[i] = 1; + + // collect unique words from source + // add aligned target words: mark trgTruthTable[word] to 1 + for(auto word : srcBatch->data()) { + WordIndex srcIndex = word.toWordIndex(); + if(shared_) + trgTruthTable[srcIndex] = 1; + // If srcIndex has not been encountered, add the corresponding target words + if (!srcTruthTable[srcIndex]) { + for (uint64_t j = wordToOffset_[srcIndex]; j < wordToOffset_[srcIndex+1]; j++) + trgTruthTable[shortLists_[j]] = 1; + srcTruthTable[srcIndex] = 1; + } + } + + // Due to the 'multiple-of-eight' issue, the following O(N) patch is inserted + size_t trgTruthTableOnes = 0; // counter for no. of selected target words + for (size_t i = 0; i < trgVocabSize; i++) { + if(trgTruthTable[i]) + trgTruthTableOnes++; + } + + // Ensure that the generated vocabulary items from a shortlist are a multiple-of-eight + // This is necessary until intgemm supports non-multiple-of-eight matrices. + for (size_t i = firstNum_; i < trgVocabSize && trgTruthTableOnes%8!=0; i++){ + if (!trgTruthTable[i]){ + trgTruthTable[i] = 1; + trgTruthTableOnes++; + } + } + + // turn selected indices into vector and sort (Bucket sort: O(V)) + std::vector<WordIndex> indices; + for (WordIndex i = 0; i < trgVocabSize; i++) { + if(trgTruthTable[i]) + indices.push_back(i); + } + + return New<Shortlist>(indices); +} + +void BinaryShortlistGenerator::dump(const std::string& fileName) const { + ABORT_IF(mmapMem_.is_open(),"No need to dump again"); + LOG(info, "[data] Saving binary shortlist dump to {}", fileName); + saveBlobToFile(fileName); +} + +void BinaryShortlistGenerator::import(const std::string& filename, double threshold) { + io::InputFileStream in(filename); + std::string src, trg; + + // Read text file + std::vector<std::unordered_map<WordIndex, float>> srcTgtProbTable(srcVocab_->size()); + float prob; + + while(in >> trg >> src >> prob) { + if(src == "NULL" || trg == "NULL") + continue; + + auto sId = (*srcVocab_)[src].toWordIndex(); + auto tId = (*trgVocab_)[trg].toWordIndex(); + + if(srcTgtProbTable[sId][tId] < prob) + srcTgtProbTable[sId][tId] = prob; + } + + // Create priority queue and count + std::vector<std::priority_queue<std::pair<float, WordIndex>>> vpq; + uint64_t shortListsSize = 0; + + vpq.resize(srcTgtProbTable.size()); + for(WordIndex sId = 0; sId < srcTgtProbTable.size(); sId++) { + uint64_t shortListsSizeCurrent = 0; + for(auto entry : srcTgtProbTable[sId]) { + if (entry.first>=threshold) { + vpq[sId].push(std::make_pair(entry.second, entry.first)); + if(shortListsSizeCurrent < bestNum_) + shortListsSizeCurrent++; + } + } + shortListsSize += shortListsSizeCurrent; + } + + wordToOffsetSize_ = vpq.size() + 1; + shortListsSize_ = shortListsSize; + + // Generate a binary blob + blob_.resize(sizeof(Header) + wordToOffsetSize_ * sizeof(uint64_t) + shortListsSize_ * sizeof(WordIndex)); + struct Header* pHeader = (struct Header *)blob_.data(); + pHeader->magic = BINARY_SHORTLIST_MAGIC; + pHeader->firstNum = firstNum_; + pHeader->bestNum = bestNum_; + pHeader->wordToOffsetSize = wordToOffsetSize_; + pHeader->shortListsSize = shortListsSize_; + uint64_t* wordToOffset = (uint64_t*)((char *)pHeader + sizeof(Header)); + WordIndex* shortLists = (WordIndex*)((char*)wordToOffset + wordToOffsetSize_*sizeof(uint64_t)); + + uint64_t shortlistIdx = 0; + for (size_t i = 0; i < wordToOffsetSize_ - 1; i++) { + wordToOffset[i] = shortlistIdx; + for(int popcnt = 0; popcnt < bestNum_ && !vpq[i].empty(); popcnt++) { + shortLists[shortlistIdx] = vpq[i].top().second; + shortlistIdx++; + vpq[i].pop(); + } + } + wordToOffset[wordToOffsetSize_-1] = shortlistIdx; + + // Sort word indices for each shortlist + for(int i = 1; i < wordToOffsetSize_; i++) { + std::sort(&shortLists[wordToOffset[i-1]], &shortLists[wordToOffset[i]]); + } + pHeader->checksum = (uint64_t)util::hashMem<uint64_t>((uint64_t *)blob_.data()+2, + blob_.size()/sizeof(uint64_t)-2); + + wordToOffset_ = wordToOffset; + shortLists_ = shortLists; +} + +void BinaryShortlistGenerator::saveBlobToFile(const std::string& fileName) const { + io::OutputFileStream outTop(fileName); + outTop.write(blob_.data(), blob_.size()); +} + } // namespace data } // namespace marian diff --git a/src/data/shortlist.h b/src/data/shortlist.h index d3841b21..f15e5455 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -256,7 +256,6 @@ public: bestNum_ = vals.size() > 2 ? std::stoi(vals[2]) : 100; float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0; std::string dumpPath = vals.size() > 4 ? vals[4] : ""; - LOG(info, "[data] Loading lexical shortlist as {} {} {} {}", fname, @@ -392,5 +391,78 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options, size_t trgIdx = 1, bool shared = false); +// Magic signature for binary shortlist: +// ASCII and Unicode text files never start with the following 64 bits +const uint64_t BINARY_SHORTLIST_MAGIC = 0xF11A48D5013417F5; + +bool isBinaryShortlist(const std::string& fileName); + +class BinaryShortlistGenerator : public ShortlistGenerator { +private: + Ptr<Options> options_; + Ptr<const Vocab> srcVocab_; + Ptr<const Vocab> trgVocab_; + + size_t srcIdx_; + bool shared_{false}; + + uint64_t firstNum_{100}; // baked into binary header + uint64_t bestNum_{100}; // baked into binary header + + // shortlist is stored in a skip list + // [&shortLists_[wordToOffset_[word]], &shortLists_[wordToOffset_[word+1]]) + // is a sorted array of word indices in the shortlist for word + mio::mmap_source mmapMem_; + uint64_t wordToOffsetSize_; + uint64_t shortListsSize_; + const uint64_t *wordToOffset_; + const WordIndex *shortLists_; + std::vector<char> blob_; // binary blob + + struct Header { + uint64_t magic; // BINARY_SHORTLIST_MAGIC + uint64_t checksum; // util::hashMem<uint64_t, uint64_t> from &firstNum to end of file. + uint64_t firstNum; // Limits used to create the shortlist. + uint64_t bestNum; + uint64_t wordToOffsetSize; // Length of wordToOffset_ array. + uint64_t shortListsSize; // Length of shortLists_ array. + }; + + void contentCheck(); + // load shortlist from buffer + void load(const void* ptr_void, size_t blobSize, bool check = true); + // load shortlist from file + void load(const std::string& filename, bool check=true); + // import text shortlist from file + void import(const std::string& filename, double threshold); + // save blob to file (called by dump) + void saveBlobToFile(const std::string& filename) const; + +public: + BinaryShortlistGenerator(Ptr<Options> options, + Ptr<const Vocab> srcVocab, + Ptr<const Vocab> trgVocab, + size_t srcIdx = 0, + size_t /*trgIdx*/ = 1, + bool shared = false); + + // construct directly from buffer + BinaryShortlistGenerator(const void* ptr_void, + const size_t blobSize, + Ptr<const Vocab> srcVocab, + Ptr<const Vocab> trgVocab, + size_t srcIdx = 0, + size_t /*trgIdx*/ = 1, + bool shared = false, + bool check = true); + + ~BinaryShortlistGenerator(){ + mmapMem_.unmap(); + } + + virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override; + virtual void dump(const std::string& fileName) const override; +}; + } // namespace data } // namespace marian diff --git a/src/translator/translator.h b/src/translator/translator.h index 8cc301b4..0829f98e 100644 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -274,11 +274,15 @@ public: trgVocab_ = New<Vocab>(options_, vocabPaths.size() - 1); trgVocab_->load(vocabPaths.back()); + auto srcVocab = srcVocabs_.front(); + + std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn"); + ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters"); // load lexical shortlist - if(options_->hasAndNotEmpty("shortlist")) - shortlistGenerator_ = New<data::LexicalShortlistGenerator>( - options_, srcVocabs_.front(), trgVocab_, 0, 1, vocabPaths.front() == vocabPaths.back()); + if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) { + shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabPaths.front() == vocabPaths.back()); + } // get device IDs auto devices = Config::getDevices(options_); |