跳到主要内容

C0726N01-FastAPI 依赖注入系统

概述

FastAPI的依赖注入系统是其最强大的特性之一,它提供了一种声明式的方式来管理应用程序的依赖关系。本章将深入分析依赖注入系统的设计原理、实现机制以及高级用法。

1. 依赖注入系统架构

1.1 核心概念

依赖注入系统基于以下核心概念:

# 核心类型定义
DependencyCallable = Callable[..., Any]

class Depends:
"""依赖声明类"""
def __init__(
self,
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True
):
self.dependency = dependency
self.use_cache = use_cache

class Dependant:
"""依赖对象模型"""
def __init__(
self,
*,
path_params: List[ModelField] = None,
query_params: List[ModelField] = None,
header_params: List[ModelField] = None,
cookie_params: List[ModelField] = None,
body_params: List[ModelField] = None,
dependencies: List[Depends] = None,
security_requirements: List[SecurityRequirement] = None,
name: Optional[str] = None,
call: Optional[Callable[..., Any]] = None,
request_param_name: Optional[str] = None,
websocket_param_name: Optional[str] = None,
http_connection_param_name: Optional[str] = None,
response_param_name: Optional[str] = None,
background_tasks_param_name: Optional[str] = None,
security_scopes_param_name: Optional[str] = None,
use_cache: bool = True,
):
self.path_params = path_params or []
self.query_params = query_params or []
self.header_params = header_params or []
self.cookie_params = cookie_params or []
self.body_params = body_params or []
self.dependencies = dependencies or []
self.security_requirements = security_requirements or []
self.name = name
self.call = call
self.request_param_name = request_param_name
self.websocket_param_name = websocket_param_name
self.http_connection_param_name = http_connection_param_name
self.response_param_name = response_param_name
self.background_tasks_param_name = background_tasks_param_name
self.security_scopes_param_name = security_scopes_param_name
self.use_cache = use_cache

1.2 依赖解析流程

依赖解析遵循以下流程:

  1. 函数签名分析: 分析端点函数的参数类型和注解
  2. 依赖图构建: 递归分析所有依赖关系
  3. 参数分类: 将参数分类为路径、查询、请求体等类型
  4. 依赖实例化: 按照依赖顺序实例化所有依赖
  5. 缓存管理: 管理依赖实例的生命周期和缓存
def get_dependant(
*,
path: str,
call: Callable[..., Any],
name: Optional[str] = None,
security_requirements: Optional[List[SecurityRequirement]] = None,
use_cache: bool = True,
) -> Dependant:
"""获取函数的依赖信息"""
# 1. 获取函数签名
signature = inspect.signature(call)

# 2. 初始化依赖对象
dependant = Dependant(
call=call,
name=name,
security_requirements=security_requirements or [],
use_cache=use_cache,
)

# 3. 分析每个参数
for param_name, param in signature.parameters.items():
param_field = analyze_param(
param_name=param_name,
annotation=param.annotation,
value=param.default,
is_path_param=param_name in get_path_param_names(path),
)

# 4. 分类参数
if param_field.field_info and isinstance(param_field.field_info, Depends):
# 依赖参数
sub_dependant = get_dependant(
path=path,
call=param_field.field_info.dependency,
name=param_name,
use_cache=param_field.field_info.use_cache,
)
dependant.dependencies.append(sub_dependant)
elif is_path_param(param_name, path):
dependant.path_params.append(param_field)
elif is_query_param(param_field):
dependant.query_params.append(param_field)
elif is_header_param(param_field):
dependant.header_params.append(param_field)
elif is_cookie_param(param_field):
dependant.cookie_params.append(param_field)
elif is_body_param(param_field):
dependant.body_params.append(param_field)
elif is_special_param(param_name):
# 特殊参数(Request, WebSocket等)
set_special_param_name(dependant, param_name, param.annotation)

return dependant

2. 依赖声明和使用

2.1 基础依赖声明

# 简单依赖函数
def get_db():
"""获取数据库连接"""
db = SessionLocal()
try:
yield db
finally:
db.close()

# 使用依赖
@app.get("/users/")
async def read_users(db: Session = Depends(get_db)):
return db.query(User).all()

2.2 依赖链

# 基础依赖
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

# 依赖于基础依赖的高级依赖
def get_user_service(db: Session = Depends(get_db)):
return UserService(db)

# 最终端点
@app.get("/users/{user_id}")
async def get_user(
user_id: int,
user_service: UserService = Depends(get_user_service)
):
return user_service.get_user(user_id)

