"""
the module to parse format strings and construct format trees
この module はフォーマット文字列を構文解析しフォーマット木を作ります。
たとえば
::
N
P_0 P_1 \cdots P_{N-1}
Q_0 Q_1 \cdots Q_{N-1}
という入力フォーマット文字列が与えられれば
::
sequence([
item("N"),
newline(),
loop(counter="i", size="N",
item("P", indices="i")
),
newline(),
loop(counter="i", size="N",
item("Q", indices="i")
),
newline(),
])
に相当する木構造 (:any:`FormatNode`) を返します。
"""
import abc
import re
from logging import getLogger
from typing import *
import ply.lex as lex
import ply.yacc as yacc
from onlinejudge_template.analyzer.simplify import simplify
from onlinejudge_template.types import *
logger = getLogger(__name__)
tokens = (
'NEWLINE',
# 'SPACE',
# 'DOLLAR',
# 'VAR_OPEN',
# 'VAR_CLOSE',
'FONTSPEC',
'IDENT',
'NUMBER',
'UNDERSCORE',
'LBRACE',
'RBRACE',
'COMMA',
'ADD',
'SUB',
'MUL',
'DIV',
'VDOTS',
'DOTS',
)
[docs]def build_lexer() -> lex.Lexer:
def t_NEWLINE(t: lex.LexToken) -> lex.LexToken:
r"""(\r?\n|<br>)"""
t.lexer.lineno += 1
return t
t_ignore = ' \t'
def t_tex_space(t: lex.LexToken) -> None:
r"""(\\[ ]|\\,|\\:|\\;|\\!|~|\\quad|\\qquad|\\hspace\{[^}]+\})"""
return None
def t_math_mode(t: lex.LexToken) -> None:
r"""(\$|\\\(|\\\)|\\\[|\\\]|<var>|</var>)"""
return None
def t_error(t: lex.LexToken) -> None:
raise FormatStringParserError("lexer: unexpected character: '{}' at line {} column {}".format(t.value[0], t.lineno, t.lexpos))
# t_DOLLAR = r'\$'
# t_VAR_OPEN = r'<\s*[vV][aA][rR]\s*>'
# t_VAR_CLOSE = r'<\s*/\s*[vV][aA][rR]\s*>'
def t_FONTSPEC(t: lex.LexToken) -> lex.LexToken:
r"""\\(rm|mathrm|mathtt|mathbf|mathit|mathscr|mathcal|mathfrak|mathbb)"""
return t
t_IDENT = r'[A-Za-z]+'
t_NUMBER = r'[0-9]+'
t_UNDERSCORE = r'_'
t_LBRACE = r'{'
t_RBRACE = r'}'
t_COMMA = r','
t_ADD = r'\+'
t_SUB = r'-'
t_MUL = r'(\*|×|\\times)'
t_DIV = r'/'
t_DOTS = r'(\.\.\.*|…|\\dots|\\ldots|\\cdots)'
t_VDOTS = r'(:|⋮|\\vdots)'
return lex.lex()
[docs]class ParserNode(abc.ABC):
"""
an internal representation which Yacc generates
"""
line: int
column: int
def __init__(self, *, line: int, column: int):
self.line = line
self.column = column
def __repr__(self) -> str:
keys = dir(self)
keys = list(filter(lambda key: not key.startswith('_'), keys))
keys.sort()
keys.remove('line')
keys.remove('column')
keys.append('line')
keys.append('column')
items = ', '.join([key + '=' + repr(getattr(self, key)) for key in keys])
return f"{self.__class__.__name__}({items})"
[docs]class SequenceParserNode(ParserNode):
items: List[ParserNode]
def __init__(self, *, items: List[ParserNode], line: int, column: int):
super().__init__(line=line, column=column)
self.items = items
[docs]class NewlineParserNode(ParserNode):
pass
[docs]class ItemParserNode(ParserNode):
name: VarName
indices: Tuple[Expr, ...]
def __init__(self, *, name: str, indices: Tuple[str, ...] = (), line: int, column: int):
super().__init__(line=line, column=column)
self.name = VarName(name)
self.indices = tuple(map(Expr, indices))
[docs]class DotsParserNode(ParserNode):
first: ParserNode
last: ParserNode
def __init__(self, *, first: ParserNode, last: ParserNode, line: int, column: int):
super().__init__(line=line, column=column)
self.first = first
self.last = last
[docs]def build_parser(*, input: str) -> yacc.LRParser:
def find_column(lexpos: int) -> int:
line_start = input.rfind('\n', 0, lexpos) + 1
return lexpos - line_start + 1
def loc(p: yacc.YaccProduction) -> Dict[str, int]:
return {
'line': p.lineno(1),
'column': find_column(p.lexpos(1)),
}
def p_main(p: yacc.YaccProduction) -> None:
"""main : lines main
| lines"""
if len(p) == 3:
p[0] = SequenceParserNode(items=[p[1]] + p[2].items, **loc(p))
elif len(p) == 2:
p[0] = SequenceParserNode(items=[p[1]], **loc(p))
def p_lines(p: yacc.YaccProduction) -> None:
"""lines : line
| line VDOTS newline line
| line DOTS newline line"""
if len(p) == 2:
p[0] = p[1]
elif len(p) == 5:
p[0] = DotsParserNode(first=p[1], last=p[4], **loc(p))
def p_newline(p: yacc.YaccProduction) -> None:
"""newline : NEWLINE"""
p[0] = NewlineParserNode(**loc(p))
def p_line(p: yacc.YaccProduction) -> None:
"""line : items newline"""
p[0] = SequenceParserNode(items=p[1].items + [p[2]], **loc(p))
def p_items(p: yacc.YaccProduction) -> None:
"""items : item DOTS item items
| item DOTS item
| item items
| item"""
if len(p) == 5:
dots = DotsParserNode(first=p[1], last=p[3], **loc(p))
p[0] = SequenceParserNode(items=[dots] + p[4].items, **loc(p))
if len(p) == 4:
dots = DotsParserNode(first=p[1], last=p[3], **loc(p))
p[0] = SequenceParserNode(items=[dots], **loc(p))
elif len(p) == 3:
p[0] = SequenceParserNode(items=[p[1]] + p[2].items, **loc(p))
elif len(p) == 2:
p[0] = SequenceParserNode(items=[p[1]], **loc(p))
def p_item(p: yacc.YaccProduction) -> None:
"""item : IDENT
| IDENT UNDERSCORE NUMBER
| IDENT UNDERSCORE IDENT
| IDENT UNDERSCORE LBRACE exprs RBRACE
| FONTSPEC LBRACE item RBRACE
| LBRACE FONTSPEC item RBRACE"""
if len(p) == 2:
p[0] = ItemParserNode(name=p[1], indices=(), **loc(p))
elif len(p) == 4:
p[0] = ItemParserNode(name=p[1], indices=(p[3], ), **loc(p))
elif len(p) == 6:
p[0] = ItemParserNode(name=p[1], indices=p[4], **loc(p))
elif len(p) == 5:
p[0] = p[3]
def p_exprs(p: yacc.YaccProduction) -> None:
"""exprs : expr COMMA exprs
| expr"""
if len(p) == 4:
p[0] = (p[1], *p[3])
elif len(p) == 2:
p[0] = (p[1], )
def p_expr(p: yacc.YaccProduction) -> None:
"""expr : IDENT
| NUMBER
| NUMBER IDENT
| IDENT binop expr
| NUMBER binop expr"""
if len(p) == 2:
p[0] = p[1]
elif len(p) == 3:
p[0] = f"""{p[1]} * {p[2]}"""
elif len(p) == 4:
p[0] = f"""{p[1]} {p[2]} {p[3]}"""
def p_binop(p: yacc.YaccProduction) -> None:
"""binop : ADD
| SUB
| MUL
| DIV"""
p[0] = p[1]
def p_error(t: lex.LexToken) -> None:
raise FormatStringParserError("parser: unexpected token: {} \"{}\" at line {} column {}".format(t.type, t.value, t.lineno, find_column(t.lexpos)))
return yacc.yacc(debug=False, write_tables=False)
[docs]def list_used_names(node: FormatNode) -> Set[str]:
if isinstance(node, ItemNode):
return set([node.name])
elif isinstance(node, NewlineNode):
return set()
elif isinstance(node, SequenceNode):
names: Set[str] = set()
for item in node.items:
names |= list_used_names(item)
return names
elif isinstance(node, LoopNode):
return set([node.name]) | list_used_names(node.body)
else:
assert False
[docs]def zip_nodes(a: FormatNode, b: FormatNode, *, name: VarName, size: Optional[Expr]) -> Tuple[FormatNode, Optional[Expr]]:
"""
:raises FormatStringParserError:
"""
if isinstance(a, ItemNode) and isinstance(b, ItemNode):
if a.name != b.name or len(a.indices) != len(b.indices):
raise FormatStringParserError("semantics: unmatched dots pair: {} and {}".format(a, b))
indices = []
for i, j in zip(a.indices, b.indices):
if simplify(i) == simplify(j):
indices.append(i)
else:
if size is None:
size = simplify(Expr(f"""{j} - {i} + 1"""))
else:
if simplify(Expr(f"""{j} - {i} + 1""")) != simplify(size):
raise FormatStringParserError("semantics: unmatched dots pair: {} and {}".format(a, b))
indices.append(simplify(Expr(f"{i} + {name}")))
return ItemNode(name=a.name, indices=indices), size
elif isinstance(a, NewlineNode) and isinstance(b, NewlineNode):
return NewlineNode(), size
elif isinstance(a, SequenceNode) and isinstance(b, SequenceNode):
if len(a.items) != len(b.items):
raise FormatStringParserError("semantics: unmatched dots pair: {} and {}".format(a, b))
items = []
for a_i, b_i in zip(a.items, b.items):
c_i, size = zip_nodes(a_i, b_i, name=name, size=size)
items.append(c_i)
return SequenceNode(items=items), size
elif isinstance(a, LoopNode) and isinstance(b, LoopNode):
if a.size != b.size or a.name != b.name:
raise FormatStringParserError("semantics: unmatched dots pair: {} and {}".format(a, b))
c, size = zip_nodes(a.body, b.body, name=name, size=size)
return LoopNode(size=a.size, name=a.name, body=c), size
else:
raise FormatStringParserError("semantics: unmatched dots pair: {} and {}".format(a, b))
[docs]def extend_loop_node(a: FormatNode, b: FormatNode, *, loop: LoopNode) -> Optional[FormatNode]:
if isinstance(a, ItemNode) and isinstance(b, ItemNode):
if a.name != b.name or len(a.indices) != len(b.indices):
return None
indices = []
for i, j in zip(a.indices, b.indices):
decr_j = Expr(re.subn(r'\b' + re.escape(loop.name) + r'\b', '(-1)', j)[0])
if simplify(i) == simplify(decr_j):
indices.append(simplify(Expr(f"""{i} + {loop.name}""")))
else:
return None
return ItemNode(name=a.name, indices=indices)
elif isinstance(a, NewlineNode) and isinstance(b, NewlineNode):
return NewlineNode()
elif isinstance(a, SequenceNode) and isinstance(b, SequenceNode):
if len(a.items) != len(b.items):
return None
items = []
for a_i, b_i in zip(a.items, b.items):
c_i = extend_loop_node(a_i, b_i, loop=loop)
if c_i is None:
return None
items.append(c_i)
return SequenceNode(items=items)
elif isinstance(a, LoopNode) and isinstance(b, LoopNode):
if a.size != b.size or a.name != b.name:
return None
c = extend_loop_node(a.body, b.body, loop=loop)
if c is None:
return None
return LoopNode(size=a.size, name=a.name, body=c)
else:
return None
[docs]def analyze_parsed_node(node: ParserNode) -> FormatNode:
"""
translates an internal representation :any:`ParserNode` to a result tree :any:`FormatNode`
:raises FormatStringParserError:
"""
if isinstance(node, ItemParserNode):
indices = [simplify(index) for index in node.indices]
return ItemNode(name=node.name, indices=indices)
elif isinstance(node, NewlineParserNode):
return NewlineNode()
elif isinstance(node, SequenceParserNode):
items: List[FormatNode] = []
que: List[FormatNode] = list(map(analyze_parsed_node, node.items))
while que:
item, *que = que
if isinstance(item, SequenceNode):
# flatten SequenceNode in SequenceNode
que = item.items + que
elif isinstance(item, LoopNode) and items:
# merge FormatNode with LoopNode if possible
if isinstance(item.body, SequenceNode) and len(items) >= len(item.body.items):
items_init = items[:-len(item.body.items)]
items_tail: FormatNode = SequenceNode(items=items[-len(item.body.items):])
else:
items_init = items[:-1]
items_tail = items[-1]
extended_body = extend_loop_node(items_tail, item.body, loop=item)
if extended_body is not None:
extended_loop: FormatNode = LoopNode(size=simplify(Expr(f"""{item.size} + 1""")), name=item.name, body=extended_body)
items = items_init
que = [extended_loop] + que
else:
items.append(item)
else:
items.append(item)
if len(items) == 1:
# return the node directly if the length is 1
return items[0]
else:
return SequenceNode(items=items)
elif isinstance(node, DotsParserNode):
a = analyze_parsed_node(node.first)
b = analyze_parsed_node(node.last)
# find the name of the new loop counter
used_names = list_used_names(a) | list_used_names(b)
name = VarName('i')
while name in used_names:
assert name != VarName('z')
name = VarName(chr(ord(name) + 1))
# zip bodies
c, size = zip_nodes(a, b, name=name, size=None)
if size is None:
raise FormatStringParserError("semantics: unmatched dots pair: {} and {}".format(a, b))
return LoopNode(size=size, name=name, body=c)
else:
assert False
[docs]def run(pre: str) -> FormatNode:
"""
:raises FormatStringParserError:
"""
# list tokens with lex
lexer = build_lexer()
lexer.input(pre)
logger.debug('Lex tokens: %s', list(lexer.clone()))
# make a tree with yacc
parser = build_parser(input=pre)
parsed = parser.parse(lexer=lexer)
logger.debug('Yacc tree: %s', parsed)
# analyze the syntax tree
ast = analyze_parsed_node(parsed)
logger.debug('abstract syntax tree: %s', ast)
return ast