diff --git a/streamlit_folium/__init__.py b/streamlit_folium/__init__.py index b62abaf..a9bc74c 100644 --- a/streamlit_folium/__init__.py +++ b/streamlit_folium/__init__.py @@ -9,6 +9,7 @@ import branca import folium +import folium.elements import folium.plugins import streamlit as st import streamlit.components.v1 as components @@ -22,6 +23,7 @@ _component_func = components.declare_component( "st_folium", url="http://localhost:3001" ) + else: parent_dir = os.path.dirname(os.path.abspath(__file__)) build_dir = os.path.join(parent_dir, "frontend/build") @@ -367,6 +369,8 @@ def bounds_to_dict(bounds_list: list[list[float]]) -> dict[str, dict[str, float] st.code(layer_control_string) def walk(fig): + if isinstance(fig, branca.colormap.ColorMap): + yield fig if isinstance(fig, folium.plugins.DualMap): yield from walk(fig.m1) yield from walk(fig.m2) @@ -376,10 +380,16 @@ def walk(fig): for child in fig._children.values(): yield from walk(child) - css_links = [] - js_links = [] + css_links: list[str] = [] + js_links: list[str] = [] for elem in walk(folium_map): + if isinstance(elem, branca.colormap.ColorMap): + # manually add d3.js + js_links.insert( + 0, "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.5/d3.min.js" + ) + js_links.insert(0, "https://d3js.org/d3.v4.min.js") css_links.extend([href for _, href in elem.default_css]) js_links.extend([src for _, src in elem.default_js]) diff --git a/streamlit_folium/frontend/src/index.tsx b/streamlit_folium/frontend/src/index.tsx index 9d05c43..bcf54b4 100644 --- a/streamlit_folium/frontend/src/index.tsx +++ b/streamlit_folium/frontend/src/index.tsx @@ -144,7 +144,7 @@ function onLayerClick(e: any) { debouncedUpdateComponentValue(window.map) } -function getPixelatedStyles(pixelated: boolean) { +function getPixelatedStyles(pixelated: boolean) { if (pixelated) { const styles = ` .leaflet-image-layer { @@ -164,7 +164,6 @@ function getPixelatedStyles(pixelated: boolean) { } ` return styles - } window.initComponent = (map: any, return_on_hover: boolean) => { @@ -190,7 +189,7 @@ window.initComponent = (map: any, return_on_hover: boolean) => { * the component is initially loaded, and then again every time the * component gets new data from Python. */ -function onRender(event: Event): void { +async function onRender(event: Event) { // Get the RenderData from the event const data = (event as CustomEvent).detail @@ -209,30 +208,54 @@ function onRender(event: Event): void { const layer_control: string = data.args["layer_control"] const pixelated: boolean = data.args["pixelated"] - var finalizeOnRender = () => { + // load scripts + const loadScripts = async () => { + for (const link of js_links) { + // use promise to load scripts synchronously + await new Promise((resolve, reject) => { + const script = document.createElement("script") + script.src = link + script.async = false + script.onload = resolve + script.onerror = reject + window.document.body.appendChild(script) + }) + } + + css_links.forEach((link) => { + const linkTag = document.createElement("link") + linkTag.rel = "stylesheet" + linkTag.href = link + window.document.head.appendChild(linkTag) + }) + + const style = document.createElement("style") + style.innerHTML = getPixelatedStyles(pixelated) + window.document.head.appendChild(style) + } + + // finalize rendering + const finalizeOnRender = () => { if ( feature_group !== window.__GLOBAL_DATA__.last_feature_group || layer_control !== window.__GLOBAL_DATA__.last_layer_control ) { + // remove previous feature group and layer control if (window.feature_group && window.feature_group.length > 0) { window.feature_group.forEach((layer: Layer) => { - window.map.removeLayer(layer); - }); + window.map.removeLayer(layer) + }) } if (window.layer_control) { window.map.removeControl(window.layer_control) } + // update feature group and layer control cache window.__GLOBAL_DATA__.last_feature_group = feature_group window.__GLOBAL_DATA__.last_layer_control = layer_control - if (feature_group){ - // Though using `eval` is generally a bad idea, we're using it here - // because we're evaluating code that we've generated ourselves on the - // Python side. This is safe because we're not evaluating user input, so this - // couldn't be used to execute arbitrary code. - + if (feature_group) { // eslint-disable-next-line eval(feature_group + layer_control) for (let key in window.map._layers) { @@ -296,7 +319,6 @@ function onRender(event: Event): void { document.body.appendChild(a) } - const render_script = document.createElement("script") // HACK -- update the folium-generated JS to add, most importantly, // the map to this global variable so that it can be used elsewhere // in the script. @@ -322,60 +344,27 @@ function onRender(event: Event): void { parent_div?.classList.remove("single") parent_div?.classList.add("double") } + } + await loadScripts().then(() => { + const render_script = document.createElement("script") - // This is only loaded once, from the onload callback - var postLoad = () => { - if (!window.map) { - render_script.innerHTML = + if (!window.map) { + render_script.innerHTML = script + - `window.map = map_div; window.initComponent(map_div, ${return_on_hover});` - document.body.appendChild(render_script) - const html_div = document.createElement("div") - html_div.innerHTML = html - document.body.appendChild(html_div) - const styles = getPixelatedStyles(pixelated) - var styleSheet = document.createElement("style") - styleSheet.innerText = styles - document.head.appendChild(styleSheet) - } - finalizeOnRender(); - } - - if (js_links.length === 0) { - postLoad(); - } else { - // make sure dependent js files are loaded - // before we initialize the component - var count = 0; - js_links.forEach((elem) => { - var scr = document.createElement('script'); - scr.src = elem; - scr.async = false; - scr.onload = () => { - count -= 1; - if(count === 0) { - setTimeout(postLoad, 0); - } - }; - document.head.appendChild(scr); - count += 1; - }); + `window.map = map_div; window.initComponent(map_div, ${return_on_hover});` + document.body.appendChild(render_script) + const html_div = document.createElement("div") + html_div.innerHTML = html + document.body.appendChild(html_div) + const styles = getPixelatedStyles(pixelated) + var styleSheet = document.createElement("style") + styleSheet.innerText = styles + document.head.appendChild(styleSheet) } - - // css is okay regardless loading order - css_links.forEach((elem) => { - var link = document.createElement('link'); - link.rel = "stylesheet"; - link.type = "text/css"; - link.href = elem; - document.head.appendChild(link); - }); - Streamlit.setFrameHeight() - } - } else { - finalizeOnRender(); + finalizeOnRender() + }) } - + finalizeOnRender() } // Attach our `onRender` handler to Streamlit's render event.