Welcome to mirror list, hosted at ThFree Co, Russian Federation.

SpatialMSECriterion.lua - github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 4c2ced863d4a7f4b4aaa05386eb4e3b0fcd2ca18 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
local SpatialMSECriterion, parent = torch.class('nn.SpatialMSECriterion', 'nn.MSECriterion')

function SpatialMSECriterion:__init(...)
   parent.__init(self)

   xlua.unpack_class(self, {...},
      'nn.SpatialMSECriterion',
      'A spatial extension of the MSECriterion class.\n'
         ..' Provides a set of parameters to deal with spatial mini-batch training.',
      {arg='resampleTarget', type='number', help='ratio to resample target (target is a KxHxW tensor)', default=1},
      {arg='nbGradients', type='number', help='number of gradients to backpropagate (-1:all, >=1:nb)', default=-1},
      {arg='sizeAverage', type='number', help='if true, forward() returns an average instead of a sum of errors', default=true},
      {arg='ignoreClass', type='number', help='all gradients for this class will be zeroed', default=false}
   )
end

function SpatialMSECriterion:adjustTarget(input, target)
   -- (1) if target has 2 dims, it is assumed to be a map
   --     of target classes, for each point. we convert this map
   --     into a 3D map of class distributions, to emulate a classical
   --     mean-square regression problem.
   local sratio = self.resampleTarget
   if target:dim() == 2 then
      self.newtarget = self.newtarget or torch.Tensor()
      self.newtarget:resizeAs(input):fill(-1)
      input.nn.SpatialMSECriterion_retarget(self.newtarget, target)
      target = self.newtarget
   end
   -- (2) if the target map has an incorrect size, it is assumed
   --     to be at the original scale of the data (e.g. for dense
   --     classification problems, like scene parsing, the target
   --     map is at the resolution of the input image. Now the input
   --     of this criterion is the output of some neural network,
   --     and might have a smaller size/resolution than the original
   --     input). Step (2) corrects for convolutional-induced losses,
   --     while step (3) corrects for downsampling/strides.
   if (target:size(3)*sratio) ~= input:size(3) then
      local h = input:size(2)/sratio
      local y = math.floor((target:size(2) - (input:size(2)-1)*1/sratio)/2) + 1
      local w = input:size(3)/sratio
      local x = math.floor((target:size(3) - (input:size(3)-1)*1/sratio)/2) + 1
      target = target:narrow(2,y,h):narrow(3,x,w)
   end
   -- (3) correct target by resampling it to the size of the
   --     input. this is to compensate for downsampling/pooling
   --     operations.
   if sratio ~= 1 then
      local target_scaled = torch.Tensor(target:size(1), input:size(2), input:size(3))
      image.scale(target, target_scaled, 'simple')
      target = target_scaled
   end
   -- (4) last thing, optionally filter out some classes. In the
   --     MSE regression setup, -1 is the negative target.
   if self.ignoreClass then
      target:select(1, self.ignoreClass):fill(-1)
   end
   self.target = target
   return target
end

function SpatialMSECriterion:forward(input,target)
   -- (1) adjust target: class -> distributions of classes
   --                    compensate for convolution losses
   --                    compensate for striding effects
   --                    ignore a classe
   target = self:adjustTarget(input, target)
   -- (2) the full output contains as many errors as input
   --     vectors, whereas the self.output is a scalar that
   --     prunes all the errors
   self.fullOutput = self.fullOutput or torch.Tensor()
   self.fullOutput:resizeAs(input)
   -- (3) compute the dense errors:
   input.nn.SpatialMSECriterion_forward(self, input, target)
   -- (4) prune the errors, either by averaging, or accumulation:
   if self.sizeAverage then
      self.output = self.fullOutput:mean()
   else
      self.output = self.fullOutput:sum()
   end
   return self.output
end

function SpatialMSECriterion:backward(input,target)
   -- (1) retrieve adjusted target
   target = self.target
   -- (2) resize input gradient map
   self.gradInput:resizeAs(input):zero()
   -- (3) compute input gradients, based on the nbGradients param
   if self.nbGradients == -1 then
      -- dense gradients
      input.nn.SpatialMSECriterion_backward(self, input, target, self.gradInput)
   elseif self.nbGradients == 1 then
      -- only 1 gradient is computed, sampled in the center
      self.fullGradInput = torch.Tensor() or self.fullGradInput
      self.fullGradInput:resizeAs(input):zero()
      input.nn.SpatialMSECriterion_backward(self, input, target, self.fullGradInput)
      local y = math.ceil(self.gradInput:size(2)/2)
      local x = math.ceil(self.gradInput:size(3)/2)
      self.gradInput:select(3,x):select(2,y):copy(self.fullGradInput:select(3,x):select(2,y))
   else
      -- only N gradients are computed, sampled in random locations
      self.fullGradInput = torch.Tensor() or self.fullGradInput
      self.fullGradInput:resizeAs(input):zero()
      input.nn.SpatialMSECriterion_backward(self, input, target, self.fullGradInput)
      for i = 1,self.nbGradients do
         local x = math.random(1,self.gradInput:size(1))
         local y = math.random(1,self.gradInput:size(2))
         self.gradInput:select(3,x):select(2,y):copy(self.fullGradInput:select(3,x):select(2,y))
      end
   end
   return self.gradInput
end

function SpatialMSECriterion:write(file)
   parent.write(self, file)
   file:writeDouble(self.resampleTarget)
   file:writeInt(self.nbGradients)
   if not self.ignoreClass then
      file:writeInt(-1)
   end
end

function SpatialMSECriterion:read(file)
   parent.read(self, file)
   self.resampleTarget= file:readDouble()
   self.nbGradients = file:readInt()
   self.ignoreClass = file:readInt()
   if self.ignoreClass == -1 then
      self.ignoreClass = false
   end
   self.fullOutput = torch.Tensor()
end