OILS / ysh / func_proc.py View on Github | oilshell.org

576 lines, 382 significant
1#!/usr/bin/env python2
2"""
3User-defined funcs and procs
4"""
5from __future__ import print_function
6
7from _devbuild.gen.id_kind_asdl import Id
8from _devbuild.gen.runtime_asdl import cmd_value, ProcArgs
9from _devbuild.gen.syntax_asdl import (proc_sig, proc_sig_e, Param, ParamGroup,
10 NamedArg, Func, loc, ArgList, expr,
11 expr_e, expr_t)
12from _devbuild.gen.value_asdl import (value, value_e, value_t, ProcDefaults,
13 LeftName)
14
15from core import error
16from core.error import e_die
17from core import state
18from core import vm
19from frontend import lexer
20from frontend import typed_args
21from mycpp import mylib
22from mycpp.mylib import log, NewDict
23
24from typing import List, Tuple, Dict, Optional, cast, TYPE_CHECKING
25if TYPE_CHECKING:
26 from _devbuild.gen.syntax_asdl import command, loc_t
27 from osh import cmd_eval
28 from ysh import expr_eval
29
30_ = log
31
32# TODO:
33# - use _EvalExpr more?
34# - a single with state.ctx_YshExpr -- I guess that's faster
35# - although EvalExpr() can take param.blame_tok
36
37
38def _DisallowMutableDefault(val, blame_loc):
39 # type: (value_t, loc_t) -> None
40 if val.tag() in (value_e.List, value_e.Dict):
41 raise error.TypeErr(val, "Default values can't be mutable", blame_loc)
42
43
44def _EvalPosDefaults(expr_ev, pos_params):
45 # type: (expr_eval.ExprEvaluator, List[Param]) -> List[value_t]
46 """Shared between func and proc: Eval defaults for positional params"""
47
48 no_val = None # type: value_t
49 pos_defaults = [no_val] * len(pos_params)
50 for i, p in enumerate(pos_params):
51 if p.default_val:
52 val = expr_ev.EvalExpr(p.default_val, p.blame_tok)
53 _DisallowMutableDefault(val, p.blame_tok)
54 pos_defaults[i] = val
55 return pos_defaults
56
57
58def _EvalNamedDefaults(expr_ev, named_params):
59 # type: (expr_eval.ExprEvaluator, List[Param]) -> Dict[str, value_t]
60 """Shared between func and proc: Eval defaults for named params"""
61
62 named_defaults = NewDict() # type: Dict[str, value_t]
63 for i, p in enumerate(named_params):
64 if p.default_val:
65 val = expr_ev.EvalExpr(p.default_val, p.blame_tok)
66 _DisallowMutableDefault(val, p.blame_tok)
67 named_defaults[p.name] = val
68 return named_defaults
69
70
71def EvalFuncDefaults(
72 expr_ev, # type: expr_eval.ExprEvaluator
73 func, # type: Func
74):
75 # type: (...) -> Tuple[List[value_t], Dict[str, value_t]]
76 """Evaluate default args for funcs, at time of DEFINITION, not call."""
77
78 if func.positional:
79 pos_defaults = _EvalPosDefaults(expr_ev, func.positional.params)
80 else:
81 pos_defaults = None
82
83 if func.named:
84 named_defaults = _EvalNamedDefaults(expr_ev, func.named.params)
85 else:
86 named_defaults = None
87
88 return pos_defaults, named_defaults
89
90
91def EvalProcDefaults(expr_ev, sig):
92 # type: (expr_eval.ExprEvaluator, proc_sig.Closed) -> ProcDefaults
93 """Evaluate default args for procs, at time of DEFINITION, not call."""
94
95 no_val = None # type: value_t
96
97 if sig.word:
98 word_defaults = [no_val] * len(sig.word.params)
99 for i, p in enumerate(sig.word.params):
100 if p.default_val:
101 val = expr_ev.EvalExpr(p.default_val, p.blame_tok)
102 if val.tag() != value_e.Str:
103 raise error.TypeErr(
104 val, 'Default val for word param must be Str',
105 p.blame_tok)
106
107 word_defaults[i] = val
108 else:
109 word_defaults = None
110
111 if sig.positional:
112 pos_defaults = _EvalPosDefaults(expr_ev, sig.positional.params)
113 else:
114 pos_defaults = None # in case there's a block param
115
116 if sig.named:
117 named_defaults = _EvalNamedDefaults(expr_ev, sig.named.params)
118 else:
119 named_defaults = None
120
121 # cd /tmp (; ; myblock)
122 if sig.block_param:
123 exp = sig.block_param.default_val
124 if exp:
125 block_default = expr_ev.EvalExpr(exp, sig.block_param.blame_tok)
126 # It can only be ^() or null
127 if block_default.tag() not in (value_e.Null, value_e.Command):
128 raise error.TypeErr(
129 block_default,
130 "Default value for block should be Command or Null",
131 sig.block_param.blame_tok)
132 else:
133 block_default = None # no default, different than value.Null
134 else:
135 block_default = None
136
137 return ProcDefaults(word_defaults, pos_defaults, named_defaults,
138 block_default)
139
140
141def _EvalPosArgs(expr_ev, exprs, pos_args):
142 # type: (expr_eval.ExprEvaluator, List[expr_t], List[value_t]) -> None
143 """Shared between func and proc: evaluate positional args."""
144
145 for e in exprs:
146 UP_e = e
147 if e.tag() == expr_e.Spread:
148 e = cast(expr.Spread, UP_e)
149 val = expr_ev._EvalExpr(e.child)
150 if val.tag() != value_e.List:
151 raise error.TypeErr(val, 'Spread expected a List', e.left)
152 pos_args.extend(cast(value.List, val).items)
153 else:
154 pos_args.append(expr_ev._EvalExpr(e))
155
156
157def _EvalNamedArgs(expr_ev, named_exprs):
158 # type: (expr_eval.ExprEvaluator, List[NamedArg]) -> Dict[str, value_t]
159 """Shared between func and proc: evaluate named args."""
160
161 named_args = NewDict() # type: Dict[str, value_t]
162 for n in named_exprs:
163 val_expr = n.value
164 UP_val_expr = val_expr
165 if val_expr.tag() == expr_e.Spread:
166 val_expr = cast(expr.Spread, UP_val_expr)
167 val = expr_ev._EvalExpr(val_expr.child)
168 if val.tag() != value_e.Dict:
169 raise error.TypeErr(val, 'Spread expected a Dict',
170 val_expr.left)
171 named_args.update(cast(value.Dict, val).d)
172 else:
173 val = expr_ev.EvalExpr(n.value, n.name)
174 name = lexer.TokenVal(n.name)
175 named_args[name] = val
176
177 return named_args
178
179
180def _EvalArgList(
181 expr_ev, # type: expr_eval.ExprEvaluator
182 args, # type: ArgList
183 me=None # type: Optional[value_t]
184):
185 # type: (...) -> Tuple[List[value_t], Optional[Dict[str, value_t]]]
186 """Evaluate arg list for funcs.
187
188 This is a PRIVATE METHOD on ExprEvaluator, but it's in THIS FILE, because I
189 want it to be next to EvalTypedArgsToProc, which is similar.
190
191 It's not valid to call this without the EvalExpr() wrapper:
192
193 with state.ctx_YshExpr(...) # required to call this
194 ...
195 """
196 pos_args = [] # type: List[value_t]
197
198 if me: # self/this argument
199 pos_args.append(me)
200
201 _EvalPosArgs(expr_ev, args.pos_args, pos_args)
202
203 named_args = None # type: Dict[str, value_t]
204 if args.named_args is not None:
205 named_args = _EvalNamedArgs(expr_ev, args.named_args)
206
207 return pos_args, named_args
208
209
210def EvalTypedArgsToProc(
211 expr_ev, # type: expr_eval.ExprEvaluator
212 mutable_opts, # type: state.MutableOpts
213 node, # type: command.Simple
214 proc_args, # type: ProcArgs
215):
216 # type: (...) -> None
217 """Evaluate word, typed, named, and block args for a proc."""
218 proc_args.typed_args = node.typed_args
219
220 # We only got here if the call looks like
221 # p (x)
222 # p { echo hi }
223 # p () { echo hi }
224 # So allocate this unconditionally
225 proc_args.pos_args = []
226
227 ty = node.typed_args
228 if ty:
229 if ty.left.id == Id.Op_LBracket: # assert [42 === x]
230 # Defer evaluation by wrapping in value.Expr
231
232 for exp in ty.pos_args:
233 proc_args.pos_args.append(value.Expr(exp))
234 # TODO: ...spread is illegal
235
236 n1 = ty.named_args
237 if n1 is not None:
238 proc_args.named_args = NewDict()
239 for named_arg in n1:
240 name = lexer.TokenVal(named_arg.name)
241 proc_args.named_args[name] = value.Expr(named_arg.value)
242 # TODO: ...spread is illegal
243
244 else: # json write (x)
245 with state.ctx_YshExpr(mutable_opts): # What EvalExpr() does
246 _EvalPosArgs(expr_ev, ty.pos_args, proc_args.pos_args)
247
248 if ty.named_args is not None:
249 proc_args.named_args = _EvalNamedArgs(
250 expr_ev, ty.named_args)
251
252 if ty.block_expr and node.block:
253 e_die("Can't accept both block expression and block literal",
254 node.block.brace_group.left)
255
256 # p ( ; ; block) is an expression to be evaluated
257 if ty.block_expr:
258 # fallback location is (
259 proc_args.block_arg = expr_ev.EvalExpr(ty.block_expr, ty.left)
260
261 # p { echo hi } is an unevaluated block
262 if node.block:
263 # TODO: conslidate value.Block (holds LiteralBlock) and value.Command
264 proc_args.block_arg = value.Block(node.block)
265
266 # Add location info so the cmd_val looks the same for both:
267 # cd /tmp (; ; ^(echo hi))
268 # cd /tmp { echo hi }
269 if not proc_args.typed_args:
270 proc_args.typed_args = ArgList.CreateNull()
271
272 # Also add locations for error message: ls { echo invalid }
273 proc_args.typed_args.left = node.block.brace_group.left
274 proc_args.typed_args.right = node.block.brace_group.right
275
276
277def _BindWords(
278 proc_name, # type: str
279 group, # type: ParamGroup
280 defaults, # type: List[value_t]
281 cmd_val, # type: cmd_value.Argv
282 mem, # type: state.Mem
283 blame_loc, # type: loc_t
284):
285 # type: (...) -> None
286
287 argv = cmd_val.argv[1:]
288 num_args = len(argv)
289 for i, p in enumerate(group.params):
290 if i < num_args:
291 val = value.Str(argv[i]) # type: value_t
292 else: # default args were evaluated on definition
293 val = defaults[i]
294 if val is None:
295 raise error.Expr(
296 "proc %r wasn't passed word param %r" %
297 (proc_name, p.name), blame_loc)
298
299 mem.SetLocalName(LeftName(p.name, p.blame_tok), val)
300
301 # ...rest
302
303 num_params = len(group.params)
304 rest = group.rest_of
305 if rest:
306 lval = LeftName(rest.name, rest.blame_tok)
307
308 items = [value.Str(s)
309 for s in argv[num_params:]] # type: List[value_t]
310 rest_val = value.List(items)
311 mem.SetLocalName(lval, rest_val)
312 else:
313 if num_args > num_params:
314 if len(cmd_val.arg_locs):
315 # point to the first extra one
316 extra_loc = cmd_val.arg_locs[num_params + 1] # type: loc_t
317 else:
318 extra_loc = loc.Missing
319
320 # Too many arguments.
321 raise error.Expr(
322 "proc %r takes %d words, but got %d" %
323 (proc_name, num_params, num_args), extra_loc)
324
325
326def _BindTyped(
327 code_name, # type: str
328 group, # type: Optional[ParamGroup]
329 defaults, # type: List[value_t]
330 pos_args, # type: Optional[List[value_t]]
331 mem, # type: state.Mem
332 blame_loc, # type: loc_t
333):
334 # type: (...) -> None
335
336 if pos_args is None:
337 pos_args = []
338
339 num_args = len(pos_args)
340 num_params = 0
341
342 i = 0
343
344 if group:
345 for p in group.params:
346 if i < num_args:
347 val = pos_args[i]
348 else:
349 val = defaults[i]
350 if val is None:
351 raise error.Expr(
352 "%r wasn't passed typed param %r" %
353 (code_name, p.name), blame_loc)
354
355 mem.SetLocalName(LeftName(p.name, p.blame_tok), val)
356 i += 1
357 num_params += len(group.params)
358
359 # ...rest
360
361 if group:
362 rest = group.rest_of
363 if rest:
364 lval = LeftName(rest.name, rest.blame_tok)
365
366 rest_val = value.List(pos_args[num_params:])
367 mem.SetLocalName(lval, rest_val)
368 else:
369 if num_args > num_params:
370 # Too many arguments.
371 raise error.Expr(
372 "%r takes %d typed args, but got %d" %
373 (code_name, num_params, num_args), blame_loc)
374
375
376def _BindNamed(
377 code_name, # type: str
378 group, # type: ParamGroup
379 defaults, # type: Dict[str, value_t]
380 named_args, # type: Optional[Dict[str, value_t]]
381 mem, # type: state.Mem
382 blame_loc, # type: loc_t
383):
384 # type: (...) -> None
385
386 if named_args is None:
387 named_args = NewDict()
388
389 for p in group.params:
390 val = named_args.get(p.name)
391 if val is None:
392 val = defaults.get(p.name)
393 if val is None:
394 raise error.Expr(
395 "%r wasn't passed named param %r" % (code_name, p.name),
396 blame_loc)
397
398 mem.SetLocalName(LeftName(p.name, p.blame_tok), val)
399 # Remove bound args
400 mylib.dict_erase(named_args, p.name)
401
402 # ...rest
403 rest = group.rest_of
404 if rest:
405 lval = LeftName(rest.name, rest.blame_tok)
406 mem.SetLocalName(lval, value.Dict(named_args))
407 else:
408 num_args = len(named_args)
409 num_params = len(group.params)
410 if num_args > num_params:
411 # Too many arguments.
412 raise error.Expr(
413 "%r takes %d named args, but got %d" %
414 (code_name, num_params, num_args), blame_loc)
415
416
417def _BindFuncArgs(func, rd, mem):
418 # type: (value.Func, typed_args.Reader, state.Mem) -> None
419
420 node = func.parsed
421 blame_loc = rd.LeftParenToken()
422
423 ### Handle positional args
424
425 if node.positional:
426 _BindTyped(func.name, node.positional, func.pos_defaults, rd.pos_args,
427 mem, blame_loc)
428 else:
429 if rd.pos_args is not None:
430 num_pos = len(rd.pos_args)
431 if num_pos != 0:
432 raise error.Expr(
433 "Func %r takes no positional args, but got %d" %
434 (func.name, num_pos), blame_loc)
435
436 semi = rd.arg_list.semi_tok
437 if semi is not None:
438 blame_loc = semi
439
440 ### Handle named args
441
442 if node.named:
443 _BindNamed(func.name, node.named, func.named_defaults, rd.named_args,
444 mem, blame_loc)
445 else:
446 if rd.named_args is not None:
447 num_named = len(rd.named_args)
448 if num_named != 0:
449 raise error.Expr(
450 "Func %r takes no named args, but got %d" %
451 (func.name, num_named), blame_loc)
452
453
454def BindProcArgs(proc, cmd_val, mem):
455 # type: (value.Proc, cmd_value.Argv, state.Mem) -> None
456
457 proc_args = cmd_val.proc_args
458
459 UP_sig = proc.sig
460 if UP_sig.tag() != proc_sig_e.Closed: # proc is-closed ()
461 return
462
463 sig = cast(proc_sig.Closed, UP_sig)
464
465 # Note: we don't call _BindX() when there is no corresponding param group.
466 # This saves a few allocations, because most procs won't have all 3 types
467 # of args.
468
469 blame_loc = loc.Missing # type: loc_t
470
471 ### Handle word args
472
473 if len(cmd_val.arg_locs) > 0:
474 blame_loc = cmd_val.arg_locs[0]
475
476 if sig.word:
477 _BindWords(proc.name, sig.word, proc.defaults.for_word, cmd_val, mem,
478 blame_loc)
479 else:
480 num_word = len(cmd_val.argv)
481 if num_word != 1:
482 raise error.Expr(
483 "Proc %r takes no word args, but got %d" %
484 (proc.name, num_word - 1), blame_loc)
485
486 ### Handle typed positional args. This includes a block arg, if any.
487
488 if proc_args and proc_args.typed_args: # blame ( of call site
489 blame_loc = proc_args.typed_args.left
490
491 pos_args = proc_args.pos_args if proc_args else None
492 if sig.positional: # or sig.block_param:
493 _BindTyped(proc.name, sig.positional, proc.defaults.for_typed,
494 pos_args, mem, blame_loc)
495 else:
496 if pos_args is not None:
497 num_pos = len(pos_args)
498 if num_pos != 0:
499 raise error.Expr(
500 "Proc %r takes no typed args, but got %d" %
501 (proc.name, num_pos), blame_loc)
502
503 ### Handle typed named args
504
505 if proc_args and proc_args.typed_args: # blame ; of call site if possible
506 semi = proc_args.typed_args.semi_tok
507 if semi is not None:
508 blame_loc = semi
509
510 named_args = proc_args.named_args if proc_args else None
511 if sig.named:
512 _BindNamed(proc.name, sig.named, proc.defaults.for_named, named_args,
513 mem, blame_loc)
514 else:
515 if named_args is not None:
516 num_named = len(named_args)
517 if num_named != 0:
518 raise error.Expr(
519 "Proc %r takes no named args, but got %d" %
520 (proc.name, num_named), blame_loc)
521
522 # Maybe blame second ; of call site. Because value_t doesn't generally
523 # have location info, as opposed to expr_t.
524 if proc_args and proc_args.typed_args:
525 semi = proc_args.typed_args.semi_tok2
526 if semi is not None:
527 blame_loc = semi
528
529 ### Handle block arg
530
531 block_param = sig.block_param
532 block_arg = proc_args.block_arg if proc_args else None
533
534 if block_param:
535 if block_arg is None:
536 block_arg = proc.defaults.for_block
537 if block_arg is None:
538 raise error.Expr(
539 "%r wasn't passed block param %r" %
540 (proc.name, block_param.name), blame_loc)
541
542 mem.SetLocalName(LeftName(block_param.name, block_param.blame_tok),
543 block_arg)
544
545 else:
546 if block_arg is not None:
547 raise error.Expr(
548 "Proc %r doesn't accept a block argument" % proc.name,
549 blame_loc)
550
551
552def CallUserFunc(
553 func, # type: value.Func
554 rd, # type: typed_args.Reader
555 mem, # type: state.Mem
556 cmd_ev, # type: cmd_eval.CommandEvaluator
557):
558 # type: (...) -> value_t
559
560 # Push a new stack frame
561 with state.ctx_FuncCall(mem, func):
562 _BindFuncArgs(func, rd, mem)
563
564 try:
565 cmd_ev._Execute(func.parsed.body)
566
567 return value.Null # implicit return
568 except vm.ValueControlFlow as e:
569 return e.value
570 except vm.IntControlFlow as e:
571 raise AssertionError('IntControlFlow in func')
572
573 raise AssertionError('unreachable')
574
575
576# vim: sw=4