Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/graph/expression_operators.cu7
-rw-r--r--src/graph/expression_operators.h1
-rw-r--r--src/graph/node_operators_unary.h95
-rw-r--r--src/kernels/tensor_operators.cu3
-rw-r--r--src/models/transformer.h45
-rw-r--r--src/tensors/allocator.h1
-rw-r--r--src/training/validator.h4
-rw-r--r--src/translator/scorers.h3
8 files changed, 116 insertions, 43 deletions
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu
index 623dd325..f80c700c 100644
--- a/src/graph/expression_operators.cu
+++ b/src/graph/expression_operators.cu
@@ -120,6 +120,13 @@ Expr concatenate(const std::vector<Expr>& concats, keywords::axis_k ax) {
return Expression<ConcatenateNodeOp>(concats, ax);
}
+Expr repeat(Expr a, size_t repeats, keywords::axis_k ax) {
+ if(repeats == 1)
+ return a;
+ return concatenate(std::vector<Expr>(repeats, a), ax);
+}
+
+
Expr reshape(Expr a, Shape shape) {
return Expression<ReshapeNodeOp>(a, shape);
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index aaf11b13..93d746e5 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -73,6 +73,7 @@ Expr transpose(Expr a);
Expr transpose(Expr a, const std::vector<int>& axes);
Expr concatenate(const std::vector<Expr>& concats, keywords::axis_k ax = 0);
+Expr repeat(Expr a, size_t repeats, keywords::axis_k ax = 0);
Expr reshape(Expr a, Shape shape);
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 90995ddf..ef11998a 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -39,6 +39,25 @@ public:
}
const std::string type() { return "scalar_add"; }
+
+ virtual size_t hash() {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, scalar_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<ScalarAddNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(scalar_ != cnode->scalar_)
+ return false;
+ return true;
+ }
};
struct ScalarMultNodeOp : public UnaryNodeOp {
@@ -61,6 +80,25 @@ public:
}
const std::string type() { return "scalar_add"; }
+
+ virtual size_t hash() {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, scalar_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<ScalarMultNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(scalar_ != cnode->scalar_)
+ return false;
+ return true;
+ }
};
struct LogitNodeOp : public UnaryNodeOp {
@@ -256,6 +294,25 @@ struct PReLUNodeOp : public UnaryNodeOp {
const std::string type() { return "PReLU"; }
+ virtual size_t hash() {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, alpha_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<PReLUNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(alpha_ != cnode->alpha_)
+ return false;
+ return true;
+ }
+
private:
float alpha_{0.01};
};
@@ -546,8 +603,6 @@ struct SqrtNodeOp : public UnaryNodeOp {
};
struct SquareNodeOp : public UnaryNodeOp {
- float epsilon_;
-
template <typename... Args>
SquareNodeOp(Args... args) : UnaryNodeOp(args...) {}
@@ -586,16 +641,16 @@ struct RowsNodeOp : public UnaryNodeOp {
template <typename... Args>
RowsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args)
: UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...),
- indeces_(indeces) {}
+ indices_(indeces) {}
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
- return {NodeOp(CopyRows(val_, child(0)->val(), indeces_))};
+ return {NodeOp(CopyRows(val_, child(0)->val(), indices_))};
}
NodeOps backwardOps() {
- return {NodeOp(PasteRows(child(0)->grad(), adj_, indeces_))};
+ return {NodeOp(PasteRows(child(0)->grad(), adj_, indices_))};
}
template <class... Args>
@@ -614,7 +669,7 @@ struct RowsNodeOp : public UnaryNodeOp {
virtual size_t hash() {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
- for(auto i : indeces_)
+ for(auto i : indices_)
boost::hash_combine(seed, i);
hash_ = seed;
}
@@ -627,28 +682,28 @@ struct RowsNodeOp : public UnaryNodeOp {
Ptr<RowsNodeOp> cnode = std::dynamic_pointer_cast<RowsNodeOp>(node);
if(!cnode)
return false;
- if(indeces_ != cnode->indeces_)
+ if(indices_ != cnode->indices_)
return false;
return true;
}
- std::vector<size_t> indeces_;
+ std::vector<size_t> indices_;
};
struct ColsNodeOp : public UnaryNodeOp {
template <typename... Args>
ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args)
: UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...),
- indeces_(indeces) {}
+ indices_(indeces) {}
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
- return {NodeOp(CopyCols(val_, child(0)->val(), indeces_))};
+ return {NodeOp(CopyCols(val_, child(0)->val(), indices_))};
}
NodeOps backwardOps() {
- return {NodeOp(PasteCols(child(0)->grad(), adj_, indeces_))};
+ return {NodeOp(PasteCols(child(0)->grad(), adj_, indices_))};
}
template <class... Args>
@@ -665,7 +720,7 @@ struct ColsNodeOp : public UnaryNodeOp {
virtual size_t hash() {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
- for(auto i : indeces_)
+ for(auto i : indices_)
boost::hash_combine(seed, i);
hash_ = seed;
}
@@ -678,27 +733,27 @@ struct ColsNodeOp : public UnaryNodeOp {
Ptr<ColsNodeOp> cnode = std::dynamic_pointer_cast<ColsNodeOp>(node);
if(!cnode)
return false;
- if(indeces_ != cnode->indeces_)
+ if(indices_ != cnode->indices_)
return false;
return true;
}
- std::vector<size_t> indeces_;
+ std::vector<size_t> indices_;
};
struct SelectNodeOp : public UnaryNodeOp {
SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces)
: UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)),
- indeces_(indeces) {}
+ indices_(indeces) {}
NodeOps forwardOps() {
return {NodeOp(
- Select(graph()->allocator(), val_, child(0)->val(), axis_, indeces_))};
+ Select(graph()->allocator(), val_, child(0)->val(), axis_, indices_))};
}
NodeOps backwardOps() {
return {NodeOp(
- Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indeces_))};
+ Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indices_))};
}
Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) {
@@ -716,7 +771,7 @@ struct SelectNodeOp : public UnaryNodeOp {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
boost::hash_combine(seed, axis_);
- for(auto i : indeces_)
+ for(auto i : indices_)
boost::hash_combine(seed, i);
hash_ = seed;
}
@@ -731,12 +786,12 @@ struct SelectNodeOp : public UnaryNodeOp {
return false;
if(axis_ != cnode->axis_)
return false;
- if(indeces_ != cnode->indeces_)
+ if(indices_ != cnode->indices_)
return false;
return true;
}
- std::vector<size_t> indeces_;
+ std::vector<size_t> indices_;
int axis_{0};
};
diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu
index ef01aec2..402bb67b 100644
--- a/src/kernels/tensor_operators.cu
+++ b/src/kernels/tensor_operators.cu
@@ -34,7 +34,6 @@ bool IsNan(Tensor in) {
void ConcatCont(Tensor out, const std::vector<Tensor>& inputs, int axis) {
cudaSetDevice(out->getDevice());
-
int step = 1;
for(int i = 0; i < axis; ++i)
step *= out->shape()[i];
@@ -1355,7 +1354,7 @@ void Att(Tensor out, Tensor va, Tensor context, Tensor state) {
cudaSetDevice(out->getDevice());
size_t m = out->shape().elements() / out->shape().back();
-
+
size_t dims = context->shape().size();
size_t k = context->shape()[dims - 1];
size_t b = context->shape()[dims - 2];
diff --git a/src/models/transformer.h b/src/models/transformer.h
index 352aa14a..92030c83 100644
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -50,7 +50,7 @@ public:
// convert 0/1 mask to transformer style -inf mask
auto ms = mask->shape();
mask = (1 - mask) * -99999999.f;
- return reshape(mask, {ms[-3], 1, ms[-2], ms[-1]});
+ return reshape(mask, {ms[-3], 1, ms[-2], ms[-1]}) ;
}
Expr SplitHeads(Expr input, int dimHeads) {
@@ -171,9 +171,10 @@ public:
// @TODO: do this better
int dimBeamQ = q->shape()[-4];
int dimBeamK = k->shape()[-4];
- if(dimBeamQ != dimBeamK) {
- k = concatenate(std::vector<Expr>(dimBeamQ, k), axis = -4);
- v = concatenate(std::vector<Expr>(dimBeamQ, v), axis = -4);
+ int dimBeam = dimBeamQ / dimBeamK;
+ if(dimBeam > 1) {
+ k = repeat(k, dimBeam, axis = -4);
+ v = repeat(v, dimBeam, axis = -4);
}
auto weights = softmax(bdot(q, k, false, true, scale) + mask);
@@ -237,6 +238,7 @@ public:
// apply multi-head attention to downscaled inputs
auto output
= Attention(graph, options, prefix, qh, kh, vh, masks[i], inference);
+
output = JoinHeads(output, q->shape()[-4]);
outputs.push_back(output);
@@ -442,7 +444,6 @@ public:
// to make RNN-based decoders and beam search work with this. We are looking
// into makeing this more natural.
auto context = TransposeTimeBatch(layer);
- debug(context, "context");
return New<EncoderState>(context, batchMask, batch);
}
@@ -457,12 +458,23 @@ public:
std::vector<Ptr<EncoderState>> &encStates)
: DecoderState(states, probs, encStates) {}
- virtual Ptr<DecoderState> select(const std::vector<size_t> &selIdx) {
+ virtual Ptr<DecoderState> select(const std::vector<size_t> &selIdx, int beamSize) {
rnn::States selectedStates;
- for(auto state : states_)
- selectedStates.push_back(
- {marian::select(state.output, -4, selIdx), nullptr});
+ int dimDepth = states_[0].output->shape()[-1];
+ int dimTime = states_[0].output->shape()[-2];
+ int dimBatch = selIdx.size() / beamSize;
+
+ std::vector<size_t> selIdx2;
+ for(auto i : selIdx)
+ for(int j = 0; j < dimTime; ++j)
+ selIdx2.push_back(i * dimTime + j);
+
+ for(auto state : states_) {
+ auto sel = rows(flatten_2d(state.output), selIdx2);
+ sel = reshape(sel, {beamSize, dimBatch, dimTime, dimDepth});
+ selectedStates.push_back({sel, nullptr});
+ }
return New<TransformerState>(selectedStates, probs_, encStates_);
}
@@ -487,8 +499,6 @@ public:
auto embeddings = state->getTargetEmbeddings();
auto decoderMask = state->getTargetMask();
- debug(embeddings, "embeddings");
-
// dropout target words
float dropoutTrg = inference_ ? 0 : opt<float>("dropout-trg");
if(dropoutTrg) {
@@ -500,6 +510,9 @@ public:
//************************************************************************//
int dimEmb = embeddings->shape()[-1];
+ int dimBeam = 1;
+ if(embeddings->shape().size() > 3)
+ dimBeam = embeddings->shape()[-4];
// according to paper embeddings are scaled by \sqrt(d_m)
auto scaledEmbeddings = std::sqrt(dimEmb) * embeddings;
@@ -531,6 +544,8 @@ public:
decoderMask = reshape(TransposeTimeBatch(decoderMask),
{1, dimBatch, 1, dimTrgWords});
selfMask = selfMask * decoderMask;
+ //if(dimBeam > 1)
+ // selfMask = repeat(selfMask, dimBeam, axis = -4);
}
selfMask = InverseMask(selfMask);
@@ -551,6 +566,8 @@ public:
encoderMask = reshape(TransposeTimeBatch(encoderMask),
{1, dimBatch, 1, dimSrcWords});
encoderMask = InverseMask(encoderMask);
+ if(dimBeam > 1)
+ encoderMask = repeat(encoderMask, dimBeam, axis = -4);
encoderContexts.push_back(encoderContext);
encoderMasks.push_back(encoderMask);
@@ -560,8 +577,7 @@ public:
for(int i = 1; i <= opt<int>("dec-depth"); ++i) {
auto values = query;
if(prevDecoderStates.size() > 0)
- values
- = concatenate({prevDecoderStates[i - 1].output, query}, axis = -2);
+ values = concatenate({prevDecoderStates[i - 1].output, query}, axis = -2);
decoderStates.push_back({values, nullptr});
@@ -640,9 +656,6 @@ public:
Expr logits = output->apply(decoderContext);
- debug(logits, "logits");
-
-
// return unormalized(!) probabilities
return New<TransformerState>(
decoderStates, logits, state->getEncoderStates());
diff --git a/src/tensors/allocator.h b/src/tensors/allocator.h
index 8a84e4d7..72d4805a 100644
--- a/src/tensors/allocator.h
+++ b/src/tensors/allocator.h
@@ -179,7 +179,6 @@ public:
auto ptr = gap.data();
auto mp = New<MemoryPiece>(ptr, bytes);
allocated_[ptr] = mp;
-
return mp;
}
diff --git a/src/training/validator.h b/src/training/validator.h
index 721b5629..8761b591 100644
--- a/src/training/validator.h
+++ b/src/training/validator.h
@@ -224,8 +224,8 @@ public:
// Temporary options for translation
auto opts = New<Config>(*options_);
- opts->set("mini-batch", 1);
- opts->set("maxi-batch", 1);
+ //opts->set("mini-batch", 1);
+ //opts->set("maxi-batch", 1);
opts->set("max-length", 1000);
// Create corpus
diff --git a/src/translator/scorers.h b/src/translator/scorers.h
index 1cabbf7a..e963efb0 100644
--- a/src/translator/scorers.h
+++ b/src/translator/scorers.h
@@ -83,8 +83,7 @@ public:
virtual Ptr<ScorerState> startState(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch) {
graph->switchParams(getName());
- auto state = encdec_->startState(graph, batch);
- return New<ScorerWrapperState>(state);
+ return New<ScorerWrapperState>(encdec_->startState(graph, batch));
}
virtual Ptr<ScorerState> step(Ptr<ExpressionGraph> graph,