Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

render rollout merge #87

Merged
merged 26 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
18f3b7c
render rollout merge
Naveen-Raj-M Sep 9, 2024
e22406a
Update config.yaml
Naveen-Raj-M Sep 24, 2024
23bfe0c
Merge branch 'v2' of https://github.com/geoelements/gns into v2
Naveen-Raj-M Oct 1, 2024
8d5d7b4
Added unit test for render-rollout merge
Naveen-Raj-M Oct 1, 2024
2f9ab01
Merge branch 'v2' of https://github.com/Naveen-Raj-M/gns into v2
Naveen-Raj-M Oct 1, 2024
5146e57
deleted debugging fixtures
Naveen-Raj-M Oct 2, 2024
a115cb0
update config
Naveen-Raj-M Oct 4, 2024
2bcf615
add test for VTK rendering
Naveen-Raj-M Oct 4, 2024
cdf6f45
bug fix for NoneType material property
Naveen-Raj-M Oct 4, 2024
f9a2a9e
remove test_rendering and temp directory
Naveen-Raj-M Oct 4, 2024
57fad31
add test for vtk rendering
Naveen-Raj-M Oct 5, 2024
b387599
modify config for render_rollout merge
Naveen-Raj-M Oct 5, 2024
db3aec3
update to merge render-rollout
Naveen-Raj-M Oct 5, 2024
74b1e43
set default mode to gif
Naveen-Raj-M Oct 5, 2024
e08fcf5
rewrite 'rendering' function in an extensible way
Naveen-Raj-M Oct 10, 2024
3b88f65
improve readability and consistency
Naveen-Raj-M Oct 10, 2024
9fbafbb
update rendering options
Naveen-Raj-M Oct 10, 2024
81bc3f0
run black
Oct 11, 2024
f82c674
minor fix on viewpoint_rotation type
Oct 12, 2024
a861ec4
improve logging and reformat with black
Oct 12, 2024
7110352
refactor: move n_files function to a separate count_n_files.py in uti…
Oct 12, 2024
a86e58f
rename count_n_files.py to file_utils.py
Naveen-Raj-M Oct 13, 2024
b72770a
minor fix on module import
Naveen-Raj-M Oct 14, 2024
94c8519
run black
Naveen-Raj-M Oct 20, 2024
037b020
minor fix on raising error
Naveen-Raj-M Oct 22, 2024
ef41a87
add package for reading vtk files
Naveen-Raj-M Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gns/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class LoggingConfig:
class GifConfig:
step_stride: int = 3
vertical_camera_angle: int = 20
viewpoint_rotation: int = 0.3
viewpoint_rotation: float = 0.3
change_yz: bool = False

@dataclass
Expand Down
172 changes: 89 additions & 83 deletions test/test_vtk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from gns.train import rendering
from omegaconf import DictConfig
from typing import Tuple
from utils.count_n_files import n_files

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@pytest.fixture
def cfg_vtk() -> DictConfig:
"""
Expand All @@ -24,11 +26,8 @@ def cfg_vtk() -> DictConfig:
DictConfig: Configuration dictionary for VTK rendering mode.
"""
logger.info("Setting up VTK configuration.")
return DictConfig({
'rendering': {
'mode': 'vtk'
}
})
return DictConfig({"rendering": {"mode": "vtk"}})


