diff options
author | Ulrich Germann <ugermann@inf.ed.ac.uk> | 2015-05-11 02:34:24 +0300 |
---|---|---|
committer | Ulrich Germann <ugermann@inf.ed.ac.uk> | 2015-05-11 02:34:24 +0300 |
commit | 7da7ce52dab34fda4f2fcae2b37a7696eb773b2c (patch) | |
tree | 05e607e4d26e2dbb819a2214eb16406cace77500 /moses/IOWrapper.cpp | |
parent | db5ccff364508f9348b2d20d6ac57bc9da38bedf (diff) |
Added context buffering in IOWrapper for context-sensitive decoding.
Unfortunately, this seems to slow things down quite a bit.
Diffstat (limited to 'moses/IOWrapper.cpp')
-rw-r--r-- | moses/IOWrapper.cpp | 99 |
1 files changed, 79 insertions, 20 deletions
diff --git a/moses/IOWrapper.cpp b/moses/IOWrapper.cpp index 57717e880..63c8ab5e0 100644 --- a/moses/IOWrapper.cpp +++ b/moses/IOWrapper.cpp @@ -81,13 +81,25 @@ IOWrapper::IOWrapper() ,m_surpressSingleBestOutput(false) + , m_look_ahead(0) + , m_look_back(0) + , m_buffered_ahead(0) + ,spe_src(NULL) ,spe_trg(NULL) ,spe_aln(NULL) { const StaticData &staticData = StaticData::Instance(); + // context buffering for context-sensitive decoding + m_look_ahead = staticData.GetContextParameters().look_ahead; + m_look_back = staticData.GetContextParameters().look_back; + m_inputType = staticData.GetInputType(); + + UTIL_THROW_IF2((m_look_ahead || m_look_back) && m_inputType != SentenceInput, + "Context-sensitive decoding currently works only with sentence input."); + m_currentLine = staticData.GetStartTranslationId(); m_inputFactorOrder = &staticData.GetInputFactorOrder(); @@ -239,40 +251,87 @@ IOWrapper::~IOWrapper() // } boost::shared_ptr<InputType> -IOWrapper::ReadInput() +IOWrapper:: +GetBufferedInput() { - boost::shared_ptr<InputType> source; switch(m_inputType) { - case SentenceInput: - source.reset(new Sentence); - break; - case ConfusionNetworkInput: - source.reset(new ConfusionNet); - break; + case SentenceInput: + return BufferInput<Sentence>(); + case ConfusionNetworkInput: + return BufferInput<ConfusionNet>(); case WordLatticeInput: - source.reset(new WordLattice); - break; + return BufferInput<WordLattice>(); case TreeInputType: - source.reset(new TreeInput); - break; + return BufferInput<TreeInput>(); case TabbedSentenceInput: - source.reset(new TabbedSentence); - break; + return BufferInput<TabbedSentence>(); case ForestInputType: - source.reset(new ForestInput); - break; + return BufferInput<ForestInput>(); default: TRACE_ERR("Unknown input type: " << m_inputType << "\n"); + return boost::shared_ptr<InputType>(); } + +} + +boost::shared_ptr<InputType> +IOWrapper::ReadInput() +{ #ifdef WITH_THREADS boost::lock_guard<boost::mutex> lock(m_lock); #endif - if (source->Read(*m_inputStream, *m_inputFactorOrder)) - source->SetTranslationId(m_currentLine++); - else - source.reset(); + boost::shared_ptr<InputType> source = GetBufferedInput(); + if (source) + { + source->SetTranslationId(m_currentLine++); + this->set_context_for(*source); + } + m_past_input.push_back(source); return source; } +void +IOWrapper:: +set_context_for(InputType& source) +{ + boost::shared_ptr<string> context(new string); + list<boost::shared_ptr<InputType> >::iterator m = m_past_input.end(); + // remove obsolete past input from buffer: + if (m_past_input.end() != m_past_input.begin()) + { + for (size_t cnt = 0; cnt < m_look_back && --m != m_past_input.begin(); + cnt += (*m)->GetSize()); + while (m_past_input.begin() != m) m_past_input.pop_front(); + } + // cerr << string(80,'=') << endl; + if (m_past_input.size()) + { + m = m_past_input.begin(); + *context += (*m)->ToString(); + // cerr << (*m)->ToString() << endl; + for (++m; m != m_past_input.end(); ++m) + { + // cerr << "\n" << (*m)->ToString() << endl; + *context += string(" ") + (*m)->ToString(); + } + // cerr << string(80,'-') << endl; + } + // cerr << source.ToString() << endl; + if (m_future_input.size()) + { + // cerr << string(80,'-') << endl; + for (m = m_future_input.begin(); m != m_future_input.end(); ++m) + { + // if (m != m_future_input.begin()) cerr << "\n"; + // cerr << (*m)->ToString() << endl; + if (context->size()) *context += " "; + *context += (*m)->ToString(); + } + } + // cerr << string(80,'=') << endl; + source.SetContext(context); +} + + } // namespace |