Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Optimizations and GIL release. #88

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

sabonerune
Copy link
Contributor

Release and optimize GIL.

If you release GIL only for mecab_dict_index, there will be no big problem.

mutex is required to release further GIL.

Otherwise problems will occur when multiple threads execute functions in __init__.py.

@sabonerune
Copy link
Contributor Author

Lock has been added.
I think that at least fatal destruction will not occur if an instance is operated from multiple threads at the same time.
However, I cannot guarantee that the result will be correct. (Especially HTS_Engine)

Copy link
Owner

@r9y9 r9y9 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your work! Could you give some more details for the motivation and how this PR fixes it? I am not sure the changes are the right direction or not. In particular,

I would appreciate it if you give me some more information!

pyopenjtalk/openjtalk.pyx Show resolved Hide resolved
@r9y9
Copy link
Owner

r9y9 commented Dec 10, 2024

However, I cannot guarantee that the result will be correct. (Especially HTS_Engine)

I think it's OK not to care HTS_engine too much. People in these days probably don't use HTS anymore. Rare case may exist for scientific research though.

Copy link
Contributor Author

@sabonerune sabonerune left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we added mutexes to OpenJTalk and HTSEngine is to provide protection if users use them directly.
So the mutex in this PR is not required.

However, the mutex in #87 is required.
The mutex #87 protects dictionary downloads, dictionary updates, and operations on the HTSEngine.

RLock in HTSEngine may not be necessary.
Because if there is a mutex inside the HTSEngine it will prevent a segmentation fault.
However, the accuracy of the results cannot be guaranteed when such operations are performed.

pyopenjtalk/openjtalk.pyx Show resolved Hide resolved
@sabonerune sabonerune requested a review from r9y9 December 11, 2024 10:33
@r9y9
Copy link
Owner

r9y9 commented Dec 11, 2024

Could you share some code example that does not work correctly without this PR but works OK with this PR? I'd love to understand how this PR works.

However, the mutex in #87 is required.
The mutex #87 protects dictionary downloads, dictionary updates, and operations on the HTSEngine.

Yes, I understand this point and that's why I already merged #87. What I'm not fully sure is whether we must (or better) need RLock/nogil in this PR.

@sabonerune
Copy link
Contributor Author

sabonerune commented Dec 11, 2024

Could you share some code example that does not work correctly without this PR but works OK with this PR?

OpenJTalk example:

from concurrent.futures import ThreadPoolExecutor

from pyopenjtalk import OPEN_JTALK_DICT_DIR
from pyopenjtalk.openjtalk import OpenJTalk

ojt = OpenJTalk(OPEN_JTALK_DICT_DIR)
text = "こんにちは"

if __name__ == "__main__":
    with ThreadPoolExecutor() as e:
        futures = [e.submit(ojt.run_frontend, text) for _ in range(32)]
        results = [i.result() for i in futures]
        first = results[0]
        for i in results[1:-1]:
            assert first == i
        print("run_frontend() done")

If OpenJTalk did not have a mutex, this code would likely cause a fatal problem such as a segmentation fault.


HTSEngine exsample.

import math
from concurrent.futures import ThreadPoolExecutor

from pyopenjtalk import DEFAULT_HTS_VOICE, extract_fullcontext
from pyopenjtalk.htsengine import HTSEngine

labels = extract_fullcontext("こんにちは")
hts = HTSEngine(DEFAULT_HTS_VOICE)


def synthesize(speed):
    hts.set_speed(speed)
    hts.add_half_tone(0.0)
    return hts.synthesize(labels)


if __name__ == "__main__":
    with ThreadPoolExecutor() as e:
        count = 32
        futures = [e.submit(synthesize, (i + 1) / count) for i in range(count)]
        results = [i.result() for i in futures]
        last_len = len(results[-1])
        try:
            for i, wave in enumerate(results):
                speed = (i + 1) / count
                assert math.isclose(last_len / len(wave), speed, rel_tol=10**-2)
        finally:
            print("synthesize() done")

This code will speed unpredictably even with a mutex.

@r9y9
Copy link
Owner

r9y9 commented Dec 13, 2024

Thank you for the examples. I understood RLock is for users who directly call OpenJTalk or HTSEngine and #87 for global instances, so they are for different purposes.

If I understand correctly, this PR is for preventing segfaults when multiple threads are trying to access OpenJTalk/HTSEngine directly but is NOT for speeding up execution by multi-threading. Is my understanding correct? Looking at C code quickly, I believe OpenJTalk and HTSEngine are not written to be thread-safe. So we have to use mutex for the entire execution from Cython/Python side, and thus there's little performance benefit by multi-threading.

If my understanding is correct (i.e purpose is to prevent segfaults), this PR looks OK to me. If your purpose includes performance optimization, I'd like to know how much speedup we can get.

@r9y9
Copy link
Owner

r9y9 commented Dec 13, 2024

As a side note, it'd be better to have a test like:

    ojt = OpenJTalk(OPEN_JTALK_DICT_DIR)

    texts =  [
        "今日もいい天気ですね",
        "こんにちは",
        "マルチスレッドプログラミング",
        "テストです",
        "Pythonはプログラミング言語です",
        "日本語テキストを音声合成します",
    ]

    # Test consistency between single and multi-threaded runs
    # make sure no corruptions happen in OJT internal
    results_s = [ojt.run_frontend(text) for text in texts]
    results_m = []
    with ThreadPoolExecutor() as e:
        futures = [e.submit(ojt.run_frontend, text) for text in texts]
        results_m = [i.result() for i in futures]
    for i, (s, m) in enumerate(zip(results_s, results_m)):
        assert len(s) == len(m)
        for s_, m_ in zip(s, m):
            # full context must exactly match
            assert s_ == m_

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants