Welcome to mirror list, hosted at ThFree Co, Russian Federation.

attention_constructors.h « rnn « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: a878f57f672f730d5feb8f1717ce496856249a00 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#pragma once

#include "marian.h"

#include "layers/factory.h"
#include "rnn/attention.h"
#include "rnn/constructors.h"
#include "rnn/types.h"

namespace marian {
namespace rnn {

class AttentionFactory : public InputFactory {
protected:
  Ptr<EncoderState> state_;

public:
//  AttentionFactory(Ptr<ExpressionGraph> graph) : InputFactory(graph) {}

  Ptr<CellInput> construct(Ptr<ExpressionGraph> graph) override {
    ABORT_IF(!state_, "EncoderState not set");
    return New<Attention>(graph, options_, state_);
  }

  Accumulator<AttentionFactory> set_state(Ptr<EncoderState> state) {
    state_ = state;
    return Accumulator<AttentionFactory>(*this);
  }

  int dimAttended() {
    ABORT_IF(!state_, "EncoderState not set");
    return state_->getAttended()->shape()[1];
  }
};

typedef Accumulator<AttentionFactory> attention;
}  // namespace rnn
}  // namespace marian