Welcome to mirror list, hosted at ThFree Co, Russian Federation.

HwcmScorer.cpp « mert - github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: bb3cd4382c11e5ecd8fa2c39118b6b2043baed14 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#include "HwcmScorer.h"

#include <fstream>

#include "ScoreStats.h"
#include "Util.h"

#include "util/tokenize_piece.hh"

// HWCM score (Liu and Gildea, 2005). Implements F1 instead of precision for better modelling of hypothesis length.
// assumes dependency trees on target side (generated by scripts/training/wrappers/conll2mosesxml.py ; use with option --brackets for reference).
// reads reference trees from separate file {REFERENCE_FILE}.trees to support mix of string-based and tree-based metrics.

using namespace std;

namespace MosesTuning
{


HwcmScorer::HwcmScorer(const string& config)
  : StatisticsBasedScorer("HWCM",config) {}

HwcmScorer::~HwcmScorer() {}

void HwcmScorer::setReferenceFiles(const vector<string>& referenceFiles)
{
  // For each line in the reference file, create a tree object
  if (referenceFiles.size() != 1) {
    throw runtime_error("HWCM only supports a single reference");
  }
  m_ref_trees.clear();
  m_ref_hwc.clear();
  ifstream in((referenceFiles[0] + ".trees").c_str());
  if (!in) {
    throw runtime_error("Unable to open " + referenceFiles[0] + ".trees");
  }
  string line;
  while (getline(in,line)) {
    line = this->preprocessSentence(line);
    TreePointer tree (boost::make_shared<InternalTree>(line));
    m_ref_trees.push_back(tree);
    vector<map<string, int> > hwc (kHwcmOrder);
    vector<string> history(kHwcmOrder);
    extractHeadWordChain(tree, history, hwc);
    m_ref_hwc.push_back(hwc);
    vector<int> totals(kHwcmOrder);
    for (size_t i = 0; i < kHwcmOrder; i++) {
      for (map<string, int>::const_iterator it = m_ref_hwc.back()[i].begin(); it != m_ref_hwc.back()[i].end(); it++) {
        totals[i] += it->second;
      }
    }
    m_ref_lengths.push_back(totals);
  }
  TRACE_ERR(endl);

}

void HwcmScorer::extractHeadWordChain(TreePointer tree, vector<string> & history, vector<map<string, int> > & hwc)
{

  if (tree->GetLength() > 0) {
    string head = getHead(tree);

    if (head.empty()) {
      for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
        extractHeadWordChain(*it, history, hwc);
      }
    } else {
      vector<string> new_history(kHwcmOrder);
      new_history[0] = head;
      hwc[0][head]++;
      for (size_t hist_idx = 0; hist_idx < kHwcmOrder-1; hist_idx++) {
        if (!history[hist_idx].empty()) {
          string chain = history[hist_idx] + " " + head;
          hwc[hist_idx+1][chain]++;
          if (hist_idx+2 < kHwcmOrder) {
            new_history[hist_idx+1] = chain;
          }
        }
      }
      for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
        extractHeadWordChain(*it, new_history, hwc);
      }
    }
  }
}

string HwcmScorer::getHead(TreePointer tree)
{
  // assumption (only true for dependency parse: each constituent has a preterminal label, and corresponding terminal is head)
  // if constituent has multiple preterminals, first one is picked; if it has no preterminals, empty string is returned
  for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
    TreePointer child = *it;

    if (child->GetLength() == 1 && child->GetChildren()[0]->IsTerminal()) {
      return child->GetChildren()[0]->GetLabel();
    }
  }
  return "";

}

void HwcmScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
{
  if (sid >= m_ref_trees.size()) {
    stringstream msg;
    msg << "Sentence id (" << sid << ") not found in reference set";
    throw runtime_error(msg.str());
  }

  string sentence = this->preprocessSentence(text);

  // if sentence has '|||', assume that tree is in second position (n-best-list);
  // otherwise, assume it is in first position (calling 'evaluate' with tree as reference)
  util::TokenIter<util::MultiCharacter> it(sentence, util::MultiCharacter("|||"));
  ++it;
  if (it) {
    sentence = it->as_string();
  }

  TreePointer tree (boost::make_shared<InternalTree>(sentence));
  vector<map<string, int> > hwc_test (kHwcmOrder);
  vector<string> history(kHwcmOrder);
  extractHeadWordChain(tree, history, hwc_test);

  ostringstream stats;
  for (size_t i = 0; i < kHwcmOrder; i++) {
    int correct = 0;
    int test_total = 0;
    for (map<string, int>::const_iterator it = hwc_test[i].begin(); it != hwc_test[i].end(); it++) {
      test_total += it->second;
      map<string, int>::const_iterator it2 = m_ref_hwc[sid][i].find(it->first);
      if (it2 != m_ref_hwc[sid][i].end()) {
        correct += std::min(it->second, it2->second);
      }
    }
    stats << correct << " " << test_total << " " << m_ref_lengths[sid][i] << " " ;
  }

  string stats_str = stats.str();
  entry.set(stats_str);
}

float HwcmScorer::calculateScore(const vector<ScoreStatsType>& comps) const
{
  float precision = 0;
  float recall = 0;
  for (size_t i = 0; i < kHwcmOrder; i++) {
    float matches = comps[i*3];
    float test_total = comps[1+(i*3)];
    float ref_total = comps[2+(i*3)];
    if (test_total > 0) {
      precision += matches/test_total;
    }
    if (ref_total > 0) {
      recall += matches/ref_total;
    }
  }

  precision /= (float)kHwcmOrder;
  recall /= (float)kHwcmOrder;
  return (2*precision*recall)/(precision+recall); // f1-score
}

}