Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/translator/translator.h')
-rwxr-xr-xsrc/translator/translator.h36
1 files changed, 33 insertions, 3 deletions
diff --git a/src/translator/translator.h b/src/translator/translator.h
index cc68a4f0..15eb9870 100755
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -1,5 +1,7 @@
#pragma once
+#include <string>
+
#include "data/batch_generator.h"
#include "data/corpus.h"
#include "data/shortlist.h"
@@ -245,10 +247,14 @@ public:
}
std::string run(const std::string& input) override {
- auto corpus_ = New<data::TextInput>(std::vector<std::string>({input}), srcVocabs_, options_);
+ // split tab-separated input into fields if necessary
+ auto inputs = options_->get<bool>("tsv", false)
+ ? convertTsvToLists(input, options_->get<size_t>("tsv-fields", 1))
+ : std::vector<std::string>({input});
+ auto corpus_ = New<data::TextInput>(inputs, srcVocabs_, options_);
data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_);
- auto collector = New<StringCollector>();
+ auto collector = New<StringCollector>(options_->get<bool>("quiet-translation", false));
auto printer = New<OutputPrinter>(options_, trgVocab_);
size_t batchId = 0;
@@ -258,7 +264,6 @@ public:
ThreadPool threadPool_(numDevices_, numDevices_);
for(auto batch : batchGenerator) {
-
auto task = [=](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local std::vector<Ptr<Scorer>> scorers;
@@ -287,5 +292,30 @@ public:
auto translations = collector->collect(options_->get<bool>("n-best"));
return utils::join(translations, "\n");
}
+
+private:
+ // Converts a multi-line input with tab-separated source(s) and target sentences into separate lists
+ // of sentences from source(s) and target sides, e.g.
+ // "src1 \t trg1 \n src2 \t trg2" -> ["src1 \n src2", "trg1 \n trg2"]
+ std::vector<std::string> convertTsvToLists(const std::string& inputText, size_t numFields) {
+ std::vector<std::string> outputFields(numFields);
+
+ std::string line;
+ std::vector<std::string> lineFields(numFields);
+ std::istringstream inputStream(inputText);
+ bool first = true;
+ while(std::getline(inputStream, line)) {
+ utils::splitTsv(line, lineFields, numFields);
+ for(size_t i = 0; i < numFields; ++i) {
+ if(!first)
+ outputFields[i] += "\n"; // join sentences with a new line sign
+ outputFields[i] += lineFields[i];
+ }
+ if(first)
+ first = false;
+ }
+
+ return outputFields;
+ }
};
} // namespace marian