Welcome to mirror list, hosted at ThFree Co, Russian Federation.

FeatureLPPooling.c « generic « THNN « lib - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 25a58dbe6ac3d4e0597d41a6eb202d1bf568c232 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/FeatureLPPooling.c"
#else

#ifndef FEATURE_LP_DEFS
#define FEATURE_LP_DEFS

typedef struct {
  size_t size[4];
  size_t stride[4];
} FeatureLPPoolingSizes;

inline size_t flpGetOffset(FeatureLPPoolingSizes* s,
                           size_t batch,
                           size_t feature,
                           size_t opt1,
                           size_t opt2) {
  return s->stride[0] * batch +
    s->stride[1] * feature +
    s->stride[2] * opt1 +
    s->stride[3] * opt2;
}

inline size_t flpOutputSize(size_t inputSize,
                            size_t width,
                            size_t stride) {
  return ((inputSize - width) / stride) + 1;
}

#endif // FEATURE_LP_DEFS

FeatureLPPoolingSizes
THNN_(FeatureLPPooling_upcastCPU)(THTensor* t, bool batchMode) {
  int dim = THTensor_(nDimension)(t);

  // Upcast to [batch dim][feature dim][opt dim 1][opt dim 2]
  FeatureLPPoolingSizes s;
  for (int i = 0; i < 4; ++i) {
    s.size[i] = 1;
    s.stride[i] = 1;
  }

  if (dim == 1) {
    THAssert(!batchMode);
    // [feature dim]
    s.size[1] = THTensor_(size)(t, 0);
    s.stride[1] = THTensor_(stride)(t, 0);
  } else if (dim == 2) {
    if (batchMode) {
      // [batch dim][feature dim]
      for (int i = 0; i < 2; ++i) {
        s.size[i] = THTensor_(size)(t, i);
        s.stride[i] = THTensor_(stride)(t, i);
      }
    } else {
      // [feature dim][opt dim 1]
      s.size[1] = THTensor_(size)(t, 0);
      s.stride[1] = THTensor_(stride)(t, 0);
      s.size[2] = THTensor_(size)(t, 1);
      s.stride[2] = THTensor_(stride)(t, 1);
    }
  } else if (dim == 3) {
    if (batchMode) {
      // [batch dim][feature dim][opt dim 1]
      for (int i = 0; i < 3; ++i) {
        s.size[i] = THTensor_(size)(t, i);
        s.stride[i] = THTensor_(stride)(t, i);
      }
    } else {
      // [feature dim][opt dim 1][opt dim 2]
      for (int i = 1; i < 4; ++i) {
        s.size[i] = THTensor_(size)(t, i - 1);
        s.stride[i] = THTensor_(stride)(t, i - 1);
      }
    }
  } else if (dim == 4) {
    // [batch dim][feature dim][opt dim 1][opt dim 2]
    THAssert(batchMode);
    for (int i = 0; i < 4; ++i) {
      s.size[i] = THTensor_(size)(t, i);
      s.stride[i] = THTensor_(stride)(t, i);
    }
  }

  return s;
}

void
THNN_(FeatureLPPooling_resizeForOutputCPU)(THTensor* toResize,
                                           THTensor* input,
                                           bool batchMode,
                                           int width,
                                           int stride) {
  int inputDim = THTensor_(nDimension)(input);
  THAssert(inputDim >= 1 && inputDim <= 4);

  long outSize =
    flpOutputSize(THTensor_(size)(input, 0), width, stride);
  if (batchMode) {
    THAssert(inputDim > 1);
    outSize =
      flpOutputSize(THTensor_(size)(input, 1), width, stride);
  } else {
    THAssert(inputDim < 4);
  }

  if (inputDim == 1) {
    THTensor_(resize1d)(toResize, outSize);
  } else if (inputDim == 2) {
    if (batchMode) {
      THTensor_(resize2d)(toResize,
                          THTensor_(size)(input, 0),
                          outSize);
    } else {
      THTensor_(resize2d)(toResize,
                          outSize,
                          THTensor_(size)(input, 1));
    }
  } else if (inputDim == 3) {
    if (batchMode) {
      THTensor_(resize3d)(toResize,
                          THTensor_(size)(input, 0), outSize,
                          THTensor_(size)(input, 2));
    } else {
      THTensor_(resize3d)(toResize,
                          outSize, THTensor_(size)(input, 1),
                          THTensor_(size)(input, 2));
    }
  } else if (inputDim == 4) {
    THTensor_(resize4d)(toResize,
                        THTensor_(size)(input, 0),
                        outSize,
                        THTensor_(size)(input, 2),
                        THTensor_(size)(input, 3));
  }
}

