Source code for onlinejudge_template.analyzer.output_types

"""
the module to analyze output types

この module は出力形式のパターンを分析し簡単化します。

たとえば
::

    N
    A_1 A_2 ... A_N

のような出力フォーマットの場合、戻り値を `std::pair<int, std::vector<int> >` とするのでなく、戻り値を `std::vector<int>` のみとし `int N = A.size();` として `N` を復元するようにした方が便利です。この module このような最適化を目的としています。
"""

from typing import *

from onlinejudge_template.analyzer.simplify import simplify
from onlinejudge_template.types import *


# TODO: remove this
def _get_variable_on_code(*, decl: VarDecl, indices: List[Expr], decls: Dict[VarName, VarDecl]) -> str:
    var = str(decl.name)
    for index, base in zip(indices, decl.bases):
        i = simplify(Expr(f"""{index} - ({base})"""))
        var = f"""{var}[{i}]"""
    return var


[docs]def match_indices(*, indices: List[Expr], names: List[str]) -> bool: if len(indices) != len(names): return False for index, name in zip(indices, names): if index not in (name + ' - 1', name, name + ' + 1'): return False return True
[docs]def analyze_output_type(*, output_format: FormatNode, output_variables: Dict[VarName, VarDecl], constants: Dict[VarName, ConstantDecl]) -> Optional[OutputType]: node = output_format decls = output_variables # pattern: # ans if isinstance(node, SequenceNode) and len(node.items) == 2: item0 = node.items[0] item1 = node.items[1] if isinstance(item0, ItemNode) and isinstance(item1, NewlineNode): type = decls[item0.name].type if type is not None: if 'YES' in constants and 'NO' in constants and type == VarType.String: return YesNoOutputType(name=Expr('ans'), yes='YES', no='NO') if 'FIRST' in constants and 'SECOND' in constants and type == VarType.String: return YesNoOutputType(name=Expr('ans'), yes='FIRST', no='SECOND') return OneOutputType(name=Expr('ans'), type=type) # pattern: # x y if isinstance(node, SequenceNode) and len(node.items) == 3: item0 = node.items[0] item1 = node.items[1] item2 = node.items[2] if isinstance(item0, ItemNode) and isinstance(item1, ItemNode) and isinstance(item2, NewlineNode): name1 = item0.name name2 = item1.name type1 = decls[name1].type type2 = decls[name2].type return TwoOutputType(name1=Expr(name1), type1=type1, name2=Expr(name2), type2=type2, print_newline_after_item=False) # pattern: # x # y if isinstance(node, SequenceNode) and len(node.items) == 4: item0 = node.items[0] item1 = node.items[1] item2 = node.items[2] item3 = node.items[3] if isinstance(item0, ItemNode) and isinstance(item1, NewlineNode) and isinstance(item2, ItemNode) and isinstance(item3, NewlineNode): name1 = item0.name name2 = item2.name type1 = decls[name1].type type2 = decls[name2].type return TwoOutputType(name1=Expr(name1), type1=type1, name2=Expr(name2), type2=type2, print_newline_after_item=False) # pattern: # a_1 ... a_n if isinstance(node, SequenceNode) and len(node.items) == 2: item0 = node.items[0] item1 = node.items[1] if isinstance(item1, NewlineNode): if isinstance(item0, LoopNode) and isinstance(item0.body, ItemNode) and match_indices(indices=item0.body.indices, names=[item0.name]): type = decls[item0.body.name].type subscripted_name = _get_variable_on_code(decl=decls[item0.body.name], indices=item0.body.indices, decls=decls) counter_name = item0.name return VectorOutputType(name=VarName('ans'), type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=False, print_newline_after_size=False, print_newline_after_item=False) # pattern: # n # a_1 ... a_n if isinstance(node, SequenceNode) and len(node.items) == 4: item0 = node.items[0] item1 = node.items[1] item2 = node.items[2] item3 = node.items[3] if isinstance(item0, ItemNode) and isinstance(item1, NewlineNode) and isinstance(item3, NewlineNode): if isinstance(item2, LoopNode) and isinstance(item2.body, ItemNode) and item2.size == item0.name and match_indices(indices=item2.body.indices, names=[item2.name]): type = decls[item2.body.name].type subscripted_name = _get_variable_on_code(decl=decls[item2.body.name], indices=item2.body.indices, decls=decls) counter_name = item2.name return VectorOutputType(name=VarName('ans'), type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=True, print_newline_after_item=False) # pattern: # n a_1 ... a_n if isinstance(node, SequenceNode) and len(node.items) == 3: item0 = node.items[0] item1 = node.items[1] item2 = node.items[2] if isinstance(item0, ItemNode) and isinstance(item2, NewlineNode): if isinstance(item1, LoopNode) and isinstance(item1.body, ItemNode) and item1.size == item0.name and match_indices(indices=item1.body.indices, names=[item1.name]): type = decls[item1.body.name].type subscripted_name = _get_variable_on_code(decl=decls[item1.body.name], indices=item1.body.indices, decls=decls) counter_name = item1.name return VectorOutputType(name=VarName('ans'), type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=False, print_newline_after_item=False) # pattern: # a_1 # ... # a_n if isinstance(node, SequenceNode) and len(node.items) == 3: item0 = node.items[0] if isinstance(item0, LoopNode) and isinstance(item0.body, SequenceNode) and len(item0.body.items) == 2: item1 = node.items[0] item2 = node.items[1] if isinstance(item1, ItemNode) and isinstance(item2, NewlineNode) and match_indices(indices=item1.indices, names=[item0.name]): type = decls[item1.name].type subscripted_name = _get_variable_on_code(decl=decls[item1.name], indices=item1.indices, decls=decls) counter_name = item0.name return VectorOutputType(name=VarName('ans'), type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=False, print_newline_after_size=False, print_newline_after_item=True) # pattern: # n # a_1 # ... # a_n if isinstance(node, SequenceNode) and len(node.items) == 3: item0 = node.items[0] item1 = node.items[1] item2 = node.items[2] if isinstance(item0, ItemNode) and isinstance(item1, NewlineNode): if isinstance(item2, LoopNode) and isinstance(item2.body, SequenceNode) and len(item2.body.items) == 2: item3 = node.items[0] item4 = node.items[1] if isinstance(item3, ItemNode) and isinstance(item4, NewlineNode) and item2.size == item0.name and match_indices(indices=item3.indices, names=[item2.name]): type = decls[item3.name].type subscripted_name = _get_variable_on_code(decl=decls[item3.name], indices=item3.indices, decls=decls) counter_name = item2.name return VectorOutputType(name=VarName('ans'), type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=True, print_newline_after_item=True) return None