diff options
Diffstat (limited to 'DataList.lua')
-rw-r--r-- | DataList.lua | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/DataList.lua b/DataList.lua index 99b117a..4922e8b 100644 --- a/DataList.lua +++ b/DataList.lua @@ -13,6 +13,8 @@ function DataList:__init() self.nbClass = 0 self.ClassName = {} self.nbSamples = 0 + self.targetIsProbability = false + self.spatialTarget = false end function DataList:__tostring__() @@ -30,8 +32,21 @@ function DataList:__index__(key) elmt = ((elmt-1) % classSize) + 1 -- create target vector on the fly - self.datasets[class][elmt][2] = torch.Tensor(1,1,self.nbClass):fill(-1) - self.datasets[class][elmt][2][1][1][class] = 1 + if self.spatialTarget then + if self.targetIsProbability then + self.datasets[class][elmt][2] = torch.Tensor(self.nbClass,1,1):zero() + else + self.datasets[class][elmt][2] = torch.Tensor(self.nbClass,1,1):fill(-1) + end + self.datasets[class][elmt][2][class][1][1] = 1 + else + if self.targetIsProbability then + self.datasets[class][elmt][2] = torch.Tensor(self.nbClass):zero() + else + self.datasets[class][elmt][2] = torch.Tensor(self.nbClass):fill(-1) + end + self.datasets[class][elmt][2][class] = 1 + end -- apply hook on sample local sample = self.datasets[class][elmt] |