Skip to content

Commit

Permalink
Merge pull request #383 from yyamano/issue-334
Browse files Browse the repository at this point in the history
Fixed broken test (Added option for inputting encoding of file #334)
  • Loading branch information
koxudaxi committed Sep 7, 2023
2 parents 9730a6d + 3ad6e66 commit fb88eba
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
18 changes: 14 additions & 4 deletions fastapi_code_generator/__main__.py
Expand Up @@ -43,7 +43,8 @@ def dynamic_load_module(module_path: Path) -> Any:

@app.command()
def main(
input_file: typer.FileText = typer.Option(..., "--input", "-i"),
encoding: str = typer.Option("utf-8", "--encoding", "-e"),
input_file: str = typer.Option(..., "--input", "-i"),
output_dir: Path = typer.Option(..., "--output", "-o"),
model_file: str = typer.Option(None, "--model-file", "-m"),
template_dir: Optional[Path] = typer.Option(None, "--template-dir", "-t"),
Expand All @@ -57,8 +58,12 @@ def main(
),
disable_timestamp: bool = typer.Option(False, "--disable-timestamp"),
) -> None:
input_name: str = input_file.name
input_text: str = input_file.read()
input_name: str = input_file
input_text: str

with open(input_file, encoding=encoding) as f:
input_text = f.read()

if model_file:
model_path = Path(model_file).with_suffix('.py')
else:
Expand All @@ -68,6 +73,7 @@ def main(
return generate_code(
input_name,
input_text,
encoding,
output_dir,
template_dir,
model_path,
Expand All @@ -80,6 +86,7 @@ def main(
return generate_code(
input_name,
input_text,
encoding,
output_dir,
template_dir,
model_path,
Expand All @@ -103,6 +110,7 @@ def _get_most_of_reference(data_type: DataType) -> Optional[Reference]:
def generate_code(
input_name: str,
input_text: str,
encoding: str,
output_dir: Path,
template_dir: Optional[Path],
model_path: Optional[Path] = None,
Expand Down Expand Up @@ -219,7 +227,9 @@ def generate_code(
header += f"\n# timestamp: {timestamp}"

for path, code in results.items():
with output_dir.joinpath(path.with_suffix(".py")).open("wt") as file:
with output_dir.joinpath(path.with_suffix(".py")).open(
"wt", encoding=encoding
) as file:
print(header, file=file)
print("", file=file)
print(code.rstrip(), file=file)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_generate.py
Expand Up @@ -22,6 +22,8 @@

SPECIFIC_TAGS = 'Wild Boars, Fat Cats'

ENCODING = 'utf-8'


@pytest.mark.parametrize(
"oas_file", (DATA_DIR / OPEN_API_DEFAULT_TEMPLATE_DIR_NAME).glob("*.yaml")
Expand All @@ -33,6 +35,7 @@ def test_generate_default_template(oas_file):
generate_code(
input_name=oas_file.name,
input_text=oas_file.read_text(),
encoding=ENCODING,
output_dir=output_dir,
template_dir=None,
)
Expand All @@ -54,6 +57,7 @@ def test_generate_custom_security_template(oas_file):
generate_code(
input_name=oas_file.name,
input_text=oas_file.read_text(),
encoding=ENCODING,
output_dir=output_dir,
template_dir=DATA_DIR / 'custom_template' / 'security',
)
Expand All @@ -79,6 +83,7 @@ def test_generate_remote_ref(mocker):
generate_code(
input_name=oas_file.name,
input_text=oas_file.read_text(),
encoding=ENCODING,
output_dir=output_dir,
template_dir=None,
)
Expand All @@ -105,6 +110,7 @@ def test_disable_timestamp(oas_file):
generate_code(
input_name=oas_file.name,
input_text=oas_file.read_text(),
encoding=ENCODING,
output_dir=output_dir,
template_dir=None,
disable_timestamp=True,
Expand All @@ -130,6 +136,7 @@ def test_generate_using_routers(oas_file):
generate_code(
input_name=oas_file.name,
input_text=oas_file.read_text(),
encoding=ENCODING,
output_dir=output_dir,
template_dir=BUILTIN_MODULAR_TEMPLATE_DIR,
generate_routers=True,
Expand Down Expand Up @@ -166,6 +173,7 @@ def test_generate_modify_specific_routers(oas_file):
generate_code(
input_name=oas_file.name,
input_text=oas_file.read_text(),
encoding=ENCODING,
output_dir=output_dir,
template_dir=BUILTIN_MODULAR_TEMPLATE_DIR,
generate_routers=True,
Expand Down

0 comments on commit fb88eba

Please sign in to comment.