Welcome to mirror list, hosted at ThFree Co, Russian Federation.

SpatialAutoCropMSECriterion.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 97206a0627b71fd726512c5eec2a4d0e24f56c26 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
--[[
   SpatialAutoCropMSECriterion.
   Implements the MSECriterion when the spatial resolution of the input is less than
   or equal to the spatial resolution of the target. It achieves this center-cropping
   the target to the same spatial resolution of the input and the MSE is then
   calculated between these cropped inputs
]]
local SpatialAutoCropMSECriterion, parent = torch.class('nn.SpatialAutoCropMSECriterion', 'nn.MSECriterion')

function SpatialAutoCropMSECriterion:__init(sizeAverage)
    parent.__init(self, sizeAverage)
end

local function centerCrop(input, cropSize)
   assert(input:dim() == 3 or input:dim() == 4, "input should be a 3D or  4D tensor")
   assert(#cropSize == 2, "cropSize should have two elements only")
   local _input = input
   if input:dim() == 3 then
      _input = input:view(1, input:size(1), input:size(2), input:size(3))
   end
   assert(cropSize[1] > 0 and cropSize[1] <= _input:size(3),
         "0 < cropSize[1] <= input:size(3) not satisfied")
   assert(cropSize[2] > 0 and cropSize[2] <= _input:size(4),
        "0 < cropSize[1] <= input:size(3) not satisfied")

   local inputHeight = _input:size(3)
   local inputWidth = _input:size(4)

   local rowStart = 1 + math.floor((inputHeight - cropSize[1])/2.0)
   local rowEnd = rowStart + cropSize[1] - 1
   local colStart = 1 +  math.floor((inputWidth - cropSize[2])/2.0)
   local colEnd = colStart + cropSize[2] - 1
   if input:dim() == 3 then
      return input[{{}, {rowStart, rowEnd}, {colStart, colEnd}}]
   else
      return input[{{}, {}, {rowStart, rowEnd}, {colStart, colEnd}}]
   end
end

local function getTensorHeightAndWidth(tensor)
   local heightIdx = 2
   local widthIdx = 3
   if tensor:dim() == 4 then
      heightIdx = 3
      widthIdx = 4
   end
   return tensor:size(heightIdx), tensor:size(widthIdx)
end

local function inputResolutionIsSmallerThanTargetResolution(input, target)
   local inputHeight, inputWidth = getTensorHeightAndWidth(input)
   local targetHeight, targetWidth = getTensorHeightAndWidth(target)
   return inputHeight <= targetHeight and inputWidth <= targetWidth
end

function SpatialAutoCropMSECriterion:updateOutput(input, target)
   assert(input:dim() == target:dim(), "input and target should have the same number of dimensions")
   assert(input:dim() == 4 or input:dim() == 3, "input and target must have 3 or 4 dimensions")
   assert(inputResolutionIsSmallerThanTargetResolution(input, target),
   "Spatial resolution of input should be less than or equal to the spatial resolution of the target")

   local inputHeight, inputWidth = getTensorHeightAndWidth(input)
   local targetCropped = centerCrop(target, {inputHeight, inputWidth})
   return parent.updateOutput(self, input, targetCropped)
end


function SpatialAutoCropMSECriterion:updateGradInput(input, gradOutput)
   assert(input:dim() == gradOutput:dim(), "input and gradOutput should have the same number of dimensions")
   assert(input:dim() == 4 or input:dim() == 3, "input and gradOutput must have 3 or 4 dimensions")
   assert(input:isSameSizeAs(gradOutput), "gradOutput and input must have the same size")

   return parent.updateGradInput(self, input, gradOutput)
end