339 lines
11 KiB
Python
339 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Create bucketed test dataset for LLM benchmarking.
|
|
|
|
Uses multiple strategies to fill all token buckets:
|
|
1. Natural conversations from UltraChat dataset
|
|
2. Concatenation of shorter conversations for larger buckets
|
|
|
|
Buckets aligned with benchmark input_tokens: 100, 500, 1k, 2k, 5k, 10k
|
|
Outputs 128 unique conversations per bucket for comprehensive testing.
|
|
|
|
Usage:
|
|
python create_test_dataset.py
|
|
python create_test_dataset.py --output test_conversations.json
|
|
python create_test_dataset.py --buckets 1000 5000 10000 --chains_per_bucket 64
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import random
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import tiktoken
|
|
from datasets import load_dataset
|
|
|
|
# Default buckets aligned with typical benchmark configurations
|
|
DEFAULT_BUCKETS = [100, 500, 1_000, 2_000, 5_000, 10_000]
|
|
CHAINS_PER_BUCKET = 128
|
|
DATASET_NAME = "HuggingFaceH4/ultrachat_200k"
|
|
ENCODING_NAME = "cl100k_base"
|
|
|
|
|
|
def count_tokens(messages: list[dict], encoding: tiktoken.Encoding) -> int:
|
|
"""Count total tokens in a conversation chain."""
|
|
total = 0
|
|
for msg in messages:
|
|
content = msg.get("content", "") or ""
|
|
role = msg.get("role", "") or ""
|
|
total += len(encoding.encode(content, disallowed_special=()))
|
|
total += len(encoding.encode(role, disallowed_special=()))
|
|
total += 4 # Message formatting overhead
|
|
total += 2 # Conversation formatting overhead
|
|
return total
|
|
|
|
|
|
def get_bucket(token_count: int, buckets: list[int]) -> int | None:
|
|
"""Find the appropriate bucket for a token count (within 20% of target)."""
|
|
for bucket in buckets:
|
|
if bucket * 0.8 <= token_count <= bucket * 1.2:
|
|
return bucket
|
|
return None
|
|
|
|
|
|
def format_ultrachat_messages(messages: list[dict]) -> list[dict]:
|
|
"""Format UltraChat conversations to OpenAI chat format."""
|
|
formatted = []
|
|
for msg in messages:
|
|
role = msg.get("role", "user")
|
|
if role not in ["user", "assistant", "system"]:
|
|
role = "user"
|
|
content = msg.get("content", "") or ""
|
|
if content:
|
|
formatted.append({"role": role, "content": content})
|
|
return formatted
|
|
|
|
|
|
def concatenate_conversations(
|
|
conversations: list[list[dict]],
|
|
target_tokens: int,
|
|
encoding: tiktoken.Encoding,
|
|
tolerance: float = 0.2
|
|
) -> list[dict] | None:
|
|
"""Concatenate multiple conversations to reach target token count."""
|
|
result = []
|
|
current_tokens = 0
|
|
target_min = target_tokens * (1 - tolerance)
|
|
target_max = target_tokens * (1 + tolerance)
|
|
|
|
random.shuffle(conversations)
|
|
|
|
for conv in conversations:
|
|
conv_tokens = count_tokens(conv, encoding)
|
|
|
|
# Skip if this would exceed target
|
|
if current_tokens + conv_tokens > target_max:
|
|
continue
|
|
|
|
# Add separator between conversations
|
|
if result and conv:
|
|
separator = {"role": "user", "content": "---\nNew conversation:\n---"}
|
|
result.append(separator)
|
|
current_tokens += 10 # Approximate tokens for separator
|
|
|
|
result.extend(conv)
|
|
current_tokens += conv_tokens
|
|
|
|
# Check if we've reached target
|
|
if current_tokens >= target_min:
|
|
break
|
|
|
|
# Verify we're within acceptable range
|
|
if current_tokens < target_min * 0.8:
|
|
return None
|
|
|
|
return result
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Create bucketed test dataset for LLM benchmarking",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Default configuration (128 conversations per bucket)
|
|
python create_test_dataset.py
|
|
|
|
# Custom buckets
|
|
python create_test_dataset.py --buckets 1000 5000 10000
|
|
|
|
# Fewer conversations per bucket
|
|
python create_test_dataset.py --chains_per_bucket 64
|
|
|
|
# Custom output location
|
|
python create_test_dataset.py --output data/conversations.json
|
|
"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
default="test_conversations.json",
|
|
help="Output file path (default: test_conversations.json)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--buckets",
|
|
type=int,
|
|
nargs='+',
|
|
default=DEFAULT_BUCKETS,
|
|
help="Token count buckets (default: 100 500 1000 2000 5000 10000)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--chains_per_bucket",
|
|
type=int,
|
|
default=CHAINS_PER_BUCKET,
|
|
help=f"Number of conversations per bucket (default: {CHAINS_PER_BUCKET})"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=42,
|
|
help="Random seed for reproducibility (default: 42)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--dataset",
|
|
type=str,
|
|
default=DATASET_NAME,
|
|
help=f"HuggingFace dataset name (default: {DATASET_NAME})"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
random.seed(args.seed)
|
|
buckets = sorted(args.buckets)
|
|
|
|
print("="*60)
|
|
print("LLM Benchmark Dataset Generator")
|
|
print("="*60)
|
|
print(f"Output: {args.output}")
|
|
print(f"Buckets: {buckets}")
|
|
print(f"Conversations per bucket: {args.chains_per_bucket}")
|
|
print(f"Random seed: {args.seed}")
|
|
print("="*60)
|
|
|
|
print(f"\nLoading dataset: {args.dataset}")
|
|
try:
|
|
dataset = load_dataset(args.dataset, split="train_sft")
|
|
except Exception as e:
|
|
print(f"Error loading dataset: {e}")
|
|
print("Make sure you have internet connection and the 'datasets' package installed:")
|
|
print(" pip install datasets")
|
|
return
|
|
|
|
print(f"Initializing tokenizer: {ENCODING_NAME}")
|
|
try:
|
|
encoding = tiktoken.get_encoding(ENCODING_NAME)
|
|
except Exception as e:
|
|
print(f"Error loading tokenizer: {e}")
|
|
print("Make sure you have 'tiktoken' installed:")
|
|
print(" pip install tiktoken")
|
|
return
|
|
|
|
bucketed_chains: dict[int, list[dict]] = defaultdict(list)
|
|
all_conversations: list[list[dict]] = []
|
|
|
|
print(f"\nProcessing {len(dataset)} conversation chains...")
|
|
|
|
for idx, row in enumerate(dataset):
|
|
messages = row.get("messages", [])
|
|
if not messages:
|
|
continue
|
|
|
|
formatted = format_ultrachat_messages(messages)
|
|
if not formatted:
|
|
continue
|
|
|
|
token_count = count_tokens(formatted, encoding)
|
|
bucket = get_bucket(token_count, buckets)
|
|
|
|
all_conversations.append(formatted)
|
|
|
|
if bucket is not None:
|
|
bucketed_chains[bucket].append(
|
|
{
|
|
"messages": formatted,
|
|
"token_count": token_count,
|
|
"bucket": bucket,
|
|
"original_index": idx,
|
|
"synthetic": False,
|
|
}
|
|
)
|
|
|
|
if (idx + 1) % 50000 == 0:
|
|
print(f" Processed {idx + 1:,} chains...")
|
|
|
|
print(f"\nTotal conversations collected: {len(all_conversations):,}")
|
|
print("\nNatural bucket distribution:")
|
|
print("-" * 60)
|
|
|
|
for bucket in buckets:
|
|
count = len(bucketed_chains[bucket])
|
|
status = "!" if count >= args.chains_per_bucket else f" need {args.chains_per_bucket - count} more"
|
|
print(f" {bucket:>6,} tokens: {count:>5,} chains {status}")
|
|
|
|
# Generate synthetic conversations for sparse buckets
|
|
print("\nGenerating synthetic chains for sparse buckets...")
|
|
large_buckets = [b for b in buckets if len(bucketed_chains[b]) < args.chains_per_bucket]
|
|
|
|
for bucket in large_buckets:
|
|
needed = args.chains_per_bucket - len(bucketed_chains[bucket])
|
|
if needed <= 0:
|
|
continue
|
|
|
|
print(f" Creating {needed} synthetic chains for {bucket:,} token bucket...")
|
|
attempts = 0
|
|
max_attempts = needed * 20
|
|
created = 0
|
|
|
|
while len(bucketed_chains[bucket]) < args.chains_per_bucket and attempts < max_attempts:
|
|
attempts += 1
|
|
synthetic = concatenate_conversations(
|
|
[c.copy() for c in all_conversations],
|
|
bucket,
|
|
encoding
|
|
)
|
|
|
|
if synthetic:
|
|
token_count = count_tokens(synthetic, encoding)
|
|
bucketed_chains[bucket].append(
|
|
{
|
|
"messages": synthetic,
|
|
"token_count": token_count,
|
|
"bucket": bucket,
|
|
"original_index": -1,
|
|
"synthetic": True,
|
|
}
|
|
)
|
|
created += 1
|
|
|
|
if created < needed:
|
|
print(f" Only created {created}/{needed} synthetic chains")
|
|
|
|
print("\nFinal bucket distribution:")
|
|
print("-" * 60)
|
|
|
|
final_dataset = {}
|
|
total_natural = 0
|
|
total_synthetic = 0
|
|
|
|
for bucket in buckets:
|
|
chains = bucketed_chains[bucket]
|
|
count = len(chains)
|
|
|
|
if count >= args.chains_per_bucket:
|
|
selected = random.sample(chains, args.chains_per_bucket)
|
|
else:
|
|
selected = chains
|
|
if count < args.chains_per_bucket:
|
|
print(f" {bucket:>6,} tokens: {count:>5,} chains insufficient (target: {args.chains_per_bucket})")
|
|
selected = chains # Use what we have
|
|
|
|
natural = sum(1 for c in selected if not c.get("synthetic", False))
|
|
synthetic = len(selected) - natural
|
|
total_natural += natural
|
|
total_synthetic += synthetic
|
|
|
|
print(f" {bucket:>6,} tokens: {len(selected):>3} chains ({natural} natural, {synthetic} synthetic)")
|
|
|
|
final_dataset[str(bucket)] = selected
|
|
|
|
# Save dataset
|
|
output_path = Path(args.output)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(final_dataset, f, indent=2, ensure_ascii=False)
|
|
|
|
print("-" * 60)
|
|
print(f"\n Dataset saved to: {output_path}")
|
|
|
|
total_chains = sum(len(chains) for chains in final_dataset.values())
|
|
print(f"\nTotal chains: {total_chains:,}")
|
|
print(f"Natural conversations: {total_natural:,}")
|
|
print(f"Synthetic conversations: {total_synthetic:,}")
|
|
|
|
print("\nBucket summary:")
|
|
for bucket in buckets:
|
|
chains = final_dataset.get(str(bucket), [])
|
|
if chains:
|
|
avg_tokens = sum(c["token_count"] for c in chains) / len(chains)
|
|
min_tokens = min(c["token_count"] for c in chains)
|
|
max_tokens = max(c["token_count"] for c in chains)
|
|
print(f" {bucket:>6,} tokens: {len(chains):>3} chains, "
|
|
f"avg={avg_tokens:>6,.0f}, min={min_tokens:>6,}, max={max_tokens:>6,}")
|
|
|
|
print("\n" + "="*60)
|
|
print("To use this dataset with benchmark:")
|
|
print("="*60)
|
|
print(f" python benchmark_llm.py --dataset {args.output} ...")
|
|
print("="*60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|