diff options
author | Rico Sennrich <rico.sennrich@gmx.ch> | 2014-09-16 14:12:14 +0400 |
---|---|---|
committer | Rico Sennrich <rico.sennrich@gmx.ch> | 2014-09-22 13:49:20 +0400 |
commit | f40bb2c53c2dcd832bde9e987c921171e2d1e581 (patch) | |
tree | 235b3b3a82a36ef1a92f31a42623e6187869f882 /mert/HwcmScorer.cpp | |
parent | 2c66ae5e34fb1165a9e0d996305d33d8318fb1bc (diff) |
HWCM for MERT
Diffstat (limited to 'mert/HwcmScorer.cpp')
-rw-r--r-- | mert/HwcmScorer.cpp | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/mert/HwcmScorer.cpp b/mert/HwcmScorer.cpp new file mode 100644 index 000000000..5e6adec52 --- /dev/null +++ b/mert/HwcmScorer.cpp @@ -0,0 +1,165 @@ +#include "HwcmScorer.h" + +#include <fstream> + +#include "ScoreStats.h" +#include "Util.h" + +#include "util/tokenize_piece.hh" + +// HWCM score (Liu and Gildea, 2005). Implements F1 instead of precision for better modelling of hypothesis length. +// assumes dependency trees on target side (generated by scripts/training/wrappers/conll2mosesxml.py ; use with option --brackets for reference). +// reads reference trees from separate file {REFERENCE_FILE}.trees to support mix of string-based and tree-based metrics. + +using namespace std; + +namespace MosesTuning +{ + + +HwcmScorer::HwcmScorer(const string& config) + : StatisticsBasedScorer("HWCM",config) {} + +HwcmScorer::~HwcmScorer() {} + +void HwcmScorer::setReferenceFiles(const vector<string>& referenceFiles) +{ + // For each line in the reference file, create a tree object + if (referenceFiles.size() != 1) { + throw runtime_error("HWCM only supports a single reference"); + } + m_ref_trees.clear(); + m_ref_hwc.clear(); + ifstream in((referenceFiles[0] + ".trees").c_str()); + if (!in) { + throw runtime_error("Unable to open " + referenceFiles[0] + ".trees"); + } + string line; + while (getline(in,line)) { + line = this->preprocessSentence(line); + TreePointer tree (boost::make_shared<InternalTree>(line)); + m_ref_trees.push_back(tree); + vector<map<string, int> > hwc (kHwcmOrder); + vector<string> history(kHwcmOrder); + extractHeadWordChain(tree, history, hwc); + m_ref_hwc.push_back(hwc); + vector<int> totals(kHwcmOrder); + for (size_t i = 0; i < kHwcmOrder; i++) { + for (map<string, int>::const_iterator it = m_ref_hwc.back()[i].begin(); it != m_ref_hwc.back()[i].end(); it++) { + totals[i] += it->second; + } + } + m_ref_lengths.push_back(totals); + } + TRACE_ERR(endl); + +} + +void HwcmScorer::extractHeadWordChain(TreePointer tree, vector<string> & history, vector<map<string, int> > & hwc) { + + if (tree->GetLength() > 0) { + string head = getHead(tree); + + if (head.empty()) { + for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) { + extractHeadWordChain(*it, history, hwc); + } + } + else { + vector<string> new_history(kHwcmOrder); + new_history[0] = head; + hwc[0][head]++; + for (size_t hist_idx = 0; hist_idx < kHwcmOrder-1; hist_idx++) { + if (!history[hist_idx].empty()) { + string chain = history[hist_idx] + " " + head; + hwc[hist_idx+1][chain]++; + if (hist_idx+2 < kHwcmOrder) { + new_history[hist_idx+1] = chain; + } + } + } + for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) { + extractHeadWordChain(*it, new_history, hwc); + } + } + } +} + +string HwcmScorer::getHead(TreePointer tree) { + // assumption (only true for dependency parse: each constituent has a preterminal label, and corresponding terminal is head) + // if constituent has multiple preterminals, first one is picked; if it has no preterminals, empty string is returned + for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) + { + TreePointer child = *it; + + if (child->GetLength() == 1 && child->GetChildren()[0]->IsTerminal()) { + return child->GetChildren()[0]->GetLabel(); + } + } + return ""; + +} + +void HwcmScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry) +{ + if (sid >= m_ref_trees.size()) { + stringstream msg; + msg << "Sentence id (" << sid << ") not found in reference set"; + throw runtime_error(msg.str()); + } + + string sentence = this->preprocessSentence(text); + + // if sentence has '|||', assume that tree is in second position (n-best-list); + // otherwise, assume it is in first position (calling 'evaluate' with tree as reference) + util::TokenIter<util::MultiCharacter> it(sentence, util::MultiCharacter("|||")); + ++it; + if (it) { + sentence = it->as_string(); + } + + TreePointer tree (boost::make_shared<InternalTree>(sentence)); + vector<map<string, int> > hwc_test (kHwcmOrder); + vector<string> history(kHwcmOrder); + extractHeadWordChain(tree, history, hwc_test); + + ostringstream stats; + for (size_t i = 0; i < kHwcmOrder; i++) { + int correct = 0; + int test_total = 0; + for (map<string, int>::const_iterator it = hwc_test[i].begin(); it != hwc_test[i].end(); it++) { + test_total += it->second; + map<string, int>::const_iterator it2 = m_ref_hwc[sid][i].find(it->first); + if (it2 != m_ref_hwc[sid][i].end()) { + correct += std::min(it->second, it2->second); + } + } + stats << correct << " " << test_total << " " << m_ref_lengths[sid][i] << " " ; + } + + string stats_str = stats.str(); + entry.set(stats_str); +} + +float HwcmScorer::calculateScore(const vector<int>& comps) const +{ + float precision = 0; + float recall = 0; + for (size_t i = 0; i < kHwcmOrder; i++) { + float matches = comps[i*3]; + float test_total = comps[1+(i*3)]; + float ref_total = comps[2+(i*3)]; + if (test_total > 0) { + precision += matches/test_total; + } + if (ref_total > 0) { + recall += matches/ref_total; + } + } + + precision /= (float)kHwcmOrder; + recall /= (float)kHwcmOrder; + return (2*precision*recall)/(precision+recall); // f1-score +} + +}
\ No newline at end of file |