Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--SpatialColorTransform.lua141
-rw-r--r--init.lua1
-rw-r--r--nnx-1.0-1.rockspec1
3 files changed, 143 insertions, 0 deletions
diff --git a/SpatialColorTransform.lua b/SpatialColorTransform.lua
new file mode 100644
index 0000000..1b181ad
--- /dev/null
+++ b/SpatialColorTransform.lua
@@ -0,0 +1,141 @@
+local SpatialColorTransform, parent = torch.class('nn.SpatialColorTransform', 'nn.Module')
+
+local help_desc =
+[[Provides a set of widely used/known color space transforms,
+for images: RGB->YUV, YUV->RGB, RGB->Y transforms, and
+more exotic transforms such as RGB->Normed-RGB]]
+
+local help_example =
+[[-- transforms an RGB image into a YUV image:
+converter = nn.SpatialColorTransform('rgb2yuv')
+rgb = image.lena()
+yuv = converter:forward(rgb)
+image.display(yuv) ]]
+
+function SpatialColorTransform:__init(type)
+ -- parent init
+ parent.__init(self)
+
+ -- require the image package
+ xlua.require('image',true)
+
+ -- usage
+ self.usage = xlua.usage(
+ 'nn.SpatialColorTransform', help_desc, help_example,
+ {type='string', req=true,
+ help='transform = yuv2rgb | rgb2yuv | rgb2y | hsl2rgb | hsv2rgb | rgb2hsl | rgb2hsv | rgb2nrgb | rgb2y+nrgb'}
+ )
+
+ -- transform type
+ self.type = type
+ if type == 'yuv2rgb' then
+ self.islinear = true
+ self.linear = nn.SpatialLinear(3,3)
+ -- R
+ self.linear.weight[1][1] = 1
+ self.linear.weight[1][2] = 0
+ self.linear.weight[1][3] = 1.13983
+ self.linear.bias[1] = 0
+ -- G
+ self.linear.weight[2][1] = 1
+ self.linear.weight[2][2] = -0.39465
+ self.linear.weight[2][3] = -0.58060
+ self.linear.bias[2] = 0
+ -- B
+ self.linear.weight[3][1] = 1
+ self.linear.weight[3][2] = 2.03211
+ self.linear.weight[3][3] = 0
+ self.linear.bias[3] = 0
+ elseif type == 'rgb2yuv' then
+ self.islinear = true
+ self.linear = nn.SpatialLinear(3,3)
+ -- Y
+ self.linear.weight[1][1] = 0.299
+ self.linear.weight[1][2] = 0.587
+ self.linear.weight[1][3] = 0.114
+ self.linear.bias[1] = 0
+ -- U
+ self.linear.weight[2][1] = -0.14713
+ self.linear.weight[2][2] = -0.28886
+ self.linear.weight[2][3] = 0.436
+ self.linear.bias[2] = 0
+ -- V
+ self.linear.weight[3][1] = 0.615
+ self.linear.weight[3][2] = -0.51499
+ self.linear.weight[3][3] = -0.10001
+ self.linear.bias[3] = 0
+ elseif type == 'rgb2y' then
+ self.islinear = true
+ self.linear = nn.SpatialLinear(3,1)
+ -- Y
+ self.linear.weight[1][1] = 0.299
+ self.linear.weight[1][2] = 0.587
+ self.linear.weight[1][3] = 0.114
+ self.linear.bias[1] = 0
+ elseif type == 'hsl2rgb' then
+ self.islinear = false
+ elseif type == 'hsv2rgb' then
+ self.islinear = false
+ elseif type == 'rgb2hsl' then
+ self.islinear = false
+ elseif type == 'rgb2hsv' then
+ self.islinear = false
+ elseif type == 'rgb2nrgb' then
+ self.islinear = false
+ elseif type == 'rgb2y+nrgb' then
+ self.islinear = false
+ else
+ xlua.error('transform required','nn.SpatialColorTransform',self.usage)
+ end
+end
+
+function SpatialColorTransform:forward(input)
+ if self.islinear then
+ self.output = self.linear:forward(input)
+ else
+ if self.type == 'rgb2hsl' then
+ self.output = image.rgb2hsl(input, self.output)
+ elseif self.type == 'rgb2hsv' then
+ self.output = image.rgb2hsv(input, self.output)
+ elseif self.type == 'hsl2rgb' then
+ self.output = image.hsl2rgb(input, self.output)
+ elseif self.type == 'rgb2hsv' then
+ self.output = image.rgb2hsv(input, self.output)
+ elseif self.type == 'rgb2nrgb' then
+ self.output = image.rgb2nrgb(input, self.output)
+ elseif self.type == 'rgb2y+nrgb' then
+ self.output:resize(4, input:size(2), input:size(3))
+ image.rgb2y(input, self.output:narrow(1,1,1))
+ image.rgb2nrgb(input, self.output:narrow(1,2,3))
+ end
+ end
+ return self.output
+end
+
+function SpatialColorTransform:backward(input, gradOutput)
+ if self.islinear then
+ self.gradInput = self.linear:backward(input, gradOutput)
+ else
+ xlua.error('backward not implemented for non-linear transforms',
+ 'SpatialColorTransform.backward')
+ end
+ return self.gradInput
+end
+
+function SpatialColorTransform:write(file)
+ parent.write(self, file)
+ file:writeObject(self.type)
+ file:writeBool(self.islinear)
+ if self.islinear then
+ file:writeObject(self.linear)
+ end
+end
+
+function SpatialColorTransform:read(file)
+ parent.read(self, file)
+ self.type = file:readObject()
+ self.islinear = file:readBool()
+ if self.islinear then
+ self.linear = file:readObject()
+ end
+end
diff --git a/init.lua b/init.lua
index 37b4c4d..0161ba5 100644
--- a/init.lua
+++ b/init.lua
@@ -81,6 +81,7 @@ torch.include('nnx', 'SpatialReSampling.lua')
torch.include('nnx', 'SpatialRecursiveFovea.lua')
torch.include('nnx', 'SpatialFovea.lua')
torch.include('nnx', 'SpatialGraph.lua')
+torch.include('nnx', 'SpatialColorTransform.lua')
-- criterions:
torch.include('nnx', 'SuperCriterion.lua')
diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec
index 636e864..16be7c0 100644
--- a/nnx-1.0-1.rockspec
+++ b/nnx-1.0-1.rockspec
@@ -91,6 +91,7 @@ build = {
install_files(/lua/nnx SparseCriterion.lua)
install_files(/lua/nnx SpatialSparseCriterion.lua)
install_files(/lua/nnx SpatialGraph.lua)
+ install_files(/lua/nnx SpatialColorTransform.lua)
install_files(/lua/nnx SpatialRecursiveFovea.lua)
add_subdirectory (test)
install_targets(/lib nnx)