2025-06-09 17:53:19 +08:00

228 lines
12 KiB
Python

# myapp/schemas.py
from marshmallow import fields, validate, ValidationError, Schema, validates_schema
from marshmallow.validate import OneOf
from bson.objectid import ObjectId, InvalidId
# Import Marshmallow instance from extensions
# Assumes 'ma = Marshmallow()' is defined in myapp/extensions.py
# and initialized in myapp/__init__.py's create_app()
try:
from .extensions import ma
except ImportError:
# Basic fallback if extensions.py or 'ma' instance is missing
print("WARNING: Flask-Marshmallow instance 'ma' not found in extensions. Falling back.")
from flask_marshmallow import Marshmallow
ma = Marshmallow()
# --- Custom Validators (Optional but useful) ---
def _validate_object_id(value):
"""Validator function to ensure a string is a valid ObjectId."""
try:
ObjectId(value)
except (InvalidId, TypeError, ValueError): # Catch potential errors
raise ValidationError("Invalid ObjectId format.")
def _is_alphabetic_or_empty(value):
"""Validator for keywords: allows empty string or purely alphabetic."""
if value is not None and value != "" and not value.isalpha():
raise ValidationError("Keyword must be alphabetic if not empty.")
return True # Pass validation if empty or alphabetic
# --- Base Schema for common fields ---
class BaseSchema(ma.Schema):
"""Base schema with common fields like ID and timestamps."""
# Dump ObjectId as string, read-only
id = fields.Function(lambda obj: str(obj.get("_id")), dump_only=True)
# Dump datetime as ISO 8601 string, read-only
createdAt = fields.DateTime(format='iso', dump_only=True)
updatedAt = fields.DateTime(format='iso', dump_only=True)
# --- User Schemas (for auth blueprint) ---
class UserRegistrationSchema(ma.Schema):
"""Schema for validating user registration input."""
username = fields.String(required=True, validate=validate.Length(min=3, max=64, error="Username must be between 3 and 64 characters."))
email = fields.Email(required=True, error="Invalid email format.") # Built-in email validation
password = fields.String(required=True, validate=validate.Length(min=8, error="Password must be at least 8 characters."), load_only=True) # load_only: Input only, never dumped
class UserLoginSchema(ma.Schema):
"""Schema for validating user login input."""
username = fields.String(required=True)
password = fields.String(required=True, load_only=True) # Input only
class UserSchema(BaseSchema):
"""Schema for serializing user data for output (excluding password)."""
username = fields.String(dump_only=True)
email = fields.Email(dump_only=True)
# Inherits id, createdAt, updatedAt from BaseSchema
class Meta:
# IMPORTANT: Explicitly exclude the password field (even if hashed) from output
exclude = ("password",)
class UserUpdateSchema(ma.Schema):
"""Schema for validating user account update input."""
username = fields.String(validate=validate.Length(min=3, max=64)) # Optional update
email = fields.Email() # Optional update
password = fields.String(validate=validate.Length(min=8), load_only=True) # Optional update, input only
# --- API Key Schemas (for api_keys blueprint) ---
ALLOWED_API_PROVIDERS = ["Gemini", "Deepseek", "Chatgpt"]
class APIKeyCreateSchema(ma.Schema):
"""Schema for validating new API key creation input."""
name = fields.String(required=True, validate=OneOf(ALLOWED_API_PROVIDERS, error=f"Provider name must be one of: {ALLOWED_API_PROVIDERS}"))
key = fields.String(required=True, validate=validate.Length(min=5, error="API Key seems too short.")) # Basic length check
selected = fields.Boolean(load_default=False) # Default to False if not provided on load
class APIKeyUpdateSchema(ma.Schema):
"""Schema for validating API key update input."""
# All fields are optional for update
name = fields.String(validate=OneOf(ALLOWED_API_PROVIDERS, error=f"Provider name must be one of: {ALLOWED_API_PROVIDERS}"))
key = fields.String(validate=validate.Length(min=5))
selected = fields.Boolean()
class APIKeySchema(BaseSchema):
"""Schema for serializing API key data for output."""
# Inherits id, createdAt, updatedAt
uid = fields.Function(lambda obj: str(obj.get("uid")), dump_only=True) # User ID as string
name = fields.String(dump_only=True)
key = fields.String(dump_only=True) # Consider masking part of the key for security: fields.Function(lambda obj: f"{obj.get('key', '')[:4]}...{obj.get('key', '')[-4:]}" if obj.get('key') else None, dump_only=True)
selected = fields.Boolean(dump_only=True)
# --- Project Schemas (for projects blueprint) ---
class KeywordSchema(ma.Schema):
"""Schema for individual keywords within a project or URL."""
word = fields.String(required=True, validate=_is_alphabetic_or_empty) # Allow empty string or alphabetic
percentage = fields.Float(required=True, validate=validate.Range(min=0, max=100))
class ProjectCreateSchema(ma.Schema):
"""Schema for validating new project creation input."""
name = fields.String(required=True, validate=validate.Length(min=1, max=100, error="Project name must be between 1 and 100 characters."))
topic = fields.String(validate=validate.Length(max=200)) # Optional topic
description = fields.String(validate=validate.Length(max=1000)) # Optional description
class ProjectUpdateSchema(ma.Schema):
"""Schema for validating project update input."""
# Only allowed fields are optional
name = fields.String(validate=validate.Length(min=1, max=100))
topic = fields.String(validate=validate.Length(max=200))
description = fields.String(validate=validate.Length(max=1000))
collaborators = fields.List(fields.String(validate=_validate_object_id)) # List of user ID strings
keywords = fields.List(fields.Nested(KeywordSchema)) # List of keyword objects
class ProjectSchema(BaseSchema):
"""Schema for serializing detailed project data for output."""
# Inherits id, createdAt, updatedAt
ownerId = fields.Function(lambda obj: str(obj.get("ownerId")), dump_only=True)
collaborators = fields.List(fields.Function(lambda oid: str(oid)), dump_only=True) # List of string IDs
passkey = fields.String(dump_only=True) # Only dump passkey if absolutely necessary, usually not needed in GET responses
name = fields.String(dump_only=True)
topic = fields.String(dump_only=True)
description = fields.String(dump_only=True)
summary = fields.String(dump_only=True)
keywords = fields.List(fields.Nested(KeywordSchema), dump_only=True)
lastActivityBy = fields.Function(lambda obj: str(obj.get("lastActivityBy")) if isinstance(obj.get("lastActivityBy"), ObjectId) else None, dump_only=True)
class ProjectListSchema(ma.Schema):
"""Schema for serializing the summary list of projects."""
id = fields.Function(lambda obj: str(obj.get("_id")), dump_only=True)
name = fields.String(dump_only=True)
updatedAt = fields.DateTime(format='iso', dump_only=True)
# --- URL Schemas (for urls blueprint) ---
class URLCreateSchema(ma.Schema):
"""Schema for validating new URL creation input."""
url = fields.URL(required=True, schemes={'http', 'https'}, error="Invalid URL format.") # Validate URL format
class URLUpdateSchema(ma.Schema):
"""Schema for validating URL update input (only specific fields)."""
title = fields.String(validate=validate.Length(max=500)) # Optional update
starred = fields.Boolean() # Optional update
note = fields.String() # Optional update
keywords = fields.List(fields.Nested(KeywordSchema)) # Optional update, validate nested structure
class URLSchema(BaseSchema):
"""Schema for serializing detailed URL data for output."""
# Inherits id, createdAt, updatedAt
projectId = fields.Function(lambda obj: str(obj.get("projectId")), dump_only=True)
url = fields.URL(dump_only=True)
title = fields.String(dump_only=True)
favicon = fields.String(dump_only=True, allow_none=True)
starred = fields.Boolean(dump_only=True)
note = fields.String(dump_only=True)
keywords = fields.List(fields.Nested(KeywordSchema), dump_only=True)
summary = fields.String(dump_only=True)
processingStatus = fields.String(dump_only=True, validate=OneOf(["pending", "processing", "completed", "failed"])) # Optional: validate status
class URLListSchema(ma.Schema):
"""Schema for serializing the simplified list of URLs."""
id = fields.Function(lambda obj: str(obj.get("_id")), dump_only=True)
title = fields.String(dump_only=True)
url = fields.URL(dump_only=True)
class URLSearchResultSchema(URLListSchema):
"""Schema for search results (same as list for now)."""
pass # Inherits fields from URLListSchema
# --- Activity Schemas (for activity blueprint) ---
class ActivityCreateSchema(ma.Schema):
"""Schema for validating new activity log creation."""
projectId = fields.String(required=True, validate=_validate_object_id) # Validate as ObjectId string
activityType = fields.String(required=True, validate=validate.Length(min=1))
message = fields.String(load_default="") # Optional message
class ActivitySchema(BaseSchema):
"""Schema for serializing activity log data."""
# Inherits id, createdAt
# Note: updatedAt is not typically used for immutable logs
projectId = fields.Function(lambda obj: str(obj.get("projectId")), dump_only=True)
userId = fields.Function(lambda obj: str(obj.get("userId")), dump_only=True)
activityType = fields.String(dump_only=True)
message = fields.String(dump_only=True)
# --- Dialog Schemas (for dialog blueprint) ---
class MessageSchema(ma.Schema):
"""Schema for individual messages within a dialog."""
role = fields.String(required=True, validate=OneOf(["user", "system"], error="Role must be 'user' or 'system'."))
content = fields.String(required=True)
timestamp = fields.DateTime(format='iso', dump_only=True) # Only dump timestamp
class DialogCreateSchema(ma.Schema):
"""Schema for validating new dialog session creation."""
projectId = fields.String(required=True, validate=_validate_object_id)
sessionId = fields.String() # Optional custom session ID
startMessage = fields.String() # Optional initial message
class DialogSendMessageSchema(ma.Schema):
"""Schema for validating user message input when sending to dialog."""
content = fields.String(required=True, validate=validate.Length(min=1, error="Message content cannot be empty."))
class DialogSchema(BaseSchema):
"""Schema for serializing detailed dialog session data (including messages)."""
# Inherits id
uid = fields.Function(lambda obj: str(obj.get("uid")), dump_only=True)
projectId = fields.Function(lambda obj: str(obj.get("projectId")), dump_only=True)
provider = fields.String(dump_only=True)
sessionId = fields.String(dump_only=True) # Dump custom session ID if present
sessionStartedAt = fields.DateTime(format='iso', dump_only=True)
sessionEndedAt = fields.DateTime(format='iso', dump_only=True, allow_none=True) # Can be null
messages = fields.List(fields.Nested(MessageSchema), dump_only=True) # Nested list of messages
class DialogSummarySchema(BaseSchema):
"""Schema for serializing dialog session list (excluding messages)."""
# Inherits id
uid = fields.Function(lambda obj: str(obj.get("uid")), dump_only=True)
projectId = fields.Function(lambda obj: str(obj.get("projectId")), dump_only=True)
provider = fields.String(dump_only=True)
sessionId = fields.String(dump_only=True)
sessionStartedAt = fields.DateTime(format='iso', dump_only=True)
sessionEndedAt = fields.DateTime(format='iso', dump_only=True, allow_none=True)
class Meta:
# Exclude the potentially large messages array for list views
exclude = ("messages",)