Welcome to mirror list, hosted at ThFree Co, Russian Federation.

Threshold.cu « generic « THCUNN « lib - github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0b7b79e42aed5b0a72f04ef2bf6d5206e5919c2f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/Threshold.cu"
#else

#include "../common.h"

void THNN_(Threshold_updateOutput)(
           THCState *state,
           THCTensor *input,
           THCTensor *output,
           accreal threshold_,
           accreal val_,
           bool inplace)
{
  real threshold = ScalarConvert<accreal, real>::to(threshold_);
  real val = ScalarConvert<accreal, real>::to(val_);
  THCUNN_assertSameGPU(state, 2, input, output);

  if (inplace)
  {
    THC_pointwiseApply1(state, input,
      ThresholdUpdateOutputIP<real>(threshold, val)
    );
    THCTensor_(set)(state, output, input);
  }
  else
  {
    THCTensor_(resizeAs)(state, output, input);
    THC_pointwiseApply2(state, output, input,
      ThresholdUpdateOutput<real>(threshold, val)
    );
  }

  THCudaCheck(cudaGetLastError());
}

void THNN_(Threshold_updateGradInput)(
           THCState *state,
           THCTensor *input,
           THCTensor *gradOutput,
           THCTensor *gradInput,
           accreal threshold_,
           accreal val_,
           bool inplace)
{
  real threshold = ScalarConvert<accreal, real>::to(threshold_);
  real val = ScalarConvert<accreal, real>::to(val_);
  THCUNN_check_nElement(state, input, gradOutput);
  THCUNN_assertSameGPU(state, 3, input, gradInput, gradOutput);

  if (inplace)
  {
    THC_pointwiseApply2(state, gradOutput, input,
      ThresholdUpdateGradInputIP<real>(threshold)
    );
    THCTensor_(set)(state, gradInput, gradOutput);
  }
  else
  {
    THCTensor_(resizeAs)(state, gradInput, input);
    THC_pointwiseApply3(state, gradInput, input, gradOutput,
       ThresholdUpdateGradInput<real>(threshold)
    );
  }

  THCudaCheck(cudaGetLastError());
}

#endif