2.3 类作为依赖

class DatabaseManager:
def __init__(self):
self.connection = create_connection()

def get_session(self):
return self.connection.session()

def close(self):
self.connection.close()

# 使用类作为依赖
@app.get("/data")
async def get_data(db_manager: DatabaseManager = Depends(DatabaseManager)):
session = db_manager.get_session()
# 使用session
return {"data": "example"}

2.4 子依赖

# 认证依赖
async def get_current_user(token: str = Depends(oauth2_scheme)):
user = decode_token(token)
if not user:
raise HTTPException(status_code=401, detail="Invalid token")
return user

# 权限检查依赖
async def get_admin_user(current_user: User = Depends(get_current_user)):
if not current_user.is_admin:
raise HTTPException(status_code=403, detail="Not enough permissions")
return current_user

# 使用多层依赖
@app.delete("/users/{user_id}")
async def delete_user(
user_id: int,
admin_user: User = Depends(get_admin_user), # 自动包含认证检查
db: Session = Depends(get_db)
):
# 只有管理员才能执行此操作
db.query(User).filter(User.id == user_id).delete()
db.commit()
return {"message": "User deleted"}

3. 依赖解析实现

3.1 参数分析

def analyze_param(
*,
param_name: str,
annotation: Any,
value: Any,
is_path_param: bool = False,
) -> ModelField:
"""分析函数参数"""
# 1. 处理默认值
if value == Parameter.empty:
if is_path_param:
# 路径参数必须有值
value = Path(...)
else:
value = Required

# 2. 处理特殊类型
if annotation == Request:
return create_request_field(param_name)
elif annotation == WebSocket:
return create_websocket_field(param_name)
elif annotation == Response:
return create_response_field(param_name)
elif annotation == BackgroundTasks:
return create_background_tasks_field(param_name)

# 3. 处理依赖注入
if isinstance(value, Depends):
return create_dependency_field(param_name, annotation, value)

# 4. 处理参数类型
if isinstance(value, (Path, Query, Header, Cookie, Body, Form, File)):
return create_param_field(param_name, annotation, value)

# 5. 默认处理
return create_default_field(param_name, annotation, value)

3.2 依赖实例化

async def solve_dependencies(
*,
request: Union[Request, WebSocket],
dependant: Dependant,
body: Optional[Union[Dict[str, Any], FormData]] = None,
background_tasks: Optional[BackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str, ...]], Any]] = None,
) -> Tuple[Dict[str, Any], List[ErrorWrapper], Optional[BackgroundTasks], Response]:
"""解析并实例化所有依赖"""
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []

# 1. 处理路径参数
path_values, path_errors = get_path_params_values(
request=request,
dependant=dependant
)
values.update(path_values)
errors.extend(path_errors)

# 2. 处理查询参数
query_values, query_errors = await get_query_params_values(
request=request,
dependant=dependant
)
values.update(query_values)
errors.extend(query_errors)

# 3. 处理头部参数
header_values, header_errors = get_header_params_values(
request=request,
dependant=dependant
)
values.update(header_values)
errors.extend(header_errors)

# 4. 处理Cookie参数
cookie_values, cookie_errors = get_cookie_params_values(
request=request,
dependant=dependant
)
values.update(cookie_values)
errors.extend(cookie_errors)

# 5. 处理请求体参数
if dependant.body_params:
body_values, body_errors = await get_body_params_values(
request=request,
dependant=dependant,
body=body
)
values.update(body_values)
errors.extend(body_errors)

# 6. 处理特殊参数
if dependant.request_param_name:
values[dependant.request_param_name] = request
if dependant.websocket_param_name:
values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks
if dependant.response_param_name:
values[dependant.response_param_name] = response

# 7. 处理依赖注入
for dependency in dependant.dependencies:
dependency_values, dependency_errors = await solve_dependencies(
request=request,
dependant=dependency,
body=body,
background_tasks=background_tasks,
response=response,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
)

if dependency_errors:
errors.extend(dependency_errors)
continue

# 调用依赖函数
dependency_value = await call_dependency(
dependency=dependency,
values=dependency_values,
dependency_cache=dependency_cache,
)

if dependency.name:
values[dependency.name] = dependency_value

return values, errors, background_tasks, response

3.3 依赖调用

