diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-07-07 10:25:08 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-07-07 10:25:08 +0400 |
commit | 603dae08fd2eb10964ee9e546b5e90d7528aaf6e (patch) | |
tree | 17caa733a045e739b094683005674a0fcf706dc4 /DataList.lua | |
parent | c20ef531db79a83bb751b65798e4d1b77dc7833d (diff) |
Added old DataSet clsases.
Diffstat (limited to 'DataList.lua')
-rw-r--r-- | DataList.lua | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/DataList.lua b/DataList.lua new file mode 100644 index 0000000..5b95c48 --- /dev/null +++ b/DataList.lua @@ -0,0 +1,58 @@ +-------------------------------------------------------------------------------- +-- DataList: a container for plain DataSets. +-- Each sub dataset represents one class. +-- +-- Authors: Corda, Farabet +-------------------------------------------------------------------------------- + +local DataList, parent = torch.class('nn.DataList','nn.DataSet') + +function DataList:__init() + parent.__init(self) + self.datasets = {} + self.nbClass = 0 + self.ClassName = {} + self.nbSamples = 0 +end + +function DataList:__tostring__() + str = 'DataList' + str = str .. ' + nb samples : '..self.nbSamples + str = str .. ' + nb classes : '..self.nbClass + return str +end + +function DataList:__index__(key) + if type(key)=='number' and self.nbClass>0 and key <= self.nbSamples then + local class = ((key-1) % self.nbClass) + 1 + local classSize = self.datasets[class]:size() + local elmt = math.floor((key-1)/self.nbClass) + 1 + elmt = ((elmt-1) % classSize) + 1 + + -- create target vector on the fly + self.datasets[class][elmt][2] = torch.Tensor(1,1,self.nbClass):fill(-1) + self.datasets[class][elmt][2][1][1][class] = 1 + + -- apply hook on sample + local sample = self.datasets[class][elmt] + if self.hookOnSample then + sample = self.hookOnSample(self,sample) + end + + return sample,true + end + -- if key is not a number this should return nil + return rawget(self, key) +end + +function DataList:appendDataSet(dataSet,className) + table.insert(self.datasets,dataSet) + if self.nbSamples == 0 then + self.nbSamples = dataSet:size() + else + self.nbSamples = math.floor(math.max(self.nbSamples/self.nbClass,dataSet:size())) + end + self.nbClass = self.nbClass + 1 + self.nbSamples = self.nbSamples * self.nbClass + table.insert(self.ClassName,self.nbClass,className) +end |