Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/data/shortlist.cpp')
-rw-r--r--src/data/shortlist.cpp260
1 files changed, 259 insertions, 1 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index 396c6ba4..79d685e0 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -3,6 +3,8 @@
#include "marian.h"
#include "layers/lsh.h"
+#include <queue>
+
namespace marian {
namespace data {
@@ -279,7 +281,9 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
ABORT_IF(vals.empty(), "No path to shortlist given");
std::string fname = vals[0];
- if(filesystem::Path(fname).extension().string() == ".bin") {
+ if(isBinaryShortlist(fname)){
+ return New<BinaryShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
+ } else if(filesystem::Path(fname).extension().string() == ".bin") {
return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
} else {
return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
@@ -287,5 +291,259 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
}
}
+bool isBinaryShortlist(const std::string& fileName){
+ uint64_t magic;
+ io::InputFileStream in(fileName);
+ in.read((char*)(&magic), sizeof(magic));
+ return in && (magic == BINARY_SHORTLIST_MAGIC);
+}
+
+void BinaryShortlistGenerator::contentCheck() {
+ bool failFlag = 0;
+ // The offset table has to be within the size of shortlists.
+ for(int i = 0; i < wordToOffsetSize_-1; i++)
+ failFlag |= wordToOffset_[i] >= shortListsSize_;
+
+ // The last element of wordToOffset_ must equal shortListsSize_
+ failFlag |= wordToOffset_[wordToOffsetSize_-1] != shortListsSize_;
+
+ // The vocabulary indices have to be within the vocabulary size.
+ size_t vSize = trgVocab_->size();
+ for(int j = 0; j < shortListsSize_; j++)
+ failFlag |= shortLists_[j] >= vSize;
+ ABORT_IF(failFlag, "Error: shortlist indices are out of bounds");
+}
+
+// load shortlist from buffer
+void BinaryShortlistGenerator::load(const void* ptr_void, size_t blobSize, bool check /*= true*/) {
+ /* File layout:
+ * header
+ * wordToOffset array
+ * shortLists array
+ */
+ ABORT_IF(blobSize < sizeof(Header), "Shortlist length {} too short to have a header", blobSize);
+
+ const char *ptr = static_cast<const char*>(ptr_void);
+ const Header &header = *reinterpret_cast<const Header*>(ptr);
+ ptr += sizeof(Header);
+ ABORT_IF(header.magic != BINARY_SHORTLIST_MAGIC, "Incorrect magic in binary shortlist");
+
+ uint64_t expectedSize = sizeof(Header) + header.wordToOffsetSize * sizeof(uint64_t) + header.shortListsSize * sizeof(WordIndex);
+ ABORT_IF(expectedSize != blobSize, "Shortlist header claims file size should be {} but file is {}", expectedSize, blobSize);
+
+ if (check) {
+ uint64_t checksumActual = util::hashMem<uint64_t, uint64_t>(&header.firstNum, (blobSize - sizeof(header.magic) - sizeof(header.checksum)) / sizeof(uint64_t));
+ ABORT_IF(checksumActual != header.checksum, "checksum check failed: this binary shortlist is corrupted");
+ }
+
+ firstNum_ = header.firstNum;
+ bestNum_ = header.bestNum;
+ LOG(info, "[data] Lexical short list firstNum {} and bestNum {}", firstNum_, bestNum_);
+
+ wordToOffsetSize_ = header.wordToOffsetSize;
+ shortListsSize_ = header.shortListsSize;
+
+ // Offsets right after header.
+ wordToOffset_ = reinterpret_cast<const uint64_t*>(ptr);
+ ptr += wordToOffsetSize_ * sizeof(uint64_t);
+
+ shortLists_ = reinterpret_cast<const WordIndex*>(ptr);
+
+ // Verify offsets and vocab ids are within bounds if requested by user.
+ if(check)
+ contentCheck();
+}
+
+// load shortlist from file
+void BinaryShortlistGenerator::load(const std::string& filename, bool check /*=true*/) {
+ std::error_code error;
+ mmapMem_.map(filename, error);
+ ABORT_IF(error, "Error mapping file: {}", error.message());
+ load(mmapMem_.data(), mmapMem_.mapped_length(), check);
+}
+
+BinaryShortlistGenerator::BinaryShortlistGenerator(Ptr<Options> options,
+ Ptr<const Vocab> srcVocab,
+ Ptr<const Vocab> trgVocab,
+ size_t srcIdx /*= 0*/,
+ size_t /*trgIdx = 1*/,
+ bool shared /*= false*/)
+ : options_(options),
+ srcVocab_(srcVocab),
+ trgVocab_(trgVocab),
+ srcIdx_(srcIdx),
+ shared_(shared) {
+
+ std::vector<std::string> vals = options_->get<std::vector<std::string>>("shortlist");
+ ABORT_IF(vals.empty(), "No path to shortlist file given");
+ std::string fname = vals[0];
+
+ if(isBinaryShortlist(fname)){
+ bool check = vals.size() > 1 ? std::stoi(vals[1]) : 1;
+ LOG(info, "[data] Loading binary shortlist as {} {}", fname, check);
+ load(fname, check);
+ }
+ else{
+ firstNum_ = vals.size() > 1 ? std::stoi(vals[1]) : 100;
+ bestNum_ = vals.size() > 2 ? std::stoi(vals[2]) : 100;
+ float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0;
+ LOG(info, "[data] Importing text lexical shortlist as {} {} {} {}",
+ fname, firstNum_, bestNum_, threshold);
+ import(fname, threshold);
+ }
+}
+
+BinaryShortlistGenerator::BinaryShortlistGenerator(const void *ptr_void,
+ const size_t blobSize,
+ Ptr<const Vocab> srcVocab,
+ Ptr<const Vocab> trgVocab,
+ size_t srcIdx /*= 0*/,
+ size_t /*trgIdx = 1*/,
+ bool shared /*= false*/,
+ bool check /*= true*/)
+ : srcVocab_(srcVocab),
+ trgVocab_(trgVocab),
+ srcIdx_(srcIdx),
+ shared_(shared) {
+ load(ptr_void, blobSize, check);
+}
+
+Ptr<Shortlist> BinaryShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
+ auto srcBatch = (*batch)[srcIdx_];
+ size_t srcVocabSize = srcVocab_->size();
+ size_t trgVocabSize = trgVocab_->size();
+
+ // Since V=trgVocab_->size() is not large, anchor the time and space complexity to O(V).
+ // Attempt to squeeze the truth tables into CPU cache
+ std::vector<bool> srcTruthTable(srcVocabSize, 0); // holds selected source words
+ std::vector<bool> trgTruthTable(trgVocabSize, 0); // holds selected target words
+
+ // add firstNum most frequent words
+ for(WordIndex i = 0; i < firstNum_ && i < trgVocabSize; ++i)
+ trgTruthTable[i] = 1;
+
+ // collect unique words from source
+ // add aligned target words: mark trgTruthTable[word] to 1
+ for(auto word : srcBatch->data()) {
+ WordIndex srcIndex = word.toWordIndex();
+ if(shared_)
+ trgTruthTable[srcIndex] = 1;
+ // If srcIndex has not been encountered, add the corresponding target words
+ if (!srcTruthTable[srcIndex]) {
+ for (uint64_t j = wordToOffset_[srcIndex]; j < wordToOffset_[srcIndex+1]; j++)
+ trgTruthTable[shortLists_[j]] = 1;
+ srcTruthTable[srcIndex] = 1;
+ }
+ }
+
+ // Due to the 'multiple-of-eight' issue, the following O(N) patch is inserted
+ size_t trgTruthTableOnes = 0; // counter for no. of selected target words
+ for (size_t i = 0; i < trgVocabSize; i++) {
+ if(trgTruthTable[i])
+ trgTruthTableOnes++;
+ }
+
+ // Ensure that the generated vocabulary items from a shortlist are a multiple-of-eight
+ // This is necessary until intgemm supports non-multiple-of-eight matrices.
+ for (size_t i = firstNum_; i < trgVocabSize && trgTruthTableOnes%8!=0; i++){
+ if (!trgTruthTable[i]){
+ trgTruthTable[i] = 1;
+ trgTruthTableOnes++;
+ }
+ }
+
+ // turn selected indices into vector and sort (Bucket sort: O(V))
+ std::vector<WordIndex> indices;
+ for (WordIndex i = 0; i < trgVocabSize; i++) {
+ if(trgTruthTable[i])
+ indices.push_back(i);
+ }
+
+ return New<Shortlist>(indices);
+}
+
+void BinaryShortlistGenerator::dump(const std::string& fileName) const {
+ ABORT_IF(mmapMem_.is_open(),"No need to dump again");
+ LOG(info, "[data] Saving binary shortlist dump to {}", fileName);
+ saveBlobToFile(fileName);
+}
+
+void BinaryShortlistGenerator::import(const std::string& filename, double threshold) {
+ io::InputFileStream in(filename);
+ std::string src, trg;
+
+ // Read text file
+ std::vector<std::unordered_map<WordIndex, float>> srcTgtProbTable(srcVocab_->size());
+ float prob;
+
+ while(in >> trg >> src >> prob) {
+ if(src == "NULL" || trg == "NULL")
+ continue;
+
+ auto sId = (*srcVocab_)[src].toWordIndex();
+ auto tId = (*trgVocab_)[trg].toWordIndex();
+
+ if(srcTgtProbTable[sId][tId] < prob)
+ srcTgtProbTable[sId][tId] = prob;
+ }
+
+ // Create priority queue and count
+ std::vector<std::priority_queue<std::pair<float, WordIndex>>> vpq;
+ uint64_t shortListsSize = 0;
+
+ vpq.resize(srcTgtProbTable.size());
+ for(WordIndex sId = 0; sId < srcTgtProbTable.size(); sId++) {
+ uint64_t shortListsSizeCurrent = 0;
+ for(auto entry : srcTgtProbTable[sId]) {
+ if (entry.first>=threshold) {
+ vpq[sId].push(std::make_pair(entry.second, entry.first));
+ if(shortListsSizeCurrent < bestNum_)
+ shortListsSizeCurrent++;
+ }
+ }
+ shortListsSize += shortListsSizeCurrent;
+ }
+
+ wordToOffsetSize_ = vpq.size() + 1;
+ shortListsSize_ = shortListsSize;
+
+ // Generate a binary blob
+ blob_.resize(sizeof(Header) + wordToOffsetSize_ * sizeof(uint64_t) + shortListsSize_ * sizeof(WordIndex));
+ struct Header* pHeader = (struct Header *)blob_.data();
+ pHeader->magic = BINARY_SHORTLIST_MAGIC;
+ pHeader->firstNum = firstNum_;
+ pHeader->bestNum = bestNum_;
+ pHeader->wordToOffsetSize = wordToOffsetSize_;
+ pHeader->shortListsSize = shortListsSize_;
+ uint64_t* wordToOffset = (uint64_t*)((char *)pHeader + sizeof(Header));
+ WordIndex* shortLists = (WordIndex*)((char*)wordToOffset + wordToOffsetSize_*sizeof(uint64_t));
+
+ uint64_t shortlistIdx = 0;
+ for (size_t i = 0; i < wordToOffsetSize_ - 1; i++) {
+ wordToOffset[i] = shortlistIdx;
+ for(int popcnt = 0; popcnt < bestNum_ && !vpq[i].empty(); popcnt++) {
+ shortLists[shortlistIdx] = vpq[i].top().second;
+ shortlistIdx++;
+ vpq[i].pop();
+ }
+ }
+ wordToOffset[wordToOffsetSize_-1] = shortlistIdx;
+
+ // Sort word indices for each shortlist
+ for(int i = 1; i < wordToOffsetSize_; i++) {
+ std::sort(&shortLists[wordToOffset[i-1]], &shortLists[wordToOffset[i]]);
+ }
+ pHeader->checksum = (uint64_t)util::hashMem<uint64_t>((uint64_t *)blob_.data()+2,
+ blob_.size()/sizeof(uint64_t)-2);
+
+ wordToOffset_ = wordToOffset;
+ shortLists_ = shortLists;
+}
+
+void BinaryShortlistGenerator::saveBlobToFile(const std::string& fileName) const {
+ io::OutputFileStream outTop(fileName);
+ outTop.write(blob_.data(), blob_.size());
+}
+
} // namespace data
} // namespace marian