async def call_dependency(
*,
dependency: Dependant,
values: Dict[str, Any],
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str, ...]], Any]] = None,
) -> Any:
"""调用依赖函数"""
# 1. 检查缓存
if dependency.use_cache and dependency_cache is not None:
cache_key = get_dependency_cache_key(dependency, values)
if cache_key in dependency_cache:
return dependency_cache[cache_key]

# 2. 准备函数参数
kwargs = {}
for param_name, param_value in values.items():
if param_name in get_function_param_names(dependency.call):
kwargs[param_name] = param_value

# 3. 调用函数
if asyncio.iscoroutinefunction(dependency.call):
result = await dependency.call(**kwargs)
else:
result = dependency.call(**kwargs)

# 4. 处理生成器
if inspect.isgenerator(result) or inspect.isasyncgen(result):
try:
if inspect.isasyncgen(result):
dependency_value = await result.__anext__()
else:
dependency_value = next(result)
except StopIteration:
dependency_value = None

# 注册清理函数
if hasattr(dependency, 'cleanup_functions'):
dependency.cleanup_functions.append(result)
else:
dependency_value = result

# 5. 缓存结果
if dependency.use_cache and dependency_cache is not None:
cache_key = get_dependency_cache_key(dependency, values)
dependency_cache[cache_key] = dependency_value

return dependency_value

4. 依赖缓存机制

4.1 缓存键生成

def get_dependency_cache_key(
dependency: Dependant,
values: Dict[str, Any]
) -> Tuple[Callable[..., Any], Tuple[str, ...]]:
"""生成依赖缓存键"""
# 1. 使用函数对象作为主键
func = dependency.call

# 2. 使用安全作用域作为辅助键
security_scopes = tuple(sorted(
scope.scope for scope in dependency.security_requirements
)) if dependency.security_requirements else ()

return (func, security_scopes)

4.2 缓存生命周期

class DependencyCache:
"""依赖缓存管理器"""

def __init__(self):
self._cache: Dict[Any, Any] = {}
self._cleanup_functions: List[Callable] = []

def get(self, key: Any) -> Any:
return self._cache.get(key)

def set(self, key: Any, value: Any) -> None:
self._cache[key] = value

def add_cleanup(self, cleanup_func: Callable) -> None:
self._cleanup_functions.append(cleanup_func)

async def cleanup(self) -> None:
"""清理缓存和资源"""
for cleanup_func in self._cleanup_functions:
try:
if asyncio.iscoroutinefunction(cleanup_func):
await cleanup_func()
else:
cleanup_func()
except Exception as e:
# 记录清理错误但不中断流程
logger.error(f"Error during dependency cleanup: {e}")

self._cache.clear()
self._cleanup_functions.clear()

4.3 请求级缓存

# 请求级缓存示例
def get_expensive_resource():
"""昂贵的资源获取操作"""
print("Creating expensive resource...") # 只会打印一次
return ExpensiveResource()

@app.get("/endpoint1")
async def endpoint1(
resource: ExpensiveResource = Depends(get_expensive_resource)
):
return {"endpoint": "1", "resource_id": resource.id}

@app.get("/endpoint2")
async def endpoint2(
resource: ExpensiveResource = Depends(get_expensive_resource) # 使用缓存的实例
):
return {"endpoint": "2", "resource_id": resource.id}

5. 高级依赖模式

5.1 条件依赖

def get_db_connection(use_replica: bool = Query(False)):
"""根据条件选择数据库连接"""
if use_replica:
return get_replica_db()
else:
return get_master_db()

@app.get("/data")
async def get_data(
db = Depends(get_db_connection) # 根据查询参数选择数据库
):
return db.query("SELECT * FROM data")

5.2 工厂模式依赖

class ServiceFactory:
def __init__(self):
self._services = {}

def get_service(self, service_type: str):
if service_type not in self._services:
if service_type == "user":
self._services[service_type] = UserService()
elif service_type == "order":
self._services[service_type] = OrderService()
else:
raise ValueError(f"Unknown service type: {service_type}")
return self._services[service_type]

# 全局工厂实例
service_factory = ServiceFactory()

def get_user_service():
return service_factory.get_service("user")

def get_order_service():
return service_factory.get_service("order")

@app.get("/users/{user_id}")
async def get_user(
user_id: int,
user_service = Depends(get_user_service)
):
return user_service.get_user(user_id)

5.3 上下文管理器依赖

from contextlib import asynccontextmanager

