Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarco Scoffier <github@metm.org>2011-09-15 11:03:55 +0400
committerMarco Scoffier <github@metm.org>2011-09-15 11:03:55 +0400
commitab6e8c5fe0cec06682d6bc0bf7ae3c518e934d7f (patch)
treece8b07a81acd93547abc32cba36365a5a8517d41
parent28d484e2b0be59aceb3addea2b23b706523100c8 (diff)
changes to run notMNIST with padding
-rw-r--r--DataList.lua20
-rw-r--r--DataSet.lua20
2 files changed, 29 insertions, 11 deletions
diff --git a/DataList.lua b/DataList.lua
index 43617ed..287d7dd 100644
--- a/DataList.lua
+++ b/DataList.lua
@@ -12,6 +12,7 @@ function DataList:__init()
self.datasets = {}
self.nbClass = 0
self.ClassName = {}
+ self.ClassRev = {}
self.nbSamples = 0
self.targetIsProbability = false
self.spatialTarget = false
@@ -67,12 +68,17 @@ end
function DataList:appendDataSet(dataSet,className)
table.insert(self.datasets,dataSet)
- if self.nbSamples == 0 then
- self.nbSamples = dataSet:size()
- else
- self.nbSamples = math.floor(math.max(self.nbSamples/self.nbClass,dataSet:size()))
+ -- if self.nbSamples == 0 then
+ -- self.nbSamples = dataSet:size()
+ -- else
+ -- print(self.nbSamples/self.nbClass,dataSet:size())
+ -- self.nbSamples = math.floor(math.max(self.nbSamples/self.nbClass,
+ -- dataSet:size()))
+ -- end
+ if not self.ClassRev[className] then
+ self.ClassRev[className] = true
+ self.nbClass = self.nbClass + 1
+ table.insert(self.ClassName,self.nbClass,className)
end
- self.nbClass = self.nbClass + 1
- self.nbSamples = self.nbSamples * self.nbClass
- table.insert(self.ClassName,self.nbClass,className)
+ self.nbSamples = self.nbSamples + dataSet:size()
end
diff --git a/DataSet.lua b/DataSet.lua
index 36cb7c9..4bb618c 100644
--- a/DataSet.lua
+++ b/DataSet.lua
@@ -30,7 +30,8 @@ end
function lDataSet:load(...)
-- parse args
- local args, dataSetFolder, nbSamplesRequired, cacheFile, channels, sampleSize
+ local args, dataSetFolder, nbSamplesRequired, cacheFile, channels,
+ sampleSize,padding
= xlua.unpack(
{...},
'DataSet.load', nil,
@@ -38,7 +39,8 @@ function lDataSet:load(...)
{arg='nbSamplesRequired', type='number', help='number of patches to load', default='all'},
{arg='cacheFile', type='string', help='path to file to cache files'},
{arg='channels', type='number', help='nb of channels', default=1},
- {arg='sampleSize', type='table', help='resize all sample: {w,h}'}
+ {arg='sampleSize', type='table', help='resize all sample: {c,w,h}'},
+ {arg='padding', type='boolean', help='center sample in w,h dont rescale'}
)
self.cacheFileName = cacheFile or self.cacheFileName
@@ -115,7 +117,7 @@ end
function lDataSet:append(...)
-- parse args
local args, dataSetFolder, channels, nbSamplesRequired, useLabelPiped,
- useDirAsLabel, nbLabels, sampleSize
+ useDirAsLabel, nbLabels, sampleSize, padding
= xlua.unpack(
{...},
'DataSet:append', 'append a folder to the dataset object',
@@ -125,7 +127,8 @@ function lDataSet:append(...)
{arg='useLabelPiped', type='boolean', help='flag to use the filename as output value',default=false},
{arg='useDirAsLabel', type='boolean', help='flag to use the directory as label',default=false},
{arg='nbLabels', type='number', help='how many classes (goes with useDirAsLabel)', default=1},
- {arg='sampleSize', type='table', help='resize all sample: {w,h}'}
+ {arg='sampleSize', type='table', help='resize all sample: {c,w,h}'},
+ {arg='padding',type='boolean',help='do we padd all the inputs in w,h'}
)
-- parse args
local files = sys.dir(dataSetFolder)
@@ -180,6 +183,15 @@ function lDataSet:append(...)
-- rescale ?
if sampleSize then
inputs = torch.Tensor(channels, sampleSize[2], sampleSize[3])
+ if padding then
+ offw = math.floor((sampleSize[2] - input[2])*0.5)
+ offh = math.floor((sampleSize[3] - input[3])*0.5)
+ if offw >= 0 and offh >= 0 then
+ inputs:narrow(2,offw,input[2]):narrow(3,offh,input[3]):copy(input)
+ else
+ print('reverse crop not implemented w,h must be larger than all data points')
+ end
+ end
image.scale(input, inputs, 'bilinear')
else
inputs = input