diff options
Diffstat (limited to 'src/translator/translator.h')
-rwxr-xr-x | src/translator/translator.h | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/src/translator/translator.h b/src/translator/translator.h index cc68a4f0..15eb9870 100755 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -1,5 +1,7 @@ #pragma once +#include <string> + #include "data/batch_generator.h" #include "data/corpus.h" #include "data/shortlist.h" @@ -245,10 +247,14 @@ public: } std::string run(const std::string& input) override { - auto corpus_ = New<data::TextInput>(std::vector<std::string>({input}), srcVocabs_, options_); + // split tab-separated input into fields if necessary + auto inputs = options_->get<bool>("tsv", false) + ? convertTsvToLists(input, options_->get<size_t>("tsv-fields", 1)) + : std::vector<std::string>({input}); + auto corpus_ = New<data::TextInput>(inputs, srcVocabs_, options_); data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_); - auto collector = New<StringCollector>(); + auto collector = New<StringCollector>(options_->get<bool>("quiet-translation", false)); auto printer = New<OutputPrinter>(options_, trgVocab_); size_t batchId = 0; @@ -258,7 +264,6 @@ public: ThreadPool threadPool_(numDevices_, numDevices_); for(auto batch : batchGenerator) { - auto task = [=](size_t id) { thread_local Ptr<ExpressionGraph> graph; thread_local std::vector<Ptr<Scorer>> scorers; @@ -287,5 +292,30 @@ public: auto translations = collector->collect(options_->get<bool>("n-best")); return utils::join(translations, "\n"); } + +private: + // Converts a multi-line input with tab-separated source(s) and target sentences into separate lists + // of sentences from source(s) and target sides, e.g. + // "src1 \t trg1 \n src2 \t trg2" -> ["src1 \n src2", "trg1 \n trg2"] + std::vector<std::string> convertTsvToLists(const std::string& inputText, size_t numFields) { + std::vector<std::string> outputFields(numFields); + + std::string line; + std::vector<std::string> lineFields(numFields); + std::istringstream inputStream(inputText); + bool first = true; + while(std::getline(inputStream, line)) { + utils::splitTsv(line, lineFields, numFields); + for(size_t i = 0; i < numFields; ++i) { + if(!first) + outputFields[i] += "\n"; // join sentences with a new line sign + outputFields[i] += lineFields[i]; + } + if(first) + first = false; + } + + return outputFields; + } }; } // namespace marian |