init commit
This commit is contained in:
122
scripts/validate_m3.py
Normal file
122
scripts/validate_m3.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Quick validation script for the fine-tuned BGE-M3 model
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Add project root to Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
|
||||
def test_model_embeddings(model_path, test_queries=None):
|
||||
"""Test if the model can generate embeddings"""
|
||||
|
||||
if test_queries is None:
|
||||
test_queries = [
|
||||
"什么是深度学习?", # Original training query
|
||||
"什么是机器学习?", # Related query
|
||||
"今天天气怎么样?" # Unrelated query
|
||||
]
|
||||
|
||||
print(f"🧪 Testing model from: {model_path}")
|
||||
|
||||
try:
|
||||
# Load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGEM3Model.from_pretrained(model_path)
|
||||
model.eval()
|
||||
|
||||
print("✅ Model and tokenizer loaded successfully")
|
||||
|
||||
# Test embeddings generation
|
||||
embeddings = []
|
||||
|
||||
for i, query in enumerate(test_queries):
|
||||
print(f"\n📝 Testing query {i+1}: '{query}'")
|
||||
|
||||
# Tokenize
|
||||
inputs = tokenizer(
|
||||
query,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
with torch.no_grad():
|
||||
outputs = model.forward(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
if 'dense' in outputs:
|
||||
embedding = outputs['dense'].cpu().numpy() # type: ignore[index]
|
||||
embeddings.append(embedding)
|
||||
|
||||
print(f" ✅ Embedding shape: {embedding.shape}")
|
||||
print(f" 📊 Embedding norm: {np.linalg.norm(embedding):.4f}")
|
||||
print(f" 📈 Mean activation: {embedding.mean():.6f}")
|
||||
print(f" 📉 Std activation: {embedding.std():.6f}")
|
||||
else:
|
||||
print(f" ❌ No dense embedding returned")
|
||||
return False
|
||||
|
||||
# Compare embeddings
|
||||
if len(embeddings) >= 2:
|
||||
print(f"\n🔗 Similarity Analysis:")
|
||||
for i in range(len(embeddings)):
|
||||
for j in range(i+1, len(embeddings)):
|
||||
sim = np.dot(embeddings[i].flatten(), embeddings[j].flatten()) / (
|
||||
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
|
||||
)
|
||||
print(f" Query {i+1} vs Query {j+1}: {sim:.4f}")
|
||||
|
||||
print(f"\n🎉 Model validation successful!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Model validation failed: {e}")
|
||||
return False
|
||||
|
||||
def compare_models(original_path, finetuned_path):
|
||||
"""Compare original vs fine-tuned model"""
|
||||
print("🔄 Comparing original vs fine-tuned model...")
|
||||
|
||||
test_query = "什么是深度学习?"
|
||||
|
||||
# Test original model
|
||||
print("\n📍 Original Model:")
|
||||
orig_success = test_model_embeddings(original_path, [test_query])
|
||||
|
||||
# Test fine-tuned model
|
||||
print("\n📍 Fine-tuned Model:")
|
||||
ft_success = test_model_embeddings(finetuned_path, [test_query])
|
||||
|
||||
return orig_success and ft_success
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 BGE-M3 Model Validation")
|
||||
print("=" * 50)
|
||||
|
||||
# Test fine-tuned model
|
||||
finetuned_model_path = "./test_m3_output/final_model"
|
||||
success = test_model_embeddings(finetuned_model_path)
|
||||
|
||||
# Optional: Compare with original model
|
||||
original_model_path = "./models/bge-m3"
|
||||
compare_models(original_model_path, finetuned_model_path)
|
||||
|
||||
if success:
|
||||
print("\n✅ Fine-tuning validation: PASSED")
|
||||
print(" The model can generate embeddings successfully!")
|
||||
else:
|
||||
print("\n❌ Fine-tuning validation: FAILED")
|
||||
print(" The model has issues generating embeddings!")
|
||||
Reference in New Issue
Block a user