| 1 | #!/usr/bin/env python
|
| 2 | from __future__ import print_function
|
| 3 | """
|
| 4 | asdl_cpp.py
|
| 5 |
|
| 6 | Turn an ASDL schema into C++ code.
|
| 7 |
|
| 8 | TODO:
|
| 9 | - Optional fields
|
| 10 | - in osh, it's only used in two places:
|
| 11 | - arith_expr? for slice length
|
| 12 | - word? for var replace
|
| 13 | - So you're already using pointers, can encode the NULL pointer.
|
| 14 |
|
| 15 | - Change everything to use references instead of pointers? Non-nullable.
|
| 16 | - Unify ClassDefVisitor and MethodBodyVisitor.
|
| 17 | - Whether you need a separate method body should be a flag.
|
| 18 | - offset calculations are duplicated
|
| 19 | - generate a C++ pretty-printer
|
| 20 |
|
| 21 | Technically we don't even need alignment? I guess the reason is to increase
|
| 22 | address space. If 1, then we have 16MiB of code. If 4, then we have 64 MiB.
|
| 23 |
|
| 24 | Everything is decoded on the fly, or is a char*, which I don't think has to be
|
| 25 | aligned (because the natural alignment would be 1 byte anyway.)
|
| 26 | """
|
| 27 |
|
| 28 | import sys
|
| 29 |
|
| 30 | from asdl import asdl_ as asdl
|
| 31 | from asdl import encode
|
| 32 | from asdl import visitor
|
| 33 |
|
| 34 | from osh.meta import Id
|
| 35 |
|
| 36 | class ChainOfVisitors:
|
| 37 | def __init__(self, *visitors):
|
| 38 | self.visitors = visitors
|
| 39 |
|
| 40 | def VisitModule(self, module):
|
| 41 | for v in self.visitors:
|
| 42 | v.VisitModule(module)
|
| 43 |
|
| 44 |
|
| 45 | _BUILTINS = {
|
| 46 | 'string': 'char*', # A read-only string is a char*
|
| 47 | 'int': 'int',
|
| 48 | 'bool': 'bool',
|
| 49 | 'id': 'Id', # Application specific hack for now
|
| 50 | }
|
| 51 |
|
| 52 | class ForwardDeclareVisitor(visitor.AsdlVisitor):
|
| 53 | """Print forward declarations.
|
| 54 |
|
| 55 | ASDL allows forward references of types, but C++ doesn't.
|
| 56 | """
|
| 57 | def VisitCompoundSum(self, sum, name, depth):
|
| 58 | self.Emit("class %(name)s_t;" % locals(), depth)
|
| 59 |
|
| 60 | def VisitProduct(self, product, name, depth):
|
| 61 | self.Emit("class %(name)s_t;" % locals(), depth)
|
| 62 |
|
| 63 | def EmitFooter(self):
|
| 64 | self.Emit("", 0) # blank line
|
| 65 |
|
| 66 |
|
| 67 | class ClassDefVisitor(visitor.AsdlVisitor):
|
| 68 | """Generate C++ classes and type-safe enums."""
|
| 69 |
|
| 70 | def __init__(self, f, enc_params, type_lookup, enum_types=None):
|
| 71 | visitor.AsdlVisitor.__init__(self, f)
|
| 72 | self.ref_width = enc_params.ref_width
|
| 73 | self.type_lookup = type_lookup
|
| 74 | self.enum_types = enum_types or {}
|
| 75 | self.pointer_type = enc_params.pointer_type
|
| 76 | self.footer = [] # lines
|
| 77 |
|
| 78 | def _GetCppType(self, field):
|
| 79 | """Return a string for the C++ name of the type."""
|
| 80 | type_name = field.type
|
| 81 |
|
| 82 | cpp_type = _BUILTINS.get(type_name)
|
| 83 | if cpp_type is not None:
|
| 84 | return cpp_type
|
| 85 |
|
| 86 | typ = self.type_lookup.ByTypeName(type_name)
|
| 87 | if isinstance(typ, asdl.Sum) and asdl.is_simple(typ):
|
| 88 | # Use the enum instead of the class.
|
| 89 | return "%s_e" % type_name
|
| 90 |
|
| 91 | # - Pointer for optional type.
|
| 92 | # - ints and strings should generally not be optional? We don't have them
|
| 93 | # in osh yet, so leave it out for now.
|
| 94 | if field.opt:
|
| 95 | return "%s_t*" % type_name
|
| 96 |
|
| 97 | return "%s_t&" % type_name
|
| 98 |
|
| 99 | def EmitFooter(self):
|
| 100 | for line in self.footer:
|
| 101 | self.f.write(line)
|
| 102 |
|
| 103 | def _EmitEnum(self, sum, name, depth):
|
| 104 | enum = []
|
| 105 | for i in range(len(sum.types)):
|
| 106 | type = sum.types[i]
|
| 107 | enum.append("%s = %d" % (type.name, i + 1)) # zero is reserved
|
| 108 |
|
| 109 | self.Emit("enum class %s_e : uint8_t {" % name, depth)
|
| 110 | self.Emit(", ".join(enum), depth + 1)
|
| 111 | self.Emit("};", depth)
|
| 112 | self.Emit("", depth)
|
| 113 |
|
| 114 | def VisitSimpleSum(self, sum, name, depth):
|
| 115 | self._EmitEnum(sum, name, depth)
|
| 116 |
|
| 117 | def VisitCompoundSum(self, sum, name, depth):
|
| 118 | # This is a sign that Python needs string interpolation!!!
|
| 119 | def Emit(s, depth=depth):
|
| 120 | self.Emit(s % sys._getframe(1).f_locals, depth)
|
| 121 |
|
| 122 | self._EmitEnum(sum, name, depth)
|
| 123 |
|
| 124 | Emit("class %(name)s_t : public Obj {")
|
| 125 | Emit(" public:")
|
| 126 | # All sum types have a tag
|
| 127 | Emit("%(name)s_e tag() const {", depth + 1)
|
| 128 | Emit("return static_cast<%(name)s_e>(bytes_[0]);", depth + 2)
|
| 129 | Emit("}", depth + 1)
|
| 130 | Emit("};")
|
| 131 | Emit("")
|
| 132 |
|
| 133 | # TODO: This should be replaced with a call to the generic
|
| 134 | # self.VisitChildren()
|
| 135 | super_name = "%s_t" % name
|
| 136 | for t in sum.types:
|
| 137 | self.VisitConstructor(t, super_name, depth)
|
| 138 |
|
| 139 | # rudimentary attribute handling
|
| 140 | for field in sum.attributes:
|
| 141 | type = str(field.type)
|
| 142 | assert type in asdl.builtin_types, type
|
| 143 | Emit("%s %s;" % (type, field.name), depth + 1)
|
| 144 |
|
| 145 | def VisitConstructor(self, cons, def_name, depth):
|
| 146 | #print(dir(cons))
|
| 147 | if cons.fields:
|
| 148 | self.Emit("class %s : public %s {" % (cons.name, def_name), depth)
|
| 149 | self.Emit(" public:", depth)
|
| 150 | offset = 1 # for the ID
|
| 151 | for f in cons.fields:
|
| 152 | self.VisitField(f, cons.name, offset, depth + 1)
|
| 153 | offset += self.ref_width
|
| 154 | self.Emit("};", depth)
|
| 155 | self.Emit("", depth)
|
| 156 |
|
| 157 | def VisitProduct(self, product, name, depth):
|
| 158 | self.Emit("class %(name)s_t : public Obj {" % locals(), depth)
|
| 159 | self.Emit(" public:", depth)
|
| 160 | offset = 0
|
| 161 | for f in product.fields:
|
| 162 | type_name = '%s_t' % name
|
| 163 | self.VisitField(f, type_name, offset, depth + 1)
|
| 164 | offset += self.ref_width
|
| 165 |
|
| 166 | for field in product.attributes:
|
| 167 | # rudimentary attribute handling
|
| 168 | type = str(field.type)
|
| 169 | assert type in asdl.builtin_types, type
|
| 170 | self.Emit("%s %s;" % (type, field.name), depth + 1)
|
| 171 | self.Emit("};", depth)
|
| 172 | self.Emit("", depth)
|
| 173 |
|
| 174 | def VisitField(self, field, type_name, offset, depth):
|
| 175 | """
|
| 176 | Even though they are inline, some of them can't be in the class {}, because
|
| 177 | static_cast<> requires inheritance relationships to be already declared. We
|
| 178 | have to print all the classes first, then all the bodies that might use
|
| 179 | static_cast<>.
|
| 180 |
|
| 181 | http://stackoverflow.com/questions/5808758/why-is-a-static-cast-from-a-pointer-to-base-to-a-pointer-to-derived-invalid
|
| 182 | """
|
| 183 | ctype = self._GetCppType(field)
|
| 184 | name = field.name
|
| 185 | pointer_type = self.pointer_type
|
| 186 | # Either 'left' or 'BoolBinary::left', depending on whether it's inline.
|
| 187 | # Mutated later.
|
| 188 | maybe_qual_name = name
|
| 189 |
|
| 190 | func_proto = None
|
| 191 | func_header = None
|
| 192 | body_line1 = None
|
| 193 | inline_body = None
|
| 194 |
|
| 195 | if field.seq: # Array/repeated
|
| 196 | # For size accessor, follow the ref, and then it's the first integer.
|
| 197 | size_header = (
|
| 198 | 'inline int %(name)s_size(const %(pointer_type)s* base) const {')
|
| 199 | size_body = "return Ref(base, %(offset)d).Int(0);"
|
| 200 |
|
| 201 | self.Emit(size_header % locals(), depth)
|
| 202 | self.Emit(size_body % locals(), depth + 1)
|
| 203 | self.Emit("}", depth)
|
| 204 |
|
| 205 | ARRAY_OFFSET = 'int a = (index+1) * 3;'
|
| 206 | A_POINTER = (
|
| 207 | 'inline const %(ctype)s %(maybe_qual_name)s('
|
| 208 | 'const %(pointer_type)s* base, int index) const')
|
| 209 |
|
| 210 | if ctype in ('bool', 'int'):
|
| 211 | func_header = A_POINTER + ' {'
|
| 212 | body_line1 = ARRAY_OFFSET
|
| 213 | inline_body = 'return Ref(base, %(offset)d).Int(a);'
|
| 214 |
|
| 215 | elif ctype.endswith('_e') or ctype in self.enum_types:
|
| 216 | func_header = A_POINTER + ' {'
|
| 217 | body_line1 = ARRAY_OFFSET
|
| 218 | inline_body = (
|
| 219 | 'return static_cast<const %(ctype)s>(Ref(base, %(offset)d).Int(a));')
|
| 220 |
|
| 221 | elif ctype == 'char*':
|
| 222 | func_header = A_POINTER + ' {'
|
| 223 | body_line1 = ARRAY_OFFSET
|
| 224 | inline_body = 'return Ref(base, %(offset)d).Str(base, a);'
|
| 225 |
|
| 226 | else:
|
| 227 | # Write function prototype now; write body later.
|
| 228 | func_proto = A_POINTER + ';'
|
| 229 |
|
| 230 | maybe_qual_name = '%s::%s' % (type_name, name)
|
| 231 | func_def = A_POINTER + ' {'
|
| 232 | # This static_cast<> (downcast) causes problems if put within "class
|
| 233 | # {}".
|
| 234 | func_body = (
|
| 235 | 'return static_cast<const %(ctype)s>('
|
| 236 | 'Ref(base, %(offset)d).Ref(base, a));')
|
| 237 |
|
| 238 | self.footer.extend(visitor.FormatLines(func_def % locals(), 0))
|
| 239 | self.footer.extend(visitor.FormatLines(ARRAY_OFFSET, 1))
|
| 240 | self.footer.extend(visitor.FormatLines(func_body % locals(), 1))
|
| 241 | self.footer.append('}\n\n')
|
| 242 | maybe_qual_name = name # RESET for later
|
| 243 |
|
| 244 | else: # not repeated
|
| 245 | SIMPLE = "inline %(ctype)s %(maybe_qual_name)s() const {"
|
| 246 | POINTER = (
|
| 247 | 'inline const %(ctype)s %(maybe_qual_name)s('
|
| 248 | 'const %(pointer_type)s* base) const')
|
| 249 |
|
| 250 | if ctype in ('bool', 'int'):
|
| 251 | func_header = SIMPLE
|
| 252 | inline_body = 'return Int(%(offset)d);'
|
| 253 |
|
| 254 | elif ctype.endswith('_e') or ctype in self.enum_types:
|
| 255 | func_header = SIMPLE
|
| 256 | inline_body = 'return static_cast<const %(ctype)s>(Int(%(offset)d));'
|
| 257 |
|
| 258 | elif ctype == 'char*':
|
| 259 | func_header = POINTER + " {"
|
| 260 | inline_body = 'return Str(base, %(offset)d);'
|
| 261 |
|
| 262 | else:
|
| 263 | # Write function prototype now; write body later.
|
| 264 | func_proto = POINTER + ";"
|
| 265 |
|
| 266 | maybe_qual_name = '%s::%s' % (type_name, name)
|
| 267 | func_def = POINTER + ' {'
|
| 268 | if field.opt:
|
| 269 | func_body = (
|
| 270 | 'return static_cast<const %(ctype)s>(Optional(base, %(offset)d));')
|
| 271 | else:
|
| 272 | func_body = (
|
| 273 | 'return static_cast<const %(ctype)s>(Ref(base, %(offset)d));')
|
| 274 |
|
| 275 | # depth 0 for bodies
|
| 276 | self.footer.extend(visitor.FormatLines(func_def % locals(), 0))
|
| 277 | self.footer.extend(visitor.FormatLines(func_body % locals(), 1))
|
| 278 | self.footer.append('}\n\n')
|
| 279 | maybe_qual_name = name # RESET for later
|
| 280 |
|
| 281 | if func_proto:
|
| 282 | self.Emit(func_proto % locals(), depth)
|
| 283 | else:
|
| 284 | self.Emit(func_header % locals(), depth)
|
| 285 | if body_line1:
|
| 286 | self.Emit(body_line1, depth + 1)
|
| 287 | self.Emit(inline_body % locals(), depth + 1)
|
| 288 | self.Emit("}", depth)
|
| 289 |
|
| 290 |
|
| 291 | # Used by osh/ast_gen.py
|
| 292 | class CEnumVisitor(visitor.AsdlVisitor):
|
| 293 |
|
| 294 | def VisitSimpleSum(self, sum, name, depth):
|
| 295 | # Just use #define, since enums aren't namespaced.
|
| 296 | for i, variant in enumerate(sum.types):
|
| 297 | self.Emit('#define %s__%s %d' % (name, variant.name, i + 1), depth)
|
| 298 | self.Emit("", depth)
|
| 299 |
|
| 300 |
|
| 301 | def main(argv):
|
| 302 | try:
|
| 303 | action = argv[1]
|
| 304 | except IndexError:
|
| 305 | raise RuntimeError('Action required')
|
| 306 |
|
| 307 | # TODO: Also generate a switch/static_cast<> pretty printer in C++! For
|
| 308 | # debugging. Might need to detect cycles though.
|
| 309 | if action == 'cpp':
|
| 310 | schema_path = argv[2]
|
| 311 |
|
| 312 | app_types = {'id': asdl.UserType(Id)}
|
| 313 | with open(schema_path) as input_f:
|
| 314 | module, type_lookup = asdl.LoadSchema(input_f, app_types)
|
| 315 |
|
| 316 | # TODO: gen_cpp.py should be a library and the application should add Id?
|
| 317 | # Or we should enable ASDL metaprogramming, and let Id be a metaprogrammed
|
| 318 | # simple sum type.
|
| 319 |
|
| 320 | f = sys.stdout
|
| 321 |
|
| 322 | # How do mutation of strings, arrays, etc. work? Are they like C++
|
| 323 | # containers, or their own? I think they mirror the oil language
|
| 324 | # semantics.
|
| 325 | # Every node should have a mirror. MutableObj. MutableRef (pointer).
|
| 326 | # MutableArithVar -- has std::string. The mirrors are heap allocated.
|
| 327 | # All the mutable ones should support Dump()/Encode()?
|
| 328 | # You can just write more at the end... don't need to disturb existing
|
| 329 | # nodes? Rewrite pointers.
|
| 330 |
|
| 331 | alignment = 4
|
| 332 | enc = encode.Params(alignment)
|
| 333 | d = {'pointer_type': enc.pointer_type}
|
| 334 |
|
| 335 | f.write("""\
|
| 336 | #include <cstdint>
|
| 337 |
|
| 338 | class Obj {
|
| 339 | public:
|
| 340 | // Decode a 3 byte integer from little endian
|
| 341 | inline int Int(int n) const;
|
| 342 |
|
| 343 | inline const Obj& Ref(const %(pointer_type)s* base, int n) const;
|
| 344 |
|
| 345 | inline const Obj* Optional(const %(pointer_type)s* base, int n) const;
|
| 346 |
|
| 347 | // NUL-terminated
|
| 348 | inline const char* Str(const %(pointer_type)s* base, int n) const;
|
| 349 |
|
| 350 | protected:
|
| 351 | uint8_t bytes_[1]; // first is ID; rest are a payload
|
| 352 | };
|
| 353 |
|
| 354 | """ % d)
|
| 355 |
|
| 356 | # Id should be treated as an enum.
|
| 357 | c = ChainOfVisitors(
|
| 358 | ForwardDeclareVisitor(f),
|
| 359 | ClassDefVisitor(f, enc, type_lookup, enum_types=['Id']))
|
| 360 | c.VisitModule(module)
|
| 361 |
|
| 362 | f.write("""\
|
| 363 | inline int Obj::Int(int n) const {
|
| 364 | return bytes_[n] + (bytes_[n+1] << 8) + (bytes_[n+2] << 16);
|
| 365 | }
|
| 366 |
|
| 367 | inline const Obj& Obj::Ref(const %(pointer_type)s* base, int n) const {
|
| 368 | int offset = Int(n);
|
| 369 | return reinterpret_cast<const Obj&>(base[offset]);
|
| 370 | }
|
| 371 |
|
| 372 | inline const Obj* Obj::Optional(const %(pointer_type)s* base, int n) const {
|
| 373 | int offset = Int(n);
|
| 374 | if (offset) {
|
| 375 | return reinterpret_cast<const Obj*>(base + offset);
|
| 376 | } else {
|
| 377 | return nullptr;
|
| 378 | }
|
| 379 | }
|
| 380 |
|
| 381 | inline const char* Obj::Str(const %(pointer_type)s* base, int n) const {
|
| 382 | int offset = Int(n);
|
| 383 | return reinterpret_cast<const char*>(base + offset);
|
| 384 | }
|
| 385 | """ % d)
|
| 386 | # uint32_t* and char*/Obj* aren't related, so we need to use
|
| 387 | # reinterpret_cast<>.
|
| 388 | # http://stackoverflow.com/questions/10151834/why-cant-i-static-cast-between-char-and-unsigned-char
|
| 389 |
|
| 390 | else:
|
| 391 | raise RuntimeError('Invalid action %r' % action)
|
| 392 |
|
| 393 |
|
| 394 | if __name__ == '__main__':
|
| 395 | try:
|
| 396 | main(sys.argv)
|
| 397 | except RuntimeError as e:
|
| 398 | print('FATAL: %s' % e, file=sys.stderr)
|
| 399 | sys.exit(1)
|