Welcome to mirror list, hosted at ThFree Co, Russian Federation.

loss.cpp « layers « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: c84895f487459e18fe8464f8f5baf1e820430e98 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include "layers/loss.h"

namespace marian {

Ptr<LossBase> LossFactory(Ptr<Options> options, bool inference) {
  float smoothing = inference ? 0.f : options->get<float>("label-smoothing");
  std::string costType = options->get<std::string>("cost-type", "ce-mean");
  if(costType == "ce-mean" || costType == "cross-entropy") {
    return New<CrossEntropyMeanLoss>(smoothing);
  } else if(costType == "ce-mean-words") {
    return New<CrossEntropyMeanWordsLoss>(smoothing);
  } else if(costType == "ce-sum") {
    return New<CrossEntropySumLoss>(smoothing);
  } else if(costType == "perplexity") {
    return New<PerplexityLoss>(smoothing);
  } else if(costType == "ce-rescore") {
    return New<CrossEntropyRescoreLoss>(smoothing);
  } else {  // same as ce-mean
    return New<CrossEntropyMeanLoss>(smoothing);
  }
}

Expr LossBase::getCrossEntropy(Expr logits,
                               Expr indices,
                               Expr mask,
                               Expr weights) {
  using namespace keywords;

  auto ce = cross_entropy(logits, indices);

  if(smoothing_ > 0) {
    // @TODO: add this to CE kernels instead
    auto ceq = mean(logsoftmax(logits), axis = -1);
    ce = (1 - smoothing_) * ce - smoothing_ * ceq;
  }

  if(mask)
    ce = ce * mask;

  if(weights)
    ce = ce * weights;

  return ce;
}

Expr CrossEntropyMeanLoss::getCost(Expr logits,
                                   Expr indices,
                                   Expr mask,
                                   Expr weights) {
  using namespace keywords;
  auto ce = getCrossEntropy(logits, indices, mask, weights);
  // Time axis (words): -3
  // Batch axis (sentences): -2
  return mean(sum(ce, axis = -3), axis = -2);
}

Expr CrossEntropyMeanWordsLoss::getCost(Expr logits,
                                        Expr indices,
                                        Expr mask,
                                        Expr weights) {
  using namespace keywords;
  auto ce = getCrossEntropy(logits, indices, mask, weights);
  return sum(sum(ce, axis = -3), axis = -2)
         / sum(sum(mask, axis = -3), axis = -2);
}

Expr CrossEntropySumLoss::getCost(Expr logits,
                                  Expr indices,
                                  Expr mask,
                                  Expr weights) {
  using namespace keywords;
  auto ce = getCrossEntropy(logits, indices, mask, weights);
  return sum(sum(ce, axis = -3), axis = -2);
}

Expr PerplexityLoss::getCost(Expr logits,
                             Expr indices,
                             Expr mask,
                             Expr weights) {
  using namespace keywords;
  auto ce = getCrossEntropy(logits, indices, mask, weights);
  return exp(sum(sum(ce, axis = -3), axis = -2)
             / sum(sum(mask, axis = -3), axis = -2));
}

Expr CrossEntropyRescoreLoss::getCost(Expr logits,
                                      Expr indices,
                                      Expr mask,
                                      Expr weights) {
  using namespace keywords;
  auto ce = getCrossEntropy(logits, indices, mask, weights);
  return -sum(ce, axis = -3);
}
}