Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-11 01:30:33 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-11 01:30:33 +0400
commit2cd92a9591e849dfd3330dfb13abf7fe9977b597 (patch)
tree34d9cbe4f69772ab79302ed5b3fa907fa0f8c68a /SpatialColorTransform.lua
parent40636e9b6afd287a46fac1a668fb0b876f9ead6f (diff)
added color transform module
Diffstat (limited to 'SpatialColorTransform.lua')
-rw-r--r--SpatialColorTransform.lua141
1 files changed, 141 insertions, 0 deletions
diff --git a/SpatialColorTransform.lua b/SpatialColorTransform.lua
new file mode 100644
index 0000000..1b181ad
--- /dev/null
+++ b/SpatialColorTransform.lua
@@ -0,0 +1,141 @@
+local SpatialColorTransform, parent = torch.class('nn.SpatialColorTransform', 'nn.Module')
+
+local help_desc =
+[[Provides a set of widely used/known color space transforms,
+for images: RGB->YUV, YUV->RGB, RGB->Y transforms, and
+more exotic transforms such as RGB->Normed-RGB]]
+
+local help_example =
+[[-- transforms an RGB image into a YUV image:
+converter = nn.SpatialColorTransform('rgb2yuv')
+rgb = image.lena()
+yuv = converter:forward(rgb)
+image.display(yuv) ]]
+
+function SpatialColorTransform:__init(type)
+ -- parent init
+ parent.__init(self)
+
+ -- require the image package
+ xlua.require('image',true)
+
+ -- usage
+ self.usage = xlua.usage(
+ 'nn.SpatialColorTransform', help_desc, help_example,
+ {type='string', req=true,
+ help='transform = yuv2rgb | rgb2yuv | rgb2y | hsl2rgb | hsv2rgb | rgb2hsl | rgb2hsv | rgb2nrgb | rgb2y+nrgb'}
+ )
+
+ -- transform type
+ self.type = type
+ if type == 'yuv2rgb' then
+ self.islinear = true
+ self.linear = nn.SpatialLinear(3,3)
+ -- R
+ self.linear.weight[1][1] = 1
+ self.linear.weight[1][2] = 0
+ self.linear.weight[1][3] = 1.13983
+ self.linear.bias[1] = 0
+ -- G
+ self.linear.weight[2][1] = 1
+ self.linear.weight[2][2] = -0.39465
+ self.linear.weight[2][3] = -0.58060
+ self.linear.bias[2] = 0
+ -- B
+ self.linear.weight[3][1] = 1
+ self.linear.weight[3][2] = 2.03211
+ self.linear.weight[3][3] = 0
+ self.linear.bias[3] = 0
+ elseif type == 'rgb2yuv' then
+ self.islinear = true
+ self.linear = nn.SpatialLinear(3,3)
+ -- Y
+ self.linear.weight[1][1] = 0.299
+ self.linear.weight[1][2] = 0.587
+ self.linear.weight[1][3] = 0.114
+ self.linear.bias[1] = 0
+ -- U
+ self.linear.weight[2][1] = -0.14713
+ self.linear.weight[2][2] = -0.28886
+ self.linear.weight[2][3] = 0.436
+ self.linear.bias[2] = 0
+ -- V
+ self.linear.weight[3][1] = 0.615
+ self.linear.weight[3][2] = -0.51499
+ self.linear.weight[3][3] = -0.10001
+ self.linear.bias[3] = 0
+ elseif type == 'rgb2y' then
+ self.islinear = true
+ self.linear = nn.SpatialLinear(3,1)
+ -- Y
+ self.linear.weight[1][1] = 0.299
+ self.linear.weight[1][2] = 0.587
+ self.linear.weight[1][3] = 0.114
+ self.linear.bias[1] = 0
+ elseif type == 'hsl2rgb' then
+ self.islinear = false
+ elseif type == 'hsv2rgb' then
+ self.islinear = false
+ elseif type == 'rgb2hsl' then
+ self.islinear = false
+ elseif type == 'rgb2hsv' then
+ self.islinear = false
+ elseif type == 'rgb2nrgb' then
+ self.islinear = false
+ elseif type == 'rgb2y+nrgb' then
+ self.islinear = false
+ else
+ xlua.error('transform required','nn.SpatialColorTransform',self.usage)
+ end
+end
+
+function SpatialColorTransform:forward(input)
+ if self.islinear then
+ self.output = self.linear:forward(input)
+ else
+ if self.type == 'rgb2hsl' then
+ self.output = image.rgb2hsl(input, self.output)
+ elseif self.type == 'rgb2hsv' then
+ self.output = image.rgb2hsv(input, self.output)
+ elseif self.type == 'hsl2rgb' then
+ self.output = image.hsl2rgb(input, self.output)
+ elseif self.type == 'rgb2hsv' then
+ self.output = image.rgb2hsv(input, self.output)
+ elseif self.type == 'rgb2nrgb' then
+ self.output = image.rgb2nrgb(input, self.output)
+ elseif self.type == 'rgb2y+nrgb' then
+ self.output:resize(4, input:size(2), input:size(3))
+ image.rgb2y(input, self.output:narrow(1,1,1))
+ image.rgb2nrgb(input, self.output:narrow(1,2,3))
+ end
+ end
+ return self.output
+end
+
+function SpatialColorTransform:backward(input, gradOutput)
+ if self.islinear then
+ self.gradInput = self.linear:backward(input, gradOutput)
+ else
+ xlua.error('backward not implemented for non-linear transforms',
+ 'SpatialColorTransform.backward')
+ end
+ return self.gradInput
+end
+
+function SpatialColorTransform:write(file)
+ parent.write(self, file)
+ file:writeObject(self.type)
+ file:writeBool(self.islinear)
+ if self.islinear then
+ file:writeObject(self.linear)
+ end
+end
+
+function SpatialColorTransform:read(file)
+ parent.read(self, file)
+ self.type = file:readObject()
+ self.islinear = file:readBool()
+ if self.islinear then
+ self.linear = file:readObject()
+ end
+end