mindef-overdracht/llm-throughput-tests-mindef-metadateren/create_test_dataset.py
2026-06-02 11:46:20 +02:00

339 lines
11 KiB
Python

#!/usr/bin/env python3
"""
Create bucketed test dataset for LLM benchmarking.
Uses multiple strategies to fill all token buckets:
1. Natural conversations from UltraChat dataset
2. Concatenation of shorter conversations for larger buckets
Buckets aligned with benchmark input_tokens: 100, 500, 1k, 2k, 5k, 10k
Outputs 128 unique conversations per bucket for comprehensive testing.
Usage:
python create_test_dataset.py
python create_test_dataset.py --output test_conversations.json
python create_test_dataset.py --buckets 1000 5000 10000 --chains_per_bucket 64
"""
import argparse
import json
import random
from collections import defaultdict
from pathlib import Path
import tiktoken
from datasets import load_dataset
# Default buckets aligned with typical benchmark configurations
DEFAULT_BUCKETS = [100, 500, 1_000, 2_000, 5_000, 10_000]
CHAINS_PER_BUCKET = 128
DATASET_NAME = "HuggingFaceH4/ultrachat_200k"
ENCODING_NAME = "cl100k_base"
def count_tokens(messages: list[dict], encoding: tiktoken.Encoding) -> int:
"""Count total tokens in a conversation chain."""
total = 0
for msg in messages:
content = msg.get("content", "") or ""
role = msg.get("role", "") or ""
total += len(encoding.encode(content, disallowed_special=()))
total += len(encoding.encode(role, disallowed_special=()))
total += 4 # Message formatting overhead
total += 2 # Conversation formatting overhead
return total
def get_bucket(token_count: int, buckets: list[int]) -> int | None:
"""Find the appropriate bucket for a token count (within 20% of target)."""
for bucket in buckets:
if bucket * 0.8 <= token_count <= bucket * 1.2:
return bucket
return None
def format_ultrachat_messages(messages: list[dict]) -> list[dict]:
"""Format UltraChat conversations to OpenAI chat format."""
formatted = []
for msg in messages:
role = msg.get("role", "user")
if role not in ["user", "assistant", "system"]:
role = "user"
content = msg.get("content", "") or ""
if content:
formatted.append({"role": role, "content": content})
return formatted
def concatenate_conversations(
conversations: list[list[dict]],
target_tokens: int,
encoding: tiktoken.Encoding,
tolerance: float = 0.2
) -> list[dict] | None:
"""Concatenate multiple conversations to reach target token count."""
result = []
current_tokens = 0
target_min = target_tokens * (1 - tolerance)
target_max = target_tokens * (1 + tolerance)
random.shuffle(conversations)
for conv in conversations:
conv_tokens = count_tokens(conv, encoding)
# Skip if this would exceed target
if current_tokens + conv_tokens > target_max:
continue
# Add separator between conversations
if result and conv:
separator = {"role": "user", "content": "---\nNew conversation:\n---"}
result.append(separator)
current_tokens += 10 # Approximate tokens for separator
result.extend(conv)
current_tokens += conv_tokens
# Check if we've reached target
if current_tokens >= target_min:
break
# Verify we're within acceptable range
if current_tokens < target_min * 0.8:
return None
return result
def main():
parser = argparse.ArgumentParser(
description="Create bucketed test dataset for LLM benchmarking",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Default configuration (128 conversations per bucket)
python create_test_dataset.py
# Custom buckets
python create_test_dataset.py --buckets 1000 5000 10000
# Fewer conversations per bucket
python create_test_dataset.py --chains_per_bucket 64
# Custom output location
python create_test_dataset.py --output data/conversations.json
"""
)
parser.add_argument(
"--output",
type=str,
default="test_conversations.json",
help="Output file path (default: test_conversations.json)"
)
parser.add_argument(
"--buckets",
type=int,
nargs='+',
default=DEFAULT_BUCKETS,
help="Token count buckets (default: 100 500 1000 2000 5000 10000)"
)
parser.add_argument(
"--chains_per_bucket",
type=int,
default=CHAINS_PER_BUCKET,
help=f"Number of conversations per bucket (default: {CHAINS_PER_BUCKET})"
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducibility (default: 42)"
)
parser.add_argument(
"--dataset",
type=str,
default=DATASET_NAME,
help=f"HuggingFace dataset name (default: {DATASET_NAME})"
)
args = parser.parse_args()
random.seed(args.seed)
buckets = sorted(args.buckets)
print("="*60)
print("LLM Benchmark Dataset Generator")
print("="*60)
print(f"Output: {args.output}")
print(f"Buckets: {buckets}")
print(f"Conversations per bucket: {args.chains_per_bucket}")
print(f"Random seed: {args.seed}")
print("="*60)
print(f"\nLoading dataset: {args.dataset}")
try:
dataset = load_dataset(args.dataset, split="train_sft")
except Exception as e:
print(f"Error loading dataset: {e}")
print("Make sure you have internet connection and the 'datasets' package installed:")
print(" pip install datasets")
return
print(f"Initializing tokenizer: {ENCODING_NAME}")
try:
encoding = tiktoken.get_encoding(ENCODING_NAME)
except Exception as e:
print(f"Error loading tokenizer: {e}")
print("Make sure you have 'tiktoken' installed:")
print(" pip install tiktoken")
return
bucketed_chains: dict[int, list[dict]] = defaultdict(list)
all_conversations: list[list[dict]] = []
print(f"\nProcessing {len(dataset)} conversation chains...")
for idx, row in enumerate(dataset):
messages = row.get("messages", [])
if not messages:
continue
formatted = format_ultrachat_messages(messages)
if not formatted:
continue
token_count = count_tokens(formatted, encoding)
bucket = get_bucket(token_count, buckets)
all_conversations.append(formatted)
if bucket is not None:
bucketed_chains[bucket].append(
{
"messages": formatted,
"token_count": token_count,
"bucket": bucket,
"original_index": idx,
"synthetic": False,
}
)
if (idx + 1) % 50000 == 0:
print(f" Processed {idx + 1:,} chains...")
print(f"\nTotal conversations collected: {len(all_conversations):,}")
print("\nNatural bucket distribution:")
print("-" * 60)
for bucket in buckets:
count = len(bucketed_chains[bucket])
status = "!" if count >= args.chains_per_bucket else f" need {args.chains_per_bucket - count} more"
print(f" {bucket:>6,} tokens: {count:>5,} chains {status}")
# Generate synthetic conversations for sparse buckets
print("\nGenerating synthetic chains for sparse buckets...")
large_buckets = [b for b in buckets if len(bucketed_chains[b]) < args.chains_per_bucket]
for bucket in large_buckets:
needed = args.chains_per_bucket - len(bucketed_chains[bucket])
if needed <= 0:
continue
print(f" Creating {needed} synthetic chains for {bucket:,} token bucket...")
attempts = 0
max_attempts = needed * 20
created = 0
while len(bucketed_chains[bucket]) < args.chains_per_bucket and attempts < max_attempts:
attempts += 1
synthetic = concatenate_conversations(
[c.copy() for c in all_conversations],
bucket,
encoding
)
if synthetic:
token_count = count_tokens(synthetic, encoding)
bucketed_chains[bucket].append(
{
"messages": synthetic,
"token_count": token_count,
"bucket": bucket,
"original_index": -1,
"synthetic": True,
}
)
created += 1
if created < needed:
print(f" Only created {created}/{needed} synthetic chains")
print("\nFinal bucket distribution:")
print("-" * 60)
final_dataset = {}
total_natural = 0
total_synthetic = 0
for bucket in buckets:
chains = bucketed_chains[bucket]
count = len(chains)
if count >= args.chains_per_bucket:
selected = random.sample(chains, args.chains_per_bucket)
else:
selected = chains
if count < args.chains_per_bucket:
print(f" {bucket:>6,} tokens: {count:>5,} chains insufficient (target: {args.chains_per_bucket})")
selected = chains # Use what we have
natural = sum(1 for c in selected if not c.get("synthetic", False))
synthetic = len(selected) - natural
total_natural += natural
total_synthetic += synthetic
print(f" {bucket:>6,} tokens: {len(selected):>3} chains ({natural} natural, {synthetic} synthetic)")
final_dataset[str(bucket)] = selected
# Save dataset
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_dataset, f, indent=2, ensure_ascii=False)
print("-" * 60)
print(f"\n Dataset saved to: {output_path}")
total_chains = sum(len(chains) for chains in final_dataset.values())
print(f"\nTotal chains: {total_chains:,}")
print(f"Natural conversations: {total_natural:,}")
print(f"Synthetic conversations: {total_synthetic:,}")
print("\nBucket summary:")
for bucket in buckets:
chains = final_dataset.get(str(bucket), [])
if chains:
avg_tokens = sum(c["token_count"] for c in chains) / len(chains)
min_tokens = min(c["token_count"] for c in chains)
max_tokens = max(c["token_count"] for c in chains)
print(f" {bucket:>6,} tokens: {len(chains):>3} chains, "
f"avg={avg_tokens:>6,.0f}, min={min_tokens:>6,}, max={max_tokens:>6,}")
print("\n" + "="*60)
print("To use this dataset with benchmark:")
print("="*60)
print(f" python benchmark_llm.py --dataset {args.output} ...")
print("="*60)
if __name__ == "__main__":
main()