Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-07 10:25:08 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-07 10:25:08 +0400
commit603dae08fd2eb10964ee9e546b5e90d7528aaf6e (patch)
tree17caa733a045e739b094683005674a0fcf706dc4 /DataList.lua
parentc20ef531db79a83bb751b65798e4d1b77dc7833d (diff)
Added old DataSet clsases.
Diffstat (limited to 'DataList.lua')
-rw-r--r--DataList.lua58
1 files changed, 58 insertions, 0 deletions
diff --git a/DataList.lua b/DataList.lua
new file mode 100644
index 0000000..5b95c48
--- /dev/null
+++ b/DataList.lua
@@ -0,0 +1,58 @@
+--------------------------------------------------------------------------------
+-- DataList: a container for plain DataSets.
+-- Each sub dataset represents one class.
+--
+-- Authors: Corda, Farabet
+--------------------------------------------------------------------------------
+
+local DataList, parent = torch.class('nn.DataList','nn.DataSet')
+
+function DataList:__init()
+ parent.__init(self)
+ self.datasets = {}
+ self.nbClass = 0
+ self.ClassName = {}
+ self.nbSamples = 0
+end
+
+function DataList:__tostring__()
+ str = 'DataList'
+ str = str .. ' + nb samples : '..self.nbSamples
+ str = str .. ' + nb classes : '..self.nbClass
+ return str
+end
+
+function DataList:__index__(key)
+ if type(key)=='number' and self.nbClass>0 and key <= self.nbSamples then
+ local class = ((key-1) % self.nbClass) + 1
+ local classSize = self.datasets[class]:size()
+ local elmt = math.floor((key-1)/self.nbClass) + 1
+ elmt = ((elmt-1) % classSize) + 1
+
+ -- create target vector on the fly
+ self.datasets[class][elmt][2] = torch.Tensor(1,1,self.nbClass):fill(-1)
+ self.datasets[class][elmt][2][1][1][class] = 1
+
+ -- apply hook on sample
+ local sample = self.datasets[class][elmt]
+ if self.hookOnSample then
+ sample = self.hookOnSample(self,sample)
+ end
+
+ return sample,true
+ end
+ -- if key is not a number this should return nil
+ return rawget(self, key)
+end
+
+function DataList:appendDataSet(dataSet,className)
+ table.insert(self.datasets,dataSet)
+ if self.nbSamples == 0 then
+ self.nbSamples = dataSet:size()
+ else
+ self.nbSamples = math.floor(math.max(self.nbSamples/self.nbClass,dataSet:size()))
+ end
+ self.nbClass = self.nbClass + 1
+ self.nbSamples = self.nbSamples * self.nbClass
+ table.insert(self.ClassName,self.nbClass,className)
+end