diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-01-14 23:00:28 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-01-14 23:00:28 +0300 |
commit | 23c312e3ba6d9050a340cb76b9c8edb27c0dabd2 (patch) | |
tree | e3ef55333d064ae1cb0478bda1d5fb474f13f48c | |
parent | 4661af509e8b268717baf2732511bb0374f08246 (diff) | |
parent | 81cfeaf4461893952befa5d8a8122cf7d8b2484f (diff) |
Merge pull request #39 from nagadomi/rotate_bilinear
Add bilinear mode for rotate()
-rw-r--r-- | README.md | 4 | ||||
-rw-r--r-- | generic/image.c | 98 | ||||
-rwxr-xr-x | init.lua | 28 | ||||
-rw-r--r-- | test/test_rotate.lua | 75 |
4 files changed, 200 insertions, 5 deletions
@@ -77,9 +77,11 @@ Rescale the height and width of image `src` to fit the dimensions of Tensor `dst`. <a name="image.rotate"/> -### [res] image.rotate([dst,], src, theta) ### +### [res] image.rotate([dst,], src, theta, [mode]) ### Rotates image `src` by `theta` radians. If `dst` is specified it is used to store the results of the rotation. +Variable `mode` specifies type of interpolation to be used. Valid values include +*simple*(the default) or *bilinear* interpolation. <a name="image.hflip"/> ### [res] image.hflip([dst,] src) ### diff --git a/generic/image.c b/generic/image.c index 377467d..eed7f5f 100644 --- a/generic/image.c +++ b/generic/image.c @@ -338,6 +338,103 @@ static int image_(Main_rotate)(lua_State *L) } return 0; } +static int image_(Main_rotateBilinear)(lua_State *L) +{ + THTensor *Tsrc = luaT_checkudata(L, 1, torch_Tensor); + THTensor *Tdst = luaT_checkudata(L, 2, torch_Tensor); + float theta = luaL_checknumber(L, 3); + real *src, *dst; + long dst_stride0, dst_stride1, dst_stride2, dst_width, dst_height, dst_depth; + long src_stride0, src_stride1, src_stride2, src_width, src_height, src_depth; + long i, j, k; + float xc, yc; + float id,jd; + long ii_0, ii_1, jj_0, jj_1; + + luaL_argcheck(L, Tsrc->nDimension==2 || Tsrc->nDimension==3, 1, "rotate: src not 2 or 3 dimensional"); + luaL_argcheck(L, Tdst->nDimension==2 || Tdst->nDimension==3, 2, "rotate: dst not 2 or 3 dimensional"); + + src= THTensor_(data)(Tsrc); + dst= THTensor_(data)(Tdst); + + dst_stride0 = 0; + dst_stride1 = Tdst->stride[Tdst->nDimension-2]; + dst_stride2 = Tdst->stride[Tdst->nDimension-1]; + dst_depth = 0; + dst_height = Tdst->size[Tdst->nDimension-2]; + dst_width = Tdst->size[Tdst->nDimension-1]; + if(Tdst->nDimension == 3) { + dst_stride0 = Tdst->stride[0]; + dst_depth = Tdst->size[0]; + } + + src_stride0 = 0; + src_stride1 = Tsrc->stride[Tsrc->nDimension-2]; + src_stride2 = Tsrc->stride[Tsrc->nDimension-1]; + src_depth = 0; + src_height = Tsrc->size[Tsrc->nDimension-2]; + src_width = Tsrc->size[Tsrc->nDimension-1]; + if(Tsrc->nDimension == 3) { + src_stride0 = Tsrc->stride[0]; + src_depth = Tsrc->size[0]; + } + + if( Tsrc->nDimension==3 && Tdst->nDimension==3 && ( src_depth!=dst_depth) ) + luaL_error(L, "image.rotate: src and dst depths do not match"); + + if( (Tsrc->nDimension!=Tdst->nDimension) ) + luaL_error(L, "image.rotate: src and dst depths do not match"); + + xc=src_width/2.0; + yc=src_height/2.0; + + for(j = 0; j < dst_height; j++) { + jd=j; + for(i = 0; i < dst_width; i++) { + float val = -1; + real ri, rj, wi, wj; + id= i; + ri = cos(theta)*(id-xc)-sin(theta)*(jd-yc); + rj = cos(theta)*(jd-yc)+sin(theta)*(id-xc); + + ii_0=(long)floor(ri); + ii_1=ii_0 + 1; + jj_0=(long)floor(rj); + jj_1=jj_0 + 1; + wi = ri - ii_0; + wj = rj - jj_0; + ii_0+=(long) xc; ii_1+=(long) xc; jj_0+=(long) yc;jj_1+=(long) yc; + + /* rotated corners are blank */ + if(ii_1>src_width-1) val=0; + if(jj_1>src_height-1) val=0; + if(ii_0<0) val=0; + if(jj_0<0) val=0; + + if(Tsrc->nDimension==2) { + if(val==-1) + val = (1.0 - wi) * (1.0 - wj) * src[ii_0*src_stride2+jj_0*src_stride1] + + wi * (1.0 - wj) * src[ii_1*src_stride2+jj_0*src_stride1] + + (1.0 - wi) * wj * src[ii_0*src_stride2+jj_1*src_stride1] + + wi * wj * src[ii_1*src_stride2+jj_1*src_stride1]; + dst[i*dst_stride2+j*dst_stride1] = val; + } else { + int do_copy=0; if(val==-1) do_copy=1; + for(k=0;k<src_depth;k++) { + if(do_copy) { + val = (1.0 - wi) * (1.0 - wj) * src[ii_0*src_stride2+jj_0*src_stride1+k*src_stride0] + + wi * (1.0 - wj) * src[ii_1*src_stride2+jj_0*src_stride1+k*src_stride0] + + (1.0 - wi) * wj * src[ii_0*src_stride2+jj_1*src_stride1+k*src_stride0] + + wi * wj * src[ii_1*src_stride2+jj_1*src_stride1+k*src_stride0]; + } + dst[i*dst_stride2+j*dst_stride1+k*dst_stride0] = val; + } + } + } + } + return 0; +} + static int image_(Main_cropNoScale)(lua_State *L) { @@ -1074,6 +1171,7 @@ static const struct luaL_Reg image_(Main__) [] = { {"scaleSimple", image_(Main_scaleSimple)}, {"scaleBilinear", image_(Main_scaleBilinear)}, {"rotate", image_(Main_rotate)}, + {"rotateBilinear", image_(Main_rotateBilinear)}, {"translate", image_(Main_translate)}, {"cropNoScale", image_(Main_cropNoScale)}, {"warp", image_(Main_warp)}, @@ -499,12 +499,23 @@ rawset(image, 'scale', scale) -- rotate -- local function rotate(...) - local dst,src,theta + local dst,src,theta, mode local args = {...} - if select('#',...) == 3 then + if select('#',...) == 4 then dst = args[1] src = args[2] theta = args[3] + mode = args[4] + elseif select('#',...) == 3 then + if type(args[2]) == 'number' then + src = args[1] + theta = args[2] + mode = args[3] + else + dst = args[1] + src = args[2] + theta = args[3] + end elseif select('#',...) == 2 then src = args[1] theta = args[2] @@ -513,15 +524,24 @@ local function rotate(...) 'rotate an image by theta radians', nil, {type='torch.Tensor', help='input image', req=true}, {type='number', help='rotation angle (in radians)', req=true}, + {type='string', help='mode: simple | bilinear', default='simple'}, '', {type='torch.Tensor', help='destination', req=true}, {type='torch.Tensor', help='input image', req=true}, - {type='number', help='rotation angle (in radians)', req=true})) + {type='number', help='rotation angle (in radians)', req=true}, + {type='string', help='mode: simple | bilinear', default='simple'})) dok.error('incorrect arguments', 'image.rotate') end dst = dst or src.new() dst:resizeAs(src) - src.image.rotate(src,dst,theta) + mode = mode or 'simple' + if mode == 'simple' then + src.image.rotate(src,dst,theta) + elseif mode == 'bilinear' then + src.image.rotateBilinear(src,dst,theta) + else + dok.error('mode must be one of: simple | bilinear', 'image.rotate') + end return dst end rawset(image, 'rotate', rotate) diff --git a/test/test_rotate.lua b/test/test_rotate.lua new file mode 100644 index 0000000..fb39ac3 --- /dev/null +++ b/test/test_rotate.lua @@ -0,0 +1,75 @@ +require 'image' + +torch.setdefaulttensortype('torch.FloatTensor') +torch.setnumthreads(16) + +local function test_rotate(src, mode) + torch.manualSeed(11) + local mean_dist = 0.0 + for i = 1, 10 do + local theta = torch.uniform(0, 2 * math.pi) + local d1, d2, d3, d4 + + -- rotate + if mode then + d1 = image.rotate(src, theta, mode) + d2 = src.new():resizeAs(src) + image.rotate(d2, src, theta, mode) + else + d1 = image.rotate(src, theta) + d2 = src.new():resizeAs(src) + image.rotate(d2, src, theta) + end + -- revert + local revert = 2 * math.pi - theta + if mode then + d3 = image.rotate(d1, revert, mode) + d4 = src.new():resizeAs(src) + image.rotate(d4, d2, revert, mode) + else + d3 = image.rotate(d1, revert) + d4 = src.new():resizeAs(src) + image.rotate(d4, d2, revert) + end + + -- diff + if src:dim() == 3 then + local cs = image.crop(src, src:size(2) / 4, src:size(3) / 4, src:size(2) / 4 * 3, src:size(3) / 4 * 3) + local c3 = image.crop(d3, src:size(2) / 4, src:size(3) / 4, src:size(2) / 4 * 3, src:size(3) / 4 * 3) + local c4 = image.crop(d4, src:size(2) / 4, src:size(3) / 4, src:size(2) / 4 * 3, src:size(3) / 4 * 3) + + mean_dist = mean_dist + cs:dist(c3) + mean_dist = mean_dist + cs:dist(c4) + elseif src:dim() == 2 then + local cs = image.crop(src, src:size(1) / 4, src:size(2) / 4, src:size(1) / 4 * 3, src:size(2) / 4 * 3) + local c3 = image.crop(d3, src:size(1) / 4, src:size(2) / 4, src:size(1) / 4 * 3, src:size(2) / 4 * 3) + local c4 = image.crop(d4, src:size(1) / 4, src:size(2) / 4, src:size(1) / 4 * 3, src:size(2) / 4 * 3) + mean_dist = mean_dist + cs:dist(c3) + mean_dist = mean_dist + cs:dist(c4) + end + if i == 1 then + --[[ + image.display(src) + image.display(d1) + image.display(d2) + image.display(d3) + image.display(d4) + --]] + end + end + if mode then + print("mode = " .. mode .. ", mean dist: " .. mean_dist / (10 * 2)) + else + print("mode = nil, mean dist: " .. mean_dist / (10 * 2)) + end +end +local src = image.scale(image.lena(), 128, 128, 'bilinear') +print("** dim3") +test_rotate(src, nil) +test_rotate(src, 'simple') +test_rotate(src, 'bilinear') +print("** dim2") +src = src:select(1, 1) +test_rotate(src, nil) +test_rotate(src, 'simple') +test_rotate(src, 'bilinear') |