OILS / asdl / gen_python.py View on Github | oilshell.org

610 lines, 394 significant
1#!/usr/bin/env python2
2"""gen_python.py: Generate Python code from an ASDL schema."""
3from __future__ import print_function
4
5from collections import defaultdict
6
7from asdl import ast
8from asdl import visitor
9from asdl.util import log
10
11_ = log # shut up lint
12
13_PRIMITIVES = {
14 'string': 'str',
15 'int': 'int',
16 'BigInt': 'mops.BigInt',
17 'float': 'float',
18 'bool': 'bool',
19 'any': 'Any',
20 # TODO: frontend/syntax.asdl should properly import id enum instead of
21 # hard-coding it here.
22 'id': 'Id_t',
23}
24
25
26def _MyPyType(typ):
27 """ASDL type to MyPy Type."""
28 if isinstance(typ, ast.ParameterizedType):
29
30 if typ.type_name == 'Dict':
31 k_type = _MyPyType(typ.children[0])
32 v_type = _MyPyType(typ.children[1])
33 return 'Dict[%s, %s]' % (k_type, v_type)
34
35 if typ.type_name == 'List':
36 return 'List[%s]' % _MyPyType(typ.children[0])
37
38 if typ.type_name == 'Optional':
39 return 'Optional[%s]' % _MyPyType(typ.children[0])
40
41 elif isinstance(typ, ast.NamedType):
42 if typ.resolved:
43 if isinstance(typ.resolved, ast.Sum): # includes SimpleSum
44 return '%s_t' % typ.name
45 if isinstance(typ.resolved, ast.Product):
46 return typ.name
47 if isinstance(typ.resolved, ast.Use):
48 return ast.TypeNameHeuristic(typ.name)
49
50 # 'id' falls through here
51 return _PRIMITIVES[typ.name]
52
53 else:
54 raise AssertionError()
55
56
57def _DefaultValue(typ, mypy_type):
58 """Values that the static CreateNull() constructor passes.
59
60 mypy_type is used to cast None, to maintain mypy --strict for ASDL.
61
62 We circumvent the type system on CreateNull(). Then the user is
63 responsible for filling in all the fields. If they do so, we can
64 rely on it when reading fields at runtime.
65 """
66 if isinstance(typ, ast.ParameterizedType):
67 type_name = typ.type_name
68
69 if type_name == 'Optional':
70 return "cast('%s', None)" % mypy_type
71
72 if type_name == 'List':
73 return "[] if alloc_lists else cast('%s', None)" % mypy_type
74
75 if type_name == 'Dict': # TODO: can respect alloc_dicts=True
76 return "cast('%s', None)" % mypy_type
77
78 raise AssertionError(type_name)
79
80 if isinstance(typ, ast.NamedType):
81 type_name = typ.name
82
83 if type_name == 'id': # hard-coded HACK
84 return '-1'
85
86 if type_name == 'int':
87 return '-1'
88
89 if type_name == 'BigInt':
90 return 'mops.BigInt(-1)'
91
92 if type_name == 'bool':
93 return 'False'
94
95 if type_name == 'float':
96 return '0.0' # or should it be NaN?
97
98 if type_name == 'string':
99 return "''"
100
101 if isinstance(typ.resolved, ast.SimpleSum):
102 sum_type = typ.resolved
103 # Just make it the first variant. We could define "Undef" for
104 # each enum, but it doesn't seem worth it.
105 return '%s_e.%s' % (type_name, sum_type.types[0].name)
106
107 # CompoundSum or Product type
108 return 'cast(%s, None)' % mypy_type
109
110 else:
111 raise AssertionError()
112
113
114def _HNodeExpr(abbrev, typ, var_name):
115 # type: (str, ast.TypeExpr, str) -> str
116 none_guard = False
117
118 if typ.IsOptional():
119 typ = typ.children[0] # descend one level
120
121 if isinstance(typ, ast.ParameterizedType):
122 code_str = '%s.%s()' % (var_name, abbrev)
123 none_guard = True
124
125 elif isinstance(typ, ast.NamedType):
126 type_name = typ.name
127
128 if type_name == 'bool':
129 code_str = "hnode.Leaf('T' if %s else 'F', color_e.OtherConst)" % var_name
130
131 elif type_name == 'int':
132 code_str = 'hnode.Leaf(str(%s), color_e.OtherConst)' % var_name
133
134 elif type_name == 'BigInt':
135 code_str = 'hnode.Leaf(mops.ToStr(%s), color_e.OtherConst)' % var_name
136
137 elif type_name == 'float':
138 code_str = 'hnode.Leaf(str(%s), color_e.OtherConst)' % var_name
139
140 elif type_name == 'string':
141 code_str = 'NewLeaf(%s, color_e.StringConst)' % var_name
142
143 elif type_name == 'any': # TODO: Remove this. Used for value.Obj().
144 code_str = 'hnode.External(%s)' % var_name
145
146 elif type_name == 'id': # was meta.UserType
147 # This assumes it's Id, which is a simple SumType. TODO: Remove this.
148 code_str = 'hnode.Leaf(Id_str(%s), color_e.UserType)' % var_name
149
150 elif typ.resolved and isinstance(typ.resolved, ast.SimpleSum):
151 code_str = 'hnode.Leaf(%s_str(%s), color_e.TypeName)' % (type_name,
152 var_name)
153
154 else:
155 code_str = '%s.%s(trav=trav)' % (var_name, abbrev)
156 none_guard = True
157
158 else:
159 raise AssertionError()
160
161 return code_str, none_guard
162
163
164class GenMyPyVisitor(visitor.AsdlVisitor):
165 """Generate Python code with MyPy type annotations."""
166
167 def __init__(self,
168 f,
169 abbrev_mod_entries=None,
170 pretty_print_methods=True,
171 py_init_n=False,
172 simple_int_sums=None):
173
174 visitor.AsdlVisitor.__init__(self, f)
175 self.abbrev_mod_entries = abbrev_mod_entries or []
176 self.pretty_print_methods = pretty_print_methods
177 self.py_init_n = py_init_n
178
179 # For Id to use different code gen. It's used like an integer, not just
180 # like an enum.
181 self.simple_int_sums = simple_int_sums or []
182
183 self._shared_type_tags = {}
184 self._product_counter = 64 # matches asdl/gen_cpp.py
185
186 self._products = []
187 self._product_bases = defaultdict(list)
188
189 def _EmitDict(self, name, d, depth):
190 self.Emit('_%s_str = {' % name, depth)
191 for k in sorted(d):
192 self.Emit('%d: %r,' % (k, d[k]), depth + 1)
193 self.Emit('}', depth)
194 self.Emit('', depth)
195
196 def VisitSimpleSum(self, sum, sum_name, depth):
197 int_to_str = {}
198 variants = []
199 for i, variant in enumerate(sum.types):
200 tag_num = i + 1
201 tag_str = '%s.%s' % (sum_name, variant.name)
202 int_to_str[tag_num] = tag_str
203 variants.append((variant, tag_num))
204
205 add_suffix = not ('no_namespace_suffix' in sum.generate)
206 gen_integers = 'integers' in sum.generate
207
208 if gen_integers:
209 self.Emit('%s_t = int # type alias for integer' % sum_name)
210 self.Emit('')
211
212 i_name = ('%s_i' % sum_name) if add_suffix else sum_name
213
214 self.Emit('class %s(object):' % i_name, depth)
215
216 for variant, tag_num in variants:
217 line = ' %s = %d' % (variant.name, tag_num)
218 self.Emit(line, depth)
219
220 # Help in sizing array. Note that we're 1-based.
221 line = ' %s = %d' % ('ARRAY_SIZE', len(variants) + 1)
222 self.Emit(line, depth)
223
224 else:
225 # First emit a type
226 self.Emit('class %s_t(pybase.SimpleObj):' % sum_name, depth)
227 self.Emit(' pass', depth)
228 self.Emit('', depth)
229
230 # Now emit a namespace
231 e_name = ('%s_e' % sum_name) if add_suffix else sum_name
232 self.Emit('class %s(object):' % e_name, depth)
233
234 for variant, tag_num in variants:
235 line = ' %s = %s_t(%d)' % (variant.name, sum_name, tag_num)
236 self.Emit(line, depth)
237
238 self.Emit('', depth)
239
240 self._EmitDict(sum_name, int_to_str, depth)
241
242 self.Emit('def %s_str(val):' % sum_name, depth)
243 self.Emit(' # type: (%s_t) -> str' % sum_name, depth)
244 self.Emit(' return _%s_str[val]' % sum_name, depth)
245 self.Emit('', depth)
246
247 def _EmitCodeForField(self, abbrev, field, counter):
248 """Generate code that returns an hnode for a field."""
249 out_val_name = 'x%d' % counter
250
251 if field.typ.IsList():
252 iter_name = 'i%d' % counter
253
254 typ = field.typ
255 if typ.type_name == 'Optional': # descend one level
256 typ = typ.children[0]
257 item_type = typ.children[0]
258
259 self.Emit(' if self.%s is not None: # List' % field.name)
260 self.Emit(' %s = hnode.Array([])' % out_val_name)
261 self.Emit(' for %s in self.%s:' % (iter_name, field.name))
262 child_code_str, none_guard = _HNodeExpr(abbrev, item_type,
263 iter_name)
264
265 if none_guard: # e.g. for List[Optional[value_t]]
266 # TODO: could consolidate with asdl/runtime.py NewLeaf(), which
267 # also uses _ to mean None/nullptr
268 self.Emit(
269 ' h = (hnode.Leaf("_", color_e.OtherConst) if %s is None else %s)'
270 % (iter_name, child_code_str))
271 self.Emit(' %s.children.append(h)' % out_val_name)
272 else:
273 self.Emit(' %s.children.append(%s)' %
274 (out_val_name, child_code_str))
275
276 self.Emit(' L.append(Field(%r, %s))' %
277 (field.name, out_val_name))
278
279 elif field.typ.IsDict():
280 k = 'k%d' % counter
281 v = 'v%d' % counter
282
283 typ = field.typ
284 if typ.type_name == 'Optional': # descend one level
285 typ = typ.children[0]
286
287 k_typ = typ.children[0]
288 v_typ = typ.children[1]
289
290 k_code_str, _ = _HNodeExpr(abbrev, k_typ, k)
291 v_code_str, _ = _HNodeExpr(abbrev, v_typ, v)
292
293 self.Emit(' if self.%s is not None: # Dict' % field.name)
294 self.Emit(' m = hnode.Leaf("Dict", color_e.OtherConst)')
295 self.Emit(' %s = hnode.Array([m])' % out_val_name)
296 self.Emit(' for %s, %s in self.%s.iteritems():' %
297 (k, v, field.name))
298 self.Emit(' %s.children.append(%s)' %
299 (out_val_name, k_code_str))
300 self.Emit(' %s.children.append(%s)' %
301 (out_val_name, v_code_str))
302 self.Emit(' L.append(Field(%r, %s))' %
303 (field.name, out_val_name))
304
305 elif field.typ.IsOptional():
306 typ = field.typ.children[0]
307
308 self.Emit(' if self.%s is not None: # Optional' % field.name)
309 child_code_str, _ = _HNodeExpr(abbrev, typ, 'self.%s' % field.name)
310 self.Emit(' %s = %s' % (out_val_name, child_code_str))
311 self.Emit(' L.append(Field(%r, %s))' %
312 (field.name, out_val_name))
313
314 else:
315 var_name = 'self.%s' % field.name
316 code_str, obj_none_guard = _HNodeExpr(abbrev, field.typ, var_name)
317 depth = self.current_depth
318 if obj_none_guard: # to satisfy MyPy type system
319 self.Emit(' assert self.%s is not None' % field.name)
320 self.Emit(' %s = %s' % (out_val_name, code_str), depth)
321
322 self.Emit(' L.append(Field(%r, %s))' % (field.name, out_val_name),
323 depth)
324
325 def _GenClass(self,
326 ast_node,
327 class_name,
328 base_classes,
329 tag_num,
330 class_ns=''):
331 """Used for both Sum variants ("constructors") and Product types.
332
333 Args:
334 class_ns: for variants like value.Str
335 """
336 self.Emit('class %s(%s):' % (class_name, ', '.join(base_classes)))
337 self.Emit(' _type_tag = %d' % tag_num)
338
339 all_fields = ast_node.fields
340
341 field_names = [f.name for f in all_fields]
342
343 quoted_fields = repr(tuple(field_names))
344 self.Emit(' __slots__ = %s' % quoted_fields)
345 self.Emit('')
346
347 #
348 # __init__
349 #
350
351 args = [f.name for f in ast_node.fields]
352
353 self.Emit(' def __init__(self, %s):' % ', '.join(args))
354
355 arg_types = []
356 default_vals = []
357 for f in ast_node.fields:
358 mypy_type = _MyPyType(f.typ)
359 arg_types.append(mypy_type)
360
361 d_str = _DefaultValue(f.typ, mypy_type)
362 default_vals.append(d_str)
363
364 self.Emit(' # type: (%s) -> None' % ', '.join(arg_types),
365 reflow=False)
366
367 if not all_fields:
368 self.Emit(' pass') # for types like NoOp
369
370 for f in ast_node.fields:
371 # don't wrap the type comment
372 self.Emit(' self.%s = %s' % (f.name, f.name), reflow=False)
373
374 self.Emit('')
375
376 pretty_cls_name = '%s%s' % (class_ns, class_name)
377
378 if len(all_fields) and not self.py_init_n:
379 self.Emit(' @staticmethod')
380 self.Emit(' def CreateNull(alloc_lists=False):')
381 self.Emit(' # type: () -> %s%s' % (class_ns, class_name))
382 self.Emit(' return %s%s(%s)' %
383 (class_ns, class_name, ', '.join(default_vals)),
384 reflow=False)
385 self.Emit('')
386
387 if not self.pretty_print_methods:
388 return
389
390 #
391 # PrettyTree
392 #
393
394 self.Emit(' def PrettyTree(self, trav=None):')
395 self.Emit(' # type: (Optional[TraversalState]) -> hnode_t')
396 self.Emit(' trav = trav or TraversalState()')
397 self.Emit(' heap_id = id(self)')
398 self.Emit(' if heap_id in trav.seen:')
399 # cut off recursion
400 self.Emit(' return hnode.AlreadySeen(heap_id)')
401 self.Emit(' trav.seen[heap_id] = True')
402
403 self.Emit(' out_node = NewRecord(%r)' % pretty_cls_name)
404 self.Emit(' L = out_node.fields')
405 self.Emit('')
406
407 # Use the runtime type to be more like asdl/format.py
408 for local_id, field in enumerate(all_fields):
409 #log('%s :: %s', field_name, field_desc)
410 self.Indent()
411 self._EmitCodeForField('PrettyTree', field, local_id)
412 self.Dedent()
413 self.Emit('')
414 self.Emit(' return out_node')
415 self.Emit('')
416
417 #
418 # _AbbreviatedTree
419 #
420
421 self.Emit(' def _AbbreviatedTree(self, trav=None):')
422 self.Emit(' # type: (Optional[TraversalState]) -> hnode_t')
423 self.Emit(' trav = trav or TraversalState()')
424 self.Emit(' heap_id = id(self)')
425 self.Emit(' if heap_id in trav.seen:')
426 # cut off recursion
427 self.Emit(' return hnode.AlreadySeen(heap_id)')
428 self.Emit(' trav.seen[heap_id] = True')
429 self.Emit(' out_node = NewRecord(%r)' % pretty_cls_name)
430 self.Emit(' L = out_node.fields')
431
432 for local_id, field in enumerate(ast_node.fields):
433 self.Indent()
434 self._EmitCodeForField('AbbreviatedTree', field, local_id)
435 self.Dedent()
436 self.Emit('')
437 self.Emit(' return out_node')
438 self.Emit('')
439
440 self.Emit(' def AbbreviatedTree(self, trav=None):')
441 self.Emit(' # type: (Optional[TraversalState]) -> hnode_t')
442 abbrev_name = '_%s' % class_name
443 if abbrev_name in self.abbrev_mod_entries:
444 self.Emit(' p = %s(self)' % abbrev_name)
445 # If the user function didn't return anything, fall back.
446 self.Emit(
447 ' return p if p else self._AbbreviatedTree(trav=trav)')
448 else:
449 self.Emit(' return self._AbbreviatedTree(trav=trav)')
450 self.Emit('')
451
452 def VisitCompoundSum(self, sum, sum_name, depth):
453 """Note that the following is_simple:
454
455 cflow = Break | Continue
456
457 But this is compound:
458
459 cflow = Break | Continue | Return(int val)
460
461 The generated code changes depending on which one it is.
462 """
463 #log('%d variants in %s', len(sum.types), sum_name)
464
465 # We emit THREE Python types for each meta.CompoundType:
466 #
467 # 1. enum for tag (cflow_e)
468 # 2. base class for inheritance (cflow_t)
469 # 3. namespace for classes (cflow) -- TODO: Get rid of this one.
470 #
471 # Should code use cflow_e.tag or isinstance()?
472 # isinstance() is better for MyPy I think. But tag is better for C++.
473 # int tag = static_cast<cflow>(node).tag;
474
475 int_to_str = {}
476
477 # enum for the tag
478 self.Emit('class %s_e(object):' % sum_name, depth)
479
480 for i, variant in enumerate(sum.types):
481 if variant.shared_type:
482 tag_num = self._shared_type_tags[variant.shared_type]
483 # e.g. DoubleQuoted may have base types expr_t, word_part_t
484 base_class = sum_name + '_t'
485 bases = self._product_bases[variant.shared_type]
486 if base_class in bases:
487 raise RuntimeError(
488 "Two tags in sum %r refer to product type %r" %
489 (sum_name, variant.shared_type))
490
491 else:
492 bases.append(base_class)
493 else:
494 tag_num = i + 1
495 self.Emit(' %s = %d' % (variant.name, tag_num), depth)
496 int_to_str[tag_num] = variant.name
497 self.Emit('', depth)
498
499 self._EmitDict(sum_name, int_to_str, depth)
500
501 self.Emit('def %s_str(tag, dot=True):' % sum_name, depth)
502 self.Emit(' # type: (int, bool) -> str', depth)
503 self.Emit(' v = _%s_str[tag]' % sum_name, depth)
504 self.Emit(' if dot:', depth)
505 self.Emit(' return "%s.%%s" %% v' % sum_name, depth)
506 self.Emit(' else:', depth)
507 self.Emit(' return v', depth)
508 self.Emit('', depth)
509
510 # the base class, e.g. 'oil_cmd'
511 self.Emit('class %s_t(pybase.CompoundObj):' % sum_name, depth)
512 self.Indent()
513 depth = self.current_depth
514
515 # To imitate C++ API
516 self.Emit('def tag(self):')
517 self.Emit(' # type: () -> int')
518 self.Emit(' return self._type_tag')
519
520 # This is what we would do in C++, but we don't need it in Python because
521 # every function is virtual.
522 if 0:
523 #if self.pretty_print_methods:
524 for abbrev in 'PrettyTree', '_AbbreviatedTree', 'AbbreviatedTree':
525 self.Emit('')
526 self.Emit('def %s(self):' % abbrev, depth)
527 self.Emit(' # type: () -> hnode_t', depth)
528 self.Indent()
529 depth = self.current_depth
530 self.Emit('UP_self = self', depth)
531 self.Emit('', depth)
532
533 for variant in sum.types:
534 if variant.shared_type:
535 subtype_name = variant.shared_type
536 else:
537 subtype_name = '%s__%s' % (sum_name, variant.name)
538
539 self.Emit(
540 'if self.tag() == %s_e.%s:' % (sum_name, variant.name),
541 depth)
542 self.Emit(' self = cast(%s, UP_self)' % subtype_name,
543 depth)
544 self.Emit(' return self.%s()' % abbrev, depth)
545
546 self.Emit('raise AssertionError()', depth)
547
548 self.Dedent()
549 depth = self.current_depth
550 else:
551 # Otherwise it's empty
552 self.Emit('pass', depth)
553
554 self.Dedent()
555 depth = self.current_depth
556 self.Emit('')
557
558 # Declare any zero argument singleton classes outside of the main
559 # "namespace" class.
560 for i, variant in enumerate(sum.types):
561 if variant.shared_type:
562 continue # Don't generate a class for shared types.
563 if len(variant.fields) == 0:
564 # We must use the old-style naming here, ie. command__NoOp, in order
565 # to support zero field variants as constants.
566 class_name = '%s__%s' % (sum_name, variant.name)
567 self._GenClass(variant, class_name, (sum_name + '_t', ), i + 1)
568
569 # Class that's just a NAMESPACE, e.g. for value.Str
570 self.Emit('class %s(object):' % sum_name, depth)
571
572 self.Indent()
573
574 for i, variant in enumerate(sum.types):
575 if variant.shared_type:
576 continue
577
578 if len(variant.fields) == 0:
579 self.Emit('%s = %s__%s()' %
580 (variant.name, sum_name, variant.name))
581 self.Emit('')
582 else:
583 # Use fully-qualified name, so we can have osh_cmd.Simple and
584 # oil_cmd.Simple.
585 fq_name = variant.name
586 self._GenClass(variant,
587 fq_name, (sum_name + '_t', ),
588 i + 1,
589 class_ns=sum_name + '.')
590 self.Emit(' pass', depth) # in case every variant is first class
591
592 self.Dedent()
593 self.Emit('')
594
595 def VisitProduct(self, product, name, depth):
596 self._shared_type_tags[name] = self._product_counter
597 # Create a tuple of _GenClass args to create LAST. They may inherit from
598 # sum types that have yet to be defined.
599 self._products.append((product, name, depth, self._product_counter))
600 self._product_counter += 1
601
602 def EmitFooter(self):
603 # Now generate all the product types we deferred.
604 for args in self._products:
605 ast_node, name, depth, tag_num = args
606 # Figure out base classes AFTERWARD.
607 bases = self._product_bases[name]
608 if not bases:
609 bases = ('pybase.CompoundObj', )
610 self._GenClass(ast_node, name, bases, tag_num)