Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/image.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-08-17 22:02:51 +0300
committerSoumith Chintala <soumith@gmail.com>2015-08-17 22:02:51 +0300
commitde926c7de81cf317ef92b908cacf20c8978efbfc (patch)
tree0a0cecb15cf7d479e05d1993d95c63545dd46cad
parent5c2e963f4cca3a6c76dc75d0d7503d77dee57d4f (diff)
parent869978fd4152a7605815f5d2c8c54066fa3fec16 (diff)
Merge pull request #93 from colesbury/byte
Support saving torch.ByteTensor images
-rw-r--r--init.lua45
-rwxr-xr-xtest/test_decompress_jpg.lua6
2 files changed, 24 insertions, 27 deletions
diff --git a/init.lua b/init.lua
index 3111e34..f1658cf 100644
--- a/init.lua
+++ b/init.lua
@@ -138,15 +138,22 @@ local function loadPNG(filename, depth, tensortype)
end
rawset(image, 'loadPNG', loadPNG)
+local function clampImage(tensor)
+ if tensor:type() == 'torch.ByteTensor' then
+ return tensor
+ end
+ local a = torch.Tensor():resize(tensor:size()):copy(tensor)
+ a.image.saturate(tensor) -- bound btwn 0 and 1
+ a:mul(255) -- remap to [0..255]
+ return a
+end
+
local function savePNG(filename, tensor)
if not xlua.require 'libpng' then
dok.error('libpng package not found, please install libpng','image.savePNG')
end
- local MAXVAL = 255
- local a = torch.Tensor():resize(tensor:size()):copy(tensor)
- a.image.saturate(a) -- bound btwn 0 and 1
- a:mul(MAXVAL) -- remap to [0..255]
- a.libpng.save(filename, a)
+ tensor = clampImage(tensor)
+ tensor.libpng.save(filename, tensor)
end
rawset(image, 'savePNG', savePNG)
@@ -222,13 +229,10 @@ local function saveJPG(filename, tensor)
if not xlua.require 'libjpeg' then
dok.error('libjpeg package not found, please install libjpeg','image.saveJPG')
end
- local MAXVAL = 255
- local a = torch.Tensor():resize(tensor:size()):copy(tensor)
- a.image.saturate(a) -- bound btwn 0 and 1
- a:mul(MAXVAL) -- remap to [0..255]
+ tensor = clampImage(tensor)
local save_to_file = 1
local quality = 75
- a.libjpeg.save(filename, a, save_to_file, quality)
+ tensor.libjpeg.save(filename, tensor, save_to_file, quality)
end
rawset(image, 'saveJPG', saveJPG)
@@ -244,14 +248,11 @@ local function compressJPG(tensor, quality)
dok.error('libjpeg package not found, please install libjpeg',
'image.compressJPG')
end
- local MAXVAL = 255
- local a = torch.Tensor():resize(tensor:size()):copy(tensor)
- a.image.saturate(a) -- bound btwn 0 and 1
- a:mul(MAXVAL) -- remap to [0..255]
+ tensor = clampImage(tensor)
local b = torch.ByteTensor()
local save_to_file = 0
quality = quality or 75
- a.libjpeg.save("", a, save_to_file, quality, b)
+ tensor.libjpeg.save("", tensor, save_to_file, quality, b)
return b
end
rawset(image, 'compressJPG', compressJPG)
@@ -280,11 +281,8 @@ local function savePPM(filename, tensor)
if tensor:nDimension() ~= 3 or tensor:size(1) ~= 3 then
dok.error('can only save 3xHxW images as PPM', 'image.savePPM')
end
- local MAXVAL = 255
- local a = torch.Tensor():resize(tensor:size()):copy(tensor)
- a.image.saturate(a) -- bound btwn 0 and 1
- a:mul(MAXVAL) -- remap to [0..255]
- a.libppm.save(filename, a)
+ tensor = clampImage(tensor)
+ tensor.libppm.save(filename, tensor)
end
rawset(image, 'savePPM', savePPM)
@@ -293,11 +291,8 @@ local function savePGM(filename, tensor)
if tensor:nDimension() == 3 and tensor:size(1) ~= 1 then
dok.error('can only save 1xHxW or HxW images as PGM', 'image.savePGM')
end
- local MAXVAL = 255
- local a = torch.Tensor():resize(tensor:size()):copy(tensor)
- a.image.saturate(a) -- bound btwn 0 and 1
- a:mul(MAXVAL) -- remap to [0..255]
- a.libppm.save(filename, a)
+ tensor = clampImage(tensor)
+ tensor.libppm.save(filename, tensor)
end
rawset(image, 'savePGM', savePGM)
diff --git a/test/test_decompress_jpg.lua b/test/test_decompress_jpg.lua
index a64a443..5728eaf 100755
--- a/test/test_decompress_jpg.lua
+++ b/test/test_decompress_jpg.lua
@@ -51,9 +51,11 @@ function test.LoadInvalid()
local img_binary = torch.rand(file_size_bytes):mul(255):byte()
-- Now decompress the image from the ByteTensor
- local img_from_tensor = image.decompressJPG(img_binary)
+ local ok, img_from_tensor = pcall(function()
+ return image.decompressJPG(img_binary)
+ end)
- mytester:assert(img_from_tensor == nil,
+ mytester:assert(not ok or img_from_tensor == nil,
'A non-nil was returned on an invalid input! ')
end