Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'DataList.lua')
-rw-r--r--DataList.lua19
1 files changed, 17 insertions, 2 deletions
diff --git a/DataList.lua b/DataList.lua
index 99b117a..4922e8b 100644
--- a/DataList.lua
+++ b/DataList.lua
@@ -13,6 +13,8 @@ function DataList:__init()
self.nbClass = 0
self.ClassName = {}
self.nbSamples = 0
+ self.targetIsProbability = false
+ self.spatialTarget = false
end
function DataList:__tostring__()
@@ -30,8 +32,21 @@ function DataList:__index__(key)
elmt = ((elmt-1) % classSize) + 1
-- create target vector on the fly
- self.datasets[class][elmt][2] = torch.Tensor(1,1,self.nbClass):fill(-1)
- self.datasets[class][elmt][2][1][1][class] = 1
+ if self.spatialTarget then
+ if self.targetIsProbability then
+ self.datasets[class][elmt][2] = torch.Tensor(self.nbClass,1,1):zero()
+ else
+ self.datasets[class][elmt][2] = torch.Tensor(self.nbClass,1,1):fill(-1)
+ end
+ self.datasets[class][elmt][2][class][1][1] = 1
+ else
+ if self.targetIsProbability then
+ self.datasets[class][elmt][2] = torch.Tensor(self.nbClass):zero()
+ else
+ self.datasets[class][elmt][2] = torch.Tensor(self.nbClass):fill(-1)
+ end
+ self.datasets[class][elmt][2][class] = 1
+ end
-- apply hook on sample
local sample = self.datasets[class][elmt]