From cb4709c82e2100250595205c8e052454b80c1ca4 Mon Sep 17 00:00:00 2001 From: jyx-su <108294040+jyx-su@users.noreply.github.com> Date: Fri, 1 Nov 2024 14:21:18 -0700 Subject: [PATCH 01/15] Simple initial implementation --- src/server/tasks/medagentbench/__init__.py | 0 src/server/tasks/medagentbench/funcs.py | 40 +++++++ src/server/tasks/medagentbench/patient.py | 113 ++++++++++++++++++ .../tasks/medagentbench/patient_data.json | 46 +++++++ src/server/tasks/medagentbench/utils.py | 0 5 files changed, 199 insertions(+) create mode 100644 src/server/tasks/medagentbench/__init__.py create mode 100644 src/server/tasks/medagentbench/funcs.py create mode 100644 src/server/tasks/medagentbench/patient.py create mode 100644 src/server/tasks/medagentbench/patient_data.json create mode 100644 src/server/tasks/medagentbench/utils.py diff --git a/src/server/tasks/medagentbench/__init__.py b/src/server/tasks/medagentbench/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/server/tasks/medagentbench/funcs.py b/src/server/tasks/medagentbench/funcs.py new file mode 100644 index 0000000..17fb725 --- /dev/null +++ b/src/server/tasks/medagentbench/funcs.py @@ -0,0 +1,40 @@ +from patient import Patient + +def filter_patients(patients, **kwargs): + """ + Filters a list of Patient objects based on provided attributes. + + Args: + patients (list): List of Patient objects. + **kwargs: Attribute-value pairs to filter by (e.g., name="John Doe", dob="1975-04-12"). + + Returns: + list: List of filtered Patient objects that match all specified criteria. + """ + filtered_patients = [] + + for patient in patients: + match = True + for key, value in kwargs.items(): + if not hasattr(patient, key) or getattr(patient, key) != value: + match = False + break + if match: + filtered_patients.append(patient) + + return filtered_patients + + +# Sample list of Patient objects +patients = [ + Patient.from_json('patient_data.json'), #Patient(patient_id="12345", name="John Doe", gender="Male", dob="1979-05-10", address="123 Main St", phone="555-1234", email="johndoe@example.com"), + Patient(patient_id="67890", name="Jane Smith", gender="Female", dob="1994-07-20", address="456 Elm St", phone="555-5678", email="janesmith@example.com"), + # Add more Patient objects as needed +] + +# Filter by name and DOB +filtered = filter_patients(patients, name="John Doe", dob="1979-05-10") + +# Display results +for patient in filtered: + print(patient) diff --git a/src/server/tasks/medagentbench/patient.py b/src/server/tasks/medagentbench/patient.py new file mode 100644 index 0000000..5a27bde --- /dev/null +++ b/src/server/tasks/medagentbench/patient.py @@ -0,0 +1,113 @@ +import json + +class Patient: + def __init__(self, patient_id, name, dob, gender, address, phone, email, medications=None, lab_tests=None, conditions=None, allergies=None, notes=None): + # Demographic information + self.patient_id = patient_id + self.name = name + self.dob = dob + self.gender = gender + self.address = address + self.phone = phone + self.email = email + + # Medical data + self.medications = medications if medications else [] + self.lab_tests = lab_tests if lab_tests else {} + self.conditions = conditions if conditions else [] + self.allergies = allergies if allergies else [] + self.notes = notes if notes else [] + + @classmethod + def from_json(cls, json_file_path): + with open(json_file_path, 'r') as file: + data = json.load(file) + + # Initialize from JSON data + return cls( + patient_id=data.get("patient_id"), + name=data.get("name"), + dob=data.get("dob"), + gender=data.get("gender"), + address=data.get("address"), + phone=data.get("phone"), + email=data.get("email"), + medications=data.get("medications"), + lab_tests=data.get("lab_tests"), + conditions=data.get("conditions"), + allergies=data.get("allergies"), + notes=data.get("notes") + ) + + def add_medication(self, medication_name, dosage, frequency, start_date, end_date=None): + medication = { + "medication_name": medication_name, + "dosage": dosage, + "frequency": frequency, + "start_date": start_date, + "end_date": end_date + } + self.medications.append(medication) + + def add_lab_test(self, test_name, result, units, date): + if test_name not in self.lab_tests: + self.lab_tests[test_name] = [] + self.lab_tests[test_name].append({ + "result": result, + "units": units, + "date": date + }) + + def add_condition(self, condition_name, diagnosis_date, status): + condition = { + "condition_name": condition_name, + "diagnosis_date": diagnosis_date, + "status": status + } + self.conditions.append(condition) + + def add_allergy(self, allergen, reaction, severity): + allergy = { + "allergen": allergen, + "reaction": reaction, + "severity": severity + } + self.allergies.append(allergy) + + def add_note(self, note, date): + self.notes.append({ + "note": note, + "date": date + }) + + def get_medications(self): + return self.medications + + def get_lab_tests(self): + return self.lab_tests + + def get_conditions(self): + return self.conditions + + def get_allergies(self): + return self.allergies + + def get_notes(self): + return self.notes + + def __str__(self): + return f"Patient ID: {self.patient_id}, Name: {self.name}, DOB: {self.dob}, Gender: {self.gender}" + +if __name__ == '__main__': + # Initialize a Patient object from a JSON file + patient = Patient.from_json("patient_data.json") + + # Print the patient information + print(patient) + + # Access the patient's data + print(patient.get_medications()) + print(patient.get_lab_tests()) + print(patient.get_conditions()) + print(patient.get_allergies()) + print(patient.get_notes()) diff --git a/src/server/tasks/medagentbench/patient_data.json b/src/server/tasks/medagentbench/patient_data.json new file mode 100644 index 0000000..9053e51 --- /dev/null +++ b/src/server/tasks/medagentbench/patient_data.json @@ -0,0 +1,46 @@ +{ + "patient_id": "12345", + "name": "John Doe", + "dob": "1979-05-10", + "gender": "Male", + "address": "123 Main St", + "phone": "555-1234", + "email": "johndoe@example.com", + "medications": [ + { + "medication_name": "Aspirin", + "dosage": "100mg", + "frequency": "Once daily", + "start_date": "2024-01-01" + } + ], + "lab_tests": { + "Hemoglobin": [ + { + "result": 13.5, + "units": "g/dL", + "date": "2024-01-15" + } + ] + }, + "conditions": [ + { + "condition_name": "Hypertension", + "diagnosis_date": "2023-05-10", + "status": "Ongoing" + } + ], + "allergies": [ + { + "allergen": "Peanuts", + "reaction": "Hives", + "severity": "Severe" + } + ], + "notes": [ + { + "note": "Patient is responding well to treatment.", + "date": "2024-01-20" + } + ] +} diff --git a/src/server/tasks/medagentbench/utils.py b/src/server/tasks/medagentbench/utils.py new file mode 100644 index 0000000..e69de29 From 4228d4df2914959087130d773a6ce02813460b75 Mon Sep 17 00:00:00 2001 From: jyx-su <108294040+jyx-su@users.noreply.github.com> Date: Wed, 6 Nov 2024 23:15:23 +0000 Subject: [PATCH 02/15] Add FHIR server yaml --- application.yaml | 329 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 329 insertions(+) create mode 100644 application.yaml diff --git a/application.yaml b/application.yaml new file mode 100644 index 0000000..e46ce6d --- /dev/null +++ b/application.yaml @@ -0,0 +1,329 @@ +#Uncomment the "servlet" and "context-path" lines below to make the fhir endpoint available at /example/path/fhir instead of the default value of /fhir +server: + # servlet: + # context-path: /example/path + port: 8080 +#Adds the option to go to eg. http://localhost:8080/actuator/health for seeing the running configuration +#see https://docs.spring.io/spring-boot/docs/current/reference/html/actuator.html#actuator.endpoints +management: + #The following configuration will enable the actuator endpoints at /actuator/health, /actuator/info, /actuator/prometheus, /actuator/metrics. For security purposes, only /actuator/health is enabled by default. + endpoints: + enabled-by-default: false + web: + exposure: + include: 'health' # or e.g. 'info,health,prometheus,metrics' or '*' for all' + endpoint: + info: + enabled: true + metrics: + enabled: true + health: + enabled: true + probes: + enabled: true + group: + liveness: + include: + - livenessState + - readinessState + prometheus: + enabled: true + prometheus: + metrics: + export: + enabled: true +spring: + main: + allow-circular-references: true + flyway: + enabled: false + baselineOnMigrate: true + fail-on-missing-locations: false + datasource: + url: 'jdbc:h2:/data/test_db;DB_CLOSE_DELAY=-1;AUTO_SERVER=TRUE' + #url: 'jdbc:h2:file:./target/database/h2' + #url: jdbc:h2:mem:test_mem + username: sa + password: null + driverClassName: org.h2.Driver + max-active: 15 + + # database connection pool size + hikari: + maximum-pool-size: 10 + jpa: + properties: + hibernate.format_sql: false + hibernate.show_sql: false + + #Hibernate dialect is automatically detected except Postgres and H2. + #If using H2, then supply the value of ca.uhn.fhir.jpa.model.dialect.HapiFhirH2Dialect + #If using postgres, then supply the value of ca.uhn.fhir.jpa.model.dialect.HapiFhirPostgresDialect + hibernate.dialect: ca.uhn.fhir.jpa.model.dialect.HapiFhirH2Dialect + # hibernate.hbm2ddl.auto: update + # hibernate.jdbc.batch_size: 20 + # hibernate.cache.use_query_cache: false + # hibernate.cache.use_second_level_cache: false + # hibernate.cache.use_structured_entries: false + # hibernate.cache.use_minimal_puts: false + + ### These settings will enable fulltext search with lucene or elastic + hibernate.search.enabled: false + ### lucene parameters +# hibernate.search.backend.type: lucene +# hibernate.search.backend.analysis.configurer: ca.uhn.fhir.jpa.search.HapiHSearchAnalysisConfigurers$HapiLuceneAnalysisConfigurer +# hibernate.search.backend.directory.type: local-filesystem +# hibernate.search.backend.directory.root: target/lucenefiles +# hibernate.search.backend.lucene_version: lucene_current + ### elastic parameters ===> see also elasticsearch section below <=== +# hibernate.search.backend.type: elasticsearch +# hibernate.search.backend.analysis.configurer: ca.uhn.fhir.jpa.search.HapiHSearchAnalysisConfigurers$HapiElasticAnalysisConfigurer +hapi: + fhir: + ### This flag when enabled to true, will avail evaluate measure operations from CR Module. + ### Flag is false by default, can be passed as command line argument to override. + cr: + enabled: false + caregaps: + reporter: "default" + section_author: "default" + cql: + use_embedded_libraries: true + compiler: + ### These are low-level compiler options. + ### They are not typically needed by most users. + # validate_units: true + # verify_only: false + # compatibility_level: "1.5" + error_level: Info + signature_level: All + # analyze_data_requirements: false + # collapse_data_requirements: false + # translator_format: JSON + # enable_date_range_optimization: true + enable_annotations: true + enable_locators: true + enable_results_type: true + enable_detailed_errors: true + # disable_list_traversal: false + # disable_list_demotion: false + # enable_interval_demotion: false + # enable_interval_promotion: false + # disable_method_invocation: false + # require_from_keyword: false + # disable_default_model_info_load: false + runtime: + debug_logging_enabled: false + # enable_validation: false + # enable_expression_caching: true + terminology: + valueset_preexpansion_mode: REQUIRE # USE_IF_PRESENT, REQUIRE, IGNORE + valueset_expansion_mode: PERFORM_NAIVE_EXPANSION # AUTO, USE_EXPANSION_OPERATION, PERFORM_NAIVE_EXPANSION + valueset_membership_mode: USE_EXPANSION # AUTO, USE_VALIDATE_CODE_OPERATION, USE_EXPANSION + code_lookup_mode: USE_VALIDATE_CODE_OPERATION # AUTO, USE_VALIDATE_CODE_OPERATION, USE_CODESYSTEM_URL + data: + search_parameter_mode: FILTER_IN_MEMORY # AUTO, USE_SEARCH_PARAMETERS, FILTER_IN_MEMORY + terminology_parameter_mode: FILTER_IN_MEMORY # AUTO, USE_VALUE_SET_URL, USE_INLINE_CODES, FILTER_IN_MEMORY + profile_mode: DECLARED # ENFORCED, DECLARED, OPTIONAL, TRUST, OFF + + cdshooks: + enabled: false + clientIdHeaderName: client_id + + ### This enables the swagger-ui at /fhir/swagger-ui/index.html as well as the /fhir/api-docs (see https://hapifhir.io/hapi-fhir/docs/server_plain/openapi.html) + openapi_enabled: true + ### This is the FHIR version. Choose between, DSTU2, DSTU3, R4 or R5 + fhir_version: R4 + ### Flag is false by default. This flag enables runtime installation of IG's. + ig_runtime_upload_enabled: false + ### This flag when enabled to true, will avail evaluate measure operations from CR Module. + + ### enable to use the ApacheProxyAddressStrategy which uses X-Forwarded-* headers + ### to determine the FHIR server address + # use_apache_address_strategy: false + ### forces the use of the https:// protocol for the returned server address. + ### alternatively, it may be set using the X-Forwarded-Proto header. + # use_apache_address_strategy_https: false + ### enables the server to overwrite defaults on HTML, css, etc. under the url pattern of eg. /content/custom ** + ### Folder with custom content MUST be named custom. If omitted then default content applies + #custom_content_path: ./custom + ### enables the server host custom content. If e.g. the value ./configs/app is supplied then the content + ### will be served under /web/app + #app_content_path: ./configs/app + ### enable to set the Server URL + # server_address: http://hapi.fhir.org/baseR4 + # defer_indexing_for_codesystems_of_size: 101 + # install_transitive_ig_dependencies: true + #implementationguides: + ### example from registry (packages.fhir.org) + # swiss: + # name: swiss.mednet.fhir + # version: 0.8.0 + # reloadExisting: false + # installMode: STORE_AND_INSTALL + # example not from registry + # ips_1_0_0: + # packageUrl: https://build.fhir.org/ig/HL7/fhir-ips/package.tgz + # name: hl7.fhir.uv.ips + # version: 1.0.0 + # supported_resource_types: + # - Patient + # - Observation + ################################################## + # Allowed Bundle Types for persistence (defaults are: COLLECTION,DOCUMENT,MESSAGE) + ################################################## + # allowed_bundle_types: COLLECTION,DOCUMENT,MESSAGE,TRANSACTION,TRANSACTIONRESPONSE,BATCH,BATCHRESPONSE,HISTORY,SEARCHSET + # allow_cascading_deletes: true + # allow_contains_searches: true + # allow_external_references: true + # allow_multiple_delete: true + # allow_override_default_search_params: true + # auto_create_placeholder_reference_targets: false + # mass_ingestion_mode_enabled: false + ### tells the server to automatically append the current version of the target resource to references at these paths + # auto_version_reference_at_paths: Device.patient, Device.location, Device.parent, DeviceMetric.parent, DeviceMetric.source, Observation.device, Observation.subject + # ips_enabled: false + # default_encoding: JSON + # default_pretty_print: true + # default_page_size: 20 + # delete_expunge_enabled: true + # enable_repository_validating_interceptor: true + # enable_index_missing_fields: false + # enable_index_of_type: true + # enable_index_contained_resource: false + # upliftedRefchains_enabled: true + # resource_dbhistory_enabled: false + ### !!Extended Lucene/Elasticsearch Indexing is still a experimental feature, expect some features (e.g. _total=accurate) to not work as expected!! + ### more information here: https://hapifhir.io/hapi-fhir/docs/server_jpa/elastic.html + advanced_lucene_indexing: false + bulk_export_enabled: false + bulk_import_enabled: false + # language_search_parameter_enabled: true + # enforce_referential_integrity_on_delete: false + # This is an experimental feature, and does not fully support _total and other FHIR features. + # enforce_referential_integrity_on_delete: false + # enforce_referential_integrity_on_write: false + # etag_support_enabled: true + # expunge_enabled: true + # client_id_strategy: ALPHANUMERIC + # server_id_strategy: SEQUENTIAL_NUMERIC + # fhirpath_interceptor_enabled: false + # filter_search_enabled: true + # graphql_enabled: true + narrative_enabled: false + mdm_enabled: false + mdm_rules_json_location: "mdm-rules.json" + # local_base_urls: + # - https://hapi.fhir.org/baseR4 + logical_urls: + - http://terminology.hl7.org/* + - https://terminology.hl7.org/* + - http://snomed.info/* + - https://snomed.info/* + - http://unitsofmeasure.org/* + - https://unitsofmeasure.org/* + - http://loinc.org/* + - https://loinc.org/* + # partitioning: + # allow_references_across_partitions: false + # partitioning_include_in_search_hashes: false + cors: + allow_Credentials: true + # These are allowed_origin patterns, see: https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/web/cors/CorsConfiguration.html#setAllowedOriginPatterns-java.util.List- + allowed_origin: + - '*' + + # Search coordinator thread pool sizes + search-coord-core-pool-size: 20 + search-coord-max-pool-size: 100 + search-coord-queue-capacity: 200 + + # Search Prefetch Thresholds. + + # This setting sets the number of search results to prefetch. For example, if this list + # is set to [100, 1000, -1] then the server will initially load 100 results and not + # attempt to load more. If the user requests subsequent page(s) of results and goes + # past 100 results, the system will load the next 900 (up to the following threshold of 1000). + # The system will progressively work through these thresholds. + # A threshold of -1 means to load all results. Note that if the final threshold is a + # number other than -1, the system will never prefetch more than the given number. + search_prefetch_thresholds: 13,503,2003,-1 + + # comma-separated package names, will be @ComponentScan'ed by Spring to allow for creating custom Spring beans + #custom-bean-packages: + + # comma-separated list of fully qualified interceptor classes. + # classes listed here will be fetched from the Spring context when combined with 'custom-bean-packages', + # or will be instantiated via reflection using an no-arg contructor; then registered with the server + #custom-interceptor-classes: + + # comma-separated list of fully qualified provider classes. + # classes listed here will be fetched from the Spring context when combined with 'custom-bean-packages', + # or will be instantiated via reflection using an no-arg contructor; then registered with the server + #custom-provider-classes: + + # Threadpool size for BATCH'ed GETs in a bundle. + # bundle_batch_pool_size: 10 + # bundle_batch_pool_max_size: 50 + + # logger: + # error_format: 'ERROR - ${requestVerb} ${requestUrl}' + # format: >- + # Path[${servletPath}] Source[${requestHeader.x-forwarded-for}] + # Operation[${operationType} ${operationName} ${idOrResourceName}] + # UA[${requestHeader.user-agent}] Params[${requestParameters}] + # ResponseEncoding[${responseEncodingNoDefault}] + # log_exceptions: true + # name: fhirtest.access + # max_binary_size: 104857600 + # max_page_size: 200 + # retain_cached_searches_mins: 60 + # reuse_cached_search_results_millis: 60000 + tester: + home: + name: Local Tester + server_address: 'http://localhost:8080/fhir' + refuse_to_fetch_third_party_urls: false + fhir_version: R4 + global: + name: Global Tester + server_address: "http://hapi.fhir.org/baseR4" + refuse_to_fetch_third_party_urls: false + fhir_version: R4 + # validation: + # requests_enabled: true + # responses_enabled: true + # binary_storage_enabled: true + inline_resource_storage_below_size: 4000 +# bulk_export_enabled: true +# subscription: +# resthook_enabled: true +# websocket_enabled: false +# email: +# from: some@test.com +# host: google.com +# port: +# username: +# password: +# auth: +# startTlsEnable: +# startTlsRequired: +# quitWait: +# lastn_enabled: true +# store_resource_in_lucene_index_enabled: true +### This is configuration for normalized quantity search level default is 0 +### 0: NORMALIZED_QUANTITY_SEARCH_NOT_SUPPORTED - default +### 1: NORMALIZED_QUANTITY_STORAGE_SUPPORTED +### 2: NORMALIZED_QUANTITY_SEARCH_SUPPORTED +# normalized_quantity_search_level: 2 +#elasticsearch: +# debug: +# pretty_print_json_log: false +# refresh_after_write: false +# enabled: false +# password: SomePassword +# required_index_status: YELLOW +# rest_url: 'localhost:9200' +# protocol: 'http' +# schema_management_strategy: CREATE +# username: SomeUsername From 49b063c1a1f7861d4629cfd9ec36cc7def9d8858 Mon Sep 17 00:00:00 2001 From: jyx-su <108294040+jyx-su@users.noreply.github.com> Date: Wed, 6 Nov 2024 23:17:10 +0000 Subject: [PATCH 03/15] script for starting the FHIR server --- start_fhir.sh | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 start_fhir.sh diff --git a/start_fhir.sh b/start_fhir.sh new file mode 100644 index 0000000..9ef63f9 --- /dev/null +++ b/start_fhir.sh @@ -0,0 +1,5 @@ +cd MedAgentBench-dev + +docker run -d -p 8080:8080 -v $(pwd)/:/configs -e "--spring.config.location=file:///configs/application.yaml" -v /home/yx/my_h2_data:/data hapiproject/hapi:latest + + From c85e7f0184724a5142cd5c57c4e5f44a15769f2b Mon Sep 17 00:00:00 2001 From: jyx-su <108294040+jyx-su@users.noreply.github.com> Date: Fri, 8 Nov 2024 14:11:51 -0800 Subject: [PATCH 04/15] add ref_sol for task1 --- MedAgentBench/ref_sol.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 MedAgentBench/ref_sol.py diff --git a/MedAgentBench/ref_sol.py b/MedAgentBench/ref_sol.py new file mode 100644 index 0000000..7fe9037 --- /dev/null +++ b/MedAgentBench/ref_sol.py @@ -0,0 +1,28 @@ +import requests +fhir_base_url = "http://34.132.86.17:8080/fhir/" + +def task1_sol(name, dob): + url = f'{fhir_base_url}Patient' + if not isinstance(dob, str): + dob = dob.strftime('%Y-%m-%d') + last_name = name.split(' ')[-1] + first_name = ' '.join(name.split(' ')[:-1]) + + params = { + 'birthdate': dob, + 'given': first_name, + 'family': last_name + } + + try: + response = requests.get(url, params=params) + response.raise_for_status() # Raises HTTPError for bad responses (4xx and 5xx) + data = response.json() + if data['total'] == 0: + return "Patient not found" + return data['entry'][0]['resource']['identifier'][0]['value'] + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") + +assert 'S6540602' == task1_sol('Bobby Klein', '1954-01-02') +assert 'Patient not found' == task1_sol('Kyle', '1999-01-01') \ No newline at end of file From 24f7192598d367acf71153e0fb1908f1d5685e3f Mon Sep 17 00:00:00 2001 From: jyx-su <108294040+jyx-su@users.noreply.github.com> Date: Fri, 8 Nov 2024 14:12:05 -0800 Subject: [PATCH 05/15] Move code --- .../application.yaml | 0 start_fhir.sh => MedAgentBench/start_fhir.sh | 0 src/server/tasks/medagentbench/task.py | 26 +++++++++++++++++++ 3 files changed, 26 insertions(+) rename application.yaml => MedAgentBench/application.yaml (100%) rename start_fhir.sh => MedAgentBench/start_fhir.sh (100%) create mode 100644 src/server/tasks/medagentbench/task.py diff --git a/application.yaml b/MedAgentBench/application.yaml similarity index 100% rename from application.yaml rename to MedAgentBench/application.yaml diff --git a/start_fhir.sh b/MedAgentBench/start_fhir.sh similarity index 100% rename from start_fhir.sh rename to MedAgentBench/start_fhir.sh diff --git a/src/server/tasks/medagentbench/task.py b/src/server/tasks/medagentbench/task.py new file mode 100644 index 0000000..8d2e1e3 --- /dev/null +++ b/src/server/tasks/medagentbench/task.py @@ -0,0 +1,26 @@ +#Structure documentation https://github.com/THUDM/AgentBench/blob/main/docs/Extension_en.md +from src.server.task import Task, Session +from src.typings import TaskOutput, SampleStatus, AgentOutputStatus + +class MedAgentBench(Task): + def __init__(self, *args, **kwargs) -> None: + super().__init__(name="MedAgentBench", *args, **kwargs) + + def get_indices(self) -> List[Any]: + return list(range(10)) + + async def start_sample(self, index, session: Session): + print("task start sample") + for loop_times in range(3): + await asyncio.sleep(1) + res = await session.action( + {"role": "user", "content": "Loop: %d" % loop_times} + ) + print("TASK", res.content) + return TaskSampleExecutionResult( + status=SampleStatus.COMPLETED, + result={"result": "ok"}, + ) + + def calculate_overall(self, results: List[TaskOutput]) -> Dict[str, Any]: + return {"score": 0.4} \ No newline at end of file From 26f7cb0826cf628896513576b6291410cb8bfc41 Mon Sep 17 00:00:00 2001 From: jyx-su <108294040+jyx-su@users.noreply.github.com> Date: Fri, 8 Nov 2024 14:27:12 -0800 Subject: [PATCH 06/15] Add example of json --- MedAgentBench/test_data_nov8.json | 1 + 1 file changed, 1 insertion(+) create mode 100644 MedAgentBench/test_data_nov8.json diff --git a/MedAgentBench/test_data_nov8.json b/MedAgentBench/test_data_nov8.json new file mode 100644 index 0000000..6e0c809 --- /dev/null +++ b/MedAgentBench/test_data_nov8.json @@ -0,0 +1 @@ +[{"id": "task1_1", "instruction": "What's the MRN of the patient with name Debra Dunn and DOB of 1969-05-12? If the patient does not exist, the answer should be \"Patient not found\"", "context": "", "sol": "S6551923"}, {"id": "task1_2", "instruction": "What's the MRN of the patient with name Bobby Klein and DOB of 1954-01-02? If the patient does not exist, the answer should be \"Patient not found\"", "context": "", "sol": "S6540602"}, {"id": "task1_3", "instruction": "What's the MRN of the patient with name Tina Reid and DOB of 1953-10-18? If the patient does not exist, the answer should be \"Patient not found\"", "context": "", "sol": "S3213957"}, {"id": "task1_4", "instruction": "What's the MRN of the patient with name Kevin Vasquez and DOB of 1953-11-19? If the patient does not exist, the answer should be \"Patient not found\"", "context": "", "sol": "S6200102"}, {"id": "task1_5", "instruction": "What's the MRN of the patient with name Christopher Cruz and DOB of 1940-08-28? If the patient does not exist, the answer should be \"Patient not found\"", "context": "", "sol": "S0658561"}, {"id": "task1_6", "instruction": "What's the MRN of the patient with name Brandon Williams and DOB of 1968-05-19? If the patient does not exist, the answer should be \"Patient not found\"", "context": "", "sol": "S6549951"}, {"id": "task1_7", "instruction": "What's the MRN of the patient with name Erica Castro and DOB of 1960-12-16? If the patient does not exist, the answer should be \"Patient not found\"", "context": "", "sol": "S6550627"}, {"id": "task1_8", "instruction": "What's the MRN of the patient with name Mary Rivera and DOB of 1953-02-28? If the patient does not exist, the answer should be \"Patient not found\"", "context": "", "sol": "S1023381"}, {"id": "task1_9", "instruction": "What's the MRN of the patient with name Debra Dunn and DOB of 1969-05-15? If the patient does not exist, the answer should be \"Patient not found\"", "context": "", "sol": "Patient not found"}, {"id": "task1_10", "instruction": "What's the MRN of the patient with name Kyle Jia and DOB of 1969-05-15? If the patient does not exist, the answer should be \"Patient not found\"", "context": "", "sol": "Patient not found"}] \ No newline at end of file From 8189592906326ccd0a618347918d6794c61d5b66 Mon Sep 17 00:00:00 2001 From: GengLongling Date: Tue, 12 Nov 2024 13:41:52 -0800 Subject: [PATCH 07/15] Add files via upload --- MedAgentBench/ref_sol.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/MedAgentBench/ref_sol.py b/MedAgentBench/ref_sol.py index 7fe9037..52c05fc 100644 --- a/MedAgentBench/ref_sol.py +++ b/MedAgentBench/ref_sol.py @@ -25,4 +25,6 @@ def task1_sol(name, dob): print(f"An error occurred: {e}") assert 'S6540602' == task1_sol('Bobby Klein', '1954-01-02') -assert 'Patient not found' == task1_sol('Kyle', '1999-01-01') \ No newline at end of file +assert 'Patient not found' == task1_sol('Kyle', '1999-01-01') + +# note: does not consider empty string \ No newline at end of file From c2553514c1c54d39f98768b375482c905c2799a1 Mon Sep 17 00:00:00 2001 From: GengLongling Date: Tue, 12 Nov 2024 13:42:39 -0800 Subject: [PATCH 08/15] Add files via upload reference solution 1-3 added --- MedAgentBench/ref_sol2.py | 73 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 MedAgentBench/ref_sol2.py diff --git a/MedAgentBench/ref_sol2.py b/MedAgentBench/ref_sol2.py new file mode 100644 index 0000000..66e6df1 --- /dev/null +++ b/MedAgentBench/ref_sol2.py @@ -0,0 +1,73 @@ +import requests +from datetime import datetime +import math + +# Define the base URL for the FHIR server +fhir_base_url = "http://34.132.86.17:8080/fhir/" + + +# Function to calculate the age from the birthdate +def calculate_age(dob: str, today: datetime): + # Convert the birthdate string to a datetime object + birth_date = datetime.strptime(dob, '%Y-%m-%d') + + # Calculate the difference in years + age = today.year - birth_date.year + + # If the birthday hasn't occurred yet this year, subtract 1 from the age + if (today.month, today.day) < (birth_date.month, birth_date.day): + age -= 1 + + # Round up the age + return math.ceil(age) + + +# Function to get the patient's age from MRN +def task2_sol(mrn: str): + # Construct the URL for the FHIR Patient Search endpoint, using MRN in the identifier query parameter + url = f'{fhir_base_url}Patient?identifier={mrn}' + + try: + # Send the GET request to search for the patient by MRN + response = requests.get(url) + response.raise_for_status() # Raises HTTPError for bad responses (4xx and 5xx) + + # Parse the JSON response from the search + data = response.json() + + # Check if we found any patient data + if data['total'] == 0: + return "Patient not found" + + # Iterate through the search results to find the matching MRN + for entry in data['entry']: + patient = entry['resource'] + + # Extract the MRN from the patient resource + patient_mrn = next((identifier['value'] for identifier in patient.get('identifier', []) + if identifier['type']['coding'][0]['code'] == 'MR'), None) + + # Check if the MRN matches the requested MRN + if patient_mrn == mrn: + # Extract the birthdate from the patient resource + birthdate = patient.get('birthDate', None) + if not birthdate: + return "Birthdate not available" + + # Calculate the patient's age + today = datetime(2023, 11, 13) # Fixed today's date as per the example + age = calculate_age(birthdate, today) + + return age + + # If no matching MRN found, return not found + return "Patient not found" + + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") + return "Error occurred while retrieving patient data" + + +# Test +print(task2_sol('S3236936')) # Test with a valid MRN (replace with the MRN you want to search for) +print(task2_sol('INVALID_MRN')) # Test with an invalid MRN From e098fe957d0ce6419c499dea9c6b3cae1a851198 Mon Sep 17 00:00:00 2001 From: GengLongling Date: Tue, 12 Nov 2024 13:42:59 -0800 Subject: [PATCH 09/15] Add files via upload reference solution 1-3 added --- MedAgentBench/ref_sol3.py | 88 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 MedAgentBench/ref_sol3.py diff --git a/MedAgentBench/ref_sol3.py b/MedAgentBench/ref_sol3.py new file mode 100644 index 0000000..c2b6e10 --- /dev/null +++ b/MedAgentBench/ref_sol3.py @@ -0,0 +1,88 @@ +# import requests +# +# # Define the base URL for the FHIR server +# fhir_base_url = "http://34.132.86.17:8080/fhir/" +# +# +# def record_blood_pressure(systolic: float, diastolic: float, effective_date: str): +# """ +# Record a blood pressure observation with dynamically calculated ID. +# +# :param systolic: Systolic blood pressure value +# :param diastolic: Diastolic blood pressure value +# :param effective_date: The date/time when the observation was taken (in ISO 8601 format) +# :return: "Done" if successful, error message otherwise +# """ +# +# # Construct the URL for the Observation endpoint +# observation_url = f"{fhir_base_url}Observation" +# +# # First, send a GET request to fetch all existing observations +# try: +# response = requests.get(observation_url) +# response.raise_for_status() # Check for any errors with the GET request +# data = response.json() +# +# # Step 1: Get the current highest ID +# max_id = 0 +# for entry in data['entry']: +# obs_id = int(entry['resource']['id']) # Convert id to integer to compare +# if obs_id > max_id: +# max_id = obs_id +# +# # Step 2: Set the new ID to max_id + 1 +# new_id = max_id + 1 +# +# # Step 3: Create the new blood pressure observation entry with the calculated new_id +# new_observation = { +# "fullUrl": f"http://34.132.86.17:8080/fhir/Observation/{new_id}", +# "resource": { +# "resourceType": "Observation", +# "id": str(new_id), # ID is now dynamically set +# "status": "final", +# "category": [{ +# "coding": [{ +# "system": "http://terminology.hl7.org/CodeSystem/observation-category", +# "code": "vital-signs", +# "display": "Vital Signs" +# }] +# }], +# "code": { +# "coding": [{ +# "system": "http://loinc.org", +# "code": "85354-9", +# "display": "Blood pressure systolic and diastolic" +# }], +# "text": "Blood pressure systolic and diastolic" +# }, +# "valueQuantity": { +# "value": f'{systolic}/{diastolic}', +# "unit": "mm Hg", +# "system": "http://unitsofmeasure.org", +# "code": "mm Hg" +# }, +# "effectiveDateTime": effective_date, +# "issued": effective_date +# }, +# "search": { +# "mode": "match" +# } +# } +# +# # Step 4: Append the new observation to the "entry" list +# data['entry'].append(new_observation) +# +# # Step 5: Send the POST request to the server with the updated bundle +# post_response = requests.post(observation_url, json=data) +# post_response.raise_for_status() +# +# # If the POST request was successful, return "Done" +# return "Done" +# +# except requests.exceptions.RequestException as e: +# print(f"An error occurred: {e}") +# return "Error occurred while recording blood pressure" + + +# test +print(record_blood_pressure(118, 77, '2023-11-13T10:00:00+00:00')) From e055d163c206be6f81557af711b0f8de46038ed2 Mon Sep 17 00:00:00 2001 From: GengLongling Date: Tue, 12 Nov 2024 14:00:45 -0800 Subject: [PATCH 10/15] Add files via upload reference solution 4 added --- MedAgentBench/ref_sol4.py | 73 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 MedAgentBench/ref_sol4.py diff --git a/MedAgentBench/ref_sol4.py b/MedAgentBench/ref_sol4.py new file mode 100644 index 0000000..0a5a4d0 --- /dev/null +++ b/MedAgentBench/ref_sol4.py @@ -0,0 +1,73 @@ +import requests +from datetime import datetime, timedelta + +# Base FHIR URL for your system +fhir_base_url = "http://34.132.86.17:8080/fhir/" + + +def get_latest_observation_value(mrn, current_time_str): + # Define the URL for the Observation resource + url = f'{fhir_base_url}Observation' + + # Parse the current time from the string + current_time = datetime.fromisoformat(current_time_str) + + # Define the 24 hours window (from current time) + twenty_four_hours_ago = current_time - timedelta(hours=24) + + try: + # Make the request to the Observation endpoint (no parameters) + response = requests.get(url) + response.raise_for_status() # Raises HTTPError for bad responses (4xx and 5xx) + + data = response.json() + + # if data['total'] == 0: + # return "No observations found" + + # Initialize to track the most recent observation + latest_observation = None + + # Loop through all entries in the response + for entry in data['entry']: + observation = entry['resource'] + + # Extract MRN from the observation and check if it matches the input MRN + if 'subject' in observation and observation['subject']['identifier']['value'] == mrn: + effective_date_time = datetime.fromisoformat(observation['effectiveDateTime']) + + # Check if the observation is within the last 24 hours + if twenty_four_hours_ago <= effective_date_time <= current_time: + if latest_observation is None or effective_date_time > latest_observation['effectiveDateTime']: + latest_observation = { + 'value': observation['valueQuantity']['value'], + 'unit': observation['valueQuantity']['unit'], + 'effectiveDateTime': effective_date_time + } + + # If we found a valid observation, return the value + if latest_observation: + return latest_observation['value'], latest_observation['unit'] + else: + return "No observation found within the last 24 hours" + + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") + return "Error retrieving data" + + +# test 1: +# mrn = "S6488980" +# current_time_str = "2023-11-13T10:15:00+00:00" # current time + +# test 2: +mrn = "S6488980" +current_time_str = "2023-03-25T21:05:00+00:00" # current time + +# Call the function and print the result +result = get_latest_observation_value(mrn, current_time_str) +if isinstance(result, tuple): + value, unit = result + print(f"The latest value within the last 24 hours is: {value} {unit}") +else: + print(result) From f525ca60dabc155a61a3da90ec75633514459b28 Mon Sep 17 00:00:00 2001 From: GengLongling Date: Tue, 12 Nov 2024 14:30:11 -0800 Subject: [PATCH 11/15] Add files via upload sample reference solution 5-7 --- MedAgentBench/ref_sol5.py | 109 ++++++++++++++++++++++++++++++++++++++ MedAgentBench/ref_sol6.py | 85 +++++++++++++++++++++++++++++ MedAgentBench/ref_sol7.py | 90 +++++++++++++++++++++++++++++++ 3 files changed, 284 insertions(+) create mode 100644 MedAgentBench/ref_sol5.py create mode 100644 MedAgentBench/ref_sol6.py create mode 100644 MedAgentBench/ref_sol7.py diff --git a/MedAgentBench/ref_sol5.py b/MedAgentBench/ref_sol5.py new file mode 100644 index 0000000..ec90a7f --- /dev/null +++ b/MedAgentBench/ref_sol5.py @@ -0,0 +1,109 @@ +import requests +from datetime import datetime +import re + +# Base FHIR URL for your system +fhir_base_url = "http://34.132.86.17:8080/fhir/" + + +def post_medication_request(mrn, dosage_instruction_text, dose_value, dose_unit, medication_codeable_concept): + # Define the URL for MedicationRequest + url = f'{fhir_base_url}MedicationRequest' + + # Fetch the existing bundle and its entries to find the maximum entry ID + # Assuming we have the current bundle data (here we're just simulating it) + # You should fetch this from the FHIR server with a GET request or keep it in memory + response = requests.get(f"{fhir_base_url}MedicationRequest") + + if response.status_code != 200: + return f"Error: Unable to retrieve existing MedicationRequests - {response.status_code}" + + # Extract the current entries from the bundle + bundle = response.json() + + # List to track the current entry IDs + existing_ids = [] + + # Iterate over the existing entries to extract the numeric part of the IDs + for entry in bundle.get('entry', []): + match = re.search(r'(\d+)', entry['fullUrl']) # Extract numbers from the URL (e.g., 'MedicationRequest/39054') + if match: + existing_ids.append(int(match.group(1))) # Append the numeric ID to the list + + # Calculate the new entry ID by finding the max of existing IDs and adding 1 + new_entry_id = max(existing_ids, default=0) + 1 # Default to 0 if no entries exist + + # Prepare the new entry that will be added to the 'entry' list + new_entry = { + "fullUrl": f"{fhir_base_url}MedicationRequest/{new_entry_id}", # New entry URL with calculated ID + "resource": { + "status": "active", # This assumes the medication request is active + "intent": "order", + "medicationCodeableConcept": { + "text": medication_codeable_concept # The medication passed in the input + }, + "subject": { + "identifier": { + "system": "http://terminology.hl7.org/CodeSystem/v2-0203", + "value": mrn # Patient's MRN + } + }, + "authoredOn": datetime.now().isoformat(), # Current timestamp + "dosageInstruction": [ + { + "text": dosage_instruction_text, + "timing": { + "code": { + "text": dosage_instruction_text # Dosage instruction text + } + }, + "doseAndRate": [ + { + "doseQuantity": { + "value": dose_value, # Dosage value (quantity) + "unit": dose_unit # Unit (e.g., mg, mL) + } + } + ] + } + ] + }, + "search": { + "mode": "match" + } + } + + # Define the request body as a Bundle with the new entry + bundle = { + "resourceType": "Bundle", + "type": "transaction", + "entry": [new_entry] # Append the new entry into the 'entry' list + } + + try: + # Make the POST request to create the MedicationRequest + response = requests.post(url, json=bundle) + response.raise_for_status() # Raises HTTPError for bad responses (4xx and 5xx) + + # If the request is successful, return the full URL of the newly created resource + if response.status_code == 201: + new_entry_response = response.json() + return new_entry_response.get('entry', [{}])[0].get('fullUrl', 'URL not available') + else: + return f"Error: {response.status_code} - {response.text}" + + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") + return "Error making the POST request" + + +# test: +mrn = "S1023381" # Example MRN +dosage_instruction_text = "BID" # Dosage instruction (e.g., "twice a day") +dose_value = 350 # Dosage amount (e.g., 350) +dose_unit = "mg" # Unit of measurement (e.g., mg) +medication_codeable_concept = "carisoprodol (Soma) 350 mg tablet" # Example medication + +# Call the function and print the result +result = post_medication_request(mrn, dosage_instruction_text, dose_value, dose_unit, medication_codeable_concept) +print(result) diff --git a/MedAgentBench/ref_sol6.py b/MedAgentBench/ref_sol6.py new file mode 100644 index 0000000..0c041ac --- /dev/null +++ b/MedAgentBench/ref_sol6.py @@ -0,0 +1,85 @@ +# import requests +# from datetime import datetime, timedelta +# +# # Base FHIR URL for your system +# fhir_base_url = "http://34.132.86.17:8080/fhir/" +# +# # Function to create a new Procedure entry +# def post_procedure_request(medication_request_id, loinc_code, loinc_display, service_request_id=None): +# # Define the URL for Procedure creation +# url = f'{fhir_base_url}Procedure' +# +# # Generate a new ID for the Procedure entry (for simplicity, assuming next available integer ID) +# # You can change this to any logic you'd like. Here it's just a static incremented ID +# new_id = 1000 # Starting point for new procedure IDs, can be modified for dynamic handling +# +# # Set the date for "completed" (next day if no service) +# if service_request_id is None: +# complete_time = (datetime.now() + timedelta(days=1)).isoformat() # Set to the next day if no service request +# else: +# complete_time = datetime.now().isoformat() # Current time if service exists +# +# # Prepare the Procedure entry data +# procedure_entry = { +# "resourceType": "Procedure", +# "id": str(new_id), # New ID for Procedure entry +# "status": "completed", # Procedure status as completed +# "code": { +# "coding": [ +# { +# "system": "http://loinc.org", # LOINC Code System +# "code": loinc_code, # LOINC Code for Serum Magnesium Test +# "display": loinc_display # Display name for the test +# } +# ], +# "text": "Morning serum magnesium test" # Free text description +# }, +# "subject": { +# "reference": f"Patient/{medication_request_id}" # Reference to the Patient (using MedicationRequest ID) +# }, +# "encounter": { +# "reference": f"Encounter/{service_request_id}" if service_request_id else None # Reference to ServiceRequest ID (optional) +# }, +# "performedDateTime": complete_time, # Set the performed datetime (next day if no service) +# "note": [ +# { +# "text": "Morning serum magnesium test to be completed next day." +# } +# ] +# } +# +# # Prepare the request body as a Bundle with the new Procedure entry +# bundle = { +# "resourceType": "Bundle", +# "type": "transaction", +# "entry": [ +# { +# "resource": procedure_entry +# } +# ] +# } +# +# try: +# # Make the POST request to create the Procedure entry +# response = requests.post(url, json=bundle) +# response.raise_for_status() # Raises HTTPError for bad responses (4xx and 5xx) +# +# # If the request is successful, return 'done' +# if response.status_code == 201: +# return "done" +# else: +# return f"Error: {response.status_code} - {response.text}" +# +# except requests.exceptions.RequestException as e: +# print(f"An error occurred: {e}") +# return "Error making the POST request" +# +# # test: +# medication_request_id = "39054" # Example MedicationRequest ID +# service_request_id = "39056" # Example ServiceRequest ID (set to None if no service request) +# loinc_code = "2503-9" # LOINC code for Serum Magnesium Test +# loinc_display = "Serum magnesium (test)" # Display for the Serum Magnesium Test +# +# # Call the function to create a Procedure entry +# result = post_procedure_request(medication_request_id, loinc_code, loinc_display, service_request_id=None) +# print(result) # Expected output: "done" or error message diff --git a/MedAgentBench/ref_sol7.py b/MedAgentBench/ref_sol7.py new file mode 100644 index 0000000..3f68103 --- /dev/null +++ b/MedAgentBench/ref_sol7.py @@ -0,0 +1,90 @@ +# import requests +# from datetime import datetime +# +# # Base FHIR URL for your system +# fhir_base_url = "http://34.132.86.17:8080/fhir/" +# +# +# def post_service_request(mrn, referral_text, snomed_code, display_text): +# # Define the URL for ServiceRequest +# url = f'{fhir_base_url}ServiceRequest' +# +# # Fetch the current entries from the FHIR server to determine the highest current ID +# response = requests.get(url) +# +# if response.status_code != 200: +# print(f"Error fetching existing ServiceRequests: {response.status_code}") +# return "Error fetching existing entries" +# +# data = response.json() +# +# # Find the highest entry ID +# existing_ids = [int(entry['resource']['id']) for entry in data['entry'] if entry['resource']['id'].isdigit()] +# new_entry_id = max(existing_ids) + 1 if existing_ids else 1 +# +# # Prepare the new ServiceRequest entry with SNOMED code and display +# new_entry = { +# "resourceType": "ServiceRequest", +# "id": str(new_entry_id), # Assign the calculated ID +# "status": "active", # The status of the service request +# "intent": "order", # The intent of the request (an order) +# "code": { +# "coding": [ +# { +# "system": "http://snomed.info/sct", # SNOMED CT system URL +# "code": snomed_code, # SNOMED code from input +# "display": display_text # Display text from input +# } +# ], +# "text": display_text # Free text description +# }, +# "subject": { +# "identifier": { +# "system": "http://terminology.hl7.org/CodeSystem/v2-0203", # Identifier system +# "value": mrn # Patient's MRN +# } +# }, +# "authoredOn": datetime.now().isoformat(), # Current timestamp +# "note": [ +# { +# "text": referral_text # Referral details (free text) +# } +# ] +# } +# +# # Define the request body as a Bundle with the new entry +# bundle = { +# "resourceType": "Bundle", +# "type": "transaction", +# "entry": [ +# { +# "resource": new_entry +# } +# ] +# } +# +# try: +# # Make the POST request to create the ServiceRequest +# response = requests.post(url, json=bundle) +# response.raise_for_status() # Raises HTTPError for bad responses (4xx and 5xx) +# +# # If the request is successful, return 'done' +# if response.status_code == 201: +# return "done" +# else: +# return f"Error: {response.status_code} - {response.text}" +# +# except requests.exceptions.RequestException as e: +# print(f"An error occurred: {e}") +# return "Error making the POST request" +# +# +# # test: +# mrn = "S1023381" # Example MRN +# referral_text = "Acute left knee injury, imaging showing ACL tear." # Referral details (free text) +# snomed_code = "306181000000106" # SNOMED code for orthopedic surgery referral +# display_text = "Order orthopedic surgery referral" # Display text for the code +# +# # Call the function and print the result +# result = post_service_request(mrn, referral_text, snomed_code, display_text) +# print(result) From c015b04cab77bdcfd1f93fb82e6492df3af47df5 Mon Sep 17 00:00:00 2001 From: Danny Park Date: Fri, 15 Nov 2024 13:19:35 -0800 Subject: [PATCH 12/15] Initial Implementation of utilizing the Task class in Extend AgentBench --- MedAgentBench/task.py | 56 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 MedAgentBench/task.py diff --git a/MedAgentBench/task.py b/MedAgentBench/task.py new file mode 100644 index 0000000..5a23f38 --- /dev/null +++ b/MedAgentBench/task.py @@ -0,0 +1,56 @@ +import asyncio +from typing import List, Dict, Any +from pydantic import BaseModel +from agentbench import Task, TaskSampleExecutionResult, SampleStatus, TaskOutput, ChatHistoryItem, Session + +class FHIRTask(Task): + def __init__(self, test_data: List[Dict[str, Any]], *args, **kwargs): + """ + Initialize the FHIRTask with test data. + :param test_data: List of test cases, each containing 'id', 'instruction', and 'sol' (expected solution). + """ + super().__init__(name="fhir-task", *args, **kwargs) + self.data = test_data # Dynamically load test data + + def get_indices(self) -> List[int]: + # Return indices for each test sample + return list(range(len(self.data))) + + async def start_sample(self, index: int, session: Session) -> TaskSampleExecutionResult: + # Get the test sample + sample = self.data[index] + instruction = sample["instruction"] + expected_sol = sample["sol"] + + # Parse the instruction to extract patient name and DOB + name = instruction.split("name ")[1].split(" and DOB")[0].strip() + dob = instruction.split("DOB of ")[1].split("?")[0].strip() + + # Call your existing task1_sol function to get the result + result = task1_sol(name, dob) + + # Compare the result with the expected solution + status = SampleStatus.COMPLETED if result == expected_sol else SampleStatus.TASK_ERROR + + # Return the result and status + return TaskSampleExecutionResult( + status=status, + result={ + "id": sample["id"], + "instruction": instruction, + "expected": expected_sol, + "actual": result, + "status": "correct" if result == expected_sol else "incorrect" + } + ) + + def calculate_overall(self, results: List[TaskOutput]) -> Dict[str, Any]: + # Calculate the overall score and error rate + total_samples = len(results) + correct_samples = sum(1 for result in results if result.result["status"] == "correct") + score = correct_samples / total_samples if total_samples > 0 else 0 + return { + "total_samples": total_samples, + "correct_samples": correct_samples, + "score": score + } From 8742080bcb4f8ace55572aa5fb81aae8918c885e Mon Sep 17 00:00:00 2001 From: Danny Park Date: Tue, 10 Dec 2024 22:54:35 -0800 Subject: [PATCH 13/15] FHIR API version of the original code --- configs/agents/api_agents.yaml | 8 ++ configs/assignments/default.yaml | 4 +- configs/tasks/dbbench.yaml | 2 +- src/server/tasks/dbbench/Interaction.py | 164 +++++++++------------ src/server/tasks/dbbench/__init__.py | 180 +++++++----------------- 5 files changed, 135 insertions(+), 223 deletions(-) diff --git a/configs/agents/api_agents.yaml b/configs/agents/api_agents.yaml index 30a46f4..dbf5fd6 100644 --- a/configs/agents/api_agents.yaml +++ b/configs/agents/api_agents.yaml @@ -21,3 +21,11 @@ text-davinci-002: body: model: "text-davinci-002" max_tokens: 512 + +gpt-3.5-turbo: + import: "./openai-chat.yaml" + parameters: + name: "gpt-3.5-turbo" + body: + model: "gpt-3.5-turbo" + max_tokens: 512 diff --git a/configs/assignments/default.yaml b/configs/assignments/default.yaml index aec7229..1a8e585 100644 --- a/configs/assignments/default.yaml +++ b/configs/assignments/default.yaml @@ -5,11 +5,11 @@ concurrency: dbbench-std: 5 os-std: 5 agent: - gpt-3.5-turbo-0613: 5 + gpt-3.5-turbo: 5 assignments: # List[Assignment] | Assignment - agent: # "task": List[str] | str , "agent": List[str] | str - - gpt-3.5-turbo-0613 + - gpt-3.5-turbo task: - dbbench-std - os-std diff --git a/configs/tasks/dbbench.yaml b/configs/tasks/dbbench.yaml index 7e0b7cd..57f84ab 100644 --- a/configs/tasks/dbbench.yaml +++ b/configs/tasks/dbbench.yaml @@ -12,4 +12,4 @@ dbbench-dev: dbbench-std: parameters: name: dbbench-std - data_file: "data/dbbench/standard.jsonl" + data_file: "MedAgentBench/test_data_nov8.json" diff --git a/src/server/tasks/dbbench/Interaction.py b/src/server/tasks/dbbench/Interaction.py index 4d11d76..767fa41 100644 --- a/src/server/tasks/dbbench/Interaction.py +++ b/src/server/tasks/dbbench/Interaction.py @@ -1,109 +1,85 @@ -import docker -import mysql.connector +import requests import random -import socket import time -from docker.models import containers -from typing import Optional, Union, Sequence, Dict, Any +from typing import Optional, Dict, Any -class Container: - port = 13000 - password = "password" +class FHIRClient: + def __init__(self, base_url: str = "http://34.170.56.151:8080/fhir/"): + self.base_url = base_url + self.session = requests.Session() + self.session.headers.update({"Content-Type": "application/fhir+json"}) + self.verify_connection() - def __init__(self, image: str = "mysql"): - self.deleted = False - self.image = image - self.client = docker.from_env() - p = Container.port + random.randint(0, 10000) - while self.is_port_open(p): - p += random.randint(0, 20) - self.port = p - self.container: containers.Container = self.client.containers.run( - image, - name=f"mysql_{self.port}", - environment={ - "MYSQL_ROOT_PASSWORD": self.password, - }, - ports={"3306": self.port}, - detach=True, - tty=True, - stdin_open=True, - remove=True, - ) - - time.sleep(1) - - retry = 0 - while True: - try: - self.conn = mysql.connector.connect( - host="127.0.0.1", - user="root", - password=self.password, - port=self.port, - pool_reset_session=True, - ) - except mysql.connector.errors.OperationalError: - time.sleep(1) - except mysql.connector.InterfaceError: - if retry > 10: - raise - time.sleep(5) - else: - break - retry += 1 - - def delete(self): - self.container.stop() - self.deleted = True + def verify_connection(self): + try: + response = self.session.get(f"{self.base_url}metadata") + response.raise_for_status() + print("Connected to FHIR server successfully!") + except requests.exceptions.RequestException as e: + print(f"Error connecting to FHIR server: {e}") + raise - def __del__(self): + def create_resource(self, resource_type: str, resource_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Create a resource on the FHIR server. + :param resource_type: The FHIR resource type (e.g., 'Patient', 'Observation'). + :param resource_data: The resource data in FHIR JSON format. + :return: The created resource or None if there was an error. + """ + url = f"{self.base_url}{resource_type}" try: - if not self.deleted: - self.delete() - except: - pass + response = self.session.post(url, json=resource_data) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"Error creating resource: {e}") + return None - def execute( - self, - sql: str, - database: str = None, - data: Union[Sequence, Dict[str, Any]] = (), - ) -> Optional[str]: - self.conn.reconnect() + def get_resource(self, resource_type: str, resource_id: str) -> Optional[Dict[str, Any]]: + """ + Retrieve a resource from the FHIR server. + :param resource_type: The FHIR resource type (e.g., 'Patient'). + :param resource_id: The ID of the resource to retrieve. + :return: The resource data or None if there was an error. + """ + url = f"{self.base_url}{resource_type}/{resource_id}" try: - with self.conn.cursor() as cursor: - if database: - cursor.execute(f"use `{database}`;") - cursor.fetchall() - cursor.execute(sql, data, multi=True) - result = cursor.fetchall() - result = str(result) - self.conn.commit() - except Exception as e: - result = str(e) - if len(result) > 800: - result = result[:800] + "[TRUNCATED]" - return result + response = self.session.get(url) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"Error retrieving resource: {e}") + return None - def is_port_open(self, port) -> bool: + def search_resources(self, resource_type: str, search_params: Dict[str, str]) -> Optional[Dict[str, Any]]: + """ + Search for resources on the FHIR server. + :param resource_type: The FHIR resource type (e.g., 'Patient'). + :param search_params: A dictionary of search parameters. + :return: The search results or None if there was an error. + """ + url = f"{self.base_url}{resource_type}" try: - self.client.containers.get(f"mysql_{port}") - return True - except Exception: - pass - # Create a socket object - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + response = self.session.get(url, params=search_params) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"Error searching resources: {e}") + return None + def delete_resource(self, resource_type: str, resource_id: str) -> bool: + """ + Delete a resource from the FHIR server. + :param resource_type: The FHIR resource type (e.g., 'Patient'). + :param resource_id: The ID of the resource to delete. + :return: True if deletion was successful, False otherwise. + """ + url = f"{self.base_url}{resource_type}/{resource_id}" try: - # Try to connect to the specified port - sock.connect(("localhost", port)) - # If the connection succeeds, the port is occupied + response = self.session.delete(url) + response.raise_for_status() return True - except ConnectionRefusedError: - # If the connection is refused, the port is not occupied + except requests.exceptions.RequestException as e: + print(f"Error deleting resource: {e}") return False - finally: - # Close the socket - sock.close() diff --git a/src/server/tasks/dbbench/__init__.py b/src/server/tasks/dbbench/__init__.py index 1e1171b..57a45f9 100644 --- a/src/server/tasks/dbbench/__init__.py +++ b/src/server/tasks/dbbench/__init__.py @@ -4,58 +4,34 @@ from src.server.task import Task, Session from src.typings import TaskOutput, SampleStatus, AgentOutputStatus -from .Interaction import Container +from .Interaction import FHIRClient big_prompt = """ -I will ask you a question, then you should help me operate a MySQL database with SQL to answer the question. +I will ask you a question, then you should help me operate a FHIR server to answer the question. You have to explain the problem and your solution to me and write down your thoughts. After thinking and explaining thoroughly, every round you can choose to operate or to answer. -your operation should be like this: +Your operation should be like this: Action: Operation -```sql -SELECT * FROM table WHERE condition; +```json +{ "resourceType": "Patient", "query": { "name": "John" } } ``` -You MUST put SQL in markdown format without any other comments. Your SQL should be in one line. -Every time you can only execute one SQL statement. I will only execute the statement in the first SQL code block. Every time you write a SQL, I will execute it for you and give you the output. -If you are done operating, and you want to commit your final answer, then write down: -Action: Answer -Final Answer: ["ANSWER1", "ANSWER2", ...] -DO NOT write this pattern unless you are sure about your answer. I expect an accurate and correct answer. -Your answer should be accurate. Your answer must be exactly the same as the correct answer. -If the question is about modifying the database, then after done operation, your answer field can be anything. -If your response cannot match any pattern I mentioned earlier, you will be judged as FAIL immediately. -Your input will be raw MySQL response, you have to deal with it by yourself. +You MUST put JSON in markdown format without any other comments. +Your query must be valid FHIR JSON. Every time you can only execute one operation. I will only execute the operation in the first JSON code block. +You are allowed to use the following FHIR API calls: Patient.Search and Observation.Search. +If you are done operating, and you want to commit your final answer, then write down: Action: Answer Final Answer: ["ANSWER1", "ANSWER2", ...] +DO NOT write this pattern unless you are sure about your answer. I expect an accurate and correct answer. +Your answer should be accurate. Your answer must be exactly the same as the correct answer. +If the question is about modifying the database, then after done operation, your answer field can be anything. +If your response cannot match any pattern I mentioned earlier, you will be judged as FAIL immediately. +Your input will be raw FHIR server response, and you have to deal with it by yourself. """ -def build_init_sql(entry): - name = entry["table"]["table_name"] - columns = ",".join( - [ - f"`{column['name']}` TEXT" - for column in entry["table"]["table_info"]["columns"] - ] - ) - column_names = ",".join( - [f"`{column['name']}`" for column in entry["table"]["table_info"]["columns"]] - ) - items = [] - items_data = () - for row in entry["table"]["table_info"]["rows"]: - item = "(" - for col in row: - item += "%s," - items_data += (col,) - item = item[:-1] + ")" - items.append(item) - items = ",".join(items) - sql = f"""CREATE DATABASE IF NOT EXISTS `{name}`; -USE `{name}`; -CREATE TABLE IF NOT EXISTS `{name}` ({columns}); -INSERT INTO `{name}` ({column_names}) VALUES {items}; -COMMIT; -""" - return sql, items_data +def build_init_resources(entry): + """ Builds FHIR resource initialization data from the entry. """ + resources = entry["table"]["table_info"]["rows"] + resource_type = entry["table"]["table_name"] + return resource_type, resources class DBBench(Task): @@ -72,50 +48,59 @@ def __init__(self, **configs): data = [json.loads(line) for line in f.readlines()] for entry in data: - if entry["type"][0] in ("INSERT", "DELETE", "UPDATE"): - ans = entry.pop("answer_md5") - else: - ans = entry.pop("label") + ans = entry.pop("sol") inp = entry self.dataset.append((inp, ans)) - self.container = Container() + self.server = FHIRClient() def get_indices(self) -> List[Any]: return list(range(len(self.dataset))) async def start_sample(self, index: int, session: Session) -> TaskOutput: - entry = self.dataset[index][0] - container = self.container - init_sql, init_data = build_init_sql(entry) - container.execute(init_sql, data=init_data) - db = entry["table"]["table_name"] + entry = self.dataset[index][0] # (inp, ans) + fhir_client = self.server + resource_type, resources = build_init_resources(entry) + + # Initialize resources + for resource in resources: + fhir_client.create_resource(resource_type, resource) + session.inject({"role": "user", "content": big_prompt}) session.inject({"role": "agent", "content": "Ok."}) - prompt = entry["description"] + "\n" + entry["add_description"] + prompt = entry["instruction"] + "\n" + entry["context"] session.inject({"role": "user", "content": prompt}) + res = (await session.action()).content or "" answer = "" finish_reason = SampleStatus.COMPLETED + try: action = re.search(r"Action: (.*?)\n", res) rounds = 0 while action and action.group(1) == "Operation" and rounds < self.max_round: - res = re.search(r"```sql\n([\s\S]*?)\n```", res) + res = re.search(r"```json\n([\s\S]*?)\n```", res) if not res: finish_reason = SampleStatus.AGENT_VALIDATION_FAILED break - sql = res.group(1).strip() - sql = sql.replace("\n", " ") - response = container.execute(sql, db) + fhir_query = json.loads(res.group(1).strip()) + print(fhir_query) + resource_type = fhir_query.get("resourceType") + query = fhir_query.get("query") + + # Perform the FHIR search operation + response = fhir_client.search_resources(resource_type, query) + if response: - session.inject({"role": "user", "content": response}) + session.inject({"role": "user", "content": json.dumps(response)}) else: - session.inject({"role": "user", "content": ""}) + session.inject({"role": "user", "content": "No results found."}) + res = await session.action() if res.status == AgentOutputStatus.AGENT_CONTEXT_LIMIT: finish_reason = SampleStatus.AGENT_CONTEXT_LIMIT break + res = res.content action = re.search(r"Action: (.*?)\n", res) rounds += 1 @@ -134,19 +119,7 @@ async def start_sample(self, index: int, session: Session) -> TaskOutput: finish_reason = SampleStatus.UNKNOWN else: error = "" - if entry["type"][0] in ("INSERT", "DELETE", "UPDATE"): - columns = ",".join( - [ - f"`{column['name']}`" - for column in entry["table"]["table_info"]["columns"] - ] - ) - md5_query = ( - f"select md5(group_concat(rowhash order by rowhash)) as hash " - f"from( SELECT substring(MD5(CONCAT_WS(',', {columns})), 1, 5) AS rowhash FROM `{db}`) as sub;" - ) - answer = container.execute(md5_query, db) - container.execute(f"drop database `{db}`") + return TaskOutput( status=finish_reason, result={ @@ -171,6 +144,7 @@ def calculate_overall(self, results: List[TaskOutput]) -> Dict[str, Any]: @property def metrics(self) -> Dict[str, Callable[[List[Dict[str, Any]], List[str]], float]]: + # Same as before but adapted for FHIR context def factory(typ): def acc(inp: List[Dict[str, Any]], tar: List[str]) -> float: correct = 0 @@ -179,31 +153,9 @@ def acc(inp: List[Dict[str, Any]], tar: List[str]) -> float: if not entry: continue ans, t = entry["answer"], entry["type"] - if t != typ and not ( - typ == "SELECT" and t not in ("INSERT", "UPDATE") - ): + if t != typ: continue - if t in ("INSERT", "DELETE", "UPDATE"): - correct += ans == cor - else: - try: - ans = list(eval(ans)) - except: - ans = [ans] - if len(ans) == 1 and len(cor) == 1: - try: - correct += float(ans[0]) == float(cor[0]) - except (ValueError, TypeError): - correct += ans[0] == cor[0] - else: - print(ans, cor) - else: - try: - cor = set(cor) - ans = set(ans) - correct += ans == cor - except: - pass + correct += ans == cor total += 1 if total == 0: print(f"WARNING: {typ} does not exist!") @@ -212,35 +164,11 @@ def acc(inp: List[Dict[str, Any]], tar: List[str]) -> float: return acc - types = [ - "other", - "counting", - "comparison", - "ranking", - "aggregation-SUM", - "aggregation-MIN", - "aggregation-MAX", - "aggregation-AVG", - "SELECT", - "INSERT", - "UPDATE", - ] - - ret = {} - for typ in types: - ret[typ + "_accuracy"] = factory(typ) - - ret["overall_cat_accuracy"] = ( - lambda inp, tar: sum( - [ - ret[typ + "_accuracy"](inp, tar) - for typ in ("SELECT", "INSERT", "UPDATE") - ] - ) - / 3 - ) - + types = ["SELECT", "CREATE", "DELETE", "UPDATE"] + ret = {typ + "_accuracy": factory(typ) for typ in types} + ret["overall_accuracy"] = lambda inp, tar: sum(ret.values()) / len(ret) return ret def release(self): - self.container.delete() + # No explicit release needed for FHIRClient + pass From 82af6b2185cc2a699a8bb1d57946de119f2b528b Mon Sep 17 00:00:00 2001 From: Danny Park Date: Fri, 13 Dec 2024 07:40:58 -0800 Subject: [PATCH 14/15] slight updates --- configs/agents/api_agents.yaml | 6 +++--- configs/agents/openai-chat.yaml | 2 +- configs/assignments/default.yaml | 8 ++++---- src/server/tasks/dbbench/Interaction.py | 1 + 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/configs/agents/api_agents.yaml b/configs/agents/api_agents.yaml index dbf5fd6..67b1254 100644 --- a/configs/agents/api_agents.yaml +++ b/configs/agents/api_agents.yaml @@ -22,10 +22,10 @@ text-davinci-002: model: "text-davinci-002" max_tokens: 512 -gpt-3.5-turbo: +gpt-4o-mini: import: "./openai-chat.yaml" parameters: - name: "gpt-3.5-turbo" + name: "gpt-4o-mini" body: - model: "gpt-3.5-turbo" + model: "gpt-4o-mini" max_tokens: 512 diff --git a/configs/agents/openai-chat.yaml b/configs/agents/openai-chat.yaml index 53eff77..912d478 100644 --- a/configs/agents/openai-chat.yaml +++ b/configs/agents/openai-chat.yaml @@ -3,7 +3,7 @@ parameters: url: https://api.openai.com/v1/chat/completions headers: Content-Type: application/json - Authorization: Bearer <% PUT-YOUR-OPENAI-KEY-HERE %> + Authorization: Bearer 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 body: temperature: 0 prompter: diff --git a/configs/assignments/default.yaml b/configs/assignments/default.yaml index 1a8e585..e5cb694 100644 --- a/configs/assignments/default.yaml +++ b/configs/assignments/default.yaml @@ -3,15 +3,15 @@ import: definition.yaml concurrency: task: dbbench-std: 5 - os-std: 5 + # os-std: 5 agent: - gpt-3.5-turbo: 5 + gpt-4o-mini: 5 assignments: # List[Assignment] | Assignment - agent: # "task": List[str] | str , "agent": List[str] | str - - gpt-3.5-turbo + - gpt-4o-mini task: - dbbench-std - - os-std + # - os-std output: "outputs/{TIMESTAMP}" diff --git a/src/server/tasks/dbbench/Interaction.py b/src/server/tasks/dbbench/Interaction.py index 767fa41..eccbff3 100644 --- a/src/server/tasks/dbbench/Interaction.py +++ b/src/server/tasks/dbbench/Interaction.py @@ -61,6 +61,7 @@ def search_resources(self, resource_type: str, search_params: Dict[str, str]) -> """ url = f"{self.base_url}{resource_type}" try: + search_params["_format"] = "json" response = self.session.get(url, params=search_params) response.raise_for_status() return response.json() From 8e1dd31808baa110c26e334d1fcb9947c0078aa9 Mon Sep 17 00:00:00 2001 From: Danny Park Date: Fri, 13 Dec 2024 07:41:49 -0800 Subject: [PATCH 15/15] slight adjustments --- configs/agents/openai-chat.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/agents/openai-chat.yaml b/configs/agents/openai-chat.yaml index 912d478..53eff77 100644 --- a/configs/agents/openai-chat.yaml +++ b/configs/agents/openai-chat.yaml @@ -3,7 +3,7 @@ parameters: url: https://api.openai.com/v1/chat/completions headers: Content-Type: application/json - Authorization: Bearer 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 + Authorization: Bearer <% PUT-YOUR-OPENAI-KEY-HERE %> body: temperature: 0 prompter: