@@ -52,21 +52,21 @@ def get_bundled_script_path(shell: str) -> Path:
5252 shell_dir = Path (__file__ ).parent / "shell"
5353 if shell == "zsh" :
5454 return shell_dir / "shelloracle.zsh"
55- elif shell == "bash" :
55+ else :
5656 return shell_dir / "shelloracle.bash"
5757
5858
5959def get_script_path (shell : str ) -> Path :
6060 if shell == "zsh" :
6161 return Path .home () / ".shelloracle.zsh"
62- elif shell == "bash" :
62+ else :
6363 return Path .home () / ".shelloracle.bash"
6464
6565
6666def get_rc_path (shell : str ) -> Path :
6767 if shell == "zsh" :
6868 return Path .home () / ".zshrc"
69- elif shell == "bash" :
69+ else :
7070 return Path .home () / ".bashrc"
7171
7272
@@ -91,7 +91,7 @@ def update_rc(shell: str) -> None:
9191 print_info (f"Successfully updated { replace_home_with_tilde (rc_path )} " )
9292
9393
94- def get_settings (provider : Provider ) -> Iterator [tuple [str , Setting ]]:
94+ def get_settings (provider : type [ Provider ] ) -> Iterator [tuple [str , Setting ]]:
9595 settings = inspect .getmembers (provider , predicate = lambda p : isinstance (p , Setting ))
9696
9797 def correct_name_setting ():
@@ -103,7 +103,7 @@ def correct_name_setting():
103103 yield from correct_name_setting ()
104104
105105
106- def write_shelloracle_config (provider : Provider , settings : dict [str , Any ]) -> None :
106+ def write_shelloracle_config (provider : type [ Provider ] , settings : dict [str , Any ]) -> None :
107107 config = tomlkit .document ()
108108
109109 shor_table = tomlkit .table ()
@@ -134,11 +134,14 @@ def install_keybindings() -> None:
134134 update_rc (shell )
135135
136136
137- def user_configure_settings (provider : Provider ) -> dict [str , Any ]:
137+ def user_configure_settings (provider : type [ Provider ] ) -> dict [str , Any ]:
138138 settings = {}
139139 for name , setting in get_settings (provider ):
140140 user_input = prompt (f"{ name } : " , default = str (setting .default ))
141- type_ = type (setting .default ) if setting .default else str
141+ if setting .default :
142+ type_ = type (setting .default )
143+ else :
144+ type_ = str
142145 value = type_ (user_input )
143146 settings [name ] = value
144147 return settings
@@ -151,7 +154,7 @@ def case_correct_user_input(user_input: str, options: Sequence[str]) -> str | No
151154 return None
152155
153156
154- def user_select_provider () -> Provider :
157+ def user_select_provider () -> type [ Provider ] :
155158 providers = list_providers ()
156159 completer = WordCompleter (providers , ignore_case = True )
157160 user_selected_provider = prompt (f"Choose your LLM provider ({ ', ' .join (providers )} ): " , completer = completer )
0 commit comments