diff options
Diffstat (limited to 'src/data/shortlist.cpp')
-rw-r--r-- | src/data/shortlist.cpp | 260 |
1 files changed, 259 insertions, 1 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index 396c6ba4..79d685e0 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -3,6 +3,8 @@ #include "marian.h" #include "layers/lsh.h" +#include <queue> + namespace marian { namespace data { @@ -279,7 +281,9 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options, std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist"); ABORT_IF(vals.empty(), "No path to shortlist given"); std::string fname = vals[0]; - if(filesystem::Path(fname).extension().string() == ".bin") { + if(isBinaryShortlist(fname)){ + return New<BinaryShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared); + } else if(filesystem::Path(fname).extension().string() == ".bin") { return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared); } else { return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared); @@ -287,5 +291,259 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options, } } +bool isBinaryShortlist(const std::string& fileName){ + uint64_t magic; + io::InputFileStream in(fileName); + in.read((char*)(&magic), sizeof(magic)); + return in && (magic == BINARY_SHORTLIST_MAGIC); +} + +void BinaryShortlistGenerator::contentCheck() { + bool failFlag = 0; + // The offset table has to be within the size of shortlists. + for(int i = 0; i < wordToOffsetSize_-1; i++) + failFlag |= wordToOffset_[i] >= shortListsSize_; + + // The last element of wordToOffset_ must equal shortListsSize_ + failFlag |= wordToOffset_[wordToOffsetSize_-1] != shortListsSize_; + + // The vocabulary indices have to be within the vocabulary size. + size_t vSize = trgVocab_->size(); + for(int j = 0; j < shortListsSize_; j++) + failFlag |= shortLists_[j] >= vSize; + ABORT_IF(failFlag, "Error: shortlist indices are out of bounds"); +} + +// load shortlist from buffer +void BinaryShortlistGenerator::load(const void* ptr_void, size_t blobSize, bool check /*= true*/) { + /* File layout: + * header + * wordToOffset array + * shortLists array + */ + ABORT_IF(blobSize < sizeof(Header), "Shortlist length {} too short to have a header", blobSize); + + const char *ptr = static_cast<const char*>(ptr_void); + const Header &header = *reinterpret_cast<const Header*>(ptr); + ptr += sizeof(Header); + ABORT_IF(header.magic != BINARY_SHORTLIST_MAGIC, "Incorrect magic in binary shortlist"); + + uint64_t expectedSize = sizeof(Header) + header.wordToOffsetSize * sizeof(uint64_t) + header.shortListsSize * sizeof(WordIndex); + ABORT_IF(expectedSize != blobSize, "Shortlist header claims file size should be {} but file is {}", expectedSize, blobSize); + + if (check) { + uint64_t checksumActual = util::hashMem<uint64_t, uint64_t>(&header.firstNum, (blobSize - sizeof(header.magic) - sizeof(header.checksum)) / sizeof(uint64_t)); + ABORT_IF(checksumActual != header.checksum, "checksum check failed: this binary shortlist is corrupted"); + } + + firstNum_ = header.firstNum; + bestNum_ = header.bestNum; + LOG(info, "[data] Lexical short list firstNum {} and bestNum {}", firstNum_, bestNum_); + + wordToOffsetSize_ = header.wordToOffsetSize; + shortListsSize_ = header.shortListsSize; + + // Offsets right after header. + wordToOffset_ = reinterpret_cast<const uint64_t*>(ptr); + ptr += wordToOffsetSize_ * sizeof(uint64_t); + + shortLists_ = reinterpret_cast<const WordIndex*>(ptr); + + // Verify offsets and vocab ids are within bounds if requested by user. + if(check) + contentCheck(); +} + +// load shortlist from file +void BinaryShortlistGenerator::load(const std::string& filename, bool check /*=true*/) { + std::error_code error; + mmapMem_.map(filename, error); + ABORT_IF(error, "Error mapping file: {}", error.message()); + load(mmapMem_.data(), mmapMem_.mapped_length(), check); +} + +BinaryShortlistGenerator::BinaryShortlistGenerator(Ptr<Options> options, + Ptr<const Vocab> srcVocab, + Ptr<const Vocab> trgVocab, + size_t srcIdx /*= 0*/, + size_t /*trgIdx = 1*/, + bool shared /*= false*/) + : options_(options), + srcVocab_(srcVocab), + trgVocab_(trgVocab), + srcIdx_(srcIdx), + shared_(shared) { + + std::vector<std::string> vals = options_->get<std::vector<std::string>>("shortlist"); + ABORT_IF(vals.empty(), "No path to shortlist file given"); + std::string fname = vals[0]; + + if(isBinaryShortlist(fname)){ + bool check = vals.size() > 1 ? std::stoi(vals[1]) : 1; + LOG(info, "[data] Loading binary shortlist as {} {}", fname, check); + load(fname, check); + } + else{ + firstNum_ = vals.size() > 1 ? std::stoi(vals[1]) : 100; + bestNum_ = vals.size() > 2 ? std::stoi(vals[2]) : 100; + float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0; + LOG(info, "[data] Importing text lexical shortlist as {} {} {} {}", + fname, firstNum_, bestNum_, threshold); + import(fname, threshold); + } +} + +BinaryShortlistGenerator::BinaryShortlistGenerator(const void *ptr_void, + const size_t blobSize, + Ptr<const Vocab> srcVocab, + Ptr<const Vocab> trgVocab, + size_t srcIdx /*= 0*/, + size_t /*trgIdx = 1*/, + bool shared /*= false*/, + bool check /*= true*/) + : srcVocab_(srcVocab), + trgVocab_(trgVocab), + srcIdx_(srcIdx), + shared_(shared) { + load(ptr_void, blobSize, check); +} + +Ptr<Shortlist> BinaryShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const { + auto srcBatch = (*batch)[srcIdx_]; + size_t srcVocabSize = srcVocab_->size(); + size_t trgVocabSize = trgVocab_->size(); + + // Since V=trgVocab_->size() is not large, anchor the time and space complexity to O(V). + // Attempt to squeeze the truth tables into CPU cache + std::vector<bool> srcTruthTable(srcVocabSize, 0); // holds selected source words + std::vector<bool> trgTruthTable(trgVocabSize, 0); // holds selected target words + + // add firstNum most frequent words + for(WordIndex i = 0; i < firstNum_ && i < trgVocabSize; ++i) + trgTruthTable[i] = 1; + + // collect unique words from source + // add aligned target words: mark trgTruthTable[word] to 1 + for(auto word : srcBatch->data()) { + WordIndex srcIndex = word.toWordIndex(); + if(shared_) + trgTruthTable[srcIndex] = 1; + // If srcIndex has not been encountered, add the corresponding target words + if (!srcTruthTable[srcIndex]) { + for (uint64_t j = wordToOffset_[srcIndex]; j < wordToOffset_[srcIndex+1]; j++) + trgTruthTable[shortLists_[j]] = 1; + srcTruthTable[srcIndex] = 1; + } + } + + // Due to the 'multiple-of-eight' issue, the following O(N) patch is inserted + size_t trgTruthTableOnes = 0; // counter for no. of selected target words + for (size_t i = 0; i < trgVocabSize; i++) { + if(trgTruthTable[i]) + trgTruthTableOnes++; + } + + // Ensure that the generated vocabulary items from a shortlist are a multiple-of-eight + // This is necessary until intgemm supports non-multiple-of-eight matrices. + for (size_t i = firstNum_; i < trgVocabSize && trgTruthTableOnes%8!=0; i++){ + if (!trgTruthTable[i]){ + trgTruthTable[i] = 1; + trgTruthTableOnes++; + } + } + + // turn selected indices into vector and sort (Bucket sort: O(V)) + std::vector<WordIndex> indices; + for (WordIndex i = 0; i < trgVocabSize; i++) { + if(trgTruthTable[i]) + indices.push_back(i); + } + + return New<Shortlist>(indices); +} + +void BinaryShortlistGenerator::dump(const std::string& fileName) const { + ABORT_IF(mmapMem_.is_open(),"No need to dump again"); + LOG(info, "[data] Saving binary shortlist dump to {}", fileName); + saveBlobToFile(fileName); +} + +void BinaryShortlistGenerator::import(const std::string& filename, double threshold) { + io::InputFileStream in(filename); + std::string src, trg; + + // Read text file + std::vector<std::unordered_map<WordIndex, float>> srcTgtProbTable(srcVocab_->size()); + float prob; + + while(in >> trg >> src >> prob) { + if(src == "NULL" || trg == "NULL") + continue; + + auto sId = (*srcVocab_)[src].toWordIndex(); + auto tId = (*trgVocab_)[trg].toWordIndex(); + + if(srcTgtProbTable[sId][tId] < prob) + srcTgtProbTable[sId][tId] = prob; + } + + // Create priority queue and count + std::vector<std::priority_queue<std::pair<float, WordIndex>>> vpq; + uint64_t shortListsSize = 0; + + vpq.resize(srcTgtProbTable.size()); + for(WordIndex sId = 0; sId < srcTgtProbTable.size(); sId++) { + uint64_t shortListsSizeCurrent = 0; + for(auto entry : srcTgtProbTable[sId]) { + if (entry.first>=threshold) { + vpq[sId].push(std::make_pair(entry.second, entry.first)); + if(shortListsSizeCurrent < bestNum_) + shortListsSizeCurrent++; + } + } + shortListsSize += shortListsSizeCurrent; + } + + wordToOffsetSize_ = vpq.size() + 1; + shortListsSize_ = shortListsSize; + + // Generate a binary blob + blob_.resize(sizeof(Header) + wordToOffsetSize_ * sizeof(uint64_t) + shortListsSize_ * sizeof(WordIndex)); + struct Header* pHeader = (struct Header *)blob_.data(); + pHeader->magic = BINARY_SHORTLIST_MAGIC; + pHeader->firstNum = firstNum_; + pHeader->bestNum = bestNum_; + pHeader->wordToOffsetSize = wordToOffsetSize_; + pHeader->shortListsSize = shortListsSize_; + uint64_t* wordToOffset = (uint64_t*)((char *)pHeader + sizeof(Header)); + WordIndex* shortLists = (WordIndex*)((char*)wordToOffset + wordToOffsetSize_*sizeof(uint64_t)); + + uint64_t shortlistIdx = 0; + for (size_t i = 0; i < wordToOffsetSize_ - 1; i++) { + wordToOffset[i] = shortlistIdx; + for(int popcnt = 0; popcnt < bestNum_ && !vpq[i].empty(); popcnt++) { + shortLists[shortlistIdx] = vpq[i].top().second; + shortlistIdx++; + vpq[i].pop(); + } + } + wordToOffset[wordToOffsetSize_-1] = shortlistIdx; + + // Sort word indices for each shortlist + for(int i = 1; i < wordToOffsetSize_; i++) { + std::sort(&shortLists[wordToOffset[i-1]], &shortLists[wordToOffset[i]]); + } + pHeader->checksum = (uint64_t)util::hashMem<uint64_t>((uint64_t *)blob_.data()+2, + blob_.size()/sizeof(uint64_t)-2); + + wordToOffset_ = wordToOffset; + shortLists_ = shortLists; +} + +void BinaryShortlistGenerator::saveBlobToFile(const std::string& fileName) const { + io::OutputFileStream outTop(fileName); + outTop.write(blob_.data(), blob_.size()); +} + } // namespace data } // namespace marian |