Skip to content

Lerobot

dataphy.visualization.lerobot

LeRobot dataset visualization using rerun.io.

Functions

to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray

Convert torch tensor to numpy array for rerun visualization.

Source code in src/dataphy/visualization/lerobot.py
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
    """Convert torch tensor to numpy array for rerun visualization."""
    assert chw_float32_torch.dtype == torch.float32
    assert chw_float32_torch.ndim == 3
    c, h, w = chw_float32_torch.shape
    assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
    hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
    return hwc_uint8_numpy

parse_timestep_range(timestep_range: Optional[str]) -> Optional[Tuple[int, int]]

Parse timestep range string into tuple.

Source code in src/dataphy/visualization/lerobot.py
def parse_timestep_range(timestep_range: Optional[str]) -> Optional[Tuple[int, int]]:
    """Parse timestep range string into tuple."""
    if not timestep_range:
        return None

    try:
        start, end = map(int, timestep_range.split(","))
        return (start, end)
    except ValueError:
        raise ValueError("timestep-range must be in format 'start,end'")

visualize_lerobot_dataset(dataset_path: Optional[str] = None, repo_id: Optional[str] = None, episode_id: Optional[str] = None, timestep_range: Optional[str] = None, camera: Optional[str] = None)

Visualize a LeRobot dataset using rerun.io.

Parameters:

Name Type Description Default
dataset_path Optional[str]

Path to the LeRobot dataset directory (if None, uses repo_id)

None
repo_id Optional[str]

Hugging Face repository ID (if None, uses dataset_path)

None
episode_id Optional[str]

Specific episode ID to visualize (if None, shows first episode)

None
timestep_range Optional[str]

Timestep range as 'start,end' (if None, shows all timesteps)

None
camera Optional[str]

Camera name to visualize (if None, shows all cameras)

