diff options
Diffstat (limited to 'src/models/transformer.h')
-rwxr-xr-x | src/models/transformer.h | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/src/models/transformer.h b/src/models/transformer.h index 6a7db643..f6aa4ff1 100755 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -177,7 +177,7 @@ public: void collectOneHead(Expr weights, int dimBeam) { // select first head, this is arbitrary as the choice does not really matter - auto head0 = index_select(weights, 0, -3); + auto head0 = slice(weights, -3, 0); int dimBatchBeam = head0->shape()[-4]; int srcWords = head0->shape()[-1]; @@ -193,7 +193,7 @@ public: // @TODO: make splitting obsolete alignments_.clear(); for(int i = 0; i < trgWords; ++i) { - alignments_.push_back(marian::step(head0, i, -1)); // [tgt index][-4: beam depth, -3: max src length, -2: batch size, -1: 1] + alignments_.push_back(slice(head0, -1, i)); // [tgt index][-4: beam depth, -3: max src length, -2: batch size, -1: 1] } } |