1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
|
#include "layers/loss.h"
namespace marian {
Ptr<LossBase> LossFactory(Ptr<Options> options, bool inference) {
float smoothing = inference ? 0.f : options->get<float>("label-smoothing");
std::string costType = options->get<std::string>("cost-type", "ce-mean");
if(costType == "ce-mean" || costType == "cross-entropy") {
return New<CrossEntropyMeanLoss>(smoothing);
} else if(costType == "ce-mean-words") {
return New<CrossEntropyMeanWordsLoss>(smoothing);
} else if(costType == "ce-sum") {
return New<CrossEntropySumLoss>(smoothing);
} else if(costType == "perplexity") {
return New<PerplexityLoss>(smoothing);
} else if(costType == "ce-rescore") {
return New<CrossEntropyRescoreLoss>(smoothing);
} else { // same as ce-mean
return New<CrossEntropyMeanLoss>(smoothing);
}
}
Expr LossBase::getCrossEntropy(Expr logits,
Expr indices,
Expr mask,
Expr weights) {
using namespace keywords;
auto ce = cross_entropy(logits, indices);
if(smoothing_ > 0) {
// @TODO: add this to CE kernels instead
auto ceq = mean(logsoftmax(logits), axis = -1);
ce = (1 - smoothing_) * ce - smoothing_ * ceq;
}
if(mask)
ce = ce * mask;
if(weights)
ce = ce * weights;
return ce;
}
Expr CrossEntropyMeanLoss::getCost(Expr logits,
Expr indices,
Expr mask,
Expr weights) {
using namespace keywords;
auto ce = getCrossEntropy(logits, indices, mask, weights);
// Time axis (words): -3
// Batch axis (sentences): -2
return mean(sum(ce, axis = -3), axis = -2);
}
Expr CrossEntropyMeanWordsLoss::getCost(Expr logits,
Expr indices,
Expr mask,
Expr weights) {
using namespace keywords;
auto ce = getCrossEntropy(logits, indices, mask, weights);
return sum(sum(ce, axis = -3), axis = -2)
/ sum(sum(mask, axis = -3), axis = -2);
}
Expr CrossEntropySumLoss::getCost(Expr logits,
Expr indices,
Expr mask,
Expr weights) {
using namespace keywords;
auto ce = getCrossEntropy(logits, indices, mask, weights);
return sum(sum(ce, axis = -3), axis = -2);
}
Expr PerplexityLoss::getCost(Expr logits,
Expr indices,
Expr mask,
Expr weights) {
using namespace keywords;
auto ce = getCrossEntropy(logits, indices, mask, weights);
return exp(sum(sum(ce, axis = -3), axis = -2)
/ sum(sum(mask, axis = -3), axis = -2));
}
Expr CrossEntropyRescoreLoss::getCost(Expr logits,
Expr indices,
Expr mask,
Expr weights) {
using namespace keywords;
auto ce = getCrossEntropy(logits, indices, mask, weights);
return -sum(ce, axis = -3);
}
}
|