Skip to content

Commit

Permalink
Add model name when generating
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Oct 12, 2024
1 parent 7004cb5 commit f87fe93
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions vgslify/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ def __init__(self, backend: str = "auto") -> None:
backend : str, optional
The backend to use for building the model. Can be "tensorflow", "torch", or "auto".
Default is "auto", which will attempt to automatically detect the available backend.
model_name : str, optional
The name of the model, by default "VGSL_Model"
"""
self.backend = self._detect_backend(backend)
self.layer_factory_class, self.layer_constructors = self._initialize_backend_and_factory(
self.backend)
self.layer_factory = self.layer_factory_class()

def generate_model(self, model_spec: str) -> Any:
def generate_model(self, model_spec: str, model_name: str = "VGSL_Model") -> Any:
"""
Build the model based on the VGSL spec string.
Expand All @@ -51,7 +53,7 @@ def generate_model(self, model_spec: str) -> Any:
Any
The built model using the specified backend.
"""
return self._process_layers(model_spec, return_history=False)
return self._process_layers(model_spec, return_history=False, model_name=model_name)

def generate_history(self, model_spec: str) -> List[Any]:
"""
Expand All @@ -73,7 +75,7 @@ def generate_history(self, model_spec: str) -> List[Any]:
"""
return self._process_layers(model_spec, return_history=True)

def _process_layers(self, model_spec: str, return_history: bool = False) -> Any:
def _process_layers(self, model_spec: str, return_history: bool = False, model_name: str = "VGSL_Model") -> Any:
"""
Process the VGSL specification string to build the model or generate a history of layers.
Expand All @@ -83,6 +85,8 @@ def _process_layers(self, model_spec: str, return_history: bool = False) -> Any:
The VGSL specification string defining the model architecture.
return_history : bool, optional
If True, returns a list of constructed layers (history) instead of the final model.
model_name : str, optional
The name of the model, by default "VGSL_Model"
Returns
-------
Expand Down Expand Up @@ -112,7 +116,7 @@ def _process_layers(self, model_spec: str, return_history: bool = False) -> Any:
return history

# Build and return the final model
return self.layer_factory.build()
return self.layer_factory.build(name=model_name)

def construct_layer(self, spec: str) -> Any:
"""
Expand Down

0 comments on commit f87fe93

Please sign in to comment.