Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'moses/TranslationModel/ProbingPT/quering.cpp')
-rw-r--r--moses/TranslationModel/ProbingPT/quering.cpp221
1 files changed, 80 insertions, 141 deletions
diff --git a/moses/TranslationModel/ProbingPT/quering.cpp b/moses/TranslationModel/ProbingPT/quering.cpp
index bd1d61a1e..ef980ef06 100644
--- a/moses/TranslationModel/ProbingPT/quering.cpp
+++ b/moses/TranslationModel/ProbingPT/quering.cpp
@@ -1,73 +1,80 @@
#include "quering.hh"
+#include "util/exception.hh"
-unsigned char * read_binary_file(const char * filename, size_t filesize)
-{
- //Get filesize
- int fd;
- unsigned char * map;
-
- fd = open(filename, O_RDONLY);
-
- if (fd == -1) {
- perror("Error opening file for reading");
- exit(EXIT_FAILURE);
- }
-
- map = (unsigned char *)mmap(0, filesize, PROT_READ, MAP_SHARED, fd, 0);
- if (map == MAP_FAILED) {
- close(fd);
- perror("Error mmapping the file");
- exit(EXIT_FAILURE);
- }
+using namespace std;
- return map;
-}
+namespace Moses
+{
-QueryEngine::QueryEngine(const char * filepath) : decoder(filepath)
+QueryEngine::QueryEngine(const char * filepath)
{
//Create filepaths
std::string basepath(filepath);
std::string path_to_hashtable = basepath + "/probing_hash.dat";
- std::string path_to_data_bin = basepath + "/binfile.dat";
std::string path_to_source_vocabid = basepath + "/source_vocabids";
+ std::string alignPath = basepath + "/Alignments.dat";
///Source phrase vocabids
- read_map(&source_vocabids, path_to_source_vocabid.c_str());
+ read_map(source_vocabids, path_to_source_vocabid.c_str());
- //Target phrase vocabIDs
- vocabids = decoder.get_target_lookup_map();
+ // alignments
+ read_alignments(alignPath);
//Read config file
+ boost::unordered_map<std::string, std::string> keyValue;
+
+ std::ifstream config((basepath + "/config").c_str());
std::string line;
- std::ifstream config ((basepath + "/config").c_str());
+ while (getline(config, line)) {
+ std::vector<std::string> toks = Tokenize(line, "\t");
+ UTIL_THROW_IF2(toks.size() != 2, "Wrong config format:" << line);
+ keyValue[ toks[0] ] = toks[1];
+ }
+
+ bool found;
//Check API version:
- getline(config, line);
- if (atoi(line.c_str()) != API_VERSION) {
- std::cerr << "The ProbingPT API has changed, please rebinarize your phrase tables." << std::endl;
+ int version;
+ found = Get(keyValue, "API_VERSION", version);
+ if (!found) {
+ std::cerr << "Old or corrupted version of ProbingPT. Please rebinarize your phrase tables." << std::endl;
+ }
+ else if (version != API_VERSION) {
+ std::cerr << "The ProbingPT API has changed. " << version << "!="
+ << API_VERSION << " Please rebinarize your phrase tables." << std::endl;
exit(EXIT_FAILURE);
}
+
//Get tablesize.
- getline(config, line);
- int tablesize = atoi(line.c_str());
+ int tablesize;
+ found = Get(keyValue, "uniq_entries", tablesize);
+ if (!found) {
+ std::cerr << "uniq_entries not found" << std::endl;
+ exit(EXIT_FAILURE);
+ }
+
//Number of scores
- getline(config, line);
- num_scores = atoi(line.c_str());
- //do we have a reordering table
- getline(config, line);
- std::transform(line.begin(), line.end(), line.begin(), ::tolower); //Get the boolean in lowercase
- is_reordering = false;
- if (line == "true") {
- is_reordering = true;
- std::cerr << "WARNING. REORDERING TABLES NOT SUPPORTED YET." << std::endl;
+ found = Get(keyValue, "num_scores", num_scores);
+ if (!found) {
+ std::cerr << "num_scores not found" << std::endl;
+ exit(EXIT_FAILURE);
}
- config.close();
- //Mmap binary table
- struct stat filestatus;
- stat(path_to_data_bin.c_str(), &filestatus);
- binary_filesize = filestatus.st_size;
- binary_mmaped = read_binary_file(path_to_data_bin.c_str(), binary_filesize);
+ //How may scores from lex reordering models
+ found = Get(keyValue, "num_lex_scores", num_lex_scores);
+ if (!found) {
+ std::cerr << "num_lex_scores not found" << std::endl;
+ exit(EXIT_FAILURE);
+ }
+
+ // have the scores been log() and FloorScore()?
+ found = Get(keyValue, "log_prob", logProb);
+ if (!found) {
+ std::cerr << "logProb not found" << std::endl;
+ exit(EXIT_FAILURE);
+ }
+
+ config.close();
//Read hashtable
table_filesize = Table::Size(tablesize, 1.2);
@@ -81,118 +88,50 @@ QueryEngine::QueryEngine(const char * filepath) : decoder(filepath)
QueryEngine::~QueryEngine()
{
//Clear mmap content from memory.
- munmap(binary_mmaped, binary_filesize);
munmap(mem, table_filesize);
}
-std::pair<bool, std::vector<target_text> > QueryEngine::query(std::vector<uint64_t> source_phrase)
+uint64_t QueryEngine::getKey(uint64_t source_phrase[], size_t size) const
{
- bool found;
- std::vector<target_text> translation_entries;
- const Entry * entry;
//TOO SLOW
//uint64_t key = util::MurmurHashNative(&source_phrase[0], source_phrase.size());
- uint64_t key = 0;
- for (int i = 0; i < source_phrase.size(); i++) {
- key += (source_phrase[i] << i);
- }
-
-
- found = table.Find(key, entry);
-
- if (found) {
- //The phrase that was searched for was found! We need to get the translation entries.
- //We will read the largest entry in bytes and then filter the unnecesarry with functions
- //from line_splitter
- uint64_t initial_index = entry -> GetValue();
- unsigned int bytes_toread = entry -> bytes_toread;
-
- //ASK HIEU FOR MORE EFFICIENT WAY TO DO THIS!
- std::vector<unsigned char> encoded_text; //Assign to the vector the relevant portion of the array.
- encoded_text.reserve(bytes_toread);
- for (int i = 0; i < bytes_toread; i++) {
- encoded_text.push_back(binary_mmaped[i+initial_index]);
- }
-
- //Get only the translation entries necessary
- translation_entries = decoder.full_decode_line(encoded_text, num_scores);
-
- }
-
- std::pair<bool, std::vector<target_text> > output (found, translation_entries);
-
- return output;
-
+ return getKey(source_phrase, size);
}
-std::pair<bool, std::vector<target_text> > QueryEngine::query(StringPiece source_phrase)
+std::pair<bool, uint64_t> QueryEngine::query(uint64_t key)
{
- bool found;
- std::vector<target_text> translation_entries;
- const Entry * entry;
- //Convert source frase to VID
- std::vector<uint64_t> source_phrase_vid = getVocabIDs(source_phrase);
- //TOO SLOW
- //uint64_t key = util::MurmurHashNative(&source_phrase_vid[0], source_phrase_vid.size());
- uint64_t key = 0;
- for (int i = 0; i < source_phrase_vid.size(); i++) {
- key += (source_phrase_vid[i] << i);
- }
-
- found = table.Find(key, entry);
-
-
- if (found) {
- //The phrase that was searched for was found! We need to get the translation entries.
- //We will read the largest entry in bytes and then filter the unnecesarry with functions
- //from line_splitter
- uint64_t initial_index = entry -> GetValue();
- unsigned int bytes_toread = entry -> bytes_toread;
- //At the end of the file we can't readd + largest_entry cause we get a segfault.
- std::cerr << "Entry size is bytes is: " << bytes_toread << std::endl;
-
- //ASK HIEU FOR MORE EFFICIENT WAY TO DO THIS!
- std::vector<unsigned char> encoded_text; //Assign to the vector the relevant portion of the array.
- encoded_text.reserve(bytes_toread);
- for (int i = 0; i < bytes_toread; i++) {
- encoded_text.push_back(binary_mmaped[i+initial_index]);
- }
-
- //Get only the translation entries necessary
- translation_entries = decoder.full_decode_line(encoded_text, num_scores);
+ std::pair<bool, uint64_t> ret;
+ const Entry * entry;
+ ret.first = table.Find(key, entry);
+ if (ret.first) {
+ ret.second = entry->value;
}
-
- std::pair<bool, std::vector<target_text> > output (found, translation_entries);
-
- return output;
-
+ return ret;
}
-void QueryEngine::printTargetInfo(std::vector<target_text> target_phrases)
+void QueryEngine::read_alignments(const std::string &alignPath)
{
- int entries = target_phrases.size();
+ std::ifstream strm(alignPath.c_str());
- for (int i = 0; i<entries; i++) {
- std::cout << "Entry " << i+1 << " of " << entries << ":" << std::endl;
- //Print text
- std::cout << getTargetWordsFromIDs(target_phrases[i].target_phrase, &vocabids) << "\t";
+ string line;
+ while (getline(strm, line)) {
+ vector<string> toks = Tokenize(line, "\t ");
+ UTIL_THROW_IF2(toks.size() == 0, "Corrupt alignment file");
- //Print probabilities:
- for (int j = 0; j<target_phrases[i].prob.size(); j++) {
- std::cout << target_phrases[i].prob[j] << " ";
+ uint32_t alignInd = Scan<uint32_t>(toks[0]);
+ if (alignInd >= alignColl.size()) {
+ alignColl.resize(alignInd + 1);
}
- std::cout << "\t";
-
- //Print word_all1
- for (int j = 0; j<target_phrases[i].word_all1.size(); j++) {
- if (j%2 == 0) {
- std::cout << (short)target_phrases[i].word_all1[j] << "-";
- } else {
- std::cout << (short)target_phrases[i].word_all1[j] << " ";
- }
+
+ Alignments &aligns = alignColl[alignInd];
+ for (size_t i = 1; i < toks.size(); ++i) {
+ size_t pos = Scan<size_t>(toks[i]);
+ aligns.push_back(pos);
}
- std::cout << std::endl;
}
}
+
+}
+