// Makes `toResize` the same size/dimensionality as `src`
void
THNN_(FeatureLPPooling_resizeCPU)(THTensor* toResize,
                                  THTensor* src) {
  int inputDim = THTensor_(nDimension)(src);
  THAssert(inputDim >= 1 && inputDim <= 4);

  if (inputDim == 1) {
    THTensor_(resize1d)(toResize,
                        THTensor_(size)(src, 0));
  } else if (inputDim == 2) {
    THTensor_(resize2d)(
      toResize,
      THTensor_(size)(src, 0),
      THTensor_(size)(src, 1));
  } else if (inputDim == 3) {
    THTensor_(resize3d)(
      toResize,
      THTensor_(size)(src, 0),
      THTensor_(size)(src, 1),
      THTensor_(size)(src, 2));
  } else if (inputDim == 4) {
    THTensor_(resize4d)(
      toResize,
      THTensor_(size)(src, 0),
      THTensor_(size)(src, 1),
      THTensor_(size)(src, 2),
      THTensor_(size)(src, 3));
  }
}

void
THNN_(FeatureLPPooling_updateOutput)(
  THNNState *state,
  THTensor *input,
  THTensor *output,
  accreal power,
  int width,
  int stride,
  bool batchMode) {
  int inputDim = THTensor_(nDimension)(input);

  if (batchMode) {
    THArgCheck(inputDim >= 2 && inputDim <= 4, 2,
               "input must be 2-4 dimensions for batch mode");
  } else {
    THArgCheck(inputDim >= 1 && inputDim <= 3, 2,
               "input must be 1-3 dimensions for non-batch mode");
  }

  FeatureLPPoolingSizes inputDesc =
    THNN_(FeatureLPPooling_upcastCPU)(input, batchMode);

  // Make sure the feature dimension is properly sized
  THArgCheck(inputDesc.size[1] >= width, 3,
             "input: feature dimension must be >= width");

  // Make sure that width and stride are within range
  THArgCheck(width >= 2 && width <= 16, 5,
             "width must be between 2 - 16");

  THArgCheck(stride >= 1 && stride <= 4, 6,
             "stride must be between 1 - 4");

  // Resize output

  THNN_(FeatureLPPooling_resizeForOutputCPU)(
    output, input, batchMode, width, stride);

  FeatureLPPoolingSizes outputDesc =
    THNN_(FeatureLPPooling_upcastCPU)(output, batchMode);

  real* inputP = THTensor_(data)(input);
  real* outputP = THTensor_(data)(output);

#pragma omp parallel for
  for (size_t batch = 0; batch < inputDesc.size[0]; ++batch) {
    for (size_t opt1 = 0; opt1 < inputDesc.size[2]; ++opt1) {
      for (size_t opt2 = 0; opt2 < inputDesc.size[3]; ++opt2) {
        for (size_t outputFeature = 0;
             outputFeature < outputDesc.size[1]; ++outputFeature) {

          accreal v = (accreal) 0;
          for (size_t i = 0; i < width; ++i) {
            size_t inputFeature = outputFeature * stride + i;
            if (inputFeature >= inputDesc.size[1]) {
              break;
            }

            v +=
              pow(inputP[flpGetOffset(&inputDesc,
                                      batch,
                                      inputFeature,
                                      opt1,
                                      opt2)], power);
          }

          outputP[flpGetOffset(&outputDesc, batch, outputFeature, opt1, opt2)] =
            pow(v, (accreal) 1 / power);
        }
      }
    }
  }
}

