OILS / builtin / printf_osh.py View on Github | oilshell.org

508 lines, 337 significant
1#!/usr/bin/env python2
2"""Builtin_printf.py."""
3from __future__ import print_function
4
5import time as time_ # avoid name conflict
6
7from _devbuild.gen import arg_types
8from _devbuild.gen.id_kind_asdl import Id, Kind, Id_t, Kind_t
9from _devbuild.gen.runtime_asdl import cmd_value
10from _devbuild.gen.syntax_asdl import (
11 loc,
12 loc_e,
13 loc_t,
14 source,
15 Token,
16 CompoundWord,
17 printf_part,
18 printf_part_e,
19 printf_part_t,
20)
21from _devbuild.gen.types_asdl import lex_mode_e, lex_mode_t
22from _devbuild.gen.value_asdl import (value, value_e)
23
24from core import alloc
25from core import error
26from core.error import e_die, p_die
27from core import state
28from core import vm
29from frontend import flag_util
30from frontend import consts
31from frontend import lexer
32from frontend import match
33from frontend import reader
34from mycpp import mylib
35from mycpp.mylib import log
36from osh import sh_expr_eval
37from osh import word_compile
38from data_lang import j8_lite
39
40import posix_ as posix
41
42from typing import Dict, List, TYPE_CHECKING, cast
43
44if TYPE_CHECKING:
45 from core import ui
46 from frontend import parse_lib
47
48_ = log
49
50
51class _FormatStringParser(object):
52 """
53 Grammar:
54
55 width = Num | Star
56 precision = Dot (Num | Star | Zero)?
57 fmt = Percent (Flag | Zero)* width? precision? (Type | Time)
58 part = Char_* | Format_EscapedPercent | fmt
59 printf_format = part* Eof_Real # we're using the main lexer
60
61 Maybe: bash also supports %(strftime)T
62 """
63
64 def __init__(self, lexer):
65 # type: (lexer.Lexer) -> None
66 self.lexer = lexer
67
68 # uninitialized values
69 self.cur_token = None # type: Token
70 self.token_type = Id.Undefined_Tok # type: Id_t
71 self.token_kind = Kind.Undefined # type: Kind_t
72
73 def _Next(self, lex_mode):
74 # type: (lex_mode_t) -> None
75 """Advance a token."""
76 self.cur_token = self.lexer.Read(lex_mode)
77 self.token_type = self.cur_token.id
78 self.token_kind = consts.GetKind(self.token_type)
79
80 def _ParseFormatStr(self):
81 # type: () -> printf_part_t
82 """Fmt production."""
83 self._Next(lex_mode_e.PrintfPercent) # move past %
84
85 part = printf_part.Percent.CreateNull(alloc_lists=True)
86 while self.token_type in (Id.Format_Flag, Id.Format_Zero):
87 # space and + could be implemented
88 flag = lexer.TokenVal(self.cur_token) # allocation will be cached
89 if flag in '# +':
90 p_die("osh printf doesn't support the %r flag" % flag,
91 self.cur_token)
92
93 part.flags.append(self.cur_token)
94 self._Next(lex_mode_e.PrintfPercent)
95
96 if self.token_type in (Id.Format_Num, Id.Format_Star):
97 part.width = self.cur_token
98 self._Next(lex_mode_e.PrintfPercent)
99
100 if self.token_type == Id.Format_Dot:
101 part.precision = self.cur_token
102 self._Next(lex_mode_e.PrintfPercent) # past dot
103 if self.token_type in (Id.Format_Num, Id.Format_Star,
104 Id.Format_Zero):
105 part.precision = self.cur_token
106 self._Next(lex_mode_e.PrintfPercent)
107
108 if self.token_type in (Id.Format_Type, Id.Format_Time):
109 part.type = self.cur_token
110
111 # ADDITIONAL VALIDATION outside the "grammar".
112 type_val = lexer.TokenVal(part.type) # allocation will be cached
113 if type_val in 'eEfFgG':
114 p_die("osh printf doesn't support floating point", part.type)
115 # These two could be implemented. %c needs utf-8 decoding.
116 if type_val == 'c':
117 p_die("osh printf doesn't support single characters (bytes)",
118 part.type)
119
120 elif self.token_type == Id.Unknown_Tok:
121 p_die('Invalid printf format character', self.cur_token)
122
123 else:
124 p_die('Expected a printf format character', self.cur_token)
125
126 return part
127
128 def Parse(self):
129 # type: () -> List[printf_part_t]
130 self._Next(lex_mode_e.PrintfOuter)
131 parts = [] # type: List[printf_part_t]
132 while True:
133 if (self.token_kind == Kind.Char or
134 self.token_type == Id.Format_EscapedPercent or
135 self.token_type == Id.Unknown_Backslash):
136
137 # Note: like in echo -e, we don't fail with Unknown_Backslash here
138 # when shopt -u parse_backslash because it's at runtime rather than
139 # parse time.
140 # Users should use $'' or the future static printf ${x %.3f}.
141
142 parts.append(printf_part.Literal(self.cur_token))
143
144 elif self.token_type == Id.Format_Percent:
145 parts.append(self._ParseFormatStr())
146
147 elif self.token_type in (Id.Eof_Real, Id.Eol_Tok):
148 # Id.Eol_Tok: special case for format string of '\x00'.
149 break
150
151 else:
152 raise AssertionError(self.token_type)
153
154 self._Next(lex_mode_e.PrintfOuter)
155
156 return parts
157
158
159class Printf(vm._Builtin):
160
161 def __init__(
162 self,
163 mem, # type: state.Mem
164 parse_ctx, # type: parse_lib.ParseContext
165 unsafe_arith, # type: sh_expr_eval.UnsafeArith
166 errfmt, # type: ui.ErrorFormatter
167 ):
168 # type: (...) -> None
169 self.mem = mem
170 self.parse_ctx = parse_ctx
171 self.unsafe_arith = unsafe_arith
172 self.errfmt = errfmt
173 self.parse_cache = {} # type: Dict[str, List[printf_part_t]]
174
175 self.shell_start_time = time_.time(
176 ) # this object initialized in main()
177
178 def _Format(self, parts, varargs, locs, out):
179 # type: (List[printf_part_t], List[str], List[CompoundWord], List[str]) -> int
180 """Hairy printf formatting logic."""
181
182 arg_index = 0
183 num_args = len(varargs)
184 backslash_c = False
185
186 while True: # loop over arguments
187 for part in parts: # loop over parsed format string
188 UP_part = part
189 if part.tag() == printf_part_e.Literal:
190 part = cast(printf_part.Literal, UP_part)
191 token = part.token
192 if token.id == Id.Format_EscapedPercent:
193 s = '%'
194 else:
195 s = word_compile.EvalCStringToken(token)
196 out.append(s)
197
198 elif part.tag() == printf_part_e.Percent:
199 # Note: This case is very long, but hard to refactor because of the
200 # error cases and "recycling" of args! (arg_index, return 1, etc.)
201 part = cast(printf_part.Percent, UP_part)
202
203 # TODO: These calculations are independent of the data, so could be
204 # cached
205 flags = [] # type: List[str]
206 if len(part.flags) > 0:
207 for flag_token in part.flags:
208 flags.append(lexer.TokenVal(flag_token))
209
210 width = -1 # nonexistent
211 if part.width:
212 if part.width.id in (Id.Format_Num, Id.Format_Zero):
213 width_str = lexer.TokenVal(part.width)
214 width_loc = part.width # type: loc_t
215 elif part.width.id == Id.Format_Star:
216 if arg_index < num_args:
217 width_str = varargs[arg_index]
218 width_loc = locs[arg_index]
219 arg_index += 1
220 else:
221 width_str = '' # invalid
222 width_loc = loc.Missing
223 else:
224 raise AssertionError()
225
226 try:
227 width = int(width_str)
228 except ValueError:
229 if width_loc.tag() == loc_e.Missing:
230 width_loc = part.width
231 self.errfmt.Print_("printf got invalid width %r" %
232 width_str,
233 blame_loc=width_loc)
234 return 1
235
236 precision = -1 # nonexistent
237 if part.precision:
238 if part.precision.id == Id.Format_Dot:
239 precision_str = '0'
240 precision_loc = part.precision # type: loc_t
241 elif part.precision.id in (Id.Format_Num,
242 Id.Format_Zero):
243 precision_str = lexer.TokenVal(part.precision)
244 precision_loc = part.precision
245 elif part.precision.id == Id.Format_Star:
246 if arg_index < num_args:
247 precision_str = varargs[arg_index]
248 precision_loc = locs[arg_index]
249 arg_index += 1
250 else:
251 precision_str = ''
252 precision_loc = loc.Missing
253 else:
254 raise AssertionError()
255
256 try:
257 precision = int(precision_str)
258 except ValueError:
259 if precision_loc.tag() == loc_e.Missing:
260 precision_loc = part.precision
261 self.errfmt.Print_(
262 'printf got invalid precision %r' %
263 precision_str,
264 blame_loc=precision_loc)
265 return 1
266
267 if arg_index < num_args:
268 s = varargs[arg_index]
269 word_loc = locs[arg_index] # type: loc_t
270 arg_index += 1
271 has_arg = True
272 else:
273 s = ''
274 word_loc = loc.Missing
275 has_arg = False
276
277 # Note: %s could be lexed into Id.Percent_S. Although small string
278 # optimization would remove the allocation as well.
279 typ = lexer.TokenVal(part.type)
280 if typ == 's':
281 if precision >= 0:
282 s = s[:precision] # truncate
283
284 elif typ == 'q':
285 # Most shells give \' for single quote, while OSH gives
286 # $'\'' this could matter when SSH'ing.
287 # Ditto for $'\\' vs. '\'
288
289 s = j8_lite.MaybeShellEncode(s)
290
291 elif typ == 'b':
292 # Process just like echo -e, except \c handling is simpler.
293
294 c_parts = [] # type: List[str]
295 lex = match.EchoLexer(s)
296 while True:
297 id_, tok_val = lex.Next()
298 if id_ == Id.Eol_Tok: # Note: This is really a NUL terminator
299 break
300
301 # Note: DummyToken is OK because EvalCStringToken() doesn't have
302 # any syntax errors.
303 tok = lexer.DummyToken(id_, tok_val)
304 p = word_compile.EvalCStringToken(tok)
305
306 # Unusual behavior: '\c' aborts processing!
307 if p is None:
308 backslash_c = True
309 break
310
311 c_parts.append(p)
312 s = ''.join(c_parts)
313
314 elif part.type.id == Id.Format_Time or typ in 'diouxX':
315 # %(...)T and %d share this complex integer conversion logic
316
317 try:
318 # note: spaces like ' -42 ' accepted and normalized
319 d = int(s)
320 except ValueError:
321 # 'a is interpreted as the ASCII value of 'a'
322 if len(s) >= 1 and s[0] in '\'"':
323 # TODO: utf-8 decode s[1:] to be more correct. Probably
324 # depends on issue #366, a utf-8 library.
325 # Note: len(s) == 1 means there is a NUL (0) after the quote..
326 d = ord(s[1]) if len(s) >= 2 else 0
327
328 # No argument means -1 for %(...)T as in Bash Reference Manual
329 # 4.2 "If no argument is specified, conversion behaves as if -1
330 # had been given."
331 elif not has_arg and part.type.id == Id.Format_Time:
332 d = -1
333
334 else:
335 if has_arg:
336 blame_loc = word_loc # type: loc_t
337 else:
338 blame_loc = part.type
339 self.errfmt.Print_(
340 'printf expected an integer, got %r' % s,
341 blame_loc)
342 return 1
343
344 if part.type.id == Id.Format_Time:
345 # Initialize timezone:
346 # `localtime' uses the current timezone information initialized
347 # by `tzset'. The function `tzset' refers to the environment
348 # variable `TZ'. When the exported variable `TZ' is present,
349 # its value should be reflected in the real environment
350 # variable `TZ' before call of `tzset'.
351 #
352 # Note: unlike LANG, TZ doesn't seem to change behavior if it's
353 # not exported.
354 #
355 # TODO: In YSH, provide an API that doesn't rely on libc's global
356 # state.
357
358 tzcell = self.mem.GetCell('TZ')
359 if tzcell and tzcell.exported and tzcell.val.tag(
360 ) == value_e.Str:
361 tzval = cast(value.Str, tzcell.val)
362 posix.putenv('TZ', tzval.s)
363
364 time_.tzset()
365
366 # Handle special values:
367 # User can specify two special values -1 and -2 as in Bash
368 # Reference Manual 4.2: "Two special argument values may be
369 # used: -1 represents the current time, and -2 represents the
370 # time the shell was invoked." from
371 # https://www.gnu.org/software/bash/manual/html_node/Bash-Builtins.html#index-printf
372 if d == -1: # the current time
373 ts = time_.time()
374 elif d == -2: # the shell start time
375 ts = self.shell_start_time
376 else:
377 ts = d
378
379 s = time_.strftime(typ[1:-2], time_.localtime(ts))
380 if precision >= 0:
381 s = s[:precision] # truncate
382
383 else: # typ in 'diouxX'
384 # Disallowed because it depends on 32- or 64- bit
385 if d < 0 and typ in 'ouxX':
386 e_die(
387 "Can't format negative number %d with %%%s"
388 % (d, typ), part.type)
389
390 if typ == 'o':
391 s = mylib.octal(d)
392 elif typ == 'x':
393 s = mylib.hex_lower(d)
394 elif typ == 'X':
395 s = mylib.hex_upper(d)
396 else: # diu
397 s = str(d) # without spaces like ' -42 '
398
399 # There are TWO different ways to ZERO PAD, and they differ on
400 # the negative sign! See spec/builtin-printf
401
402 zero_pad = 0 # no zero padding
403 if width >= 0 and '0' in flags:
404 zero_pad = 1 # style 1
405 elif precision > 0 and len(s) < precision:
406 zero_pad = 2 # style 2
407
408 if zero_pad:
409 negative = (s[0] == '-')
410 if negative:
411 digits = s[1:]
412 sign = '-'
413 if zero_pad == 1:
414 # [%06d] -42 becomes [-00042] (6 TOTAL)
415 n = width - 1
416 else:
417 # [%6.6d] -42 becomes [-000042] (1 for '-' + 6)
418 n = precision
419 else:
420 digits = s
421 sign = ''
422 if zero_pad == 1:
423 n = width
424 else:
425 n = precision
426 s = sign + digits.rjust(n, '0')
427
428 else:
429 raise AssertionError()
430
431 if width >= 0:
432 if '-' in flags:
433 s = s.ljust(width, ' ')
434 else:
435 s = s.rjust(width, ' ')
436
437 out.append(s)
438
439 else:
440 raise AssertionError()
441
442 if backslash_c: # 'printf %b a\cb xx' - \c terminates processing!
443 break
444
445 if arg_index == 0:
446 # We went through ALL parts and didn't consume ANY arg.
447 # Example: print x y
448 break
449 if arg_index >= num_args:
450 # We printed all args
451 break
452 # There are more arg: Implement the 'arg recycling' behavior.
453
454 return 0
455
456 def Run(self, cmd_val):
457 # type: (cmd_value.Argv) -> int
458 """
459 printf: printf [-v var] format [argument ...]
460 """
461 attrs, arg_r = flag_util.ParseCmdVal('printf', cmd_val)
462 arg = arg_types.printf(attrs.attrs)
463
464 fmt, fmt_loc = arg_r.ReadRequired2('requires a format string')
465 varargs, locs = arg_r.Rest2()
466
467 #log('fmt %s', fmt)
468 #log('vals %s', vals)
469
470 arena = self.parse_ctx.arena
471 if fmt in self.parse_cache:
472 parts = self.parse_cache[fmt]
473 else:
474 line_reader = reader.StringLineReader(fmt, arena)
475 # TODO: Make public
476 lexer = self.parse_ctx.MakeLexer(line_reader)
477 parser = _FormatStringParser(lexer)
478
479 with alloc.ctx_SourceCode(arena,
480 source.ArgvWord('printf', fmt_loc)):
481 try:
482 parts = parser.Parse()
483 except error.Parse as e:
484 self.errfmt.PrettyPrintError(e)
485 return 2 # parse error
486
487 self.parse_cache[fmt] = parts
488
489 if 0:
490 print()
491 for part in parts:
492 part.PrettyPrint()
493 print()
494
495 out = [] # type: List[str]
496 status = self._Format(parts, varargs, locs, out)
497 if status != 0:
498 return status # failure
499
500 result = ''.join(out)
501 if arg.v is not None:
502 # TODO: get the location for arg.v!
503 v_loc = loc.Missing
504 lval = self.unsafe_arith.ParseLValue(arg.v, v_loc)
505 state.BuiltinSetValue(self.mem, lval, value.Str(result))
506 else:
507 mylib.Stdout().write(result)
508 return 0