diff options
author | Kenneth Heafield <github@kheafield.com> | 2013-04-25 22:42:30 +0400 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2013-04-25 22:42:30 +0400 |
commit | f1d366381033c0caae18f8d15305ded38734bdbf (patch) | |
tree | 22b0cbd3acc337a995701629bf9facbe179f5618 | |
parent | 8a1e944bb428a0af9f6c82c26e5633361ce4052c (diff) |
Back FactorCollection with a memory pool. Less memory for large vocabularies.
27 files changed, 169 insertions, 179 deletions
diff --git a/OnDiskPt/OnDiskWrapper.cpp b/OnDiskPt/OnDiskWrapper.cpp index 743a77db1..3a1773c0a 100644 --- a/OnDiskPt/OnDiskWrapper.cpp +++ b/OnDiskPt/OnDiskWrapper.cpp @@ -207,8 +207,7 @@ Word *OnDiskWrapper::ConvertFromMoses(Moses::FactorDirection /* direction */ size_t factorType = factorsVec[0]; const Moses::Factor *factor = origWord.GetFactor(factorType); CHECK(factor); - string str = factor->GetString(); - strme << str; + strme << factor->GetString(); for (size_t ind = 1 ; ind < factorsVec.size() ; ++ind) { size_t factorType = factorsVec[ind]; @@ -218,8 +217,7 @@ Word *OnDiskWrapper::ConvertFromMoses(Moses::FactorDirection /* direction */ break; } CHECK(factor); - string str = factor->GetString(); - strme << "|" << str; + strme << "|" << factor->GetString(); } // for (size_t factorType bool found; diff --git a/moses/ChartParser.cpp b/moses/ChartParser.cpp index ea55a46a6..5331a5fe4 100644 --- a/moses/ChartParser.cpp +++ b/moses/ChartParser.cpp @@ -49,8 +49,7 @@ void ChartParserUnknown::Process(const Word &sourceWord, const WordsRange &range size_t isDigit = 0; if (staticData.GetDropUnknown()) { const Factor *f = sourceWord[0]; // TODO hack. shouldn't know which factor is surface - const string &s = f->GetString(); - isDigit = s.find_first_of("0123456789"); + isDigit = f->GetString().find_first_of("0123456789"); if (isDigit == string::npos) isDigit = 0; else diff --git a/moses/Factor.h b/moses/Factor.h index ac1b591ed..87e8f8028 100644 --- a/moses/Factor.h +++ b/moses/Factor.h @@ -26,6 +26,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA #include <string> #include "TypeDef.h" #include "Util.h" +#include "util/string_piece.hh" namespace Moses { @@ -44,8 +45,9 @@ class Factor friend class FactorCollection; friend struct FactorFriend; - // FactorCollection writes here. - std::string m_string; + // FactorCollection writes here. + // This is mutable so the pointer can be changed to pool-backed memory. + mutable StringPiece m_string; size_t m_id; //! protected constructor. only friend class, FactorCollection, is allowed to create Factor objects @@ -59,7 +61,7 @@ class Factor public: //! original string representation of the factor - inline const std::string &GetString() const { + StringPiece GetString() const { return m_string; } //! contiguous ID diff --git a/moses/FactorCollection.cpp b/moses/FactorCollection.cpp index 849830f4d..969bb39d1 100644 --- a/moses/FactorCollection.cpp +++ b/moses/FactorCollection.cpp @@ -27,6 +27,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA #include <string> #include "FactorCollection.h" #include "Util.h" +#include "util/pool.hh" using namespace std; @@ -36,42 +37,23 @@ FactorCollection FactorCollection::s_instance; const Factor *FactorCollection::AddFactor(const StringPiece &factorString) { -// Sorry this is so complicated. Can't we just require everybody to use Boost >= 1.42? The issue is that I can't check BOOST_VERSION unless we have Boost. -#ifdef WITH_THREADS - -#if BOOST_VERSION < 104200 FactorFriend to_ins; - to_ins.in.m_string.assign(factorString.data(), factorString.size()); -#endif // BOOST_VERSION + to_ins.in.m_string = factorString; + to_ins.in.m_id = m_factorId; + // If we're threaded, hope a read-only lock is sufficient. +#ifdef WITH_THREADS { // read=lock scope boost::shared_lock<boost::shared_mutex> read_lock(m_accessLock); -#if BOOST_VERSION >= 104200 - // If this line doesn't compile, upgrade your Boost. - Set::const_iterator i = m_set.find(factorString, HashFactor(), EqualsFactor()); -#else // BOOST_VERSION Set::const_iterator i = m_set.find(to_ins); -#endif // BOOST_VERSION if (i != m_set.end()) return &i->in; } boost::unique_lock<boost::shared_mutex> lock(m_accessLock); -#if BOOST_VERSION >= 104200 - FactorFriend to_ins; - to_ins.in.m_string.assign(factorString.data(), factorString.size()); -#endif // BOOST_VERSION - -#else // WITH_THREADS - -#if BOOST_VERSION >= 104200 - Set::const_iterator i = m_set.find(factorString, HashFactor(), EqualsFactor()); - if (i != m_set.end()) return &i->in; -#endif - FactorFriend to_ins; - to_ins.in.m_string.assign(factorString.data(), factorString.size()); - #endif // WITH_THREADS - to_ins.in.m_id = m_factorId; std::pair<Set::iterator, bool> ret(m_set.insert(to_ins)); if (ret.second) { + ret.first->in.m_string.set( + memcpy(m_string_backing.Allocate(factorString.size()), factorString.data(), factorString.size()), + factorString.size()); m_factorId++; } return &ret.first->in; diff --git a/moses/FactorCollection.h b/moses/FactorCollection.h index 9a01766f4..e7749244f 100644 --- a/moses/FactorCollection.h +++ b/moses/FactorCollection.h @@ -33,6 +33,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA #include <string> #include "util/string_piece.hh" +#include "util/pool.hh" #include "Factor.h" namespace Moses @@ -62,27 +63,20 @@ class FactorCollection friend std::ostream& operator<<(std::ostream&, const FactorCollection&); struct HashFactor : public std::unary_function<const FactorFriend &, std::size_t> { - std::size_t operator()(const StringPiece &str) const { - return util::MurmurHashNative(str.data(), str.size()); - } std::size_t operator()(const FactorFriend &factor) const { - return (*this)(factor.in.GetString()); + return util::MurmurHashNative(factor.in.m_string.data(), factor.in.m_string.size()); } }; struct EqualsFactor : public std::binary_function<const FactorFriend &, const FactorFriend &, bool> { bool operator()(const FactorFriend &left, const FactorFriend &right) const { return left.in.GetString() == right.in.GetString(); } - bool operator()(const FactorFriend &left, const StringPiece &right) const { - return left.in.GetString() == right; - } - bool operator()(const StringPiece &left, const FactorFriend &right) const { - return left == right.in.GetString(); - } }; typedef boost::unordered_set<FactorFriend, HashFactor, EqualsFactor> Set; Set m_set; + util::Pool m_string_backing; + static FactorCollection s_instance; #ifdef WITH_THREADS //reader-writer lock @@ -117,6 +111,5 @@ public: }; - } #endif diff --git a/moses/FeatureVector.cpp b/moses/FeatureVector.cpp index c01775fd1..6cbddf3d7 100644 --- a/moses/FeatureVector.cpp +++ b/moses/FeatureVector.cpp @@ -26,6 +26,7 @@ #include <stdexcept> #include "FeatureVector.h" +#include "util/string_piece_hash.hh" using namespace std; @@ -41,12 +42,12 @@ namespace Moses { boost::shared_mutex FName::m_idLock; #endif - void FName::init(const string& name) { + void FName::init(const StringPiece &name) { #ifdef WITH_THREADS //reader lock boost::shared_lock<boost::shared_mutex> lock(m_idLock); #endif - Name2Id::iterator i = name2id.find(name); + Name2Id::iterator i = FindStringPiece(name2id, name); if (i != name2id.end()) { m_id = i->second; } else { @@ -55,15 +56,15 @@ namespace Moses { lock.unlock(); boost::unique_lock<boost::shared_mutex> write_lock(m_idLock); #endif - //Need to check again if the id is in the map, as someone may have added - //it while we were waiting on the writer lock. - if (i != name2id.end()) { - m_id = i->second; - } else { - m_id = name2id.size(); - name2id[name] = m_id; - id2name.push_back(name); + std::pair<std::string, size_t> to_ins; + to_ins.first.assign(name.data(), name.size()); + to_ins.second = name2id.size(); + std::pair<Name2Id::iterator, bool> res(name2id.insert(to_ins)); + if (res.second) { + // TODO this should be string pointers backed by the hash table. + id2name.push_back(to_ins.first); } + m_id = res.first->second; } } diff --git a/moses/FeatureVector.h b/moses/FeatureVector.h index 983248076..4401e3c03 100644 --- a/moses/FeatureVector.h +++ b/moses/FeatureVector.h @@ -45,6 +45,7 @@ #endif #include "util/check.hh" +#include "util/string_piece.hh" namespace Moses { @@ -68,9 +69,13 @@ namespace Moses { //A feature name can either be initialised as a pair of strings, //which will be concatenated with a SEP between them, or as //a single string, which will be used as-is. - explicit FName(const std::string root, const std::string name) - {init(root + SEP + name);} - explicit FName(const std::string& name) + FName(const StringPiece &root, const StringPiece &name) { + std::string assembled(root.data(), root.size()); + assembled += SEP; + assembled.append(name.data(), name.size()); + init(assembled); + } + explicit FName(const StringPiece &name) {init(name);} const std::string& name() const; @@ -89,7 +94,7 @@ namespace Moses { static void eraseId(size_t id); private: - void init(const std::string& name); + void init(const StringPiece& name); size_t m_id; #ifdef WITH_THREADS //reader-writer lock diff --git a/moses/GlobalLexicalModelUnlimited.cpp b/moses/GlobalLexicalModelUnlimited.cpp index f1de65bd0..cd8299e46 100644 --- a/moses/GlobalLexicalModelUnlimited.cpp +++ b/moses/GlobalLexicalModelUnlimited.cpp @@ -3,6 +3,8 @@ #include "StaticData.h" #include "InputFileStream.h" #include "UserMessage.h" +#include "util/string_piece_hash.hh" +#include "util/murmur_hash.hh" using namespace std; @@ -57,11 +59,11 @@ void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComp const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase(); for(int targetIndex = 0; targetIndex < targetPhrase.GetSize(); targetIndex++ ) { - string targetString = targetPhrase.GetWord(targetIndex).GetString(0); // TODO: change for other factors + StringPiece targetString = targetPhrase.GetWord(targetIndex).GetString(0); // TODO: change for other factors if (m_ignorePunctuation) { // check if first char is punctuation - char firstChar = targetString.at(0); + char firstChar = targetString[0]; CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; @@ -76,23 +78,24 @@ void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComp accumulator->SparsePlusEquals(feature.str(), 1); } - StringHash alreadyScored; + boost::unordered_set<uint64_t> alreadyScored; for(int sourceIndex = 0; sourceIndex < input.GetSize(); sourceIndex++ ) { - string sourceString = input.GetWord(sourceIndex).GetString(0); // TODO: change for other factors + const StringPiece sourceString = input.GetWord(sourceIndex).GetString(0); // TODO: change for other factors if (m_ignorePunctuation) { // check if first char is punctuation - char firstChar = sourceString.at(0); + char firstChar = sourceString[0]; CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; } + const uint64_t sourceHash = util::MurmurHashNative(sourceString.data(), sourceString.size()); - if ( alreadyScored.find(sourceString) == alreadyScored.end()) { + if (alreadyScored.find(sourceHash) == alreadyScored.end()) { bool sourceExists, targetExists; if (!m_unrestricted) { - sourceExists = m_vocabSource.find( sourceString ) != m_vocabSource.end(); - targetExists = m_vocabTarget.find( targetString) != m_vocabTarget.end(); + sourceExists = FindStringPiece(m_vocabSource, sourceString) != m_vocabSource.end(); + targetExists = FindStringPiece(m_vocabTarget, targetString) != m_vocabTarget.end(); } // no feature if vocab is in use and both words are not in restricted vocabularies @@ -107,15 +110,15 @@ void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComp feature << "<s>,"; feature << sourceString; accumulator->SparsePlusEquals(feature.str(), 1); - alreadyScored[sourceString] = 1; + alreadyScored.insert(sourceHash); } // add source words to the right of current source word as context for(int contextIndex = sourceIndex+1; contextIndex < input.GetSize(); contextIndex++ ) { - string contextString = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors + StringPiece contextString = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors bool contextExists; if (!m_unrestricted) - contextExists = m_vocabSource.find( contextString ) != m_vocabSource.end(); + contextExists = FindStringPiece(m_vocabSource, contextString ) != m_vocabSource.end(); if (m_unrestricted || contextExists) { stringstream feature; @@ -126,7 +129,7 @@ void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComp feature << ","; feature << contextString; accumulator->SparsePlusEquals(feature.str(), 1); - alreadyScored[sourceString] = 1; + alreadyScored.insert(sourceHash); } } } @@ -135,7 +138,7 @@ void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComp int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex; // 1) source-target pair, trigger source word (can be discont.) and adjacent target word (bigram) - string targetContext; + StringPiece targetContext; if (globalTargetIndex > 0) targetContext = cur_hypo.GetWord(globalTargetIndex-1).GetString(0); // TODO: change for other factors else @@ -143,23 +146,23 @@ void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComp if (sourceIndex == 0) { string sourceTrigger = "<s>"; - AddFeature(accumulator, alreadyScored, sourceTrigger, sourceString, + AddFeature(accumulator, sourceTrigger, sourceString, targetContext, targetString); } else for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) { - string sourceTrigger = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors + StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors bool sourceTriggerExists = false; if (!m_unrestricted) - sourceTriggerExists = m_vocabSource.find( sourceTrigger ) != m_vocabSource.end(); + sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger) != m_vocabSource.end(); if (m_unrestricted || sourceTriggerExists) - AddFeature(accumulator, alreadyScored, sourceTrigger, sourceString, + AddFeature(accumulator, sourceTrigger, sourceString, targetContext, targetString); } // 2) source-target pair, adjacent source word (bigram) and trigger target word (can be discont.) - string sourceContext; + StringPiece sourceContext; if (sourceIndex-1 >= 0) sourceContext = input.GetWord(sourceIndex-1).GetString(0); // TODO: change for other factors else @@ -167,18 +170,18 @@ void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComp if (globalTargetIndex == 0) { string targetTrigger = "<s>"; - AddFeature(accumulator, alreadyScored, sourceContext, sourceString, + AddFeature(accumulator, sourceContext, sourceString, targetTrigger, targetString); } else for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) { - string targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors + StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors bool targetTriggerExists = false; if (!m_unrestricted) - targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end(); + targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger) != m_vocabTarget.end(); if (m_unrestricted || targetTriggerExists) - AddFeature(accumulator, alreadyScored, sourceContext, sourceString, + AddFeature(accumulator, sourceContext, sourceString, targetTrigger, targetString); } } @@ -195,19 +198,19 @@ void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComp bool targetTriggerExists = true; if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) - AddFeature(accumulator, alreadyScored, sourceTrigger, sourceString, + AddFeature(accumulator, sourceTrigger, sourceString, targetTrigger, targetString); } else { // iterate backwards over target for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) { - string targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors + StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors bool targetTriggerExists = false; if (!m_unrestricted) - targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end(); + targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger) != m_vocabTarget.end(); if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) - AddFeature(accumulator, alreadyScored, sourceTrigger, sourceString, + AddFeature(accumulator, sourceTrigger, sourceString, targetTrigger, targetString); } } @@ -216,29 +219,29 @@ void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComp else { // iterate backwards over source for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) { - string sourceTrigger = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors + StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors bool sourceTriggerExists = false; if (!m_unrestricted) - sourceTriggerExists = m_vocabSource.find( sourceTrigger ) != m_vocabSource.end(); + sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger) != m_vocabSource.end(); if (globalTargetIndex == 0) { string targetTrigger = "<s>"; bool targetTriggerExists = true; if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) - AddFeature(accumulator, alreadyScored, sourceTrigger, sourceString, + AddFeature(accumulator, sourceTrigger, sourceString, targetTrigger, targetString); } else { // iterate backwards over target for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) { - string targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors + StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors bool targetTriggerExists = false; if (!m_unrestricted) - targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end(); + targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger) != m_vocabTarget.end(); if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) - AddFeature(accumulator, alreadyScored, sourceTrigger, sourceString, + AddFeature(accumulator, sourceTrigger, sourceString, targetTrigger, targetString); } } @@ -252,8 +255,7 @@ void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComp feature << "~"; feature << sourceString; accumulator->SparsePlusEquals(feature.str(), 1); - //alreadyScored.insert( &inputWord ); - alreadyScored[sourceString] = 1; + alreadyScored.insert(sourceHash); } } } @@ -262,8 +264,8 @@ void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComp } void GlobalLexicalModelUnlimited::AddFeature(ScoreComponentCollection* accumulator, - StringHash alreadyScored, string sourceTrigger, string sourceWord, string targetTrigger, - string targetWord) const { + StringPiece sourceTrigger, StringPiece sourceWord, StringPiece targetTrigger, + StringPiece targetWord) const { stringstream feature; feature << "glm_"; feature << targetTrigger; @@ -274,7 +276,8 @@ void GlobalLexicalModelUnlimited::AddFeature(ScoreComponentCollection* accumulat feature << ","; feature << sourceWord; accumulator->SparsePlusEquals(feature.str(), 1); - alreadyScored[sourceWord] = 1; + // BUG(ehasler): this did nothing because alreadyScored was passed by value not reference. + //alreadyScored[sourceWord] = 1; } } diff --git a/moses/GlobalLexicalModelUnlimited.h b/moses/GlobalLexicalModelUnlimited.h index 307461db0..2358e1d19 100644 --- a/moses/GlobalLexicalModelUnlimited.h +++ b/moses/GlobalLexicalModelUnlimited.h @@ -12,9 +12,12 @@ #include "FeatureFunction.h" #include "FactorTypeSet.h" #include "Sentence.h" - #include "FFState.h" +#include "util/string_piece.hh" +#include <boost/unordered_set.hpp> +#include <boost/unordered_map.hpp> + #ifdef WITH_THREADS #include <boost/thread/tss.hpp> #endif @@ -35,8 +38,8 @@ class InputType; class GlobalLexicalModelUnlimited : public StatelessFeatureFunction { + // TODO(ehasler): This should be an array of size 256. typedef std::map< char, short > CharHash; - typedef std::map< std::string, short > StringHash; struct ThreadLocalStorage { @@ -64,8 +67,8 @@ private: float m_sparseProducerWeight; bool m_ignorePunctuation; - std::set<std::string> m_vocabSource; - std::set<std::string> m_vocabTarget; + boost::unordered_set<std::string> m_vocabSource; + boost::unordered_set<std::string> m_vocabTarget; public: GlobalLexicalModelUnlimited(const std::vector< FactorType >& inFactors, const std::vector< FactorType >& outFactors, @@ -137,9 +140,9 @@ public: void SetSparseProducerWeight(float weight) { m_sparseProducerWeight = weight; } float GetSparseProducerWeight() const { return m_sparseProducerWeight; } - void AddFeature(ScoreComponentCollection* accumulator, StringHash alreadyScored, - std::string sourceTrigger, std::string sourceWord, std::string targetTrigger, - std::string targetWord) const; + void AddFeature(ScoreComponentCollection* accumulator, + StringPiece sourceTrigger, StringPiece sourceWord, StringPiece targetTrigger, + StringPiece targetWord) const; }; } diff --git a/moses/LM/IRST.cpp b/moses/LM/IRST.cpp index 2748fa1ba..a477fe20d 100644 --- a/moses/LM/IRST.cpp +++ b/moses/LM/IRST.cpp @@ -154,7 +154,8 @@ int LanguageModelIRST::GetLmID( const Factor *factor ) const if ((factorId >= m_lmIdLookup.size()) || (m_lmIdLookup[factorId] == m_empty)) { if (d->incflag()==1) { - std::string s = factor->GetString(); + const StringPiece &f = factor->GetString(); + std::string s(f.data(), f.size()); int code = d->encode(s.c_str()); ////////// diff --git a/moses/Phrase.cpp b/moses/Phrase.cpp index 2e020ef69..7981abdfd 100644 --- a/moses/Phrase.cpp +++ b/moses/Phrase.cpp @@ -265,9 +265,8 @@ bool Phrase::Contains(const vector< vector<string> > &subPhraseVector FactorType factorType = inputFactor[currFactorIndex]; for (size_t currSubPos = 0 ; currSubPos < subSize ; currSubPos++) { size_t currThisPos = currSubPos + currStartPos; - const string &subStr = subPhraseVector[currSubPos][currFactorIndex] - ,&thisStr = GetFactor(currThisPos, factorType)->GetString(); - if (subStr != thisStr) { + const string &subStr = subPhraseVector[currSubPos][currFactorIndex]; + if (subStr != GetFactor(currThisPos, factorType)->GetString()) { match = false; break; } diff --git a/moses/PhrasePairFeature.cpp b/moses/PhrasePairFeature.cpp index 020292748..ba0a7343d 100644 --- a/moses/PhrasePairFeature.cpp +++ b/moses/PhrasePairFeature.cpp @@ -5,7 +5,7 @@ #include "TargetPhrase.h" #include "Hypothesis.h" #include "TranslationOption.h" -#include <boost/algorithm/string.hpp> +#include "util/string_piece_hash.hh" using namespace std; @@ -182,10 +182,10 @@ void PhrasePairFeature::Evaluate( // range over source words to get context for(size_t contextIndex = 0; contextIndex < input.GetSize(); contextIndex++ ) { - string sourceTrigger = input.GetWord(contextIndex).GetFactor(m_sourceFactorId)->GetString(); + StringPiece sourceTrigger = input.GetWord(contextIndex).GetFactor(m_sourceFactorId)->GetString(); if (m_ignorePunctuation) { // check if trigger is punctuation - char firstChar = sourceTrigger.at(0); + char firstChar = sourceTrigger.data()[0]; CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; @@ -193,7 +193,7 @@ void PhrasePairFeature::Evaluate( bool sourceTriggerExists = false; if (!m_unrestricted) - sourceTriggerExists = m_vocabSource.find( sourceTrigger ) != m_vocabSource.end(); + sourceTriggerExists = (FindStringPiece(m_vocabSource, sourceTrigger) != m_vocabSource.end()); if (m_unrestricted || sourceTriggerExists) { ostringstream namestr; diff --git a/moses/PhrasePairFeature.h b/moses/PhrasePairFeature.h index d7cc3ea48..ac51aa61e 100644 --- a/moses/PhrasePairFeature.h +++ b/moses/PhrasePairFeature.h @@ -1,6 +1,7 @@ #ifndef moses_PhrasePairFeature_h #define moses_PhrasePairFeature_h +#include <boost/unordered_set.hpp> #include <stdexcept> #include "Factor.h" @@ -13,11 +14,11 @@ namespace Moses { * Phrase pair feature: complete source/target phrase pair **/ class PhrasePairFeature: public StatelessFeatureFunction { - + // TODO(ehasler): This should be bool ispunct[256]; typedef std::map< char, short > CharHash; typedef std::vector< std::set<std::string> > DocumentVector; - std::set<std::string> m_vocabSource; + boost::unordered_set<std::string> m_vocabSource; //std::set<std::string> m_vocabTarget; DocumentVector m_vocabDomain; FactorType m_sourceFactorId; diff --git a/moses/ScoreComponentCollection.h b/moses/ScoreComponentCollection.h index be23e03fd..32b8cb1c9 100644 --- a/moses/ScoreComponentCollection.h +++ b/moses/ScoreComponentCollection.h @@ -200,7 +200,7 @@ public: } //For features which have an unbounded number of components - void SparseMinusEquals(const std::string& full_name, float score) + void SparseMinusEquals(const StringPiece &full_name, float score) { FName fname(full_name); m_scores[fname] -= score; @@ -240,7 +240,7 @@ public: } //For features which have an unbounded number of components - void PlusEquals(const ScoreProducer*sp, const std::string& name, float score) + void PlusEquals(const ScoreProducer*sp, const StringPiece &name, float score) { CHECK(sp->GetNumScoreComponents() == ScoreProducer::unlimited); FName fname(sp->GetScoreProducerDescription(),name); @@ -248,7 +248,7 @@ public: } //For features which have an unbounded number of components - void SparsePlusEquals(const std::string& full_name, float score) + void SparsePlusEquals(const StringPiece &full_name, float score) { FName fname(full_name); m_scores[fname] += score; diff --git a/moses/SourceWordDeletionFeature.cpp b/moses/SourceWordDeletionFeature.cpp index c312a3b03..082e0900b 100644 --- a/moses/SourceWordDeletionFeature.cpp +++ b/moses/SourceWordDeletionFeature.cpp @@ -6,6 +6,7 @@ #include "ChartHypothesis.h" #include "ScoreComponentCollection.h" #include "TranslationOption.h" +#include "util/string_piece_hash.hh" namespace Moses { @@ -70,9 +71,9 @@ void SourceWordDeletionFeature::ComputeFeatures(const TargetPhrase& targetPhrase if (!aligned[i]) { Word w = targetPhrase.GetSourcePhrase().GetWord(i); if (!w.IsNonTerminal()) { - const string &word = w.GetFactor(m_factorType)->GetString(); + const StringPiece &word = w.GetFactor(m_factorType)->GetString(); if (word != "<s>" && word != "</s>") { - if (!m_unrestricted && m_vocab.find( word ) == m_vocab.end()) { + if (!m_unrestricted && FindStringPiece(m_vocab, word) == m_vocab.end()) { accumulator->PlusEquals(this,"OTHER",1); } else { diff --git a/moses/SourceWordDeletionFeature.h b/moses/SourceWordDeletionFeature.h index d34aa92f5..b503b4670 100644 --- a/moses/SourceWordDeletionFeature.h +++ b/moses/SourceWordDeletionFeature.h @@ -2,7 +2,7 @@ #define moses_SourceWordDeletionFeature_h #include <string> -#include <map> +#include <boost/unordered_set.hpp> #include "FeatureFunction.h" #include "FactorCollection.h" @@ -15,7 +15,7 @@ namespace Moses */ class SourceWordDeletionFeature : public StatelessFeatureFunction { private: - std::set<std::string> m_vocab; + boost::unordered_set<std::string> m_vocab; FactorType m_factorType; bool m_unrestricted; diff --git a/moses/TargetBigramFeature.cpp b/moses/TargetBigramFeature.cpp index a9ad2216b..64942e947 100644 --- a/moses/TargetBigramFeature.cpp +++ b/moses/TargetBigramFeature.cpp @@ -3,6 +3,7 @@ #include "TargetPhrase.h" #include "Hypothesis.h" #include "ScoreComponentCollection.h" +#include "util/string_piece_hash.hh" namespace Moses { @@ -71,24 +72,26 @@ FFState* TargetBigramFeature::Evaluate(const Hypothesis& cur_hypo, f1 = targetPhrase.GetWord(i-1).GetFactor(m_factorType); } const Factor* f2 = targetPhrase.GetWord(i).GetFactor(m_factorType); - const string& w1 = f1->GetString(); - const string& w2 = f2->GetString(); + StringPiece w1(f1->GetString()), w2(f2->GetString()); // skip bigrams if they don't belong to a given restricted vocabulary - if (m_vocab.size() && - (m_vocab.find(w1) == m_vocab.end() || m_vocab.find(w2) == m_vocab.end())) { + if (m_vocab.size() && + (FindStringPiece(m_vocab, w1) == m_vocab.end() || FindStringPiece(m_vocab, w2) == m_vocab.end())) { continue; } - - string name(w1 +":"+w2); + string name(w1.data(), w1.size()); + name += ':'; + name.append(w2.data(), w2.size()); accumulator->PlusEquals(this,name,1); } if (cur_hypo.GetWordsBitmap().IsComplete()) { - const string& w1 = targetPhrase.GetWord(targetPhrase.GetSize()-1).GetFactor(m_factorType)->GetString(); + StringPiece w1(targetPhrase.GetWord(targetPhrase.GetSize()-1).GetFactor(m_factorType)->GetString()); const string& w2 = EOS_; - if (m_vocab.empty() || (m_vocab.find(w1) != m_vocab.end())) { - string name(w1 +":"+w2); + if (m_vocab.empty() || (FindStringPiece(m_vocab, w1) != m_vocab.end())) { + string name(w1.data(), w1.size()); + name += ':'; + name += w2; accumulator->PlusEquals(this,name,1); } return NULL; diff --git a/moses/TargetBigramFeature.h b/moses/TargetBigramFeature.h index 76b4f6ef7..50bfc8e2c 100644 --- a/moses/TargetBigramFeature.h +++ b/moses/TargetBigramFeature.h @@ -4,6 +4,8 @@ #include <string> #include <map> +#include <boost/unordered_set.hpp> + #include "FactorCollection.h" #include "FeatureFunction.h" #include "FFState.h" @@ -56,7 +58,7 @@ public: private: FactorType m_factorType; Word m_bos; - std::set<std::string> m_vocab; + boost::unordered_set<std::string> m_vocab; }; } diff --git a/moses/TargetNgramFeature.cpp b/moses/TargetNgramFeature.cpp index 7973cedce..24a484f59 100644 --- a/moses/TargetNgramFeature.cpp +++ b/moses/TargetNgramFeature.cpp @@ -5,6 +5,8 @@ #include "ScoreComponentCollection.h" #include "ChartHypothesis.h" +#include "util/string_piece_hash.hh" + namespace Moses { using namespace std; @@ -94,9 +96,9 @@ FFState* TargetNgramFeature::Evaluate(const Hypothesis& cur_hypo, for (size_t n = m_n; n >= smallest_n; --n) { // iterate over ngram size for (size_t i = 0; i < targetPhrase.GetSize(); ++i) { // const string& curr_w = targetPhrase.GetWord(i).GetFactor(m_factorType)->GetString(); - const string& curr_w = targetPhrase.GetWord(i).GetString(m_factorType); + const StringPiece& curr_w = targetPhrase.GetWord(i).GetString(m_factorType); - if (m_vocab.size() && (m_vocab.find(curr_w) == m_vocab.end())) continue; // skip ngrams + if (m_vocab.size() && (FindStringPiece(m_vocab, curr_w) == m_vocab.end())) continue; // skip ngrams if (n > 1) { // can we build an ngram at this position? ("<s> this" --> cannot build 3gram at this position) @@ -172,8 +174,8 @@ FFState* TargetNgramFeature::Evaluate(const Hypothesis& cur_hypo, void TargetNgramFeature::appendNgram(const Word& word, bool& skip, stringstream &ngram) const { // const string& w = word.GetFactor(m_factorType)->GetString(); - const string& w = word.GetString(m_factorType); - if (m_vocab.size() && (m_vocab.find(w) == m_vocab.end())) skip = true; + const StringPiece& w = word.GetString(m_factorType); + if (m_vocab.size() && (FindStringPiece(m_vocab, w) == m_vocab.end())) skip = true; else { ngram << w; ngram << ":"; @@ -215,7 +217,7 @@ FFState* TargetNgramFeature::EvaluateChart(const ChartHypothesis& cur_hypo, int makeSuffix = true; // beginning/end of sentence symbol <s>,</s>? - string factorZero = word.GetString(0); + StringPiece factorZero = word.GetString(0); if (factorZero.compare("<s>") == 0) prefixTerminals++; // end of sentence symbol </s>? @@ -396,7 +398,7 @@ void TargetNgramFeature::MakePrefixNgrams(std::vector<const Word*> &contextFacto for (size_t i=k+offset; i <= end_pos; ++i) { if (i > k+offset) ngram << ":"; - string factorZero = (*contextFactor[i]).GetString(0); + StringPiece factorZero = (*contextFactor[i]).GetString(0); if (m_factorType == 0 || factorZero.compare("<s>") == 0 || factorZero.compare("</s>") == 0) ngram << factorZero; else @@ -417,7 +419,7 @@ void TargetNgramFeature::MakeSuffixNgrams(std::vector<const Word*> &contextFacto for (int start_pos=end_pos-1; (start_pos >= 0) && (end_pos-start_pos < m_n); --start_pos) { ngram << m_baseName; for (size_t j=start_pos; j <= end_pos; ++j){ - string factorZero = (*contextFactor[j]).GetString(0); + StringPiece factorZero = (*contextFactor[j]).GetString(0); if (m_factorType == 0 || factorZero.compare("<s>") == 0 || factorZero.compare("</s>") == 0) ngram << factorZero; else diff --git a/moses/TargetNgramFeature.h b/moses/TargetNgramFeature.h index 0aa98be7d..c26198b2a 100644 --- a/moses/TargetNgramFeature.h +++ b/moses/TargetNgramFeature.h @@ -13,6 +13,8 @@ #include "ChartHypothesis.h" #include "ChartManager.h" +#include <boost/unordered_set.hpp> + namespace Moses { @@ -213,7 +215,7 @@ public: private: FactorType m_factorType; Word m_bos; - std::set<std::string> m_vocab; + boost::unordered_set<std::string> m_vocab; size_t m_n; bool m_lower_ngrams; diff --git a/moses/TargetWordInsertionFeature.cpp b/moses/TargetWordInsertionFeature.cpp index 3b9bf36ba..4420bd7a5 100644 --- a/moses/TargetWordInsertionFeature.cpp +++ b/moses/TargetWordInsertionFeature.cpp @@ -6,6 +6,7 @@ #include "ChartHypothesis.h" #include "ScoreComponentCollection.h" #include "TranslationOption.h" +#include "util/string_piece_hash.hh" namespace Moses { @@ -73,9 +74,9 @@ void TargetWordInsertionFeature::ComputeFeatures(const TargetPhrase& targetPhras if (!aligned[i]) { Word w = targetPhrase.GetWord(i); if (!w.IsNonTerminal()) { - const string &word = w.GetFactor(m_factorType)->GetString(); + const StringPiece &word = w.GetFactor(m_factorType)->GetString(); if (word != "<s>" && word != "</s>") { - if (!m_unrestricted && m_vocab.find( word ) == m_vocab.end()) { + if (!m_unrestricted && FindStringPiece(m_vocab, word) == m_vocab.end()) { accumulator->PlusEquals(this,"OTHER",1); } else { diff --git a/moses/TargetWordInsertionFeature.h b/moses/TargetWordInsertionFeature.h index a7a149db6..7a1e3770a 100644 --- a/moses/TargetWordInsertionFeature.h +++ b/moses/TargetWordInsertionFeature.h @@ -2,7 +2,7 @@ #define moses_TargetWordInsertionFeature_h #include <string> -#include <map> +#include <boost/unordered_set.hpp> #include "FeatureFunction.h" #include "FactorCollection.h" @@ -15,7 +15,7 @@ namespace Moses */ class TargetWordInsertionFeature : public StatelessFeatureFunction { private: - std::set<std::string> m_vocab; + boost::unordered_set<std::string> m_vocab; FactorType m_factorType; bool m_unrestricted; diff --git a/moses/TranslationOptionCollection.cpp b/moses/TranslationOptionCollection.cpp index 3d553a458..553e68aa7 100644 --- a/moses/TranslationOptionCollection.cpp +++ b/moses/TranslationOptionCollection.cpp @@ -207,7 +207,7 @@ void TranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceWord,s size_t isDigit = 0; const Factor *f = sourceWord[0]; // TODO hack. shouldn't know which factor is surface - const string &s = f->GetString(); + const StringPiece &s = f->GetString(); bool isEpsilon = (s=="" || s==EPSILON); if (StaticData::Instance().GetDropUnknown()) { diff --git a/moses/Word.cpp b/moses/Word.cpp index 2c1ac09ea..69d382c8a 100644 --- a/moses/Word.cpp +++ b/moses/Word.cpp @@ -87,13 +87,8 @@ std::string Word::GetString(const vector<FactorType> factorType,bool endWithBlan return strme.str(); } -std::string Word::GetString(FactorType factorType) const -{ - const Factor *factor = m_factorArray[factorType]; - if (factor != NULL) - return factor->GetString(); - else - return NULL; +StringPiece Word::GetString(FactorType factorType) const { + return m_factorArray[factorType]->GetString(); } class StrayFactorException : public util::Exception {}; diff --git a/moses/Word.h b/moses/Word.h index 70875d75c..d650fb67e 100644 --- a/moses/Word.h +++ b/moses/Word.h @@ -102,7 +102,7 @@ public: * these debugging functions. */ std::string GetString(const std::vector<FactorType> factorType,bool endWithBlank) const; - std::string GetString(FactorType factorType) const; + StringPiece GetString(FactorType factorType) const; TO_STRING(); //! transitive comparison of Word objects diff --git a/moses/WordTranslationFeature.cpp b/moses/WordTranslationFeature.cpp index 6fd5040d6..908274c2b 100644 --- a/moses/WordTranslationFeature.cpp +++ b/moses/WordTranslationFeature.cpp @@ -7,7 +7,7 @@ #include "ChartHypothesis.h" #include "ScoreComponentCollection.h" #include "TranslationOption.h" -#include <boost/algorithm/string.hpp> +#include "util/string_piece_hash.hh" namespace Moses { @@ -25,14 +25,11 @@ bool WordTranslationFeature::Load(const std::string &filePathSource, const std:: std::string line; while (getline(inFileSource, line)) { - std::set<std::string> terms; + m_vocabDomain.resize(m_vocabDomain.size() + 1); vector<string> termVector; boost::split(termVector, line, boost::is_any_of("\t ")); for (size_t i=0; i < termVector.size(); ++i) - terms.insert(termVector[i]); - - // add term set for current document - m_vocabDomain.push_back(terms); + m_vocabDomain.back().insert(termVector[i]); } inFileSource.close(); @@ -89,24 +86,24 @@ void WordTranslationFeature::Evaluate if (m_factorTypeSource == 0 && ws.IsNonTerminal()) continue; Word wt = targetPhrase.GetWord(targetIndex); if (m_factorTypeSource == 0 && wt.IsNonTerminal()) continue; - string sourceWord = ws.GetFactor(m_factorTypeSource)->GetString(); - string targetWord = wt.GetFactor(m_factorTypeTarget)->GetString(); + StringPiece sourceWord = ws.GetFactor(m_factorTypeSource)->GetString(); + StringPiece targetWord = wt.GetFactor(m_factorTypeTarget)->GetString(); if (m_ignorePunctuation) { // check if source or target are punctuation - char firstChar = sourceWord.at(0); + char firstChar = sourceWord.data()[0]; CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; - firstChar = targetWord.at(0); + firstChar = targetWord.data()[0]; charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; } if (!m_unrestricted) { - if (m_vocabSource.find(sourceWord) == m_vocabSource.end()) + if (FindStringPiece(m_vocabSource, sourceWord) == m_vocabSource.end()) sourceWord = "OTHER"; - if (m_vocabTarget.find(targetWord) == m_vocabTarget.end()) + if (FindStringPiece(m_vocabTarget, targetWord) == m_vocabTarget.end()) targetWord = "OTHER"; } @@ -167,7 +164,7 @@ void WordTranslationFeature::Evaluate else { // range over domain trigger words (keywords) const long docid = input.GetDocumentId(); - for (set<string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) { + for (boost::unordered_set<string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) { string sourceTrigger = *p; stringstream feature; feature << "wt_"; @@ -196,10 +193,10 @@ void WordTranslationFeature::Evaluate // range over source words to get context for(size_t contextIndex = 0; contextIndex < input.GetSize(); contextIndex++ ) { if (contextIndex == globalSourceIndex) continue; - string sourceTrigger = input.GetWord(contextIndex).GetFactor(m_factorTypeSource)->GetString(); + StringPiece sourceTrigger = input.GetWord(contextIndex).GetFactor(m_factorTypeSource)->GetString(); if (m_ignorePunctuation) { // check if trigger is punctuation - char firstChar = sourceTrigger.at(0); + char firstChar = sourceTrigger.data()[0]; CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; @@ -208,9 +205,9 @@ void WordTranslationFeature::Evaluate const long docid = input.GetDocumentId(); bool sourceTriggerExists = false; if (m_domainTrigger) - sourceTriggerExists = m_vocabDomain[docid].find( sourceTrigger ) != m_vocabDomain[docid].end(); + sourceTriggerExists = FindStringPiece(m_vocabDomain[docid], sourceTrigger) != m_vocabDomain[docid].end(); else if (!m_unrestricted) - sourceTriggerExists = m_vocabSource.find( sourceTrigger ) != m_vocabSource.end(); + sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger) != m_vocabSource.end(); if (m_domainTrigger) { if (sourceTriggerExists) { @@ -304,24 +301,24 @@ void WordTranslationFeature::EvaluateChart( if (m_factorTypeSource == 0 && ws.IsNonTerminal()) continue; Word wt = targetPhrase.GetWord(targetIndex); if (m_factorTypeSource == 0 && wt.IsNonTerminal()) continue; - string sourceWord = ws.GetFactor(m_factorTypeSource)->GetString(); - string targetWord = wt.GetFactor(m_factorTypeTarget)->GetString(); + StringPiece sourceWord = ws.GetFactor(m_factorTypeSource)->GetString(); + StringPiece targetWord = wt.GetFactor(m_factorTypeTarget)->GetString(); if (m_ignorePunctuation) { // check if source or target are punctuation - char firstChar = sourceWord.at(0); + char firstChar = sourceWord[0]; CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; - firstChar = targetWord.at(0); + firstChar = targetWord[0]; charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; } if (!m_unrestricted) { - if (m_vocabSource.find(sourceWord) == m_vocabSource.end()) + if (FindStringPiece(m_vocabSource, sourceWord) == m_vocabSource.end()) sourceWord = "OTHER"; - if (m_vocabTarget.find(targetWord) == m_vocabTarget.end()) + if (FindStringPiece(m_vocabTarget, targetWord) == m_vocabTarget.end()) targetWord = "OTHER"; } diff --git a/moses/WordTranslationFeature.h b/moses/WordTranslationFeature.h index 7f74ae4e3..bac948219 100644 --- a/moses/WordTranslationFeature.h +++ b/moses/WordTranslationFeature.h @@ -2,7 +2,7 @@ #define moses_WordTranslationFeature_h #include <string> -#include <map> +#include <boost/unordered_set.hpp> #include "FeatureFunction.h" #include "FactorCollection.h" @@ -18,11 +18,11 @@ namespace Moses class WordTranslationFeature : public StatelessFeatureFunction { typedef std::map< char, short > CharHash; - typedef std::vector< std::set<std::string> > DocumentVector; + typedef std::vector< boost::unordered_set<std::string> > DocumentVector; private: - std::set<std::string> m_vocabSource; - std::set<std::string> m_vocabTarget; + boost::unordered_set<std::string> m_vocabSource; + boost::unordered_set<std::string> m_vocabTarget; DocumentVector m_vocabDomain; FactorType m_factorTypeSource; FactorType m_factorTypeTarget; |