OILS / mycpp / const_pass.py View on Github | oilshell.org

531 lines, 335 significant
1"""
2const_pass.py - AST pass that collects constants.
3
4Immutable string constants like 'StrFromC("foo")' are moved to the top level of
5the generated C++ program for efficiency.
6"""
7import json
8
9from typing import overload, Union, Optional, Dict, List
10
11import mypy
12from mypy.visitor import ExpressionVisitor, StatementVisitor
13from mypy.nodes import (Expression, Statement, ExpressionStmt, StrExpr,
14 ComparisonExpr, NameExpr, MemberExpr)
15
16from mypy.types import Type
17
18from mycpp.crash import catch_errors
19from mycpp import format_strings
20from mycpp.util import log
21from mycpp import util
22
23T = None # TODO: Make it type check?
24
25
26class UnsupportedException(Exception):
27 pass
28
29
30class Collect(ExpressionVisitor[T], StatementVisitor[None]):
31
32 def __init__(self, types: Dict[Expression, Type],
33 const_lookup: Dict[Expression, str], const_code: List[str]):
34
35 self.types = types
36 self.const_lookup = const_lookup
37 self.const_code = const_code
38 self.unique_id = 0
39
40 self.indent = 0
41
42 def out(self, msg, *args):
43 if args:
44 msg = msg % args
45 self.const_code.append(msg)
46
47 #
48 # COPIED from IRBuilder
49 #
50
51 @overload
52 def accept(self, node: Expression) -> T:
53 ...
54
55 @overload
56 def accept(self, node: Statement) -> None:
57 ...
58
59 def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
60 with catch_errors(self.module_path, node.line):
61 if isinstance(node, Expression):
62 try:
63 res = node.accept(self)
64 #res = self.coerce(res, self.node_type(node), node.line)
65
66 # If we hit an error during compilation, we want to
67 # keep trying, so we can produce more error
68 # messages. Generate a temp of the right type to keep
69 # from causing more downstream trouble.
70 except UnsupportedException:
71 res = self.alloc_temp(self.node_type(node))
72 return res
73 else:
74 try:
75 node.accept(self)
76 except UnsupportedException:
77 pass
78 return None
79
80 def log(self, msg, *args):
81 if 0: # quiet
82 ind_str = self.indent * ' '
83 log(ind_str + msg, *args)
84
85 # Not in superclasses:
86
87 def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
88 # Skip some stdlib stuff. A lot of it is brought in by 'import
89 # typing'.
90 if o.fullname in ('__future__', 'sys', 'types', 'typing', 'abc',
91 '_ast', 'ast', '_weakrefset', 'collections',
92 'cStringIO', 're', 'builtins'):
93
94 # These module are special; their contents are currently all
95 # built-in primitives.
96 return
97
98 self.module_path = o.path
99
100 self.indent += 1
101 for node in o.defs:
102 # skip module docstring
103 if isinstance(node, ExpressionStmt) and isinstance(
104 node.expr, StrExpr):
105 continue
106 self.accept(node)
107 self.indent -= 1
108
109 # LITERALS
110
111 def visit_int_expr(self, o: 'mypy.nodes.IntExpr') -> T:
112 self.log('IntExpr %d', o.value)
113
114 def visit_str_expr(self, o: 'mypy.nodes.StrExpr') -> T:
115 # - Need new BigStr() everywhere because "foo" doesn't match BigStr* :-(
116
117 id_ = 'str%d' % self.unique_id
118 self.unique_id += 1
119
120 raw_string = format_strings.DecodeMyPyString(o.value)
121
122 if util.SMALL_STR:
123 self.out('GLOBAL_STR2(%s, %s);', id_, json.dumps(raw_string))
124 else:
125 self.out('GLOBAL_STR(%s, %s);', id_, json.dumps(raw_string))
126
127 self.const_lookup[o] = id_
128
129 def visit_bytes_expr(self, o: 'mypy.nodes.BytesExpr') -> T:
130 pass
131
132 def visit_unicode_expr(self, o: 'mypy.nodes.UnicodeExpr') -> T:
133 pass
134
135 def visit_float_expr(self, o: 'mypy.nodes.FloatExpr') -> T:
136 pass
137
138 def visit_complex_expr(self, o: 'mypy.nodes.ComplexExpr') -> T:
139 pass
140
141 # Expression
142
143 def visit_ellipsis(self, o: 'mypy.nodes.EllipsisExpr') -> T:
144 pass
145
146 def visit_star_expr(self, o: 'mypy.nodes.StarExpr') -> T:
147 pass
148
149 def visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> T:
150 #self.log('NameExpr %s', o.name)
151 pass
152
153 def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> T:
154 if o.expr:
155 self.accept(o.expr)
156
157 def visit_yield_from_expr(self, o: 'mypy.nodes.YieldFromExpr') -> T:
158 pass
159
160 def visit_yield_expr(self, o: 'mypy.nodes.YieldExpr') -> T:
161 pass
162
163 def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
164 self.log('CallExpr')
165 self.accept(o.callee) # could be f() or obj.method()
166
167 self.indent += 1
168 for arg in o.args:
169 self.accept(arg)
170 # The type of each argument
171 #self.log(':: %s', self.types[arg])
172 self.indent -= 1
173 #self.log( 'args %s', o.args)
174
175 #self.log(' arg_kinds %s', o.arg_kinds)
176 #self.log(' arg_names %s', o.arg_names)
177
178 def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T:
179 self.log('OpExpr')
180 self.indent += 1
181 self.accept(o.left)
182 self.accept(o.right)
183 self.indent -= 1
184
185 def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T:
186 self.log('ComparisonExpr')
187 self.log(' operators %s', o.operators)
188 self.indent += 1
189
190 for operand in o.operands:
191 self.indent += 1
192 self.accept(operand)
193 self.indent -= 1
194
195 self.indent -= 1
196
197 def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> T:
198 pass
199
200 def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> T:
201 pass
202
203 def visit_super_expr(self, o: 'mypy.nodes.SuperExpr') -> T:
204 pass
205
206 def visit_assignment_expr(self, o: 'mypy.nodes.AssignmentExpr') -> T:
207 pass
208
209 def visit_unary_expr(self, o: 'mypy.nodes.UnaryExpr') -> T:
210 # e.g. a[-1] or 'not x'
211 self.accept(o.expr)
212
213 def visit_list_expr(self, o: 'mypy.nodes.ListExpr') -> T:
214 # lists are MUTABLE, so we can't generate constants at the top level
215
216 # but we want to visit the string literals!
217 for item in o.items:
218 self.accept(item)
219
220 def visit_dict_expr(self, o: 'mypy.nodes.DictExpr') -> T:
221 for k, v in o.items:
222 self.accept(k)
223 self.accept(v)
224
225 def visit_tuple_expr(self, o: 'mypy.nodes.TupleExpr') -> T:
226 for item in o.items:
227 self.accept(item)
228
229 def visit_set_expr(self, o: 'mypy.nodes.SetExpr') -> T:
230 pass
231
232 def visit_index_expr(self, o: 'mypy.nodes.IndexExpr') -> T:
233 self.accept(o.base)
234 self.accept(o.index)
235
236 def visit_type_application(self, o: 'mypy.nodes.TypeApplication') -> T:
237 pass
238
239 def visit_lambda_expr(self, o: 'mypy.nodes.LambdaExpr') -> T:
240 pass
241
242 def visit_list_comprehension(self, o: 'mypy.nodes.ListComprehension') -> T:
243 gen = o.generator # GeneratorExpr
244 left_expr = gen.left_expr
245 index_expr = gen.indices[0]
246 seq = gen.sequences[0]
247 cond = gen.condlists[0]
248
249 # We might use all of these, so collect constants.
250 self.accept(left_expr)
251 self.accept(index_expr)
252 self.accept(seq)
253 for c in cond:
254 self.accept(c)
255
256 def visit_set_comprehension(self, o: 'mypy.nodes.SetComprehension') -> T:
257 pass
258
259 def visit_dictionary_comprehension(
260 self, o: 'mypy.nodes.DictionaryComprehension') -> T:
261 pass
262
263 def visit_generator_expr(self, o: 'mypy.nodes.GeneratorExpr') -> T:
264 pass
265
266 def visit_slice_expr(self, o: 'mypy.nodes.SliceExpr') -> T:
267 if o.begin_index:
268 self.accept(o.begin_index)
269
270 if o.end_index:
271 self.accept(o.end_index)
272
273 if o.stride:
274 self.accept(o.stride)
275
276 def visit_conditional_expr(self, o: 'mypy.nodes.ConditionalExpr') -> T:
277 self.accept(o.cond)
278 self.accept(o.if_expr)
279 self.accept(o.else_expr)
280
281 def visit_backquote_expr(self, o: 'mypy.nodes.BackquoteExpr') -> T:
282 pass
283
284 def visit_type_var_expr(self, o: 'mypy.nodes.TypeVarExpr') -> T:
285 pass
286
287 def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T:
288 pass
289
290 def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T:
291 pass
292
293 def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T:
294 pass
295
296 def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T:
297 pass
298
299 def visit_newtype_expr(self, o: 'mypy.nodes.NewTypeExpr') -> T:
300 pass
301
302 def visit__promote_expr(self, o: 'mypy.nodes.PromoteExpr') -> T:
303 pass
304
305 def visit_await_expr(self, o: 'mypy.nodes.AwaitExpr') -> T:
306 pass
307
308 def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T:
309 pass
310
311 def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
312 # How does this get reached??
313
314 # Ah wtf, why is there no type on here!
315 # I thought we did parse_and_typecheck already?
316
317 if 1:
318 self.log('AssignmentStmt')
319 #self.log(' type %s', o.type)
320 #self.log(' unanalyzed_type %s', o.unanalyzed_type)
321
322 # NICE! Got the lvalue
323 for lval in o.lvalues:
324 try:
325 self.log(' lval %s :: %s', lval, self.types[lval])
326 except KeyError: # TODO: handle this
327 pass
328 self.accept(lval)
329
330 try:
331 r = self.types[o.rvalue]
332 except KeyError:
333 # This seems to only happen for Ellipsis, I guess in the abc module
334 #log(' NO TYPE FOR RVALUE: %s', o.rvalue)
335 pass
336 else:
337 #self.log(' %s :: %s', o.rvalue, r)
338 self.indent += 1
339 #self.log(' rvalue :: %s', r)
340 self.accept(o.rvalue)
341 self.indent -= 1
342 #self.log(' o.rvalue %s', o.rvalue)
343
344 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
345 self.log('ForStmt')
346 #self.log(' index_type %s', o.index_type)
347 #self.log(' inferred_item_type %s', o.inferred_item_type)
348 #self.log(' inferred_iterator_type %s', o.inferred_iterator_type)
349 self.accept(o.index) # index var expression
350 self.accept(o.expr) # the thing being iterated over
351 self.accept(o.body)
352 if o.else_body:
353 raise AssertionError("can't translate for-else")
354
355 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
356 assert len(o.expr) == 1, o.expr
357 self.accept(o.expr[0])
358 self.accept(o.body)
359
360 def visit_del_stmt(self, o: 'mypy.nodes.DelStmt') -> T:
361 self.accept(o.expr)
362
363 def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
364 # got the type here, nice!
365 typ = o.type
366 self.log('FuncDef %s :: %s', o.name, typ)
367 #self.log('%s', type(typ))
368
369 for t, name in zip(typ.arg_types, typ.arg_names):
370 self.log(' arg %s %s', t, name)
371 self.log(' ret %s', o.type.ret_type)
372
373 self.indent += 1
374 for arg in o.arguments:
375 # e.g. foo=''
376 if arg.initializer:
377 self.accept(arg.initializer)
378
379 # We can't use __str__ on these Argument objects? That seems like an
380 # oversight
381 #self.log('%r', arg)
382
383 self.log('Argument %s', arg.variable)
384 self.log(' type_annotation %s', arg.type_annotation)
385 # I think these are for default values
386 self.log(' initializer %s', arg.initializer)
387 self.log(' kind %s', arg.kind)
388
389 self.accept(o.body)
390 self.indent -= 1
391
392 def visit_overloaded_func_def(self,
393 o: 'mypy.nodes.OverloadedFuncDef') -> T:
394 pass
395
396 def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
397 self.log('const_pass ClassDef %s', o.name)
398 for b in o.base_type_exprs:
399 self.log(' base_type_expr %s', b)
400 self.indent += 1
401 self.accept(o.defs)
402 self.indent -= 1
403
404 def visit_global_decl(self, o: 'mypy.nodes.GlobalDecl') -> T:
405 pass
406
407 def visit_nonlocal_decl(self, o: 'mypy.nodes.NonlocalDecl') -> T:
408 pass
409
410 def visit_decorator(self, o: 'mypy.nodes.Decorator') -> T:
411 pass
412
413 def visit_var(self, o: 'mypy.nodes.Var') -> T:
414 pass
415
416 # Module structure
417
418 def visit_import(self, o: 'mypy.nodes.Import') -> T:
419 pass
420
421 def visit_import_from(self, o: 'mypy.nodes.ImportFrom') -> T:
422 pass
423
424 def visit_import_all(self, o: 'mypy.nodes.ImportAll') -> T:
425 pass
426
427 # Statements
428
429 def visit_block(self, block: 'mypy.nodes.Block') -> T:
430 self.log('Block')
431 self.indent += 1
432 for stmt in block.body:
433 # Ignore things that look like docstrings
434 if isinstance(stmt, ExpressionStmt) and isinstance(
435 stmt.expr, StrExpr):
436 continue
437 #log('-- %d', self.indent)
438 self.accept(stmt)
439 self.indent -= 1
440
441 def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T:
442 self.log('ExpressionStmt')
443 self.indent += 1
444 self.accept(o.expr)
445 self.indent -= 1
446
447 def visit_operator_assignment_stmt(
448 self, o: 'mypy.nodes.OperatorAssignmentStmt') -> T:
449 self.log('OperatorAssignmentStmt')
450
451 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
452 self.log('WhileStmt')
453 self.accept(o.expr)
454 self.accept(o.body)
455
456 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
457 self.log('ReturnStmt')
458 if o.expr:
459 self.accept(o.expr)
460
461 def visit_assert_stmt(self, o: 'mypy.nodes.AssertStmt') -> T:
462 pass
463
464 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
465 # Copied from cppgen_pass.py
466 # Not sure why this wouldn't be true
467 assert len(o.expr) == 1, o.expr
468
469 # Omit anything that looks like if __name__ == ...
470 cond = o.expr[0]
471 if (isinstance(cond, ComparisonExpr) and
472 isinstance(cond.operands[0], NameExpr) and
473 cond.operands[0].name == '__name__'):
474 return
475
476 # Omit if TYPE_CHECKING blocks. They contain type expressions that
477 # don't type check!
478 if isinstance(cond, NameExpr) and cond.name == 'TYPE_CHECKING':
479 return
480 # mylib.CPP
481 if isinstance(cond, MemberExpr) and cond.name == 'CPP':
482 # just take the if block
483 for node in o.body:
484 self.accept(node)
485 return
486 # mylib.PYTHON
487 if isinstance(cond, MemberExpr) and cond.name == 'PYTHON':
488 if o.else_body:
489 self.accept(o.else_body)
490 return
491
492 self.log('IfStmt')
493 self.indent += 1
494 for e in o.expr:
495 self.accept(e)
496
497 for node in o.body:
498 self.accept(node)
499
500 if o.else_body:
501 self.accept(o.else_body)
502 self.indent -= 1
503
504 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
505 pass
506
507 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
508 pass
509
510 def visit_pass_stmt(self, o: 'mypy.nodes.PassStmt') -> T:
511 pass
512
513 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
514 if o.expr:
515 self.accept(o.expr)
516
517 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
518 self.accept(o.body)
519 for t, v, handler in zip(o.types, o.vars, o.handlers):
520 self.accept(handler)
521
522 #if o.else_body:
523 # raise AssertionError('try/else not supported')
524 #if o.finally_body:
525 # raise AssertionError('try/finally not supported')
526
527 def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T:
528 pass
529
530 def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> T:
531 pass