diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-09-26 20:14:36 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-09-26 20:47:31 +0400 |
commit | 2da6353dc28457c5b8dd758d552026c9aebdbcba (patch) | |
tree | 05e4b1ea5985e3f91c48d716004400fb44957231 /test | |
parent | 440ea8adeda929efc637316585446eb9996bd6fe (diff) |
adding serialization to unit tests and fixing descriptor checks. Fixes #4
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 0ca3ae5..78a4fdc 100644 --- a/test/test.lua +++ b/test/test.lua @@ -63,6 +63,12 @@ function cudnntest.SpatialConvolution_backward() gconv.weight:copy(sconv.weight) gconv.bias:copy(sconv.bias) gconv:forward(input) + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + gconv:forward(input) gconv:zeroGradParameters() local rescuda = gconv:backward(input, gradOutput) cutorch.synchronize() @@ -98,6 +104,10 @@ function cudnntest.SpatialMaxPooling() cutorch.synchronize() local gconv = cudnn.SpatialMaxPooling(ki,kj,si,sj):cuda() local rescuda = gconv:forward(input) + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + local rescuda = gconv:forward(input) local resgrad = gconv:backward(input, gradOutput) cutorch.synchronize() local error = rescuda:float() - groundtruth:float() @@ -126,6 +136,12 @@ function cudnntest.ReLU() cutorch.synchronize() local gconv = cudnn.ReLU(ki,kj,si,sj):cuda() local rescuda = gconv:forward(input) + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + local rescuda = gconv:forward(input) local resgrad = gconv:backward(input, gradOutput) cutorch.synchronize() local error = rescuda:float() - groundtruth:float() @@ -154,6 +170,12 @@ function cudnntest.Tanh() cutorch.synchronize() local gconv = cudnn.Tanh(ki,kj,si,sj):cuda() local rescuda = gconv:forward(input) + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + local rescuda = gconv:forward(input) local resgrad = gconv:backward(input, gradOutput) cutorch.synchronize() local error = rescuda:float() - groundtruth:float() @@ -182,6 +204,12 @@ function cudnntest.Sigmoid() cutorch.synchronize() local gconv = cudnn.Tanh(ki,kj,si,sj):cuda() local rescuda = gconv:forward(input) + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + local rescuda = gconv:forward(input) local resgrad = gconv:backward(input, gradOutput) cutorch.synchronize() local error = rescuda:float() - groundtruth:float() |