Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-09-26 20:14:36 +0400
committerSoumith Chintala <soumith@gmail.com>2014-09-26 20:47:31 +0400
commit2da6353dc28457c5b8dd758d552026c9aebdbcba (patch)
tree05e4b1ea5985e3f91c48d716004400fb44957231 /test
parent440ea8adeda929efc637316585446eb9996bd6fe (diff)
adding serialization to unit tests and fixing descriptor checks. Fixes #4
Diffstat (limited to 'test')
-rw-r--r--test/test.lua28
1 files changed, 28 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 0ca3ae5..78a4fdc 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -63,6 +63,12 @@ function cudnntest.SpatialConvolution_backward()
gconv.weight:copy(sconv.weight)
gconv.bias:copy(sconv.bias)
gconv:forward(input)
+
+ -- serialize and deserialize
+ torch.save('modelTemp.t7', gconv)
+ gconv = torch.load('modelTemp.t7')
+
+ gconv:forward(input)
gconv:zeroGradParameters()
local rescuda = gconv:backward(input, gradOutput)
cutorch.synchronize()
@@ -98,6 +104,10 @@ function cudnntest.SpatialMaxPooling()
cutorch.synchronize()
local gconv = cudnn.SpatialMaxPooling(ki,kj,si,sj):cuda()
local rescuda = gconv:forward(input)
+ -- serialize and deserialize
+ torch.save('modelTemp.t7', gconv)
+ gconv = torch.load('modelTemp.t7')
+ local rescuda = gconv:forward(input)
local resgrad = gconv:backward(input, gradOutput)
cutorch.synchronize()
local error = rescuda:float() - groundtruth:float()
@@ -126,6 +136,12 @@ function cudnntest.ReLU()
cutorch.synchronize()
local gconv = cudnn.ReLU(ki,kj,si,sj):cuda()
local rescuda = gconv:forward(input)
+
+ -- serialize and deserialize
+ torch.save('modelTemp.t7', gconv)
+ gconv = torch.load('modelTemp.t7')
+
+ local rescuda = gconv:forward(input)
local resgrad = gconv:backward(input, gradOutput)
cutorch.synchronize()
local error = rescuda:float() - groundtruth:float()
@@ -154,6 +170,12 @@ function cudnntest.Tanh()
cutorch.synchronize()
local gconv = cudnn.Tanh(ki,kj,si,sj):cuda()
local rescuda = gconv:forward(input)
+
+ -- serialize and deserialize
+ torch.save('modelTemp.t7', gconv)
+ gconv = torch.load('modelTemp.t7')
+
+ local rescuda = gconv:forward(input)
local resgrad = gconv:backward(input, gradOutput)
cutorch.synchronize()
local error = rescuda:float() - groundtruth:float()
@@ -182,6 +204,12 @@ function cudnntest.Sigmoid()
cutorch.synchronize()
local gconv = cudnn.Tanh(ki,kj,si,sj):cuda()
local rescuda = gconv:forward(input)
+
+ -- serialize and deserialize
+ torch.save('modelTemp.t7', gconv)
+ gconv = torch.load('modelTemp.t7')
+
+ local rescuda = gconv:forward(input)
local resgrad = gconv:backward(input, gradOutput)
cutorch.synchronize()
local error = rescuda:float() - groundtruth:float()