diff options
author | Sasank Chilamkurthy <sasankchilamkurthy@gmail.com> | 2016-08-23 16:57:41 +0300 |
---|---|---|
committer | Sasank Chilamkurthy <sasankchilamkurthy@gmail.com> | 2016-08-23 16:59:12 +0300 |
commit | e3941d9b0a053e73bc70c10cd5c806b11d0ac6e5 (patch) | |
tree | e2b8db3017032d9d0f61ae8b8e4179c5a72f2202 | |
parent | 797fcb101b76c9b329b3ee83349b0f6adeacac94 (diff) |
Add affine transform function
-rw-r--r-- | init.lua | 121 |
1 files changed, 121 insertions, 0 deletions
@@ -969,6 +969,127 @@ end rawset(image, 'warp', warp) ---------------------------------------------------------------------- +-- affine transform +-- +local function affinetransform(...) + local dst,src,matrix + local mode = 'bilinear' + local translation = torch.Tensor{0,0} + local clamp_mode = 'clamp' + local pad_value = 0 + local args = {...} + local nargs = select('#',...) + local bad_args = false + if nargs == 2 then + src = args[1] + matrix = args[2] + elseif nargs >= 3 then + if type(args[3]) == 'string' then + -- No destination tensor + src = args[1] + matrix = args[2] + mode = args[3] + if nargs >= 4 then translation = args[4] end + if nargs >= 5 then clamp_mode = args[5] end + if nargs >= 6 then + assert(clamp_mode == 'pad', 'pad_value can only be specified if' .. + ' clamp_mode = "pad"') + pad_value = args[6] + end + if nargs >= 7 then bad_args = true end + else + -- With Destination tensor + dst = args[1] + src = args[2] + matrix = args[3] + if nargs >= 4 then mode = args[4] end + if nargs >= 5 then translation = args[5] end + if nargs >= 6 then clamp_mode = args[6] end + if nargs >= 7 then + assert(clamp_mode == 'pad', 'pad_value can only be specified if' .. + ' clamp_mode = "pad"') + pad_value = args[7] + end + if nargs >= 8 then bad_args = true end + end + end + if bad_args then + print(dok.usage('image.warp', + 'warp an image, according to given affine transform', nil, + {type='torch.Tensor', help='input image (KxHxW)', req=true}, + {type='torch.Tensor', help='(y,x) affine translation matrix', req=true}, + {type='string', help='mode: lanczos | bicubic | bilinear | simple', default='bilinear'}, + {type='torch.Tensor', help='extra (y,x) translation to be done before transform', default=torch.Tensor{0,0}}, + {type='string', help='clamp mode: how to handle interp of samples off the input image (clamp | pad)', default='clamp'}, + '', + {type='torch.Tensor', help='input image (KxHxW)', req=true}, + {type='torch.Tensor', help='(y,x) affine translation matrix', req=true}, + {type='string', help='mode: lanczos | bicubic | bilinear | simple', default='bilinear'}, + {type='torch.Tensor', help='extra (y,x) translation to be done before transform', default=torch.Tensor{0,0}}, + {type='string', help='clamp mode: how to handle interp of samples off the input image (clamp | pad)', default='clamp'}, + {type='number', help='pad value: value to pad image. Can only be set when clamp mode equals "pad"', default=0})) + dok.error('incorrect arguments', 'image.warp') + end + -- This is a little messy, but convert mode string to an enum + if (mode == 'simple') then + mode = 0 + elseif (mode == 'bilinear') then + mode = 1 + elseif (mode == 'bicubic') then + mode = 2 + elseif (mode == 'lanczos') then + mode = 3 + else + dok.error('Incorrect arguments (mode is not lanczos | bicubic | bilinear | simple)!', 'image.warp') + end + if (clamp_mode == 'clamp') then + clamp_mode = 0 + elseif (clamp_mode == 'pad') then + clamp_mode = 1 + else + dok.error('Incorrect arguments (clamp_mode is not clamp | pad)!', 'image.warp') + end + + local dim2 = false + if src:nDimension() == 2 then + dim2 = true + src = src:reshape(1,src:size(1),src:size(2)) + end + dst = dst or src.new() + dst:resize(src:size(1), src:size(2), src:size(3)) + + -- create field + local height = src:size(2) + local width = src:size(3) + + local grid_y = torch.ger( torch.linspace(-1,1,height), torch.ones(width) ) + local grid_x = torch.ger( torch.ones(height), torch.linspace(-1,1,width) ) + + local grid_xy = torch.FloatTensor() + grid_xy:resize(2,height,width) + grid_xy[1] = grid_y * ((height-1)/2) * -1 + grid_xy[2] = grid_x * ((width-1)/2) * -1 + local view_xy = grid_xy:reshape(2,height*width) + + local field = torch.mm(matrix, view_xy) + field = (grid_xy - field:reshape( 2, height, width )):double() + + -- offset field for translation + translation = torch.Tensor(translation) + field[1] = field[1] - translation[1] + field[2] = field[2] - translation[2] + + + local offset_mode = true + src.image.warp(dst, src, field, mode, offset_mode, clamp_mode, pad_value) + if dim2 then + dst = dst[1] + end + return dst +end +rawset(image, 'affinetransform', affinetransform) + +---------------------------------------------------------------------- -- hflip -- local function hflip(...) |