Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/codegen_fp16fp32.cc')
-rw-r--r--src/codegen_fp16fp32.cc134
1 files changed, 58 insertions, 76 deletions
diff --git a/src/codegen_fp16fp32.cc b/src/codegen_fp16fp32.cc
index 17bb113..7c8e10c 100644
--- a/src/codegen_fp16fp32.cc
+++ b/src/codegen_fp16fp32.cc
@@ -47,20 +47,35 @@ int main() {
{2,
"AVX2",
{
- {1, 1, 0},
- {2, 1, 0},
- {3, 1, 0},
- {4, 1, 0},
- {5, 1, 0},
- {6, 1, 0},
- {7, 1, 0},
- {8, 1, 0},
- {9, 1, 0},
- {10, 1, 0},
- {11, 1, 0},
- {12, 1, 0},
- {13, 1, 0},
- {14, 1, 0},
+ // 4x3 register layout
+ // {1, 3, 0},
+ // {2, 3, 0},
+ // {3, 3, 0},
+ // {4, 3, 0},
+
+ // 6x2 register layout
+ {1, 2, 0},
+ {2, 2, 0},
+ {3, 2, 0},
+ {4, 2, 0},
+ {5, 2, 0},
+ {6, 2, 0},
+
+ // 14x1 register layout
+ // {1, 1, 0},
+ // {2, 1, 0},
+ // {3, 1, 0},
+ // {4, 1, 0},
+ // {5, 1, 0},
+ // {6, 1, 0},
+ // {7, 1, 0},
+ // {8, 1, 0},
+ // {9, 1, 0},
+ // {10, 1, 0},
+ // {11, 1, 0},
+ // {12, 1, 0},
+ // {13, 1, 0},
+ // {14, 1, 0},
}}};
// open all files
@@ -159,7 +174,6 @@ int main() {
string vAtmp = "ymm" + to_string(last_free_ymmreg++);
// produce register block of B col
- assert(ukernel_shape[k][1] == 1);
vector<string> vBcol(ukernel_shape[k][1]);
for (auto c = 0; c < ukernel_shape[k][1]; c++) {
@@ -228,82 +242,50 @@ int main() {
srcfile << "\n";
- if (ukernel_shape[k][0] <= 13) {
- addi(srcfile, "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]");
- addi(srcfile, "mov r11, 16");
- } else {
- addi(srcfile, "mov r11, 0");
- }
-
srcfile << "\n";
string label = "loop_inner%=";
addi(srcfile, label + ":");
srcfile << "\n";
- if (ukernel_shape[k][0] <= 13) {
- auto a_offset = 0, unroll_factor = 2;
- for (auto u = 0; u < unroll_factor; u++) {
- string breg = (u == 0) ? "ymm14" : "ymm15";
- string breg_rev = (u == 0) ? "ymm15" : "ymm14";
-
- addi(
- srcfile,
- "vcvtph2ps " + breg + ",XMMWORD PTR [r10 + r11 + " +
- to_string(u * 16) + "]");
- addi(srcfile, "inc r14");
- for (auto r = 0; r < vCtile.size(); r++) {
- addi(
- srcfile,
- "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" +
- to_string(a_offset) + "]");
- addi(
- srcfile,
- "vfmadd231ps " + vCtile[r][0] + "," + breg_rev + "," +
- vAtmp);
- if (u == 1 && r == vCtile.size() / 2)
- addi(srcfile, "add r11, 32");
- a_offset += 4;
- }
- if (u < unroll_factor - 1) {
- addi(srcfile, "cmp r14, r8");
- addi(srcfile, "jge " + exitlabel);
- }
- }
-
- addi(srcfile, "add r9," + to_string(a_offset));
- addi(srcfile, "cmp r14, r8");
- addi(srcfile, "jl " + label);
-
- srcfile << "\n";
+ for (int c = 0; c < vCtile[0].size(); c++) {
+ addi(
+ srcfile,
+ "vcvtph2ps " + vBcol[c] + ",XMMWORD PTR [r10 + " +
+ to_string(16 * c) + "]");
+ }
- addi(srcfile, exitlabel + ":");
- } else {
+ for (int r = 0; r < vCtile.size(); r++) {
addi(
srcfile,
- "vcvtph2ps " + vBcol[0] + ",XMMWORD PTR [r10 + r11]");
- for (auto r = 0; r < vCtile.size(); r++) {
+ "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" +
+ to_string(4 * r) + "]");
+ for (int c = 0; c < vCtile[0].size(); c++) {
addi(
srcfile,
- "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" +
- to_string(4 * r) + "]");
- addi(
- srcfile,
- "vfmadd231ps " + vCtile[r][0] + "," + vBcol[0] + "," +
+ "vfmadd231ps " + vCtile[r][c] + "," + vBcol[c] + "," +
vAtmp);
}
+ }
- addi(
- srcfile,
- "add r9," + to_string(4 * ukernel_shape[k][0]),
- fixedA); // move A ptr
- addi(srcfile, "add r11, 16");
+ addi(
+ srcfile,
+ "add r9," + to_string(4 * ukernel_shape[k][0]),
+ fixedA); // move A ptr
- addi(srcfile, "inc r14");
- addi(srcfile, "cmp r14, r8");
- addi(srcfile, "jl " + label);
- }
+ addi(
+ srcfile,
+ "add r10," + to_string(16 * ukernel_shape[k][1]),
+ fixedA); // move A ptr
+
+ addi(srcfile, "inc r14");
+ addi(srcfile, "cmp r14, r8");
+ addi(srcfile, "jl " + label);
+
+ srcfile << "\n";
+
+ addi(srcfile, exitlabel + ":");
- addi(srcfile, "add r10, rsi");
+ // addi(srcfile, "add r10, rsi");
srcfile << "\n";
// end marker