diff options
Diffstat (limited to 'src/collectives/device/all_reduce.h')
-rw-r--r-- | src/collectives/device/all_reduce.h | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index ea89a71..9b058cc 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -1,10 +1,10 @@ /************************************************************************* - * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ -#include "core.h" +#include "devcomm.h" #include "primitives.h" #include "collectives.h" @@ -13,7 +13,7 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; const int nthreads = blockDim.x - 1; const int bid = args->bid; - struct ncclComm* comm = args->comm; + struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; struct ncclRing* ring = &channel->ring; const ssize_t size = args->N; @@ -87,7 +87,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; const int nthreads = blockDim.x - 1; const int bid = args->bid; - struct ncclComm* comm = args->comm; + struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; struct ncclTree* tree = &channel->tree; const ssize_t size = args->N; @@ -139,7 +139,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; const int bid = args->bid; const int nthreads = args->nThreads; - struct ncclComm* comm = args->comm; + struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; struct ncclRing* ring = &channel->ring; @@ -214,7 +214,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; const int nthreads = args->nThreads; const int bid = args->bid; - struct ncclComm* comm = args->comm; + struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; struct ncclTree* tree = &channel->tree; const ssize_t size = args->N; |