diff options
-rw-r--r-- | SpatialNormalization.lua | 22 | ||||
-rw-r--r-- | init.lua | 3 | ||||
-rw-r--r-- | nnx-1.0-1.rockspec | 1 |
3 files changed, 22 insertions, 4 deletions
diff --git a/SpatialNormalization.lua b/SpatialNormalization.lua index 349098b..e0e5c60 100644 --- a/SpatialNormalization.lua +++ b/SpatialNormalization.lua @@ -186,14 +186,14 @@ function SpatialNormalization:forward(input) -- auto switch to 3-channel self.input = input if (input:nDimension() == 2) then - self.input = torch.Tensor(1,input:size(1),input:size(2)):copy(input) + self.input = input:clone():resize(1,input:size(1),input:size(2)) end -- recompute coef only if necessary if (self.input:size(3) ~= self.coef:size(2)) or (self.input:size(2) ~= self.coef:size(1)) then - local intVals = torch.Tensor(self.nfeatures,self.input:size(2),self.input:size(3)):fill(1) + local intVals = self.input.new(self.nfeatures,self.input:size(2),self.input:size(3)):fill(1) self.coef = self.convo:forward(intVals) - self.coef = torch.Tensor():resizeAs(self.coef):copy(self.coef) + self.coef = self.coef:clone() end -- compute mean @@ -221,7 +221,7 @@ function SpatialNormalization:backward(input, gradOutput) -- auto switch to 3-channel self.input = input if (input:nDimension() == 2) then - self.input = torch.Tensor(1,input:size(1),input:size(2)):copy(input) + self.input = input:clone():resize(1,input:size(1),input:size(2)) end self.gradInput:resizeAs(self.input):zero() @@ -246,6 +246,20 @@ function SpatialNormalization:backward(input, gradOutput) return self.gradInput end +function SpatialNormalization:type(type) + parent.type(self,type) + self.convo:type(type) + self.meanDiviseMod:type(type) + self.subtractMod:type(type) + self.squareMod:type(type) + self.convostd:type(type) + self.sqrtMod:type(type) + self.stdDiviseMod:type(type) + self.thresMod:type(type) + self.diviseMod:type(type) + return self +end + function SpatialNormalization:write(file) parent.write(self,file) file:writeObject(self.kernel) @@ -108,6 +108,9 @@ torch.include('nnx', 'Trainer.lua') torch.include('nnx', 'OnlineTrainer.lua') torch.include('nnx', 'BatchTrainer.lua') +-- conversion helper: +torch.include('nnx', 'Type.lua') + -- datasets: torch.include('nnx', 'DataSet.lua') torch.include('nnx', 'DataList.lua') diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec index 120cafc..38c0af2 100644 --- a/nnx-1.0-1.rockspec +++ b/nnx-1.0-1.rockspec @@ -103,6 +103,7 @@ build = { install_files(/lua/nnx Probe.lua) install_files(/lua/nnx HardShrink.lua) install_files(/lua/nnx Narrow.lua) + install_files(/lua/nnx Type.lua) install_files(/lua/nnx Power.lua) install_files(/lua/nnx Square.lua) install_files(/lua/nnx Sqrt.lua) |