Welcome to mirror list, hosted at ThFree Co, Russian Federation.

View.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 542e57e1631d335f6bc32ecfbc3c05b685aab80a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
local View, parent = torch.class('nn.View', 'nn.Module')

function View:resetSize(...)
   if select('#', ...) == 1 and torch.typename(select(1, ...)) == 'torch.LongStorage' then
      self.size = select(1, ...)
   else
      self.size = torch.LongStorage({...})
   end

   self.numElements = 1
   local inferdim = false
   for i = 1,#self.size do
      local szi = self.size[i]
      if szi >= 0 then
         self.numElements = self.numElements * self.size[i]
      else
         assert(szi == -1, 'size should be positive or -1')
         assert(not inferdim, 'only one dimension can be at -1')
         inferdim = true
      end
   end

   return self
end

function View:__init(...)
   parent.__init(self)
   self:resetSize(...)
   self.numInputDims = nil
end

function View:setNumInputDims(numInputDims)
   self.numInputDims = numInputDims
   return self
end

local function batchsize(input, size, numInputDims, numElements)
   local ind = input:nDimension()
   local isz = input:size()
   local maxdim = numInputDims and numInputDims or ind
   local ine = 1
   for i=ind,ind-maxdim+1,-1 do
      ine = ine * isz[i]
   end

   if ine % numElements ~= 0 then
      error(string.format(
               'input view (%s) and desired view (%s) do not match',
               table.concat(input:size():totable(), 'x'),
               table.concat(size:totable(), 'x')))
   end

   -- the remainder is either the batch...
   local bsz = ine / numElements

   -- ... or the missing size dim
   for i=1,size:size() do
      if size[i] == -1 then
         bsz = 1
         break
      end
   end

   -- for dim over maxdim, it is definitively the batch
   for i=ind-maxdim,1,-1 do
      bsz = bsz * isz[i]
   end

   -- special card
   if bsz == 1 and (not numInputDims or input:nDimension() <= numInputDims) then
      return
   end

   return bsz
end

function View:updateOutput(input)
   self.output = self.output or input.new()
   local bsz = batchsize(input, self.size, self.numInputDims, self.numElements)
   if bsz then
      self.output:view(input, bsz, table.unpack(self.size:totable()))
   else
      self.output:view(input, self.size)
   end
   return self.output
end

function View:updateGradInput(input, gradOutput)
   self.gradInput = self.gradInput or gradOutput.new()
   self.gradInput:view(gradOutput, input:size())
   return self.gradInput
end

function View:__tostring__()
   return torch.type(self)..'('..table.concat(self.size:totable(), ', ')..')'
end