Welcome to mirror list, hosted at ThFree Co, Russian Federation.

states.h « models « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: c2f9ee05a22410a7fd301d70e60618b4760c90a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#pragma once

#include "marian.h"
#include "layers/generic.h" // @HACK: for factored embeddings only so far
#include "rnn/types.h"

namespace marian {

class EncoderState {
private:
  Expr context_;
  Expr mask_;       // [beam depth=1, max length, batch size, vector dim=1] source mask
  Ptr<data::CorpusBatch> batch_;

public:
  EncoderState(Expr context, Expr mask, Ptr<data::CorpusBatch> batch)
      : context_(context), mask_(mask), batch_(batch) {}

  EncoderState() {}
  virtual ~EncoderState() {}

  virtual Expr getContext()   const { return context_;   }
  virtual Expr getAttended()  const { return context_;   }
  virtual Expr getMask()      const { return mask_;      } // source batch mask; may have additional positions suppressed

  virtual const Words& getSourceWords() {
    return batch_->front()->data();
  }

  // Sub-select active batch entries from encoder context and context mask
  Ptr<EncoderState> select(const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries
    // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer gets transposed to the same dimension layout
    return New<EncoderState>(index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_);
  }
};

class DecoderState {
protected:
  rnn::States states_; // states of individual decoder layers
  Logits logProbs_;
  std::vector<Ptr<EncoderState>> encStates_;
  Ptr<data::CorpusBatch> batch_;

  Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded
  Expr targetMask_;
  Words targetWords_;            // target labels

  // Keep track of current target token position during translation
  size_t position_{0};

public:
  DecoderState(const rnn::States& states,
               Logits logProbs,
               const std::vector<Ptr<EncoderState>>& encStates,
               Ptr<data::CorpusBatch> batch)
      : states_(states), logProbs_(logProbs), encStates_(encStates), batch_(batch) {}
  virtual ~DecoderState() {}

  // @TODO: Do we need all these to be virtual?
  virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const {
    return encStates_;
  }

  virtual Logits getLogProbs() const { return logProbs_; }
  virtual void setLogProbs(Logits logProbs) { logProbs_ = logProbs; }

  // @TODO: should this be a constructor? Then derived classes can call this without the New<> in the loop
  virtual Ptr<DecoderState> select(const std::vector<IndexType>& hypIndices,   // [beamIndex * activeBatchSize + batchIndex]
                                   const std::vector<IndexType>& batchIndices, // [batchIndex]
                                   int beamSize) const {

    std::vector<Ptr<EncoderState>> newEncStates;
    for(auto& es : encStates_)
      // If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries
      newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));

    // hypindices matches batchIndices in terms of batch dimension, so we only need hypIndices
    auto selectedState = New<DecoderState>(
        states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), logProbs_, newEncStates, batch_);

    // Set positon of new state based on the target token position of current state
    selectedState->setPosition(getPosition());
    return selectedState;
  }

  virtual const rnn::States& getStates() const { return states_; }

  virtual Expr getTargetHistoryEmbeddings() const { return targetHistoryEmbeddings_; };
  virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { targetHistoryEmbeddings_ = targetHistoryEmbeddings; }

  virtual const Words& getTargetWords() const { return targetWords_; };
  virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; }

  virtual Expr getTargetMask() const { return targetMask_; };
  virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; }

  virtual const Words& getSourceWords() const {
    return getEncoderStates()[0]->getSourceWords();
  }

  Ptr<data::CorpusBatch> getBatch() const { return batch_; }

  // Set current target token position in state when decoding
  size_t getPosition() const { return position_; }

  // Set current target token position in state when decoding
  void setPosition(size_t position) { position_ = position; }

  virtual void blacklist(Expr /*totalCosts*/, Ptr<data::CorpusBatch> /*batch*/) {}
};

/**
 * Classifier output based on DecoderState
 * @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have stateful output.
 */
class ClassifierState {
private:
  Expr logProbs_;
  std::vector<Ptr<EncoderState>> encStates_;
  Ptr<data::CorpusBatch> batch_;

  Expr targetMask_;
  Words targetWords_;

public:
  virtual ~ClassifierState() {}
  virtual Expr getLogProbs() const { return logProbs_; }
  virtual void setLogProbs(Expr logProbs) { logProbs_ = logProbs; }

  virtual const Words& getTargetWords() const { return targetWords_; };
  virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; }

  virtual Expr getTargetMask() const { return targetMask_; };

  virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; }
};

}  // namespace marian