Welcome to mirror list, hosted at ThFree Co, Russian Federation.

SpatialColorTransform.lua - github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 941a391b5f0ac9205a2be50d0bf53b583a11e2c6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
local SpatialColorTransform, parent = torch.class('nn.SpatialColorTransform', 'nn.Module')

local help_desc = 
[[Provides a set of widely used/known color space transforms,
for images: RGB->YUV, YUV->RGB, RGB->Y transforms, and 
more exotic transforms such as RGB->Normed-RGB]]

local help_example = 
[[-- transforms an RGB image into a YUV image:
converter = nn.SpatialColorTransform('rgb2yuv')
rgb = image.lena()
yuv = converter:forward(rgb) 
image.display(yuv) ]]

function SpatialColorTransform:__init(type)
   -- parent init
   parent.__init(self)

   -- require the image package
   xlua.require('image',true)

   -- usage
   self.usage = xlua.usage(
      'nn.SpatialColorTransform', help_desc, help_example,
      {type='string', req=true,
       help='transform = yuv2rgb | rgb2yuv | rgb2y | hsl2rgb | hsv2rgb | rgb2hsl | rgb2hsv | rgb2nrgb | rgb2y+nrgb'}
   )

   -- transform type
   self.transform = type
   if type == 'yuv2rgb' then
      self.islinear = true
      self.linear = nn.SpatialLinear(3,3)
      -- R
      self.linear.weight[1][1] = 1
      self.linear.weight[1][2] = 0
      self.linear.weight[1][3] = 1.13983
      self.linear.bias[1] = 0
      -- G
      self.linear.weight[2][1] = 1
      self.linear.weight[2][2] = -0.39465
      self.linear.weight[2][3] = -0.58060
      self.linear.bias[2] = 0
      -- B
      self.linear.weight[3][1] = 1
      self.linear.weight[3][2] = 2.03211
      self.linear.weight[3][3] = 0
      self.linear.bias[3] = 0
   elseif type == 'rgb2yuv' then
      self.islinear = true
      self.linear = nn.SpatialLinear(3,3)
      -- Y
      self.linear.weight[1][1] = 0.299
      self.linear.weight[1][2] = 0.587
      self.linear.weight[1][3] = 0.114
      self.linear.bias[1] = 0
      -- U
      self.linear.weight[2][1] = -0.14713
      self.linear.weight[2][2] = -0.28886
      self.linear.weight[2][3] = 0.436
      self.linear.bias[2] = 0
      -- V
      self.linear.weight[3][1] = 0.615
      self.linear.weight[3][2] = -0.51499
      self.linear.weight[3][3] = -0.10001
      self.linear.bias[3] = 0
   elseif type == 'rgb2y' then
      self.islinear = true
      self.linear = nn.SpatialLinear(3,1)
      -- Y
      self.linear.weight[1][1] = 0.299
      self.linear.weight[1][2] = 0.587
      self.linear.weight[1][3] = 0.114
      self.linear.bias[1] = 0
   elseif type == 'hsl2rgb' then
      self.islinear = false
   elseif type == 'hsv2rgb' then
      self.islinear = false
   elseif type == 'rgb2hsl' then
      self.islinear = false
   elseif type == 'rgb2hsv' then
      self.islinear = false
   elseif type == 'rgb2nrgb' then
      self.islinear = false
   elseif type == 'rgb2y+nrgb' then
      self.islinear = false
   else
      xlua.error('transform required','nn.SpatialColorTransform',self.usage)
   end      
end

function SpatialColorTransform:updateOutput(input)
   if self.islinear then
      self.output = self.linear:updateOutput(input)
   else
      if self.transform == 'rgb2hsl' then
         self.output = image.rgb2hsl(input, self.output)
      elseif self.transform == 'rgb2hsv' then
         self.output = image.rgb2hsv(input, self.output)
      elseif self.transform == 'hsl2rgb' then
         self.output = image.hsl2rgb(input, self.output)
      elseif self.transform == 'rgb2hsv' then
         self.output = image.rgb2hsv(input, self.output)
      elseif self.transform == 'rgb2nrgb' then
         self.output = image.rgb2nrgb(input, self.output)
      elseif self.transform == 'rgb2y+nrgb' then
         self.output:resize(4, input:size(2), input:size(3))
         image.rgb2y(input, self.output:narrow(1,1,1))
         image.rgb2nrgb(input, self.output:narrow(1,2,3))
      end
   end
   return self.output
end

function SpatialColorTransform:updateGradInput(input, gradOutput)
   if self.islinear then
      self.gradInput = self.linear:updateGradInput(input, gradOutput)
   else
      xlua.error('updateGradInput not implemented for non-linear transforms',
                 'SpatialColorTransform.updateGradInput')
   end
   return self.gradInput
end

function SpatialColorTransform:type(type)
   parent.type(self,type)
   if self.islinear then
      self.linear:type(type)
   end
end