init commit
This commit is contained in:
5
data/datasets/examples/embedding_data.jsonl
Normal file
5
data/datasets/examples/embedding_data.jsonl
Normal file
@@ -0,0 +1,5 @@
|
||||
{"query":"什么是深度学习?","pos":["深度学习是机器学习的一个子领域,使用多层神经网络模拟人脑处理信息的方式","深度学习(Deep Learning)是人工智能和机器学习的重要分支,通过构建深层神经网络来学习数据表示"],"neg":["机器学习包含多种算法,如决策树、支持向量机、神经网络等","人工智能是计算机科学的一个分支,目标是创建能够执行智能任务的系统","数据挖掘是从大型数据集中提取模式和知识的过程"],"pos_scores":[1.0,0.95],"neg_scores":[0.6,0.4,0.3],"prompt":"为此查询生成表示:","type":"normal"}
|
||||
{"query":"如何制作红烧肉?","pos":["红烧肉制作步骤:选五花肉切块,焯水去腥,热锅炒糖色,下肉翻炒,加生抽老抽料酒,小火炖煮30-40分钟至软烂","红烧肉是经典上海菜,用五花肉、冰糖、生抽、老抽、料酒等,关键在于炒糖色和火候控制"],"neg":["红烧肉是中国传统菜肴,口感香甜软糯,深受大众喜爱","五花肉含有丰富的蛋白质和脂肪,营养价值较高","上海菜以红烧见长,口味偏甜,适合南方人饮食习惯"],"pos_scores":[1.0,0.9],"neg_scores":[0.5,0.3,0.2],"prompt":"为此查询生成表示:","type":"recipe"}
|
||||
{"query":"Python编程入门应该学习哪些内容?","pos":["Python编程入门需要掌握:基本语法、数据类型(字符串、列表、字典)、控制结构(if/for/while)、函数定义、面向对象编程基础","Python新手学习路线:先学变量和数据类型,然后是条件语句和循环,接着学习函数和模块,最后是类和对象"],"neg":["Python是一种简单易学的编程语言,语法清晰,适合初学者","编程思维比语法更重要,需要培养逻辑思维能力","计算机科学包含算法、数据结构、操作系统等多个分支"],"pos_scores":[1.0,0.95],"neg_scores":[0.4,0.3,0.2],"prompt":"为此查询生成表示:","type":"programming"}
|
||||
{"query":"北京有哪些必去的旅游景点?","pos":["北京必去景点推荐:故宫(紫禁城)、天安门广场、万里长城、颐和园、天坛、圆明园、北海公园、什刹海等","故宫是明清两代皇宫,占地72万平方米,是世界最大的古代宫殿建筑群,必须提前预约参观"],"neg":["北京是中华人民共和国首都,有3000多年建城史","中国拥有众多世界文化遗产,旅游资源丰富","春秋季节是北京旅游的最佳时间,气候宜人"],"pos_scores":[1.0,0.9],"neg_scores":[0.3,0.25,0.4],"prompt":"为此查询生成表示:","type":"travel"}
|
||||
{"query":"机器学习和深度学习有什么区别?","pos":["机器学习是更广义的概念,深度学习是机器学习的一个子集,主要区别在于深度学习使用深层神经网络进行特征自动提取","深度学习使用多层神经网络自动学习特征,而传统机器学习通常需要人工设计特征,深度学习在图像和语音识别方面表现更优"],"neg":["机器学习算法包括监督学习、无监督学习和强化学习三大类","人工智能技术发展迅速,在各行各业都有广泛应用","神经网络是深度学习的基础,模拟人脑神经元连接"],"pos_scores":[1.0,0.95],"neg_scores":[0.4,0.3,0.5],"prompt":"为此查询生成表示:","type":"tech_comparison"}
|
||||
580
data/datasets/examples/format_usage_examples.py
Normal file
580
data/datasets/examples/format_usage_examples.py
Normal file
@@ -0,0 +1,580 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive usage examples for the refactored BGE data pipeline
|
||||
Demonstrates all four supported formats: triplets, pairs, candidates, and legacy nested
|
||||
|
||||
Data Formats Supported:
|
||||
1. Triplets: Query-triplets format (query, pos, neg) for BGE-M3 embedding training
|
||||
2. Pairs: Flat pairs format (query, passage, label) for BGE-reranker training
|
||||
3. Candidates: Unified format (query, candidates) with threshold-based label assignment
|
||||
4. Legacy Nested: Legacy nested format with automatic conversion to flat pairs
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import tempfile
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from data.dataset import BGEDataset, BGEM3Dataset, BGERerankerDataset
|
||||
from data.optimization import optimize_data
|
||||
|
||||
|
||||
def create_sample_data():
|
||||
"""Create sample data files for all four formats"""
|
||||
|
||||
# Create temporary directory
|
||||
temp_dir = tempfile.mkdtemp(prefix="bge_examples_")
|
||||
print(f"📁 Creating sample data in: {temp_dir}")
|
||||
|
||||
# 1. Triplets format (for BGE-M3 embedding training)
|
||||
triplets_file = os.path.join(temp_dir, "triplets_data.jsonl")
|
||||
triplets_data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"pos": [
|
||||
"Machine learning is a subset of artificial intelligence that enables computers to learn without explicit programming.",
|
||||
"ML algorithms use statistical techniques to learn patterns from data and make predictions."
|
||||
],
|
||||
"neg": [
|
||||
"Weather forecasting uses meteorological data to predict atmospheric conditions.",
|
||||
"Cooking recipes require specific ingredients and step-by-step instructions."
|
||||
],
|
||||
"pos_scores": [0.95, 0.88],
|
||||
"neg_scores": [0.15, 0.08],
|
||||
"prompt": "为此查询生成表示:",
|
||||
"type": "definition"
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"pos": [
|
||||
"Deep learning uses neural networks with multiple layers to learn complex patterns.",
|
||||
"Deep neural networks automatically extract features from raw data through backpropagation."
|
||||
],
|
||||
"neg": [
|
||||
"Baking bread requires yeast, flour, and proper kneading techniques.",
|
||||
"Solar panels convert sunlight into electrical energy through photovoltaic cells."
|
||||
],
|
||||
"pos_scores": [0.92, 0.89],
|
||||
"neg_scores": [0.12, 0.18]
|
||||
}
|
||||
]
|
||||
|
||||
with open(triplets_file, 'w', encoding='utf-8') as f:
|
||||
for item in triplets_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# 2. Pairs format (for BGE-reranker training)
|
||||
pairs_file = os.path.join(temp_dir, "pairs_data.jsonl")
|
||||
pairs_data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"passage": "Machine learning is a subset of artificial intelligence that enables computers to learn without explicit programming.",
|
||||
"label": 1,
|
||||
"score": 0.95,
|
||||
"qid": "q1",
|
||||
"pid": "p1"
|
||||
},
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"passage": "Weather forecasting uses meteorological data to predict atmospheric conditions.",
|
||||
"label": 0,
|
||||
"score": 0.15,
|
||||
"qid": "q1",
|
||||
"pid": "p2"
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"passage": "Deep learning uses neural networks with multiple layers to learn complex patterns.",
|
||||
"label": 1,
|
||||
"score": 0.92,
|
||||
"qid": "q2",
|
||||
"pid": "p3"
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"passage": "Baking bread requires yeast, flour, and proper kneading techniques.",
|
||||
"label": 0,
|
||||
"score": 0.12,
|
||||
"qid": "q2",
|
||||
"pid": "p4"
|
||||
}
|
||||
]
|
||||
|
||||
with open(pairs_file, 'w', encoding='utf-8') as f:
|
||||
for item in pairs_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# 3. Candidates format (unified format with threshold-based labels)
|
||||
candidates_file = os.path.join(temp_dir, "candidates_data.jsonl")
|
||||
candidates_data = [
|
||||
{
|
||||
"query": "What is artificial intelligence?",
|
||||
"qid": "q3",
|
||||
"candidates": [
|
||||
{
|
||||
"text": "Artificial intelligence is the simulation of human intelligence in machines.",
|
||||
"score": 0.94,
|
||||
"pid": "c1"
|
||||
},
|
||||
{
|
||||
"text": "AI systems can perform tasks that typically require human intelligence.",
|
||||
"score": 0.87,
|
||||
"pid": "c2"
|
||||
},
|
||||
{
|
||||
"text": "Mountain climbing requires proper equipment and training.",
|
||||
"score": 0.23,
|
||||
"pid": "c3"
|
||||
},
|
||||
{
|
||||
"text": "Chess is a strategic board game played between two players.",
|
||||
"score": 0.31,
|
||||
"pid": "c4"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Explain neural networks",
|
||||
"qid": "q4",
|
||||
"candidates": [
|
||||
{
|
||||
"text": "Neural networks are computing systems inspired by biological neural networks.",
|
||||
"score": 0.91,
|
||||
"pid": "c5"
|
||||
},
|
||||
{
|
||||
"text": "They consist of interconnected nodes that process information in layers.",
|
||||
"score": 0.85,
|
||||
"pid": "c6"
|
||||
},
|
||||
{
|
||||
"text": "Gardening involves planting seeds and watering plants regularly.",
|
||||
"score": 0.19,
|
||||
"pid": "c7"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
with open(candidates_file, 'w', encoding='utf-8') as f:
|
||||
for item in candidates_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# 4. Legacy nested format (for backward compatibility)
|
||||
legacy_file = os.path.join(temp_dir, "legacy_nested_data.jsonl")
|
||||
legacy_data = [
|
||||
{
|
||||
"query": "What is natural language processing?",
|
||||
"pos": [
|
||||
"Natural language processing is a field of AI that helps computers understand human language.",
|
||||
"NLP combines computational linguistics with machine learning to process text and speech."
|
||||
],
|
||||
"neg": [
|
||||
"Automobile manufacturing involves assembly lines and quality control processes.",
|
||||
"Photography captures light through camera lenses to create images."
|
||||
],
|
||||
"pos_scores": [0.93, 0.86],
|
||||
"neg_scores": [0.14, 0.22]
|
||||
},
|
||||
{
|
||||
"query": "How do transformers work in NLP?",
|
||||
"pos": [
|
||||
"Transformers use self-attention mechanisms to process sequential data in parallel.",
|
||||
"The transformer architecture revolutionized NLP with attention-based models like BERT and GPT."
|
||||
],
|
||||
"neg": [
|
||||
"Electric vehicles use batteries to power electric motors for transportation.",
|
||||
"Coffee brewing requires proper water temperature and grinding consistency."
|
||||
],
|
||||
"pos_scores": [0.89, 0.91],
|
||||
"neg_scores": [0.16, 0.13]
|
||||
}
|
||||
]
|
||||
|
||||
with open(legacy_file, 'w', encoding='utf-8') as f:
|
||||
for item in legacy_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
return {
|
||||
'triplets': triplets_file,
|
||||
'pairs': pairs_file,
|
||||
'candidates': candidates_file,
|
||||
'legacy_nested': legacy_file,
|
||||
'temp_dir': temp_dir
|
||||
}
|
||||
|
||||
|
||||
def demonstrate_fine_tuning_loader():
|
||||
"""Demonstrate fine-tuning data loader with all formats"""
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("🎯 FINE-TUNING DATA LOADER EXAMPLES")
|
||||
print("="*80)
|
||||
|
||||
# Create sample data
|
||||
data_files = create_sample_data()
|
||||
|
||||
# Initialize tokenizer (use a simple one for demo)
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# 1. Triplets format with BGE-M3 Dataset
|
||||
print("\n📊 1. TRIPLETS FORMAT (BGE-M3 Embedding Training)")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
triplets_dataset = BGEM3Dataset(
|
||||
data_path=data_files['triplets'],
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=64,
|
||||
passage_max_length=256,
|
||||
is_train=True
|
||||
)
|
||||
|
||||
print(f"✅ Loaded {len(triplets_dataset)} triplet examples")
|
||||
print(f" Detected format: {triplets_dataset.detected_format}")
|
||||
|
||||
# Get a sample
|
||||
sample = triplets_dataset[0]
|
||||
print(f" Sample keys: {list(sample.keys())}")
|
||||
print(f" Query shape: {sample['query_input_ids'].shape}")
|
||||
print(f" Passages shape: {sample['passage_input_ids'].shape}")
|
||||
print(f" Labels: {sample['labels']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
|
||||
# 2. Pairs format with BGE-Reranker Dataset
|
||||
print("\n📊 2. PAIRS FORMAT (BGE-Reranker Training)")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
pairs_dataset = BGERerankerDataset(
|
||||
data_path=data_files['pairs'],
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=64,
|
||||
passage_max_length=256,
|
||||
is_train=True,
|
||||
legacy_support=False
|
||||
)
|
||||
|
||||
print(f"✅ Loaded {len(pairs_dataset)} pair examples")
|
||||
print(f" Detected format: {pairs_dataset.detected_format}")
|
||||
|
||||
# Get a sample
|
||||
sample = pairs_dataset[0]
|
||||
print(f" Sample keys: {list(sample.keys())}")
|
||||
print(f" Input shape: {sample['input_ids'].shape}")
|
||||
print(f" Label: {sample['labels']}")
|
||||
print(f" Query: {sample['query_text'][:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
|
||||
# 3. Candidates format with automatic conversion
|
||||
print("\n📊 3. CANDIDATES FORMAT (Threshold-based Conversion)")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
candidates_dataset = BGEDataset(
|
||||
data_path=data_files['candidates'],
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=64,
|
||||
passage_max_length=256,
|
||||
is_train=True,
|
||||
format_type="candidates",
|
||||
candidates_threshold=0.5 # Scores >= 0.5 become label=1
|
||||
)
|
||||
|
||||
print(f"✅ Loaded {len(candidates_dataset)} exploded examples")
|
||||
print(f" Detected format: {candidates_dataset.detected_format}")
|
||||
print(f" Note: Each query with multiple candidates exploded into separate pairs")
|
||||
|
||||
# Get a sample
|
||||
sample = candidates_dataset[0]
|
||||
print(f" Sample keys: {list(sample.keys())}")
|
||||
print(f" Format after processing: {sample['format']}")
|
||||
print(f" Source: {sample.get('source', 'N/A')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
|
||||
# 4. Legacy nested format with automatic conversion
|
||||
print("\n📊 4. LEGACY NESTED FORMAT (Automatic Conversion)")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
legacy_dataset = BGERerankerDataset(
|
||||
data_path=data_files['legacy_nested'],
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=64,
|
||||
passage_max_length=256,
|
||||
is_train=True,
|
||||
legacy_support=True
|
||||
)
|
||||
|
||||
print(f"✅ Loaded {len(legacy_dataset)} converted examples")
|
||||
print(f" Detected format: {legacy_dataset.detected_format}")
|
||||
print(f" Note: Legacy nested format automatically converted to flat pairs")
|
||||
|
||||
# Get a sample
|
||||
sample = legacy_dataset[0]
|
||||
print(f" Sample keys: {list(sample.keys())}")
|
||||
print(f" Format after processing: {sample['format']}")
|
||||
print(f" Source: {sample.get('source', 'N/A')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
|
||||
return data_files
|
||||
|
||||
|
||||
def demonstrate_optimization_pipeline(data_files):
|
||||
"""Demonstrate optimization pipeline with all formats"""
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("⚡ OPTIMIZATION PIPELINE EXAMPLES")
|
||||
print("="*80)
|
||||
|
||||
cache_dir = os.path.join(data_files['temp_dir'], 'cache')
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# 1. BGE-M3 optimization with triplets format
|
||||
print("\n🚀 1. BGE-M3 OPTIMIZATION (Triplets Format)")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
cached_files = optimize_data(
|
||||
model_type='bge-m3',
|
||||
input_files=[data_files['triplets']],
|
||||
output_dir=cache_dir,
|
||||
tokenizer_path='bert-base-uncased',
|
||||
force=True,
|
||||
hard_negative_ratio=0.7,
|
||||
max_triplets_per_query=8,
|
||||
difficulty_threshold=0.1
|
||||
)
|
||||
|
||||
print(f"✅ M3 optimization complete: {len(cached_files)} files")
|
||||
for file in cached_files:
|
||||
print(f" 📦 {file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ M3 optimization error: {e}")
|
||||
|
||||
# 2. BGE-Reranker optimization with pairs format
|
||||
print("\n🚀 2. BGE-RERANKER OPTIMIZATION (Pairs Format)")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
cached_files = optimize_data(
|
||||
model_type='bge-reranker',
|
||||
input_files=[data_files['pairs']],
|
||||
output_dir=cache_dir,
|
||||
tokenizer_path='bert-base-uncased',
|
||||
force=True,
|
||||
output_format='flat'
|
||||
)
|
||||
|
||||
print(f"✅ Reranker optimization complete: {len(cached_files)} files")
|
||||
for file in cached_files:
|
||||
print(f" 📦 {file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Reranker optimization error: {e}")
|
||||
|
||||
# 3. Candidates format optimization (converts to pairs for reranker)
|
||||
print("\n🚀 3. CANDIDATES FORMAT OPTIMIZATION")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
cached_files = optimize_data(
|
||||
model_type='bge-reranker',
|
||||
input_files=[data_files['candidates']],
|
||||
output_dir=cache_dir,
|
||||
tokenizer_path='bert-base-uncased',
|
||||
force=True,
|
||||
output_format='flat'
|
||||
)
|
||||
|
||||
print(f"✅ Candidates optimization complete: {len(cached_files)} files")
|
||||
print(" Note: Candidates exploded into pairs with threshold-based labels")
|
||||
for file in cached_files:
|
||||
print(f" 📦 {file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Candidates optimization error: {e}")
|
||||
|
||||
# 4. Candidates to M3 triplets conversion
|
||||
print("\n🚀 4. CANDIDATES TO M3 TRIPLETS CONVERSION")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
cached_files = optimize_data(
|
||||
model_type='bge-m3',
|
||||
input_files=[data_files['candidates']],
|
||||
output_dir=cache_dir,
|
||||
tokenizer_path='bert-base-uncased',
|
||||
force=True,
|
||||
hard_negative_ratio=0.8
|
||||
)
|
||||
|
||||
print(f"✅ Candidates-to-M3 optimization complete: {len(cached_files)} files")
|
||||
print(" Note: Candidates split by threshold into pos/neg for triplets")
|
||||
for file in cached_files:
|
||||
print(f" 📦 {file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Candidates-to-M3 optimization error: {e}")
|
||||
|
||||
# 5. Legacy nested format optimization
|
||||
print("\n🚀 5. LEGACY NESTED FORMAT OPTIMIZATION")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
cached_files = optimize_data(
|
||||
model_type='bge-reranker',
|
||||
input_files=[data_files['legacy_nested']],
|
||||
output_dir=cache_dir,
|
||||
tokenizer_path='bert-base-uncased',
|
||||
force=True,
|
||||
output_format='nested' # Keep nested format
|
||||
)
|
||||
|
||||
print(f"✅ Legacy optimization complete: {len(cached_files)} files")
|
||||
print(" Note: Legacy nested format preserved for compatibility")
|
||||
for file in cached_files:
|
||||
print(f" 📦 {file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Legacy optimization error: {e}")
|
||||
|
||||
|
||||
def show_data_format_schemas():
|
||||
"""Show clean data format schemas without inline comments"""
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("📋 DATA FORMAT SCHEMAS")
|
||||
print("="*80)
|
||||
|
||||
print("\n1. 🎯 TRIPLETS FORMAT (BGE-M3 Embedding Training)")
|
||||
print("-" * 60)
|
||||
triplets_schema = {
|
||||
"query": "What is machine learning?",
|
||||
"pos": [
|
||||
"Machine learning is a subset of artificial intelligence...",
|
||||
"ML algorithms learn patterns from data..."
|
||||
],
|
||||
"neg": [
|
||||
"Weather forecasting uses meteorological data...",
|
||||
"Cooking recipes require specific ingredients..."
|
||||
],
|
||||
"pos_scores": [0.95, 0.88],
|
||||
"neg_scores": [0.15, 0.08],
|
||||
"prompt": "为此查询生成表示:",
|
||||
"type": "definition"
|
||||
}
|
||||
print(json.dumps(triplets_schema, indent=2, ensure_ascii=False))
|
||||
|
||||
print("\n2. 🔍 PAIRS FORMAT (BGE-Reranker Training)")
|
||||
print("-" * 60)
|
||||
pairs_schema = {
|
||||
"query": "What is machine learning?",
|
||||
"passage": "Machine learning is a subset of artificial intelligence...",
|
||||
"label": 1,
|
||||
"score": 0.95,
|
||||
"qid": "q1",
|
||||
"pid": "p1"
|
||||
}
|
||||
print(json.dumps(pairs_schema, indent=2, ensure_ascii=False))
|
||||
|
||||
print("\n3. 🎲 CANDIDATES FORMAT (Threshold-based Conversion)")
|
||||
print("-" * 60)
|
||||
candidates_schema = {
|
||||
"query": "What is artificial intelligence?",
|
||||
"qid": "q1",
|
||||
"candidates": [
|
||||
{
|
||||
"text": "AI is the simulation of human intelligence in machines.",
|
||||
"score": 0.94,
|
||||
"pid": "c1"
|
||||
},
|
||||
{
|
||||
"text": "Mountain climbing requires proper equipment.",
|
||||
"score": 0.23,
|
||||
"pid": "c2"
|
||||
}
|
||||
]
|
||||
}
|
||||
print(json.dumps(candidates_schema, indent=2, ensure_ascii=False))
|
||||
print("📝 Note: Scores >= threshold become label=1, others become label=0")
|
||||
|
||||
print("\n4. 🔄 LEGACY NESTED FORMAT (Backward Compatibility)")
|
||||
print("-" * 60)
|
||||
legacy_schema = {
|
||||
"query": "What is natural language processing?",
|
||||
"pos": [
|
||||
"NLP is a field of AI that helps computers understand human language.",
|
||||
"NLP combines computational linguistics with machine learning..."
|
||||
],
|
||||
"neg": [
|
||||
"Automobile manufacturing involves assembly lines...",
|
||||
"Photography captures light through camera lenses..."
|
||||
],
|
||||
"pos_scores": [0.93, 0.86],
|
||||
"neg_scores": [0.14, 0.22]
|
||||
}
|
||||
print(json.dumps(legacy_schema, indent=2, ensure_ascii=False))
|
||||
print("📝 Note: Automatically converted to flat pairs for reranker training")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main demonstration function"""
|
||||
|
||||
print("🚀 BGE Data Pipeline Format Examples")
|
||||
print("="*80)
|
||||
print("Demonstrating all four supported formats:")
|
||||
print("1. Triplets (BGE-M3 embedding training)")
|
||||
print("2. Pairs (BGE-reranker training)")
|
||||
print("3. Candidates (unified format with threshold conversion)")
|
||||
print("4. Legacy nested (backward compatibility)")
|
||||
print("="*80)
|
||||
|
||||
try:
|
||||
# Show data schemas
|
||||
show_data_format_schemas()
|
||||
|
||||
# Demonstrate fine-tuning loader
|
||||
data_files = demonstrate_fine_tuning_loader()
|
||||
|
||||
# Demonstrate optimization pipeline
|
||||
demonstrate_optimization_pipeline(data_files)
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("🎉 ALL EXAMPLES COMPLETED SUCCESSFULLY!")
|
||||
print("="*80)
|
||||
|
||||
print("\n📋 SUMMARY:")
|
||||
print("✅ Format Detection: Automatic detection of all four formats")
|
||||
print("✅ Conversion Logic: Seamless conversion between formats")
|
||||
print("✅ Threshold Assignment: Candidates → pairs with configurable threshold")
|
||||
print("✅ Legacy Support: Automatic migration of nested → flat pairs")
|
||||
print("✅ Optimization: Pre-tokenization and caching for 10-15x speedup")
|
||||
print("✅ Backward Compatibility: Full support for existing datasets")
|
||||
|
||||
print(f"\n📁 Sample data and cache files created in: {data_files['temp_dir']}")
|
||||
print(" You can examine the generated files to see the format conversions")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error during demonstration: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
20
data/datasets/examples/reranker_data.jsonl
Normal file
20
data/datasets/examples/reranker_data.jsonl
Normal file
@@ -0,0 +1,20 @@
|
||||
{"query":"什么是深度学习?","passage":"深度学习是机器学习的一个子领域,使用多层神经网络模拟人脑处理信息的方式","label":1,"score":1.0,"qid":"Q001","pid":"P001"}
|
||||
{"query":"什么是深度学习?","passage":"深度学习(Deep Learning)是人工智能和机器学习的重要分支,通过构建深层神经网络来学习数据表示","label":1,"score":0.95,"qid":"Q001","pid":"P002"}
|
||||
{"query":"什么是深度学习?","passage":"机器学习包含多种算法,如决策树、支持向量机、神经网络等","label":0,"score":0.6,"qid":"Q001","pid":"N001"}
|
||||
{"query":"什么是深度学习?","passage":"人工智能是计算机科学的一个分支,目标是创建能够执行智能任务的系统","label":0,"score":0.4,"qid":"Q001","pid":"N002"}
|
||||
{"query":"什么是深度学习?","passage":"数据挖掘是从大型数据集中提取模式和知识的过程","label":0,"score":0.3,"qid":"Q001","pid":"N003"}
|
||||
{"query":"如何制作红烧肉?","passage":"红烧肉制作步骤:选五花肉切块,焯水去腥,热锅炒糖色,下肉翻炒,加生抽老抽料酒,小火炖煮30-40分钟至软烂","label":1,"score":1.0,"qid":"Q002","pid":"P003"}
|
||||
{"query":"如何制作红烧肉?","passage":"红烧肉是经典上海菜,用五花肉、冰糖、生抽、老抽、料酒等,关键在于炒糖色和火候控制","label":1,"score":0.9,"qid":"Q002","pid":"P004"}
|
||||
{"query":"如何制作红烧肉?","passage":"红烧肉是中国传统菜肴,口感香甜软糯,深受大众喜爱","label":0,"score":0.5,"qid":"Q002","pid":"N004"}
|
||||
{"query":"如何制作红烧肉?","passage":"五花肉含有丰富的蛋白质和脂肪,营养价值较高","label":0,"score":0.3,"qid":"Q002","pid":"N005"}
|
||||
{"query":"如何制作红烧肉?","passage":"上海菜以红烧见长,口味偏甜,适合南方人饮食习惯","label":0,"score":0.2,"qid":"Q002","pid":"N006"}
|
||||
{"query":"Python编程入门应该学习哪些内容?","passage":"Python编程入门需要掌握:基本语法、数据类型(字符串、列表、字典)、控制结构(if/for/while)、函数定义、面向对象编程基础","label":1,"score":1.0,"qid":"Q003","pid":"P005"}
|
||||
{"query":"Python编程入门应该学习哪些内容?","passage":"Python新手学习路线:先学变量和数据类型,然后是条件语句和循环,接着学习函数和模块,最后是类和对象","label":1,"score":0.95,"qid":"Q003","pid":"P006"}
|
||||
{"query":"Python编程入门应该学习哪些内容?","passage":"Python是一种简单易学的编程语言,语法清晰,适合初学者","label":0,"score":0.4,"qid":"Q003","pid":"N007"}
|
||||
{"query":"Python编程入门应该学习哪些内容?","passage":"编程思维比语法更重要,需要培养逻辑思维能力","label":0,"score":0.3,"qid":"Q003","pid":"N008"}
|
||||
{"query":"Python编程入门应该学习哪些内容?","passage":"计算机科学包含算法、数据结构、操作系统等多个分支","label":0,"score":0.2,"qid":"Q003","pid":"N009"}
|
||||
{"query":"北京有哪些必去的旅游景点?","passage":"北京必去景点推荐:故宫(紫禁城)、天安门广场、万里长城、颐和园、天坛、圆明园、北海公园、什刹海等","label":1,"score":1.0,"qid":"Q004","pid":"P007"}
|
||||
{"query":"北京有哪些必去的旅游景点?","passage":"故宫是明清两代皇宫,占地72万平方米,是世界最大的古代宫殿建筑群,必须提前预约参观","label":1,"score":0.9,"qid":"Q004","pid":"P008"}
|
||||
{"query":"北京有哪些必去的旅游景点?","passage":"北京是中华人民共和国首都,有3000多年建城史","label":0,"score":0.3,"qid":"Q004","pid":"N010"}
|
||||
{"query":"北京有哪些必去的旅游景点?","passage":"中国拥有众多世界文化遗产,旅游资源丰富","label":0,"score":0.25,"qid":"Q004","pid":"N011"}
|
||||
{"query":"北京有哪些必去的旅游景点?","passage":"春秋季节是北京旅游的最佳时间,气候宜人","label":0,"score":0.4,"qid":"Q004","pid":"N012"}
|
||||
Reference in New Issue
Block a user