"""
the module to manipulate mathematical expressions (e.g. ``2 * n + 1``, ``a_i + i``)
この module は数式を簡約します。
たとえば ``(n + 1) + (n - 1)`` という式が与えられれば ``2 * n`` という式に簡約して返します。
これは選言標準形のようなものを構成して並べ替えることによって実装されています。
中心部分では、次のような競技プログラミングの問題を解いています:
::
変数 x, y, z および整数と四則演算と括弧からなる数式が与えられます。
この数式と等しい数式であって $\sum k_i x^{a_i} y^{b_i} z^{c_i}$ という形のものを求めてください。
"""
# TODO: move and split this module?
import abc
import fractions
import re
from logging import getLogger
from typing import *
import ply.lex as lex
import ply.yacc as yacc
from onlinejudge_template.types import *
logger = getLogger(__name__)
[docs]class ExprParserError(AnalyzerError):
pass
class _Expr(abc.ABC):
pass
class _Variable(_Expr):
"""_Variable represents a symbol whose value is not fixed.
"""
def __init__(self, name: str, *args: _Expr):
self.name = name
self.args = args
def __eq__(self, other):
return isinstance(other, self.__class__) and self.name == other.name and self.args == other.args
class _Function(_Expr):
"""_Function represents an n-ary symbol (n >= 1) whose value is fixed.
"""
ADD = '__add__'
SUB = '__sub__'
MUL = '__mul__'
DIV = '__div__'
NEG = '__neg__'
def __init__(self, value: str, *args: _Expr):
self.value = value
self.args = args
def __eq__(self, other):
return isinstance(other, self.__class__) and self.value == other.value and self.args == other.args
class _Constant(_Expr):
"""_Constant represents a 0-ary symbol whose value is fixed.
"""
def __init__(self, value: int):
self.value = value
def __eq__(self, other):
return isinstance(other, self.__class__) and self.value == other.value
_tokens = (
'IDENT',
'NUMBER',
'UNDERSCORE',
'LBRACE',
'RBRACE',
'COMMA',
'LPAREN',
'RPAREN',
'ADD',
'SUB',
'MUL',
'DIV',
)
def _build_lexer() -> lex.Lexer:
tokens = _tokens
t_ignore = ' '
def t_error(t: lex.LexToken) -> None:
raise ExprParserError("lexer: unexpected character: '{}' at line {} column {}".format(t.value[0], t.lineno, t.lexpos))
t_IDENT = r'[A-Za-z]+'
t_NUMBER = r'[0-9]+'
t_UNDERSCORE = r'_'
t_LBRACE = r'{'
t_RBRACE = r'}'
t_COMMA = r','
t_LPAREN = r'\('
t_RPAREN = r'\)'
t_ADD = r'\+'
t_SUB = r'-'
t_MUL = r'\*'
t_DIV = r'/'
return lex.lex()
def _build_parser(*, input: str) -> yacc.LRParser:
tokens = _tokens
def find_column(lexpos: int) -> int:
line_start = input.rfind('\n', 0, lexpos) + 1
return lexpos - line_start + 1
def loc(p: yacc.YaccProduction) -> Dict[str, int]:
return {
'line': p.lineno(1),
'column': find_column(p.lexpos(1)),
}
def p_expr(p: yacc.YaccProduction) -> None:
"""expr : expr ADD term
| expr SUB term
| term"""
if len(p) == 4:
op = {'+': _Function.ADD, '-': _Function.SUB}[p[2]]
p[0] = _Function(op, p[1], p[3])
elif len(p) == 2:
p[0] = p[1]
def p_exprs(p: yacc.YaccProduction) -> None:
"""exprs : expr COMMA exprs
| expr"""
if len(p) == 4:
p[0] = (p[1], *p[3])
elif len(p) == 2:
p[0] = (p[1], )
def p_term(p: yacc.YaccProduction) -> None:
"""term : term MUL factor
| term DIV factor
| SUB term
| factor"""
if len(p) == 4:
op = {'*': _Function.MUL, '/': _Function.DIV}[p[2]]
p[0] = _Function(op, p[1], p[3])
elif len(p) == 3:
p[0] = _Function(_Function.NEG, p[2])
elif len(p) == 2:
p[0] = p[1]
def p_factor(p: yacc.YaccProduction) -> None:
"""factor : number variable
| number
| variable
| LPAREN expr RPAREN
| number LPAREN expr RPAREN"""
if len(p) == 3:
p[0] = _Function(_Function.MUL, p[1], p[2])
elif len(p) == 2:
p[0] = p[1]
elif len(p) == 4:
p[0] = p[2]
elif len(p) == 5:
p[0] = _Function(_Function.MUL, p[1], p[3])
def p_number(p: yacc.YaccProduction) -> None:
"""number : NUMBER"""
p[0] = _Constant(int(p[1]))
def p_variable(p: yacc.YaccProduction) -> None:
"""variable : IDENT UNDERSCORE IDENT
| IDENT UNDERSCORE number
| IDENT UNDERSCORE LBRACE exprs RBRACE
| IDENT"""
if len(p) == 4 and isinstance(p[3], str):
p[0] = _Variable(p[1], _Variable(p[3]))
elif len(p) == 4 and isinstance(p[3], _Expr):
p[0] = _Variable(p[1], p[3])
elif len(p) == 6:
p[0] = _Variable(p[1], *p[4])
elif len(p) == 2:
p[0] = _Variable(p[1])
def p_error(t: Optional[lex.LexToken]) -> None:
if t is None:
raise ExprParserError("parser: something wrong")
else:
raise ExprParserError("parser: unexpected token: {} \"{}\" at line {} column {}".format(t.type, t.value, t.lineno, find_column(t.lexpos)))
return yacc.yacc(debug=False, write_tables=False)
def _parse(s: str) -> _Expr:
"""
:raises ExprParserError:
"""
try:
lexer = _build_lexer()
lexer.input(s)
parser = _build_parser(input=s)
return parser.parse(lexer=lexer)
except ExprParserError as e:
logger.debug('failed to parse {}: {}'.format(repr(s), e))
raise
def _format(e: _Expr) -> str:
def with_paren(s: str, *, cur_prec: int, prev_prec: int, paren: str) -> str:
if cur_prec >= prev_prec:
return s
if paren == '()':
return '(' + s + ')'
elif paren == '{}':
return '{' + s + '}'
else:
assert False
def go(e: _Expr, *, prec: int, paren: str = '()') -> str:
if isinstance(e, _Variable):
if len(e.args) == 0:
return e.name
elif len(e.args) == 1:
return e.name + '_' + go(e.args[0], prec=2, paren='{}')
else:
return e.name + '_{' + ', '.join(map(lambda arg: go(arg, prec=0), e.args)) + '}'
elif isinstance(e, _Function):
if e.value == _Function.ADD and len(e.args) == 2:
return with_paren(go(e.args[0], prec=1) + ' + ' + go(e.args[1], prec=1), cur_prec=1, prev_prec=prec, paren=paren)
elif e.value == _Function.SUB and len(e.args) == 2:
return with_paren(go(e.args[0], prec=1) + ' - ' + go(e.args[1], prec=1), cur_prec=1, prev_prec=prec, paren=paren)
elif e.value == _Function.MUL and len(e.args) == 2:
return with_paren(go(e.args[0], prec=2) + ' * ' + go(e.args[1], prec=2), cur_prec=2, prev_prec=prec, paren=paren)
elif e.value == _Function.DIV and len(e.args) == 2:
return with_paren(go(e.args[0], prec=2) + ' / ' + go(e.args[1], prec=2), cur_prec=2, prev_prec=prec, paren=paren)
elif e.value == _Function.NEG and len(e.args) == 1:
return with_paren('- ' + go(e.args[0], prec=2), cur_prec=2, prev_prec=prec, paren=paren)
else:
assert False
elif isinstance(e, _Constant):
return str(e.value)
else:
assert False
return go(e, prec=0)
def _get_subscripted_value(value: Union[int, List[int], List[List[int]], List[List[List[int]]]], args: List[int], *, name_for_error_message: str) -> int:
"""
:raises ExprParserError:
"""
result: Any = value
for depth, index in enumerate(args):
if not isinstance(result, list):
raise ExprParserError('{} is expected to have type int^{} -> int, but actually has type int^{} -> {}'.format(name_for_error_message, len(args), depth + 1, type(result).__name__))
result = result[index]
if not isinstance(result, int):
raise ExprParserError('{} is expected to have type int^{} -> int, but actually has type int^{} -> {}'.format(name_for_error_message, len(args), len(args), type(result).__name__))
return result
[docs]def evaluate(s: Expr, *, env: Mapping[VarName, Union[int, List[int], List[List[int]], List[List[List[int]]]]] = {}) -> Optional[int]:
"""evaluate converts the given expr to an integer.
"""
def go(e: _Expr) -> fractions.Fraction:
if isinstance(e, _Variable):
if e.name not in env:
raise ExprParserError('{} is not defined'.format(e.name))
args: List[fractions.Fraction] = list(map(go, e.args))
indices: List[int] = []
for arg in args:
if arg.denominator != 1:
raise ExprParserError('indices must be an integer, not fraction: {}[{}]'.format(e.name, ', '.join(map(str, args))))
indices.append(arg.numerator)
return fractions.Fraction(_get_subscripted_value(env[VarName(e.name)], indices, name_for_error_message=e.name))
elif isinstance(e, _Function):
args = list(map(go, e.args))
if e.value == _Function.ADD and len(e.args) == 2:
return args[0] + args[1]
elif e.value == _Function.SUB and len(e.args) == 2:
return args[0] - args[1]
elif e.value == _Function.MUL and len(e.args) == 2:
return args[0] * args[1]
elif e.value == _Function.DIV and len(e.args) == 2:
return args[0] / args[1]
elif e.value == _Function.NEG and len(e.args) == 1:
return -args[0]
else:
assert False
elif isinstance(e, _Constant):
return fractions.Fraction(e.value)
else:
assert False
try:
expr = _parse(s)
except ExprParserError as e:
logger.debug('failed to parse {}: {}'.format(repr(s), e))
return None
try:
evaluated = go(expr)
except ExprParserError as e:
logger.debug('failed to evaluate {} in {}: {}'.format(repr(s), env, e))
return None
if evaluated.denominator != 1:
logger.debug('failed to evaluate {} in {}: {} is not an integer'.format(repr(s), env, evaluated))
return None
return evaluated.numerator
def _convert_to_dnf(e: _Expr) -> List[Tuple[List[_Expr], List[_Expr]]]:
"""_convert_to_dnf converts exprs to a format like disjunction normal form (DNF).
This also simplifies the subscripted exprs.
"""
# TODO: Rename variables. We can see $a_i + 2 a_{i + 1} + 3 a_{i + 2}$ as $x + 2 y + 3 z$ or just $v_0 + 2 v_1 + 3 v_2$. Replacing the variables with their indices will make the implementation simpler.
if isinstance(e, _Variable):
args = list(map(_simplify_expr, e.args))
return [([_Variable(e.name, *args)], [])]
elif isinstance(e, _Function):
if e.value == _Function.ADD and len(e.args) == 2:
lhs = _convert_to_dnf(e.args[0])
rhs = _convert_to_dnf(e.args[1])
return lhs + rhs
elif e.value == _Function.SUB and len(e.args) == 2:
lhs = _convert_to_dnf(e.args[0])
rhs = _convert_to_dnf(e.args[1])
return lhs + [([_Constant(-1), *num], den) for num, den in rhs]
elif e.value == _Function.MUL and len(e.args) == 2:
lhs = _convert_to_dnf(e.args[0])
rhs = _convert_to_dnf(e.args[1])
return sum([[(num1 + num2, den1 + den2) for num2, den2 in rhs] for num1, den1 in lhs], [])
elif e.value == _Function.DIV and len(e.args) == 2:
lhs = _convert_to_dnf(e.args[0])
rhs = _convert_to_dnf(e.args[1])
return sum([[(num1 + den2, den1 + num2) for num2, den2 in rhs] for num1, den1 in lhs], [])
elif e.value == _Function.NEG and len(e.args) == 1:
rhs = _convert_to_dnf(e.args[0])
return [([_Constant(-1), *num], den) for num, den in rhs]
else:
assert False
elif isinstance(e, _Constant):
return [([e], [])]
else:
assert False
def _simplify_dnf(dnf: List[Tuple[List[_Expr], List[_Expr]]]) -> Dict[Tuple[Tuple[str, ...], Tuple[str, ...]], fractions.Fraction]:
"""
:raises ExprParserError:
"""
freq: Dict[Tuple[Tuple[str, ...], Tuple[str, ...]], fractions.Fraction] = {}
for old_num, old_den in dnf:
# collect constants
coeff = fractions.Fraction(1)
num = []
den = []
for e in old_num:
if isinstance(e, _Variable):
num.append(_format(e))
elif isinstance(e, _Constant):
coeff *= e.value
else:
assert False
for e in old_den:
if isinstance(e, _Variable):
str_e = _format(e)
if str_e in num:
num.remove(str_e)
else:
den.append(str_e)
elif isinstance(e, _Constant):
if e.value == 0:
raise ExprParserError('division by zero')
coeff /= e.value
else:
assert False
# count up
str_num = tuple(sorted(num))
str_den = tuple(sorted(den))
if (str_num, str_den) not in freq:
freq[(str_num, str_den)] = fractions.Fraction(0)
freq[(str_num, str_den)] += coeff
return freq
def _convert_from_dnf(freq: Dict[Tuple[Tuple[str, ...], Tuple[str, ...]], fractions.Fraction]) -> _Expr:
"""
:raises ExprParserError:
"""
exprs: List[str] = []
factors = sorted(freq.items())
if factors and not factors[0][0][0] and not factors[0][0][1]:
# move the constant factor to the back
factors = factors[1:] + factors[:1]
for (num, den), coeff in factors:
if coeff == 0:
continue
if coeff.denominator != 1:
raise ExprParserError('coefficient is not an integer: {} \\prod {} / \\prod {}'.format(coeff, num, den))
neg = False
if coeff < 0:
neg = True
coeff = -coeff
# construct a string for a factor
s = ''
if coeff == 1:
pass
else:
s += str(coeff.numerator)
if s and num:
s += ' * '
s += ' * '.join(['(' + e + ')' for e in num])
if not s:
s += '1'
if den:
s += ' / '
s += ' / '.join(['(' + e + ')' for e in den])
if neg:
s = '- ' + s
if exprs and not s.startswith('-'):
s = '+ ' + s
exprs.append(s)
s = ' '.join(exprs)
if not s:
s += '0'
return _parse(s)
def _simplify_expr(e: _Expr) -> _Expr:
"""
:raises ExprParserError:
"""
return _convert_from_dnf(_simplify_dnf(_convert_to_dnf(e)))
[docs]def simplify(s: Expr) -> Expr:
"""simplify converts the given expr to a simple expr.
"""
try:
expr = _parse(s)
except ExprParserError as e:
logger.debug('failed to parse %s: %s', repr(s), e)
return s
try:
simplified = _simplify_expr(expr)
except ExprParserError as e:
logger.debug('failed to simplify %s: %s', repr(s), e)
return s
return Expr(_format(simplified))
[docs]def parse_subscripted_variable(s: str) -> Tuple[str, List[str]]:
"""parse_subscripted_variable is an inverse of format_subscripted_variable.
:raises ExprParserError: if not a variable
"""
expr = _parse(s)
if not isinstance(expr, _Variable):
raise ExprParserError('not a subscripted variable: {}'.format(s))
return expr.name, list(map(_format, expr.args))
[docs]def rename_variables_in_expr(expr: Expr, *, replace: Dict[VarName, VarName]) -> Expr:
"""
:raises ExprParserError:
"""
pattern = r'[A-Za-z]+|[^A-Za-z]+'
s = []
for c in re.findall(pattern, str(expr)):
if c.isalpha() and VarName(c) in replace:
s.append(str(replace[VarName(c)]))
else:
s.append(c)
return Expr(_format(_parse(''.join(s))))