bge_finetune/scripts/convert_reranker_dataset.py
2025-07-23 14:54:46 +08:00

146 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
convert_jsonl_with_fallback.py
Like your original converter, but when inner JSON parsing fails
(it often does when there are unescaped quotes inside the passage),
we do a simple regex/stringsearch to pull out passage, label, score.
"""
import argparse
import json
import os
import re
import sys
from pathlib import Path
def parse_args():
p = argparse.ArgumentParser(
description="Convert JSONL with embedded JSON in 'data' to flat JSONL, with fallback."
)
p.add_argument(
"--input", "-i", required=True,
help="Path to input .jsonl file or directory containing .jsonl files"
)
p.add_argument(
"--output-dir", "-o", required=True,
help="Directory where converted files will be written"
)
return p.parse_args()
def strip_code_fence(s: str) -> str:
"""Remove ```json ... ``` or ``` ... ``` fences, return inner text."""
fenced = re.compile(r"```(?:json)?\s*(.*?)\s*```", re.DOTALL)
m = fenced.search(s)
return m.group(1) if m else s
def fallback_extract(inner: str):
"""
Fallback extractor that:
- finds the passage between "passage": "",
- then pulls label and score via regex.
Returns dict or None.
"""
# 1) passage: find between `"passage": "` and the next `",`
p_start = inner.find('"passage": "')
if p_start < 0:
return None
p_start += len('"passage": "')
# we assume the passage ends at the first occurrence of `",` after p_start
p_end = inner.find('",', p_start)
if p_end < 0:
return None
passage = inner[p_start:p_end]
# 2) label
m_lbl = re.search(r'"label"\s*:\s*(\d+)', inner[p_end:])
if not m_lbl:
return None
label = int(m_lbl.group(1))
# 3) score
m_sc = re.search(r'"score"\s*:\s*([0-9]*\.?[0-9]+)', inner[p_end:])
if not m_sc:
return None
score = float(m_sc.group(1))
return {"passage": passage, "label": label, "score": score}
def process_file(in_path: Path, out_path: Path):
with in_path.open("r", encoding="utf-8") as fin, \
out_path.open("w", encoding="utf-8") as fout:
for lineno, line in enumerate(fin, 1):
line = line.strip()
if not line:
continue
# parse the outer JSON
try:
outer = json.loads(line)
except json.JSONDecodeError:
print(f"[WARN] line {lineno} in {in_path.name}: outer JSON invalid, skipping", file=sys.stderr)
continue
query = outer.get("query")
data_field = outer.get("data")
if not query or not data_field:
# missing query or data → skip
continue
inner_text = strip_code_fence(data_field)
# attempt normal JSON parse
record = None
try:
inner = json.loads(inner_text)
# must have all three keys
if all(k in inner for k in ("passage", "label", "score")):
record = {
"query": query,
"passage": inner["passage"],
"label": inner["label"],
"score": inner["score"]
}
# else record stays None → fallback below
except json.JSONDecodeError:
pass
# if JSON parse gave nothing, try fallback
if record is None:
fb = fallback_extract(inner_text)
if fb:
record = {
"query": query,
**fb
}
print(f"[INFO] line {lineno} in {in_path.name}: used fallback extraction", file=sys.stderr)
else:
print(f"[WARN] line {lineno} in {in_path.name}: unable to extract fields, skipping", file=sys.stderr)
continue
fout.write(json.dumps(record, ensure_ascii=False) + "\n")
def main():
args = parse_args()
inp = Path(args.input)
outdir = Path(args.output_dir)
outdir.mkdir(parents=True, exist_ok=True)
if inp.is_dir():
files = list(inp.glob("*.jsonl"))
else:
files = [inp]
if not files:
print(f"No .jsonl files found in {inp}", file=sys.stderr)
sys.exit(1)
for f in files:
dest = outdir / f.name
print(f"Converting {f}{dest}", file=sys.stderr)
process_file(f, dest)
if __name__ == "__main__":
main()