Skip to content

Commit

Permalink
Queue: Isolate Models in Processes (#199)
Browse files Browse the repository at this point in the history
* Spawn process for model

* Remove sleeps

* Fix comments and exceptions

* Only set start method once

* Raise same exception

* Propagate exception
  • Loading branch information
haroldrubio authored Dec 23, 2024
1 parent 21c453e commit 7dc429f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
37 changes: 33 additions & 4 deletions expertise/service/expertise.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from openreview import OpenReviewException
from enum import Enum
from threading import Lock
import multiprocessing
from bullmq import Queue, Worker
from expertise.execute_expertise import execute_create_dataset, execute_expertise
import asyncio
Expand Down Expand Up @@ -69,6 +70,9 @@ def __init__(self, config, logger):
self.optional_fields = ['model', 'model_params', 'exclusion_inv', 'token', 'baseurl', 'baseurl_v2', 'paper_invitation', 'paper_id']
self.path_fields = ['work_dir', 'scores_path', 'publications_path', 'submissions_path']

if multiprocessing.get_start_method(allow_none=True) != 'spawn':
multiprocessing.set_start_method('spawn', force=True)

def set_client(self, client):
self.client = client

Expand Down Expand Up @@ -204,6 +208,18 @@ def update_status(self, config, new_status, desc=None):
config.mdate = int(time.time() * 1000)
self.redis.save_job(config)

@staticmethod
def expertise_worker(config_json, queue):
try:
config = json.loads(config_json)
execute_expertise(config=config)
except Exception as e:
queue.put(e)
finally:
# Cleanup resources
torch.cuda.empty_cache()
gc.collect()

async def worker_process(self, job, token):
job_id = job.data['job_id']
user_id = job.data['user_id']
Expand All @@ -218,19 +234,32 @@ async def worker_process(self, job, token):
baseurl=config.baseurl_v2
)
try:
# Create dataset
execute_create_dataset(openreview_client, openreview_client_v2, config=config.to_json())
self.update_status(config, JobStatus.RUN_EXPERTISE)
execute_expertise(config=config.to_json())

queue = multiprocessing.Queue() # Queue for exception handling
config_json = json.dumps(config.to_json()) # Serialize config
process = multiprocessing.Process(target=ExpertiseService.expertise_worker, args=(config_json, queue))
process.start()
process.join()

if not queue.empty():
exception = queue.get()
raise exception # Re-raise the exception from the subprocess

# Update job status
self.update_status(config, JobStatus.COMPLETED)

# Explicitly cleanup resources
torch.cuda.empty_cache()
gc.collect()
except Exception as e:
self.update_status(config, JobStatus.ERROR, str(e))
# Re raise exception so that it appears in the queue
exception = e.with_traceback(e.__traceback__)
raise exception
finally:
# Cleanup resources
torch.cuda.empty_cache()
gc.collect()

def _get_job_name(self, request):
job_name_parts = [request.get('name', 'No name provided')]
Expand Down
5 changes: 1 addition & 4 deletions tests/test_expertise_apiv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ def test_venueid_v2(self, openreview_client, openreview_context, celery_session_
)
assert response.status_code == 200, f'{response.json}'
job_id = response.json['jobId']
time.sleep(2)
response = test_client.get('/expertise/status', headers=openreview_client.headers, query_string={'jobId': f'{job_id}'}).json
assert response['name'] == 'test_run'
assert response['status'] != 'Error'
Expand Down Expand Up @@ -406,7 +405,6 @@ def test_submission_content_v2(self, openreview_client, openreview_context, cele
)
assert response.status_code == 200, f'{response.json}'
job_id = response.json['jobId']
time.sleep(2)
response = test_client.get('/expertise/status', headers=openreview_client.headers, query_string={'jobId': f'{job_id}'}).json
assert response['name'] == 'test_run'
assert response['status'] != 'Error'
Expand Down Expand Up @@ -508,7 +506,6 @@ def test_specter2_scincl(self, openreview_client, openreview_context, celery_ses
)
assert response.status_code == 200, f'{response.json}'
job_id = response.json['jobId']
time.sleep(2)
response = test_client.get('/expertise/status', headers=openreview_client.headers, query_string={'jobId': f'{job_id}'}).json
assert response['name'] == 'test_run'
assert response['status'] != 'Error'
Expand Down Expand Up @@ -583,4 +580,4 @@ def test_specter2_scincl(self, openreview_client, openreview_context, celery_ses
assert len(submission_id) >= 1
assert len(profile_id) >= 1
assert profile_id.startswith('~')
assert score >= 0 and score <= 1
assert score >= 0 and score <= 1

0 comments on commit 7dc429f

Please sign in to comment.