Welcome to mirror list, hosted at ThFree Co, Russian Federation.

test_rotate.lua « test - github.com/torch/image.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 8f7ef91a6e3fe3f2bf7a7b3938c4ad82f5888d6f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
require 'image'

torch.setdefaulttensortype('torch.FloatTensor')
torch.setnumthreads(16)

local function test_rotate(src, mode)
   torch.manualSeed(11)
   local mean_dist = 0.0
   for i = 1, 10 do
      local theta = torch.uniform(0, 2 * math.pi)
      local d1, d2, d3, d4
      
      -- rotate
      if mode then
         d1 = image.rotate(src, theta, mode)
         d2 = src.new():resizeAs(src)
         image.rotate(d2, src, theta, mode)
      else
         d1 = image.rotate(src, theta)
         d2 = src.new():resizeAs(src)
         image.rotate(d2, src, theta)
      end

      -- revert
      local revert = 2 * math.pi - theta
      if mode then
         d3 = image.rotate(d1, revert, mode)
         d4 = src.new():resizeAs(src)
         image.rotate(d4, d2, revert, mode)
      else
         d3 = image.rotate(d1, revert)
         d4 = src.new():resizeAs(src)
         image.rotate(d4, d2, revert)
      end
      
      -- diff
      if src:dim() == 3 then
         local cs = image.crop(src, src:size(2) / 4, src:size(3) / 4, src:size(2) / 4 * 3, src:size(3) / 4 * 3)
         local c3 = image.crop(d3, src:size(2) / 4, src:size(3) / 4, src:size(2) / 4 * 3, src:size(3) / 4 * 3)
         local c4 = image.crop(d4, src:size(2) / 4, src:size(3) / 4, src:size(2) / 4 * 3, src:size(3) / 4 * 3)
         mean_dist = mean_dist + cs:dist(c3)
         mean_dist = mean_dist + cs:dist(c4)
      elseif src:dim() == 2 then
         local cs = image.crop(src, src:size(1) / 4, src:size(2) / 4, src:size(1) / 4 * 3, src:size(2) / 4 * 3)
         local c3 = image.crop(d3, src:size(1) / 4, src:size(2) / 4, src:size(1) / 4 * 3, src:size(2) / 4 * 3)
         local c4 = image.crop(d4, src:size(1) / 4, src:size(2) / 4, src:size(1) / 4 * 3, src:size(2) / 4 * 3)
         mean_dist = mean_dist + cs:dist(c3)
         mean_dist = mean_dist + cs:dist(c4)
      end
      --[[
      if i == 1 then
         image.display(src)
         image.display(d1)
         image.display(d2)
         image.display(d3)
         image.display(d4)
      end
      --]]
   end
   if mode then
      print("mode = " .. mode .. ", mean dist: " .. mean_dist / (10 * 2))
   else
      print("mode = nil, mean dist: " .. mean_dist / (10 * 2))
   end
end
local src = image.scale(image.lena(), 128, 128, 'bilinear')
print("** dim3")
test_rotate(src, nil)
test_rotate(src, 'simple')
test_rotate(src, 'bilinear')
print("** dim2")
src = src:select(1, 1)
test_rotate(src, nil)
test_rotate(src, 'simple')
test_rotate(src, 'bilinear')