diff --git a/fastdata/core.py b/fastdata/core.py index 442f4f8..53d0080 100644 --- a/fastdata/core.py +++ b/fastdata/core.py @@ -42,6 +42,7 @@ def generate(self, temp: float = 1., sp: str = "You are a helpful assistant.", max_workers: int = 64) -> list[dict]: + "For every input in INPUTS, fill PROMPT_TEMPLATE and generate a value fitting SCHEMA" def process_input(input_data): try: @@ -57,10 +58,9 @@ def process_input(input_data): return None results = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(process_input, input_data) for input_data in inputs] - for future in tqdm(concurrent.futures.as_completed(futures), total=len(inputs)): - result = future.result() - results.append(result) - - return results + with tqdm(total=len(inputs)) as pbar: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(process_input, input_data) for input_data in inputs] + for completed_future in concurrent.futures.as_completed(futures): + pbar.update(1) + return [f.result() for f in futures]