Welcome to mirror list, hosted at ThFree Co, Russian Federation.

Convert.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 855338dd676fdee838169360909acafaa97a6785 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
------------------------------------------------------------------------
--[ nn.Convert ]--
-- Module to convert between different data formats
-- nn.Convert('bchw', 'bf') or nn.Convert('chw', 'f')
-- Automatically converts input to same type as self.output
-- Simplest use is for automatic input type converions : nn.Convert()
------------------------------------------------------------------------
local _ = require 'moses'
local Convert, parent = torch.class("nn.Convert", "nn.Container")

function Convert:__init(inputShape, outputShape)
   if outputShape and not inputShape then
      error"Expecting non-nil arg 1 when arg 2 is provided"
   end
   inputShape = inputShape or 'b*'
   outputShape = outputShape or inputShape
   self.inputShape = inputShape:find('b') and inputShape or ('b'..inputShape)
   self.outputShape = outputShape:find('b') and outputShape or ('b'..outputShape)
   self.inputBatchDim = self.inputShape:find('b')
   self.outputBatchDim = self.outputShape:find('b')
   if self.inputShape == 'b*' or self.outputShape == 'b*' then
      assert(self.inputShape == 'b*' and self.outputShape == 'b*', 'Both or neither shapes must be b*')
      self.nInputDim = -1
      self.nOutputDim = -1
      self.transposition = true
   else
      -- number of dims in batch mode
      self.nInputDim = #self.inputShape
      self.nOutputDim = #self.outputShape
      -- is the outputShape just a transposition of the inputShape?
      if self.nInputDim == self.nOutputDim then
         self.transposition = true
         for i=1,self.nInputDim do
            if not self.outputShape:find(self.inputShape:sub(i,i)) then
               self.transposition = false
               break
            end
         end
      end
   end
   parent.__init(self)
end

-- post-initialization
function Convert:buildConverter(input)
   if self.transposition then
      self.converter = self:transpose(self.outputShape)
   else
      if (torch.type(self[self.outputShape]) ~= 'function') then
         error(string.format("Unrecognized conversion of shape %s to %s", self.inputShape, self.outputShape))
      end
      self.converter = self[self.outputShape](self, input)
   end
   assert(torch.isTensor(self.output), "Expecting Tensor output")

   self.converter:type(torch.type(self.output))

   self.modules[1] = self.converter
end

function Convert:updateOutput(input)
   assert(torch.isTensor(input), "expecting Tensor")
   if not torch.isTypeOf(input, torch.type(self.output)) then
      -- handle different input type
      self._input = self._input or self.output.new()
      self._input:resize(input:size()):copy(input)
      input = self._input
   end
   self.batchMode = true
   if input:dim() < self.nInputDim then
      -- handle non-batch mode
      local inputSize = input:size():totable()
      table.insert(inputSize, self.inputBatchDim, 1)
      self.__input = self.__input or input.new()
      self.__input:set(input):resize(table.unpack(inputSize))
      input = self.__input
      self.batchMode = false
   end
   if not self.converter then
      self:buildConverter(input)
   end

   self.output = self.converter:updateOutput(input)

   if not self.batchMode then
      local outputSize = self.output:size():totable()
      table.remove(outputSize, self.outputBatchDim)
      self.__output = self.__output or self.output.new()
      self.__output:set(self.output):resize(table.unpack(outputSize))
      self.output = self.__output
   end
   return self.output
end

function Convert:updateGradInput(input, gradOutput)
   local input_ = input
   input = self._input or input
   if not self.batchMode then
      input = self.__input
      self.__gradOutput = self.__gradOutput or gradOutput.new()
      self.__gradOutput:set(gradOutput):resize(self.converter.output:size())
      gradOutput = self.__gradOutput
   end

   local gradInput = self.converter:updateGradInput(input, gradOutput)

   if not self.batchMode then
      self.__gradInput = self.__gradInput or gradInput.new()
      self.__gradInput:set(gradInput):resize(input_:size())
      gradInput = self.__gradInput
   end
   if self._input then
      self._gradInput = self._gradInput or input.new()
      self._gradInput:resize(input:size()):copy(gradInput)
      self.gradInput = self._gradInput
   else
      self.gradInput = gradInput
   end

   return self.gradInput
