diff options
-rw-r--r-- | src/graph/expression_operators.cu | 7 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 1 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 95 | ||||
-rw-r--r-- | src/kernels/tensor_operators.cu | 3 | ||||
-rw-r--r-- | src/models/transformer.h | 45 | ||||
-rw-r--r-- | src/tensors/allocator.h | 1 | ||||
-rw-r--r-- | src/training/validator.h | 4 | ||||
-rw-r--r-- | src/translator/scorers.h | 3 |
8 files changed, 116 insertions, 43 deletions
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu index 623dd325..f80c700c 100644 --- a/src/graph/expression_operators.cu +++ b/src/graph/expression_operators.cu @@ -120,6 +120,13 @@ Expr concatenate(const std::vector<Expr>& concats, keywords::axis_k ax) { return Expression<ConcatenateNodeOp>(concats, ax); } +Expr repeat(Expr a, size_t repeats, keywords::axis_k ax) { + if(repeats == 1) + return a; + return concatenate(std::vector<Expr>(repeats, a), ax); +} + + Expr reshape(Expr a, Shape shape) { return Expression<ReshapeNodeOp>(a, shape); } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index aaf11b13..93d746e5 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -73,6 +73,7 @@ Expr transpose(Expr a); Expr transpose(Expr a, const std::vector<int>& axes); Expr concatenate(const std::vector<Expr>& concats, keywords::axis_k ax = 0); +Expr repeat(Expr a, size_t repeats, keywords::axis_k ax = 0); Expr reshape(Expr a, Shape shape); diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 90995ddf..ef11998a 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -39,6 +39,25 @@ public: } const std::string type() { return "scalar_add"; } + + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, scalar_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<ScalarAddNodeOp>(node); + if(!cnode) + return false; + if(scalar_ != cnode->scalar_) + return false; + return true; + } }; struct ScalarMultNodeOp : public UnaryNodeOp { @@ -61,6 +80,25 @@ public: } const std::string type() { return "scalar_add"; } + + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, scalar_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<ScalarMultNodeOp>(node); + if(!cnode) + return false; + if(scalar_ != cnode->scalar_) + return false; + return true; + } }; struct LogitNodeOp : public UnaryNodeOp { @@ -256,6 +294,25 @@ struct PReLUNodeOp : public UnaryNodeOp { const std::string type() { return "PReLU"; } + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, alpha_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<PReLUNodeOp>(node); + if(!cnode) + return false; + if(alpha_ != cnode->alpha_) + return false; + return true; + } + private: float alpha_{0.01}; }; @@ -546,8 +603,6 @@ struct SqrtNodeOp : public UnaryNodeOp { }; struct SquareNodeOp : public UnaryNodeOp { - float epsilon_; - template <typename... Args> SquareNodeOp(Args... args) : UnaryNodeOp(args...) {} @@ -586,16 +641,16 @@ struct RowsNodeOp : public UnaryNodeOp { template <typename... Args> RowsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args) : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...), - indeces_(indeces) {} + indices_(indeces) {} NodeOps forwardOps() { // @TODO: solve this with a tensor! - return {NodeOp(CopyRows(val_, child(0)->val(), indeces_))}; + return {NodeOp(CopyRows(val_, child(0)->val(), indices_))}; } NodeOps backwardOps() { - return {NodeOp(PasteRows(child(0)->grad(), adj_, indeces_))}; + return {NodeOp(PasteRows(child(0)->grad(), adj_, indices_))}; } template <class... Args> @@ -614,7 +669,7 @@ struct RowsNodeOp : public UnaryNodeOp { virtual size_t hash() { if(!hash_) { size_t seed = NaryNodeOp::hash(); - for(auto i : indeces_) + for(auto i : indices_) boost::hash_combine(seed, i); hash_ = seed; } @@ -627,28 +682,28 @@ struct RowsNodeOp : public UnaryNodeOp { Ptr<RowsNodeOp> cnode = std::dynamic_pointer_cast<RowsNodeOp>(node); if(!cnode) return false; - if(indeces_ != cnode->indeces_) + if(indices_ != cnode->indices_) return false; return true; } - std::vector<size_t> indeces_; + std::vector<size_t> indices_; }; struct ColsNodeOp : public UnaryNodeOp { template <typename... Args> ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args) : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...), - indeces_(indeces) {} + indices_(indeces) {} NodeOps forwardOps() { // @TODO: solve this with a tensor! - return {NodeOp(CopyCols(val_, child(0)->val(), indeces_))}; + return {NodeOp(CopyCols(val_, child(0)->val(), indices_))}; } NodeOps backwardOps() { - return {NodeOp(PasteCols(child(0)->grad(), adj_, indeces_))}; + return {NodeOp(PasteCols(child(0)->grad(), adj_, indices_))}; } template <class... Args> @@ -665,7 +720,7 @@ struct ColsNodeOp : public UnaryNodeOp { virtual size_t hash() { if(!hash_) { size_t seed = NaryNodeOp::hash(); - for(auto i : indeces_) + for(auto i : indices_) boost::hash_combine(seed, i); hash_ = seed; } @@ -678,27 +733,27 @@ struct ColsNodeOp : public UnaryNodeOp { Ptr<ColsNodeOp> cnode = std::dynamic_pointer_cast<ColsNodeOp>(node); if(!cnode) return false; - if(indeces_ != cnode->indeces_) + if(indices_ != cnode->indices_) return false; return true; } - std::vector<size_t> indeces_; + std::vector<size_t> indices_; }; struct SelectNodeOp : public UnaryNodeOp { SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces) : UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)), - indeces_(indeces) {} + indices_(indeces) {} NodeOps forwardOps() { return {NodeOp( - Select(graph()->allocator(), val_, child(0)->val(), axis_, indeces_))}; + Select(graph()->allocator(), val_, child(0)->val(), axis_, indices_))}; } NodeOps backwardOps() { return {NodeOp( - Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indeces_))}; + Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indices_))}; } Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) { @@ -716,7 +771,7 @@ struct SelectNodeOp : public UnaryNodeOp { if(!hash_) { size_t seed = NaryNodeOp::hash(); boost::hash_combine(seed, axis_); - for(auto i : indeces_) + for(auto i : indices_) boost::hash_combine(seed, i); hash_ = seed; } @@ -731,12 +786,12 @@ struct SelectNodeOp : public UnaryNodeOp { return false; if(axis_ != cnode->axis_) return false; - if(indeces_ != cnode->indeces_) + if(indices_ != cnode->indices_) return false; return true; } - std::vector<size_t> indeces_; + std::vector<size_t> indices_; int axis_{0}; }; diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu index ef01aec2..402bb67b 100644 --- a/src/kernels/tensor_operators.cu +++ b/src/kernels/tensor_operators.cu @@ -34,7 +34,6 @@ bool IsNan(Tensor in) { void ConcatCont(Tensor out, const std::vector<Tensor>& inputs, int axis) { cudaSetDevice(out->getDevice()); - int step = 1; for(int i = 0; i < axis; ++i) step *= out->shape()[i]; @@ -1355,7 +1354,7 @@ void Att(Tensor out, Tensor va, Tensor context, Tensor state) { cudaSetDevice(out->getDevice()); size_t m = out->shape().elements() / out->shape().back(); - + size_t dims = context->shape().size(); size_t k = context->shape()[dims - 1]; size_t b = context->shape()[dims - 2]; diff --git a/src/models/transformer.h b/src/models/transformer.h index 352aa14a..92030c83 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -50,7 +50,7 @@ public: // convert 0/1 mask to transformer style -inf mask auto ms = mask->shape(); mask = (1 - mask) * -99999999.f; - return reshape(mask, {ms[-3], 1, ms[-2], ms[-1]}); + return reshape(mask, {ms[-3], 1, ms[-2], ms[-1]}) ; } Expr SplitHeads(Expr input, int dimHeads) { @@ -171,9 +171,10 @@ public: // @TODO: do this better int dimBeamQ = q->shape()[-4]; int dimBeamK = k->shape()[-4]; - if(dimBeamQ != dimBeamK) { - k = concatenate(std::vector<Expr>(dimBeamQ, k), axis = -4); - v = concatenate(std::vector<Expr>(dimBeamQ, v), axis = -4); + int dimBeam = dimBeamQ / dimBeamK; + if(dimBeam > 1) { + k = repeat(k, dimBeam, axis = -4); + v = repeat(v, dimBeam, axis = -4); } auto weights = softmax(bdot(q, k, false, true, scale) + mask); @@ -237,6 +238,7 @@ public: // apply multi-head attention to downscaled inputs auto output = Attention(graph, options, prefix, qh, kh, vh, masks[i], inference); + output = JoinHeads(output, q->shape()[-4]); outputs.push_back(output); @@ -442,7 +444,6 @@ public: // to make RNN-based decoders and beam search work with this. We are looking // into makeing this more natural. auto context = TransposeTimeBatch(layer); - debug(context, "context"); return New<EncoderState>(context, batchMask, batch); } @@ -457,12 +458,23 @@ public: std::vector<Ptr<EncoderState>> &encStates) : DecoderState(states, probs, encStates) {} - virtual Ptr<DecoderState> select(const std::vector<size_t> &selIdx) { + virtual Ptr<DecoderState> select(const std::vector<size_t> &selIdx, int beamSize) { rnn::States selectedStates; - for(auto state : states_) - selectedStates.push_back( - {marian::select(state.output, -4, selIdx), nullptr}); + int dimDepth = states_[0].output->shape()[-1]; + int dimTime = states_[0].output->shape()[-2]; + int dimBatch = selIdx.size() / beamSize; + + std::vector<size_t> selIdx2; + for(auto i : selIdx) + for(int j = 0; j < dimTime; ++j) + selIdx2.push_back(i * dimTime + j); + + for(auto state : states_) { + auto sel = rows(flatten_2d(state.output), selIdx2); + sel = reshape(sel, {beamSize, dimBatch, dimTime, dimDepth}); + selectedStates.push_back({sel, nullptr}); + } return New<TransformerState>(selectedStates, probs_, encStates_); } @@ -487,8 +499,6 @@ public: auto embeddings = state->getTargetEmbeddings(); auto decoderMask = state->getTargetMask(); - debug(embeddings, "embeddings"); - // dropout target words float dropoutTrg = inference_ ? 0 : opt<float>("dropout-trg"); if(dropoutTrg) { @@ -500,6 +510,9 @@ public: //************************************************************************// int dimEmb = embeddings->shape()[-1]; + int dimBeam = 1; + if(embeddings->shape().size() > 3) + dimBeam = embeddings->shape()[-4]; // according to paper embeddings are scaled by \sqrt(d_m) auto scaledEmbeddings = std::sqrt(dimEmb) * embeddings; @@ -531,6 +544,8 @@ public: decoderMask = reshape(TransposeTimeBatch(decoderMask), {1, dimBatch, 1, dimTrgWords}); selfMask = selfMask * decoderMask; + //if(dimBeam > 1) + // selfMask = repeat(selfMask, dimBeam, axis = -4); } selfMask = InverseMask(selfMask); @@ -551,6 +566,8 @@ public: encoderMask = reshape(TransposeTimeBatch(encoderMask), {1, dimBatch, 1, dimSrcWords}); encoderMask = InverseMask(encoderMask); + if(dimBeam > 1) + encoderMask = repeat(encoderMask, dimBeam, axis = -4); encoderContexts.push_back(encoderContext); encoderMasks.push_back(encoderMask); @@ -560,8 +577,7 @@ public: for(int i = 1; i <= opt<int>("dec-depth"); ++i) { auto values = query; if(prevDecoderStates.size() > 0) - values - = concatenate({prevDecoderStates[i - 1].output, query}, axis = -2); + values = concatenate({prevDecoderStates[i - 1].output, query}, axis = -2); decoderStates.push_back({values, nullptr}); @@ -640,9 +656,6 @@ public: Expr logits = output->apply(decoderContext); - debug(logits, "logits"); - - // return unormalized(!) probabilities return New<TransformerState>( decoderStates, logits, state->getEncoderStates()); diff --git a/src/tensors/allocator.h b/src/tensors/allocator.h index 8a84e4d7..72d4805a 100644 --- a/src/tensors/allocator.h +++ b/src/tensors/allocator.h @@ -179,7 +179,6 @@ public: auto ptr = gap.data(); auto mp = New<MemoryPiece>(ptr, bytes); allocated_[ptr] = mp; - return mp; } diff --git a/src/training/validator.h b/src/training/validator.h index 721b5629..8761b591 100644 --- a/src/training/validator.h +++ b/src/training/validator.h @@ -224,8 +224,8 @@ public: // Temporary options for translation auto opts = New<Config>(*options_); - opts->set("mini-batch", 1); - opts->set("maxi-batch", 1); + //opts->set("mini-batch", 1); + //opts->set("maxi-batch", 1); opts->set("max-length", 1000); // Create corpus diff --git a/src/translator/scorers.h b/src/translator/scorers.h index 1cabbf7a..e963efb0 100644 --- a/src/translator/scorers.h +++ b/src/translator/scorers.h @@ -83,8 +83,7 @@ public: virtual Ptr<ScorerState> startState(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) { graph->switchParams(getName()); - auto state = encdec_->startState(graph, batch); - return New<ScorerWrapperState>(state); + return New<ScorerWrapperState>(encdec_->startState(graph, batch)); } virtual Ptr<ScorerState> step(Ptr<ExpressionGraph> graph, |