Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dataset seed parameter to config #41

Open
github-actions bot opened this issue Aug 4, 2024 · 0 comments
Open

add dataset seed parameter to config #41

github-actions bot opened this issue Aug 4, 2024 · 0 comments
Assignees
Labels

Comments

@github-actions
Copy link

github-actions bot commented Aug 4, 2024

# TODO: add dataset seed parameter to config

import numpy as np
from pathlib import Path
from safetensors import safe_open
from safetensors.torch import save_file
import torch
from torch.utils.data import Dataset

import bevy_zeroverse


# TODO: add sample-level world rotation augment
class View:
    def __init__(self, color, depth, normal, view_from_world, fovy, width, height):
        self.color = color
        self.depth = depth
        self.normal = normal
        self.view_from_world = view_from_world
        self.fovy = fovy
        self.width = width
        self.height = height

    @classmethod
    def from_rust(cls, rust_view, width, height):
        width = int(width)
        height = int(height)

        def reshape_data(data, dtype):
            return np.frombuffer(data, dtype=dtype).reshape(height, width, 4)

        if len(rust_view.color) == 0:
            print("empty color buffer")

        if len(rust_view.depth) == 0:
            print("empty depth buffer")

        if len(rust_view.normal) == 0:
            print("empty normal buffer")

        color = reshape_data(rust_view.color, np.uint8)
        depth = reshape_data(rust_view.depth, np.uint8)
        normal = reshape_data(rust_view.normal, np.uint8)

        view_from_world = np.array(rust_view.view_from_world)
        fovy = rust_view.fovy
        return cls(color, depth, normal, view_from_world, fovy, width, height)

    def to_tensors(self):
        color_tensor = torch.tensor(self.color, dtype=torch.uint8)
        depth_tensor = torch.tensor(self.depth, dtype=torch.uint8)
        normal_tensor = torch.tensor(self.normal, dtype=torch.uint8)

        color_tensor[..., 3] = 255
        depth_tensor[..., 3] = 255
        normal_tensor[..., 3] = 255

        view_from_world_tensor = torch.tensor(self.view_from_world, dtype=torch.float32)
        fovy_tensor = torch.tensor(self.fovy, dtype=torch.float32)
        return {
            'color': color_tensor,
            'depth': depth_tensor,
            'normal': normal_tensor,
            'view_from_world': view_from_world_tensor,
            'fovy': fovy_tensor
        }

class Sample:
    def __init__(self, views):
        self.views = views

    @classmethod
    def from_rust(cls, rust_sample, width, height):
        views = [View.from_rust(view, width, height) for view in rust_sample.views]
        return cls(views)

    def to_tensors(self):
        tensor_dict = {
            'color': [],
            'depth': [],
            'normal': [],
            'view_from_world': [],
            'fovy': []
        }

        if len(self.views) == 0:
            print("empty views")
            return tensor_dict

        for view in self.views:
            tensors = view.to_tensors()
            for key in tensor_dict:
                tensor_dict[key].append(tensors[key])

        for key in tensor_dict:
            tensor_dict[key] = torch.stack(tensor_dict[key], dim=0)

        return tensor_dict


# TODO: add dataset seed parameter to config
class BevyZeroverseDataset(Dataset):
    def __init__(self, editor, headless, num_cameras, width, height, num_samples):
        self.editor = editor
        self.headless = headless
        self.num_cameras = num_cameras
        self.width = width
        self.height = height
        self.num_samples = int(num_samples)
        self.initialized = False

    def initialize(self):
        config = bevy_zeroverse.BevyZeroverseConfig()
        config.editor = self.editor
        config.headless = self.headless
        config.num_cameras = self.num_cameras
        config.width = self.width
        config.height = self.height
        config.scene_type = bevy_zeroverse.ZeroverseSceneType.Room
        bevy_zeroverse.initialize(config)
        self.initialized = True

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        if not self.initialized:
            self.initialize()

        rust_sample = bevy_zeroverse.next()
        sample = Sample.from_rust(rust_sample, self.width, self.height)
        return sample.to_tensors()


    def chunk_and_save(self, output_dir: Path, bytes_per_chunk: int):
        chunk_size = 0
        chunk_index = 0
        chunk = []
        original_samples = []

        def save_chunk():
            nonlocal chunk_size, chunk_index, chunk, original_samples
            chunk_key = f"{chunk_index:0>6}"
            print(f"saving chunk {chunk_key} of {self.num_samples} ({chunk_size / 1e6:.2f} MB).")
            output_dir.mkdir(exist_ok=True, parents=True)
            file_path = output_dir / f"{chunk_key}.safetensors"

            # Flatten the tensors dictionary
            flat_tensors = {f"{i}_{key}": tensor for i, sample in enumerate(chunk) for key, tensor in sample.items()}
            save_file(flat_tensors, str(file_path))

            chunk_size = 0
            chunk_index += 1
            chunk = []

        for idx in range(self.num_samples):
            sample = self[idx]
            sample_size = sum(tensor.numel() * tensor.element_size() for tensor in sample.values())
            chunk.append(sample)
            original_samples.append(sample)
            chunk_size += sample_size

            print(f"    added sample {idx} to chunk ({sample_size / 1e6:.2f} MB).")
            if chunk_size >= bytes_per_chunk:
                save_chunk()

        if chunk_size > 0:
            save_chunk()

        return original_samples


class ChunkedDataset(Dataset):
    def __init__(self, output_dir: Path):
        self.output_dir = output_dir
        self.chunk_files = sorted(output_dir.glob("*.safetensors"))
        self.chunks = [self.load_chunk(chunk_file) for chunk_file in self.chunk_files]

    def load_chunk(self, file_path: Path):
        with safe_open(str(file_path), framework="pt", device="cpu") as f:
            return {key: f.get_tensor(key) for key in f.keys()}

    def __len__(self):
        return len(self.chunk_files)

    def __getitem__(self, idx):
        chunk = self.chunks[idx]

        # Restructure the keys to remove sample indices
        batch = {}
        for key, tensor in chunk.items():
            _, base_key = key.split('_', 1)
            if base_key not in batch:
                batch[base_key] = []

            batch[base_key].append(tensor)

        # Stack the tensors for each key
        batch = {key: torch.stack(tensors, dim=0) for key, tensors in batch.items()}
        return batch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant