Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/models/transformer.h')
-rwxr-xr-xsrc/models/transformer.h4
1 files changed, 2 insertions, 2 deletions
diff --git a/src/models/transformer.h b/src/models/transformer.h
index 6a7db643..f6aa4ff1 100755
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -177,7 +177,7 @@ public:
void collectOneHead(Expr weights, int dimBeam) {
// select first head, this is arbitrary as the choice does not really matter
- auto head0 = index_select(weights, 0, -3);
+ auto head0 = slice(weights, -3, 0);
int dimBatchBeam = head0->shape()[-4];
int srcWords = head0->shape()[-1];
@@ -193,7 +193,7 @@ public:
// @TODO: make splitting obsolete
alignments_.clear();
for(int i = 0; i < trgWords; ++i) {
- alignments_.push_back(marian::step(head0, i, -1)); // [tgt index][-4: beam depth, -3: max src length, -2: batch size, -1: 1]
+ alignments_.push_back(slice(head0, -1, i)); // [tgt index][-4: beam depth, -3: max src length, -2: batch size, -1: 1]
}
}