diff --git a/vgslify/generator.py b/vgslify/generator.py index 1ad5315..f3092ac 100644 --- a/vgslify/generator.py +++ b/vgslify/generator.py @@ -28,13 +28,15 @@ def __init__(self, backend: str = "auto") -> None: backend : str, optional The backend to use for building the model. Can be "tensorflow", "torch", or "auto". Default is "auto", which will attempt to automatically detect the available backend. + model_name : str, optional + The name of the model, by default "VGSL_Model" """ self.backend = self._detect_backend(backend) self.layer_factory_class, self.layer_constructors = self._initialize_backend_and_factory( self.backend) self.layer_factory = self.layer_factory_class() - def generate_model(self, model_spec: str) -> Any: + def generate_model(self, model_spec: str, model_name: str = "VGSL_Model") -> Any: """ Build the model based on the VGSL spec string. @@ -51,7 +53,7 @@ def generate_model(self, model_spec: str) -> Any: Any The built model using the specified backend. """ - return self._process_layers(model_spec, return_history=False) + return self._process_layers(model_spec, return_history=False, model_name=model_name) def generate_history(self, model_spec: str) -> List[Any]: """ @@ -73,7 +75,7 @@ def generate_history(self, model_spec: str) -> List[Any]: """ return self._process_layers(model_spec, return_history=True) - def _process_layers(self, model_spec: str, return_history: bool = False) -> Any: + def _process_layers(self, model_spec: str, return_history: bool = False, model_name: str = "VGSL_Model") -> Any: """ Process the VGSL specification string to build the model or generate a history of layers. @@ -83,6 +85,8 @@ def _process_layers(self, model_spec: str, return_history: bool = False) -> Any: The VGSL specification string defining the model architecture. return_history : bool, optional If True, returns a list of constructed layers (history) instead of the final model. + model_name : str, optional + The name of the model, by default "VGSL_Model" Returns ------- @@ -112,7 +116,7 @@ def _process_layers(self, model_spec: str, return_history: bool = False) -> Any: return history # Build and return the final model - return self.layer_factory.build() + return self.layer_factory.build(name=model_name) def construct_layer(self, spec: str) -> Any: """