bge_finetune/scripts/validate_m3.py
2025-07-22 16:55:25 +08:00

123 lines
4.1 KiB
Python

"""
Quick validation script for the fine-tuned BGE-M3 model
"""
import os
import sys
import torch
import numpy as np
from transformers import AutoTokenizer
# Add project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
def test_model_embeddings(model_path, test_queries=None):
"""Test if the model can generate embeddings"""
if test_queries is None:
test_queries = [
"什么是深度学习?", # Original training query
"什么是机器学习?", # Related query
"今天天气怎么样?" # Unrelated query
]
print(f"🧪 Testing model from: {model_path}")
try:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGEM3Model.from_pretrained(model_path)
model.eval()
print("✅ Model and tokenizer loaded successfully")
# Test embeddings generation
embeddings = []
for i, query in enumerate(test_queries):
print(f"\n📝 Testing query {i+1}: '{query}'")
# Tokenize
inputs = tokenizer(
query,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
)
# Get embeddings
with torch.no_grad():
outputs = model.forward(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
return_dense=True,
return_dict=True
)
if 'dense' in outputs:
embedding = outputs['dense'].cpu().numpy() # type: ignore[index]
embeddings.append(embedding)
print(f" ✅ Embedding shape: {embedding.shape}")
print(f" 📊 Embedding norm: {np.linalg.norm(embedding):.4f}")
print(f" 📈 Mean activation: {embedding.mean():.6f}")
print(f" 📉 Std activation: {embedding.std():.6f}")
else:
print(f" ❌ No dense embedding returned")
return False
# Compare embeddings
if len(embeddings) >= 2:
print(f"\n🔗 Similarity Analysis:")
for i in range(len(embeddings)):
for j in range(i+1, len(embeddings)):
sim = np.dot(embeddings[i].flatten(), embeddings[j].flatten()) / (
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
)
print(f" Query {i+1} vs Query {j+1}: {sim:.4f}")
print(f"\n🎉 Model validation successful!")
return True
except Exception as e:
print(f"❌ Model validation failed: {e}")
return False
def compare_models(original_path, finetuned_path):
"""Compare original vs fine-tuned model"""
print("🔄 Comparing original vs fine-tuned model...")
test_query = "什么是深度学习?"
# Test original model
print("\n📍 Original Model:")
orig_success = test_model_embeddings(original_path, [test_query])
# Test fine-tuned model
print("\n📍 Fine-tuned Model:")
ft_success = test_model_embeddings(finetuned_path, [test_query])
return orig_success and ft_success
if __name__ == "__main__":
print("🚀 BGE-M3 Model Validation")
print("=" * 50)
# Test fine-tuned model
finetuned_model_path = "./test_m3_output/final_model"
success = test_model_embeddings(finetuned_model_path)
# Optional: Compare with original model
original_model_path = "./models/bge-m3"
compare_models(original_model_path, finetuned_model_path)
if success:
print("\n✅ Fine-tuning validation: PASSED")
print(" The model can generate embeddings successfully!")
else:
print("\n❌ Fine-tuning validation: FAILED")
print(" The model has issues generating embeddings!")