OILS / osh / prompt.py View on Github | oilshell.org

373 lines, 214 significant
1"""
2prompt.py: A LIBRARY for prompt evaluation.
3
4User interface details should go in core/ui.py.
5"""
6from __future__ import print_function
7
8import time as time_
9
10from _devbuild.gen.id_kind_asdl import Id, Id_t
11from _devbuild.gen.syntax_asdl import (loc, command_t, source, CompoundWord)
12from _devbuild.gen.value_asdl import (value, value_e, value_t)
13from core import alloc
14from core import main_loop
15from core import error
16from core import pyos
17from core import state
18from display import ui
19from frontend import consts
20from frontend import match
21from frontend import reader
22from mycpp import mylib
23from mycpp.mylib import log, tagswitch
24from osh import word_
25from pylib import os_path
26
27import libc # gethostname()
28import posix_ as posix
29
30from typing import Dict, List, Tuple, Optional, cast, TYPE_CHECKING
31if TYPE_CHECKING:
32 from core.state import Mem
33 from frontend.parse_lib import ParseContext
34 from osh import cmd_eval
35 from osh import word_eval
36 from ysh import expr_eval
37
38_ = log
39
40#
41# Prompt Evaluation
42#
43
44_ERROR_FMT = '<Error: %s> '
45_UNBALANCED_ERROR = r'Unbalanced \[ and \]'
46
47
48class _PromptEvaluatorCache(object):
49 """Cache some values we don't expect to change for the life of a
50 process."""
51
52 def __init__(self):
53 # type: () -> None
54 self.cache = {} # type: Dict[str, str]
55 self.euid = -1 # invalid value
56
57 def _GetEuid(self):
58 # type: () -> int
59 """Cached lookup."""
60 if self.euid == -1:
61 self.euid = posix.geteuid()
62 return self.euid
63
64 def Get(self, name):
65 # type: (str) -> str
66 if name in self.cache:
67 return self.cache[name]
68
69 if name == '$': # \$
70 value = '#' if self._GetEuid() == 0 else '$'
71
72 elif name == 'hostname': # for \h and \H
73 value = libc.gethostname()
74
75 elif name == 'user': # for \u
76 # recursive call for caching
77 value = pyos.GetUserName(self._GetEuid())
78
79 else:
80 raise AssertionError(name)
81
82 self.cache[name] = value
83 return value
84
85
86class Evaluator(object):
87 """Evaluate the prompt mini-language.
88
89 bash has a very silly algorithm:
90 1. replace backslash codes, except any $ in those values get quoted into \$.
91 2. Parse the word as if it's in a double quoted context, and then evaluate
92 the word.
93
94 Haven't done this from POSIX: POSIX:
95 http://pubs.opengroup.org/onlinepubs/9699919799/utilities/V3_chap02.html
96
97 The shell shall replace each instance of the character '!' in PS1 with the
98 history file number of the next command to be typed. Escaping the '!' with
99 another '!' (that is, "!!" ) shall place the literal character '!' in the
100 prompt.
101 """
102
103 def __init__(self, lang, version_str, parse_ctx, mem):
104 # type: (str, str, ParseContext, Mem) -> None
105 self.word_ev = None # type: word_eval.AbstractWordEvaluator
106 self.expr_ev = None # type: expr_eval.ExprEvaluator
107 self.global_io = None # type: value.IO
108
109 assert lang in ('osh', 'ysh'), lang
110 self.lang = lang
111 self.version_str = version_str
112 self.parse_ctx = parse_ctx
113 self.mem = mem
114 # Cache to save syscalls / libc calls.
115 self.cache = _PromptEvaluatorCache()
116
117 # These caches should reduce memory pressure a bit. We don't want to
118 # reparse the prompt twice every time you hit enter.
119 self.tokens_cache = {} # type: Dict[str, List[Tuple[Id_t, str]]]
120 self.parse_cache = {} # type: Dict[str, CompoundWord]
121
122 def CheckCircularDeps(self):
123 # type: () -> None
124 assert self.word_ev is not None
125
126 def PromptVal(self, what):
127 # type: (str) -> str
128 """
129 _io->promptVal('$')
130 """
131 if what == 'D':
132 # TODO: wrap strftime(), time(), localtime(), etc. so users can do
133 # it themselves
134 return _ERROR_FMT % '\D{} not in promptVal()'
135 else:
136 # Could make hostname -> h alias, etc.
137 return self.PromptSubst(what)
138
139 def PromptSubst(self, ch, arg=None):
140 # type: (str, Optional[str]) -> str
141
142 if ch == '$': # So the user can tell if they're root or not.
143 r = self.cache.Get('$')
144
145 elif ch == 'u':
146 r = self.cache.Get('user')
147
148 elif ch == 'h':
149 hostname = self.cache.Get('hostname')
150 # foo.com -> foo
151 r, _ = mylib.split_once(hostname, '.')
152
153 elif ch == 'H':
154 r = self.cache.Get('hostname')
155
156 elif ch == 's':
157 r = self.lang
158
159 elif ch == 'v':
160 r = self.version_str
161
162 elif ch == 'A':
163 now = time_.time()
164 r = time_.strftime('%H:%M', time_.localtime(now))
165
166 elif ch == 'D': # \D{%H:%M} is the only one with a suffix
167 now = time_.time()
168 assert arg is not None
169 if len(arg) == 0:
170 # In bash y.tab.c uses %X when string is empty
171 # This doesn't seem to match exactly, but meh for now.
172 fmt = '%X'
173 else:
174 fmt = arg
175 r = time_.strftime(fmt, time_.localtime(now))
176
177 elif ch == 'w':
178 try:
179 pwd = state.GetString(self.mem, 'PWD')
180 # doesn't have to exist
181 home = state.MaybeString(self.mem, 'HOME')
182 # Shorten to ~/mydir
183 r = ui.PrettyDir(pwd, home)
184 except error.Runtime as e:
185 r = _ERROR_FMT % e.UserErrorString()
186
187 elif ch == 'W':
188 val = self.mem.GetValue('PWD')
189 if val.tag() == value_e.Str:
190 str_val = cast(value.Str, val)
191 r = os_path.basename(str_val.s)
192 else:
193 r = _ERROR_FMT % 'PWD is not a string'
194
195 else:
196 # e.g. \e \r \n \\
197 r = consts.LookupCharPrompt(ch)
198
199 # TODO: Handle more codes
200 # R(r'\\[adehHjlnrstT@AuvVwW!#$\\]', Id.PS_Subst),
201 if r is None:
202 r = _ERROR_FMT % (r'\%s is invalid or unimplemented in $PS1' %
203 ch)
204
205 return r
206
207 def _ReplaceBackslashCodes(self, tokens):
208 # type: (List[Tuple[Id_t, str]]) -> str
209 ret = [] # type: List[str]
210 non_printing = 0
211 for id_, s in tokens:
212 # BadBacklash means they should have escaped with \\. TODO: Make it an error.
213 # 'echo -e' has a similar issue.
214 if id_ in (Id.PS_Literals, Id.PS_BadBackslash):
215 ret.append(s)
216
217 elif id_ == Id.PS_Octal3:
218 i = int(s[1:], 8)
219 ret.append(chr(i % 256))
220
221 elif id_ == Id.PS_LBrace:
222 non_printing += 1
223 ret.append('\x01')
224
225 elif id_ == Id.PS_RBrace:
226 non_printing -= 1
227 if non_printing < 0: # e.g. \]\[
228 return _ERROR_FMT % _UNBALANCED_ERROR
229
230 ret.append('\x02')
231
232 elif id_ == Id.PS_Subst: # \u \h \w etc.
233 ch = s[1]
234 arg = None # type: Optional[str]
235 if ch == 'D':
236 arg = s[3:-1] # \D{%H:%M}
237 r = self.PromptSubst(ch, arg=arg)
238
239 # See comment above on bash hack for $.
240 ret.append(r.replace('$', '\\$'))
241
242 else:
243 raise AssertionError('Invalid token %r %r' % (id_, s))
244
245 # mismatched brackets, see https://github.com/oilshell/oil/pull/256
246 if non_printing != 0:
247 return _ERROR_FMT % _UNBALANCED_ERROR
248
249 return ''.join(ret)
250
251 def EvalPrompt(self, UP_val):
252 # type: (value_t) -> str
253 """Perform the two evaluations that bash does.
254
255 Used by $PS1 and ${x@P}.
256 """
257 if UP_val.tag() != value_e.Str:
258 return '' # e.g. if the user does 'unset PS1'
259
260 val = cast(value.Str, UP_val)
261
262 # Parse backslash escapes (cached)
263 tokens = self.tokens_cache.get(val.s)
264 if tokens is None:
265 tokens = match.Ps1Tokens(val.s)
266 self.tokens_cache[val.s] = tokens
267
268 # Replace values.
269 ps1_str = self._ReplaceBackslashCodes(tokens)
270
271 # Parse it like a double-quoted word (cached). TODO: This could be done on
272 # mem.SetValue(), so we get the error earlier.
273 # NOTE: This is copied from the PS4 logic in Tracer.
274 ps1_word = self.parse_cache.get(ps1_str)
275 if ps1_word is None:
276 w_parser = self.parse_ctx.MakeWordParserForPlugin(ps1_str)
277 try:
278 ps1_word = w_parser.ReadForPlugin()
279 except error.Parse as e:
280 ps1_word = word_.ErrorWord("<ERROR: Can't parse PS1: %s>" %
281 e.UserErrorString())
282 self.parse_cache[ps1_str] = ps1_word
283
284 # Evaluate, e.g. "${debian_chroot}\u" -> '\u'
285 val2 = self.word_ev.EvalForPlugin(ps1_word)
286 return val2.s
287
288 def EvalFirstPrompt(self):
289 # type: () -> str
290
291 # First try calling renderPrompt()
292 UP_func_val = self.mem.GetValue('renderPrompt')
293 if UP_func_val.tag() == value_e.Func:
294 func_val = cast(value.Func, UP_func_val)
295
296 assert self.global_io is not None
297 pos_args = [self.global_io] # type: List[value_t]
298 val = self.expr_ev.PluginCall(func_val, pos_args)
299
300 UP_val = val
301 with tagswitch(val) as case:
302 if case(value_e.Str):
303 val = cast(value.Str, UP_val)
304 return val.s
305 else:
306 msg = 'renderPrompt() should return Str, got %s' % ui.ValType(
307 val)
308 return _ERROR_FMT % msg
309
310 # Now try evaluating $PS1
311
312 ps1_val = self.mem.GetValue('PS1')
313 prompt_str = self.EvalPrompt(ps1_val)
314
315 # Add string to show it's YSH. The user can disable this with
316 #
317 # func renderPrompt() {
318 # return ("${PS1@P}")
319 # }
320 if self.lang == 'ysh':
321 prompt_str = 'ysh ' + prompt_str
322
323 return prompt_str
324
325
326PROMPT_COMMAND = 'PROMPT_COMMAND'
327
328
329class UserPlugin(object):
330 """For executing PROMPT_COMMAND and caching its parse tree.
331
332 Similar to core/dev.py:Tracer, which caches $PS4.
333 """
334
335 def __init__(self, mem, parse_ctx, cmd_ev, errfmt):
336 # type: (Mem, ParseContext, cmd_eval.CommandEvaluator, ui.ErrorFormatter) -> None
337 self.mem = mem
338 self.parse_ctx = parse_ctx
339 self.cmd_ev = cmd_ev
340 self.errfmt = errfmt
341
342 self.arena = parse_ctx.arena
343 self.parse_cache = {} # type: Dict[str, command_t]
344
345 def Run(self):
346 # type: () -> None
347 val = self.mem.GetValue(PROMPT_COMMAND)
348 if val.tag() != value_e.Str:
349 return
350
351 # PROMPT_COMMAND almost never changes, so we try to cache its parsing.
352 # This avoids memory allocations.
353 prompt_cmd = cast(value.Str, val).s
354 node = self.parse_cache.get(prompt_cmd)
355 if node is None:
356 line_reader = reader.StringLineReader(prompt_cmd, self.arena)
357 c_parser = self.parse_ctx.MakeOshParser(line_reader)
358
359 # NOTE: This is similar to CommandEvaluator.ParseTrapCode().
360 src = source.Variable(PROMPT_COMMAND, loc.Missing)
361 with alloc.ctx_SourceCode(self.arena, src):
362 try:
363 node = main_loop.ParseWholeFile(c_parser)
364 except error.Parse as e:
365 self.errfmt.PrettyPrintError(e)
366 return # don't execute
367
368 self.parse_cache[prompt_cmd] = node
369
370 # Save this so PROMPT_COMMAND can't set $?
371 with state.ctx_Registers(self.mem):
372 # Catches fatal execution error
373 self.cmd_ev.ExecuteAndCatch(node)