@asynccontextmanager
async def get_db_transaction():
"""数据库事务上下文管理器"""
db = SessionLocal()
transaction = db.begin()
try:
yield db
transaction.commit()
except Exception:
transaction.rollback()
raise
finally:
db.close()

# 使用上下文管理器依赖
@app.post("/users/")
async def create_user(
user_data: UserCreate,
db: Session = Depends(get_db_transaction)
):
# 自动事务管理
user = User(**user_data.dict())
db.add(user)
# 事务会在函数结束时自动提交或回滚
return user

5.4 依赖覆盖

# 原始依赖
def get_settings():
return Settings()

# 测试时的依赖覆盖
def get_test_settings():
return Settings(testing=True, database_url="sqlite:///:memory:")

# 在测试中覆盖依赖
app.dependency_overrides[get_settings] = get_test_settings

@app.get("/config")
async def get_config(settings: Settings = Depends(get_settings)):
# 在测试环境中会使用测试设置
return {"database_url": settings.database_url}

6. 安全依赖

6.1 认证依赖

from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

security = HTTPBearer()

async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security)
) -> User:
"""获取当前用户"""
token = credentials.credentials

try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=401, detail="Invalid token")
except JWTError:
raise HTTPException(status_code=401, detail="Invalid token")

user = get_user_by_username(username)
if user is None:
raise HTTPException(status_code=401, detail="User not found")

return user

@app.get("/protected")
async def protected_endpoint(
current_user: User = Depends(get_current_user)
):
return {"message": f"Hello, {current_user.username}!"}

6.2 权限依赖

from functools import wraps
from typing import List

def require_permissions(required_permissions: List[str]):
"""权限检查依赖工厂"""
def permission_checker(
current_user: User = Depends(get_current_user)
) -> User:
user_permissions = get_user_permissions(current_user.id)

for permission in required_permissions:
if permission not in user_permissions:
raise HTTPException(
status_code=403,
detail=f"Permission '{permission}' required"
)

return current_user

return permission_checker

# 使用权限依赖
@app.delete("/users/{user_id}")
async def delete_user(
user_id: int,
current_user: User = Depends(require_permissions(["user:delete"]))
):
# 只有具有 user:delete 权限的用户才能访问
delete_user_by_id(user_id)
return {"message": "User deleted"}

6.3 作用域依赖

from fastapi.security import SecurityScopes

async def get_current_user_with_scopes(
security_scopes: SecurityScopes,
token: str = Depends(oauth2_scheme)
) -> User:
"""带作用域检查的用户认证"""
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_value = "Bearer"

credentials_exception = HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": authenticate_value},
)

try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_scopes = payload.get("scopes", [])
except JWTError:
raise credentials_exception

user = get_user_by_username(username)
if user is None:
raise credentials_exception

for scope in security_scopes.scopes:
if scope not in token_scopes:
raise HTTPException(
status_code=403,
detail="Not enough permissions",
headers={"WWW-Authenticate": authenticate_value},
)

return user

# 使用作用域
@app.get("/users/me/items/")
async def read_own_items(
current_user: User = Security(get_current_user_with_scopes, scopes=["items:read"])
):
return get_user_items(current_user.id)

7. 依赖测试

7.1 依赖模拟

import pytest
from fastapi.testclient import TestClient

# 模拟依赖
def mock_get_db():
return MockDatabase()

def mock_get_current_user():
return User(id=1, username="testuser", email="test@example.com")

# 测试配置
@pytest.fixture
def client():
# 覆盖依赖
app.dependency_overrides[get_db] = mock_get_db
app.dependency_overrides[get_current_user] = mock_get_current_user

with TestClient(app) as client:
yield client

# 清理覆盖
app.dependency_overrides.clear()

# 测试用例
def test_protected_endpoint(client):
response = client.get("/protected")
assert response.status_code == 200
assert response.json()["message"] == "Hello, testuser!"

7.2 依赖隔离测试

def test_dependency_isolation():
"""测试依赖隔离"""
call_count = 0

def counting_dependency():
nonlocal call_count
call_count += 1
return f"call_{call_count}"

# 创建测试应用
test_app = FastAPI()

@test_app.get("/test1")
async def test1(dep = Depends(counting_dependency)):
return {"result": dep}

@test_app.get("/test2")
async def test2(dep = Depends(counting_dependency)):
return {"result": dep}

with TestClient(test_app) as client:
# 每个请求都应该有独立的依赖实例
response1 = client.get("/test1")
response2 = client.get("/test2")

assert response1.json()["result"] == "call_1"
assert response2.json()["result"] == "call_2"

