diff options
Diffstat (limited to 'moses/TranslationModel/ProbingPT/storing.cpp')
-rw-r--r-- | moses/TranslationModel/ProbingPT/storing.cpp | 315 |
1 files changed, 226 insertions, 89 deletions
diff --git a/moses/TranslationModel/ProbingPT/storing.cpp b/moses/TranslationModel/ProbingPT/storing.cpp index 01128c1e4..baf6ae91e 100644 --- a/moses/TranslationModel/ProbingPT/storing.cpp +++ b/moses/TranslationModel/ProbingPT/storing.cpp @@ -1,161 +1,298 @@ +#include <sys/stat.h> +#include <boost/foreach.hpp> +#include "line_splitter.hh" #include "storing.hh" +#include "StoreTarget.h" +#include "StoreVocab.h" +#include "moses/Util.h" +#include "moses/InputFileStream.h" -BinaryFileWriter::BinaryFileWriter (std::string basepath) : os ((basepath + "/binfile.dat").c_str(), std::ios::binary) +using namespace std; + +namespace Moses { - binfile.reserve(10000); //Reserve part of the vector to avoid realocation - it = binfile.begin(); - dist_from_start = 0; //Initialize variables - extra_counter = 0; -} -void BinaryFileWriter::write (std::vector<unsigned char> * bytes) +/////////////////////////////////////////////////////////////////////// +void Node::Add(Table &table, const SourcePhrase &sourcePhrase, size_t pos) { - binfile.insert(it, bytes->begin(), bytes->end()); //Insert the bytes - //Keep track of the offsets - it += bytes->size(); - dist_from_start = distance(binfile.begin(),it); - //Flush the vector to disk every once in a while so that we don't consume too much ram - if (dist_from_start > 9000) { - flush(); + if (pos < sourcePhrase.size()) { + uint64_t vocabId = sourcePhrase[pos]; + + Node *child; + Children::iterator iter = m_children.find(vocabId); + if (iter == m_children.end()) { + // New node. Write other children then discard them + BOOST_FOREACH(Children::value_type &valPair, m_children) { + Node &otherChild = valPair.second; + otherChild.Write(table); + } + m_children.clear(); + + // create new node + child = &m_children[vocabId]; + assert(!child->done); + child->key = key + (vocabId << pos); + } else { + child = &iter->second; + } + + child->Add(table, sourcePhrase, pos + 1); + } else { + // this node was written previously 'cos it has rules + done = true; } } -void BinaryFileWriter::flush () +void Node::Write(Table &table) { - //Cast unsigned char to char before writing... - os.write((char *)&binfile[0], dist_from_start); - //Clear the vector: - binfile.clear(); - binfile.reserve(10000); - extra_counter += dist_from_start; //Keep track of the total number of bytes. - it = binfile.begin(); //Reset iterator - dist_from_start = distance(binfile.begin(),it); //Reset dist from start -} + //cerr << "START write " << done << " " << key << endl; + BOOST_FOREACH(Children::value_type &valPair, m_children) { + Node &child = valPair.second; + child.Write(table); + } -BinaryFileWriter::~BinaryFileWriter () -{ - os.close(); - binfile.clear(); + if (!done) { + // save + Entry sourceEntry; + sourceEntry.value = NONE; + sourceEntry.key = key; + + //Put into table + table.Insert(sourceEntry); + } } -void createProbingPT(const char * phrasetable_path, const char * target_path, - const char * num_scores, const char * is_reordering) +/////////////////////////////////////////////////////////////////////// +void createProbingPT(const std::string &phrasetable_path, + const std::string &basepath, int num_scores, int num_lex_scores, + bool log_prob, int max_cache_size, bool scfg) { + std::cerr << "Starting..." << std::endl; + //Get basepath and create directory if missing - std::string basepath(target_path); mkdir(basepath.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); - //Set up huffman and serialize decoder maps. - Huffman huffmanEncoder(phrasetable_path); //initialize - huffmanEncoder.assign_values(); - huffmanEncoder.produce_lookups(); - huffmanEncoder.serialize_maps(target_path); + StoreTarget storeTarget(basepath); //Get uniq lines: - unsigned long uniq_entries = huffmanEncoder.getUniqLines(); + unsigned long uniq_entries = countUniqueSource(phrasetable_path); //Source phrase vocabids - std::map<uint64_t, std::string> source_vocabids; + StoreVocab<uint64_t> sourceVocab(basepath + "/source_vocabids"); //Read the file - util::FilePiece filein(phrasetable_path); + util::FilePiece filein(phrasetable_path.c_str()); //Init the probing hash table size_t size = Table::Size(uniq_entries, 1.2); char * mem = new char[size]; memset(mem, 0, size); - Table table(mem, size); + Table sourceEntries(mem, size); - BinaryFileWriter binfile(basepath); //Init the binary file writer. - - line_text prev_line; //Check if the source phrase of the previous line is the same + std::priority_queue<CacheItem*, std::vector<CacheItem*>, CacheItemOrderer> cache; + float totalSourceCount = 0; //Keep track of the size of each group of target phrases - uint64_t entrystartidx = 0; - //uint64_t line_num = 0; - + size_t line_num = 0; //Read everything and processs - while(true) { + std::string prevSource; + + Node sourcePhrases; + sourcePhrases.done = true; + sourcePhrases.key = 0; + + while (true) { try { //Process line read line_text line; - line = splitLine(filein.ReadLine()); - //Add source phrases to vocabularyIDs - add_to_map(&source_vocabids, line.source_phrase); + line = splitLine(filein.ReadLine(), scfg); + //cerr << "line=" << line.source_phrase << endl; - if ((binfile.dist_from_start + binfile.extra_counter) == 0) { - prev_line = line; //For the first iteration assume the previous line is - } //The same as this one. + ++line_num; + if (line_num % 1000000 == 0) { + std::cerr << line_num << " " << std::flush; + } - if (line.source_phrase != prev_line.source_phrase) { + //Add source phrases to vocabularyIDs + add_to_map(sourceVocab, line.source_phrase); + + if (prevSource.empty()) { + // 1st line + prevSource = line.source_phrase.as_string(); + storeTarget.Append(line, log_prob, scfg); + } else if (prevSource == line.source_phrase) { + //If we still have the same line, just append to it: + storeTarget.Append(line, log_prob, scfg); + } else { + assert(prevSource != line.source_phrase); //Create a new entry even + // save + uint64_t targetInd = storeTarget.Save(); + + // next line + storeTarget.Append(line, log_prob, scfg); + //Create an entry for the previous source phrase: - Entry pesho; - pesho.value = entrystartidx; + Entry sourceEntry; + sourceEntry.value = targetInd; //The key is the sum of hashes of individual words bitshifted by their position in the phrase. //Probably not entirerly correct, but fast and seems to work fine in practise. - pesho.key = 0; - std::vector<uint64_t> vocabid_source = getVocabIDs(prev_line.source_phrase); - for (int i = 0; i < vocabid_source.size(); i++) { - pesho.key += (vocabid_source[i] << i); + std::vector<uint64_t> vocabid_source = getVocabIDs(prevSource); + if (scfg) { + // storing prefixes? + sourcePhrases.Add(sourceEntries, vocabid_source); } - pesho.bytes_toread = binfile.dist_from_start + binfile.extra_counter - entrystartidx; + sourceEntry.key = getKey(vocabid_source); + /* + cerr << "prevSource=" << prevSource << flush + << " vocabids=" << Debug(vocabid_source) << flush + << " key=" << sourceEntry.key << endl; + */ //Put into table - table.Insert(pesho); + sourceEntries.Insert(sourceEntry); - entrystartidx = binfile.dist_from_start + binfile.extra_counter; //Designate start idx for new entry + // update cache - CURRENT source phrase, not prev + if (max_cache_size) { + std::string countStr = line.counts.as_string(); + countStr = Trim(countStr); + if (!countStr.empty()) { + std::vector<float> toks = Tokenize<float>(countStr); + //cerr << "CACHE:" << line.source_phrase << " " << countStr << " " << toks[1] << endl; - //Encode a line and write it to disk. - std::vector<unsigned char> encoded_line = huffmanEncoder.full_encode_line(line); - binfile.write(&encoded_line); + if (toks.size() >= 2) { + totalSourceCount += toks[1]; - //Set prevLine - prev_line = line; + // compute key for CURRENT source + std::vector<uint64_t> currVocabidSource = getVocabIDs(line.source_phrase.as_string()); + uint64_t currKey = getKey(currVocabidSource); - } else { - //If we still have the same line, just append to it: - std::vector<unsigned char> encoded_line = huffmanEncoder.full_encode_line(line); - binfile.write(&encoded_line); + CacheItem *item = new CacheItem( + Trim(line.source_phrase.as_string()), + currKey, + toks[1]); + cache.push(item); + + if (max_cache_size > 0 && cache.size() > max_cache_size) { + cache.pop(); + } + } + } + } + + //Set prevLine + prevSource = line.source_phrase.as_string(); } } catch (util::EndOfFileException e) { - std::cerr << "Reading phrase table finished, writing remaining files to disk." << std::endl; - binfile.flush(); + std::cerr + << "Reading phrase table finished, writing remaining files to disk." + << std::endl; //After the final entry is constructed we need to add it to the phrase_table //Create an entry for the previous source phrase: - Entry pesho; - pesho.value = entrystartidx; + uint64_t targetInd = storeTarget.Save(); + + Entry sourceEntry; + sourceEntry.value = targetInd; + //The key is the sum of hashes of individual words. Probably not entirerly correct, but fast - pesho.key = 0; - std::vector<uint64_t> vocabid_source = getVocabIDs(prev_line.source_phrase); - for (int i = 0; i < vocabid_source.size(); i++) { - pesho.key += (vocabid_source[i] << i); - } - pesho.bytes_toread = binfile.dist_from_start + binfile.extra_counter - entrystartidx; + std::vector<uint64_t> vocabid_source = getVocabIDs(prevSource); + sourceEntry.key = getKey(vocabid_source); + //Put into table - table.Insert(pesho); + sourceEntries.Insert(sourceEntry); break; } } - serialize_table(mem, size, (basepath + "/probing_hash.dat").c_str()); + sourcePhrases.Write(sourceEntries); + + storeTarget.SaveAlignment(); + + serialize_table(mem, size, (basepath + "/probing_hash.dat")); + + sourceVocab.Save(); - serialize_map(&source_vocabids, (basepath + "/source_vocabids").c_str()); + serialize_cache(cache, (basepath + "/cache"), totalSourceCount); delete[] mem; //Write configfile std::ofstream configfile; configfile.open((basepath + "/config").c_str()); - configfile << API_VERSION << '\n'; - configfile << uniq_entries << '\n'; - configfile << num_scores << '\n'; - configfile << is_reordering << '\n'; + configfile << "API_VERSION\t" << API_VERSION << '\n'; + configfile << "uniq_entries\t" << uniq_entries << '\n'; + configfile << "num_scores\t" << num_scores << '\n'; + configfile << "num_lex_scores\t" << num_lex_scores << '\n'; + configfile << "log_prob\t" << log_prob << '\n'; configfile.close(); } + +size_t countUniqueSource(const std::string &path) +{ + size_t ret = 0; + InputFileStream strme(path); + + std::string line, prevSource; + while (std::getline(strme, line)) { + std::vector<std::string> toks = TokenizeMultiCharSeparator(line, "|||"); + assert(toks.size() != 0); + + if (prevSource != toks[0]) { + prevSource = toks[0]; + ++ret; + } + } + + return ret; +} + +void serialize_cache( + std::priority_queue<CacheItem*, std::vector<CacheItem*>, CacheItemOrderer> &cache, + const std::string &path, float totalSourceCount) +{ + std::vector<const CacheItem*> vec(cache.size()); + + size_t ind = cache.size() - 1; + while (!cache.empty()) { + const CacheItem *item = cache.top(); + vec[ind] = item; + cache.pop(); + --ind; + } + + std::ofstream os(path.c_str()); + + os << totalSourceCount << std::endl; + for (size_t i = 0; i < vec.size(); ++i) { + const CacheItem *item = vec[i]; + os << item->count << "\t" << item->sourceKey << "\t" << item->source << std::endl; + delete item; + } + + os.close(); +} + +uint64_t getKey(const std::vector<uint64_t> &vocabid_source) +{ + return getKey(vocabid_source.data(), vocabid_source.size()); +} + +std::vector<uint64_t> CreatePrefix(const std::vector<uint64_t> &vocabid_source, size_t endPos) +{ + assert(endPos < vocabid_source.size()); + + std::vector<uint64_t> ret(endPos + 1); + for (size_t i = 0; i <= endPos; ++i) { + ret[i] = vocabid_source[i]; + } + return ret; +} + +} + |