C0726N01-FastAPI 安全认证系统
概述
FastAPI的安全认证系统基于OpenAPI标准,提供了完整的身份验证和授权解决方案。本章将深入分析安全系统的架构设计、认证机制实现以及各种安全模式的具体应用。
1. 安全系统架构
1.1 核心组件
FastAPI安全系统由以下核心组件构成:
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials, HTTPBearer, HTTPAuthorizationCredentials
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.security import APIKeyHeader, APIKeyQuery, APIKeyCookie
from pydantic import BaseModel
from typing import Optional, List, Dict, Set, Any
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from enum import Enum
import jwt
import secrets
import hashlib
# 安全基类
class SecurityBase:
"""安全方案基类"""
def __init__(self, *, scheme_name: Optional[str] = None):
self.scheme_name = scheme_name or self.__class__.__name__
self.model: SecurityBaseModel = self._get_security_model()
def _get_security_model(self) -> SecurityBaseModel:
"""获取安全模型"""
raise NotImplementedError()
def __call__(self, request: Request) -> Optional[str]:
"""从请求中提取安全信息"""
raise NotImplementedError()
# 安全模型基类
class SecurityBaseModel(BaseModel):
"""安全模型基类"""
type_: str = Field(alias="type")
description: Optional[str] = None
# 安全需求
@dataclass
class SecurityRequirement:
"""安全需求定义"""
security_scheme: SecurityBase
scopes: List[str] = field(default_factory=list)
@property
def security_scopes(self) -> List[str]:
"""获取安全作用域"""
return self.scopes
1.2 安全处理流程
安全认证遵循以下处理流程:
- 安全方案注册: 在路由中声明安全依赖
- 请求拦截: 从HTTP请求中提取认证信息
- 身份验证: 验证用户身份和凭据
- 权限检查: 检查用户权限和作用域
- 上下文注入: 将用户信息注入到请求上下文
- 错误处理: 处理认证和授权失败
async def solve_security(
request: Request,
dependant: Dependant,
security_scopes: SecurityScopes,
dependency_overrides_provider: Any = None,
) -> Dict[str, Any]:
"""解析安全依赖"""
values = {}
for security_requirement in dependant.security_requirements:
security_scheme = security_requirement.security_scheme
required_scopes = security_requirement.scopes
# 调用安全方案
security_value = await security_scheme(request)
if security_value is None:
# 处理认证失败
if security_scheme.auto_error:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="认证失败",
headers={"WWW-Authenticate": security_scheme.scheme_name}
)
continue
# 验证作用域
if required_scopes:
await verify_security_scopes(
security_value, required_scopes, security_scopes
)
values[security_scheme.scheme_name] = security_value
return values
2. HTTP Basic 认证
2.1 基本实现
from fastapi.security import HTTPBasic, HTTPBasicCredentials
import secrets
# HTTP Basic 认证方案
basic_auth = HTTPBasic()
# 用户数据库(示例)
fake_users_db = {
"admin": {
"username": "admin",
"password": "secret123",
"email": "admin@example.com",
"roles": ["admin"]
},
"user": {
"username": "user",
"password": "password123",
"email": "user@example.com",
"roles": ["user"]
}
}
class HTTPBasicAuth:
"""HTTP Basic 认证实现"""
def __init__(self, realm: str = "Secure Area"):
self.realm = realm
self.security = HTTPBasic(realm=realm)
def authenticate_user(self, username: str, password: str) -> Optional[dict]:
"""验证用户凭据"""
user = fake_users_db.get(username)
if not user:
return None
# 使用安全的字符串比较
if not secrets.compare_digest(user["password"], password):
return None
return user
def __call__(self, credentials: HTTPBasicCredentials = Depends(basic_auth)):
"""认证依赖"""
user = self.authenticate_user(credentials.username, credentials.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": f'Basic realm="{self.realm}"'}
)
return user
# 创建认证实例
http_basic_auth = HTTPBasicAuth()
# 使用示例
@app.get("/basic-protected")
async def basic_protected_route(
current_user: dict = Depends(http_basic_auth)
):
"""需要Basic认证的路由"""
return {
"message": "访问成功",
"user": current_user["username"],
"roles": current_user["roles"]
}
2.2 增强的Basic认证
import bcrypt
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
class EnhancedHTTPBasicAuth:
"""增强的HTTP Basic认证"""
def __init__(
self,
realm: str = "Secure Area",
max_attempts: int = 5,
lockout_duration: int = 300 # 5分钟
):
self.realm = realm
self.security = HTTPBasic(realm=realm)
self.max_attempts = max_attempts
self.lockout_duration = lockout_duration
self.failed_attempts: Dict[str, Dict[str, Any]] = {}
def is_account_locked(self, username: str) -> bool:
"""检查账户是否被锁定"""
if username not in self.failed_attempts:
return False
attempt_info = self.failed_attempts[username]
# 检查锁定是否过期
if datetime.utcnow() > attempt_info["locked_until"]:
del self.failed_attempts[username]
return False
return attempt_info["count"] >= self.max_attempts
def record_failed_attempt(self, username: str):
"""记录失败尝试"""
now = datetime.utcnow()
if username not in self.failed_attempts:
self.failed_attempts[username] = {
"count": 1,
"first_attempt": now,
"locked_until": now + timedelta(seconds=self.lockout_duration)
}
else:
self.failed_attempts[username]["count"] += 1
self.failed_attempts[username]["locked_until"] = (
now + timedelta(seconds=self.lockout_duration)
)
def clear_failed_attempts(self, username: str):
"""清除失败尝试记录"""
if username in self.failed_attempts:
del self.failed_attempts[username]
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return bcrypt.checkpw(
plain_password.encode('utf-8'),
hashed_password.encode('utf-8')
)
def authenticate_user(self, username: str, password: str) -> Optional[dict]:
"""验证用户凭据"""
# 检查账户锁定
if self.is_account_locked(username):
raise HTTPException(
status_code=status.HTTP_423_LOCKED,
detail="账户已被锁定,请稍后再试"
)
user = fake_users_db.get(username)
if not user:
self.record_failed_attempt(username)
return None
# 验证密码
if not self.verify_password(password, user["password"]):
self.record_failed_attempt(username)
return None
# 认证成功,清除失败记录
self.clear_failed_attempts(username)
return user
def __call__(self, credentials: HTTPBasicCredentials = Depends(basic_auth)):
"""认证依赖"""
user = self.authenticate_user(credentials.username, credentials.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": f'Basic realm="{self.realm}"'}
)
return user
3. Bearer Token 认证
3.1 JWT Token 实现
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import jwt
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
# Bearer Token 认证方案
bearer_auth = HTTPBearer()
class JWTTokenService:
"""JWT令牌服务"""
def __init__(
self,
secret_key: str,
algorithm: str = "HS256",
access_token_expire_minutes: int = 30,
refresh_token_expire_days: int = 7
):
self.secret_key = secret_key
self.algorithm = algorithm
self.access_token_expire_minutes = access_token_expire_minutes
self.refresh_token_expire_days = refresh_token_expire_days
def create_access_token(
self,
data: Dict[str, Any],
expires_delta: Optional[timedelta] = None
) -> str:
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=self.access_token_expire_minutes
)
to_encode.update({
"exp": expire,
"iat": datetime.utcnow(),
"type": "access"
})
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
def create_refresh_token(self, data: Dict[str, Any]) -> str:
"""创建刷新令牌"""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(days=self.refresh_token_expire_days)
to_encode.update({
"exp": expire,
"iat": datetime.utcnow(),
"type": "refresh"
})
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
"""验证令牌"""
try:
payload = jwt.decode(
token, self.secret_key, algorithms=[self.algorithm]
)
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌已过期",
headers={"WWW-Authenticate": "Bearer"}
)
except jwt.JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的令牌",
headers={"WWW-Authenticate": "Bearer"}
)
def get_current_user(self, token: str) -> Optional[dict]:
"""从令牌获取当前用户"""
payload = self.verify_token(token)
if payload.get("type") != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的令牌类型"
)
username = payload.get("sub")
if username is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌中缺少用户信息"
)
user = fake_users_db.get(username)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
return user
# JWT服务实例
jwt_service = JWTTokenService(
secret_key="your-secret-key-here",
access_token_expire_minutes=30
)
class BearerTokenAuth:
"""Bearer Token认证"""
def __init__(self, jwt_service: JWTTokenService):
self.jwt_service = jwt_service
self.security = HTTPBearer()
def __call__(
self,
credentials: HTTPAuthorizationCredentials = Depends(bearer_auth)
):
"""认证依赖"""
return self.jwt_service.get_current_user(credentials.credentials)
# 创建认证实例
bearer_token_auth = BearerTokenAuth(jwt_service)
# 登录端点
@app.post("/token")
async def login(
credentials: HTTPBasicCredentials = Depends(basic_auth)
):
"""用户登录获取令牌"""
# 验证用户
user = fake_users_db.get(credentials.username)
if not user or not secrets.compare_digest(user["password"], credentials.password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误"
)
# 创建令牌
access_token = jwt_service.create_access_token(
data={"sub": user["username"], "roles": user["roles"]}
)
refresh_token = jwt_service.create_refresh_token(
data={"sub": user["username"]}
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": jwt_service.access_token_expire_minutes * 60
}
# 使用Bearer Token的受保护路由
@app.get("/bearer-protected")
async def bearer_protected_route(
current_user: dict = Depends(bearer_token_auth)
):
"""需要Bearer Token认证的路由"""
return {
"message": "访问成功",
"user": current_user["username"],
"roles": current_user["roles"]
}
4. OAuth2 认证
4.1 OAuth2 Password Flow
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
from typing import Optional, List
# OAuth2 方案
oauth2_scheme = OAuth2PasswordBearer(
tokenUrl="token",
scopes={
"read": "读取权限",
"write": "写入权限",
"admin": "管理员权限"
}
)
class Token(BaseModel):
"""令牌响应模型"""
access_token: str
token_type: str
expires_in: int
refresh_token: Optional[str] = None
scope: Optional[str] = None
class TokenData(BaseModel):
"""令牌数据模型"""
username: Optional[str] = None
scopes: List[str] = []
class OAuth2PasswordBearerWithScopes:
"""支持作用域的OAuth2 Password Bearer"""
def __init__(
self,
jwt_service: JWTTokenService,
tokenUrl: str,
scopes: Dict[str, str] = None
):
self.jwt_service = jwt_service
self.oauth2_scheme = OAuth2PasswordBearer(
tokenUrl=tokenUrl,
scopes=scopes or {}
)
def authenticate_user(self, username: str, password: str) -> Optional[dict]:
"""验证用户"""
user = fake_users_db.get(username)
if not user:
return None
if not secrets.compare_digest(user["password"], password):
return None
return user
def get_user_scopes(self, user: dict) -> List[str]:
"""获取用户作用域"""
# 根据用户角色确定作用域
scopes = ["read"] # 默认读取权限
if "admin" in user.get("roles", []):
scopes.extend(["write", "admin"])
elif "editor" in user.get("roles", []):
scopes.append("write")
return scopes
async def create_token(self, form_data: OAuth2PasswordRequestForm) -> Token:
"""创建令牌"""
# 验证用户
user = self.authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"}
)
# 获取用户作用域
user_scopes = self.get_user_scopes(user)
# 验证请求的作用域
requested_scopes = form_data.scopes
if requested_scopes:
# 检查请求的作用域是否在用户权限范围内
invalid_scopes = set(requested_scopes) - set(user_scopes)
if invalid_scopes:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"无效的作用域: {', '.join(invalid_scopes)}"
)
granted_scopes = requested_scopes
else:
granted_scopes = user_scopes
# 创建令牌
token_data = {
"sub": user["username"],
"scopes": granted_scopes,
"roles": user["roles"]
}
access_token = self.jwt_service.create_access_token(data=token_data)
refresh_token = self.jwt_service.create_refresh_token(
data={"sub": user["username"]}
)
return Token(
access_token=access_token,
token_type="bearer",
expires_in=self.jwt_service.access_token_expire_minutes * 60,
refresh_token=refresh_token,
scope=" ".join(granted_scopes)
)
async def get_current_user(
self,
security_scopes: SecurityScopes,
token: str = Depends(oauth2_scheme)
):
"""获取当前用户并验证作用域"""
# 验证令牌
payload = self.jwt_service.verify_token(token)
username = payload.get("sub")
if username is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌中缺少用户信息",
headers={"WWW-Authenticate": "Bearer"}
)
# 获取令牌作用域
token_scopes = payload.get("scopes", [])
# 验证所需作用域
if security_scopes.scopes:
missing_scopes = set(security_scopes.scopes) - set(token_scopes)
if missing_scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="权限不足",
headers={"WWW-Authenticate": authenticate_value}
)
# 获取用户信息
user = fake_users_db.get(username)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
return user
# OAuth2实例
oauth2_handler = OAuth2PasswordBearerWithScopes(
jwt_service=jwt_service,
tokenUrl="token",
scopes={
"read": "读取权限",
"write": "写入权限",
"admin": "管理员权限"
}
)
# OAuth2令牌端点
@app.post("/token", response_model=Token)
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends()
):
"""OAuth2令牌端点"""
return await oauth2_handler.create_token(form_data)
# 需要特定作用域的路由
@app.get("/oauth2/read")
async def read_data(
current_user: dict = Security(
oauth2_handler.get_current_user,
scopes=["read"]
)
):
"""需要读取权限的路由"""
return {"message": "读取数据成功", "user": current_user["username"]}
@app.post("/oauth2/write")
async def write_data(
data: dict,
current_user: dict = Security(
oauth2_handler.get_current_user,
scopes=["write"]
)
):
"""需要写入权限的路由"""
return {
"message": "写入数据成功",
"user": current_user["username"],
"data": data
}
@app.get("/oauth2/admin")
async def admin_only(
current_user: dict = Security(
oauth2_handler.get_current_user,
scopes=["admin"]
)
):
"""需要管理员权限的路由"""
return {
"message": "管理员操作成功",
"user": current_user["username"]
}
5. API Key 认证
5.1 API Key 实现
from fastapi.security import APIKeyHeader, APIKeyQuery, APIKeyCookie
from typing import Optional, Dict, List, Set
from dataclasses import dataclass
from datetime import datetime, timedelta
import secrets
import hashlib
@dataclass
class APIKey:
"""API密钥模型"""
key_id: str
key_hash: str
name: str
permissions: Set[str]
rate_limit: int # 每分钟请求限制
allowed_ips: Optional[List[str]] = None
expires_at: Optional[datetime] = None
created_at: datetime = datetime.utcnow()
last_used: Optional[datetime] = None
is_active: bool = True
class APIKeyManager:
"""API密钥管理器"""
def __init__(self):
self.api_keys: Dict[str, APIKey] = {}
self.usage_stats: Dict[str, Dict[str, int]] = {} # key_id -> {minute: count}
def generate_api_key(self) -> str:
"""生成API密钥"""
return secrets.token_urlsafe(32)
def hash_api_key(self, api_key: str) -> str:
"""哈希API密钥"""
return hashlib.sha256(api_key.encode()).hexdigest()
def create_api_key(
self,
name: str,
permissions: Set[str],
rate_limit: int = 1000,
allowed_ips: Optional[List[str]] = None,
expires_in_days: Optional[int] = None
) -> tuple[str, APIKey]:
"""创建API密钥"""
api_key = self.generate_api_key()
key_hash = self.hash_api_key(api_key)
key_id = secrets.token_hex(8)
expires_at = None
if expires_in_days:
expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
api_key_obj = APIKey(
key_id=key_id,
key_hash=key_hash,
name=name,
permissions=permissions,
rate_limit=rate_limit,
allowed_ips=allowed_ips,
expires_at=expires_at
)
self.api_keys[key_id] = api_key_obj
return api_key, api_key_obj
def validate_api_key(
self,
api_key: str,
client_ip: Optional[str] = None
) -> Optional[APIKey]:
"""验证API密钥"""
key_hash = self.hash_api_key(api_key)
# 查找匹配的密钥
for key_obj in self.api_keys.values():
if key_obj.key_hash == key_hash and key_obj.is_active:
# 检查过期时间
if key_obj.expires_at and datetime.utcnow() > key_obj.expires_at:
continue
# 检查IP限制
if key_obj.allowed_ips and client_ip:
if client_ip not in key_obj.allowed_ips:
continue
# 检查速率限制
if not self._check_rate_limit(key_obj):
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="API调用频率超限"
)
# 更新使用时间
key_obj.last_used = datetime.utcnow()
self._record_usage(key_obj.key_id)
return key_obj
return None
def _check_rate_limit(self, api_key: APIKey) -> bool:
"""检查速率限制"""
current_minute = datetime.utcnow().strftime("%Y-%m-%d %H:%M")
if api_key.key_id not in self.usage_stats:
self.usage_stats[api_key.key_id] = {}
current_count = self.usage_stats[api_key.key_id].get(current_minute, 0)
return current_count < api_key.rate_limit
def _record_usage(self, key_id: str):
"""记录使用情况"""
current_minute = datetime.utcnow().strftime("%Y-%m-%d %H:%M")
if key_id not in self.usage_stats:
self.usage_stats[key_id] = {}
self.usage_stats[key_id][current_minute] = (
self.usage_stats[key_id].get(current_minute, 0) + 1
)
# 清理旧的统计数据(保留最近1小时)
cutoff_time = datetime.utcnow() - timedelta(hours=1)
cutoff_str = cutoff_time.strftime("%Y-%m-%d %H:%M")
keys_to_remove = [
minute for minute in self.usage_stats[key_id].keys()
if minute < cutoff_str
]
for minute in keys_to_remove:
del self.usage_stats[key_id][minute]
def revoke_api_key(self, key_id: str) -> bool:
"""撤销API密钥"""
if key_id in self.api_keys:
self.api_keys[key_id].is_active = False
return True
return False
def get_api_key_info(self, key_id: str) -> Optional[APIKey]:
"""获取API密钥信息"""
return self.api_keys.get(key_id)
class APIKeyAuth:
"""API密钥认证"""
def __init__(
self,
api_key_manager: APIKeyManager,
name: str = "X-API-Key",
scheme_name: Optional[str] = None,
description: Optional[str] = None,
auto_error: bool = True
):
self.api_key_manager = api_key_manager
self.name = name
self.scheme_name = scheme_name or "APIKey"
self.description = description
self.auto_error = auto_error
def __call__(self, request: Request) -> Optional[APIKey]:
"""从请求中提取和验证API密钥"""
# 尝试从不同位置获取API密钥
api_key = self._extract_api_key(request)
if not api_key:
if self.auto_error:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供API密钥",
headers={"WWW-Authenticate": "ApiKey"}
)
return None
# 获取客户端IP
client_ip = self._get_client_ip(request)
# 验证API密钥
key_obj = self.api_key_manager.validate_api_key(
api_key, client_ip=client_ip
)
if not key_obj:
if self.auto_error:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的API密钥",
headers={"WWW-Authenticate": "ApiKey"}
)
return None
return key_obj
def _extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取API密钥"""
# 从Header中获取
api_key = request.headers.get(self.name)
if api_key:
return api_key
# 从查询参数中获取
api_key = request.query_params.get("api_key")
if api_key:
return api_key
# 从Cookie中获取
api_key = request.cookies.get("api_key")
return api_key
def _get_client_ip(self, request: Request) -> Optional[str]:
"""获取客户端IP地址"""
# 检查代理头部
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直接连接IP
if request.client:
return request.client.host
return None
# API密钥认证使用示例
api_key_manager = APIKeyManager()
api_key_auth = APIKeyAuth(api_key_manager)
@app.get("/api/protected")
async def protected_endpoint(
api_key: APIKey = Depends(api_key_auth)
):
"""需要API密钥的受保护端点"""
return {
"message": "访问成功",
"api_key_id": api_key.key_id,
"permissions": list(api_key.permissions)
}
6. 基于角色的访问控制 (RBAC)
6.1 权限和角色定义
from enum import Enum
from typing import Set, List, Dict, Optional
from dataclasses import dataclass, field
from datetime import datetime
# 权限枚举
class Permission(Enum):
"""系统权限枚举"""
# 用户管理权限
USER_READ = "user:read"
USER_WRITE = "user:write"
USER_DELETE = "user:delete"
# 内容管理权限
POST_READ = "post:read"
POST_WRITE = "post:write"
POST_DELETE = "post:delete"
POST_PUBLISH = "post:publish"
# 系统管理权限
ADMIN_READ = "admin:read"
ADMIN_WRITE = "admin:write"
# API访问权限
API_READ = "api:read"
API_WRITE = "api:write"
API_DELETE = "api:delete"
@dataclass
class Role:
"""角色定义"""
name: str
description: str
permissions: Set[Permission] = field(default_factory=set)
is_system_role: bool = False
created_at: datetime = field(default_factory=datetime.utcnow)
def has_permission(self, permission: Permission) -> bool:
"""检查角色是否拥有指定权限"""
return permission in self.permissions
def add_permission(self, permission: Permission):
"""添加权限"""
self.permissions.add(permission)
def remove_permission(self, permission: Permission):
"""移除权限"""
self.permissions.discard(permission)
@dataclass
class UserPermissions:
"""用户权限模型"""
user_id: str
roles: Set[str] = field(default_factory=set) # 角色名称集合
direct_permissions: Set[Permission] = field(default_factory=set) # 直接权限
denied_permissions: Set[Permission] = field(default_factory=set) # 被拒绝的权限
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
def add_role(self, role_name: str):
"""添加角色"""
self.roles.add(role_name)
self.updated_at = datetime.utcnow()
def remove_role(self, role_name: str):
"""移除角色"""
self.roles.discard(role_name)
self.updated_at = datetime.utcnow()
def grant_permission(self, permission: Permission):
"""直接授予权限"""
self.direct_permissions.add(permission)
# 从拒绝列表中移除(如果存在)
self.denied_permissions.discard(permission)
self.updated_at = datetime.utcnow()
def deny_permission(self, permission: Permission):
"""拒绝权限"""
self.denied_permissions.add(permission)
# 从直接权限中移除(如果存在)
self.direct_permissions.discard(permission)
self.updated_at = datetime.utcnow()
6.2 权限管理服务
class PermissionManager:
"""权限管理服务"""
def __init__(self):
self.roles: Dict[str, Role] = {}
self.user_permissions: Dict[str, UserPermissions] = {}
self._init_default_roles()
def _init_default_roles(self):
"""初始化默认角色"""
# 超级管理员
admin_role = Role(
name="admin",
description="超级管理员",
permissions=set(Permission), # 所有权限
is_system_role=True
)
# 编辑者
editor_role = Role(
name="editor",
description="编辑者",
permissions={
Permission.USER_READ,
Permission.POST_READ,
Permission.POST_WRITE,
Permission.POST_PUBLISH,
Permission.API_READ,
Permission.API_WRITE
},
is_system_role=True
)
# 作者
author_role = Role(
name="author",
description="作者",
permissions={
Permission.USER_READ,
Permission.POST_READ,
Permission.POST_WRITE,
Permission.API_READ
},
is_system_role=True
)
# 普通用户
user_role = Role(
name="user",
description="普通用户",
permissions={
Permission.USER_READ,
Permission.POST_READ,
Permission.API_READ
},
is_system_role=True
)
self.roles.update({
"admin": admin_role,
"editor": editor_role,
"author": author_role,
"user": user_role
})
def get_user_permissions(self, user_id: str) -> Set[Permission]:
"""获取用户的所有权限"""
user_perms = self.user_permissions.get(user_id)
if not user_perms:
return set()
# 收集角色权限
role_permissions = set()
for role_name in user_perms.roles:
role = self.roles.get(role_name)
if role:
role_permissions.update(role.permissions)
# 合并直接权限
all_permissions = role_permissions | user_perms.direct_permissions
# 移除被拒绝的权限
all_permissions -= user_perms.denied_permissions
return all_permissions
def check_user_permission(self, user_id: str, permission: Permission) -> bool:
"""检查用户是否拥有特定权限"""
user_permissions = self.get_user_permissions(user_id)
return permission in user_permissions
# 权限检查装饰器
def require_permissions(required_permissions: List[Permission]):
"""要求特定权限的装饰器"""
def permission_dependency(
current_user: dict = Depends(bearer_token_auth),
permission_manager: PermissionManager = Depends(lambda: permission_manager)
):
user_id = current_user["username"]
# 检查所有必需权限
for permission in required_permissions:
if not permission_manager.check_user_permission(user_id, permission):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"权限不足,需要权限: {permission.value}"
)
return current_user
return permission_dependency
# 全局权限管理器
permission_manager = PermissionManager()
# 使用权限控制的端点示例
@app.post("/posts")
async def create_post(
post_data: dict,
current_user: dict = Depends(require_permissions([Permission.POST_WRITE]))
):
"""创建文章"""
return {
"message": "文章创建成功",
"author": current_user["username"],
"data": post_data
}