Source code for onlinejudge_template.analyzer.node_util

import string
from typing import *

import onlinejudge_template.analyzer.simplify as simplify
from onlinejudge_template.types import *


[docs]def remove_superfluous_sequence_nodes(node: FormatNode) -> FormatNode: if isinstance(node, ItemNode): return node elif isinstance(node, NewlineNode): return node elif isinstance(node, SequenceNode): items = [] for item in node.items: item = remove_superfluous_sequence_nodes(item) if isinstance(item, SequenceNode): items.extend(item.items) else: items.append(item) if len(items) == 1: return items[0] return SequenceNode(items=items) elif isinstance(node, LoopNode): return LoopNode(size=node.size, name=node.name, body=remove_superfluous_sequence_nodes(node.body)) else: assert False
def _get_nice_variable_name(*, used: Set[VarName]) -> VarName: for c in map(VarName, 'abcdefgh' + 'mnopqrstuvwxyz'): if c not in used: return c for c1 in string.ascii_uppercase: for c2 in string.ascii_uppercase: for c3 in string.ascii_uppercase: s = VarName('a' + c1 + c2 + c3) if s not in used: return s assert False def _get_nice_counter_name(*, used: Set[VarName]) -> VarName: for c in map(VarName, 'ijkl'): if c not in used: return c for c1 in string.ascii_uppercase: for c2 in string.ascii_uppercase: for c3 in string.ascii_uppercase: s = VarName('i' + c1 + c2 + c3) if s not in used: return s assert False # use the name `testcases` for variables which describes the number of test cases. testcases_varname: VarName = VarName('testcases') def _rename_variable_nicely_dfs(node: FormatNode, *, replace: Dict[VarName, VarName], used: Set[VarName]) -> FormatNode: if isinstance(node, ItemNode): if node.name == testcases_varname: name = testcases_varname else: name = _get_nice_variable_name(used=used) indices = [simplify.rename_variables_in_expr(index, replace=replace) for index in node.indices] assert node.name not in replace replace[node.name] = name used.add(name) return ItemNode(name=name, indices=indices) elif isinstance(node, NewlineNode): return NewlineNode() elif isinstance(node, SequenceNode): items = [] for item in node.items: items.append(_rename_variable_nicely_dfs(item, replace=replace, used=used)) return SequenceNode(items=items) elif isinstance(node, LoopNode): name = _get_nice_counter_name(used=used) size = simplify.rename_variables_in_expr(node.size, replace=replace) assert node.name not in replace replace[node.name] = name used.add(name) body = _rename_variable_nicely_dfs(node.body, replace=replace, used=used) used.remove(name) replace.pop(node.name) return LoopNode(size=size, name=name, body=body) else: assert False
[docs]def rename_variable_nicely(node: FormatNode, *, used: Optional[Set[VarName]] = None) -> FormatNode: return _rename_variable_nicely_dfs(node, replace={}, used=used or set())