跳到主要内容

C0726N01-FastAPI 参数处理系统

概述

FastAPI的参数处理系统是其核心功能之一,负责从HTTP请求中提取、验证和转换各种类型的参数。本章将深入分析参数处理系统的设计原理、实现机制以及各种参数类型的处理方式。

1. 参数系统架构

1.1 参数类型体系

FastAPI支持以下参数类型:

# 参数类型继承关系
class FieldInfo:
"""参数信息基类"""
def __init__(
self,
default: Any = Undefined,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
examples: Optional[Dict[str, Any]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
json_schema_extra: Optional[Dict[str, Any]] = None,
**extra: Any,
):
self.default = default
self.alias = alias
self.title = title
self.description = description
self.examples = examples
self.deprecated = deprecated
self.include_in_schema = include_in_schema
self.json_schema_extra = json_schema_extra
self.extra = extra

# 具体参数类型
class Path(FieldInfo):
"""路径参数"""
def __init__(self, default: Any = ..., **kwargs):
super().__init__(default, **kwargs)

class Query(FieldInfo):
"""查询参数"""
def __init__(self, default: Any = Undefined, **kwargs):
super().__init__(default, **kwargs)

class Header(FieldInfo):
"""头部参数"""
def __init__(self, default: Any = Undefined, **kwargs):
super().__init__(default, **kwargs)

class Cookie(FieldInfo):
"""Cookie参数"""
def __init__(self, default: Any = Undefined, **kwargs):
super().__init__(default, **kwargs)

class Body(FieldInfo):
"""请求体参数"""
def __init__(self, default: Any = Undefined, **kwargs):
super().__init__(default, **kwargs)

class Form(FieldInfo):
"""表单参数"""
def __init__(self, default: Any = Undefined, **kwargs):
super().__init__(default, **kwargs)

class File(FieldInfo):
"""文件参数"""
def __init__(self, default: Any = Undefined, **kwargs):
super().__init__(default, **kwargs)

1.2 参数处理流程

参数处理遵循以下流程:

  1. 参数识别: 根据函数签名识别参数类型
  2. 数据提取: 从HTTP请求中提取原始数据
  3. 类型转换: 将字符串数据转换为目标类型
  4. 数据验证: 使用Pydantic进行数据验证
  5. 错误处理: 收集和格式化验证错误
  6. 值注入: 将处理后的值注入到函数调用中
def analyze_param(
*,
param_name: str,
annotation: Any,
value: Any,
is_path_param: bool = False,
) -> ModelField:
"""分析函数参数并创建字段模型"""
# 1. 确定参数类型
if value == Parameter.empty:
if is_path_param:
value = Path(...)
else:
value = Query(None)

# 2. 处理特殊参数
if annotation in SPECIAL_TYPES:
return create_special_field(param_name, annotation)

# 3. 处理依赖注入
if isinstance(value, Depends):
return create_dependency_field(param_name, annotation, value)

# 4. 处理各种参数类型
if isinstance(value, Path):
return create_path_field(param_name, annotation, value)
elif isinstance(value, Query):
return create_query_field(param_name, annotation, value)
elif isinstance(value, Header):
return create_header_field(param_name, annotation, value)
elif isinstance(value, Cookie):
return create_cookie_field(param_name, annotation, value)
elif isinstance(value, Body):
return create_body_field(param_name, annotation, value)
elif isinstance(value, Form):
return create_form_field(param_name, annotation, value)
elif isinstance(value, File):
return create_file_field(param_name, annotation, value)
else:
# 默认处理
return create_default_field(param_name, annotation, value)

2. 路径参数处理

2.1 路径参数定义

from fastapi import Path
from typing import Optional
from enum import Enum

# 基础路径参数
@app.get("/users/{user_id}")
async def get_user(user_id: int = Path(..., description="用户ID")):
return {"user_id": user_id}

# 带验证的路径参数
@app.get("/items/{item_id}")
async def get_item(
item_id: int = Path(
...,
title="商品ID",
description="商品的唯一标识符",
ge=1, # 大于等于1
le=1000, # 小于等于1000
example=123
)
):
return {"item_id": item_id}

# 字符串路径参数
@app.get("/files/{file_path:path}")
async def get_file(
file_path: str = Path(
...,
title="文件路径",
description="文件的完整路径",
min_length=1,
max_length=255,
regex=r"^[a-zA-Z0-9/._-]+$"
)
):
return {"file_path": file_path}

# 枚举路径参数
class ModelName(str, Enum):
alexnet = "alexnet"
resnet = "resnet"
lenet = "lenet"

@app.get("/models/{model_name}")
async def get_model(
model_name: ModelName = Path(..., description="模型名称")
):
return {"model_name": model_name, "message": f"使用模型: {model_name.value}"}

2.2 路径参数提取

import re
from typing import Dict, List, Tuple, Pattern

def compile_path(
path: str
) -> Tuple[Pattern[str], Dict[str, Any], List[str]]:
"""编译路径模式"""
path_regex = ""
path_format = ""
param_convertors = {}
param_names = []

idx = 0
for match in re.finditer(r"{([^}]+)}", path):
param_name, param_type, param_convertor = parse_param(
match.group(1)
)
param_names.append(param_name)
param_convertors[param_name] = param_convertor

# 添加路径前缀
path_regex += re.escape(path[idx:match.start()])
path_format += path[idx:match.start()]

# 添加参数模式
if param_type == "path":
path_regex += f"(?P<{param_name}>.+)"
else:
path_regex += f"(?P<{param_name}>[^/]+)"

path_format += f"{{{param_name}}}"
idx = match.end()

# 添加剩余路径
path_regex += re.escape(path[idx:])
path_format += path[idx:]

return re.compile(path_regex + "$"), param_convertors, param_names

def parse_param(param: str) -> Tuple[str, str, Callable]:
"""解析参数定义"""
if ":" in param:
param_name, param_type = param.split(":", 1)
if param_type == "int":
return param_name, "int", int
elif param_type == "float":
return param_name, "float", float
elif param_type == "path":
return param_name, "path", str
else:
return param_name, "str", str
else:
return param, "str", str

def extract_path_params(
path: str,
path_regex: Pattern[str],
param_convertors: Dict[str, Callable]
) -> Dict[str, Any]:
"""从路径中提取参数"""
match = path_regex.match(path)
if not match:
return {}

params = {}
for name, value in match.groupdict().items():
convertor = param_convertors.get(name, str)
try:
params[name] = convertor(value)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid value for path parameter '{name}': {value}")

return params

2.3 路径参数验证

from pydantic import validator, ValidationError
from typing import Any, Dict, List

class PathParamValidator:
"""路径参数验证器"""

def __init__(self, field: ModelField):
self.field = field
self.field_info = field.field_info

def validate(self, value: Any) -> Any:
"""验证路径参数值"""
try:
# 1. 类型转换
if self.field.type_ == int:
value = int(value)
elif self.field.type_ == float:
value = float(value)
elif self.field.type_ == bool:
value = str(value).lower() in ('true', '1', 'yes', 'on')

# 2. 使用Pydantic验证
validated_value, error = self.field.validate(value, {}, loc=(self.field.name,))

if error:
raise ValidationError([error], self.field.type_)

return validated_value

except (ValueError, TypeError, ValidationError) as e:
raise ValidationError(
[{
'loc': (self.field.name,),
'msg': f'Invalid value for path parameter: {str(e)}',
'type': 'value_error.path_param',
'input': value
}],
self.field.type_
)

3. 查询参数处理

3.1 查询参数定义

from fastapi import Query
from typing import Optional, List, Union
from datetime import datetime
from enum import Enum

# 基础查询参数
@app.get("/items/")
async def read_items(
skip: int = Query(0, ge=0, description="跳过的项目数"),
limit: int = Query(10, ge=1, le=100, description="返回的项目数"),
q: Optional[str] = Query(None, min_length=3, max_length=50, description="搜索关键词")
):
return {"skip": skip, "limit": limit, "q": q}

# 列表查询参数
@app.get("/items/search/")
async def search_items(
tags: List[str] = Query([], description="标签列表"),
categories: List[int] = Query([], description="分类ID列表")
):
return {"tags": tags, "categories": categories}

# 复杂查询参数
class SortOrder(str, Enum):
asc = "asc"
desc = "desc"

@app.get("/products/")
async def get_products(
# 分页参数
page: int = Query(1, ge=1, description="页码"),
size: int = Query(20, ge=1, le=100, description="每页大小"),

# 搜索参数
keyword: Optional[str] = Query(
None,
min_length=2,
max_length=100,
regex=r"^[\w\s-]+$",
description="搜索关键词"
),

# 过滤参数
min_price: Optional[float] = Query(None, ge=0, description="最低价格"),
max_price: Optional[float] = Query(None, ge=0, description="最高价格"),
in_stock: Optional[bool] = Query(None, description="是否有库存"),

# 排序参数
sort_by: Optional[str] = Query("created_at", description="排序字段"),
sort_order: SortOrder = Query(SortOrder.desc, description="排序方向"),

# 时间范围
created_after: Optional[datetime] = Query(None, description="创建时间起始"),
created_before: Optional[datetime] = Query(None, description="创建时间结束")
):
return {
"pagination": {"page": page, "size": size},
"search": {"keyword": keyword},
"filters": {
"price_range": [min_price, max_price],
"in_stock": in_stock,
"date_range": [created_after, created_before]
},
"sorting": {"field": sort_by, "order": sort_order}
}

3.2 查询参数提取

from urllib.parse import parse_qsl, unquote_plus
from typing import Dict, List, Any, Optional

def extract_query_params(
query_string: bytes,
fields: List[ModelField]
) -> Dict[str, Any]:
"""从查询字符串中提取参数"""
if not query_string:
return {}

# 解析查询字符串
parsed_params = parse_qsl(
query_string.decode('utf-8'),
keep_blank_values=True,
strict_parsing=False
)

# 组织参数
params = {}
for key, value in parsed_params:
key = unquote_plus(key)
value = unquote_plus(value)

if key in params:
# 处理重复参数(转换为列表)
if not isinstance(params[key], list):
params[key] = [params[key]]
params[key].append(value)
else:
params[key] = value

# 处理字段映射
result = {}
for field in fields:
field_name = field.name
field_alias = field.alias or field_name

if field_alias in params:
raw_value = params[field_alias]
result[field_name] = process_query_param_value(
raw_value, field
)
elif field.default is not Ellipsis:
result[field_name] = field.default

return result

def process_query_param_value(raw_value: Any, field: ModelField) -> Any:
"""处理查询参数值"""
# 1. 处理列表类型
if field.shape == SHAPE_LIST:
if isinstance(raw_value, str):
# 单个值转换为列表
values = [raw_value]
else:
# 多个值
values = raw_value if isinstance(raw_value, list) else [raw_value]

# 转换列表中的每个值
return [convert_query_param_value(v, field.type_) for v in values]

# 2. 处理单个值
if isinstance(raw_value, list):
# 如果期望单个值但收到列表,取最后一个
raw_value = raw_value[-1]

return convert_query_param_value(raw_value, field.type_)

def convert_query_param_value(value: str, target_type: type) -> Any:
"""转换查询参数值类型"""
if target_type == str:
return value
elif target_type == int:
return int(value)
elif target_type == float:
return float(value)
elif target_type == bool:
return value.lower() in ('true', '1', 'yes', 'on')
elif target_type == datetime:
return datetime.fromisoformat(value.replace('Z', '+00:00'))
elif hasattr(target_type, '__origin__') and target_type.__origin__ is Union:
# 处理Optional类型
non_none_types = [t for t in target_type.__args__ if t != type(None)]
if non_none_types:
return convert_query_param_value(value, non_none_types[0])
else:
# 尝试直接转换
return target_type(value)

3.3 查询参数验证

class QueryParamValidator:
"""查询参数验证器"""

def __init__(self, fields: List[ModelField]):
self.fields = {field.name: field for field in fields}

def validate_all(self, params: Dict[str, Any]) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
"""验证所有查询参数"""
validated_params = {}
errors = []

for field_name, field in self.fields.items():
try:
if field_name in params:
value = params[field_name]
validated_value = self.validate_field(field, value)
validated_params[field_name] = validated_value
elif field.default is not Ellipsis:
validated_params[field_name] = field.default
else:
# 必需参数缺失
errors.append(ErrorWrapper(
ValueError("field required"),
loc=(field_name,)
))
except ValidationError as e:
for error in e.errors():
errors.append(ErrorWrapper(
ValueError(error['msg']),
loc=(field_name,) + error.get('loc', ())
))

return validated_params, errors

def validate_field(self, field: ModelField, value: Any) -> Any:
"""验证单个字段"""
validated_value, error = field.validate(value, {}, loc=(field.name,))

if error:
raise ValidationError([error], field.type_)

return validated_value

4. 请求体参数处理

4.1 JSON请求体

from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
from datetime import datetime
from enum import Enum

# 基础模型
class User(BaseModel):
name: str = Field(..., min_length=1, max_length=100, description="用户名")
email: str = Field(..., regex=r'^[\w\.-]+@[\w\.-]+\.\w+$', description="邮箱")
age: Optional[int] = Field(None, ge=0, le=150, description="年龄")
is_active: bool = Field(True, description="是否激活")

class Config:
schema_extra = {
"example": {
"name": "张三",
"email": "zhangsan@example.com",
"age": 25,
"is_active": True
}
}

@app.post("/users/")
async def create_user(user: User):
return {"message": "用户创建成功", "user": user}

# 嵌套模型
class Address(BaseModel):
street: str = Field(..., min_length=1, description="街道")
city: str = Field(..., min_length=1, description="城市")
country: str = Field(..., min_length=1, description="国家")
postal_code: Optional[str] = Field(None, description="邮政编码")

class UserProfile(BaseModel):
user: User
address: Address
tags: List[str] = Field(default_factory=list, description="标签列表")
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")

@app.post("/profiles/")
async def create_profile(profile: UserProfile):
return {"message": "用户档案创建成功", "profile": profile}

# 多个请求体参数
class Item(BaseModel):
name: str
description: Optional[str] = None
price: float = Field(..., gt=0)
tax: Optional[float] = None

@app.put("/items/{item_id}")
async def update_item(
item_id: int,
item: Item,
user: User,
importance: int = Body(..., gt=0, le=5, description="重要性等级")
):
return {
"item_id": item_id,
"item": item,
"user": user,
"importance": importance
}

4.2 请求体提取和验证

import json
from typing import Any, Dict, List, Optional, Tuple
from pydantic import ValidationError

async def extract_body_params(
request: Request,
body_fields: List[ModelField]
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
"""提取和验证请求体参数"""
if not body_fields:
return {}, []

try:
# 1. 获取请求体内容
body_bytes = await request.body()

if not body_bytes:
return handle_empty_body(body_fields)

# 2. 解析内容类型
content_type = request.headers.get("content-type", "")

if "application/json" in content_type:
return await parse_json_body(body_bytes, body_fields)
elif "application/x-www-form-urlencoded" in content_type:
return await parse_form_body(request, body_fields)
elif "multipart/form-data" in content_type:
return await parse_multipart_body(request, body_fields)
else:
return await parse_raw_body(body_bytes, body_fields)

except Exception as e:
return {}, [ErrorWrapper(
ValueError(f"请求体解析错误: {str(e)}"),
loc=("body",)
)]

async def parse_json_body(
body_bytes: bytes,
body_fields: List[ModelField]
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
"""解析JSON请求体"""
try:
body_data = json.loads(body_bytes.decode('utf-8'))
except (json.JSONDecodeError, UnicodeDecodeError) as e:
return {}, [ErrorWrapper(
ValueError(f"无效的JSON格式: {str(e)}"),
loc=("body",)
)]

return validate_body_fields(body_data, body_fields)

def validate_body_fields(
body_data: Any,
body_fields: List[ModelField]
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
"""验证请求体字段"""
validated_data = {}
errors = []

if len(body_fields) == 1:
# 单个请求体字段
field = body_fields[0]
try:
validated_value, error = field.validate(body_data, {}, loc=(field.name,))
if error:
errors.append(error)
else:
validated_data[field.name] = validated_value
except ValidationError as e:
for error in e.errors():
errors.append(ErrorWrapper(
ValueError(error['msg']),
loc=(field.name,) + error.get('loc', ())
))
else:
# 多个请求体字段
if not isinstance(body_data, dict):
errors.append(ErrorWrapper(
ValueError("多个请求体字段需要JSON对象"),
loc=("body",)
))
return validated_data, errors

for field in body_fields:
field_name = field.alias or field.name
if field_name in body_data:
try:
validated_value, error = field.validate(
body_data[field_name], {}, loc=(field.name,)
)
if error:
errors.append(error)
else:
validated_data[field.name] = validated_value
except ValidationError as e:
for error in e.errors():
errors.append(ErrorWrapper(
ValueError(error['msg']),
loc=(field.name,) + error.get('loc', ())
))
elif field.default is not Ellipsis:
validated_data[field.name] = field.default
else:
errors.append(ErrorWrapper(
ValueError("字段必需"),
loc=(field.name,)
))

return validated_data, errors

5. 表单参数处理

5.1 表单参数定义

from fastapi import Form, File, UploadFile
from typing import List, Optional

# 基础表单
@app.post("/login/")
async def login(
username: str = Form(..., min_length=3, max_length=50),
password: str = Form(..., min_length=6),
remember_me: bool = Form(False)
):
return {
"username": username,
"remember_me": remember_me,
"message": "登录成功"
}

# 文件上传
@app.post("/upload/")
async def upload_file(
file: UploadFile = File(..., description="上传的文件"),
description: Optional[str] = Form(None, description="文件描述")
):
contents = await file.read()
return {
"filename": file.filename,
"content_type": file.content_type,
"size": len(contents),
"description": description
}

# 多文件上传
@app.post("/upload-multiple/")
async def upload_multiple_files(
files: List[UploadFile] = File(..., description="多个文件"),
category: str = Form(..., description="文件分类"),
tags: List[str] = Form([], description="文件标签")
):
file_info = []
for file in files:
contents = await file.read()
file_info.append({
"filename": file.filename,
"content_type": file.content_type,
"size": len(contents)
})

return {
"files": file_info,
"category": category,
"tags": tags
}

# 混合表单和JSON
@app.post("/submit-form/")
async def submit_form(
# 表单字段
name: str = Form(...),
email: str = Form(...),

# 文件字段
avatar: Optional[UploadFile] = File(None),

# JSON字段
metadata: dict = Body(...),

# 查询参数
source: str = Query("web")
):
avatar_info = None
if avatar:
avatar_info = {
"filename": avatar.filename,
"content_type": avatar.content_type,
"size": len(await avatar.read())
}

return {
"name": name,
"email": email,
"avatar": avatar_info,
"metadata": metadata,
"source": source
}

5.2 表单数据提取

from starlette.datastructures import FormData, UploadFile
from typing import Dict, List, Any, Tuple

async def extract_form_params(
request: Request,
form_fields: List[ModelField]
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
"""提取表单参数"""
try:
form_data = await request.form()
return process_form_data(form_data, form_fields)
except Exception as e:
return {}, [ErrorWrapper(
ValueError(f"表单数据解析错误: {str(e)}"),
loc=("form",)
)]

def process_form_data(
form_data: FormData,
form_fields: List[ModelField]
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
"""处理表单数据"""
validated_data = {}
errors = []

for field in form_fields:
field_name = field.alias or field.name

try:
if field_name in form_data:
raw_value = form_data[field_name]

# 处理文件字段
if isinstance(field.field_info, File):
if isinstance(raw_value, UploadFile):
validated_data[field.name] = raw_value
else:
errors.append(ErrorWrapper(
ValueError("期望文件类型"),
loc=(field.name,)
))
# 处理普通表单字段
else:
if field.shape == SHAPE_LIST:
# 处理列表字段
values = form_data.getlist(field_name)
validated_data[field.name] = [
convert_form_value(v, field.type_) for v in values
]
else:
# 处理单个字段
validated_data[field.name] = convert_form_value(
raw_value, field.type_
)
elif field.default is not Ellipsis:
validated_data[field.name] = field.default
else:
errors.append(ErrorWrapper(
ValueError("字段必需"),
loc=(field.name,)
))

except (ValueError, TypeError) as e:
errors.append(ErrorWrapper(
ValueError(f"字段验证错误: {str(e)}"),
loc=(field.name,)
))

return validated_data, errors

def convert_form_value(value: Any, target_type: type) -> Any:
"""转换表单值类型"""
if isinstance(value, UploadFile):
return value

if target_type == str:
return str(value)
elif target_type == int:
return int(value)
elif target_type == float:
return float(value)
elif target_type == bool:
return str(value).lower() in ('true', '1', 'yes', 'on')
else:
return target_type(value)

6. 头部参数处理

6.1 头部参数定义

from fastapi import Header
from typing import Optional, List

# 基础头部参数
@app.get("/items/")
async def read_items(
user_agent: Optional[str] = Header(None, description="用户代理"),
x_token: Optional[str] = Header(None, description="认证令牌"),
accept_language: Optional[str] = Header(None, alias="accept-language", description="接受语言")
):
return {
"user_agent": user_agent,
"token": x_token,
"language": accept_language
}

# 自定义头部处理
@app.get("/api/data")
async def get_data(
# 认证头部
authorization: Optional[str] = Header(None, description="授权头部"),

# 内容协商
accept: str = Header("application/json", description="接受的内容类型"),
accept_encoding: Optional[str] = Header(None, alias="accept-encoding"),

# 自定义头部
x_request_id: Optional[str] = Header(None, alias="x-request-id", description="请求ID"),
x_forwarded_for: Optional[str] = Header(None, alias="x-forwarded-for", description="转发IP"),

# 缓存控制
if_none_match: Optional[str] = Header(None, alias="if-none-match", description="ETag匹配"),
if_modified_since: Optional[str] = Header(None, alias="if-modified-since", description="修改时间")
):
return {
"auth": authorization,
"accept": accept,
"request_id": x_request_id,
"client_ip": x_forwarded_for,
"cache_headers": {
"if_none_match": if_none_match,
"if_modified_since": if_modified_since
}
}

# 头部验证
@app.post("/secure-endpoint")
async def secure_endpoint(
x_api_key: str = Header(
...,
alias="x-api-key",
min_length=32,
max_length=64,
regex=r'^[a-zA-Z0-9]+$',
description="API密钥"
),
content_type: str = Header(
"application/json",
alias="content-type",
regex=r'^application/json(;.*)?$',
description="内容类型"
)
):
return {"message": "访问授权成功", "api_key": x_api_key[:8] + "..."}

6.2 头部参数提取

def extract_header_params(
headers: Dict[str, str],
header_fields: List[ModelField]
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
"""提取头部参数"""
validated_data = {}
errors = []

# 创建不区分大小写的头部字典
case_insensitive_headers = {
key.lower(): value for key, value in headers.items()
}

for field in header_fields:
field_name = field.name
header_name = (field.alias or field_name).lower()

# 处理下划线转换为连字符
if '_' in header_name:
header_name = header_name.replace('_', '-')

try:
if header_name in case_insensitive_headers:
raw_value = case_insensitive_headers[header_name]

# 处理列表类型头部(如Accept-Language: en,zh-CN;q=0.9)
if field.shape == SHAPE_LIST:
values = [v.strip() for v in raw_value.split(',')]
validated_data[field_name] = values
else:
validated_data[field_name] = raw_value

elif field.default is not Ellipsis:
validated_data[field_name] = field.default
else:
errors.append(ErrorWrapper(
ValueError(f"头部字段 '{header_name}' 必需"),
loc=(field_name,)
))

except Exception as e:
errors.append(ErrorWrapper(
ValueError(f"头部字段验证错误: {str(e)}"),
loc=(field_name,)
))

return validated_data, errors

7. Cookie参数处理

7.1 Cookie参数定义

from fastapi import Cookie
from typing import Optional

# 基础Cookie参数
@app.get("/items/")
async def read_items(
session_id: Optional[str] = Cookie(None, description="会话ID"),
user_preferences: Optional[str] = Cookie(None, description="用户偏好"),
theme: str = Cookie("light", description="主题设置")
):
return {
"session_id": session_id,
"preferences": user_preferences,
"theme": theme
}

# Cookie验证
@app.get("/dashboard")
async def dashboard(
auth_token: str = Cookie(
...,
min_length=20,
max_length=100,
description="认证令牌"
),
csrf_token: Optional[str] = Cookie(
None,
alias="csrf-token",
description="CSRF令牌"
)
):
return {
"message": "欢迎访问仪表板",
"auth_token": auth_token[:10] + "...",
"csrf_protected": csrf_token is not None
}

7.2 Cookie参数提取

def extract_cookie_params(
cookies: Dict[str, str],
cookie_fields: List[ModelField]
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
"""提取Cookie参数"""
validated_data = {}
errors = []

for field in cookie_fields:
field_name = field.name
cookie_name = field.alias or field_name

try:
if cookie_name in cookies:
raw_value = cookies[cookie_name]

# URL解码Cookie值
try:
from urllib.parse import unquote
decoded_value = unquote(raw_value)
except Exception:
decoded_value = raw_value

validated_data[field_name] = decoded_value

elif field.default is not Ellipsis:
validated_data[field_name] = field.default
else:
errors.append(ErrorWrapper(
ValueError(f"Cookie '{cookie_name}' 必需"),
loc=(field_name,)
))

except Exception as e:
errors.append(ErrorWrapper(
ValueError(f"Cookie字段验证错误: {str(e)}"),
loc=(field_name,)
))

return validated_data, errors

8. 参数验证系统

8.1 验证规则

from pydantic import validator, root_validator
from typing import Any, Dict

class AdvancedQueryParams(BaseModel):
"""高级查询参数模型"""

# 基础字段
keyword: Optional[str] = Field(None, min_length=2, max_length=100)
page: int = Field(1, ge=1, le=1000)
size: int = Field(20, ge=1, le=100)

# 价格范围
min_price: Optional[float] = Field(None, ge=0)
max_price: Optional[float] = Field(None, ge=0)

# 日期范围
start_date: Optional[datetime] = None
end_date: Optional[datetime] = None

@validator('keyword')
def validate_keyword(cls, v):
if v is not None:
# 移除多余空格
v = ' '.join(v.split())
# 检查是否包含特殊字符
if any(char in v for char in ['<', '>', '&', '"', "'"]):
raise ValueError('关键词不能包含特殊字符')
return v

@validator('max_price')
def validate_max_price(cls, v, values):
if v is not None and 'min_price' in values and values['min_price'] is not None:
if v <= values['min_price']:
raise ValueError('最高价格必须大于最低价格')
return v

@validator('end_date')
def validate_end_date(cls, v, values):
if v is not None and 'start_date' in values and values['start_date'] is not None:
if v <= values['start_date']:
raise ValueError('结束日期必须晚于开始日期')
return v

@root_validator
def validate_date_range(cls, values):
start_date = values.get('start_date')
end_date = values.get('end_date')

if start_date and end_date:
# 检查日期范围不能超过一年
if (end_date - start_date).days > 365:
raise ValueError('日期范围不能超过一年')

return values

# 使用验证模型
@app.get("/advanced-search")
async def advanced_search(params: AdvancedQueryParams = Depends()):
return {"search_params": params.dict()}

8.2 自定义验证器

from typing import Callable, Any
from pydantic.validators import str_validator

def create_custom_validator(validation_func: Callable[[Any], Any]):
"""创建自定义验证器"""
def validator_wrapper(cls, v):
return validation_func(v)
return validator_wrapper

# 邮箱验证器
def validate_email(email: str) -> str:
import re
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(pattern, email):
raise ValueError('无效的邮箱格式')
return email.lower()

# 手机号验证器
def validate_phone(phone: str) -> str:
import re
# 中国手机号格式
pattern = r'^1[3-9]\d{9}$'
if not re.match(pattern, phone):
raise ValueError('无效的手机号格式')
return phone

# 身份证验证器
def validate_id_card(id_card: str) -> str:
import re
# 简化的身份证验证
pattern = r'^[1-9]\d{5}(18|19|20)\d{2}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])\d{3}[\dXx]$'
if not re.match(pattern, id_card):
raise ValueError('无效的身份证号格式')
return id_card.upper()

class UserRegistration(BaseModel):
"""用户注册模型"""

username: str = Field(..., min_length=3, max_length=20)
email: str = Field(...)
phone: str = Field(...)
id_card: Optional[str] = Field(None)

# 应用自定义验证器
_validate_email = validator('email', allow_reuse=True)(validate_email)
_validate_phone = validator('phone', allow_reuse=True)(validate_phone)
_validate_id_card = validator('id_card', allow_reuse=True)(validate_id_card)

@validator('username')
def validate_username(cls, v):
import re
if not re.match(r'^[a-zA-Z0-9_]+$', v):
raise ValueError('用户名只能包含字母、数字和下划线')
return v

8.3 错误处理和格式化

from fastapi import HTTPException
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
from typing import List, Dict, Any

class ValidationErrorHandler:
"""验证错误处理器"""

@staticmethod
def format_validation_errors(errors: List[ErrorWrapper]) -> Dict[str, Any]:
"""格式化验证错误"""
formatted_errors = []

for error in errors:
error_dict = {
"field": ".".join(str(loc) for loc in error.loc_tuple()),
"message": str(error.exc),
"type": error.exc.__class__.__name__,
"input": getattr(error.exc, 'input', None)
}
formatted_errors.append(error_dict)

return {
"detail": "参数验证失败",
"errors": formatted_errors
}

@staticmethod
def create_validation_exception(errors: List[ErrorWrapper]) -> HTTPException:
"""创建验证异常"""
error_response = ValidationErrorHandler.format_validation_errors(errors)
return HTTPException(
status_code=422,
detail=error_response
)

# 全局异常处理器
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""处理请求验证错误"""
return JSONResponse(
status_code=422,
content=ValidationErrorHandler.format_validation_errors(exc.errors())
)

@app.exception_handler(ValidationError)
async def pydantic_validation_exception_handler(request: Request, exc: ValidationError):
"""处理Pydantic验证错误"""
errors = [ErrorWrapper(error, loc=error.get('loc', ())) for error in exc.errors()]
return JSONResponse(
status_code=422,
content=ValidationErrorHandler.format_validation_errors(errors)
)

9. 性能优化

9.1 参数解析缓存

from functools import lru_cache
from typing import Dict, List, Tuple

class ParameterCache:
"""参数解析缓存"""

def __init__(self, max_size: int = 1000):
self.max_size = max_size
self._field_cache = {}
self._validation_cache = {}

@lru_cache(maxsize=1000)
def get_compiled_fields(self, func_signature: str) -> List[ModelField]:
"""获取编译后的字段列表"""
# 这里应该是实际的字段编译逻辑
return self._compile_fields_from_signature(func_signature)

def _compile_fields_from_signature(self, signature: str) -> List[ModelField]:
"""从函数签名编译字段"""
# 实际实现会更复杂
pass

@lru_cache(maxsize=5000)
def get_validation_result(self, field_name: str, value_hash: int, field_config: str):
"""获取验证结果缓存"""
# 缓存验证结果
pass

# 全局缓存实例
parameter_cache = ParameterCache()

9.2 批量参数处理

class BatchParameterProcessor:
"""批量参数处理器"""

def __init__(self):
self.processors = {
'path': self.process_path_params,
'query': self.process_query_params,
'header': self.process_header_params,
'cookie': self.process_cookie_params,
'body': self.process_body_params,
}

async def process_all_params(
self,
request: Request,
dependant: Dependant
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
"""批量处理所有参数"""
all_values = {}
all_errors = []

# 并行处理不同类型的参数
tasks = []

if dependant.path_params:
tasks.append(self.process_path_params(request, dependant.path_params))

if dependant.query_params:
tasks.append(self.process_query_params(request, dependant.query_params))

if dependant.header_params:
tasks.append(self.process_header_params(request, dependant.header_params))

if dependant.cookie_params:
tasks.append(self.process_cookie_params(request, dependant.cookie_params))

if dependant.body_params:
tasks.append(self.process_body_params(request, dependant.body_params))

# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)

# 合并结果
for result in results:
if isinstance(result, Exception):
all_errors.append(ErrorWrapper(result, loc=('processing',)))
else:
values, errors = result
all_values.update(values)
all_errors.extend(errors)

return all_values, all_errors

async def process_path_params(self, request: Request, fields: List[ModelField]):
# 实现路径参数处理
pass

async def process_query_params(self, request: Request, fields: List[ModelField]):
# 实现查询参数处理
pass

# 其他处理方法...