8. 性能优化

8.1 依赖预编译

class DependencyCompiler:
"""依赖编译器"""

def __init__(self):
self._compiled_dependencies = {}

def compile_dependency(self, func: Callable) -> Dependant:
"""预编译依赖"""
if func in self._compiled_dependencies:
return self._compiled_dependencies[func]

dependant = get_dependant(path="", call=func)
self._compiled_dependencies[func] = dependant
return dependant

def get_compiled_dependency(self, func: Callable) -> Optional[Dependant]:
return self._compiled_dependencies.get(func)

# 全局编译器实例
dependency_compiler = DependencyCompiler()

8.2 依赖池化

from typing import Dict, List
import asyncio

class DependencyPool:
"""依赖对象池"""

def __init__(self, max_size: int = 100):
self.max_size = max_size
self._pools: Dict[Callable, List[Any]] = {}
self._locks: Dict[Callable, asyncio.Lock] = {}

async def get_dependency(self, dependency_func: Callable) -> Any:
"""从池中获取依赖对象"""
if dependency_func not in self._locks:
self._locks[dependency_func] = asyncio.Lock()

async with self._locks[dependency_func]:
pool = self._pools.get(dependency_func, [])

if pool:
return pool.pop()
else:
# 创建新实例
if asyncio.iscoroutinefunction(dependency_func):
return await dependency_func()
else:
return dependency_func()

async def return_dependency(self, dependency_func: Callable, instance: Any) -> None:
"""将依赖对象返回池中"""
async with self._locks[dependency_func]:
pool = self._pools.setdefault(dependency_func, [])

if len(pool) < self.max_size:
pool.append(instance)
else:
# 池已满,销毁实例
if hasattr(instance, 'close'):
await instance.close()

8.3 依赖图优化

class DependencyGraph:
"""依赖图优化器"""

def __init__(self):
self.graph = {}
self.resolved_order = []

def add_dependency(self, func: Callable, dependencies: List[Callable]):
"""添加依赖关系"""
self.graph[func] = dependencies

def topological_sort(self) -> List[Callable]:
"""拓扑排序,确定最优解析顺序"""
visited = set()
temp_visited = set()
result = []

def visit(node):
if node in temp_visited:
raise ValueError("Circular dependency detected")
if node in visited:
return

temp_visited.add(node)
for dependency in self.graph.get(node, []):
visit(dependency)
temp_visited.remove(node)
visited.add(node)
result.append(node)

for node in self.graph:
if node not in visited:
visit(node)

return result

def optimize_resolution_order(self) -> List[Callable]:
"""优化依赖解析顺序"""
return self.topological_sort()

9. 错误处理和调试

9.1 依赖错误处理

class DependencyError(Exception):
"""依赖错误基类"""
pass

class CircularDependencyError(DependencyError):
"""循环依赖错误"""
pass

class MissingDependencyError(DependencyError):
"""缺失依赖错误"""
pass

def validate_dependencies(dependant: Dependant) -> List[str]:
"""验证依赖配置"""
errors = []
visited = set()

def check_circular(dep: Dependant, path: List[str]):
if dep.call.__name__ in path:
errors.append(f"Circular dependency: {' -> '.join(path + [dep.call.__name__])}")
return

if dep.call in visited:
return

visited.add(dep.call)
new_path = path + [dep.call.__name__]

for sub_dep in dep.dependencies:
check_circular(sub_dep, new_path)

check_circular(dependant, [])
return errors

9.2 依赖调试工具

class DependencyDebugger:
"""依赖调试器"""

def __init__(self):
self.call_stack = []
self.resolution_times = {}

def trace_dependency_resolution(self, dependant: Dependant):
"""跟踪依赖解析过程"""
import time

def trace_call(func_name: str):
start_time = time.time()
self.call_stack.append(func_name)

def end_trace():
end_time = time.time()
self.resolution_times[func_name] = end_time - start_time
self.call_stack.pop()

return end_trace

return trace_call

def print_dependency_tree(self, dependant: Dependant, indent: int = 0):
"""打印依赖树"""
prefix = " " * indent
print(f"{prefix}{dependant.call.__name__}")

for dep in dependant.dependencies:
self.print_dependency_tree(dep, indent + 1)

def get_performance_report(self) -> Dict[str, float]:
"""获取性能报告"""
return dict(sorted(
self.resolution_times.items(),
key=lambda x: x[1],
reverse=True
))