Skip to content

Commit

Permalink
Refactor complete and acomplete methods to reduce code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
VP (aider) committed Jul 30, 2024
1 parent 7b05de5 commit b174470
Showing 1 changed file with 20 additions and 27 deletions.
47 changes: 20 additions & 27 deletions llms/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import threading
import queue
import concurrent.futures
import asyncio
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from logging import getLogger
Expand Down Expand Up @@ -85,39 +86,31 @@ def count_tokens(self, content: Union[str, List[Dict[str, Any]]]) -> Union[int,
results = [provider.count_tokens(content) for provider in self._providers]
return results if self.n_provider > 1 else results[0]

def complete(self, prompt: str, **kwargs: Any) -> Union[Result, Results]:
def _generate(provider):
result = provider.complete(prompt, **kwargs)
return result

if self.n_provider > 1:
results = []
with ThreadPoolExecutor() as executor:
futures = {
executor.submit(_generate, provider): provider
for provider in self._providers
}
for future in as_completed(futures):
results.append(future.result())
def _process_completion(self, prompt: str, is_async: bool, **kwargs: Any) -> Union[Result, Results]:
async def _async_generate(provider):
return await provider.acomplete(prompt, **kwargs)

return Results(results)
else:
return self._providers[0].complete(prompt, **kwargs)
def _sync_generate(provider):
return provider.complete(prompt, **kwargs)

async def acomplete(
self,
prompt: str,
**kwargs: Any,
) -> Union[Result, Results]:
if self.n_provider > 1:
results = []
for provider in self._providers:
result = await provider.acomplete(prompt, **kwargs)
results.append(result)
if is_async:
async def gather_results():
return await asyncio.gather(*[_async_generate(provider) for provider in self._providers])
results = asyncio.run(gather_results())
else:
with ThreadPoolExecutor() as executor:
results = list(executor.map(_sync_generate, self._providers))
return Results(results)
else:
provider = self._providers[0]
return await provider.acomplete(prompt, **kwargs)
return _async_generate(provider) if is_async else _sync_generate(provider)

def complete(self, prompt: str, **kwargs: Any) -> Union[Result, Results]:
return self._process_completion(prompt, is_async=False, **kwargs)

async def acomplete(self, prompt: str, **kwargs: Any) -> Union[Result, Results]:
return await self._process_completion(prompt, is_async=True, **kwargs)

def complete_stream(self, prompt: str, **kwargs: Any) -> StreamResult:
if self.n_provider > 1:
Expand Down

0 comments on commit b174470

Please sign in to comment.