diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-08-17 22:02:51 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-08-17 22:02:51 +0300 |
commit | de926c7de81cf317ef92b908cacf20c8978efbfc (patch) | |
tree | 0a0cecb15cf7d479e05d1993d95c63545dd46cad | |
parent | 5c2e963f4cca3a6c76dc75d0d7503d77dee57d4f (diff) | |
parent | 869978fd4152a7605815f5d2c8c54066fa3fec16 (diff) |
Merge pull request #93 from colesbury/byte
Support saving torch.ByteTensor images
-rw-r--r-- | init.lua | 45 | ||||
-rwxr-xr-x | test/test_decompress_jpg.lua | 6 |
2 files changed, 24 insertions, 27 deletions
@@ -138,15 +138,22 @@ local function loadPNG(filename, depth, tensortype) end rawset(image, 'loadPNG', loadPNG) +local function clampImage(tensor) + if tensor:type() == 'torch.ByteTensor' then + return tensor + end + local a = torch.Tensor():resize(tensor:size()):copy(tensor) + a.image.saturate(tensor) -- bound btwn 0 and 1 + a:mul(255) -- remap to [0..255] + return a +end + local function savePNG(filename, tensor) if not xlua.require 'libpng' then dok.error('libpng package not found, please install libpng','image.savePNG') end - local MAXVAL = 255 - local a = torch.Tensor():resize(tensor:size()):copy(tensor) - a.image.saturate(a) -- bound btwn 0 and 1 - a:mul(MAXVAL) -- remap to [0..255] - a.libpng.save(filename, a) + tensor = clampImage(tensor) + tensor.libpng.save(filename, tensor) end rawset(image, 'savePNG', savePNG) @@ -222,13 +229,10 @@ local function saveJPG(filename, tensor) if not xlua.require 'libjpeg' then dok.error('libjpeg package not found, please install libjpeg','image.saveJPG') end - local MAXVAL = 255 - local a = torch.Tensor():resize(tensor:size()):copy(tensor) - a.image.saturate(a) -- bound btwn 0 and 1 - a:mul(MAXVAL) -- remap to [0..255] + tensor = clampImage(tensor) local save_to_file = 1 local quality = 75 - a.libjpeg.save(filename, a, save_to_file, quality) + tensor.libjpeg.save(filename, tensor, save_to_file, quality) end rawset(image, 'saveJPG', saveJPG) @@ -244,14 +248,11 @@ local function compressJPG(tensor, quality) dok.error('libjpeg package not found, please install libjpeg', 'image.compressJPG') end - local MAXVAL = 255 - local a = torch.Tensor():resize(tensor:size()):copy(tensor) - a.image.saturate(a) -- bound btwn 0 and 1 - a:mul(MAXVAL) -- remap to [0..255] + tensor = clampImage(tensor) local b = torch.ByteTensor() local save_to_file = 0 quality = quality or 75 - a.libjpeg.save("", a, save_to_file, quality, b) + tensor.libjpeg.save("", tensor, save_to_file, quality, b) return b end rawset(image, 'compressJPG', compressJPG) @@ -280,11 +281,8 @@ local function savePPM(filename, tensor) if tensor:nDimension() ~= 3 or tensor:size(1) ~= 3 then dok.error('can only save 3xHxW images as PPM', 'image.savePPM') end - local MAXVAL = 255 - local a = torch.Tensor():resize(tensor:size()):copy(tensor) - a.image.saturate(a) -- bound btwn 0 and 1 - a:mul(MAXVAL) -- remap to [0..255] - a.libppm.save(filename, a) + tensor = clampImage(tensor) + tensor.libppm.save(filename, tensor) end rawset(image, 'savePPM', savePPM) @@ -293,11 +291,8 @@ local function savePGM(filename, tensor) if tensor:nDimension() == 3 and tensor:size(1) ~= 1 then dok.error('can only save 1xHxW or HxW images as PGM', 'image.savePGM') end - local MAXVAL = 255 - local a = torch.Tensor():resize(tensor:size()):copy(tensor) - a.image.saturate(a) -- bound btwn 0 and 1 - a:mul(MAXVAL) -- remap to [0..255] - a.libppm.save(filename, a) + tensor = clampImage(tensor) + tensor.libppm.save(filename, tensor) end rawset(image, 'savePGM', savePGM) diff --git a/test/test_decompress_jpg.lua b/test/test_decompress_jpg.lua index a64a443..5728eaf 100755 --- a/test/test_decompress_jpg.lua +++ b/test/test_decompress_jpg.lua @@ -51,9 +51,11 @@ function test.LoadInvalid() local img_binary = torch.rand(file_size_bytes):mul(255):byte() -- Now decompress the image from the ByteTensor - local img_from_tensor = image.decompressJPG(img_binary) + local ok, img_from_tensor = pcall(function() + return image.decompressJPG(img_binary) + end) - mytester:assert(img_from_tensor == nil, + mytester:assert(not ok or img_from_tensor == nil, 'A non-nil was returned on an invalid input! ') end |