None
Source code in src/dataphy/visualization/lerobot.py
def visualize_lerobot_dataset(
    dataset_path: Optional[str] = None,
    repo_id: Optional[str] = None,
    episode_id: Optional[str] = None,
    timestep_range: Optional[str] = None,
    camera: Optional[str] = None,
):
    """Visualize a LeRobot dataset using rerun.io.

    Args:
        dataset_path: Path to the LeRobot dataset directory (if None, uses repo_id)
        repo_id: Hugging Face repository ID (if None, uses dataset_path)
        episode_id: Specific episode ID to visualize (if None, shows first episode)
        timestep_range: Timestep range as 'start,end' (if None, shows all timesteps)
        camera: Camera name to visualize (if None, shows all cameras)
    """
    # Validate inputs
    if dataset_path is None and repo_id is None:
        raise ValueError("Either dataset_path or repo_id must be provided")

    if dataset_path is not None and repo_id is not None:
        raise ValueError("Only one of dataset_path or repo_id should be provided")

    # Parse timestep range
    timestep_range_tuple = parse_timestep_range(timestep_range)

    # Load the LeRobot dataset  
    if 'console' not in locals():
        from rich.console import Console
        console = Console()
    console.print(f"[green]Loading LeRobot dataset...[/green]")
    try:
        if repo_id is not None:
            # Try standard LeRobot loading for repo_id first, then with tolerance if needed
            try:
                dataset = LeRobotDataset(repo_id)
            except Exception as repo_error:
                if "timestamps unexpectedly violate the tolerance" in str(repo_error):
                    console.print(f"[yellow]Dataset has sync issues, trying with increasing tolerance...[/yellow]")
                    last_error = repo_error
                    for tolerance in range(20, 101, 20):
                        try:
                            console.print(f"[yellow]Trying with tolerance_s={tolerance}...[/yellow]")
                            dataset = LeRobotDataset(repo_id, tolerance_s=float(tolerance))
                            break
                        except Exception as e:
                            last_error = e
                            if tolerance == 100:
                                raise last_error
                            if "timestamps unexpectedly violate the tolerance" not in str(e):
                                raise e
                    else:
                        raise last_error
                else:
                    raise repo_error
        else:
            # Load local dataset using root parameter
            from pathlib import Path
            dataset_path_obj = Path(dataset_path).resolve()

            if not dataset_path_obj.exists():
                raise FileNotFoundError(f"Dataset path does not exist: {dataset_path}")

            # Check if this looks like a LeRobot dataset
            has_data_dir = (dataset_path_obj / "data").exists()
            has_episode_dir = (dataset_path_obj / "episodes").exists()
            has_dataset_info = (dataset_path_obj / "dataset_info.json").exists()

            if not (has_data_dir or has_episode_dir or has_dataset_info):
                console.print(f"[yellow]Warning:[/yellow] Directory {dataset_path} doesn't look like a LeRobot dataset.")
                console.print("[yellow]Expected to find 'data/', 'episodes/', or 'dataset_info.json'[/yellow]")

            dataset_name = dataset_path_obj.name
            console.print(f"[green]Loading local dataset:[/green] {dataset_name} from {dataset_path}")

            try:
                # Direct visualization approach - bypass LeRobot's data loading issues
                from dataphy.dataset.registry import create_dataset_loader, DatasetFormat

                dataphy_loader = create_dataset_loader(dataset_path, DatasetFormat.LEROBOT)

                # Get dataset info
                dataset_info = dataphy_loader.load_info()
                console.print(f"[green]Dataset Info:[/green] {dataset_info.num_episodes} episodes, {dataset_info.total_timesteps} timesteps")

                # Create direct rerun visualization
                return _create_direct_rerun_visualization(
                    dataphy_loader,
                    dataset_path_obj,
                    episode_id,
                    timestep_range_tuple,
                    camera,
                    console
                )

            except Exception as fallback_error:
                console.print(f"[red]Direct visualization also failed:[/red] {fallback_error}")
                console.print("[red]All methods failed to load the local dataset.[/red]")
                console.print("[yellow]Tip:[/yellow] Make sure your dataset directory contains:")
                console.print("  • data/ directory with parquet files")
                console.print("  • meta/ directory with info.json")
                console.print("  • videos/ directory (optional)")
                raise RuntimeError("Failed to load local dataset with any method")
    except Exception as e:
        console.print(f"[red]Error loading dataset:[/red] {e}")
        raise

    # Get episode information
    episode_indices = list(range(len(dataset.episode_data_index["from"])))
    console.print(f"[green]Available episodes:[/green] {len(episode_indices)}")

    # Select episode
    if episode_id is None:
        episode_index = 0
        console.print(f"[green]No episode specified, using first episode:[/green] {episode_index}")
    else:
        # Try to parse episode_id as index
        try:
            episode_index = int(episode_id)
        except ValueError:
            # Try to find episode by name
            episode_index = None
            for i, ep_idx in enumerate(episode_indices):
                if f"episode_{ep_idx}" == episode_id or str(ep_idx) == episode_id:
                    episode_index = ep_idx
                    break

            if episode_index is None:
                raise ValueError(f"Episode {episode_id} not found. Available episodes: {episode_indices}")

    if episode_index >= len(episode_indices):
        raise ValueError(f"Episode index {episode_index} out of range. Available episodes: {len(episode_indices)}")

    # Get episode frame range
    from_idx = dataset.episode_data_index["from"][episode_index].item()
    to_idx = dataset.episode_data_index["to"][episode_index].item()
    total_frames = to_idx - from_idx

    console.print(f"[green]Episode {episode_index}:[/green] frames {from_idx} to {to_idx} ({total_frames} total)")

    # Apply timestep range if specified
    if timestep_range_tuple:
        start, end = timestep_range_tuple
        if start < 0 or end > total_frames or start >= end:
            raise ValueError(f"Invalid timestep range {start}-{end}. Episode has {total_frames} frames.")
        frame_start = from_idx + start
        frame_end = from_idx + end
        console.print(f"[green]Visualizing frames:[/green] {frame_start}-{frame_end} ({end - start} total)")
    else:
        frame_start = from_idx
        frame_end = to_idx
        console.print(f"[green]Visualizing all frames:[/green] {frame_start}-{frame_end}")

    # Initialize rerun
    rr.init(f"dataphy-lerobot-viewer", spawn=True)

    # Manually call python garbage collector after `rr.init` to avoid hanging
    gc.collect()

    # Get camera keys
    camera_keys = dataset.meta.camera_keys
    console.print(f"[green]Available cameras:[/green] {camera_keys}")

    # Filter cameras if specified
    if camera:
        if camera not in camera_keys:
            raise ValueError(f"Camera {camera} not found. Available cameras: {camera_keys}")
        camera_keys = [camera]

    # Visualize each frame
    console.print(f"[green]Visualizing frames...[/green]")
    for frame_idx in range(frame_start, frame_end):
        # Get frame data
        frame_data = dataset[frame_idx]

        # Set timing
        rr.set_time_sequence("frame_index", frame_data["frame_index"].item())
        rr.set_time_seconds("timestamp", frame_data["timestamp"].item())

        # Display each camera image
        for key in camera_keys:
            if key in frame_data:
                try:
                    image = to_hwc_uint8_numpy(frame_data[key])
                    rr.log(key, rr.Image(image))
                except Exception as e:
                    console.print(f"[yellow]Warning:[/yellow] Could not visualize camera {key}: {e}")

        # Display each dimension of action space
        if "action" in frame_data:
            for dim_idx, val in enumerate(frame_data["action"]):
                rr.log(f"action/{dim_idx}", rr.Scalars([val.item()]))

        # Display each dimension of observed state space
        if "observation.state" in frame_data:
            for dim_idx, val in enumerate(frame_data["observation.state"]):
                rr.log(f"state/{dim_idx}", rr.Scalars([val.item()]))

        # Display additional metadata
        if "next.done" in frame_data:
            rr.log("next.done", rr.Scalars([frame_data["next.done"].item()]))

        if "next.reward" in frame_data:
            rr.log("next.reward", rr.Scalars([frame_data["next.reward"].item()]))

        if "next.success" in frame_data:
            rr.log("next.success", rr.Scalars([frame_data["next.success"].item()]))

        # Progress indicator
        if (frame_idx - frame_start + 1) % 10 == 0 or frame_idx == frame_end - 1:
            progress = (frame_idx - frame_start + 1) / (frame_end - frame_start)
            print(f"\rCreating visualization: {progress:.1%} ({frame_idx - frame_start + 1}/{frame_end - frame_start})", end='', flush=True)
            if frame_idx == frame_end - 1:
                print()

    # Keep the viewer open
    try:
        input("Press Enter to close the visualization...")
    except KeyboardInterrupt:
        console.print("\n[yellow]Visualization interrupted[/yellow]")
    finally:
        rr.disconnect()