diff options
Diffstat (limited to 'src/primitives.h')
-rw-r--r-- | src/primitives.h | 206 |
1 files changed, 0 insertions, 206 deletions
diff --git a/src/primitives.h b/src/primitives.h deleted file mode 100644 index bcaeca8..0000000 --- a/src/primitives.h +++ /dev/null @@ -1,206 +0,0 @@ -/************************************************************************* - * Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#ifndef PRIMITIVES_H_ -#define PRIMITIVES_H_ - -#include <type_traits> -#include "copy_kernel.h" // for FuncPassA -#include "reduce_kernel.h" // for reduction funcs - - -/* Defines primitive operations: Copy, Reduce, DoubleCopy, and ReduceCopy. - * - * In order to reduce the reptetion of template arguments, the operations - * are bundled as static methods of the Primitives class. - * - * Each primitive operation copies/reduces a contiguous buffer and syncs - * an optional set of flags against a sub-step counter. The sync value is - * based on the step parameter. Sync flags must be of type WaitFlag or - * PostFlag. The primitive routines wait for all WaitFlag args to attain - * at least a value of SUBSTEPS*(step-1)+substep+1 (i.e. completion of - * corresponding substep by previous step) before executing the transfer. - * After each substep is transfered, all PostFlag arguments get updated to - * the value SUBSTEPS*step+substep+1. - */ - - -class WaitFlag { - volatile int * const flag; - const int shift; - public: - __device__ __forceinline__ - WaitFlag(volatile int * const flag, const int shift) : flag(flag), shift(shift) { } - __device__ __forceinline__ - void wait(int val) { while (*flag < (val + shift)) /*SPIN*/; } -}; - - -class PostFlag { - volatile int * const flag; - const int shift; - public: - __device__ __forceinline__ - PostFlag(volatile int* const flag, const int shift) : flag(flag), shift(shift) { } - __device__ __forceinline__ - void post(int val) { *flag = (val + shift); } -}; - - -// Helper to check if any argument is of type T. -// e.g. AnyAre<WaitFlag>(Flag1, Flag2, ...) -template<typename T> __device__ __forceinline__ -bool AnyAre() { return false; } - -template<typename T, typename FIRST_T, typename... TAIL_Ts> -__device__ __forceinline__ -bool AnyAre(FIRST_T first, TAIL_Ts... tail) { - return std::is_same<T, FIRST_T>::value || AnyAre<T>(tail...); -} - - -// Wait on all WaitFlags, ignore PostFlags -__device__ __forceinline__ -void WaitOnFlags(int val) { } - -template <typename... TAIL_Ts> __device__ __forceinline__ -void WaitOnFlags(int val, WaitFlag flag, TAIL_Ts... tail) { - flag.wait(val); - WaitOnFlags(val, tail...); -} - -template <typename... TAIL_Ts> __device__ __forceinline__ -void WaitOnFlags(int val, PostFlag, TAIL_Ts... tail) { - WaitOnFlags(val, tail...); -} - - -// Post all PostFlags, ingnore WaitFlags -__device__ __forceinline__ -void PostToFlags(int val) { } - -template <typename... TAIL_Ts> __device__ __forceinline__ -void PostToFlags(int val, WaitFlag flag, TAIL_Ts... tail) { - PostToFlags(val, tail...); -} - -template <typename... TAIL_Ts> __device__ __forceinline__ -void PostToFlags(int val, PostFlag flag, TAIL_Ts... tail) { - flag.post(val); - PostToFlags(val, tail...); -} - - -// Create pointer arithmetic syntax that doesn't break for nullptr_t -template <typename Tptr> __device__ __forceinline__ -Tptr ptradd(Tptr ptr, int i) { - return ptr + i; -} - -__device__ __forceinline__ -std::nullptr_t ptradd(std::nullptr_t ptr, int i) { - return nullptr; -} - - -// Implementation of primitive types -template <int THREADS, int UNROLL, int SUBSTEPS, typename T, typename REDOP=FuncSum<T> > -class Primitives { - private: - template <typename SRC2_T, // either T* or nullptr_t - typename DST2_T, // either T* or nullptr_t - typename... SYNC_Ts> // either WaitFunc or PostFunc - static __device__ __forceinline__ void - GenericOp(const T* src1, - const SRC2_T src2, - T* dst1, - DST2_T dst2, - int len, int maxoffset, int step, SYNC_Ts... flags) { - - enum { noSrc2 = std::is_same<SRC2_T, std::nullptr_t>::value }; - enum { noDst2 = std::is_same<DST2_T, std::nullptr_t>::value }; - static_assert(noSrc2 || std::is_same<SRC2_T, const T*>::value, - "src2 must be of type T* or nullptr_t"); - static_assert(noDst2 || std::is_same<DST2_T, T*>::value, - "dst2 must be of type T* or nullptr_t"); - - using OpType = typename std::conditional<noSrc2, FuncPassA<T>, REDOP>::type; - - if (threadIdx.x < THREADS) { - int sliceSize = len / SUBSTEPS; - int sliceOffset = 0; - #pragma unroll 1 - for (int sub=0; sub<SUBSTEPS; ++sub) { - if (AnyAre<WaitFlag>(flags...)) { - if (threadIdx.x == 0) { - WaitOnFlags(SUBSTEPS*step + sub + 1, flags...); - } - asm volatile ("bar.sync 1, %0;" :: "r"(THREADS)); - } - ReduceOrCopy - < - UNROLL, - THREADS, - OpType, - T, - !std::is_same<DST2_T, std::nullptr_t>::value, // HAS_DEST1 - !std::is_same<SRC2_T, std::nullptr_t>::value // HAS_SRC1 - > - ( - threadIdx.x, - ptradd(dst1, sliceOffset), - ptradd(dst2, sliceOffset), - ptradd(src1, sliceOffset), - ptradd(src2, sliceOffset), - min(sliceSize, maxoffset-sliceOffset) - ); - if (AnyAre<PostFlag>(flags...)) { - __syncthreads(); - } - sliceOffset += sliceSize; - } - } else { - for(int sub=0; sub<SUBSTEPS; ++sub) { - if (AnyAre<PostFlag>(flags...)) { - __syncthreads(); - __threadfence_system(); - PostToFlags(SUBSTEPS*step + sub + 1, flags...); - } - } - } - } - - public: - template <typename... SYNC_Ts> - static __device__ __forceinline__ void - Copy(const T* src, T* dst, - int len, int maxOffset, int step, SYNC_Ts... flags) { - GenericOp(src, nullptr, dst, nullptr, len, maxOffset, step, flags...); - } - - template <typename... SYNC_Ts> - static __device__ __forceinline__ void - DoubleCopy(const T* src, T* dst1, T* dst2, - int len, int maxOffset, int step, SYNC_Ts... flags) { - GenericOp(src, nullptr, dst1, dst2, len, maxOffset, step, flags...); - } - - template <typename... SYNC_Ts> - static __device__ __forceinline__ void - Reduce(const T* src1, const T* src2, T* dst, - int len, int maxOffset, int step, SYNC_Ts... flags) { - GenericOp(src1, src2, dst, nullptr, len, maxOffset, step, flags...); - } - - template <typename... SYNC_Ts> - static __device__ __forceinline__ void - ReduceCopy(const T* src1, const T* src2, T* dst1, T* dst2, - int len, int maxOffset, int step, SYNC_Ts... flags) { - GenericOp(src1, src2, dst1, dst2, len, maxOffset, step, flags...); - } -}; - -#endif // end include guard |