OILS / mycpp / pass_state_test.py View on Github | oilshell.org

541 lines, 299 significant
1#!/usr/bin/env python3
2"""
3pass_state_test.py: Tests for pass_state.py
4"""
5from __future__ import print_function
6
7import unittest
8
9import pass_state # module under test
10
11
12class VirtualTest(unittest.TestCase):
13
14 def testVirtual(self):
15 """
16 Example:
17
18 class Base(object):
19 def method(self): # we don't know if this is virtual yet
20 pass
21 def x(self):
22 pass
23
24 class Derived(Base):
25 def method(self): # now it's virtual!
26 pass
27 def y(self):
28 pass
29 """
30 v = pass_state.Virtual()
31 v.OnMethod(('Base',), 'method')
32 v.OnMethod(('Base',), 'x')
33 v.OnSubclass(('Base',), ('Derived',))
34 v.OnMethod(('Derived',), 'method')
35 v.OnMethod(('Derived',), 'y')
36
37 v.Calculate()
38
39 print(v.virtuals)
40 self.assertEqual({(('Base',), 'method'): None,
41 (('Derived',), 'method'): (('Base',), 'method')},
42 v.virtuals)
43
44 self.assertEqual(True, v.IsVirtual(('Base',), 'method'))
45 self.assertEqual(True, v.IsVirtual(('Derived',), 'method'))
46 self.assertEqual(False, v.IsVirtual(('Derived',), 'y'))
47
48 self.assertEqual(False, v.IsVirtual(('Klass',), 'z'))
49
50 self.assertEqual(True, v.HasVTable(('Base',)))
51 self.assertEqual(True, v.HasVTable(('Derived',)))
52
53 self.assertEqual(False, v.HasVTable(('Klass',)))
54
55 def testNoInit(self):
56 v = pass_state.Virtual()
57 v.OnMethod(('Base',), '__init__')
58 v.OnSubclass(('Base',), ('Derived',))
59 v.OnMethod(('Derived',), '__init__')
60 v.Calculate()
61 self.assertEqual(False, v.HasVTable(('Base',)))
62 self.assertEqual(False, v.HasVTable(('Derived',)))
63
64 def testCanReorderFields(self):
65 """
66 class Base(object):
67 def __init__(self):
68 self.s = '' # pointer
69 self.i = 42
70
71 class Derived(Base):
72 def __init__(self):
73 Base.__init__()
74 self.mylist = [] # type: List[str]
75
76 Note: we can't reorder these, even though there are no virtual methods.
77 """
78 v = pass_state.Virtual()
79 v.OnSubclass(('Base2',), ('Derived2',))
80 v.Calculate()
81
82 self.assertEqual(False, v.CanReorderFields(('Base2',)))
83 self.assertEqual(False, v.CanReorderFields(('Derived2',)))
84
85 self.assertEqual(True, v.CanReorderFields(('Klass2',)))
86
87 def testBaseCollision(self):
88 v = pass_state.Virtual()
89 v.OnSubclass(('moduleA', 'Base',), ('foo', 'Derived',))
90 with self.assertRaises(AssertionError):
91 v.OnSubclass(('moduleB', 'Base',), ('bar', 'Derived',))
92
93 def testSubclassMapping(self):
94 v = pass_state.Virtual()
95 v.OnMethod(('moduleA', 'Base',), 'frobnicate')
96 v.OnSubclass(('moduleA', 'Base',), ('foo', 'Derived',))
97 v.OnMethod(('foo', 'Derived',), 'frobnicate')
98 v.OnSubclass(('moduleA', 'Base',), ('bar', 'Derived',))
99 v.OnMethod(('bar', 'Derived',), 'frobnicate')
100 v.Calculate()
101 self.assertEqual(
102 (('moduleA', 'Base'), 'frobnicate'),
103 v.virtuals[(('foo', 'Derived'), 'frobnicate')])
104 self.assertEqual(
105 (('moduleA', 'Base'), 'frobnicate'),
106 v.virtuals[(('bar', 'Derived'), 'frobnicate')])
107 self.assertEqual(
108 None,
109 v.virtuals[(('moduleA', 'Base'), 'frobnicate')])
110
111
112class DummyFact(pass_state.Fact):
113 def __init__(self, n: int) -> None:
114 self.n = n
115
116 def name(self): return 'dummy'
117
118 def Generate(self, func: str, statement: int) -> str:
119 return '{},{},{}'.format(func, statement, self.n)
120
121
122class ControlFlowGraphTest(unittest.TestCase):
123
124 def testLinear(self):
125 cfg = pass_state.ControlFlowGraph()
126
127 a = cfg.AddStatement()
128 b = cfg.AddStatement()
129 c = cfg.AddStatement()
130 d = cfg.AddStatement()
131
132 cfg.AddFact(b, DummyFact(1))
133 cfg.AddFact(d, DummyFact(99))
134 cfg.AddFact(d, DummyFact(7))
135
136 expected_edges = {
137 (0, a), (a, b), (b, c), (c, d),
138 }
139 self.assertEqual(expected_edges, cfg.edges)
140
141 self.assertEqual(1, len(cfg.facts[b]))
142 self.assertEqual('foo,1,1', cfg.facts[b][0].Generate('foo', 1))
143 self.assertEqual('dummy', cfg.facts[b][0].name())
144 self.assertEqual(2, len(cfg.facts[d]))
145 self.assertEqual('bar,1,99', cfg.facts[d][0].Generate('bar', 1))
146 self.assertEqual('bar,2,7', cfg.facts[d][1].Generate('bar', 2))
147
148 def testBranches(self):
149 cfg = pass_state.ControlFlowGraph()
150
151 main0 = cfg.AddStatement()
152
153 # branch condition facts all get attached to this statement
154 branch_point = cfg.AddStatement()
155
156 # first statement in if block
157 with pass_state.CfgBranchContext(cfg, branch_point) as branch_ctx:
158 with branch_ctx.AddBranch() as arm0:
159 arm0_a = cfg.AddStatement() # block statement 2
160 arm0_b = cfg.AddStatement() # block statement 2
161 arm0_c = cfg.AddStatement() # block statement 3
162
163 # frist statement in elif block
164 with branch_ctx.AddBranch() as arm1:
165 arm1_a = cfg.AddStatement()
166 arm1_b = cfg.AddStatement() # block statement 2
167
168 # frist statement in else block
169 with branch_ctx.AddBranch() as arm2:
170 arm2_a = cfg.AddStatement()
171 arm2_b = cfg.AddStatement() # block statement 2
172
173 self.assertEqual(arm0_c, arm0.exit)
174 self.assertEqual(arm1_b, arm1.exit)
175 self.assertEqual(arm2_b, arm2.exit)
176
177 join = cfg.AddStatement()
178 end = cfg.AddStatement()
179
180 """
181 We expecte a graph like this.
182
183 begin
184 |
185 main0
186 |
187 v
188 branch_point
189 / | \
190 arm0_a arm1_a arm2_a
191 | | |
192 arm0_b arm1_b arm2_b
193 | | |
194 arm0_c | |
195 | | /
196 \ | /
197 \ | /
198 \ | /
199 \ | /
200 join
201 |
202 end
203 """
204 expected_edges = {
205 (0, main0),
206 (main0, branch_point),
207 (branch_point, arm0_a), (branch_point, arm1_a), (branch_point, arm2_a),
208 (arm0_a, arm0_b), (arm0_b, arm0_c),
209 (arm1_a, arm1_b),
210 (arm2_a, arm2_b),
211 (arm0_c, join), (arm1_b, join), (arm2_b, join),
212 (join, end),
213 }
214 self.assertEqual(expected_edges, cfg.edges)
215
216 def testDeadends(self):
217 """
218 Make sure we don't create orphans in the presence of continue, return,
219 raise, etc...
220 """
221
222 cfg = pass_state.ControlFlowGraph()
223 with pass_state.CfgBranchContext(cfg, cfg.AddStatement()) as branch_ctx:
224 with branch_ctx.AddBranch() as branchA: # if
225 ret = cfg.AddStatement() # return
226 cfg.AddDeadend(ret)
227
228 """
229 while ...:
230 if ...:
231 continue
232 else:
233 print(...)
234
235 print(...)
236 """
237 with pass_state.CfgLoopContext(cfg) as loop:
238 branch_point = cfg.AddStatement()
239 with pass_state.CfgBranchContext(cfg, branch_point) as branch_ctx:
240 with branch_ctx.AddBranch() as branchB: # if
241 cont = cfg.AddStatement() # continue
242 loop.AddContinue(cont)
243
244 with branch_ctx.AddBranch() as branchC: # else
245 innerC = cfg.AddStatement()
246
247 end = cfg.AddStatement()
248 expected_edges = {
249 (0, branchA.entry),
250 (branchA.entry, ret),
251 (branchA.entry, loop.entry),
252 (loop.entry, branchB.entry),
253 (branch_point, cont),
254 (cont, loop.entry),
255 (branch_point, innerC),
256 (innerC, end),
257 (innerC, loop.entry),
258 }
259 self.assertEqual(expected_edges, cfg.edges)
260
261 def testNedstedIf(self):
262 """
263 The mypy AST represents else-if as nested if-statements inside the else arm.
264 """
265 cfg = pass_state.ControlFlowGraph()
266
267 outer_branch_point = cfg.AddStatement()
268 with pass_state.CfgBranchContext(cfg, outer_branch_point) as branch_ctx:
269 with branch_ctx.AddBranch() as branch0: # if
270 branch0_a = cfg.AddStatement()
271
272 with branch_ctx.AddBranch() as branch1: # else
273 with branch1.AddBranch(cfg.AddStatement()) as branch2: # if
274 branch2_a = cfg.AddStatement()
275
276 branch1_a = cfg.AddStatement()
277
278 end = cfg.AddStatement()
279
280 """
281 We expect a graph like this.
282
283 begin
284 |
285 outer_branch_point +------
286 | | \
287 branch0_a | branch2.entry
288 | | |
289 | | branch2_a
290 | | |
291 | | /
292 | | /
293 | | /
294 | branch1_a
295 | /
296 | /
297 | /
298 | /
299 end _____/
300 """
301 expected_edges = {
302 (0, outer_branch_point),
303 (outer_branch_point, branch0_a),
304 (outer_branch_point, branch2.entry),
305 (branch2.entry, branch2_a),
306 (branch2_a, branch1_a),
307 (branch0.exit, end),
308 (branch1.exit, end),
309 (branch2.exit, end),
310 }
311 self.assertEqual(expected_edges, cfg.edges)
312
313
314 def testLoops(self):
315 cfg = pass_state.ControlFlowGraph()
316
317 with pass_state.CfgLoopContext(cfg) as loopA:
318 branch_point = cfg.AddStatement()
319 with pass_state.CfgBranchContext(cfg, branch_point) as branch_ctx:
320 with branch_ctx.AddBranch() as arm0:
321 arm0_a = cfg.AddStatement()
322 arm0_b = cfg.AddStatement()
323
324 with branch_ctx.AddBranch() as arm1:
325 arm1_a = cfg.AddStatement()
326 arm1_b = cfg.AddStatement()
327
328 self.assertEqual(arm0_b, arm0.exit)
329 self.assertEqual(arm1_b, arm1.exit)
330
331 with pass_state.CfgLoopContext(cfg) as loopB:
332 innerB = cfg.AddStatement()
333
334 self.assertEqual(innerB, loopB.exit)
335
336 end = cfg.AddStatement()
337
338 """
339 We expecte a graph like this:.
340
341 begin
342 |
343 loopA <------+
344 | |
345 v |
346 branch_point |
347 / \ |
348 arm0_a arm2_a |
349 | | |
350 arm0_b arm2_b |
351 \ / |
352 \ / |
353 loopB <-+ |
354 | | |
355 innerB -+---+
356 |
357 end
358 """
359 expected_edges = {
360 (0, loopA.entry),
361 (loopA.entry, branch_point),
362 (branch_point, arm0_a), (branch_point, arm1_a),
363 (arm0_a, arm0_b),
364 (arm1_a, arm1_b),
365 (arm0_b, loopB.entry), (arm1_b, loopB.entry),
366 (loopB.entry, innerB),
367 (innerB, loopA.entry), (innerB, loopB.entry),
368 (innerB, end),
369 }
370 self.assertEqual(expected_edges, cfg.edges)
371
372
373 def testLoops2(self):
374 cfg = pass_state.ControlFlowGraph()
375
376 with pass_state.CfgLoopContext(cfg) as loopA:
377 with pass_state.CfgLoopContext(cfg) as loopB:
378 innerB = cfg.AddStatement()
379
380 innerA = cfg.AddStatement()
381
382 end = cfg.AddStatement()
383
384 expected_edges = {
385 (0, loopA.entry),
386 (loopA.entry, loopB.entry),
387 (loopB.entry, innerB),
388 (innerB, innerA),
389 (innerB, loopB.entry),
390 (innerA, loopA.entry),
391 (innerA, end),
392 }
393 self.assertEqual(expected_edges, cfg.edges)
394
395 def testDeepTry(self):
396 """
397 A code snippet like the following.
398
399 1 while i < n:
400 2 for prog in cases:
401 3 try:
402 4 result = f(prog)
403 except ParseError as e:
404 5 num_exceptions += 1
405 6 continue
406 7 i += 1
407
408 8 mylib.MaybeCollect() # manual GC point
409
410 9 log('num_exceptions = %d', num_exceptions)
411 """
412 cfg = pass_state.ControlFlowGraph()
413
414 with pass_state.CfgLoopContext(cfg) as loopA:
415 with pass_state.CfgLoopContext(cfg) as loopB:
416 with pass_state.CfgBlockContext(cfg) as try_block:
417 try_s0 = cfg.AddStatement()
418
419 with pass_state.CfgBlockContext(cfg, try_block.exit) as except_block:
420 except_s0 = cfg.AddStatement()
421 cont = cfg.AddStatement()
422 loopB.AddContinue(cont)
423
424 a_s0 = cfg.AddStatement()
425 a_s1 = cfg.AddStatement()
426
427 log_stmt = cfg.AddStatement()
428 end = cfg.AddStatement()
429
430 expected_edges = {
431 (0, loopA.entry),
432 (loopA.entry, loopB.entry),
433 (loopB.entry, try_block.entry),
434 (try_block.entry, try_s0),
435 (try_s0, except_s0),
436 (try_s0, loopB.entry),
437 (except_s0, cont),
438 (cont, loopB.entry),
439 (try_block.exit, a_s0),
440 (a_s0, a_s1),
441 (a_s1, loopA.entry),
442 (a_s1, log_stmt),
443 (log_stmt, end),
444 }
445 self.assertEqual(expected_edges, cfg.edges)
446
447 def testLoopWithDanglingBlocks(self):
448 """
449 for i in xrange(1000000):
450 try:
451 with ctx_DirStack(d, 'foo') as _:
452 if i % 10000 == 0:
453 raise MyError()
454 pass
455 except MyError:
456 log('exception')
457 """
458 cfg = pass_state.ControlFlowGraph()
459
460 with pass_state.CfgLoopContext(cfg) as loop:
461 with pass_state.CfgBranchContext(cfg, cfg.AddStatement()) as try_ctx:
462 with try_ctx.AddBranch() as try_block:
463 with pass_state.CfgBlockContext(cfg, cfg.AddStatement()) as with_block:
464 with pass_state.CfgBranchContext(cfg, cfg.AddStatement()) as if_ctx:
465 with if_ctx.AddBranch() as if_block:
466 s_raise = cfg.AddStatement()
467 cfg.AddDeadend(s_raise)
468
469 pass_stmt = cfg.AddStatement()
470
471 with try_ctx.AddBranch(try_block.exit) as except_block:
472 log_stmt = cfg.AddStatement()
473
474 expected_edges = {
475 (0, loop.entry),
476 (loop.entry, try_block.entry),
477 (try_block.entry, with_block.entry),
478 (with_block.entry, if_block.entry),
479 (if_block.entry, s_raise),
480 (if_block.entry, pass_stmt),
481 (pass_stmt, loop.entry),
482 (pass_stmt, log_stmt),
483 (log_stmt, loop.entry),
484 }
485 self.assertEqual(expected_edges, cfg.edges)
486
487 def testLoopBreak(self):
488 """
489 while ...:
490 if ...:
491 break
492
493 else:
494 try:
495 pass
496
497 except ...:
498 break
499
500 pass
501
502 pass
503 """
504 cfg = pass_state.ControlFlowGraph()
505
506 with pass_state.CfgLoopContext(cfg) as loop:
507 with pass_state.CfgBranchContext(cfg, cfg.AddStatement()) as if_ctx:
508 with if_ctx.AddBranch() as if_block:
509 break1 = cfg.AddStatement()
510 loop.AddBreak(break1)
511
512 with if_ctx.AddBranch() as else_block:
513 with pass_state.CfgBranchContext(cfg, cfg.AddStatement()) as try_ctx:
514 with try_ctx.AddBranch() as try_block:
515 pass1 = cfg.AddStatement()
516
517 with try_ctx.AddBranch(try_block.exit) as except_block:
518 break2 = cfg.AddStatement()
519 loop.AddBreak(break2)
520
521 pass2 = cfg.AddStatement()
522
523 pass3 = cfg.AddStatement()
524
525 expected_edges = {
526 (0, loop.entry),
527 (loop.entry, if_block.entry),
528 (if_block.entry, break1),
529 (if_block.entry, try_block.entry),
530 (try_block.entry, pass1),
531 (pass1, break2),
532 (pass1, pass2),
533 (pass2, loop.entry),
534 (pass2, pass3),
535 (break1, pass3),
536 (break2, pass3),
537 }
538 self.assertEqual(expected_edges, cfg.edges)
539
540if __name__ == '__main__':
541 unittest.main()