Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRico Sennrich <rico.sennrich@gmx.ch>2014-09-16 14:12:14 +0400
committerRico Sennrich <rico.sennrich@gmx.ch>2014-09-22 13:49:20 +0400
commitf40bb2c53c2dcd832bde9e987c921171e2d1e581 (patch)
tree235b3b3a82a36ef1a92f31a42623e6187869f882 /mert/HwcmScorer.cpp
parent2c66ae5e34fb1165a9e0d996305d33d8318fb1bc (diff)
HWCM for MERT
Diffstat (limited to 'mert/HwcmScorer.cpp')
-rw-r--r--mert/HwcmScorer.cpp165
1 files changed, 165 insertions, 0 deletions
diff --git a/mert/HwcmScorer.cpp b/mert/HwcmScorer.cpp
new file mode 100644
index 000000000..5e6adec52
--- /dev/null
+++ b/mert/HwcmScorer.cpp
@@ -0,0 +1,165 @@
+#include "HwcmScorer.h"
+
+#include <fstream>
+
+#include "ScoreStats.h"
+#include "Util.h"
+
+#include "util/tokenize_piece.hh"
+
+// HWCM score (Liu and Gildea, 2005). Implements F1 instead of precision for better modelling of hypothesis length.
+// assumes dependency trees on target side (generated by scripts/training/wrappers/conll2mosesxml.py ; use with option --brackets for reference).
+// reads reference trees from separate file {REFERENCE_FILE}.trees to support mix of string-based and tree-based metrics.
+
+using namespace std;
+
+namespace MosesTuning
+{
+
+
+HwcmScorer::HwcmScorer(const string& config)
+ : StatisticsBasedScorer("HWCM",config) {}
+
+HwcmScorer::~HwcmScorer() {}
+
+void HwcmScorer::setReferenceFiles(const vector<string>& referenceFiles)
+{
+ // For each line in the reference file, create a tree object
+ if (referenceFiles.size() != 1) {
+ throw runtime_error("HWCM only supports a single reference");
+ }
+ m_ref_trees.clear();
+ m_ref_hwc.clear();
+ ifstream in((referenceFiles[0] + ".trees").c_str());
+ if (!in) {
+ throw runtime_error("Unable to open " + referenceFiles[0] + ".trees");
+ }
+ string line;
+ while (getline(in,line)) {
+ line = this->preprocessSentence(line);
+ TreePointer tree (boost::make_shared<InternalTree>(line));
+ m_ref_trees.push_back(tree);
+ vector<map<string, int> > hwc (kHwcmOrder);
+ vector<string> history(kHwcmOrder);
+ extractHeadWordChain(tree, history, hwc);
+ m_ref_hwc.push_back(hwc);
+ vector<int> totals(kHwcmOrder);
+ for (size_t i = 0; i < kHwcmOrder; i++) {
+ for (map<string, int>::const_iterator it = m_ref_hwc.back()[i].begin(); it != m_ref_hwc.back()[i].end(); it++) {
+ totals[i] += it->second;
+ }
+ }
+ m_ref_lengths.push_back(totals);
+ }
+ TRACE_ERR(endl);
+
+}
+
+void HwcmScorer::extractHeadWordChain(TreePointer tree, vector<string> & history, vector<map<string, int> > & hwc) {
+
+ if (tree->GetLength() > 0) {
+ string head = getHead(tree);
+
+ if (head.empty()) {
+ for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
+ extractHeadWordChain(*it, history, hwc);
+ }
+ }
+ else {
+ vector<string> new_history(kHwcmOrder);
+ new_history[0] = head;
+ hwc[0][head]++;
+ for (size_t hist_idx = 0; hist_idx < kHwcmOrder-1; hist_idx++) {
+ if (!history[hist_idx].empty()) {
+ string chain = history[hist_idx] + " " + head;
+ hwc[hist_idx+1][chain]++;
+ if (hist_idx+2 < kHwcmOrder) {
+ new_history[hist_idx+1] = chain;
+ }
+ }
+ }
+ for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
+ extractHeadWordChain(*it, new_history, hwc);
+ }
+ }
+ }
+}
+
+string HwcmScorer::getHead(TreePointer tree) {
+ // assumption (only true for dependency parse: each constituent has a preterminal label, and corresponding terminal is head)
+ // if constituent has multiple preterminals, first one is picked; if it has no preterminals, empty string is returned
+ for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it)
+ {
+ TreePointer child = *it;
+
+ if (child->GetLength() == 1 && child->GetChildren()[0]->IsTerminal()) {
+ return child->GetChildren()[0]->GetLabel();
+ }
+ }
+ return "";
+
+}
+
+void HwcmScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
+{
+ if (sid >= m_ref_trees.size()) {
+ stringstream msg;
+ msg << "Sentence id (" << sid << ") not found in reference set";
+ throw runtime_error(msg.str());
+ }
+
+ string sentence = this->preprocessSentence(text);
+
+ // if sentence has '|||', assume that tree is in second position (n-best-list);
+ // otherwise, assume it is in first position (calling 'evaluate' with tree as reference)
+ util::TokenIter<util::MultiCharacter> it(sentence, util::MultiCharacter("|||"));
+ ++it;
+ if (it) {
+ sentence = it->as_string();
+ }
+
+ TreePointer tree (boost::make_shared<InternalTree>(sentence));
+ vector<map<string, int> > hwc_test (kHwcmOrder);
+ vector<string> history(kHwcmOrder);
+ extractHeadWordChain(tree, history, hwc_test);
+
+ ostringstream stats;
+ for (size_t i = 0; i < kHwcmOrder; i++) {
+ int correct = 0;
+ int test_total = 0;
+ for (map<string, int>::const_iterator it = hwc_test[i].begin(); it != hwc_test[i].end(); it++) {
+ test_total += it->second;
+ map<string, int>::const_iterator it2 = m_ref_hwc[sid][i].find(it->first);
+ if (it2 != m_ref_hwc[sid][i].end()) {
+ correct += std::min(it->second, it2->second);
+ }
+ }
+ stats << correct << " " << test_total << " " << m_ref_lengths[sid][i] << " " ;
+ }
+
+ string stats_str = stats.str();
+ entry.set(stats_str);
+}
+
+float HwcmScorer::calculateScore(const vector<int>& comps) const
+{
+ float precision = 0;
+ float recall = 0;
+ for (size_t i = 0; i < kHwcmOrder; i++) {
+ float matches = comps[i*3];
+ float test_total = comps[1+(i*3)];
+ float ref_total = comps[2+(i*3)];
+ if (test_total > 0) {
+ precision += matches/test_total;
+ }
+ if (ref_total > 0) {
+ recall += matches/ref_total;
+ }
+ }
+
+ precision /= (float)kHwcmOrder;
+ recall /= (float)kHwcmOrder;
+ return (2*precision*recall)/(precision+recall); // f1-score
+}
+
+} \ No newline at end of file