OILS / mycpp / control_flow_pass.py View on Github | oilshell.org

233 lines, 162 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4from typing import overload, Union, Optional, Dict
5
6import mypy
7from mypy.visitor import ExpressionVisitor, StatementVisitor
8from mypy.nodes import (Block, Expression, Statement, ExpressionStmt, StrExpr,
9 ForStmt, WhileStmt, CallExpr, FuncDef, IfStmt)
10
11from mypy.types import Type
12
13from mycpp.crash import catch_errors
14from mycpp.util import split_py_name
15from mycpp import util
16from mycpp import pass_state
17
18T = None # TODO: Make it type check?
19
20
21class UnsupportedException(Exception):
22 pass
23
24
25class Build(ExpressionVisitor[T], StatementVisitor[None]):
26
27 def __init__(self, types: Dict[Expression, Type]):
28
29 self.types = types
30 self.cfgs = {}
31 self.current_statement_id = None
32 self.current_func_node = None
33 self.loop_stack = []
34
35 def current_cfg(self):
36 if not self.current_func_node:
37 return None
38
39 return self.cfgs.get(split_py_name(self.current_func_node.fullname))
40
41 #
42 # COPIED from IRBuilder
43 #
44
45 @overload
46 def accept(self, node: Expression) -> T:
47 ...
48
49 @overload
50 def accept(self, node: Statement) -> None:
51 ...
52
53 def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
54 with catch_errors(self.module_path, node.line):
55 if isinstance(node, Expression):
56 try:
57 res = node.accept(self)
58 #res = self.coerce(res, self.node_type(node), node.line)
59
60 # If we hit an error during compilation, we want to
61 # keep trying, so we can produce more error
62 # messages. Generate a temp of the right type to keep
63 # from causing more downstream trouble.
64 except UnsupportedException:
65 res = self.alloc_temp(self.node_type(node))
66 return res
67 else:
68 try:
69 cfg = self.current_cfg()
70 # Most statements have empty visitors because they don't
71 # require any special logic. Create statements for them
72 # here. Blocks and loops are handled by their visitors.
73 if cfg and not isinstance(node, Block) and not isinstance(node, ForStmt) and not isinstance(node, WhileStmt):
74 self.current_statement_id = cfg.AddStatement()
75
76 node.accept(self)
77 except UnsupportedException:
78 pass
79 return None
80
81 # Not in superclasses:
82
83 def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
84 if util.ShouldSkipPyFile(o):
85 return
86
87 self.module_path = o.path
88
89 for node in o.defs:
90 # skip module docstring
91 if isinstance(node, ExpressionStmt) and isinstance(
92 node.expr, StrExpr):
93 continue
94 self.accept(node)
95
96 # LITERALS
97
98 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
99 cfg = self.current_cfg()
100 if not cfg:
101 return
102
103 with pass_state.CfgLoopContext(cfg) as loop:
104 self.loop_stack.append(loop)
105 self.accept(o.body)
106 self.loop_stack.pop()
107
108 def _handle_switch(self, expr, o, cfg):
109 assert len(o.body.body) == 1, o.body.body
110 if_node = o.body.body[0]
111 assert isinstance(if_node, IfStmt), if_node
112 cases = []
113 default_block = util._collect_cases(self.module_path, if_node, cases)
114 with pass_state.CfgBranchContext(cfg, self.current_statement_id) as branch_ctx:
115 for expr, body in cases:
116 assert expr is not None, expr
117 with branch_ctx.AddBranch():
118 self.accept(body)
119
120 if default_block:
121 with branch_ctx.AddBranch():
122 self.accept(default_block)
123
124 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
125 cfg = self.current_cfg()
126 if not cfg:
127 return
128
129 assert len(o.expr) == 1, o.expr
130 expr = o.expr[0]
131 assert isinstance(expr, CallExpr), expr
132
133 callee_name = expr.callee.name
134 if callee_name == 'switch':
135 self._handle_switch(expr, o, cfg)
136 elif callee_name == 'str_switch':
137 self._handle_switch(expr, o, cfg)
138 elif callee_name == 'tagswitch':
139 self._handle_switch(expr, o, cfg)
140 else:
141 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
142 for stmt in o.body.body:
143 self.accept(stmt)
144
145 def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
146 if o.name == '__repr__': # Don't translate
147 return
148
149 self.cfgs[split_py_name(o.fullname)] = pass_state.ControlFlowGraph()
150 self.current_func_node = o
151 self.accept(o.body)
152 self.current_func_node = None
153
154 def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
155 for stmt in o.defs.body:
156 # Ignore things that look like docstrings
157 if (isinstance(stmt, ExpressionStmt) and
158 isinstance(stmt.expr, StrExpr)):
159 continue
160
161 if isinstance(stmt, FuncDef) and stmt.name == '__repr__':
162 continue
163
164 self.accept(stmt)
165
166 # Statements
167
168 def visit_block(self, block: 'mypy.nodes.Block') -> T:
169 for stmt in block.body:
170 # Ignore things that look like docstrings
171 if (isinstance(stmt, ExpressionStmt) and
172 isinstance(stmt.expr, StrExpr)):
173 continue
174
175 self.accept(stmt)
176
177 def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T:
178 self.accept(o.expr)
179
180 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
181 cfg = self.current_cfg()
182 if not cfg:
183 return
184
185 with pass_state.CfgLoopContext(cfg) as loop:
186 self.loop_stack.append(loop)
187 self.accept(o.body)
188 self.loop_stack.pop()
189
190 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
191 cfg = self.current_cfg()
192 if cfg:
193 cfg.AddDeadend(self.current_statement_id)
194
195 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
196 cfg = self.current_cfg()
197 if not cfg:
198 return
199
200 with pass_state.CfgBranchContext(cfg, self.current_statement_id) as branch_ctx:
201 with branch_ctx.AddBranch():
202 for node in o.body:
203 self.accept(node)
204
205 if o.else_body:
206 with branch_ctx.AddBranch():
207 self.accept(o.else_body)
208
209 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
210 if len(self.loop_stack):
211 self.loop_stack[-1].AddBreak(self.current_statement_id)
212
213 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
214 if len(self.loop_stack):
215 self.loop_stack[-1].AddContinue(self.current_statement_id)
216
217 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
218 cfg = self.current_cfg()
219 if cfg:
220 cfg.AddDeadend(self.current_statement_id)
221
222 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
223 cfg = self.current_cfg()
224 if not cfg:
225 return
226
227 with pass_state.CfgBranchContext(cfg, self.current_statement_id) as try_ctx:
228 with try_ctx.AddBranch() as try_block:
229 self.accept(o.body)
230
231 for t, v, handler in zip(o.types, o.vars, o.handlers):
232 with try_ctx.AddBranch(try_block.exit):
233 self.accept(handler)