Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/layers/constructors.h')
-rw-r--r--src/layers/constructors.h70
1 files changed, 40 insertions, 30 deletions
diff --git a/src/layers/constructors.h b/src/layers/constructors.h
index e25449aa..9e9de207 100644
--- a/src/layers/constructors.h
+++ b/src/layers/constructors.h
@@ -1,8 +1,8 @@
#pragma once
+#include "layers/embedding.h"
#include "layers/factory.h"
#include "layers/generic.h"
-#include "layers/embedding.h"
#include "layers/output.h"
namespace marian {
@@ -45,6 +45,7 @@ struct LogitLayerFactory : public Factory {
// @TODO: In the long run, I hope we can get rid of the abstract factories altogether.
class OutputFactory : public LogitLayerFactory {
using LogitLayerFactory::LogitLayerFactory;
+
protected:
std::string tiedTransposedName_;
Ptr<data::Shortlist> shortlist_;
@@ -55,9 +56,7 @@ public:
return Accumulator<OutputFactory>(*this);
}
- void setShortlist(Ptr<data::Shortlist> shortlist) {
- shortlist_ = shortlist;
- }
+ void setShortlist(Ptr<data::Shortlist> shortlist) { shortlist_ = shortlist; }
Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) override {
auto output = New<Output>(graph, options_);
@@ -89,8 +88,7 @@ protected:
std::vector<Ptr<IUnaryLayer>> layers_;
public:
- MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : graph_(graph), options_(options) {}
+ MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
Expr apply(const std::vector<Expr>& av) override {
Expr output;
@@ -106,46 +104,53 @@ public:
}
Logits applyAsLogits(const std::vector<Expr>& av) override {
- // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different return type
+ // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different
+ // return type
auto lastLayer = std::dynamic_pointer_cast<IUnaryLogitLayer>(layers_.back());
- ABORT_IF(!lastLayer, "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer");
- if (layers_.size() == 1) {
- if (av.size() == 1)
+ ABORT_IF(
+ !lastLayer,
+ "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer");
+ if(layers_.size() == 1) {
+ if(av.size() == 1)
return lastLayer->applyAsLogits(av[0]);
else
return lastLayer->applyAsLogits(av);
- }
- else {
+ } else {
Expr output;
- if (av.size() == 1)
+ if(av.size() == 1)
output = layers_[0]->apply(av[0]);
else
output = layers_[0]->apply(av);
- for (size_t i = 1; i < layers_.size() - 1; ++i)
+ for(size_t i = 1; i < layers_.size() - 1; ++i)
output = layers_[i]->apply(output);
return lastLayer->applyAsLogits(output);
}
}
- Expr apply(Expr e) override { return apply(std::vector<Expr>{ e }); }
- Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{ e }); }
+ Expr apply(Expr e) override { return apply(std::vector<Expr>{e}); }
+ Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{e}); }
void push_back(Ptr<IUnaryLayer> layer) { layers_.push_back(layer); }
void push_back(Ptr<IUnaryLogitLayer> layer) { layers_.push_back(layer); }
void setShortlist(Ptr<data::Shortlist> shortlist) override final {
auto p = tryAsHasShortlist();
- ABORT_IF(!p, "setShortlist() called on an MLP with an output layer that does not support short lists");
+ ABORT_IF(
+ !p,
+ "setShortlist() called on an MLP with an output layer that does not support short lists");
p->setShortlist(shortlist);
}
void clear() override final {
auto p = tryAsHasShortlist();
- if (p)
+ if(p)
p->clear();
}
+
private:
- Ptr<IHasShortList> tryAsHasShortlist() const { return std::dynamic_pointer_cast<IHasShortList>(layers_.back()); }
+ Ptr<IHasShortList> tryAsHasShortlist() const {
+ return std::dynamic_pointer_cast<IHasShortList>(layers_.back());
+ }
};
/**
@@ -154,6 +159,7 @@ private:
*/
class MLPFactory : public Factory {
using Factory::Factory;
+
private:
std::vector<Ptr<LayerFactory>> layers_;
@@ -177,23 +183,27 @@ public:
// which will go away if we get rid of the abstract factories, and instead just construct
// all layers immediately, which is my long-term goal for Marian.
private:
- template<class WrappedFactory>
+ template <class WrappedFactory>
class AsLayerFactory : public LayerFactory {
- WrappedFactory us;
+ WrappedFactory us;
+
public:
- AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {}
- Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final {
- auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph));
- ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one");
- return p;
- }
+ AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {}
+ Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final {
+ auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph));
+ ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one");
+ return p;
+ }
};
- template<class WrappedFactory>
- static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) { return wrapped; }
+ template <class WrappedFactory>
+ static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) {
+ return wrapped;
+ }
+
public:
Accumulator<MLPFactory> push_back(const Accumulator<OutputFactory>& lf) {
push_back(AsLayerFactory<OutputFactory>(lf));
- //layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
+ // layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
return Accumulator<MLPFactory>(*this);
}
};