diff --git a/openalex.py b/openalex.py index 429618e..8d73aee 100644 --- a/openalex.py +++ b/openalex.py @@ -7,11 +7,12 @@ from collections import defaultdict from datetime import datetime +import matplotlib.pyplot as plt import numpy import pandas from setup import get_base_parser, get_verbosity_parser, setup_logger -from utils import create_session +from utils import create_session, region_map DEFAULT_PROTOCOL = "(clinicaltrial[Filter] NOT editorial)" @@ -45,15 +46,18 @@ def read_dataset(fpath): df = pandas.read_csv( fpath, - delimiter="\t", + delimiter="%%", names=[ "pmid", + "title", "accession", "abstract", - "pubdate", + "pub_types", + "journal_date", + "epub_date", ], - parse_dates=["pubdate"], - na_values="-", + parse_dates=["journal_date", "epub_date"], + dtype={"pmid": str}, ) return df @@ -73,11 +77,14 @@ def build_cohort(args): logging.error("Is edirect installed?") # The date MUST be included in the query with [dp] (rather than # -mindate -maxdate) in order for 10k+ queries to work - # cmd = f"efetch -db pubmed -id 30553130 -format xml | xtract -pattern PubmedArticle -sep '|' -def '-' -element MedlineCitation/PMID -element AccessionNumber -element AbstractText -block PubDate -sep '-' -element Year,Month,Day > {output_file}" - cmd = f"esearch -db pubmed -query '({start_date}:{end_date}[dp]) AND ({protocol})' | efetch -format xml | xtract -pattern PubmedArticle -sep '|' -def '-' -element MedlineCitation/PMID -element AccessionNumber -element AbstractText -block PubDate -sep '-' -element Year,Month,Day > {output_file}" + # cmd = f"efetch -db pubmed -id 30553130 -format xml" + # cmd += f" | xtract -pattern PubmedArticle -def '' -sep '|' -tab '%%' -element MedlineCitation/PMID -element ArticleTitle -element AccessionNumber -element AbstractText -element PublicationType -block Journal -sep '-' -tab '%%' -element Year,Month -block ArticleDate -sep '-' -element Year,Month,Day > {output_file}" + + cmd = f"esearch -db pubmed -query '({start_date}:{end_date}[dp]) AND ({protocol})' | efetch -format xml | xtract -pattern PubmedArticle -def '' -sep '|' -tab '%%' -element MedlineCitation/PMID -element ArticleTitle -element AccessionNumber -element AbstractText -element PublicationType -block Journal -sep '-' -tab '%%' -element Year,Month -block ArticleDate -sep '-' -element Year,Month,Day > {output_file}" logging.info(cmd) edirect.pipeline(cmd) + # could do as stringio df = read_dataset(output_file) df = split_bar(df, columns=["accession"]) df = get_ids_from_abstract(df) @@ -198,6 +205,27 @@ def query_openalex(args): ) +def make_site_map(args): + input_file = args.input_file + output_file = args.output_file + last_author = args.last_author + + df = pandas.read_csv(input_file) + if last_author: + df[df.author_position == "last"] + title = "Last Author Affiliation by WHO Region: Trials in Pubmed 2018-2023" + else: + df[df.author_position == "first"] + title = "First Author Affiliation by WHO Region: Trials in Pubmed 2018-2023" + + counts = ( + df.groupby(["pmid", "country"]).author_name.nunique().groupby("country").sum() + ) + region_map(counts) + plt.suptitle(title) + plt.savefig(output_file, bbox_inches="tight") + + if __name__ == "__main__": verbosity_parser = get_verbosity_parser() base_parser = get_base_parser() @@ -247,6 +275,19 @@ def query_openalex(args): required=True, help="Output file name to write openalex cohort", ) + + site_map_parser = subparsers.add_parser("site_map", parents=[base_parser]) + site_map_parser.add_argument( + "--output-file", + type=pathlib.Path, + required=True, + help="Output file to save map", + ) + site_map_parser.add_argument( + "--last-author", action="store_true", help="Use last author rather than first" + ) + site_map_parser.set_defaults(func=make_site_map) + args = openalex_parser.parse_args() if hasattr(args, "func"): setup_logger(args.verbosity) diff --git a/pubmed.py b/pubmed.py index 79446fb..7112cd9 100644 --- a/pubmed.py +++ b/pubmed.py @@ -2,16 +2,19 @@ import logging import multiprocessing as mp import os +import pathlib import re import shutil import sys from functools import partial from io import StringIO +import matplotlib.pyplot as plt import pandas -from setup import get_env_setting, get_full_parser, setup_logger +from setup import get_env_setting, get_full_parser, get_verbosity_parser, setup_logger from utils import ( + REGISTRY_MAP, create_session, filter_unindexed, load_trials, @@ -26,7 +29,7 @@ def analyse_metadata(args): input_file, parse_dates=["Date_enrollment", "epub_date", "journal_date"], index_col=[0], - dtype={"pmid": "str"}, + dtype={"pmid": str}, ) df = df[~(df.title.str.contains("protocol", flags=re.IGNORECASE) is True)] import code @@ -136,7 +139,8 @@ def add_pubmed_metadata(args): n = args.chunk_size df = pandas.read_csv(input_file, dtype={"pmids": "str"}) - df["source"] = df.trial_id.str[0:3] + df["source"] = df.trial_id.str[0:3].str.upper() + df.loc[df.source.str.startswith("NL"), "source"] = "NTR" unique_pmids = df.pmids.dropna().unique() try: sys.path.insert(1, os.path.dirname(shutil.which("xtract"))) @@ -166,20 +170,56 @@ def add_pubmed_metadata(args): df.to_csv(output_file) +def reported_over_time(args): + input_file = args.input_file + df = pandas.read_csv(input_file) + + fig, ax = plt.subplots(figsize=(12, 6)) + df["enrollment_year"] = pandas.to_datetime(df.Date_enrollment).dt.strftime("%Y") + df["source"] = df.source.map(REGISTRY_MAP) + counts = df.groupby(["enrollment_year", "source"]).agg( + {"trial_id": "count", "pmids": "count"} + ) + counts["pcnt"] = 100 * (counts.pmids / counts.trial_id) + to_plot = counts.reset_index().pivot(index="source", columns=["enrollment_year"])[ + "pcnt" + ] + to_plot = to_plot.sort_values("2014", ascending=False) + to_plot.plot.bar(ax=ax) + plt.legend(loc="upper left", bbox_to_anchor=(1, 1), title="Enrollment Year") + plt.title( + "Percent of trials with trial id in Pubmed Accession or Abstract by registry" + ) + plt.xlabel("Registry") + plt.ylabel("Percent (%)") + plt.xticks(rotation=45) + plt.savefig("percent_reported.png", bbox_inches="tight") + + if __name__ == "__main__": + verb = get_verbosity_parser() parent = get_full_parser() - pubmed_parser = argparse.ArgumentParser(parents=[parent]) + pubmed_parser = argparse.ArgumentParser() subparsers = pubmed_parser.add_subparsers() - query_parser = subparsers.add_parser("query") + query_parser = subparsers.add_parser("query", parents=[parent]) query_parser.set_defaults(func=trials_in_pubmed) - metadata_parser = subparsers.add_parser("metadata") + metadata_parser = subparsers.add_parser("metadata", parents=[parent]) metadata_parser.set_defaults(func=add_pubmed_metadata) - analyse_parser = subparsers.add_parser("analyse") + analyse_parser = subparsers.add_parser("analyse", parents=[parent]) analyse_parser.set_defaults(func=analyse_metadata) + reported_parser = subparsers.add_parser("percent-reported", parents=[verb]) + reported_parser.add_argument( + "--input-file", + required=True, + type=pathlib.Path, + help="Cohort file with discovered pmids", + ) + reported_parser.set_defaults(func=reported_over_time) + args = pubmed_parser.parse_args() if hasattr(args, "func"): setup_logger(args.verbosity) diff --git a/query_ror.py b/query_ror.py index af07428..0512369 100644 --- a/query_ror.py +++ b/query_ror.py @@ -7,18 +7,34 @@ import time import urllib +import matplotlib.pyplot as plt import pandas - -from setup import get_base_parser, get_full_parser, setup_logger +import plotly.graph_objects as go +import schemdraw +from schemdraw import flow + +from setup import ( + get_base_parser, + get_full_parser, + get_results_parser, + get_verbosity_parser, + setup_logger, +) from utils import ( add_suffix, append_safe, create_session, filter_unindexed, + load_glob, load_trials, + map_who, + match_paths, preprocess_trial_file, query, + region_map, + region_pie, remove_surrounding_double_quotes, + world_map, ) @@ -439,8 +455,167 @@ def update_metadata(args): trials.to_csv(output_name, index=False) +def make_map(args): + input_files = args.input_files + file_filter = args.file_filter + plot_world = args.plot_world + country_column = args.country_column + title = args.title + df = load_glob(input_files, file_filter) + + sources = sorted(df.source.unique()) + if country_column not in df.columns: + raise argparse.ArgumentTypeError(f"Input data does not have {country_column}") + + counts = df.groupby(country_column).trial_id.size() + if plot_world: + world_map(counts, country_column=country_column) + else: + region_map(counts, country_column=country_column) + plt.suptitle(f"{title}\nData from: {' '.join(sources)} ({file_filter})") + plt.savefig(f"{'_'.join(sources)}_map.png", bbox_inches="tight") + + +def org_region(args): + input_files = args.input_files + df = load_glob(input_files, "ror") + sources = sorted(df.source.unique()) + region_pie(df) + plt.suptitle( + f"Sponsor Type by WHO Region with Registry Data\nData from: {' '.join(sources)}" + ) + plt.savefig(f"{'_'.join(sources)}_sponsor_by_region.png", bbox_inches="tight") + + +def site_sponsor(args): + sponsor_files = args.sponsor_files + site_files = args.site_files + file_filter = args.file_filter + + site_df = load_glob(site_files, file_filter) + site_df["who_region"] = map_who(site_df.country_ror) + sponsor_df = load_glob(sponsor_files, file_filter) + sponsor_df["sponsor_who_region"] = map_who(sponsor_df.country_ror) + merged = site_df.merge( + sponsor_df, left_on="trial_id", right_on="trial_id", how="left" + ) + counts = ( + merged.groupby(["who_region", "sponsor_who_region"]) + .trial_id.count() + .reset_index() + ) + # Map nodes to node ids + who_map = {name: index for index, name in enumerate(counts.who_region.unique())} + who_sponsor_map = { + name: index + len(who_map) + for index, name in enumerate(counts.sponsor_who_region.unique()) + } + link = dict( + source=list(counts.who_region.map(who_map)), + target=list(counts.sponsor_who_region.map(who_sponsor_map)), + value=list(counts.trial_id), + ) + data = go.Sankey( + link=link, node=dict(label=list(who_map.keys()) + list(who_sponsor_map.keys())) + ) + fig = go.Figure(data) + sources = sorted(set(merged.source_x).intersection(set(merged.source_y))) + fig.update_layout( + title=f"Mapping Trials Sites Country to Sponsor Country by WHO Region (data from: {' '.join(sources)})", + ) + fig.write_html("sankey.html") + + +def flowchart(args): + input_files = args.input_files + df = load_glob(input_files, "manual") + + total = df.shape[0] + + individual = df.individual + no_manual = df.no_manual_match | df.name.isnull() + + leftover = df[~(individual | no_manual)] + + ror_manual = ( + leftover.ror.isnull() + & leftover.name_manual.isnull() + & leftover.ror_manual.notnull() + ) + ror_fixed = ( + leftover.ror.notnull() + & leftover.name_manual.isnull() + & leftover.ror_manual.notnull() + ) # 32 + ror_right = ( + leftover.ror.notnull() + & leftover.name_manual.isnull() + & leftover.ror_manual.isnull() + ) + + ror_any = ror_manual | ror_fixed | ror_right + + manual = leftover.name_manual.notnull() + + assert (ror_any.sum() + manual.sum()) == len(leftover) + + with schemdraw.Drawing() as d: + d.config(fontsize=10) + d += flow.Start(w=6, h=2).label(f"Total trials\nn={total}") + d += flow.Arrow().down(d.unit / 2) + d += (step1 := flow.Box(w=0, h=0)) + d += flow.Arrow().down(d.unit / 2) + d += (step2 := flow.Box(w=0, h=0)) + + d += flow.Arrow().theta(-135) + d += ( + flow.Box(w=6, h=4) + .label(f"ROR resolved\nn={ror_any.sum()}") + .label(f"\n\n\n\n(n={ror_fixed.sum()} ROR manually corrected)", fontsize=8) + .label( + f"\n\n\n\n\n\n\n(n={ror_manual.sum()} ROR manually resolved)", + fontsize=8, + ) + ) + + d.move_from(step2.S) + d += flow.Arrow().theta(-45) + d += flow.Box(w=6, h=4).label(f"Name manually resolved\nn={manual.sum()}") + + # Exclusions + d.config(fontsize=8) + d += flow.Arrow().right(d.unit / 4).at(step1.E) + d += flow.Box(w=6, h=1).anchor("W").label(f"Individual\nn={individual.sum()}") + d += flow.Arrow().right(d.unit / 4).at(step2.E) + d += ( + flow.Box(w=6, h=1) + .anchor("W") + .label(f"No manual match\nn={no_manual.sum()}") + ) + + output_name = "_".join(sorted(df.source.unique())) + plt.savefig(f"{output_name}_flowchart") + if "manual_org_type" in leftover.columns: + leftover = leftover.organization_type.fillna(leftover.manual_org_type) + leftover.value_counts().to_csv(f"{output_name}_orgs.csv") + + +def multisite(args): + input_files = args.input_files + df = load_glob(input_files, "none") + counts = df.groupby("trial_id").trial_id.count() + table = ( + (counts > 1) + .value_counts() + .rename(index={False: "Single Site", True: "Multi-Site"}) + ) + output_name = "_".join(sorted(df.source.unique())) + table.to_csv(f"{output_name}_single_multi.csv") + + if __name__ == "__main__": base = get_base_parser() + results = get_results_parser() parent = get_full_parser() ror_parser = argparse.ArgumentParser() subparsers = ror_parser.add_subparsers() @@ -514,6 +689,54 @@ def update_metadata(args): help="Add newly resolved ids to the mapping dict", ) + map_parser = subparsers.add_parser("map", parents=[results]) + map_parser.add_argument( + "--plot-world", + action="store_true", + help="Plot a world map, rather than by WHO region", + ) + map_parser.add_argument("--title", type=str, help="Title for plot", required=True) + map_parser.add_argument( + "--country-column", + type=str, + help="Name of country column to use", + default="country", + ) + map_parser.set_defaults(func=make_map) + + org_parser = subparsers.add_parser("sponsor-org", parents=[results]) + org_parser.set_defaults(func=org_region) + + flowchart_parser = subparsers.add_parser("flowchart", parents=[results]) + flowchart_parser.set_defaults(func=flowchart) + + multisite_parser = subparsers.add_parser("multisite", parents=[results]) + multisite_parser.set_defaults(func=multisite) + + verb = get_verbosity_parser() + site_sponsor_parser = subparsers.add_parser("site-sponsor", parents=[verb]) + site_sponsor_parser.add_argument( + "--site-files", + required=True, + action="append", + type=match_paths, + help="One or more glob patterns for matching input files", + ) + site_sponsor_parser.add_argument( + "--sponsor-files", + required=True, + action="append", + type=match_paths, + help="One or more glob patterns for matching input files", + ) + site_sponsor_parser.add_argument( + "--file-filter", + choices=["manual", "ror", "country"], + default="country", + help="Filter registry data", + ) + site_sponsor_parser.set_defaults(func=site_sponsor) + args = ror_parser.parse_args() if hasattr(args, "func"): setup_logger(args.verbosity) diff --git a/requirements.dev.in b/requirements.dev.in index 58744bb..09bd1cd 100644 --- a/requirements.dev.in +++ b/requirements.dev.in @@ -12,11 +12,13 @@ matplotlib numpy pandas pip-tools +plotly pre-commit pytest requests requests-cache ruff schemdraw +seaborn unidecode vcrpy diff --git a/requirements.dev.txt b/requirements.dev.txt index 0b1a65a..6eb1ed5 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -68,7 +68,9 @@ iniconfig==2.0.0 kiwisolver==1.4.5 # via matplotlib matplotlib==3.8.0 - # via -r requirements.dev.in + # via + # -r requirements.dev.in + # seaborn multidict==6.0.4 # via yarl mypy-extensions==1.0.0 @@ -81,6 +83,7 @@ numpy==1.26.0 # contourpy # matplotlib # pandas + # seaborn # shapely packaging==23.1 # via @@ -88,12 +91,14 @@ packaging==23.1 # build # geopandas # matplotlib + # plotly # pytest pandas==2.1.0 # via # -r requirements.dev.in # country-converter # geopandas + # seaborn pathspec==0.11.2 # via black pillow==10.0.1 @@ -105,6 +110,8 @@ platformdirs==3.10.0 # black # requests-cache # virtualenv +plotly==5.17.0 + # via -r requirements.dev.in pluggy==1.3.0 # via pytest pre-commit==3.4.0 @@ -137,6 +144,8 @@ ruff==0.0.290 # via -r requirements.dev.in schemdraw==0.17 # via -r requirements.dev.in +seaborn==0.13.0 + # via -r requirements.dev.in shapely==2.0.1 # via geopandas six==1.16.0 @@ -144,6 +153,8 @@ six==1.16.0 # fiona # python-dateutil # url-normalize +tenacity==8.2.3 + # via plotly tomli==2.0.1 # via # black diff --git a/results.py b/results.py deleted file mode 100644 index e7a3c87..0000000 --- a/results.py +++ /dev/null @@ -1,287 +0,0 @@ -import argparse -import glob -import pathlib -from itertools import chain - -import geopandas -import matplotlib.pyplot as plt -import pandas -import schemdraw -from schemdraw import flow - -from setup import get_verbosity_parser, setup_logger -from utils import ( - add_suffix, - convert_country_simple, - map_who, -) - - -def get_resolved_filter(df): - assert (df.individual.dtype == "bool") and (df.no_manual_match.dtype == "bool") - return (~df.individual) & (~df.no_manual_match) - - -def normalize_name(df): - resolved_filter = get_resolved_filter(df) - df.loc[resolved_filter, "name_normalized"] = df.name_manual.fillna(df.name_resolved) - return df - - -def clean_individual_manual(df): - df["individual"] = df["individual"].fillna(0).astype(bool) - df["no_manual_match"] = df["no_manual_match"].fillna(0).astype(bool) - return df - - -def flowchart(args, df): - df = clean_individual_manual(df) - df = normalize_name(df) - total = df.shape[0] - - individual = df.individual - no_manual = df.no_manual_match | df.name.isnull() - - leftover = df[~(individual | no_manual)] - - ror_manual = ( - leftover.ror.isnull() - & leftover.name_manual.isnull() - & leftover.ror_manual.notnull() - ) - ror_fixed = ( - leftover.ror.notnull() - & leftover.name_manual.isnull() - & leftover.ror_manual.notnull() - ) # 32 - ror_right = ( - leftover.ror.notnull() - & leftover.name_manual.isnull() - & leftover.ror_manual.isnull() - ) - - ror_any = ror_manual | ror_fixed | ror_right - - manual = leftover.name_manual.notnull() - - assert (ror_any.sum() + manual.sum()) == len(leftover) - - with schemdraw.Drawing() as d: - d.config(fontsize=10) - d += flow.Start(w=6, h=2).label(f"Total trials\nn={total}") - d += flow.Arrow().down(d.unit / 2) - d += (step1 := flow.Box(w=0, h=0)) - d += flow.Arrow().down(d.unit / 2) - d += (step2 := flow.Box(w=0, h=0)) - - d += flow.Arrow().theta(-135) - d += ( - flow.Box(w=6, h=4) - .label(f"ROR resolved\nn={ror_any.sum()}") - .label(f"\n\n\n\n(n={ror_fixed.sum()} ROR manually corrected)", fontsize=8) - .label( - f"\n\n\n\n\n\n\n(n={ror_manual.sum()} ROR manually resolved)", - fontsize=8, - ) - ) - - d.move_from(step2.S) - d += flow.Arrow().theta(-45) - d += flow.Box(w=6, h=4).label(f"Name manually resolved\nn={manual.sum()}") - - # Exclusions - d.config(fontsize=8) - d += flow.Arrow().right(d.unit / 4).at(step1.E) - d += flow.Box(w=6, h=1).anchor("W").label(f"Individual\nn={individual.sum()}") - d += flow.Arrow().right(d.unit / 4).at(step2.E) - d += ( - flow.Box(w=6, h=1) - .anchor("W") - .label(f"No manual match\nn={no_manual.sum()}") - ) - - output_name = "_".join(df.source.unique()) - plt.savefig(f"{output_name}_flowchart") - orgs = leftover.organization_type.fillna(leftover.manual_org_type) - orgs.groupby(orgs).count().sort_values(ascending=False).to_csv( - f"{output_name}_orgs.csv" - ) - - -def sponsor_map(args, df): - df = clean_individual_manual(df) - df = normalize_name(df) - column_to_map = "trial_id" - - world = geopandas.read_file(geopandas.datasets.get_path("naturalearth_lowres")) - fig, ax = plt.subplots(1, 1, figsize=(15, 10)) - world.boundary.plot(ax=ax) - - df["iso_a3"] = convert_country_simple(df["country"], to="ISO3") - counts = df.groupby("iso_a3")[column_to_map].size() - merged = world.merge(counts, on="iso_a3") - - world.boundary.plot(ax=ax) - merged.plot( - column=column_to_map, - cmap="YlOrRd", - ax=ax, - legend=True, - legend_kwds={"label": "Number of Trials"}, - ) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - output_name = "_".join(df.source.unique()) - registries = " and ".join(df.source.str.upper().unique()) - ax.set_title(f"{registries} Sponsors by Country") - plt.savefig(f"{output_name}_map") - - -def site_map(args, df): - output_name = "_".join(df.source.unique()) - df["who_region"] = map_who(df.country) - - world = geopandas.read_file(geopandas.datasets.get_path("naturalearth_lowres")) - world["who_region"] = map_who(world.iso_a3) - fig, ax = plt.subplots(1, 1, figsize=(15, 10)) - world.boundary.plot(ax=ax) - - # Trial sites per WHO region (sites with multiple trials counted multiple times) - counts = df.groupby("who_region").trial_id.count() - merged = world.merge(counts, on="who_region") - - world.boundary.plot(ax=ax) - merged.plot( - column="trial_id", - cmap="YlOrRd", - ax=ax, - legend=True, - legend_kwds={"label": "Number of Trials"}, - ) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - output_name = "_".join(df.source.unique()) - registries = " and ".join(df.source.str.upper().unique()) - ax.set_title(f"{registries} Trial Sites by WHO Region") - plt.savefig(f"{output_name}_trial_map") - - -def write_table(df, output_dir, input_file, suffix): - path = add_suffix(output_dir, input_file, suffix) - df.to_csv(path) - - -def over_time(args, df): - df["enrollment_year"] = pandas.to_datetime(df.enrollment_date).dt.strftime("%Y") - df["registration_year"] = pandas.to_datetime(df.registration_date).dt.strftime("%Y") - by_enroll_date = df.groupby("enrollment_year").trial_id.count().reset_index() - by_reg_date = df.groupby("registration_year").trial_id.count().reset_index() - - fig, ax = plt.subplots(figsize=(10, 6)) - ax.bar( - by_reg_date["registration_year"].astype(int) - 0.2, - by_reg_date["trial_id"], - 0.4, - label="Registration year", - ) - ax.bar( - by_enroll_date["enrollment_year"].astype(int) + 0.2, - by_enroll_date["trial_id"], - 0.4, - label="Enrollment year", - ) - ax.legend(bbox_to_anchor=(1.3, 1.05)) - registries = " and ".join(df.source.str.upper().unique()) - ax.set_title(f"New trials enrolled or registered in {registries}") - fig.tight_layout() - output_name = "_".join(df.source.unique()) - plt.savefig(f"{output_name}_registrations_over_time") - - -def sites(args, df): - counts = df.groupby("trial_id").trial_id.count() - table = ( - (counts > 1) - .value_counts() - .rename(index={False: "Single Site", True: "Multi-Site"}) - ) - output_name = "_".join(df.source.unique()) - table.to_csv(f"{output_name}_single_multi.csv") - - -def get_path(*args): - return pathlib.Path(*args).resolve() - - -def match_paths(pattern): - return [get_path(x) for x in glob.glob(pattern)] - - -if __name__ == "__main__": - verb = get_verbosity_parser() - results_parser = argparse.ArgumentParser() - subparsers = results_parser.add_subparsers() - - map_parser = subparsers.add_parser("map", parents=[verb]) - map_parser.add_argument( - "--input-files", - required=True, - action="append", - type=match_paths, - help="One or more glob patterns for matching input files", - ) - map_parser.set_defaults(func=sponsor_map) - - site_map_parser = subparsers.add_parser("site_map", parents=[verb]) - site_map_parser.add_argument( - "--input-files", - required=True, - action="append", - type=match_paths, - help="One or more glob patterns for matching input files", - ) - site_map_parser.set_defaults(func=site_map) - - flowchart_parser = subparsers.add_parser("flowchart", parents=[verb]) - flowchart_parser.add_argument( - "--input-files", - required=True, - action="append", - type=match_paths, - help="One or more glob patterns for matching input files", - ) - flowchart_parser.set_defaults(func=flowchart) - - time_parser = subparsers.add_parser("time", parents=[verb]) - time_parser.add_argument( - "--input-files", - required=True, - action="append", - type=match_paths, - help="One or more glob patterns for matching input files", - ) - time_parser.set_defaults(func=over_time) - - sites_parser = subparsers.add_parser("sites", parents=[verb]) - sites_parser.add_argument( - "--input-files", - required=True, - action="append", - type=match_paths, - help="One or more glob patterns for matching input files", - ) - sites_parser.set_defaults(func=sites) - - args = results_parser.parse_args() - filenames_flat = list(chain.from_iterable(args.input_files)) - frames = [] - for input_file in filenames_flat: - temp = pandas.read_csv(input_file) - frames.append(temp) - df = pandas.concat(frames, ignore_index=True) - - if hasattr(args, "func"): - setup_logger(args.verbosity) - args.func(args, df) - else: - results_parser.print_help() diff --git a/setup.py b/setup.py index 4e206e1..15e2a07 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,8 @@ import pathlib from os import environ +from utils import match_paths + def setup_logger(verbosity): logging_level = logging.ERROR @@ -43,6 +45,25 @@ def get_verbosity_parser(): return verb +def get_results_parser(): + verb = get_verbosity_parser() + res = argparse.ArgumentParser(add_help=False, parents=[verb]) + res.add_argument( + "--input-files", + required=True, + action="append", + type=match_paths, + help="One or more glob patterns for matching input files", + ) + res.add_argument( + "--file-filter", + choices=["manual", "ror", "country"], + default="country", + help="Filter registry data", + ) + return res + + def get_base_parser(): verb = get_verbosity_parser() base = argparse.ArgumentParser(add_help=False, parents=[verb]) diff --git a/utils.py b/utils.py index 535e157..57248f1 100644 --- a/utils.py +++ b/utils.py @@ -1,18 +1,25 @@ import difflib +import glob import html import json import logging +import pathlib import re import sys import time from ast import literal_eval +from itertools import chain import country_converter as coco +import geopandas +import matplotlib.pyplot as plt import numpy import pandas import requests import requests_cache +import seaborn as sns from requests_cache import NEVER_EXPIRE, CachedSession +from shapely.geometry import MultiPolygon from unidecode import unidecode @@ -36,6 +43,99 @@ "NOT (animals [mh] NOT humans [mh])" ) +REGISTRY_MAP = { + "ACT": "ANZCTR", + "CTR": "CRiS", + "CHI": "ChiCTR", + "DRK": "DRKS", + "EUC": "EUCTR", + "IRC": "IRCT", + "ITM": "ITMCTR", + "JRP": "JPRN", + "KCT": "KCTR", + "LBC": "LBCTR", + "NCT": "ClinicalTrials.gov", + "NTR": "NTR", + "PAC": "PACTRN", + "PER": "REPEC", + "RBR": "ReBec", + "RPC": "RPCEC", + "SLC": "SLCTR", + "TCT": "TCTR", +} + + +def get_path(*args): + return pathlib.Path(*args).resolve() + + +def match_paths(pattern): + return [get_path(x) for x in glob.glob(pattern)] + + +def load_glob(filenames, file_filter): + filenames_flat = list(chain.from_iterable(filenames)) + frames = [] + for input_file in filenames_flat: + df = pandas.read_csv(input_file) + + # NOTE: have not removed individual/no manual/ror wrong + # NOTE: ror metadata i.e. country_ror might be wrong (if ror_wrong) + if file_filter == "manual": + if ( + len( + set( + [ + "individual", + "no_manual_match", + "name_manual", + "name_resolved", + ] + ) + - set(df.columns) + ) + == 0 + ): + df["individual"] = df["individual"].fillna(0).astype(bool) + df["no_manual_match"] = df["no_manual_match"].fillna(0).astype(bool) + resolved_filter = (~df.individual) & (~df.no_manual_match) + df.loc[resolved_filter, "name_normalized"] = df.name_manual.fillna( + df.name_resolved + ) + if "manual_org_type" in df.columns: + # Prefer manual + df.organization_type = df.manual_org_type.fillna( + df.organization_type + ) + else: + logging.info(f"Skipping {input_file}: has not been manually resolved") + continue + + elif file_filter == "ror": + # TODO: use manual fixes if they exist? + # Filter for those that ror resolved + if "ror" in df.columns and "organization_type" in df.columns: + df = df[df.ror.notnull()] + # TODO: could exclude Company here (high error rate) + # df = df[df.organization_type != "Company"] + else: + logging.info(f"Skipping {input_file}: does not have ROR columns") + continue + elif file_filter == "country": + # Skip the dataset if it has no country data + if "country" not in df.columns: + logging.info(f"Skipping {input_file}: has no country data") + continue + + # TODO: do we need to merge so they have the same columns? Fillna + logging.info(f"Adding {input_file}") + frames.append(df) + if len(frames) > 0: + return pandas.concat(frames, ignore_index=True) + else: + logging.error(f"No data passed the {file_filter} filter") + sys.exit(1) + def append_safe(df, filepath): if pandas.io.common.file_exists(filepath): @@ -421,6 +521,7 @@ def map_country(country_column): # https://ourworldindata.org/grapher/who-regions +# TODO: country codes that are not listed as WHO countries def map_who(country_column): """ Map country to WHO region @@ -621,3 +722,134 @@ def preprocess_trial_file(args): ) ] ].to_csv(output_dir / f"{source}_sites.csv") + + +def world_map(counts, country_column="country", legend_title="Number of Trials"): + """ + Counts is a series indexed by iso2 country + """ + world = geopandas.read_file(geopandas.datasets.get_path("naturalearth_lowres")) + world["country"] = convert_country_simple(world["iso_a3"], to="iso2") + fig, ax = plt.subplots(1, 1, figsize=(16, 10)) + world.boundary.plot(ax=ax) + + column_name = counts.name + counts = counts.reset_index() + + merged = world.merge(counts, left_on="country", right_on=country_column) + merged.plot( + column=column_name, + cmap="YlOrRd", + ax=ax, + legend=True, + legend_kwds={"label": f"{legend_title}"}, + ) + ax.set_xticklabels([]) + ax.set_yticklabels([]) + + +def region_map(counts, country_column="country", legend_title="Number of Trials"): + """ + Counts is a series indexed by iso2 country + """ + world = geopandas.read_file(geopandas.datasets.get_path("naturalearth_lowres")) + world["country"] = convert_country_simple(world["iso_a3"], to="iso2") + world["who_region"] = map_who(world["country"]) + + # Remove geometries that leave large gaps/impact scaling + world.loc[world["country"] == "FR", "geometry"] = ( + world[world["country"] == "FR"].iloc[0].geometry.geoms[1] + ) + world.loc[world["country"] == "FJ", "geometry"] = MultiPolygon( + list(world[world["country"] == "FJ"].iloc[0].geometry.geoms)[0:2] + ) + ru_shapes = list(world[world["country"] == "RU"].iloc[0].geometry.geoms) + world.loc[world["country"] == "RU", "geometry"] = MultiPolygon( + ru_shapes[0:10] + ru_shapes[13:] + ) + + column_name = counts.name + counts = counts.reset_index() + merged = world.merge(counts, left_on="country", right_on=country_column) + + fig, axs = plt.subplots(2, 3, figsize=(20, 10)) + axs = axs.flat + + for i, region in enumerate(merged.groupby("who_region")): + ax = axs[i] + region_name, region_df = region + # NOTE: plot WHOLE region boundary, not just those with counts + region_boundary = world[world.who_region == region_name] + region_boundary.boundary.plot(ax=ax) + region_df.plot( + column=column_name, + cmap="YlOrRd", + ax=ax, + legend=True, + legend_kwds={"label": f"{legend_title}"}, + ) + ax.set_title(f"{region_name} Trial Sites") + ax.set_xticklabels([]) + ax.set_yticklabels([]) + + +def region_pie(df, legend_title="Number of Trials"): + """ + Counts is a series indexed by iso2 country + """ + # TODO: which country- country_ror? + df["who_region"] = map_who(df["country"]) + grouped = df.groupby("who_region") + + orgs = df.organization_type.unique() + colors = dict(zip(orgs, sns.color_palette("colorblind", len(orgs)))) + + fig, axs = plt.subplots(2, 3, figsize=(20, 10)) + axs = axs.flat + + for i, region in enumerate(grouped): + region_name, data = region + ax = axs[i] + counts = ( + data.groupby("organization_type") + .trial_id.count() + .sort_values(ascending=False) + ) + labels = [f"{label} {count}" for label, count in counts.items()] + region_colors = list(counts.index.map(colors)) + ax.pie( + counts, + labels=labels, + startangle=140, + colors=region_colors, + labeldistance=None, + ) + ax.legend(bbox_to_anchor=(1.0, 0, 0.5, 1)) + ax.set_title(f"{region_name}") + + +def over_time(df, column="trial_id"): + df["enrollment_year"] = pandas.to_datetime(df.enrollment_date).dt.strftime("%Y") + df["registration_year"] = pandas.to_datetime(df.registration_date).dt.strftime("%Y") + by_enroll_date = df.groupby("enrollment_year")[column].count().reset_index() + by_reg_date = df.groupby("registration_year")[column].count().reset_index() + + fig, ax = plt.subplots(figsize=(10, 6)) + ax.bar( + by_reg_date["registration_year"].astype(int) - 0.2, + by_reg_date["trial_id"], + 0.4, + label="Registration year", + ) + ax.bar( + by_enroll_date["enrollment_year"].astype(int) + 0.2, + by_enroll_date["trial_id"], + 0.4, + label="Enrollment year", + ) + ax.legend(bbox_to_anchor=(1.3, 1.05)) + registries = " and ".join(df.source.str.upper().unique()) + ax.set_title(f"New trials enrolled or registered in {registries}") + fig.tight_layout() + output_name = "_".join(df.source.unique()) + plt.savefig(f"{output_name}_registrations_over_time")