Welcome to mirror list, hosted at ThFree Co, Russian Federation.

TemporalMaxPooling.c « generic - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0111cb5da76b9da0f95e2c82f221b9d208e70d5a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/TemporalMaxPooling.c"
#else

static int nn_(TemporalMaxPooling_updateOutput)(lua_State *L)
{
  THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
  int kW = luaT_getfieldcheckint(L, 1, "kW");
  int dW = luaT_getfieldcheckint(L, 1, "dW");
  THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_Tensor);
  THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);

  luaL_argcheck(L, input->nDimension == 2, 2, "2D tensor expected");
  luaL_argcheck(L, input->size[0] >= kW, 2, "input sequence smaller than kernel size");

  // sizes
  long niframe = input->size[0];
  long framesize = input->size[1];
  long noframe = (niframe - kW) / dW + 1;

  // get contiguous input
  input = THTensor_(newContiguous)(input);

  // resize output
  THTensor_(resize2d)(output, noframe, framesize);

  // indices will contain index locations for each output point
  THTensor_(resize2d)(indices, noframe, framesize);

  // get raw pointers
  real *input_data = THTensor_(data)(input);
  real *output_data = THTensor_(data)(output);
  real *indices_data = THTensor_(data)(indices);

  long t, x, y;
  for(t = 0; t < noframe; t++)
  {
    real *ip = input_data + t*framesize*dW;
    real *op = output_data + t*framesize;
    real *xp = indices_data + t*framesize;
#pragma omp parallel for private(y)
    for(y = 0; y < framesize; y++)
    {
      // compute local max:
      long maxindex = -1;
      real maxval = -THInf;
      for(x = 0; x < kW; x++)
      {
        real val = ip[x*framesize+y];
        if (val > maxval)
        {
          maxval = val;
          maxindex = x;
        }
      }

      // set output to local max
      op[y] = maxval;
      xp[y] = (real)maxindex;
    }
  }

  // cleanup
  THTensor_(free)(input);

  return 1;
}

static int nn_(TemporalMaxPooling_updateGradInput)(lua_State *L)
{
  THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
  THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
  int dW = luaT_getfieldcheckint(L, 1, "dW");
  THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_Tensor);
  THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);

  // get contiguous gradOutput
  gradOutput = THTensor_(newContiguous)(gradOutput);

  // resize and zero
  THTensor_(resizeAs)(gradInput, input);
  THTensor_(zero)(gradInput);

  // sizes
  int noframe = gradOutput->size[0];
  long framesize = gradOutput->size[1];

  // get raw pointers
  real *gradInput_data = THTensor_(data)(gradInput);
  real *gradOutput_data = THTensor_(data)(gradOutput);
  real *indices_data = THTensor_(data)(indices);

  long t, y;
  for(t = 0; t < noframe; t++)
  {
    real *gip = gradInput_data + t*framesize*dW;
    real *gop = gradOutput_data + t*framesize;
    real *xp = indices_data + t*framesize;
#pragma omp parallel for private(y)
    for(y = 0; y < framesize; y++)
    {
      // compute local max:
      long maxindex = (long)xp[y];
      gip[maxindex*framesize+y] += gop[y];
    }
  }

  // cleanup
  THTensor_(free)(gradOutput);

  return 1;
}

static const struct luaL_Reg nn_(TemporalMaxPooling__) [] = {
  {"TemporalMaxPooling_updateOutput", nn_(TemporalMaxPooling_updateOutput)},
  {"TemporalMaxPooling_updateGradInput", nn_(TemporalMaxPooling_updateGradInput)},
  {NULL, NULL}
};

static void nn_(TemporalMaxPooling_init)(lua_State *L)
{
  luaT_pushmetatable(L, torch_Tensor);
  luaT_registeratname(L, nn_(TemporalMaxPooling__), "nn");
  lua_pop(L,1);
}

#endif