diff options
Diffstat (limited to 'src/codegen_fp16fp32.cc')
-rw-r--r-- | src/codegen_fp16fp32.cc | 134 |
1 files changed, 58 insertions, 76 deletions
diff --git a/src/codegen_fp16fp32.cc b/src/codegen_fp16fp32.cc index 17bb113..7c8e10c 100644 --- a/src/codegen_fp16fp32.cc +++ b/src/codegen_fp16fp32.cc @@ -47,20 +47,35 @@ int main() { {2, "AVX2", { - {1, 1, 0}, - {2, 1, 0}, - {3, 1, 0}, - {4, 1, 0}, - {5, 1, 0}, - {6, 1, 0}, - {7, 1, 0}, - {8, 1, 0}, - {9, 1, 0}, - {10, 1, 0}, - {11, 1, 0}, - {12, 1, 0}, - {13, 1, 0}, - {14, 1, 0}, + // 4x3 register layout + // {1, 3, 0}, + // {2, 3, 0}, + // {3, 3, 0}, + // {4, 3, 0}, + + // 6x2 register layout + {1, 2, 0}, + {2, 2, 0}, + {3, 2, 0}, + {4, 2, 0}, + {5, 2, 0}, + {6, 2, 0}, + + // 14x1 register layout + // {1, 1, 0}, + // {2, 1, 0}, + // {3, 1, 0}, + // {4, 1, 0}, + // {5, 1, 0}, + // {6, 1, 0}, + // {7, 1, 0}, + // {8, 1, 0}, + // {9, 1, 0}, + // {10, 1, 0}, + // {11, 1, 0}, + // {12, 1, 0}, + // {13, 1, 0}, + // {14, 1, 0}, }}}; // open all files @@ -159,7 +174,6 @@ int main() { string vAtmp = "ymm" + to_string(last_free_ymmreg++); // produce register block of B col - assert(ukernel_shape[k][1] == 1); vector<string> vBcol(ukernel_shape[k][1]); for (auto c = 0; c < ukernel_shape[k][1]; c++) { @@ -228,82 +242,50 @@ int main() { srcfile << "\n"; - if (ukernel_shape[k][0] <= 13) { - addi(srcfile, "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]"); - addi(srcfile, "mov r11, 16"); - } else { - addi(srcfile, "mov r11, 0"); - } - srcfile << "\n"; string label = "loop_inner%="; addi(srcfile, label + ":"); srcfile << "\n"; - if (ukernel_shape[k][0] <= 13) { - auto a_offset = 0, unroll_factor = 2; - for (auto u = 0; u < unroll_factor; u++) { - string breg = (u == 0) ? "ymm14" : "ymm15"; - string breg_rev = (u == 0) ? "ymm15" : "ymm14"; - - addi( - srcfile, - "vcvtph2ps " + breg + ",XMMWORD PTR [r10 + r11 + " + - to_string(u * 16) + "]"); - addi(srcfile, "inc r14"); - for (auto r = 0; r < vCtile.size(); r++) { - addi( - srcfile, - "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + - to_string(a_offset) + "]"); - addi( - srcfile, - "vfmadd231ps " + vCtile[r][0] + "," + breg_rev + "," + - vAtmp); - if (u == 1 && r == vCtile.size() / 2) - addi(srcfile, "add r11, 32"); - a_offset += 4; - } - if (u < unroll_factor - 1) { - addi(srcfile, "cmp r14, r8"); - addi(srcfile, "jge " + exitlabel); - } - } - - addi(srcfile, "add r9," + to_string(a_offset)); - addi(srcfile, "cmp r14, r8"); - addi(srcfile, "jl " + label); - - srcfile << "\n"; + for (int c = 0; c < vCtile[0].size(); c++) { + addi( + srcfile, + "vcvtph2ps " + vBcol[c] + ",XMMWORD PTR [r10 + " + + to_string(16 * c) + "]"); + } - addi(srcfile, exitlabel + ":"); - } else { + for (int r = 0; r < vCtile.size(); r++) { addi( srcfile, - "vcvtph2ps " + vBcol[0] + ",XMMWORD PTR [r10 + r11]"); - for (auto r = 0; r < vCtile.size(); r++) { + "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + + to_string(4 * r) + "]"); + for (int c = 0; c < vCtile[0].size(); c++) { addi( srcfile, - "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + - to_string(4 * r) + "]"); - addi( - srcfile, - "vfmadd231ps " + vCtile[r][0] + "," + vBcol[0] + "," + + "vfmadd231ps " + vCtile[r][c] + "," + vBcol[c] + "," + vAtmp); } + } - addi( - srcfile, - "add r9," + to_string(4 * ukernel_shape[k][0]), - fixedA); // move A ptr - addi(srcfile, "add r11, 16"); + addi( + srcfile, + "add r9," + to_string(4 * ukernel_shape[k][0]), + fixedA); // move A ptr - addi(srcfile, "inc r14"); - addi(srcfile, "cmp r14, r8"); - addi(srcfile, "jl " + label); - } + addi( + srcfile, + "add r10," + to_string(16 * ukernel_shape[k][1]), + fixedA); // move A ptr + + addi(srcfile, "inc r14"); + addi(srcfile, "cmp r14, r8"); + addi(srcfile, "jl " + label); + + srcfile << "\n"; + + addi(srcfile, exitlabel + ":"); - addi(srcfile, "add r10, rsi"); + // addi(srcfile, "add r10, rsi"); srcfile << "\n"; // end marker |