end

function Convert:accGradParameters(input, gradOutput, scale)
   input = self.batchMode and self.__input or self._input or input
   gradOutput = self.batchMode and self.__gradOutput or gradOutput
   self.converter:accGradParameters(input, gradOutput, scale)
end

function Convert:accUpdateGradParameters(input, gradOutput, lr)
   input = self.batchMode and self.__input or self._input or input
   gradOutput = self.batchMode and self.__gradOutput or gradOutput
   self.converter:accUpdateGradParameters(input, gradOutput, lr)
end

-- batch feature
function Convert:bf(input)
   local b_pos = self:findAxis('b', self.inputShape)
   local dim = #self.inputShape
   if self.inputShape == 'bt' then
      error"Conversion of shape bt to bf not supported: open an issue on github"
   end
   -- was b
   if dim == 1 then
      return nn.Reshape(1)
   end
   -- was b...
   local modula
   if b_pos ~= 1 then
      modula = nn.Transpose({1, b_pos})
   end
   if dim > 2 then
      local transpose = modula
      local sampleSize = input:select(self:findAxis('b'),1):nElement()
      local reshape = nn.Reshape(sampleSize)
      if transpose then
         modula = nn.Sequential()
         modula:add(transpose)
         modula:add(reshape)
      else
         modula = reshape
      end
   end
   return modula or nn.Identity()
end

-- each example is a scalar; batch is a vector
function Convert:b(input)
   local b_pos = self:findAxis('b')
   if self.inputShape == 'bt' or self.inputShape == 'tb' then
      local t_pos = self:findAxis('t')
      -- select first set of classes
      return nn.Select(t_pos, 1)
   elseif self.inputShape == 'bf' or self.inputShape == 'fb' then
      -- this wont work as expected with size(f) > 1
      local f_pos = self:findAxis('f')
      if input:size(f_pos) > 1 then
         error("Cannot convert shape "..self.inputShape.." to b when feature > 1")
      end
      return nn.Select(f_pos, 1)
   else
      error("Cannot convert shape "..self.inputShape.." to shape b")
   end
end

-- returns the current shape of the data
function Convert:default()
   return nn.Identity()
end

-- multi-class (batch target)
function Convert:bt()
   local b_pos = self:findAxis('b')
   local modula
   if self.inputShape == 'b' then
      modula = nn.Reshape(1)
   else
      error("cannot convert shape '"..self.inputShape.."' to bt")
   end
   return modula
end

-- a generic function for transposing shape axes
function Convert:transpose(newShape)
   if newShape == self.inputShape then
      return nn.Identity()
   end
   local inputShape = {}
   for i=1,#self.inputShape do
      table.insert(inputShape, self.inputShape:sub(i,i))
   end
   local transpositions = {}
   for i=1,#newShape do
      local j = _.indexOf(inputShape, newShape:sub(i,i))
      if i ~= j then
         local char = inputShape[i]
         inputShape[i] = inputShape[j]
         inputShape[j] = char
         table.insert(transpositions, {j, i})
      end
   end
   return nn.Transpose(table.unpack(transpositions))
end

function Convert:findAxis(axis_char, shape, silent)
   shape = shape or self.inputShape
   local axis_pos = shape:find(axis_char)
   if (not silent) and (not axis_pos) then
      error("Provided shape '"..shape.."' has no axis '"..axis_char.."'", 2)
   end
   return axis_pos
end

function Convert:clearState()
   self._input = nil
   self._gradInput = nil
   self.__input = nil
   self.__output = nil
   self.__gradInput = nil
   self.__gradOutput =  nil
end

function Convert:type(type)
   self:clearState()
   return parent.type(self, type)
end