Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/image.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-01-14 23:00:28 +0300
committerSoumith Chintala <soumith@gmail.com>2015-01-14 23:00:28 +0300
commit23c312e3ba6d9050a340cb76b9c8edb27c0dabd2 (patch)
treee3ef55333d064ae1cb0478bda1d5fb474f13f48c
parent4661af509e8b268717baf2732511bb0374f08246 (diff)
parent81cfeaf4461893952befa5d8a8122cf7d8b2484f (diff)
Merge pull request #39 from nagadomi/rotate_bilinear
Add bilinear mode for rotate()
-rw-r--r--README.md4
-rw-r--r--generic/image.c98
-rwxr-xr-xinit.lua28
-rw-r--r--test/test_rotate.lua75
4 files changed, 200 insertions, 5 deletions
diff --git a/README.md b/README.md
index 7af68eb..88c7948 100644
--- a/README.md
+++ b/README.md
@@ -77,9 +77,11 @@ Rescale the height and width of image `src` to fit the dimensions of
Tensor `dst`.
<a name="image.rotate"/>
-### [res] image.rotate([dst,], src, theta) ###
+### [res] image.rotate([dst,], src, theta, [mode]) ###
Rotates image `src` by `theta` radians.
If `dst` is specified it is used to store the results of the rotation.
+Variable `mode` specifies type of interpolation to be used. Valid values include
+*simple*(the default) or *bilinear* interpolation.
<a name="image.hflip"/>
### [res] image.hflip([dst,] src) ###
diff --git a/generic/image.c b/generic/image.c
index 377467d..eed7f5f 100644
--- a/generic/image.c
+++ b/generic/image.c
@@ -338,6 +338,103 @@ static int image_(Main_rotate)(lua_State *L)
}
return 0;
}
+static int image_(Main_rotateBilinear)(lua_State *L)
+{
+ THTensor *Tsrc = luaT_checkudata(L, 1, torch_Tensor);
+ THTensor *Tdst = luaT_checkudata(L, 2, torch_Tensor);
+ float theta = luaL_checknumber(L, 3);
+ real *src, *dst;
+ long dst_stride0, dst_stride1, dst_stride2, dst_width, dst_height, dst_depth;
+ long src_stride0, src_stride1, src_stride2, src_width, src_height, src_depth;
+ long i, j, k;
+ float xc, yc;
+ float id,jd;
+ long ii_0, ii_1, jj_0, jj_1;
+
+ luaL_argcheck(L, Tsrc->nDimension==2 || Tsrc->nDimension==3, 1, "rotate: src not 2 or 3 dimensional");
+ luaL_argcheck(L, Tdst->nDimension==2 || Tdst->nDimension==3, 2, "rotate: dst not 2 or 3 dimensional");
+
+ src= THTensor_(data)(Tsrc);
+ dst= THTensor_(data)(Tdst);
+
+ dst_stride0 = 0;
+ dst_stride1 = Tdst->stride[Tdst->nDimension-2];
+ dst_stride2 = Tdst->stride[Tdst->nDimension-1];
+ dst_depth = 0;
+ dst_height = Tdst->size[Tdst->nDimension-2];
+ dst_width = Tdst->size[Tdst->nDimension-1];
+ if(Tdst->nDimension == 3) {
+ dst_stride0 = Tdst->stride[0];
+ dst_depth = Tdst->size[0];
+ }
+
+ src_stride0 = 0;
+ src_stride1 = Tsrc->stride[Tsrc->nDimension-2];
+ src_stride2 = Tsrc->stride[Tsrc->nDimension-1];
+ src_depth = 0;
+ src_height = Tsrc->size[Tsrc->nDimension-2];
+ src_width = Tsrc->size[Tsrc->nDimension-1];
+ if(Tsrc->nDimension == 3) {
+ src_stride0 = Tsrc->stride[0];
+ src_depth = Tsrc->size[0];
+ }
+
+ if( Tsrc->nDimension==3 && Tdst->nDimension==3 && ( src_depth!=dst_depth) )
+ luaL_error(L, "image.rotate: src and dst depths do not match");
+
+ if( (Tsrc->nDimension!=Tdst->nDimension) )
+ luaL_error(L, "image.rotate: src and dst depths do not match");
+
+ xc=src_width/2.0;
+ yc=src_height/2.0;
+
+ for(j = 0; j < dst_height; j++) {
+ jd=j;
+ for(i = 0; i < dst_width; i++) {
+ float val = -1;
+ real ri, rj, wi, wj;
+ id= i;
+ ri = cos(theta)*(id-xc)-sin(theta)*(jd-yc);
+ rj = cos(theta)*(jd-yc)+sin(theta)*(id-xc);
+
+ ii_0=(long)floor(ri);
+ ii_1=ii_0 + 1;
+ jj_0=(long)floor(rj);
+ jj_1=jj_0 + 1;
+ wi = ri - ii_0;
+ wj = rj - jj_0;
+ ii_0+=(long) xc; ii_1+=(long) xc; jj_0+=(long) yc;jj_1+=(long) yc;
+
+ /* rotated corners are blank */
+ if(ii_1>src_width-1) val=0;
+ if(jj_1>src_height-1) val=0;
+ if(ii_0<0) val=0;
+ if(jj_0<0) val=0;
+
+ if(Tsrc->nDimension==2) {
+ if(val==-1)
+ val = (1.0 - wi) * (1.0 - wj) * src[ii_0*src_stride2+jj_0*src_stride1]
+ + wi * (1.0 - wj) * src[ii_1*src_stride2+jj_0*src_stride1]
+ + (1.0 - wi) * wj * src[ii_0*src_stride2+jj_1*src_stride1]
+ + wi * wj * src[ii_1*src_stride2+jj_1*src_stride1];
+ dst[i*dst_stride2+j*dst_stride1] = val;
+ } else {
+ int do_copy=0; if(val==-1) do_copy=1;
+ for(k=0;k<src_depth;k++) {
+ if(do_copy) {
+ val = (1.0 - wi) * (1.0 - wj) * src[ii_0*src_stride2+jj_0*src_stride1+k*src_stride0]
+ + wi * (1.0 - wj) * src[ii_1*src_stride2+jj_0*src_stride1+k*src_stride0]
+ + (1.0 - wi) * wj * src[ii_0*src_stride2+jj_1*src_stride1+k*src_stride0]
+ + wi * wj * src[ii_1*src_stride2+jj_1*src_stride1+k*src_stride0];
+ }
+ dst[i*dst_stride2+j*dst_stride1+k*dst_stride0] = val;
+ }
+ }
+ }
+ }
+ return 0;
+}
+
static int image_(Main_cropNoScale)(lua_State *L)
{
@@ -1074,6 +1171,7 @@ static const struct luaL_Reg image_(Main__) [] = {
{"scaleSimple", image_(Main_scaleSimple)},
{"scaleBilinear", image_(Main_scaleBilinear)},
{"rotate", image_(Main_rotate)},
+ {"rotateBilinear", image_(Main_rotateBilinear)},
{"translate", image_(Main_translate)},
{"cropNoScale", image_(Main_cropNoScale)},
{"warp", image_(Main_warp)},
diff --git a/init.lua b/init.lua
index 6e078c3..6bdb21a 100755
--- a/init.lua
+++ b/init.lua
@@ -499,12 +499,23 @@ rawset(image, 'scale', scale)
-- rotate
--
local function rotate(...)
- local dst,src,theta
+ local dst,src,theta, mode
local args = {...}
- if select('#',...) == 3 then
+ if select('#',...) == 4 then
dst = args[1]
src = args[2]
theta = args[3]
+ mode = args[4]
+ elseif select('#',...) == 3 then
+ if type(args[2]) == 'number' then
+ src = args[1]
+ theta = args[2]
+ mode = args[3]
+ else
+ dst = args[1]
+ src = args[2]
+ theta = args[3]
+ end
elseif select('#',...) == 2 then
src = args[1]
theta = args[2]
@@ -513,15 +524,24 @@ local function rotate(...)
'rotate an image by theta radians', nil,
{type='torch.Tensor', help='input image', req=true},
{type='number', help='rotation angle (in radians)', req=true},
+ {type='string', help='mode: simple | bilinear', default='simple'},
'',
{type='torch.Tensor', help='destination', req=true},
{type='torch.Tensor', help='input image', req=true},
- {type='number', help='rotation angle (in radians)', req=true}))
+ {type='number', help='rotation angle (in radians)', req=true},
+ {type='string', help='mode: simple | bilinear', default='simple'}))
dok.error('incorrect arguments', 'image.rotate')
end
dst = dst or src.new()
dst:resizeAs(src)
- src.image.rotate(src,dst,theta)
+ mode = mode or 'simple'
+ if mode == 'simple' then
+ src.image.rotate(src,dst,theta)
+ elseif mode == 'bilinear' then
+ src.image.rotateBilinear(src,dst,theta)
+ else
+ dok.error('mode must be one of: simple | bilinear', 'image.rotate')
+ end
return dst
end
rawset(image, 'rotate', rotate)
diff --git a/test/test_rotate.lua b/test/test_rotate.lua
new file mode 100644
index 0000000..fb39ac3
--- /dev/null
+++ b/test/test_rotate.lua
@@ -0,0 +1,75 @@
+require 'image'
+
+torch.setdefaulttensortype('torch.FloatTensor')
+torch.setnumthreads(16)
+
+local function test_rotate(src, mode)
+ torch.manualSeed(11)
+ local mean_dist = 0.0
+ for i = 1, 10 do
+ local theta = torch.uniform(0, 2 * math.pi)
+ local d1, d2, d3, d4
+
+ -- rotate
+ if mode then
+ d1 = image.rotate(src, theta, mode)
+ d2 = src.new():resizeAs(src)
+ image.rotate(d2, src, theta, mode)
+ else
+ d1 = image.rotate(src, theta)
+ d2 = src.new():resizeAs(src)
+ image.rotate(d2, src, theta)
+ end
+ -- revert
+ local revert = 2 * math.pi - theta
+ if mode then
+ d3 = image.rotate(d1, revert, mode)
+ d4 = src.new():resizeAs(src)
+ image.rotate(d4, d2, revert, mode)
+ else
+ d3 = image.rotate(d1, revert)
+ d4 = src.new():resizeAs(src)
+ image.rotate(d4, d2, revert)
+ end
+
+ -- diff
+ if src:dim() == 3 then
+ local cs = image.crop(src, src:size(2) / 4, src:size(3) / 4, src:size(2) / 4 * 3, src:size(3) / 4 * 3)
+ local c3 = image.crop(d3, src:size(2) / 4, src:size(3) / 4, src:size(2) / 4 * 3, src:size(3) / 4 * 3)
+ local c4 = image.crop(d4, src:size(2) / 4, src:size(3) / 4, src:size(2) / 4 * 3, src:size(3) / 4 * 3)
+
+ mean_dist = mean_dist + cs:dist(c3)
+ mean_dist = mean_dist + cs:dist(c4)
+ elseif src:dim() == 2 then
+ local cs = image.crop(src, src:size(1) / 4, src:size(2) / 4, src:size(1) / 4 * 3, src:size(2) / 4 * 3)
+ local c3 = image.crop(d3, src:size(1) / 4, src:size(2) / 4, src:size(1) / 4 * 3, src:size(2) / 4 * 3)
+ local c4 = image.crop(d4, src:size(1) / 4, src:size(2) / 4, src:size(1) / 4 * 3, src:size(2) / 4 * 3)
+ mean_dist = mean_dist + cs:dist(c3)
+ mean_dist = mean_dist + cs:dist(c4)
+ end
+ if i == 1 then
+ --[[
+ image.display(src)
+ image.display(d1)
+ image.display(d2)
+ image.display(d3)
+ image.display(d4)
+ --]]
+ end
+ end
+ if mode then
+ print("mode = " .. mode .. ", mean dist: " .. mean_dist / (10 * 2))
+ else
+ print("mode = nil, mean dist: " .. mean_dist / (10 * 2))
+ end
+end
+local src = image.scale(image.lena(), 128, 128, 'bilinear')
+print("** dim3")
+test_rotate(src, nil)
+test_rotate(src, 'simple')
+test_rotate(src, 'bilinear')
+print("** dim2")
+src = src:select(1, 1)
+test_rotate(src, nil)
+test_rotate(src, 'simple')
+test_rotate(src, 'bilinear')