@pytest.fixture
def dummy_pkl_data() -> dict:
Expand All @@ -38,7 +37,6 @@ def dummy_pkl_data() -> dict:
Returns:
Dict: A dictionary containing dummy rollout data.
"""
logger.info("Generating dummy pickle data.")
n_timesteps = 2
n_particles = 3
dim = 2
Expand All @@ -48,90 +46,83 @@ def dummy_pkl_data() -> dict:
# Generate random predictions and ground truth positions
predictions = np.random.rand(n_timesteps, n_particles, dim)
ground_truth_positions = np.random.randn(n_timesteps, n_particles, dim)
loss = (predictions - ground_truth_positions)**2
loss = (predictions - ground_truth_positions) ** 2

# Rollout dictionary to store all relevant information
dummy_rollout = {
"initial_positions": np.random.rand(n_init_pos, n_particles, dim),
"predicted_rollout": predictions,
"ground_truth_rollout": ground_truth_positions,
"particle_types": np.full(n_particles, 5),
"metadata": {
"bounds": [[0.0, 1.0], [0.0, 1.0]]
},
"loss": loss.mean()
"metadata": {"bounds": [[0.0, 1.0], [0.0, 1.0]]},
# MSE loss between predictions and ground truth positions
"loss": loss.mean(),
}
logger.info("Dummy pickle data generated successfully.")
logger.info(
f"Generated dummy data: {n_particles} particles over {n_timesteps} time steps in {dim} dimensions, "
f"with {n_init_pos} initial positions"
)

except Exception as e:
logger.error(f"Failed to generate dummy pickle data: {e}")
raise

return dummy_rollout


@pytest.fixture
def temp_dir_with_file(dummy_pkl_data: dict) -> Tuple[str, str]:
"""
kks32 marked this conversation as resolved.
Show resolved Hide resolved
Fixture for generating a temporary directory with dummy pickle data for testing.
Fixture to create a temporary directory and a pickle file containing the
provided dummy data for testing purposes.

Returns:
Tuple[str, str]: Path to the temporary directory and the pickle file name.
Args:
dummy_pkl_data (dict): A dictionary containing the dummy data to be
serialized and stored in a temporary pickle file.

Yields:
Tuple[str, str]: A tuple containing:
- The path to the temporary directory where the pickle file is stored.
- The base name of the pickle file (without the '.pkl' extension).
"""
temp_dir = tempfile.mkdtemp()
logger.info(f"Created temporary directory: {temp_dir}")

try:
with tempfile.NamedTemporaryFile(dir=temp_dir, suffix='.pkl', delete=False) as temp_file:
with tempfile.NamedTemporaryFile(
dir=temp_dir, suffix=".pkl", delete=False
) as temp_file:
pkl_file_path = temp_file.name

with open(pkl_file_path, 'wb') as f:
with open(pkl_file_path, "wb") as f:
pickle.dump(dummy_pkl_data, f)

# get the base file name without any extension
file_name = os.path.splitext(os.path.basename(pkl_file_path))[0]
logger.info(f"created temporary '.pkl' file: {file_name}")
temp_dir += '/'
temp_dir += "/"

yield temp_dir, file_name

except Exception as e:
logger.error(f"Failed to create a Temporary file for the input rollout data: {e}")
logger.error(
f"Failed to create a Temporary file for the input rollout data: {e}"
)

finally:
shutil.rmtree(temp_dir)



def n_files(directory: str, extension: str) -> int:
"""
Count the number of files with a specific extension in a directory.

Args:
dir (str): Directory path.
extension (str): File extension to count.

Returns:
int: Number of files with the specified extension.
"""
try:
pattern = os.path.join(directory, f'*.{extension}')
file_count = len(glob.glob(pattern))
logger.info(f"Counted {file_count} files with extension '{extension}' in {directory}.")
return file_count
except Exception as e:
logger.error(f"Error counting files in {directory} with extension '{extension}': {e}")
return 0

def verify_vtk_files(rollout_data: dict, label: str, temp_dir: str) -> None:
"""
Verify the integrity of VTK files against expected data.

This function checks VTK files (VTU and VTR) for specific properties, ensuring that the
displacement, particle types, color maps, and bounds match the expected values in the
This function checks VTK files (VTU and VTR) for specific properties, ensuring that the
displacement, particle types, color maps, and bounds match the expected values in the
provided rollout data.

Args:
rollout_data (dict): A dictionary containing
rollout_data (dict): A dictionary containing
- 'inital positions'
- 'predicted_rollout'
- 'ground_truth_rollout'
Expand All @@ -141,11 +132,10 @@ def verify_vtk_files(rollout_data: dict, label: str, temp_dir: str) -> None:
temp_dir (str): The temporary directory where the VTK files are stored.

