Welcome to mirror list, hosted at ThFree Co, Russian Federation.

torch9luaffi.lua « benchmark - github.com/torch/argcheck.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 161e592a62d1e2eee0c13a6f9e6ba8d25c988b4b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
local argcheck = require 'argcheck'
local ffi = require 'ffi'

local env = require 'argcheck.env'

local SZ  = tonumber(arg[1])
local N   = tonumber(arg[2])
local scale = tonumber(arg[3]) or 1
local dbg = arg[4] == '1'
local named = arg[5] == '1'

if named then
   print('warning: using named arguments!')
end

function env.istype(obj, typename)
   if type(obj) == 'userdata' then
      if typename == 'torch.DoubleTensor' then
         return true
      else
         return false
      end
   end
   return type(obj) == typename
end

ffi.cdef[[

typedef struct THLongStorage THLongStorage;
THLongStorage* THLongStorage_newWithSize2(long, long);
void THLongStorage_free(THLongStorage *storage);

typedef struct THGenerator THGenerator;
THGenerator* THGenerator_new();
void THRandom_manualSeed(THGenerator *_generator, unsigned long the_seed_);

typedef struct THDoubleTensor THDoubleTensor;
THDoubleTensor *THDoubleTensor_new(void);
void THDoubleTensor_free(THDoubleTensor *self);
void THDoubleTensor_rand(THDoubleTensor *r_, THGenerator *_generator, THLongStorage *size);
void THDoubleTensor_add(THDoubleTensor *r_, THDoubleTensor *t, double value);
void THDoubleTensor_cadd(THDoubleTensor *r_, THDoubleTensor *t, double value, THDoubleTensor *src);
double THDoubleTensor_normall(THDoubleTensor *t, double value);

]]

local status, C = pcall(ffi.load, ffi.os == 'OSX' and 'libTH.dylib' or 'libTH.so')
if not status then
   error('please specify path to libTH in your (DY)LD_LIBRARY_PATH')
end

local DoubleTensor = {}

function DoubleTensor_new()
   local self = C.THDoubleTensor_new()
   ffi.gc(self, C.THDoubleTensor_free)
   return self
end

function DoubleTensor:norm(l)
   l = l or 2
   return tonumber(C.THDoubleTensor_normall(self, l))
end

DoubleTensor_mt = {__index=DoubleTensor, __new=DoubleTensor_new}
DoubleTensor = ffi.metatype('THDoubleTensor', DoubleTensor_mt)

local _gen = C.THGenerator_new()
C.THRandom_manualSeed(_gen, 1111)

local function rand(a, b)
   local size = C.THLongStorage_newWithSize2(a, b)
   local self = DoubleTensor()
   C.THDoubleTensor_rand(self, _gen, size)
   C.THLongStorage_free(size)
   return self
end

local add
local dotgraph

for _, RealTensor in ipairs{'torch.ByteTensor', 'torch.ShortTensor', 'torch.FloatTensor',
'torch.LongTensor', 'torch.IntTensor', 'torch.CharTensor',
'torch.DoubleTensor'} do

   add = argcheck{
      {name="res", type=RealTensor, opt=true},
      {name="src", type=RealTensor},
      {name="value", type="number"},
      call =
         function(res, src, value)
            res = res or DoubleTensor()
            C.THDoubleTensor_add(res, src, value)
            return res
         end
   }

   add, dotgraph = argcheck{
      debug=dbg,
      overload = add,
      {name="res", type=RealTensor, opt=true},
      {name="src1", type=RealTensor},
      {name="value", type="number", default=1},
      {name="src2", type=RealTensor},
      call =
         function(res, src1, value, src2)
            res = res or torch.DoubleTensor()
            C.THDoubleTensor_cadd(res, src1, value, src2)
            return res
         end
   }

end

if dotgraph then
   local f = io.open('argtree.dot', 'w')
   f:write(dotgraph)
   f:close()
end

local x = rand(SZ, SZ)
local y = rand(SZ, SZ)

print('x', x:norm())
print('y', x:norm())
print('running')

if named then
   local clk = os.clock()
   if scale == 1 then
      for i=1,N do
         add{res=y, src=x, value=5}
         add{res=y, src1=x, src2=y}
      end
   else
      for i=1,N do
         add{res=y, src=x, value=5}
         add{res=y, src1=x, value=scale, src2=y}
      end
   end
   print('time (s)', os.clock()-clk)
else
   local clk = os.clock()
   if scale == 1 then
      for i=1,N do
         add(y, x, 5)
         add(y, x, y)
      end
   else
      for i=1,N do
         add(y, x, 5)
         add(y, x, scale, y)
      end
   end
   print('time (s)', os.clock()-clk)
end

print('x', x:norm())
print('y', y:norm())