diff options
author | Martin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-03-26 19:17:12 +0300 |
---|---|---|
committer | Martin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-03-26 19:17:12 +0300 |
commit | 7d1f941242928c976640a20f37e1bd9ac10011e8 (patch) | |
tree | a8f895b2d26bc1d947fe8a5fcb215d88a747dd6f /src | |
parent | 08bb158974597e92c3b5b0e20d938697bf6146b8 (diff) |
Merged PR 18309: Cleaner suppression of unwanted output words
This PR adds cleaner suppression of unwanted output words. We identified a situation where SPM with byte-fallback can generate random bytes with output-sampling.
That is particularly harmful when that random bytes happens to be a newline symbol. Here we suppress newline in output unless explicitly wanted.
Diffstat (limited to 'src')
-rw-r--r-- | src/common/config_parser.cpp | 2 | ||||
-rw-r--r-- | src/data/sentencepiece_vocab.cpp | 27 | ||||
-rw-r--r-- | src/data/shortlist.h | 7 | ||||
-rw-r--r-- | src/data/vocab.cpp | 21 | ||||
-rw-r--r-- | src/data/vocab.h | 7 | ||||
-rw-r--r-- | src/data/vocab_base.h | 5 | ||||
-rw-r--r-- | src/translator/beam_search.cpp | 31 | ||||
-rw-r--r-- | src/translator/helpers.cpp | 23 | ||||
-rw-r--r-- | src/translator/helpers.cu | 64 | ||||
-rw-r--r-- | src/translator/helpers.h | 11 |
10 files changed, 135 insertions, 63 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 602509c5..6495db0e 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -651,6 +651,8 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) { "Subtract (arg * translation length) from translation score"); cli.add<bool>("--allow-unk", "Allow unknown words to appear in output"); + cli.add<bool>("--allow-special", + "Allow special symbols to appear in output, e.g. for SentencePiece with byte-fallback do not suppress the newline symbol"); cli.add<bool>("--n-best", "Generate n-best list"); cli.add<std::string>("--alignment", diff --git a/src/data/sentencepiece_vocab.cpp b/src/data/sentencepiece_vocab.cpp index c168f6e3..090d478b 100644 --- a/src/data/sentencepiece_vocab.cpp +++ b/src/data/sentencepiece_vocab.cpp @@ -39,6 +39,20 @@ private: // Keeps sentences segmented into subword units bool keepEncoded_{false}; + // Contains control characters added to vocab due to byte-fallback + std::vector<Word> controlChars_; + + // Creates the first 32 control characters as done in byte-fallback and checks if they exist in the vocab. + // This makes sure that we do not waste computational effort on suppression if they don't actually appear. + void populateControlChars() { + for(int i = 0; i < 32; ++i) { + std::string bytePiece = fmt::format("<0x{:02X}>", i); // 0 becomes <0x00>, 10 becomes <0x0A>, note uppercase A and lowercase x + auto id = spm_->PieceToId(bytePiece); + if(id != spm_->unk_id()) + controlChars_.push_back(Word::fromWordIndex(id)); + } + } + // Sample from one file, based on first algorithm from: // https://en.wikipedia.org/wiki/Reservoir_sampling void reservoirSampling(std::vector<std::string>& sample, size_t& seenLines, @@ -262,11 +276,24 @@ public: "SentencePiece vocabulary error: {}", status.ToString()); + populateControlChars(); + return spm_->GetPieceSize(); } std::string toUpper(const std::string& line) const override { return utils::utf8ToUpper(line); } std::string toEnglishTitleCase(const std::string& line) const override { return utils::toEnglishTitleCase(line); } + + // SentencePiece with byte-fallback may generate control symbols with output sampling. + // Let's mark them as special and suppress them later on output. This is generally safe + // for UTF-8 since control chars are not used as partial bytes in multi-byte sequences. + // They only appear in single-byte chars as themselves and this is what we suppress. + void addSpecialWords(std::vector<Word>& special) const override { + special.reserve(special.size() + controlChars_.size()); + for(auto c : controlChars_) + special.push_back(c); + } + }; #endif // USE_SENTENCEPIECE diff --git a/src/data/shortlist.h b/src/data/shortlist.h index ab6a087b..f0467640 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -13,6 +13,7 @@ #include <vector> #include <iostream> #include <algorithm> +#include <limits> namespace marian { namespace data { @@ -22,18 +23,20 @@ private: std::vector<WordIndex> indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings public: + static constexpr WordIndex npos{std::numeric_limits<WordIndex>::max()}; // used to identify invalid shortlist entries similar to std::string::npos + Shortlist(const std::vector<WordIndex>& indices) : indices_(indices) {} const std::vector<WordIndex>& indices() const { return indices_; } WordIndex reverseMap(int idx) { return indices_[idx]; } - int tryForwardMap(WordIndex wIdx) { + WordIndex tryForwardMap(WordIndex wIdx) { auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx); if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx return (int)std::distance(indices_.begin(), first); // return coordinate if found else - return -1; // return -1 if not found + return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17? } }; diff --git a/src/data/vocab.cpp b/src/data/vocab.cpp index 07ac479e..8a3d49c7 100644 --- a/src/data/vocab.cpp +++ b/src/data/vocab.cpp @@ -138,6 +138,27 @@ Word Vocab::getEosId() const { return vImpl_->getEosId(); } // return UNK symbol id Word Vocab::getUnkId() const { return vImpl_->getUnkId(); } +std::vector<Word> Vocab::suppressedIds(bool suppressUnk, bool suppressSpecial) const { + std::vector<Word> ids; + if(suppressUnk) { + auto unkId = getUnkId(); + if(unkId != Word::NONE) + ids.push_back(unkId); + } + if(suppressSpecial) + vImpl_->addSpecialWords(/*in/out=*/ids); + return ids; +} + +std::vector<WordIndex> Vocab::suppressedIndices(bool suppressUnk, bool suppressSpecial) const { + std::vector<WordIndex> indices; + for(Word word : suppressedIds(suppressUnk, suppressSpecial)) + indices.push_back(word.toWordIndex()); + + vImpl_->transcodeToShortlistInPlace(indices.data(), indices.size()); + return indices; +} + // for corpus augmentation: convert string to all-caps std::string Vocab::toUpper(const std::string& line) const { return vImpl_->toUpper(line); } diff --git a/src/data/vocab.h b/src/data/vocab.h index 9a40ba16..2ab6b2b0 100644 --- a/src/data/vocab.h +++ b/src/data/vocab.h @@ -70,6 +70,13 @@ public: // return UNK symbol id Word getUnkId() const; + // return a set of Word ids that should be suppressed based on the underlying vocabulary implementation. + // Arguments mosty likely provided based on outside options like --allow-unk etc. + std::vector<Word> suppressedIds(bool suppressUnk = true, bool suppressSpecial = true) const; + + // same as suppressedIds but return numeric word indices into the embedding matrices + std::vector<WordIndex> suppressedIndices(bool suppressUnk = true, bool suppressSpecial = true) const; + // for corpus augmentation: convert string to all-caps // @TODO: Consider a different implementation where this does not show on the vocab interface, // but instead as additional options passed to vocab instantiation. diff --git a/src/data/vocab_base.h b/src/data/vocab_base.h index 8c214c97..fc512026 100644 --- a/src/data/vocab_base.h +++ b/src/data/vocab_base.h @@ -49,9 +49,12 @@ public: virtual std::string toUpper(const std::string& line) const { return line; } virtual std::string toEnglishTitleCase(const std::string& line) const { return line; } - // this function is an identity mapping for default vocabularies, hence do nothing + // Identity mapping for default vocabularies, hence do nothing virtual void transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const { ptr; num; } + // Populates vector `special` with special words like "\n" etc. + virtual void addSpecialWords(std::vector<Word>& special) const { special; } + virtual void createFake() = 0; virtual Word randWord() const { diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp index 5c1989a6..91dde6e6 100644 --- a/src/translator/beam_search.cpp +++ b/src/translator/beam_search.cpp @@ -258,7 +258,6 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> // We will use the prefix "currentBatch.." whenever we refer to batch dimension that can change due to batch-pruning. const int origDimBatch = (int)batch->size(); const auto trgEosId = trgVocab_->getEosId(); - const auto trgUnkId = trgVocab_->getUnkId(); auto getNBestList = createGetNBestListFn(beamSize_, origDimBatch, graph->getDeviceId()); @@ -298,13 +297,23 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> const_cast<std::vector<bool>&>(emptyBatchEntries).push_back(batch->front()->data()[origBatchIdx] == srcEosId); // const_cast during construction } - // determine index of UNK in the log prob vectors if we want to suppress it in the decoding process - int unkColId = -1; - if (trgUnkId != Word::NONE && !options_->get<bool>("allow-unk", false)) { // do we need to suppress unk? - unkColId = factoredVocab ? factoredVocab->getUnkIndex() : trgUnkId.toWordIndex(); // what's the raw index of unk in the log prob vector? - auto shortlist = scorers_[0]->getShortlist(); // first shortlist is generally ok, @TODO: make sure they are the same across scorers? - if (shortlist) - unkColId = shortlist->tryForwardMap(unkColId); // use shifted postion of unk in case of using a shortlist, shortlist may have removed unk which results in -1 + Expr suppressedWordIndices; + bool suppressUnk = !options_->get<bool>("allow-unk", false); + bool suppressSpecial = !options_->get<bool>("allow-special", false); + if (suppressUnk || suppressSpecial) { // do we need to suppress unk or special? + std::vector<WordIndex> suppressed = trgVocab_->suppressedIndices(suppressUnk, suppressSpecial); + + auto shortlist = scorers_[0]->getShortlist(); // first shortlist is generally ok, @TODO: make sure they are the same across scorers? + if(shortlist) // check if suppressed words are allowed by the shortlist, if not, remove + suppressed.erase(std::remove_if(suppressed.begin(), + suppressed.end(), + [&](WordIndex i) { + return shortlist->tryForwardMap(i) == data::Shortlist::npos; + }), + suppressed.end()); + + if(!suppressed.empty()) + suppressedWordIndices = graph->indices(suppressed); } // the decoding process updates the following state information in each output time step: @@ -453,10 +462,8 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> //********************************************************************** // suppress specific symbols if not at right positions - if(unkColId != -1 && factorGroup == 0) - suppressWord(expandedPathScores, unkColId); - for(auto state : states) - state->blacklist(expandedPathScores, batch); + if(suppressedWordIndices && factorGroup == 0) + suppressWords(expandedPathScores, suppressedWordIndices); //********************************************************************** // perform beam search diff --git a/src/translator/helpers.cpp b/src/translator/helpers.cpp index f4b75da0..e37c6a4a 100644 --- a/src/translator/helpers.cpp +++ b/src/translator/helpers.cpp @@ -13,29 +13,30 @@ namespace marian { namespace cpu { -void SetColumn(Tensor in_, size_t col, float value) { - int nRows = in_->shape().elements() / in_->shape()[-1]; - int nColumns = in_->shape()[-1]; +void SetColumns(Tensor in, Tensor indices, float value) { + int nRows = in->shape().elements() / in->shape()[-1]; + int nColumns = in->shape()[-1]; + int nSuppress = indices->shape()[-1]; - float* in = in_->data(); for(int rowNumber = 0; rowNumber < nRows; ++rowNumber) { - auto index = col + rowNumber * nColumns; - in[index] = value; + float* row = in->data() + rowNumber * nColumns; + for(int i = 0; i < nSuppress; ++i) + row[indices->data<WordIndex>()[i]] = value; } } -void suppressWord(Expr logProbs, WordIndex wordIndex) { - SetColumn(logProbs->val(), wordIndex, std::numeric_limits<float>::lowest()); +void suppressWords(Expr logProbs, Expr wordIndices) { + SetColumns(logProbs->val(), wordIndices->val(), std::numeric_limits<float>::lowest()); } } // namespace cpu -void suppressWord(Expr logProbs, WordIndex wordIndex) { +void suppressWords(Expr logProbs, Expr wordIndices) { if(logProbs->val()->getBackend()->getDeviceId().type == DeviceType::cpu) { - cpu::suppressWord(logProbs, wordIndex); + cpu::suppressWords(logProbs, wordIndices); } #ifdef CUDA_FOUND else { - gpu::suppressWord(logProbs, wordIndex); + gpu::suppressWords(logProbs, wordIndices); } #endif } diff --git a/src/translator/helpers.cu b/src/translator/helpers.cu index 48a079fc..f0272609 100644 --- a/src/translator/helpers.cu +++ b/src/translator/helpers.cu @@ -1,8 +1,3 @@ -/* All or part of this file was contributed by Intel under license: - * Copyright (C) 2017-2018 Intel Corporation - * SPDX-License-Identifier: MIT - */ - #include <cuda.h> #include <limits> @@ -17,39 +12,50 @@ namespace marian { namespace gpu { template <typename T> -__global__ void gSetColumn(T* d_in, - size_t n_columns, - size_t n_rows, - size_t noColumn, - T value) { - size_t rowNumber = threadIdx.x + blockDim.x * blockIdx.x; - size_t index = noColumn + rowNumber * n_columns; - - if(index < n_columns * n_rows) { - d_in[index] = value; +__global__ void gSetColumns(T* out, + int rows, + int cols, + const IndexType* wordIndices, + int numIndices, + T value) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + T* rowOut = out + j * cols; + for(int tid = 0; tid < numIndices; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < numIndices) + rowOut[wordIndices[i]] = value; + } + } } } -void SetColumn(Tensor in, size_t col, float value) { - int nRows = in->shape().elements() / in->shape()[-1]; - int nColumns = in->shape()[-1]; +void SetColumns(Tensor in, Tensor wordIndices, float value) { + matchOrAbort<IndexType>(wordIndices->type()); + + int rows = in->shape().elements() / in->shape().back(); + int cols = in->shape().back(); + + int numIndices = wordIndices->size(); - int nBlocks = nRows / 512 + ((nRows % 512 == 0) ? 0 : 1); - int nThreads = std::min(512, nRows); + int threads = std::min(MAX_THREADS, numIndices); + int blocks = std::min(MAX_BLOCKS, rows); - if(in->type() == Type::float32) { - gSetColumn<<<nBlocks, nThreads>>>(in->data<float>(), nColumns, nRows, col, value); + if(in->type() == Type::float32) { + gSetColumns<<<blocks, threads>>>(in->data<float>(), rows, cols, wordIndices->data<WordIndex>(), numIndices, value); #if COMPILE_FP16 - } else if(in->type() == Type::float16) { - gSetColumn<<<nBlocks, nThreads>>>(in->data<half>(), nColumns, nRows, col, (half)value); + } else if(in->type() == Type::float16) { + gSetColumns<<<blocks, threads>>>(in->data<half>(), rows, cols, wordIndices->data<WordIndex>(), numIndices, (half)value); #endif - } else { - ABORT("suppressWord not implemented for type {}", in->type()); - } + } else { + ABORT("suppressWord not implemented for type {}", in->type()); + } } -void suppressWord(Expr probs, WordIndex wordIndex) { - SetColumn(probs->val(), wordIndex, NumericLimits<float>(probs->value_type()).lowest); +void suppressWords(Expr probs, Expr wordIndices) { + SetColumns(probs->val(), wordIndices->val(), NumericLimits<float>(probs->value_type()).lowest); } + } // namespace gpu } // namespace marian diff --git a/src/translator/helpers.h b/src/translator/helpers.h index 71b1eb20..68159fa3 100644 --- a/src/translator/helpers.h +++ b/src/translator/helpers.h @@ -1,8 +1,3 @@ -/* All or part of this file was contributed by Intel under license: - * Copyright (C) 2017-2018 Intel Corporation - * SPDX-License-Identifier: MIT - */ - #pragma once #include "graph/expression_graph.h" @@ -11,13 +6,13 @@ namespace marian { namespace cpu { -void suppressWord(Expr logProbs, WordIndex wordIndex); +void suppressWords(Expr logProbs, Expr wordIndices); } namespace gpu { -void suppressWord(Expr logProbs, WordIndex wordIndex); +void suppressWords(Expr logProbs, Expr wordIndices); } -void suppressWord(Expr logProbs, WordIndex wordIndex); +void suppressWords(Expr logProbs, Expr wordIndices); } // namespace marian |