diff options
Diffstat (limited to 'moses/TranslationModel/ProbingPT/quering.cpp')
-rw-r--r-- | moses/TranslationModel/ProbingPT/quering.cpp | 221 |
1 files changed, 80 insertions, 141 deletions
diff --git a/moses/TranslationModel/ProbingPT/quering.cpp b/moses/TranslationModel/ProbingPT/quering.cpp index bd1d61a1e..ef980ef06 100644 --- a/moses/TranslationModel/ProbingPT/quering.cpp +++ b/moses/TranslationModel/ProbingPT/quering.cpp @@ -1,73 +1,80 @@ #include "quering.hh" +#include "util/exception.hh" -unsigned char * read_binary_file(const char * filename, size_t filesize) -{ - //Get filesize - int fd; - unsigned char * map; - - fd = open(filename, O_RDONLY); - - if (fd == -1) { - perror("Error opening file for reading"); - exit(EXIT_FAILURE); - } - - map = (unsigned char *)mmap(0, filesize, PROT_READ, MAP_SHARED, fd, 0); - if (map == MAP_FAILED) { - close(fd); - perror("Error mmapping the file"); - exit(EXIT_FAILURE); - } +using namespace std; - return map; -} +namespace Moses +{ -QueryEngine::QueryEngine(const char * filepath) : decoder(filepath) +QueryEngine::QueryEngine(const char * filepath) { //Create filepaths std::string basepath(filepath); std::string path_to_hashtable = basepath + "/probing_hash.dat"; - std::string path_to_data_bin = basepath + "/binfile.dat"; std::string path_to_source_vocabid = basepath + "/source_vocabids"; + std::string alignPath = basepath + "/Alignments.dat"; ///Source phrase vocabids - read_map(&source_vocabids, path_to_source_vocabid.c_str()); + read_map(source_vocabids, path_to_source_vocabid.c_str()); - //Target phrase vocabIDs - vocabids = decoder.get_target_lookup_map(); + // alignments + read_alignments(alignPath); //Read config file + boost::unordered_map<std::string, std::string> keyValue; + + std::ifstream config((basepath + "/config").c_str()); std::string line; - std::ifstream config ((basepath + "/config").c_str()); + while (getline(config, line)) { + std::vector<std::string> toks = Tokenize(line, "\t"); + UTIL_THROW_IF2(toks.size() != 2, "Wrong config format:" << line); + keyValue[ toks[0] ] = toks[1]; + } + + bool found; //Check API version: - getline(config, line); - if (atoi(line.c_str()) != API_VERSION) { - std::cerr << "The ProbingPT API has changed, please rebinarize your phrase tables." << std::endl; + int version; + found = Get(keyValue, "API_VERSION", version); + if (!found) { + std::cerr << "Old or corrupted version of ProbingPT. Please rebinarize your phrase tables." << std::endl; + } + else if (version != API_VERSION) { + std::cerr << "The ProbingPT API has changed. " << version << "!=" + << API_VERSION << " Please rebinarize your phrase tables." << std::endl; exit(EXIT_FAILURE); } + //Get tablesize. - getline(config, line); - int tablesize = atoi(line.c_str()); + int tablesize; + found = Get(keyValue, "uniq_entries", tablesize); + if (!found) { + std::cerr << "uniq_entries not found" << std::endl; + exit(EXIT_FAILURE); + } + //Number of scores - getline(config, line); - num_scores = atoi(line.c_str()); - //do we have a reordering table - getline(config, line); - std::transform(line.begin(), line.end(), line.begin(), ::tolower); //Get the boolean in lowercase - is_reordering = false; - if (line == "true") { - is_reordering = true; - std::cerr << "WARNING. REORDERING TABLES NOT SUPPORTED YET." << std::endl; + found = Get(keyValue, "num_scores", num_scores); + if (!found) { + std::cerr << "num_scores not found" << std::endl; + exit(EXIT_FAILURE); } - config.close(); - //Mmap binary table - struct stat filestatus; - stat(path_to_data_bin.c_str(), &filestatus); - binary_filesize = filestatus.st_size; - binary_mmaped = read_binary_file(path_to_data_bin.c_str(), binary_filesize); + //How may scores from lex reordering models + found = Get(keyValue, "num_lex_scores", num_lex_scores); + if (!found) { + std::cerr << "num_lex_scores not found" << std::endl; + exit(EXIT_FAILURE); + } + + // have the scores been log() and FloorScore()? + found = Get(keyValue, "log_prob", logProb); + if (!found) { + std::cerr << "logProb not found" << std::endl; + exit(EXIT_FAILURE); + } + + config.close(); //Read hashtable table_filesize = Table::Size(tablesize, 1.2); @@ -81,118 +88,50 @@ QueryEngine::QueryEngine(const char * filepath) : decoder(filepath) QueryEngine::~QueryEngine() { //Clear mmap content from memory. - munmap(binary_mmaped, binary_filesize); munmap(mem, table_filesize); } -std::pair<bool, std::vector<target_text> > QueryEngine::query(std::vector<uint64_t> source_phrase) +uint64_t QueryEngine::getKey(uint64_t source_phrase[], size_t size) const { - bool found; - std::vector<target_text> translation_entries; - const Entry * entry; //TOO SLOW //uint64_t key = util::MurmurHashNative(&source_phrase[0], source_phrase.size()); - uint64_t key = 0; - for (int i = 0; i < source_phrase.size(); i++) { - key += (source_phrase[i] << i); - } - - - found = table.Find(key, entry); - - if (found) { - //The phrase that was searched for was found! We need to get the translation entries. - //We will read the largest entry in bytes and then filter the unnecesarry with functions - //from line_splitter - uint64_t initial_index = entry -> GetValue(); - unsigned int bytes_toread = entry -> bytes_toread; - - //ASK HIEU FOR MORE EFFICIENT WAY TO DO THIS! - std::vector<unsigned char> encoded_text; //Assign to the vector the relevant portion of the array. - encoded_text.reserve(bytes_toread); - for (int i = 0; i < bytes_toread; i++) { - encoded_text.push_back(binary_mmaped[i+initial_index]); - } - - //Get only the translation entries necessary - translation_entries = decoder.full_decode_line(encoded_text, num_scores); - - } - - std::pair<bool, std::vector<target_text> > output (found, translation_entries); - - return output; - + return getKey(source_phrase, size); } -std::pair<bool, std::vector<target_text> > QueryEngine::query(StringPiece source_phrase) +std::pair<bool, uint64_t> QueryEngine::query(uint64_t key) { - bool found; - std::vector<target_text> translation_entries; - const Entry * entry; - //Convert source frase to VID - std::vector<uint64_t> source_phrase_vid = getVocabIDs(source_phrase); - //TOO SLOW - //uint64_t key = util::MurmurHashNative(&source_phrase_vid[0], source_phrase_vid.size()); - uint64_t key = 0; - for (int i = 0; i < source_phrase_vid.size(); i++) { - key += (source_phrase_vid[i] << i); - } - - found = table.Find(key, entry); - - - if (found) { - //The phrase that was searched for was found! We need to get the translation entries. - //We will read the largest entry in bytes and then filter the unnecesarry with functions - //from line_splitter - uint64_t initial_index = entry -> GetValue(); - unsigned int bytes_toread = entry -> bytes_toread; - //At the end of the file we can't readd + largest_entry cause we get a segfault. - std::cerr << "Entry size is bytes is: " << bytes_toread << std::endl; - - //ASK HIEU FOR MORE EFFICIENT WAY TO DO THIS! - std::vector<unsigned char> encoded_text; //Assign to the vector the relevant portion of the array. - encoded_text.reserve(bytes_toread); - for (int i = 0; i < bytes_toread; i++) { - encoded_text.push_back(binary_mmaped[i+initial_index]); - } - - //Get only the translation entries necessary - translation_entries = decoder.full_decode_line(encoded_text, num_scores); + std::pair<bool, uint64_t> ret; + const Entry * entry; + ret.first = table.Find(key, entry); + if (ret.first) { + ret.second = entry->value; } - - std::pair<bool, std::vector<target_text> > output (found, translation_entries); - - return output; - + return ret; } -void QueryEngine::printTargetInfo(std::vector<target_text> target_phrases) +void QueryEngine::read_alignments(const std::string &alignPath) { - int entries = target_phrases.size(); + std::ifstream strm(alignPath.c_str()); - for (int i = 0; i<entries; i++) { - std::cout << "Entry " << i+1 << " of " << entries << ":" << std::endl; - //Print text - std::cout << getTargetWordsFromIDs(target_phrases[i].target_phrase, &vocabids) << "\t"; + string line; + while (getline(strm, line)) { + vector<string> toks = Tokenize(line, "\t "); + UTIL_THROW_IF2(toks.size() == 0, "Corrupt alignment file"); - //Print probabilities: - for (int j = 0; j<target_phrases[i].prob.size(); j++) { - std::cout << target_phrases[i].prob[j] << " "; + uint32_t alignInd = Scan<uint32_t>(toks[0]); + if (alignInd >= alignColl.size()) { + alignColl.resize(alignInd + 1); } - std::cout << "\t"; - - //Print word_all1 - for (int j = 0; j<target_phrases[i].word_all1.size(); j++) { - if (j%2 == 0) { - std::cout << (short)target_phrases[i].word_all1[j] << "-"; - } else { - std::cout << (short)target_phrases[i].word_all1[j] << " "; - } + + Alignments &aligns = alignColl[alignInd]; + for (size_t i = 1; i < toks.size(); ++i) { + size_t pos = Scan<size_t>(toks[i]); + aligns.push_back(pos); } - std::cout << std::endl; } } + +} + |