OILS / mycpp / control_flow_pass.py View on Github | oilshell.org

423 lines, 292 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4import collections
5from typing import overload, Union, Optional, Dict
6
7import mypy
8from mypy.visitor import ExpressionVisitor, StatementVisitor
9from mypy.nodes import (Block, Expression, Statement, ExpressionStmt, StrExpr,
10 CallExpr, FuncDef, IfStmt, NameExpr, MemberExpr)
11
12from mypy.types import CallableType, Instance, Type, UnionType
13
14from mycpp.crash import catch_errors
15from mycpp.util import join_name, split_py_name
16from mycpp import util
17from mycpp import pass_state
18
19T = None # TODO: Make it type check?
20
21
22class UnsupportedException(Exception):
23 pass
24
25
26class Build(ExpressionVisitor[T], StatementVisitor[None]):
27
28 def __init__(self, types: Dict[Expression, Type], virtual, local_vars, imported_names):
29
30 self.types = types
31 self.cfgs = collections.defaultdict(pass_state.ControlFlowGraph)
32 self.current_statement_id = None
33 self.current_class_name = None
34 self.current_func_node = None
35 self.loop_stack = []
36 self.virtual = virtual
37 self.local_vars = local_vars
38 self.imported_names = imported_names
39 self.callees = {} # statement object -> SymbolPath of the callee
40
41 def current_cfg(self):
42 if not self.current_func_node:
43 return None
44
45 return self.cfgs[split_py_name(self.current_func_node.fullname)]
46
47 def resolve_callee(self, o: CallExpr) -> Optional[util.SymbolPath]:
48 """
49 Returns the fully qualified name of the callee in the given call
50 expression.
51
52 Member functions are prefixed by the names of the classes that contain
53 them. For example, the name of the callee in the last statement of the
54 snippet below is `module.SomeObject.Foo`.
55
56 x = module.SomeObject()
57 x.Foo()
58
59 Free-functions defined in the local module are referred to by their
60 normal fully qualified names. The function `foo` in a module called
61 `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
62 in imported modules are named the same way.
63 """
64
65 if isinstance(o.callee, NameExpr):
66 return split_py_name(o.callee.fullname)
67
68 elif isinstance(o.callee, MemberExpr):
69 if isinstance(o.callee.expr, NameExpr):
70 is_module = (isinstance(o.callee.expr, NameExpr) and
71 o.callee.expr.name in self.imported_names)
72 if is_module:
73 return split_py_name(
74 o.callee.expr.fullname) + (o.callee.name, )
75
76 elif o.callee.expr.name == 'self':
77 assert self.current_class_name
78 return self.current_class_name + (o.callee.name, )
79
80 else:
81 local_type = None
82 for name, t in self.local_vars.get(self.current_func_node,
83 []):
84 if name == o.callee.expr.name:
85 local_type = t
86 break
87
88 if local_type:
89 if isinstance(local_type, str):
90 return split_py_name(local_type) + (
91 o.callee.name, )
92
93 elif isinstance(local_type, Instance):
94 return split_py_name(
95 local_type.type.fullname) + (o.callee.name, )
96
97 elif isinstance(local_type, UnionType):
98 assert len(local_type.items) == 2
99 return split_py_name(
100 local_type.items[0].type.fullname) + (
101 o.callee.expr.name, )
102
103 else:
104 assert not isinstance(local_type, CallableType)
105 # primitive type or string. don't care.
106 return None
107
108 else:
109 # context or exception handler. probably safe to ignore.
110 return None
111
112 else:
113 t = self.types.get(o.callee.expr)
114 if isinstance(t, Instance):
115 return split_py_name(t.type.fullname) + (o.callee.name, )
116
117 elif isinstance(t, UnionType):
118 assert len(t.items) == 2
119 return split_py_name(
120 t.items[0].type.fullname) + (o.callee.name, )
121
122 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
123 None):
124 return split_py_name(
125 o.callee.expr.fullname) + (o.callee.name, )
126
127 else:
128 # constructors of things that we don't care about.
129 return None
130
131 # Don't currently get here
132 raise AssertionError()
133
134 #
135 # COPIED from IRBuilder
136 #
137
138 @overload
139 def accept(self, node: Expression) -> T:
140 ...
141
142 @overload
143 def accept(self, node: Statement) -> None:
144 ...
145
146 def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
147 with catch_errors(self.module_path, node.line):
148 if isinstance(node, Expression):
149 try:
150 res = node.accept(self)
151 #res = self.coerce(res, self.node_type(node), node.line)
152
153 # If we hit an error during compilation, we want to
154 # keep trying, so we can produce more error
155 # messages. Generate a temp of the right type to keep
156 # from causing more downstream trouble.
157 except UnsupportedException:
158 res = self.alloc_temp(self.node_type(node))
159 return res
160 else:
161 try:
162 cfg = self.current_cfg()
163 # Most statements have empty visitors because they don't
164 # require any special logic. Create statements for them
165 # here. Don't create statements from blocks to avoid
166 # stuttering.
167 if cfg and not isinstance(node, Block):
168 self.current_statement_id = cfg.AddStatement()
169
170 node.accept(self)
171 except UnsupportedException:
172 pass
173 return None
174
175 # Not in superclasses:
176
177 def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
178 if util.ShouldSkipPyFile(o):
179 return
180
181 self.module_path = o.path
182
183 for node in o.defs:
184 # skip module docstring
185 if isinstance(node, ExpressionStmt) and isinstance(
186 node.expr, StrExpr):
187 continue
188 self.accept(node)
189
190 # LITERALS
191
192 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
193 cfg = self.current_cfg()
194 with pass_state.CfgLoopContext(
195 cfg, entry=self.current_statement_id) as loop:
196 self.accept(o.expr)
197 self.loop_stack.append(loop)
198 self.accept(o.body)
199 self.loop_stack.pop()
200
201 def _handle_switch(self, expr, o, cfg):
202 assert len(o.body.body) == 1, o.body.body
203 if_node = o.body.body[0]
204 assert isinstance(if_node, IfStmt), if_node
205 cases = []
206 default_block = util._collect_cases(self.module_path, if_node, cases)
207 with pass_state.CfgBranchContext(
208 cfg, self.current_statement_id) as branch_ctx:
209 for expr, body in cases:
210 self.accept(expr)
211 assert expr is not None, expr
212 with branch_ctx.AddBranch():
213 self.accept(body)
214
215 if default_block:
216 with branch_ctx.AddBranch():
217 self.accept(default_block)
218
219 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
220 cfg = self.current_cfg()
221 assert len(o.expr) == 1, o.expr
222 expr = o.expr[0]
223 assert isinstance(expr, CallExpr), expr
224 self.accept(expr)
225
226 callee_name = expr.callee.name
227 if callee_name == 'switch':
228 self._handle_switch(expr, o, cfg)
229 elif callee_name == 'str_switch':
230 self._handle_switch(expr, o, cfg)
231 elif callee_name == 'tagswitch':
232 self._handle_switch(expr, o, cfg)
233 else:
234 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
235 self.accept(o.body)
236
237 def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
238 if o.name == '__repr__': # Don't translate
239 return
240
241 # For virtual methods, pretend that the method on the base class calls
242 # the same method on every subclass. This way call sites using the
243 # abstract base class will over-approximate the set of call paths they
244 # can take when checking if they can reach MaybeCollect().
245 if self.current_class_name and self.virtual.IsVirtual(
246 self.current_class_name, o.name):
247 key = (self.current_class_name, o.name)
248 base = self.virtual.virtuals[key]
249 if base:
250 sub = join_name(self.current_class_name + (o.name, ),
251 delim='.')
252 base_key = base[0] + (base[1], )
253 cfg = self.cfgs[base_key]
254 cfg.AddFact(0, pass_state.FunctionCall(sub))
255
256 self.current_func_node = o
257 self.accept(o.body)
258 self.current_func_node = None
259 self.current_statement_id = None
260
261 def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
262 self.current_class_name = split_py_name(o.fullname)
263 for stmt in o.defs.body:
264 # Ignore things that look like docstrings
265 if (isinstance(stmt, ExpressionStmt) and
266 isinstance(stmt.expr, StrExpr)):
267 continue
268
269 if isinstance(stmt, FuncDef) and stmt.name == '__repr__':
270 continue
271
272 self.accept(stmt)
273
274 self.current_class_name = None
275
276 # Statements
277
278 def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
279 for lval in o.lvalues:
280 self.accept(lval)
281
282 self.accept(o.rvalue)
283
284 def visit_block(self, block: 'mypy.nodes.Block') -> T:
285 for stmt in block.body:
286 # Ignore things that look like docstrings
287 if (isinstance(stmt, ExpressionStmt) and
288 isinstance(stmt.expr, StrExpr)):
289 continue
290
291 self.accept(stmt)
292
293 def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T:
294 self.accept(o.expr)
295
296 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
297 cfg = self.current_cfg()
298 with pass_state.CfgLoopContext(
299 cfg, entry=self.current_statement_id) as loop:
300 self.accept(o.expr)
301 self.loop_stack.append(loop)
302 self.accept(o.body)
303 self.loop_stack.pop()
304
305 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
306 cfg = self.current_cfg()
307 if cfg:
308 cfg.AddDeadend(self.current_statement_id)
309
310 if o.expr:
311 self.accept(o.expr)
312
313 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
314 cfg = self.current_cfg()
315 for expr in o.expr:
316 self.accept(expr)
317
318 with pass_state.CfgBranchContext(
319 cfg, self.current_statement_id) as branch_ctx:
320 with branch_ctx.AddBranch():
321 for node in o.body:
322 self.accept(node)
323
324 if o.else_body:
325 with branch_ctx.AddBranch():
326 self.accept(o.else_body)
327
328 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
329 if len(self.loop_stack):
330 self.loop_stack[-1].AddBreak(self.current_statement_id)
331
332 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
333 if len(self.loop_stack):
334 self.loop_stack[-1].AddContinue(self.current_statement_id)
335
336 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
337 cfg = self.current_cfg()
338 if cfg:
339 cfg.AddDeadend(self.current_statement_id)
340
341 if o.expr:
342 self.accept(o.expr)
343
344 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
345 cfg = self.current_cfg()
346 with pass_state.CfgBranchContext(cfg,
347 self.current_statement_id) as try_ctx:
348 with try_ctx.AddBranch() as try_block:
349 self.accept(o.body)
350
351 for t, v, handler in zip(o.types, o.vars, o.handlers):
352 with try_ctx.AddBranch(try_block.exit):
353 self.accept(handler)
354
355 def visit_del_stmt(self, o: 'mypy.nodes.DelStmt') -> T:
356 self.accept(o.expr)
357
358 # Expressions
359
360 def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> T:
361 self.accept(o.expr)
362
363 def visit_yield_expr(self, o: 'mypy.nodes.YieldExpr') -> T:
364 self.accept(o.expr)
365
366 def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T:
367 self.accept(o.left)
368 self.accept(o.right)
369
370 def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T:
371 for operand in o.operands:
372 self.accept(operand)
373
374 def visit_unary_expr(self, o: 'mypy.nodes.UnaryExpr') -> T:
375 self.accept(o.expr)
376
377 def visit_list_expr(self, o: 'mypy.nodes.ListExpr') -> T:
378 if o.items:
379 for item in o.items:
380 self.accept(item)
381
382 def visit_dict_expr(self, o: 'mypy.nodes.DictExpr') -> T:
383 if o.items:
384 for k, v in o.items:
385 self.accept(k)
386 self.accept(v)
387
388 def visit_tuple_expr(self, o: 'mypy.nodes.TupleExpr') -> T:
389 if o.items:
390 for item in o.items:
391 self.accept(item)
392
393 def visit_index_expr(self, o: 'mypy.nodes.IndexExpr') -> T:
394 self.accept(o.base)
395
396 def visit_slice_expr(self, o: 'mypy.nodes.SliceExpr') -> T:
397 if o.begin_index:
398 self.accept(o.begin_index)
399
400 if o.end_index:
401 self.accept(o.end_index)
402
403 if o.stride:
404 self.accept(o.stride)
405
406 def visit_conditional_expr(self, o: 'mypy.nodes.ConditionalExpr') -> T:
407 self.accept(o.cond)
408 self.accept(o.if_expr)
409 self.accept(o.else_expr)
410
411 def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
412 cfg = self.current_cfg()
413 if self.current_func_node:
414 full_callee = self.resolve_callee(o)
415 if full_callee:
416 self.callees[o] = full_callee
417 cfg.AddFact(
418 self.current_statement_id,
419 pass_state.FunctionCall(join_name(full_callee, delim='.')))
420
421 self.accept(o.callee)
422 for arg in o.args:
423 self.accept(arg)