def visualize_lerobot_dataset(
dataset_path: Optional[str] = None,
repo_id: Optional[str] = None,
episode_id: Optional[str] = None,
timestep_range: Optional[str] = None,
camera: Optional[str] = None,
):
"""Visualize a LeRobot dataset using rerun.io.
Args:
dataset_path: Path to the LeRobot dataset directory (if None, uses repo_id)
repo_id: Hugging Face repository ID (if None, uses dataset_path)
episode_id: Specific episode ID to visualize (if None, shows first episode)
timestep_range: Timestep range as 'start,end' (if None, shows all timesteps)
camera: Camera name to visualize (if None, shows all cameras)
"""
# Validate inputs
if dataset_path is None and repo_id is None:
raise ValueError("Either dataset_path or repo_id must be provided")
if dataset_path is not None and repo_id is not None:
raise ValueError("Only one of dataset_path or repo_id should be provided")
# Parse timestep range
timestep_range_tuple = parse_timestep_range(timestep_range)
# Load the LeRobot dataset
if 'console' not in locals():
from rich.console import Console
console = Console()
console.print(f"[green]Loading LeRobot dataset...[/green]")
try:
if repo_id is not None:
# Try standard LeRobot loading for repo_id first, then with tolerance if needed
try:
dataset = LeRobotDataset(repo_id)
except Exception as repo_error:
if "timestamps unexpectedly violate the tolerance" in str(repo_error):
console.print(f"[yellow]Dataset has sync issues, trying with increasing tolerance...[/yellow]")
last_error = repo_error
for tolerance in range(20, 101, 20):
try:
console.print(f"[yellow]Trying with tolerance_s={tolerance}...[/yellow]")
dataset = LeRobotDataset(repo_id, tolerance_s=float(tolerance))
break
except Exception as e:
last_error = e
if tolerance == 100:
raise last_error
if "timestamps unexpectedly violate the tolerance" not in str(e):
raise e
else:
raise last_error
else:
raise repo_error
else:
# Load local dataset using root parameter
from pathlib import Path
dataset_path_obj = Path(dataset_path).resolve()
if not dataset_path_obj.exists():
raise FileNotFoundError(f"Dataset path does not exist: {dataset_path}")
# Check if this looks like a LeRobot dataset
has_data_dir = (dataset_path_obj / "data").exists()
has_episode_dir = (dataset_path_obj / "episodes").exists()
has_dataset_info = (dataset_path_obj / "dataset_info.json").exists()
if not (has_data_dir or has_episode_dir or has_dataset_info):
console.print(f"[yellow]Warning:[/yellow] Directory {dataset_path} doesn't look like a LeRobot dataset.")
console.print("[yellow]Expected to find 'data/', 'episodes/', or 'dataset_info.json'[/yellow]")
dataset_name = dataset_path_obj.name
console.print(f"[green]Loading local dataset:[/green] {dataset_name} from {dataset_path}")
try:
# Direct visualization approach - bypass LeRobot's data loading issues
from dataphy.dataset.registry import create_dataset_loader, DatasetFormat
dataphy_loader = create_dataset_loader(dataset_path, DatasetFormat.LEROBOT)
# Get dataset info
dataset_info = dataphy_loader.load_info()
console.print(f"[green]Dataset Info:[/green] {dataset_info.num_episodes} episodes, {dataset_info.total_timesteps} timesteps")
# Create direct rerun visualization
return _create_direct_rerun_visualization(
dataphy_loader,
dataset_path_obj,
episode_id,
timestep_range_tuple,
camera,
console
)
except Exception as fallback_error:
console.print(f"[red]Direct visualization also failed:[/red] {fallback_error}")
console.print("[red]All methods failed to load the local dataset.[/red]")
console.print("[yellow]Tip:[/yellow] Make sure your dataset directory contains:")
console.print(" • data/ directory with parquet files")
console.print(" • meta/ directory with info.json")
console.print(" • videos/ directory (optional)")
raise RuntimeError("Failed to load local dataset with any method")
except Exception as e:
console.print(f"[red]Error loading dataset:[/red] {e}")
raise
# Get episode information
episode_indices = list(range(len(dataset.episode_data_index["from"])))
console.print(f"[green]Available episodes:[/green] {len(episode_indices)}")
# Select episode
if episode_id is None:
episode_index = 0
console.print(f"[green]No episode specified, using first episode:[/green] {episode_index}")
else:
# Try to parse episode_id as index
try:
episode_index = int(episode_id)
except ValueError:
# Try to find episode by name
episode_index = None
for i, ep_idx in enumerate(episode_indices):
if f"episode_{ep_idx}" == episode_id or str(ep_idx) == episode_id:
episode_index = ep_idx
break
if episode_index is None:
raise ValueError(f"Episode {episode_id} not found. Available episodes: {episode_indices}")
if episode_index >= len(episode_indices):
raise ValueError(f"Episode index {episode_index} out of range. Available episodes: {len(episode_indices)}")
# Get episode frame range
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
total_frames = to_idx - from_idx
console.print(f"[green]Episode {episode_index}:[/green] frames {from_idx} to {to_idx} ({total_frames} total)")
# Apply timestep range if specified
if timestep_range_tuple:
start, end = timestep_range_tuple
if start < 0 or end > total_frames or start >= end:
raise ValueError(f"Invalid timestep range {start}-{end}. Episode has {total_frames} frames.")
frame_start = from_idx + start
frame_end = from_idx + end
console.print(f"[green]Visualizing frames:[/green] {frame_start}-{frame_end} ({end - start} total)")
else:
frame_start = from_idx
frame_end = to_idx
console.print(f"[green]Visualizing all frames:[/green] {frame_start}-{frame_end}")
# Initialize rerun
rr.init(f"dataphy-lerobot-viewer", spawn=True)
# Manually call python garbage collector after `rr.init` to avoid hanging
gc.collect()
# Get camera keys
camera_keys = dataset.meta.camera_keys
console.print(f"[green]Available cameras:[/green] {camera_keys}")
# Filter cameras if specified
if camera:
if camera not in camera_keys:
raise ValueError(f"Camera {camera} not found. Available cameras: {camera_keys}")
camera_keys = [camera]
# Visualize each frame
console.print(f"[green]Visualizing frames...[/green]")
for frame_idx in range(frame_start, frame_end):
# Get frame data
frame_data = dataset[frame_idx]
# Set timing
rr.set_time_sequence("frame_index", frame_data["frame_index"].item())
rr.set_time_seconds("timestamp", frame_data["timestamp"].item())
# Display each camera image
for key in camera_keys:
if key in frame_data:
try:
image = to_hwc_uint8_numpy(frame_data[key])
rr.log(key, rr.Image(image))
except Exception as e:
console.print(f"[yellow]Warning:[/yellow] Could not visualize camera {key}: {e}")
# Display each dimension of action space
if "action" in frame_data:
for dim_idx, val in enumerate(frame_data["action"]):
rr.log(f"action/{dim_idx}", rr.Scalars([val.item()]))
# Display each dimension of observed state space
if "observation.state" in frame_data:
for dim_idx, val in enumerate(frame_data["observation.state"]):
rr.log(f"state/{dim_idx}", rr.Scalars([val.item()]))
# Display additional metadata
if "next.done" in frame_data:
rr.log("next.done", rr.Scalars([frame_data["next.done"].item()]))
if "next.reward" in frame_data:
rr.log("next.reward", rr.Scalars([frame_data["next.reward"].item()]))
if "next.success" in frame_data:
rr.log("next.success", rr.Scalars([frame_data["next.success"].item()]))
# Progress indicator
if (frame_idx - frame_start + 1) % 10 == 0 or frame_idx == frame_end - 1:
progress = (frame_idx - frame_start + 1) / (frame_end - frame_start)
print(f"\rCreating visualization: {progress:.1%} ({frame_idx - frame_start + 1}/{frame_end - frame_start})", end='', flush=True)
if frame_idx == frame_end - 1:
print()
# Keep the viewer open
try:
input("Press Enter to close the visualization...")
except KeyboardInterrupt:
console.print("\n[yellow]Visualization interrupted[/yellow]")
finally:
rr.disconnect()