Raises:
AssertionError: If any of the checks on displacement, particle types, color maps,
AssertionError: If any of the checks on displacement, particle types, color maps,
or bounds fail.
Exception: If there is an error reading or processing the VTK files.
"""
logger.info(f"Verifying VTK files for label: {label}.")
VTU_PREFIX = "points"
VTR_PREFIX = "boundary"
VTU_EXTENSION = "vtu"
Expand All @@ -158,70 +148,84 @@ def verify_vtk_files(rollout_data: dict, label: str, temp_dir: str) -> None:

for time_step in range(positions.shape[0]):
try:
vtu_path = os.path.join(temp_dir, f"{VTU_PREFIX}{time_step}.{VTU_EXTENSION}")
vtu_path = os.path.join(
temp_dir, f"{VTU_PREFIX}{time_step}.{VTU_EXTENSION}"
)
logger.info(f"Reading VTU file: {vtu_path}")
vtu_object = pv.read(vtu_path)

displacement = vtu_object['displacement']
particle_type = vtu_object['particle_type']
color_map = vtu_object['color']
displacement = vtu_object["displacement"]
particle_type = vtu_object["particle_type"]
color_map = vtu_object["color"]

assert np.all(displacement == np.linalg.norm(positions[0] - positions[time_step], axis=1)), (
assert np.all(
displacement
== np.linalg.norm(positions[0] - positions[time_step], axis=1)
), (
f"Displacement mismatch for timestep {time_step}: "
f"expected {np.linalg.norm(positions[0] - positions[time_step], axis=1)}, "
f"got {displacement}"
)
assert np.all(particle_type == rollout_data['particle_types']), (
assert np.all(particle_type == rollout_data["particle_types"]), (
f"Particle type mismatch for timestep {time_step}: "
f"expected {rollout_data['particle_types']}, got {particle_type}"
)
assert np.all(color_map == rollout_data['particle_types']), (
assert np.all(color_map == rollout_data["particle_types"]), (
f"Color map mismatch for timestep {time_step}: "
f"expected {rollout_data['particle_types']}, got {color_map}"
)

except AssertionError as e:
logger.error(f"Assertion failed while verifying {VTU_PREFIX}{time_step}.{VTU_EXTENSION}: {e}")
logger.error(
f"Assertion failed while verifying {VTU_PREFIX}{time_step}.{VTU_EXTENSION}: {e}"
)
raise
except Exception as e:
logger.error(f"Error reading {VTU_PREFIX}{time_step}.{VTU_EXTENSION}: {e}")
raise

try:
vtr_path = os.path.join(temp_dir, f"{VTR_PREFIX}{time_step}.{VTR_EXTENSION}")
vtr_path = os.path.join(
temp_dir, f"{VTR_PREFIX}{time_step}.{VTR_EXTENSION}"
)
logger.info(f"Reading VTR file: {vtr_path}")
vtr_object = pv.read(vtr_path)

bounds = vtr_object.bounds
xmin, xmax, ymin, ymax, zmin, zmax = bounds

assert xmin == rollout_data['metadata']['bounds'][0][0], (
assert xmin == rollout_data["metadata"]["bounds"][0][0], (
f"Xmin mismatch for timestep {time_step}: "
f"expected {rollout_data['metadata']['bounds'][0][0]}, got {xmin}"
)
assert xmax == rollout_data['metadata']['bounds'][0][1], (
assert xmax == rollout_data["metadata"]["bounds"][0][1], (
f"Xmax mismatch for timestep {time_step}: "
f"expected {rollout_data['metadata']['bounds'][0][1]}, got {xmax}"
)
assert ymin == rollout_data['metadata']['bounds'][1][0], (
assert ymin == rollout_data["metadata"]["bounds"][1][0], (
f"Ymin mismatch for timestep {time_step}: "
f"expected {rollout_data['metadata']['bounds'][1][0]}, got {ymin}"
)
assert ymax == rollout_data['metadata']['bounds'][1][1], (
assert ymax == rollout_data["metadata"]["bounds"][1][1], (
f"Ymax mismatch for timestep {time_step}: "
f"expected {rollout_data['metadata']['bounds'][1][1]}, got {ymax}"
)

except AssertionError as e:
logger.error(f"Assertion failed while verifying {VTR_PREFIX}{time_step}.{VTR_EXTENSION}: {e}")
logger.error(
f"Assertion failed while verifying {VTR_PREFIX}{time_step}.{VTR_EXTENSION}: {e}"
)
raise
except Exception as e:
logger.error(f"Error reading {VTR_PREFIX}{time_step}.{VTR_EXTENSION}: {e}")
raise

logger.info("VTK file verification complete.")

def test_rendering_vtk( temp_dir_with_file: Tuple[str, str], cfg_vtk: DictConfig) -> None:

def test_rendering_vtk(
temp_dir_with_file: Tuple[str, str], cfg_vtk: DictConfig
) -> None:
"""
Test the VTK rendering function.

