Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMartin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-03-26 19:17:12 +0300
committerMartin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-03-26 19:17:12 +0300
commit7d1f941242928c976640a20f37e1bd9ac10011e8 (patch)
treea8f895b2d26bc1d947fe8a5fcb215d88a747dd6f /src
parent08bb158974597e92c3b5b0e20d938697bf6146b8 (diff)
Merged PR 18309: Cleaner suppression of unwanted output words
This PR adds cleaner suppression of unwanted output words. We identified a situation where SPM with byte-fallback can generate random bytes with output-sampling. That is particularly harmful when that random bytes happens to be a newline symbol. Here we suppress newline in output unless explicitly wanted.
Diffstat (limited to 'src')
-rw-r--r--src/common/config_parser.cpp2
-rw-r--r--src/data/sentencepiece_vocab.cpp27
-rw-r--r--src/data/shortlist.h7
-rw-r--r--src/data/vocab.cpp21
-rw-r--r--src/data/vocab.h7
-rw-r--r--src/data/vocab_base.h5
-rw-r--r--src/translator/beam_search.cpp31
-rw-r--r--src/translator/helpers.cpp23
-rw-r--r--src/translator/helpers.cu64
-rw-r--r--src/translator/helpers.h11
10 files changed, 135 insertions, 63 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index 602509c5..6495db0e 100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -651,6 +651,8 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
"Subtract (arg * translation length) from translation score");
cli.add<bool>("--allow-unk",
"Allow unknown words to appear in output");
+ cli.add<bool>("--allow-special",
+ "Allow special symbols to appear in output, e.g. for SentencePiece with byte-fallback do not suppress the newline symbol");
cli.add<bool>("--n-best",
"Generate n-best list");
cli.add<std::string>("--alignment",
diff --git a/src/data/sentencepiece_vocab.cpp b/src/data/sentencepiece_vocab.cpp
index c168f6e3..090d478b 100644
--- a/src/data/sentencepiece_vocab.cpp
+++ b/src/data/sentencepiece_vocab.cpp
@@ -39,6 +39,20 @@ private:
// Keeps sentences segmented into subword units
bool keepEncoded_{false};
+ // Contains control characters added to vocab due to byte-fallback
+ std::vector<Word> controlChars_;
+
+ // Creates the first 32 control characters as done in byte-fallback and checks if they exist in the vocab.
+ // This makes sure that we do not waste computational effort on suppression if they don't actually appear.
+ void populateControlChars() {
+ for(int i = 0; i < 32; ++i) {
+ std::string bytePiece = fmt::format("<0x{:02X}>", i); // 0 becomes <0x00>, 10 becomes <0x0A>, note uppercase A and lowercase x
+ auto id = spm_->PieceToId(bytePiece);
+ if(id != spm_->unk_id())
+ controlChars_.push_back(Word::fromWordIndex(id));
+ }
+ }
+
// Sample from one file, based on first algorithm from:
// https://en.wikipedia.org/wiki/Reservoir_sampling
void reservoirSampling(std::vector<std::string>& sample, size_t& seenLines,
@@ -262,11 +276,24 @@ public:
"SentencePiece vocabulary error: {}",
status.ToString());
+ populateControlChars();
+
return spm_->GetPieceSize();
}
std::string toUpper(const std::string& line) const override { return utils::utf8ToUpper(line); }
std::string toEnglishTitleCase(const std::string& line) const override { return utils::toEnglishTitleCase(line); }
+
+ // SentencePiece with byte-fallback may generate control symbols with output sampling.
+ // Let's mark them as special and suppress them later on output. This is generally safe
+ // for UTF-8 since control chars are not used as partial bytes in multi-byte sequences.
+ // They only appear in single-byte chars as themselves and this is what we suppress.
+ void addSpecialWords(std::vector<Word>& special) const override {
+ special.reserve(special.size() + controlChars_.size());
+ for(auto c : controlChars_)
+ special.push_back(c);
+ }
+
};
#endif // USE_SENTENCEPIECE
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index ab6a087b..f0467640 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -13,6 +13,7 @@
#include <vector>
#include <iostream>
#include <algorithm>
+#include <limits>
namespace marian {
namespace data {
@@ -22,18 +23,20 @@ private:
std::vector<WordIndex> indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings
public:
+ static constexpr WordIndex npos{std::numeric_limits<WordIndex>::max()}; // used to identify invalid shortlist entries similar to std::string::npos
+
Shortlist(const std::vector<WordIndex>& indices)
: indices_(indices) {}
const std::vector<WordIndex>& indices() const { return indices_; }
WordIndex reverseMap(int idx) { return indices_[idx]; }
- int tryForwardMap(WordIndex wIdx) {
+ WordIndex tryForwardMap(WordIndex wIdx) {
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
return (int)std::distance(indices_.begin(), first); // return coordinate if found
else
- return -1; // return -1 if not found
+ return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17?
}
};
diff --git a/src/data/vocab.cpp b/src/data/vocab.cpp
index 07ac479e..8a3d49c7 100644
--- a/src/data/vocab.cpp
+++ b/src/data/vocab.cpp
@@ -138,6 +138,27 @@ Word Vocab::getEosId() const { return vImpl_->getEosId(); }
// return UNK symbol id
Word Vocab::getUnkId() const { return vImpl_->getUnkId(); }
+std::vector<Word> Vocab::suppressedIds(bool suppressUnk, bool suppressSpecial) const {
+ std::vector<Word> ids;
+ if(suppressUnk) {
+ auto unkId = getUnkId();
+ if(unkId != Word::NONE)
+ ids.push_back(unkId);
+ }
+ if(suppressSpecial)
+ vImpl_->addSpecialWords(/*in/out=*/ids);
+ return ids;
+}
+
+std::vector<WordIndex> Vocab::suppressedIndices(bool suppressUnk, bool suppressSpecial) const {
+ std::vector<WordIndex> indices;
+ for(Word word : suppressedIds(suppressUnk, suppressSpecial))
+ indices.push_back(word.toWordIndex());
+
+ vImpl_->transcodeToShortlistInPlace(indices.data(), indices.size());
+ return indices;
+}
+
// for corpus augmentation: convert string to all-caps
std::string Vocab::toUpper(const std::string& line) const { return vImpl_->toUpper(line); }
diff --git a/src/data/vocab.h b/src/data/vocab.h
index 9a40ba16..2ab6b2b0 100644
--- a/src/data/vocab.h
+++ b/src/data/vocab.h
@@ -70,6 +70,13 @@ public:
// return UNK symbol id
Word getUnkId() const;
+ // return a set of Word ids that should be suppressed based on the underlying vocabulary implementation.
+ // Arguments mosty likely provided based on outside options like --allow-unk etc.
+ std::vector<Word> suppressedIds(bool suppressUnk = true, bool suppressSpecial = true) const;
+
+ // same as suppressedIds but return numeric word indices into the embedding matrices
+ std::vector<WordIndex> suppressedIndices(bool suppressUnk = true, bool suppressSpecial = true) const;
+
// for corpus augmentation: convert string to all-caps
// @TODO: Consider a different implementation where this does not show on the vocab interface,
// but instead as additional options passed to vocab instantiation.
diff --git a/src/data/vocab_base.h b/src/data/vocab_base.h
index 8c214c97..fc512026 100644
--- a/src/data/vocab_base.h
+++ b/src/data/vocab_base.h
@@ -49,9 +49,12 @@ public:
virtual std::string toUpper(const std::string& line) const { return line; }
virtual std::string toEnglishTitleCase(const std::string& line) const { return line; }
- // this function is an identity mapping for default vocabularies, hence do nothing
+ // Identity mapping for default vocabularies, hence do nothing
virtual void transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const { ptr; num; }
+ // Populates vector `special` with special words like "\n" etc.
+ virtual void addSpecialWords(std::vector<Word>& special) const { special; }
+
virtual void createFake() = 0;
virtual Word randWord() const {
diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp
index 5c1989a6..91dde6e6 100644
--- a/src/translator/beam_search.cpp
+++ b/src/translator/beam_search.cpp
@@ -258,7 +258,6 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
// We will use the prefix "currentBatch.." whenever we refer to batch dimension that can change due to batch-pruning.
const int origDimBatch = (int)batch->size();
const auto trgEosId = trgVocab_->getEosId();
- const auto trgUnkId = trgVocab_->getUnkId();
auto getNBestList = createGetNBestListFn(beamSize_, origDimBatch, graph->getDeviceId());
@@ -298,13 +297,23 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
const_cast<std::vector<bool>&>(emptyBatchEntries).push_back(batch->front()->data()[origBatchIdx] == srcEosId); // const_cast during construction
}
- // determine index of UNK in the log prob vectors if we want to suppress it in the decoding process
- int unkColId = -1;
- if (trgUnkId != Word::NONE && !options_->get<bool>("allow-unk", false)) { // do we need to suppress unk?
- unkColId = factoredVocab ? factoredVocab->getUnkIndex() : trgUnkId.toWordIndex(); // what's the raw index of unk in the log prob vector?
- auto shortlist = scorers_[0]->getShortlist(); // first shortlist is generally ok, @TODO: make sure they are the same across scorers?
- if (shortlist)
- unkColId = shortlist->tryForwardMap(unkColId); // use shifted postion of unk in case of using a shortlist, shortlist may have removed unk which results in -1
+ Expr suppressedWordIndices;
+ bool suppressUnk = !options_->get<bool>("allow-unk", false);
+ bool suppressSpecial = !options_->get<bool>("allow-special", false);
+ if (suppressUnk || suppressSpecial) { // do we need to suppress unk or special?
+ std::vector<WordIndex> suppressed = trgVocab_->suppressedIndices(suppressUnk, suppressSpecial);
+
+ auto shortlist = scorers_[0]->getShortlist(); // first shortlist is generally ok, @TODO: make sure they are the same across scorers?
+ if(shortlist) // check if suppressed words are allowed by the shortlist, if not, remove
+ suppressed.erase(std::remove_if(suppressed.begin(),
+ suppressed.end(),
+ [&](WordIndex i) {
+ return shortlist->tryForwardMap(i) == data::Shortlist::npos;
+ }),
+ suppressed.end());
+
+ if(!suppressed.empty())
+ suppressedWordIndices = graph->indices(suppressed);
}
// the decoding process updates the following state information in each output time step:
@@ -453,10 +462,8 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
//**********************************************************************
// suppress specific symbols if not at right positions
- if(unkColId != -1 && factorGroup == 0)
- suppressWord(expandedPathScores, unkColId);
- for(auto state : states)
- state->blacklist(expandedPathScores, batch);
+ if(suppressedWordIndices && factorGroup == 0)
+ suppressWords(expandedPathScores, suppressedWordIndices);
//**********************************************************************
// perform beam search
diff --git a/src/translator/helpers.cpp b/src/translator/helpers.cpp
index f4b75da0..e37c6a4a 100644
--- a/src/translator/helpers.cpp
+++ b/src/translator/helpers.cpp
@@ -13,29 +13,30 @@ namespace marian {
namespace cpu {
-void SetColumn(Tensor in_, size_t col, float value) {
- int nRows = in_->shape().elements() / in_->shape()[-1];
- int nColumns = in_->shape()[-1];
+void SetColumns(Tensor in, Tensor indices, float value) {
+ int nRows = in->shape().elements() / in->shape()[-1];
+ int nColumns = in->shape()[-1];
+ int nSuppress = indices->shape()[-1];
- float* in = in_->data();
for(int rowNumber = 0; rowNumber < nRows; ++rowNumber) {
- auto index = col + rowNumber * nColumns;
- in[index] = value;
+ float* row = in->data() + rowNumber * nColumns;
+ for(int i = 0; i < nSuppress; ++i)
+ row[indices->data<WordIndex>()[i]] = value;
}
}
-void suppressWord(Expr logProbs, WordIndex wordIndex) {
- SetColumn(logProbs->val(), wordIndex, std::numeric_limits<float>::lowest());
+void suppressWords(Expr logProbs, Expr wordIndices) {
+ SetColumns(logProbs->val(), wordIndices->val(), std::numeric_limits<float>::lowest());
}
} // namespace cpu
-void suppressWord(Expr logProbs, WordIndex wordIndex) {
+void suppressWords(Expr logProbs, Expr wordIndices) {
if(logProbs->val()->getBackend()->getDeviceId().type == DeviceType::cpu) {
- cpu::suppressWord(logProbs, wordIndex);
+ cpu::suppressWords(logProbs, wordIndices);
}
#ifdef CUDA_FOUND
else {
- gpu::suppressWord(logProbs, wordIndex);
+ gpu::suppressWords(logProbs, wordIndices);
}
#endif
}
diff --git a/src/translator/helpers.cu b/src/translator/helpers.cu
index 48a079fc..f0272609 100644
--- a/src/translator/helpers.cu
+++ b/src/translator/helpers.cu
@@ -1,8 +1,3 @@
-/* All or part of this file was contributed by Intel under license:
- * Copyright (C) 2017-2018 Intel Corporation
- * SPDX-License-Identifier: MIT
- */
-
#include <cuda.h>
#include <limits>
@@ -17,39 +12,50 @@ namespace marian {
namespace gpu {
template <typename T>
-__global__ void gSetColumn(T* d_in,
- size_t n_columns,
- size_t n_rows,
- size_t noColumn,
- T value) {
- size_t rowNumber = threadIdx.x + blockDim.x * blockIdx.x;
- size_t index = noColumn + rowNumber * n_columns;
-
- if(index < n_columns * n_rows) {
- d_in[index] = value;
+__global__ void gSetColumns(T* out,
+ int rows,
+ int cols,
+ const IndexType* wordIndices,
+ int numIndices,
+ T value) {
+ for(int bid = 0; bid < rows; bid += gridDim.x) {
+ int j = bid + blockIdx.x;
+ if(j < rows) {
+ T* rowOut = out + j * cols;
+ for(int tid = 0; tid < numIndices; tid += blockDim.x) {
+ int i = tid + threadIdx.x;
+ if(i < numIndices)
+ rowOut[wordIndices[i]] = value;
+ }
+ }
}
}
-void SetColumn(Tensor in, size_t col, float value) {
- int nRows = in->shape().elements() / in->shape()[-1];
- int nColumns = in->shape()[-1];
+void SetColumns(Tensor in, Tensor wordIndices, float value) {
+ matchOrAbort<IndexType>(wordIndices->type());
+
+ int rows = in->shape().elements() / in->shape().back();
+ int cols = in->shape().back();
+
+ int numIndices = wordIndices->size();
- int nBlocks = nRows / 512 + ((nRows % 512 == 0) ? 0 : 1);
- int nThreads = std::min(512, nRows);
+ int threads = std::min(MAX_THREADS, numIndices);
+ int blocks = std::min(MAX_BLOCKS, rows);
- if(in->type() == Type::float32) {
- gSetColumn<<<nBlocks, nThreads>>>(in->data<float>(), nColumns, nRows, col, value);
+ if(in->type() == Type::float32) {
+ gSetColumns<<<blocks, threads>>>(in->data<float>(), rows, cols, wordIndices->data<WordIndex>(), numIndices, value);
#if COMPILE_FP16
- } else if(in->type() == Type::float16) {
- gSetColumn<<<nBlocks, nThreads>>>(in->data<half>(), nColumns, nRows, col, (half)value);
+ } else if(in->type() == Type::float16) {
+ gSetColumns<<<blocks, threads>>>(in->data<half>(), rows, cols, wordIndices->data<WordIndex>(), numIndices, (half)value);
#endif
- } else {
- ABORT("suppressWord not implemented for type {}", in->type());
- }
+ } else {
+ ABORT("suppressWord not implemented for type {}", in->type());
+ }
}
-void suppressWord(Expr probs, WordIndex wordIndex) {
- SetColumn(probs->val(), wordIndex, NumericLimits<float>(probs->value_type()).lowest);
+void suppressWords(Expr probs, Expr wordIndices) {
+ SetColumns(probs->val(), wordIndices->val(), NumericLimits<float>(probs->value_type()).lowest);
}
+
} // namespace gpu
} // namespace marian
diff --git a/src/translator/helpers.h b/src/translator/helpers.h
index 71b1eb20..68159fa3 100644
--- a/src/translator/helpers.h
+++ b/src/translator/helpers.h
@@ -1,8 +1,3 @@
-/* All or part of this file was contributed by Intel under license:
- * Copyright (C) 2017-2018 Intel Corporation
- * SPDX-License-Identifier: MIT
- */
-
#pragma once
#include "graph/expression_graph.h"
@@ -11,13 +6,13 @@ namespace marian {
namespace cpu {
-void suppressWord(Expr logProbs, WordIndex wordIndex);
+void suppressWords(Expr logProbs, Expr wordIndices);
}
namespace gpu {
-void suppressWord(Expr logProbs, WordIndex wordIndex);
+void suppressWords(Expr logProbs, Expr wordIndices);
}
-void suppressWord(Expr logProbs, WordIndex wordIndex);
+void suppressWords(Expr logProbs, Expr wordIndices);
} // namespace marian