Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorphikoehn <pkoehn@inf.ed.ac.uk>2013-05-31 15:28:57 +0400
committerphikoehn <pkoehn@inf.ed.ac.uk>2013-05-31 15:28:57 +0400
commitd1650a5aa7621f24b8bbf59da23071d45f145ad1 (patch)
tree6faad1cb2c19e8ff2d6230fc00b0d36171ed0b7b
parent68501f5a363d63cfcfd2a9fec3683bddbe0dd4e8 (diff)
basic support for alternate weight settings
-rw-r--r--moses/InputType.h14
-rw-r--r--moses/Manager.cpp15
-rw-r--r--moses/Parameter.cpp2
-rw-r--r--moses/Sentence.cpp8
-rw-r--r--moses/StaticData.cpp72
-rw-r--r--moses/StaticData.h24
-rw-r--r--moses/TranslationOptionCollection.cpp18
7 files changed, 140 insertions, 13 deletions
diff --git a/moses/InputType.h b/moses/InputType.h
index d0106e5ca..89efe9e7c 100644
--- a/moses/InputType.h
+++ b/moses/InputType.h
@@ -48,10 +48,12 @@ protected:
long m_translationId; //< contiguous Id
long m_documentId;
long m_topicId;
+ std::string m_weightSetting;
std::vector<std::string> m_topicIdAndProb;
bool m_useTopicId;
bool m_useTopicIdAndProb;
bool m_hasMetaData;
+ bool m_specifiesWeightSetting;
long m_segId;
ReorderingConstraint m_reorderingConstraint; /**< limits on reordering specified either by "-mp" switch or xml tags */
std::string m_textType;
@@ -109,6 +111,18 @@ public:
std::string GetTextType() const {
return m_textType;
}
+ void SetSpecifiesWeightSetting(bool specifiesWeightSetting) {
+ m_specifiesWeightSetting = specifiesWeightSetting;
+ }
+ bool GetSpecifiesWeightSetting() const {
+ return m_specifiesWeightSetting;
+ }
+ void SetWeightSetting(std::string settingName) {
+ m_weightSetting = settingName;
+ }
+ std::string GetWeightSetting() const {
+ return m_weightSetting;
+ }
void SetTextType(std::string type) {
m_textType = type;
}
diff --git a/moses/Manager.cpp b/moses/Manager.cpp
index 76809f224..bb2260030 100644
--- a/moses/Manager.cpp
+++ b/moses/Manager.cpp
@@ -80,7 +80,22 @@ void Manager::ProcessSentence()
{
// reset statistics
ResetSentenceStats(m_source);
+
+ // check if alternate weight setting is used
+ // this is not thread safe! it changes StaticData
+ if (StaticData::Instance().GetHasAlternateWeightSettings()) {
+ std::cerr << "config defines weight setting\n";
+ if (m_source.GetSpecifiesWeightSetting()) {
+ std::cerr << "sentence specifies weight setting\n";
+ std::cerr << "calling SetWeightSetting( " << m_source.GetWeightSetting() << ")\n";
+ StaticData::Instance().SetWeightSetting(m_source.GetWeightSetting());
+ }
+ else {
+ StaticData::Instance().SetWeightSetting("default");
+ }
+ }
+ // get translation options
Timer getOptionsTime;
getOptionsTime.start();
m_transOptColl->CreateTranslationOptions();
diff --git a/moses/Parameter.cpp b/moses/Parameter.cpp
index e16b6d08f..e58dc95e2 100644
--- a/moses/Parameter.cpp
+++ b/moses/Parameter.cpp
@@ -194,7 +194,7 @@ Parameter::Parameter()
AddParam("feature", "");
AddParam("print-id", "prefix translations with id. Default if false");
-
+ AddParam("alternate-weight-setting", "aws", "alternate set of weights to used per xml specification");
}
Parameter::~Parameter()
diff --git a/moses/Sentence.cpp b/moses/Sentence.cpp
index 8e76b0f03..40e218e56 100644
--- a/moses/Sentence.cpp
+++ b/moses/Sentence.cpp
@@ -110,6 +110,13 @@ int Sentence::Read(std::istream& in,const std::vector<FactorType>& factorOrder)
this->SetUseTopicIdAndProb(true);
}
}
+ if (meta.find("weight-setting") != meta.end()) {
+ this->SetWeightSetting(meta["weight-setting"]);
+ this->SetSpecifiesWeightSetting(true);
+ }
+ else {
+ this->SetSpecifiesWeightSetting(false);
+ }
// parse XML markup in translation line
//const StaticData &staticData = StaticData::Instance();
@@ -156,6 +163,7 @@ int Sentence::Read(std::istream& in,const std::vector<FactorType>& factorOrder)
}
+ // reordering walls and zones
m_reorderingConstraint.InitializeWalls( GetSize() );
// set reordering walls, if "-monotone-at-punction" is set
diff --git a/moses/StaticData.cpp b/moses/StaticData.cpp
index 85e015e36..67a311c09 100644
--- a/moses/StaticData.cpp
+++ b/moses/StaticData.cpp
@@ -570,6 +570,7 @@ bool StaticData::LoadData(Parameter *parameter)
vector<string> toks = Tokenize(line);
const string &feature = toks[0];
+ //int featureIndex = GetFeatureIndex(featureIndexMap, feature);
if (feature == "GlobalLexicalModel") {
GlobalLexicalModel *model = new GlobalLexicalModel(line);
@@ -706,7 +707,6 @@ bool StaticData::LoadData(Parameter *parameter)
UserMessage::Add("Unknown feature function:" + feature);
return false;
}
-
}
CollectFeatureFunctions();
@@ -738,6 +738,10 @@ bool StaticData::LoadData(Parameter *parameter)
//cerr << endl << "m_allWeights=" << m_allWeights << endl;
+ // alternate weight settings
+ if (m_parameter->GetParam("alternate-weight-setting").size() > 0) {
+ ProcessAlternateWeightSettings();
+ }
return true;
}
@@ -1181,6 +1185,70 @@ bool StaticData::CheckWeights() const
return true;
}
-} // namespace
+void StaticData::ProcessAlternateWeightSettings() {
+ const vector<string> &weightSpecification = m_parameter->GetParam("alternate-weight-setting");
+
+ // get mapping from feature names to feature functions
+ map<string,FeatureFunction*> nameToFF;
+ const std::vector<FeatureFunction*> &ffs = FeatureFunction::GetFeatureFunctions();
+ for (size_t i = 0; i < ffs.size(); ++i) {
+ nameToFF[ ffs[i]->GetScoreProducerDescription() ] = ffs[i];
+ }
+
+ // copy main weight setting as default
+ m_weightSetting["default"] = new ScoreComponentCollection( m_allWeights );
+
+ // go through specification in config file
+ string currentId = "";
+ bool hasErrors = false;
+ for (size_t i=0; i<weightSpecification.size(); ++i) {
+
+ // identifier line (with optional additional specifications)
+ if (weightSpecification[i].find("id=") == 0) {
+ vector<string> tokens = Tokenize(weightSpecification[i]);
+ vector<string> args = Tokenize(tokens[0], "=");
+ currentId = args[1];
+ cerr << "alternate weight setting " << currentId << endl;
+ CHECK(m_weightSetting.find(currentId) == m_weightSetting.end());
+ m_weightSetting[ currentId ] = new ScoreComponentCollection;
+
+ // other specifications
+ for(size_t j=1; j<tokens.size(); j++) {
+ vector<string> args = Tokenize(tokens[j], "=");
+ if (args[0] == "weight-file") {
+ // TODO: support for sparse weights
+ }
+ }
+ }
+
+ // weight lines
+ else {
+ CHECK(currentId != "");
+ vector<string> tokens = Tokenize(weightSpecification[i]);
+ CHECK(tokens.size() >= 2);
+
+ // get name and weight values
+ string name = tokens[0];
+ name = name.substr(0, name.size() - 1); // remove trailing "="
+ vector<float> weights(tokens.size() - 1);
+ for (size_t i = 1; i < tokens.size(); ++i) {
+ float weight = Scan<float>(tokens[i]);
+ weights[i - 1] = weight;
+ }
+ // check if a valid nane
+ map<string,FeatureFunction*>::iterator ffLookUp = nameToFF.find(name);
+ if (ffLookUp == nameToFF.end()) {
+ cerr << "ERROR: alternate weight setting " << currentId << " specifies weight(s) for " << name << " but there is no such feature function" << endl;
+ hasErrors = true;
+ }
+ else {
+ m_weightSetting[ currentId ]->Assign( nameToFF[name], weights);
+ }
+ }
+ }
+ CHECK(!hasErrors);
+}
+
+} // namespace
diff --git a/moses/StaticData.h b/moses/StaticData.h
index 5a1cec213..ac847d944 100644
--- a/moses/StaticData.h
+++ b/moses/StaticData.h
@@ -75,7 +75,7 @@ protected:
std::vector<const GenerationDictionary*> m_generationDictionary;
Parameter *m_parameter;
std::vector<FactorType> m_inputFactorOrder, m_outputFactorOrder;
- ScoreComponentCollection m_allWeights;
+ mutable ScoreComponentCollection m_allWeights;
std::vector<DecodeGraph*> m_decodeGraphs;
std::vector<size_t> m_decodeGraphBackoff;
@@ -206,6 +206,9 @@ protected:
int m_threadCount;
long m_startTranslationId;
+
+ // alternate weight settings
+ std::map< std::string, ScoreComponentCollection* > m_weightSetting;
StaticData();
@@ -658,6 +661,24 @@ public:
return m_nBestIncludesSegmentation;
}
+ bool GetHasAlternateWeightSettings() const {
+ return m_weightSetting.size() > 0;
+ }
+
+ void SetWeightSetting(const std::string &settingName) const {
+ std::cerr << "SetWeightSetting( " << settingName << ")\n";
+ CHECK(GetHasAlternateWeightSettings());
+ std::map< std::string, ScoreComponentCollection* >::const_iterator i =
+ m_weightSetting.find( settingName );
+ // if not found, resort to default
+ std::cerr << "using weight setting " << settingName << std::endl;
+ if (i == m_weightSetting.end()) {
+ i = m_weightSetting.find( "default" );
+ std::cerr << "not found, using default weight setting instead\n";
+ }
+ m_allWeights = *(i->second);
+ }
+
float GetWeightWordPenalty() const;
float GetWeightUnknownWordPenalty() const;
@@ -688,6 +709,7 @@ public:
void CollectFeatureFunctions();
bool CheckWeights() const;
+ void ProcessAlternateWeightSettings();
void SetTemporaryMultiModelWeightsVector(std::vector<float> weights) const {
diff --git a/moses/TranslationOptionCollection.cpp b/moses/TranslationOptionCollection.cpp
index 2d7024c7a..20a51f5a3 100644
--- a/moses/TranslationOptionCollection.cpp
+++ b/moses/TranslationOptionCollection.cpp
@@ -385,10 +385,12 @@ void TranslationOptionCollection::CreateTranslationOptions()
// ... and that end at endPos
for (size_t endPos = startPos ; endPos < startPos + maxSize ; endPos++) {
if (graphInd > 0 && // only skip subsequent graphs
- decodeGraphBackoff[graphInd] != 0 && // use of backoff specified
- (endPos-startPos+1 >= decodeGraphBackoff[graphInd] || // size exceeds backoff limit or ...
- m_collection[startPos][endPos-startPos].size() > 0)) { // no phrases found so far
- VERBOSE(3,"No backoff to graph " << graphInd << " for span [" << startPos << ";" << endPos << "]" << endl);
+ decodeGraphBackoff[graphInd] != 0 && // limited use of backoff specified
+ (endPos-startPos+1 > decodeGraphBackoff[graphInd] || // size exceeds backoff limit or ...
+ m_collection[startPos][endPos-startPos].size() > 0)) { // already covered
+ VERBOSE(3,"No backoff to graph " << graphInd << " for span [" << startPos << ";" << endPos << "]");
+ VERBOSE(3,", length limit: " << decodeGraphBackoff[graphInd]);
+ VERBOSE(3,", found so far: " << m_collection[startPos][endPos-startPos].size() << endl);
// do not create more options
continue;
}
@@ -505,11 +507,10 @@ void TranslationOptionCollection::CreateTranslationOptionsForRange(
, startPos, endPos, adhereTableLimit );
// do rest of decode steps
- int indexStep = 0;
+ int indexStep = 1;
- for (++iterStep ; iterStep != decodeGraph.end() ; ++iterStep) {
-
- const DecodeStep &decodeStep = **iterStep;
+ for (++iterStep; iterStep != decodeGraph.end() ; ++iterStep, ++indexStep) {
+ const DecodeStep &decodeStep = **iterStep;
PartialTranslOptColl* newPtoc = new PartialTranslOptColl;
// go thru each intermediate trans opt just created
@@ -531,7 +532,6 @@ void TranslationOptionCollection::CreateTranslationOptionsForRange(
delete oldPtoc;
oldPtoc = newPtoc;
- indexStep++;
} // for (++iterStep
// add to fully formed translation option list