Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-11-27 02:20:15 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-11-27 02:20:15 +0300
commitde822a72b51b8c1e3e73b26c6627c61f295d4c71 (patch)
tree69f868e8e0e1ca8574af087df88c10037a7fdbe9
parent689b756172cf9eef61eceb6436750f16bf8f3353 (diff)
MatrixWrapper -> VectorWrapper for sentence lengths
-rw-r--r--src/amun/gpu/decoder/encoder_decoder.h2
-rw-r--r--src/amun/gpu/dl4mt/decoder.h10
-rw-r--r--src/amun/gpu/dl4mt/encoder.cu4
-rw-r--r--src/amun/gpu/dl4mt/encoder.h4
-rw-r--r--src/amun/gpu/mblas/matrix_functions.cu18
-rw-r--r--src/amun/gpu/mblas/matrix_functions.h6
6 files changed, 22 insertions, 22 deletions
diff --git a/src/amun/gpu/decoder/encoder_decoder.h b/src/amun/gpu/decoder/encoder_decoder.h
index 08207e97..ae314bb8 100644
--- a/src/amun/gpu/decoder/encoder_decoder.h
+++ b/src/amun/gpu/decoder/encoder_decoder.h
@@ -63,7 +63,7 @@ class EncoderDecoder : public Scorer {
std::unique_ptr<Encoder> encoder_;
std::unique_ptr<Decoder> decoder_;
mblas::Vector<uint> indices_;
- mblas::IMatrix sentenceLengths_;
+ mblas::Vector<uint> sentenceLengths_;
// set in Encoder::GetContext() to length (maxSentenceLength * batchSize). 1 if it's a word, 0 otherwise
std::unique_ptr<mblas::Matrix> SourceContext_;
diff --git a/src/amun/gpu/dl4mt/decoder.h b/src/amun/gpu/dl4mt/decoder.h
index 46a051bc..8e8cc8cb 100644
--- a/src/amun/gpu/dl4mt/decoder.h
+++ b/src/amun/gpu/dl4mt/decoder.h
@@ -67,7 +67,7 @@ class Decoder {
void InitializeState(CellState& State,
const mblas::Matrix& SourceContext,
const size_t batchSize,
- const mblas::IMatrix &sentenceLengths)
+ const mblas::Vector<uint> &sentenceLengths)
{
using namespace mblas;
@@ -157,7 +157,7 @@ class Decoder {
void GetAlignedSourceContext(mblas::Matrix& AlignedSourceContext,
const CellState& HiddenState,
const mblas::Matrix& SourceContext,
- const mblas::IMatrix &sentenceLengths,
+ const mblas::Vector<uint> &sentenceLengths,
const std::vector<uint>& beamSizes)
{
// mapping = 1/0 whether each position, in each sentence in the batch is actually a valid word
@@ -375,7 +375,7 @@ class Decoder {
const CellState& State,
const mblas::Matrix& Embeddings,
const mblas::Matrix& SourceContext,
- const mblas::IMatrix &sentenceLengths,
+ const mblas::Vector<uint> &sentenceLengths,
const std::vector<uint>& beamSizes,
bool useFusedSoftmax)
{
@@ -418,7 +418,7 @@ class Decoder {
void EmptyState(CellState& State,
const mblas::Matrix& SourceContext,
size_t batchSize,
- const mblas::IMatrix &sentenceLengths)
+ const mblas::Vector<uint> &sentenceLengths)
{
rnn1_.InitializeState(State, SourceContext, batchSize, sentenceLengths);
alignment_.Init(SourceContext);
@@ -469,7 +469,7 @@ class Decoder {
void GetAlignedSourceContext(mblas::Matrix& AlignedSourceContext,
const CellState& HiddenState,
const mblas::Matrix& SourceContext,
- const mblas::IMatrix &sentenceLengths,
+ const mblas::Vector<uint> &sentenceLengths,
const std::vector<uint>& beamSizes) {
alignment_.GetAlignedSourceContext(AlignedSourceContext,
HiddenState,
diff --git a/src/amun/gpu/dl4mt/encoder.cu b/src/amun/gpu/dl4mt/encoder.cu
index 7f66a776..12b58c28 100644
--- a/src/amun/gpu/dl4mt/encoder.cu
+++ b/src/amun/gpu/dl4mt/encoder.cu
@@ -63,7 +63,7 @@ std::vector<std::vector<uint>> GetBatchInput(const Sentences& source, size_t tab
}
void Encoder::Encode(const Sentences& source, size_t tab, mblas::Matrix& context,
- mblas::IMatrix &sentenceLengths)
+ mblas::Vector<uint> &sentenceLengths)
{
size_t maxSentenceLength = GetMaxLength(source, tab);
@@ -72,7 +72,7 @@ void Encoder::Encode(const Sentences& source, size_t tab, mblas::Matrix& context
hSentenceLengths[i] = source.at(i)->GetWords(tab).size();
}
- sentenceLengths.NewSize(source.size(), 1, 1, 1);
+ sentenceLengths.newSize(source.size());
mblas::copy(hSentenceLengths.data(),
hSentenceLengths.size(),
sentenceLengths.data(),
diff --git a/src/amun/gpu/dl4mt/encoder.h b/src/amun/gpu/dl4mt/encoder.h
index 02b38fc2..0224fc32 100644
--- a/src/amun/gpu/dl4mt/encoder.h
+++ b/src/amun/gpu/dl4mt/encoder.h
@@ -73,7 +73,7 @@ class Encoder {
template <class It>
void Encode(It it, It end, mblas::Matrix& Context,
size_t batchSize, bool invert,
- const mblas::IMatrix *sentenceLengths=nullptr)
+ const mblas::Vector<uint> *sentenceLengths=nullptr)
{
InitializeState(batchSize);
@@ -128,7 +128,7 @@ class Encoder {
Encoder(const Weights& model, const YAML::Node& config);
void Encode(const Sentences& words, size_t tab, mblas::Matrix& context,
- mblas::IMatrix &sentenceLengths);
+ mblas::Vector<uint> &sentenceLengths);
private:
std::unique_ptr<Cell> InitForwardCell(const Weights& model, const YAML::Node& config);
diff --git a/src/amun/gpu/mblas/matrix_functions.cu b/src/amun/gpu/mblas/matrix_functions.cu
index e0b42b6b..f4fb84f1 100644
--- a/src/amun/gpu/mblas/matrix_functions.cu
+++ b/src/amun/gpu/mblas/matrix_functions.cu
@@ -18,7 +18,7 @@ Matrix& Swap(Matrix& Out, Matrix& In) {
__global__ void gMean(MatrixWrapper<float> out,
const MatrixWrapper<float> in,
- const MatrixWrapper<uint> sentenceLengths)
+ const VectorWrapper<uint> sentenceLengths)
{
// out = batches * states
// in = max sentence length * states * 1 * batches
@@ -53,7 +53,7 @@ __global__ void gMean(MatrixWrapper<float> out,
void Mean(Matrix& Out,
const Matrix& In,
- const mblas::IMatrix &sentenceLengths)
+ const mblas::Vector<uint> &sentenceLengths)
{
assert(Out.dim(2) == 1);
assert(Out.dim(3) == 1);
@@ -69,7 +69,7 @@ void Mean(Matrix& Out,
MatrixWrapper<float> inWrap(In);
//cerr << "outWrap=" << outWrap.Debug() << endl;
- MatrixWrapper<uint> sentenceLengthsWrap(sentenceLengths, false);
+ VectorWrapper<uint> sentenceLengthsWrap(sentenceLengths);
uint size = outWrap.size();
uint threads = std::min((uint)MAX_THREADS, size);
@@ -435,7 +435,7 @@ Matrix& Prod(Matrix& C, const Matrix& A, const Matrix& B,
__global__ void gSoftMax(MatrixWrapper<float> out,
const VectorWrapper<uint> batchIdsWrap,
- const MatrixWrapper<uint> sentenceLengthsWrap,
+ const VectorWrapper<uint> sentenceLengthsWrap,
uint shareSize)
{
extern __shared__ float _share[];
@@ -520,14 +520,14 @@ __global__ void gSoftMax(MatrixWrapper<float> out,
Matrix& Softmax(Matrix& Out,
const mblas::Vector<uint>& batchIds,
- const mblas::IMatrix &sentenceLengths,
+ const mblas::Vector<uint> &sentenceLengths,
size_t batchSize)
{
size_t maxLength = Out.dim(1);
MatrixWrapper<float> outWrap(Out);
const VectorWrapper<uint> batchIdsWrap(batchIds);
- const MatrixWrapper<uint> sentenceLengthsWrap(sentenceLengths, false);
+ const VectorWrapper<uint> sentenceLengthsWrap(sentenceLengths);
int blocks = batchSize;
int threads = std::min(MAX_THREADS, (int)maxLength);
@@ -681,7 +681,7 @@ void Fill(Matrix& In, float value) {
__global__
void gMapMatrix(MatrixWrapper<float> in,
- const MatrixWrapper<uint> sentenceLengthsWrap,
+ const VectorWrapper<uint> sentenceLengthsWrap,
int i)
{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
@@ -696,7 +696,7 @@ void gMapMatrix(MatrixWrapper<float> in,
}
void MapMatrix(Matrix& state,
- const mblas::IMatrix &sentenceLengths,
+ const mblas::Vector<uint> &sentenceLengths,
size_t i)
{
// blank out rows in the state matrix where the word position i does not exist
@@ -709,7 +709,7 @@ void MapMatrix(Matrix& state,
int numBlocks = (state.size() / numThreads) + ((state.size() % numThreads == 0) ? 0 : 1);
MatrixWrapper<float> stateWrap(state);
- MatrixWrapper<uint> sentenceLengthsWrap(sentenceLengths);
+ VectorWrapper<uint> sentenceLengthsWrap(sentenceLengths);
gMapMatrix<<<numBlocks, numThreads, 0, CudaStreamHandler::GetStream()>>>
(stateWrap, sentenceLengthsWrap, i);
diff --git a/src/amun/gpu/mblas/matrix_functions.h b/src/amun/gpu/mblas/matrix_functions.h
index f62b5fd6..4ac80d5a 100644
--- a/src/amun/gpu/mblas/matrix_functions.h
+++ b/src/amun/gpu/mblas/matrix_functions.h
@@ -95,7 +95,7 @@ Matrix& Swap(Matrix& Out, Matrix& In);
void Mean(Matrix& Out,
const Matrix& In,
- const mblas::IMatrix &sentenceLengths);
+ const mblas::Vector<uint> &sentenceLengths);
void WeightedMean(Matrix& Out,const Matrix& Weights, const Matrix& In, const mblas::Vector<uint>& mapping);
@@ -119,7 +119,7 @@ Matrix& CopyRow(Matrix& Out,
Matrix& Concat(Matrix& Out, const Matrix& In);
void MapMatrix(Matrix& state,
- const mblas::IMatrix &sentenceLengths,
+ const mblas::Vector<uint> &sentenceLengths,
size_t i);
Matrix& CopyRows(Matrix& Out,
@@ -139,7 +139,7 @@ Matrix& Prod(Matrix& C, const Matrix& A, const Matrix& B,
Matrix& Softmax(Matrix& Out,
const mblas::Vector<uint>& batchIds,
- const mblas::IMatrix &sentenceLengths,
+ const mblas::Vector<uint> &sentenceLengths,
size_t batchSize);
Matrix& LogSoftmax(Matrix& Out);