Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTetsuo Kiso <tetsuo-s@is.naist.jp>2012-03-18 00:58:40 +0400
committerTetsuo Kiso <tetsuo-s@is.naist.jp>2012-03-18 00:58:40 +0400
commit6b95a19eda818fb772767a0037c70a7bbb6c32e5 (patch)
treee1b7d608005bcc33ee00646263583e83d03a53a2
parent918bcafb808fe3067a4d689607bffb7dbbf0a914 (diff)
Create Reference class to clean up BleuScorer.
- Add an unit test for Reference. - Move functions to calculate the reference length from BleuScorer to Reference.
-rw-r--r--mert/BleuScorer.cpp99
-rw-r--r--mert/BleuScorer.h16
-rw-r--r--mert/Jamfile9
-rw-r--r--mert/Reference.h78
-rw-r--r--mert/ReferenceTest.cpp116
5 files changed, 236 insertions, 82 deletions
diff --git a/mert/BleuScorer.cpp b/mert/BleuScorer.cpp
index 4063d9acf..f143df66b 100644
--- a/mert/BleuScorer.cpp
+++ b/mert/BleuScorer.cpp
@@ -7,6 +7,7 @@
#include <iostream>
#include <stdexcept>
#include "Ngram.h"
+#include "Reference.h"
#include "Util.h"
namespace {
@@ -19,7 +20,6 @@ const char REFLEN_CLOSEST[] = "closest";
} // namespace
-
BleuScorer::BleuScorer(const string& config)
: StatisticsBasedScorer("BLEU", config),
m_ref_length_type(CLOSEST) {
@@ -60,9 +60,8 @@ size_t BleuScorer::countNgrams(const string& line, NgramCounts& counts,
void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles)
{
- //make sure reference data is clear
- m_ref_counts.reset();
- m_ref_lengths.clear();
+ // Make sure reference data is clear
+ m_references.reset();
ClearEncoder();
//load reference data
@@ -77,12 +76,10 @@ void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles)
while (getline(refin,line)) {
line = this->applyFactors(line);
if (i == 0) {
- NgramCounts *counts = new NgramCounts; //these get leaked
- m_ref_counts.push_back(counts);
- vector<size_t> lengths;
- m_ref_lengths.push_back(lengths);
+ Reference* ref = new Reference;
+ m_references.push_back(ref); // Take ownership of the Reference object.
}
- if (m_ref_counts.size() <= sid) {
+ if (m_references.size() <= sid) {
throw runtime_error("File " + referenceFiles[i] + " has too many sentences");
}
NgramCounts counts;
@@ -94,13 +91,13 @@ void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles)
const NgramCounts::Value newcount = ci->second;
NgramCounts::Value oldcount = 0;
- m_ref_counts[sid]->lookup(ngram, &oldcount);
+ m_references[sid]->get_counts()->lookup(ngram, &oldcount);
if (newcount > oldcount) {
- m_ref_counts[sid]->operator[](ngram) = newcount;
+ m_references[sid]->get_counts()->operator[](ngram) = newcount;
}
}
//add in the length
- m_ref_lengths[sid].push_back(length);
+ m_references[sid]->push_back(length);
if (sid > 0 && sid % 100 == 0) {
TRACE_ERR(".");
}
@@ -112,7 +109,7 @@ void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles)
void BleuScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
{
- if (sid >= m_ref_counts.size()) {
+ if (sid >= m_references.size()) {
stringstream msg;
msg << "Sentence id (" << sid << ") not found in reference set";
throw runtime_error(msg.str());
@@ -123,20 +120,8 @@ void BleuScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
string sentence = this->applyFactors(text);
const size_t length = countNgrams(sentence, testcounts, kBleuNgramOrder);
- // Calculate effective reference length.
- switch (m_ref_length_type) {
- case SHORTEST:
- CalcShortest(sid, stats);
- break;
- case AVERAGE:
- CalcAverage(sid, stats);
- break;
- case CLOSEST:
- CalcClosest(sid, length, stats);
- break;
- default:
- throw runtime_error("Unsupported reflength strategy");
- }
+ const int reference_len = CalcReferenceLength(sid, length);
+ stats.push_back(reference_len);
//precision on each ngram type
for (NgramCounts::const_iterator testcounts_it = testcounts.begin();
@@ -146,7 +131,7 @@ void BleuScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
NgramCounts::Value correct = 0;
NgramCounts::Value v = 0;
- if (m_ref_counts[sid]->lookup(testcounts_it->first, &v)) {
+ if (m_references[sid]->get_counts()->lookup(testcounts_it->first, &v)) {
correct = min(v, guess);
}
stats[len * 2 - 2] += correct;
@@ -174,6 +159,23 @@ float BleuScorer::calculateScore(const vector<int>& comps) const
return exp(logbleu);
}
+int BleuScorer::CalcReferenceLength(size_t sentence_id, size_t length) {
+ switch (m_ref_length_type) {
+ case AVERAGE:
+ return m_references[sentence_id]->CalcAverage();
+ break;
+ case CLOSEST:
+ return m_references[sentence_id]->CalcClosest(length);
+ break;
+ case SHORTEST:
+ return m_references[sentence_id]->CalcShortest();
+ break;
+ default:
+ cerr << "unknown reference types." << endl;
+ exit(1);
+ }
+}
+
void BleuScorer::dump_counts(ostream* os,
const NgramCounts& counts) const {
for (NgramCounts::const_iterator it = counts.begin();
@@ -191,44 +193,3 @@ void BleuScorer::dump_counts(ostream* os,
*os << endl;
}
-void BleuScorer::CalcAverage(size_t sentence_id,
- vector<ScoreStatsType>& stats) const {
- int total = 0;
- for (size_t i = 0;
- i < m_ref_lengths[sentence_id].size(); ++i) {
- total += m_ref_lengths[sentence_id][i];
- }
- const float mean = static_cast<float>(total) /
- m_ref_lengths[sentence_id].size();
- stats.push_back(static_cast<ScoreStatsType>(mean));
-}
-
-void BleuScorer::CalcClosest(size_t sentence_id,
- size_t length,
- vector<ScoreStatsType>& stats) const {
- int min_diff = INT_MAX;
- int min_idx = 0;
- for (size_t i = 0; i < m_ref_lengths[sentence_id].size(); ++i) {
- const int reflength = m_ref_lengths[sentence_id][i];
- const int length_diff = abs(reflength - static_cast<int>(length));
-
- // Look for the closest reference
- if (length_diff < abs(min_diff)) {
- min_diff = reflength - length;
- min_idx = i;
- // if two references has the same closest length, take the shortest
- } else if (length_diff == abs(min_diff)) {
- if (reflength < static_cast<int>(m_ref_lengths[sentence_id][min_idx])) {
- min_idx = i;
- }
- }
- }
- stats.push_back(m_ref_lengths[sentence_id][min_idx]);
-}
-
-void BleuScorer::CalcShortest(size_t sentence_id,
- vector<ScoreStatsType>& stats) const {
- const int shortest = *min_element(m_ref_lengths[sentence_id].begin(),
- m_ref_lengths[sentence_id].end());
- stats.push_back(shortest);
-}
diff --git a/mert/BleuScorer.h b/mert/BleuScorer.h
index c35d4ad1d..d58277a41 100644
--- a/mert/BleuScorer.h
+++ b/mert/BleuScorer.h
@@ -15,6 +15,7 @@ using namespace std;
const int kBleuNgramOrder = 4;
class NgramCounts;
+class Reference;
/**
* Bleu scoring
@@ -30,6 +31,8 @@ public:
virtual float calculateScore(const vector<int>& comps) const;
virtual size_t NumberOfScores() const { return 2 * kBleuNgramOrder + 1; }
+ int CalcReferenceLength(size_t sentence_id, size_t length);
+
private:
enum ReferenceLengthType {
AVERAGE,
@@ -44,19 +47,10 @@ private:
void dump_counts(std::ostream* os, const NgramCounts& counts) const;
- // For calculating effective reference length.
- void CalcAverage(size_t sentence_id,
- vector<ScoreStatsType>& stats) const;
- void CalcClosest(size_t sentence_id, size_t length,
- vector<ScoreStatsType>& stats) const;
- void CalcShortest(size_t sentence_id,
- vector<ScoreStatsType>& stats) const;
-
ReferenceLengthType m_ref_length_type;
- // data extracted from reference files
- ScopedVector<NgramCounts> m_ref_counts;
- vector<vector<size_t> > m_ref_lengths;
+ // reference translations.
+ ScopedVector<Reference> m_references;
// no copying allowed
BleuScorer(const BleuScorer&);
diff --git a/mert/Jamfile b/mert/Jamfile
index 47f52b1ab..8879253d5 100644
--- a/mert/Jamfile
+++ b/mert/Jamfile
@@ -6,9 +6,13 @@ lib mert_lib :
Util.cpp
FileStream.cpp
Timer.cpp
-ScoreStats.cpp ScoreArray.cpp ScoreData.cpp
+ScoreStats.cpp
+ScoreArray.cpp
+ScoreData.cpp
ScoreDataIterator.cpp
-FeatureStats.cpp FeatureArray.cpp FeatureData.cpp
+FeatureStats.cpp
+FeatureArray.cpp
+FeatureData.cpp
FeatureDataIterator.cpp
Data.cpp
BleuScorer.cpp
@@ -47,6 +51,7 @@ alias programs : mert extractor evaluator pro ;
unit-test feature_data_test : FeatureDataTest.cpp mert_lib ..//boost_unit_test_framework ;
unit-test data_test : DataTest.cpp mert_lib ..//boost_unit_test_framework ;
unit-test ngram_test : NgramTest.cpp mert_lib ..//boost_unit_test_framework ;
+unit-test reference_test : ReferenceTest.cpp mert_lib ..//boost_unit_test_framework ;
unit-test timer_test : TimerTest.cpp mert_lib ..//boost_unit_test_framework ;
unit-test util_test : UtilTest.cpp mert_lib ..//boost_unit_test_framework ;
diff --git a/mert/Reference.h b/mert/Reference.h
new file mode 100644
index 000000000..de5a6fecc
--- /dev/null
+++ b/mert/Reference.h
@@ -0,0 +1,78 @@
+#ifndef MERT_REFERENCE_H_
+#define MERT_REFERENCE_H_
+
+#include <algorithm>
+#include <climits>
+#include <iostream>
+#include <vector>
+
+#include "Ngram.h"
+
+// Refernece class is a reference translation for an output translation.
+class Reference {
+ public:
+ // for m_length
+ typedef std::vector<size_t>::iterator iterator;
+ typedef std::vector<size_t>::const_iterator const_iterator;
+
+ Reference() : m_counts(new NgramCounts) { }
+ ~Reference() { delete m_counts; }
+
+ NgramCounts* get_counts() { return m_counts; }
+ const NgramCounts* get_counts() const { return m_counts; }
+
+ iterator begin() { return m_length.begin(); }
+ const_iterator begin() const { return m_length.begin(); }
+ iterator end() { return m_length.end(); }
+ const_iterator end() const { return m_length.end(); }
+
+ void push_back(size_t len) { m_length.push_back(len); }
+
+ size_t num_references() const { return m_length.size(); }
+
+ int CalcAverage() const;
+ int CalcClosest(size_t length) const;
+ int CalcShortest() const;
+
+ private:
+ NgramCounts* m_counts;
+
+ // multiple reference lengths
+ std::vector<size_t> m_length;
+};
+
+inline int Reference::CalcAverage() const {
+ int total = 0;
+ for (size_t i = 0; i < m_length.size(); ++i) {
+ total += m_length[i];
+ }
+ return static_cast<int>(
+ static_cast<float>(total) / m_length.size());
+}
+
+inline int Reference::CalcClosest(size_t length) const {
+ int min_diff = INT_MAX;
+ int closest_ref_id = 0; // an index of the closest reference translation
+ for (size_t i = 0; i < m_length.size(); ++i) {
+ const int ref_length = m_length[i];
+ const int length_diff = abs(ref_length - static_cast<int>(length));
+ const int abs_min_diff = abs(min_diff);
+ // Look for the closest reference
+ if (length_diff < abs_min_diff) {
+ min_diff = ref_length - length;
+ closest_ref_id = i;
+ // if two references has the same closest length, take the shortest
+ } else if (length_diff == abs_min_diff) {
+ if (ref_length < static_cast<int>(m_length[closest_ref_id])) {
+ closest_ref_id = i;
+ }
+ }
+ }
+ return static_cast<int>(m_length[closest_ref_id]);
+}
+
+inline int Reference::CalcShortest() const {
+ return *std::min_element(m_length.begin(), m_length.end());
+}
+
+#endif // MERT_REFERENCE_H_
diff --git a/mert/ReferenceTest.cpp b/mert/ReferenceTest.cpp
new file mode 100644
index 000000000..454768195
--- /dev/null
+++ b/mert/ReferenceTest.cpp
@@ -0,0 +1,116 @@
+#include "Reference.h"
+
+#define BOOST_TEST_MODULE MertReference
+#include <boost/test/unit_test.hpp>
+
+BOOST_AUTO_TEST_CASE(refernece_count) {
+ Reference ref;
+ BOOST_CHECK(ref.get_counts() != NULL);
+}
+
+BOOST_AUTO_TEST_CASE(refernece_length_iterator) {
+ Reference ref;
+ ref.push_back(4);
+ ref.push_back(2);
+ BOOST_REQUIRE(ref.num_references() == 2);
+
+ Reference::iterator it = ref.begin();
+ BOOST_CHECK_EQUAL(*it, 4);
+ ++it;
+ BOOST_CHECK_EQUAL(*it, 2);
+ ++it;
+ BOOST_CHECK(it == ref.end());
+}
+
+BOOST_AUTO_TEST_CASE(refernece_length_average) {
+ {
+ Reference ref;
+ ref.push_back(4);
+ ref.push_back(1);
+ BOOST_CHECK_EQUAL(2, ref.CalcAverage());
+ }
+
+ {
+ Reference ref;
+ ref.push_back(4);
+ ref.push_back(3);
+ BOOST_CHECK_EQUAL(3, ref.CalcAverage());
+ }
+
+ {
+ Reference ref;
+ ref.push_back(4);
+ ref.push_back(3);
+ ref.push_back(4);
+ ref.push_back(5);
+ BOOST_CHECK_EQUAL(4, ref.CalcAverage());
+ }
+}
+
+BOOST_AUTO_TEST_CASE(refernece_length_closest) {
+ {
+ Reference ref;
+ ref.push_back(4);
+ ref.push_back(1);
+ BOOST_REQUIRE(ref.num_references() == 2);
+
+ BOOST_CHECK_EQUAL(1, ref.CalcClosest(2));
+ BOOST_CHECK_EQUAL(1, ref.CalcClosest(1));
+ BOOST_CHECK_EQUAL(4, ref.CalcClosest(3));
+ BOOST_CHECK_EQUAL(4, ref.CalcClosest(4));
+ BOOST_CHECK_EQUAL(4, ref.CalcClosest(5));
+ }
+
+ {
+ Reference ref;
+ ref.push_back(4);
+ ref.push_back(3);
+ BOOST_REQUIRE(ref.num_references() == 2);
+
+ BOOST_CHECK_EQUAL(3, ref.CalcClosest(1));
+ BOOST_CHECK_EQUAL(3, ref.CalcClosest(2));
+ BOOST_CHECK_EQUAL(3, ref.CalcClosest(3));
+ BOOST_CHECK_EQUAL(4, ref.CalcClosest(4));
+ BOOST_CHECK_EQUAL(4, ref.CalcClosest(5));
+ }
+
+ {
+ Reference ref;
+ ref.push_back(4);
+ ref.push_back(3);
+ ref.push_back(4);
+ ref.push_back(5);
+ BOOST_REQUIRE(ref.num_references() == 4);
+
+ BOOST_CHECK_EQUAL(3, ref.CalcClosest(1));
+ BOOST_CHECK_EQUAL(3, ref.CalcClosest(2));
+ BOOST_CHECK_EQUAL(3, ref.CalcClosest(3));
+ BOOST_CHECK_EQUAL(4, ref.CalcClosest(4));
+ BOOST_CHECK_EQUAL(5, ref.CalcClosest(5));
+ }
+}
+
+BOOST_AUTO_TEST_CASE(refernece_length_shortest) {
+ {
+ Reference ref;
+ ref.push_back(4);
+ ref.push_back(1);
+ BOOST_CHECK_EQUAL(1, ref.CalcShortest());
+ }
+
+ {
+ Reference ref;
+ ref.push_back(4);
+ ref.push_back(3);
+ BOOST_CHECK_EQUAL(3, ref.CalcShortest());
+ }
+
+ {
+ Reference ref;
+ ref.push_back(4);
+ ref.push_back(3);
+ ref.push_back(4);
+ ref.push_back(5);
+ BOOST_CHECK_EQUAL(3, ref.CalcShortest());
+ }
+}