Welcome to mirror list, hosted at ThFree Co, Russian Federation.

ClassNLLCriterion.cu « generic « THCUNN « lib - github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 62925a8fa124848489abba69d033e5cfe4aaa57c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/ClassNLLCriterion.cu"
#else

void THNN_(ClassNLLCriterion_updateOutput)(
           THCState *state,
           THCTensor *input,
           THCIndexTensor *target,
           THCTensor *output,
           bool sizeAverage,
           THCTensor *weights,
           THCTensor *total_weight) {
  if (THCIndexTensor_(nDimension)(state, target) > 1) {
    THError("multi-target not supported");
  }

  int n_dims = THCTensor_(nDimension)(state, input);
  int n_classes = THCTensor_(size)(state, input, n_dims - 1);

  if (weights) {
    THCUNN_assertSameGPU_generic(
      state, 5, input, target, weights, output, total_weight
    );
  } else {
    THCUNN_assertSameGPU_generic(
      state, 4, input, target, output, total_weight
    );
  }

  if (THCTensor_(nDimension)(state, input) > 2) {
    THArgCheck(0, 2, "vector or matrix expected");
  }
  if (weights && THCTensor_(nElement)(state, weights) != n_classes) {
    THError("weight tensor should be defined either for all or no classes");
  }

  input = THCTensor_(newContiguous)(state, input);
  weights = weights ? THCTensor_(newContiguous)(state, weights) : NULL;
  target = THCIndexTensor_(newContiguous)(state, target);

  real *input_data = THCTensor_(data)(state, input);
  real *weights_data = weights ? THCTensor_(data)(state, weights) : NULL;
  THCIndex_t  *target_data = THCIndexTensor_(data)(state, target);
  real *output_data = THCTensor_(data)(state, output);
  real *total_weight_data = THCTensor_(data)(state, total_weight);

  if (THCTensor_(nDimension)(state, input) == 1) {
    cunn_ClassNLLCriterion_updateOutput_kernel1<real>
      <<<1, 1, 0, THCState_getCurrentStream(state)>>>(
        output_data,
        total_weight_data,
        input_data,
        target_data,
        weights_data,
        sizeAverage,
        n_classes
    );

  } else if (THCTensor_(nDimension)(state, input) == 2) {
    cunn_ClassNLLCriterion_updateOutput_kernel<real, accreal>
      <<<1, NTHREADS, 0, THCState_getCurrentStream(state)>>>(
        output_data,
        total_weight_data,
        input_data,
        target_data,
        weights_data,
        sizeAverage,
        THCTensor_(size)(state, input, 0),
        THCTensor_(size)(state, input, 1),
        n_classes
    );
  }
  THCudaCheck(cudaGetLastError());

  if (weights) {
    THCTensor_(free)(state, weights);
  }
  THCIndexTensor_(free)(state, target);
  THCTensor_(free)(state, input);
}

void THNN_(ClassNLLCriterion_updateGradInput)(
           THCState *state,
           THCTensor *input,
           THCIndexTensor *target,
           THCTensor *gradInput,
           bool sizeAverage,
           THCTensor *weights,
           THCTensor *total_weight) {
  if (THCIndexTensor_(nDimension)(state, target) > 1) {
    THError("multi-target not supported");
  }

  int n_dims = THCTensor_(nDimension)(state, input);
  int n_classes = THCTensor_(size)(state, input, n_dims - 1);

  THArgCheck(THCTensor_(isContiguous)(state, gradInput), 4, "gradInput must be contiguous");

  if (weights) {
    THCUNN_assertSameGPU(
      state, 5, weights, input, target, gradInput, total_weight
    );
  }
  else {
    THCUNN_assertSameGPU(
      state, 4, input, target, gradInput, total_weight
    );
  }

  if (THCTensor_(nDimension)(state, input) > 2) {
    THArgCheck(0, 2, "vector or matrix expected");
  }
  if (weights && THCTensor_(nElement)(state, weights) != n_classes) {
    THError("weight tensor should be defined either for all or no classes");
  }

  weights = weights ? THCTensor_(newContiguous)(state, weights) : NULL;
  target = THCIndexTensor_(newContiguous)(state, target);

  real *weights_data = weights ? THCTensor_(data)(state, weights) : NULL;
  real *gradInput_data = THCTensor_(data)(state, gradInput);
  THCIndex_t  *target_data = THCIndexTensor_(data)(state, target);
  real *total_weight_data = THCTensor_(data)(state, total_weight);

  if (THCTensor_(nDimension)(state, input) == 1) {
    cunn_ClassNLLCriterion_updateGradInput_kernel1<real>
      <<<1, 1, 0, THCState_getCurrentStream(state)>>>(
        gradInput_data,
        weights_data,
        target_data,
        total_weight_data,
        sizeAverage,
        n_classes
    );
  } else {
    cunn_ClassNLLCriterion_updateGradInput_kernel<real>
      <<<1, NTHREADS, 0, THCState_getCurrentStream(state)>>>(
        gradInput_data,
        target_data,
        weights_data,
        total_weight_data,
        sizeAverage,
        THCTensor_(size)(state, input, 0),
        THCTensor_(size)(state, input, 1),
        n_classes
    );
  }
  THCudaCheck(cudaGetLastError());

  if (weights) {
    THCTensor_(free)(state, weights);
  }
  THCIndexTensor_(free)(state, target);
}

#endif