void
THNN_(FeatureLPPooling_updateGradInput)(
  THNNState *state,
  THTensor* gradOutput,
  THTensor* input,
  THTensor* output,
  THTensor* gradInput,
  accreal power,
  int width,
  int stride,
  bool batchMode) {
  int inputDim = THTensor_(nDimension)(input);

  if (batchMode) {
    THArgCheck(inputDim >= 2 && inputDim <= 4, 3,
               "input must be 2-4 dimensions for batch mode");
  } else {
    THArgCheck(inputDim >= 1 && inputDim <= 3, 3,
               "input must be 1-3 dimensions for non-batch mode");
  }

  FeatureLPPoolingSizes inputDesc =
    THNN_(FeatureLPPooling_upcastCPU)(input, batchMode);
  FeatureLPPoolingSizes gradOutputDesc =
    THNN_(FeatureLPPooling_upcastCPU)(gradOutput, batchMode);
  FeatureLPPoolingSizes outputDesc =
    THNN_(FeatureLPPooling_upcastCPU)(output, batchMode);

  // Make sure the feature dimension is properly sized
  THArgCheck(inputDesc.size[1] >= width, 3,
             "input: feature dimension must be >= width");

  // Make sure that width and stride are within range
  THArgCheck(width >= 2 && width <= 16, 7,
             "width must be between 2 - 16");

  THArgCheck(stride >= 1 && stride <= 4, 8,
             "stride must be between 1 - 4");

  for (int i = 0; i < 4; ++i) {
    THAssertMsg(outputDesc.size[i] == gradOutputDesc.size[i],
                "output and gradOutput sizes do not match");
  }

  // Make sure that the input sizes produce the output sizes
  THArgCheck(flpOutputSize(inputDesc.size[1], width, stride) ==
             outputDesc.size[1], 3,
             "input and output sizes do not match with respect to "
             "width and stride");

  // Resize `gradInput` based on `input`
  THNN_(FeatureLPPooling_resizeCPU)(gradInput, input);

  // Zero gradInput for accumulation
  THTensor_(zero)(gradInput);

  FeatureLPPoolingSizes gradInputDesc =
    THNN_(FeatureLPPooling_upcastCPU)(gradInput, batchMode);

  real* gradOutputP = THTensor_(data)(gradOutput);
  real* gradInputP = THTensor_(data)(gradInput);
  real* outputP = THTensor_(data)(output);
  real* inputP = THTensor_(data)(input);

#pragma omp parallel for
  for (size_t batch = 0; batch < inputDesc.size[0]; ++batch) {
    for (size_t opt1 = 0; opt1 < inputDesc.size[2]; ++opt1) {
      for (size_t opt2 = 0; opt2 < inputDesc.size[3]; ++opt2) {
        for (size_t outputFeature = 0;
             outputFeature < outputDesc.size[1]; ++outputFeature) {

          // Load output (f(x_is)). It is possible that this is zero, in
          // which case we'll ignore this point.
          real outputV =
            outputP[
              flpGetOffset(&outputDesc, batch, outputFeature, opt1, opt2)];

          if (outputV == (real) 0) {
            continue;
          }

          for (size_t i = 0; i < width; ++i) {
            size_t inputFeature = outputFeature * stride + i;
            THAssert(inputFeature < inputDesc.size[1]);

            real gradOutputV =
              gradOutputP[
                flpGetOffset(&gradOutputDesc, batch, outputFeature, opt1, opt2)];
            real inputV =
              inputP[
                flpGetOffset(&inputDesc, batch, inputFeature, opt1, opt2)];

            // Calculate grad * (x_i / f(x_is))^(p - 1)
            real v = gradOutputV * pow(inputV / outputV, power - (accreal) 1);

            gradInputP[
              flpGetOffset(&gradInputDesc, batch, inputFeature, opt1, opt2)]
              += v;
          }
        }
      }
    }
  }
}

#endif