"""
the module to guess simple format trees from sample strings
この module はサンプル文字列から直接 (つまり、フォーマット文字列を用いずに) 典型的なフォーマット木を推測します。利用可能なサンプル文字列の個数がひとつしかない場合での利用が想定されています。
単純なフォーマット木を列挙しておき、それらとのパターンマッチをすることによって実装されています。
たとえば
::
6
1 3 8 7 10 2
というサンプル文字列から
::
sequence([
item("N"),
newline(),
loop(counter="i", size="N",
item("A", indices="i")
),
newline(),
])
のようなフォーマット木 (:any:`FormatNode`) を作ります。
"""
import functools
import itertools
import re
from logging import getLogger
from typing import *
import onlinejudge_template.analyzer.variables
from onlinejudge_template.analyzer.match import FormatMatchError, match_format
from onlinejudge_template.types import *
logger = getLogger(__name__)
[docs]class SimplePatternMatchingError(AnalyzerError):
pass
# simple patterns
_one_pattern = SequenceNode(items=[
ItemNode(name='a'),
NewlineNode(),
])
_two_pattern = SequenceNode(items=[
ItemNode(name='a'),
ItemNode(name='b'),
NewlineNode(),
])
_three_pattern = SequenceNode(items=[
ItemNode(name='a'),
ItemNode(name='b'),
ItemNode(name='c'),
NewlineNode(),
])
_four_pattern = SequenceNode(items=[
ItemNode(name='a'),
ItemNode(name='b'),
ItemNode(name='c'),
ItemNode(name='d'),
NewlineNode(),
])
_simple_patterns = [
_one_pattern,
_two_pattern,
_three_pattern,
_four_pattern,
]
# simple patterns (vertical versions)
_vertical_two_pattern = SequenceNode(items=[
ItemNode(name='a'),
NewlineNode(),
ItemNode(name='b'),
NewlineNode(),
])
_vertical_three_pattern = SequenceNode(items=[
ItemNode(name='a'),
NewlineNode(),
ItemNode(name='b'),
NewlineNode(),
ItemNode(name='c'),
NewlineNode(),
])
_vertical_four_pattern = SequenceNode(items=[
ItemNode(name='a'),
NewlineNode(),
ItemNode(name='b'),
NewlineNode(),
ItemNode(name='c'),
NewlineNode(),
ItemNode(name='d'),
NewlineNode(),
])
_vertical_simple_patterns = [
_two_pattern,
_three_pattern,
_four_pattern,
]
# one vector patterns
_length_and_vector_pattern = SequenceNode(items=[
ItemNode(name='n'),
NewlineNode(),
LoopNode(name='i', size='n', body=ItemNode(name='a', indices=['i'])),
NewlineNode(),
])
_length_and_vertical_vector_pattern = SequenceNode(items=[
ItemNode(name='n'),
NewlineNode(),
LoopNode(name='i', size='n', body=SequenceNode(items=[
ItemNode(name='a', indices=['i']),
NewlineNode(),
])),
])
_one_vector_patterns = [
_length_and_vector_pattern,
_length_and_vertical_vector_pattern,
]
# one vector patterns (with data)
_length_data_and_vector_pattern = SequenceNode(items=[
ItemNode(name='n'),
ItemNode(name='k'),
NewlineNode(),
LoopNode(name='i', size='n', body=ItemNode(name='a', indices=['i'])),
NewlineNode(),
])
_data_length_and_vector_pattern = SequenceNode(items=[
ItemNode(name='k'),
ItemNode(name='n'),
NewlineNode(),
LoopNode(name='i', size='n', body=ItemNode(name='a', indices=['i'])),
NewlineNode(),
])
_length_data_and_vertical_vector_pattern = SequenceNode(items=[
ItemNode(name='n'),
ItemNode(name='k'),
NewlineNode(),
LoopNode(name='i', size='n', body=SequenceNode(items=[
ItemNode(name='a', indices=['i']),
NewlineNode(),
])),
])
_data_length_and_vertical_vector_pattern = SequenceNode(items=[
ItemNode(name='k'),
ItemNode(name='n'),
NewlineNode(),
LoopNode(name='i', size='n', body=SequenceNode(items=[
ItemNode(name='a', indices=['i']),
NewlineNode(),
])),
])
_one_vector_with_data_patterns = [
_length_data_and_vector_pattern,
_data_length_and_vector_pattern,
_length_data_and_vertical_vector_pattern,
_data_length_and_vertical_vector_pattern,
]
# two vectors patterns
_length_and_two_vector_pattern = SequenceNode(items=[
ItemNode(name='n'),
NewlineNode(),
LoopNode(name='i', size='n', body=ItemNode(name='a', indices=['i'])),
NewlineNode(),
LoopNode(name='i', size='n', body=ItemNode(name='b', indices=['i'])),
NewlineNode(),
])
_length_and_vertical_two_vector_pattern = SequenceNode(items=[
ItemNode(name='n'),
NewlineNode(),
LoopNode(name='i', size='n', body=SequenceNode(items=[
ItemNode(name='a', indices=['i']),
ItemNode(name='b', indices=['i']),
NewlineNode(),
])),
])
_two_vectors_patterns = [
_length_and_two_vector_pattern,
_length_and_vertical_two_vector_pattern,
]
_all_patterns: List[FormatNode] = [
*_simple_patterns,
*_vertical_simple_patterns,
*_one_vector_patterns,
*_one_vector_with_data_patterns,
*_two_vectors_patterns,
]
[docs]@functools.lru_cache(maxsize=None)
def list_all_patterns() -> List[Tuple[FormatNode, Dict[VarName, VarDecl]]]:
"""list_all_patterns lists all pre-defined petterns.
"""
results: List[Tuple[FormatNode, Dict[VarName, VarDecl]]] = []
for pattern in _all_patterns:
try:
variables = onlinejudge_template.analyzer.variables.list_declared_variables(pattern)
results.append((pattern, variables))
except onlinejudge_template.analyzer.variables.DeclaredVariablesError:
assert False
return results
def _rename_variables_if_conflicts_dfs(node: FormatNode, *, mapping: Dict[VarName, Expr], env: Dict[VarName, VarDecl]) -> FormatNode:
def rename(s: str) -> str:
for a, b in mapping.items():
s = re.sub(r'\b' + re.escape(a) + r'\b', b, s)
return s
if isinstance(node, ItemNode):
assert node.name not in mapping # because there are only such patterns
if node.name not in env:
return node
else:
for i in itertools.count(1):
new_name = VarName(node.name + str(i))
if new_name not in env:
mapping[node.name] = Expr(new_name)
break
indices = list(map(rename, node.indices))
return ItemNode(name=mapping[node.name], indices=indices)
elif isinstance(node, NewlineNode):
return node
elif isinstance(node, SequenceNode):
items: List[FormatNode] = []
for item in node.items:
items.append(_rename_variables_if_conflicts_dfs(item, mapping=mapping, env=env))
return SequenceNode(items=items)
elif isinstance(node, LoopNode):
body = _rename_variables_if_conflicts_dfs(node.body, mapping=mapping, env=env)
return LoopNode(name=node.name, size=rename(node.size), body=body)
else:
assert False
[docs]def rename_variables_if_conflicts(node: FormatNode, *, env: Dict[VarName, VarDecl]) -> FormatNode:
return _rename_variables_if_conflicts_dfs(node, mapping={}, env=env)