Welcome to mirror list, hosted at ThFree Co, Russian Federation.

loss.cpp « layers « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 66eabd8a63fef4a6457fe962a228621440299e97 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "layers/loss.h"

namespace marian {

// @TODO, simplify this. Currently here for back-compat
Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference) {
  float smoothing = inference ? 0.f : options->get<float>("label-smoothing");
  float factorWeight = options->get<float>("factor-weight", 1.0f);
  std::string costType = options->get<std::string>("cost-type", "ce-mean");
  bool unlikelihood = options->get<bool>("unlikelihood-loss", false);
  
  if(costType == "ce-rescore") { // returns per-batch-item scores (while ce-mean reduces over batch)
    return New<RescorerLoss>();
  } else if(unlikelihood) {  
    ABORT_IF(!options->hasAndNotEmpty("data-weighting") 
             && options->get<std::string>("data-weighting-type") != "word",
             "Unlikelihood loss training requires error annotation in form of per-target-label scores");
    return New<SequenceUnlikelihoodLoss>(smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on values given for data-weighting
  } else {  // same as ce-mean  --@TODO: better check all allowed values, and fail for invalid ones. E.g. what about ce-sum?
    return New<CrossEntropyLoss>(smoothing, factorWeight);
  }
}

// see loss.h for detailed explanations of each class
Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options) {
    std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
    if(multiLossType == "sum")         // sum of sums
      return New<SumMultiRationalLoss>();
    else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale
      return New<ScaledMultiRationalLoss>();
    else if(multiLossType == "mean")   // sum of means
      return New<MeanMultiRationalLoss>();
    else
      ABORT("Unknown multi-loss-type {}", multiLossType);
}

}  // namespace marian