OILS / ysh / func_proc.py View on Github | oilshell.org

546 lines, 364 significant
1#!/usr/bin/env python2
2"""
3User-defined funcs and procs
4"""
5from __future__ import print_function
6
7from _devbuild.gen.id_kind_asdl import Id
8from _devbuild.gen.runtime_asdl import cmd_value
9from _devbuild.gen.syntax_asdl import (proc_sig, proc_sig_e, Param, ParamGroup,
10 NamedArg, Func, loc, ArgList, expr,
11 expr_e, expr_t)
12from _devbuild.gen.value_asdl import (value, value_e, value_t, ProcDefaults,
13 LeftName)
14
15from core import error
16from core import state
17from core import vm
18from frontend import lexer
19from frontend import typed_args
20from mycpp import mylib
21from mycpp.mylib import log, NewDict
22
23from typing import List, Tuple, Dict, Optional, cast, TYPE_CHECKING
24if TYPE_CHECKING:
25 from _devbuild.gen.syntax_asdl import command, loc_t
26 from osh import cmd_eval
27 from ysh import expr_eval
28
29_ = log
30
31# TODO:
32# - use _EvalExpr more?
33# - a single with state.ctx_YshExpr -- I guess that's faster
34# - although EvalExpr() can take param.blame_tok
35
36
37def _DisallowMutableDefault(val, blame_loc):
38 # type: (value_t, loc_t) -> None
39 if val.tag() in (value_e.List, value_e.Dict):
40 raise error.TypeErr(val, "Default values can't be mutable", blame_loc)
41
42
43def _EvalPosDefaults(expr_ev, pos_params):
44 # type: (expr_eval.ExprEvaluator, List[Param]) -> List[value_t]
45 """Shared between func and proc: Eval defaults for positional params"""
46
47 no_val = None # type: value_t
48 pos_defaults = [no_val] * len(pos_params)
49 for i, p in enumerate(pos_params):
50 if p.default_val:
51 val = expr_ev.EvalExpr(p.default_val, p.blame_tok)
52 _DisallowMutableDefault(val, p.blame_tok)
53 pos_defaults[i] = val
54 return pos_defaults
55
56
57def _EvalNamedDefaults(expr_ev, named_params):
58 # type: (expr_eval.ExprEvaluator, List[Param]) -> Dict[str, value_t]
59 """Shared between func and proc: Eval defaults for named params"""
60
61 named_defaults = NewDict() # type: Dict[str, value_t]
62 for i, p in enumerate(named_params):
63 if p.default_val:
64 val = expr_ev.EvalExpr(p.default_val, p.blame_tok)
65 _DisallowMutableDefault(val, p.blame_tok)
66 named_defaults[p.name] = val
67 return named_defaults
68
69
70def EvalFuncDefaults(
71 expr_ev, # type: expr_eval.ExprEvaluator
72 func, # type: Func
73):
74 # type: (...) -> Tuple[List[value_t], Dict[str, value_t]]
75 """Evaluate default args for funcs, at time of DEFINITION, not call."""
76
77 if func.positional:
78 pos_defaults = _EvalPosDefaults(expr_ev, func.positional.params)
79 else:
80 pos_defaults = None
81
82 if func.named:
83 named_defaults = _EvalNamedDefaults(expr_ev, func.named.params)
84 else:
85 named_defaults = None
86
87 return pos_defaults, named_defaults
88
89
90def EvalProcDefaults(expr_ev, sig):
91 # type: (expr_eval.ExprEvaluator, proc_sig.Closed) -> ProcDefaults
92 """Evaluate default args for procs, at time of DEFINITION, not call."""
93
94 no_val = None # type: value_t
95
96 if sig.word:
97 word_defaults = [no_val] * len(sig.word.params)
98 for i, p in enumerate(sig.word.params):
99 if p.default_val:
100 val = expr_ev.EvalExpr(p.default_val, p.blame_tok)
101 if val.tag() != value_e.Str:
102 raise error.TypeErr(
103 val, 'Default val for word param must be Str',
104 p.blame_tok)
105
106 word_defaults[i] = val
107 else:
108 word_defaults = None
109
110 if sig.positional:
111 pos_defaults = _EvalPosDefaults(expr_ev, sig.positional.params)
112 else:
113 pos_defaults = None # in case there's a block param
114
115 # Block param is treated like another positional param, so you can also
116 # pass it
117 # cd /tmp (myblock)
118 if sig.block_param:
119 exp = sig.block_param.default_val
120 if exp:
121 val = expr_ev.EvalExpr(exp, sig.block_param.blame_tok)
122 # TODO: it can only be ^() or null
123 if val.tag() not in (value_e.Null, value_e.Command):
124 raise error.TypeErr(
125 val, "Default value for block should be Command or Null",
126 sig.block_param.blame_tok)
127 else:
128 val = None # no default, different than value.Null
129
130 if pos_defaults is None:
131 pos_defaults = []
132 pos_defaults.append(val)
133
134 if sig.named:
135 named_defaults = _EvalNamedDefaults(expr_ev, sig.named.params)
136 else:
137 named_defaults = None
138
139 return ProcDefaults(word_defaults, pos_defaults, named_defaults)
140
141
142def _EvalPosArgs(expr_ev, exprs, pos_args):
143 # type: (expr_eval.ExprEvaluator, List[expr_t], List[value_t]) -> None
144 """Shared between func and proc: evaluate positional args."""
145
146 for e in exprs:
147 UP_e = e
148 if e.tag() == expr_e.Spread:
149 e = cast(expr.Spread, UP_e)
150 val = expr_ev._EvalExpr(e.child)
151 if val.tag() != value_e.List:
152 raise error.TypeErr(val, 'Spread expected a List', e.left)
153 pos_args.extend(cast(value.List, val).items)
154 else:
155 pos_args.append(expr_ev._EvalExpr(e))
156
157
158def _EvalNamedArgs(expr_ev, named_exprs):
159 # type: (expr_eval.ExprEvaluator, List[NamedArg]) -> Dict[str, value_t]
160 """Shared between func and proc: evaluate named args."""
161
162 named_args = NewDict() # type: Dict[str, value_t]
163 for n in named_exprs:
164 val_expr = n.value
165 UP_val_expr = val_expr
166 if val_expr.tag() == expr_e.Spread:
167 val_expr = cast(expr.Spread, UP_val_expr)
168 val = expr_ev._EvalExpr(val_expr.child)
169 if val.tag() != value_e.Dict:
170 raise error.TypeErr(val, 'Spread expected a Dict',
171 val_expr.left)
172 named_args.update(cast(value.Dict, val).d)
173 else:
174 val = expr_ev.EvalExpr(n.value, n.name)
175 name = lexer.TokenVal(n.name)
176 named_args[name] = val
177
178 return named_args
179
180
181def _EvalArgList(
182 expr_ev, # type: expr_eval.ExprEvaluator
183 args, # type: ArgList
184 me=None # type: Optional[value_t]
185):
186 # type: (...) -> Tuple[List[value_t], Optional[Dict[str, value_t]]]
187 """Evaluate arg list for funcs.
188
189 This is a PRIVATE METHOD on ExprEvaluator, but it's in THIS FILE, because I
190 want it to be next to EvalTypedArgsToProc, which is similar.
191
192 It's not valid to call this without the EvalExpr() wrapper:
193
194 with state.ctx_YshExpr(...) # required to call this
195 ...
196 """
197 pos_args = [] # type: List[value_t]
198
199 if me: # self/this argument
200 pos_args.append(me)
201
202 _EvalPosArgs(expr_ev, args.pos_args, pos_args)
203
204 named_args = None # type: Dict[str, value_t]
205 if args.named_args is not None:
206 named_args = _EvalNamedArgs(expr_ev, args.named_args)
207
208 return pos_args, named_args
209
210
211def EvalTypedArgsToProc(
212 expr_ev, # type: expr_eval.ExprEvaluator
213 mutable_opts, # type: state.MutableOpts
214 node, # type: command.Simple
215 cmd_val, # type: cmd_value.Argv
216):
217 # type: (...) -> None
218 """Evaluate word, typed, named, and block args for a proc."""
219 cmd_val.typed_args = node.typed_args
220
221 # We only got here if the call looks like
222 # p (x)
223 # p { echo hi }
224 # p () { echo hi }
225 # So allocate this unconditionally
226 cmd_val.pos_args = []
227
228 ty = node.typed_args
229 if ty:
230 if ty.left.id == Id.Op_LBracket: # assert [42 === x]
231 # Defer evaluation by wrapping in value.Expr
232
233 for exp in ty.pos_args:
234 cmd_val.pos_args.append(value.Expr(exp))
235 # TODO: ...spread is illegal
236
237 n1 = ty.named_args
238 if n1 is not None:
239 cmd_val.named_args = NewDict()
240 for named_arg in n1:
241 name = lexer.TokenVal(named_arg.name)
242 cmd_val.named_args[name] = value.Expr(named_arg.value)
243 # TODO: ...spread is illegal
244
245 else: # json write (x)
246 with state.ctx_YshExpr(mutable_opts): # What EvalExpr() does
247 _EvalPosArgs(expr_ev, ty.pos_args, cmd_val.pos_args)
248
249 if ty.named_args is not None:
250 cmd_val.named_args = _EvalNamedArgs(expr_ev, ty.named_args)
251
252 # Pass the unevaluated block.
253 if node.block:
254 cmd_val.pos_args.append(value.Block(node.block))
255
256 # Important invariant: cmd_val look the same for
257 # eval (^(echo hi))
258 # eval { echo hi }
259 if not cmd_val.typed_args:
260 cmd_val.typed_args = ArgList.CreateNull()
261
262 # Also add locations for error message: ls { echo invalid }
263 cmd_val.typed_args.left = node.block.brace_group.left
264 cmd_val.typed_args.right = node.block.brace_group.right
265
266
267def _BindWords(
268 proc_name, # type: str
269 group, # type: ParamGroup
270 defaults, # type: List[value_t]
271 cmd_val, # type: cmd_value.Argv
272 mem, # type: state.Mem
273 blame_loc, # type: loc_t
274):
275 # type: (...) -> None
276
277 argv = cmd_val.argv[1:]
278 num_args = len(argv)
279 for i, p in enumerate(group.params):
280 if i < num_args:
281 val = value.Str(argv[i]) # type: value_t
282 else: # default args were evaluated on definition
283 val = defaults[i]
284 if val is None:
285 raise error.Expr(
286 "proc %r wasn't passed word param %r" %
287 (proc_name, p.name), blame_loc)
288
289 mem.SetLocalName(LeftName(p.name, p.blame_tok), val)
290
291 # ...rest
292
293 num_params = len(group.params)
294 rest = group.rest_of
295 if rest:
296 lval = LeftName(rest.name, rest.blame_tok)
297
298 items = [value.Str(s)
299 for s in argv[num_params:]] # type: List[value_t]
300 rest_val = value.List(items)
301 mem.SetLocalName(lval, rest_val)
302 else:
303 if num_args > num_params:
304 if len(cmd_val.arg_locs):
305 # point to the first extra one
306 extra_loc = cmd_val.arg_locs[num_params + 1] # type: loc_t
307 else:
308 extra_loc = loc.Missing
309
310 # Too many arguments.
311 raise error.Expr(
312 "proc %r takes %d words, but got %d" %
313 (proc_name, num_params, num_args), extra_loc)
314
315
316def _BindTyped(
317 code_name, # type: str
318 group, # type: Optional[ParamGroup]
319 block_param, # type: Optional[Param]
320 defaults, # type: List[value_t]
321 pos_args, # type: Optional[List[value_t]]
322 mem, # type: state.Mem
323 blame_loc, # type: loc_t
324):
325 # type: (...) -> None
326
327 if pos_args is None:
328 pos_args = []
329
330 num_args = len(pos_args)
331 num_params = 0
332
333 i = 0
334
335 if group:
336 for p in group.params:
337 if i < num_args:
338 val = pos_args[i]
339 else:
340 val = defaults[i]
341 if val is None:
342 raise error.Expr(
343 "%r wasn't passed typed param %r" %
344 (code_name, p.name), blame_loc)
345
346 mem.SetLocalName(LeftName(p.name, p.blame_tok), val)
347 i += 1
348 num_params += len(group.params)
349
350 # Special case: treat block param like the next positional arg
351 if block_param:
352 if i < num_args:
353 val = pos_args[i]
354 else:
355 val = defaults[i]
356 if val is None:
357 raise error.Expr(
358 "%r wasn't passed block param %r" %
359 (code_name, block_param.name), blame_loc)
360
361 mem.SetLocalName(LeftName(block_param.name, block_param.blame_tok),
362 val)
363 num_params += 1
364
365 # ...rest
366
367 if group:
368 rest = group.rest_of
369 if rest:
370 lval = LeftName(rest.name, rest.blame_tok)
371
372 rest_val = value.List(pos_args[num_params:])
373 mem.SetLocalName(lval, rest_val)
374 else:
375 if num_args > num_params:
376 # Too many arguments.
377 raise error.Expr(
378 "%r takes %d typed args, but got %d" %
379 (code_name, num_params, num_args), blame_loc)
380
381
382def _BindNamed(
383 code_name, # type: str
384 group, # type: ParamGroup
385 defaults, # type: Dict[str, value_t]
386 named_args, # type: Optional[Dict[str, value_t]]
387 mem, # type: state.Mem
388 blame_loc, # type: loc_t
389):
390 # type: (...) -> None
391
392 if named_args is None:
393 named_args = NewDict()
394
395 for p in group.params:
396 val = named_args.get(p.name)
397 if val is None:
398 val = defaults.get(p.name)
399 if val is None:
400 raise error.Expr(
401 "%r wasn't passed named param %r" % (code_name, p.name),
402 blame_loc)
403
404 mem.SetLocalName(LeftName(p.name, p.blame_tok), val)
405 # Remove bound args
406 mylib.dict_erase(named_args, p.name)
407
408 # ...rest
409 rest = group.rest_of
410 if rest:
411 lval = LeftName(rest.name, rest.blame_tok)
412 mem.SetLocalName(lval, value.Dict(named_args))
413 else:
414 num_args = len(named_args)
415 num_params = len(group.params)
416 if num_args > num_params:
417 # Too many arguments.
418 raise error.Expr(
419 "%r takes %d named args, but got %d" %
420 (code_name, num_params, num_args), blame_loc)
421
422
423def _BindFuncArgs(func, rd, mem):
424 # type: (value.Func, typed_args.Reader, state.Mem) -> None
425
426 node = func.parsed
427 blame_loc = rd.LeftParenToken()
428
429 ### Handle positional args
430
431 if node.positional:
432 _BindTyped(func.name, node.positional, None, func.pos_defaults,
433 rd.pos_args, mem, blame_loc)
434 else:
435 if rd.pos_args is not None:
436 num_pos = len(rd.pos_args)
437 if num_pos != 0:
438 raise error.Expr(
439 "Func %r takes no positional args, but got %d" %
440 (func.name, num_pos), blame_loc)
441
442 semi = rd.arg_list.semi_tok
443 if semi is not None:
444 blame_loc = semi
445
446 ### Handle named args
447
448 if node.named:
449 _BindNamed(func.name, node.named, func.named_defaults, rd.named_args,
450 mem, blame_loc)
451 else:
452 if rd.named_args is not None:
453 num_named = len(rd.named_args)
454 if num_named != 0:
455 raise error.Expr(
456 "Func %r takes no named args, but got %d" %
457 (func.name, num_named), blame_loc)
458
459
460def BindProcArgs(proc, cmd_val, mem):
461 # type: (value.Proc, cmd_value.Argv, state.Mem) -> None
462
463 UP_sig = proc.sig
464 if UP_sig.tag() != proc_sig_e.Closed: # proc is-closed ()
465 return
466
467 sig = cast(proc_sig.Closed, UP_sig)
468
469 # Note: we don't call _BindX() when there is no corresponding param group.
470 # This saves a few allocations, because most procs won't have all 3 types
471 # of args.
472
473 blame_loc = loc.Missing # type: loc_t
474
475 ### Handle word args
476
477 if len(cmd_val.arg_locs) > 0:
478 blame_loc = cmd_val.arg_locs[0]
479
480 if sig.word:
481 _BindWords(proc.name, sig.word, proc.defaults.for_word, cmd_val, mem,
482 blame_loc)
483 else:
484 num_word = len(cmd_val.argv)
485 if num_word != 1:
486 raise error.Expr(
487 "Proc %r takes no word args, but got %d" %
488 (proc.name, num_word - 1), blame_loc)
489
490 ### Handle typed positional args. This includes a block arg, if any.
491
492 if cmd_val.typed_args:
493 blame_loc = cmd_val.typed_args.left
494
495 if sig.positional or sig.block_param:
496 _BindTyped(proc.name, sig.positional, sig.block_param,
497 proc.defaults.for_typed, cmd_val.pos_args, mem, blame_loc)
498 else:
499 if cmd_val.pos_args is not None:
500 num_pos = len(cmd_val.pos_args)
501 if num_pos != 0:
502 raise error.Expr(
503 "Proc %r takes no typed args, but got %d" %
504 (proc.name, num_pos), blame_loc)
505
506 ### Handle typed named args.
507
508 if cmd_val.typed_args:
509 semi = cmd_val.typed_args.semi_tok
510 if semi is not None:
511 blame_loc = semi
512
513 if sig.named:
514 _BindNamed(proc.name, sig.named, proc.defaults.for_named,
515 cmd_val.named_args, mem, blame_loc)
516 else:
517 if cmd_val.named_args is not None:
518 num_named = len(cmd_val.named_args)
519 if num_named != 0:
520 raise error.Expr(
521 "Proc %r takes no named args, but got %d" %
522 (proc.name, num_named), blame_loc)
523
524
525def CallUserFunc(
526 func, # type: value.Func
527 rd, # type: typed_args.Reader
528 mem, # type: state.Mem
529 cmd_ev, # type: cmd_eval.CommandEvaluator
530):
531 # type: (...) -> value_t
532
533 # Push a new stack frame
534 with state.ctx_FuncCall(mem, func):
535 _BindFuncArgs(func, rd, mem)
536
537 try:
538 cmd_ev._Execute(func.parsed.body)
539
540 return value.Null # implicit return
541 except vm.ValueControlFlow as e:
542 return e.value
543 except vm.IntControlFlow as e:
544 raise AssertionError('IntControlFlow in func')
545
546 raise AssertionError('unreachable')