"""
the module to find minimum format trees from sample strings
この module はサンプル文字列から直接 (つまり、フォーマット文字列を用いずに) フォーマット木を推測します。利用可能なサンプル文字列の個数がひとつしかない場合での利用が想定されています。
フォーマット木に対する評価関数を固定しておき、すべてのサンプル文字列とマッチするフォーマット木の中で最小のものを求めるという形で実装されています。
たとえば
::
3
1 2
3 4 1 2
2 4 1
および
::
1
2 0 8
というサンプル文字列から
::
sequence([
item("N"),
newline(),
loop(counter="i", size="N", sequence([
item("K_i"),
loop(counter="j", size="K_i",
item("A", indices=("i", "j"))
),
newline(),
])),
])
のようなフォーマット木 (:any:`FormatNode`) を作ります。
この例の場合は
::
sequence([
item("N"),
newline(),
loop(counter="i", size="N - 1", sequence([
item("K_i"),
loop(counter="j", size="K_i - 1",
item("A", indices=("i", "j"))
),
item("B", indices="i"),
newline(),
])),
item("L"),
loop(counter="i", size="L",
item("C", indices="i")
),
newline(),
])
というフォーマット木もこれらのサンプルにマッチしますが、これは木の大きさが最小ではないので作られません。
内部のデータ構造は Haskell 風に書くと以下のような感じになります。
`LoopNode` が持つふたつの `Int` は、ループの回数を表現する変数の de Bruijn index およびその変数を修正するための -1, 0, 1 のいずれかの数です。
木の一部は構築途中である場合があります。
:: haskell
data Token
= IntToken Int
| StrngToken
| NewlineToken
data Node m
= LoopNode Int Int (m (Node m)) (m (Node m))
| IntNode (m (Node m))
| StringNode (m (Node m))
| NewlineNode (m (Node m))
| EOFNode
match :: Node Maybe -> [Token] -> Maybe MatchState
...
size :: Node Maybe -> Int
size (LoopNode _ delta body next) = 1 + abs delta + size body + size next
size (IntNode next) = 1 + size next
size (StringNode next) = 1 + size next
size (NewlineNode next) = 1 + size next
size EOFNode = 1
"""
import abc
import heapq
import itertools
import string
from logging import getLogger
from typing import *
import onlinejudge_template.analyzer.node_util as node_util
from onlinejudge_template.analyzer.match import FormatMatchError, match_format
from onlinejudge_template.types import *
logger = getLogger(__name__)
class _Token(abc.ABC):
row: int
column: int
def __init__(self, *, row: int, column: int):
self.row = row
self.column = column
def __repr__(self) -> str:
return f"{self.__class__.__name__}(L{self.row}C{self.column})"
class _IntToken(_Token):
value: int
def __init__(self, *, value: int, row: int, column: int):
super().__init__(row=row, column=column)
self.value = value
def __repr__(self) -> str:
return f"{self.__class__.__name__}(L{self.row}C{self.column}, value={self.value})"
class _StringToken(_Token):
value: str
def __init__(self, *, value: str, row: int, column: int):
super().__init__(row=row, column=column)
self.value = value
class _NewlineToken(_Token):
pass
class _MatchState(NamedTuple):
tokens: List[_Token]
offset: int
env: List[int]
class _MatchStop(Exception):
def __init__(self, state: _MatchState):
self.state = state
class _Node(abc.ABC):
"""_Node is a node similar to FormatNode but is easy to use for optimization.
"""
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
class _PlaceholderNode(_Node):
pass
class _EOFNode(_Node):
pass
class _SimpleNonLeafNode(_Node):
next: _Node
def __init__(self, *, next: _Node):
self.next = next
def __repr__(self) -> str:
return f"{self.__class__.__name__}(next={self.next})"
class _IntNode(_SimpleNonLeafNode):
pass
class _StringNode(_SimpleNonLeafNode):
pass
class _NewlineNode(_SimpleNonLeafNode):
pass
class _LoopNode(_Node):
index: int # de Bruijn index
delta: int
body: _Node
next: _Node
def __init__(self, *, index: int, delta: int = 0, body: _Node, next: _Node):
assert delta in (-1, 0, 1)
self.index = index
self.delta = delta
self.body = body
self.next = next
def __repr__(self) -> str:
return f"{self.__class__.__name__}(index={self.index}, delta={self.delta}, body={self.body}, next={self.next})"
[docs]def get_tree_size(node: _Node) -> int:
if isinstance(node, _PlaceholderNode):
return 1
elif isinstance(node, _EOFNode):
return 1
elif isinstance(node, _SimpleNonLeafNode):
return 1 + get_tree_size(node.next)
elif isinstance(node, _LoopNode):
return 1 + abs(node.delta) + get_tree_size(node.body) + get_tree_size(node.next)
else:
assert False
[docs]def run_match(node: _Node, state: _MatchState) -> Optional[_MatchState]:
"""
:raises _MatchStop:
"""
if isinstance(node, _PlaceholderNode):
raise _MatchStop(state)
elif isinstance(node, _EOFNode):
return state
elif isinstance(node, _IntNode):
assert 0 <= state.offset <= len(state.tokens)
if state.offset >= len(state.tokens):
return None
token = state.tokens[state.offset]
if not isinstance(token, _IntToken):
return None
state = _MatchState(tokens=state.tokens, offset=state.offset + 1, env=[token.value] + state.env)
return run_match(node.next, state)
elif isinstance(node, _StringNode):
assert 0 <= state.offset <= len(state.tokens)
if state.offset >= len(state.tokens):
return None
# An int is a str. `101` is an int but `1010100101010101010100111111101010101` may be a str. `10.1` is also a str.
if not isinstance(state.tokens[state.offset], _StringToken) and not isinstance(state.tokens[state.offset], _IntToken):
return None
state = _MatchState(tokens=state.tokens, offset=state.offset + 1, env=state.env)
return run_match(node.next, state)
elif isinstance(node, _NewlineNode):
assert 0 <= state.offset <= len(state.tokens)
if state.offset >= len(state.tokens):
return None
if not isinstance(state.tokens[state.offset], _NewlineToken):
return None
state = _MatchState(tokens=state.tokens, offset=state.offset + 1, env=state.env)
return run_match(node.next, state)
elif isinstance(node, _LoopNode):
assert 0 <= node.index < len(state.env)
count = state.env[node.index] + node.delta
if count <= 0:
# loops of zero times cause some problems because some placeholders may be skipped
return None
for _ in range(count):
result = run_match(node.body, state)
if result is None:
return None
state = _MatchState(tokens=state.tokens, offset=result.offset, env=state.env) # reset
return run_match(node.next, state)
else:
assert False
[docs]def count_placeholder(node: _Node) -> int:
if isinstance(node, _PlaceholderNode):
return 1
elif isinstance(node, _EOFNode):
return 0
elif isinstance(node, _SimpleNonLeafNode):
return count_placeholder(node.next)
elif isinstance(node, _LoopNode):
return count_placeholder(node.body) + count_placeholder(node.next)
else:
assert False
[docs]def get_replaced_first_placeholder(node: _Node, subst: _Node) -> Optional[_Node]:
if isinstance(node, _PlaceholderNode):
return subst
elif isinstance(node, _EOFNode):
return None
elif isinstance(node, _SimpleNonLeafNode):
next = get_replaced_first_placeholder(node.next, subst)
if next is None:
return None
else:
return node.__class__(next=next)
elif isinstance(node, _LoopNode):
body = get_replaced_first_placeholder(node.body, subst)
if body is not None:
return _LoopNode(index=node.index, delta=node.delta, body=body, next=node.next)
else:
next = get_replaced_first_placeholder(node.next, subst)
if next is not None:
return _LoopNode(index=node.index, delta=node.delta, body=node.body, next=next)
else:
return None
else:
assert False
class _PriorityQueue:
def __init__(self) -> None:
self._heap: List[Tuple[int, int, _Node]] = []
self._counter = itertools.count()
def push(self, cost: int, node: _Node) -> None:
# Put an index to costs to avoid comparison of nodes.
heapq.heappush(self._heap, (cost, next(self._counter), node))
def pop(self) -> _Node:
"""pop() returns the item which has smallest cost value.
:raises IndexError:
"""
_, _, node = heapq.heappop(self._heap)
return node
def empty(self) -> bool:
return not self._heap
[docs]def tokenize_content(content: str) -> Iterator[_Token]:
# The int tokens are tokens which can be used as loop sizes. Only small integers satisfy this condition.
int_max = len(content.split()) + len(content.splitlines()) + 3
for y, line in enumerate(content.splitlines(keepends=True)):
words = line.split()
for x, word in enumerate(words):
try:
n = int(word)
except ValueError:
yield _StringToken(value=word, row=y, column=x)
else:
if 0 <= n <= int_max:
yield _IntToken(value=n, row=y, column=x)
else:
yield _StringToken(value=word, row=y, column=x)
if line.endswith('\n'): # including "\r\n"
yield _NewlineToken(row=y, column=len(words))
[docs]def list_next_possible_node(states: List[_MatchState]) -> Iterator[_Node]:
# validate a set of states
assert states
for state in states:
assert 0 <= state.offset <= len(state.tokens)
env_size = len(states[0].env)
assert all([len(state.env) == env_size for state in states])
# EOF
yield _EOFNode()
if all([state.offset == len(state.tokens) for state in states]):
return
# when some instances reach EOF but some instances don't
if any([state.offset == len(state.tokens) for state in states]):
return
# when all next tokens are int tokens
if all([isinstance(state.tokens[state.offset], _IntToken) for state in states]):
yield _IntNode(next=_PlaceholderNode())
for i in range(env_size):
for delta in (-1, 0, 1):
if all([0 <= state.env[i] + delta for state in states]):
yield _LoopNode(index=i, delta=delta, body=_IntNode(next=_PlaceholderNode()), next=_PlaceholderNode())
return
# when all next tokens are string tokens
if all([isinstance(state.tokens[state.offset], _StringToken) or isinstance(state.tokens[state.offset], _IntToken) for state in states]):
yield _StringNode(next=_PlaceholderNode())
for i in range(env_size):
for delta in (-1, 0, 1):
if all([0 <= state.env[i] + delta for state in states]):
yield _LoopNode(index=i, delta=delta, body=_StringNode(next=_PlaceholderNode()), next=_PlaceholderNode())
return
# when all next tokens are newline tokens
if all([isinstance(state.tokens[state.offset], _NewlineToken) for state in states]):
yield _NewlineNode(next=_PlaceholderNode())
# don't yield loop node here
return
return
def _construct_minimum_input_format_internal_tree(*, instances: List[List[_Token]], initial_env: Optional[List[List[int]]] = None, iteration_limit: int = 10000, size_limit: int = 20, initial_node: _Node = _PlaceholderNode()) -> Optional[_Node]:
# init
que = _PriorityQueue()
que.push(get_tree_size(initial_node), initial_node)
while not que.empty():
# pop
cur = que.pop()
# calc
states = []
for i, instance in enumerate(instances):
if initial_env is not None:
env = initial_env[i]
else:
env = []
try:
state = run_match(cur, _MatchState(tokens=instance, offset=0, env=env))
if state is None:
break
if state.offset != len(state.tokens):
break # matching finished before EOF
except _MatchStop as e:
state = e.state
states.append(state)
if len(states) != len(instances):
continue
if all([state.offset == len(state.tokens) for state in states]) and not count_placeholder(cur):
return cur
# push
for delta in list_next_possible_node(states):
nxt = get_replaced_first_placeholder(cur, delta)
assert nxt is not None
if get_tree_size(nxt) <= size_limit:
que.push(get_tree_size(nxt), nxt)
# timeout. This function doesn't have good time complexity, so may take too long time.
iteration_limit -= 1
if iteration_limit < 0:
break
return None
[docs]class EnvItem(NamedTuple):
name: VarName
is_counter: bool
def _convert_to_format_node(node: _Node, *, env: List[EnvItem], used: Set[VarName], fixed_names: List[VarName]) -> FormatNode:
def get_fresh_name() -> VarName:
# allow using fixed name for multiple test cases
if fixed_names:
return fixed_names.pop() # update the list
for var in map(VarName, string.ascii_letters):
if var not in used:
return var
else:
assert False # TODO: improve name assiging
def list_indices(index: int) -> List[VarName]:
indices = []
for item in reversed(env[index + 1:]):
if item.is_counter:
indices.append(item.name)
return indices
if isinstance(node, _EOFNode):
return SequenceNode(items=[])
elif isinstance(node, _IntNode) or isinstance(node, _StringNode):
var = get_fresh_name()
delta: List[EnvItem] = []
if isinstance(node, _IntNode):
delta = [EnvItem(var, False)]
indices = list_indices(-1)
used.add(var)
return SequenceNode(items=[
ItemNode(name=var, indices=indices),
_convert_to_format_node(node.next, env=delta + env, used=used, fixed_names=fixed_names),
])
elif isinstance(node, _NewlineNode):
return SequenceNode(items=[
NewlineNode(),
_convert_to_format_node(node.next, env=env, used=used, fixed_names=fixed_names),
])
elif isinstance(node, _LoopNode):
size = Expr(env[node.index].name)
if list_indices(node.index):
size = Expr(str(size) + '_{' + ','.join(list_indices(node.index)) + '}')
var = get_fresh_name()
used.add(var)
body = _convert_to_format_node(node.body, env=[EnvItem(var, True)] + env, used=used, fixed_names=fixed_names)
used.remove(var)
return SequenceNode(items=[
LoopNode(size=size, name=var, body=body),
_convert_to_format_node(node.next, env=env, used=used, fixed_names=fixed_names),
])
elif isinstance(node, _PlaceholderNode):
assert False
else:
assert False