"""
the module to generate Python code
この module は Python のコードを生成します。
以下の関数を提供します。
- :func:`read_input`
- :func:`write_output`
- :func:`declare_constants`
- :func:`formal_arguments`
- :func:`actual_arguments`
- :func:`return_type`
- :func:`return_value`
加えて、ランダムケースの生成のために、以下の関数を提供します。
- :func:`generate_input`
- :func:`write_input`
"""
from typing import *
import onlinejudge_template.generator._utils as utils
from onlinejudge_template.analyzer.simplify import simplify
from onlinejudge_template.generator._python import *
from onlinejudge_template.types import *
_DEDENT = '*DEDENT*'
def _join_with_indent(lines: Sequence[str], *, nest: int, data: Dict[str, Any]) -> str:
indent = data['config'].get('indent', ' ' * 4)
buf = []
for line in lines:
if line == _DEDENT:
nest -= 1
continue
buf.append(indent * nest + line)
if line.endswith(':'):
nest += 1
return '\n'.join(buf)
def _get_python_type(type: Optional[VarType]) -> str:
if type == VarType.IndexInt:
return "int"
elif type == VarType.ValueInt:
return "int"
elif type == VarType.Float:
return "float"
elif type == VarType.String:
return "str"
elif type == VarType.Char:
return "str"
elif type is None:
return "str"
else:
assert False
def _get_variable(*, decl: VarDecl, indices: Sequence[Expr]) -> str:
var = str(decl.name)
for index, base in zip(indices, decl.bases):
i = simplify(Expr(f"""{index} - ({base})"""))
var = f"""{var}[{i}]"""
return var
def _declare_variable(name: VarName, dims: List[Expr], *, data: Dict[str, Any]) -> Iterator[str]:
if dims:
ctor = "None"
for dim in reversed(dims):
ctor = f"""[{ctor} for _ in range({dim})]"""
yield f"""{name} = {ctor}"""
def _declare_all_possible_variables(declared: Set[VarName], initialized: Set[VarName], decls: Dict[VarName, VarDecl], data: Dict[str, Any]) -> List[PythonNode]:
"""
:param declared: updated
"""
decl_nodes: List[PythonNode] = []
for var, decl in decls.items():
if var not in declared and all([dep in initialized for dep in decl.depending]):
for line in _declare_variable(var, decl.dims, data=data):
decl_nodes.append(OtherNode(line=line))
declared.add(var)
return decl_nodes
def _declare_constant(decl: ConstantDecl, *, data: Dict[str, Any]) -> str:
if decl.type in (VarType.String, VarType.Char):
value = repr(decl.value)
else:
value = decl.value
return f"""{decl.name} = {value}"""
def _generate_input_dfs(node: FormatNode, *, declared: Set[VarName], initialized: Set[VarName], decls: Dict[VarName, VarDecl], data: Dict[str, Any]) -> PythonNode:
decl_nodes = _declare_all_possible_variables(declared=declared, initialized=initialized, decls=decls, data=data)
if decl_nodes:
return SentencesNode(sentences=decl_nodes + [_generate_input_dfs(node, declared=declared, initialized=initialized, decls=decls, data=data)])
# traverse AST
if isinstance(node, ItemNode):
var = _get_variable(decl=decls[node.name], indices=node.indices)
type_ = decls[node.name].type
initialized.add(node.name)
if type_ == VarType.IndexInt:
return OtherNode(line=f"""{var} = random.randint(1, 1000) # TODO: edit here""")
elif type_ == VarType.ValueInt:
return OtherNode(line=f"""{var} = random.randint(1, 10 ** 9) # TODO: edit here""")
elif type_ == VarType.Float:
return OtherNode(line=f"""{var} = 100.0 * random.random() # TODO: edit here""")
elif type_ == VarType.String:
return OtherNode(line=f"""{var} = ''.join([random.choice('abcde') for _ in range(random.randint(1, 100))]) # TODO: edit here""")
elif type_ == VarType.Char:
return OtherNode(line=f"""{var} = random.choice('abcde') # TODO: edit here""")
else:
return OtherNode(line=f"""{var} = random.randint(1, 10) # TODO: edit here""")
elif isinstance(node, NewlineNode):
return SentencesNode(sentences=[])
elif isinstance(node, SequenceNode):
sentences = []
for item in node.items:
sentences.append(_generate_input_dfs(item, declared=declared, initialized=initialized, decls=decls, data=data))
return SentencesNode(sentences=sentences)
elif isinstance(node, LoopNode):
declared.add(node.name)
body = _generate_input_dfs(node.body, declared=declared, initialized=initialized, decls=decls, data=data)
declared.remove(node.name)
return RangeNode(name=node.name, size=node.size, body=body)
else:
assert False
def _read_input_dfs(node: FormatNode, *, decls: Dict[VarName, VarDecl], data: Dict[str, Any]) -> PythonNode:
if isinstance(node, ItemNode):
decl = decls[node.name]
var = _get_variable(decl=decl, indices=node.indices)
return InputTokensNode(exprs=[(var, decl)])
elif isinstance(node, NewlineNode):
return InputNode(exprs=[])
elif isinstance(node, SequenceNode):
sentences = []
for item in node.items:
sentences.append(_read_input_dfs(item, decls=decls, data=data))
return SentencesNode(sentences=sentences)
elif isinstance(node, LoopNode):
return RangeNode(name=node.name, size=node.size, body=_read_input_dfs(node.body, decls=decls, data=data))
else:
assert False
def _write_output_dfs(node: FormatNode, *, decls: Dict[VarName, VarDecl], data: Dict[str, Any]) -> PythonNode:
if isinstance(node, ItemNode):
var = _get_variable(decl=decls[node.name], indices=node.indices)
return PrintTokensNode(exprs=[var])
elif isinstance(node, NewlineNode):
return PrintNode(exprs=[])
elif isinstance(node, SequenceNode):
sentences = []
for item in node.items:
sentences.append(_write_output_dfs(item, decls=decls, data=data))
return SentencesNode(sentences=sentences)
elif isinstance(node, LoopNode):
return RangeNode(name=node.name, size=node.size, body=_write_output_dfs(node.body, decls=decls, data=data))
else:
assert False
def _optimize_syntax_tree(node: PythonNode, *, data: Dict[str, Any]) -> PythonNode:
if isinstance(node, InputTokensNode):
return node
elif isinstance(node, InputNode):
return node
elif isinstance(node, PrintTokensNode):
return node
elif isinstance(node, PrintNode):
return node
elif isinstance(node, SentencesNode):
sentences: List[PythonNode] = []
que: List[PythonNode] = [_optimize_syntax_tree(sentence, data=data) for sentence in node.sentences]
while que:
sentence: PythonNode = que[0]
que = que[1:]
if sentences:
last: Optional[PythonNode] = sentences[-1]
else:
last = None
if isinstance(last, InputTokensNode) and isinstance(sentence, InputNode):
sentence = InputNode(exprs=last.exprs + sentence.exprs)
sentences.pop()
que = [sentence] + que # type: ignore
elif isinstance(last, PrintTokensNode) and isinstance(sentence, PrintNode):
sentence = PrintNode(exprs=last.exprs + sentence.exprs)
sentences.pop()
que = [sentence] + que # type: ignore
elif isinstance(last, RangeNode) and isinstance(last.body, PrintTokensNode) and len(last.body.exprs) == 1 and isinstance(sentence, PrintNode):
array = f"""*[{last.body.exprs[0]} for {last.name} in range({last.size})]"""
sentence = PrintNode(exprs=[array] + sentence.exprs)
sentences.pop()
que = [sentence] + que # type: ignore
elif isinstance(sentence, SentencesNode):
que = sentence.sentences + que
else:
sentences.append(sentence)
return SentencesNode(sentences=sentences)
elif isinstance(node, RangeNode):
return RangeNode(name=node.name, size=node.size, body=_optimize_syntax_tree(node.body, data=data))
elif isinstance(node, OtherNode):
return node
else:
assert False
def _realize_input_nodes_without_tokens(node: PythonNode, *, declared: Set[VarName], initialized: Set[VarName], decls: Dict[VarName, VarDecl], data: Dict[str, Any]) -> PythonNode:
"""
:raises TokenizedInputRequiredError:
"""
decl_nodes = _declare_all_possible_variables(declared=declared, initialized=initialized, decls=decls, data=data)
if decl_nodes:
return SentencesNode(sentences=decl_nodes + [_realize_input_nodes_without_tokens(node, declared=declared, initialized=initialized, decls=decls, data=data)])
if isinstance(node, InputTokensNode):
raise TokenizedInputRequiredError
elif isinstance(node, InputNode):
for _, decl in node.exprs:
initialized.add(decl.name)
if len(node.exprs) == 0:
return OtherNode(line="""assert input() == ''""")
elif len(node.exprs) == 1:
expr, decl = node.exprs[0]
if _get_python_type(decl.type) == "str":
return OtherNode(line=f"""{expr} = input()""")
else:
return OtherNode(line=f"""{expr} = {_get_python_type(decl.type)}(input())""")
else:
exprs = [expr for expr, _ in node.exprs]
types = [decl.type for _, decl in node.exprs]
if len(set(map(_get_python_type, types))) == 1:
type = types[0]
if _get_python_type(type) == "str":
return OtherNode(line=f"""{', '.join(exprs)} = input().split()""")
else:
return OtherNode(line=f"""{', '.join(exprs)} = map({_get_python_type(type)}, input().split())""")
else:
raise TokenizedInputRequiredError
elif isinstance(node, PrintTokensNode):
return node
elif isinstance(node, PrintNode):
return node
elif isinstance(node, SentencesNode):
sentences: List[PythonNode] = []
for sentence in node.sentences:
sentences.append(_realize_input_nodes_without_tokens(sentence, declared=declared, initialized=initialized, decls=decls, data=data))
return SentencesNode(sentences=sentences)
elif isinstance(node, RangeNode):
declared.add(node.name)
body = _realize_input_nodes_without_tokens(node.body, declared=declared, initialized=initialized, decls=decls, data=data)
declared.remove(node.name)
return RangeNode(name=node.name, size=node.size, body=body)
elif isinstance(node, OtherNode):
return node
else:
assert False
def _realize_input_nodes_with_tokens_dfs(node: PythonNode, tokens: str, *, declared: Set[VarName], initialized: Set[VarName], decls: Dict[VarName, VarDecl], data: Dict[str, Any]) -> PythonNode:
decl_nodes = _declare_all_possible_variables(declared=declared, initialized=initialized, decls=decls, data=data)
if decl_nodes:
return SentencesNode(sentences=decl_nodes + [_realize_input_nodes_with_tokens_dfs(node, tokens, declared=declared, initialized=initialized, decls=decls, data=data)])
if isinstance(node, InputTokensNode) or isinstance(node, InputNode):
for _, decl in node.exprs:
initialized.add(decl.name)
sentences: List[PythonNode] = []
for expr, decl in node.exprs:
if _get_python_type(decl.type) == "str":
node_ = OtherNode(line=f"""{expr} = next({tokens})""")
else:
node_ = OtherNode(line=f"""{expr} = {_get_python_type(decl.type)}(next({tokens}))""")
sentences.append(node_)
return SentencesNode(sentences=sentences)
elif isinstance(node, PrintTokensNode):
return node
elif isinstance(node, PrintNode):
return node
elif isinstance(node, SentencesNode):
sentences = []
for sentence in node.sentences:
sentences.append(_realize_input_nodes_with_tokens_dfs(sentence, tokens, declared=declared, initialized=initialized, decls=decls, data=data))
return SentencesNode(sentences=sentences)
elif isinstance(node, RangeNode):
declared.add(node.name)
body = _realize_input_nodes_with_tokens_dfs(node.body, tokens, declared=declared, initialized=initialized, decls=decls, data=data)
declared.remove(node.name)
return RangeNode(name=node.name, size=node.size, body=body)
elif isinstance(node, OtherNode):
return node
else:
assert False
def _realize_input_nodes_with_tokens(node: PythonNode, tokens: str, *, decls: Dict[VarName, VarDecl], data: Dict[str, Any]) -> PythonNode:
node = SentencesNode(sentences=[
OtherNode(line="""import sys"""),
OtherNode(line=f"""{tokens} = iter(sys.stdin.read().split())"""),
node,
OtherNode(line=f"""assert next({tokens}, None) is None"""),
])
return _realize_input_nodes_with_tokens_dfs(node, tokens, declared=set(), initialized=set(), decls=decls, data=data)
def _serialize_syntax_tree(node: PythonNode, *, data: Dict[str, Any]) -> Iterator[str]:
if isinstance(node, InputTokensNode):
assert False
elif isinstance(node, InputNode):
assert False
elif isinstance(node, PrintTokensNode):
if node.exprs:
yield f"""print({', '.join(node.exprs)}, end=' ')"""
elif isinstance(node, PrintNode):
yield f"""print({', '.join(node.exprs)})"""
elif isinstance(node, SentencesNode):
for item in node.sentences:
yield from _serialize_syntax_tree(item, data=data)
elif isinstance(node, RangeNode):
yield f"""for {node.name} in range({node.size}):"""
yield from _serialize_syntax_tree(node.body, data=data)
yield _DEDENT
elif isinstance(node, OtherNode):
yield node.line
else:
assert False
[docs]def write_output(data: Dict[str, Any], *, nest: int = 1) -> str:
analyzed = utils.get_analyzed(data)
if analyzed.output_format is None or analyzed.output_variables is None:
lines = [
'print(ans) # TODO: edit here',
]
return _join_with_indent(lines, nest=nest, data=data)
node = _write_output_dfs(analyzed.output_format, decls=analyzed.output_variables, data=data)
node = _optimize_syntax_tree(node, data=data)
lines = list(_serialize_syntax_tree(node, data=data))
return _join_with_indent(lines, nest=nest, data=data)
[docs]def actual_arguments(data: Dict[str, Any]) -> str:
analyzed = utils.get_analyzed(data)
if analyzed.input_format is None or analyzed.input_variables is None:
return 'n, a'
return ', '.join(analyzed.input_variables.keys())
[docs]def return_type(data: Dict[str, Any]) -> str:
analyzed = utils.get_analyzed(data)
if analyzed.output_format is None or analyzed.output_variables is None:
return 'Any'
types: List[str] = []
for decl in analyzed.output_variables.values():
type = _get_python_type(decl.type)
for _ in decl.dims:
type = f"""List[{type}]"""
types.append(type)
if len(types) == 0:
return "None"
elif len(types) == 1:
return types[0]
else:
return f"""Tuple[{", ".join(types)}]"""
[docs]def return_value(data: Dict[str, Any]) -> str:
analyzed = utils.get_analyzed(data)
if analyzed.output_format is None or analyzed.output_variables is None:
return 'ans'
return ', '.join(analyzed.output_variables.keys())
[docs]def declare_constants(data: Dict[str, Any], *, nest: int = 0) -> str:
analyzed = utils.get_analyzed(data)
lines: List[str] = []
for decl in analyzed.constants.values():
lines.append(_declare_constant(decl, data=data))
return _join_with_indent(lines, nest=nest, data=data)