Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-27 23:59:43 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-27 23:59:43 +0400
commit4fe018c20f291e00b7bca2849c4cf0c9e1415903 (patch)
tree3e954a5329c27d6651bf0c551a0016860f60418f
parent1e180fefe5018accab7b17cb67658a42a5a37fbe (diff)
Added a filter to sample patches in a smarter way.
-rw-r--r--DataSetLabelMe.lua7
-rw-r--r--generic/DataSetLabelMe.c51
2 files changed, 40 insertions, 18 deletions
diff --git a/DataSetLabelMe.lua b/DataSetLabelMe.lua
index 960d07f..fed7b70 100644
--- a/DataSetLabelMe.lua
+++ b/DataSetLabelMe.lua
@@ -30,6 +30,7 @@ function DataSetLabelMe:__init(...)
{arg='nbPatchPerSample', type='number', help='number of patches to extract from each image', default=100},
{arg='patchSize', type='number', help='size of patches to extract from images', default=64},
{arg='samplingMode', type='string', help='patch sampling method: random | equal', default='random'},
+ {arg='samplingFilter', type='table', help='a filter to sample patches: {ratio=,size=,step}'},
{arg='labelType', type='string', help='type of label returned: center | pixelwise', default='center'},
{arg='labelGenerator', type='function', help='a function to generate sample+target (bypasses labelType)'},
{arg='infiniteSet', type='boolean', help='if true, the set can be indexed to infinity, looping around samples', default=false},
@@ -358,12 +359,16 @@ function DataSetLabelMe:parseMask(existing_tags)
end
end
end
+ -- use filter
+ local filter = self.samplingFilter or {ratio=0, size=self.patchSize, step=4}
+ -- extract labels
local mask = self.currentMask
local x_start = math.ceil(self.patchSize/2)
local x_end = mask:size(2) - math.ceil(self.patchSize/2)
local y_start = math.ceil(self.patchSize/2)
local y_end = mask:size(1) - math.ceil(self.patchSize/2)
- mask.nn.DataSetLabelMe_extract(tags, mask, x_start, x_end, y_start, y_end, self.currentIndex)
+ mask.nn.DataSetLabelMe_extract(tags, mask, x_start, x_end, y_start, y_end, self.currentIndex,
+ filter.ratio, filter.size, filter.step)
return tags
end
diff --git a/generic/DataSetLabelMe.c b/generic/DataSetLabelMe.c
index 0f74419..878c44b 100644
--- a/generic/DataSetLabelMe.c
+++ b/generic/DataSetLabelMe.c
@@ -12,33 +12,50 @@ static int nn_(DataSetLabelMe_extract)(lua_State *L)
int y_start = lua_tonumber(L, 5);
int y_end = lua_tonumber(L, 6);
int idx = lua_tonumber(L, 7);
+ float filter_ratio = lua_tonumber(L, 8);
+ int filter_size = lua_tonumber(L, 9);
+ int filter_step = lua_tonumber(L, 10);
+ float ratio = 1;
int x,y,label,tag,size;
THShortStorage *data;
for (x=x_start; x<=x_end; x++) {
for (y=y_start; y<=y_end; y++) {
- label = THTensor_(get2d)(mask, y-1, x-1); // label = mask[x][y]
- lua_rawgeti(L, tags, label); // tag = tags[label]
- tag = lua_gettop(L);
- lua_pushstring(L, "size"); lua_rawget(L, tag); // size = tag.size
- size = lua_tonumber(L,-1); lua_pop(L,1);
- lua_pushstring(L, "size"); lua_pushnumber(L, size+3); lua_rawset(L, tag); // tag.size = size + 3
- lua_pushstring(L, "data"); lua_rawget(L, tag); // data = tag.data
- data = luaT_checkudata(L, -1, torch_ShortStorage_id); lua_pop(L, 1);
- data->data[size] = x; // data[size+1] = x
- data->data[size+1] = y; // data[size+1] = y
- data->data[size+2] = idx; // data[size+1] = idx
- lua_pop(L, 1);
+ // label = mask[x][y]
+ label = THTensor_(get2d)(mask, y-1, x-1);
+
+ // optional filter: insures that at least N% of local pixels belong to the same class
+ if (filter_ratio > 0) {
+ int kx,ky,count=0,good=0;
+ for (kx=MAX(1,x-filter_size/2); kx<=MIN(x_end,x+filter_size/2); kx+=filter_step) {
+ for (ky=MAX(1,y-filter_size/2); ky<=MIN(y_end,y+filter_size/2); ky+=filter_step) {
+ int other = THTensor_(get2d)(mask, ky-1, kx-1);
+ if (other == label) good++;
+ count++;
+ }
+ }
+ ratio = (float)good/(float)count;
+ }
+
+ // if filter(s) satisfied, then append label
+ if (ratio >= filter_ratio) {
+ lua_rawgeti(L, tags, label); // tag = tags[label]
+ tag = lua_gettop(L);
+ lua_pushstring(L, "size"); lua_rawget(L, tag); // size = tag.size
+ size = lua_tonumber(L,-1); lua_pop(L,1);
+ lua_pushstring(L, "size"); lua_pushnumber(L, size+3); lua_rawset(L, tag); // tag.size = size + 3
+ lua_pushstring(L, "data"); lua_rawget(L, tag); // data = tag.data
+ data = luaT_checkudata(L, -1, torch_ShortStorage_id); lua_pop(L, 1);
+ data->data[size] = x; // data[size+1] = x
+ data->data[size+1] = y; // data[size+1] = y
+ data->data[size+2] = idx; // data[size+1] = idx
+ lua_pop(L, 1);
+ }
}
}
return 0;
}
-static int nn_(DataSetLabelMe_backward)(lua_State *L)
-{
-
-}
-
static const struct luaL_Reg nn_(DataSetLabelMe__) [] = {
{"DataSetLabelMe_extract", nn_(DataSetLabelMe_extract)},
{NULL, NULL}