init commit
This commit is contained in:
584
scripts/evaluate.py
Normal file
584
scripts/evaluate.py
Normal file
@@ -0,0 +1,584 @@
|
||||
"""
|
||||
Evaluation script for fine-tuned BGE models
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
import numpy as np
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from utils.logging import setup_logging, get_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments for evaluation."""
|
||||
parser = argparse.ArgumentParser(description="Evaluate fine-tuned BGE models")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--retriever_model", type=str, required=True,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
|
||||
help="Base retriever model for comparison")
|
||||
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Base reranker model for comparison")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--eval_data", type=str, required=True,
|
||||
help="Path to evaluation data file")
|
||||
parser.add_argument("--corpus_data", type=str, default=None,
|
||||
help="Path to corpus data for retrieval evaluation")
|
||||
parser.add_argument("--data_format", type=str, default="auto",
|
||||
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
||||
help="Format of input data")
|
||||
|
||||
# Evaluation arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
|
||||
help="Directory to save evaluation results")
|
||||
parser.add_argument("--batch_size", type=int, default=32,
|
||||
help="Batch size for evaluation")
|
||||
parser.add_argument("--max_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
|
||||
help="K values for evaluation metrics")
|
||||
|
||||
# Evaluation modes
|
||||
parser.add_argument("--eval_retriever", action="store_true",
|
||||
help="Evaluate retriever model")
|
||||
parser.add_argument("--eval_reranker", action="store_true",
|
||||
help="Evaluate reranker model")
|
||||
parser.add_argument("--eval_pipeline", action="store_true",
|
||||
help="Evaluate full retrieval pipeline")
|
||||
parser.add_argument("--compare_baseline", action="store_true",
|
||||
help="Compare with baseline models")
|
||||
parser.add_argument("--interactive", action="store_true",
|
||||
help="Run interactive evaluation")
|
||||
|
||||
# Pipeline settings
|
||||
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
||||
help="Top-k for initial retrieval")
|
||||
parser.add_argument("--rerank_top_k", type=int, default=10,
|
||||
help="Top-k for reranking")
|
||||
|
||||
# Other
|
||||
parser.add_argument("--device", type=str, default="auto",
|
||||
help="Device to use (auto, cpu, cuda)")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validation
|
||||
if not any([args.eval_retriever, args.eval_reranker, args.eval_pipeline, args.interactive]):
|
||||
args.eval_pipeline = True # Default to pipeline evaluation
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def load_evaluation_data(data_path: str, data_format: str):
|
||||
"""Load evaluation data from a file."""
|
||||
preprocessor = DataPreprocessor()
|
||||
|
||||
if data_format == "auto":
|
||||
data_format = preprocessor._detect_format(data_path)
|
||||
|
||||
if data_format == "jsonl":
|
||||
data = preprocessor._load_jsonl(data_path)
|
||||
elif data_format == "json":
|
||||
data = preprocessor._load_json(data_path)
|
||||
elif data_format in ["csv", "tsv"]:
|
||||
data = preprocessor._load_csv(data_path, delimiter="," if data_format == "csv" else "\t")
|
||||
elif data_format == "dureader":
|
||||
data = preprocessor._load_dureader(data_path)
|
||||
elif data_format == "msmarco":
|
||||
data = preprocessor._load_msmarco(data_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {data_format}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def evaluate_retriever(args, retriever_model, tokenizer, eval_data):
|
||||
"""Evaluate the retriever model."""
|
||||
logger.info("Evaluating retriever model...")
|
||||
|
||||
evaluator = RetrievalEvaluator(
|
||||
model=retriever_model,
|
||||
tokenizer=tokenizer,
|
||||
device=args.device,
|
||||
k_values=args.k_values
|
||||
)
|
||||
|
||||
# Prepare data for evaluation
|
||||
queries = [item['query'] for item in eval_data]
|
||||
|
||||
# If we have passage data, create passage corpus
|
||||
all_passages = set()
|
||||
for item in eval_data:
|
||||
if 'pos' in item:
|
||||
all_passages.update(item['pos'])
|
||||
if 'neg' in item:
|
||||
all_passages.update(item['neg'])
|
||||
|
||||
passages = list(all_passages)
|
||||
|
||||
# Create relevance mapping
|
||||
relevant_docs = {}
|
||||
for i, item in enumerate(eval_data):
|
||||
query = item['query']
|
||||
if 'pos' in item:
|
||||
relevant_indices = []
|
||||
for pos in item['pos']:
|
||||
if pos in passages:
|
||||
relevant_indices.append(passages.index(pos))
|
||||
if relevant_indices:
|
||||
relevant_docs[query] = relevant_indices
|
||||
|
||||
# Evaluate
|
||||
if passages and relevant_docs:
|
||||
metrics = evaluator.evaluate_retrieval_pipeline(
|
||||
queries=queries,
|
||||
corpus=passages,
|
||||
relevant_docs=relevant_docs,
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
else:
|
||||
logger.warning("No valid corpus or relevance data found. Running basic evaluation.")
|
||||
# Create dummy evaluation
|
||||
metrics = {"note": "Limited evaluation due to data format"}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate_reranker(args, reranker_model, tokenizer, eval_data):
|
||||
"""Evaluate the reranker model."""
|
||||
logger.info("Evaluating reranker model...")
|
||||
|
||||
evaluator = RetrievalEvaluator(
|
||||
model=reranker_model,
|
||||
tokenizer=tokenizer,
|
||||
device=args.device,
|
||||
k_values=args.k_values
|
||||
)
|
||||
|
||||
# Prepare data for reranker evaluation
|
||||
total_queries = 0
|
||||
total_correct = 0
|
||||
total_mrr = 0
|
||||
|
||||
for item in eval_data:
|
||||
if 'pos' not in item or 'neg' not in item:
|
||||
continue
|
||||
|
||||
query = item['query']
|
||||
positives = item['pos']
|
||||
negatives = item['neg']
|
||||
|
||||
if not positives or not negatives:
|
||||
continue
|
||||
|
||||
# Create candidates (1 positive + negatives)
|
||||
candidates = positives[:1] + negatives
|
||||
|
||||
# Score all candidates
|
||||
pairs = [(query, candidate) for candidate in candidates]
|
||||
scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
|
||||
|
||||
# Check if positive is ranked first
|
||||
best_idx = np.argmax(scores)
|
||||
if best_idx == 0: # First is positive
|
||||
total_correct += 1
|
||||
total_mrr += 1.0
|
||||
else:
|
||||
# Find rank of positive (it's always at index 0)
|
||||
sorted_indices = np.argsort(scores)[::-1]
|
||||
positive_rank = np.where(sorted_indices == 0)[0][0] + 1
|
||||
total_mrr += 1.0 / positive_rank
|
||||
|
||||
total_queries += 1
|
||||
|
||||
if total_queries > 0:
|
||||
metrics = {
|
||||
'accuracy@1': total_correct / total_queries,
|
||||
'mrr': total_mrr / total_queries,
|
||||
'total_queries': total_queries
|
||||
}
|
||||
else:
|
||||
metrics = {"note": "No valid query-passage pairs found for reranker evaluation"}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate_pipeline(args, retriever_model, reranker_model, tokenizer, eval_data):
|
||||
"""Evaluate the full retrieval + reranking pipeline."""
|
||||
logger.info("Evaluating full pipeline...")
|
||||
|
||||
# Prepare corpus
|
||||
all_passages = set()
|
||||
for item in eval_data:
|
||||
if 'pos' in item:
|
||||
all_passages.update(item['pos'])
|
||||
if 'neg' in item:
|
||||
all_passages.update(item['neg'])
|
||||
|
||||
passages = list(all_passages)
|
||||
|
||||
if not passages:
|
||||
return {"note": "No passages found for pipeline evaluation"}
|
||||
|
||||
# Pre-encode corpus with retriever
|
||||
logger.info("Encoding corpus...")
|
||||
corpus_embeddings = retriever_model.encode(
|
||||
passages,
|
||||
batch_size=args.batch_size,
|
||||
return_numpy=True,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
total_queries = 0
|
||||
pipeline_metrics = {
|
||||
'retrieval_recall@10': [],
|
||||
'rerank_accuracy@1': [],
|
||||
'rerank_mrr': []
|
||||
}
|
||||
|
||||
for item in eval_data:
|
||||
if 'pos' not in item or not item['pos']:
|
||||
continue
|
||||
|
||||
query = item['query']
|
||||
relevant_passages = item['pos']
|
||||
|
||||
# Step 1: Retrieval
|
||||
query_embedding = retriever_model.encode(
|
||||
[query],
|
||||
return_numpy=True,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
# Compute similarities
|
||||
similarities = np.dot(query_embedding, corpus_embeddings.T).squeeze()
|
||||
|
||||
# Get top-k for retrieval
|
||||
top_k_indices = np.argsort(similarities)[::-1][:args.retrieval_top_k]
|
||||
retrieved_passages = [passages[i] for i in top_k_indices]
|
||||
|
||||
# Check retrieval recall
|
||||
retrieved_relevant = sum(1 for p in retrieved_passages[:10] if p in relevant_passages)
|
||||
retrieval_recall = retrieved_relevant / min(len(relevant_passages), 10)
|
||||
pipeline_metrics['retrieval_recall@10'].append(retrieval_recall)
|
||||
|
||||
# Step 2: Reranking (if reranker available)
|
||||
if reranker_model is not None:
|
||||
# Rerank top candidates
|
||||
rerank_candidates = retrieved_passages[:args.rerank_top_k]
|
||||
pairs = [(query, candidate) for candidate in rerank_candidates]
|
||||
|
||||
if pairs:
|
||||
rerank_scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
|
||||
reranked_indices = np.argsort(rerank_scores)[::-1]
|
||||
reranked_passages = [rerank_candidates[i] for i in reranked_indices]
|
||||
|
||||
# Check reranking performance
|
||||
if reranked_passages[0] in relevant_passages:
|
||||
pipeline_metrics['rerank_accuracy@1'].append(1.0)
|
||||
pipeline_metrics['rerank_mrr'].append(1.0)
|
||||
else:
|
||||
pipeline_metrics['rerank_accuracy@1'].append(0.0)
|
||||
# Find MRR
|
||||
for rank, passage in enumerate(reranked_passages, 1):
|
||||
if passage in relevant_passages:
|
||||
pipeline_metrics['rerank_mrr'].append(1.0 / rank)
|
||||
break
|
||||
else:
|
||||
pipeline_metrics['rerank_mrr'].append(0.0)
|
||||
|
||||
total_queries += 1
|
||||
|
||||
# Average metrics
|
||||
final_metrics = {}
|
||||
for metric, values in pipeline_metrics.items():
|
||||
if values:
|
||||
final_metrics[metric] = np.mean(values)
|
||||
|
||||
final_metrics['total_queries'] = total_queries
|
||||
|
||||
return final_metrics
|
||||
|
||||
|
||||
def run_interactive_evaluation(args, retriever_model, reranker_model, tokenizer, eval_data):
|
||||
"""Run interactive evaluation."""
|
||||
logger.info("Starting interactive evaluation...")
|
||||
|
||||
# Prepare corpus
|
||||
all_passages = set()
|
||||
for item in eval_data:
|
||||
if 'pos' in item:
|
||||
all_passages.update(item['pos'])
|
||||
if 'neg' in item:
|
||||
all_passages.update(item['neg'])
|
||||
|
||||
passages = list(all_passages)
|
||||
|
||||
if not passages:
|
||||
print("No passages found for interactive evaluation")
|
||||
return
|
||||
|
||||
# Create interactive evaluator
|
||||
evaluator = InteractiveEvaluator(
|
||||
retriever=retriever_model,
|
||||
reranker=reranker_model,
|
||||
tokenizer=tokenizer,
|
||||
corpus=passages,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
print("\n=== Interactive Evaluation ===")
|
||||
print("Enter queries to test the retrieval system.")
|
||||
print("Type 'quit' to exit, 'sample' to test with sample queries.")
|
||||
|
||||
# Load some sample queries
|
||||
sample_queries = [item['query'] for item in eval_data[:5]]
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\nEnter query: ").strip()
|
||||
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
break
|
||||
elif user_input.lower() == 'sample':
|
||||
print("\nSample queries:")
|
||||
for i, query in enumerate(sample_queries):
|
||||
print(f"{i + 1}. {query}")
|
||||
continue
|
||||
elif user_input.isdigit() and 1 <= int(user_input) <= len(sample_queries):
|
||||
query = sample_queries[int(user_input) - 1]
|
||||
else:
|
||||
query = user_input
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
# Search
|
||||
results = evaluator.search(query, top_k=10, rerank=reranker_model is not None)
|
||||
|
||||
print(f"\nResults for: {query}")
|
||||
print("-" * 50)
|
||||
for i, (idx, score, text) in enumerate(results):
|
||||
print(f"{i + 1}. (Score: {score:.3f}) {text[:100]}...")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting interactive evaluation...")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run evaluation with robust error handling and config-driven defaults."""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Evaluate BGE models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--retriever_model", type=str, required=True,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
|
||||
help="Base retriever model for comparison")
|
||||
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Base reranker model for comparison")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--eval_data", type=str, required=True,
|
||||
help="Path to evaluation data file")
|
||||
parser.add_argument("--corpus_data", type=str, default=None,
|
||||
help="Path to corpus data for retrieval evaluation")
|
||||
parser.add_argument("--data_format", type=str, default="auto",
|
||||
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
||||
help="Format of input data")
|
||||
|
||||
# Evaluation arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
|
||||
help="Directory to save evaluation results")
|
||||
parser.add_argument("--batch_size", type=int, default=32,
|
||||
help="Batch size for evaluation")
|
||||
parser.add_argument("--max_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
|
||||
help="K values for evaluation metrics")
|
||||
|
||||
# Evaluation modes
|
||||
parser.add_argument("--eval_retriever", action="store_true",
|
||||
help="Evaluate retriever model")
|
||||
parser.add_argument("--eval_reranker", action="store_true",
|
||||
help="Evaluate reranker model")
|
||||
parser.add_argument("--eval_pipeline", action="store_true",
|
||||
help="Evaluate full retrieval pipeline")
|
||||
parser.add_argument("--compare_baseline", action="store_true",
|
||||
help="Compare with baseline models")
|
||||
parser.add_argument("--interactive", action="store_true",
|
||||
help="Run interactive evaluation")
|
||||
|
||||
# Pipeline settings
|
||||
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
||||
help="Top-k for initial retrieval")
|
||||
parser.add_argument("--rerank_top_k", type=int, default=10,
|
||||
help="Top-k for reranking")
|
||||
|
||||
# Other
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config and set defaults
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=args.output_dir,
|
||||
log_level=logging.INFO,
|
||||
log_to_console=True,
|
||||
log_to_file=True
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Starting model evaluation")
|
||||
logger.info(f"Arguments: {args}")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Set device
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
elif device == "npu":
|
||||
npu = getattr(torch, "npu", None)
|
||||
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
||||
device = "npu"
|
||||
else:
|
||||
logger.warning("NPU (Ascend) is not available. Falling back to CPU.")
|
||||
device = "cpu"
|
||||
else:
|
||||
device = device
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Load models
|
||||
logger.info("Loading models...")
|
||||
from utils.config_loader import get_config
|
||||
model_source = get_config().get('model_paths.source', 'huggingface')
|
||||
# Load retriever
|
||||
try:
|
||||
retriever_model = BGEM3Model.from_pretrained(args.retriever_model)
|
||||
retriever_tokenizer = None # Already loaded inside BGEM3Model if needed
|
||||
retriever_model.to(device)
|
||||
retriever_model.eval()
|
||||
logger.info(f"Loaded retriever from {args.retriever_model} (source: {model_source})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load retriever: {e}")
|
||||
return
|
||||
|
||||
# Load reranker (optional)
|
||||
reranker_model = None
|
||||
reranker_tokenizer = None
|
||||
if args.reranker_model:
|
||||
try:
|
||||
reranker_model = BGERerankerModel(args.reranker_model)
|
||||
reranker_tokenizer = None # Already loaded inside BGERerankerModel if needed
|
||||
reranker_model.to(device)
|
||||
reranker_model.eval()
|
||||
logger.info(f"Loaded reranker from {args.reranker_model} (source: {model_source})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load reranker: {e}")
|
||||
|
||||
# Load evaluation data
|
||||
logger.info("Loading evaluation data...")
|
||||
eval_data = load_evaluation_data(args.eval_data, args.data_format)
|
||||
logger.info(f"Loaded {len(eval_data)} evaluation examples")
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Run evaluations
|
||||
all_results = {}
|
||||
|
||||
if args.eval_retriever:
|
||||
retriever_metrics = evaluate_retriever(args, retriever_model, retriever_tokenizer, eval_data)
|
||||
all_results['retriever'] = retriever_metrics
|
||||
logger.info(f"Retriever metrics: {retriever_metrics}")
|
||||
|
||||
if args.eval_reranker and reranker_model:
|
||||
reranker_metrics = evaluate_reranker(args, reranker_model, reranker_tokenizer, eval_data)
|
||||
all_results['reranker'] = reranker_metrics
|
||||
logger.info(f"Reranker metrics: {reranker_metrics}")
|
||||
|
||||
if args.eval_pipeline:
|
||||
pipeline_metrics = evaluate_pipeline(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
|
||||
all_results['pipeline'] = pipeline_metrics
|
||||
logger.info(f"Pipeline metrics: {pipeline_metrics}")
|
||||
|
||||
# Compare with baseline if requested
|
||||
if args.compare_baseline:
|
||||
logger.info("Loading baseline models for comparison...")
|
||||
try:
|
||||
base_retriever = BGEM3Model(args.base_retriever)
|
||||
base_retriever.to(device)
|
||||
base_retriever.eval()
|
||||
|
||||
base_retriever_metrics = evaluate_retriever(args, base_retriever, retriever_tokenizer, eval_data)
|
||||
all_results['baseline_retriever'] = base_retriever_metrics
|
||||
logger.info(f"Baseline retriever metrics: {base_retriever_metrics}")
|
||||
|
||||
if reranker_model:
|
||||
base_reranker = BGERerankerModel(args.base_reranker)
|
||||
base_reranker.to(device)
|
||||
base_reranker.eval()
|
||||
|
||||
base_reranker_metrics = evaluate_reranker(args, base_reranker, reranker_tokenizer, eval_data)
|
||||
all_results['baseline_reranker'] = base_reranker_metrics
|
||||
logger.info(f"Baseline reranker metrics: {base_reranker_metrics}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load baseline models: {e}")
|
||||
|
||||
# Save results
|
||||
results_file = os.path.join(args.output_dir, "evaluation_results.json")
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
|
||||
logger.info(f"Evaluation results saved to {results_file}")
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("EVALUATION SUMMARY")
|
||||
print("=" * 50)
|
||||
for model_type, metrics in all_results.items():
|
||||
print(f"\n{model_type.upper()}:")
|
||||
for metric, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.4f}")
|
||||
else:
|
||||
print(f" {metric}: {value}")
|
||||
|
||||
# Interactive evaluation
|
||||
if args.interactive:
|
||||
run_interactive_evaluation(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user