diff options
author | phikoehn <pkoehn@inf.ed.ac.uk> | 2013-05-31 15:28:57 +0400 |
---|---|---|
committer | phikoehn <pkoehn@inf.ed.ac.uk> | 2013-05-31 15:28:57 +0400 |
commit | d1650a5aa7621f24b8bbf59da23071d45f145ad1 (patch) | |
tree | 6faad1cb2c19e8ff2d6230fc00b0d36171ed0b7b | |
parent | 68501f5a363d63cfcfd2a9fec3683bddbe0dd4e8 (diff) |
basic support for alternate weight settings
-rw-r--r-- | moses/InputType.h | 14 | ||||
-rw-r--r-- | moses/Manager.cpp | 15 | ||||
-rw-r--r-- | moses/Parameter.cpp | 2 | ||||
-rw-r--r-- | moses/Sentence.cpp | 8 | ||||
-rw-r--r-- | moses/StaticData.cpp | 72 | ||||
-rw-r--r-- | moses/StaticData.h | 24 | ||||
-rw-r--r-- | moses/TranslationOptionCollection.cpp | 18 |
7 files changed, 140 insertions, 13 deletions
diff --git a/moses/InputType.h b/moses/InputType.h index d0106e5ca..89efe9e7c 100644 --- a/moses/InputType.h +++ b/moses/InputType.h @@ -48,10 +48,12 @@ protected: long m_translationId; //< contiguous Id long m_documentId; long m_topicId; + std::string m_weightSetting; std::vector<std::string> m_topicIdAndProb; bool m_useTopicId; bool m_useTopicIdAndProb; bool m_hasMetaData; + bool m_specifiesWeightSetting; long m_segId; ReorderingConstraint m_reorderingConstraint; /**< limits on reordering specified either by "-mp" switch or xml tags */ std::string m_textType; @@ -109,6 +111,18 @@ public: std::string GetTextType() const { return m_textType; } + void SetSpecifiesWeightSetting(bool specifiesWeightSetting) { + m_specifiesWeightSetting = specifiesWeightSetting; + } + bool GetSpecifiesWeightSetting() const { + return m_specifiesWeightSetting; + } + void SetWeightSetting(std::string settingName) { + m_weightSetting = settingName; + } + std::string GetWeightSetting() const { + return m_weightSetting; + } void SetTextType(std::string type) { m_textType = type; } diff --git a/moses/Manager.cpp b/moses/Manager.cpp index 76809f224..bb2260030 100644 --- a/moses/Manager.cpp +++ b/moses/Manager.cpp @@ -80,7 +80,22 @@ void Manager::ProcessSentence() { // reset statistics ResetSentenceStats(m_source); + + // check if alternate weight setting is used + // this is not thread safe! it changes StaticData + if (StaticData::Instance().GetHasAlternateWeightSettings()) { + std::cerr << "config defines weight setting\n"; + if (m_source.GetSpecifiesWeightSetting()) { + std::cerr << "sentence specifies weight setting\n"; + std::cerr << "calling SetWeightSetting( " << m_source.GetWeightSetting() << ")\n"; + StaticData::Instance().SetWeightSetting(m_source.GetWeightSetting()); + } + else { + StaticData::Instance().SetWeightSetting("default"); + } + } + // get translation options Timer getOptionsTime; getOptionsTime.start(); m_transOptColl->CreateTranslationOptions(); diff --git a/moses/Parameter.cpp b/moses/Parameter.cpp index e16b6d08f..e58dc95e2 100644 --- a/moses/Parameter.cpp +++ b/moses/Parameter.cpp @@ -194,7 +194,7 @@ Parameter::Parameter() AddParam("feature", ""); AddParam("print-id", "prefix translations with id. Default if false"); - + AddParam("alternate-weight-setting", "aws", "alternate set of weights to used per xml specification"); } Parameter::~Parameter() diff --git a/moses/Sentence.cpp b/moses/Sentence.cpp index 8e76b0f03..40e218e56 100644 --- a/moses/Sentence.cpp +++ b/moses/Sentence.cpp @@ -110,6 +110,13 @@ int Sentence::Read(std::istream& in,const std::vector<FactorType>& factorOrder) this->SetUseTopicIdAndProb(true); } } + if (meta.find("weight-setting") != meta.end()) { + this->SetWeightSetting(meta["weight-setting"]); + this->SetSpecifiesWeightSetting(true); + } + else { + this->SetSpecifiesWeightSetting(false); + } // parse XML markup in translation line //const StaticData &staticData = StaticData::Instance(); @@ -156,6 +163,7 @@ int Sentence::Read(std::istream& in,const std::vector<FactorType>& factorOrder) } + // reordering walls and zones m_reorderingConstraint.InitializeWalls( GetSize() ); // set reordering walls, if "-monotone-at-punction" is set diff --git a/moses/StaticData.cpp b/moses/StaticData.cpp index 85e015e36..67a311c09 100644 --- a/moses/StaticData.cpp +++ b/moses/StaticData.cpp @@ -570,6 +570,7 @@ bool StaticData::LoadData(Parameter *parameter) vector<string> toks = Tokenize(line); const string &feature = toks[0]; + //int featureIndex = GetFeatureIndex(featureIndexMap, feature); if (feature == "GlobalLexicalModel") { GlobalLexicalModel *model = new GlobalLexicalModel(line); @@ -706,7 +707,6 @@ bool StaticData::LoadData(Parameter *parameter) UserMessage::Add("Unknown feature function:" + feature); return false; } - } CollectFeatureFunctions(); @@ -738,6 +738,10 @@ bool StaticData::LoadData(Parameter *parameter) //cerr << endl << "m_allWeights=" << m_allWeights << endl; + // alternate weight settings + if (m_parameter->GetParam("alternate-weight-setting").size() > 0) { + ProcessAlternateWeightSettings(); + } return true; } @@ -1181,6 +1185,70 @@ bool StaticData::CheckWeights() const return true; } -} // namespace +void StaticData::ProcessAlternateWeightSettings() { + const vector<string> &weightSpecification = m_parameter->GetParam("alternate-weight-setting"); + + // get mapping from feature names to feature functions + map<string,FeatureFunction*> nameToFF; + const std::vector<FeatureFunction*> &ffs = FeatureFunction::GetFeatureFunctions(); + for (size_t i = 0; i < ffs.size(); ++i) { + nameToFF[ ffs[i]->GetScoreProducerDescription() ] = ffs[i]; + } + + // copy main weight setting as default + m_weightSetting["default"] = new ScoreComponentCollection( m_allWeights ); + + // go through specification in config file + string currentId = ""; + bool hasErrors = false; + for (size_t i=0; i<weightSpecification.size(); ++i) { + + // identifier line (with optional additional specifications) + if (weightSpecification[i].find("id=") == 0) { + vector<string> tokens = Tokenize(weightSpecification[i]); + vector<string> args = Tokenize(tokens[0], "="); + currentId = args[1]; + cerr << "alternate weight setting " << currentId << endl; + CHECK(m_weightSetting.find(currentId) == m_weightSetting.end()); + m_weightSetting[ currentId ] = new ScoreComponentCollection; + + // other specifications + for(size_t j=1; j<tokens.size(); j++) { + vector<string> args = Tokenize(tokens[j], "="); + if (args[0] == "weight-file") { + // TODO: support for sparse weights + } + } + } + + // weight lines + else { + CHECK(currentId != ""); + vector<string> tokens = Tokenize(weightSpecification[i]); + CHECK(tokens.size() >= 2); + + // get name and weight values + string name = tokens[0]; + name = name.substr(0, name.size() - 1); // remove trailing "=" + vector<float> weights(tokens.size() - 1); + for (size_t i = 1; i < tokens.size(); ++i) { + float weight = Scan<float>(tokens[i]); + weights[i - 1] = weight; + } + // check if a valid nane + map<string,FeatureFunction*>::iterator ffLookUp = nameToFF.find(name); + if (ffLookUp == nameToFF.end()) { + cerr << "ERROR: alternate weight setting " << currentId << " specifies weight(s) for " << name << " but there is no such feature function" << endl; + hasErrors = true; + } + else { + m_weightSetting[ currentId ]->Assign( nameToFF[name], weights); + } + } + } + CHECK(!hasErrors); +} + +} // namespace diff --git a/moses/StaticData.h b/moses/StaticData.h index 5a1cec213..ac847d944 100644 --- a/moses/StaticData.h +++ b/moses/StaticData.h @@ -75,7 +75,7 @@ protected: std::vector<const GenerationDictionary*> m_generationDictionary; Parameter *m_parameter; std::vector<FactorType> m_inputFactorOrder, m_outputFactorOrder; - ScoreComponentCollection m_allWeights; + mutable ScoreComponentCollection m_allWeights; std::vector<DecodeGraph*> m_decodeGraphs; std::vector<size_t> m_decodeGraphBackoff; @@ -206,6 +206,9 @@ protected: int m_threadCount; long m_startTranslationId; + + // alternate weight settings + std::map< std::string, ScoreComponentCollection* > m_weightSetting; StaticData(); @@ -658,6 +661,24 @@ public: return m_nBestIncludesSegmentation; } + bool GetHasAlternateWeightSettings() const { + return m_weightSetting.size() > 0; + } + + void SetWeightSetting(const std::string &settingName) const { + std::cerr << "SetWeightSetting( " << settingName << ")\n"; + CHECK(GetHasAlternateWeightSettings()); + std::map< std::string, ScoreComponentCollection* >::const_iterator i = + m_weightSetting.find( settingName ); + // if not found, resort to default + std::cerr << "using weight setting " << settingName << std::endl; + if (i == m_weightSetting.end()) { + i = m_weightSetting.find( "default" ); + std::cerr << "not found, using default weight setting instead\n"; + } + m_allWeights = *(i->second); + } + float GetWeightWordPenalty() const; float GetWeightUnknownWordPenalty() const; @@ -688,6 +709,7 @@ public: void CollectFeatureFunctions(); bool CheckWeights() const; + void ProcessAlternateWeightSettings(); void SetTemporaryMultiModelWeightsVector(std::vector<float> weights) const { diff --git a/moses/TranslationOptionCollection.cpp b/moses/TranslationOptionCollection.cpp index 2d7024c7a..20a51f5a3 100644 --- a/moses/TranslationOptionCollection.cpp +++ b/moses/TranslationOptionCollection.cpp @@ -385,10 +385,12 @@ void TranslationOptionCollection::CreateTranslationOptions() // ... and that end at endPos for (size_t endPos = startPos ; endPos < startPos + maxSize ; endPos++) { if (graphInd > 0 && // only skip subsequent graphs - decodeGraphBackoff[graphInd] != 0 && // use of backoff specified - (endPos-startPos+1 >= decodeGraphBackoff[graphInd] || // size exceeds backoff limit or ... - m_collection[startPos][endPos-startPos].size() > 0)) { // no phrases found so far - VERBOSE(3,"No backoff to graph " << graphInd << " for span [" << startPos << ";" << endPos << "]" << endl); + decodeGraphBackoff[graphInd] != 0 && // limited use of backoff specified + (endPos-startPos+1 > decodeGraphBackoff[graphInd] || // size exceeds backoff limit or ... + m_collection[startPos][endPos-startPos].size() > 0)) { // already covered + VERBOSE(3,"No backoff to graph " << graphInd << " for span [" << startPos << ";" << endPos << "]"); + VERBOSE(3,", length limit: " << decodeGraphBackoff[graphInd]); + VERBOSE(3,", found so far: " << m_collection[startPos][endPos-startPos].size() << endl); // do not create more options continue; } @@ -505,11 +507,10 @@ void TranslationOptionCollection::CreateTranslationOptionsForRange( , startPos, endPos, adhereTableLimit ); // do rest of decode steps - int indexStep = 0; + int indexStep = 1; - for (++iterStep ; iterStep != decodeGraph.end() ; ++iterStep) { - - const DecodeStep &decodeStep = **iterStep; + for (++iterStep; iterStep != decodeGraph.end() ; ++iterStep, ++indexStep) { + const DecodeStep &decodeStep = **iterStep; PartialTranslOptColl* newPtoc = new PartialTranslOptColl; // go thru each intermediate trans opt just created @@ -531,7 +532,6 @@ void TranslationOptionCollection::CreateTranslationOptionsForRange( delete oldPtoc; oldPtoc = newPtoc; - indexStep++; } // for (++iterStep // add to fully formed translation option list |