Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelise get_traceback_path #106

Open
szhan opened this issue Jun 20, 2023 · 5 comments
Open

Parallelise get_traceback_path #106

szhan opened this issue Jun 20, 2023 · 5 comments
Labels
enhancement New feature or request

Comments

@szhan
Copy link
Owner

szhan commented Jun 20, 2023

This can be done using 'concurrent.futures'.

@szhan
Copy link
Owner Author

szhan commented Jun 20, 2023

h/t @benjeffery

def threaded_map(func, args, num_workers):
    results_buffer = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = set()
        next_index = 0
        for i, arg in enumerate(args):
            # +1 so that we're not waiting for the args generator to produce the next arg
            while len(futures) >= num_workers + 1:
                # If there are too many in-progress tasks, wait for one to complete
                done, futures = concurrent.futures.wait(
                    futures, return_when=concurrent.futures.FIRST_COMPLETED
                )
                for future in done:
                    index, result = future.result()
                    if index == next_index:
                        # If this result is the next expected one, yield it immediately
                        yield result
                        next_index += 1
                    else:
                        heapq.heappush(results_buffer, (index, result))
                    # Yield any results from the buffer that are next in line
                    while results_buffer and results_buffer[0][0] == next_index:
                        _, result = heapq.heappop(results_buffer)
                        yield result
                        next_index += 1
            # Wraps the function so we can track the index of the argument
            futures.add(executor.submit(lambda arg, i=i: (i, func(arg)), arg))
        concurrent.futures.wait(futures)
        for future in futures:
            index, result = future.result()
            if index == next_index:
                yield result
                next_index += 1
            else:
                heapq.heappush(results_buffer, (index, result))
        # Yield any remaining results in the buffer
        while results_buffer:
            _, result = heapq.heappop(results_buffer)
            yield result

@szhan
Copy link
Owner Author

szhan commented Jun 20, 2023

The above function can be used as below.

def test(a):
    return(a[0] + a[1])

c = threaded_map(test, [(1, 2), (3, 4)], num_workers=2)

@szhan
Copy link
Owner Author

szhan commented Jun 20, 2023

The second argument could take a list of numpy.ndarray, which contain sample paths.

@szhan
Copy link
Owner Author

szhan commented Jun 20, 2023

Only two built-in dependencies.

import concurrent.futures
import heapq

@szhan
Copy link
Owner Author

szhan commented Jun 26, 2023

It is actually already pretty fast per call, when precision is not too high (< 26).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant