Skip to content

Commit

Permalink
Merge pull request #408 from airtai/master
Browse files Browse the repository at this point in the history
Update pydantic to v2 and update datamodel-code-generator to 0.25.6
  • Loading branch information
koxudaxi committed May 2, 2024
2 parents fb88eba + f11f922 commit e8e381f
Show file tree
Hide file tree
Showing 30 changed files with 740 additions and 960 deletions.
25 changes: 19 additions & 6 deletions .github/workflows/test.yml
Expand Up @@ -10,23 +10,24 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-latest, windows-latest, macos-latest]

steps:
- uses: actions/checkout@v1
- uses: actions/cache@v1
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip --disable-pip-version-check
python -m pip install poetry
poetry install
- name: Lint
Expand All @@ -43,9 +44,21 @@ jobs:
./scripts/poetry_test.bat
- name: Upload coverage to Codecov
if: matrix.os == 'ubuntu-latest'
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: ./coverage.xml
flags: unittests
# fail_ci_if_error: true
check: # This job does nothing and is only used for the branch protection
if: github.event.pull_request.draft == false

needs:
- test

runs-on: ubuntu-latest

steps:
- name: Decide whether the needed jobs succeeded or failed
uses: re-actors/alls-green@release/v1 # nosemgrep
with:
jobs: ${{ toJSON(needs) }}
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
@@ -1,12 +1,12 @@
repos:
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 24.4.2
hooks:
- id: black
files: "^fastapi_code_generator|^tests"
exclude: "^tests/data"
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
files: "^fastapi_code_generator|^tests"
Expand Down
4 changes: 2 additions & 2 deletions fastapi_code_generator/__main__.py
Expand Up @@ -77,7 +77,7 @@ def main(
output_dir,
template_dir,
model_path,
enum_field_as_literal,
enum_field_as_literal, # type: ignore[arg-type]
custom_visitors=custom_visitors,
disable_timestamp=disable_timestamp,
generate_routers=generate_routers,
Expand Down Expand Up @@ -131,7 +131,7 @@ def generate_code(
BUILTIN_MODULAR_TEMPLATE_DIR if generate_routers else BUILTIN_TEMPLATE_DIR
)
if enum_field_as_literal:
parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal)
parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal) # type: ignore[arg-type]
else:
parser = OpenAPIParser(input_text)
with chdir(output_dir):
Expand Down
82 changes: 58 additions & 24 deletions fastapi_code_generator/parser.py
Expand Up @@ -25,13 +25,12 @@
LiteralType,
OpenAPIScope,
PythonVersion,
cached_property,
snooper_to_methods,
)
from datamodel_code_generator.imports import Import, Imports
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model import pydantic as pydantic_model
from datamodel_code_generator.model.pydantic import DataModelField
from datamodel_code_generator.model.pydantic import CustomRootType, DataModelField
from datamodel_code_generator.parser.jsonschema import JsonSchemaObject
from datamodel_code_generator.parser.openapi import MediaObject
from datamodel_code_generator.parser.openapi import OpenAPIParser as OpenAPIModelParser
Expand All @@ -43,7 +42,8 @@
ResponseObject,
)
from datamodel_code_generator.types import DataType, DataTypeManager, StrictTypes
from pydantic import BaseModel
from datamodel_code_generator.util import cached_property
from pydantic import BaseModel, ValidationInfo

RE_APPLICATION_JSON_PATTERN: Pattern[str] = re.compile(r'^application/.*json$')

Expand Down Expand Up @@ -72,7 +72,7 @@ def __get_validators__(cls) -> Any:
yield cls.validate

@classmethod
def validate(cls, v: Any) -> Any:
def validate(cls, v: Any, info: ValidationInfo) -> Any:
return cls(v)

@property
Expand All @@ -91,8 +91,8 @@ def camelcase(self) -> str:
class Argument(CachedPropertyModel):
name: UsefulStr
type_hint: UsefulStr
default: Optional[UsefulStr]
default_value: Optional[UsefulStr]
default: Optional[UsefulStr] = None
default_value: Optional[UsefulStr] = None
required: bool

def __str__(self) -> str:
Expand All @@ -108,20 +108,20 @@ def argument(self) -> str:
class Operation(CachedPropertyModel):
method: UsefulStr
path: UsefulStr
operationId: Optional[UsefulStr]
description: Optional[str]
summary: Optional[str]
operationId: Optional[UsefulStr] = None
description: Optional[str] = None
summary: Optional[str] = None
parameters: List[Dict[str, Any]] = []
responses: Dict[UsefulStr, Any] = {}
deprecated: bool = False
imports: List[Import] = []
security: Optional[List[Dict[str, List[str]]]] = None
tags: Optional[List[str]]
tags: Optional[List[str]] = []
arguments: str = ''
snake_case_arguments: str = ''
request: Optional[Argument] = None
response: str = ''
additional_responses: Dict[str, Dict[str, str]] = {}
additional_responses: Dict[Union[str, int], Dict[str, str]] = {}
return_type: str = ''

@cached_property
Expand Down Expand Up @@ -245,16 +245,22 @@ def parse_info(self) -> Optional[Dict[str, Any]]:
result['servers'] = servers
return result or None

def parse_parameters(self, parameters: ParameterObject, path: List[str]) -> None:
super().parse_parameters(parameters, path)
self._temporary_operation['_parameters'].append(parameters)
def parse_all_parameters(
self,
name: str,
parameters: List[Union[ReferenceObject, ParameterObject]],
path: List[str],
) -> None:
super().parse_all_parameters(name, parameters, path)
self._temporary_operation['_parameters'].extend(parameters)

def get_parameter_type(
self,
parameters: ParameterObject,
parameters: Union[ReferenceObject, ParameterObject],
snake_case: bool,
path: List[str],
) -> Optional[Argument]:
parameters = self.resolve_object(parameters, ParameterObject)
orig_name = parameters.name
if snake_case:
name = stringcase.snakecase(parameters.name)
Expand All @@ -274,7 +280,10 @@ def get_parameter_type(
if not data_type:
if not schema:
schema = parameters.schema_
if schema is None:
raise RuntimeError("schema is None") # pragma: no cover
data_type = self.parse_schema(name, schema, [*path, name])
data_type = self._collapse_root_model(data_type)
if not schema:
return None

Expand All @@ -290,16 +299,18 @@ def get_parameter_type(
self.imports_for_fastapi.append(
Import(from_='fastapi', import_=param_is)
)
default: Optional[
str
] = f"{param_is}({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
default: Optional[str] = (
f"{param_is}({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
)
else:
default = repr(schema.default) if schema.has_default else None
self.imports_for_fastapi.append(field.imports)
self.data_types.append(field.data_type)
if field.name is None:
raise RuntimeError("field.name is None") # pragma: no cover
return Argument(
name=field.name,
type_hint=field.type_hint,
name=UsefulStr(field.name),
type_hint=UsefulStr(field.type_hint),
default=default, # type: ignore
default_value=schema.default,
required=field.required,
Expand Down Expand Up @@ -361,11 +372,12 @@ def parse_request_body(
data_type = self.parse_schema(
name, media_obj.schema_, [*path, media_type]
)
data_type = self._collapse_root_model(data_type)
arguments.append(
# TODO: support multiple body
Argument(
name='body', # type: ignore
type_hint=data_type.type_hint,
type_hint=UsefulStr(data_type.type_hint),
required=request_body.required,
)
)
Expand Down Expand Up @@ -406,17 +418,18 @@ def parse_request_body(
)
self._temporary_operation['_request'] = arguments[0] if arguments else None

def parse_responses(
def parse_responses( # type: ignore[override]
self,
name: str,
responses: Dict[str, Union[ResponseObject, ReferenceObject]],
path: List[str],
) -> Dict[str, Dict[str, DataType]]:
data_types = super().parse_responses(name, responses, path)
) -> Dict[Union[str, int], Dict[str, DataType]]:
data_types = super().parse_responses(name, responses, path) # type: ignore[arg-type]
status_code_200 = data_types.get('200')
if status_code_200:
data_type = list(status_code_200.values())[0]
if data_type:
data_type = self._collapse_root_model(data_type)
self.data_types.append(data_type)
else:
data_type = DataType(type='None')
Expand Down Expand Up @@ -466,3 +479,24 @@ def parse_operation(
path=f'/{path_name}', # type: ignore
method=method, # type: ignore
)

def _collapse_root_model(self, data_type: DataType) -> DataType:
reference = data_type.reference
import functools

if not (
reference
and (
len(reference.children) == 1
or functools.reduce(lambda a, b: a == b, reference.children)
)
):
return data_type
source = reference.source
if not isinstance(source, CustomRootType):
return data_type
data_type.remove_reference()
data_type = source.fields[0].data_type
if source in self.results:
self.results.remove(source)
return data_type

0 comments on commit e8e381f

Please sign in to comment.