Welcome to mirror list, hosted at ThFree Co, Russian Federation.

dropout.cu « gpu « tensors « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: b416871f50bc3b78909a684b77e02e151dd670c1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <cuda.h>
#include <curand.h>
#include <stdio.h>
#include <stdlib.h>

// clang-format off
#include "tensors/tensor_operators.h"
#include "tensors/gpu/backend.h"
// clang-format on

#define CUDA_CALL(x)                                  \
  do {                                                \
    if((x) != cudaSuccess) {                          \
      printf("Error at %s:%d\n", __FILE__, __LINE__); \
      exit(1);                                        \
    }                                                 \
  } while(0)

#define CURAND_CALL(x)                                \
  do {                                                \
    if((x) != CURAND_STATUS_SUCCESS) {                \
      printf("Error at %s:%d\n", __FILE__, __LINE__); \
      exit(1);                                        \
    }                                                 \
  } while(0)

namespace marian {
namespace gpu {

__global__ void gScale(float* data, int n, float p) {
  int index = threadIdx.x + blockIdx.x * blockDim.x;

  while(index < n) {
    data[index] = (data[index] < p) / p;
    index += gridDim.x * blockDim.x;
  }
}

void Dropout(Tensor tensor, float p) {
  auto gpuBackend
      = std::static_pointer_cast<gpu::Backend>(tensor->getBackend());
  curandGenerator_t gen = gpuBackend->getCurandGenerator();
  int n = tensor->size();
  CURAND_CALL(curandGenerateUniform(gen, tensor->data(), n));

  int numThreads = std::min(n, 512);
  int numBlocks = n / numThreads + (n % numThreads != 0);

  gScale<<<numBlocks, numThreads>>>(tensor->data(), n, 1.f - p);
}
}
}