diff options
Diffstat (limited to 'src/layers/constructors.h')
-rw-r--r-- | src/layers/constructors.h | 70 |
1 files changed, 40 insertions, 30 deletions
diff --git a/src/layers/constructors.h b/src/layers/constructors.h index e25449aa..9e9de207 100644 --- a/src/layers/constructors.h +++ b/src/layers/constructors.h @@ -1,8 +1,8 @@ #pragma once +#include "layers/embedding.h" #include "layers/factory.h" #include "layers/generic.h" -#include "layers/embedding.h" #include "layers/output.h" namespace marian { @@ -45,6 +45,7 @@ struct LogitLayerFactory : public Factory { // @TODO: In the long run, I hope we can get rid of the abstract factories altogether. class OutputFactory : public LogitLayerFactory { using LogitLayerFactory::LogitLayerFactory; + protected: std::string tiedTransposedName_; Ptr<data::Shortlist> shortlist_; @@ -55,9 +56,7 @@ public: return Accumulator<OutputFactory>(*this); } - void setShortlist(Ptr<data::Shortlist> shortlist) { - shortlist_ = shortlist; - } + void setShortlist(Ptr<data::Shortlist> shortlist) { shortlist_ = shortlist; } Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) override { auto output = New<Output>(graph, options_); @@ -89,8 +88,7 @@ protected: std::vector<Ptr<IUnaryLayer>> layers_; public: - MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) - : graph_(graph), options_(options) {} + MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {} Expr apply(const std::vector<Expr>& av) override { Expr output; @@ -106,46 +104,53 @@ public: } Logits applyAsLogits(const std::vector<Expr>& av) override { - // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different return type + // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different + // return type auto lastLayer = std::dynamic_pointer_cast<IUnaryLogitLayer>(layers_.back()); - ABORT_IF(!lastLayer, "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer"); - if (layers_.size() == 1) { - if (av.size() == 1) + ABORT_IF( + !lastLayer, + "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer"); + if(layers_.size() == 1) { + if(av.size() == 1) return lastLayer->applyAsLogits(av[0]); else return lastLayer->applyAsLogits(av); - } - else { + } else { Expr output; - if (av.size() == 1) + if(av.size() == 1) output = layers_[0]->apply(av[0]); else output = layers_[0]->apply(av); - for (size_t i = 1; i < layers_.size() - 1; ++i) + for(size_t i = 1; i < layers_.size() - 1; ++i) output = layers_[i]->apply(output); return lastLayer->applyAsLogits(output); } } - Expr apply(Expr e) override { return apply(std::vector<Expr>{ e }); } - Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{ e }); } + Expr apply(Expr e) override { return apply(std::vector<Expr>{e}); } + Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{e}); } void push_back(Ptr<IUnaryLayer> layer) { layers_.push_back(layer); } void push_back(Ptr<IUnaryLogitLayer> layer) { layers_.push_back(layer); } void setShortlist(Ptr<data::Shortlist> shortlist) override final { auto p = tryAsHasShortlist(); - ABORT_IF(!p, "setShortlist() called on an MLP with an output layer that does not support short lists"); + ABORT_IF( + !p, + "setShortlist() called on an MLP with an output layer that does not support short lists"); p->setShortlist(shortlist); } void clear() override final { auto p = tryAsHasShortlist(); - if (p) + if(p) p->clear(); } + private: - Ptr<IHasShortList> tryAsHasShortlist() const { return std::dynamic_pointer_cast<IHasShortList>(layers_.back()); } + Ptr<IHasShortList> tryAsHasShortlist() const { + return std::dynamic_pointer_cast<IHasShortList>(layers_.back()); + } }; /** @@ -154,6 +159,7 @@ private: */ class MLPFactory : public Factory { using Factory::Factory; + private: std::vector<Ptr<LayerFactory>> layers_; @@ -177,23 +183,27 @@ public: // which will go away if we get rid of the abstract factories, and instead just construct // all layers immediately, which is my long-term goal for Marian. private: - template<class WrappedFactory> + template <class WrappedFactory> class AsLayerFactory : public LayerFactory { - WrappedFactory us; + WrappedFactory us; + public: - AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {} - Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final { - auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph)); - ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one"); - return p; - } + AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {} + Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final { + auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph)); + ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one"); + return p; + } }; - template<class WrappedFactory> - static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) { return wrapped; } + template <class WrappedFactory> + static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) { + return wrapped; + } + public: Accumulator<MLPFactory> push_back(const Accumulator<OutputFactory>& lf) { push_back(AsLayerFactory<OutputFactory>(lf)); - //layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf))); + // layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf))); return Accumulator<MLPFactory>(*this); } }; |