Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQianqian Zhu <qianqian.zhu@hotmail.com>2021-07-11 08:56:58 +0300
committerGitHub <noreply@github.com>2021-07-11 08:56:58 +0300
commit42f0b8b74bba16fed646c8af7b2f75e02af7a85c (patch)
tree5ce1c621aa65f286ff82c8374f9e8605a7b873b3
parent7e6ea51841025d5abdb6fdb1fc33dc4907355dc9 (diff)
Binary shortlist (#856)
Co-authored-by: Kenneth Heafield <github@kheafield.com>
-rw-r--r--CHANGELOG.md2
-rw-r--r--src/command/marian_conv.cpp26
-rw-r--r--src/common/hash.h14
-rw-r--r--src/data/shortlist.cpp260
-rw-r--r--src/data/shortlist.h74
-rw-r--r--src/translator/translator.h10
6 files changed, 372 insertions, 14 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index b95ffb5e..a9e24f57 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Add unit tests for binary files.
- Fix compilation with OMP
- Compute aligned memory sizes using exact sizing
+- Support for loading lexical shortlist from a binary blob
+- Integrate a shortlist converter (which can convert a text lexical shortlist to a binary shortlist) into marian-conv with --shortlist option
### Fixed
- Added support to MPIWrappest::bcast (and similar) for count of type size_t
diff --git a/src/command/marian_conv.cpp b/src/command/marian_conv.cpp
index e0e89d2b..943f61d4 100644
--- a/src/command/marian_conv.cpp
+++ b/src/command/marian_conv.cpp
@@ -3,7 +3,7 @@
#include "tensors/cpu/expression_graph_packable.h"
#include "onnx/expression_graph_onnx_exporter.h"
#include "layers/lsh.h"
-
+#include "data/shortlist.h"
#include <sstream>
int main(int argc, char** argv) {
@@ -16,7 +16,8 @@ int main(int argc, char** argv) {
YAML::Node config; // @TODO: get rid of YAML::Node here entirely to avoid the pattern. Currently not fixing as it requires more changes to the Options object.
auto cli = New<cli::CLIWrapper>(
config,
- "Convert a model in the .npz format and normal memory layout to a mmap-able binary model which could be in normal memory layout or packed memory layout",
+ "Convert a model in the .npz format and normal memory layout to a mmap-able binary model which could be in normal memory layout or packed memory layout\n"
+ "or convert a text lexical shortlist to a binary shortlist with {--shortlist,-s} option",
"Allowed options",
"Examples:\n"
" ./marian-conv -f model.npz -t model.bin --gemm-type packed16");
@@ -30,9 +31,30 @@ int main(int argc, char** argv) {
"Encode output matrix and optional rotation matrix into model file. "
"arg1: number of bits in LSH encoding, arg2: name of output weights matrix")->implicit_val("1024 Wemb");
cli->add<std::vector<std::string>>("--vocabs,-V", "Vocabulary file, required for ONNX export");
+ cli->add<std::vector<std::string>>("--shortlist,-s", "Shortlist conversion: filePath firstNum bestNum threshold");
+ cli->add<std::string>("--dump-shortlist,-d", "Binary shortlist dump path","lex.bin");
cli->parse(argc, argv);
options->merge(config);
}
+
+ // shortlist conversion:
+ // ./marian-conv --shortlist lex.esen.s2t 100 100 0 --dump-shortlist lex.esen.bin --vocabs vocab.esen.spm vocab.esen.spm
+ if(options->hasAndNotEmpty("shortlist")){
+ auto vocabPaths = options->get<std::vector<std::string>>("vocabs");
+ auto dumpPath = options->get<std::string>("dump-shortlist");
+
+ Ptr<Vocab> srcVocab = New<Vocab>(options, 0);
+ srcVocab->load(vocabPaths[0]);
+ Ptr<Vocab> trgVocab = New<Vocab>(options, 1);
+ trgVocab->load(vocabPaths[1]);
+
+ Ptr<const data::ShortlistGenerator> binaryShortlistGenerator
+ = New<data::BinaryShortlistGenerator>(options, srcVocab, trgVocab, 0, 1, vocabPaths[0] == vocabPaths[1]);
+ binaryShortlistGenerator->dump(dumpPath);
+ LOG(info, "Dumping of the shortlist is finished");
+ return 0;
+ }
+
auto modelFrom = options->get<std::string>("from");
auto modelTo = options->get<std::string>("to");
diff --git a/src/common/hash.h b/src/common/hash.h
index 1b24dbe2..7aca30de 100644
--- a/src/common/hash.h
+++ b/src/common/hash.h
@@ -10,20 +10,20 @@ template <class T> using hash = std::hash<T>;
// This combinator is based on boost::hash_combine, but uses
// std::hash as the hash implementation. Used as a drop-in
// replacement for boost::hash_combine.
-template <class T>
-inline void hash_combine(std::size_t& seed, T const& v) {
+template <class T, class HashType = std::size_t>
+inline void hash_combine(HashType& seed, T const& v) {
hash<T> hasher;
- seed ^= hasher(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
+ seed ^= static_cast<HashType>(hasher(v)) + 0x9e3779b9 + (seed<<6) + (seed>>2);
}
// Hash a whole chunk of memory, mostly used for diagnostics
-template <class T>
-inline size_t hashMem(const T* beg, size_t len) {
- size_t seed = 0;
+template <class T, class HashType = std::size_t>
+inline HashType hashMem(const T* beg, size_t len) {
+ HashType seed = 0;
for(auto it = beg; it < beg + len; ++it)
hash_combine(seed, *it);
return seed;
}
}
-} \ No newline at end of file
+}
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index 396c6ba4..79d685e0 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -3,6 +3,8 @@
#include "marian.h"
#include "layers/lsh.h"
+#include <queue>
+
namespace marian {
namespace data {
@@ -279,7 +281,9 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
ABORT_IF(vals.empty(), "No path to shortlist given");
std::string fname = vals[0];
- if(filesystem::Path(fname).extension().string() == ".bin") {
+ if(isBinaryShortlist(fname)){
+ return New<BinaryShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
+ } else if(filesystem::Path(fname).extension().string() == ".bin") {
return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
} else {
return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
@@ -287,5 +291,259 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
}
}
+bool isBinaryShortlist(const std::string& fileName){
+ uint64_t magic;
+ io::InputFileStream in(fileName);
+ in.read((char*)(&magic), sizeof(magic));
+ return in && (magic == BINARY_SHORTLIST_MAGIC);
+}
+
+void BinaryShortlistGenerator::contentCheck() {
+ bool failFlag = 0;
+ // The offset table has to be within the size of shortlists.
+ for(int i = 0; i < wordToOffsetSize_-1; i++)
+ failFlag |= wordToOffset_[i] >= shortListsSize_;
+
+ // The last element of wordToOffset_ must equal shortListsSize_
+ failFlag |= wordToOffset_[wordToOffsetSize_-1] != shortListsSize_;
+
+ // The vocabulary indices have to be within the vocabulary size.
+ size_t vSize = trgVocab_->size();
+ for(int j = 0; j < shortListsSize_; j++)
+ failFlag |= shortLists_[j] >= vSize;
+ ABORT_IF(failFlag, "Error: shortlist indices are out of bounds");
+}
+
+// load shortlist from buffer
+void BinaryShortlistGenerator::load(const void* ptr_void, size_t blobSize, bool check /*= true*/) {
+ /* File layout:
+ * header
+ * wordToOffset array
+ * shortLists array
+ */
+ ABORT_IF(blobSize < sizeof(Header), "Shortlist length {} too short to have a header", blobSize);
+
+ const char *ptr = static_cast<const char*>(ptr_void);
+ const Header &header = *reinterpret_cast<const Header*>(ptr);
+ ptr += sizeof(Header);
+ ABORT_IF(header.magic != BINARY_SHORTLIST_MAGIC, "Incorrect magic in binary shortlist");
+
+ uint64_t expectedSize = sizeof(Header) + header.wordToOffsetSize * sizeof(uint64_t) + header.shortListsSize * sizeof(WordIndex);
+ ABORT_IF(expectedSize != blobSize, "Shortlist header claims file size should be {} but file is {}", expectedSize, blobSize);
+
+ if (check) {
+ uint64_t checksumActual = util::hashMem<uint64_t, uint64_t>(&header.firstNum, (blobSize - sizeof(header.magic) - sizeof(header.checksum)) / sizeof(uint64_t));
+ ABORT_IF(checksumActual != header.checksum, "checksum check failed: this binary shortlist is corrupted");
+ }
+
+ firstNum_ = header.firstNum;
+ bestNum_ = header.bestNum;
+ LOG(info, "[data] Lexical short list firstNum {} and bestNum {}", firstNum_, bestNum_);
+
+ wordToOffsetSize_ = header.wordToOffsetSize;
+ shortListsSize_ = header.shortListsSize;
+
+ // Offsets right after header.
+ wordToOffset_ = reinterpret_cast<const uint64_t*>(ptr);
+ ptr += wordToOffsetSize_ * sizeof(uint64_t);
+
+ shortLists_ = reinterpret_cast<const WordIndex*>(ptr);
+
+ // Verify offsets and vocab ids are within bounds if requested by user.
+ if(check)
+ contentCheck();
+}
+
+// load shortlist from file
+void BinaryShortlistGenerator::load(const std::string& filename, bool check /*=true*/) {
+ std::error_code error;
+ mmapMem_.map(filename, error);
+ ABORT_IF(error, "Error mapping file: {}", error.message());
+ load(mmapMem_.data(), mmapMem_.mapped_length(), check);
+}
+
+BinaryShortlistGenerator::BinaryShortlistGenerator(Ptr<Options> options,
+ Ptr<const Vocab> srcVocab,
+ Ptr<const Vocab> trgVocab,
+ size_t srcIdx /*= 0*/,
+ size_t /*trgIdx = 1*/,
+ bool shared /*= false*/)
+ : options_(options),
+ srcVocab_(srcVocab),
+ trgVocab_(trgVocab),
+ srcIdx_(srcIdx),
+ shared_(shared) {
+
+ std::vector<std::string> vals = options_->get<std::vector<std::string>>("shortlist");
+ ABORT_IF(vals.empty(), "No path to shortlist file given");
+ std::string fname = vals[0];
+
+ if(isBinaryShortlist(fname)){
+ bool check = vals.size() > 1 ? std::stoi(vals[1]) : 1;
+ LOG(info, "[data] Loading binary shortlist as {} {}", fname, check);
+ load(fname, check);
+ }
+ else{
+ firstNum_ = vals.size() > 1 ? std::stoi(vals[1]) : 100;
+ bestNum_ = vals.size() > 2 ? std::stoi(vals[2]) : 100;
+ float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0;
+ LOG(info, "[data] Importing text lexical shortlist as {} {} {} {}",
+ fname, firstNum_, bestNum_, threshold);
+ import(fname, threshold);
+ }
+}
+
+BinaryShortlistGenerator::BinaryShortlistGenerator(const void *ptr_void,
+ const size_t blobSize,
+ Ptr<const Vocab> srcVocab,
+ Ptr<const Vocab> trgVocab,
+ size_t srcIdx /*= 0*/,
+ size_t /*trgIdx = 1*/,
+ bool shared /*= false*/,
+ bool check /*= true*/)
+ : srcVocab_(srcVocab),
+ trgVocab_(trgVocab),
+ srcIdx_(srcIdx),
+ shared_(shared) {
+ load(ptr_void, blobSize, check);
+}
+
+Ptr<Shortlist> BinaryShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
+ auto srcBatch = (*batch)[srcIdx_];
+ size_t srcVocabSize = srcVocab_->size();
+ size_t trgVocabSize = trgVocab_->size();
+
+ // Since V=trgVocab_->size() is not large, anchor the time and space complexity to O(V).
+ // Attempt to squeeze the truth tables into CPU cache
+ std::vector<bool> srcTruthTable(srcVocabSize, 0); // holds selected source words
+ std::vector<bool> trgTruthTable(trgVocabSize, 0); // holds selected target words
+
+ // add firstNum most frequent words
+ for(WordIndex i = 0; i < firstNum_ && i < trgVocabSize; ++i)
+ trgTruthTable[i] = 1;
+
+ // collect unique words from source
+ // add aligned target words: mark trgTruthTable[word] to 1
+ for(auto word : srcBatch->data()) {
+ WordIndex srcIndex = word.toWordIndex();
+ if(shared_)
+ trgTruthTable[srcIndex] = 1;
+ // If srcIndex has not been encountered, add the corresponding target words
+ if (!srcTruthTable[srcIndex]) {
+ for (uint64_t j = wordToOffset_[srcIndex]; j < wordToOffset_[srcIndex+1]; j++)
+ trgTruthTable[shortLists_[j]] = 1;
+ srcTruthTable[srcIndex] = 1;
+ }
+ }
+
+ // Due to the 'multiple-of-eight' issue, the following O(N) patch is inserted
+ size_t trgTruthTableOnes = 0; // counter for no. of selected target words
+ for (size_t i = 0; i < trgVocabSize; i++) {
+ if(trgTruthTable[i])
+ trgTruthTableOnes++;
+ }
+
+ // Ensure that the generated vocabulary items from a shortlist are a multiple-of-eight
+ // This is necessary until intgemm supports non-multiple-of-eight matrices.
+ for (size_t i = firstNum_; i < trgVocabSize && trgTruthTableOnes%8!=0; i++){
+ if (!trgTruthTable[i]){
+ trgTruthTable[i] = 1;
+ trgTruthTableOnes++;
+ }
+ }
+
+ // turn selected indices into vector and sort (Bucket sort: O(V))
+ std::vector<WordIndex> indices;
+ for (WordIndex i = 0; i < trgVocabSize; i++) {
+ if(trgTruthTable[i])
+ indices.push_back(i);
+ }
+
+ return New<Shortlist>(indices);
+}
+
+void BinaryShortlistGenerator::dump(const std::string& fileName) const {
+ ABORT_IF(mmapMem_.is_open(),"No need to dump again");
+ LOG(info, "[data] Saving binary shortlist dump to {}", fileName);
+ saveBlobToFile(fileName);
+}
+
+void BinaryShortlistGenerator::import(const std::string& filename, double threshold) {
+ io::InputFileStream in(filename);
+ std::string src, trg;
+
+ // Read text file
+ std::vector<std::unordered_map<WordIndex, float>> srcTgtProbTable(srcVocab_->size());
+ float prob;
+
+ while(in >> trg >> src >> prob) {
+ if(src == "NULL" || trg == "NULL")
+ continue;
+
+ auto sId = (*srcVocab_)[src].toWordIndex();
+ auto tId = (*trgVocab_)[trg].toWordIndex();
+
+ if(srcTgtProbTable[sId][tId] < prob)
+ srcTgtProbTable[sId][tId] = prob;
+ }
+
+ // Create priority queue and count
+ std::vector<std::priority_queue<std::pair<float, WordIndex>>> vpq;
+ uint64_t shortListsSize = 0;
+
+ vpq.resize(srcTgtProbTable.size());
+ for(WordIndex sId = 0; sId < srcTgtProbTable.size(); sId++) {
+ uint64_t shortListsSizeCurrent = 0;
+ for(auto entry : srcTgtProbTable[sId]) {
+ if (entry.first>=threshold) {
+ vpq[sId].push(std::make_pair(entry.second, entry.first));
+ if(shortListsSizeCurrent < bestNum_)
+ shortListsSizeCurrent++;
+ }
+ }
+ shortListsSize += shortListsSizeCurrent;
+ }
+
+ wordToOffsetSize_ = vpq.size() + 1;
+ shortListsSize_ = shortListsSize;
+
+ // Generate a binary blob
+ blob_.resize(sizeof(Header) + wordToOffsetSize_ * sizeof(uint64_t) + shortListsSize_ * sizeof(WordIndex));
+ struct Header* pHeader = (struct Header *)blob_.data();
+ pHeader->magic = BINARY_SHORTLIST_MAGIC;
+ pHeader->firstNum = firstNum_;
+ pHeader->bestNum = bestNum_;
+ pHeader->wordToOffsetSize = wordToOffsetSize_;
+ pHeader->shortListsSize = shortListsSize_;
+ uint64_t* wordToOffset = (uint64_t*)((char *)pHeader + sizeof(Header));
+ WordIndex* shortLists = (WordIndex*)((char*)wordToOffset + wordToOffsetSize_*sizeof(uint64_t));
+
+ uint64_t shortlistIdx = 0;
+ for (size_t i = 0; i < wordToOffsetSize_ - 1; i++) {
+ wordToOffset[i] = shortlistIdx;
+ for(int popcnt = 0; popcnt < bestNum_ && !vpq[i].empty(); popcnt++) {
+ shortLists[shortlistIdx] = vpq[i].top().second;
+ shortlistIdx++;
+ vpq[i].pop();
+ }
+ }
+ wordToOffset[wordToOffsetSize_-1] = shortlistIdx;
+
+ // Sort word indices for each shortlist
+ for(int i = 1; i < wordToOffsetSize_; i++) {
+ std::sort(&shortLists[wordToOffset[i-1]], &shortLists[wordToOffset[i]]);
+ }
+ pHeader->checksum = (uint64_t)util::hashMem<uint64_t>((uint64_t *)blob_.data()+2,
+ blob_.size()/sizeof(uint64_t)-2);
+
+ wordToOffset_ = wordToOffset;
+ shortLists_ = shortLists;
+}
+
+void BinaryShortlistGenerator::saveBlobToFile(const std::string& fileName) const {
+ io::OutputFileStream outTop(fileName);
+ outTop.write(blob_.data(), blob_.size());
+}
+
} // namespace data
} // namespace marian
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index d3841b21..f15e5455 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -256,7 +256,6 @@ public:
bestNum_ = vals.size() > 2 ? std::stoi(vals[2]) : 100;
float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0;
std::string dumpPath = vals.size() > 4 ? vals[4] : "";
-
LOG(info,
"[data] Loading lexical shortlist as {} {} {} {}",
fname,
@@ -392,5 +391,78 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
size_t trgIdx = 1,
bool shared = false);
+// Magic signature for binary shortlist:
+// ASCII and Unicode text files never start with the following 64 bits
+const uint64_t BINARY_SHORTLIST_MAGIC = 0xF11A48D5013417F5;
+
+bool isBinaryShortlist(const std::string& fileName);
+
+class BinaryShortlistGenerator : public ShortlistGenerator {
+private:
+ Ptr<Options> options_;
+ Ptr<const Vocab> srcVocab_;
+ Ptr<const Vocab> trgVocab_;
+
+ size_t srcIdx_;
+ bool shared_{false};
+
+ uint64_t firstNum_{100}; // baked into binary header
+ uint64_t bestNum_{100}; // baked into binary header
+
+ // shortlist is stored in a skip list
+ // [&shortLists_[wordToOffset_[word]], &shortLists_[wordToOffset_[word+1]])
+ // is a sorted array of word indices in the shortlist for word
+ mio::mmap_source mmapMem_;
+ uint64_t wordToOffsetSize_;
+ uint64_t shortListsSize_;
+ const uint64_t *wordToOffset_;
+ const WordIndex *shortLists_;
+ std::vector<char> blob_; // binary blob
+
+ struct Header {
+ uint64_t magic; // BINARY_SHORTLIST_MAGIC
+ uint64_t checksum; // util::hashMem<uint64_t, uint64_t> from &firstNum to end of file.
+ uint64_t firstNum; // Limits used to create the shortlist.
+ uint64_t bestNum;
+ uint64_t wordToOffsetSize; // Length of wordToOffset_ array.
+ uint64_t shortListsSize; // Length of shortLists_ array.
+ };
+
+ void contentCheck();
+ // load shortlist from buffer
+ void load(const void* ptr_void, size_t blobSize, bool check = true);
+ // load shortlist from file
+ void load(const std::string& filename, bool check=true);
+ // import text shortlist from file
+ void import(const std::string& filename, double threshold);
+ // save blob to file (called by dump)
+ void saveBlobToFile(const std::string& filename) const;
+
+public:
+ BinaryShortlistGenerator(Ptr<Options> options,
+ Ptr<const Vocab> srcVocab,
+ Ptr<const Vocab> trgVocab,
+ size_t srcIdx = 0,
+ size_t /*trgIdx*/ = 1,
+ bool shared = false);
+
+ // construct directly from buffer
+ BinaryShortlistGenerator(const void* ptr_void,
+ const size_t blobSize,
+ Ptr<const Vocab> srcVocab,
+ Ptr<const Vocab> trgVocab,
+ size_t srcIdx = 0,
+ size_t /*trgIdx*/ = 1,
+ bool shared = false,
+ bool check = true);
+
+ ~BinaryShortlistGenerator(){
+ mmapMem_.unmap();
+ }
+
+ virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
+ virtual void dump(const std::string& fileName) const override;
+};
+
} // namespace data
} // namespace marian
diff --git a/src/translator/translator.h b/src/translator/translator.h
index 8cc301b4..0829f98e 100644
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -274,11 +274,15 @@ public:
trgVocab_ = New<Vocab>(options_, vocabPaths.size() - 1);
trgVocab_->load(vocabPaths.back());
+ auto srcVocab = srcVocabs_.front();
+
+ std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn");
+ ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters");
// load lexical shortlist
- if(options_->hasAndNotEmpty("shortlist"))
- shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
- options_, srcVocabs_.front(), trgVocab_, 0, 1, vocabPaths.front() == vocabPaths.back());
+ if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) {
+ shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabPaths.front() == vocabPaths.back());
+ }
// get device IDs
auto devices = Config::getDevices(options_);