Skip to content

Commit

Permalink
Handle error via rich error console
Browse files Browse the repository at this point in the history
  • Loading branch information
dudeperf3ct committed Feb 19, 2025
1 parent d73742e commit fcddc75
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
9 changes: 5 additions & 4 deletions ecoml/src/ecoml/ecoml.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from rich.console import Console
from rich.table import Table

console = Console(stderr=True)
console = Console()
error_console = Console(stderr=True, style="bold red")
app = typer.Typer(no_args_is_help=True)

CONFIG = {"jetson_orin": {"pytorch": {"low": 5, "average": 7, "high": 10}}}
Expand All @@ -21,7 +22,7 @@ def validate_model(model_path: str):
model_summary = json.load(file)
return True, model_summary
except json.JSONDecodeError:
console.print("Invalid JSON file.")
error_console.print("Invalid JSON file.")
return False, None
return False, None

Expand Down Expand Up @@ -62,13 +63,13 @@ def predict(
If --verbose is used, a detailed summary of predictions is provided.
"""
cfg = CONFIG["jetson_orin"]["pytorch"]
success = validate_model(model)
success, _ = validate_model(model)
if success:
from ecoml.infer import run_inference

run_inference(model, power_profiles=cfg, verbose=verbose)
else:
console.print("Expected PyTorch model summary as a JSON file")
error_console.print("Expected PyTorch model summary as a JSON file")


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions ecoml/src/ecoml/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from ecoml.data_preparation.pytorch_utils import read_layers_info
from ecoml.model_builder.model_inference import InferenceModel

console = Console(stderr=True)
console = Console()
error_console = Console(stderr=True, style="bold red")


def get_metrics(df: pd.DataFrame, cfg: dict[str, int]) -> pd.DataFrame:
Expand Down Expand Up @@ -150,7 +151,7 @@ def run_inference(
data["layer_type"].append(layer_info.layer_type)

if not len(data):
console.print(
error_console.print(
"Looks like there are no convolutional, pooling or linear layers in the model"
)
return
Expand Down

0 comments on commit fcddc75

Please sign in to comment.