Expand All @@ -238,7 +242,7 @@ def test_rendering_vtk( temp_dir_with_file: Tuple[str, str], cfg_vtk: DictConfig
except Exception as e:
logger.error("Failed to render with dummy rollout data")
raise

try:
# Define paths for the generated VTK files
VTK_GNS_SUFFIX = "_vtk-GNS"
Expand All @@ -252,27 +256,30 @@ def test_rendering_vtk( temp_dir_with_file: Tuple[str, str], cfg_vtk: DictConfig
rollout = pickle.load(file)

# Count the number of .vtu and .vtr files in the VTK directories
n_vtu_files_gns = n_files(vtk_path_gns, 'vtu')
n_vtu_files_reality = n_files(vtk_path_reality, 'vtu')
n_vtr_files_gns = n_files(vtk_path_gns, 'vtr')
n_vtr_files_reality = n_files(vtk_path_reality, 'vtr')

expected_n_files = rollout["initial_positions"].shape[0] + rollout["predicted_rollout"].shape[0]
n_vtu_files_gns = n_files(vtk_path_gns, "vtu")
n_vtu_files_reality = n_files(vtk_path_reality, "vtu")
n_vtr_files_gns = n_files(vtk_path_gns, "vtr")
n_vtr_files_reality = n_files(vtk_path_reality, "vtr")

expected_n_files = (
rollout["initial_positions"].shape[0]
+ rollout["predicted_rollout"].shape[0]
)
logger.info(f"Expected number of files: {expected_n_files}.")

# Assert that the number of .vtu and .vtr files matches the expected count
assert n_vtu_files_gns == expected_n_files, (
f"Expected {expected_n_files} VTU files in GNS path, got {n_vtu_files_gns}"
)
assert n_vtu_files_reality == expected_n_files, (
f"Expected {expected_n_files} VTU files in Reality path, got {n_vtu_files_reality}"
)
assert n_vtr_files_gns == expected_n_files, (
f"Expected {expected_n_files} VTR files in GNS path, got {n_vtr_files_gns}"
)
assert n_vtr_files_reality == expected_n_files, (
f"Expected {expected_n_files} VTR files in Reality path, got {n_vtr_files_reality}"
)
assert (
n_vtu_files_gns == expected_n_files
), f"Expected {expected_n_files} VTU files in GNS path, got {n_vtu_files_gns}"
assert (
n_vtu_files_reality == expected_n_files
), f"Expected {expected_n_files} VTU files in Reality path, got {n_vtu_files_reality}"
assert (
n_vtr_files_gns == expected_n_files
), f"Expected {expected_n_files} VTR files in GNS path, got {n_vtr_files_gns}"
assert (
n_vtr_files_reality == expected_n_files
), f"Expected {expected_n_files} VTR files in Reality path, got {n_vtr_files_reality}"

logger.info("Verifying VTK files for predicted rollout.")
verify_vtk_files(rollout, "predicted_rollout", vtk_path_gns)
Expand All @@ -283,4 +290,3 @@ def test_rendering_vtk( temp_dir_with_file: Tuple[str, str], cfg_vtk: DictConfig

except Exception as e:
logger.error(f"Rendering test failed: {e}")

29 changes: 29 additions & 0 deletions utils/count_n_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
Naveen-Raj-M marked this conversation as resolved.
Show resolved Hide resolved
import glob
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def n_files(directory: str, extension: str) -> int:
"""
Count the number of files with a specific extension in a directory.

Args:
directory (str): Directory path.
extension (str): File extension to count.

Returns:
int: Number of files with the specified extension.
"""
try:
pattern = os.path.join(directory, f"*.{extension}")
file_count = len(glob.glob(pattern))
return file_count
except Exception as e:
logger.error(
f"Error counting files in {directory} with extension '{extension}': {e}"
)
return 0