diff options
Diffstat (limited to 'src/codegen_fp16fp32.cc')
-rw-r--r-- | src/codegen_fp16fp32.cc | 432 |
1 files changed, 354 insertions, 78 deletions
diff --git a/src/codegen_fp16fp32.cc b/src/codegen_fp16fp32.cc index 7c8e10c..8f80593 100644 --- a/src/codegen_fp16fp32.cc +++ b/src/codegen_fp16fp32.cc @@ -18,10 +18,16 @@ using namespace std; +void addi(ofstream& of, string i, string asmstr = "", bool disable = false) { + if (disable == false) + of << " " + i + " //\"" + asmstr + "\\t\\n\"" + "\n"; +} +#if 0 void addi(ofstream& of, string i, bool disable = false) { if (disable == false) of << " \"" + i + "\\t\\n\"" + "\n"; } +#endif struct ISA { unsigned avx; // 1, 2 or 3 @@ -88,7 +94,8 @@ int main() { " * This source code is licensed under the BSD-style license found in the\n" " * LICENSE file in the root directory of this source tree.\n" " */\n"; - srcfile << "#include \"FbgemmFP16UKernelsAvx2.h\"\n\n"; + srcfile << "#include \"FbgemmFP16UKernelsAvx2.h\"\n"; + srcfile << "#include <immintrin.h>\n\n"; srcfile << "namespace fbgemm {\n\n"; if (iaca) { srcfile << "#include \"iacaMarks.h\"\n"; @@ -111,6 +118,11 @@ int main() { hdrfile << "namespace fbgemm {\n\n"; hdrfile << "using fp16 = float16;\n"; hdrfile << "using fp32 = float;\n"; + hdrfile << "#ifdef _MSC_VER\n"; + hdrfile << " #define NOINLINE_ATTR __declspec(noinline)\n"; + hdrfile << "#else\n"; + hdrfile << " #define NOINLINE_ATTR __attribute__((noinline))\n"; + hdrfile << "#endif\n"; hdrfile << "struct GemmParams {\n uint64_t k;\n float* A;\n const fp16* B;\n" " float* beta;\n uint64_t accum;\n float* C;\n uint64_t ldc;\n" @@ -158,8 +170,9 @@ int main() { fargs = "(" + p1 + ")"; +#if 1 fheader[k] = - "void __attribute__((noinline)) " + funcname[k] + fargs; + "void NOINLINE_ATTR " + funcname[k] + fargs; srcfile << fheader[k] << " {\n"; unsigned last_free_ymmreg = 0; @@ -183,85 +196,92 @@ int main() { assert(last_free_ymmreg <= 16); - srcfile << " asm volatile(\n"; + //srcfile << " asm volatile(\n"; - srcfile << "#if !defined(__clang__)" - << "\n"; - addi(srcfile, "mov r14, %[gp]"); - srcfile << "#else\n"; - addi(srcfile, "mov %[gp], %%r14"); - addi(srcfile, ".intel_syntax noprefix"); - srcfile << "#endif\n"; + //srcfile << "#if !defined(__clang__)" + //<< "\n"; + addi(srcfile, "char* r14 = (char*)gp;", "mov r14, %[gp]"); + //srcfile << "#else\n"; + //addi(srcfile, "mov %[gp], %%r14"); + //addi(srcfile, ".intel_syntax noprefix"); + //srcfile << "#endif\n"; srcfile << "\n // Copy parameters\n"; - srcfile << " // k\n"; - addi(srcfile, "mov r8, [r14 + 0]"); - srcfile << " // A\n"; - addi(srcfile, "mov r9, [r14 + 8]"); - srcfile << " // B\n"; - addi(srcfile, "mov r10, [r14 + 16]"); - srcfile << " // beta\n"; - addi(srcfile, "mov r15, [r14 + 24]"); - srcfile << " // accum\n"; - addi(srcfile, "mov rdx, [r14 + 32]"); - srcfile << " // C\n"; - addi(srcfile, "mov r12, [r14 + 40]"); - srcfile << " // ldc\n"; - addi(srcfile, "mov r13, [r14 + 48]"); - srcfile << " // b_block_cols\n"; - addi(srcfile, "mov rdi, [r14 + 56]"); - srcfile << " // b_block_size\n"; - addi(srcfile, "mov rsi, [r14 + 64]"); + srcfile << " // k\n"; addi(srcfile, "uint64_t r8 = *(uint64_t *)((char*)r14 + 0 );", "mov r8, [r14 + 0]"); + srcfile << " // A\n"; addi(srcfile, "float* r9 = *(float* *)((char*)r14 + 8 );", "mov r9, [r14 + 8]"); + srcfile << " // B\n"; addi(srcfile, "const fp16* r10 = *(const fp16**)((char*)r14 + 16);", "mov r10, [r14 + 16]"); + srcfile << " // beta\n"; addi(srcfile, "float* r15 = *(float* *)((char*)r14 + 24);", "mov r15, [r14 + 24]"); + srcfile << " // accum\n"; addi(srcfile, "uint64_t rdx = *(uint64_t *)((char*)r14 + 32);", "mov rdx, [r14 + 32]"); + srcfile << " // C\n"; addi(srcfile, "float* r12 = *(float* *)((char*)r14 + 40);", "mov r12, [r14 + 40]"); + srcfile << " // ldc\n"; addi(srcfile, "uint64_t r13 = *(uint64_t *)((char*)r14 + 48);", "mov r13, [r14 + 48]"); + srcfile << " // b_block_cols\n"; addi(srcfile, "uint64_t rdi = *(uint64_t *)((char*)r14 + 56);", "mov rdi, [r14 + 56]"); + srcfile << " // b_block_size\n"; addi(srcfile, "uint64_t rsi = *(uint64_t *)((char*)r14 + 64);", "mov rsi, [r14 + 64]"); srcfile << " // Make copies of A and C\n"; - addi(srcfile, "mov rax, r9"); - addi(srcfile, "mov rcx, r12"); + addi(srcfile, "float* rax = r9;", "mov rax, r9"); + addi(srcfile, "float* rcx = r12;", "mov rcx, r12"); srcfile << "\n"; - addi(srcfile, "mov rbx, 0"); + addi(srcfile, "uint64_t rbx = 0;", "mov rbx, 0"); string exitlabel = "L_exit%="; string label2 = "loop_outter%="; - addi(srcfile, label2 + ":"); - addi(srcfile, "mov r14, 0"); + addi(srcfile, "for (; rbx < rdi; ++rbx) {", "inc rbx; cmp rbx, rdi; jl " + label2); + addi(srcfile, "// ", label2 + ":"); + addi(srcfile, " uint64_t r14_i = 0;", "mov r14, 0"); // set all vCtile regs to zeros for (auto r = 0; r < vCtile.size(); r++) { for (auto c = 0; c < vCtile[r].size(); c++) { addi( srcfile, + " __m256 " + vCtile[r][c] + " = _mm256_setzero_ps();", "vxorps " + vCtile[r][c] + "," + vCtile[r][c] + "," + vCtile[r][c]); } } // start marker - if (iaca) { - addi(srcfile, "mov ebx, 111"); - addi(srcfile, ".byte 0x64, 0x67, 0x90"); - } + //if (iaca) { + // addi(srcfile, "mov ebx, 111"); + // addi(srcfile, ".byte 0x64, 0x67, 0x90"); + //} - srcfile << "\n"; + //srcfile << "\n"; srcfile << "\n"; string label = "loop_inner%="; - addi(srcfile, label + ":"); - srcfile << "\n"; + addi(srcfile, " for (; r14_i < r8; ++r14_i) {", "inc r14; cmp r14, r8; jl " + label); + addi(srcfile, " // " + label + ":"); + //srcfile << "\n"; for (int c = 0; c < vCtile[0].size(); c++) { addi( + srcfile, + " auto fp16mem" + to_string(16 * c) + " = _mm_load_si128((__m128i*)((char*)r10 + " + to_string(16 * c) + "));", + "vcvtph2ps " + vBcol[c] + ",XMMWORD PTR [r10 + " + + to_string(16 * c) + "]"); + addi( srcfile, + " auto " + vBcol[c] + " = _mm256_cvtph_ps(fp16mem" + to_string(16 * c) + ");", "vcvtph2ps " + vBcol[c] + ",XMMWORD PTR [r10 + " + to_string(16 * c) + "]"); } for (int r = 0; r < vCtile.size(); r++) { + //addi( + // srcfile, + // ((r == 0) ? " auto " + vAtmp : "" + vAtmp) + " = _mm256_broadcastss_ps(r9 + " + to_string(4 * r) + ");", + // "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + + // to_string(4 * r) + "]"); addi( srcfile, + ((r == 0) ? " auto " + vAtmp : " " + vAtmp) + " = _mm256_broadcast_ss((float*)((char*)r9 + " + to_string(4 * r) + "));", "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + to_string(4 * r) + "]"); for (int c = 0; c < vCtile[0].size(); c++) { addi( srcfile, + " " + vCtile[r][c] + " = _mm256_fmadd_ps(" + vAtmp + ", " + vBcol[c] + ", " + vCtile[r][c] + ");", "vfmadd231ps " + vCtile[r][c] + "," + vBcol[c] + "," + vAtmp); } @@ -269,21 +289,25 @@ int main() { addi( srcfile, + " r9 = (float*)((char*)r9 + " + to_string(4 * ukernel_shape[k][0]) + ");", "add r9," + to_string(4 * ukernel_shape[k][0]), fixedA); // move A ptr addi( srcfile, + " r10 = (fp16*)((char*)r10 + " + to_string(16 * ukernel_shape[k][1]) + ");", "add r10," + to_string(16 * ukernel_shape[k][1]), fixedA); // move A ptr - addi(srcfile, "inc r14"); - addi(srcfile, "cmp r14, r8"); - addi(srcfile, "jl " + label); + addi(srcfile, " }", "inc r14; cmp r14, r8; jl " + label2); + // move to for loop + //addi(srcfile, "inc r14"); + //addi(srcfile, "cmp r14, r8"); + //addi(srcfile, "jl " + label); - srcfile << "\n"; + //srcfile << "\n"; - addi(srcfile, exitlabel + ":"); + //addi(srcfile, exitlabel + ":"); // addi(srcfile, "add r10, rsi"); srcfile << "\n"; @@ -294,29 +318,33 @@ int main() { addi(srcfile, ".byte 0x64, 0x67, 0x90"); } - addi(srcfile, "cmp rdx, 1"); - addi(srcfile, "je L_accum%="); - srcfile << " // Dump C\n"; + //addi(srcfile, "cmp rdx, 1"); + addi(srcfile, " if(rdx != 1) {", "cmp rdx, 1; je L_accum%="); + + srcfile << " // Dump C\n"; for (auto r = 0; r < vCtile.size(); r++) { for (auto c = 0; c < vCtile[r].size(); c++) { addi( srcfile, + " _mm256_storeu_ps((float*)((char*)r12 + " + to_string(32 * c) + "), " + vCtile[r][c] + ");", "vmovups YMMWORD PTR [r12 + " + to_string(32 * c) + "], " + vCtile[r][c], fixedC); } - addi(srcfile, "add r12, r13", fixedC); // move C ptr + if (r != vCtile.size() - 1) + addi(srcfile, " r12 = (float*)((char*)r12 + r13);", "add r12, r13", fixedC); // move C ptr } - addi(srcfile, "jmp L_done%="); + addi(srcfile, " } else {", "jmp L_done%="); - srcfile << "\n"; - addi(srcfile, "L_accum%=:"); - srcfile << " // Dump C with accumulate\n"; + //srcfile << "\n"; + //addi(srcfile, "L_accum%=:"); + srcfile << " // Dump C with accumulate\n"; string r_spare = (s.avx == 1) ? "ymm14" : "ymm15"; addi( srcfile, + " auto " + r_spare + " = _mm256_broadcast_ss((float*)r15);", "vbroadcastss " + r_spare + string(",DWORD PTR [r15]"), fixedC); // store out C @@ -326,18 +354,29 @@ int main() { case 1: addi( srcfile, + "not supported", string("vmulps ymm15, ") + r_spare + comma + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", fixedC); addi( srcfile, + "not supported", "vaddps " + vCtile[r][c] + "," + vCtile[r][c] + "," + "ymm15", fixedC); break; case 2: + //if (r == 0) { + addi( + srcfile, + ((r == 0) ? " auto r12_" + to_string(32 * c) : " r12_" + to_string(32 * c)) + " = _mm256_load_ps((float*)((char*)r12 + " + to_string(32 * c) + "));", + "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," + + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", + fixedC); + //} addi( srcfile, + " " + vCtile[r][c] + " = _mm256_fmadd_ps(r12_" + to_string(32 * c) + ", " + r_spare + ", " + vCtile[r][c] + ");", "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", fixedC); @@ -347,46 +386,283 @@ int main() { } addi( srcfile, + " _mm256_storeu_ps((float*)((char*)r12 + " + to_string(32 * c) + "), " + vCtile[r][c] + ");", "vmovups YMMWORD PTR [r12 + " + to_string(32 * c) + "], " + vCtile[r][c], fixedC); } - addi(srcfile, "add r12, r13", fixedC); // move C ptr + if (r != vCtile.size() - 1) + addi(srcfile, " r12 = (float*)((char*)r12 + r13);", "add r12, r13", fixedC); // move C ptr } - srcfile << "\n"; - addi(srcfile, "L_done%=:"); + //srcfile << "\n"; + addi(srcfile, " }", "L_done%=:"); - srcfile << "\n // next outer iteration\n"; + srcfile << "\n // next outer iteration\n"; // C addi( srcfile, + " rcx = (float*)((char*)rcx + " + to_string(32 * ukernel_shape[k][1]) + ");", "add rcx, " + to_string(32 * ukernel_shape[k][1]), fixedC); - addi(srcfile, "mov r12, rcx", fixedC); + addi(srcfile, " r12 = rcx;", "mov r12, rcx", fixedC); // A - addi(srcfile, "mov r9, rax"); - - addi(srcfile, "inc rbx"); - addi(srcfile, "cmp rbx, rdi"); - addi(srcfile, "jl " + label2); - - // output - srcfile << " :\n"; - // input - srcfile << " : [gp] \"rm\"(gp)\n"; - - // clobbered - srcfile - << " : \"r8\",\n \"r9\",\n \"r10\",\n" - " \"r11\",\n \"r15\",\n \"r13\",\n" - " \"r14\",\n \"rax\",\n \"rcx\",\n" - " \"rdx\",\n \"rsi\",\n \"rdi\",\n" - " \"rbx\",\n \"r12\",\n" - " \"memory\");\n"; - srcfile << "}\n"; + addi(srcfile, " r9 = rax;", "mov r9, rax"); + + // move to top for looop + //addi(srcfile, "inc rbx"); + //addi(srcfile, "cmp rbx, rdi"); + //addi(srcfile, "jl " + label2); + addi(srcfile, "}", "inc rbx; cmp rbx, rdi; jl " + label2); + + //// output + //srcfile << " :\n"; + //// input + //srcfile << " : [gp] \"rm\"(gp)\n"; + + //// clobbered + //srcfile + // << " : \"r8\",\n \"r9\",\n \"r10\",\n" + // " \"r11\",\n \"r15\",\n \"r13\",\n" + // " \"r14\",\n \"rax\",\n \"rcx\",\n" + // " \"rdx\",\n \"rsi\",\n \"rdi\",\n" + // " \"rbx\",\n \"r12\",\n" + // " \"memory\");\n"; + srcfile << "}\n\n"; + } + +#else + fheader[k] = + "void __attribute__((noinline)) " + funcname[k] + fargs; + srcfile << fheader[k] << " {\n"; + + unsigned last_free_ymmreg = 0; + // produce register block of C + vector<vector<string>> vCtile(ukernel_shape[k][0]); + for (auto r = 0; r < ukernel_shape[k][0]; r++) + for (auto c = 0; c < ukernel_shape[k][1]; c++) { + vCtile[r].push_back("ymm" + to_string(last_free_ymmreg)); + last_free_ymmreg++; + } + assert(last_free_ymmreg <= 14); + + string vAtmp = "ymm" + to_string(last_free_ymmreg++); + // produce register block of B col + vector<string> vBcol(ukernel_shape[k][1]); + + for (auto c = 0; c < ukernel_shape[k][1]; c++) { + vBcol[c] = ("ymm" + to_string(last_free_ymmreg)); + last_free_ymmreg++; + } + + assert(last_free_ymmreg <= 16); + + srcfile << " asm volatile(\n"; + + srcfile << "#if !defined(__clang__)" + << "\n"; + addi(srcfile, "mov r14, %[gp]"); + srcfile << "#else\n"; + addi(srcfile, "mov %[gp], %%r14"); + addi(srcfile, ".intel_syntax noprefix"); + srcfile << "#endif\n"; + + srcfile << "\n // Copy parameters\n"; + srcfile << " // k\n"; + addi(srcfile, "mov r8, [r14 + 0]"); + srcfile << " // A\n"; + addi(srcfile, "mov r9, [r14 + 8]"); + srcfile << " // B\n"; + addi(srcfile, "mov r10, [r14 + 16]"); + srcfile << " // beta\n"; + addi(srcfile, "mov r15, [r14 + 24]"); + srcfile << " // accum\n"; + addi(srcfile, "mov rdx, [r14 + 32]"); + srcfile << " // C\n"; + addi(srcfile, "mov r12, [r14 + 40]"); + srcfile << " // ldc\n"; + addi(srcfile, "mov r13, [r14 + 48]"); + srcfile << " // b_block_cols\n"; + addi(srcfile, "mov rdi, [r14 + 56]"); + srcfile << " // b_block_size\n"; + addi(srcfile, "mov rsi, [r14 + 64]"); + srcfile << " // Make copies of A and C\n"; + addi(srcfile, "mov rax, r9"); + addi(srcfile, "mov rcx, r12"); + srcfile << "\n"; + + addi(srcfile, "mov rbx, 0"); + + string exitlabel = "L_exit%="; + string label2 = "loop_outter%="; + addi(srcfile, label2 + ":"); + addi(srcfile, "mov r14, 0"); + + // set all vCtile regs to zeros + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + addi( + srcfile, + "vxorps " + vCtile[r][c] + "," + vCtile[r][c] + "," + + vCtile[r][c]); + } + } + + // start marker + if (iaca) { + addi(srcfile, "mov ebx, 111"); + addi(srcfile, ".byte 0x64, 0x67, 0x90"); + } + + srcfile << "\n"; + + srcfile << "\n"; + string label = "loop_inner%="; + addi(srcfile, label + ":"); + srcfile << "\n"; + + for (int c = 0; c < vCtile[0].size(); c++) { + addi( + srcfile, + "vcvtph2ps " + vBcol[c] + ",XMMWORD PTR [r10 + " + + to_string(16 * c) + "]"); } + for (int r = 0; r < vCtile.size(); r++) { + addi( + srcfile, + "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + + to_string(4 * r) + "]"); + for (int c = 0; c < vCtile[0].size(); c++) { + addi( + srcfile, + "vfmadd231ps " + vCtile[r][c] + "," + vBcol[c] + "," + + vAtmp); + } + } + + addi( + srcfile, + "add r9," + to_string(4 * ukernel_shape[k][0]), + fixedA); // move A ptr + + addi( + srcfile, + "add r10," + to_string(16 * ukernel_shape[k][1]), + fixedA); // move A ptr + + addi(srcfile, "inc r14"); + addi(srcfile, "cmp r14, r8"); + addi(srcfile, "jl " + label); + + srcfile << "\n"; + + addi(srcfile, exitlabel + ":"); + + // addi(srcfile, "add r10, rsi"); + srcfile << "\n"; + + // end marker + if (iaca) { + addi(srcfile, "mov ebx, 222"); + addi(srcfile, ".byte 0x64, 0x67, 0x90"); + } + + addi(srcfile, "cmp rdx, 1"); + addi(srcfile, "je L_accum%="); + srcfile << " // Dump C\n"; + + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + addi( + srcfile, + "vmovups YMMWORD PTR [r12 + " + to_string(32 * c) + + "], " + vCtile[r][c], + fixedC); + } + addi(srcfile, "add r12, r13", fixedC); // move C ptr + } + addi(srcfile, "jmp L_done%="); + + srcfile << "\n"; + addi(srcfile, "L_accum%=:"); + srcfile << " // Dump C with accumulate\n"; + + string r_spare = (s.avx == 1) ? "ymm14" : "ymm15"; + addi( + srcfile, + "vbroadcastss " + r_spare + string(",DWORD PTR [r15]"), + fixedC); + // store out C + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + switch (s.avx) { + case 1: + addi( + srcfile, + string("vmulps ymm15, ") + r_spare + comma + + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", + fixedC); + addi( + srcfile, + "vaddps " + vCtile[r][c] + "," + vCtile[r][c] + "," + + "ymm15", + fixedC); + break; + case 2: + addi( + srcfile, + "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," + + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", + fixedC); + break; + default: + assert(0); + } + addi( + srcfile, + "vmovups YMMWORD PTR [r12 + " + to_string(32 * c) + + "], " + vCtile[r][c], + fixedC); + } + addi(srcfile, "add r12, r13", fixedC); // move C ptr + } + + srcfile << "\n"; + addi(srcfile, "L_done%=:"); + + srcfile << "\n // next outer iteration\n"; + // C + addi( + srcfile, + "add rcx, " + to_string(32 * ukernel_shape[k][1]), + fixedC); + addi(srcfile, "mov r12, rcx", fixedC); + // A + addi(srcfile, "mov r9, rax"); + + addi(srcfile, "inc rbx"); + addi(srcfile, "cmp rbx, rdi"); + addi(srcfile, "jl " + label2); + + // output + srcfile << " :\n"; + // input + srcfile << " : [gp] \"rm\"(gp)\n"; + + // clobbered + srcfile + << " : \"r8\",\n \"r9\",\n \"r10\",\n" + " \"r11\",\n \"r15\",\n \"r13\",\n" + " \"r14\",\n \"rax\",\n \"rcx\",\n" + " \"rdx\",\n \"rsi\",\n \"rdi\",\n" + " \"rbx\",\n \"r12\",\n" + " \"memory\");\n"; + srcfile << "}\n"; + } + +#endif + for (unsigned k = 0; k < ukernel_shape.size(); k++) { hdrfile << fheader[k] << ";\n"; } |