C0726N01-FastAPI 依赖注入系统
概述
FastAPI的依赖注入系统是其最强大的特性之一,它提供了一种声明式的方式来管理应用程序的依赖关系。本章将深入分析依赖注入系统的设计原理、实现机制以及高级用法。
1. 依赖注入系统架构
1.1 核心概念
依赖注入系统基于以下核心概念:
# 核心类型定义
DependencyCallable = Callable[..., Any]
class Depends:
"""依赖声明类"""
def __init__(
self,
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True
):
self.dependency = dependency
self.use_cache = use_cache
class Dependant:
"""依赖对象模型"""
def __init__(
self,
*,
path_params: List[ModelField] = None,
query_params: List[ModelField] = None,
header_params: List[ModelField] = None,
cookie_params: List[ModelField] = None,
body_params: List[ModelField] = None,
dependencies: List[Depends] = None,
security_requirements: List[SecurityRequirement] = None,
name: Optional[str] = None,
call: Optional[Callable[..., Any]] = None,
request_param_name: Optional[str] = None,
websocket_param_name: Optional[str] = None,
http_connection_param_name: Optional[str] = None,
response_param_name: Optional[str] = None,
background_tasks_param_name: Optional[str] = None,
security_scopes_param_name: Optional[str] = None,
use_cache: bool = True,
):
self.path_params = path_params or []
self.query_params = query_params or []
self.header_params = header_params or []
self.cookie_params = cookie_params or []
self.body_params = body_params or []
self.dependencies = dependencies or []
self.security_requirements = security_requirements or []
self.name = name
self.call = call
self.request_param_name = request_param_name
self.websocket_param_name = websocket_param_name
self.http_connection_param_name = http_connection_param_name
self.response_param_name = response_param_name
self.background_tasks_param_name = background_tasks_param_name
self.security_scopes_param_name = security_scopes_param_name
self.use_cache = use_cache
1.2 依赖解析流程
依赖解析遵循以下流程:
- 函数签名分析: 分析端点函数的参数类型和注解
- 依赖图构建: 递归分析所有依赖关系
- 参数分类: 将参数分类为路径、查询、请求体等类型
- 依赖实例化: 按照依赖顺序实例化所有依赖
- 缓存管理: 管理依赖实例的生命周期和缓存
def get_dependant(
*,
path: str,
call: Callable[..., Any],
name: Optional[str] = None,
security_requirements: Optional[List[SecurityRequirement]] = None,
use_cache: bool = True,
) -> Dependant:
"""获取函数的依赖信息"""
# 1. 获取函数签名
signature = inspect.signature(call)
# 2. 初始化依赖对象
dependant = Dependant(
call=call,
name=name,
security_requirements=security_requirements or [],
use_cache=use_cache,
)
# 3. 分析每个参数
for param_name, param in signature.parameters.items():
param_field = analyze_param(
param_name=param_name,
annotation=param.annotation,
value=param.default,
is_path_param=param_name in get_path_param_names(path),
)
# 4. 分类参数
if param_field.field_info and isinstance(param_field.field_info, Depends):
# 依赖参数
sub_dependant = get_dependant(
path=path,
call=param_field.field_info.dependency,
name=param_name,
use_cache=param_field.field_info.use_cache,
)
dependant.dependencies.append(sub_dependant)
elif is_path_param(param_name, path):
dependant.path_params.append(param_field)
elif is_query_param(param_field):
dependant.query_params.append(param_field)
elif is_header_param(param_field):
dependant.header_params.append(param_field)
elif is_cookie_param(param_field):
dependant.cookie_params.append(param_field)
elif is_body_param(param_field):
dependant.body_params.append(param_field)
elif is_special_param(param_name):
# 特殊参数(Request, WebSocket等)
set_special_param_name(dependant, param_name, param.annotation)
return dependant
2. 依赖声明和使用
2.1 基础依赖声明
# 简单依赖函数
def get_db():
"""获取数据库连接"""
db = SessionLocal()
try:
yield db
finally:
db.close()
# 使用依赖
@app.get("/users/")
async def read_users(db: Session = Depends(get_db)):
return db.query(User).all()
2.2 依赖链
# 基础依赖
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# 依赖于基础依赖的高级依赖
def get_user_service(db: Session = Depends(get_db)):
return UserService(db)
# 最终端点
@app.get("/users/{user_id}")
async def get_user(
user_id: int,
user_service: UserService = Depends(get_user_service)
):
return user_service.get_user(user_id)
2.3 类作为依赖
class DatabaseManager:
def __init__(self):
self.connection = create_connection()
def get_session(self):
return self.connection.session()
def close(self):
self.connection.close()
# 使用类作为依赖
@app.get("/data")
async def get_data(db_manager: DatabaseManager = Depends(DatabaseManager)):
session = db_manager.get_session()
# 使用session
return {"data": "example"}
2.4 子依赖
# 认证依赖
async def get_current_user(token: str = Depends(oauth2_scheme)):
user = decode_token(token)
if not user:
raise HTTPException(status_code=401, detail="Invalid token")
return user
# 权限检查依赖
async def get_admin_user(current_user: User = Depends(get_current_user)):
if not current_user.is_admin:
raise HTTPException(status_code=403, detail="Not enough permissions")
return current_user
# 使用多层依赖
@app.delete("/users/{user_id}")
async def delete_user(
user_id: int,
admin_user: User = Depends(get_admin_user), # 自动包含认证检查
db: Session = Depends(get_db)
):
# 只有管理员才能执行此操作
db.query(User).filter(User.id == user_id).delete()
db.commit()
return {"message": "User deleted"}
3. 依赖解析实现
3.1 参数分析
def analyze_param(
*,
param_name: str,
annotation: Any,
value: Any,
is_path_param: bool = False,
) -> ModelField:
"""分析函数参数"""
# 1. 处理默认值
if value == Parameter.empty:
if is_path_param:
# 路径参数必须有值
value = Path(...)
else:
value = Required
# 2. 处理特殊类型
if annotation == Request:
return create_request_field(param_name)
elif annotation == WebSocket:
return create_websocket_field(param_name)
elif annotation == Response:
return create_response_field(param_name)
elif annotation == BackgroundTasks:
return create_background_tasks_field(param_name)
# 3. 处理依赖注入
if isinstance(value, Depends):
return create_dependency_field(param_name, annotation, value)
# 4. 处理参数类型
if isinstance(value, (Path, Query, Header, Cookie, Body, Form, File)):
return create_param_field(param_name, annotation, value)
# 5. 默认处理
return create_default_field(param_name, annotation, value)
3.2 依赖实例化
async def solve_dependencies(
*,
request: Union[Request, WebSocket],
dependant: Dependant,
body: Optional[Union[Dict[str, Any], FormData]] = None,
background_tasks: Optional[BackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str, ...]], Any]] = None,
) -> Tuple[Dict[str, Any], List[ErrorWrapper], Optional[BackgroundTasks], Response]:
"""解析并实例化所有依赖"""
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
# 1. 处理路径参数
path_values, path_errors = get_path_params_values(
request=request,
dependant=dependant
)
values.update(path_values)
errors.extend(path_errors)
# 2. 处理查询参数
query_values, query_errors = await get_query_params_values(
request=request,
dependant=dependant
)
values.update(query_values)
errors.extend(query_errors)
# 3. 处理头部参数
header_values, header_errors = get_header_params_values(
request=request,
dependant=dependant
)
values.update(header_values)
errors.extend(header_errors)
# 4. 处理Cookie参数
cookie_values, cookie_errors = get_cookie_params_values(
request=request,
dependant=dependant
)
values.update(cookie_values)
errors.extend(cookie_errors)
# 5. 处理请求体参数
if dependant.body_params:
body_values, body_errors = await get_body_params_values(
request=request,
dependant=dependant,
body=body
)
values.update(body_values)
errors.extend(body_errors)
# 6. 处理特殊参数
if dependant.request_param_name:
values[dependant.request_param_name] = request
if dependant.websocket_param_name:
values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks
if dependant.response_param_name:
values[dependant.response_param_name] = response
# 7. 处理依赖注入
for dependency in dependant.dependencies:
dependency_values, dependency_errors = await solve_dependencies(
request=request,
dependant=dependency,
body=body,
background_tasks=background_tasks,
response=response,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
)
if dependency_errors:
errors.extend(dependency_errors)
continue
# 调用依赖函数
dependency_value = await call_dependency(
dependency=dependency,
values=dependency_values,
dependency_cache=dependency_cache,
)
if dependency.name:
values[dependency.name] = dependency_value
return values, errors, background_tasks, response
3.3 依赖调用
async def call_dependency(
*,
dependency: Dependant,
values: Dict[str, Any],
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str, ...]], Any]] = None,
) -> Any:
"""调用依赖函数"""
# 1. 检查缓存
if dependency.use_cache and dependency_cache is not None:
cache_key = get_dependency_cache_key(dependency, values)
if cache_key in dependency_cache:
return dependency_cache[cache_key]
# 2. 准备函数参数
kwargs = {}
for param_name, param_value in values.items():
if param_name in get_function_param_names(dependency.call):
kwargs[param_name] = param_value
# 3. 调用函数
if asyncio.iscoroutinefunction(dependency.call):
result = await dependency.call(**kwargs)
else:
result = dependency.call(**kwargs)
# 4. 处理生成器
if inspect.isgenerator(result) or inspect.isasyncgen(result):
try:
if inspect.isasyncgen(result):
dependency_value = await result.__anext__()
else:
dependency_value = next(result)
except StopIteration:
dependency_value = None
# 注册清理函数
if hasattr(dependency, 'cleanup_functions'):
dependency.cleanup_functions.append(result)
else:
dependency_value = result
# 5. 缓存结果
if dependency.use_cache and dependency_cache is not None:
cache_key = get_dependency_cache_key(dependency, values)
dependency_cache[cache_key] = dependency_value
return dependency_value
4. 依赖缓存机制
4.1 缓存键生成
def get_dependency_cache_key(
dependency: Dependant,
values: Dict[str, Any]
) -> Tuple[Callable[..., Any], Tuple[str, ...]]:
"""生成依赖缓存键"""
# 1. 使用函数对象作为主键
func = dependency.call
# 2. 使用安全作用域作为辅助键
security_scopes = tuple(sorted(
scope.scope for scope in dependency.security_requirements
)) if dependency.security_requirements else ()
return (func, security_scopes)
4.2 缓存生命周期
class DependencyCache:
"""依赖缓存管理器"""
def __init__(self):
self._cache: Dict[Any, Any] = {}
self._cleanup_functions: List[Callable] = []
def get(self, key: Any) -> Any:
return self._cache.get(key)
def set(self, key: Any, value: Any) -> None:
self._cache[key] = value
def add_cleanup(self, cleanup_func: Callable) -> None:
self._cleanup_functions.append(cleanup_func)
async def cleanup(self) -> None:
"""清理缓存和资源"""
for cleanup_func in self._cleanup_functions:
try:
if asyncio.iscoroutinefunction(cleanup_func):
await cleanup_func()
else:
cleanup_func()
except Exception as e:
# 记录清理错误但不中断流程
logger.error(f"Error during dependency cleanup: {e}")
self._cache.clear()
self._cleanup_functions.clear()
4.3 请求级缓存
# 请求级缓存示例
def get_expensive_resource():
"""昂贵的资源获取操作"""
print("Creating expensive resource...") # 只会打印一次
return ExpensiveResource()
@app.get("/endpoint1")
async def endpoint1(
resource: ExpensiveResource = Depends(get_expensive_resource)
):
return {"endpoint": "1", "resource_id": resource.id}
@app.get("/endpoint2")
async def endpoint2(
resource: ExpensiveResource = Depends(get_expensive_resource) # 使用缓存的实例
):
return {"endpoint": "2", "resource_id": resource.id}
5. 高级依赖模式
5.1 条件依赖
def get_db_connection(use_replica: bool = Query(False)):
"""根据条件选择数据库连接"""
if use_replica:
return get_replica_db()
else:
return get_master_db()
@app.get("/data")
async def get_data(
db = Depends(get_db_connection) # 根据查询参数选择数据库
):
return db.query("SELECT * FROM data")
5.2 工厂模式依赖
class ServiceFactory:
def __init__(self):
self._services = {}
def get_service(self, service_type: str):
if service_type not in self._services:
if service_type == "user":
self._services[service_type] = UserService()
elif service_type == "order":
self._services[service_type] = OrderService()
else:
raise ValueError(f"Unknown service type: {service_type}")
return self._services[service_type]
# 全局工厂实例
service_factory = ServiceFactory()
def get_user_service():
return service_factory.get_service("user")
def get_order_service():
return service_factory.get_service("order")
@app.get("/users/{user_id}")
async def get_user(
user_id: int,
user_service = Depends(get_user_service)
):
return user_service.get_user(user_id)
5.3 上下文管理器依赖
from contextlib import asynccontextmanager
@asynccontextmanager
async def get_db_transaction():
"""数据库事务上下文管理器"""
db = SessionLocal()
transaction = db.begin()
try:
yield db
transaction.commit()
except Exception:
transaction.rollback()
raise
finally:
db.close()
# 使用上下文管理器依赖
@app.post("/users/")
async def create_user(
user_data: UserCreate,
db: Session = Depends(get_db_transaction)
):
# 自动事务管理
user = User(**user_data.dict())
db.add(user)
# 事务会在函数结束时自动提交或回滚
return user
5.4 依赖覆盖
# 原始依赖
def get_settings():
return Settings()
# 测试时的依赖覆盖
def get_test_settings():
return Settings(testing=True, database_url="sqlite:///:memory:")
# 在测试中覆盖依赖
app.dependency_overrides[get_settings] = get_test_settings
@app.get("/config")
async def get_config(settings: Settings = Depends(get_settings)):
# 在测试环境中会使用测试设置
return {"database_url": settings.database_url}
6. 安全依赖
6.1 认证依赖
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
security = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security)
) -> User:
"""获取当前用户"""
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=401, detail="Invalid token")
except JWTError:
raise HTTPException(status_code=401, detail="Invalid token")
user = get_user_by_username(username)
if user is None:
raise HTTPException(status_code=401, detail="User not found")
return user
@app.get("/protected")
async def protected_endpoint(
current_user: User = Depends(get_current_user)
):
return {"message": f"Hello, {current_user.username}!"}
6.2 权限依赖
from functools import wraps
from typing import List
def require_permissions(required_permissions: List[str]):
"""权限检查依赖工厂"""
def permission_checker(
current_user: User = Depends(get_current_user)
) -> User:
user_permissions = get_user_permissions(current_user.id)
for permission in required_permissions:
if permission not in user_permissions:
raise HTTPException(
status_code=403,
detail=f"Permission '{permission}' required"
)
return current_user
return permission_checker
# 使用权限依赖
@app.delete("/users/{user_id}")
async def delete_user(
user_id: int,
current_user: User = Depends(require_permissions(["user:delete"]))
):
# 只有具有 user:delete 权限的用户才能访问
delete_user_by_id(user_id)
return {"message": "User deleted"}
6.3 作用域依赖
from fastapi.security import SecurityScopes
async def get_current_user_with_scopes(
security_scopes: SecurityScopes,
token: str = Depends(oauth2_scheme)
) -> User:
"""带作用域检查的用户认证"""
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_value = "Bearer"
credentials_exception = HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": authenticate_value},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_scopes = payload.get("scopes", [])
except JWTError:
raise credentials_exception
user = get_user_by_username(username)
if user is None:
raise credentials_exception
for scope in security_scopes.scopes:
if scope not in token_scopes:
raise HTTPException(
status_code=403,
detail="Not enough permissions",
headers={"WWW-Authenticate": authenticate_value},
)
return user
# 使用作用域
@app.get("/users/me/items/")
async def read_own_items(
current_user: User = Security(get_current_user_with_scopes, scopes=["items:read"])
):
return get_user_items(current_user.id)
7. 依赖测试
7.1 依赖模拟
import pytest
from fastapi.testclient import TestClient
# 模拟依赖
def mock_get_db():
return MockDatabase()
def mock_get_current_user():
return User(id=1, username="testuser", email="test@example.com")
# 测试配置
@pytest.fixture
def client():
# 覆盖依赖
app.dependency_overrides[get_db] = mock_get_db
app.dependency_overrides[get_current_user] = mock_get_current_user
with TestClient(app) as client:
yield client
# 清理覆盖
app.dependency_overrides.clear()
# 测试用例
def test_protected_endpoint(client):
response = client.get("/protected")
assert response.status_code == 200
assert response.json()["message"] == "Hello, testuser!"
7.2 依赖隔离测试
def test_dependency_isolation():
"""测试依赖隔离"""
call_count = 0
def counting_dependency():
nonlocal call_count
call_count += 1
return f"call_{call_count}"
# 创建测试应用
test_app = FastAPI()
@test_app.get("/test1")
async def test1(dep = Depends(counting_dependency)):
return {"result": dep}
@test_app.get("/test2")
async def test2(dep = Depends(counting_dependency)):
return {"result": dep}
with TestClient(test_app) as client:
# 每个请求都应该有独立的依赖实例
response1 = client.get("/test1")
response2 = client.get("/test2")
assert response1.json()["result"] == "call_1"
assert response2.json()["result"] == "call_2"
8. 性能优化
8.1 依赖预编译
class DependencyCompiler:
"""依赖编译器"""
def __init__(self):
self._compiled_dependencies = {}
def compile_dependency(self, func: Callable) -> Dependant:
"""预编译依赖"""
if func in self._compiled_dependencies:
return self._compiled_dependencies[func]
dependant = get_dependant(path="", call=func)
self._compiled_dependencies[func] = dependant
return dependant
def get_compiled_dependency(self, func: Callable) -> Optional[Dependant]:
return self._compiled_dependencies.get(func)
# 全局编译器实例
dependency_compiler = DependencyCompiler()
8.2 依赖池化
from typing import Dict, List
import asyncio
class DependencyPool:
"""依赖对象池"""
def __init__(self, max_size: int = 100):
self.max_size = max_size
self._pools: Dict[Callable, List[Any]] = {}
self._locks: Dict[Callable, asyncio.Lock] = {}
async def get_dependency(self, dependency_func: Callable) -> Any:
"""从池中获取依赖对象"""
if dependency_func not in self._locks:
self._locks[dependency_func] = asyncio.Lock()
async with self._locks[dependency_func]:
pool = self._pools.get(dependency_func, [])
if pool:
return pool.pop()
else:
# 创建新实例
if asyncio.iscoroutinefunction(dependency_func):
return await dependency_func()
else:
return dependency_func()
async def return_dependency(self, dependency_func: Callable, instance: Any) -> None:
"""将依赖对象返回池中"""
async with self._locks[dependency_func]:
pool = self._pools.setdefault(dependency_func, [])
if len(pool) < self.max_size:
pool.append(instance)
else:
# 池已满,销毁实例
if hasattr(instance, 'close'):
await instance.close()
8.3 依赖图优化
class DependencyGraph:
"""依赖图优化器"""
def __init__(self):
self.graph = {}
self.resolved_order = []
def add_dependency(self, func: Callable, dependencies: List[Callable]):
"""添加依赖关系"""
self.graph[func] = dependencies
def topological_sort(self) -> List[Callable]:
"""拓扑排序,确定最优解析顺序"""
visited = set()
temp_visited = set()
result = []
def visit(node):
if node in temp_visited:
raise ValueError("Circular dependency detected")
if node in visited:
return
temp_visited.add(node)
for dependency in self.graph.get(node, []):
visit(dependency)
temp_visited.remove(node)
visited.add(node)
result.append(node)
for node in self.graph:
if node not in visited:
visit(node)
return result
def optimize_resolution_order(self) -> List[Callable]:
"""优化依赖解析顺序"""
return self.topological_sort()
9. 错误处理和调试
9.1 依赖错误处理
class DependencyError(Exception):
"""依赖错误基类"""
pass
class CircularDependencyError(DependencyError):
"""循环依赖错误"""
pass
class MissingDependencyError(DependencyError):
"""缺失依赖错误"""
pass
def validate_dependencies(dependant: Dependant) -> List[str]:
"""验证依赖配置"""
errors = []
visited = set()
def check_circular(dep: Dependant, path: List[str]):
if dep.call.__name__ in path:
errors.append(f"Circular dependency: {' -> '.join(path + [dep.call.__name__])}")
return
if dep.call in visited:
return
visited.add(dep.call)
new_path = path + [dep.call.__name__]
for sub_dep in dep.dependencies:
check_circular(sub_dep, new_path)
check_circular(dependant, [])
return errors
9.2 依赖调试工具
class DependencyDebugger:
"""依赖调试器"""
def __init__(self):
self.call_stack = []
self.resolution_times = {}
def trace_dependency_resolution(self, dependant: Dependant):
"""跟踪依赖解析过程"""
import time
def trace_call(func_name: str):
start_time = time.time()
self.call_stack.append(func_name)
def end_trace():
end_time = time.time()
self.resolution_times[func_name] = end_time - start_time
self.call_stack.pop()
return end_trace
return trace_call
def print_dependency_tree(self, dependant: Dependant, indent: int = 0):
"""打印依赖树"""
prefix = " " * indent
print(f"{prefix}{dependant.call.__name__}")
for dep in dependant.dependencies:
self.print_dependency_tree(dep, indent + 1)
def get_performance_report(self) -> Dict[str, float]:
"""获取性能报告"""
return dict(sorted(
self.resolution_times.items(),
key=lambda x: x[1],
reverse=True
))