From 6f48e90c39abd5b944a4731d1ab85570e7e5edbe Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 15 Nov 2022 09:03:45 -0800 Subject: [PATCH 01/19] Added `diffrax.citation` --- README.md | 2 + diffrax/__init__.py | 1 + diffrax/adjoint.py | 12 + diffrax/autocitation.py | 547 +++++++++++++++++++++++++++++++ diffrax/heuristics.py | 28 +- docs/api/citation.md | 5 + docs/further_details/citation.md | 2 + mkdocs.yml | 1 + 8 files changed, 592 insertions(+), 6 deletions(-) create mode 100644 diffrax/autocitation.py create mode 100644 docs/api/citation.md diff --git a/README.md b/README.md index 7330e3b3..e3f0e3d3 100644 --- a/README.md +++ b/README.md @@ -65,4 +65,6 @@ Neural networks: [Equinox](https://github.com/patrick-kidger/equinox). Type annotations and runtime checking for PyTrees and shape/dtype of JAX arrays: [jaxtyping](https://github.com/google/jaxtyping). +Computer vision models: [Eqxvision](https://github.com/paganpasta/eqxvision). + SymPy<->JAX conversion; train symbolic expressions via gradient descent: [sympy2jax](https://github.com/google/sympy2jax). diff --git a/diffrax/__init__.py b/diffrax/__init__.py index a038518a..403cb382 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -5,6 +5,7 @@ NoAdjoint, RecursiveCheckpointAdjoint, ) +from .autocitation import citation, citation_rules from .brownian import AbstractBrownianPath, UnsafeBrownianPath, VirtualBrownianTree from .event import ( AbstractDiscreteTerminatingEvent, diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index 981447be..889c52b2 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -125,6 +125,18 @@ class RecursiveCheckpointAdjoint(AbstractAdjoint): In addition a binomial checkpointing scheme is used so that memory usage is low. (This checkpointing can increase compile time a bit, though.) + + !!! Reference + + Binomial checkpointing (also known as "treeverse") was introduced in: + ```bibtex + @article{griewank1998treeverse, + title = {Treeverse: An Implementation of Checkpointing for the Reverse or + Adjoint Mode of Computational Differentiation} + author = {Griewank, Andreas and Walther, Andrea}, + year = {1998}, + } + ``` """ def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs): diff --git a/diffrax/autocitation.py b/diffrax/autocitation.py new file mode 100644 index 00000000..92f2f974 --- /dev/null +++ b/diffrax/autocitation.py @@ -0,0 +1,547 @@ +import functools as ft +import inspect +import re +from typing import Callable, Optional, Sequence + +import jax +import jax.tree_util as jtu + +from .adjoint import BacksolveAdjoint, RecursiveCheckpointAdjoint +from .brownian import VirtualBrownianTree +from .heuristics import is_cde, is_sde +from .integrate import diffeqsolve +from .misc import adjoint_rms_seminorm +from .solver import ( + AbstractImplicitSolver, + Dopri5, + Dopri8, + Kvaerno3, + Kvaerno4, + Kvaerno5, + LeapfrogMidpoint, + ReversibleHeun, + SemiImplicitEuler, + Tsit5, +) +from .step_size_controller import PIDController + + +def citation(*args, **kwargs): + """Autogenerate a list of BibTeX references for the numerical methods being used. + + **Arguments:** + + `citation` may be called with any subset of the argments to + [`diffrax.diffeqsolve`][]. To generate the citation list it may be easiest + to simply replace `diffeqsolve` with `citation`. + + **Returns:** + + Nothing. Prints a BibTeX file to stdout. + + !!! Example + + ```python + from diffrax import citation, Dopri5, PIDController + citation(solver=Dopri5(), + stepsize_controller=PIDController(pcoeff=0.4, rtol=1e-3, atol=1e-6)) + # % --- AUTOGENERATED REFERENCES PRODUCED USING `diffrax.citation(...)` --- + # % The following references were found for the numerical techniques being used. + # % This does not cover e.g. any modelling techniques being used. + # + # ... + # ... Full output truncated in this example! + # ... Here's what the final entry looks like: + # ... + # + # % The use of a PI-controller to adapt step sizes is from Section IV.2 of: + # @book{hairer2002solving-ii, + # address={Berlin}, + # author={Hairer, E. and Wanner, G.}, + # edition={Second Revised Edition}, + # publisher={Springer}, + # title={{S}olving {O}rdinary {D}ifferential {E}quations {II} {S}tiff and + # {D}ifferential-{A}lgebraic {P}roblems}, + # year={2002} + # } + # % and Sections 1--3 of: + # @article{soderlind2002automatic, + # title={Automatic control and adaptive time-stepping}, + # author={Gustaf S{\"o}derlind}, + # year={2002}, + # journal={Numerical Algorithms}, + # volume={31}, + # pages={281--310} + # } + # + # % --- END AUTOGENERATED REFERENCES --- + ``` + + """ + bound = _diffeqsignature.bind_partial(*args, **kwargs) + kwargs = dict(bound.kwargs) + for arg_name, arg_value in zip(_diffeqsignature.parameters.keys(), bound.args): + kwargs[arg_name] = arg_value + cites = [] + cites.append(_start) + for rule in citation_rules: + rule_parameters = list(inspect.signature(rule).parameters.values()) + needed_keys = set() + has_var = False + for param in rule_parameters: + if param.kind == inspect.Parameter.VAR_KEYWORD: + has_var = True + else: + assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + if param.default is inspect.Parameter.empty: + needed_keys.add(param.name) + if not set(kwargs).issuperset(needed_keys): + continue + if has_var: + rulekwargs = kwargs + else: + rulekwargs = { + param.name: kwargs[param.name] + for param in rule_parameters + if param.name in kwargs + } + cite = rule(**rulekwargs) + if cite is not None: + cites.append(cite.strip()) + cites.append(_end) + print("\n\n".join(cites)) + + +_diffeqsignature = inspect.signature(diffeqsolve) + + +citation_rules: Sequence[Callable[..., Optional[str]]] = [] + + +_thesis_cite = r""" +phdthesis{kidger2021on, + title={{O}n {N}eural {D}ifferential {E}quations}, + author={Patrick Kidger}, + year={2021}, + school={University of Oxford}, +} +""".strip() + +_start = r""" +% --- AUTOGENERATED REFERENCES PRODUCED USING `diffrax.citation(...)` --- +% The following references were found for the numerical techniques being used. +% This does not cover e.g. any modelling techniques being used. +% If you think a paper is missing from here then open an issue or pull request at +% https://github.com/patrick-kidger/diffrax +""".strip() + +_end = r""" +% --- END AUTOGENERATED REFERENCES --- +""".strip() + + +_reference_regex = re.compile(r"```bibtex([^`]*)```") + + +@ft.lru_cache(maxsize=None) +def _parse_reference(obj, allow_multiple=False): + references = _reference_regex.findall(obj.__doc__) + references = [inspect.cleandoc(ref) for ref in references] + if allow_multiple: + return references + else: + [reference] = references + return reference + + +def _no_tracer(x, name): + if isinstance(x, jax.core.Tracer): + raise RuntimeError( + f"`diffrax.citation` was called with {name} as a traced JAX value. Try " + "running again without this, e.g. using `jax.disable_jit()`." + ) + + +@citation_rules.append +def _diffrax(): + return ( + r""" +% You are using Diffrax, which is citable as: +""" + + _thesis_cite + + r""" + +% You are using Equinox, which is citable as: +@article{kidger2021equinox, + author={Patrick Kidger and Cristian Garcia}, + title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and + filtered transformations}, + year={2021}, + journal={Differentiable Programming workshop at Neural Information Processing + Systems 2021} +} + +% You are using JAX, which is citable as: +@software{jax2018github, + author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson + and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and + Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, + title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, + url = {http://github.com/google/jax}, + version = {""" + + str(jax.__version__) + + r"""}, + year = {2018}, +} +""" + ) + + +@citation_rules.append +def _backsolve_adjoint(adjoint, terms=None): + if type(adjoint) is BacksolveAdjoint: + if is_sde(terms): + return ( + r""" + % You are backpropagating through an SDE using optimise-then-discretise + % (`adjoint=BacksolveAdjoint(...)`) + % This technique was introduced in + """ + + _parse_reference(VirtualBrownianTree) + + r""" + % This technique was refined (simplified via rough path theory) in Section 5.2.3 of: + """ + + _thesis_cite + ) + elif is_cde(terms): + return ( + r""" + % You are backpropagating through a CDE using optimise-then-discretise + % (`adjoint=BacksolveAdjoint(...)`) + % This technique was introduced in Section 5.2.2 of: + """ + + _thesis_cite + ) + else: + return ( + r""" +% You are backpropagating through an ODE using optimise-then-discretise +% (`adjoint=BacksolveAdjoint(...)`) +% Many references exist for this technique. For example: +@article{chen2018neuralode, + title={Neural Ordinary Differential Equations}, + author={Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, + David}, + journal={Advances in Neural Information Processing Systems}, + year={2018} +} +% In addition, the most modern (6-line) proof of this result can be found in Section +5.1.2.1 of: +""" + + _thesis_cite + ) + + +@citation_rules.append +def _discrete_adjoint(adjoint): + if type(adjoint) in (RecursiveCheckpointAdjoint,): + pieces = [] + pieces.append( + r""" +% You are differentiating using discretise-then-optimise. The following papers may be +% relevant. +""" + ) + if type(adjoint) is RecursiveCheckpointAdjoint: + pieces.append( + r""" +% If using reverse-mode autodifferentiation (backpropagation), then you are +% using binomial checkpointing ("treeverse"), which was introduced in: +""" + + _parse_reference(RecursiveCheckpointAdjoint) + ) + + pieces.append( + r""" +% If using forward-mode autodifferentiation, then this was studied in: +@inproceedings{ma2021comparison, + title={A Comparison of Automatic Differentiation and Continuous Sensitivity Analysis + for Derivatives of Differential Equation Solutions}, + author={Ma, Yingbo and Dixit, Vaibhav and Innes, Michael J and Guo, Xingjian and + Rackauckas, Chris}, + booktitle={2021 IEEE High Performance Extreme Computing Conference (HPEC)}, + year={2021}, + pages={1-9}, + doi={10.1109/HPEC49654.2021.9622796} +} +""" + ) + return "\n".join([p.strip() for p in pieces]) + + +@citation_rules.append +def _virtual_brownian_tree(terms): + is_vbt = lambda x: isinstance(x, VirtualBrownianTree) + leaves = jtu.tree_leaves(terms, is_leaf=is_vbt) + if any(is_vbt(leaf) for leaf in leaves): + return r""" +% You are simulating Brownian motion using a virtual Brownian tree, which was introduced +% in: +""" + _parse_reference( + VirtualBrownianTree + ) + + +@citation_rules.append +def _backsolve_rms_norm(adjoint): + if type(adjoint) is BacksolveAdjoint: + if adjoint_rms_seminorm in jtu.tree_leaves(adjoint): + return r""" +% You are backpropagating using adjoint seminorms, which was introduced in:: +""" + _parse_reference( + adjoint_rms_seminorm + ) + + +@citation_rules.append +def _explicit_solver(solver, terms=None): + if not isinstance(solver, AbstractImplicitSolver) and not is_sde(terms): + return r""" +% You are using an explicit solver, and may wish to cite the standard textbook: +@book{hairer2008solving-i, + address={Berlin}, + author={Hairer, E. and N{\o}rsett, S.P. and Wanner, G.}, + edition={Second Revised Edition}, + publisher={Springer}, + title={{S}olving {O}rdinary {D}ifferential {E}quations {I} {N}onstiff + {P}roblems}, + year={2008} +} +""" + + +@citation_rules.append +def _implicit_solver(solver, terms=None): + if isinstance(solver, AbstractImplicitSolver) and not is_sde(terms): + return r""" +% You are using an implicit solver, and may wish to cite the standard textbook: +@book{hairer2002solving-ii, + address={Berlin}, + author={Hairer, E. and Wanner, G.}, + edition={Second Revised Edition}, + publisher={Springer}, + title={{S}olving {O}rdinary {D}ifferential {E}quations {II} {S}tiff and + {D}ifferential-{A}lgebraic {P}roblems}, + year={2002} +} +""" + + +@citation_rules.append +def _symplectic_solver(solver, terms=None): + if type(solver) is SemiImplicitEuler and not is_sde(terms): + return r""" +You are using a symplectic solver, and may wish to cite the textbook: +@book{hairer2013geometric, + title={Geometric Numerical Integration: Structure-Preserving Algorithms for Ordinary + Differential Equations}, + author={Hairer, E. and Lubich, C. and Wanner, G.}, + isbn={9783662050187}, + series={Springer Series in Computational Mathematics}, + year={2013}, + publisher={Springer Berlin Heidelberg} +} + +""" + + +@citation_rules.append +def _cde(terms): + if is_cde(terms): + return r""" +% You are solving a CDE. These were studied in: +@incollection{kidger2020neuralcde, + title={Neural Controlled Differential Equations for Irregular Time Series}, + author={Kidger, Patrick and Morrill, James and Foster, James and Lyons, Terry}, + booktitle={Advances in Neural Information Processing Systems}, + publisher={Curran Associates, Inc.}, + year={2020}, +} +""" + + +@citation_rules.append +def _sde(terms): + if is_sde(terms): + return r""" +% You are solving an SDE, and may wish to cite the textbook: +@book{kloeden2011numerical, + title={Numerical Solution of Stochastic Differential Equations}, + author={Kloeden, P.E. and Platen, E.}, + isbn={9783540540625}, + series={Stochastic Modelling and Applied Probability}, + year={2011}, + publisher={Springer Berlin Heidelberg} +} +""" + + +@citation_rules.append +def _solvers(solver, saveat=None): + if type(solver) in ( + Tsit5, + Kvaerno3, + Kvaerno4, + Kvaerno5, + ReversibleHeun, + LeapfrogMidpoint, + ): + return ( + r""" +% You are using the """ + + solver.__class__.__name__ + + r""" solver, which was introduced in: +""" + + _parse_reference(solver) + ) + elif type(solver) is Dopri5: + ref1, ref2 = _parse_reference(Dopri5, allow_multiple=True) + assert "Dormand" in ref1 + assert "Prince" in ref1 + assert "Shampine" in ref2 + return ( + r""" +% Dormand--Prince 5(4) was introduced in: +""" + + ref1 + + r""" +% The specific implementation used here is the improved version (different Butcher +% tableau) introduced in: +""" + + ref2 + ) + elif type(solver) is Dopri8: + ref1, ref2 = _parse_reference(Dopri8, allow_multiple=True) + assert "Dormand" in ref1 + assert "Prince" in ref1 + assert "Bogacki" in ref2 + assert "Shampine" in ref2 + msg = ( + r""" +% Dormand--Prince 8(7) was introduced in: +""" + + ref1 + ) + if saveat is not None and (saveat.ts or saveat.dense): + msg += ( + r""" +% Output via `SaveAt(ts=...)` or `SaveAt(dense=True)` is done using the +% Dormand--Prince 8(7) interpolant introduced in: +""" + + ref2 + ) + return msg + + +@citation_rules.append +def _auto_dt0(dt0): + if dt0 is None: + return r""" +% Automatic selection of initial step size is from Section II.4 of: +@book{hairer2008solving-i, + address={Berlin}, + author={Hairer, E. and N{\o}rsett, S.P. and Wanner, G.}, + edition={Second Revised Edition}, + publisher={Springer}, + title={{S}olving {O}rdinary {D}ifferential {E}quations {I} {N}onstiff + {P}roblems}, + year={2008} +} +""" + + +@citation_rules.append +def _pid_controller(stepsize_controller, terms=None): + if type(stepsize_controller) is PIDController: + if is_sde(terms): + return r""" +% The use of PI and PI controllers to adapt step sizes for SDEs are from: +@article{burrage2004adaptive, + title={Adaptive stepsize based on control theory for stochastic + differential equations}, + journal={Journal of Computational and Applied Mathematics}, + volume={170}, + number={2}, + pages={317--336}, + year={2004}, + doi={https://doi.org/10.1016/j.cam.2004.01.027}, + author={P.M. Burrage and R. Herdiana and K. Burrage}, +} +@article{ilie2015adaptive, + author={Ilie, Silvana and Jackson, Kenneth R. and Enright, Wayne H.}, + title={{A}daptive {T}ime-{S}tepping for the {S}trong {N}umerical {S}olution + of {S}tochastic {D}ifferential {E}quations}, + year={2015}, + publisher={Springer-Verlag}, + address={Berlin, Heidelberg}, + volume={68}, + number={4}, + doi={https://doi.org/10.1007/s11075-014-9872-6}, + journal={Numer. Algorithms}, + pages={791–-812}, +} +""" + else: + no_p = stepsize_controller.pcoeff == 0 + no_d = stepsize_controller.dcoeff == 0 + _no_tracer(no_p, "stepsize_controller.pcoeff") + _no_tracer(no_d, "stepsize_controller.dcoeff") + if no_d: + if no_p: + return r""" +% The use of an I-controller to adapt step sizes is from Section II.4 of: +@book{hairer2008solving-i, + address={Berlin}, + author={Hairer, E. and N{\o}rsett, S.P. and Wanner, G.}, + edition={Second Revised Edition}, + publisher={Springer}, + title={{S}olving {O}rdinary {D}ifferential {E}quations {I} {N}onstiff + {P}roblems}, + year={2008} +} +""" + else: + return r""" +% The use of a PI-controller to adapt step sizes is from Section IV.2 of: +@book{hairer2002solving-ii, + address={Berlin}, + author={Hairer, E. and Wanner, G.}, + edition={Second Revised Edition}, + publisher={Springer}, + title={{S}olving {O}rdinary {D}ifferential {E}quations {II} {S}tiff and + {D}ifferential-{A}lgebraic {P}roblems}, + year={2002} +} +% and Sections 1--3 of: +@article{soderlind2002automatic, + title={Automatic control and adaptive time-stepping}, + author={Gustaf S{\"o}derlind}, + year={2002}, + journal={Numerical Algorithms}, + volume={31}, + pages={281--310} +} +""" + else: + return r""" +% The use of a PID controller to adapt step sizes is from: +@article{soderlind2003digital, + title={{D}igital {F}ilters in {A}daptive {T}ime-{S}tepping, + author={Gustaf S{\"o}derlind}, + year={2003}, + journal={ACM Transactions on Mathematical Software}, + volume={20}, + number={1}, + pages={1--26} +} +""" diff --git a/diffrax/heuristics.py b/diffrax/heuristics.py index 43b401cb..41f9eea5 100644 --- a/diffrax/heuristics.py +++ b/diffrax/heuristics.py @@ -2,6 +2,7 @@ from .brownian import AbstractBrownianPath, UnsafeBrownianPath from .custom_types import PyTree +from .path import AbstractPath from .term import AbstractTerm @@ -16,13 +17,28 @@ # really just to catch common errors. # That is, for the power user who implements enough to bypass this check -- probably # they know what they're doing and can handle both of these cases appropriately. +def _is_brownian(x): + return isinstance(x, AbstractBrownianPath) + + +def _is_unsafe_brownian(x): + return isinstance(x, UnsafeBrownianPath) + + +def _is_path(x): + return isinstance(x, AbstractPath) + + def is_sde(terms: PyTree[AbstractTerm]) -> bool: - is_brownian = lambda x: isinstance(x, AbstractBrownianPath) - leaves, _ = jtu.tree_flatten(terms, is_leaf=is_brownian) - return any(is_brownian(leaf) for leaf in leaves) + leaves, _ = jtu.tree_flatten(terms, is_leaf=_is_brownian) + return any(_is_brownian(leaf) for leaf in leaves) def is_unsafe_sde(terms: PyTree[AbstractTerm]) -> bool: - is_brownian = lambda x: isinstance(x, UnsafeBrownianPath) - leaves, _ = jtu.tree_flatten(terms, is_leaf=is_brownian) - return any(is_brownian(leaf) for leaf in leaves) + leaves, _ = jtu.tree_flatten(terms, is_leaf=_is_unsafe_brownian) + return any(_is_unsafe_brownian(leaf) for leaf in leaves) + + +def is_cde(terms: PyTree[AbstractTerm]) -> bool: + leaves, _ = jtu.tree_flatten(terms, is_leaf=_is_path) + return any(_is_path(leaf) and not _is_brownian(leaf) for leaf in leaves) diff --git a/docs/api/citation.md b/docs/api/citation.md new file mode 100644 index 00000000..2cc2d588 --- /dev/null +++ b/docs/api/citation.md @@ -0,0 +1,5 @@ +# Create citations + +Diffrax can autogenerate BibTeX citations for all the numerical methods you use. + +::: diffrax.citation diff --git a/docs/further_details/citation.md b/docs/further_details/citation.md index 16ed698d..3841153d 100644 --- a/docs/further_details/citation.md +++ b/docs/further_details/citation.md @@ -1,3 +1,5 @@ # Citation --8<-- "further_details/.citation.md" + +In addition, see the [Create citations](../api/citation.md) page for how to get Diffrax to autogenerate a list of BibTeX citations for the numerical methods you are using. diff --git a/mkdocs.yml b/mkdocs.yml index ad3020d3..05d9ea99 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -117,6 +117,7 @@ nav: - 'api/saveat.md' - 'api/stepsize_controller.md' - 'api/solution.md' + - 'api/citation.md' - Advanced API: - 'api/adjoints.md' - 'api/events.md' From f6c2edba4294b25251ea5ba5c78966467d1f68ad Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 7 Dec 2022 12:55:45 -0800 Subject: [PATCH 02/19] Removed old+undocumented Fehlberg2 --- diffrax/__init__.py | 1 - diffrax/solver/__init__.py | 1 - diffrax/solver/fehlberg2.py | 26 -------------------------- 3 files changed, 28 deletions(-) delete mode 100644 diffrax/solver/fehlberg2.py diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 403cb382..852f01d9 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -56,7 +56,6 @@ Dopri8, Euler, EulerHeun, - Fehlberg2, HalfSolver, Heun, ImplicitEuler, diff --git a/diffrax/solver/__init__.py b/diffrax/solver/__init__.py index 8c8dabf9..30964682 100644 --- a/diffrax/solver/__init__.py +++ b/diffrax/solver/__init__.py @@ -12,7 +12,6 @@ from .dopri8 import Dopri8 from .euler import Euler from .euler_heun import EulerHeun -from .fehlberg2 import Fehlberg2 from .heun import Heun from .implicit_euler import ImplicitEuler from .kvaerno3 import Kvaerno3 diff --git a/diffrax/solver/fehlberg2.py b/diffrax/solver/fehlberg2.py deleted file mode 100644 index 5ed58cc5..00000000 --- a/diffrax/solver/fehlberg2.py +++ /dev/null @@ -1,26 +0,0 @@ -import numpy as np - -from ..local_interpolation import ThirdOrderHermitePolynomialInterpolation -from .runge_kutta import AbstractERK, ButcherTableau - - -_fehlberg2_tableau = ButcherTableau( - a_lower=(np.array([1 / 2]), np.array([1 / 256, 255 / 256])), - b_sol=np.array([1 / 512, 255 / 256, 1 / 512]), - b_error=np.array([-1 / 512, 0, 1 / 512]), - c=np.array([1 / 2, 1.0]), -) - - -class Fehlberg2(AbstractERK): - """Fehlberg's method. - - 2nd order explicit Runge--Kutta method. Has an embedded first order method for - adaptive step sizing. - """ - - tableau = _fehlberg2_tableau - interpolation_cls = ThirdOrderHermitePolynomialInterpolation.from_k - - def order(self, terms): - return 2 From d611fad77825f06221e1ef77a22094f1085ea4d9 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 9 Jan 2023 14:31:03 -0800 Subject: [PATCH 03/19] update to Equinox version 0.10.0 --- benchmarks/scan_stages.py | 2 +- benchmarks/scan_stages_cnf.py | 7 ++++--- benchmarks/small_neural_ode.py | 8 ++++---- diffrax/brownian/path.py | 2 +- diffrax/brownian/tree.py | 2 +- diffrax/global_interpolation.py | 18 +++++++++--------- diffrax/integrate.py | 2 +- examples/kalman_filter.ipynb | 23 ++++++++++++++--------- examples/neural_cde.ipynb | 2 +- setup.py | 2 +- test/test_adjoint.py | 5 +++-- 11 files changed, 40 insertions(+), 33 deletions(-) diff --git a/benchmarks/scan_stages.py b/benchmarks/scan_stages.py index 6255167c..0110326b 100644 --- a/benchmarks/scan_stages.py +++ b/benchmarks/scan_stages.py @@ -53,7 +53,7 @@ def main(scan_stages): t1 = 1 dt0 = None - @eqx.filter_jit + @eqx.filter_jit(donate="none") def solve(y0): return dfx.diffeqsolve( term, solver, t0, t1, dt0, y0, stepsize_controller=stepsize_controller diff --git a/benchmarks/scan_stages_cnf.py b/benchmarks/scan_stages_cnf.py index 41782168..1108819a 100644 --- a/benchmarks/scan_stages_cnf.py +++ b/benchmarks/scan_stages_cnf.py @@ -84,9 +84,10 @@ def main(scan_stages, backsolve): mkey, dkey = jr.split(jr.PRNGKey(0), 2) model = eqx.nn.MLP(2, 2, 10, 2, activation=jnn.gelu, key=mkey) x = jr.normal(dkey, (256, 2)) - solve_ = ft.partial(solve, model, x, scan_stages, backsolve) - print("Compile+run time", timeit.timeit(solve_, number=1)) - print("Run time", timeit.timeit(solve_, number=1)) + solve1 = ft.partial(solve, model, jnp.coyp(x), scan_stages, backsolve) + solve2 = ft.partial(solve, model, jnp.copy(x), scan_stages, backsolve) + print("Compile+run time", timeit.timeit(solve1, number=1)) + print("Run time", timeit.timeit(solve2, number=1)) fire.Fire(main) diff --git a/benchmarks/small_neural_ode.py b/benchmarks/small_neural_ode.py index 95eb2260..1beae093 100644 --- a/benchmarks/small_neural_ode.py +++ b/benchmarks/small_neural_ode.py @@ -185,11 +185,11 @@ def main(batch_size=64, t1=100, multiple=False, grad=False): time_torch(neural_ode_torch, y0_torch, t1, grad) torch_time = time_torch(neural_ode_torch, y0_torch, t1, grad) - time_jax(neural_ode_diffrax, y0_jax, t1, grad) - diffrax_time = time_jax(neural_ode_diffrax, y0_jax, t1, grad) + time_jax(neural_ode_diffrax, jnp.copy(y0_jax), t1, grad) + diffrax_time = time_jax(neural_ode_diffrax, jnp.copy(y0_jax), t1, grad) - time_jax(neural_ode_experimental, y0_jax, t1, grad) - experimental_time = time_jax(neural_ode_experimental, y0_jax, t1, grad) + time_jax(neural_ode_experimental, jnp.copy(y0_jax), t1, grad) + experimental_time = time_jax(neural_ode_experimental, jnp.copy(y0_jax), t1, grad) print( f""" diff --git a/diffrax/brownian/path.py b/diffrax/brownian/path.py index 84019f01..60de8155 100644 --- a/diffrax/brownian/path.py +++ b/diffrax/brownian/path.py @@ -62,7 +62,7 @@ def t0(self): def t1(self): return None - @eqx.filter_jit + @eqx.filter_jit(donate="none") def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]: del left t0 = eqxi.nondifferentiable(t0, name="t0") diff --git a/diffrax/brownian/tree.py b/diffrax/brownian/tree.py index 0941d544..2c0f1456 100644 --- a/diffrax/brownian/tree.py +++ b/diffrax/brownian/tree.py @@ -88,7 +88,7 @@ def __init__( ) self.key = split_by_tree(key, self.shape) - @eqx.filter_jit + @eqx.filter_jit(donate="none") def evaluate( self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True ) -> PyTree[Array]: diff --git a/diffrax/global_interpolation.py b/diffrax/global_interpolation.py index ee1dcdaa..0c5b894e 100644 --- a/diffrax/global_interpolation.py +++ b/diffrax/global_interpolation.py @@ -76,7 +76,7 @@ def _check(_ys): jtu.tree_map(_check, self.ys) - @eqx.filter_jit + @eqx.filter_jit(donate="none") def evaluate( self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True ) -> PyTree: @@ -130,7 +130,7 @@ def _index(_ys): prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t) ).ω - @eqx.filter_jit + @eqx.filter_jit(donate="none") def derivative(self, t: Scalar, left: bool = True) -> PyTree: r"""Evaluate the derivative of the linear interpolation. Essentially equivalent to `jax.jvp(self.evaluate, (t,), (jnp.ones_like(t),))`. @@ -195,7 +195,7 @@ def _check(d, c, b, a): jtu.tree_map(_check, *self.coeffs) - @eqx.filter_jit + @eqx.filter_jit(donate="none") def evaluate( self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True ) -> PyTree: @@ -239,7 +239,7 @@ def evaluate( + frac * (ω(b)[index] + frac * (ω(c)[index] + frac * ω(d)[index])) ).ω - @eqx.filter_jit + @eqx.filter_jit(donate="none") def derivative(self, t: Scalar, left: bool = True) -> PyTree: r"""Evaluate the derivative of the cubic interpolation. Essentially equivalent to `jax.jvp(self.evaluate, (t,), (jnp.ones_like(t),))`. @@ -309,7 +309,7 @@ def _get_local_interpolation(self, t: Scalar, left: bool): infos = ω(self.infos)[index].ω return self.interpolation_cls(t0=prev_t, t1=next_t, **infos) - @eqx.filter_jit + @eqx.filter_jit(donate="none") def evaluate( self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True ) -> PyTree: @@ -320,7 +320,7 @@ def evaluate( # continuous. return self._get_local_interpolation(t0, left).evaluate(t0) - @eqx.filter_jit + @eqx.filter_jit(donate="none") def derivative(self, t: Scalar, left: bool = True) -> PyTree: # Passing `left` doesn't matter on a local interpolation, which is globally # continuous. @@ -420,7 +420,7 @@ def _linear_interpolation( return ys -@eqx.filter_jit +@eqx.filter_jit(donate="none") def linear_interpolation( ts: Array["times"], # noqa: F821 ys: PyTree["times", ...], # noqa: F821 @@ -474,7 +474,7 @@ def _rectilinear_interpolation( return ts, ys -@eqx.filter_jit +@eqx.filter_jit(donate="none") def rectilinear_interpolation( ts: Array["times"], # noqa: F821 ys: PyTree["times", ...], # noqa: F821 @@ -659,7 +659,7 @@ def _backward_hermite_coefficients( return ds, cs, bs, as_ -@eqx.filter_jit +@eqx.filter_jit(donate="none") def backward_hermite_coefficients( ts: Array["times"], # noqa: F821 ys: PyTree["times", ...], # noqa: F821 diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 1bf48e00..29c6a867 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -510,7 +510,7 @@ def _cond_fun(state): return eqx.tree_at(lambda s: s.result, final_state, result), aux_stats -@eqx.filter_jit +@eqx.filter_jit(donate="none") def diffeqsolve( terms: PyTree[AbstractTerm], solver: AbstractSolver, diff --git a/examples/kalman_filter.ipynb b/examples/kalman_filter.ipynb index 5ed8e5f7..5a1aea20 100644 --- a/examples/kalman_filter.ipynb +++ b/examples/kalman_filter.ipynb @@ -22,7 +22,6 @@ "metadata": {}, "outputs": [], "source": [ - "import functools as ft\n", "from types import SimpleNamespace\n", "from typing import Optional\n", "\n", @@ -320,21 +319,27 @@ " lambda tree: (tree.Q, tree.R), filter_spec, replace=(True, True)\n", " )\n", "\n", - " @eqx.filter_jit\n", - " @ft.partial(eqx.filter_value_and_grad, arg=filter_spec)\n", - " def loss_fn(kmf, ts, ys, xs):\n", + " opt = optax.adam(1e-2)\n", + " opt_state = opt.init(kmf)\n", + "\n", + " @eqx.filter_value_and_grad\n", + " def loss_fn(dynamic_kmf, static_kmf, ts, ys, xs):\n", + " kmf = eqx.combine(dynamic_kmf, static_kmf)\n", " xhats = kmf(ts, ys)\n", " return jnp.mean((xs - xhats) ** 2)\n", "\n", - " opt = optax.adam(1e-2)\n", - " opt_state = opt.init(kmf)\n", + " @eqx.filter_jit\n", + " def make_step(kmf, opt_state, ts, ys, xs):\n", + " dynamic_kmf, static_kmf = eqx.partition(kmf, filter_spec)\n", + " value, grads = loss_fn(dynamic_kmf, static_kmf, ts, ys, xs)\n", + " updates, opt_state = opt.update(grads, opt_state)\n", + " kmf = eqx.apply_updates(kmf, updates)\n", + " return value, kmf, opt_state\n", "\n", " for step in range(n_gradient_steps):\n", - " value, grads = loss_fn(kmf, ts, ys, xs)\n", + " value, kmf, opt_state = make_step(kmf, opt_state, ts, ys, xs)\n", " if step % print_every == 0:\n", " print(\"Current MSE: \", value)\n", - " updates, opt_state = opt.update(grads, opt_state)\n", - " kmf = eqx.apply_updates(kmf, updates)\n", "\n", " print(f\"Final Q: \\n{kmf.Q}\\n Final R: \\n{kmf.R}\")\n", "\n", diff --git a/examples/neural_cde.ipynb b/examples/neural_cde.ipynb index d894c847..c989541f 100644 --- a/examples/neural_cde.ipynb +++ b/examples/neural_cde.ipynb @@ -275,7 +275,7 @@ "\n", " # Training loop like normal.\n", "\n", - " @eqx.filter_jit\n", + " @eqx.filter_jit(donate=\"none\")\n", " def loss(model, ti, label_i, coeff_i):\n", " pred = jax.vmap(model)(ti, coeff_i)\n", " # Binary cross-entropy\n", diff --git a/setup.py b/setup.py index c7329ad3..62c12ae0 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ python_requires = "~=3.7" -install_requires = ["jax>=0.3.4", "equinox>=0.9.1"] +install_requires = ["jax>=0.3.4", "equinox>=0.10.0"] setuptools.setup( name=name, diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 3bf48ec6..0b9f7aee 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -76,9 +76,10 @@ def _run(y0__args__term, saveat, adjoint): _run_grad = eqx.filter_jit( jax.grad( lambda d, saveat, adjoint: _run(eqx.combine(d, nondiff), saveat, adjoint) - ) + ), + donate="none", ) - _run_grad_int = eqx.filter_jit(jax.grad(_run, allow_int=True)) + _run_grad_int = eqx.filter_jit(jax.grad(_run, allow_int=True), donate="none") # Yep, test that they're not implemented. We can remove these checks if we ever # do implement them. From cbf944ca228ffcea075c541463a8a1b898b2ec5b Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 17 Nov 2022 14:31:22 -0800 Subject: [PATCH 04/19] The great bounded-while-loop clean-up --- diffrax/__init__.py | 4 +- diffrax/{misc => }/ad.py | 0 diffrax/adjoint.py | 180 ++++++++++++++-- diffrax/bounded_while_loop.py | 75 +++++++ diffrax/integrate.py | 246 ++++----------------- diffrax/{misc => }/misc.py | 2 +- diffrax/misc/__init__.py | 13 -- diffrax/misc/bounded_while_loop.py | 241 --------------------- diffrax/misc/sde_kl_divergence.py | 74 ------- diffrax/nonlinear_solver/base.py | 2 +- diffrax/saveat.py | 10 +- diffrax/step_size_controller/adaptive.py | 6 +- diffrax/step_size_controller/constant.py | 25 +-- docs/api/stepsize_controller.md | 3 +- docs/devdocs/bounded_while_loop.md | 138 ------------ test/helpers.py | 22 -- test/test_adjoint.py | 36 ++-- test/test_bounded_while_loop.py | 264 +++++------------------ test/test_integrate.py | 197 ----------------- 19 files changed, 361 insertions(+), 1177 deletions(-) rename diffrax/{misc => }/ad.py (100%) create mode 100644 diffrax/bounded_while_loop.py rename diffrax/{misc => }/misc.py (99%) delete mode 100644 diffrax/misc/__init__.py delete mode 100644 diffrax/misc/bounded_while_loop.py delete mode 100644 diffrax/misc/sde_kl_divergence.py delete mode 100644 docs/devdocs/bounded_while_loop.md diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 852f01d9..ff90008b 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -28,7 +28,7 @@ LocalLinearInterpolation, ThirdOrderHermitePolynomialInterpolation, ) -from .misc import adjoint_rms_seminorm, sde_kl_divergence +from .misc import adjoint_rms_seminorm from .nonlinear_solver import ( AbstractNonlinearSolver, NewtonNonlinearSolver, @@ -87,4 +87,4 @@ ) -__version__ = "0.2.2" +__version__ = "0.3.0" diff --git a/diffrax/misc/ad.py b/diffrax/ad.py similarity index 100% rename from diffrax/misc/ad.py rename to diffrax/ad.py diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index 889c52b2..06427590 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -1,5 +1,7 @@ import abc -from typing import Any, Dict +import functools as ft +import math +from typing import Any, Dict, Optional import equinox as eqx import equinox.internal as eqxi @@ -8,7 +10,8 @@ import jax.tree_util as jtu from equinox.internal import ω -from .misc import implicit_jvp +from .ad import implicit_jvp +from .bounded_while_loop import bounded_while_loop from .saveat import SaveAt from .term import AbstractTerm, AdjointTerm @@ -63,6 +66,23 @@ def _no_transpose_final_state(final_state): return final_state +def _while_loop(cond_fun, body_fun, init_val, max_steps): + if max_steps is None: + return lax.while_loop(cond_fun, body_fun, init_val) + else: + + def _cond_fun(carry): + step, val = carry + return (step < max_steps) & cond_fun(val) + + def _body_fun(carry): + step, val = carry + return step + 1, body_fun(val) + + _, final_val = lax.while_loop(_cond_fun, _body_fun, (0, init_val)) + return final_val + + class AbstractAdjoint(eqx.Module): """Abstract base class for all adjoint methods.""" @@ -120,28 +140,152 @@ class RecursiveCheckpointAdjoint(AbstractAdjoint): solution directly. This is sometimes known as "discretise-then-optimise", or described as "backpropagation through the solver". + Uses a binomial checkpointing scheme to keep memory usage low. + + For most problems this is the preferred technique for backpropagating through a + differential equation. + """ + + def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs): + del throw, passed_solver_state, passed_controller_state + return self._loop_fn(**kwargs, while_loop=bounded_while_loop) + + +class RecursiveCheckpointAdjoint2(AbstractAdjoint): + """Backpropagate through [`diffrax.diffeqsolve`][] by differentiating the numerical + solution directly. This is sometimes known as "discretise-then-optimise", or + described as "backpropagation through the solver". + + Uses a binomial checkpointing scheme to keep memory usage low. + For most problems this is the preferred technique for backpropagating through a differential equation. - In addition a binomial checkpointing scheme is used so that memory usage is low. - (This checkpointing can increase compile time a bit, though.) + !!! info - !!! Reference + Note that this cannot be forward-mode autodifferentiated. (E.g. using + `jax.jvp`.) + + ??? cite "References" + + Selecting which steps at which to save checkpoints (and when this is done, which + old checkpoint to evict) is important for minimising the amount of recomputation + performed. + + The implementation here performs "online checkpointing", as the number of steps + is not known in advance. This was developed in: - Binomial checkpointing (also known as "treeverse") was introduced in: ```bibtex - @article{griewank1998treeverse, - title = {Treeverse: An Implementation of Checkpointing for the Reverse or - Adjoint Mode of Computational Differentiation} + @article{stumm2010new, + author = {Stumm, Philipp and Walther, Andrea}, + title = {New Algorithms for Optimal Online Checkpointing}, + journal = {SIAM Journal on Scientific Computing}, + volume = {32}, + number = {2}, + pages = {836--854}, + year = {2010}, + doi = {10.1137/080742439}, + } + + @article{wang2009minimal, + author = {Wang, Qiqi and Moin, Parviz and Iaccarino, Gianluca}, + title = {Minimal Repetition Dynamic Checkpointing Algorithm for Unsteady + Adjoint Calculation}, + journal = {SIAM Journal on Scientific Computing}, + volume = {31}, + number = {4}, + pages = {2549--2567}, + year = {2009}, + doi = {10.1137/080727890}, + } + ``` + + For reference, the classical "offline checkpointing" (also known as "treeverse", + "recursive binary checkpointing", "revolve" etc.) was developed in: + + ```bibtex + @article{griewank1992achieving, + author = {Griewank, Andreas}, + title = {Achieving logarithmic growth of temporal and spatial complexity in + reverse automatic differentiation}, + journal = {Optimization Methods and Software}, + volume = {1}, + number = {1}, + pages = {35--54}, + year = {1992}, + publisher = {Taylor & Francis}, + doi = {10.1080/10556789208805505}, + } + + @article{griewank2000revolve, author = {Griewank, Andreas and Walther, Andrea}, - year = {1998}, + title = {Algorithm 799: Revolve: An Implementation of Checkpointing for the + Reverse or Adjoint Mode of Computational Differentiation}, + year = {2000}, + publisher = {Association for Computing Machinery}, + volume = {26}, + number = {1}, + doi = {10.1145/347837.347846}, + journal = {ACM Trans. Math. Softw.}, + pages = {19--45}, } ``` """ - def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs): + checkpoints: Optional[int] = None + + def loop( + self, + *, + max_steps, + throw, + passed_solver_state, + passed_controller_state, + **kwargs, + ): del throw, passed_solver_state, passed_controller_state - return self._loop_fn(**kwargs, is_bounded=True) + if self.checkpoints is None: + if max_steps is None: + raise ValueError( + "Cannot use " + "`diffeqsolve(..., max_steps=None, adjoint=RecursiveCheckpointAdjoint(checkpoints=None))` " # noqa: E501 + "Either specify the number of `checkpoints` to use, or specify the " + "maximum number of steps (and `checkpoints` is chosen " + "automatically as `log2(max_steps)``.)" + ) + # Binomial logarithmic growth is what is needed in classical treeverse. + # + # Moreover this is optimal even in the online case, as provided + # `max_steps >= 21` + # then + # `checkpoints = ceil(log2(max_steps))` + # satisfies + # `max_steps <= (checkpoints + 1)(checkpoints + 2)/2` + # which is the condition for optimality. + # + # Meanwhile if + # `max_steps <= 20` + # then we handle it as a special case, to once again ensure we satisfy + # `max_steps <= (checkpoints + 1)(checkpoints + 2)/2` + # + # The optimality condition is equation (2.2) of + # "New Algorithms for Optimal Online Checkpointing", Stumm and Walther 2010. + # https://tu-dresden.de/mn/math/wir/ressourcen/dateien/forschung/publikationen/pdf2010/new_algorithms_for_optimal_online_checkpointing.pdf + if max_steps <= 20: + checkpoints = 1 + while (checkpoints + 1) * (checkpoints + 2) < 2 * max_steps: + checkpoints += 1 + else: + checkpoints = math.ceil(math.log2(max_steps)) + else: + checkpoints = self.checkpoints + return self._loop_fn( + max_steps=max_steps, + while_loop=ft.partial( + eqxi.checkpointed_while_loop, checkpoints=checkpoints + ), + **kwargs, + ) class NoAdjoint(AbstractAdjoint): @@ -153,9 +297,7 @@ class NoAdjoint(AbstractAdjoint): def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs): del throw, passed_solver_state, passed_controller_state - final_state, aux_stats = self._loop_fn(**kwargs, is_bounded=False) - final_state = eqxi.nondifferentiable_backward(final_state) - return final_state, aux_stats + return self._loop_fn(**kwargs, while_loop=_while_loop) def _vf(ys, residual, args__terms, closure): @@ -178,7 +320,7 @@ def _solve(args__terms, closure): solver=solver, saveat=saveat, init_state=init_state, - is_bounded=False, + while_loop=_while_loop, ) # Note that we use .ys not .y here. The former is what is actually returned # by diffeqsolve, so it is the thing we want to attach the tangent to. @@ -260,7 +402,11 @@ def _loop_backsolve(y__args__terms, *, self, throw, init_state, **kwargs): ) del y return self._loop_fn( - args=args, terms=terms, init_state=init_state, is_bounded=False, **kwargs + args=args, + terms=terms, + init_state=init_state, + while_loop=_while_loop, + **kwargs, ) diff --git a/diffrax/bounded_while_loop.py b/diffrax/bounded_while_loop.py new file mode 100644 index 00000000..59d5e0ac --- /dev/null +++ b/diffrax/bounded_while_loop.py @@ -0,0 +1,75 @@ +import functools as ft +import math + +import equinox.internal as eqxi +import jax +import jax.lax as lax +import jax.numpy as jnp +import jax.tree_util as jtu + + +def bounded_while_loop(cond_fun, body_fun, init_val, max_steps, base=16): + """Reverse-mode autodifferentiable while loop. + + Mostly as `lax.while_loop`, with a few small changes. + + Arguments: + cond_fun: function `a -> bool` + body_fun: function `a -> a`. + init_val: pytree of type `a`. + max_steps: integer or `None`. + base: integer. + + Note the extra `max_steps` argument. If this is `None` then `bounded_while_loop` + will fall back to `lax.while_loop` (which is not reverse-mode autodifferentiable). + If it is a non-negative integer then this is the maximum number of steps which may + be taken in the loop, after which the loop will exit unconditionally. + + Note the extra `base` argument. + - Run time will increase slightly as `base` increases. + - Compilation time will decrease substantially as + `math.ceil(math.log(max_steps, base))` decreases. (Which happens as `base` + increases.) + """ + + init_val = jtu.tree_map(jnp.asarray, init_val) + + if max_steps is None: + return lax.while_loop(cond_fun, body_fun, init_val) + + if not isinstance(max_steps, int) or max_steps < 0: + raise ValueError("max_steps must be a non-negative integer") + if max_steps == 0: + return init_val + + def _cond_fun(val, step): + return cond_fun(val) & (step < max_steps) + + init_data = (cond_fun(init_val), init_val, 0) + rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base))) + _, val, _ = _while_loop(_cond_fun, body_fun, init_data, rounded_max_steps, base) + return val + + +def _while_loop(cond_fun, body_fun, data, max_steps, base): + if max_steps == 1: + pred, val, step = data + new_val = body_fun(val) + new_val = jtu.tree_map(ft.partial(lax.select, pred), new_val, val) + new_step = step + 1 + return cond_fun(new_val, new_step), new_val, new_step + else: + + def _call(_data): + return _while_loop(cond_fun, body_fun, _data, max_steps // base, base) + + def _scan_fn(_data, _): + _pred, _, _ = _data + _unvmap_pred = eqxi.unvmap_any(_pred) + return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None + + # Don't put checkpointing on the lowest level + if max_steps != base: + _scan_fn = jax.checkpoint(_scan_fn, prevent_cse=False) + + return lax.scan(_scan_fn, data, xs=None, length=base)[0] diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 29c6a867..f1d1a16f 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -5,21 +5,21 @@ import equinox as eqx import equinox.internal as eqxi import jax -import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu from .adjoint import ( AbstractAdjoint, BacksolveAdjoint, + ImplicitAdjoint, NoAdjoint, RecursiveCheckpointAdjoint, ) +from .bounded_while_loop import bounded_while_loop from .custom_types import Array, Bool, Int, PyTree, Scalar from .event import AbstractDiscreteTerminatingEvent from .global_interpolation import DenseInterpolation from .heuristics import is_sde, is_unsafe_sde -from .misc import bounded_while_loop, HadInplaceUpdate from .saveat import SaveAt from .solution import is_okay, is_successful, RESULTS, Solution from .solver import AbstractItoSolver, AbstractSolver, AbstractStratonovichSolver, Euler @@ -105,7 +105,7 @@ def loop( terms, args, init_state, - is_bounded, + while_loop, ): if saveat.t0: @@ -127,7 +127,7 @@ def loop( def cond_fun(state): return (state.tprev < t1) & is_successful(state.result) - def body_fun(state, inplace): + def body_fun(state): # # Actually do some differential equation solving! Make numerical steps, adapt @@ -215,11 +215,6 @@ def body_fun(state, inplace): # # Store the output produced from this numerical step. - # This is a bit involved, and uses the `inplace` function passed as an argument - # to this body function. - # This is because we need to make in-place updates to store our results, but - # doing is a bit of a hassle inside `bounded_while_loop`. (See its docstring - # for details.) # saveat_ts_index = state.saveat_ts_index @@ -229,97 +224,34 @@ def body_fun(state, inplace): dense_ts = state.dense_ts dense_infos = state.dense_infos dense_save_index = state.dense_save_index - made_inplace_update = False if saveat.ts is not None: - made_inplace_update = True _interpolator = solver.interpolation_cls( t0=state.tprev, t1=state.tnext, **dense_info ) - def _saveat_get(_saveat_ts_index): - return saveat.ts[jnp.minimum(_saveat_ts_index, len(saveat.ts) - 1)] - def _cond_fun(_state): - _saveat_ts_index = _state.saveat_ts_index - _saveat_t = _saveat_get(_saveat_ts_index) return ( keep_step - & (_saveat_t <= state.tnext) - & (_saveat_ts_index < len(saveat.ts)) + & (saveat.ts[_state.saveat_ts_index] <= state.tnext) + & (_state.saveat_ts_index < len(saveat.ts)) ) - def _body_fun(_state, _inplace): - _saveat_ts_index = _state.saveat_ts_index - _ts = _state.ts - _ys = _state.ys - _save_index = _state.save_index - - _saveat_t = _saveat_get(_saveat_ts_index) + def _body_fun(_state): + _saveat_t = saveat.ts[_state.saveat_ts_index] _saveat_y = _interpolator.evaluate(_saveat_t) - - # VOODOO MAGIC - # - # Okay, time for some voodoo that I absolutely don't understand. - # - # Shown in the comment is what I would to write: - # - # _inplace = _inplace.merge(inplace) - # _ts = _inplace(_ts).at[_save_index].set(_saveat_t) - # _ys = jtu.tree_map(lambda __ys, __saveat_y: _inplace(__ys).at[_save_index].set(__saveat_y), _ys, _saveat_y) # noqa: E501 - # - # Seems reasonable, right? Just updating a value. - # - # Below is what we actually run: - - _inplace.merge(inplace) - _pred = cond_fun(state) & _cond_fun(_state) - _ts = _ts.at[_save_index].set( - jnp.where(_pred, _saveat_t, _ts[_save_index]) - ) + _ts = _state.ts.at[_state.save_index].set(_saveat_t) _ys = jtu.tree_map( - lambda __ys, __saveat_y: __ys.at[_save_index].set( - jnp.where(_pred, __saveat_y, __ys[_save_index]) - ), - _ys, + lambda __ys, __saveat_y: __ys.at[_state.save_index].set(__saveat_y), + _state.ys, _saveat_y, ) - - # Some immediate questions you might have: - # - # - Isn't this essentially equivalent to the commented-out version? - # - Nitpick: the commented-out version includes an enhanced cond_fun - # that checks the step count, but it shouldn't matter here. - # - It looks like `_inplace.merge(inplace)` isn't even used? - # - I think it will appear in the jaxpr, interestingly, based off of - # the toy example: - # >>> def f(x, y): - # ... x & y - # ... return x + 1 - # >>> jax.make_jaxpr(f)(1, 2) - # Which is presumably how this manages to affect anything at all. - # - # And you are right. Those are both reasonable questions, at least as - # far as I can see. - # - # And yet for some reason this version will run substantially faster. - # (At time of writing: on the `small_neural_ode.py` benchmark, on the - # CPU.) - # - # ~VOODOO MAGIC - - _saveat_ts_index = _saveat_ts_index + 1 - _save_index = _save_index + 1 - - _ts = HadInplaceUpdate(_ts) - _ys = jtu.tree_map(HadInplaceUpdate, _ys) - return _InnerState( - saveat_ts_index=_saveat_ts_index, + saveat_ts_index=_state.saveat_ts_index + 1, ts=_ts, ys=_ys, - save_index=_save_index, + save_index=_state.save_index + 1, ) init_inner_state = _InnerState( @@ -335,17 +267,16 @@ def _body_fun(_state, _inplace): ys = final_inner_state.ys save_index = final_inner_state.save_index + # TODO: make while loop? def maybe_inplace(i, x, u): - return inplace(x).at[i].set(jnp.where(keep_step, u, x[i])) + return x.at[i].set(jnp.where(keep_step, u, x[i])) if saveat.steps: - made_inplace_update = True ts = maybe_inplace(save_index, ts, tprev) ys = jtu.tree_map(ft.partial(maybe_inplace, save_index), ys, y) save_index = save_index + keep_step if saveat.dense: - made_inplace_update = True dense_ts = maybe_inplace(dense_save_index + 1, dense_ts, tprev) dense_infos = jtu.tree_map( ft.partial(maybe_inplace, dense_save_index), @@ -354,12 +285,6 @@ def maybe_inplace(i, x, u): ) dense_save_index = dense_save_index + keep_step - if made_inplace_update: - ts = HadInplaceUpdate(ts) - ys = jtu.tree_map(HadInplaceUpdate, ys) - dense_ts = HadInplaceUpdate(dense_ts) - dense_infos = jtu.tree_map(HadInplaceUpdate, dense_infos) - new_state = _State( y=y, tprev=tprev, @@ -402,101 +327,7 @@ def maybe_inplace(i, x, u): return new_state - if is_bounded: - # Some privileged optimisations, but for common use cases. - # TODO: make these a method on an AbstractFixedStepSizeController? - # - # These optimisations depend on implementations details of `ConstantStepSize`, - # `StepTo`, and `bounded_while_loop`. - # - # We try to determine the exact number of integration steps that will be made. - # If this is possible then we can use a single `lax.scan`, rather than the - # recursive construction of `bounded_while_loop`. This primarily reduces - # compilation times. - if max_steps is None: - # `bounded_while_loop(..., max_steps=None)` lowers to `lax.while_loop` - # anyway; this is already fast. Don't try to determine the number of steps - # needed. - compiled_num_steps = None - elif isinstance(stepsize_controller, ConstantStepSize) and ( - stepsize_controller.compile_steps is None - or stepsize_controller.compile_steps is True - ): - # We can determine the number of steps quite easily with constant step - # size. - # - # We do so using a `lax.while_loop`. - # - Not just a (t1 - t0)/dt0 division, to avoid floating point errors. - # - lax.while_loop, not just a Python one, to ensure that we match the - # behaviour at runtime; no funny edge cases. - with jax.ensure_compile_time_eval(): - - def _is_finite(_t): - all_finite = eqxi.unvmap_all(jnp.isfinite(_t)) - return not isinstance(all_finite, jax.core.Tracer) and all_finite - - if _is_finite(t0) and _is_finite(t1) and _is_finite(dt0): - - def _cond_fun(_state): - _, _t = _state - return _t < t1 - - def _body_fun(_state): - _step, _t = _state - return _step + 1, _clip_to_end(_t, _t + dt0, t1, True) - - compiled_num_steps, _ = lax.while_loop( - _cond_fun, _body_fun, (0, t0) - ) - compiled_num_steps = eqxi.unvmap_max(compiled_num_steps) - else: - if stepsize_controller.compile_steps is None: - compiled_num_steps = None - else: - assert stepsize_controller.compile_steps is True - raise ValueError( - "Could not determine exact number of steps, but " - "`stepsize_controller.compile_steps=True`" - ) - elif isinstance(stepsize_controller, StepTo) and ( - stepsize_controller.compile_steps is None - or stepsize_controller.compile_steps is True - ): - # The user has explicitly specified the number of steps. - compiled_num_steps = len(stepsize_controller.ts) - 1 - else: - # Else we can't determine the number of steps. - compiled_num_steps = None - - if compiled_num_steps is None or isinstance( - compiled_num_steps, jax.core.Tracer - ): - # If we couldn't determine the number of steps then use the default - # recursive construction. - compiled_num_steps = None - base = 16 - else: - if isinstance(compiled_num_steps, jnp.ndarray): - compiled_num_steps = compiled_num_steps.item() - base = compiled_num_steps - max_steps = min(max_steps, compiled_num_steps) - - final_state = bounded_while_loop( - cond_fun, body_fun, init_state, max_steps, base=base - ) - else: - compiled_num_steps = None - - if max_steps is None: - _cond_fun = cond_fun - else: - - def _cond_fun(state): - return cond_fun(state) & (state.num_steps < max_steps) - - final_state = bounded_while_loop( - _cond_fun, body_fun, init_state, max_steps=None - ) + final_state = while_loop(cond_fun, body_fun, init_state, max_steps) if saveat.t1 and not saveat.steps: # if saveat.steps then the final value is already saved. @@ -506,7 +337,7 @@ def _cond_fun(state): result = jnp.where( cond_fun(final_state), RESULTS.max_steps_reached, final_state.result ) - aux_stats = dict(compiled_num_steps=compiled_num_steps) + aux_stats = dict() return eqx.tree_at(lambda s: s.result, final_state, result), aux_stats @@ -702,23 +533,22 @@ def diffeqsolve( raise ValueError( "`UnsafeBrownianPath` cannot be used with adaptive step sizes." ) - if not isinstance(adjoint, NoAdjoint): + if not isinstance(adjoint, (NoAdjoint, ImplicitAdjoint)): raise ValueError( - "`UnsafeBrownianPath` can only be used with `adjoint=NoAdjoint()`." + "`UnsafeBrownianPath` can only be used with `adjoint=NoAdjoint()` or " + "`adjoint=ImplicitAdjoint()`." ) - # Allow setting e.g. t0 as an int with dt0 as a float. (We need consistent - # types for JAX to be happy with the bounded_while_loop below.) - with jax.ensure_compile_time_eval(): - timelikes = (jnp.array(0.0), t0, t1, dt0, saveat.ts) - timelikes = [x for x in timelikes if x is not None] - dtype = jnp.result_type(*timelikes) - t0 = jnp.asarray(t0, dtype=dtype) - t1 = jnp.asarray(t1, dtype=dtype) - if dt0 is not None: - dt0 = jnp.asarray(dt0, dtype=dtype) - if saveat.ts is not None: - saveat = eqx.tree_at(lambda s: s.ts, saveat, saveat.ts.astype(dtype)) + # Allow setting e.g. t0 as an int with dt0 as a float. + timelikes = (jnp.array(0.0), t0, t1, dt0, saveat.ts) + timelikes = [x for x in timelikes if x is not None] + dtype = jnp.result_type(*timelikes) + t0 = jnp.asarray(t0, dtype=dtype) + t1 = jnp.asarray(t1, dtype=dtype) + if dt0 is not None: + dt0 = jnp.asarray(dt0, dtype=dtype) + if saveat.ts is not None: + saveat = eqx.tree_at(lambda s: s.ts, saveat, saveat.ts.astype(dtype)) # Time will affect state, so need to promote the state dtype as well if necessary. def _promote(yi): @@ -729,14 +559,13 @@ def _promote(yi): del timelikes, dtype # Normalises time: if t0 > t1 then flip things around. - with jax.ensure_compile_time_eval(): - direction = jnp.where(t0 < t1, 1, -1) - t0 = t0 * direction - t1 = t1 * direction - if dt0 is not None: - dt0 = dt0 * direction - if saveat.ts is not None: - saveat = eqx.tree_at(lambda s: s.ts, saveat, saveat.ts * direction) + direction = jnp.where(t0 < t1, 1, -1) + t0 = t0 * direction + t1 = t1 * direction + if dt0 is not None: + dt0 = dt0 * direction + if saveat.ts is not None: + saveat = eqx.tree_at(lambda s: s.ts, saveat, saveat.ts * direction) stepsize_controller = stepsize_controller.wrap(direction) terms = jtu.tree_map( lambda t: WrapTerm(t, direction), @@ -918,7 +747,6 @@ def _promote(yi): "num_accepted_steps": final_state.num_accepted_steps, "num_rejected_steps": final_state.num_rejected_steps, "max_steps": max_steps, - "compiled_num_steps": aux_stats["compiled_num_steps"], } result = final_state.result sol = Solution( diff --git a/diffrax/misc/misc.py b/diffrax/misc.py similarity index 99% rename from diffrax/misc/misc.py rename to diffrax/misc.py index 6ae6797e..755a8efd 100644 --- a/diffrax/misc/misc.py +++ b/diffrax/misc.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import jax.tree_util as jtu -from ..custom_types import Array, PyTree, Scalar +from .custom_types import Array, PyTree, Scalar _itemsize_kind_type = { diff --git a/diffrax/misc/__init__.py b/diffrax/misc/__init__.py deleted file mode 100644 index 4b35bc2a..00000000 --- a/diffrax/misc/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .ad import implicit_jvp -from .bounded_while_loop import bounded_while_loop, HadInplaceUpdate -from .misc import ( - adjoint_rms_seminorm, - fill_forward, - force_bitcast_convert_type, - is_tuple_of_ints, - left_broadcast_to, - linear_rescale, - rms_norm, - split_by_tree, -) -from .sde_kl_divergence import sde_kl_divergence diff --git a/diffrax/misc/bounded_while_loop.py b/diffrax/misc/bounded_while_loop.py deleted file mode 100644 index ae3b6c89..00000000 --- a/diffrax/misc/bounded_while_loop.py +++ /dev/null @@ -1,241 +0,0 @@ -import math - -import equinox as eqx -import equinox.internal as eqxi -import jax -import jax.lax as lax -import jax.numpy as jnp -import jax.tree_util as jtu - -from ..custom_types import Array - - -def bounded_while_loop(cond_fun, body_fun, init_val, max_steps, base=16): - """Reverse-mode autodifferentiable while loop. - - Mostly as `lax.while_loop`, with a few small changes. - - Arguments: - cond_fun: function `a -> a` - body_fun: function `a -> b -> a`, where `b` is a function that should be used - instead of performing in-place updates with .at[].set() etc; see below. - init_val: pytree with structure `a`. - max_steps: integer or `None`. - base: integer. - - Limitations with in-place updates.: - The single big limitation is around making in-place updates. Done naively then - the XLA compiler will fail to treat these as in-place and will make a copy - every time. (See JAX issue #8192.) - - Working around this is a bit of a hassle -- as follows -- and it is for this - reason that `body_fun` takes a second argument. - - If you ever have: - - an inplace update... - - ...made to the input to the body_fun... - - ...whose result is returned from the body_fun... - ...then you should use - - ```python - x = inplace(x).at[i].set(u) - x = HadInplaceUpdate(x) - ``` - - in place of - - ```python - x = x.at[i].set(u) - ``` - - where `inplace` is the second argument to `body_fun`, and `HadInplaceUpdate` is - available at `diffrax.misc.HadInplaceUpdate`. - - Internally, `bounded_while_loop` will treat things so as to work around this - limitation of XLA. - - !!! faq - - `HadInplaceUpdate` is available separately (instead of being returned - automatically from `inplace().at[].set()`) in case the in-place update - takes place inside e.g. a `lax.scan` or similar, and you need to maintain - PyTree structures. Just place the `HadInplaceUpdate` at the very end of - `body_fun`. (And applied only to those array(s) that actually had in-place - update(s), if the state is a PyTree.) - - !!! note - - If you need to nest `bounded_while_loop`s, then the two `inplace` functions - can be merged: - - ```python - def body_fun(val, inplace): - ... # stuff (use inplace) - - def inner_body_fun(_val, _inplace): - _inplace = _inplace.merge(inplace) - ... # stuff (use _inplace) - - bounded_while_loop(body_fun=inner_body_fun, ...) - - ... # stuff (use inplace) - - bounded_while_loop(body_fun=body_fun, ...) - ``` - - !!! note - - In-place updates to arrays that are _created_ inside of `body_fun` can be - made as normal. It's just those arrays that are part of the state (that is - passed in and out) that need to be treated specially. - - Note the extra `max_steps` argument. If this is `None` then `bounded_while_loop` - will fall back to `lax.while_loop` (which is not reverse-mode autodifferentiable). - If it is a non-negative integer then this is the maximum number of steps which may - be taken in the loop, after which the loop will exit unconditionally. - - Note the extra `base` argument. - - Run time will increase slightly as `base` increases. - - Compilation time will decrease substantially as - `math.ceil(math.log(max_steps, base))` decreases. (Which happens as `base` - increases.) - """ - - init_val = jtu.tree_map(jnp.asarray, init_val) - - if max_steps is None: - - def _make_update(_new_val): - if isinstance(_new_val, HadInplaceUpdate): - return _new_val.val - else: - return _new_val - - def _body_fun(_val): - inplace = lambda x: x - inplace.pred = True - inplace.merge = lambda x: x - _new_val = body_fun(_val, inplace) - return jtu.tree_map( - _make_update, - _new_val, - is_leaf=lambda x: isinstance(x, HadInplaceUpdate), - ) - - return lax.while_loop(cond_fun, _body_fun, init_val) - - if not isinstance(max_steps, int) or max_steps < 0: - raise ValueError("max_steps must be a non-negative integer") - if max_steps == 0: - return init_val - - def _cond_fun(val, step): - return cond_fun(val) & (step < max_steps) - - init_data = (cond_fun(init_val), init_val, 0) - rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base))) - _, val, _ = _while_loop(_cond_fun, body_fun, init_data, rounded_max_steps, base) - return val - - -class _InplaceUpdate(eqx.Module): - pred: Array[bool] - - def __call__(self, val: Array): - return _InplaceUpdateInner(self.pred, val) - - def merge(self, other: "_InplaceUpdate") -> "_InplaceUpdate": - return _InplaceUpdate(self.pred & other.pred) - - -class _InplaceUpdateInner(eqx.Module): - pred: Array[bool] - val: Array - - @property - def at(self): - return _InplaceUpdateInnerInner(self.pred, self.val) - - -class _InplaceUpdateInnerInner(eqx.Module): - pred: Array[bool] - val: Array - - def __getitem__(self, index: Array): - return _InplaceUpdateInnerInnerInner(self.pred, self.val, index) - - -class _InplaceUpdateInnerInnerInner(eqx.Module): - pred: Array[bool] - val: Array - index: Array - - # TODO: implement other .add() etc. methods if required. - - def set(self, update: Array, **kwargs) -> Array: - old = self.val[self.index] - new = lax.select(self.pred, update, old) - return self.val.at[self.index].set(new, **kwargs) - - -class HadInplaceUpdate(eqx.Module): - val: Array - - -# There's several tricks happening here to work around various limitations of JAX. -# (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633) -# 1. `unvmap_any` prior to using `lax.cond`. JAX has a problem in that vmap-of-cond -# is converted to a `lax.select`, which executes both branches unconditionally. -# Thus writing this naively, using a plain `lax.cond`, will mean the loop always -# runs to `max_steps` when executing under vmap. Instead we run (only) until every -# batch element has finished. -# 2. Treating in-place updates specially in the body_fun. Specifically we need to -# `lax.select` the update-to-make, not the updated buffer. This is because the -# latter instead results in XLA:CPU failing to determine that the buffer can be -# updated in-place, and instead it makes a copy. c.f. JAX issue #8192. -# This is done through the extra `inplace` argument provided to `body_fun`. -# 3. The use of the `@jax.checkpoint` decorator. Backpropagating through a -# `bounded_while_loop` will otherwise run in θ(max_steps) time, rather than -# θ(number of steps actually taken). See -# https://docs.kidger.site/diffrax/devdocs/bounded_while_loop/ -# 4. The use of `base`. In theory `base=2` is optimal at run time, as it implies the -# fewest superfluous operations. In practice this implies quite deep recursion in -# the construction of the bounded while loop, and this slows down the jaxpr -# creation and the XLA compilation. We choose `base=16` as a reasonable-looking -# compromise between compilation time and run time. -def _while_loop(cond_fun, body_fun, data, max_steps, base): - if max_steps == 1: - pred, val, step = data - - inplace_update = _InplaceUpdate(pred) - new_val = body_fun(val, inplace_update) - - def _make_update(_new_val, _val): - if isinstance(_new_val, HadInplaceUpdate): - return _new_val.val - else: - return lax.select(pred, _new_val, _val) - - new_val = jtu.tree_map( - _make_update, - new_val, - val, - is_leaf=lambda x: isinstance(x, HadInplaceUpdate), - ) - new_step = step + 1 - return cond_fun(new_val, new_step), new_val, new_step - else: - - def _call(_data): - return _while_loop(cond_fun, body_fun, _data, max_steps // base, base) - - def _scan_fn(_data, _): - _pred, _, _ = _data - _unvmap_pred = eqxi.unvmap_any(_pred) - return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None - - # Don't put checkpointing on the lowest level - if max_steps != base: - _scan_fn = jax.checkpoint(_scan_fn, prevent_cse=False) - - return lax.scan(_scan_fn, data, xs=None, length=base)[0] diff --git a/diffrax/misc/sde_kl_divergence.py b/diffrax/misc/sde_kl_divergence.py deleted file mode 100644 index 996ff5f2..00000000 --- a/diffrax/misc/sde_kl_divergence.py +++ /dev/null @@ -1,74 +0,0 @@ -import operator - -import equinox as eqx -import jax.numpy as jnp -import jax.tree_util as jtu - -from ..brownian import AbstractBrownianPath -from ..custom_types import PyTree - - -def _kl(drift1, drift2, diffusion): - inv_diffusion = jnp.linalg.pinv(diffusion) - scale = inv_diffusion @ (drift1 - drift2) - return 0.5 * jnp.sum(scale**2) - - -class _AugDrift(eqx.Module): - drift1: callable - drift2: callable - diffusion: callable - context: callable - - def __call__(self, t, y, args): - y, _ = y - context = self.context(t) - aug_y = jnp.concatenate([y, context], axis=-1) - drift1 = self.drift1(t, aug_y, args) - drift2 = self.drift2(t, y, args) - diffusion = self.diffusion(t, y, args) - kl_divergence = jtu.tree_map(_kl, drift1, drift2, diffusion) - kl_divergence = jtu.tree_reduce(operator.add, kl_divergence) - return drift1, kl_divergence - - -class _AugDiffusion(eqx.Module): - diffusion: callable - - def __call__(self, t, y, args): - y, _ = y - diffusion = self.diffusion(t, y, args) - return diffusion, 0.0 - - -class _AugBrownianPath(eqx.Module): - bm: AbstractBrownianPath - - @property - def t0(self): - return self.bm.t0 - - @property - def t1(self): - return self.bm.t1 - - def evaluate(self, t0, t1): - return self.bm.evaluate(t0, t1), 0.0 - - -def sde_kl_divergence( - *, - drift1: callable, - drift2: callable, - diffusion: callable, - context: callable, - y0: PyTree, - bm: AbstractBrownianPath, -): - aug_y0 = (y0, 0.0) - return ( - _AugDrift(drift1, drift2, diffusion, context), - _AugDiffusion(diffusion), - aug_y0, - _AugBrownianPath(bm), - ) diff --git a/diffrax/nonlinear_solver/base.py b/diffrax/nonlinear_solver/base.py index 24872ae5..ee742420 100644 --- a/diffrax/nonlinear_solver/base.py +++ b/diffrax/nonlinear_solver/base.py @@ -8,8 +8,8 @@ import jax.numpy as jnp import jax.scipy as jsp +from ..ad import implicit_jvp from ..custom_types import Int, PyTree, Scalar -from ..misc import implicit_jvp from ..solution import RESULTS diff --git a/diffrax/saveat.py b/diffrax/saveat.py index 2eccc883..800d6083 100644 --- a/diffrax/saveat.py +++ b/diffrax/saveat.py @@ -1,7 +1,6 @@ from typing import Optional, Sequence, Union import equinox as eqx -import jax import jax.numpy as jnp from .custom_types import Array, Scalar @@ -24,9 +23,12 @@ class SaveAt(eqx.Module): made_jump: bool = False def __post_init__(self): - with jax.ensure_compile_time_eval(): - ts = None if self.ts is None else jnp.asarray(self.ts) - object.__setattr__(self, "ts", ts) + if self.ts is not None: + if len(self.ts) == 0: + ts = None + else: + ts = jnp.asarray(self.ts) + object.__setattr__(self, "ts", ts) if ( not self.t0 and not self.t1 diff --git a/diffrax/step_size_controller/adaptive.py b/diffrax/step_size_controller/adaptive.py index 1dee4850..5297050d 100644 --- a/diffrax/step_size_controller/adaptive.py +++ b/diffrax/step_size_controller/adaptive.py @@ -3,7 +3,6 @@ import equinox as eqx import equinox.internal as eqxi -import jax import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu @@ -284,9 +283,8 @@ class PIDController(AbstractAdaptiveStepSizeController): def __post_init__(self): super().__post_init__() - with jax.ensure_compile_time_eval(): - step_ts = None if self.step_ts is None else jnp.asarray(self.step_ts) - jump_ts = None if self.jump_ts is None else jnp.asarray(self.jump_ts) + step_ts = None if self.step_ts is None else jnp.asarray(self.step_ts) + jump_ts = None if self.jump_ts is None else jnp.asarray(self.jump_ts) object.__setattr__(self, "step_ts", step_ts) object.__setattr__(self, "jump_ts", jump_ts) diff --git a/diffrax/step_size_controller/constant.py b/diffrax/step_size_controller/constant.py index 7eaf5605..8547065e 100644 --- a/diffrax/step_size_controller/constant.py +++ b/diffrax/step_size_controller/constant.py @@ -1,7 +1,6 @@ from typing import Callable, Optional, Sequence, Tuple, Union import equinox.internal as eqxi -import jax import jax.numpy as jnp from ..custom_types import Array, Int, PyTree, Scalar @@ -15,8 +14,6 @@ class ConstantStepSize(AbstractStepSizeController): [`diffrax.diffeqsolve`][]. """ - compile_steps: Optional[bool] = False - def wrap(self, direction: Scalar): return self @@ -61,30 +58,13 @@ def adapt_step_size( ) -ConstantStepSize.__init__.__doc__ = """**Arguments:** - -- `compile_steps`: If `True` then the number of steps taken in the differential - equation solve will be baked into the compilation. When this is possible then - this can improve compile times and run times slightly. The downside is that this - implies re-compiling if this changes, and that this is only possible if the exact - number of steps to be taken is known in advance (i.e. `t0`, `t1`, `dt0` cannot be - traced values) -- and an error will be thrown if the exact number of steps could - not be determined. Set to `False` (the default) to not bake in the number of steps. - Set to `None` to attempt to bake in the number of steps, but to fall back to - `False`-behaviour if the number of steps could not be determined (rather than - throwing an error). -""" - - class StepTo(AbstractStepSizeController): """Make steps to just prespecified times.""" ts: Union[Sequence[Scalar], Array["times"]] # noqa: F821 - compile_steps: Optional[bool] = False def __post_init__(self): - with jax.ensure_compile_time_eval(): - object.__setattr__(self, "ts", jnp.asarray(self.ts)) + object.__setattr__(self, "ts", jnp.asarray(self.ts)) if self.ts.ndim != 1: raise ValueError("`ts` must be one-dimensional.") if len(self.ts) < 2: @@ -99,7 +79,7 @@ def wrap(self, direction: Scalar): "`StepTo(ts=...)` must be strictly increasing (or strictly decreasing if " "t0 > t1).", ) - return type(self)(ts=ts, compile_steps=self.compile_steps) + return type(self)(ts=ts) def init( self, @@ -153,5 +133,4 @@ def adapt_step_size( between the `t0` and `t1` (inclusive) passed to [`diffrax.diffeqsolve`][]. Correctness of `ts` with respect to `t0` and `t1` as well as its monotonicity is checked by the implementation. -- `compile_steps`: As [`diffrax.ConstantStepSize.__init__`][]. """ diff --git a/docs/api/stepsize_controller.md b/docs/api/stepsize_controller.md index cdd10b7d..4e323ebf 100644 --- a/docs/api/stepsize_controller.md +++ b/docs/api/stepsize_controller.md @@ -31,8 +31,7 @@ The list of step size controllers is as follows. The most common cases are fixed ::: diffrax.ConstantStepSize selection: - members: - - __init__ + members: false ::: diffrax.StepTo selection: diff --git a/docs/devdocs/bounded_while_loop.md b/docs/devdocs/bounded_while_loop.md deleted file mode 100644 index 54614428..00000000 --- a/docs/devdocs/bounded_while_loop.md +++ /dev/null @@ -1,138 +0,0 @@ -# Bounded while loop - -Some notes on implementing a bounded while loop in JAX. (Note that the bound is required for any hope of reverse-mode autodifferentability, due to the static memory requirements imposed by XLA.) - -Let $n$ be the number of steps actually taken. -Let $m$ be the maximum number of steps allowed. -Let $d$ be the depth of the recursive structure, when one is used. -Let $b$ be the base of the recursive structure, when one is used. -(So roughly $b^d = m$.) - -"Forward time" will refer to the amount of work done on the forward pass. -"Backward time" will refer to the amount of work done on the backawrd pass, including recomputing from checkpoints. -"Compile time" will refer to the size of the jaxpr or XLA HLO. (Which we assume to be proportional, although there are a few exceptions to this.) -"Memory usage" will refer to the maximum amount of memory needed to store an entire forward pass, if we land in the case that $n=m$. - -In practice, because XLA statically allocates memory, then the value specified by "memory usage" is actually allocated when performing a backward pass. And as spatial complexity bounds temporal complexity, then the actual backward time is the maximum of "backward time" and "memory usage". (!) - -We use $O(\ldots)$ to denote the costs involved, as usual. We additionally introduce $I(\ldots)$ to denote the cost of performing identity operations, which are used in some implementations instead of making a step. Identity operations are very cheap but not completely free so we count them separately. - -### Implementation 1: `scan`-`cond` - -This implementation just does a `scan` for m steps, checking `cond` on each one. - -Forward time: $O(n) + I(m)$ -Backward time: $O(n) + I(m)$ -Compile time: $O(1)$ -Memory usage: $O(m)$ - -Verdict: unsuitable, because of the huge memory usage. In addition the runtime $I(m)$ is disdvantageous. - -### Implementation 2: nested `scan`-`cond` - -This is probably the first serious idea you come up with when trying to write a bounded while loop. Do a `scan` for $b$ steps, checking `cond` on each one. Nest that implementation recursively $d$ times, so that you make a total of $m$ steps. That is, nested `scan`-`cond`-`scan`-...-`cond` where there are $d$-many `scan`s each of length $b$. - -Forward time: $O(n) + I(db)$ -Backward time: $O(n) + I(db)$ -Compile time: $O(1)$ -Memory usage: $O(m)$ - -This fixes the $I(m)$ runtime of the previous implementation by nesting things, so that you start making larger identity steps once you're done. Unfortunately the $O(m)$ memory usage (and thus speed on the backward pass) remains, so this is still unsuitable. - -### Implementation 3: treeverse - -Okay, memory usage is an issue. The obvious thing to do is to start thinking about gradient checkpointing, for which treeverse is the known optimality result. Assuming $b=2$ for simplicity/optimality, then this is arrived at by recursively calculating `fn(jax.checkpoint(fn)(x))` where the base case takes `fn` to be a `scan` over $b=2$ steps. - -[Morally speaking this is taking the same tree structure as in Implementation 2 and then adding some checkpoints.] - -Forward time: $O(n) + I(d)$ -Backward time: $O(n \log n) + I(d \log d)$ -Compile time: $O(m)$ -Memory usage: $O(d)$ - -[Assuming $b=2$ and therefore it doesn't appear in these values.] - -Great, we've fixed our memory usage! Note that the additional work needing to recompute from our checkpoints increases our backward computation time slightly. - -Unfortunately the compile time has exploded: every level of our recusion involves calling `fn` twice (once inside the checkpoint, once outside) and by doing so recursively we're making $2^d = m$ such calls. Both the jaxpr and the resulting XLA HLO will be of size $O(m)$, as we've basically just written out the whole loop manually! Compile times are already one of the most serious issues facing the JAX ecosystem, so this is also unacceptable. - -Whilst treeverse is optimal for run time, it is maximally nonoptimal for compile time. - -### Implementation 4: naive checkpointing - -Next let's try naive checkpointing. This just means picking some $\sqrt{m}$ equally-spaced points between $0$ and $m$ and placing a checkpoint at each one. Unlike treeverse, this does not use any recursive checkpointing. [Note that this is the kind of checkpointing you often see used in practice with e.g. ResNets etc.] - -This can be implemented very simply: nest `scan`-`cond`-`checkpoint`-`scan`-`cond`, where the length of each `scan` is $\sqrt{m}$. - -Forward time: $O(n) + I(\sqrt{m})$ -Backward time: $O(n) + I(\sqrt{m})$ -Compile time: $O(1)$ -Memory usage: $O(\sqrt{m})$ - -Each intermediate step is re-computed from a checkpoint precisely once, so the backward pass has the same complexity as the forward pass. - -This is a surprisingly decent option: $O(\sqrt{m})$ represents much worse memory usage (and therefore backward computation time) than we'd like, but this still represents a not-completely-awful trade-off compared to our previous options. - -### Implementation 5: nested `checkpoint`-`scan`-`cond` - -Can we combine the best pieces of implementations 2/3 and 4? In other words, nest `scan`-`checkpoint`-`scan`-`cond`-`checkpoint`-`scan`-`cond`-...`checkpoint`-`scan`-`cond`, with $d$-many `scan`s each of length $b$. As an example, in the $b=2$ case and unrolling any individual `scan` produces something a bit like implementation 3, except with `jax.checkpoint(fn)(jax.checkpoint(fn)(x))` instead. - -Forward time: $O(n) + I(db)$ -Backward time: $O(dn) + I(db)$ -Compile time: $O(1)$ -Memory usage: $O(db)$ - -Overall we have performed $O(dn) + I(db)$ work on the backward pass. This is _liveable_... but still not stellar. That $d$ factor slows the backward pass down by a noticable factor. We see that for this to work, we must choose $b \neq 2$ (often an optimal value), as otherwise $d$ becomes large. In practice I've found that tractable values for an ODE solve are something like $d=3$ and $b=16$, for a maximum number of $16^3 = 4096$ steps. - -This is at least better than implementation 4, in that the memory usage, and therefore the practical backward time, has come down from $O(\sqrt{m})$ to $O(b \log m) (=O(db))$. - -Theoretical justification for these values as follows: - -The forward time and compile time are both as in implementation 2. The memory usage can be found by considering backpropagating from the step just prior to the end: we have saved $b-1$ checkpoints at the top level, and we have saved $b-1$ checkpoints at the second (nested) level, etc., for $d$ levels. - -Now for the runtime of the backward pass. - -Suppose we take a lot of steps, so that $n \approx m$. Consider reconstructing the final iteration of the top-level `scan` from its checkpoint at the start. This takes $O(m/b)$ work (the forward evaluation over the proportion of the overall interval that that final iteration covers), and leaves us with a number of checkpoints along the second-level `scan`. The forward evaluation through each of those in turn takes $O(m/b^2)$ work -- by the same logic -- and there are $b$ many of them, once again requiring $O(m/b)$ overall work. This happens for $d$ many levels, so that the overall amount of work to backpropagate through this final iteration is actually $O(dm/b)$. Now the fact that it was the final iteration didn't actually affect this analysis (that was just for pedagogical simplicity), so we do the above procedure $b$ times, for an overall $O(dm) = O(dn)$ amount of work. Meanwhile we take very few identity steps, so the $I$ term is approximately zero. - -Now suppose that we take very few steps, so that $n$ is much smaller than $m$, and in fact contained within just the first top-level iteration (i.e. $n < m/b$). Then all the latter iterations of the top-level `scan` are just the identity and do not contribute anything to our $O$ measurement, so consider just the first iteration. We are now within our top-level checkpointed region, and so need to recompute all of our checkpoints. Once again suppose $n$ is very small and contained within just the first sub-iteration (that is $n < m/b^2$). Repeat ad nauseum, so that the entirety of our $O$-measured work is contained within the very first bottom-level iteration. This bottom-level iteration takes $O(n)$ work to compute in isolation. However we have recomputed it and then discarded it many times: $d - 1$ times, to be precise. Once when computing the checkpoints for the second-level iteration; once when computing the checkpoint for the third-level iteration; etc. And thus overall we have performed $O(dn)$ work. (Meanwhile, the number of identity steps we have neglected in this analysis cost $I(db)$. Indeed they are counted in an identical manner to the forward pass.) - -### Implementation 6? - -Maybe there's another better way of doing it? I make no claims that the above result is as good as it gets. - -## Coda - -### Optimums - -The theoretical optimum without checkpointing is: - -Forward time: $O(n)$ -Backward time: $O(n)$ -Compile time: $O(1)$ -Memory usage: $O(n)$ - -and with treeverse it is: - -Forward time: $O(n)$ -Backward time: $O(n \log n)$ -Compile time: $O(1)$ -Memory usage: $O(\log n)$ - -It is clearly impossible to obtain the non-checkpointing optimum under the JAX/XLA model of computation, due to the requirement that all memory must be statically allocated in advance. (This is a great pity, as it's also the single best option for most problems.) - -- Ever-so-maybe it might be possible to achieve the treeverse optimal value by writing `bounded_while_loop` as a new primitive? This would certainly reduce the jaxpr size down to $O(1)$, but it's not clear (at least to me) what the size of the backward pass, expressed as an XLA HLO expression, must be -- and compile times are proportional to that too. -- Alternatively, a way to compile a function only once (rather than inlining everything) would also make it possible to represent treeverse, as then `jax.checkpoint(fn)` can be compiled in constant time from `fn`, without introducing an exponential explosion as depth progresses. - -### Higher-order derivatives - -I don't know of anything discussing the interaction between checkpointing schemes and higher-order autodifferentiation. Given that a checkpointing scheme is required for memory usage (and thus backward pass speed) to be tractable, then it's not clear to me what the best approach is when this is a concern that needs to be born in mind. - -### Other implementation complexities - -JAX has a variety of other limitations that must be worked around when building a bounded while loop. Most noticably: - -- Handling `vmap` appropriately, as `vmap`'ing a `cond` produces a `select`. (Which would then always run the entire loop to completion.) -- Handling in-place updates. The recursively nested structures here mean that XLA:CPU is unable to optimise away in-place updates made during the body function of the while loop. (And instead makes copies.) -- Actually getting the compile time that the above asymptotics promise. In particular it is possible to get a compile time that is exponential in the size of the program when using nested `cond`s. - -(See the implementation itself for further thoughts on these.) diff --git a/test/helpers.py b/test/helpers.py index 3d8812b9..265ac94b 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,7 +1,5 @@ import functools as ft -import gc import operator -import time import diffrax import equinox as eqx @@ -85,23 +83,3 @@ def shaped_allclose(x, y, **kwargs): return same_structure and jtu.tree_reduce( operator.and_, jtu.tree_map(allclose, x, y), True ) - - -def time_fn(fn, repeat=1): - fn() # Compile - gc_enabled = gc.isenabled() - if gc_enabled: - gc.collect() - gc.disable() - try: - times = [] - for _ in range(repeat): - start = time.perf_counter_ns() - fn() - end = time.perf_counter_ns() - times.append(end - start) - return min(times) - finally: - if gc_enabled: - gc.enable() - gc.collect() diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 0b9f7aee..a733b4da 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -1,4 +1,3 @@ -import math from typing import Any import diffrax @@ -13,26 +12,6 @@ from .helpers import shaped_allclose -def test_no_adjoint(): - def fn(y0): - term = diffrax.ODETerm(lambda t, y, args: -y) - t0 = 0 - t1 = 1 - dt0 = 0.1 - solver = diffrax.Dopri5() - adjoint = diffrax.NoAdjoint() - sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, adjoint=adjoint) - return jnp.sum(sol.ys) - - with pytest.raises(ValueError): - jax.grad(fn)(1.0) - - primal, dual = jax.jvp(fn, (1.0,), (1.0,)) - e_inv = 1 / math.e - assert shaped_allclose(primal, e_inv) - assert shaped_allclose(dual, e_inv) - - class _VectorField(eqx.Module): nondiff_arg: int diff_arg: float @@ -107,21 +86,32 @@ def _convert_float0(x): continue saveat = diffrax.SaveAt(t0=t0, t1=t1, ts=ts) + direct_grads = _run_grad( + diff, saveat, diffrax.adjoint.RecursiveCheckpointAdjoint2() + ) recursive_grads = _run_grad( diff, saveat, diffrax.RecursiveCheckpointAdjoint() ) backsolve_grads = _run_grad(diff, saveat, diffrax.BacksolveAdjoint()) - assert shaped_allclose(recursive_grads, backsolve_grads, atol=1e-5) + assert shaped_allclose(direct_grads, recursive_grads, atol=1e-5) + assert shaped_allclose(direct_grads, backsolve_grads, atol=1e-5) + direct_grads = _run_grad_int( + y0__args__term, + saveat, + diffrax.adjoint.RecursiveCheckpointAdjoint2(), + ) recursive_grads = _run_grad_int( y0__args__term, saveat, diffrax.RecursiveCheckpointAdjoint() ) backsolve_grads = _run_grad_int( y0__args__term, saveat, diffrax.BacksolveAdjoint() ) + direct_grads = jtu.tree_map(_convert_float0, direct_grads) recursive_grads = jtu.tree_map(_convert_float0, recursive_grads) backsolve_grads = jtu.tree_map(_convert_float0, backsolve_grads) - assert shaped_allclose(recursive_grads, backsolve_grads, atol=1e-5) + assert shaped_allclose(direct_grads, recursive_grads, atol=1e-5) + assert shaped_allclose(direct_grads, backsolve_grads, atol=1e-5) def test_adjoint_seminorm(): diff --git a/test/test_bounded_while_loop.py b/test/test_bounded_while_loop.py index c939fdb8..03a504dc 100644 --- a/test/test_bounded_while_loop.py +++ b/test/test_bounded_while_loop.py @@ -3,18 +3,11 @@ # - Test grad time # - Test compile time -import functools as ft - -import diffrax -import equinox as eqx import jax -import jax.lax as lax import jax.numpy as jnp -import jax.random as jrandom -import jax.tree_util as jtu -import numpy as np +from diffrax.bounded_while_loop import bounded_while_loop -from .helpers import shaped_allclose, time_fn +from .helpers import shaped_allclose def test_functional_no_vmap_no_inplace(): @@ -22,28 +15,28 @@ def cond_fun(val): x, step = val return step < 5 - def body_fun(val, _): + def body_fun(val): x, step = val return (x + 0.1, step + 1) init_val = (jnp.array([0.3]), 0) - val = diffrax.misc.bounded_while_loop(cond_fun, body_fun, init_val, max_steps=0) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=0) assert shaped_allclose(val[0], jnp.array([0.3])) and val[1] == 0 - val = diffrax.misc.bounded_while_loop(cond_fun, body_fun, init_val, max_steps=1) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=1) assert shaped_allclose(val[0], jnp.array([0.4])) and val[1] == 1 - val = diffrax.misc.bounded_while_loop(cond_fun, body_fun, init_val, max_steps=2) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=2) assert shaped_allclose(val[0], jnp.array([0.5])) and val[1] == 2 - val = diffrax.misc.bounded_while_loop(cond_fun, body_fun, init_val, max_steps=4) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=4) assert shaped_allclose(val[0], jnp.array([0.7])) and val[1] == 4 - val = diffrax.misc.bounded_while_loop(cond_fun, body_fun, init_val, max_steps=8) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=8) assert shaped_allclose(val[0], jnp.array([0.8])) and val[1] == 5 - val = diffrax.misc.bounded_while_loop(cond_fun, body_fun, init_val, max_steps=None) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=None) assert shaped_allclose(val[0], jnp.array([0.8])) and val[1] == 5 @@ -52,32 +45,30 @@ def cond_fun(val): x, step = val return step < 5 - def body_fun(val, inplace): + def body_fun(val): x, step = val - x = inplace(x).at[jnp.minimum(step + 1, 4)].set(x[step] + 0.1) - step = inplace(step).at[()].set(step + 1) - x = diffrax.misc.HadInplaceUpdate(x) - step = diffrax.misc.HadInplaceUpdate(step) + x = x.at[jnp.minimum(step + 1, 4)].set(x[step] + 0.1) + step = step.at[()].set(step + 1) return x, step init_val = (jnp.array([0.3, 0.3, 0.3, 0.3, 0.3]), 0) - val = diffrax.misc.bounded_while_loop(cond_fun, body_fun, init_val, max_steps=0) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=0) assert shaped_allclose(val[0], jnp.array([0.3, 0.3, 0.3, 0.3, 0.3])) and val[1] == 0 - val = diffrax.misc.bounded_while_loop(cond_fun, body_fun, init_val, max_steps=1) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=1) assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.3, 0.3, 0.3])) and val[1] == 1 - val = diffrax.misc.bounded_while_loop(cond_fun, body_fun, init_val, max_steps=2) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=2) assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.3, 0.3])) and val[1] == 2 - val = diffrax.misc.bounded_while_loop(cond_fun, body_fun, init_val, max_steps=4) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=4) assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.6, 0.7])) and val[1] == 4 - val = diffrax.misc.bounded_while_loop(cond_fun, body_fun, init_val, max_steps=8) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=8) assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.6, 0.8])) and val[1] == 5 - val = diffrax.misc.bounded_while_loop(cond_fun, body_fun, init_val, max_steps=None) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=None) assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.6, 0.8])) and val[1] == 5 @@ -86,50 +77,50 @@ def cond_fun(val): x, step = val return step < 5 - def body_fun(val, _): + def body_fun(val): x, step = val return (x + 0.1, step + 1) init_val = (jnp.array([[0.3], [0.4]]), jnp.array([0, 3])) - val = jax.vmap( - lambda v: diffrax.misc.bounded_while_loop(cond_fun, body_fun, v, max_steps=0) - )(init_val) + val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=0))( + init_val + ) assert shaped_allclose(val[0], jnp.array([[0.3], [0.4]])) and jnp.array_equal( val[1], jnp.array([0, 3]) ) - val = jax.vmap( - lambda v: diffrax.misc.bounded_while_loop(cond_fun, body_fun, v, max_steps=1) - )(init_val) + val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=1))( + init_val + ) assert shaped_allclose(val[0], jnp.array([[0.4], [0.5]])) and jnp.array_equal( val[1], jnp.array([1, 4]) ) - val = jax.vmap( - lambda v: diffrax.misc.bounded_while_loop(cond_fun, body_fun, v, max_steps=2) - )(init_val) + val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=2))( + init_val + ) assert shaped_allclose(val[0], jnp.array([[0.5], [0.6]])) and jnp.array_equal( val[1], jnp.array([2, 5]) ) - val = jax.vmap( - lambda v: diffrax.misc.bounded_while_loop(cond_fun, body_fun, v, max_steps=4) - )(init_val) + val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=4))( + init_val + ) assert shaped_allclose(val[0], jnp.array([[0.7], [0.6]])) and jnp.array_equal( val[1], jnp.array([4, 5]) ) - val = jax.vmap( - lambda v: diffrax.misc.bounded_while_loop(cond_fun, body_fun, v, max_steps=8) - )(init_val) + val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=8))( + init_val + ) assert shaped_allclose(val[0], jnp.array([[0.8], [0.6]])) and jnp.array_equal( val[1], jnp.array([5, 5]) ) - val = jax.vmap( - lambda v: diffrax.misc.bounded_while_loop(cond_fun, body_fun, v, max_steps=None) - )(init_val) + val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=None))( + init_val + ) assert shaped_allclose(val[0], jnp.array([[0.8], [0.6]])) and jnp.array_equal( val[1], jnp.array([5, 5]) ) @@ -140,12 +131,10 @@ def cond_fun(val): x, step, max_step = val return step < max_step - def body_fun(val, inplace): + def body_fun(val): x, step, max_step = val - x = inplace(x).at[jnp.minimum(step + 1, 4)].set(x[step] + 0.1) - step = inplace(step).at[()].set(step + 1) - x = diffrax.misc.HadInplaceUpdate(x) - step = diffrax.misc.HadInplaceUpdate(step) + x = x.at[jnp.minimum(step + 1, 4)].set(x[step] + 0.1) + step = step.at[()].set(step + 1) return x, step, max_step init_val = ( @@ -154,181 +143,44 @@ def body_fun(val, inplace): jnp.array([5, 3]), ) - val = jax.vmap( - lambda v: diffrax.misc.bounded_while_loop(cond_fun, body_fun, v, max_steps=0) - )(init_val) + val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=0))( + init_val + ) assert shaped_allclose( val[0], jnp.array([[0.3, 0.3, 0.3, 0.3, 0.3], [0.4, 0.4, 0.4, 0.4, 0.4]]) ) and jnp.array_equal(val[1], jnp.array([0, 1])) - val = jax.vmap( - lambda v: diffrax.misc.bounded_while_loop(cond_fun, body_fun, v, max_steps=1) - )(init_val) + val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=1))( + init_val + ) assert shaped_allclose( val[0], jnp.array([[0.3, 0.4, 0.3, 0.3, 0.3], [0.4, 0.4, 0.5, 0.4, 0.4]]) ) and jnp.array_equal(val[1], jnp.array([1, 2])) - val = jax.vmap( - lambda v: diffrax.misc.bounded_while_loop(cond_fun, body_fun, v, max_steps=2) - )(init_val) + val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=2))( + init_val + ) assert shaped_allclose( val[0], jnp.array([[0.3, 0.4, 0.5, 0.3, 0.3], [0.4, 0.4, 0.5, 0.6, 0.4]]) ) and jnp.array_equal(val[1], jnp.array([2, 3])) - val = jax.vmap( - lambda v: diffrax.misc.bounded_while_loop(cond_fun, body_fun, v, max_steps=4) - )(init_val) + val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=4))( + init_val + ) assert shaped_allclose( val[0], jnp.array([[0.3, 0.4, 0.5, 0.6, 0.7], [0.4, 0.4, 0.5, 0.6, 0.4]]) ) and jnp.array_equal(val[1], jnp.array([4, 3])) - val = jax.vmap( - lambda v: diffrax.misc.bounded_while_loop(cond_fun, body_fun, v, max_steps=8) - )(init_val) + val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=8))( + init_val + ) assert shaped_allclose( val[0], jnp.array([[0.3, 0.4, 0.5, 0.6, 0.8], [0.4, 0.4, 0.5, 0.6, 0.4]]) ) and jnp.array_equal(val[1], jnp.array([5, 3])) - val = jax.vmap( - lambda v: diffrax.misc.bounded_while_loop(cond_fun, body_fun, v, max_steps=None) - )(init_val) + val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=None))( + init_val + ) assert shaped_allclose( val[0], jnp.array([[0.3, 0.4, 0.5, 0.6, 0.8], [0.4, 0.4, 0.5, 0.6, 0.4]]) ) and jnp.array_equal(val[1], jnp.array([5, 3])) - - -# -# Test speed. Two things are tested: -# - asymptotic computational complexity; -# - speed compared to `lax.while_loop`. -# - - -def _make_update(i, u, v): - return u if i is None else v.at[i].set(u) - - -def _body_fun(body_fun): - def __body_fun(val): - update, index = body_fun(val) - return jtu.tree_map(_make_update, index, update, val) - - return __body_fun - - -def _quadratic_fit(x, y): - return np.polynomial.Polynomial.fit(x, y, deg=2).convert().coef - - -def _test_scaling_max_steps(): - key = jrandom.PRNGKey(567) - expensive_fn = eqx.nn.MLP(in_size=1, out_size=1, width_size=1024, depth=2, key=key) - - def cond_fun(val): - x, step = val - return step < 5 - - def body_fun(val): - x, step = val - return (expensive_fn(x[step, None])[0], step + 1), ( - jnp.minimum(step + 1, 5), - None, - ) - - init_val = ( - jnp.array([[0.3, 0.3, 0.3, 0.3, 0.3], [0.4, 0.4, 0.4, 0.4, 0.4]]), - jnp.array([0, 3]), - ) - - @ft.partial(jax.jit, static_argnums=1) - @ft.partial(jax.vmap, in_axes=(0, None)) - def test_fun(val, max_steps): - return diffrax.misc.bounded_while_loop(cond_fun, body_fun, val, max_steps) - - time16 = time_fn(lambda: test_fun(init_val, 16), repeat=10) - time32 = time_fn(lambda: test_fun(init_val, 32), repeat=10) - time64 = time_fn(lambda: test_fun(init_val, 64), repeat=10) - time128 = time_fn(lambda: test_fun(init_val, 128), repeat=10) - time256 = time_fn(lambda: test_fun(init_val, 256), repeat=10) - maxtime = max(time16, time32, time64, time128, time256) - - # Rescale to fit the graph inside [0, 1] x [0, 1] so that polynomials are actually - # a reasonable thing to use. - _, c1, c2 = _quadratic_fit( - [16 / 256, 32 / 256, 64 / 256, 128 / 256, 256 / 256], - [ - time16 / maxtime, - time32 / maxtime, - time64 / maxtime, - time128 / maxtime, - time256 / maxtime, - ], - ) - # Runtime expected to be O(1) - assert -0.05 < c1 < 0.05 - assert -0.05 < c2 < 0.05 - - @ft.partial(jax.jit, static_argnums=1) - @jax.vmap - def lax_test_fun(val): - return lax.while_loop(cond_fun, _body_fun(body_fun), val) - - lax_time = time_fn(lambda: lax_test_fun(init_val), repeat=10) - - assert maxtime < 2 * lax_time - - -def _test_scaling_num_steps(): - key = jrandom.PRNGKey(567) - expensive_fn = eqx.nn.MLP(in_size=1, out_size=1, width_size=1024, depth=2, key=key) - - def cond_fun(val): - x, step, num_steps = val - return step < num_steps - - def body_fun(val): - x, step, num_steps = val - return (expensive_fn(x[step, None])[0], step + 1, num_steps), ( - jnp.minimum(step + 1, num_steps), - None, - None, - ) - - init_val = (jnp.array([[0.3] * 256, [0.4] * 256]), jnp.array([0, 3])) - - @ft.partial(jax.jit, static_argnums=1) - @ft.partial(jax.vmap, in_axes=(0, None)) - def test_fun(val, num_steps): - return diffrax.misc.bounded_while_loop( - cond_fun, body_fun, (*val, num_steps), max_steps=256 - ) - - time16 = time_fn(lambda: test_fun(init_val, 16), repeat=10) - time32 = time_fn(lambda: test_fun(init_val, 32), repeat=10) - time64 = time_fn(lambda: test_fun(init_val, 64), repeat=10) - time128 = time_fn(lambda: test_fun(init_val, 128), repeat=10) - time256 = time_fn(lambda: test_fun(init_val, 256), repeat=10) - - _, c1, c2 = _quadratic_fit( - [16, 32, 64, 128, 256], [time16, time32, time64, time128, time256] - ) - # Runtime expected to be O(steps taken) - assert 0.95 < c1 < 1.05 - assert -0.05 < c2 < 0.05 - - @ft.partial(jax.jit, static_argnums=1) - @ft.partial(jax.vmap, in_axes=(0, None)) - def lax_test_fun(val, num_steps): - return lax.while_loop(cond_fun, _body_fun(body_fun), (*val, num_steps)) - - lax_time16 = time_fn(lambda: lax_test_fun(init_val, 16), repeat=10) - lax_time32 = time_fn(lambda: lax_test_fun(init_val, 32), repeat=10) - lax_time64 = time_fn(lambda: lax_test_fun(init_val, 64), repeat=10) - lax_time128 = time_fn(lambda: lax_test_fun(init_val, 128), repeat=10) - lax_time256 = time_fn(lambda: lax_test_fun(init_val, 256), repeat=10) - - assert time16 < 2 * lax_time16 - assert time32 < 2 * lax_time32 - assert time64 < 2 * lax_time64 - assert time128 < 2 * lax_time128 - assert time256 < 2 * lax_time256 diff --git a/test/test_integrate.py b/test/test_integrate.py index 41d897e6..c8196613 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -336,203 +336,6 @@ def test_semi_implicit_euler(): assert shaped_allclose(sol1.ys, sol2.ys) -def test_compile_time_steps(): - terms = diffrax.ODETerm(lambda t, y, args: -y) - y0 = jnp.array([1.0]) - solver = diffrax.Tsit5() - - sol = diffrax.diffeqsolve( - terms, - solver, - 0, - 1, - None, - y0, - stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6), - ) - assert sol.stats["compiled_num_steps"] is None - - sol = diffrax.diffeqsolve( - terms, - solver, - 0, - 1, - 0.1, - y0, - stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6), - ) - assert sol.stats["compiled_num_steps"] is None - - sol = diffrax.diffeqsolve( - terms, - solver, - 0, - 1, - 0.1, - y0, - stepsize_controller=diffrax.ConstantStepSize(compile_steps=True), - ) - assert shaped_allclose(sol.stats["compiled_num_steps"], 10) - - sol = diffrax.diffeqsolve( - terms, - solver, - 0, - 1, - 0.1, - y0, - stepsize_controller=diffrax.ConstantStepSize(compile_steps=None), - ) - assert shaped_allclose(sol.stats["compiled_num_steps"], 10) - - sol = diffrax.diffeqsolve( - terms, - solver, - 0, - 1, - 0.1, - y0, - stepsize_controller=diffrax.ConstantStepSize(compile_steps=False), - ) - assert sol.stats["compiled_num_steps"] is None - - sol = diffrax.diffeqsolve( - terms, - solver, - 0, - 1, - None, - y0, - stepsize_controller=diffrax.StepTo([0, 0.3, 0.5, 1], compile_steps=True), - ) - assert shaped_allclose(sol.stats["compiled_num_steps"], 3) - - sol = diffrax.diffeqsolve( - terms, - solver, - 0, - 1, - None, - y0, - stepsize_controller=diffrax.StepTo([0, 0.3, 0.5, 1], compile_steps=None), - ) - assert shaped_allclose(sol.stats["compiled_num_steps"], 3) - - sol = diffrax.diffeqsolve( - terms, - solver, - 0, - 1, - None, - y0, - stepsize_controller=diffrax.StepTo([0, 0.3, 0.5, 1], compile_steps=False), - ) - assert sol.stats["compiled_num_steps"] is None - - with pytest.raises(ValueError): - sol = jax.jit( - lambda t0: diffrax.diffeqsolve( - terms, - solver, - t0, - 1, - 0.1, - y0, - stepsize_controller=diffrax.ConstantStepSize(compile_steps=True), - ) - )(0) - - sol = jax.jit( - lambda t0: diffrax.diffeqsolve( - terms, - solver, - t0, - 1, - 0.1, - y0, - stepsize_controller=diffrax.ConstantStepSize(compile_steps=None), - ) - )(0) - assert sol.stats["compiled_num_steps"] is None - - sol = jax.jit( - lambda t1: diffrax.diffeqsolve( - terms, - solver, - 0, - t1, - 0.1, - y0, - stepsize_controller=diffrax.ConstantStepSize(compile_steps=None), - ) - )(1) - assert sol.stats["compiled_num_steps"] is None - - sol = jax.jit( - lambda dt0: diffrax.diffeqsolve( - terms, - solver, - 0, - 1, - dt0, - y0, - stepsize_controller=diffrax.ConstantStepSize(compile_steps=None), - ) - )(0.1) - assert sol.stats["compiled_num_steps"] is None - - # Work around JAX issue #9298 - diffeqsolve_nojit = diffrax.diffeqsolve.__wrapped__ - - _t0 = jnp.array([0, 0]) - sol = jax.jit( - lambda: jax.vmap( - lambda t0: diffeqsolve_nojit( - terms, - solver, - t0, - 1, - 0.1, - y0, - stepsize_controller=diffrax.ConstantStepSize(compile_steps=True), - ) - )(_t0) - )() - assert shaped_allclose(sol.stats["compiled_num_steps"], jnp.array([10, 10])) - - _t1 = jnp.array([1, 2]) - sol = jax.jit( - lambda: jax.vmap( - lambda t1: diffeqsolve_nojit( - terms, - solver, - 0, - t1, - 0.1, - y0, - stepsize_controller=diffrax.ConstantStepSize(compile_steps=True), - ) - )(_t1) - )() - assert shaped_allclose(sol.stats["compiled_num_steps"], jnp.array([20, 20])) - - _dt0 = jnp.array([0.1, 0.05]) - sol = jax.jit( - lambda: jax.vmap( - lambda dt0: diffeqsolve_nojit( - terms, - solver, - 0, - 1, - dt0, - y0, - stepsize_controller=diffrax.ConstantStepSize(compile_steps=True), - ) - )(_dt0) - )() - assert shaped_allclose(sol.stats["compiled_num_steps"], jnp.array([20, 20])) - - @pytest.mark.parametrize( "solver", [ From da5cac8f3041000410137fc43a376d79c145d8ca Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 22 Jan 2023 11:46:44 -0800 Subject: [PATCH 05/19] Moved checkpoint handling to Equinox --- diffrax/adjoint.py | 47 ++++++++++------------------------------------ 1 file changed, 10 insertions(+), 37 deletions(-) diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index 06427590..20a28794 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -1,6 +1,5 @@ import abc import functools as ft -import math from typing import Any, Dict, Optional import equinox as eqx @@ -244,45 +243,19 @@ def loop( **kwargs, ): del throw, passed_solver_state, passed_controller_state - if self.checkpoints is None: - if max_steps is None: - raise ValueError( - "Cannot use " - "`diffeqsolve(..., max_steps=None, adjoint=RecursiveCheckpointAdjoint(checkpoints=None))` " # noqa: E501 - "Either specify the number of `checkpoints` to use, or specify the " - "maximum number of steps (and `checkpoints` is chosen " - "automatically as `log2(max_steps)``.)" - ) - # Binomial logarithmic growth is what is needed in classical treeverse. - # - # Moreover this is optimal even in the online case, as provided - # `max_steps >= 21` - # then - # `checkpoints = ceil(log2(max_steps))` - # satisfies - # `max_steps <= (checkpoints + 1)(checkpoints + 2)/2` - # which is the condition for optimality. - # - # Meanwhile if - # `max_steps <= 20` - # then we handle it as a special case, to once again ensure we satisfy - # `max_steps <= (checkpoints + 1)(checkpoints + 2)/2` - # - # The optimality condition is equation (2.2) of - # "New Algorithms for Optimal Online Checkpointing", Stumm and Walther 2010. - # https://tu-dresden.de/mn/math/wir/ressourcen/dateien/forschung/publikationen/pdf2010/new_algorithms_for_optimal_online_checkpointing.pdf - if max_steps <= 20: - checkpoints = 1 - while (checkpoints + 1) * (checkpoints + 2) < 2 * max_steps: - checkpoints += 1 - else: - checkpoints = math.ceil(math.log2(max_steps)) - else: - checkpoints = self.checkpoints + if self.checkpoints is None and max_steps is None: + # Raise a more informative error than `checkpointed_while_loop` would. + raise ValueError( + "Cannot use " + "`diffeqsolve(..., max_steps=None, adjoint=RecursiveCheckpointAdjoint(checkpoints=None))` " # noqa: E501 + "Either specify the number of `checkpoints` to use, or specify the " + "maximum number of steps (and `checkpoints` is chosen " + "automatically as `log2(max_steps)``.)" + ) return self._loop_fn( max_steps=max_steps, while_loop=ft.partial( - eqxi.checkpointed_while_loop, checkpoints=checkpoints + eqxi.checkpointed_while_loop, checkpoints=self.checkpoints ), **kwargs, ) From 79b9a50e2db4ea2fe9ed6090988e0d84d144efcb Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 26 Jan 2023 13:30:18 -0800 Subject: [PATCH 06/19] RK solvers now do their linear ops at highest precision --- README.md | 2 +- diffrax/solver/base.py | 5 ++++- setup.py | 4 ++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e3f0e3d3..4692b526 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ _From a technical point of view, the internal structure of the library is pretty pip install diffrax ``` -Requires Python >=3.7 and JAX >=0.3.4. +Requires Python >=3.8 and JAX >=0.4.1. ## Documentation diff --git a/diffrax/solver/base.py b/diffrax/solver/base.py index 090f7d84..9a4c191e 100644 --- a/diffrax/solver/base.py +++ b/diffrax/solver/base.py @@ -2,6 +2,7 @@ from typing import Callable, Optional, Tuple, TypeVar import equinox as eqx +import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu @@ -17,7 +18,9 @@ def vector_tree_dot(a, b): - return jtu.tree_map(lambda bi: jnp.tensordot(a, bi, axes=1), b) + return jtu.tree_map( + lambda bi: jnp.tensordot(a, bi, axes=1, precision=lax.Precision.HIGHEST), b + ) class _MetaAbstractSolver(type(eqx.Module)): diff --git a/setup.py b/setup.py index 62c12ae0..c8820668 100644 --- a/setup.py +++ b/setup.py @@ -44,9 +44,9 @@ "Topic :: Scientific/Engineering :: Mathematics", ] -python_requires = "~=3.7" +python_requires = "~=3.8" -install_requires = ["jax>=0.3.4", "equinox>=0.10.0"] +install_requires = ["jax>=0.4.1", "equinox>=0.10.0"] setuptools.setup( name=name, From d6ed3719062b9ff779071d742a97c3889bc43713 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 26 Jan 2023 16:30:31 -0800 Subject: [PATCH 07/19] Adjoint API is now DirectAdjoint and RecursiveCheckpointAdjoint --- diffrax/adjoint.py | 111 +++++++++++++++++++++++++++++++++---------- diffrax/integrate.py | 49 ++++--------------- docs/api/adjoints.md | 17 ++++--- test/test_adjoint.py | 8 +--- 4 files changed, 106 insertions(+), 79 deletions(-) diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index 20a28794..19251b5f 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -1,5 +1,6 @@ import abc import functools as ft +import warnings from typing import Any, Dict, Optional import equinox as eqx @@ -11,7 +12,9 @@ from .ad import implicit_jvp from .bounded_while_loop import bounded_while_loop +from .heuristics import is_unsafe_sde from .saveat import SaveAt +from .solver import AbstractItoSolver, AbstractStratonovichSolver from .term import AbstractTerm, AdjointTerm @@ -122,7 +125,7 @@ def loop( # `integrate.py`. For convenience we make them available as properties here so all # adjoint methods can access these. @property - def _loop_fn(self): + def _loop(self): from .integrate import loop return loop @@ -134,23 +137,40 @@ def _diffeqsolve(self): return diffeqsolve -class RecursiveCheckpointAdjoint(AbstractAdjoint): - """Backpropagate through [`diffrax.diffeqsolve`][] by differentiating the numerical - solution directly. This is sometimes known as "discretise-then-optimise", or - described as "backpropagation through the solver". +class DirectAdjoint(AbstractAdjoint): + """A variant of [`diffrax.RecursiveCheckpointAdjoint`][]. The differences are that + `DirectAdjoint`: - Uses a binomial checkpointing scheme to keep memory usage low. + - Is less time+memory efficient at reverse-mode autodifferentiation (specifically, + these will increase every time `max_steps` increases passes a power of 16); + - Cannot be reverse-mode autodifferentated if `max_steps is None`; + - Supports forward-mode autodifferentiation. - For most problems this is the preferred technique for backpropagating through a - differential equation. + So unless you need forward-mode autodifferentiation then + [`diffrax.RecursiveCheckpointAdjoint`][] should be preferred. """ - def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs): + def loop( + self, + *, + max_steps, + terms, + throw, + passed_solver_state, + passed_controller_state, + **kwargs, + ): del throw, passed_solver_state, passed_controller_state - return self._loop_fn(**kwargs, while_loop=bounded_while_loop) + if is_unsafe_sde(terms) or max_steps is None: + while_loop = _while_loop + else: + while_loop = bounded_while_loop + return self._loop( + **kwargs, max_steps=max_steps, terms=terms, while_loop=while_loop + ) -class RecursiveCheckpointAdjoint2(AbstractAdjoint): +class RecursiveCheckpointAdjoint(AbstractAdjoint): """Backpropagate through [`diffrax.diffeqsolve`][] by differentiating the numerical solution directly. This is sometimes known as "discretise-then-optimise", or described as "backpropagation through the solver". @@ -163,7 +183,7 @@ class RecursiveCheckpointAdjoint2(AbstractAdjoint): !!! info Note that this cannot be forward-mode autodifferentiated. (E.g. using - `jax.jvp`.) + `jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if that is something you need. ??? cite "References" @@ -236,6 +256,8 @@ class RecursiveCheckpointAdjoint2(AbstractAdjoint): def loop( self, *, + terms, + init_state, max_steps, throw, passed_solver_state, @@ -249,10 +271,23 @@ def loop( "Cannot use " "`diffeqsolve(..., max_steps=None, adjoint=RecursiveCheckpointAdjoint(checkpoints=None))` " # noqa: E501 "Either specify the number of `checkpoints` to use, or specify the " - "maximum number of steps (and `checkpoints` is chosen " - "automatically as `log2(max_steps)``.)" + "maximum number of steps (and `checkpoints` is then chosen " + "automatically as `log(max_steps)`)." ) - return self._loop_fn( + if is_unsafe_sde(terms): + raise ValueError( + "`adjoint=RecursiveCheckpointAdjoint()` does not support " + "`UnsafeBrownianPath`. Consider using `adjoint=DirectAdjoint()` " + "instead." + ) + init_state = eqx.tree_at( + lambda s: (s.ts, s.ys, s.dense_ts, s.dense_infos), + init_state, + replace_fn=eqxi.Buffer, + ) + return self._loop( + terms=terms, + init_state=init_state, max_steps=max_steps, while_loop=ft.partial( eqxi.checkpointed_while_loop, checkpoints=self.checkpoints @@ -261,16 +296,17 @@ def loop( ) -class NoAdjoint(AbstractAdjoint): - """Disable backpropagation through [`diffrax.diffeqsolve`][]. - Forward-mode autodifferentiation (`jax.jvp`) will continue to work as normal. - If you do not need to differentiate the results of [`diffrax.diffeqsolve`][] then - this may sometimes improve the speed at which the differential equation is solved. - """ +RecursiveCheckpointAdjoint.__init__.__doc__ = """ +**Arguments:** - def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs): - del throw, passed_solver_state, passed_controller_state - return self._loop_fn(**kwargs, while_loop=_while_loop) +- `checkpoints`: the number of checkpoints to save. The amount of memory used by the + differential equation solve will be roughly equal to the number of checkpoints + multiplied by the size of `y0`. You can speed up backpropagation by allocating more + checkpoints. (So it makes sense to set as many checkpoints as you have memory for.) + This value can also be set to `None` (the default), in which case it will be set to + `log(max_steps)`, for which a theoretical result is available guaranteeing that + backpropagation will take `O(n log n)` time in the number of steps `n <= max_steps`. +""" def _vf(ys, residual, args__terms, closure): @@ -333,7 +369,8 @@ def loop( # `is` check because this may return a Tracer from SaveAt(ts=) if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True: raise ValueError( - "Can only use `adjoint=ImplicitAdjoint()` with `SaveAt(t1=True)`." + "Can only use `adjoint=ImplicitAdjoint()` with " + "`saveat=SaveAt(t1=True)`." ) if not passed_solver_state: @@ -608,6 +645,7 @@ def loop( *, args, terms, + solver, saveat, init_state, passed_solver_state, @@ -620,6 +658,22 @@ def loop( "Cannot use `adjoint=BacksolveAdjoint()` with " "`saveat=Steps(steps=True)` or `saveat=Steps(dense=True)`." ) + if is_unsafe_sde(terms): + raise ValueError( + "`adjoint=BacksolveAdjoint()` does not support `UnsafeBrownianPath`. " + "Consider using `adjoint=DirectAdjoint()` instead." + ) + if isinstance(solver, AbstractItoSolver): + raise NotImplementedError( + f"`{solver.__name__}` converges to the Itô solution. However " + "`BacksolveAdjoint` currently only supports Stratonovich SDEs." + ) + elif not isinstance(solver, AbstractStratonovichSolver): + warnings.warn( + f"{solver.__name__} is not marked as converging to either the Itô " + "or the Stratonovich solution. Note that `BacksolveAdjoint` will " + "only produce the correct solution for Stratonovich SDEs." + ) y = init_state.y sentinel = object() @@ -628,7 +682,12 @@ def loop( ) final_state, aux_stats = _loop_backsolve( - (y, args, terms), self=self, saveat=saveat, init_state=init_state, **kwargs + (y, args, terms), + self=self, + saveat=saveat, + init_state=init_state, + solver=solver, + **kwargs, ) final_state = _no_transpose_final_state(final_state) return final_state, aux_stats diff --git a/diffrax/integrate.py b/diffrax/integrate.py index f1d1a16f..cb43a90e 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -8,13 +8,7 @@ import jax.numpy as jnp import jax.tree_util as jtu -from .adjoint import ( - AbstractAdjoint, - BacksolveAdjoint, - ImplicitAdjoint, - NoAdjoint, - RecursiveCheckpointAdjoint, -) +from .adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint from .bounded_while_loop import bounded_while_loop from .custom_types import Array, Bool, Int, PyTree, Scalar from .event import AbstractDiscreteTerminatingEvent @@ -384,7 +378,7 @@ def diffeqsolve( - `dt0`: The step size to use for the first step. If using fixed step sizes then this will also be the step size for all other steps. (Except the last one, which may be slightly smaller and clipped to `t1`.) If set as `None` then the - initial step size will be determined automatically if possible. + initial step size will be determined automatically. - `y0`: The initial value. This can be any PyTree of JAX arrays. (Or types that can be coerced to JAX arrays, like Python floats.) - `args`: Any additional arguments to pass to the vector field. @@ -397,13 +391,12 @@ def diffeqsolve( **Other arguments:** - These arguments are infrequently used, and for most purposes you shouldn't need to - understand these. All of these are keyword-only arguments. + These arguments are less frequently used, and for most purposes you shouldn't need + to understand these. All of these are keyword-only arguments. - - `adjoint`: How to backpropagate (and compute forward-mode autoderivatives) of - `diffeqsolve`. Defaults to discretise-then-optimise, which is usually the best - option for most problems. See the page on [Adjoints](./adjoints.md) for more - information. + - `adjoint`: How to differentiate `diffeqsolve`. Defaults to + discretise-then-optimise, which is usually the best option for most problems. + See the page on [Adjoints](./adjoints.md) for more information. - `discrete_terminating_event`: A discrete event at which to terminate the solve early. See the page on [Events](./events.md) for more information. @@ -412,14 +405,7 @@ def diffeqsolve( unconditionally. Can also be set to `None` to allow an arbitrary number of steps, although this - is incompatible with `saveat=SaveAt(steps=True)` or `saveat=SaveAt(dense=True)`, - and can only be backpropagated through if using `adjoint=BacksolveAdjoint()` or - `adjoint=ImplicitAdjoint()`. - - Note that (a) compile times; and (b) backpropagation run times; will increase - as `max_steps` increases. (Specifically, each time `max_steps` passes a power - of 16.) You can reduce these times by using the smallest value of `max_steps` - that is reasonable for your problem. + is incompatible with `saveat=SaveAt(steps=True)` or `saveat=SaveAt(dense=True)`. - `throw`: Whether to raise an exception if the integration fails for any reason. @@ -436,7 +422,7 @@ def diffeqsolve( !!! note - Note that when `jax.vmap`-ing a differential equation solve, then + When `jax.vmap`-ing a differential equation solve, then `throw=True` means that an exception will be raised if any batch element fails. You may prefer to set `throw=False` and inspect the `result` field of the returned solution object, to determine which batch elements @@ -509,18 +495,6 @@ def diffeqsolve( f"`{type(solver).__name__}` is not marked as converging to either the " "Itô or the Stratonovich solution." ) - if isinstance(adjoint, BacksolveAdjoint): - if isinstance(solver, AbstractItoSolver): - raise NotImplementedError( - f"`{solver.__name__}` converges to the Itô solution. However " - "`BacksolveAdjoint` currently only supports Stratonovich SDEs." - ) - elif not isinstance(solver, AbstractStratonovichSolver): - warnings.warn( - f"{solver.__name__} is not marked as converging to either the Itô " - "or the Stratonovich solution. Note that BacksolveAdjoint will " - "only produce the correct solution for Stratonovich SDEs." - ) if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): # Specific check to not work even if using HalfSolver(Euler()) if isinstance(solver, Euler): @@ -533,11 +507,6 @@ def diffeqsolve( raise ValueError( "`UnsafeBrownianPath` cannot be used with adaptive step sizes." ) - if not isinstance(adjoint, (NoAdjoint, ImplicitAdjoint)): - raise ValueError( - "`UnsafeBrownianPath` can only be used with `adjoint=NoAdjoint()` or " - "`adjoint=ImplicitAdjoint()`." - ) # Allow setting e.g. t0 as an int with dt0 as a float. timelikes = (jnp.array(0.0), t0, t1, dt0, saveat.ts) diff --git a/docs/api/adjoints.md b/docs/api/adjoints.md index 39cef0d4..cc04d63e 100644 --- a/docs/api/adjoints.md +++ b/docs/api/adjoints.md @@ -21,24 +21,27 @@ There are multiple ways to backpropagate through a differential equation (to com members: - loop +Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.BacksolveAdjoint`][] can only be reverse-mode autodifferentiated. [`diffrax.DirectAdjoint`][] and [`diffrax.ImplicitAdjoint`][] support both forward and reverse-mode autodifferentiation. + --- ::: diffrax.RecursiveCheckpointAdjoint selection: - members: false + members: + - __init__ -::: diffrax.NoAdjoint +::: diffrax.BacksolveAdjoint selection: - members: false + members: + - __init__ -::: diffrax.ImplicitAdjoint +::: diffrax.DirectAdjoint selection: members: false -::: diffrax.BacksolveAdjoint +::: diffrax.ImplicitAdjoint selection: - members: - - __init__ + members: false --- diff --git a/test/test_adjoint.py b/test/test_adjoint.py index a733b4da..42ec73e1 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -86,9 +86,7 @@ def _convert_float0(x): continue saveat = diffrax.SaveAt(t0=t0, t1=t1, ts=ts) - direct_grads = _run_grad( - diff, saveat, diffrax.adjoint.RecursiveCheckpointAdjoint2() - ) + direct_grads = _run_grad(diff, saveat, diffrax.DirectAdjoint()) recursive_grads = _run_grad( diff, saveat, diffrax.RecursiveCheckpointAdjoint() ) @@ -97,9 +95,7 @@ def _convert_float0(x): assert shaped_allclose(direct_grads, backsolve_grads, atol=1e-5) direct_grads = _run_grad_int( - y0__args__term, - saveat, - diffrax.adjoint.RecursiveCheckpointAdjoint2(), + y0__args__term, saveat, diffrax.DirectAdjoint() ) recursive_grads = _run_grad_int( y0__args__term, saveat, diffrax.RecursiveCheckpointAdjoint() From 7584141b5b397980d57620032238c50912a87311 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 27 Jan 2023 20:18:32 -0800 Subject: [PATCH 08/19] Separate inner/outer while loops --- diffrax/adjoint.py | 51 +++++++++++++++++++++++++++++++++----------- diffrax/integrate.py | 8 +++---- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index 19251b5f..1ec66694 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -1,6 +1,7 @@ import abc import functools as ft import warnings +from dataclasses import fields from typing import Any, Dict, Optional import equinox as eqx @@ -166,7 +167,11 @@ def loop( else: while_loop = bounded_while_loop return self._loop( - **kwargs, max_steps=max_steps, terms=terms, while_loop=while_loop + **kwargs, + max_steps=max_steps, + terms=terms, + inner_while_loop=while_loop, + outer_while_loop=while_loop, ) @@ -257,6 +262,7 @@ def loop( self, *, terms, + saveat, init_state, max_steps, throw, @@ -280,17 +286,35 @@ def loop( "`UnsafeBrownianPath`. Consider using `adjoint=DirectAdjoint()` " "instead." ) - init_state = eqx.tree_at( - lambda s: (s.ts, s.ys, s.dense_ts, s.dense_infos), - init_state, - replace_fn=eqxi.Buffer, - ) + + def inner_buffers(state): + assert type(state).__name__ == "_InnerState" + assert {f.name for f in fields(state)} == { + "ts", + "ys", + "saveat_ts_index", + "saveat_index", + } + return state.ts, state.ys + + def outer_buffers(state): + assert type(state).__name__ == "_State" + return state.ts, state.ys, state.dense_ts, state.dense_infos + return self._loop( terms=terms, + saveat=saveat, init_state=init_state, max_steps=max_steps, - while_loop=ft.partial( - eqxi.checkpointed_while_loop, checkpoints=self.checkpoints + inner_while_loop=ft.partial( + eqxi.checkpointed_while_loop, + checkpoints=(len(saveat.ts),), + buffers=inner_buffers, + ), + outer_while_loop=ft.partial( + eqxi.checkpointed_while_loop, + checkpoints=self.checkpoints, + buffers=outer_buffers, ), **kwargs, ) @@ -322,14 +346,15 @@ def _vf(ys, residual, args__terms, closure): def _solve(args__terms, closure): args, terms = args__terms self, kwargs, solver, saveat, init_state = closure - final_state, aux_stats = self._loop_fn( + final_state, aux_stats = self._loop( **kwargs, args=args, terms=terms, solver=solver, saveat=saveat, init_state=init_state, - while_loop=_while_loop, + inner_while_loop=_while_loop, + outer_while_loop=_while_loop, ) # Note that we use .ys not .y here. The former is what is actually returned # by diffeqsolve, so it is the thing we want to attach the tangent to. @@ -411,12 +436,12 @@ def _loop_backsolve(y__args__terms, *, self, throw, init_state, **kwargs): lambda s: jtu.tree_leaves(s.y), init_state, jtu.tree_leaves(y) ) del y - return self._loop_fn( + return self._loop( args=args, terms=terms, init_state=init_state, - while_loop=_while_loop, - **kwargs, + inner_while_loop=_while_loop, + outer_while_loop=_while_loop**kwargs, ) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index cb43a90e..9637f733 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -9,7 +9,6 @@ import jax.tree_util as jtu from .adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint -from .bounded_while_loop import bounded_while_loop from .custom_types import Array, Bool, Int, PyTree, Scalar from .event import AbstractDiscreteTerminatingEvent from .global_interpolation import DenseInterpolation @@ -99,7 +98,8 @@ def loop( terms, args, init_state, - while_loop, + inner_while_loop, + outer_while_loop, ): if saveat.t0: @@ -252,7 +252,7 @@ def _body_fun(_state): saveat_ts_index=saveat_ts_index, ts=ts, ys=ys, save_index=save_index ) - final_inner_state = bounded_while_loop( + final_inner_state = inner_while_loop( _cond_fun, _body_fun, init_inner_state, max_steps=len(saveat.ts) ) @@ -321,7 +321,7 @@ def maybe_inplace(i, x, u): return new_state - final_state = while_loop(cond_fun, body_fun, init_state, max_steps) + final_state = outer_while_loop(cond_fun, body_fun, init_state, max_steps) if saveat.t1 and not saveat.steps: # if saveat.steps then the final value is already saved. From 598e3bba9fba07e8bf640022ebdd06e396f0700c Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 30 Jan 2023 16:51:10 -0800 Subject: [PATCH 09/19] Breaking JAX --- diffrax/__init__.py | 2 +- diffrax/bounded_while_loop.py | 139 ++++++++- test/test_bounded_while_loop.py | 489 ++++++++++++++++++++++++++++++-- 3 files changed, 589 insertions(+), 41 deletions(-) diff --git a/diffrax/__init__.py b/diffrax/__init__.py index ff90008b..75a5d268 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -1,8 +1,8 @@ from .adjoint import ( AbstractAdjoint, BacksolveAdjoint, + DirectAdjoint, ImplicitAdjoint, - NoAdjoint, RecursiveCheckpointAdjoint, ) from .autocitation import citation, citation_rules diff --git a/diffrax/bounded_while_loop.py b/diffrax/bounded_while_loop.py index 59d5e0ac..5378c9ad 100644 --- a/diffrax/bounded_while_loop.py +++ b/diffrax/bounded_while_loop.py @@ -1,6 +1,8 @@ import functools as ft import math +from typing import Any, Callable, Optional, Union +import equinox as eqx import equinox.internal as eqxi import jax import jax.lax as lax @@ -8,23 +10,44 @@ import jax.tree_util as jtu -def bounded_while_loop(cond_fun, body_fun, init_val, max_steps, base=16): +def bounded_while_loop( + cond_fun, + body_fun, + init_val, + max_steps: Optional[int], + *, + buffers: Optional[Callable] = None, + base: int = 16 +): """Reverse-mode autodifferentiable while loop. - Mostly as `lax.while_loop`, with a few small changes. + This only exists to support a few edge cases: + - forward-mode autodiff; + - reading from `buffers`. + You should almost always prefer to use `equinox.internal.checkpointed_while_loop` + instead. - Arguments: - cond_fun: function `a -> bool` - body_fun: function `a -> a`. - init_val: pytree of type `a`. - max_steps: integer or `None`. - base: integer. + Once 'bloops' land in JAX core then this function will be removed. + + **Arguments:** + + - cond_fun: function `a -> bool`. + - body_fun: function `a -> a`. + - init_val: pytree of type `a`. + - max_steps: integer or `None`. + - buffers: function `a -> node or nodes`. + - base: integer. Note the extra `max_steps` argument. If this is `None` then `bounded_while_loop` will fall back to `lax.while_loop` (which is not reverse-mode autodifferentiable). If it is a non-negative integer then this is the maximum number of steps which may be taken in the loop, after which the loop will exit unconditionally. + Note the extra `buffers` argument. This behaves similarly to the same argument for + `equinox.internal.checkpointed_while_loop`: these support efficient in-place updates + but no operation. (Unlike `checkpointed_while_loop`, however, this supports being + read from.) + Note the extra `base` argument. - Run time will increase slightly as `base` increases. - Compilation time will decrease substantially as @@ -47,21 +70,53 @@ def _cond_fun(val, step): init_data = (cond_fun(init_val), init_val, 0) rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base))) - _, val, _ = _while_loop(_cond_fun, body_fun, init_data, rounded_max_steps, base) + if buffers is None: + buffers = lambda _: () + _, val, _ = _while_loop( + _cond_fun, body_fun, init_data, rounded_max_steps, buffers, base + ) return val -def _while_loop(cond_fun, body_fun, data, max_steps, base): +def _while_loop(cond_fun, body_fun, data, max_steps, buffers, base): if max_steps == 1: pred, val, step = data + + tag = object() + + def _buffers(v): + nodes = buffers(v) + tree = jtu.tree_map(_unwrap_buffers, nodes, is_leaf=_is_buffer) + return jtu.tree_leaves(tree) + + val = eqx.tree_at( + _buffers, val, replace_fn=ft.partial(_Buffer, _pred=pred, _tag=tag) + ) new_val = body_fun(val) - new_val = jtu.tree_map(ft.partial(lax.select, pred), new_val, val) + if jax.eval_shape(lambda: val) != jax.eval_shape(lambda: new_val): + raise ValueError("body_fun must have matching input and output structures") + + def _is_our_buffer(x): + return isinstance(x, _Buffer) and x._tag is tag + + def _unwrap_or_select(new_v, v): + if _is_our_buffer(new_v): + assert _is_our_buffer(v) + assert eqx.is_array(new_v._array) + assert eqx.is_array(v._array) + return new_v._array + else: + return lax.select(pred, new_v, v) + + new_val = jtu.tree_map(_unwrap_or_select, new_val, val, is_leaf=_is_our_buffer) new_step = step + 1 return cond_fun(new_val, new_step), new_val, new_step else: def _call(_data): - return _while_loop(cond_fun, body_fun, _data, max_steps // base, base) + return _while_loop( + cond_fun, body_fun, _data, max_steps // base, buffers, base + ) def _scan_fn(_data, _): _pred, _, _ = _data @@ -73,3 +128,63 @@ def _scan_fn(_data, _): _scan_fn = jax.checkpoint(_scan_fn, prevent_cse=False) return lax.scan(_scan_fn, data, xs=None, length=base)[0] + + +def _is_buffer(x): + return isinstance(x, _Buffer) + + +def _unwrap_buffers(x): + while _is_buffer(x): + x = x._array + return x + + +class _Buffer(eqx.Module): + _array: Union[jnp.ndarray, "_Buffer"] + _pred: jnp.ndarray + _tag: object = eqx.static_field() + + def __getitem__(self, item): + return self._array[item] + + def _set(self, pred, item, x): + pred = pred & self._pred + if isinstance(self._array, _Buffer): + array = self._array._set(pred, item, x) + else: + old_x = self._array[item] + x = jnp.where(pred, x, old_x) + array = self._array.at[item].set(x) + return _Buffer(array, self._pred, self._tag) + + @property + def at(self): + return _BufferAt(self) + + @property + def shape(self): + return self._array.shape + + @property + def dtype(self): + return self._array.dtype + + @property + def size(self): + return self._array.size + + +class _BufferAt(eqx.Module): + _buffer: _Buffer + + def __getitem__(self, item): + return _BufferItem(self._buffer, item) + + +class _BufferItem(eqx.Module): + _buffer: _Buffer + _item: Any + + def set(self, x): + return self._buffer._set(True, self._item, x) diff --git a/test/test_bounded_while_loop.py b/test/test_bounded_while_loop.py index 03a504dc..4b32ad80 100644 --- a/test/test_bounded_while_loop.py +++ b/test/test_bounded_while_loop.py @@ -1,10 +1,14 @@ -# TODO: -# - Test forward times -# - Test grad time -# - Test compile time +import functools as ft +import timeit +from typing import Optional +import equinox as eqx import jax +import jax.lax as lax import jax.numpy as jnp +import jax.random as jr +import jax.tree_util as jtu +import pytest from diffrax.bounded_while_loop import bounded_while_loop from .helpers import shaped_allclose @@ -51,24 +55,30 @@ def body_fun(val): step = step.at[()].set(step + 1) return x, step + def buffers(val): + x, step = val + return x + init_val = (jnp.array([0.3, 0.3, 0.3, 0.3, 0.3]), 0) - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=0) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=0, buffers=buffers) assert shaped_allclose(val[0], jnp.array([0.3, 0.3, 0.3, 0.3, 0.3])) and val[1] == 0 - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=1) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=1, buffers=buffers) assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.3, 0.3, 0.3])) and val[1] == 1 - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=2) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=2, buffers=buffers) assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.3, 0.3])) and val[1] == 2 - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=4) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=4, buffers=buffers) assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.6, 0.7])) and val[1] == 4 - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=8) + val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=8, buffers=buffers) assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.6, 0.8])) and val[1] == 5 - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=None) + val = bounded_while_loop( + cond_fun, body_fun, init_val, max_steps=None, buffers=buffers + ) assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.6, 0.8])) and val[1] == 5 @@ -137,50 +147,473 @@ def body_fun(val): step = step.at[()].set(step + 1) return x, step, max_step + def buffers(val): + x, step, max_step = val + return x + init_val = ( jnp.array([[0.3, 0.3, 0.3, 0.3, 0.3], [0.4, 0.4, 0.4, 0.4, 0.4]]), jnp.array([0, 1]), jnp.array([5, 3]), ) - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=0))( - init_val - ) + val = jax.vmap( + lambda v: bounded_while_loop( + cond_fun, body_fun, v, max_steps=0, buffers=buffers + ) + )(init_val) assert shaped_allclose( val[0], jnp.array([[0.3, 0.3, 0.3, 0.3, 0.3], [0.4, 0.4, 0.4, 0.4, 0.4]]) ) and jnp.array_equal(val[1], jnp.array([0, 1])) - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=1))( - init_val - ) + val = jax.vmap( + lambda v: bounded_while_loop( + cond_fun, body_fun, v, max_steps=1, buffers=buffers + ) + )(init_val) assert shaped_allclose( val[0], jnp.array([[0.3, 0.4, 0.3, 0.3, 0.3], [0.4, 0.4, 0.5, 0.4, 0.4]]) ) and jnp.array_equal(val[1], jnp.array([1, 2])) - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=2))( - init_val - ) + val = jax.vmap( + lambda v: bounded_while_loop( + cond_fun, body_fun, v, max_steps=2, buffers=buffers + ) + )(init_val) assert shaped_allclose( val[0], jnp.array([[0.3, 0.4, 0.5, 0.3, 0.3], [0.4, 0.4, 0.5, 0.6, 0.4]]) ) and jnp.array_equal(val[1], jnp.array([2, 3])) - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=4))( - init_val - ) + val = jax.vmap( + lambda v: bounded_while_loop( + cond_fun, body_fun, v, max_steps=4, buffers=buffers + ) + )(init_val) assert shaped_allclose( val[0], jnp.array([[0.3, 0.4, 0.5, 0.6, 0.7], [0.4, 0.4, 0.5, 0.6, 0.4]]) ) and jnp.array_equal(val[1], jnp.array([4, 3])) - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=8))( - init_val - ) + val = jax.vmap( + lambda v: bounded_while_loop( + cond_fun, body_fun, v, max_steps=8, buffers=buffers + ) + )(init_val) assert shaped_allclose( val[0], jnp.array([[0.3, 0.4, 0.5, 0.6, 0.8], [0.4, 0.4, 0.5, 0.6, 0.4]]) ) and jnp.array_equal(val[1], jnp.array([5, 3])) - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=None))( - init_val - ) + val = jax.vmap( + lambda v: bounded_while_loop( + cond_fun, body_fun, v, max_steps=None, buffers=buffers + ) + )(init_val) assert shaped_allclose( val[0], jnp.array([[0.3, 0.4, 0.5, 0.6, 0.8], [0.4, 0.4, 0.5, 0.6, 0.4]]) ) and jnp.array_equal(val[1], jnp.array([5, 3])) + + +# +# Remaining tests copied from Equinox's tests for `checkpointed_while_loop`. +# + + +def _get_problem(key, *, num_steps: Optional[int]): + valkey1, valkey2, modelkey = jr.split(key, 3) + + def cond_fun(carry): + if num_steps is None: + return True + else: + step, _, _ = carry + return step < num_steps + + def make_body_fun(dynamic_mlp): + mlp = eqx.combine(dynamic_mlp, static_mlp) + + def body_fun(carry): + # A simple new_val = mlp(val) tends to converge to a fixed point in just a + # few iterations, which implies zero gradient... which doesn't make for a + # test that actually tests anything. Making things rotational like this + # keeps things more interesting. + step, val1, val2 = carry + (theta,) = mlp(val1) + real, imag = val1 + z = real + imag * 1j + z = z * jnp.exp(1j * theta) + real = jnp.real(z) + imag = jnp.imag(z) + val1 = jnp.stack([real, imag]) + val2 = val2.at[step % 8].set(real) + return step + 1, val1, val2 + + return body_fun + + init_val1 = jr.normal(valkey1, (2,)) + init_val2 = jr.normal(valkey2, (20,)) + mlp = eqx.nn.MLP(2, 1, 2, 2, key=modelkey) + dynamic_mlp, static_mlp = eqx.partition(mlp, eqx.is_array) + + return cond_fun, make_body_fun, init_val1, init_val2, dynamic_mlp + + +def _while_as_scan(cond, body, init_val, max_steps): + def f(val, _): + val2 = lax.cond(cond(val), body, lambda x: x, val) + return val2, None + + final_val, _ = lax.scan(f, init_val, xs=None, length=max_steps) + return final_val + + +@pytest.mark.parametrize("buffer", (False, True)) +def test_forward(buffer, getkey): + cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( + getkey(), num_steps=5 + ) + body_fun = make_body_fun(mlp) + true_final_carry = lax.while_loop(cond_fun, body_fun, (0, init_val1, init_val2)) + if buffer: + buffer_fn = lambda i: i[2] + else: + buffer_fn = None + final_carry = bounded_while_loop( + cond_fun, + body_fun, + (0, init_val1, init_val2), + max_steps=16, + buffers=buffer_fn, + ) + assert shaped_allclose(final_carry, true_final_carry) + + +@pytest.mark.parametrize("buffer", (False, True)) +def test_backward(buffer, getkey): + cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( + getkey(), num_steps=None + ) + + @jax.jit + @jax.value_and_grad + def true_run(arg): + init_val1, init_val2, mlp = arg + body_fun = make_body_fun(mlp) + _, true_final_val1, true_final_val2 = _while_as_scan( + cond_fun, body_fun, (0, init_val1, init_val2), max_steps=14 + ) + return jnp.sum(true_final_val1) + jnp.sum(true_final_val2) + + @jax.jit + @jax.value_and_grad + def run(arg): + init_val1, init_val2, mlp = arg + if buffer: + buffer_fn = lambda i: i[2] + else: + buffer_fn = None + body_fun = make_body_fun(mlp) + _, final_val1, final_val2 = bounded_while_loop( + cond_fun, + body_fun, + (0, init_val1, init_val2), + max_steps=14, + buffers=buffer_fn, + ) + return jnp.sum(final_val1) + jnp.sum(final_val2) + + true_value, true_grad = true_run((init_val1, init_val2, mlp)) + value, grad = run((init_val1, init_val2, mlp)) + assert shaped_allclose(value, true_value) + assert shaped_allclose(grad, true_grad, rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize("buffer", (False, True)) +def test_vmap_primal_unbatched_cond(buffer, getkey): + cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( + getkey(), num_steps=14 + ) + + @jax.jit + @ft.partial(jax.vmap, in_axes=((0, 0, None),)) + @jax.value_and_grad + def true_run(arg): + init_val1, init_val2, mlp = arg + body_fun = make_body_fun(mlp) + _, true_final_val1, true_final_val2 = _while_as_scan( + cond_fun, body_fun, (0, init_val1, init_val2), max_steps=14 + ) + return jnp.sum(true_final_val1) + jnp.sum(true_final_val2) + + @jax.jit + @ft.partial(jax.vmap, in_axes=((0, 0, None),)) + @jax.value_and_grad + def run(arg): + init_val1, init_val2, mlp = arg + if buffer: + buffer_fn = lambda i: i[2] + else: + buffer_fn = None + body_fun = make_body_fun(mlp) + _, final_val1, final_val2 = bounded_while_loop( + cond_fun, + body_fun, + (0, init_val1, init_val2), + max_steps=16, + buffers=buffer_fn, + ) + return jnp.sum(final_val1) + jnp.sum(final_val2) + + init_val1, init_val2 = jtu.tree_map( + lambda x: jr.normal(getkey(), (3,) + x.shape, x.dtype), (init_val1, init_val2) + ) + true_value, true_grad = true_run((init_val1, init_val2, mlp)) + value, grad = run((init_val1, init_val2, mlp)) + assert shaped_allclose(value, true_value) + assert shaped_allclose(grad, true_grad) + + +@pytest.mark.parametrize("buffer", (False, True)) +def test_vmap_primal_batched_cond(buffer, getkey): + cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( + getkey(), num_steps=14 + ) + + @jax.jit + @ft.partial(jax.vmap, in_axes=((0, 0, None), 0)) + @jax.value_and_grad + def true_run(arg, init_step): + init_val1, init_val2, mlp = arg + body_fun = make_body_fun(mlp) + _, true_final_val1, true_final_val2 = _while_as_scan( + cond_fun, body_fun, (init_step, init_val1, init_val2), max_steps=14 + ) + return jnp.sum(true_final_val1) + jnp.sum(true_final_val2) + + @jax.jit + @ft.partial(jax.vmap, in_axes=((0, 0, None), 0)) + @jax.value_and_grad + def run(arg, init_step): + init_val1, init_val2, mlp = arg + if buffer: + buffer_fn = lambda i: i[2] + else: + buffer_fn = None + body_fun = make_body_fun(mlp) + _, final_val1, final_val2 = bounded_while_loop( + cond_fun, + body_fun, + (init_step, init_val1, init_val2), + max_steps=16, + buffers=buffer_fn, + ) + return jnp.sum(final_val1) + jnp.sum(final_val2) + + init_step = jnp.array([0, 1, 2, 3, 5, 10]) + init_val1, init_val2 = jtu.tree_map( + lambda x: jr.normal(getkey(), (6,) + x.shape, x.dtype), (init_val1, init_val2) + ) + true_value, true_grad = true_run((init_val1, init_val2, mlp), init_step) + value, grad = run((init_val1, init_val2, mlp), init_step) + assert shaped_allclose(value, true_value, rtol=1e-4, atol=1e-4) + assert shaped_allclose(grad, true_grad, rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize("buffer", (False, True)) +def test_vmap_cotangent(buffer, getkey): + cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( + getkey(), num_steps=14 + ) + + @jax.jit + @jax.jacrev + def true_run(arg): + init_val1, init_val2, mlp = arg + body_fun = make_body_fun(mlp) + _, true_final_val1, true_final_val2 = _while_as_scan( + cond_fun, body_fun, (0, init_val1, init_val2), max_steps=14 + ) + return true_final_val1, true_final_val2 + + @jax.jit + @jax.jacrev + def run(arg): + init_val1, init_val2, mlp = arg + if buffer: + buffer_fn = lambda i: i[2] + else: + buffer_fn = None + body_fun = make_body_fun(mlp) + _, final_val1, final_val2 = bounded_while_loop( + cond_fun, + body_fun, + (0, init_val1, init_val2), + max_steps=16, + buffers=buffer_fn, + ) + return final_val1, final_val2 + + true_jac = true_run((init_val1, init_val2, mlp)) + jac = run((init_val1, init_val2, mlp)) + assert shaped_allclose(jac, true_jac, rtol=1e-4, atol=1e-4) + + +# This tests the possible failure mode of "the buffer doesn't do anything". +# This test takes O(1e-3) seconds with buffer. +# This test takes O(10) seconds without buffer. +# This speed improvement is precisely the reason that buffer exists. +def test_speed_buffer_while(): + size = 16**4 + + @jax.jit + @jax.vmap + def f(init_step, init_xs): + def cond(carry): + step, xs = carry + return step < size + + def body(carry): + step, xs = carry + xs = xs.at[step].set(1) + return step + 1, xs + + def loop(init_xs): + return bounded_while_loop( + cond, + body, + (init_step, init_xs), + max_steps=size, + buffers=lambda i: i[1], + ) + + # Linearize so that we save residuals + return jax.linearize(loop, init_xs) + + # nontrivial batch size is important to ensure that the `.at[].set()` is really a + # scatter, and that XLA doesn't optimise it into a dynamic_update_slice. (Which + # can be switched with `select` in the compiler.) + args = jnp.array([0, 1]), jnp.zeros((2, size)) + f(*args) # compile + + speed = timeit.timeit(lambda: f(*args), number=1) + assert speed < 0.1 + + +# This isn't testing any particular failure mode: just that things generally work. +def test_speed_grad_checkpointed_while(getkey): + mlp = eqx.nn.MLP(2, 1, 2, 2, key=getkey()) + + @jax.jit + @jax.vmap + @jax.grad + def f(init_val, init_step): + def cond(carry): + step, _ = carry + return step < 8 * 16**3 + + def body(carry): + step, val = carry + (theta,) = mlp(val) + real, imag = val + z = real + imag * 1j + z = z * jnp.exp(1j * theta) + real = jnp.real(z) + imag = jnp.imag(z) + return step + 1, jnp.stack([real, imag]) + + _, final_xs = bounded_while_loop( + cond, + body, + (init_step, init_val), + max_steps=16**3, + ) + return jnp.sum(final_xs) + + init_step = jnp.array([0, 10]) + init_val = jr.normal(getkey(), (2, 2)) + + f(init_val, init_step) # compile + speed = timeit.timeit(lambda: f(init_val, init_step), number=1) + # Should take ~0.001 seconds + assert speed < 0.01 + + +# This is deliberately meant to emulate the pattern of saving used in +# `diffrax.diffeqsolve(..., saveat=SaveAt(ts=...))`. +def test_nested_loops(getkey): + @ft.partial(jax.jit, static_argnums=5) + @ft.partial(jax.vmap, in_axes=(0, 0, 0, 0, 0, None)) + def run(step, vals, ts, final_step, cotangents, true): + value, vjp_fn = jax.vjp( + lambda *v: outer_loop(step, v, ts, true, final_step), *vals + ) + cotangents = vjp_fn(cotangents) + return value, cotangents + + def outer_loop(step, vals, ts, true, final_step): + def cond(carry): + step, _ = carry + return step < final_step + + def body(carry): + step, (val1, val2, val3, val4) = carry + mul = 1 + 0.05 * jnp.sin(105 * val1 + 1) + val1 = val1 * mul + return inner_loop(step, (val1, val2, val3, val4), ts, true) + + def buffers(carry): + _, (_, val2, val3, _) = carry + return val2, val3 + + if true: + while_loop = ft.partial(_while_as_scan, max_steps=50) + else: + while_loop = ft.partial(bounded_while_loop, max_steps=50, buffers=buffers) + _, out = while_loop(cond, body, (step, vals)) + return out + + def inner_loop(step, vals, ts, true): + ts_done = jnp.floor(ts[step] + 1) + + def cond(carry): + step, _ = carry + return ts[step] < ts_done + + def body(carry): + step, (val1, val2, val3, val4) = carry + mul = 1 + 0.05 * jnp.sin(100 * val1 + 3) + val1 = val1 * mul + val2 = val2.at[step].set(val1) + val3 = val3.at[step].set(val1) + val4 = val4.at[step].set(val1) + return step + 1, (val1, val2, val3, val4) + + def buffers(carry): + _, (_, _, val3, val4) = carry + return val3, val4 + + if true: + while_loop = ft.partial(_while_as_scan, max_steps=10) + else: + while_loop = ft.partial(bounded_while_loop, max_steps=10, buffers=buffers) + return while_loop(cond, body, (step, vals)) + + step = jnp.array([0, 5]) + val1 = jr.uniform(getkey(), shape=(2,), minval=0.1, maxval=0.7) + val2 = val3 = val4 = jnp.zeros((2, 47)) + ts = jnp.stack([jnp.linspace(0, 19, 47), jnp.linspace(0, 13, 47)]) + final_step = jnp.array([46, 43]) + cotangents = ( + jr.normal(getkey(), (2,)), + jr.normal(getkey(), (2, 47)), + jr.normal(getkey(), (2, 47)), + jr.normal(getkey(), (2, 47)), + ) + + value, grads = run( + step, (val1, val2, val3, val4), ts, final_step, cotangents, False + ) + true_value, true_grads = run( + step, (val1, val2, val3, val4), ts, final_step, cotangents, True + ) + + assert shaped_allclose(value, true_value) + assert shaped_allclose(grads, true_grads, rtol=1e-4, atol=1e-5) From 9a993138e10f8cbb018f48e19c1dca114f220183 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 14 Feb 2023 00:42:38 -0800 Subject: [PATCH 10/19] Moved bounded_while_loop into Equinox --- .github/workflows/run_tests.yml | 2 +- diffrax/adjoint.py | 309 ++++++++-------- diffrax/bounded_while_loop.py | 190 ---------- diffrax/integrate.py | 19 +- docs/api/adjoints.md | 6 +- test/test_bounded_while_loop.py | 619 -------------------------------- test/test_brownian.py | 13 +- 7 files changed, 184 insertions(+), 974 deletions(-) delete mode 100644 diffrax/bounded_while_loop.py delete mode 100644 test/test_bounded_while_loop.py diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 486c8b75..935e17f6 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -7,7 +7,7 @@ jobs: run-tests: strategy: matrix: - python-version: [ 3.7, 3.8, 3.9 ] + python-version: [ 3.8, 3.9 ] os: [ ubuntu-latest ] fail-fast: false runs-on: ${{ matrix.os }} diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index 1ec66694..40ec4602 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -12,8 +12,7 @@ from equinox.internal import ω from .ad import implicit_jvp -from .bounded_while_loop import bounded_while_loop -from .heuristics import is_unsafe_sde +from .heuristics import is_sde, is_unsafe_sde from .saveat import SaveAt from .solver import AbstractItoSolver, AbstractStratonovichSolver from .term import AbstractTerm, AdjointTerm @@ -23,69 +22,31 @@ def _is_none(x): return x is None -def _no_transpose_final_state(final_state): - y = eqxi.nondifferentiable_backward(final_state.y, name="y") - tprev = eqxi.nondifferentiable_backward(final_state.tprev, name="tprev") - tnext = eqxi.nondifferentiable_backward(final_state.tnext, name="tnext") - solver_state = eqxi.nondifferentiable_backward( - final_state.solver_state, name="solver_state" - ) - controller_state = eqxi.nondifferentiable_backward( - final_state.controller_state, name="controller_state" - ) - ts = eqxi.nondifferentiable_backward(final_state.ts, name="ts") - ys = final_state.ys - dense_ts = eqxi.nondifferentiable_backward(final_state.dense_ts, name="dense_ts") - dense_infos = eqxi.nondifferentiable_backward( - final_state.dense_infos, name="dense_infos" - ) - final_state = eqxi.nondifferentiable_backward(final_state) # no more specific name - final_state = eqx.tree_at( - lambda s: ( - s.y, - s.tprev, - s.tnext, - s.solver_state, - s.controller_state, - s.ts, - s.ys, - s.dense_ts, - s.dense_infos, - ), - final_state, - ( - y, - tprev, - tnext, - solver_state, - controller_state, - ts, - ys, - dense_ts, - dense_infos, - ), - is_leaf=_is_none, +def _only_transpose_ys(final_state): + entries = ( + "y", + "tprev", + "tnext", + "solver_state", + "controller_state", + "ts", + "dense_ts", + "dense_infos", ) + values = { + k: eqxi.nondifferentiable_backward( + getattr(final_state, k), name=k, symbolic=False + ) + for k in entries + } + values["ys"] = final_state.ys + final_state = eqxi.nondifferentiable_backward(final_state, symbolic=False) + get = lambda s: tuple(getattr(s, k) for k in entries + ("ys",)) + replace = tuple(values[k] for k in entries + ("ys",)) + final_state = eqx.tree_at(get, final_state, replace, is_leaf=_is_none) return final_state -def _while_loop(cond_fun, body_fun, init_val, max_steps): - if max_steps is None: - return lax.while_loop(cond_fun, body_fun, init_val) - else: - - def _cond_fun(carry): - step, val = carry - return (step < max_steps) & cond_fun(val) - - def _body_fun(carry): - step, val = carry - return step + 1, body_fun(val) - - _, final_val = lax.while_loop(_cond_fun, _body_fun, (0, init_val)) - return final_val - - class AbstractAdjoint(eqx.Module): """Abstract base class for all adjoint methods.""" @@ -138,41 +99,28 @@ def _diffeqsolve(self): return diffeqsolve -class DirectAdjoint(AbstractAdjoint): - """A variant of [`diffrax.RecursiveCheckpointAdjoint`][]. The differences are that - `DirectAdjoint`: +def _inner_buffers(state): + assert type(state).__name__ == "_InnerState" + assert {f.name for f in fields(state)} == { + "ts", + "ys", + "saveat_ts_index", + "save_index", + } + return state.ts, state.ys - - Is less time+memory efficient at reverse-mode autodifferentiation (specifically, - these will increase every time `max_steps` increases passes a power of 16); - - Cannot be reverse-mode autodifferentated if `max_steps is None`; - - Supports forward-mode autodifferentiation. - So unless you need forward-mode autodifferentiation then - [`diffrax.RecursiveCheckpointAdjoint`][] should be preferred. - """ +def _outer_buffers(state): + assert type(state).__name__ == "_State" + return state.ts, state.ys, state.dense_ts, state.dense_infos - def loop( - self, - *, - max_steps, - terms, - throw, - passed_solver_state, - passed_controller_state, - **kwargs, - ): - del throw, passed_solver_state, passed_controller_state - if is_unsafe_sde(terms) or max_steps is None: - while_loop = _while_loop - else: - while_loop = bounded_while_loop - return self._loop( - **kwargs, - max_steps=max_steps, - terms=terms, - inner_while_loop=while_loop, - outer_while_loop=while_loop, - ) + +_inner_loop = ft.partial(eqxi.while_loop, buffers=_inner_buffers) +_outer_loop = ft.partial(eqxi.while_loop, buffers=_outer_buffers) + + +def _uncallable(*args, **kwargs): + assert False class RecursiveCheckpointAdjoint(AbstractAdjoint): @@ -271,53 +219,50 @@ def loop( **kwargs, ): del throw, passed_solver_state, passed_controller_state - if self.checkpoints is None and max_steps is None: - # Raise a more informative error than `checkpointed_while_loop` would. - raise ValueError( - "Cannot use " - "`diffeqsolve(..., max_steps=None, adjoint=RecursiveCheckpointAdjoint(checkpoints=None))` " # noqa: E501 - "Either specify the number of `checkpoints` to use, or specify the " - "maximum number of steps (and `checkpoints` is then chosen " - "automatically as `log(max_steps)`)." - ) if is_unsafe_sde(terms): raise ValueError( "`adjoint=RecursiveCheckpointAdjoint()` does not support " "`UnsafeBrownianPath`. Consider using `adjoint=DirectAdjoint()` " "instead." ) - - def inner_buffers(state): - assert type(state).__name__ == "_InnerState" - assert {f.name for f in fields(state)} == { - "ts", - "ys", - "saveat_ts_index", - "saveat_index", - } - return state.ts, state.ys - - def outer_buffers(state): - assert type(state).__name__ == "_State" - return state.ts, state.ys, state.dense_ts, state.dense_infos - - return self._loop( + if self.checkpoints is None and max_steps is None: + if saveat.ts is None: + inner_while_loop = _uncallable + else: + inner_while_loop = ft.partial(_inner_loop, kind="lax") + outer_while_loop = ft.partial(_outer_loop, kind="lax") + msg = ( + "Cannot reverse-mode autodifferentiate when using " + "`diffeqsolve(..., max_steps=None, adjoint=RecursiveCheckpointAdjoint(checkpoints=None))`. " # noqa: E501 + "This is because JAX needs to know how much memory to allocate for " + "saving the forward pass. You should either put a bound on the maximum " + "number of steps, or explicitly specify how many checkpoints to use." + ) + else: + if saveat.ts is None: + inner_while_loop = _uncallable + else: + inner_while_loop = ft.partial( + _inner_loop, kind="checkpointed", checkpoints=len(saveat.ts) + ) + outer_while_loop = ft.partial( + _outer_loop, kind="checkpointed", checkpoints=self.checkpoints + ) + msg = None + final_state = self._loop( terms=terms, saveat=saveat, init_state=init_state, max_steps=max_steps, - inner_while_loop=ft.partial( - eqxi.checkpointed_while_loop, - checkpoints=(len(saveat.ts),), - buffers=inner_buffers, - ), - outer_while_loop=ft.partial( - eqxi.checkpointed_while_loop, - checkpoints=self.checkpoints, - buffers=outer_buffers, - ), + inner_while_loop=inner_while_loop, + outer_while_loop=outer_while_loop, **kwargs, ) + if msg is not None: + final_state = eqxi.nondifferentiable_backward( + final_state, msg=msg, symbolic=True + ) + return final_state RecursiveCheckpointAdjoint.__init__.__doc__ = """ @@ -330,9 +275,77 @@ def outer_buffers(state): This value can also be set to `None` (the default), in which case it will be set to `log(max_steps)`, for which a theoretical result is available guaranteeing that backpropagation will take `O(n log n)` time in the number of steps `n <= max_steps`. + +You must pass either `diffeqsolve(..., max_steps=...)` or +`RecursiveCheckpointAdjoint(checkpoints=...)` to be able to backpropagate; otherwise +the computation will not be autodifferentiable. """ +class DirectAdjoint(AbstractAdjoint): + """A variant of [`diffrax.RecursiveCheckpointAdjoint`][]. The differences are that + `DirectAdjoint`: + + - Is less time+memory efficient at reverse-mode autodifferentiation (specifically, + these will increase every time `max_steps` increases passes a power of 16); + - Cannot be reverse-mode autodifferentated if `max_steps is None`; + - Supports forward-mode autodifferentiation. + + So unless you need forward-mode autodifferentiation then + [`diffrax.RecursiveCheckpointAdjoint`][] should be preferred. + + This is not reverse-mode autodifferentiable if `diffeqsolve(..., max_steps=None)`. + """ + + def loop( + self, + *, + max_steps, + terms, + throw, + passed_solver_state, + passed_controller_state, + **kwargs, + ): + del throw, passed_solver_state, passed_controller_state + # TODO: remove the `is_unsafe_sde` guard. + # We need JAX to release bloops, so that we can deprecate `kind="bounded"`. + if is_unsafe_sde(terms): + kind = "lax" + msg = ( + "Cannot reverse-mode autodifferentiate when using " + "`UnsafeBrownianPath`." + ) + elif max_steps is None: + kind = "lax" + msg = ( + "Cannot reverse-mode autodifferentiate when using " + "`diffeqsolve(..., max_steps=None, adjoint=DirectAdjoint())`. " + "This is because JAX needs to know how much memory to allocate for " + "saving the forward pass. You should either put a bound on the maximum " + "number of steps, or switch to " + "`adjoint=RecursiveCheckpointAdjoint(checkpoints=...)`, with an " + "explicitly specified number of checkpoints." + ) + else: + kind = "bounded" + msg = None + inner_while_loop = ft.partial(_inner_loop, kind=kind) + outer_while_loop = ft.partial(_outer_loop, kind=kind) + final_state = self._loop( + **kwargs, + max_steps=max_steps, + terms=terms, + inner_while_loop=inner_while_loop, + outer_while_loop=outer_while_loop, + ) + if msg is not None: + final_state = eqxi.nondifferentiable_backward( + final_state, msg=msg, symbolic=True + ) + return final_state + + def _vf(ys, residual, args__terms, closure): state_no_y, _ = residual t = state_no_y.tprev @@ -353,8 +366,8 @@ def _solve(args__terms, closure): solver=solver, saveat=saveat, init_state=init_state, - inner_while_loop=_while_loop, - outer_while_loop=_while_loop, + inner_while_loop=ft.partial(_inner_loop, kind="lax"), + outer_while_loop=ft.partial(_outer_loop, kind="lax"), ) # Note that we use .ys not .y here. The former is what is actually returned # by diffeqsolve, so it is the thing we want to attach the tangent to. @@ -420,7 +433,7 @@ def loop( final_state = eqx.tree_at( lambda s: s.ys, final_state_no_ys, ys, is_leaf=_is_none ) - final_state = _no_transpose_final_state(final_state) + final_state = _only_transpose_ys(final_state) return final_state, aux_stats @@ -440,8 +453,9 @@ def _loop_backsolve(y__args__terms, *, self, throw, init_state, **kwargs): args=args, terms=terms, init_state=init_state, - inner_while_loop=_while_loop, - outer_while_loop=_while_loop**kwargs, + inner_while_loop=ft.partial(_inner_loop, kind="lax"), + outer_while_loop=ft.partial(_outer_loop, kind="lax"), + **kwargs, ) @@ -583,6 +597,8 @@ def __get(__aug): else: if len(ts) > 1: + # TODO: fold this `_scan_fun` into the `lax.scan`. This will reduce compile + # time. val0 = (ts[-2], ts[-1], ω(ys)[-1].ω, ω(grad_ys)[-1].ω) state, _ = _scan_fun(state, val0, first=True) vals = ( @@ -688,17 +704,20 @@ def loop( "`adjoint=BacksolveAdjoint()` does not support `UnsafeBrownianPath`. " "Consider using `adjoint=DirectAdjoint()` instead." ) - if isinstance(solver, AbstractItoSolver): - raise NotImplementedError( - f"`{solver.__name__}` converges to the Itô solution. However " - "`BacksolveAdjoint` currently only supports Stratonovich SDEs." - ) - elif not isinstance(solver, AbstractStratonovichSolver): - warnings.warn( - f"{solver.__name__} is not marked as converging to either the Itô " - "or the Stratonovich solution. Note that `BacksolveAdjoint` will " - "only produce the correct solution for Stratonovich SDEs." - ) + if is_sde(terms): + if isinstance(solver, AbstractItoSolver): + raise NotImplementedError( + f"`{solver.__class__.__name__}` converges to the Itô solution. " + "However `BacksolveAdjoint` currently only supports Stratonovich " + "SDEs." + ) + elif not isinstance(solver, AbstractStratonovichSolver): + warnings.warn( + f"{solver.___class__._name__} is not marked as converging to " + "either the Itô or the Stratonovich solution. Note that " + "`BacksolveAdjoint` will only produce the correct solution for " + "Stratonovich SDEs." + ) y = init_state.y sentinel = object() @@ -714,5 +733,5 @@ def loop( solver=solver, **kwargs, ) - final_state = _no_transpose_final_state(final_state) + final_state = _only_transpose_ys(final_state) return final_state, aux_stats diff --git a/diffrax/bounded_while_loop.py b/diffrax/bounded_while_loop.py deleted file mode 100644 index 5378c9ad..00000000 --- a/diffrax/bounded_while_loop.py +++ /dev/null @@ -1,190 +0,0 @@ -import functools as ft -import math -from typing import Any, Callable, Optional, Union - -import equinox as eqx -import equinox.internal as eqxi -import jax -import jax.lax as lax -import jax.numpy as jnp -import jax.tree_util as jtu - - -def bounded_while_loop( - cond_fun, - body_fun, - init_val, - max_steps: Optional[int], - *, - buffers: Optional[Callable] = None, - base: int = 16 -): - """Reverse-mode autodifferentiable while loop. - - This only exists to support a few edge cases: - - forward-mode autodiff; - - reading from `buffers`. - You should almost always prefer to use `equinox.internal.checkpointed_while_loop` - instead. - - Once 'bloops' land in JAX core then this function will be removed. - - **Arguments:** - - - cond_fun: function `a -> bool`. - - body_fun: function `a -> a`. - - init_val: pytree of type `a`. - - max_steps: integer or `None`. - - buffers: function `a -> node or nodes`. - - base: integer. - - Note the extra `max_steps` argument. If this is `None` then `bounded_while_loop` - will fall back to `lax.while_loop` (which is not reverse-mode autodifferentiable). - If it is a non-negative integer then this is the maximum number of steps which may - be taken in the loop, after which the loop will exit unconditionally. - - Note the extra `buffers` argument. This behaves similarly to the same argument for - `equinox.internal.checkpointed_while_loop`: these support efficient in-place updates - but no operation. (Unlike `checkpointed_while_loop`, however, this supports being - read from.) - - Note the extra `base` argument. - - Run time will increase slightly as `base` increases. - - Compilation time will decrease substantially as - `math.ceil(math.log(max_steps, base))` decreases. (Which happens as `base` - increases.) - """ - - init_val = jtu.tree_map(jnp.asarray, init_val) - - if max_steps is None: - return lax.while_loop(cond_fun, body_fun, init_val) - - if not isinstance(max_steps, int) or max_steps < 0: - raise ValueError("max_steps must be a non-negative integer") - if max_steps == 0: - return init_val - - def _cond_fun(val, step): - return cond_fun(val) & (step < max_steps) - - init_data = (cond_fun(init_val), init_val, 0) - rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base))) - if buffers is None: - buffers = lambda _: () - _, val, _ = _while_loop( - _cond_fun, body_fun, init_data, rounded_max_steps, buffers, base - ) - return val - - -def _while_loop(cond_fun, body_fun, data, max_steps, buffers, base): - if max_steps == 1: - pred, val, step = data - - tag = object() - - def _buffers(v): - nodes = buffers(v) - tree = jtu.tree_map(_unwrap_buffers, nodes, is_leaf=_is_buffer) - return jtu.tree_leaves(tree) - - val = eqx.tree_at( - _buffers, val, replace_fn=ft.partial(_Buffer, _pred=pred, _tag=tag) - ) - new_val = body_fun(val) - if jax.eval_shape(lambda: val) != jax.eval_shape(lambda: new_val): - raise ValueError("body_fun must have matching input and output structures") - - def _is_our_buffer(x): - return isinstance(x, _Buffer) and x._tag is tag - - def _unwrap_or_select(new_v, v): - if _is_our_buffer(new_v): - assert _is_our_buffer(v) - assert eqx.is_array(new_v._array) - assert eqx.is_array(v._array) - return new_v._array - else: - return lax.select(pred, new_v, v) - - new_val = jtu.tree_map(_unwrap_or_select, new_val, val, is_leaf=_is_our_buffer) - new_step = step + 1 - return cond_fun(new_val, new_step), new_val, new_step - else: - - def _call(_data): - return _while_loop( - cond_fun, body_fun, _data, max_steps // base, buffers, base - ) - - def _scan_fn(_data, _): - _pred, _, _ = _data - _unvmap_pred = eqxi.unvmap_any(_pred) - return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None - - # Don't put checkpointing on the lowest level - if max_steps != base: - _scan_fn = jax.checkpoint(_scan_fn, prevent_cse=False) - - return lax.scan(_scan_fn, data, xs=None, length=base)[0] - - -def _is_buffer(x): - return isinstance(x, _Buffer) - - -def _unwrap_buffers(x): - while _is_buffer(x): - x = x._array - return x - - -class _Buffer(eqx.Module): - _array: Union[jnp.ndarray, "_Buffer"] - _pred: jnp.ndarray - _tag: object = eqx.static_field() - - def __getitem__(self, item): - return self._array[item] - - def _set(self, pred, item, x): - pred = pred & self._pred - if isinstance(self._array, _Buffer): - array = self._array._set(pred, item, x) - else: - old_x = self._array[item] - x = jnp.where(pred, x, old_x) - array = self._array.at[item].set(x) - return _Buffer(array, self._pred, self._tag) - - @property - def at(self): - return _BufferAt(self) - - @property - def shape(self): - return self._array.shape - - @property - def dtype(self): - return self._array.dtype - - @property - def size(self): - return self._array.size - - -class _BufferAt(eqx.Module): - _buffer: _Buffer - - def __getitem__(self, item): - return _BufferItem(self._buffer, item) - - -class _BufferItem(eqx.Module): - _buffer: _Buffer - _item: Any - - def set(self, x): - return self._buffer._set(True, self._item, x) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 9637f733..812f8fb4 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -237,9 +237,9 @@ def _body_fun(_state): _saveat_y = _interpolator.evaluate(_saveat_t) _ts = _state.ts.at[_state.save_index].set(_saveat_t) _ys = jtu.tree_map( - lambda __ys, __saveat_y: __ys.at[_state.save_index].set(__saveat_y), - _state.ys, + lambda __saveat_y, __ys: __ys.at[_state.save_index].set(__saveat_y), _saveat_y, + _state.ys, ) return _InnerState( saveat_ts_index=_state.saveat_ts_index + 1, @@ -261,21 +261,20 @@ def _body_fun(_state): ys = final_inner_state.ys save_index = final_inner_state.save_index - # TODO: make while loop? - def maybe_inplace(i, x, u): - return x.at[i].set(jnp.where(keep_step, u, x[i])) + def maybe_inplace(i, u, x): + return x.at[i].set(u, pred=keep_step) if saveat.steps: - ts = maybe_inplace(save_index, ts, tprev) - ys = jtu.tree_map(ft.partial(maybe_inplace, save_index), ys, y) + ts = maybe_inplace(save_index, tprev, ts) + ys = jtu.tree_map(ft.partial(maybe_inplace, save_index), y, ys) save_index = save_index + keep_step if saveat.dense: - dense_ts = maybe_inplace(dense_save_index + 1, dense_ts, tprev) + dense_ts = maybe_inplace(dense_save_index + 1, tprev, dense_ts) dense_infos = jtu.tree_map( ft.partial(maybe_inplace, dense_save_index), - dense_infos, dense_info, + dense_infos, ) dense_save_index = dense_save_index + keep_step @@ -321,7 +320,7 @@ def maybe_inplace(i, x, u): return new_state - final_state = outer_while_loop(cond_fun, body_fun, init_state, max_steps) + final_state = outer_while_loop(cond_fun, body_fun, init_state, max_steps=max_steps) if saveat.t1 and not saveat.steps: # if saveat.steps then the final value is already saved. diff --git a/docs/api/adjoints.md b/docs/api/adjoints.md index cc04d63e..a5870b8d 100644 --- a/docs/api/adjoints.md +++ b/docs/api/adjoints.md @@ -21,7 +21,7 @@ There are multiple ways to backpropagate through a differential equation (to com members: - loop -Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.BacksolveAdjoint`][] can only be reverse-mode autodifferentiated. [`diffrax.DirectAdjoint`][] and [`diffrax.ImplicitAdjoint`][] support both forward and reverse-mode autodifferentiation. +Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.BacksolveAdjoint`][] can only be reverse-mode autodifferentiated. [`diffrax.ImplicitAdjoint`][] and [`diffrax.DirectAdjoint`][] support both forward and reverse-mode autodifferentiation. --- @@ -35,11 +35,11 @@ Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax members: - __init__ -::: diffrax.DirectAdjoint +::: diffrax.ImplicitAdjoint selection: members: false -::: diffrax.ImplicitAdjoint +::: diffrax.DirectAdjoint selection: members: false diff --git a/test/test_bounded_while_loop.py b/test/test_bounded_while_loop.py deleted file mode 100644 index 4b32ad80..00000000 --- a/test/test_bounded_while_loop.py +++ /dev/null @@ -1,619 +0,0 @@ -import functools as ft -import timeit -from typing import Optional - -import equinox as eqx -import jax -import jax.lax as lax -import jax.numpy as jnp -import jax.random as jr -import jax.tree_util as jtu -import pytest -from diffrax.bounded_while_loop import bounded_while_loop - -from .helpers import shaped_allclose - - -def test_functional_no_vmap_no_inplace(): - def cond_fun(val): - x, step = val - return step < 5 - - def body_fun(val): - x, step = val - return (x + 0.1, step + 1) - - init_val = (jnp.array([0.3]), 0) - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=0) - assert shaped_allclose(val[0], jnp.array([0.3])) and val[1] == 0 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=1) - assert shaped_allclose(val[0], jnp.array([0.4])) and val[1] == 1 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=2) - assert shaped_allclose(val[0], jnp.array([0.5])) and val[1] == 2 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=4) - assert shaped_allclose(val[0], jnp.array([0.7])) and val[1] == 4 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=8) - assert shaped_allclose(val[0], jnp.array([0.8])) and val[1] == 5 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=None) - assert shaped_allclose(val[0], jnp.array([0.8])) and val[1] == 5 - - -def test_functional_no_vmap_inplace(): - def cond_fun(val): - x, step = val - return step < 5 - - def body_fun(val): - x, step = val - x = x.at[jnp.minimum(step + 1, 4)].set(x[step] + 0.1) - step = step.at[()].set(step + 1) - return x, step - - def buffers(val): - x, step = val - return x - - init_val = (jnp.array([0.3, 0.3, 0.3, 0.3, 0.3]), 0) - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=0, buffers=buffers) - assert shaped_allclose(val[0], jnp.array([0.3, 0.3, 0.3, 0.3, 0.3])) and val[1] == 0 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=1, buffers=buffers) - assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.3, 0.3, 0.3])) and val[1] == 1 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=2, buffers=buffers) - assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.3, 0.3])) and val[1] == 2 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=4, buffers=buffers) - assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.6, 0.7])) and val[1] == 4 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=8, buffers=buffers) - assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.6, 0.8])) and val[1] == 5 - - val = bounded_while_loop( - cond_fun, body_fun, init_val, max_steps=None, buffers=buffers - ) - assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.6, 0.8])) and val[1] == 5 - - -def test_functional_vmap_no_inplace(): - def cond_fun(val): - x, step = val - return step < 5 - - def body_fun(val): - x, step = val - return (x + 0.1, step + 1) - - init_val = (jnp.array([[0.3], [0.4]]), jnp.array([0, 3])) - - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=0))( - init_val - ) - assert shaped_allclose(val[0], jnp.array([[0.3], [0.4]])) and jnp.array_equal( - val[1], jnp.array([0, 3]) - ) - - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=1))( - init_val - ) - assert shaped_allclose(val[0], jnp.array([[0.4], [0.5]])) and jnp.array_equal( - val[1], jnp.array([1, 4]) - ) - - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=2))( - init_val - ) - assert shaped_allclose(val[0], jnp.array([[0.5], [0.6]])) and jnp.array_equal( - val[1], jnp.array([2, 5]) - ) - - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=4))( - init_val - ) - assert shaped_allclose(val[0], jnp.array([[0.7], [0.6]])) and jnp.array_equal( - val[1], jnp.array([4, 5]) - ) - - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=8))( - init_val - ) - assert shaped_allclose(val[0], jnp.array([[0.8], [0.6]])) and jnp.array_equal( - val[1], jnp.array([5, 5]) - ) - - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=None))( - init_val - ) - assert shaped_allclose(val[0], jnp.array([[0.8], [0.6]])) and jnp.array_equal( - val[1], jnp.array([5, 5]) - ) - - -def test_functional_vmap_inplace(): - def cond_fun(val): - x, step, max_step = val - return step < max_step - - def body_fun(val): - x, step, max_step = val - x = x.at[jnp.minimum(step + 1, 4)].set(x[step] + 0.1) - step = step.at[()].set(step + 1) - return x, step, max_step - - def buffers(val): - x, step, max_step = val - return x - - init_val = ( - jnp.array([[0.3, 0.3, 0.3, 0.3, 0.3], [0.4, 0.4, 0.4, 0.4, 0.4]]), - jnp.array([0, 1]), - jnp.array([5, 3]), - ) - - val = jax.vmap( - lambda v: bounded_while_loop( - cond_fun, body_fun, v, max_steps=0, buffers=buffers - ) - )(init_val) - assert shaped_allclose( - val[0], jnp.array([[0.3, 0.3, 0.3, 0.3, 0.3], [0.4, 0.4, 0.4, 0.4, 0.4]]) - ) and jnp.array_equal(val[1], jnp.array([0, 1])) - - val = jax.vmap( - lambda v: bounded_while_loop( - cond_fun, body_fun, v, max_steps=1, buffers=buffers - ) - )(init_val) - assert shaped_allclose( - val[0], jnp.array([[0.3, 0.4, 0.3, 0.3, 0.3], [0.4, 0.4, 0.5, 0.4, 0.4]]) - ) and jnp.array_equal(val[1], jnp.array([1, 2])) - - val = jax.vmap( - lambda v: bounded_while_loop( - cond_fun, body_fun, v, max_steps=2, buffers=buffers - ) - )(init_val) - assert shaped_allclose( - val[0], jnp.array([[0.3, 0.4, 0.5, 0.3, 0.3], [0.4, 0.4, 0.5, 0.6, 0.4]]) - ) and jnp.array_equal(val[1], jnp.array([2, 3])) - - val = jax.vmap( - lambda v: bounded_while_loop( - cond_fun, body_fun, v, max_steps=4, buffers=buffers - ) - )(init_val) - assert shaped_allclose( - val[0], jnp.array([[0.3, 0.4, 0.5, 0.6, 0.7], [0.4, 0.4, 0.5, 0.6, 0.4]]) - ) and jnp.array_equal(val[1], jnp.array([4, 3])) - - val = jax.vmap( - lambda v: bounded_while_loop( - cond_fun, body_fun, v, max_steps=8, buffers=buffers - ) - )(init_val) - assert shaped_allclose( - val[0], jnp.array([[0.3, 0.4, 0.5, 0.6, 0.8], [0.4, 0.4, 0.5, 0.6, 0.4]]) - ) and jnp.array_equal(val[1], jnp.array([5, 3])) - - val = jax.vmap( - lambda v: bounded_while_loop( - cond_fun, body_fun, v, max_steps=None, buffers=buffers - ) - )(init_val) - assert shaped_allclose( - val[0], jnp.array([[0.3, 0.4, 0.5, 0.6, 0.8], [0.4, 0.4, 0.5, 0.6, 0.4]]) - ) and jnp.array_equal(val[1], jnp.array([5, 3])) - - -# -# Remaining tests copied from Equinox's tests for `checkpointed_while_loop`. -# - - -def _get_problem(key, *, num_steps: Optional[int]): - valkey1, valkey2, modelkey = jr.split(key, 3) - - def cond_fun(carry): - if num_steps is None: - return True - else: - step, _, _ = carry - return step < num_steps - - def make_body_fun(dynamic_mlp): - mlp = eqx.combine(dynamic_mlp, static_mlp) - - def body_fun(carry): - # A simple new_val = mlp(val) tends to converge to a fixed point in just a - # few iterations, which implies zero gradient... which doesn't make for a - # test that actually tests anything. Making things rotational like this - # keeps things more interesting. - step, val1, val2 = carry - (theta,) = mlp(val1) - real, imag = val1 - z = real + imag * 1j - z = z * jnp.exp(1j * theta) - real = jnp.real(z) - imag = jnp.imag(z) - val1 = jnp.stack([real, imag]) - val2 = val2.at[step % 8].set(real) - return step + 1, val1, val2 - - return body_fun - - init_val1 = jr.normal(valkey1, (2,)) - init_val2 = jr.normal(valkey2, (20,)) - mlp = eqx.nn.MLP(2, 1, 2, 2, key=modelkey) - dynamic_mlp, static_mlp = eqx.partition(mlp, eqx.is_array) - - return cond_fun, make_body_fun, init_val1, init_val2, dynamic_mlp - - -def _while_as_scan(cond, body, init_val, max_steps): - def f(val, _): - val2 = lax.cond(cond(val), body, lambda x: x, val) - return val2, None - - final_val, _ = lax.scan(f, init_val, xs=None, length=max_steps) - return final_val - - -@pytest.mark.parametrize("buffer", (False, True)) -def test_forward(buffer, getkey): - cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( - getkey(), num_steps=5 - ) - body_fun = make_body_fun(mlp) - true_final_carry = lax.while_loop(cond_fun, body_fun, (0, init_val1, init_val2)) - if buffer: - buffer_fn = lambda i: i[2] - else: - buffer_fn = None - final_carry = bounded_while_loop( - cond_fun, - body_fun, - (0, init_val1, init_val2), - max_steps=16, - buffers=buffer_fn, - ) - assert shaped_allclose(final_carry, true_final_carry) - - -@pytest.mark.parametrize("buffer", (False, True)) -def test_backward(buffer, getkey): - cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( - getkey(), num_steps=None - ) - - @jax.jit - @jax.value_and_grad - def true_run(arg): - init_val1, init_val2, mlp = arg - body_fun = make_body_fun(mlp) - _, true_final_val1, true_final_val2 = _while_as_scan( - cond_fun, body_fun, (0, init_val1, init_val2), max_steps=14 - ) - return jnp.sum(true_final_val1) + jnp.sum(true_final_val2) - - @jax.jit - @jax.value_and_grad - def run(arg): - init_val1, init_val2, mlp = arg - if buffer: - buffer_fn = lambda i: i[2] - else: - buffer_fn = None - body_fun = make_body_fun(mlp) - _, final_val1, final_val2 = bounded_while_loop( - cond_fun, - body_fun, - (0, init_val1, init_val2), - max_steps=14, - buffers=buffer_fn, - ) - return jnp.sum(final_val1) + jnp.sum(final_val2) - - true_value, true_grad = true_run((init_val1, init_val2, mlp)) - value, grad = run((init_val1, init_val2, mlp)) - assert shaped_allclose(value, true_value) - assert shaped_allclose(grad, true_grad, rtol=1e-4, atol=1e-4) - - -@pytest.mark.parametrize("buffer", (False, True)) -def test_vmap_primal_unbatched_cond(buffer, getkey): - cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( - getkey(), num_steps=14 - ) - - @jax.jit - @ft.partial(jax.vmap, in_axes=((0, 0, None),)) - @jax.value_and_grad - def true_run(arg): - init_val1, init_val2, mlp = arg - body_fun = make_body_fun(mlp) - _, true_final_val1, true_final_val2 = _while_as_scan( - cond_fun, body_fun, (0, init_val1, init_val2), max_steps=14 - ) - return jnp.sum(true_final_val1) + jnp.sum(true_final_val2) - - @jax.jit - @ft.partial(jax.vmap, in_axes=((0, 0, None),)) - @jax.value_and_grad - def run(arg): - init_val1, init_val2, mlp = arg - if buffer: - buffer_fn = lambda i: i[2] - else: - buffer_fn = None - body_fun = make_body_fun(mlp) - _, final_val1, final_val2 = bounded_while_loop( - cond_fun, - body_fun, - (0, init_val1, init_val2), - max_steps=16, - buffers=buffer_fn, - ) - return jnp.sum(final_val1) + jnp.sum(final_val2) - - init_val1, init_val2 = jtu.tree_map( - lambda x: jr.normal(getkey(), (3,) + x.shape, x.dtype), (init_val1, init_val2) - ) - true_value, true_grad = true_run((init_val1, init_val2, mlp)) - value, grad = run((init_val1, init_val2, mlp)) - assert shaped_allclose(value, true_value) - assert shaped_allclose(grad, true_grad) - - -@pytest.mark.parametrize("buffer", (False, True)) -def test_vmap_primal_batched_cond(buffer, getkey): - cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( - getkey(), num_steps=14 - ) - - @jax.jit - @ft.partial(jax.vmap, in_axes=((0, 0, None), 0)) - @jax.value_and_grad - def true_run(arg, init_step): - init_val1, init_val2, mlp = arg - body_fun = make_body_fun(mlp) - _, true_final_val1, true_final_val2 = _while_as_scan( - cond_fun, body_fun, (init_step, init_val1, init_val2), max_steps=14 - ) - return jnp.sum(true_final_val1) + jnp.sum(true_final_val2) - - @jax.jit - @ft.partial(jax.vmap, in_axes=((0, 0, None), 0)) - @jax.value_and_grad - def run(arg, init_step): - init_val1, init_val2, mlp = arg - if buffer: - buffer_fn = lambda i: i[2] - else: - buffer_fn = None - body_fun = make_body_fun(mlp) - _, final_val1, final_val2 = bounded_while_loop( - cond_fun, - body_fun, - (init_step, init_val1, init_val2), - max_steps=16, - buffers=buffer_fn, - ) - return jnp.sum(final_val1) + jnp.sum(final_val2) - - init_step = jnp.array([0, 1, 2, 3, 5, 10]) - init_val1, init_val2 = jtu.tree_map( - lambda x: jr.normal(getkey(), (6,) + x.shape, x.dtype), (init_val1, init_val2) - ) - true_value, true_grad = true_run((init_val1, init_val2, mlp), init_step) - value, grad = run((init_val1, init_val2, mlp), init_step) - assert shaped_allclose(value, true_value, rtol=1e-4, atol=1e-4) - assert shaped_allclose(grad, true_grad, rtol=1e-4, atol=1e-4) - - -@pytest.mark.parametrize("buffer", (False, True)) -def test_vmap_cotangent(buffer, getkey): - cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( - getkey(), num_steps=14 - ) - - @jax.jit - @jax.jacrev - def true_run(arg): - init_val1, init_val2, mlp = arg - body_fun = make_body_fun(mlp) - _, true_final_val1, true_final_val2 = _while_as_scan( - cond_fun, body_fun, (0, init_val1, init_val2), max_steps=14 - ) - return true_final_val1, true_final_val2 - - @jax.jit - @jax.jacrev - def run(arg): - init_val1, init_val2, mlp = arg - if buffer: - buffer_fn = lambda i: i[2] - else: - buffer_fn = None - body_fun = make_body_fun(mlp) - _, final_val1, final_val2 = bounded_while_loop( - cond_fun, - body_fun, - (0, init_val1, init_val2), - max_steps=16, - buffers=buffer_fn, - ) - return final_val1, final_val2 - - true_jac = true_run((init_val1, init_val2, mlp)) - jac = run((init_val1, init_val2, mlp)) - assert shaped_allclose(jac, true_jac, rtol=1e-4, atol=1e-4) - - -# This tests the possible failure mode of "the buffer doesn't do anything". -# This test takes O(1e-3) seconds with buffer. -# This test takes O(10) seconds without buffer. -# This speed improvement is precisely the reason that buffer exists. -def test_speed_buffer_while(): - size = 16**4 - - @jax.jit - @jax.vmap - def f(init_step, init_xs): - def cond(carry): - step, xs = carry - return step < size - - def body(carry): - step, xs = carry - xs = xs.at[step].set(1) - return step + 1, xs - - def loop(init_xs): - return bounded_while_loop( - cond, - body, - (init_step, init_xs), - max_steps=size, - buffers=lambda i: i[1], - ) - - # Linearize so that we save residuals - return jax.linearize(loop, init_xs) - - # nontrivial batch size is important to ensure that the `.at[].set()` is really a - # scatter, and that XLA doesn't optimise it into a dynamic_update_slice. (Which - # can be switched with `select` in the compiler.) - args = jnp.array([0, 1]), jnp.zeros((2, size)) - f(*args) # compile - - speed = timeit.timeit(lambda: f(*args), number=1) - assert speed < 0.1 - - -# This isn't testing any particular failure mode: just that things generally work. -def test_speed_grad_checkpointed_while(getkey): - mlp = eqx.nn.MLP(2, 1, 2, 2, key=getkey()) - - @jax.jit - @jax.vmap - @jax.grad - def f(init_val, init_step): - def cond(carry): - step, _ = carry - return step < 8 * 16**3 - - def body(carry): - step, val = carry - (theta,) = mlp(val) - real, imag = val - z = real + imag * 1j - z = z * jnp.exp(1j * theta) - real = jnp.real(z) - imag = jnp.imag(z) - return step + 1, jnp.stack([real, imag]) - - _, final_xs = bounded_while_loop( - cond, - body, - (init_step, init_val), - max_steps=16**3, - ) - return jnp.sum(final_xs) - - init_step = jnp.array([0, 10]) - init_val = jr.normal(getkey(), (2, 2)) - - f(init_val, init_step) # compile - speed = timeit.timeit(lambda: f(init_val, init_step), number=1) - # Should take ~0.001 seconds - assert speed < 0.01 - - -# This is deliberately meant to emulate the pattern of saving used in -# `diffrax.diffeqsolve(..., saveat=SaveAt(ts=...))`. -def test_nested_loops(getkey): - @ft.partial(jax.jit, static_argnums=5) - @ft.partial(jax.vmap, in_axes=(0, 0, 0, 0, 0, None)) - def run(step, vals, ts, final_step, cotangents, true): - value, vjp_fn = jax.vjp( - lambda *v: outer_loop(step, v, ts, true, final_step), *vals - ) - cotangents = vjp_fn(cotangents) - return value, cotangents - - def outer_loop(step, vals, ts, true, final_step): - def cond(carry): - step, _ = carry - return step < final_step - - def body(carry): - step, (val1, val2, val3, val4) = carry - mul = 1 + 0.05 * jnp.sin(105 * val1 + 1) - val1 = val1 * mul - return inner_loop(step, (val1, val2, val3, val4), ts, true) - - def buffers(carry): - _, (_, val2, val3, _) = carry - return val2, val3 - - if true: - while_loop = ft.partial(_while_as_scan, max_steps=50) - else: - while_loop = ft.partial(bounded_while_loop, max_steps=50, buffers=buffers) - _, out = while_loop(cond, body, (step, vals)) - return out - - def inner_loop(step, vals, ts, true): - ts_done = jnp.floor(ts[step] + 1) - - def cond(carry): - step, _ = carry - return ts[step] < ts_done - - def body(carry): - step, (val1, val2, val3, val4) = carry - mul = 1 + 0.05 * jnp.sin(100 * val1 + 3) - val1 = val1 * mul - val2 = val2.at[step].set(val1) - val3 = val3.at[step].set(val1) - val4 = val4.at[step].set(val1) - return step + 1, (val1, val2, val3, val4) - - def buffers(carry): - _, (_, _, val3, val4) = carry - return val3, val4 - - if true: - while_loop = ft.partial(_while_as_scan, max_steps=10) - else: - while_loop = ft.partial(bounded_while_loop, max_steps=10, buffers=buffers) - return while_loop(cond, body, (step, vals)) - - step = jnp.array([0, 5]) - val1 = jr.uniform(getkey(), shape=(2,), minval=0.1, maxval=0.7) - val2 = val3 = val4 = jnp.zeros((2, 47)) - ts = jnp.stack([jnp.linspace(0, 19, 47), jnp.linspace(0, 13, 47)]) - final_step = jnp.array([46, 43]) - cotangents = ( - jr.normal(getkey(), (2,)), - jr.normal(getkey(), (2, 47)), - jr.normal(getkey(), (2, 47)), - jr.normal(getkey(), (2, 47)), - ) - - value, grads = run( - step, (val1, val2, val3, val4), ts, final_step, cotangents, False - ) - true_value, true_grads = run( - step, (val1, val2, val3, val4), ts, final_step, cotangents, True - ) - - assert shaped_allclose(value, true_value) - assert shaped_allclose(grads, true_grads, rtol=1e-4, atol=1e-5) diff --git a/test/test_brownian.py b/test/test_brownian.py index 8e23a76c..4e6b8389 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -17,6 +17,11 @@ } +def _make_struct(shape, dtype): + dtype = jax.dtypes.canonicalize_dtype(dtype) + return jax.ShapeDtypeStruct(shape, dtype) + + @pytest.mark.parametrize( "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree] ) @@ -61,9 +66,7 @@ def is_tuple_of_ints(obj): for shape, dtype in zip(shapes, dtypes): # Shape to pass as input if dtype is not None: - shape = jtu.tree_map( - jax.ShapeDtypeStruct, shape, dtype, is_leaf=is_tuple_of_ints - ) + shape = jtu.tree_map(_make_struct, shape, dtype, is_leaf=is_tuple_of_ints) if ctr is diffrax.UnsafeBrownianPath: path = ctr(shape, getkey()) @@ -79,9 +82,7 @@ def is_tuple_of_ints(obj): # Expected output shape if dtype is None: - shape = jtu.tree_map( - jax.ShapeDtypeStruct, shape, dtype, is_leaf=is_tuple_of_ints - ) + shape = jtu.tree_map(_make_struct, shape, dtype, is_leaf=is_tuple_of_ints) for _t0 in _vals.values(): for _t1 in _vals.values(): From d30a1a7233ba5c3b65b1278c64cb0867cc98d3c1 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 14 Feb 2023 01:49:08 -0800 Subject: [PATCH 11/19] Simplified running several benchmarks --- benchmarks/compile_times.py | 48 ++++++++++++++++++++++++++------- benchmarks/scan_stages.py | 24 +++++++++-------- benchmarks/scan_stages_cnf.py | 17 +++++++----- benchmarks/small_neural_ode.py | 20 ++++++++------ diffrax/brownian/path.py | 2 +- diffrax/brownian/tree.py | 2 +- diffrax/global_interpolation.py | 18 ++++++------- diffrax/integrate.py | 2 +- examples/neural_cde.ipynb | 2 +- test/test_adjoint.py | 5 ++-- 10 files changed, 89 insertions(+), 51 deletions(-) diff --git a/benchmarks/compile_times.py b/benchmarks/compile_times.py index f9598c7b..fba0ec3d 100644 --- a/benchmarks/compile_times.py +++ b/benchmarks/compile_times.py @@ -3,7 +3,6 @@ import diffrax as dfx import equinox as eqx -import fire import jax import jax.numpy as jnp import jax.random as jr @@ -31,12 +30,12 @@ def __call__(self, t, y, args): return jnp.stack(y) -def main(inline: bool, scan_stages: bool, grad: bool, adjoint: str): - if adjoint == "direct": +def run(inline: bool, scan_stages: bool, grad: bool, adjoint_name: str): + if adjoint_name == "direct": adjoint = dfx.DirectAdjoint() - elif adjoint == "recursive": + elif adjoint_name == "recursive": adjoint = dfx.RecursiveCheckpointAdjoint() - elif adjoint == "backsolve": + elif adjoint_name == "backsolve": adjoint = dfx.BacksolveAdjoint() else: raise ValueError @@ -72,9 +71,40 @@ def solve(y0): return jnp.sum(sol.ys) solve_ = ft.partial(solve, jnp.array([1.0])) - print("Compile+run time", timeit.timeit(solve_, number=1)) - print("Run time", timeit.timeit(solve_, number=1)) + compile_time = timeit.timeit(solve_, number=1) + print( + f"{inline=}, {scan_stages=}, {grad=}, adjoint={adjoint_name}, {compile_time=}" + ) -if __name__ == "__main__": - fire.Fire(main) +run(inline=False, scan_stages=False, grad=False, adjoint_name="direct") +run(inline=False, scan_stages=False, grad=False, adjoint_name="recursive") +run(inline=False, scan_stages=False, grad=False, adjoint_name="backsolve") + +run(inline=False, scan_stages=False, grad=True, adjoint_name="direct") +run(inline=False, scan_stages=False, grad=True, adjoint_name="recursive") +run(inline=False, scan_stages=False, grad=True, adjoint_name="backsolve") + +run(inline=False, scan_stages=True, grad=False, adjoint_name="direct") +run(inline=False, scan_stages=True, grad=False, adjoint_name="recursive") +run(inline=False, scan_stages=True, grad=False, adjoint_name="backsolve") + +run(inline=False, scan_stages=True, grad=True, adjoint_name="direct") +run(inline=False, scan_stages=True, grad=True, adjoint_name="recursive") +run(inline=False, scan_stages=True, grad=True, adjoint_name="backsolve") + +run(inline=True, scan_stages=False, grad=False, adjoint_name="direct") +run(inline=True, scan_stages=False, grad=False, adjoint_name="recursive") +run(inline=True, scan_stages=False, grad=False, adjoint_name="backsolve") + +run(inline=True, scan_stages=False, grad=True, adjoint_name="direct") +run(inline=True, scan_stages=False, grad=True, adjoint_name="recursive") +run(inline=True, scan_stages=False, grad=True, adjoint_name="backsolve") + +run(inline=True, scan_stages=True, grad=False, adjoint_name="direct") +run(inline=True, scan_stages=True, grad=False, adjoint_name="recursive") +run(inline=True, scan_stages=True, grad=False, adjoint_name="backsolve") + +run(inline=True, scan_stages=True, grad=True, adjoint_name="direct") +run(inline=True, scan_stages=True, grad=True, adjoint_name="recursive") +run(inline=True, scan_stages=True, grad=True, adjoint_name="backsolve") diff --git a/benchmarks/scan_stages.py b/benchmarks/scan_stages.py index 0110326b..a1f443c0 100644 --- a/benchmarks/scan_stages.py +++ b/benchmarks/scan_stages.py @@ -1,14 +1,14 @@ """Benchmarks the effect of `diffrax.AbstractRungeKutta(scan_stages=...)`. -On my CPU-only machine: +On my relatively beefy CPU-only machine: ``` -bash> python scan_stages.py False -Compile+run time 24.38062646985054 -Run time 0.0018830380868166685 +scan_stages=True +Compile+run time 1.8253102810122073 +Run time 0.00017526978626847267 -bash> python scan_stages.py True -Compile+run time 11.418417416978627 -Run time 0.0014536201488226652 +scan_stages=False +Compile+run time 10.679616351146251 +Run time 0.00021236995235085487 ``` """ @@ -17,7 +17,6 @@ import diffrax as dfx import equinox as eqx -import fire import jax.numpy as jnp import jax.random as jr @@ -44,7 +43,7 @@ def __call__(self, t, y, args): return jnp.stack(y) -def main(scan_stages): +def run(scan_stages): vf = VectorField(1, 1, 16, 2, key=jr.PRNGKey(0)) term = dfx.ODETerm(vf) solver = dfx.Dopri8(scan_stages=scan_stages) @@ -53,15 +52,18 @@ def main(scan_stages): t1 = 1 dt0 = None - @eqx.filter_jit(donate="none") + @eqx.filter_jit def solve(y0): return dfx.diffeqsolve( term, solver, t0, t1, dt0, y0, stepsize_controller=stepsize_controller ) solve_ = ft.partial(solve, jnp.array([1.0])) + print(f"scan_stages={scan_stages}") print("Compile+run time", timeit.timeit(solve_, number=1)) print("Run time", timeit.timeit(solve_, number=1)) -fire.Fire(main) +run(scan_stages=True) +print() +run(scan_stages=False) diff --git a/benchmarks/scan_stages_cnf.py b/benchmarks/scan_stages_cnf.py index 1108819a..3b8bbfa9 100644 --- a/benchmarks/scan_stages_cnf.py +++ b/benchmarks/scan_stages_cnf.py @@ -32,7 +32,6 @@ import diffrax import equinox as eqx -import fire import jax import jax.nn as jnn import jax.numpy as jnp @@ -50,7 +49,7 @@ def vector_field_prob(t, input, model): return f, logp -@eqx.filter_vmap(args=(None, 0, None, None)) +@eqx.filter_vmap(in_axes=(None, 0, None, None)) def log_prob(model, y0, scan_stages, backsolve): term = diffrax.ODETerm(vector_field_prob) solver = diffrax.Dopri5(scan_stages=scan_stages) @@ -80,14 +79,18 @@ def solve(model, inputs, scan_stages, backsolve): return -log_prob(model, inputs, scan_stages, backsolve).mean() -def main(scan_stages, backsolve): +def run(scan_stages, backsolve): mkey, dkey = jr.split(jr.PRNGKey(0), 2) model = eqx.nn.MLP(2, 2, 10, 2, activation=jnn.gelu, key=mkey) x = jr.normal(dkey, (256, 2)) - solve1 = ft.partial(solve, model, jnp.coyp(x), scan_stages, backsolve) - solve2 = ft.partial(solve, model, jnp.copy(x), scan_stages, backsolve) - print("Compile+run time", timeit.timeit(solve1, number=1)) + solve2 = ft.partial(solve, model, x, scan_stages, backsolve) + print(f"scan_stages={scan_stages}, backsolve={backsolve}") + print("Compile+run time", timeit.timeit(solve2, number=1)) print("Run time", timeit.timeit(solve2, number=1)) + print() -fire.Fire(main) +run(scan_stages=False, backsolve=False) +run(scan_stages=False, backsolve=True) +run(scan_stages=True, backsolve=False) +run(scan_stages=True, backsolve=True) diff --git a/benchmarks/small_neural_ode.py b/benchmarks/small_neural_ode.py index 1beae093..59b45ea5 100644 --- a/benchmarks/small_neural_ode.py +++ b/benchmarks/small_neural_ode.py @@ -1,9 +1,10 @@ +"""Benchmarks Diffrax vs torchdiffeq vs jax.experimental.ode.odeint""" + import gc import time import diffrax import equinox as eqx -import fire import jax import jax.experimental.ode as experimental import jax.nn as jnn @@ -166,7 +167,7 @@ def time_jax(neural_ode_jax, y0, t1, grad): _eval_jax(neural_ode_jax, y0, t1) -def main(batch_size=64, t1=100, multiple=False, grad=False): +def run(multiple, grad, batch_size=64, t1=100): neural_ode_torch = NeuralODETorch(multiple) neural_ode_diffrax = NeuralODEDiffrax(multiple) neural_ode_experimental = NeuralODEExperimental(multiple) @@ -180,7 +181,7 @@ def main(batch_size=64, t1=100, multiple=False, grad=False): func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias))) y0_jax = jrandom.normal(jrandom.PRNGKey(1), (batch_size, 4)) - y0_torch = torch.tensor(y0_jax.to_py()) + y0_torch = torch.tensor(np.asarray(y0_jax)) time_torch(neural_ode_torch, y0_torch, t1, grad) torch_time = time_torch(neural_ode_torch, y0_torch, t1, grad) @@ -192,13 +193,16 @@ def main(batch_size=64, t1=100, multiple=False, grad=False): experimental_time = time_jax(neural_ode_experimental, jnp.copy(y0_jax), t1, grad) print( - f""" - torch_time={torch_time} - diffrax_time={diffrax_time} - experimetnal_time={experimental_time} + f""" multiple={multiple}, grad={grad} + torch_time={torch_time} + diffrax_time={diffrax_time} +experimental_time={experimental_time} """ ) if __name__ == "__main__": - fire.Fire(main) + run(multiple=False, grad=False) + run(multiple=True, grad=False) + run(multiple=False, grad=True) + run(multiple=True, grad=True) diff --git a/diffrax/brownian/path.py b/diffrax/brownian/path.py index 60de8155..84019f01 100644 --- a/diffrax/brownian/path.py +++ b/diffrax/brownian/path.py @@ -62,7 +62,7 @@ def t0(self): def t1(self): return None - @eqx.filter_jit(donate="none") + @eqx.filter_jit def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]: del left t0 = eqxi.nondifferentiable(t0, name="t0") diff --git a/diffrax/brownian/tree.py b/diffrax/brownian/tree.py index 2c0f1456..0941d544 100644 --- a/diffrax/brownian/tree.py +++ b/diffrax/brownian/tree.py @@ -88,7 +88,7 @@ def __init__( ) self.key = split_by_tree(key, self.shape) - @eqx.filter_jit(donate="none") + @eqx.filter_jit def evaluate( self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True ) -> PyTree[Array]: diff --git a/diffrax/global_interpolation.py b/diffrax/global_interpolation.py index 0c5b894e..ee1dcdaa 100644 --- a/diffrax/global_interpolation.py +++ b/diffrax/global_interpolation.py @@ -76,7 +76,7 @@ def _check(_ys): jtu.tree_map(_check, self.ys) - @eqx.filter_jit(donate="none") + @eqx.filter_jit def evaluate( self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True ) -> PyTree: @@ -130,7 +130,7 @@ def _index(_ys): prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t) ).ω - @eqx.filter_jit(donate="none") + @eqx.filter_jit def derivative(self, t: Scalar, left: bool = True) -> PyTree: r"""Evaluate the derivative of the linear interpolation. Essentially equivalent to `jax.jvp(self.evaluate, (t,), (jnp.ones_like(t),))`. @@ -195,7 +195,7 @@ def _check(d, c, b, a): jtu.tree_map(_check, *self.coeffs) - @eqx.filter_jit(donate="none") + @eqx.filter_jit def evaluate( self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True ) -> PyTree: @@ -239,7 +239,7 @@ def evaluate( + frac * (ω(b)[index] + frac * (ω(c)[index] + frac * ω(d)[index])) ).ω - @eqx.filter_jit(donate="none") + @eqx.filter_jit def derivative(self, t: Scalar, left: bool = True) -> PyTree: r"""Evaluate the derivative of the cubic interpolation. Essentially equivalent to `jax.jvp(self.evaluate, (t,), (jnp.ones_like(t),))`. @@ -309,7 +309,7 @@ def _get_local_interpolation(self, t: Scalar, left: bool): infos = ω(self.infos)[index].ω return self.interpolation_cls(t0=prev_t, t1=next_t, **infos) - @eqx.filter_jit(donate="none") + @eqx.filter_jit def evaluate( self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True ) -> PyTree: @@ -320,7 +320,7 @@ def evaluate( # continuous. return self._get_local_interpolation(t0, left).evaluate(t0) - @eqx.filter_jit(donate="none") + @eqx.filter_jit def derivative(self, t: Scalar, left: bool = True) -> PyTree: # Passing `left` doesn't matter on a local interpolation, which is globally # continuous. @@ -420,7 +420,7 @@ def _linear_interpolation( return ys -@eqx.filter_jit(donate="none") +@eqx.filter_jit def linear_interpolation( ts: Array["times"], # noqa: F821 ys: PyTree["times", ...], # noqa: F821 @@ -474,7 +474,7 @@ def _rectilinear_interpolation( return ts, ys -@eqx.filter_jit(donate="none") +@eqx.filter_jit def rectilinear_interpolation( ts: Array["times"], # noqa: F821 ys: PyTree["times", ...], # noqa: F821 @@ -659,7 +659,7 @@ def _backward_hermite_coefficients( return ds, cs, bs, as_ -@eqx.filter_jit(donate="none") +@eqx.filter_jit def backward_hermite_coefficients( ts: Array["times"], # noqa: F821 ys: PyTree["times", ...], # noqa: F821 diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 812f8fb4..96258d27 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -334,7 +334,7 @@ def maybe_inplace(i, u, x): return eqx.tree_at(lambda s: s.result, final_state, result), aux_stats -@eqx.filter_jit(donate="none") +@eqx.filter_jit def diffeqsolve( terms: PyTree[AbstractTerm], solver: AbstractSolver, diff --git a/examples/neural_cde.ipynb b/examples/neural_cde.ipynb index c989541f..d894c847 100644 --- a/examples/neural_cde.ipynb +++ b/examples/neural_cde.ipynb @@ -275,7 +275,7 @@ "\n", " # Training loop like normal.\n", "\n", - " @eqx.filter_jit(donate=\"none\")\n", + " @eqx.filter_jit\n", " def loss(model, ti, label_i, coeff_i):\n", " pred = jax.vmap(model)(ti, coeff_i)\n", " # Binary cross-entropy\n", diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 42ec73e1..3c940c6c 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -55,10 +55,9 @@ def _run(y0__args__term, saveat, adjoint): _run_grad = eqx.filter_jit( jax.grad( lambda d, saveat, adjoint: _run(eqx.combine(d, nondiff), saveat, adjoint) - ), - donate="none", + ) ) - _run_grad_int = eqx.filter_jit(jax.grad(_run, allow_int=True), donate="none") + _run_grad_int = eqx.filter_jit(jax.grad(_run, allow_int=True)) # Yep, test that they're not implemented. We can remove these checks if we ever # do implement them. From c068cfa90a7bba118a27f48101481b818af3d33d Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 15 Feb 2023 17:55:13 -0800 Subject: [PATCH 12/19] Fixes #216: dense interpolation with t0==t1 --- diffrax/global_interpolation.py | 34 +++++++++++++++++++++++++-------- diffrax/integrate.py | 2 +- test/test_saveat_solution.py | 18 +++++++++++++++++ 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/diffrax/global_interpolation.py b/diffrax/global_interpolation.py index ee1dcdaa..e42675b6 100644 --- a/diffrax/global_interpolation.py +++ b/diffrax/global_interpolation.py @@ -288,6 +288,7 @@ class DenseInterpolation(AbstractGlobalInterpolation): ts_size: Int infos: DenseInfos direction: Scalar + y0: PyTree[Array] interpolation_cls: Type[AbstractLocalInterpolation] = eqx.static_field() def __post_init__(self): @@ -315,26 +316,43 @@ def evaluate( ) -> PyTree: if t1 is not None: return self.evaluate(t1, left=left) - self.evaluate(t0, left=left) - t0 = t0 * self.direction - # Passing `left` doesn't matter on a local interpolation, which is globally - # continuous. - return self._get_local_interpolation(t0, left).evaluate(t0) + t = t0 * self.direction + ts_0 = self.ts[0] + ts_1 = self.ts[self.ts_size - 1] + _to_int = lambda x: jnp.where(x, 1, 0) + index = _to_int(t < ts_0) + _to_int(t <= ts_0) + _to_int(t <= ts_1) + _nan = self.__class__._nan + _y0 = lambda s: s.y0 + _evaluate = ft.partial(self.__class__._evaluate, t=t0, left=left) + return lax.switch(index, [_nan, _evaluate, _y0, _nan], self) @eqx.filter_jit def derivative(self, t: Scalar, left: bool = True) -> PyTree: - # Passing `left` doesn't matter on a local interpolation, which is globally - # continuous. t = t * self.direction - out = self._get_local_interpolation(t, left).derivative(t) + ts_0 = self.ts[0] + ts_1 = self.ts[self.ts_size - 1] + pred = (t >= ts_0) & (t <= ts_1) + _derivative = ft.partial(self.__class__._derivative, t=t, left=left) + _nan = self.__class__._nan + return lax.cond(pred, _derivative, _nan, self) + + def _evaluate(self, t, left): + return self._get_local_interpolation(t, left).evaluate(t, left=left) + + def _derivative(self, t, left): + out = self._get_local_interpolation(t, left).derivative(t, left=left) return (self.direction * out**ω).ω + def _nan(self): + return jtu.tree_map(ft.partial(jnp.full_like, fill_value=jnp.nan), self.y0) + @property def t0(self): return self.ts[0] * self.direction @property def t1(self): - return self.ts[-1] * self.direction + return self.ts[self.ts_size - 1] * self.direction # diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 96258d27..c49759bf 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -610,7 +610,6 @@ def _promote(yi): ys = jtu.tree_map(lambda y: jnp.full((out_size,) + jnp.shape(y), jnp.inf), y0) result = jnp.array(RESULTS.successful) if saveat.dense: - t0 = eqxi.error_if(t0, t0 == t1, "Cannot save dense output if t0 == t1") if max_steps is None: raise ValueError( "`max_steps=None` is incompatible with `saveat.dense=True`" @@ -701,6 +700,7 @@ def _promote(yi): ts_size=final_state.dense_save_index + 1, interpolation_cls=solver.interpolation_cls, infos=final_state.dense_infos, + y0=y0, direction=direction, ) else: diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 4788cbba..e86299c7 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -144,3 +144,21 @@ def test_saveat_solution(): assert shaped_allclose(sol.derivative(0.2), -0.5 * _y0 * math.exp(-0.05)) assert sol.stats["num_steps"] > 0 assert sol.result == diffrax.RESULTS.successful + + +def test_trivial_dense(): + term = diffrax.ODETerm(lambda t, y, args: -0.5 * y) + y0 = jnp.array([2.1]) + saveat = diffrax.SaveAt(dense=True) + stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8) + sol = diffrax.diffeqsolve( + term, + t0=2.0, + t1=2.0, + y0=y0, + dt0=None, + solver=diffrax.Dopri5(), + saveat=saveat, + stepsize_controller=stepsize_controller, + ) + assert shaped_allclose(sol.evaluate(2.0), y0) From f077759beeddad946d67a26b8624b04e47d4c9b0 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 16 Feb 2023 10:42:09 -0800 Subject: [PATCH 13/19] Update versions --- .pre-commit-config.yaml | 4 ++-- README.md | 2 +- setup.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a1c8d723..dad2385a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,13 +4,13 @@ repos: hooks: - id: black - repo: https://github.com/nbQA-dev/nbQA - rev: 1.2.3 + rev: 1.6.3 hooks: - id: nbqa-black - id: nbqa-isort - id: nbqa-flake8 - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort - repo: https://github.com/pycqa/flake8 diff --git a/README.md b/README.md index 4692b526..cdb73773 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ _From a technical point of view, the internal structure of the library is pretty pip install diffrax ``` -Requires Python >=3.8 and JAX >=0.4.1. +Requires Python >=3.8 and JAX >=0.4.3. ## Documentation diff --git a/setup.py b/setup.py index c8820668..2c8bad1f 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ python_requires = "~=3.8" -install_requires = ["jax>=0.4.1", "equinox>=0.10.0"] +install_requires = ["jax>=0.4.3", "equinox>=0.10.0"] setuptools.setup( name=name, From 86bc0042b6123db2999177662023bd7c7999dc93 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 16 Feb 2023 14:40:24 -0800 Subject: [PATCH 14/19] Updated autocitation for RecursiveCheckpointAdjoint --- diffrax/autocitation.py | 72 +++++++++++++++++++++++++++++++++-------- docs/api/citation.md | 2 +- mkdocs.yml | 1 - 3 files changed, 60 insertions(+), 15 deletions(-) diff --git a/diffrax/autocitation.py b/diffrax/autocitation.py index 92f2f974..36f83ed1 100644 --- a/diffrax/autocitation.py +++ b/diffrax/autocitation.py @@ -6,7 +6,7 @@ import jax import jax.tree_util as jtu -from .adjoint import BacksolveAdjoint, RecursiveCheckpointAdjoint +from .adjoint import BacksolveAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint from .brownian import VirtualBrownianTree from .heuristics import is_cde, is_sde from .integrate import diffeqsolve @@ -244,23 +244,13 @@ def _backsolve_adjoint(adjoint, terms=None): @citation_rules.append def _discrete_adjoint(adjoint): - if type(adjoint) in (RecursiveCheckpointAdjoint,): + if type(adjoint) in (RecursiveCheckpointAdjoint, DirectAdjoint): pieces = [] pieces.append( r""" -% You are differentiating using discretise-then-optimise. The following papers may be -% relevant. +% You are differentiating using discretise-then-optimise. """ ) - if type(adjoint) is RecursiveCheckpointAdjoint: - pieces.append( - r""" -% If using reverse-mode autodifferentiation (backpropagation), then you are -% using binomial checkpointing ("treeverse"), which was introduced in: -""" - + _parse_reference(RecursiveCheckpointAdjoint) - ) - pieces.append( r""" % If using forward-mode autodifferentiation, then this was studied in: @@ -276,6 +266,62 @@ def _discrete_adjoint(adjoint): } """ ) + if type(adjoint) is RecursiveCheckpointAdjoint: + pieces.append( + r""" +% If using reverse-mode autodifferentiation (backpropagation), then you are using +% online recursive checkpointing in order to minimise memory usage. This was developed +% in: +@article{stumm2010new, + author = {Stumm, Philipp and Walther, Andrea}, + title = {New Algorithms for Optimal Online Checkpointing}, + journal = {SIAM Journal on Scientific Computing}, + volume = {32}, + number = {2}, + pages = {836--854}, + year = {2010}, + doi = {10.1137/080742439}, +} +@article{wang2009minimal, + author = {Wang, Qiqi and Moin, Parviz and Iaccarino, Gianluca}, + title = {Minimal Repetition Dynamic Checkpointing Algorithm for Unsteady + Adjoint Calculation}, + journal = {SIAM Journal on Scientific Computing}, + volume = {31}, + number = {4}, + pages = {2549--2567}, + year = {2009}, + doi = {10.1137/080727890}, +} + +% In addition, the equivalent offline recursive checkpointing scheme (also known as +% "treeverse", "binary checkpointing", or "revolve") was developed in: +@article{griewank1992achieving, + author = {Griewank, Andreas}, + title = {Achieving logarithmic growth of temporal and spatial complexity in + reverse automatic differentiation}, + journal = {Optimization Methods and Software}, + volume = {1}, + number = {1}, + pages = {35--54}, + year = {1992}, + publisher = {Taylor & Francis}, + doi = {10.1080/10556789208805505}, +} +@article{griewank2000revolve, + author = {Griewank, Andreas and Walther, Andrea}, + title = {Algorithm 799: Revolve: An Implementation of Checkpointing for the + Reverse or Adjoint Mode of Computational Differentiation}, + year = {2000}, + publisher = {Association for Computing Machinery}, + volume = {26}, + number = {1}, + doi = {10.1145/347837.347846}, + journal = {ACM Trans. Math. Softw.}, + pages = {19--45}, +} +""" + ) return "\n".join([p.strip() for p in pieces]) diff --git a/docs/api/citation.md b/docs/api/citation.md index 2cc2d588..f68aa9ce 100644 --- a/docs/api/citation.md +++ b/docs/api/citation.md @@ -1,4 +1,4 @@ -# Create citations +# Autocitations Diffrax can autogenerate BibTeX citations for all the numerical methods you use. diff --git a/mkdocs.yml b/mkdocs.yml index 05d9ea99..18cf5380 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -133,4 +133,3 @@ nav: - Developer Documentation: - 'devdocs/predictor_dirk.md' - 'devdocs/adjoint_commutative_noise.md' - - 'devdocs/bounded_while_loop.md' From 43b8601eb320669db4111f7685e892e1e2318ec2 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 20 Feb 2023 10:04:46 -0800 Subject: [PATCH 15/19] - Added support for SubSaveAt - SaveAt(dense=True) now supports t0==t1 - `AbstractSolver.term_structure` should now be a `PyTree[Type[AbstractTerm]]` rather than a `PyTreeDef`. --- diffrax/__init__.py | 2 +- diffrax/adjoint.py | 204 +++++++++------ diffrax/autocitation.py | 12 +- diffrax/brownian/path.py | 2 +- diffrax/brownian/tree.py | 2 + diffrax/custom_types.py | 2 - diffrax/global_interpolation.py | 47 ++-- diffrax/integrate.py | 355 ++++++++++++++++---------- diffrax/saveat.py | 155 +++++++++-- diffrax/solution.py | 4 +- diffrax/solver/base.py | 6 +- diffrax/solver/euler.py | 3 +- diffrax/solver/euler_heun.py | 7 +- diffrax/solver/implicit_euler.py | 3 +- diffrax/solver/leapfrog_midpoint.py | 3 +- diffrax/solver/milstein.py | 6 +- diffrax/solver/reversible_heun.py | 3 +- diffrax/solver/runge_kutta.py | 2 +- diffrax/solver/semi_implicit_euler.py | 3 +- 19 files changed, 545 insertions(+), 276 deletions(-) diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 75a5d268..bd3acff7 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -35,7 +35,7 @@ NonlinearSolution, ) from .path import AbstractPath -from .saveat import SaveAt +from .saveat import SaveAt, SubSaveAt from .solution import is_event, is_okay, is_successful, RESULTS, Solution from .solver import ( AbstractAdaptiveSolver, diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index 40ec4602..fb4d2c36 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -1,11 +1,11 @@ import abc import functools as ft import warnings -from dataclasses import fields from typing import Any, Dict, Optional import equinox as eqx import equinox.internal as eqxi +import jax import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu @@ -13,7 +13,7 @@ from .ad import implicit_jvp from .heuristics import is_sde, is_unsafe_sde -from .saveat import SaveAt +from .saveat import save_y, SaveAt, SubSaveAt from .solver import AbstractItoSolver, AbstractStratonovichSolver from .term import AbstractTerm, AdjointTerm @@ -22,28 +22,86 @@ def _is_none(x): return x is None +def _is_subsaveat(x: Any) -> bool: + return isinstance(x, SubSaveAt) + + +def _nondiff_solver_controller_state( + adjoint, init_state, passed_solver_state, passed_controller_state +): + if passed_solver_state: + name = ( + f"When using `adjoint={adjoint.__class__.__name__}()`, then `solver_state`" + ) + solver_fn = ft.partial( + eqxi.nondifferentiable, + name=name, + ) + else: + solver_fn = lax.stop_gradient + if passed_controller_state: + name = ( + f"When using `adjoint={adjoint.__class__.__name__}()`, then " + "`controller_state`" + ) + controller_fn = ft.partial( + eqxi.nondifferentiable, + name=name, + ) + else: + controller_fn = lax.stop_gradient + init_state = eqx.tree_at( + lambda s: s.solver_state, + init_state, + replace_fn=solver_fn, + is_leaf=_is_none, + ) + init_state = eqx.tree_at( + lambda s: s.controller_state, + init_state, + replace_fn=controller_fn, + is_leaf=_is_none, + ) + return init_state + + def _only_transpose_ys(final_state): - entries = ( + from .integrate import SaveState + + is_save_state = lambda x: isinstance(x, SaveState) + + def get_ys(_final_state): + return [ + s.ys + for s in jtu.tree_leaves(_final_state.save_state, is_leaf=is_save_state) + ] + + ys = get_ys(final_state) + + named_nondiff_entries = ( "y", "tprev", "tnext", "solver_state", "controller_state", - "ts", "dense_ts", "dense_infos", ) - values = { - k: eqxi.nondifferentiable_backward( - getattr(final_state, k), name=k, symbolic=False - ) - for k in entries - } - values["ys"] = final_state.ys + named_nondiff_values = tuple( + eqxi.nondifferentiable_backward(getattr(final_state, k), name=k, symbolic=False) + for k in named_nondiff_entries + ) + final_state = eqxi.nondifferentiable_backward(final_state, symbolic=False) - get = lambda s: tuple(getattr(s, k) for k in entries + ("ys",)) - replace = tuple(values[k] for k in entries + ("ys",)) - final_state = eqx.tree_at(get, final_state, replace, is_leaf=_is_none) + + get_named_nondiff_entries = lambda s: tuple( + getattr(s, k) for k in named_nondiff_entries + ) + final_state = eqx.tree_at( + get_named_nondiff_entries, final_state, named_nondiff_values, is_leaf=_is_none + ) + + final_state = eqx.tree_at(get_ys, final_state, ys) return final_state @@ -99,24 +157,8 @@ def _diffeqsolve(self): return diffeqsolve -def _inner_buffers(state): - assert type(state).__name__ == "_InnerState" - assert {f.name for f in fields(state)} == { - "ts", - "ys", - "saveat_ts_index", - "save_index", - } - return state.ts, state.ys - - -def _outer_buffers(state): - assert type(state).__name__ == "_State" - return state.ts, state.ys, state.dense_ts, state.dense_infos - - -_inner_loop = ft.partial(eqxi.while_loop, buffers=_inner_buffers) -_outer_loop = ft.partial(eqxi.while_loop, buffers=_outer_buffers) +_inner_loop = jax.named_call(eqxi.while_loop, name="inner-loop") +_outer_loop = jax.named_call(eqxi.while_loop, name="outer-loop") def _uncallable(*args, **kwargs): @@ -226,10 +268,7 @@ def loop( "instead." ) if self.checkpoints is None and max_steps is None: - if saveat.ts is None: - inner_while_loop = _uncallable - else: - inner_while_loop = ft.partial(_inner_loop, kind="lax") + inner_while_loop = ft.partial(_inner_loop, kind="lax") outer_while_loop = ft.partial(_outer_loop, kind="lax") msg = ( "Cannot reverse-mode autodifferentiate when using " @@ -239,12 +278,7 @@ def loop( "number of steps, or explicitly specify how many checkpoints to use." ) else: - if saveat.ts is None: - inner_while_loop = _uncallable - else: - inner_while_loop = ft.partial( - _inner_loop, kind="checkpointed", checkpoints=len(saveat.ts) - ) + inner_while_loop = ft.partial(_inner_loop, kind="checkpointed") outer_while_loop = ft.partial( _outer_loop, kind="checkpointed", checkpoints=self.checkpoints ) @@ -349,8 +383,12 @@ def loop( def _vf(ys, residual, args__terms, closure): state_no_y, _ = residual t = state_no_y.tprev - # unpack length-1 dimension - y = jtu.tree_map(lambda _y: _y[0], ys) + + def _unpack(_y): + (_y1,) = _y + return _y1 + + y = jtu.tree_map(_unpack, ys) args, terms = args__terms _, _, solver, _, _ = closure return solver.func(terms, t, y, args) @@ -371,8 +409,12 @@ def _solve(args__terms, closure): ) # Note that we use .ys not .y here. The former is what is actually returned # by diffeqsolve, so it is the thing we want to attach the tangent to. - return final_state.ys, ( - eqx.tree_at(lambda s: s.ys, final_state, None), + # + # Note that `final_state.save_state` has type PyTree[SaveState]. To access `.ys` + # we are assuming that this PyTree has trivial structure. This is the case because + # of the guard in `ImplicitAdjoint` that `saveat` be `SaveAt(t1=True)`. + return final_state.save_state.ys, ( + eqx.tree_at(lambda s: s.save_state.ys, final_state, None), aux_stats, ) @@ -410,28 +452,18 @@ def loop( "Can only use `adjoint=ImplicitAdjoint()` with " "`saveat=SaveAt(t1=True)`." ) - - if not passed_solver_state: - init_state = eqx.tree_at( - lambda s: s.solver_state, - init_state, - replace_fn=lax.stop_gradient, - is_leaf=_is_none, - ) - if not passed_controller_state: - init_state = eqx.tree_at( - lambda s: s.controller_state, - init_state, - replace_fn=lax.stop_gradient, - is_leaf=_is_none, - ) - + init_state = _nondiff_solver_controller_state( + self, init_state, passed_solver_state, passed_controller_state + ) closure = (self, kwargs, solver, saveat, init_state) ys, residual = implicit_jvp(_solve, _vf, (args, terms), closure) final_state_no_ys, aux_stats = residual + # Note that `final_state.save_state` has type PyTree[SaveState]. To access `.ys` + # we are assuming that this PyTree has trivial structure. This is the case + # because of the guard that `saveat` be `SaveAt(t1=True)`. final_state = eqx.tree_at( - lambda s: s.ys, final_state_no_ys, ys, is_leaf=_is_none + lambda s: s.save_state.ys, final_state_no_ys, ys, is_leaf=_is_none ) final_state = _only_transpose_ys(final_state) return final_state, aux_stats @@ -445,9 +477,7 @@ def loop( def _loop_backsolve(y__args__terms, *, self, throw, init_state, **kwargs): del throw y, args, terms = y__args__terms - init_state = eqx.tree_at( - lambda s: jtu.tree_leaves(s.y), init_state, jtu.tree_leaves(y) - ) + init_state = eqx.tree_at(lambda s: s.y, init_state, y) del y return self._loop( args=args, @@ -461,8 +491,10 @@ def _loop_backsolve(y__args__terms, *, self, throw, init_state, **kwargs): def _loop_backsolve_fwd(y__args__terms, **kwargs): final_state, aux_stats = _loop_backsolve(y__args__terms, **kwargs) - ts = final_state.ts - ys = final_state.ys + # Note that `final_state.save_state` has type `PyTree[SaveState]`; here we are + # relying on the guard in `BacksolveAdjoint` that it have trivial structure. + ts = final_state.save_state.ts + ys = final_state.save_state.ys return (final_state, aux_stats), (ts, ys) @@ -493,7 +525,9 @@ def _loop_backsolve_bwd( ts, ys = residuals del residuals grad_final_state, _ = grad_final_state__aux_stats - grad_ys = grad_final_state.ys + # Note that `grad_final_state.save_state` has type `PyTree[SaveState]`; here we are + # relying on the guard in `BacksolveAdjoint` that it have trivial structure. + grad_ys = grad_final_state.save_state.ys del grad_final_state, grad_final_state__aux_stats y, args, terms = y__args__terms del y__args__terms @@ -521,7 +555,9 @@ def _loop_backsolve_bwd( kwargs.update(self.kwargs) del self, solver, stepsize_controller, adjoint_terms, dt0, max_steps, throw del y, args, terms - saveat_t0 = saveat.t0 + # Note that `saveat.subs` has type `PyTree[SubSaveAt]`. Here we use the assumption + # (checked in `BacksolveAdjoint`) that it has trivial pytree structure. + saveat_t0 = saveat.subs.t0 del saveat # @@ -675,9 +711,10 @@ def __init__(self, **kwargs): } given_keys = set(kwargs.keys()) diff_keys = given_keys - valid_keys - if len(diff_keys): + if len(diff_keys) > 0: raise ValueError( - f"The following keys are not valid for `BacksolveAdjoint`: {diff_keys}" + "The following keyword argments are not valid for `BacksolveAdjoint`: " + f"{diff_keys}" ) self.kwargs = kwargs @@ -693,11 +730,20 @@ def loop( passed_controller_state, **kwargs, ): - del passed_solver_state, passed_controller_state - if saveat.steps or saveat.dense: + if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure( + 0 + ): + raise NotImplementedError( + "Cannot use `adjoint=BacksolveAdjoint()` with `SaveAt(subs=...)`." + ) + if saveat.dense or saveat.subs.steps: raise NotImplementedError( "Cannot use `adjoint=BacksolveAdjoint()` with " - "`saveat=Steps(steps=True)` or `saveat=Steps(dense=True)`." + "`saveat=SaveAt(steps=True)` or saveat=SaveAt(dense=True)`." + ) + if saveat.subs.fn is not save_y: + raise NotImplementedError( + "Cannot use `adjoint=BacksolveAdjoint()` with `saveat=SaveAt(fn=...)`." ) if is_unsafe_sde(terms): raise ValueError( @@ -713,16 +759,16 @@ def loop( ) elif not isinstance(solver, AbstractStratonovichSolver): warnings.warn( - f"{solver.___class__._name__} is not marked as converging to " + f"{solver.__class__.__name__} is not marked as converging to " "either the Itô or the Stratonovich solution. Note that " "`BacksolveAdjoint` will only produce the correct solution for " "Stratonovich SDEs." ) y = init_state.y - sentinel = object() - init_state = eqx.tree_at( - lambda s: jtu.tree_leaves(s.y), init_state, replace_fn=lambda _: sentinel + init_state = eqx.tree_at(lambda s: s.y, init_state, object()) + init_state = _nondiff_solver_controller_state( + self, init_state, passed_solver_state, passed_controller_state ) final_state, aux_stats = _loop_backsolve( diff --git a/diffrax/autocitation.py b/diffrax/autocitation.py index 36f83ed1..251ab0be 100644 --- a/diffrax/autocitation.py +++ b/diffrax/autocitation.py @@ -11,6 +11,7 @@ from .heuristics import is_cde, is_sde from .integrate import diffeqsolve from .misc import adjoint_rms_seminorm +from .saveat import SubSaveAt from .solver import ( AbstractImplicitSolver, Dopri5, @@ -432,6 +433,9 @@ def _sde(terms): """ +_is_subsaveat = lambda x: isinstance(x, SubSaveAt) + + @citation_rules.append def _solvers(solver, saveat=None): if type(solver) in ( @@ -478,7 +482,13 @@ def _solvers(solver, saveat=None): """ + ref1 ) - if saveat is not None and (saveat.ts or saveat.dense): + if saveat is not None and ( + saveat.dense + or ( + subsaveat.ts is not None + for subsaveat in jtu.tree_leaves(saveat, is_leaf=_is_subsaveat) + ) + ): msg += ( r""" % Output via `SaveAt(ts=...)` or `SaveAt(dense=True)` is done using the diff --git a/diffrax/brownian/path.py b/diffrax/brownian/path.py index 84019f01..c1d5f95f 100644 --- a/diffrax/brownian/path.py +++ b/diffrax/brownian/path.py @@ -66,7 +66,7 @@ def t1(self): def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]: del left t0 = eqxi.nondifferentiable(t0, name="t0") - t1 = eqxi.nondifferentiable(t1, name="t0") + t1 = eqxi.nondifferentiable(t1, name="t1") t0_ = force_bitcast_convert_type(t0, jnp.int32) t1_ = force_bitcast_convert_type(t1, jnp.int32) key = jrandom.fold_in(self.key, t0_) diff --git a/diffrax/brownian/tree.py b/diffrax/brownian/tree.py index 0941d544..ccdd7208 100644 --- a/diffrax/brownian/tree.py +++ b/diffrax/brownian/tree.py @@ -93,9 +93,11 @@ def evaluate( self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True ) -> PyTree[Array]: del left + t0 = eqxi.nondifferentiable(t0, name="t0") if t1 is None: return self._evaluate(t0) else: + t1 = eqxi.nondifferentiable(t1, name="t1") return jtu.tree_map( lambda x, y: x - y, self._evaluate(t1), diff --git a/diffrax/custom_types.py b/diffrax/custom_types.py index d23ec8af..624f47f4 100644 --- a/diffrax/custom_types.py +++ b/diffrax/custom_types.py @@ -129,5 +129,3 @@ def __class_getitem__(cls, item): DenseInfo = Dict[str, PyTree[Array]] DenseInfos = Dict[str, PyTree[Array["times", ...]]] # noqa: F821 - -PyTreeDef = type(jtu.tree_structure(0)) diff --git a/diffrax/global_interpolation.py b/diffrax/global_interpolation.py index e42675b6..8cd67914 100644 --- a/diffrax/global_interpolation.py +++ b/diffrax/global_interpolation.py @@ -285,11 +285,12 @@ def derivative(self, t: Scalar, left: bool = True) -> PyTree: class DenseInterpolation(AbstractGlobalInterpolation): - ts_size: Int + ts_size: Int # Takes values in {1, 2, 3, ...} infos: DenseInfos - direction: Scalar - y0: PyTree[Array] interpolation_cls: Type[AbstractLocalInterpolation] = eqx.static_field() + direction: Scalar + t0_if_trivial: Array + y0_if_trivial: PyTree[Array] def __post_init__(self): def _check(_infos): @@ -319,22 +320,26 @@ def evaluate( t = t0 * self.direction ts_0 = self.ts[0] ts_1 = self.ts[self.ts_size - 1] - _to_int = lambda x: jnp.where(x, 1, 0) - index = _to_int(t < ts_0) + _to_int(t <= ts_0) + _to_int(t <= ts_1) - _nan = self.__class__._nan - _y0 = lambda s: s.y0 - _evaluate = ft.partial(self.__class__._evaluate, t=t0, left=left) - return lax.switch(index, [_nan, _evaluate, _y0, _nan], self) + pred = (self.ts_size > 1) & (t >= ts_0) & (t <= ts_1) + eval_fn = ft.partial(self.__class__._evaluate, t=t, left=left) + nan_fn = self.__class__._nan + # Use cond to avoid generating nans unless we have to. + out = lax.cond(pred, eval_fn, nan_fn, self) + keep = ft.partial(jnp.where, (t == self.t0_if_trivial) & (self.ts_size == 1)) + return jtu.tree_map(keep, self.y0_if_trivial, out) @eqx.filter_jit def derivative(self, t: Scalar, left: bool = True) -> PyTree: t = t * self.direction + # Note that len(self.ts) == max_steps + 1 > 0 so the indexing is always valid, + # even if we throw it away because self.ts_size == 0. ts_0 = self.ts[0] ts_1 = self.ts[self.ts_size - 1] - pred = (t >= ts_0) & (t <= ts_1) - _derivative = ft.partial(self.__class__._derivative, t=t, left=left) - _nan = self.__class__._nan - return lax.cond(pred, _derivative, _nan, self) + pred = (self.ts_size > 1) & (t >= ts_0) & (t <= ts_1) + deriv_fn = ft.partial(self.__class__._derivative, t=t, left=left) + nan_fn = self.__class__._nan + # Use cond to avoid generating nans unless we have to. + return lax.cond(pred, deriv_fn, nan_fn, self) def _evaluate(self, t, left): return self._get_local_interpolation(t, left).evaluate(t, left=left) @@ -344,15 +349,25 @@ def _derivative(self, t, left): return (self.direction * out**ω).ω def _nan(self): - return jtu.tree_map(ft.partial(jnp.full_like, fill_value=jnp.nan), self.y0) + return jtu.tree_map( + ft.partial(jnp.full_like, fill_value=jnp.nan), self.y0_if_trivial + ) @property def t0(self): - return self.ts[0] * self.direction + # Note that len(self.ts) == max_steps + 1 > 0 so the indexing is always valid, + # even if we throw it away because self.ts_size == 0. + ts_0 = jnp.where(self.ts_size == 1, self.t0_if_trivial, self.ts[0]) + return ts_0 * self.direction @property def t1(self): - return self.ts[self.ts_size - 1] * self.direction + # Note that len(self.ts) == max_steps + 1 > 0 so the indexing is always valid, + # even if we throw it away because self.ts_size == 0. + ts_1 = jnp.where( + self.ts_size == 1, self.t0_if_trivial, self.ts[self.ts_size - 1] + ) + return ts_1 * self.direction # diff --git a/diffrax/integrate.py b/diffrax/integrate.py index c49759bf..feb9b3a5 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -1,6 +1,7 @@ import functools as ft +import typing import warnings -from typing import Optional +from typing import Any, Callable, Optional import equinox as eqx import equinox.internal as eqxi @@ -8,12 +9,12 @@ import jax.numpy as jnp import jax.tree_util as jtu -from .adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint +from .adjoint import AbstractAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint from .custom_types import Array, Bool, Int, PyTree, Scalar from .event import AbstractDiscreteTerminatingEvent from .global_interpolation import DenseInterpolation from .heuristics import is_sde, is_unsafe_sde -from .saveat import SaveAt +from .saveat import SaveAt, SubSaveAt from .solution import is_okay, is_successful, RESULTS, Solution from .solver import AbstractItoSolver, AbstractSolver, AbstractStratonovichSolver, Euler from .step_size_controller import ( @@ -26,7 +27,14 @@ from .term import AbstractTerm, WrapTerm -class _State(eqx.Module): +class SaveState(eqx.Module): + saveat_ts_index: Int + ts: Array["times"] # noqa: F821 + ys: PyTree[Array["times", ...]] # noqa: F821 + save_index: Int + + +class State(eqx.Module): # Evolving state during the solve y: Array["state"] # noqa: F821 tprev: Scalar @@ -39,36 +47,49 @@ class _State(eqx.Module): num_accepted_steps: Int num_rejected_steps: Int # Output that is .at[].set() updated during the solve (and their indices) - saveat_ts_index: Scalar - ts: Array["times"] # noqa: F821 - ys: PyTree[Array["times", ...]] # noqa: F821 - save_index: Int + save_state: PyTree[SaveState] dense_ts: Optional[Array["times + 1"]] # noqa: F821 dense_infos: Optional[PyTree[Array["times", ...]]] # noqa: F821 dense_save_index: Int -class _InnerState(eqx.Module): - saveat_ts_index: Int - ts: Array["times"] # noqa: F821 - ys: PyTree[Array["times", ...]] # noqa: F821 - save_index: Int +def _is_none(x): + return x is None + + +def _is_subsaveat(x: Any) -> bool: + return isinstance(x, SubSaveAt) + +def _inner_buffers(save_state): + assert type(save_state) is SaveState + return save_state.ts, save_state.ys -def _save(state: _State, t: Scalar) -> _State: - ts = state.ts - ys = state.ys - save_index = state.save_index - y = state.y + +def _outer_buffers(state): + assert type(state) is State + is_save_state = lambda x: isinstance(x, SaveState) + save_states = jtu.tree_leaves(state.save_state, is_leaf=is_save_state) + return ( + [s.ts for s in save_states] + + [s.ys for s in save_states] + + [state.dense_ts, state.dense_infos] + ) + + +def _save( + t: Scalar, y: PyTree[Array], args: PyTree, fn: Callable, save_state: SaveState +) -> SaveState: + ts = save_state.ts + ys = save_state.ys + save_index = save_state.save_index ts = ts.at[save_index].set(t) - ys = jtu.tree_map(lambda ys_, y_: ys_.at[save_index].set(y_), ys, y) + ys = jtu.tree_map(lambda ys_, y_: ys_.at[save_index].set(y_), ys, fn(t, y, args)) save_index = save_index + 1 return eqx.tree_at( - lambda s: [s.ts, s.save_index] + jtu.tree_leaves(s.ys), - state, - [ts, save_index] + jtu.tree_leaves(ys), + lambda s: [s.ts, s.ys, s.save_index], save_state, [ts, ys, save_index] ) @@ -102,13 +123,23 @@ def loop( outer_while_loop, ): - if saveat.t0: - init_state = _save(init_state, t0) if saveat.dense: dense_ts = init_state.dense_ts dense_ts = dense_ts.at[0].set(t0) init_state = eqx.tree_at(lambda s: s.dense_ts, init_state, dense_ts) + def save_t0(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: + if subsaveat.t0: + save_state = _save(t0, init_state.y, args, subsaveat.fn, save_state) + return save_state + + save_state = jtu.tree_map( + save_t0, saveat.subs, init_state.save_state, is_leaf=_is_subsaveat + ) + init_state = eqx.tree_at( + lambda s: s.save_state, init_state, save_state, is_leaf=_is_none + ) + # Privileged optimisation for the common case of no jumps. We can reduce # solver compile time with this. # TODO: somehow make this a non-priviliged optimisation, i.e. detect when @@ -211,63 +242,78 @@ def body_fun(state): # Store the output produced from this numerical step. # - saveat_ts_index = state.saveat_ts_index - ts = state.ts - ys = state.ys - save_index = state.save_index + interpolator = solver.interpolation_cls( + t0=state.tprev, t1=state.tnext, **dense_info + ) + save_state = state.save_state dense_ts = state.dense_ts dense_infos = state.dense_infos dense_save_index = state.dense_save_index - if saveat.ts is not None: - - _interpolator = solver.interpolation_cls( - t0=state.tprev, t1=state.tnext, **dense_info - ) + def save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: + if subsaveat.ts is not None: + save_state = save_ts_impl(subsaveat.ts, subsaveat.fn, save_state) + return save_state - def _cond_fun(_state): + def save_ts_impl(ts, fn, save_state: SaveState) -> SaveState: + def _cond_fun(_save_state): return ( keep_step - & (saveat.ts[_state.saveat_ts_index] <= state.tnext) - & (_state.saveat_ts_index < len(saveat.ts)) + & (ts[_save_state.saveat_ts_index] <= state.tnext) + & (_save_state.saveat_ts_index < len(ts)) ) - def _body_fun(_state): - _saveat_t = saveat.ts[_state.saveat_ts_index] - _saveat_y = _interpolator.evaluate(_saveat_t) - _ts = _state.ts.at[_state.save_index].set(_saveat_t) + def _body_fun(_save_state): + _t = ts[_save_state.saveat_ts_index] + _y = interpolator.evaluate(_t) + _ts = _save_state.ts.at[_save_state.save_index].set(_t) _ys = jtu.tree_map( - lambda __saveat_y, __ys: __ys.at[_state.save_index].set(__saveat_y), - _saveat_y, - _state.ys, + lambda __y, __ys: __ys.at[_save_state.save_index].set(__y), + fn(_t, _y, args), + _save_state.ys, ) - return _InnerState( - saveat_ts_index=_state.saveat_ts_index + 1, + return SaveState( + saveat_ts_index=_save_state.saveat_ts_index + 1, ts=_ts, ys=_ys, - save_index=_state.save_index + 1, + save_index=_save_state.save_index + 1, ) - init_inner_state = _InnerState( - saveat_ts_index=saveat_ts_index, ts=ts, ys=ys, save_index=save_index - ) - - final_inner_state = inner_while_loop( - _cond_fun, _body_fun, init_inner_state, max_steps=len(saveat.ts) + return inner_while_loop( + _cond_fun, + _body_fun, + save_state, + max_steps=len(ts), + buffers=_inner_buffers, + checkpoints=len(ts), ) - saveat_ts_index = final_inner_state.saveat_ts_index - ts = final_inner_state.ts - ys = final_inner_state.ys - save_index = final_inner_state.save_index + save_state = jtu.tree_map( + save_ts, saveat.subs, save_state, is_leaf=_is_subsaveat + ) def maybe_inplace(i, u, x): return x.at[i].set(u, pred=keep_step) - if saveat.steps: - ts = maybe_inplace(save_index, tprev, ts) - ys = jtu.tree_map(ft.partial(maybe_inplace, save_index), y, ys) - save_index = save_index + keep_step + def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: + if subsaveat.steps: + ts = maybe_inplace(save_state.save_index, tprev, save_state.ts) + ys = jtu.tree_map( + ft.partial(maybe_inplace, save_state.save_index), + subsaveat.fn(tprev, y, args), + save_state.ys, + ) + save_index = save_state.save_index + keep_step + save_state = eqx.tree_at( + lambda s: [s.ts, s.ys, s.save_index], + save_state, + [ts, ys, save_index], + ) + return save_state + + save_state = jtu.tree_map( + save_steps, saveat.subs, save_state, is_leaf=_is_subsaveat + ) if saveat.dense: dense_ts = maybe_inplace(dense_save_index + 1, tprev, dense_ts) @@ -278,7 +324,7 @@ def maybe_inplace(i, u, x): ) dense_save_index = dense_save_index + keep_step - new_state = _State( + new_state = State( y=y, tprev=tprev, tnext=tnext, @@ -289,10 +335,7 @@ def maybe_inplace(i, u, x): num_steps=num_steps, num_accepted_steps=num_accepted_steps, num_rejected_steps=num_rejected_steps, - saveat_ts_index=saveat_ts_index, - ts=ts, - ys=ys, - save_index=save_index, + save_state=save_state, dense_ts=dense_ts, dense_infos=dense_infos, dense_save_index=dense_save_index, @@ -320,13 +363,28 @@ def maybe_inplace(i, u, x): return new_state - final_state = outer_while_loop(cond_fun, body_fun, init_state, max_steps=max_steps) + final_state = outer_while_loop( + cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers + ) + + def _save_t1(subsaveat, save_state): + if subsaveat.t1 and not subsaveat.steps: + # If subsaveat.steps then the final value is already saved. + # + # Use `tprev` instead of `t1` in case of an event terminating the solve + # early. (And absent such an event then `tprev == t1`.) + save_state = _save( + final_state.tprev, final_state.y, args, subsaveat.fn, save_state + ) + return save_state + + save_state = jtu.tree_map( + _save_t1, saveat.subs, final_state.save_state, is_leaf=_is_subsaveat + ) + final_state = eqx.tree_at( + lambda s: s.save_state, final_state, save_state, is_leaf=_is_none + ) - if saveat.t1 and not saveat.steps: - # if saveat.steps then the final value is already saved. - # Using `tprev` instead of `t1` in case of an event terminating the solve - # early. (And absent such an event then `tprev == t1`.) - final_state = _save(final_state, final_state.tprev) result = jnp.where( cond_fun(final_state), RESULTS.max_steps_reached, final_state.result ) @@ -334,6 +392,14 @@ def maybe_inplace(i, u, x): return eqx.tree_at(lambda s: s.result, final_state, result), aux_stats +if getattr(typing, "GENERATING_DOCUMENTATION", False): + # Nicer documentation for the default `diffeqsolve(saveat=...)` argument. + # Not using `eqxi.doc_repr` as some IDEs (Helix, at least) show the source code + # of the default argument directly. + class SaveAt(eqx.Module): # noqa: F811 + t1: bool + + @eqx.filter_jit def diffeqsolve( terms: PyTree[AbstractTerm], @@ -348,7 +414,7 @@ def diffeqsolve( stepsize_controller: AbstractStepSizeController = ConstantStepSize(), adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(), discrete_terminating_event: Optional[AbstractDiscreteTerminatingEvent] = None, - max_steps: Optional[int] = 16**3, + max_steps: Optional[int] = 4096, throw: bool = True, solver_state: Optional[PyTree] = None, controller_state: Optional[PyTree] = None, @@ -474,19 +540,15 @@ def diffeqsolve( term_leaves, term_structure = jtu.tree_flatten( terms, is_leaf=lambda x: isinstance(x, AbstractTerm) ) - raises = False - for leaf in term_leaves: - if not isinstance(leaf, AbstractTerm): - raises = True - del leaf - if term_structure != solver.term_structure: - raises = True - if raises: + term_leaves2, term_structure2 = jtu.tree_flatten(solver.term_structure) + if term_structure != term_structure2 or any( + not isinstance(x, y) for x, y in zip(term_leaves, term_leaves2) + ): raise ValueError( "`terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with " f"structure {solver.term_structure}" ) - del term_leaves, term_structure, raises + del term_leaves, term_structure, term_leaves2, term_structure2 if is_sde(terms): if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)): @@ -501,6 +563,22 @@ def diffeqsolve( "An SDE should not be solved with adaptive step sizes with Euler's " "method, as it may not converge to the correct solution." ) + # TODO: remove these lines. + # + # These are to work around an edge case: on the backward pass, + # RecursiveCheckpointAdjoint currently tries to differentiate the overall + # per-step function wrt all floating-point arrays. In particular this includes + # `state.tprev`, which feeds into the control, which feeds into + # VirtualBrownianTree, which can't be differentiated. + # We're waiting on JAX to offer a way of specifying which arguments to a + # custom_vjp have symbolic zero *tangents* (not cotangents) so that we can more + # precisely determine what to differentiate wrt. + # + # We don't replace this in the case of an unsafe SDE because + # RecursiveCheckpointAdjoint will raise an error in that case anyway, so we + # should let the normal error be raised. + if isinstance(adjoint, RecursiveCheckpointAdjoint) and not is_unsafe_sde(terms): + adjoint = DirectAdjoint() if is_unsafe_sde(terms): if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): raise ValueError( @@ -508,15 +586,23 @@ def diffeqsolve( ) # Allow setting e.g. t0 as an int with dt0 as a float. - timelikes = (jnp.array(0.0), t0, t1, dt0, saveat.ts) + timelikes = [jnp.array(0.0), t0, t1, dt0] + [ + s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat) + ] timelikes = [x for x in timelikes if x is not None] dtype = jnp.result_type(*timelikes) t0 = jnp.asarray(t0, dtype=dtype) t1 = jnp.asarray(t1, dtype=dtype) if dt0 is not None: dt0 = jnp.asarray(dt0, dtype=dtype) - if saveat.ts is not None: - saveat = eqx.tree_at(lambda s: s.ts, saveat, saveat.ts.astype(dtype)) + + def _get_subsaveat_ts(saveat): + out = [s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat)] + return [x for x in out if x is not None] + + saveat = eqx.tree_at( + _get_subsaveat_ts, saveat, replace_fn=lambda ts: ts.astype(dtype) # noqa: F821 + ) # Time will affect state, so need to promote the state dtype as well if necessary. def _promote(yi): @@ -532,8 +618,9 @@ def _promote(yi): t1 = t1 * direction if dt0 is not None: dt0 = dt0 * direction - if saveat.ts is not None: - saveat = eqx.tree_at(lambda s: s.ts, saveat, saveat.ts * direction) + saveat = eqx.tree_at( + _get_subsaveat_ts, saveat, replace_fn=lambda ts: ts * direction + ) stepsize_controller = stepsize_controller.wrap(direction) terms = jtu.tree_map( lambda t: WrapTerm(t, direction), @@ -547,18 +634,20 @@ def _promote(yi): solver = stepsize_controller.wrap_solver(solver) # Error checking - if saveat.ts is not None: - saveat_ts = eqxi.error_if( - saveat.ts, - saveat.ts[1:] < saveat.ts[:-1], + def _check_subsaveat_ts(ts): + ts = eqxi.error_if( + ts, + ts[1:] < ts[:-1], "saveat.ts must be increasing or decreasing.", ) - saveat_ts = eqxi.error_if( - saveat_ts, - (saveat.ts > t1) | (saveat.ts < t0), + ts = eqxi.error_if( + ts, + (ts > t1) | (ts < t0), "saveat.ts must lie between t0 and t1.", ) - saveat = eqx.tree_at(lambda s: s.ts, saveat, saveat_ts) + return ts + + saveat = eqx.tree_at(_get_subsaveat_ts, saveat, replace_fn=_check_subsaveat_ts) # Initialise states tprev = t0 @@ -584,30 +673,37 @@ def _promote(yi): passed_solver_state = True # Allocate memory to store output. - out_size = 0 - if saveat.t0: - out_size += 1 - if saveat.ts is not None: - out_size += len(saveat.ts) - if saveat.steps: - # We have no way of knowing how many steps we'll actually end up taking, and - # XLA doesn't support dynamic shapes. So we just have to allocate the maximum - # amount of steps we can possibly take. - if max_steps is None: - raise ValueError( - "`max_steps=None` is incompatible with `saveat.steps=True`" - ) - out_size += max_steps - if saveat.t1 and not saveat.steps: - out_size += 1 + def _allocate_output(subsaveat: SubSaveAt) -> SaveState: + out_size = 0 + if subsaveat.t0: + out_size += 1 + if subsaveat.ts is not None: + out_size += len(subsaveat.ts) + if subsaveat.steps: + # We have no way of knowing how many steps we'll actually end up taking, and + # XLA doesn't support dynamic shapes. So we just have to allocate the + # maximum amount of steps we can possibly take. + if max_steps is None: + raise ValueError( + "`max_steps=None` is incompatible with saving at `steps=True`" + ) + out_size += max_steps + if subsaveat.t1 and not subsaveat.steps: + out_size += 1 + saveat_ts_index = 0 + save_index = 0 + ts = jnp.full(out_size, jnp.inf) + struct = eqx.filter_eval_shape(subsaveat.fn, t0, y0, args) + ys = jtu.tree_map(lambda y: jnp.full((out_size,) + y.shape, jnp.inf), struct) + return SaveState( + ts=ts, ys=ys, save_index=save_index, saveat_ts_index=saveat_ts_index + ) + + save_state = jtu.tree_map(_allocate_output, saveat.subs, is_leaf=_is_subsaveat) num_steps = 0 num_accepted_steps = 0 num_rejected_steps = 0 - saveat_ts_index = 0 - save_index = 0 made_jump = False if made_jump is None else made_jump - ts = jnp.full(out_size, jnp.inf) - ys = jtu.tree_map(lambda y: jnp.full((out_size,) + jnp.shape(y), jnp.inf), y0) result = jnp.array(RESULTS.successful) if saveat.dense: if max_steps is None: @@ -627,7 +723,7 @@ def _promote(yi): dense_save_index = None # Initialise state - init_state = _State( + init_state = State( y=y0, tprev=tprev, tnext=tnext, @@ -638,10 +734,7 @@ def _promote(yi): num_steps=num_steps, num_accepted_steps=num_accepted_steps, num_rejected_steps=num_rejected_steps, - saveat_ts_index=saveat_ts_index, - ts=ts, - ys=ys, - save_index=save_index, + save_state=save_state, dense_ts=dense_ts, dense_infos=dense_infos, dense_save_index=dense_save_index, @@ -672,16 +765,15 @@ def _promote(yi): # Finish up # - if saveat.t0 or saveat.t1 or saveat.steps or (saveat.ts is not None): - ts = final_state.ts - ts = ts * direction - ys = final_state.ys - # It's important that we don't do any further postprocessing on `ys` here, as - # it is the `final_state` value that is used when backpropagating via - # optimise-then-discretise. - else: - ts = None - ys = None + is_save_state = lambda x: isinstance(x, SaveState) + ts = jtu.tree_map( + lambda s: s.ts * direction, final_state.save_state, is_leaf=is_save_state + ) + ys = jtu.tree_map(lambda s: s.ys, final_state.save_state, is_leaf=is_save_state) + # It's important that we don't do any further postprocessing on `ys` here, as + # it is the `final_state` value that is used when backpropagating via + # optimise-then-discretise. + if saveat.controller_state: controller_state = final_state.controller_state else: @@ -698,10 +790,11 @@ def _promote(yi): interpolation = DenseInterpolation( ts=final_state.dense_ts, ts_size=final_state.dense_save_index + 1, - interpolation_cls=solver.interpolation_cls, infos=final_state.dense_infos, - y0=y0, + interpolation_cls=solver.interpolation_cls, direction=direction, + t0_if_trivial=t0, + y0_if_trivial=y0, ) else: interpolation = None diff --git a/diffrax/saveat.py b/diffrax/saveat.py index 800d6083..ec57a814 100644 --- a/diffrax/saveat.py +++ b/diffrax/saveat.py @@ -1,26 +1,28 @@ -from typing import Optional, Sequence, Union +from typing import Callable, Optional, Sequence, Union import equinox as eqx import jax.numpy as jnp -from .custom_types import Array, Scalar +from .custom_types import Array, PyTree, Scalar -class SaveAt(eqx.Module): - """Determines what to save as output from the differential equation solve. +def save_y(t, y, args): + return y - Instances of this class should be passed as the `saveat` argument of - [`diffrax.diffeqsolve`][]. + +class SubSaveAt(eqx.Module): + """Used for finer-grained control over what is saved. A PyTree of these should be + passed to `SaveAt(subs=...)`. + + See [`diffrax.SaveAt`][] for more details on how this is used. (This is a + relatively niche feature and most users will probably not need to use `SubSaveAt`.) """ t0: bool = False t1: bool = False ts: Optional[Union[Sequence[Scalar], Array["times"]]] = None # noqa: F821 steps: bool = False - dense: bool = False - solver_state: bool = False - controller_state: bool = False - made_jump: bool = False + fn: Callable = save_y def __post_init__(self): if self.ts is not None: @@ -28,17 +30,67 @@ def __post_init__(self): ts = None else: ts = jnp.asarray(self.ts) - object.__setattr__(self, "ts", ts) - if ( - not self.t0 - and not self.t1 - and self.ts is None - and not self.steps - and not self.dense - ): + self.ts = ts + if not self.t0 and not self.t1 and self.ts is None and not self.steps: raise ValueError("Empty saveat -- nothing will be saved.") +SubSaveAt.__init__.__doc__ = """**Arguments:** + +- `t0`: If `True`, save the initial input `y0`. +- `t1`: If `True`, save the output at `t1`. +- `ts`: Some array of times at which to save the output. +- `steps`: If `True`, save the output at every step of the numerical solver. +- `fn`: A function `fn(t, y, args)` which specifies what to save into `sol.ys` when + using `t0`, `t1`, `ts` or `steps`. Defaults to `fn(t, y, args) -> y`, so that the + evolving solution is saved. This can be useful to save only statistics of your + solution, so as to reduce memory usage. +""" + + +class SaveAt(eqx.Module): + """Determines what to save as output from the differential equation solve. + + Instances of this class should be passed as the `saveat` argument of + [`diffrax.diffeqsolve`][]. + """ + + subs: PyTree[SubSaveAt] = None + dense: bool = False + solver_state: bool = False + controller_state: bool = False + made_jump: bool = False + + def __init__( + self, + *, + t0: bool = False, + t1: bool = False, + ts: Union[None, Sequence[Scalar], Array["times"]] = None, # noqa: F821 + steps: bool = False, + fn: Callable = save_y, + subs: PyTree[SubSaveAt] = None, + dense: bool = False, + solver_state: bool = False, + controller_state: bool = False, + made_jump: bool = False, + ): + if subs is None: + if t0 or t1 or (ts is not None) or steps: + subs = SubSaveAt(t0=t0, t1=t1, ts=ts, steps=steps, fn=fn) + else: + if t0 or t1 or (ts is not None) or steps: + raise ValueError( + "Cannot pass both `subs` and any of `t0`, `t1`, `ts`, `steps` to " + "`SaveAt`." + ) + self.subs = subs + self.dense = dense + self.solver_state = solver_state + self.controller_state = controller_state + self.made_jump = made_jump + + SaveAt.__init__.__doc__ = """**Main Arguments:** - `t0`: If `True`, save the initial input `y0`. @@ -50,11 +102,70 @@ def __post_init__(self): **Other Arguments:** -It is less likely you will need to use these options. +These arguments are used less frequently. + +- `fn`: A function `fn(t, y, args)` which specifies what to save into `sol.ys` when + using `t0`, `t1`, `ts` or `steps`. Defaults to `fn(t, y, args) -> y`, so that the + evolving solution is saved. For example this can be useful to save only statistics + of your solution, so as to reduce memory usage. + +- `subs`: Some PyTree of [`diffrax.SubSaveAt`][], which allows for finer-grained control + over what is saved. Each `SubSaveAt` specifies a combination of a function `fn` and + some times `t0`, `t1`, `ts`, `steps` at which to evaluate it. `sol.ts` and `sol.ys` + will then by PyTrees of the same structure as `subs`, with each leaf of the PyTree + saving what the corresponding `SubSaveAt` specifies. The arguments + `SaveAt(t0=..., t1=..., ts=..., steps=..., fn=...)` are actually just a convenience + for passing a single `SubSaveAt` as + `SaveAt(subs=SubSaveAt(t0=..., t1=..., ts=..., steps=..., fn=...))`. This + functionality can be useful when you need different functions of the output saved + at different times; see the examples below. - `solver_state`: If `True`, save the internal state of the numerical solver at - `t1`. + `t1`; accessible as `sol.solver_state`. + - `controller_state`: If `True`, save the internal state of the step size - controller at `t1`. -- `made_jump`: If `True`, save the internal state of the jump tracker at `t1`. + controller at `t1`; accessible as `sol.controller_state`. + +- `made_jump`: If `True`, save the internal state of the jump tracker at `t1`; + accessible as `sol.made_jump`. + + +!!! Example + + When solving a large PDE system, it may be the case that saving the full output + `y` at all timesteps is too memory-intensive. Instead, we may prefer to save only + the full final value, and only save statistics of the evolving solution. We can do + this by: + ```python + t0 = 0 + t1 = 100 + ts = jnp.linspace(t0, t1, 1000) + + def statistics(t, y, args): + return jnp.mean(y), jnp.std(y) + + final_subsaveat = diffrax.SubSaveAt(t1=True) + evolving_subsaveat = diffrax.SubSaveAt(ts=ts, fn=statistics) + saveat = diffrax.SaveAt(subs=[final_subsaveat, evolving_subsaveat]) + + sol = diffrax.diffeqsolve(..., t0=t0, t1=t1, saveat=saveat) + (y1, evolving_stats) = sol.ys # PyTree of the save structure as `SaveAt(subs=...)`. + evolving_means, evolving_stds = evolving_stats + ``` + + As another example, it may be the case that you are solving a 2-dimensional + ODE, and want to save each component of its solution at different times. (Perhaps + because you are comparing your model against data, and each dimension has data + observed at different times.) This can be done through: + ```python + y0 = (y0_a, y0_b) + ts_a = ... + ts_b = ... + subsaveat_a = diffrax.SubSaveAt(ts=ts_a, fn=lambda t, y, args: y[0]) + subsaveat_b = diffrax.SubSaveAt(ts=ts_b, fn=lambda t, y, args: y[1]) + saveat = diffrax.SaveAt(subs=[subsaveat_a, subsaveat_b]) + sol = diffrax.diffeqsolve(..., y0=y0, saveat=saveat) + y_a, y_b = sol.ys # PyTree of the same structure as `SaveAt(subs=...)`. + # `sol.ts` will equal `(ts_a, ts_b)`. + ``` """ diff --git a/diffrax/solution.py b/diffrax/solution.py index 12f3805e..d91261a9 100644 --- a/diffrax/solution.py +++ b/diffrax/solution.py @@ -112,7 +112,7 @@ def evaluate( """ if self.interpolation is None: raise ValueError( - "Dense solution has not been saved; pass saveat.dense=True." + "Dense solution has not been saved; pass SaveAt(dense=True)." ) return self.interpolation.evaluate(t0, t1, left) @@ -159,6 +159,6 @@ def derivative(self, t: Scalar, left: bool = True) -> PyTree: """ if self.interpolation is None: raise ValueError( - "Dense solution has not been saved; pass saveat.dense=True." + "Dense solution has not been saved; pass SaveAt(dense=True)." ) return self.interpolation.derivative(t, left) diff --git a/diffrax/solver/base.py b/diffrax/solver/base.py index 9a4c191e..854dc579 100644 --- a/diffrax/solver/base.py +++ b/diffrax/solver/base.py @@ -1,12 +1,12 @@ import abc -from typing import Callable, Optional, Tuple, TypeVar +from typing import Callable, Optional, Tuple, Type, TypeVar import equinox as eqx import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu -from ..custom_types import Bool, DenseInfo, PyTree, PyTreeDef, Scalar +from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..heuristics import is_sde from ..local_interpolation import AbstractLocalInterpolation from ..nonlinear_solver import AbstractNonlinearSolver, NewtonNonlinearSolver @@ -43,7 +43,7 @@ class AbstractSolver(eqx.Module, metaclass=_MetaAbstractSolver): @property @abc.abstractmethod - def term_structure(self) -> PyTreeDef: + def term_structure(self) -> PyTree[Type[AbstractTerm]]: """What PyTree structure `terms` should have when used with this solver.""" # On the type: frequently just Type[AbstractLocalInterpolation] diff --git a/diffrax/solver/euler.py b/diffrax/solver/euler.py index eec2f79a..83b105ed 100644 --- a/diffrax/solver/euler.py +++ b/diffrax/solver/euler.py @@ -1,6 +1,5 @@ from typing import Tuple -import jax.tree_util as jtu from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar @@ -22,7 +21,7 @@ class Euler(AbstractItoSolver): When used to solve SDEs, converges to the Itô solution. """ - term_structure = jtu.tree_structure(0) + term_structure = AbstractTerm interpolation_cls = LocalLinearInterpolation def order(self, terms): diff --git a/diffrax/solver/euler_heun.py b/diffrax/solver/euler_heun.py index b8f865ca..1713eeda 100644 --- a/diffrax/solver/euler_heun.py +++ b/diffrax/solver/euler_heun.py @@ -1,12 +1,11 @@ from typing import Tuple -import jax.tree_util as jtu from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..local_interpolation import LocalLinearInterpolation from ..solution import RESULTS -from ..term import AbstractTerm +from ..term import AbstractTerm, ODETerm from .base import AbstractStratonovichSolver @@ -20,7 +19,7 @@ class EulerHeun(AbstractStratonovichSolver): Used to solve SDEs, and converges to the Stratonovich solution. """ - term_structure = jtu.tree_structure((0, 0)) + term_structure = (ODETerm, AbstractTerm) interpolation_cls = LocalLinearInterpolation def order(self, terms): @@ -31,7 +30,7 @@ def strong_order(self, terms): def step( self, - terms: Tuple[AbstractTerm, AbstractTerm], + terms: Tuple[ODETerm, AbstractTerm], t0: Scalar, t1: Scalar, y0: PyTree, diff --git a/diffrax/solver/implicit_euler.py b/diffrax/solver/implicit_euler.py index 582b3c53..29c38dbc 100644 --- a/diffrax/solver/implicit_euler.py +++ b/diffrax/solver/implicit_euler.py @@ -1,6 +1,5 @@ from typing import Tuple -import jax.tree_util as jtu from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar @@ -26,7 +25,7 @@ class ImplicitEuler(AbstractImplicitSolver): A-B-L stable 1st order SDIRK method. Does not support adaptive step sizing. """ - term_structure = jtu.tree_structure(0) + term_structure = AbstractTerm interpolation_cls = LocalLinearInterpolation def order(self, terms): diff --git a/diffrax/solver/leapfrog_midpoint.py b/diffrax/solver/leapfrog_midpoint.py index b563f601..ad6e99e1 100644 --- a/diffrax/solver/leapfrog_midpoint.py +++ b/diffrax/solver/leapfrog_midpoint.py @@ -1,6 +1,5 @@ from typing import Tuple -import jax.tree_util as jtu from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar @@ -41,7 +40,7 @@ class LeapfrogMidpoint(AbstractSolver): ``` """ - term_structure = jtu.tree_structure(0) + term_structure = AbstractTerm interpolation_cls = LocalLinearInterpolation def order(self, terms): diff --git a/diffrax/solver/milstein.py b/diffrax/solver/milstein.py index 9264acdc..1e76323a 100644 --- a/diffrax/solver/milstein.py +++ b/diffrax/solver/milstein.py @@ -8,7 +8,7 @@ from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..local_interpolation import LocalLinearInterpolation from ..solution import RESULTS -from ..term import AbstractTerm +from ..term import AbstractTerm, ODETerm from .base import AbstractItoSolver, AbstractStratonovichSolver @@ -36,7 +36,7 @@ class StratonovichMilstein(AbstractStratonovichSolver): Note that this commutativity condition is not checked. """ # noqa: E501 - term_structure = jtu.tree_structure((0, 0)) + term_structure = (ODETerm, AbstractTerm) interpolation_cls = LocalLinearInterpolation def order(self, terms): @@ -94,7 +94,7 @@ class ItoMilstein(AbstractItoSolver): Note that this commutativity condition is not checked. """ # noqa: E501 - term_structure = jtu.tree_structure((0, 0)) + term_structure = (ODETerm, AbstractTerm) interpolation_cls = LocalLinearInterpolation def order(self, terms): diff --git a/diffrax/solver/reversible_heun.py b/diffrax/solver/reversible_heun.py index eeb86552..cb337af8 100644 --- a/diffrax/solver/reversible_heun.py +++ b/diffrax/solver/reversible_heun.py @@ -1,7 +1,6 @@ from typing import Tuple import jax.lax as lax -import jax.tree_util as jtu from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar @@ -34,7 +33,7 @@ class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): ``` """ - term_structure = jtu.tree_structure(0) + term_structure = AbstractTerm interpolation_cls = LocalLinearInterpolation # TODO use something better than this? def order(self, terms): diff --git a/diffrax/solver/runge_kutta.py b/diffrax/solver/runge_kutta.py index 9d014542..110c1138 100644 --- a/diffrax/solver/runge_kutta.py +++ b/diffrax/solver/runge_kutta.py @@ -195,7 +195,7 @@ class AbstractRungeKutta(AbstractAdaptiveSolver): scan_stages: bool = False - term_structure = jtu.tree_structure(0) + term_structure = AbstractTerm @property @abc.abstractmethod diff --git a/diffrax/solver/semi_implicit_euler.py b/diffrax/solver/semi_implicit_euler.py index d3a09ae4..798fceea 100644 --- a/diffrax/solver/semi_implicit_euler.py +++ b/diffrax/solver/semi_implicit_euler.py @@ -1,6 +1,5 @@ from typing import Tuple -import jax.tree_util as jtu from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar @@ -20,7 +19,7 @@ class SemiImplicitEuler(AbstractSolver): Symplectic method. Does not support adaptive step sizing. """ - term_structure = jtu.tree_structure((0, 0)) + term_structure = (AbstractTerm, AbstractTerm) interpolation_cls = LocalLinearInterpolation def order(self, terms): From a2873e4a2bf20d0a13da15b9809b663363a932d3 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 20 Feb 2023 10:06:45 -0800 Subject: [PATCH 16/19] Added tests for SubSaveAt, and SaveAt(dense=True) with t0 == t1, and backsolve through SDEs. --- test/test_adjoint.py | 56 ++++++++++++++++++++++++ test/test_integrate.py | 13 ++++-- test/test_saveat_solution.py | 85 ++++++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 4 deletions(-) diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 3c940c6c..15cfe925 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -236,3 +236,59 @@ def make_step(model, opt_state, target_steady_state): assert shaped_allclose( model.steady_state, target_steady_state, rtol=1e-2, atol=1e-2 ) + + +def test_backprop_ts(getkey): + mlp = eqx.nn.MLP(1, 1, 8, 2, key=jrandom.PRNGKey(0)) + + @eqx.filter_jit + @eqx.filter_value_and_grad + def run(model): + sol = diffrax.diffeqsolve( + diffrax.ODETerm(lambda t, y, args: model(y)), + diffrax.Euler(), + 0, + 1, + 0.1, + jnp.array([1.0]), + saveat=diffrax.SaveAt(ts=jnp.linspace(0, 1, 5)), + ) + return jnp.sum(sol.ys) + + run(mlp) + + +def test_sde_against(getkey): + def f(t, y, args): + k0, _ = args + return -k0 * y + + def g(t, y, args): + _, k1 = args + return k1 * y + + t0 = 0 + t1 = 1 + dt0 = 0.001 + tol = 1e-5 + shape = (2,) + bm = diffrax.VirtualBrownianTree(t0, t1, tol, shape, key=getkey()) + drift = diffrax.ODETerm(f) + diffusion = diffrax.WeaklyDiagonalControlTerm(g, bm) + terms = diffrax.MultiTerm(drift, diffusion) + solver = diffrax.Heun() + + @eqx.filter_jit + @jax.grad + def run(y0__args, adjoint): + y0, args = y0__args + sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0, y0, args, adjoint=adjoint) + return jnp.sum(sol.ys) + + y0 = jnp.array([1.0, 2.0]) + args = (0.5, 0.1) + grads1 = run((y0, args), diffrax.DirectAdjoint()) + grads2 = run((y0, args), diffrax.BacksolveAdjoint()) + grads3 = run((y0, args), diffrax.RecursiveCheckpointAdjoint()) + assert shaped_allclose(grads1, grads2, rtol=1e-3, atol=1e-3) + assert shaped_allclose(grads1, grads3, rtol=1e-3, atol=1e-3) diff --git a/test/test_integrate.py b/test/test_integrate.py index c8196613..30d7e74b 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -205,7 +205,7 @@ def diffusion(t, y, args): bm = diffrax.VirtualBrownianTree( t0=t0, t1=t1, shape=(noise_dim,), tol=2**-15, key=bmkey ) - if solver_ctr.term_structure == jtu.tree_structure(0): + if solver_ctr.term_structure == diffrax.AbstractTerm: terms = diffrax.MultiTerm( diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, bm) ) @@ -292,8 +292,8 @@ def f(t, y, args): t0 = -4 t1 = -0.3 negdt0 = None if dt0 is None else -dt0 - if saveat.ts is not None: - saveat = diffrax.SaveAt(ts=[-ti for ti in saveat.ts]) + if saveat.subs is not None and saveat.subs.ts is not None: + saveat = diffrax.SaveAt(ts=[-ti for ti in saveat.subs.ts]) sol2 = diffrax.diffeqsolve( diffrax.ODETerm(f), solver_ctr(), @@ -307,7 +307,12 @@ def f(t, y, args): assert shaped_allclose(sol2.t0, -4) assert shaped_allclose(sol2.t1, -0.3) - if saveat.t0 or saveat.t1 or saveat.ts is not None or saveat.steps: + if saveat.subs is not None and ( + saveat.subs.t0 + or saveat.subs.t1 + or saveat.subs.ts is not None + or saveat.subs.steps + ): assert shaped_allclose(sol1.ts, -sol2.ts, equal_nan=True) assert shaped_allclose(sol1.ys, sol2.ys, equal_nan=True) if saveat.dense: diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index e86299c7..f6986356 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -1,6 +1,9 @@ +import contextlib import math import diffrax +import equinox as eqx +import jax import jax.numpy as jnp import pytest @@ -162,3 +165,85 @@ def test_trivial_dense(): stepsize_controller=stepsize_controller, ) assert shaped_allclose(sol.evaluate(2.0), y0) + + +@pytest.mark.parametrize( + "adjoint", + [ + diffrax.RecursiveCheckpointAdjoint(), + diffrax.DirectAdjoint(), + diffrax.ImplicitAdjoint(), + diffrax.BacksolveAdjoint(), + ], +) +@pytest.mark.parametrize("multi_subs", [True, False]) +@pytest.mark.parametrize("with_fn", [True, False]) +def test_subsaveat(adjoint, multi_subs, with_fn, getkey): + if with_fn: + mlp = eqx.nn.MLP(3, 1, 32, 2, key=getkey()) + apply = lambda _, x, __: mlp(x) + subsaveat_kwargs = dict(fn=apply) + else: + mlp = lambda x: x + subsaveat_kwargs = dict() + get2 = diffrax.SubSaveAt(t0=True, ts=jnp.linspace(0.5, 1.5, 3), **subsaveat_kwargs) + if multi_subs: + get0 = diffrax.SubSaveAt(steps=True, fn=lambda _, y, __: y[0]) + get1 = diffrax.SubSaveAt( + ts=jnp.linspace(0, 1, 5), t1=True, fn=lambda _, y, __: y[1] + ) + subs = (get0, get1, get2) + else: + subs = get2 + + context = contextlib.nullcontext() + if isinstance(adjoint, diffrax.ImplicitAdjoint): + context = pytest.raises(ValueError) + elif isinstance(adjoint, diffrax.BacksolveAdjoint): + if with_fn or multi_subs: + context = pytest.raises(NotImplementedError) + + term = diffrax.ODETerm(lambda t, y, args: -0.5 * y) + y0 = jnp.array([2.1, 1.1, 0.1]) + saveat = diffrax.SaveAt(subs=subs) + stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8) + + with context: + sol = diffrax.diffeqsolve( + term, + t0=0, + t1=2, + y0=y0, + dt0=None, + solver=diffrax.Dopri5(), + saveat=saveat, + stepsize_controller=stepsize_controller, + adjoint=adjoint, + ) + steps = sol.stats["num_accepted_steps"] + + sol2 = diffrax.diffeqsolve( + term, + t0=0, + t1=2, + y0=y0, + dt0=None, + solver=diffrax.Dopri5(), + saveat=diffrax.SaveAt(dense=True), + stepsize_controller=stepsize_controller, + ) + + if multi_subs: + ts0, ts1, ts2 = sol.ts + ys0, ys1, ys2 = sol.ys + assert ts0.shape == (4096,) + assert shaped_allclose(ts1, jnp.array([0, 0.25, 0.5, 0.75, 1, 2])) + assert shaped_allclose( + ys0[:steps], jax.vmap(sol2.evaluate)(ts0[:steps])[:, 0] + ) + assert shaped_allclose(ys1, jax.vmap(sol2.evaluate)(ts1)[:, 1]) + else: + ts2 = sol.ts + ys2 = sol.ys + assert shaped_allclose(ts2, jnp.array([0, 0.5, 1.0, 1.5])) + assert shaped_allclose(ys2, jax.vmap(mlp)(jax.vmap(sol2.evaluate)(ts2))) From 9ab1bee0bc28bb5b88bda699d8db71327a973ec3 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 20 Feb 2023 10:07:04 -0800 Subject: [PATCH 17/19] Updated documentation and examples --- docs/api/saveat.md | 13 +- docs/{ => other_examples}/basic-examples.md | 0 examples/coupled_odes.ipynb | 115 ++++++ examples/neural_sde.ipynb | 202 ++++------ examples/nonlinear_heat_pde.ipynb | 417 ++++++++++++++++++++ mkdocs.yml | 21 +- 6 files changed, 619 insertions(+), 149 deletions(-) rename docs/{ => other_examples}/basic-examples.md (100%) create mode 100644 examples/coupled_odes.ipynb create mode 100644 examples/nonlinear_heat_pde.ipynb diff --git a/docs/api/saveat.md b/docs/api/saveat.md index b432ea06..665560d3 100644 --- a/docs/api/saveat.md +++ b/docs/api/saveat.md @@ -4,11 +4,8 @@ selection: members: - __init__ - - t0 - - t1 - - ts - - steps - - dense - - solver_state - - controller_state - - made_jump + +::: diffrax.SubSaveAt + selection: + members: + - __init__ diff --git a/docs/basic-examples.md b/docs/other_examples/basic-examples.md similarity index 100% rename from docs/basic-examples.md rename to docs/other_examples/basic-examples.md diff --git a/examples/coupled_odes.ipynb b/examples/coupled_odes.ipynb new file mode 100644 index 00000000..a1a6b71e --- /dev/null +++ b/examples/coupled_odes.ipynb @@ -0,0 +1,115 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1fe536ed", + "metadata": {}, + "source": [ + "# Coupled ODEs" + ] + }, + { + "cell_type": "markdown", + "id": "598ab169-05d8-4733-a6cc-9fa91aa92198", + "metadata": {}, + "source": [ + "This example demonstrates basic functionality for solving a system of coupled ODEs; in this the [Lotka–Volterra](https://en.wikipedia.org/wiki/Lotka%E2%80%93Volterra_equations) equations.\n", + "\n", + "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/coupled_odes.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "6d6bdf63", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5\n", + "\n", + "\n", + "def vector_field(t, y, args):\n", + " prey, predator = y\n", + " α, β, γ, δ = args\n", + " d_prey = α * prey - β * prey * predator\n", + " d_predator = -γ * predator + δ * prey * predator\n", + " d_y = d_prey, d_predator\n", + " return d_y\n", + "\n", + "\n", + "term = ODETerm(vector_field)\n", + "solver = Tsit5()\n", + "t0 = 0\n", + "t1 = 140\n", + "dt0 = 0.1\n", + "y0 = (10.0, 10.0)\n", + "args = (0.1, 0.02, 0.4, 0.02)\n", + "saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))\n", + "sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9654fd84-19b9-4a0b-bff6-d20f36c4f333", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGfCAYAAAD/BbCUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAACXdElEQVR4nO2dd3yb1b3/39ree8d29l5kQAh7BEJombmlpbSMS8ulDTO0pem9pfsXoLctHZQuCu0tFEpbKJtCgEAgCdk7zk6ceMV2vK39/P44eiQ5cRLLlvQMnffr5ZccSZG+0vFzzud8z3dYFEVRkEgkEolEIkkSVq0NkEgkEolEklpI8SGRSCQSiSSpSPEhkUgkEokkqUjxIZFIJBKJJKlI8SGRSCQSiSSpSPEhkUgkEokkqUjxIZFIJBKJJKlI8SGRSCQSiSSpSPEhkUgkEokkqUjxIZFIJBKJJKnYh/KfH374YZYsWcK9997LY489BoDb7eaBBx7gueeew+PxMH/+fH79619TWlo6oNcMBoPU1dWRnZ2NxWIZinkSiUQikUiShKIodHZ2UlFRgdV6at/GoMXHmjVr+O1vf8u0adP63H///ffz2muv8cILL5Cbm8tdd93F9ddfz0cffTSg162rq6OqqmqwZkkkEolEItGQ2tpaKisrT/mcQYmPrq4ubrrpJn7/+9/zwx/+MHx/e3s7Tz75JM8++yyXXHIJAE899RQTJ05k1apVnH322ad97ezs7LDxOTk5gzFPIpFIJBJJkuno6KCqqiq8jp+KQYmPRYsW8alPfYp58+b1ER/r1q3D5/Mxb9688H0TJkygurqalStX9is+PB4PHo8n/O/Ozk4AcnJypPiQSCQSicRgDCRkImbx8dxzz7F+/XrWrFlzwmMNDQ04nU7y8vL63F9aWkpDQ0O/r7d06VK+973vxWqGRCKRSCQSgxJTtkttbS333nsvzzzzDGlpaXExYMmSJbS3t4d/amtr4/K6EolEIpFI9ElM4mPdunU0NTUxc+ZM7HY7drud5cuX84tf/AK73U5paSler5e2trY+/6+xsZGysrJ+X9PlcoWPWORRi0QikUgk5iemY5dLL72ULVu29LnvtttuY8KECTz44INUVVXhcDhYtmwZCxcuBKCmpoZDhw4xd+7c+FktkUgkEonEsMQkPrKzs5kyZUqf+zIzMyksLAzff/vtt7N48WIKCgrIycnh7rvvZu7cuQPKdJFIJBKJRGJ+hlRkrD9+9rOfYbVaWbhwYZ8iYxKJRCKRSCQAFkVRFK2NiKajo4Pc3Fza29tl/IdEIpFIJAYhlvVb9naRSCQSiUSSVKT4kEgkEolEklSk+JBIJBKJRJJUpPiQSCQSiUSSVKT4kEgkEolEklTinmorSQ1W7m3hoz3NXDCumLNGFmhtjiRGFEXhtS317Gro5OozKhhTcvoulBJ94fUHeWFdLa1dXj57VhUl2fFpeSFJHu09Pp5bcwib1cLn51ST4UydJVmm2kpi5uVNddzz1w0AWCzwxE0zuWJKucZWSWLhf9+q4Vfv7QEg3WHjn189h4nl8nozCoqicMf/rePt7Y0AVOSm8eo951OQ6dTYMslAcfsCXP2rFexq7AJg9vB8/nrH2Thsxj2QkKm2koRxrNvLd/61FRATnqLAQ//ahtsX0NgyyUDZeqSdX78vhEd5bhq9vgDf+dc2dLYPkZyClzfV8fb2Rhw2C4WZTura3fz07RqtzZLEwM+X7WZXYxc5aXYynDbWHjzGc2tSp7GqFB+SmPjDin0c6/ExoSybZQ9cxLC8dJo6Pby1rUFr0yQD5LF3dhNU4FPTynnxq+fitFn55EArOxs6tTZNMgCCQYWf/HsXAPdcMpZffn4GAH9fd5huj19L0yQD5Fi3lydX7Afgfz8znW/MHw/Anz8+oKFVyUWKD8mA8QWCPL/mMAD3XDqWdKeN62cOA+CVTfVamiYZIHVtvby7U7jqF182jrLcNC4aXwzAa5vlGBqBFXuaOdTaQ3aandvPH8ncUYVUF2Tg9gV5d2eT1uZJBsDf1x3G6w8yuSKHyyaVct3MShw2C7ubutjdmBqbACk+JAPm3Z1NNHd5KM52cdmkUgA+Pa0CgOW7muhw+7Q0TzIA/r7uMEEFzh5VwOjiLEB4QABe2yLFhxF4fq1wzS+cWUmG047FYgmP4etyDA2BOoZfOHs4FouF3HQH548NbQJSZAyl+JAMGPVo5erpFeGgqPFl2QwvzMAXUFh38JiW5kkGgDqG18+oDN936cRSbFYL+5u7qWvr1co0yQBw+wK8H/JuXDdjWPj++ZPLAPhoTzPBoIzd0TN7mrrY09SFwxYRjQDzJ4sN3Ud7mrUyLalI8SEZEP5AxKWrej1UZg8XqbZrD7Qm3S7JwDnS1su2ug6sFrh0Ykn4/iyXnUmhTJe1UkDqmpV7W+j2BijLSWPqsNzw/ZMrckh32Ohw+9lztEtDCyWnQ81Qmju6iJw0R/j+M0eIeXTT4XY8fvMH8EvxIRkQG2rbaOvxkZfhYPbw/D6PnTlC/HvtAblw6RlVPM4ank9hlqvPY7NCY7pOCkhdsywUrzNvUglWqyV8v8Nm5YyqPADpgdQ5aszV8Zu4kUWZFGQ68fqDbKvr0MK0pCLFh2RArNzbAsB5Y4qwH5eHPjskPjYdbsMXCCbdNsnAWBUawwtCZ8vRqGMoPR/6ZuUpxlAVkHIToF96vH421rYBcMHYoj6PWSwWZlarmwDzj6EUH5IBsWqfmPTmjCo84bHRxVlkuey4fUH2N3cn2zTJAFAUJTyGZ48+cQynV+YBsKuxUwpIndLU6Wbv0W4sFvqtKqx6PrbVtSfZMslAWX+wDV9AoSI3jeqCjBMen1GdB6TGGErxITktHn8g7MqdO+rESc9isTC2VGRO1MhaEbpkT1MXLd1e0hxWplXmnvB4ZX46mU4bvoDCASkgdcnqfeJIbGJZDnkZJ1YyHV8mSuTvO9otBaROCW8ARhVisVhOeHx8qRjDmkbzx+1I8SE5LZtq2/H4gxRlOcPpmcczITTx7UqRHHWjoU56s4bn47LbTnjcYrEwrkyd+OQY6pHohas/huUJAekNBDnYIgWkHll5mjFUBeTepi78JheQUnxITsvqqCOX/tQ6wLiQYpdVMvXJqtCu+eyR/U96ENl17ZJjqEsi4qP/Ro5Wq4Wx6s65wfw7Z6PR6w2wKRTvMbefo08QAjIjJCAPtPQk0brkI8WH5LSsOySOXM4acfLutWF3oVy4dMn60BjOPsUYjiuVng+90t7jY+9R4c04cyDXoRxD3bG1rh1/UKE0x0Vlfnq/z4kWkGb3IkvxITkliqKw5bAIfpoeCmjrD/WCqT3WkxI56kaiqdNNfbsbiwWm9hPvoaLG7expkrtmvbHliLgGqwsyyD9F51p1DPfKMdQdqtdjWmXeST3IAONKxBjuNnnchxQfklNypK2Xlm4vdqslHNfRH0VZTjKcNhQFDh+TVTL1hCoex4Sykk7GiMJMAGqP9coqmTpj0+E24NTiEWB4aAwPtsqYD72xOXQdTht26jEcUZQaYyjFh+SUqAvX+LJs0hwnBiqqWCyWcOrYIZOfVRoNddI73cJVnpuG3WrB6w/S0OFOhmmSARL2Pp5WfIhr8GBzD4oiBaSeUL1X007hQQZSZh6V4kNySjapaj1UB+JUqDvnAzLSXldsDu2ap59mDO02K1WhiU+Oob5Qx/B016G6cHV6/BzrkY0e9UJ7ry9cA+l0no+wgGyV4kOSwmw50gbQb22I4wlfNCZX7EZCUZTwjut0ng9InV2XkTja6aEuFLMz5TQLV5rDRllOGoBMt9URqueqqiD9lDE7AMMLxCbuaKeHHq8/4bZphRQfkpMSDCqRc8oBiQ9x0RwyuWI3EnXtbpq7RMyO2jzuVKTKrstIqBuA0aeJ2VGpDo2hvA71w6YBeq4AcjMc5KaLhnNmHkMpPiQn5WBrD51uPy67NZyGeSoing+549ILm0MR9uNKTx2zoyI9H/pjU+3ANwAAwwukB1JvDDRmRyUVvMhSfEhOinrOPKkiB4ft9H8qw/JE7npdm1sGu+mEzUfUNOmBTXqV+WLSO9wmM5b0QjhQ8TRHLirqGB6RWWe6IXz0OSxvQM+vSoExlOJDclK214u2zpMrTu+uByjLFWfNvb4AbTLYTRfsCI3hpIqBLVwVeWIM66X40A3qGE4eoPhQx7CuXY6hHmjv9XEkdD1NGuBcWh6aS+tNPIZSfEhOilqtdELZwC6YNIeNoiwRTCUnPn2gjuHEU9RoiaYi5L062uXB6zd3bwkj0N7jo75dpD2Pj3EM1f8n0Rb1GhyWlx6O5Tgd5VFeZLMixYfkpOysDy1c5QOb9ADKc81/0RiF6IVr3AAXrsJMJ067FUWBRlnrQ3N2Ngivx7C8dHLSBrZwVYQXrl55/KkD1DEcqHgEGJYC3ispPiT90tbjDReaGkiwqUrYbW/ii8YoDGbhslgsYZdvnTx60Ry1R8upqgsfjzp+Pd4AHb3mTdU0CjsbBjOGIe+ViTdxUnxI+kW9YCrz08ke4MIFkYvmiFy4NEcdw1g8VwAVqvdKCkjN2RHyPsaya05z2CjIlMefemFnfeyej/LQJq6p040vYM7jz5jExxNPPMG0adPIyckhJyeHuXPn8sYbb4Qfv+iii7BYLH1+7rzzzrgbLUk8NYNQ6xAdsGhexW4UVPERy6QHkYlPHp1pT03IezVhADVaogkHncpNgKYEgwq7Qg3iJsYwhkWZLhw2C0ETH3/GJD4qKyt5+OGHWbduHWvXruWSSy7hmmuuYdu2beHnfPnLX6a+vj788+ijj8bdaEniUV32Aw02VQm7C+WOS3MGO4YVcgx1QTCoxBwwrBKOvZJBp5pypK2XLo8fp83KyFDDuIFgtVqi5lJzjuHpy+VFcdVVV/X5949+9COeeOIJVq1axeTJkwHIyMigrKwsfhZKNGGwu+bSHNVd6Im7TZKBEwwq7Bqk96o0xwVAU4ccQy050tZLtzeA02YNdzodKOoYHjXprtkoqPPo6JKsAdVKiqYsJ41DrT3S83E8gUCA5557ju7ububOnRu+/5lnnqGoqIgpU6awZMkSenrMW6HNrETvuGJduIqzQ5OeFB+acvhYZOGKZccFUJwtBaQeUOt7jBnEwlUix1AXqPEesXquwPxzaUyeD4AtW7Ywd+5c3G43WVlZvPjii0yaNAmAz3/+8wwfPpyKigo2b97Mgw8+SE1NDf/85z9P+noejwePJ/LldnR0DOJjSOLJ4WO99Axy4SoJXTA93gBdHv+AelFI4o965DKmJAt7jAuX2Sc9ozDYDQDIMdQLg/UgQ2QMzSogY14Zxo8fz8aNG2lvb+fvf/87t9xyC8uXL2fSpEnccccd4edNnTqV8vJyLr30Uvbu3cvo0aP7fb2lS5fyve99b/CfQBJ3dgxh4cp02cl02uj2BmjqcJNVnJUIEyWnIZzeF2OmC0QE5NFOD4qiYLFY4mqbZGDEYwzNunAZhZ2DDBgG8wvImI9dnE4nY8aMYdasWSxdupTp06fz85//vN/nzpkzB4A9e/ac9PWWLFlCe3t7+Ke2tjZWkyRxpmYIkx6Y/6IxAvHYNXsDQVknQkMixankwmVE3L4A+5tFk83BHLuUmHwMh+wTDwaDfY5Notm4cSMA5eXlJ/3/LpcLl8s1VDMkcWQoCxeI8+YDLT1y16UhQ1m40hw2ctLsdLj9NHW6yc0YeJ0XSXzw+CML12CuQzXmo7nLQzCoYLVK71Wy2dPURVCBvAxHWAzGgjx2iWLJkiUsWLCA6upqOjs7efbZZ3n//fd566232Lt3L88++yxXXnklhYWFbN68mfvvv58LLriAadOmJcp+SQLY0yTy0seWDNLzkWNuxa53fIFguBX32JLBHXuV5KTR4e7iaKeHsTFUuJXEh/3N3QQVyE6zh3fAsVCY5cRiAX9Q4ViPl8IsucFLNnuPqvNo1qCOLs3uvYpJfDQ1NXHzzTdTX19Pbm4u06ZN46233uKyyy6jtraWd955h8cee4zu7m6qqqpYuHAh//M//5Mo2yUJwB8IhndcYwa5cBVnmVux652DLd34gwqZTlu41HasFGe52NPUJcdQI/Y2Ra7BwSxcDpuVggwnLd1emjo9UnxowN7QJm6w86jqvWrp9uAPBGOOv9M7MYmPJ5988qSPVVVVsXz58iEbJNGWw8d68QaCuOxWhoUaVMVKiVonotOc+el6Z09o4Ro9yIULImNo1l2X3lG9j6OHELBdnO2ipdvL0U4PE09+8i1JEHuODm0MCzKdWC0QVKC120tJzuA2EnrFXFJKMmTUSW9Ucdagz4lVxS4XLm3YO8RJD6KzJaSA1AJ14RrsrhnMHzOgd8ICcpBjaLNaKDKxF1mKD0kf4jnpSfGhDXuG6O4FuXBpTdhlP0TPB8jrUAv8gSAHmkXclRzD/pHiQ9KHeEx6Zk8R0zvx8XxI75VWBIMK+5qHtmuG6Cqn0nuVbGpDx9dpjsEfX4O5PZBSfEj6ED6nLImtsmk06gXT0u01bTtovaIoSlSg29DHUHo+ks+Rtl7cviBOm5Wq/MEvXGbeNesd9RocVTT442sw9xhK8SEJoyhKXFz2+RlO7KELrrnLfBeNnqlvd9PtDWC3WhheOHjxET52MWlTKz2jbgBGFmUOKcNBCkjtiMfxNUjxIUkRjnZ56HT7sVpgxBAWLqvVQn6mE4CWLm+8zJMMAPXIZXhhRszNyKJRUzM73H7pvUoye5uG7n0EUesDoEVuAJJOPLKVIFK2oNmE86gUH5Iw6gVTVZBBmsM2pNcqDImP1m7zXTR6Jl6TXl66A9VbfEyOYVLZE4e4K4DCTLFwyWsw+eyNk+cjvInrNp+AlOJDEiYewaYqBVJ8aEI8js0g5L3KUCc+OYbJJBwwPMQxVK/Btl4fgaAyZLskAyNex9dgbgEpxYckzN6jQ6tsGk1Bply4tCBeOy6QAlIr4uW9yg/15FEUONYjxzBZHO2MOr4uyhjSa5n5GpTiQxImXpMeRB+7mM9dqGfC1U3j6L2SAjJ5tHR5ONbjw2IZ+hjabVbyQgLEjIuXXlGDTasLMnDZh3h8HYrbOdbjI2gy75UUH5IwQ63IF02Bid2FeqW9xxfOLorHGKoTX6sMWEwaqvdxWF466c6hLVwQJSBNGLCoV4ba0yUa9egzEFRo7/UN+fX0hBQfEgC6PH4aQmmVcYn5yJKTXrJRd1zluWlkuWJq29QvZnb56pV4eh9BBn5rQTzH0Gm3kp0mrmWzeSCl+JAAEbVelOUiN+SqHQpy0ks+e+O8cKneK7NNenomnjE7EC0gpfcqWajeq3h4H8G8c6kUHxIgOktiaLUFVOSuOfnEq7CRilknPT0TrywJFSkgk0/8x9Cc16EUHxIg/juuQhmsmHT2xjFmB6JrDMgxTBbxPnYpyBReTFmrJTl0un3h4+v4jaEUHxITE+9JT1242nt9skJmkgj35SmOj/dKej6SS683wJG2XkB6PozKvtCRS3G2i9z0oR9fg3mPzqT4kADxd9nnZzixqBUyZY2BhOP2BahtDbXwjrO7V+6ak4PqfSzIdIa/+6EiBWRyiVd12mjMKiCl+JDgCwQ51BLfhctmtZCXLmsMJIsDLd0EFchJs4f7QQwVdeE61uM1XY0BPbI3zp4rMK/LXq/Eoyv48ZhVQErxIeFgSzf+oEKm00ZZTlrcXjc88cl024QTXaPFYhl8C+9o1KOzoCJKdEsSSzzrQ6jIQnHJJZ4tKlTMKiCl+JAkZOGCSF8COfElnkS4ex02KzmhGgNmO2/WI5GYnfiNYbhCZrcXRZHeq0QTOb7OjttrmrVmkhQfkkhPlzhOemBexa5H4l1bQKUwdIRjtolPj8SzwrCKeg36gwodvf64va7kRLz+IAdDx9fy2OX0SPEhScikBxHFLgNOE08iPB8gBWSy8AeCHGgOxV3FcQxddlu42q0Z27LriUOt3QSCClkue1yPr9US660m815J8SGJe5qtihpw2tYj4wUSSSCosC/O2Uoq4TGUMR8JpfZYL95AkDSHlWF56XF9bbW5nBzDxBKZRzPjenytbgC8gSC9vkDcXldrpPhIcRRFiSowFj9XIUQUe5v0fCSUurZePP4gTpuVyvx4L1zSe5UM1IVrVFEWVmv8Fi6IEh9yDBNKojZxGU4bDpv4mzhmoo2cFB8pTn27mx5vALvVwvDC+IoPddIz0wWjR9RJb2RRJnZbfC/p/NAYtssxTCjxrjAcTWQTIMcwkSQq7spisUQ2ASY6/pTiI8VRF67hhRk44rxw5UnPR1KIdy+JaCICUo5hIknkGKqVNuUmILEkcgzDmwATHZ1J8ZHi7E1Aep9KvjxrTgqJKE6lEjl2kWOYSBLlsgd5/JkMgkEloXNpXrr5jj+l+EhxkrJrNpGrUI8kKlsJIguXPHZJHH3jrhK4CZBjmDAaOqKPrzPi/vp5JhxDKT5SnMSKD7Fwdbj9+GVzuYSgKErc+/JEI49dEs/RTg+dbj9WC4woiv/ClSuDhhOOOo+OKMqM+/E1mDNoWIqPFCccJJUQV2Gkq2OHWxY4SgSt3V7aenxYLCJTIt7INM3Eoy5c1QUZuOy2uL++GeMF9EZ0mm0iMGPQsBQfKUx7j4/mLlF4KBEue7vNSnaowJHcdSUGddIblpdOujP+C1d00LCZChzpiUQeuYD0XiWDxI+h+WKvpPhIYfYc7QSgPDctXAUx3uRlms9dqCcSeeQCkV2zL6DQ7TVPgSM9kciYHYhauLrNs3DpjUQeX4M8dpGYjL1NiTtyUTGju1BPJHoM0x02nHYxTZhp4tMTiTz6hKigYXnskjASmekC5swclOIjhUn0rhlkjYFEk+gxtFgsskx+gkn4rjk0fl0eP16/DPyON209XppDjRcTJT5yUz3V9oknnmDatGnk5OSQk5PD3LlzeeONN8KPu91uFi1aRGFhIVlZWSxcuJDGxsa4Gy2JD4l294KsMZBo9iZ44QLpvUoknW4fDR1uIHELV066A7XViPR+xB/V61GRm0Zmgo6v8zPNV2k4JvFRWVnJww8/zLp161i7di2XXHIJ11xzDdu2bQPg/vvv55VXXuGFF15g+fLl1NXVcf311yfEcMnQSWRxKhVZYyBx9Hj9HGnrBRJ7dCYDFhPHvtCRS3G2K+wljDc2qyX82nITEH+Suonr9Zkm8DsmmXbVVVf1+fePfvQjnnjiCVatWkVlZSVPPvkkzz77LJdccgkATz31FBMnTmTVqlWcffbZ8bNaMmTcvgC1raEW3ok8dpE1BhKGunAVZDrDnS8TgRmD3fRC+MglgeIRxNFLW49PHn8mgETH7EDk+DoQVOhw+xMmVJPJoGM+AoEAzz33HN3d3cydO5d169bh8/mYN29e+DkTJkygurqalStXnvR1PB4PHR0dfX4kiWd/czdBBXLS7BRnuRL2PmYMlNILyVq45LFL4khG3BXIPkuJJNExOwBpDhvpDpFKb5ajl5jFx5YtW8jKysLlcnHnnXfy4osvMmnSJBoaGnA6neTl5fV5fmlpKQ0NDSd9vaVLl5Kbmxv+qaqqivlDSGInfORSkoXFEt8W3tHImI/EERnDxB2bAeTK7sQJY2+Ci1OpmLE8t15IZF+eaMx2/Bmz+Bg/fjwbN25k9erVfOUrX+GWW25h+/btgzZgyZIltLe3h39qa2sH/VqSgZOsXXN44ZI1BuJOsiY9KSATR8TzkZ3Q94nEDMgxjCduX4DaY4k/voYo75VJvMgxh+Y6nU7GjBkDwKxZs1izZg0///nP+exnP4vX66Wtra2P96OxsZGysrKTvp7L5cLlSpzbX9I/yXAVgqwxkEiSN4by6CwReP1BDrYka+GS3qtEsL+5G0URMRlFWYmLu4Lo4H1zCMgh1/kIBoN4PB5mzZqFw+Fg2bJl4cdqamo4dOgQc+fOHerbSOJMshauvHRzuQr1gj8Q5ECLCHRL9BiascaAHjjY0k0gqJDlslOak9gNmNqS3SwLl16InkcTeXwN5usSHpPnY8mSJSxYsIDq6mo6Ozt59tlnef/993nrrbfIzc3l9ttvZ/HixRQUFJCTk8Pdd9/N3LlzZaaLzggEFfY1J2fhUj0fPd4AHn8gIY2zUpGDrT34AgrpDhsVuekJfa9wYzK5a44r0SmaiV648jNlzEciSNbxNaT4sUtTUxM333wz9fX15ObmMm3aNN566y0uu+wyAH72s59htVpZuHAhHo+H+fPn8+tf/zohhksGz+FjPXj9QZx2K5X58W/hHU12mh2rBYKKWLxKcqT4iAeRhSsTqzXRC5f0fCSCZC5cudIDmRCSla0EmK7ScEzi48knnzzl42lpaTz++OM8/vjjQzJKkljUSW9UUSa2BC9c1lCBo2OhGgMlOWkJfb9UIak7rvRIS/ZgUEm42EkVkrlwyXTpxJCMCsMqZgv8lr1dUpBkVOSLJl8WGos7yZz01IyloAKdbn/C3y9V2JOkNFuQ4iMRJPP4GsyX8i7FRwqSzF0zRAKlZMZL/NiT4C6a0bjsNjKc4rhMCsj4EAwq4TotSXHZm6xGhB6obRXH1y67lWF5iY27Aun5kJiAZE56EAmUkgGL8UFRlKR6PkCmTMebI229uH1BnDYr1QWJjbuCiPjw+IO4fYGEv18qED6+Ls5KylGk2TZxUnykGIqiJC3NVkWm28aXhg433d4ANquF4YWJd9mDDFiMN6rnakRRBnZb4qfhLJc9HN9llsVLa5IZswPmq7cjxUeKcbTLQ4fbj9UCI4uStHCZ7KLRGlU8Di/MwGlPziVstl2X1iTbc2WxWOQmIM4k+/harbejBn4bHSk+Ugz1gqkqyCDNkZy010iBI7lwxYNkT3oge4PEGy3GMFeOYVxJtgdZ9T4qCnS4jT+GUnykGHs1XLjaZV+JuJDsSQ8iuy65cMWHZGecgfnqRGiJFnFXTruVzFDgtxnGUIqPFEOLhUvumuOLpmMoBeSQURQl6fECEBX4LcdwyDR1euj0iOPrEUWJDxhWMVOVUyk+UoxwimZSd81SfMSTZGcrQVShMTmGQ6a120tbjw+LJTmp0irS8xE/9objrjKT2jIiMpcaX0BK8ZFiaLNrlmma8aKtx0tzl5h4krpwyaDhuKFeg5X56UmLuwIZ+B1PkllnJxozBX5L8ZFCdLh9NHZ4AG12zWZQ61qjLlwVuWlkumLqjjAkcmVX1LgRPnJJ9sIl43bihhabODDXEbYUHymE6iosyXaRk+ZI2vuqF0y3N4DXH0za+5oRLQIVQXo+4onWC5eM+Rg6Wo2hmQK/pfhIIbS6YLLTHKgdw83gLtQSzRcuE0x6WqP1GJph4dIazcfQBAJSio8UQosIewBbqLMtyF3XUNEi2BSi+kr0+lAU4xc40pK9TVrFC5hn16wlHW4fTZ3i+DoZTQGjyTfRJkCKjxRib5PowJjsSQ9kpH280CpeQBWPgaBCl0d2th0s3R4/de1uQINdc7p5ghW1RPV6lOa4yE7i8TVExe2YYAyl+EghtNo1A+TKXdeQcfsCHD7WCyQ/5iPNYSPNIaYLOYaDR70Gi7KcYU9EspCdbeODVkcuEMlYMsMYSvGRInj8AQ62CM+HFheN7CsxdPYe7UJRxCJSmJnchQtktkQ82KPRkQtExq/HG8Djl51tB4sWVaJVzFRvR4qPFOFAcw9BBbJddkqyXUl/fzPlp2tFdD8QiyXxLbyPx0zBblqh5a45O80uA7/jgJZjKCucSgxHdIqmJguXjPkYMsnuJXE8slLt0NFy4bJGB37LMRw0WlSJVolkLHkN39lWio8UQctJD6JiPuSuedBola2kImt9DB3NxzBdjuFQcPsC1Lb2ABrFfITGL6hAl9fYgd9SfKQIu5o6ARhXqvGkJ3dcg2ZXo1i4xpZma/L+asxAu4zbGRRuX4CDLWLhGqfRGMrA76Gxp6mLYCjuqjgr+cfX0YHfRvdeSfGRIuxqEOJDs4VLxnwMCY8/wP5mETCsmYCURaqGxL6j3QSCCjlp2sRdgWx1MFR2hzdx2ZocX4N5Ar+l+EgBvP5geOEar7H4MPoFoxX7m8XClZ1mpywnTRMbZGOyoaEuXOPLNFy45CZgSKjeR602AGCewG8pPlKAAy3d+IMK2S475bkaLVzpMuZjKEQmPbnjMiq7GrX1PoI8/hwqqgdZq2MzME/gtxQfKUBN6IIZU6pNpgtIz8dQiUx62u+4ZIn8wVHTEBKQGgWbggz8Hiq7mrQXH/kmSbeV4iMF2B3acWl15AKRHVen248/IDvbxoq6a9Zy0pO75qERjhco03LhkmM4WHq8fmpbRYVhTa/DcH8XYwtIKT5SAK2zJCDiKgTocBs7RUwLdjdFjl20wkwFjpJNrzfAoVZtM11AeiCHwu7GSGn8Ag0qDKtESqwbewyl+EgBIrtm7dy9dpuV7DQ7ICPtY8XtC3CgRc100cPC5ZWdbWNkT5MojV+Y6aRIgxRNlTwZezVo9OB9BPPEXknxYXKiFy4tj10gurGVsS+aZKMuXPkZDoqytNtxqePnCyj0eGVvkFiIBJtqtwGAqIwleQ3GjB68j2Ce2CspPkzO3qOiKE5uuoNijWoLqISLVBn8okk2aqzAWA0zXQDSHTactlBnW3n0EhP62TXL8uqDpaZBHwLSLLFXUnyYHPWccrzGCxfI8+bBomZJaO25slgsUTtnKSBjQTfiIxS30+nx45OB3zGhh8B9ME+9HSk+TI5e3L1gnvz0ZLNbBzE7KnLnPDii67RoSU4o7gqgw+CLVzLpdPuoa3cD2gbug4z5kBgEvey4QDYmGyw1OihOpSLHMHa6PH6OtKkpmtoKyD6B33IMB4wqHktzXH0y97QgOubDyIHfMYmPpUuXcuaZZ5KdnU1JSQnXXnstNTU1fZ5z0UUXYbFY+vzceeedcTVaMnD0suMC2ZhsMHR7/Bw+pn1tAZVck+y6konquSrJdoWPPbREHn/Gzm4dbuKMHvgdk/hYvnw5ixYtYtWqVbz99tv4fD4uv/xyuru7+zzvy1/+MvX19eGfRx99NK5GSwZGj9dP7TG1toAOXPZy1xwze5rU2gIuTWsLqJilr0Qy2a2jDQDIwO/BoKdNnFkCv+2nf0qEN998s8+/n376aUpKSli3bh0XXHBB+P6MjAzKysriY6Fk0ETXFijUsLaAioz5iJ0aHcV7gIz5GAw1Ooq7Aun5GAx6qJWkYrFYyMtw0NTpoa3Hy7C8dK1NGhRDivlob28HoKCgoM/9zzzzDEVFRUyZMoUlS5bQ09Nz0tfweDx0dHT0+ZHEBz2pdZAVMgeDnty9IBeuwbBLJ1kSKuHrUI7hgNFT7ByY4zqMyfMRTTAY5L777uPcc89lypQp4fs///nPM3z4cCoqKti8eTMPPvggNTU1/POf/+z3dZYuXcr3vve9wZohOQV6ypIA8/QkSCZ6E5CyMVns7NZBe4NownUi5CZgQLT1eGnq9AB6GkPjC8hBi49FixaxdetWVqxY0ef+O+64I/z71KlTKS8v59JLL2Xv3r2MHj36hNdZsmQJixcvDv+7o6ODqqqqwZoliSLsstewkVU0ctKLnfCuuUwfAlI2JouN9l4fDR1qiqY+xjBP1mqJCXUDMCwvnSzXoJfMuJJrgtirQX2Td911F6+++ioffPABlZWVp3zunDlzANizZ0+/4sPlcuFyaR+PYEb0FuiWG04R8xEMKlit2hY90zsdbh/1odoCY0r0MYZm2HElE9X7WJGbRk6atimaKjL2Kjb0FO+hYoYqpzHFfCiKwl133cWLL77Iu+++y8iRI0/7fzZu3AhAeXn5oAyUDI5Oty9SW0BnC5eiQKfsbHtaVPFYlpOmeW0BFZntEht66Ch9PDL2Kjb0FncF0bU+jDuGMXk+Fi1axLPPPsu//vUvsrOzaWhoACA3N5f09HT27t3Ls88+y5VXXklhYSGbN2/m/vvv54ILLmDatGkJ+QCS/lHVemmOK+xx0Bqn3Uqm00a3N8CxHq9u7NIrOxtE8LVejs1A7ppjRR3DCToaw0jGkhSQA2FHg3r0qaMxDAcNG3cMY/J8PPHEE7S3t3PRRRdRXl4e/nn++ecBcDqdvPPOO1x++eVMmDCBBx54gIULF/LKK68kxHjJydlRLy6YieU5GlvSF7nrGjg7w2Oop0lPLFwefxC3z7gFjpLFjnohPvR0Hcp6OwNHURRdjqEZNgExeT5OV8q1qqqK5cuXD8kgSXzQ4wUD4qI50tZraMWeLNQxnKSjMcxy2bFZLQSCCm09PspybVqbpFsURQkLyAk6FJBGXriSRV27m063H7vVwuhiHcV8mEBAyt4uJmVnyFWoJ3cvmOOsMhkoihI1hvoRHxaLJSprSQrIU3GkrZdOjx+HzcKoIv0sXGqJ/A63j0DQuL1BksHO0AZgTEkWTrt+lstIqwrjzqP6+TYlcSMYVMIXjd48H3LXNTAOH+uly+PHabMyqjhTa3P6kCvHcECoR59jSrJ1tXCpLnsR+C3H8FTo1YNshsBv/VwRkrhx+Fgv3d6AWLiKdLZwyVTNAbEjasflsOnrMjVDml8yCG8AdOZ9VAO/QY7h6dihcw+ykcdPX7OaJC7sCEXYjy3Nwq63hcsEij0Z7NBhrICKGjQsG5OdGvU61PMYGjlmIBno1/Mhxs/jD9Jr0M62+lqZJHFBrxcMyMZkA0VN0dRTsKmK9HwMjJ06zTgDWeV0IPR6AxxoFh3b9SYgM5027KEijUbdyEnxYULCEfY6cxWCOaK0k4Eeg01VcuUYnpZeb4D9LaGFS4djaAa3faLZ3dRJMNQVvFgHXcGjUTvbgnHHUIoPE7JDx7vmSMyHMdV6Mujx+jkQWrj0VONDJV92RT0tNY2dKAoUZbkoztbXwgXRZfLldXgyoj3IFov+WkEYvdaHFB8mo9vj52BLD6Cvinwq0vNxemoaxMJVnO2iUGc7LpAu+4EQyTbT3zUI0ns1EHbo2IMMxo+9kuLDZKju+hKdL1wy5uPk6H3SM/qOKxmou2a9jqGM2zk94THUoQcZjD+GUnyYDDVQUY9BbhDlsu/1nbZibqqi52BTkJkSA0FN0dTrdSiL/Z2a6CJ/0nuVGKT4MBkRta7TCyak1gNBhU6P7GzbH3osyR2NbEx2akRZddXzoVPxIWM+TklDh5v2Xh92q4UxJfqpThtNnsFrJknxYTLUhUuvu+Y0h400h/izk0cvJ6IoSjhgWO+7ZqPuuBJNXbubjlA/EL0uXEbfNScadRM3ujgLl12f/Ysi3itjCkgpPkyEXvuBHI/RFXsiOdLWS6dbf/1AolHHr8cbwOM3ZoGjRKLXfiDRyHo7p0bPRf5U8mWqrUQv6LkfSDSyyunJUSe90cX6Xbiy0+yomYcyZuBE9B5sCjJu53QYYROXGxrDYwY9OtPn7CYZFNt13A8kGpktcXLUhUuvx2YAVqslPIZy53wiO3Rc2VQlOl06KDvbnsD2unZAv8GmILNdJDpi2xFxwUwZpt9JD2TMwKnYGhrDycNyNbbk1IQnPjmGJ7C1Tr0O9TuGqngMKtDllYHf0XR7/OwLlVWfXKHfMTR6xpIUHyZia53YNet50oNIzIDMljiRbeoYVuhbQObKKqf90t7rCxf5m6zjMUxz2Eh3hDrbdssxjGZHfQeKAmU5abqsTqti9Ng5u9YGDJZAIIDPZ8wvPVEcbetkWLaNySXpuN1urc0BwOFwYLP1jRY3ek+CRNHa7eVIWy8Ak3S8cEG0y1cKyGi2h8RjZX56OK5Cr+RlOOhtD9DW66WaDK3N0Q1bDeJBVjOWen0B3L4AaQ59ZuWcDMOJD0VRaGhooK2tTWtTdEUgqHDPWXlYgDRPC/v3t2ptUpi8vDzKysrC/RFkml//bAu560cWZZKd5tDYmlOTb3CXb6JQx3CKjt31KrnpDurb3XITcByqB1nPRy4A2S47Vos4Ouvo9UnxkWhU4VFSUkJGRoYuG/5oQZfbRyCzF6fNxkidZLooikJPTw9NTU0AlJeXA8Z3FyaKrUfUSU/fOy6IZEsYNdI+URhl1wwy9upkRMZQ3+JDDfw+1uOjrddHSU6a1ibFhKHERyAQCAuPwsJCrc3RFe1esNgDZGU4SUvTzx9heno6AE1NTZSUlGCz2QxfHCdRGCFQUUVmLPVPeNdsgDGUsVcn4vYF2N3UBRhFQDqF+DDgdWiogFM1xiMjQ55PHk+vVxR70qPrTR0vdfyMniKWKMLZSjp394LcNfdHj9fP3qOhhctIYyivwzA1DZ0EggqFmU7KDOBJyDVw7JWhxIeKPGo5EbdPiI90p/7Ex/HjJQscnUiH28cBA2RJqMjuxCeiZkmU5rh0nSWhImOvTkT1Pk4elmuIdcbImwBDig9JX/yBIN5AEIB0h/6HNLrAkexsK9gWivcYlpdOfqa+syQgKm5HHp2FUWN2jOD1ABl71R+RMdT/BgAiXcKNuAnQ/0olOS29Ia+Hy27FZtX/kKriwxdQ6PHK3iAQlSVhgHNmiNo1G3DSSxRGKRCnImOvTmSbgeKuIHLsYsTAb/2vVCbh1ltvxWKxYLFYcDqdjBkzhu9///v4/UOvLqiKDz3Ge/RHusOGM1T+3YjuwkSw1UDxHiAbk/XHVoMUiFORsVd98QWC4a7ghrkO5bGLZCBcccUV1NfXs3v3bh544AG++93v8uMf//iE53m9salYNdhUj/Ee/WGxWKJ2zsZT7InAKNVpVdS4nU6PH1/oyC+VcfsC7G4MLVwGGUMZ89GX3Y1deANBstPsVBWka23OgDDyJkCKjyTicrkoKytj+PDhfOUrX2HevHm8/PLL3HrrrVx77bX86Ec/oqKigvHjxwNQW1vLDTfcQF5eHgUFBVxzzTUcOHAAgA8++ACHw0FDQ0Mk2NRh47777uP888/X6iMOGCNfNPEmOktiskGOXXLSIln6HXLxYldjJ/6gQkGmk/Jc/WdJQCReQG4ABFujCsQZIdgUooP3jTeGhqrz0R+KooSPHZJNusM2pD/S9PR0WlpaAFi2bBk5OTm8/fbbgEhLnT9/PnPnzuXDDz/Ebrfzwx/+kCuuuILNmzdzwQUXMGrUKP705z+z4PN3AGAnyDPPPMOjjz469A+XYIzsLow3apZESbaLkmxjLFx2m5XsNDudbj9tvT4Ks/Sf3ZFIogvEGWfhihy7KIpiGLsThVEac0Zj5Ngrw4uPXl+ASQ+9pcl7b//+fDKcsX+FiqKwbNky3nrrLe6++26OHj1KZmYmf/jDH3A6hZL9y1/+QjAY5A9/+EN4UnjqqafIy8vj/fff5/LLL+f222/nj089zYLP34HTZuWN11/D7XZzww03xPVzJoJcGWkfJhxhbxB3vUpehkOIDzmGhioQp6Jmu/iDCt3eAFkuwy8HQ8JoR59g7LgdeeySRF599VWysrJIS0tjwYIFfPazn+W73/0uAFOnTg0LD4BNmzaxZ88esrOzycrKIisri4KCAtxuN3v37gVEEOu+vXvYvH4NaQ4bTz/9NDfccAOZmfoor34qIp4P47kL480WNUvCIIGKKuE0PzmGkUwXA41hmsOK0x4K/E7xoxd/IBhuCqj3ni7R5IWvQeOJD8NL3XSHje3fn6/Ze8fCxRdfzBNPPIHT6aSiogK7PfL1Hy8Yurq6mDVrFs8888wJr1NcXAxASUkJl85fwEt/e4YzJo3jjTfe4P3334/9g2iAjPmIsKm2DYDplXma2hEr4TS/FG/J7vYF2FEvFi4jjaHFYiEv3UFTp4e2Hh+V+VpbpB27m7ro9QXIdtkZVaT/zZuKOo92hQK/HTbj+BMMLz4sFsugjj60IDMzkzFjxgzouTNnzuT555+npKSEnJyT76au+9zNfO2r/8mE0SMYPXo05557brzMTSiytLOgy+NnTyjYdFqVcXZcICvVquyo78AXECW5K/ONkSWhkpchxIcRd87xRN0ATK3MxWo1TuxLTnqk+3V7r48iA8VeGUcmpRg33XQTRUVFXHPNNXz44Yfs37+f999/n3vuuYfDhw8DwlV41vkXk5mVzY8fXsptt92msdUDJ1d2RQVgy+F2FEVUNjVKsKlKnoH7SsSTsOeqKs9wQZtq3EeqX4ebDrcBYgyNhM1qCWeeGW0jF5P4WLp0KWeeeSbZ2dmUlJRw7bXXUlNT0+c5brebRYsWUVhYSFZWFgsXLqSxsTGuRqcCGRkZfPDBB1RXV3P99dczceJEbr/9dtxud9gT0uMLYLVauf6zNxEIBLj55ps1tnrgFIZKiLd2y0kPYLrBvB5AuAx8qo/h5sMi3sNIRy4qqgfymMEWrnizsda4Y2jU6zCm84rly5ezaNEizjzzTPx+P9/61re4/PLL2b59ezhm4f777+e1117jhRdeIDc3l7vuuovrr7+ejz76KCEfwCg8/fTTMT9WVlbGn/70p5P+P7W4WEtTA1deeSXl5eVDMTGpFBj0gok3Ro33ACkgVTYaWECqKdItXR6NLdGOXm+AXaECcWcYzPMB4jo82NJDa7exxjAm8fHmm2/2+ffTTz9NSUkJ69at44ILLqC9vZ0nn3ySZ599lksuuQQQ6aETJ05k1apVnH322fGzXEJDcysbNm7iX//4Gy+//LLW5sREUZZYuJpTeNKDvi57o6EKyJYUFh/tvT72He0GpIA0Ktvq2gkEFUpzXJQZpEBcNAWZIQFpsDEcUqRme7twVRUUFACwbt06fD4f8+bNCz9nwoQJVFdXs3Llyn7Fh8fjweOJLEAdHR1DMSllUBSFL3/hs2zZsI7bv3wHl112mdYmxYR6wXS4jRelHS+aOtzUtbuxWIxVW0ClMCQgU3nXvCV05FJdkGGIbsTHExaQXcZauOLJxtAGYJoBxSNECUiDjeGgxUcwGOS+++7j3HPPZcqUKQA0NDTgdDrJy8vr89zS0lIaGhr6fZ2lS5fyve99b7BmpCy+QJA//O0VLBYLk8uNU1tAJS/dgdUCQQWOdXspyTHejmOobAotXGNLsgxZ4KkwJCBTedds1EBFlbCANJjLPp6o16ERj1wgegyNdR0Oeru5aNEitm7dynPPPTckA5YsWUJ7e3v4p7a2dkivlyqorejTHFZDpYapWK2W8K6r2WCKPV5sVhcug+641PE71uPDn6LN5TaGY3aM57kCKSDB2HFXYNzjz0Ftt+666y5effVVPvjgAyorK8P3l5WV4fV6aWtr6+P9aGxspKysrN/XcrlcuFzGyU3WC2qwaYbDeDtmlYJMJ81d3pSd+DYaON4DID/DgcUCiiIESHF2al3HiqKEx9Dwu+YU3QC0dns51NoDiBofRqTIoEHDMXk+FEXhrrvu4sUXX+Tdd99l5MiRfR6fNWsWDoeDZcuWhe+rqanh0KFDzJ07Nz4WS4CI5yPdGVuVVT0RUezGumjigaIo4R2XURcuu80arvWRigKyocPN0U4PNqvFUCW5oynMjNT5CAQVja1JPuqx2ajizHDFXqNh1MzBmLbNixYt4tlnn+Vf//oX2dnZ4TiO3Nxc0tPTyc3N5fbbb2fx4sUUFBSQk5PD3Xffzdy5c2WmSxyJ7uSbYWDxoab5Ge2iiQcHWnrocPtx2q2ML8vW2pxBU5jl4liPL7TrMu7nGAyqeBxfmm3YTYAaJBtURLG4VOtOHN4AGPTIBSLeK6MdX8ckPp544gkALrrooj73P/XUU9x6660A/OxnP8NqtbJw4UI8Hg/z58/n17/+dVyMlQjcviBBRcFmseCyGzdLpDCFI+031h4DRCMyI2f6GPW8OR5sCB+bGdPrAeCwWclNd9De66O1O/XEh9GPPiESt3Osx0swqBgmBjAm8aEop3fLpaWl8fjjj/P4448P2ijJqenx+gFx5GK0cs7RpPLCte6gEB+zqo3dzSsiIFPv6Gx9aAxnmmAM23t9NHd5GVuqtTXJIxhUwmM4a7hxx1CdRwNBhQ63L9xzSe8Yd8uVwqjxHpn9pGfeeuutXHvttUm2aHBEjl1Sb+Fad7ANMPakBxGXb6odnXn9wXCKphxDY7L3aBcdbj/pDhsTDHz06bRbyQ71dzHS0YsUH0ni1ltvxWKxYLFYcDqdjBkzhu9///v4/f6YX6s75PmIR7zHgQMHsFgsbNy4ccivFSupeuzS6fZR0yCK6c00+MJl1OqKQ2VbXTtef5D8DAcjDdSCvT8iAYuptQlQvY9nVOVhN/DRJ0QyXowkII39jRuMK664gvr6enbv3s0DDzzAd7/7XX784x+f8Dyv9+R/QL5AEK9f1FTQW7CpzxdbcyqjRmkPlU217QQVqMxPp9TgxdVSVUCui3LXG/noEyIC0ki75nigjuHM4XnaGhIHCgx4/CnFRxJxuVyUlZUxfPhwvvKVrzBv3jxefvnl8FHJj370IyoqKhg/fjwAtbW13HDDDeTl5VFQUMA111zDjt17AUhz2EBRWLx4MXl5eRQWFvKNb3zjhLicN998k/POOy/8nE9/+tPs3bs3/LiaLj1jxgwsFks4mDgYDPL973+fyspKXC4XZ5xxRp/ePqrH5Pnnn+fCCy8kLS2NZ555Jqbvo8iglfmGyjoTnDOrpKrLfv0hdeEy/hgWpegYrjtkouvQgPFzxhcfigLebm1+BhCAeyrS09PDXo5ly5ZRU1PD22+/zauvvorP52P+/PlkZ2fz4Ycf8tFHH5GVlcX1V38an9dLhtPGT37yE55++mn++Mc/smLFClpbW3nxxRf7vEd3dzeLFy9m7dq1LFu2DKvVynXXXUcwKLwnn3zyCQDvvPMO9fX1/POf/wTg5z//OT/5yU/43//9XzZv3sz8+fO5+uqr2b17d5/X/+Y3v8m9997Ljh07mD9/fkyfX91xtff68KVQhUwzTXrhKrUp5LJXFMU0AcOQmvV2Wru94YaAM6qMP4ZGLBZn3PKYKr4e+H8V2rz3t+rAGft5r6IoLFu2jLfeeou7776bo0ePkpmZyR/+8AecTvFH9Je//IVgMMgf/vCHsFv3qaeeIjcvjzUrV/DZaz/NY489xpIlS7j++usB+M1vfsNbb73V570WLlzY599//OMfKS4uZvv27UyZMoXi4mIACgsL+1Sh/d///V8efPBBPve5zwHwyCOP8N577/HYY4/1yWS67777wu8fK6nY3yUYVNhgkiwJSM3y3Efaemns8GC3WgzbjCyawnCFzNQZww2hDcDo4kxDNgQ8HiPG7Rjf82EgXn31VbKyskhLS2PBggV89rOf5bvf/S4AU6dODQsPgE2bNrFnzx6ys7PJysoiKyuLgoICPG43hw/ux9fbRX19PXPmzAn/H7vdzuzZs/u85+7du7nxxhsZNWoUOTk5jBgxAoBDhw6d1M6Ojg7q6uo499xz+9x/7rnnsmPHjj73Hf9+sRDd38VI7sKhsLupi06PnwynsSPsVdQdV1sK9XdRvR6TK3IMW1wsmsIUjL0y09EnRDYBzQYaQ+N7PhwZwgOh1XvHwMUXX8wTTzyB0+mkoqICuz3y9Wdm9vWgdHV1MWvWrD5xFD1eP4daeiguLsY5wOJiV111FcOHD+f3v/89FRUVBINBpkyZcsqg1lg43u5YUfu7pMquy0wR9gD5Gc5wf5fWHi8l2eb3XoXre5hk4Uq1DQCYUHyocTsGmkeNLz4slkEdfWhBZmYmY8aMGdBzZ86cyfPPP09JSQk5OTkAHO30QE4vOWkO8vIyKS8vZ/Xq1VxwwQUA+P1+1q1bx8yZMwFoaWmhpqaG3//+95x//vkArFixos/7qN6WQCAQvi8nJ4eKigo++ugjLrzwwvD9H330EWedddYgP33/pNp5s9kmPZvVQn6Gk9Zu0SAwFcSHmWJ2ILJwqf1dbAapkDlYfIFguKeLacbQgMefxt96mZSbbrqJoqIirrnmGj788EP279/Psnff5eGHHqStuR6Ae++9l4cffpiXXnqJnTt38tWvfpW2trbwa+Tn51NYWMjvfvc79uzZw7vvvsvixYv7vE9JSQnp6em8+eabNDY20t4uCid9/etf55FHHuH555+npqaGb37zm2zcuJF77703rp8z1fq7mClLQqUghdJtuz1+dtR3AuZZuPJDFTFFd2Lzj+GO+g7cviC56Q5GFWVpbU5cMOImTooPnZKRkcEHH3xAdXU1119/PRMnTuQb934Vr8dDaWEBAA888ABf/OIXueWWW5g7dy7Z2dlcd9114dewWq0899xzrFu3jilTpnD//fefUFfEbrfzi1/8gt/+9rdUVFRwzTXXAHDPPfewePFiHnjgAaZOncqbb77Jyy+/zNixY+P6OYvUbAkD5acPlqOdHvY3d2OxwEwTRNirFKbQGG6sbSMQVKjITaM8N11rc+KCw2YlL0N0dE0FAbnmgBrwnWeYPiinIzpd2ijdiY1/7GIQnn766ZgfKysr409/+hMAbl+AXY2dWC0WSovEMYzdbuexxx7jscceO+lrz5s3j+3bt/e57/haIF/60pf40pe+1Oc+q9XKd77zHb7zne/0+7ojRowYUK+f01GcLTwfTR3mX7g+2d8KiC6ouRnGbN/dH+oYHu00/xiu3tcCwFkjCzS2JL4UZ7lo6/FxtNNj6C7LAyEyhoUaWxI/CjJF7FVQEQJEvSb1jPR8GIRuT6SZnNXgFRWjUWMEmlJg4fpkv5j0zh5lnkkPUmsMV4cE5ByzjWFOaBPQ6dbYksQSDCqsOaCOoXkEpN1mDXsgjTKGUnwYhO5TNJMzMsXhSS91Fi6z7ZrDC1eHMSa9weL2BdgQasE+x2xjmCICcndTF8d6fKQ7bEwdlqu1OXGl2GBjKMWHAVAUJez5yDJBXYFoSsIue3MvXMe6vexsEIGKphMf2akhIDfVtuH1BynKchm+mdzxlKTI8efqkPdx1vB8HCZIdY8mPJcaZAzN9e2bFG8giC8QxGKxkOE0l+dD3XG1dHtNXaRKdfWOKckKd6A0C6mya/5kf8Rdb/RmcscTjr0y+SYgfGxmsg0ARG8CjDGGUnwYgG6POHJJd9hME52tUpjpxGa1oCjm7qpp6kkvJzUCTtUxPNuUY2h+AakoCqv3mfPoEwh3yDbKGBpSfMQjy8JIhI9cXMY8cjnVeFmtlnCamFEU+2BQ3b1mnPTUHVd7rw+3L3CaZxsTXyAYLhBnpiwJlZIUyFja39xNc5cHp93K9Ko8rc2JO5HYK2OMoaHEh8Mh0hN7eno0tiS5dHuF+DBqsKk6Xur4HU/YbW+QiyZWOtw+ttd1AObLdAHITXeEy/2bdfHacqSdXl+A/AwHY0vMUZgqmkjMh5k3AMLrcUZVHmkOY27kToXRjl0MtZrZbDby8vJoamoCRCEus529Ho/XH8TjdmPBgjXox+02zs5SURR6enpoamoiLy8Pm63/C97sAYvrDhwjqMCIwoywa9RMWCwWirNcHGnrpanTQ1VBbD2PjEC0u95sR58QOXbp9gbo9vgNu9E5FWp9DzMem4Hxsl0M9xemtn1XBYjZ6fH6ae324bRbONRrzIUrLy8vPG79YfYaA2ZNsY2mJEeID7NmLX2y33yFqaLJctnJcNro8QZo6vQw0mTiQ1GUqOvQnGMYvYlTFEX3G3PD/YVZLBbKy8spKSnB5/NpbU7C+d+3anhjaxM3zK7iv6aN1NqcmHE4HCf1eKgYTbHHysrQjmuOSSc9MLf3yhcIhktymzFgWKUk28WBlh6aOtymSyU+1NpDfbsbh83CzOF5WpuTENSMJa8/SEevX/dVlA0nPlRsNttpFzWjoygKb+xo4UhngOkjiklLM6bn43SYucZAe4+PLaEOmueOKdLWmARi5ridzYfb6PL4yc9wMKk8R2tzEkZJdpoQHyYUkCv2NAMwozrfdOUKVNIcNnLTHbT3+mjqdOtefBgq4DTVONTaw5G2Xhw2i7ld9iYuNLZyXwtBBUYXZ1KWa07xCMYLdouFFbuF5+qc0UWmjPdQMXO14Y9C4uM8E28AwFgeSCk+dIyq1meaVa23H4aOOlPXGDD1pKco0LIXelqj4nbMO4am9FwFA9C0AzxdphWQgaDCx3uFgDTlGPrc0LgdAj5Dxc+ZcEUzD6ZeuN5/GN5fCsDoOQ8Aszja6SEYVEy1u/xor0kXrmAA/vEl2PZPsLmYctZSoMx0xy7dHj/rD4l4D9Ndh+52+PO1ULce0guYOvGngNMw5bkHyva6Dtp6fGS57EyvNFc/F1r2wp+vgfZaKBjNuIKlfIQxjj+l50On9FHrY0026R34KCw8ALJX/4QLrJvwBxWO9ZinymldWy/7jnZjtcDZo00WbLruKSE8AAIeJn6yhGpLo+k8H5/sb8UfVKgqSKe60GQpxO98VwgPgN5Wrtj+TdJxm24MVQ/y2aMKsZusnwv/WiSEB0DrXm5tXAoohhhDk42EeVDVerbLzjSTdV/k41+I25k3w5w7AVjsfAkwl9te9VxNr8ojJ03fwV8xEQzAR6ExvOJhGHkB1oCH/7K9Sku3x1Q9ekzrfexqgg3PiN9vfB5yq8lwN/Aftg8M4bKPhY/3qmNosg3AoVVwaCXYnHDbG2BPY3jnBs6y7DTEPCrFh04Jq/XRJlPrrftg15uABc69D867H2xOzqCGGZbdhrhoBoppF65db0HbQUgvgJm3wIXfBGCh7QPylA5aus3jvVph1niP9X+CgAeGzYZx8+HcewC43fYGRzt6NTYufrh9gXBDQNON4erfitvpn4Ph58D0GwH4kv11Q1SqNdGqZi5Mu3DVvCluR54PhaMhuwwmXw/AtbYVNLSbY+JTFIUVe0wa5Fbzmrid9llwZoiJr2wqaRYfC2xrqG/X/8Q3EI52etjZ0AnAXLOVxa95Q9zOvBksFjjj8yjObEZYGxnp3mGaHj3rDx7D4w9Sku1ijJnK4vu9sOcd8fvMW8Tt2V8B4BLrBrrbj2pk2MCR4kOHuH0BPjlgUrW++9/idtwVkfumCPFxhW0NdcfM0bdnV2MXzV0e0h02ZlTnaW1O/FAU2P22+H3c5eLWYgkLyCusn1DfZg4BqbrrJ5XnUJjl0tiaONLVBEfWid/HzRe3zkwYvwAIXYcmGcMVUZs4vVf8jIlDK8HTAZnFUDFT3Fc8Hm/RJOyWIJM7P9Z9A1YpPnTImgOteP1BSnNcjC42UaVBbzcc/Ej8PvbyyP2jLsJjy6TU0oajfp02tsWZD3eLnceZIwtw2U1UDK9hM3Q1giMThp8buX/i1QDMtW6nudkcrQ8+3K0euZjM67Fnmbgtny48jyEsE68CzCUgI2Nosk3cntAGYOzlYI0s49ZJ4jqcx2rdH39K8aFD3q8RC9eF44rNpdaPrIeAF3KGQeGYyP12F0dLzgOgrGWVRsbFl+gxNBWHVovb4eeAPcobUDSG5rThOCwBXIc/0sa2OBIMKuExvGh8icbWxJna0DU28sK+94+5FD92qq1Haa/blXy74kxTp5stR9oBuMCs1+HIC/rcbZ94JQDnWLdR19qRbKtiImbx8cEHH3DVVVdRUVGBxWLhpZde6vP4rbfeisVi6fNzxRVX9P9ikn55r0bsHC8226R3eI24rTxTuOqj8FQL8TG22/iej26Pn9WhRmQXjzfZpBc9hsdxtPhsAEqajS8gt9d30NzlIdNpY/aIfK3NiS+H14rbqrP63u/M5FDmZABchz5MslHx54NdwusxdVhuuO+JKfB7oH6T+P3467B0Kh2WHDItHnr2rUm+bTEQs/jo7u5m+vTpPP744yd9zhVXXEF9fX34569//euQjEwlDrZ0s+9oN3arxXz1PdRz5srZJzzkGnsJABP9O1E8Xcm0Ku58tKcZX0BheGGG6Rp0RcTHiWPYW3k+AKO71ifTooTw3k6xATh3TJG5js08ndC0Xfw+7MQxbCwUArLYBAIysokz2QagYYvIVMoohIJRfR+zWtmTeQYADp0LyJgrnC5YsIAFCxac8jkul+uULdQlJ0d19c4ekW+u2hCKcspdc1H1BI4ohQyztNC5dxXZk+Yl2cD48V5oDC8eX2KuY7PuFji2X/w+bNYJD7vGXEDwYwtVgVoR1JhlXM9deOGaYNzP0C91G0AJQm4V5JSf8HBv5blw6LcM79wgrlmD/v36A0E+2BU6NjPbGJ7CgwzQUDAHuj6g4OjqJBsWGwmJ+Xj//fcpKSlh/PjxfOUrX6GlpeWkz/V4PHR0dPT5SWVMe+TS1SQCFbFA2bQTHk5z2tlmHQ9Azz7j7roUReH90BheZLYdV+MWcVswGtLzTni4tLSM3cowAHyHPkmiYfGltdvLhto2wIRjWL9Z3Fac0e/DacNn41Vs5AbbRC0Xg7L+UBudbtGJeHplntbmxJeGreK2/Ix+H3aXi81dWfcOURBQp8RdfFxxxRX8+c9/ZtmyZTzyyCMsX76cBQsWEAj0/yUsXbqU3Nzc8E9VVVW8TTIMvd4AK0Ml1U234zq6Q9wWjBS1IfrhQPoUAKxH1ibLqrhT09hJfbubNIeVs81WG6IpNIYlE/t9uDDTyWbGAtBrYAH54e6jKApMKMumPDdda3Pii3rkUjql34fLCnLZrowAQKnVd8zAqVA3cReOK8Zmol5RADSGxEfp5H4fTh82mW7FRXqwB47WJNGw2Ii7+Pjc5z7H1VdfzdSpU7n22mt59dVXWbNmDe+//36/z1+yZAnt7e3hn9ra2nibZBhW7WvB4w8yLC+dsWYqiANRC9ekkz6lJU94RLKbNwqXrwF5b6dw9Z4zuog0h4liBSCycJ1kDC0WCwfSxGPKYeMKSDXew3QbAIgsXCcZw4q8dDYERSaa96BxvVemHcNgAI7uFL+fRHxU5GexKTha/OOwfgVkwlNtR40aRVFREXv27On3cZfLRU5OTp+fVOW9KHe9qWIFICI+iiec9Cn+kil4FDtpvmOR2AKDYdogN4Cm0KRXcvIxbMoVAjLj6CZdu3xPRiCosHxXJGbHVAT8kZ3wSRauNIeNXQ7h2Qoa1PNR397LzoZOLBY4f6zJrsPWfeB3gyMD8kf0+5TyvDQ2KEJABnW8CUi4+Dh8+DAtLS2Ul58Y3CSJoCgK7+40abwHRNT6SVz2AKUFOWwLuXzR8UVzMtp7faw7KNqvm642hKJEjeHJvVfBorF0Kuk4Aj2R5xuIjbVtHOvxkZ1mZ6aZKtPCcQvXyJM+rSlXHMm4mreKtE6DoXofz6jKoyDTqbE1cUb1XBVPAGv/ntWiTBdbGAeAX8exVzGLj66uLjZu3MjGjRsB2L9/Pxs3buTQoUN0dXXx9a9/nVWrVnHgwAGWLVvGNddcw5gxY5g/f368bTcVOxs6OXysF6fdyjlmq6ioKFG75pOLj6r8DDYERcxAOC3XQLy3s4lAUGFsSRZVBSZrv95xRJRzttpFwOlJqCzIZlMwlP6nY5fvyXh7eyMgYgVM1dARoDlUOKxobJ+qmMfjKBjJUSUHa9AXCVA1EP/e3gDApWY7cgFo3i1uT+FBtlotNOYIAeloqRHp1Tok5qtr7dq1zJgxgxkzZgCwePFiZsyYwUMPPYTNZmPz5s1cffXVjBs3jttvv51Zs2bx4Ycf4nKZqMhLAnhrm7hgLhhbRIYz5gxofdNRB552sNj6VjY9jqqCDLYGR4h/GHDSU8dw/mQTppmrx2aFY8F+8t1kdUEGW5WQ+GjYkgTD4oeiKPzbzGPYEjr6Lhx7yqdVF2WyNRjyjDQY6zrsdPv4ONTQ0dRjWHTyeRQgu2gYdUoBFhRo3JYEw2In5lXuoosuOmXDmrfeemtIBqUqb20TO67LzXjBqJkuhWP6luQ+jurCjPCxi9KwGUsweModmp5w+wLhGi2mnPTCAcMn33EBDC/MYHlwuPiHwcTHnqYu9jV347RZzZdiC1Hi49QLV1VBBtuV4VzMJsOJj/drjuINBBlVlGmuLrYqAxzD6oJ0tu8fToWtVWzkqs9OgnGxYYyZ3eTUtvawo74DqwXmTSzV2pz406yq9VPvuHLSHLSmVeNRHFi8XdB2IPG2xYkVu5vp9QWoyE1jyjATBk23hNy9ReNO+bTqggy2KUJ8KA1bDRV0qnquzhlTSLaZCvyptOwVt6dduDLYrnogDSYg1TG8bHKp+YL2FSUG8RHZyOlVQErxoQPUC+askQXmC5ACEegGUHjyWAGVYYU57FRCtV4MNPGpY3j55DLzTXoAraHso1PEewAUZ7uosw2jV3Fi8XVH/p8B+Hco3sOUniuIWrhOPYbDowVk4zaRJWMAPH6Tex97WsAtGuWdUFb9OKoLMtke9kBK8SE5Cf/eZvJJT02bPc0FAyGXr3rRGCTuwx8I8s4O9djMhJ4riBIfJ8+SAFHro7Igi51KtbijYVOCDYsPdW29bD7cjsWs3kd3B3SLbLrTiY+KvHRqKaVLScPid0dEi875eG8LXR4/JdkuzjBbVVOIjENuFThOXfyuOnR0Bogj04AvwcbFjhQfGtPc5WHNwVbApPEeEPF8DEB89HUXGsPzsebAMY71+MjPcHDWiAKtzYk/vl7oOCx+H+gYGizuQw00nT0831wdUFVaQ0cumSWQlnvKpzrtVspyM9kRFpDGGsPLJ5diNVtVUxiw5wpE/FytUkKHkg4Bry4rnUrxoTHvbG9EUUTb52F5JivlDOLM/1ioR8QpaguoVEd7PnTqLjwe9cjl0oml5kvPhMj4uXJEJ83TUF2YES7RbRTv1Vtm9z4OMN5DZXihsa7DQFAJp0mbdwwHFu8BkOWyU5jpYoei302ACWdKY/FmOLXPhK5egPbDEPSBzQU5w0779OqCDHYq1QSxQGc9dB1NgpGDJxg0eXomRI7N8kcMqMtpXwGpv0nveFq6PHxyIOR9nGTSMYxh1wyq236E+IcBxMfaA600d3nJTrMzZ6TJ6iSpxCA+4LgjbB1eh1J8aEhrt5cVu5sBuHKqSSvAqkcu+SMGlDZbVZBBD2kcUEKLgM4nvnWHjlHX7ibbZef8sUVam5MYYjg2A1VAVhHAKuIMOhsSaNzQeX1LPYGgwrTKXKoLTVYcTmWAMTsqVccfnem819LLm+oAuGJyGU67SZe1YwfE7QA8yKD/jBeTjpIxeH1LPf6gwpRhOYwqNmFOOkQtXAO7YCry0nHarWwPhs6bdVogR+XljWLSu3xymfkayanEKD5GFGXixsUBJSSo1ZLQOuWVTfUAXDWtQmNLEkjbIXGbN3xATx9ZlMlupVIIyJ4W4YXUKb5AkDe2CoF71fRUGMPqAT19RGEGO3QsIKX40BBVrV9t5gsmhkwXAJvVwqiiTGqCoXRbtbiVDvEHgry+RUzKV59h4jFsjW0MqwsysFstbFfHUMcCsq6tl08OtGKxwKenm9T7CNAe6hY+wIVrdHEWHpzsJ/R33bg9QYYNnY/2NNPa7aUoy8k5o0165OLuiKTZ5lUN6L+MLslij1IhBKS7TXcCUooPjahv72VN6Jz502becakL1wBdhSAmvl1qrY8m/S5cH+1toaXbS2Gmk3PNOulBzN4rh83K8MIMdhjAe/XqZrEBOHNEAeW5Jgz4BlGno0N8TnIHtnANL8zAaoHtAVVA6td7pW7irpxabs6Ab4iIx/R8cGUP6L+oAvIQIVHdpC8BadKR0j+vbqpHUeCsEQVUmDHLRaVNzXQZMeD/Mro4M1Jo7GiNbqtkqkcupp70gkERNAwDdtmDmPjCtT50LD7CRy5m9j52HAElADYnZA0ssD3NYaOqIIOdOheQbl8gXCfJ1GOoHrkMUDwCjCrOBGBboFLcoTPvlUlnTP3zSmjHdZWZXb0Abaq7d+AXzeiSLA4pJXhwiRbgOqySKSa9FDhn7moQ2UoWG2QP/G91dElWZOE6WgN+b4IMHDz7m7vZcqQdm9XClVNMmuUCkV1zbmVMvZJGF2dFan3oVHy8X9NEl8dPRW4as6rztTYncbTFdmwGkOG0MywvnV06PcKW4kMD9jd3s/lwaNIza5YLiFbO7jbxewyKfXRxFgpW9hBS7DpzF4JoYNXp8VOem8bs4Skw6eUMA9vA+1COLs6ijkK6LZlCvKjt3HWE6rk6d0wRhVkmLCymoo5hDNcghDyQqoBs1qeAVI9cPj29wpyFxVTaYws2VRlVnEmNTo+wpfjQgBc3HAFSaNJLzwfXwLN5RhaF3IV+/YqPFzeIo4irTD/pxe65ArFwgYVd6HPnrChKeAxNHfANUVkSsY5hFvUU0G3JgqBfCBAd0d7j450domR86oxhbOJjTEkWNUpoHtXZEbYUH0kmGFT4xzox6S2cefqiW4ZmEOeUAJkuOxW5aZGLRmcLV0uXh2WhSW/hzEqNrUkwgxzD0aF25pt96hjqK2Bx7cFjHGjpIcNpY4GZj1wgatc88JgdINSSXr8C8uXNdXj9QSaUZTO5woSdpKMZtPcqi0NKKR6L/o6wpfhIMh/vbeFIWy/ZaXbzVsRUiTG9L5rRJVnUqOfNOjurfHHDEfyholTjywYWeW5YBun5yElzUJLt0m3Q6d/WiM/1qanlZLoGfpxkSIawcAFs0qmAfGGt+Fz/MavSnJ2koxmC9yqIlX0W/R29SPGRZF5YJy6Ya86oMG9RKpVB7poBJpbnUBMMTXqte0VzMx2gKAp/D3muPjM79s9lOAa5cAFMKM/RZbZEt8fPa6H6LKkxhoNbuPIznZTm6FNA1jR0svlwO3arhetmmNyD7O2BHlEJO9aNnLo52urVX70WKT6SSHuvjzdDlfg+MysFJr1B7poBJpXncJQ8Oiw5oAR1E7C49UgHOxs6cdqtXG3m+iwqg1y4ACaWZ0eC3boaoLs5joYNnte31NPjDTCiMIMzR5g4WBhEqnSHiDEbjAdyok4FpOr1uGRCibnj5iCS6u7MhrS8mP5rQaaT8ty0SOkCHcXPSfGRRF7ZVIfHH2RcaRbTKk/d1toUDGHXPLE8B7BEvB86Ueyq52r+5DJyMxwaW5NgFCVKQMYWLwBCQPaQRr1NLbOuj8XrhSjPlend9V2NoqW6xQbZsYvlSeU57FIqRaPHrkZdNHr0BYK8tFEIqhtSynNVPaDGjscjxlCKj5RFURT++on4I/rMrBSY9GBIno9RxZk47dZIgRwdXDQ9Xn84U+kzs0weaArQ0wq+HvH7ADoSH48aBLjVr5/A4T1NnXyyvxWrBa43e8A3RBaunIqYUqVVJlUIAdmgCkgdxAz8e1sjzV1eirJcXDi+WGtzEk/74L2PIMZwp1rro3Wfbo6wpfhIEhtq29hW14HTbuU/UmHh8rnFTgkGtWt22KyMK83SlWL/18Y6Ot1+hhdmcN4Yk3awjUad9LJKwZEW838fUZiJy25lW7hEt/YL119Wic906cRS85ZTj2YIQd8gds2gLwH5f6sOAHDjWVU4zFpZOJoheJChnyPso/pImU6BkdMHf1kpyoxfNa2C/EynxtYkAfWc2ZEp6nwMgknlUYpd42MXRVH4v9AYfmHOcHPX9lAZ4qRnt1mZUJYd1eNF22yJbo8/nOb+xbNjF8SGZAhB3wDDCzPJcNrY5teHgNzd2MmqfcJzdeNZgxNUhmOQNT5UJlWII+wdwZCnTwcbOZDiIym0dHl4dbOIrr95bqpMeqGeLnlVgzqnBBH3sVut9dFZB73H4mRc7Kw/1Mb2+g5cdiufmZ0CnisY0rGZysTynEi2xNGdosmZRry08QidHj8jUsVzBUMeQ5vVwoSy7KiMF20F5F9WiXnlskml5u6JFc0Qx7AqP4Msl50dAf14kUGKj6Twt7WH8QaCTKvMZXpVntbmJIch7poBplXm0UkG9YQWCg3rffzfygOAqKSYl5ECniuI2xgeUkpwh4sc7YuTcbHRx3N1dop4rmDIu2aAqcNyIz1emrQTkF0eP/9YLzyqXzx7hCY2aMIQx9BqtTC5IiroVCfB+1J8JBh/IBhW619IFVcvxGXXPGVYDk6bNdLWWyPF3tTp5vUtIkX6i6niuYIhxwsAzByeh4KVmvDxmTY750/2t7KzoZM0hzU10txV4iAgZw7Pp1Ypppc0CHhE3R0N+Of6w3R5/IwqyuSc0YWa2JB0/F7oFHMPuUO5DvMjR9jS85EavL61gSNtvRRmOs3ffyCaOEx6LruNycNEqh+gmWL/08cH8AaCzKzOY1plniY2aMIQ4wUAxpZkk+Wyax50+rsPhMfluhmV5k+RVumTKj2Ehas6HwVrVPxV8gVkIKjwhw9FafBbzhmROp6rjsOAAvZ0yBz8UeHM6vzIPNpZLzLZNEaKjwSiKAq/+0DsEm6eO8L8FU2jicOkB+Ki2alhS+gujz/srr/jgtFJf39NicMY2qwWzqjK07RK5u7GTpbtbMJigS+fPzLp768ZPS2RVOncwccpVeanU5TlYkdQOwH55tYGDrX2kJfhSJ2YK+hb5G8I5RlmVufRRQaHFe2PsFWk+EggH+9tYeuRDtIc1tRy10NcPB+gKvaovgSKMkTDYuP5NbV0uP2MLMrkskmlSX1vTfF0RgJ8h3B0BmLi07JK5u8/FF6PyyaWMqp44N2VDY+6cGWVgX3wVUAtFgszq/MicR9JHsM+m7izh5PhNHkvnmjiNI8WZrkYUZgROf7UwdGLFB8J5LchV+8Ns6soSIX0WpWAP6qk89AumhnVeexVKvArVnC3C5dhkvAFgvxxhXD1fun8kdhSxdULkUkvLQ9cQ2ueN2N4fqS8c/shMY5JoqnDzUsb6gD4rwtHJe19dUEc4q5URMyANuJj9f5WNh1ux2W3cvM5I5L63poTJw8yiI1cjY7qJknxkSC2Hmnng11HsVrgS+el2KTXWQ9KAKwOsesaAhV56ZQW5LBfUUt0J++ieXljXTheZ+HMFHL1QlwXrlnD8+myZHFECQUJJnEM/7BiP95AkFnD85k1vCBp76sL4pDponLmiILIwtVeC71tQ37NgfLr94XXY+GsSorM3sfleIbQW+l4zhxZoKt2FVJ8JIjH3hGN0K6aXkF1YYbG1iSZcKDiMLAO/U/svDFFUYo9ObsuXyDIL97dDcDt549MrXgdiBrDoS9cOWkOpldFH70kJ2CxqdPNn0Mp0osuTrF4HYibyx5gemUuSlpuVMxAchavtQda+WDXUWxWC/91QYpt4iBqDId+HYp5VLyO0rQ96UfYxyPFRwLYWNvGOzuasFrg3kvHam1O8mmP36QHcN6Y4ohiT1Kg1Ivrj3CwpYfCTCe3zB2RlPfUFXH0fACcP6YocvSSJLf9b97fh9sX5IyqPC4eX5KU99QVcRxDu83K3FGFURkvyRnDn4U2cZ+ZVcnwwsykvKeuiKP3qqogA3/+aPyKFYunI3I0rhFSfCSAn70tLpjrZlSmVoCbSlv8zikBzhldSA1i0vPVJ37X7PVHvB53XjiaTFcKBbipxHHXDHDe2OKw50NJwsLV2OHmL6tFltLiy8alRiPH4wkvXPEJdj9/bFFSK52u3tfCR3tacNgsLLp4TMLfT3fEMXZOZc7YcvZpcITdH1J8xJm1B1pZHnIT3nNpCl4wEGlIFqeFKz/TiVI8GQBrc03CKyy+sK6Ww8d6KcpypVZhuGji7PmYUZ3HQbtIcw02boNgMC6vezJ+9e4evP4gs4fnc/7YFCmlfjwJFJCBBG8CFEXhJ6FN3A2zq6gqSLGja4hr7JzK+WOLdBN0GrP4+OCDD7jqqquoqKjAYrHw0ksv9XlcURQeeughysvLSU9PZ968eezevTte9uqaYFDhB6+JY4GUdRNC3D0fABMnTaVLScMW9EJL4v6eOt2+sOdq0cWjSXemWKyHSpwXLofNSsXoqXgUOzZfd6T3TwLY09TJs58IAbz48hT1erjbwRPKKoqTgBxRmEFH7jgAlMbtCRWQb21r5JP9rTjt1tT0ekBU3FVlXGLnAM4ZU8TukPjoPLQpLq85WGL+RN3d3UyfPp3HH3+838cfffRRfvGLX/Cb3/yG1atXk5mZyfz583G73UM2Vu+8tPEIm2rbyHTaWHz5OK3N0Y4475oBrpxWEXb5ums3xu11j+dX7+6hucvLqKJMbpqTol4Pvwe6QiWd4yggr5hWGW4UqCTQbf+DV3cQCCrMm1jKOaNT3OuRUQjO+GyCLBYLk6fOxKM4sAd6oO1AXF73eDz+AP/vdbGJu+P8UanTQO54EjCP5qQ5sJZNAcBdp22H4pjFx4IFC/jhD3/Iddddd8JjiqLw2GOP8T//8z9cc801TJs2jT//+c/U1dWd4CHRgvZjrezetjYhr93t8fPImzsBWHTJGEqy0xLyPronGIz7rhlgfGk2h10iY+Hwzk/i9rrRHGju5o8fiboe//PpiTjtKXoq2S7azuPIEItXnLhkYgm7EGKmZd+GuL1uNO/VNLF811EcNgv//amJCXkPQxCH0vj9sWBaJbsU0Zrdc3hzXF9b5amPDnCotYeSbBdfuSgFs5RUEuBBBhg5aTb1SgF7vfmaZrzEdXbdv38/DQ0NzJs3L3xfbm4uc+bMYeXKlf3+H4/HQ0dHR5+fRLBl5Vvk/nwkWX//PEoCvvBfv7+Hxg4PVQXp/Oe5KVTC+Xi6j4rmUxYr5AyL28taLBbSK6cDiZn0FEXhh69txxdQOH9sUWpmR6hER9jH8cgiJ82Br3ASAMcSID48/gA/eFWcY996zghGFqXosSckZNcMosPtIYdIeT20I/6bgMYON796dw8A37hiQmoGe6uoR5NxSLONZu6Zs7ko8Gt+X/FDvAGTiI+GBuGqLS3tW4a6tLQ0/NjxLF26lNzc3PBPVVViOk6OmDgTgHKlkfW74nvevKO+g98uF9VM//vKialXEyIadeHKLgd7fKu6jpt+DgAlPbtp7vLE9bVf39LAOzuasFstPPTpSakZJ6CSoF0zQOWE2QC4WsXRSDx5/N097DvaTVGWk7tTMcU9mjhnuqhYLBZcw6YC0H5gY1xfW1EUvv3SVro8fqZX5XH9jPhtXgxJHKubRlOSncbGhy7nyVvP1NS7q7lfecmSJbS3t4d/amtrE/I+2XnFHHOIiOGPP3ovbq/rDwR58B+b8QcV5k8uZf7k+EQlG5b2+OWlH8/ISbMJYqHY0s6rH8cvWOpYt5fvvCxiEL568RjGlg6tnLjhiWNtgeOZNed8ACqDDazYHr9NwI76jnAlzO9fM4WctBTpXHsyEiggp8wQm4Cinj3UtvbE7XXf2NrAv7c3YrdaeGTh1NTpXHsy4ljd9Hj0EEgfV/FRViYW3sbGxj73NzY2hh87HpfLRU5OTp+fRGGtmAZA+/71tHZ74/Kav/1gH5sPt5OdZucH10xJ7R0zJHThwplJV6bYyW1et4JgHHbOiqLw0MvbaO7yMrYkKzUrYR5PgnZcAGl5ZXTZC7BaFD786MO4vKbHH+BrL2wKbwAWTEnxDQAk7NgFoGyc8F6NsDTy95U74/KaTZ1uHvpXaANw0WgmlCVuHTAE0bFziZhLdUBcxcfIkSMpKytj2bJl4fs6OjpYvXo1c+fOjedbDYrcEeLoZbxygKdCgYVDYe2BVn4aSsv8zlWTKclJ0SDTaBIQbBpNRpWI+yju2sUbW/s/youF59fU8sqmOmxWC4/+xzRcdu13BJqTwB0XgLVcRNt3HdrErsbOIb/e0td3sq2ug/wMh9wAqCRy4coswp1WDMC6NR/T6fYN6eWCQYXFz2+iucvLhLJsFl2Soqm10XQ1QNAHVjtkV2htTUKIWXx0dXWxceNGNm7cCIgg040bN3Lo0CEsFgv33XcfP/zhD3n55ZfZsmULN998MxUVFVx77bVxNn0QlAvPx2TrQZ766ABtPYP3frR2e7nnrxsIBBWuPaOChTNT/HxSJZGeD8BeIc6bJ1oP8tg7u4YUN7CzoYPvviLSzb52+XhmVOfHxUbDk6B4AZWMSnEdjrcc4ufvDK1my1vbGnj64wMA/OSG6XIDAODtgZ5m8XuCNgGu0HVY5dvHUx8dGNJrPf7eHlbsaSbdYeNXn58hNwAQuQZzhoHNnEG3MYuPtWvXMmPGDGbMmAHA4sWLmTFjBg899BAA3/jGN7j77ru54447OPPMM+nq6uLNN98kLU0Hk0KZuGDGWo/g8bjDZ8Sx4vYFuOPPa6lrdzOyKJMfXjdV7rZUEiw+KBVjOMVWy+6mLl7aMLj+BE2dbm5/ei1uX5ALxhWnZtOq/vB7RWVFSOAYCs/HJOtBXttSz5bD7YN6ma1H2rnvuY0A3H7eSC6ZUHrq/5AqqNegKwfS8xLyFpZQrYgJlkP8/sN9tAwyAPz1LfXhSqbfu2YyY0pSPN5KJdHzqA6IWXxcdNFFKIpyws/TTz8NiGjo73//+zQ0NOB2u3nnnXcYN04nBbdyqyAtDwd+xloO8+SK/Ww9EtvEFwgqfP3vm1l78BjZaXZ+f/MsslI5HSwaRUlovAAQFpCjqMOFlx+9viPmia/L4+fLf1rLkbZeRhVl8ovPnSGD21Q6joASBHsaZBYn5j1KRan8qfbDgMKD/9iMLxBbtczDx3q4/U9r6PUFOG9MEd9cMCEBhhqUBHuugLCAnJlWR6fbz/deib1U97qDrdz//EZApEbfMDsxXhpDoqbZJnIMNUbzbJekYrGEF68bq9sJBBW+9sImerwD6xXiCwS5//mNvLKpDrvVwm+/MEsq9Wh6WsAXin7PrUzMe+RUQHo+VgJcVnSM1m4v33pxy4CDT9t7fHzxydVsOtxOfoaDP956JnkZ8U0JNjTRHYkT5c0rHg8WGxnBLsand7C9voNfLhv48cuB5m4++9tVNHZ4GFuSxa+/MBOHLbWmslOiLlz5iRQfQkBOstZitSi8vKmO1zbXD/i/r9rXws1PfoLHH+TSCSV8+9OTEmWpMTmmig/p+TAPZeK8eWFFK4WZTnY2dHLPXzfiP83Oq73Hxx1/XsvLIeHxixtncM6YFC3dfDLUHVdWGdhdiXkPiyW863pwph+71cJb2xr58b9rTvtfa1t7+OzvVrLhUBt5GQ7+9J9nMSKVC1H1R4KDTQHxt1EkvKHfPUvc9Yt39/DihsOn/a/rDh7jht+uFF6r4kz+7/Y5Mq32eNqSsHAVjQOrHZu3g6+dLTp3L/7bRtYfOnba//rq5jpufeoTur0Bzh1TyC8/PwOb9Dz2Rb0OEykgNSYFxYfwfGS0bud3N8/Cabfyzo5Gbnt6zUkDUN+vaeLTv/qQ92qO4rRb+d3Ns7hyankyrTYGyTqnDI1hlXsPDy8UYvKJ9/fy3y9uweMPnPD0YFDh+TWH+NQvPmRnQydFWU6eu+NsplXmJdZOI5KsMQztnOdm1fPl80VF4Af+tonff7CvXy+W2xfg5+/s5rO/XUlTp4cJZdk8f8dcynJ1EEumN44lwWVvd0LReAD+a1wPl0woweMP8sU/rObNk2ShHev28s1/bOauZzfg9gW5ZEIJT95yJhlOeWx9AikQ85F6ox5auGjYwqzqfH79+Znc/dcNfLi7mQt//D43zanmrJEFOO1Wdjd28a+NR1h/qA2AqoJ0fv35WUytzNXOfj2TrAum/AxxW7+R/7iykrYeEfvxzOpDvF9zlC+cPZxplbkEggpbjrTzj/WH2Xe0GxCt3R///MzUbVZ1OpJVW6B0Mmz9OzRuY8n1i+l0+3luTS0/en0HL244wufOqmJsSTa9Pj9rDhzj7+sOc7RTxPZ8amo5j/zHNBlrdTKSKSCbtmE7up1f3ngfX/rTWlbua+HOv6zj/LFFXHPGMCrz02nr8bJiTzMvbaijyyOOuO+8cDRfu3wcdnlcdiLBQKS/khQfJqJ4PNic4OmAYweYN2kkf//KXO5/fiO7Grv49ft7T8iCsVst3HbuCO65dCzZ0sV7chJY2KgPFSLTivrNEPDzpfNHMbIokyX/3MKRtt5wg79osl127p03llvOGSHjA05FuDJmohcucXRG4zasVgtLr5/KpIocHn2zhu31HTz0rxM7bpblpPHfn5rIp6eVy+yyU5GMmA8Q4mML0LiNTJedP99+Fj9+q4Y/rtjPh7ub+XB38wn/ZWJ5Dt+7ejJnjSxIrG1GpjO6xod5PeypJz5sDnHR1G0QPwUjmVyRyxv3XsDrW+p5e3sjOxs68AcVqgsymDOykIWzhqVul9pYSNaOq3AMOLPB2wnNNVA6mUsnlvLBN4p4Yd1hltcc5UCL8HSMLcnivNAuTO6UB0CSj11o3g0+NxZHGjfPHcFV0yp49pNDfLy3mfp2N06blYnlOVw6sYTLJ5WlbqfhgeLugN5Q3EXCxzAiIAEcNivfunIiX5gznL+uOcQn+1tp6/GS4bQzZVgun5pazrljCqVwPB3hDUAlWM1b8yQ1Z+OKmUJ4HFkHU64HwGa1cNX0Cq6abs5qckkhXN00wZOe1QoVZ8CBD8U4hhayNIeNL549nC+ebd4grYQS8ItUW0i89yqnAtLywN0mBGS5qFybn+lk0cVjWHSxrHI5KNSFK70AXAnOxFMFZIsQkDjEBq26MIMHr5Cpz4MmGQHDOiA1txHDZonbuvi39U5ZFCW5QVIVZ4jbI+sT/16pQmcdKAGwOkTGUiKJSnunPn5NAlOeZGZJZJcJkaMEw94PSRxIRp0WHZCi4kP0eKFuowjukQyd3mPiGAQSv2uGSNyHFJDxoy0qZseahKlBHcMj6xL/XqlCMnfNFktkLpVjGD9SoMAYpKr4KBoHzizwdcPR09eHkAwA9YLJLAFHEjJJ1IWrcasoCS4ZOsdCzRaT5e6tFN1ROSwXrriRjDTbaIaFxvDI2uS8XyqQAmm2kKriw2qLpGvWSbd9XGjdJ24LktQjJX+kiBkIeKEp9tLOkn5oDYmPZI2hunA1bQNvd3Le0+wke+EKC0gpPuKGFB8mJ+wulOIjLiRbfFgsUUcvcgzjQrLHMHeYSCVUguIIVDJ0wmm2I5Lzfmr8XOte6GlNznuamRSp8QFSfMizynjRekDcFoxM3nvKuI/4oh675CdxDNXFS16HQ0dR4NgB8Xuyjl0yCiJiVW4Chk77YQj6RS2q7AQHfWtMCouP0KTXuE2kiUmGRrJ3zRAlIKX4iAtajGGljBmIG11N4O0CizW5PUGGydiduNEaKnCZP9LUNT4glcVHbhVkFIlKco1btbbG+KgLV1J3zVExA57O5L2vGelpBXe7+D1ZLnuQC1c8UReu3MrENXbsDykg40dLaAwLR2trRxJIXfEh08Tih7cbukLNpJJ57JJTLs5FlSAcXpO89zUjarBpdjk4M5L3vhVnABboOCzKSksGT9hzleSFS/UiH14rjn4kg0cL76NGpK74AKgM9fM+tEpbO4yOes6clifOgJNJ1dni9tDq5L6v2TiW5EwXFVc2lEwUv8tNwNDQatdcNlXEKPS2Rv6OJINDio8UoVpduFZJxT4UwimaSfR6qFTPEbe1UkAOCS2OzVTCO2fpvRoS6rFLsj0fdlekWq08Phsa8tglRRg2S3QO7KyLdGSVxI6Wal31fBxeK3qTSAaHlgKyKiQgpQdyaLSErkMtFq6wF3ll8t/bLAT8ES9ysgWkBqS2+HBmhBtayYlvCGgpPkomgitHRPk3yf4SgyY8hhqIj+HniNsj62Tm2WBRFG2vQ3UMpfgYPO21IgHC5oKcYVpbk3BSW3wAVM8Vt/KiGTxa1IdQsdqg8kzxu4z7GDxaxXyo75lZIqrVyloRg6OzQbSLsFi16Qmiio+m7bLY2GAJH5uNTE5vJY0x/yc8HdFxH5LBoXWQlDqGMu5jcHi6oKtR/K6FgLRYIovXwY+S//5mQF248qrB7kz++2cWiZ5ZIOfSwRI++jT/kQtI8RE5b27aITqzSmLD1xvphqqV+AjHDEjPx6Bo2SNuMwohPU8bG8LiQ3ogB4VWabbRhI9ePtbOBiMTDjY1f6YLSPEBWSWhC1aRin0wtOwBFEjLFd+lFlTOBotN1IpQu3pKBk7zLnFbNF47G9SFq3a1DBweDHrIkqhWBaQUH4NCq2wljZDiA2Dk+eJ2/wfa2mFEohcui0UbG5yZkXRNOYaxc7RG3BaP086GkkngyhWBw41btLPDqOhh4VIFZP0mcZQniQ1VQKZAjQ+Q4kMw8kJxu2+5tnYYkaMh8aHlwgUwKjSG++UYxkxzSHxo6fmw2iKxO/s/1M4Oo6Jeh0VjtLMhrwpyq0VjNBl/FRs+dyTou1jD6zCJSPEBMPICcdu0DbqOamuL0dDDwgURAbn/A1kwLlaad4tbvQjIfe9raobh8Hsjno/iidraMio0l8oxjI2WPaJNRFouZJVqbU1SkOIDRKR26RTx+wHpto+J8MKlsfioPBPsaSJrQz1GkJyegD/i7i3SWnxcLG4PfizrfcRC617hbXDlQE6FtraoY7j3fU3NMBxHd4rb4gnaHV8nGSk+VOTRS+wEAxHxofXC5UiLctvLMRwwx/aLwkaODMip1NaWkomQVQb+XhF4KhkYTTvEbbGGcVcqoy4St41boKtJU1MMRbT4SBGk+FCRMQOx03YQAh7hccir1toaKSAHg+olKhqrfWEjiyWyeO17T1NTDEU4YFgHsQKZRVA2Tfwuj14GjhQfKczwc0Sfl2MHIjnzklOjTnqFY0XAoNaoAvLAhxDwaWuLUVAnPa09VyqjVbe9FB8D5qjq+dA43kNFFZByDAdOU+g6LJHiI/VwZUdKre/6t7a2GIXGreK2dJK2dqiUz4CMIvB0yHL5AyU8hpO1tUNFXbjqN0F3i6amGIaw50MnC5cqIPe9J4O/B4LfE9nw6mUMk4AUH9GMmy9ud72prR1GoSFUj0EN1tUaqxXGXi5+3/WWtrYYhQZVfEzV1g6V7LLQ35MCe5dpbY3+8XsjFWr1smuungv2dOisj4hbyclp2QNKQNS5yS7X2pqkEXfx8d3vfheLxdLnZ8IEnVwUp2NsSHwc/EgWyRkI6sJVphPxATBOio8B4+2JpGjqagyvELc1r2trhxForglluuTqpxOqIz3i/ah5Q1tbjIC6iSuZqH3AcBJJiOdj8uTJ1NfXh39WrFiRiLeJP0VjRWOtgFcGS50Ob3fEVagXzwfA6EtE7E7L7kgKqaR/mnaI2gIZRfqqLTD+SnG7+x2xs5ecnPpN4rZ8mr4WrvELxK0UkKenfrO4LZ+urR1JJiHiw263U1ZWFv4pKipKxNvEH4tFHr0MlMbtgCJaoWvV06U/0nIjsTu7ZezOKVHLmJdN0dfCVTFDiCFvJxw0yMZFK9SFS80w0QvjrgAsULcBOuq1tkbfNKjiQ2djmGASIj52795NRUUFo0aN4qabbuLQoUMnfa7H46Gjo6PPj6ZExwwEA9raomeiFy69Id32AyMc76GzMbRao8ZQuu1PiV4XrqwS0fAR5EbuVChKZAz1JiATTNzFx5w5c3j66ad58803eeKJJ9i/fz/nn38+nZ2d/T5/6dKl5Obmhn+qqqribVJsjDhf7J67m2TGxKnQ68IFMPHT4vbAClku/1SowYBlOgk2jUY9eql5Q2ZMnIxgMBIvoMeFSx69nJ62g+BuB6sjpTJdIAHiY8GCBXzmM59h2rRpzJ8/n9dff522tjb+9re/9fv8JUuW0N7eHv6pra2Nt0mxYXfChKvE79te1NYWPVO3Xtzq8Zwyf4Rw3StB2PmK1tbok4A/Kl5Ah2M46kJwZEJ7LRxZp7U1+qR1n+gCbE/TT52WaMZ/StzufQ96j2lri15Rj81KJoq1J4VIeKptXl4e48aNY8+ePf0+7nK5yMnJ6fOjOZOvFbfbX5ZHL/3hc0c8H6prVW9MulbcSgHZP0d3gK8HnNn6XLgc6TAh5P3Y8ndtbdErDSHxWDoZbHZtbemPkglQMlmU798hNwH9otdjsySQcPHR1dXF3r17KS83UP7yyAshLU8evZyMhi1iQskogrzhWlvTP6qAlEcv/aN6E4bN0Ed12v6YslDcbntRbgL643BoDMvP0NSMUzLlenG79R/a2qFXDq8RtxUztLVDA+IuPr72ta+xfPlyDhw4wMcff8x1112HzWbjxhtvjPdbJQ67EyaE4ga2/lNbW/TIkbXitnK2vrIkook+etnxL62t0R+HQ2M4TKeeK4DRl4r4q64G0elW0he1+V7VHG3tOBWqgNz/AXQ2amuL3ggGIgJSz2OYIOIuPg4fPsyNN97I+PHjueGGGygsLGTVqlUUFxfH+60Sy9TQRbP1H7K99/EYYeECmPIf4nbjX7W1Q4+ong+9HpuB2ARMvFr8vlUevfTB547E7FSdpa0tp6JgJAybJTYB21/S2hp90bRDpJM7s6BEJy0qkkjcxcdzzz1HXV0dHo+Hw4cP89xzzzF69Oh4v03iGXmhqBjobpPR2scT9nzM0taO0zHtBrDYhL1q4yYJeDojbdiH6XwMp4YE5NYXRUVWiaB+ozj6zCwRXj49o24CNslNQB9Uz1XlbP0efSYQ2dvlZFhtcMbnxe8b/qKtLXqis0F0/sUCFTO1tubUZJVE6kVslGMYpvYTQIHcatFLRc+MuADyqsHTDjte1toa/RA+cjlLv0efKtNuEKmkdRsi2R2S0HVISh65gBQfp0YVH3vfhfYj2tqiFw6EKk6WTYX0PE1NGRAzbhK3m56HgE9bW/SCOoYjztPWjoFgtcKMm8Xv6/6krS16wkgLV2ZRpPbOejmGYcKeDx0fmyUQKT5ORcEoGH4uoEjvh8r+D8TtyAu0tWOgjL0cMotF5pKstCg48KG4NYL4ACEgLVY49DEc3aW1NdoTDEQEpNpKQO/MvEXcbn5BHp8BtNXCsf3iWLjqTK2t0QQpPk7HrNvE7donZZMriFq4ztfWjoFic8CML4jfV/9WW1v0gKcTjoQKxI00yBjmVEQ6Tq97WlNTdEH9RhGL5so1TormyAtFWr6nHbbJDMJw49Jhs0RGVwoixcfpmHQNZJdDV6MsWNV+RFRVtFhhuEF2XABnflnsMA58KM+cD60CJSAWgrxqra0ZOGfeLm7X/1mUo05l9r4nbkeer8/iYv1htcLs/xS/r3xclszfFxrDURdpaoaWSPFxOuzOyMS36tepfdHsXy5uy88wllrPHRYpOpbq3g91x2UUr4fKmHmi94W3U8Z+qGNotIVr1q0irbRpO+xdprU12hEMwr7QXGq0MYwjUnwMhFm3gc0l3J2pXOxI7TA69jJt7RgMc74ibrf8TWTspCrqGI4x2BhaLDD3LvH76t+kbvCwtycSqDjqYm1tiZX0PJgZCh7++FeamqIpjVuhp1n0LqpMzXgPkOJjYGQWRTJflj+irS1a4feIrB+IpK8aiaozoepsCHjho59rbY02NO+G1r0i7XH0JVpbEzvTbhB1LTqOwJYXtLZGG/a+K/6G86qh0ID1k+bcKY5A970HdRu1tkYb1A3AyAtSrplcNFJ8DJTzF4tJe//y1PR+HPhQdNDMLtd3L4lTcdGD4nbtH1PT+6FOeiPOgzQdNHCMFbsL5n5V/L78kdT0fqgN2iZcpf/6Hv2RPzxScv29H2lri1aonbbV9OMURYqPgZJXHcmaeP9hbW3RAnXhGjdfBI8ZkVEXi7oIfndqej/UMRx/pbZ2DIWz7hDej2MHUi/93e+FXaExnHiVtrYMhYu+Kbwfu/8Nh1ZrbU1yOXZANOa02GDcAq2t0RSDriIaEe392POO1tYkj4APtr0kfp9gYLVusYiJD2DNk6FKrSlC++FIh+bxBp70nJlw/gPi9+WPplbfpQMfikyfzBJ993M5HYWjI8X/3v1BagXx73hV3I44FzILtbVFY6T4iIW8arHzAnjzW6nj9t37ngiQyigyfnT2qIvFZwh44N/f1tqa5LHlBUCB4edBXpXW1gyNWbdCTiV01sHHv9DamuSxJdRcb8KnjN8L5IJvgM0pBNXO17S2Jnls+Zu4VRsmpjBSfMTKhd+AjEJorhGxA6nA5ufE7dT/EEW7jIzFAvOXilolO16OVGw1M4oiysuDCNo0Oo40uPz74vcPfwLHDmprTzJwd0S6wqrB70YmrwrOuVv8/uaS1Kh6Wr9ZdCK2OSNxLymMFB+xkp4HF/+3+P3dH0JHnabmJJzeY7Az1NV32me1tSVelE6C2aHaLa89YH7Xfd0GOLpDpItPukZra+LD5OtFlV2/WyxeZmfbi+DrgaJx5knPPP8B4cFqPwQrfqq1NYln4zPidvyVkFGgrS06QIqPwTDrVlEW19MBr9xn7jPL9f8H/l4omWycUs4D4eJvQVYpNO+C5SYPIP7kd+J20tXGaAY4ECwWuPLHYLVDzWuRIwkzoiiRhmwzvmDMLJf+cGbCFUvF7yt+JrwCZsXbA5tD3scZX9TWFp0gxcdgsNrgml8L99nut2DTX7W2KDEE/PDJ78XvZ99pnkkPxM7j0z8Tv3/0czi8Tlt7EkVXE2z9h/hdLbRmFkomwgVfF7+/tti8XshDK+HIOuG5mn6j1tbEl4lXifiHoB/+eYd5vZAbnxFe5LzhMNpgxeEShBQfg6VkQiRz4rWvwdEabe1JBDteFi7RjEKY+hmtrYk/Ez4lPpcShL/fKiYHs/HJ70VRqsozoXKW1tbEn/MfEB45dzu8eKfo+Go21LTwMz4PWSXa2hJvLBaxCcgsgaM74e2HtLYo/gT8sDJU0XXuXcYPFo4TUnwMhXPvE1XqfN3wt5vB2621RfEjGIjUMznzy+BI19aeRHHl/4rdSNuh0OIV1Nqi+NHdIvoRQaQ0udmwOeC634EjQ6TAL/u+1hbFl/rNsOtNwBIJ0DQbmUVwTWhx/uS3keBos7D1HyKtP70gkmIskeJjSFhtsPBJyCoTqv0fXzbPzmvz30RGT3p+pKqkGUnPgxv+LFzau94UdQfMwkc/E1Vpy6aZO7WveFxk8froMfPEfygK/Pt/xO9TFhqznPpAGTc/coT2yj1wZL229sQLnzsyp8xdJOJcJIAUH0MnqySyeNW8Bm98w/gBqO6OyA7yvPuN1cF2MFScAVc9Jn5f8VNY8wctrYkPR3dFOvhe+pBxq9IOlCkL4Zx7xO8v3hlpO29kdr8tvDk2J1yaAjVpLloCYy8XGUzPfAaa92ht0dBZ/QS010LOMDjbxJu4QWDyGSlJVM+Bhb8HLGLheuc7xhYg7/5AFHAqGBUpqmZ2zvg8XPQt8ftrX4ONz2prz1BQFHj1fhHrMeYy0Y4+FZj3XZFKHPTBczfBoVVaWzR43B0iiBZgzn9B/ghNzUkKqie5fLooavh/1xq7CnHLXng/1Ij0km+DM0Nbe3SGFB/xYtI1IvUPRIDYm0uMGT+w991Ihsunf2beWI/+uPAbcOaXAAVe+qpxi8it/g0cXCHiID71E3NlKZ0Kqw2u/72oYOvrhj9fC7v+rbVVg+Pf/y12zHnD4cJvam1N8kjLgZv+AYVjxOd/cj40btPaqtgJ+MUc4u8Vf4/TP6e1RbpDio94ctaX4VOhYjmrn4C/fRE8XdraFAvtR0TcCgrMus34pdRjxWIRAahn/RcQ8h78+3+MFcdTuyYSJ3DZ90UX0VTC7oLP/TXkvu+F5240XgO69f8H6/8sfr/mcXBlaWtPsskqhltehZJJ0NUATy2A/R9qbVVsvP0Q1K4CZxZc/cvU2QDEgBQf8ebM28Xuy+aCna/CH+dD826trTo9vW3wzH8Id2fZVLjC5IW3TobFAgsegQsfFP/++Jfw189Bd7O2dg2Elr3C1qAfJl0b8uKkIM4M+NyzIo066Id/LYKX7zFGDYl9yyPHLRf/N4w8X1t7tCKnHG57XXShdrfDn68WhciM4E1e+0dY9bj4/donRE8wyQlI8ZEIpt0At74mctcbt8JvzofVv9PvhdPdAn+5Hpq2i6qfn/2L6J+RqlgsogLqf/wR7Gmi9fevz46UmdcjR2vgz9cI8Vg+XWR/pPJuS03BvehbgEVUCP3DpaJYl17Z/6EQjwGvKL51/te0tkhb0vPhiy+JwmpKEN75LjyzUN+9fNb/GV4NiccLviGqCkv6xaIo+oqM7OjoIDc3l/b2dnJycrQ2Z2h01MFLX4F974t/V54lPAp6KvbUsldElrfuhbQ8IZrKpmhtlX6o3ywqLx7dIf498Sq47AdQMFJbu6I5+DH89UZwt4mz8tveFK5riWDvu/CPL0FPi2goeNZ/hRpE6qi/xoa/iFYNQR+MvkQcHaXyBiAatbz8698Q3agdGcIzOedO/XxHwaAI1Fd71Jx1Byx4NOU2ALGs31J8JJpgMJIB4wt1bpx8HZy3GMqnaWeXosCG/4M3vimC83Kr4Qt/h+Lx2tmkV3xueO+HsPJxsQOzOUVMzDl3a9uePuCD5Y+Izq5KEIbNhs//DTILtbNJr3Q3iyBwtaW5K0ek5p71ZW373fQegze/BZtC2VWTroXrfpNagd4DpXkPvHKvCKYGkb564Tdg2ue0FSHHDsDLd0c6ZJ+3WGS3mD29vR+k+NAjHXWw7AeRSQbEDmfmzaLLod2VPFsOrQoFRK0W/x5+rkhxyylPng1GpHE7vPUt2BeqIWG1i8Vixhdg5IXJm2yCQRFP9M53hccKRMfhTz8m0/lOx55l8PZ3oHGL+LcjQ3x3Z94OpVOSt1P1e4SLfvmj0N0EWMRCeuE3U3LRGjDBoOil9d6PoOOIuC+jUDT7nHVrcuMrettE2fSVj4uNpT1d1AtK4cwWKT70TMMWETi17UWxWwVx3DHxKlHlb9RF4MqO//v2tkHNGyIY6vAn4j5HhuhPI/sNDBxFEcdoK34mCkCp5FaJMRx7uRBzdmf837uzAba9JDxpLaEg5owiuPJRUWRLMjCCQdj2T/jwp9AUlcZZNE6IyfELRFVYmz3+792yVzQZ2/CMyOQAKBwrslqq58T//cyKzy3mspWPQ8fhyP2VZ4qyB2MvF+MZbzGpKFC/UQjHzS+At1PcP/xckdVi5iq0A0CKDyPQuk+c8278qyjopWJ1iGyTytnCjV4yQRT7ikWQKIpw59ZvgsNr4dDHIpgt6Iu8xxk3imA86e0YPHUbxdHVlhdERL6KI1M0O6ucBcNmQdF4USQqFtewogixUbdBiMUDH8HhNUDocnXliOJT596bGLGaCiiKiJf55LdCmAe8kcec2TB8rojTKp0ExRPEGMYi0oNBcZ03bBJdk/e8I1oWqGSXwwVfgxk3J0aspgIBP9S8Dmt+H0rHjVrOMouFKKicLTogF0+EnIrYBEnAJ4K5GzYLj/GeZX3FTvEEkZU08aqUi+/oDyk+jEQwIM4Kd70Ju96CY/v7f15miRAK6QUiUM6RISZCi02ICncHeDrE2faxg+BpP/E1iifA5OuFezK7NKEfK6XwuUVGzO63REnsrsZ+nmQRZ9RZJWL80guEGLFYxRgGPJEx7GoS58hqjFA0lWeJFNIzbpSiI56426HmTdHJ+cCHfcWkitUB2WXiJ7NEHHE50kVafdAnjlJ8vWL8Ouugo16MazQWq/BuzrwldNwqRUfc6GyAHa+In9rVokz78djTQ2NYLmKjHBkio83mFGPl94oYuM5G6KwXP0H/ca+RJjpiz7wFRpwvj8mikOLDyBw7ILwVR9aJ5kote0T65GDIqxaLVeWZIr6keFxcTZX0QzAodreH18KRtcI70rpPiIpYsViF67jyTPEz9jKxc5MklmBApMgfWCGynZq2i93v8UJiINjToHSyOMYZdaEQHun5cTdZchx+j5g/D66Ahq2i8WfzblAGUTDQlSO80eVniHl0xLkyIPgkSPFhNtzt0Lpf7Kh6W6GnVeyKlaCYKK02cYGk5YiJLX+EKMssgw/1gaKIMTu2H7qPipTPnlaxmAWDoQwau4j9ScsVnpH8kSKORO6M9UEwIILGOxvEbrj7qPBy+HvFQmdzih97mkhzzq4QnsqcysTEjkhix+8VQarqGPa0CO+Ir1ccudlc4nqzpwvPsDqG2RXSuzFApPiQSCQSiUSSVGJZvxMm5x5//HFGjBhBWloac+bM4ZNPPknUW0kkEolEIjEQCREfzz//PIsXL+Y73/kO69evZ/r06cyfP5+mpqZEvJ1EIpFIJBIDkRDx8dOf/pQvf/nL3HbbbUyaNInf/OY3ZGRk8Mc/GrRFuUQikUgkkrgRd/Hh9XpZt24d8+bNi7yJ1cq8efNYuXLlCc/3eDx0dHT0+ZFIJBKJRGJe4i4+mpubCQQClJb2rSNRWlpKQ0PDCc9funQpubm54Z+qKg17ZUgkEolEIkk4mucPLVmyhPb29vBPbW2t1iZJJBKJRCJJIHFPQC8qKsJms9HY2LfKY2NjI2VlZSc83+Vy4XIlsamaRCKRSCQSTYm758PpdDJr1iyWLVsWvi8YDLJs2TLmzp0b77eTSCQSiURiMBJSem/x4sXccsstzJ49m7POOovHHnuM7u5ubrvttkS8nUQikUgkEgOREPHx2c9+lqNHj/LQQw/R0NDAGWecwZtvvnlCEKpEIpFIJJLUQ5ZXl0gkEolEMmR0UV5dIpFIJBKJpD+k+JBIJBKJRJJUdNfrWT0FkpVOJRKJRCIxDuq6PZBoDt2Jj87OTgBZ6VQikUgkEgPS2dlJbm7uKZ+ju4DTYDBIXV0d2dnZWCyWuL52R0cHVVVV1NbWpmQwa6p/fpDfQap/fpDfgfz8qf35IXHfgaIodHZ2UlFRgdV66qgO3Xk+rFYrlZWVCX2PnJyclP2jA/n5QX4Hqf75QX4H8vOn9ueHxHwHp/N4qMiAU4lEIpFIJElFig+JRCKRSCRJJaXEh8vl4jvf+U7KNrJL9c8P8jtI9c8P8juQnz+1Pz/o4zvQXcCpRCKRSCQSc5NSng+JRCKRSCTaI8WHRCKRSCSSpCLFh0QikUgkkqQixYdEIpFIJJKkkjLi4/HHH2fEiBGkpaUxZ84cPvnkE61NShhLly7lzDPPJDs7m5KSEq699lpqamr6PMftdrNo0SIKCwvJyspi4cKFNDY2amRxYnn44YexWCzcd9994fvM/vmPHDnCF77wBQoLC0lPT2fq1KmsXbs2/LiiKDz00EOUl5eTnp7OvHnz2L17t4YWx5dAIMC3v/1tRo4cSXp6OqNHj+YHP/hBn54TZvoOPvjgA6666ioqKiqwWCy89NJLfR4fyGdtbW3lpptuIicnh7y8PG6//Xa6urqS+CmGxqm+A5/Px4MPPsjUqVPJzMykoqKCm2++mbq6uj6vYeTv4HR/A9HceeedWCwWHnvssT73J/Pzp4T4eP7551m8eDHf+c53WL9+PdOnT2f+/Pk0NTVpbVpCWL58OYsWLWLVqlW8/fbb+Hw+Lr/8crq7u8PPuf/++3nllVd44YUXWL58OXV1dVx//fUaWp0Y1qxZw29/+1umTZvW534zf/5jx45x7rnn4nA4eOONN9i+fTs/+clPyM/PDz/n0Ucf5Re/+AW/+c1vWL16NZmZmcyfPx+3262h5fHjkUce4YknnuBXv/oVO3bs4JFHHuHRRx/ll7/8Zfg5ZvoOuru7mT59Oo8//ni/jw/ks950001s27aNt99+m1dffZUPPviAO+64I1kfYcic6jvo6elh/fr1fPvb32b9+vX885//pKamhquvvrrP84z8HZzub0DlxRdfZNWqVVRUVJzwWFI/v5ICnHXWWcqiRYvC/w4EAkpFRYWydOlSDa1KHk1NTQqgLF++XFEURWlra1McDofywgsvhJ+zY8cOBVBWrlyplZlxp7OzUxk7dqzy9ttvKxdeeKFy7733Kopi/s//4IMPKuedd95JHw8Gg0pZWZny4x//OHxfW1ub4nK5lL/+9a/JMDHhfOpTn1L+8z//s899119/vXLTTTcpimLu7wBQXnzxxfC/B/JZt2/frgDKmjVrws954403FIvFohw5ciRptseL47+D/vjkk08UQDl48KCiKOb6Dk72+Q8fPqwMGzZM2bp1qzJ8+HDlZz/7WfixZH9+03s+vF4v69atY968eeH7rFYr8+bNY+XKlRpaljza29sBKCgoAGDdunX4fL4+38mECROorq421XeyaNEiPvWpT/X5nGD+z//yyy8ze/ZsPvOZz1BSUsKMGTP4/e9/H358//79NDQ09Pn8ubm5zJkzxxSfH+Ccc85h2bJl7Nq1C4BNmzaxYsUKFixYAKTGd6AykM+6cuVK8vLymD17dvg58+bNw2q1snr16qTbnAza29uxWCzk5eUB5v8OgsEgX/ziF/n617/O5MmTT3g82Z9fd43l4k1zczOBQIDS0tI+95eWlrJz506NrEoewWCQ++67j3PPPZcpU6YA0NDQgNPpDF90KqWlpTQ0NGhgZfx57rnnWL9+PWvWrDnhMbN//n379vHEE0+wePFivvWtb7FmzRruuecenE4nt9xyS/gz9ndNmOHzA3zzm9+ko6ODCRMmYLPZCAQC/OhHP+Kmm24CSInvQGUgn7WhoYGSkpI+j9vtdgoKCkz3fYCI+XrwwQe58cYbw43VzP4dPPLII9jtdu65555+H0/25ze9+Eh1Fi1axNatW1mxYoXWpiSN2tpa7r33Xt5++23S0tK0NifpBINBZs+ezf/7f/8PgBkzZrB161Z+85vfcMstt2hsXXL429/+xjPPPMOzzz7L5MmT2bhxI/fddx8VFRUp8x1I+sfn83HDDTegKApPPPGE1uYkhXXr1vHzn/+c9evXY7FYtDYHSIGA06KiImw22wmZDI2NjZSVlWlkVXK46667ePXVV3nvvfeorKwM319WVobX66Wtra3P883ynaxbt46mpiZmzpyJ3W7HbrezfPlyfvGLX2C32yktLTX15y8vL2fSpEl97ps4cSKHDh0CCH9GM18TX//61/nmN7/J5z73OaZOncoXv/hF7r//fpYuXQqkxnegMpDPWlZWdkIAvt/vp7W11VTfhyo8Dh48yNtvv92nnbyZv4MPP/yQpqYmqqurw3PiwYMHeeCBBxgxYgSQ/M9vevHhdDqZNWsWy5YtC98XDAZZtmwZc+fO1dCyxKEoCnfddRcvvvgi7777LiNHjuzz+KxZs3A4HH2+k5qaGg4dOmSK7+TSSy9ly5YtbNy4Mfwze/ZsbrrppvDvZv7855577gmp1bt27WL48OEAjBw5krKysj6fv6Ojg9WrV5vi84PIbrBa+05vNpuNYDAIpMZ3oDKQzzp37lza2tpYt25d+DnvvvsuwWCQOXPmJN3mRKAKj927d/POO+9QWFjY53Ezfwdf/OIX2bx5c585saKigq9//eu89dZbgAafP+4hrDrkueeeU1wul/L0008r27dvV+644w4lLy9PaWho0Nq0hPCVr3xFyc3NVd5//32lvr4+/NPT0xN+zp133qlUV1cr7777rrJ27Vpl7ty5yty5czW0OrFEZ7soirk//yeffKLY7XblRz/6kbJ7927lmWeeUTIyMpS//OUv4ec8/PDDSl5envKvf/1L2bx5s3LNNdcoI0eOVHp7ezW0PH7ccsstyrBhw5RXX31V2b9/v/LPf/5TKSoqUr7xjW+En2Om76Czs1PZsGGDsmHDBgVQfvrTnyobNmwIZ3IM5LNeccUVyowZM5TVq1crK1asUMaOHavceOONWn2kmDnVd+D1epWrr75aqaysVDZu3NhnXvR4POHXMPJ3cLq/geM5PttFUZL7+VNCfCiKovzyl79UqqurFafTqZx11lnKqlWrtDYpYQD9/jz11FPh5/T29ipf/epXlfz8fCUjI0O57rrrlPr6eu2MTjDHiw+zf/5XXnlFmTJliuJyuZQJEyYov/vd7/o8HgwGlW9/+9tKaWmp4nK5lEsvvVSpqanRyNr409HRodx7771KdXW1kpaWpowaNUr57//+7z4LjZm+g/fee6/fa/6WW25RFGVgn7WlpUW58cYblaysLCUnJ0e57bbblM7OTg0+zeA41Xewf//+k86L7733Xvg1jPwdnO5v4Hj6Ex/J/PwWRYkq+SeRSCQSiUSSYEwf8yGRSCQSiURfSPEhkUgkEokkqUjxIZFIJBKJJKlI8SGRSCQSiSSpSPEhkUgkEokkqUjxIZFIJBKJJKlI8SGRSCQSiSSpSPEhkUgkEokkqUjxIZFIJBKJJKlI8SGRSCQSiSSpSPEhkUgkEokkqUjxIZFIJBKJJKn8fz9Z9IrPVEeNAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(sol.ts, sol.ys[0], label=\"Prey\")\n", + "plt.plot(sol.ts, sol.ys[1], label=\"Predator\")\n", + "plt.legend()" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,py:light" + }, + "kernelspec": { + "display_name": "py38", + "language": "python", + "name": "py38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/neural_sde.ipynb b/examples/neural_sde.ipynb index 9a8de6d3..b79967e6 100644 --- a/examples/neural_sde.ipynb +++ b/examples/neural_sde.ipynb @@ -56,21 +56,14 @@ "\n", "!!! danger \"Advanced example\"\n", "\n", - " This is a pretty advanced example." + " This is an advanced example, due to the complexity of the modelling techniques used." ] }, { "cell_type": "code", "execution_count": 1, "id": "350ecd31-c6f3-4cff-adbc-2f880c40f11a", - "metadata": { - "execution": { - "iopub.execute_input": "2022-02-04T16:29:28.459951Z", - "iopub.status.busy": "2022-02-04T16:29:28.458984Z", - "iopub.status.idle": "2022-02-04T16:29:30.627192Z", - "shell.execute_reply": "2022-02-04T16:29:30.626238Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "from typing import Union\n", @@ -98,14 +91,7 @@ "cell_type": "code", "execution_count": 2, "id": "df41f97b-8b00-49c4-84fe-b35f340b7be5", - "metadata": { - "execution": { - "iopub.execute_input": "2022-02-04T16:29:30.632152Z", - "iopub.status.busy": "2022-02-04T16:29:30.631233Z", - "iopub.status.idle": "2022-02-04T16:29:30.633125Z", - "shell.execute_reply": "2022-02-04T16:29:30.633856Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "def lipswish(x):\n", @@ -124,14 +110,7 @@ "cell_type": "code", "execution_count": 3, "id": "592dad43-7a89-4485-8b74-7855931d2526", - "metadata": { - "execution": { - "iopub.execute_input": "2022-02-04T16:29:30.643907Z", - "iopub.status.busy": "2022-02-04T16:29:30.643041Z", - "iopub.status.idle": "2022-02-04T16:29:31.035324Z", - "shell.execute_reply": "2022-02-04T16:29:31.036001Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "class VectorField(eqx.Module):\n", @@ -213,14 +192,7 @@ "cell_type": "code", "execution_count": 4, "id": "a4c157fe-4c86-4e15-9020-b523b517ebce", - "metadata": { - "execution": { - "iopub.execute_input": "2022-02-04T16:29:31.050073Z", - "iopub.status.busy": "2022-02-04T16:29:31.049057Z", - "iopub.status.idle": "2022-02-04T16:29:31.227187Z", - "shell.execute_reply": "2022-02-04T16:29:31.228029Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "class NeuralSDE(eqx.Module):\n", @@ -261,6 +233,7 @@ " def __call__(self, ts, *, key):\n", " t0 = ts[0]\n", " t1 = ts[-1]\n", + " # Very large dt0 for computational speed\n", " dt0 = 1.0\n", " init_key, bm_key = jrandom.split(key, 2)\n", " init = jrandom.normal(init_key, (self.initial_noise_size,))\n", @@ -274,11 +247,7 @@ " solver = diffrax.ReversibleHeun()\n", " y0 = self.initial(init)\n", " saveat = diffrax.SaveAt(ts=ts)\n", - " # We happen to know from our dataset that we're not going to take many steps.\n", - " # Specifying a smallest-possible upper bound speeds things up.\n", - " sol = diffrax.diffeqsolve(\n", - " terms, solver, t0, t1, dt0, y0, saveat=saveat, max_steps=64\n", - " )\n", + " sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0, y0, saveat=saveat)\n", " return jax.vmap(self.readout)(sol.ys)\n", "\n", "\n", @@ -322,9 +291,7 @@ " # The output at `t1` has seen the entire path of a sample. This is needed to\n", " # actually learn the evolving trajectory.\n", " saveat = diffrax.SaveAt(t0=True, t1=True)\n", - " sol = diffrax.diffeqsolve(\n", - " terms, solver, t0, t1, dt0, y0, saveat=saveat, max_steps=64\n", - " )\n", + " sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0, y0, saveat=saveat)\n", " return jax.vmap(self.readout)(sol.ys)\n", "\n", " @eqx.filter_jit\n", @@ -355,14 +322,7 @@ "cell_type": "code", "execution_count": 5, "id": "a181d457-2ff5-4eac-8943-ca9e83faeb26", - "metadata": { - "execution": { - "iopub.execute_input": "2022-02-04T16:29:31.238504Z", - "iopub.status.busy": "2022-02-04T16:29:31.237394Z", - "iopub.status.idle": "2022-02-04T16:29:31.519912Z", - "shell.execute_reply": "2022-02-04T16:29:31.520705Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", @@ -394,7 +354,7 @@ " ts = jnp.linspace(t0, t1, t_size)\n", " saveat = diffrax.SaveAt(ts=ts)\n", " sol = diffrax.diffeqsolve(\n", - " terms, solver, t0, t1, dt0, y0, saveat=saveat, adjoint=diffrax.NoAdjoint()\n", + " terms, solver, t0, t1, dt0, y0, saveat=saveat, adjoint=diffrax.DirectAdjoint()\n", " )\n", "\n", " # Make the data irregularly sampled\n", @@ -436,14 +396,7 @@ "cell_type": "code", "execution_count": 6, "id": "f7ec8e37-1aaa-4623-9601-9b21175708eb", - "metadata": { - "execution": { - "iopub.execute_input": "2022-02-04T16:29:31.530490Z", - "iopub.status.busy": "2022-02-04T16:29:31.529497Z", - "iopub.status.idle": "2022-02-04T16:29:31.634361Z", - "shell.execute_reply": "2022-02-04T16:29:31.635027Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "@eqx.filter_jit\n", @@ -504,14 +457,7 @@ "cell_type": "code", "execution_count": 7, "id": "b0581722-97fb-4771-94da-c65f9929e0f1", - "metadata": { - "execution": { - "iopub.execute_input": "2022-02-04T16:29:31.650019Z", - "iopub.status.busy": "2022-02-04T16:29:31.648906Z", - "iopub.status.idle": "2022-02-04T16:29:31.705044Z", - "shell.execute_reply": "2022-02-04T16:29:31.705864Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "def main(\n", @@ -528,7 +474,6 @@ " dataset_size=8192,\n", " seed=5678,\n", "):\n", - "\n", " key = jrandom.PRNGKey(seed)\n", " (\n", " data_key,\n", @@ -626,81 +571,74 @@ "execution_count": 8, "id": "f182fe77-e4d2-4094-88c5-926cf2b1f8dd", "metadata": { - "execution": { - "iopub.execute_input": "2022-02-04T16:29:31.711821Z", - "iopub.status.busy": "2022-02-04T16:29:31.710636Z", - "iopub.status.idle": "2022-02-04T21:47:38.622142Z", - "shell.execute_reply": "2022-02-04T21:47:38.623120Z" - } + "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Step: 0, Loss: 0.1617398304598672\n", - "Step: 200, Loss: 4.86433390208653\n", - "Step: 400, Loss: 7.129980427878244\n", - "Step: 600, Loss: 9.915551458086286\n", - "Step: 800, Loss: 13.451773507254464\n", - "Step: 1000, Loss: 8.164145742143903\n", - "Step: 1200, Loss: 5.45476382119315\n", - "Step: 1400, Loss: 2.8523939677647183\n", - "Step: 1600, Loss: 1.5683379343577795\n", - "Step: 1800, Loss: 0.5781421405928475\n", - "Step: 2000, Loss: 0.40823133076940266\n", - "Step: 2200, Loss: 0.842534065246582\n", - "Step: 2400, Loss: 1.0200202294758387\n", - "Step: 2600, Loss: 0.9040745667048863\n", - "Step: 2800, Loss: 0.9775767070906503\n", - "Step: 3000, Loss: 0.7866051537649972\n", - "Step: 3200, Loss: 1.1655586957931519\n", - "Step: 3400, Loss: 1.0307511942727225\n", - "Step: 3600, Loss: 1.2704946994781494\n", - "Step: 3800, Loss: 1.0042534044810705\n", - "Step: 4000, Loss: 1.5494119099208288\n", - "Step: 4200, Loss: 1.1781179734638758\n", - "Step: 4400, Loss: 1.4706323657717024\n", - "Step: 4600, Loss: 0.517096242734364\n", - "Step: 4800, Loss: -3.1678489616939\n", - "Step: 5000, Loss: -0.6181566289493016\n", - "Step: 5200, Loss: -1.2799221788133894\n", - "Step: 5400, Loss: 0.6105378525597709\n", - "Step: 5600, Loss: 5.683326925550189\n", - "Step: 5800, Loss: 2.9931929452078685\n", - "Step: 6000, Loss: 0.5538083400045123\n", - "Step: 6200, Loss: 0.30910458096436094\n", - "Step: 6400, Loss: -0.20523044999156678\n", - "Step: 6600, Loss: 0.6073118192808968\n", - "Step: 6800, Loss: 1.1460884383746557\n", - "Step: 7000, Loss: 0.9030835287911552\n", - "Step: 7200, Loss: 0.8061422961098808\n", - "Step: 7400, Loss: -0.16337597050837108\n", - "Step: 7600, Loss: 0.21688391161816462\n", - "Step: 7800, Loss: 0.32648008848939625\n", - "Step: 8000, Loss: 0.623529851436615\n", - "Step: 8200, Loss: 1.4328223807471139\n", - "Step: 8400, Loss: 0.6255699864455632\n", - "Step: 8600, Loss: 0.37481165677309036\n", - "Step: 8800, Loss: 0.4862654720033918\n", - "Step: 9000, Loss: 0.604121344430106\n", - "Step: 9200, Loss: 0.5833924242428371\n", - "Step: 9400, Loss: 1.328011427606855\n", - "Step: 9600, Loss: 0.37051604262420107\n", - "Step: 9800, Loss: -0.7500091024807521\n", - "Step: 9999, Loss: -2.032062990324838\n" + "Step: 0, Loss: 0.13390611750738962\n", + "Step: 200, Loss: 4.786926678248814\n", + "Step: 400, Loss: 7.736175605228969\n", + "Step: 600, Loss: 10.103722981044225\n", + "Step: 800, Loss: 11.831081799098424\n", + "Step: 1000, Loss: 7.418417045048305\n", + "Step: 1200, Loss: 6.938951356070382\n", + "Step: 1400, Loss: 2.881302390779768\n", + "Step: 1600, Loss: 1.5363099915640694\n", + "Step: 1800, Loss: 1.0079529796327864\n", + "Step: 2000, Loss: 0.936917781829834\n", + "Step: 2200, Loss: 0.9594544768333435\n", + "Step: 2400, Loss: 1.247592806816101\n", + "Step: 2600, Loss: 0.9021680951118469\n", + "Step: 2800, Loss: 0.861811808177403\n", + "Step: 3000, Loss: 1.1381437267575945\n", + "Step: 3200, Loss: 1.5369644505637032\n", + "Step: 3400, Loss: 1.3387839964457922\n", + "Step: 3600, Loss: 1.0477747491427831\n", + "Step: 3800, Loss: 1.7565655538014002\n", + "Step: 4000, Loss: 1.8188678196498327\n", + "Step: 4200, Loss: 1.4719816957201277\n", + "Step: 4400, Loss: 1.4189972026007516\n", + "Step: 4600, Loss: 0.6867345826966422\n", + "Step: 4800, Loss: 0.6138326355389186\n", + "Step: 5000, Loss: 0.5908999613353184\n", + "Step: 5200, Loss: 0.579599814755576\n", + "Step: 5400, Loss: -0.8964726499148777\n", + "Step: 5600, Loss: -4.22784035546439\n", + "Step: 5800, Loss: 1.8623723132269723\n", + "Step: 6000, Loss: -0.17913252328123366\n", + "Step: 6200, Loss: 1.2232166869299752\n", + "Step: 6400, Loss: 1.1680303982325964\n", + "Step: 6600, Loss: -0.5765694592680249\n", + "Step: 6800, Loss: 0.5931433950151715\n", + "Step: 7000, Loss: 0.12497492773192269\n", + "Step: 7200, Loss: 0.5957097922052655\n", + "Step: 7400, Loss: 0.33551327671323505\n", + "Step: 7600, Loss: 0.5243289640971592\n", + "Step: 7800, Loss: 0.797236042363303\n", + "Step: 8000, Loss: 0.5341930559703282\n", + "Step: 8200, Loss: 1.1995042221886771\n", + "Step: 8400, Loss: -0.5231874521289553\n", + "Step: 8600, Loss: -0.42040516648973736\n", + "Step: 8800, Loss: 1.384656548500061\n", + "Step: 9000, Loss: 1.4223246574401855\n", + "Step: 9200, Loss: 0.2646511915538992\n", + "Step: 9400, Loss: -0.046253203813518794\n", + "Step: 9600, Loss: 0.738983656678881\n", + "Step: 9800, Loss: 1.1247712458883012\n", + "Step: 9999, Loss: -0.44179755449295044\n" ] }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa4AAAEeCAYAAADSP/HvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOydZ5gc1ZWw39u5e6Yn55yzNCNpRjlnlBCInAwYsI3Tru3PYe1ds3jXaxvv2sYJG0zOUSjnHGdGmhykyTnHns6hvh89AiEkISEJgan3eeaZ7ls31a3qOnXOPfdcIUkSMjIyMjIyXxQU17sDMjIyMjIyl4MsuGRkZGRkvlDIgktGRkZG5guF6np3QEZGRuaLzIkTJ8JUKtUzQA6yMnC18QCVLpfroSlTpvSeSZQFl4yMjMwVoFKpnomIiMgMDQ0dUigUsrfbVcTj8Yi+vr6s7u7uZ4A1Z9LltwMZGRmZKyMnNDR0VBZaVx+FQiGFhoaO4NVmP0y/Tv2RkZGR+WdBIQuta8f42H5EVsmCS0ZGRkbmC4UsuGRkZGS+4CiVyikZGRlZqamp2QsXLkzp7+9Xfpp6nnzyyeD77rsv7mr372ojCy4ZGRmZLzhardZTW1tbXVdXVxUQEOB64oknQq93n64lsuCSkZGR+Sdi+vTp5o6ODg1AVVWVds6cOanZ2dmZU6ZMSS8pKdEBvPrqq/4TJ07MyMzMzJo5c2ZaW1vbF8rDXBZcMjIyMv8kuFwu9u7da1y7du0wwEMPPRT/l7/8pbWqqqrmiSeeaP/GN74RB7BkyZKx0tLS2pqamupbbrll8PHHH4+4rh2/TL5QUlbmwggh9gEvS5L0zFWsMxx4C5gE/F2SpO9frbovsf1m4CFJknZdhbrm4x2fmCut61O0nQA0AWpJklzXuK3HgBRJku65lu183hFC3I/33pl9ifmbx/PvEkL8G5AkSdJDn6btHQ0Yey2oz3zv72hKCAyLbleqNa7RgZ5QhVLl8g0IHrrU+sIMOJcmYzrfscbGxji1Wu202+2KjIyMrJ6eHnVycrJt7dq1oyMjI4qSkhLfW2+9NflMfofDIQCampo0a9eujenr61M7HA5FbGys/dOc6/XiSyu4xh/004EzD5IOSZLSzzp+F/A/QAiwE3hQkqTBz7qf15lHgH7AT/qCbSMghJCAVEmS6q93X2Qun2vxInapSJL0y0vJd6aPpaWlH0n/mJCZkPjBc6O+3uSn0WgccXHBl/0s6enpCR4YGAjJyso6dSYtKSmpFUCr1UbU1tZWm0wmxfz581N/9atfhT366KP9RqPRVVtbW31uXd/61rfivvvd73bffffdI5s2bTI+/vjjUZfbn+vJl91U+C1JknzH/84WWtnA34B7gXDAAvzlOvXxehIPVF9IaAkhvhQvPsLLl/23ctX4stw3Ho/nM2/TaDR6nnzyyda//OUv4Uaj0RMTE+N49tlnA8/05+jRo3oAk8mkjIuLcwI8//zzwZ95R68Q+cd4fu4GNkqSdECSpDHg34GbhRDGczOOP9R+J4ToFUKMCiEqhBA548dWCiFKxtPbxs04Z8olCCEkIcQD48eGhBBfF0IUCCHKhRDDQog/nZX/fiHEYSHEn4QQI0KIWiHEogudgBDiQSFEzXi924UQ8Z/U33PKPw98BfihEGJMCLFYCPGYEOJtIcTLQohR4H4hRJQQYoMQYlAIUS+EePisOh4TQrw1nt803laaEOIn4+23CSGWfsK1KBBCVI+fx3NCCN1Z9T883ubgeB+ixtMPjGcpG+/77WeV+f54211CiAcuMn77hBD/LYQ4jPfFJUkIkSGE2Dne3ikhxG1n5b/gtf4khBA/FkI0jI9RtRDiprOO3S+EOCSE+O34GDQJIW4463iiEGL/eNmdeC0EF2vrh+Pn3imEeGj8HkwZP6Ydb6dVCNEjhHhKCKEfPzZfCNF+ofG7xLI/EkJ0A88JIQKFEJuEEH3j57VJCBEznv+/gTnAn8av35/G0y82/sHj98CoEKIQ+MA8doFxuFcI0SKEGBBC/PScY48JIV4e/6wbv38HhPc3WSSECD+7j93d3XFNTU1xAMXFxVO6urpCy8vLcyoqKiacSbNardoz9btcLlVNTU3qyZMnJ1VXV6fbbDYNgM1m0xQXF085W+DV1NSkd3d3h5jNZl17e3u8xWLxPXny5KSTJ0/mAdTX1ye0trZ+oC11d3eHGI3G5JSUFPVvf/vbzBdeeKHtueeeC0lPT89KSUmZ8tprr0WXl5fnPPzww9o777wzPTs7OzM4ONgFiOrq6vTBwcE4m80WXFdXl3Sx8bvefCnefC7C/wghfgWcAn4qSdK+8fRs4MiZTJIkNQghHEAacOKcOpYCc8ePjQAZwPD4MTNwH1CFN2TJTiFEqSRJ688qPw1IHa9jA7ANWAyogRIhxFuSJO0/K+/beB9ONwPvCiESzzVhCiFuBP4NWA3UAT8GXgNmfkJ/P0CSpPuFEADtkiT9bLze2cCNwK3j56UFtgOVQNR4XTuFEA2SJO0Zr2r1eJn7gWfH8z8DRI+n/Q1IPLf9s7gbWDY+lhuBnwE/E0IsxGvKXTo+vr8FXgfmSpI0V3hNhblnTIXCO8cVAfiPt70EeFsIsV6SpAvNN9wL3ID3/vAZP8//GE+bMH6ulZIkVXNp1/pCNOB9CHbjHduXhRApkiR1jR+fBryA97o/AvxDCBE9rgm/ChwdH4dpwGbg/fM1IoRYDnwPWIR3zu3v52T5Fd4Hfh7gHK/7P4CfjB+/2PhdStkgvFq8AjAAzwG3AUq898afgLWSJP1UCDGLs0yFQggfvCb7C43/nwEbEIn3fto+fo7nG4cs4K/ACuA43vvoQnOfXxk/51jAPn5+1rP7GBER8bPExMT+MwWGh4cDMjMzaxQKxXlVruHh4eDk5OQ6o9FobmlpiWlsbEw82/x3Pnx8fGwxMTEt55oKz2CxWEqGh4eNXV1d0SkpKXX79u2ztrS0xFit1piDBw+eAq8ANRqNJCcn16SlpSkXLFiQGR8f3xEUFDRaV1eXqNfrrT//+c9PeTweMTY25nOx/lxvvswa14+AJLw/wr8DG4UQZ97SfPE+1M9mBPiYxoX3R2rE+9AWkiTVnHngSJK0T5KkCkmSPJIkleMVHvPOKf8LSZJskiTtwPvwe02SpF5JkjqAg3gdI87QC/xekiSnJElv4H2grjxPn74O/M94X1zAL4E84dW6LtjfS+SoJEnrJUny4H2QzgJ+NH4OpXiF0n1n5T8oSdL28X68BYQCv5IkyYlX0CQIIQIu0t6fJElqGxfO/w3cOZ5+N/CsJEknJUmy431AzhBeR4gL4QQeHx+/LcAYkH6R/M9LklQ13vflQLMkSc9JkuSSJKkEeAevoLnUa31eJEl6S5KkzvGyb+B92Zh6VpYWSZKeliTJjVeARQLhQog4oAD4d0mS7JIkHcAr3C/EbcBz4+dkAR47c0B431IeAf5VkqRBSZJMeO+bO84qf97xu8SyHuDn4/20SpI0IEnSO5IkWcbz//cnjNcqLjD+QgglsA74D0mSzJIkVY6P04W4Bdg0blGx47WoXMiu5wSC8Tq8uCVJOiFJ0uhF6iYyMrJbrVa7lUrleU3sRqNxxN/ff0yhUEixsbEdFovF12azqc+X93IYGBgICgoKGjAajZYzdVutVp8zGh1AREREt0qlcut0Ooevr6/JYrEYwDsn7HA4tA6HQ61UKiV/f/+xK+3PteRLK7gkSTouSZJp/If0AnAY7xsYeH+QfucU8YOPe/aMaxZ/wvvG1yuE+LsQwg9ACDFNCLF33BwyglegnGvK6Tnrs/U8333P+t5xznxTC15N51zigT+MmzaGgUFAANEX6+8l0nbW5yjgzIPq7D5Fn/X93PPpH38An/kOHz3Hi7V39vlGjX8HYNykO3BO2+cycI5Xn+Uy2o4Hpp0Z0/FxvRuvJnGp1/q8CCHuE0KUnlVvzjllu898GBc4jPc7ChiSJMl8Vt4WLkzUOed09udQvFrQibP6sW08/QwXGr9LKdsnSZLtrHM2CCH+Nm6uGwUOAAHjQuh8XGz8Q/Faj869Vy5pHMbHb+ACeV/Cq729Lrzm1d8IIS4qZDQajeNix9Vq9QfHVSqVR6lUuhwOh+ZiZS4Fp9Op0Wg0H3gHjtftttvtH/RXo9E4z3wWQng8Ho8CIC4url2SJGpqajIrKiqye3p6PtfzXl9awXUeJLwPd/Cae3LPHBBCJOE1i50+b0FJelKSpClAFl4T3P8bP/QqXvNfrCRJ/sBTZ7XxaYgef7s9QxzQeZ58bcDXJEkKOOtPL0nSkU/o76VwtuDsBILER+f+4oCOy6jvk4g9p+4z59uJ92EGfGBKCr7KbZ99rm3A/nPG1FeSpG+MH/9U13pcC34a+BYQLElSAF6T5KXcJ11A4Pi5n+Fi4Xq6+KhJ7Oyx7cf7IpF91vn5S5J0McF+OWXP1T6+j1fbnSZJkh9e8zV8eN7n5r/Y+Pfh9Q4+9165EF1n5xVCGPDeOx9jXLv8T0mSsvCa2lfxoUXhQk5LF2naK2DOfHa5XAq3263SaDSOM6bFM8JkPO8lT+eo1WqHw+H4YC7N7XYr3G63UqvVOi9WDkCj0biSk5Nb8vLyyuPi4lra29vjz56X+7zxpRRcQogAIcSy8YlXlRDibrw/nG3jWV4BVgsh5ow/FB4H3j1HszhTV8H427Yar6nPxodmByNejcQmhJgK3HWFXQ8DviOEUAshbgUygS3nyfcU8BPh9Y5ECOE/nv+T+ntZSJLUhncu8H/Gx3Ii8FXg5U9T3wX4phAiRggRBPwUeGM8/TXgASFEnhBCi9c0dVySpObx4z14TcFXi01A2vikvnr8r0AIkTl+/NNeax+8D8A+AOF1ePiYs8z5kCSpBSgG/lMIoRmfg1x9kSJv4h2zzPGH9b+fVZcHrwD9nRAibLwv0UKIZZfQj09T1ohX2A2PX9ufn3P83Ot3wfEf1+DfBR4b1+Sy8M5NXYi3gVVCiNlCCA3e3/d5n4VCiAVCiAnjmuAoXtPhmd/Lp7rHTCaT/8jIiK/H4xHt7e3RBoPBrNPpnBqNxqVSqZx9fX3BkiTR09MT7HQ6PxAearXa6XQ6NR6P57ySMSgoaHBwcDB4bGxM7/F4RGtra7RerzfrdLqLaoAA/f39gWc0M5VKdUar/twugflSCi68jg//hfdh0Q98G++k8GkASZKq8Jp6XsE7r2QEHr1AXX54f7RDeM0TA8AT48ceBR4XQpjwTiq/eYX9Po7XkaMf75zALZIkfczEIUnSe8Cv8Zo3RvG+wZ/xRLtYfz8NdwIJeDWg9/DOY1zxguGzeBXYATTidWL4L4DxNv4d7zxHF17HgLPnVB4DXhg3K93GFTL+0rJ0vI1OvOa7X+PVxOFTXmvJ61jwv3gdLHrwOh0cvoyu3YXXKWMQ78P/xYu0tRV4EtgL1APHxg+dMS/96Ez6+H2zi4vPAZ7N5Zb9PaDHey8f48OXxjP8AbhFeD0On7yE8f8WXrNlN/A8XseP8zL++/4m3nurC+9vof0C2SPwCrpRoAbYj9d8+EEfu7u7Y5uammIvUP5jBAQEDHR2dkaWlpbmWSwWQ2JiYuOZY3Fxcc29vb0RJSUleVarVa/X681nlTPpdDprWVlZbklJSe659QYGBpoiIyM7Gxsbk8vKynIdDoc2OTm58dx858NsNvvU1NRknjx5clJDQ0NKdHR0q16vdwBUVFRk9/b2Bl3q+X0WCOmLta70S4u4zEgAMjKfxLi2WAlopWsc0eOfmbKysubc3Nz+T84p82kpKysLyc3NTTjz/cuqccnIfCkRQtwkvGuuAvFqLBtlofXFp62tTbV69erEmJiYCdnZ2Zl5eXkZL774YsD16MumTZuMO3fuvGx3+ujo6AldXV2XNKcnCy4ZmS8XX8Nr/m4A3MA3Lp5d5vOOx+Nh9erVKXPmzBlrb2+vqKqqqnnzzTcb29rarthT8UI4nRf299izZ4/x4MGDl+LU86mRBdcXBEmSnpfNhDJXiiRJy8c9/oIkSbpJurw1fDKfQzZu3GhUq9XSD3/4w74zaWlpaY6f/vSnvS6Xi6997WsxOTk5mWlpaVlPPPFECHi1oqlTp6YvX748KTExMXvNmjWJZyJ2HDx40FBQUJCenZ2dOXv27NSWlhY1wNSpU9MffPDB2JycnMz/+q//Cj/f1iinTp3SvPjii6FPPfVUeEZGRta2bdt8Ozs7VcuWLUvOycnJzMnJydyxY4cPQHd3t3LWrFmpKSkp2bfffnv85UxbfdkjZ8jIyMh8oamoqNBPnDjRcr5jv//970P8/f3dlZWVNVarVRQUFGSsXr16FKCmpkZfWlramJCQ4JwyZUrGzp07fefPn2/+zne+E7d58+b6qKgo19NPPx34gx/8IPqtt95qBm90+crKyhqAvr4+5R133FGrUCj4v//7v5DHH3884umnn26/7777+nx9fd2PP/54D8Dq1asTv/e97/UsW7ZsrK6uTrNs2bLUxsbGqh//+MdRM2bMGPvtb3/b9frrr/u/+eabl7TuET7ngiskJERKSEi43t2QkZGRuSC/+c1vqK6u/nBN4YkaxOB5dyH5VEhBRqQpmRc8brFYcDqdVFdXhwH84he/4OTJk6jVaqKiojh9+jQbNmyIAhgbG+PgwYMT1Go1EyZMwG63Tzx16hQZGRmUlpam22w26urqWLhwYS54zZChoaFUV1cHAyxZssS3uro6FOD06dM88cQT9PX14XQ6iY6Oprq6Otxut6NUKqmuro4BOHLkCHV1dR94JZrNZoqKiqYcO3aMP/zhD1RXV0elp6e7/Pz83Oee24X4XAuuhIQEiouLr3c3ZGRkZC5ITU0NmZlnCZasrM+0/UWLFvH444+TNd7ua6+9Rn9/P/n5+RiNRv72t7+xbNlHl9Tt27ePoKCgD8qEhIQQHh5OUlISEyZM4OjRox9rx2AwkJ2d/UGZRx99lB/96EesWbOGffv28dhjj5GVlUVoaCi+vr4f5BNCUFpaik6n+0h9Op2O1NRUkpKSqKysdPDh0oZP5JLnuIQQzwpvVOjKs9IeE0J0jIerKRVCrLhA2eXCG825Xgjx40ttU0ZGRkbm4ixcuBCbzcZf//rXD9IsFq/lcNmyZfz1r3/9wJni9OnTmM3m89YDkJ6eTl9f3weCy+l0UlVVdd68IyMjREd7I6y98MKHoSGNRiMm04ca59KlS/njH//4wfcz+5fNnTuXV199FYD9+/crR0dHLxTu62NcjnPG83gDjZ7L7yRJyhv/+1gUh/EV53/GuwA2C7hzfGW7jIyMjMwVIoRg/fr17N+/n8TERKZOncpXvvIVfv3rX/PQQw+RlZXF5MmTycnJ4Wtf+xou14VXP2g0Gt5++21+9KMfkZubS15eHkeOHDlv3scee4xbb72VKVOmEBLy4fTU6tWree+998jLy+PgwYM8+eSTFBcXM3HiRLKysnjqqacA+PnPf86BAwfIzs5m9+7dysjIyE+M8PHBOV+OJ4fwRt7eJEnSmf2mHgPGJEn67UXKzAAekyRp2fj3nwBIkvQ/n9Refn6+JJsKZWRkPs98zFQoc9lUVlZacnJyai50/FosQP6W8G58+Oz4osZzieajUZvbuXgEbxkZGRkZmQtypYLrr3y4eVwX3phrV4QQ4hEhRLEQorivr++TC8jIyMjIfKm4IsElSVLP+OZqZ6JDTz1Ptg4+ut1ADBfZekKSpL9LkpQvSVJ+aGjohbLJyMjIfG6QY75eO8aj4X9kB4srElxCiMizvt6EN2DnuRQBqUKIxPEtBO7Au2+RjIyMzBcenU7HwMCALLyuAR6PR/T19flzjmy55HVcQojXgPlAiBCiHe8WCvOFEHl4921pxhsHDSFEFPCMJEkrJElyCSG+hXcXUSXe7dbP718pIyMj8wUjJiaG9vZ25KmNT093d7fK7XafL3KGB6h0uVwPnZ34ud7WRPYqlJGRkbn+mB3QPvrxdH8dRFyFcLpCiBOSJOVfan45yK6MjIyMzEXZ1gBuCTzn+bsefK5DPsnIyMjIXF96zOCvhazPka+crHHJyMjIyFyQQ60wJ+569+KjyIJLRkZGRua89JnBTwvaz5ltThZcMjIyMv/kPF8GrSOXX+7g51DbAnmOS0ZGRuafmi6TRNpwG9Xd0bSPKpkZe+G8tpIaHLVNCI0akx0ihsFZDU6XG0VwAD6Lp39m/b4YsuCSkZGR+SfFY7XT8MwusiaGoyosoy04lndNuaxOF6jP2kTE1TOAZccRNBNS8bvTuzvVzlpYkQK6cSlhK6rEsq8Iw/wCANzDJhxV9ehnTfqsT0sWXDIyMjL/jDjburHsOU7vrEXMnupdbJVa10LggU1sLktg9pocgtUuzNsOgVKJ8fblCI0agAEL+Kg/FFoAuoIcrEdKsR4uQTsxjbH3duF318rrcWqy4JKRkZH5Z8N66CTuYROmG1cTYfrQlUGTGk9UajyLKxs5/ufNpAdDxJpZqEI/urHH/lavtnUu+pl5jO04zOATzxH0k4cRWs21PpXzIjtnyMjIyPyTINkdmN7YhsLfiO+qeZT2KpgU8fF8vjlJzPreKkrmrPqY0Bq0gkH1UW3rDB6rHXdHL7ppExktr2dP0zU6kU9A1rhkZGRkvuBIkoTtaBnO5g58ls9GGeSPJIHNBXr1+cvoVGB3gSSBEB+m72+BZcnnacPuwPT6Foy3LEVh9GHvX/eS5aqAb99xbU7qIsiCS0ZGRuYLjKOhFdvhUnTTJ6KfmfdBetsoxPpfvGxiIDQOQ/K40mV3gUKA4RxhJzmcjL6+Fd+1i1AYfegxg0YF/jPyzq3yM0E2FcrIyMh8AXEPjjD6+lZcnX0Y712NJi3hI8fLeiA3/OJ1TAiD8p6PljmfaXFsw158V85FGegHwIEqM2lDzWBzXNlJfEpkjUtGRkbmC4bkcDL2/h6Md65AodN+/LgEDvf556nORqP8MFiuQngXKRdEfTSPvaoeVVQYyhCvWlbTB+n7NqNNiEQ/feLVOqXLQta4ZGRkZK4DTvenL2vZcxyfG+acV2gBNA9D/CeYCc+QHgy1/d6tS/Tqj853eWx27Cdr0M3I9X6XoOpEBxG9rRhvXfbpT+AKkTUuGRkZmc+YQ63Qa/ZqOWcjAWqFNz6gv3b8vw4CdB/m9ZjMeCw2VBHn23fRS0UvLD+PO/v5yAyB9095127lR370mHnTfnxWz0eMS7Mj9XYmrX+R4F8++kHa9UAWXDIyMjKfId1jMGKHmzPPf9zpBpMDRmwwaofWUTDZYVWa97h5+2F8ls++YP1nzIQa5QWzfASlwqtldZpgTvyH6R+YCAOMAFidEu5X1hO6YNK416KEq7kTdWL0pTV0FZFNhTIyMjKfES4P7Go8v7v5GdRKCNJ7Pf5yI2BevFewjNjA1dWHwuiDwtdwwfINQ5AcdHn9iveHYfuH3881EQIUPrOXtGDwXT0fAOueQiT79XHOkAWXjIyMzGfEtnpYkgyqy3zyLkiAPc1g2VOIYeE0wKtZnY/KXsi5zE0fBywfdeQ410TYteU4wteAX0oUCoMOR30rktuFJiPx8hq6SsimQhkZGZnPgNp+71xVuM/llzWoIbC1EUtMLH5qFWYHvFoJweOKV4QPpAVDsB7cEh8JoHspjNghUAdWJxTta6Bg3EQoSRKD6/dx1B3FAlUT+gUL8JjM2I6WYbxnFVV9kH0ddkaWNS4ZGRmZa4zZASXdMOsiW4pcDEmSmNBRyeHQCQBsqYc7c+DmDLgpHRIDoKoP3qqGiedZu9U9duG6e8wQ6gN5EfB+jRvLwZN0ZuQiud0Mv7aNvZpkFhYEoQn282538u4ufNctYdAq6DV/uvO5UmSNS0ZGRuYaIkmwqQ5Wp33U1fxysB0pxTgnjyCVYFu914Xddzy+rRAQafT+nY9hG/ylGH4488MyZ3OiE2bp+lHsLYetzUwJMXFkVzXGom1UGuOZmVyL869VKCZl0v8ff0IdG4l52yHqeyUyglWQsODTndQVIAsuGRkZmWvI0XaYGHZ+oXEpSA4nzuZO9LMmkWuFXx+GXy++9PIHWuDrU7wu+Ge7yHvGLFiLqtCd6EU9KYS+aTPQFg1hE1bSt7/P+pu+yuJZ4XSYJYa63CRHJRIVHYZ+9mQABjY3kht6fVSuSxZcQohngVVAryRJOeNpTwCrAQfQADwgSdLweco2AybADbgkScq/4p7LyMjIfM7pMcOAlYvuOvxJmLcfxrB0JpIE2xu8wqd5GBICPrnsmMOrkUUZ4VjNGGPHm5DausDjQfgYGEhKx+fOAgzRUHKgn+ThFrarYwmOiaZICieszcGEwXrSojwcL2oi6V+XA2B3SYgTZVSkRFAw9dOf26flcua4ngeWn5O2E8iRJGkicBr4yUXKL5AkKU8WWjIyMl8GXB7Y0QA3XOJC4PPW0TcEgCo0kKPt3jiC8+LhWPulld/fAnPjYPSlDeS1llJLML43L8Z423J8V86lTBFObrjXqzBs6yZMdgg0qnG0dvP9hneIPXGYgKNH0CTHEX7rYir7vPUe2noaj8XGpCUZn/7kroBLFlySJB0ABs9J2yFJkmv86zEg5ir2TUZGRuYLy/YGWJJ0+a7vZ2PZfhifZbMYtEKfBTJCvBE0UoLgVP/Fy9pcYHeDtaOP14YiiV47mwbfKITC2yGPR8LpltCp4FDpILojx1FMziF6VgbLfv8ArZKRuoIFKOKjceRNALWKrXXwXrVEw55KopJCeL078OKduEZcTa/CB4GtFzgmATuEECeEEI9cxTZlZGRkPlMkySs0rM4L5zk9AH4aiPD99O3YK+rQpMWDWs22+o9qblMi4UTXhddygXdOK9IHtr5VSeScbH550OuK3+BV4qh/7xgZuzdi6h/D//HfMFIwDf+qCrLNbeibm0kJUaBtb6NSGcbvj8HpQQjQQ8+hSmI0dhpyZzAx7NOf35VwVQSXEOKngAt45QJZZkuSNBm4AfimEGLuRep6RAhRLIQo7uvruxrdk5GRkbkqtI3C61XeUEzrT0H1eR5RFqdXqMyO+/TtSC4X9pIatAU5HGiFaTEfDeEkhLf+Pc3nL+9wec2Jtf0Sub5W1kwykBUGjYOws9Gbp6VpmNRbZ1F9179jGOwl2uAi8/ffxbhuCc7mTiLTwkjev5U3lZlE+nq9E5WSh9iWasIyonAYfKns/fTneCVcseASQtyP12njbkk6v/yXJKlj/H8v8B5wwek8SZL+LklSviRJ+aGh12Flm4yMjMw5jNjg3RpoGITbsqAgGu7I9sYUfLfWa5Y7w8bTsCr107u+A1h2H8ewaBq9FsGY48ONHs8mzt8b17DT9NF0lwd+e9QbPSNsqIvMyd59SlamejWmIQu8vbefkAAV7T/7C62qQIIzY4hZlIsVFe/UwK60+Rw/3M7Q6U5GlXrW13oXKGc0luKwuanOnkHTEITqL6LyXUOuyB1eCLEc+CEwT5IkywXy+AAKSZJM45+XAo9fSbsyMjIynwWS5NVQHG6vN9/ZOwMLAdOivZrIe7UwOQKGbJATBsbz7zZySbiHTXgsVhRR4eyqhNuzL5x3SRK8Vgl35Hjn0kbteIWMHswuWDVUjXblHFwdPSibO4mtGiBL5aF0ezV2Hxd/Sb+RB/V7CJmRgrV7iH2b6lm+MgUfjaBPM4WW48e5X1HN75XZ7KpzMfvwacp9EzhWpSVxoImZ6tPw3c9+e5NL1riEEK8BR4F0IUS7EOKrwJ8AI7BTCFEqhHhqPG+UEGLLeNFw4JAQogwoBDZLkrTtqp6FjIyMzDWgedgbpmlV2se3sz9DgM6rfQ1YvRrYlYZAMm87hM+y2exugrnxF3fuUCpgUaI3cG/riFcrDPPxamKrUyU8fYOMvb2TsS0HUCVEMfmB+WzPWoIrI4WiiByy2qsJ9VOgXL6ALZlLWaDrQ1F4ErcHirafIut338b2wnv8Y6mdn6lOEG7wELBoKgmjHawteoeJi66PV+Ela1ySJN15nuR/XCBvJ7Bi/HMjkHu+fDIyMjKfZ0p7vCa2T0KIK1urdQZHXQvqmHA6XTokCWL9Lpx3+Ol38H/gRiKNKjac9q4ZM2rgpXL4ZYEJ99MbQKHAsHg6Y+/twlFVj090OKqWFuw2N5JWR6YY5FTqFJrafbg5A3wmzsBefpqjf9hMjo+LBmUIpoKpjPzyb/Q41XgKChho7OWO/c+zYcUjtJlD+c6Vn/ZlI0fOkJGRkTkPLo/XVHip+1pdKZIkYTtahv7u1eyv9sYivBDO5g7M2w6CQsHhuTcxJRJO1pnpO1HD8ooSfNoMCJ0avwduwrL1EH73rsF2sprCHacIqG9BqRVYR0eIcw/zdPxMlvm5UdXUM1rTSOewB4PkQFvXxJ7dzay9axptd71HoHCwM6KAvD1P88SK7yI0AagHL9zHa4ksuGRkZGTOQ1XvZxv53HrwBLpZk9jZKFic+PHdkc9m5Ln1+P33v1L536+SNfI6ITEBNHQa6K3uQVq9hAY/J8mHdzK2fg+24+VITiddY4L2vcdYHO6gyqxD2zXCO1Nmk3diN8U7XexMTGMwZjl9WsHXmrdQZIgn/uheTFtbsQWHYpyZxNw3nqJp1TpumOyLnwFMF1kScC2RBZeMjIzMeagbhHUX2KX4auOx2HB19tE9IR/DCIRfZP2Xo6UTR2UdJ7fUkDMpFoMOOmcuoHZrB4YsXzrCEjD21ZF38xIkqw1tXgbv2ONxuiWydEbGXnsDo8lJn280YeE+hM2fxqQQLRtPQ4wa1qbDkT8KjuffwIw9b9E96EA5NEzTez0UTllEU8IUbj68haqEPCZOjgY+I5X0LGTBJSMjI3MOFqd3Y8UrcWn/JBxu6BiFlhHwbDrE2LTZ7CuCh/I+ms9lsmAfHkNts2IrrsJWXIXn9hsxTVhAePQYY1sOcvjPO0gQEvP+32oah+HUtgb2Zi8h5eg2XkxZQUowlPQIFloH6ZS0WG12+MOP8SCYkexdnxZqgKp+2Fg4gtOmZEH5djZNv5E4QyFBI/3syFzEjbXbmDBaTvzDqxl89TDRr+2G/3zg2g3SBZAFl4yMjMw5nOzyRqe4VrSOeOMI5oRBjtSHLlmHapIRRSPUD0r42MeItA7g6Bqg9I2jKPsHSH5oBSFrF2I/Uc3BlNlUd8OCBF9OdAmizT0MBEUQVn6C8CB/RvwU7Dg+RMVQCNnhgqlRkKSz0vhuA8r2HvZ953H+pX4fvXMXsOE02F1e78UbkuHQHw4zzWjiz4k3M7H+JE2KQDrTEqlWh6OaexcPb/4dm4yx3BYrEXn32ms3SBdBFlwyMjIy59A9dmWRLz6JE13e9VkaJYw8ewhlbhr1L+wkQ+EhzFdwdNQXZVYwxz1xTM3vJvSWr3LozztJHbPhl5ZAaa/ge9JxCv8yzFhlI1GWfvx1YA3KpnRTOaJ7iIUmE55bVhEVCE8WwteHihk7dIjtP/gVqqRYFHoT6k076GwUxBglJkVAWaebjJrjPH7rz/FXCFQD/WicKhozJmF0Qr9Tjev7j7Lq2X+gD9LiTo9EFfLZxyuUBZeMjIzMWQxavVEirhV2uwttYyuOmibGGloRajVKow+npy1gbbYKhYD5LvjxbrjNVcXp2CxOdOqJnD+F7l/9meJVa1lasgW/u6fx2mgg31XaaRyNY/63l1D2t+3kxqmpX3c7fb9+lvBDe7HMmckD2Qb6b3+VwcRMeiMTeawAni7Nwp6UhT4NSp2wzwauxlImPfwQkw0Kagqb0SXHEtHaiSdExc1R3hBXv6305YXEEDwDw9hPVqPNvoLw95+SqxlkV0ZGRuYLT2GHN6TT1cQ9NIplfzGmN7ZR88IeMv2dGJbNQhUeQsB37kadmoCkVOFwQ1EH/Mc+yAqBxsJmJsxIYE06RA534eswU3O8mZK5q3mpL5yVzQfYkzYPTXYKg9Vt+M6aiLOkBtufX8I3LoxWjLB+G6af/i9+SjddX32YYD3sbvaaK8t6IDnIG8qqYxSmejrIKoil3QSPamo4OuyDJzGe+/O868SC9BCih6frfGlqGKJ+0eqrO1CXiCy4ZGRkZMaRJG/0C78rCNn0kfpcLkZf2YytsAJNegLG25dTN3sp8flJjPzjXVRRoTgb2ug43UO4Y4Q3TtgpbnXx3Xw390f2snxWCH8uEhzZWovmnY2YVq8gLyuQOfoBVI1NTMoJpEvlj39uCj2v7yC1uZIDax4ke0o0/Tevoyo0ncbECSSnBtETmYhvajT3T4KuUa/XJJI3aHByICRbOrlxYSQtI+BrH2NYqSdrqAG/vFT+ftIbOSTCF/6Q2Uas1s6hlFnUP7PlE8fgWiCbCmVkZGTGaR2BeP+rU5ckSZje3olh0TRUESEADAzZCevsZGj9ThR+vqhiI5DMVhrrBnFb7QQM21kY5sbY7Wb4WDktsRmsUO6ktMNNcnY6exJm8fBsHeo31hNnc7Nnxk2k7T9G26EBQqPj+It+Gou1VnyTIinpUxAYHYKPMZDBba/TPn8NvSPw5+NQP+SNyvHDmfCLgzDqgNXD5ewLWsRb1XDXwElOJE4i33GEwyYVd0/wbodySxaMvlJGWIiezp4xNAxdncG6TGTBJSMjIzNOSTesuIQQT5eCectBdFOycHX1YT10EoCqQQ05GQGo4qPwv3vVB3nbFNAz5o07GBsBHpud0/3gWbGYadGg++NG3m7zJ3yeL0HOYTpP1NLRPEzokWb6191LuWE6s8IdpO/YS0WFxO65i5gQKsGoibE3D1DT7UQU5HK606s1zU/wtjVih/tzYVOZFbvQ8IvjahbEehipNDMjfYSu8Fh+OBNeqfT288WDowQe7EQ3L5/owj1s/MoPWHt1huuykAWXjIyMDOD2gOecEE+Sx4OzqQNN8uUFIrQeLkEVGog6Poqx9bsx3rYcSYLeGphUvQPDmoW0jXrbtLrgYCukBYFW5d2gsmZDKa4peUyLBntNI2EBavqt4WTu3sOgqpPK+FzUWXpCrUOY33uTrrlr6QxMYEprPfrqZir7FSiNArtWT27DaZRTJ/J+m4qFiV63d4MagvVwuN0bkPe20ZP83W8SN6eAo/Q0htxUuotOM/+rcwnxgUhfrxk1teQoxmQD5buOUrH2Hv5fUi/w2e8mKc9xycjIyOD1mMs6J8ST7Vg5tiOlWA+XXHI99qp6PGYLuqkTsB4pRTdzEm6PVzhJ3X0UDup4548HKPrNuwz3mNjf7N3Pa3ESHGiBPx6XOF3dT78xhMZuO9bCCpqO1ZPt6EI3OZPNrRqGJS010xfzfO6tDNxxBzdVbCT/H7+nv3uMfXNuYl5eIPq1i4ntbsRHslPiCSGpvxF/HWSGeAX0jkZICACXW2Kgc5hRn0AengRTxxowZCYRpHCSHKGhuheOtIOvcGI9UsrJbkHB3AQWuJoIfvXVq3oNLhVZcMnIyMgApwYgPfjD75LHg7OhDb97V4NSiXnXsU+sw9XRg6OmEcOSmVgcEtXlvbw/Gs6mOu+OxFMr9zFZ9KKcO42Au1cSsH8fVaXd/Gimd3sUCVA3NrFiZSILEmBs036OE4kxxJehZTdQ+eZRzG4FjuWLSArwRrzo8o/A7ydfp9ulw7e9ma/rThNeW0re3/+PIKOSgxGT8L1xEXmdlZS1u3ivFt6s9m5A+WI5pPWc5mRgGr9fBq/t7MEeEkr4cDf6xCj+VAivVcHk+kKSdryP0mZm8dxItNMnEvXS84jrsIYLZMElIyMjg83lNdOdHeLJdqQUT0EefyyEttSJqEIDGdu0H0mSGLJ6TWdnkDwerIdOMvCb5+hIyOTdGji8tZaoGenckuXd8LHg1CF8RwYw3XQjEfGBzM/Q8kzyKpJbq0jqPMWhVq/wultxig3qDDw19SRlhhFdcowXUlbQvq2IafohVv5oBXGBSur+sp7V1dv4qqGRrrf2sD9/JVsK1lHRK1DbrWhrTxFwspj86oOwYRunI9MJPnaEUTtkh4DTAxnB0F5Uz+QFKYQawL+shFMpk6nYW8sGTQZTImGe3zDT3R24X36Pif/zNTSSm75fv4DOOkaH5fqIEHmOS0ZG5ktPcedHQzxJbjfO5k6q4iezLnp8zZM6nQXxWqR3drIpdQlBBsHykFGsh05iGrLS2GGh/6YHSenrZvFwJY6SagJ+8ACSJFHxTiFJ1UUE/ua77D+lYHWaN1Zhq0kQdMMiWluL8O8vJGl2Bj5BPuT623njmVpCF+UTYAM/g4qbeo+Q+qO72DhiIN/eytagWB78fi7ud7YiNfeR3tWCXqvg1UV3kRMKyTFHcJssuAZH6R0TTC05RLB1iEaHnqSb87khTfD87kEMPoHcmCLYc3KIgBAfOlQqNFYLT6R14Clpofrd4wz6+xGbHY+nvBa3zY7yWBF9M2YQF3iV1g1cJrLgkpGR+dJzbogn66ES9LMn0z0Gs2Ihygi5TtjZmIAhXEvctg0EGNVscvvimpZPQJYPEw5vYdb8MCAMR0MgqtAAHGWnsHb3MdarIvLhtZidCrQq7xqqF8q8dYf5wC51ATHu0/T/+XVKbr+bzG37WfS1edQ+v4uBxHTyN71M6reXoosJI6AB9j1TyAprN5pmf0otep6c/iD37fgr9T5h3HB6D6N7B9izZB359ccYjAmnxaShPamAOfWHuK9pG7Fl3Rwsi8ZZM8DEr81mxAb9O4o4ljKDGa88T4TOhdKThis4mNGVN+C/fgM+y2fg7B9k+Pcv05s5gfC+bgKnLL8u10sWXDIyMl9qBq1eE90ZJJcLV3sPyln5qBs+TNerYU067FKE0RoUy0hOIhZjIJMjIKHjFExO/yCvrbAC461LESoV71S4cL27nS1SIscOgsPjjTzfOATToiExAMq6QR2fxsmViUzpbMAVF0NofQ0lSi3mwkpmrMlENzkLAM/wKJqGRtL//DUGfvFH/rb0x2gHhlFbLeT/bC32/mHcL25jQsl+XK2tbFj4MI8mNtJ4rJCBwECq0qbSf7gOj6uNqX1dvNuyBGe5mQKtiiWn9pI62kBJ8BRcvQO0VXYTMzhI+42rcdYVMvr8++BnxKbUoX1kLSUxecz8jK7T2ciCS0ZG5ktNcSdMi/nwu/XACfRzp1De543efgbJ4cR68CSq0n5uWZ6Ju7kCT7uNvaap+NTUE/3ACgBcfUMog/wRKhW7To5if/8gy++fyiGHdx7t32fCb45AnL/XFb2mHxL8oWJ7BYtGqwkd6qJ07T00PP0KklZLe1gC/0+zkILjXgGb+vxLvL/gVlSvFFESNI/RmlYWnNrL1onLmWGCqRXHqV23lj2nXUy17mVx9S66owMJufsGTjWZuM1dTMe2ahxpuYz4BePadZjF0Q4G/P2I723ElJuHc/laRo+8x3BLL6Gt9URrtZjX70YoBcWL1mHUCY5ETyFqzAZcw8COF0AWXDIyMl9qRu3gPz5VIzlduLr7MSycRksNTM4Aj9mKdX8RHpMF3ezJDERNQ58JZCfhsdqZ++4uKoub8Fs+jF9UIGMb9+LyC2Dvb7bRpzKSctNMdlsDCdB5F/vubfa6oQ+ZXDSdHmatvp1Tr+4nRRNI2iQj4TPziPjdTzl6/3cxHyolTDi4oW0jxfqZdHSN4mtR4vYP4MQpLYUZs3lo118pSPehJTedGMcgW7oMHPDRMcdSSlFKAYHxIei2b8BhO8JgbA5byofJe+LfWbR/C+4uG6O9pxkqbMIVEYHC2kGhfzLr2o5Q1W7BWFND+7z5+L32BihgLGMC7ZpAVi2KRvv03wlOCoXMmz/zayYLLhkZmS8trSMQ4/fhd8v+YgzzC3B5vB6G9pNVOOvbMCyejjI4gJZhiANMdvDRgEKvRWXQMumx+9j/ThkFmkGaq/toumUprVG+RBkhIhrmBMEbVeD0B70Koke7iNhzhPp+D9YwOxvn3snX5vtS//4BggdNND78KD7t3cTXFTMybz7HFq5GbN7F2obDbF14D/HVRZTnzuPfssaw1QTS19HHyqhjGDadIFAVxi/W9LG3vBtX9gRmHd8IWjstg8PcUPcyhSvvJuXwUdynmxg0e5hZU0Sf2ojobWYkPY1o2yCjL5/Ad8hKW2outn1lRIX5U0sQgylTmNZ+ksa/FONZs4ykRYnX5brJ7vAyMjJfWkq6YVKE97PkcOLuH0IVFcapfonMskNINgfG25ejDA4AoLIPUoLgnVp4rQKKmh00D3jYP+JH4+Q5vNIRiPHrd3B01Jfp0XBXDoS99y4NPQ6ijd59uIxaSGkqZ9SlxDppEr/OuocC3Qjpx/eQd+tUjrVDY1IugaEGhn0CSW8qJebALnK1o9T4JeA8dIJ4P4kfzNcw+Nx6csNBtHYgZaZS6ZdIUkESA398lZSju/nhrv8jxAdWhowwr2IXqv4B1m56ivYuM7+84ce8cfdPsCm0OD0CH70KXVggU6ytVBcsoMepQZRUkH66mA6PD00peQS1NzNwqoPGrzyEOiSA1mc2XpfrJgsuGRmZzwxHXQujr2zG2dx5vbuCRwKXxzvvBGA5cALD/AIkp4u+l7YQnxeLftYkAMYcUNsPJzrht0dBrQC1Ek6sL+VlXS4DFojQuRjoM/NvJ/353jSYEw/u7j4kp4tTB+voGvPGQazulejdWYwjIBBlciwx+7Zze6IF490rURaX0TZ5Bv1HKgh6+RU0OjWtc5eQ+d6rxNaeJKzpFEFtjQSVnyTwv36DFBBAu38Ujd/8LpYdhwmvq2Q0IZnqlbfjiYlChAdjyszBVF6Pzm4hLdCNY2CUvGAXiyLtLN/6D8ozpmELCGIsOZVhq4fqQ/UEPvUMfT7B+LpsNM1cRPfNtxKidWMaslA3fRH4+uL/0iuY1n32ZkK4TMElhHhWCNErhKg8Ky1ICLFTCFE3/v+8S6mFEF8Zz1MnhPjKlXZcRkbmi4OjvpXRlzbi7hvCePsynM0djL6+FffgyHXpjyRJ1PR7wx+dobBymH5Jx+hLGxmeORN9ejzVvfCrQ3CkDYZt3v2obs70uscbVBKLDH389JZwjFo4va2E4LwkMoJhXwuc6gfrwRNId9yIs6GNQD1srwdzUTX2MTsBbU3YX9mAY9YM/GflcqLBStHmKpTHTzLLz8QpfRShdy1H9Y9X6UlIo1hEsn7NN1D5+zI4cTIvZa1m1vdXYx620tTrQlFZi27RdPzjQlC9v4XJ986l5867mfK331KkiiF41gQy/2UdvdNnY2hvw2ffXhwN7YxExJKQFEjbPV9BPzLEsNoXH61A6bSjjAymJy4Fc2ElocM9bFv9CNMnBhL3+kv0pU+go6juuly/y53jeh74E/DiWWk/BnZLkvQrIcSPx7//6OxCQogg4OdAPt6oJieEEBskSbo+MfFlZGQ+Exz1rdiOlaNOisF490qEwvuubJhfgMdqx7LziPf7kpko9J/NYlbHqSash0roGfMlPwpMAu+EVmMLTa85CbtrBXGSlpo+ONYBCgFLk2FLnXdea0okvFsDqxTNeCYnoNNBmjSI4dhWohbmURKajo8aRnpH6akcYuDUDoaGbPi63KQGKsg+/C5NagNkpjFSMUJsvB+vV0L2C89jWbKAvthspu95EXtXO72bDmOaOQf9uqWk//wxKivLcD/6AG0eNZqdBznyx1a0XR0EWg30T59F0lAvVVvLiPNxUZk8hen9tYz4+lC07kHiJqgIqTrK3DvzOfqsmdk7N9Iam0F8ThTH2nW4n36L8hU3M/+XP8EUGUPWYCNutQb/ynIMuen4DlhYZakg7OQRYlpaaTgVQMqNHiDvM7luZ3NZgkuSpANCiIRzkm8E5o9/fgHYxzmCC1gG7JQkaRBACLETWA68dnndlZGR+aJg3nEEhUGH8a4VHwiss1HotfiuWYB7cISxd3ainz0ZdULUNe2Tu38IW0ktmvtuYqxR4De+9Kr/eA2qGQF05ubTNyKI84eWEYg0QrgvDFhgTzP8Yr434oVaCbaTtRhvXUr/sWp6NpSgunstKTEw0FaLuq8PjhZSmTKdpkEP3zFWE9hcSMm+elSN3UT/9if86HQMj88p4pXydiKc7YSH6NjuMPDgoWcpOj2Ee/EipJIq6u75BivbT9K+cCmZmw9jf2QVc2oO49u8k0NjRtQGNQMJU3D0dmPIn8Bg3SDLpydSZHVj/5+n2HjLv/Cd0YN4itTUbzuOa3cVQ5ETaPiv/2HXaSczCreQajehbq9F99g2lKmxDPuGk9ZWxXtxS0lZMpGupn6m//yrSM/vYMgm8HEJTP/1E2rDgkm/2IBfI67GHFe4JEld45+7gfDz5IkG2s763j6eJiMj80+IraQGhY8e/ezJ5xVaZ6MM8sd45w2XFYH90+Cx2RnbuA/juiWUdAsmnRXiqbW0jXhphICBHqr7oNcMy1O8W4zMiBkPRhvkXYTcNARJYhRh0DH2/l6ONLkIjAkkwehh8Nn1BP3xL3TuLaeneRCdy84Dt6dgHrEy9IeXcHb10ZeWzQtDMaQHQ1vOFO6q3YStppF9E5YyZ7AKbVQIHqHAMzRK5YylGDQC0dXDkQEDtsAgZj33vwQWHaN3xmwS+lo4mLsEj5+R28wlDOkDMSZE0K7yZ/Lff0enWTCztRh9YwM+ZeXEBSrIuW8RhlUL+HZVFHVjahJaa9EMDaJXgTUunpHn/opJ6Oh168huK0d/upbFBUGU/OAvBFSUMhoVx9F560hO9GNN2jW9ZBfkqjpnSJIk4TUFfmqEEI8IIYqFEMV9fX1XqWcyMjKfFc7WLlyt3R84NlwKQqlEFRGCq6PnmvRJkiTG3tqB702LEWoVnSbvDsBncJysIjwvCcXhQoYHLCxN9jpudI/BOzWwtwkEcKgVDrSC8dBhXO09FEbnMTE/ipGaNkZ/9zwdgy5GE1NxaHV03biOQINA+/QLtEi+jA7bsTT1UB+ezMlu7yaOM52tZGQHE2wb5mj5ENmpRqwGPzxqDe72LqzTpnL7wHFe6AsjvXg/WcESnqMlKGfm83L8EgZuuomZ5XsIamukecFyuopPk79/PbXFbQy09FP/4NfJTvdHstsJ+unDOL/7MCWvHWLgtW1858Bf+elLP6RUH4PJ4EfHlJnYsrPpe+J5wtvq+M3KH1K0YB1t6hA6Tvdg1fry5rJH2CISGYpPoWdnEWXFHdfken0SV0Nw9QghIgHG//eeJ08HcPZObDHjaR9DkqS/S5KUL0lSfmho6PmyyMjIfE5xj5iw7i/GZ838yy6rnzMZy8GTV79TeHcj1s+ahDLASFGzi7T2KqTx8O6jb27HHhNLdVASRQVLmVe2g7ZBF3887tWuKnu9Luy+Gu/29fXlXZRsKePfI9aye0czR3/wDxqbhukyhPLiTd+jPyiCgIEehnInI/yMNCVN5PTMpVSEpBLU2UzoqUq+ceJlfji6l6YXdyBUanxGh5hxeAPteyuo2HiS0LJiOgOiWPLib9n6132sPPAajdPm0Zqeh+mu29iyq52orkYC02N4denXKSjfQ+WWchb7D2NwWskyt1MWmUPUtDTMb23HNG06e58+xJbhEIb8QggVNsJ9BYq4aFYuikHzjbsxdY/Q3zmCtqqakjX3EBXtzyLrKRIL9xK9Yws+Tgvz3/wrt4UNEdDWSPqh7SRUHL8m1+uTuBqCawNwxkvwK8D758mzHVgqhAgc9zpcOp4mIyPzT4LkcDL23m58b12KOHt/kEtEaNQo/X1x91/cZ8v03i5GX92MZHdcUr224iqUQf6ok2Jo6HfjemcLSf4eTC9txPT+Hgb6rXTPnMf2ehgROmrzZnPor7sZtEG0H/z3ApgXDynBcJuuheUj5YjgIGIP7mbdnECmq/uZNDUaz1038e1lAYjoCAK1LqLWv0Xkob0MtA2R8NKzHI/MxRKfQG9qDsbUaFpquoh55EaO1dlwx8cQ5BpjcNEiSlURjBiDmBKrQhkZyrR7Z3M4bymGqFDUMWG8ErOQqPoKFrz1Vzbbo8lVDRJmVBDlGWUMNaaBMRwmK5kLM1G+vYlidygl2XMYmjyF2a/+kZDaCqKG2km292JeNJ8DKbNwF1eSphkjbLATzeAgETo3uZohVAISAyTqQxI5seJOAmZNJDtBx4yGo3RWdWBKSL7s63w1uFx3+NeAo0C6EKJdCPFV4FfAEiFEHbB4/DtCiHwhxDMA404ZvwCKxv8eP+OoISMj88VHkiRMb+/Ad/V8FLpP7x2on1eAZX/xBY/bK+tRhYfgc8McRl/djKvv4kLOcboZV0cv+hm5DFgkmp7fhu+SGWwwTOC4I5jKQ02crBhACgshOcgbrV0ZHkJoXjJBJ4v4ej70WCA/CvoKT9F5uBrfoX5q9VFMuHsuQc+/SM+0WWjDAvHLz0RhMRNfUUjG7BRmlu1C8W/fYpH/EKaEFGbXHmTYNwhbRjqlXRKvTb2LXWVmBtuHaDMpqZ+1hOG/vsHcos24lGoCB7qImZrC/m4NLp2OdHsXVruHkOOHOXLb16mKyWH2tpdY6NNP7YRZhKeEk9BSTV9KFhaPkj63mvB9O3F866t0Hakh+43n6UPPYGoGgUM9RKaEkjc7kaWRNnzeeJuRziGsc2fRvO4OYmz9TD2yGavZjjkylsLFt5H25ot0Z+VRvaeWvslTOX7LQ9Sa9Z/6Wl8JQpKuaErqmpKfny8VF1/4JpaRkfl8YN56EHVqPJqUuE/O/AmMrd+DfuFUlH6+H0n3mK2Mrd/j9VIUAsnpYuzdXWhyUtBmp3wkr6OhFdvRclTxkehnT8bmgv1/3EnG0hw2WKKIrjrJ4jQVFdE5tP/yOazR0Yzq/RkNj2bdDXH8oxQWtB1nyYwQdqiSmddXTunOapSB/hw1GclNMzJ5qI7TB0/R88jXUdacIiwvAeebW5h80xSEQmDec4xyewDa4UGOaROIaaigNTGHEOcoQQXpKCtq0B8+SvR3bmffiSGOxuVz999+ikG4CPrX+xhduIB97Spijh0gZNd2WpOy8V04FXtKCpW9kDnSjOnACdLCVSg6e5jz/VXse6MER9Vp0oIEqqOFjEXFEJEVg9XHj43Tb8Fl8CX87TeZd/hdgn7+DTxd/Qxv3M+wRWIoPZtBu8Ceno5bpSaiqxFjXze29HR8jh4jxD2GWnIR8r//D116IrsqLOTUFxNx09wrvuZCiBOSJOVfan45coaMjMwVYSupQeHne1WEFoB+4VSse4s+lj62cR++a+Z/YIYUahXG25fj7hvCvPMokiThON3sXejcPYDxrhUY5kwBBPv/vo8JC9N5aTCKKZZW5obb+aN7IlJLB/oF02D2NNKO7eZh+wmOHOlg0AqnM6dSvq8O/ebtOHYfJnj2RN5MuwExZsbd0Mqp2gFOLVtHxaYSetInEnJwPynmThwNLYxt3IfPgmkY9+7llC4Cv5F+nCYrRqWLREc/Ew5sJvD4EYbQ0PbkG2Sd3Md3f/8oIaO9DAZHUTFhDq/VqvDRQOC+3bRlTwGthu6wBEwOWBRhY/5INb6WUcxdgxy74S52moJp7bMTmJfCsMWFyEhGr1Nz5NZHqLzxbgyBvvTtPYlfRwvNsRm0/PZ1GvZU0jki0ekbRo/an5QlE7npzhxix7rIUw6Q+/ASTB2D6P20GGdOwGfVPEzPv4/HZmd6uoGy4OvjVigH2ZWRkfnUuLr6cDZ1YLx58VWrU+nni+R04rHaP1iUbDtRhSY9AYXR52P5DfMLvALrhffRZCRhvGfVR+bYDjx/mLS8GP4+HM+iSDsTCkupmLeaPBt0b6qnKHk6Bd2NJHz/TjqPVzDTUYw6UcHExEiiE2Zhe+4tgr73FXotBg6th5+IMZJOV9EVkchQTAKu0jbiKgvp77egWrUUzUAfo197kPWlJtItTsKPHUI9PEz1hJmcmLycnL697Otxo45TEnbTPJr/+1nydSM4rGP8csWPmXNsI3H/9xRfX5KNnzqERvsY+5fdhHpkGL+NO0lNDWDqWB2lgxrUMRG4TpbiNvhQ87f3mGepx7/oNJ7wUFoWLSM1L5p5216hoWWUkJBg8o8VEh3pQ8+9d+D80/OEWIcYUIMiPIQMXxuZk6IY+fPL+J7ux/iTOxnrHmY0IRnlPV8hJMw7ptaiKvp/9iQ+axYwN+d6rOKSNS4ZGZlPicdiw7zjCL43LrjqdevnFWDd79W63KNjOOvb0E3KvGB+TVoC/vevRT994gdCy2O1U/H8bvwjA3jFncq8eJhUuhfHkgWM2AXLU6Ct147aV09gRzNFPkn0Ll5GlI+HmKoTdJzuQV9YhP6BW9nQbuDn+2FpuIXUwztoC44lxTOI9nghS/ID6DNLSC0d7GkWPOfO5C+n/Bho6CEyL5H+4CgsxgD8IgO5Ycs/kGrqsVvsBLQ3of/ZL7FZ3fxP7I2cmjibEOswR5fcSf1dD2I0DzH8/V/Tl5yBvquDuNOlpPfW4WhoY+NgMMcXrqMqeTKZlg7ynvk9C2r30efRYouIxHjPSrL3bkD//Ct07ytF77Rh3L6LkJRw9PMKiD60jyDnGHVBCbQHx2NLSyU5P57+nz6JQ1Iw+Iv/QBEcyMkmOzd+cw7lfQJJknAPm1AGGdFPnYjp5U2M/eaZq37tLwV5jktGRuaykSQJ08ub8F278Lxa0NVg9PWtGNctwfTGNnxvXozCcGkbFkpOF5a9hbiGRtkRPZ1mRQCZIbB4pBJUSnb5ZJIfCXtqbAxsO0rw9CwailsIWzYVgxpuDBmi5NXDSDY7kkJQLKKIE2NY7W5S92zCZbFhS06hMyqZqEN7UKUlYrc6CNS4EYEBtEWlUuXyJ79oO9tjZzJqDGK2/ygZ/i582lpQHjiCTaHBPXUyoZKVDYYcYg/soDF1ElmDDTgKpmCKjiX82EGcTokGdQgJnmEs8fG0uQzk7l1PzK3zqXf6ohocJHbnZnS9PYwkp+GJiyYxWMHokA1L/ygjweEM+IaQ2VnNmKQmLkKHfvZkcLmwHjxJZ+sIxdG5ZDt7iPcME/rY16kyxOFvGWbs7e1oslOJsfZxqk/Co1CSnuiLNiwARZA/o8++h35+AT4Lpl7xtb7cOS7ZVCgjI3PZmDcfQD93yjUTWgD6WZMYevJlfJbPviShJXk8WA+dxNXWg2HhVMpcoXR0QJQP3BA6iqW0He265XSXwd4WmDDUwPr4ZNSHyxiZOp+vRFgwBup4tjyQRfMnM/r3NxnpHGJ1YhP/OuXb3H7oZWz+QZyeNY386TGY2s1ENvsRHujEZuplw0OPYQ8IpLIbZh14B+fACFk5TmJHShg2RKMpLMfd00NFSAaxRoiZlcPR6jFORk9hfvEm9vlGELhgCoMxiSRveANlZw+Nkj+D9y5iNCKEMLWD1K2bOfLD/yJ233Z6C7JY8eL3UagUKOZNwzxtLnE7N1OXt4iItoNE5iXQ65fEwhAHTT5phCoc0NuNZfN+nB19qKLDaPCLYap+mBCFjYopN5CwtxrTcDkhDeX0Fcxn/qRopIhJNNR6tdh2JWgUMLWplIA7V6BOivmEq3JtkE2FMjIyl4XtRBXKkEDU8dc2rqA6NgKfFXM/5jF4PhwWB6X/tx5PZAR+d6+E8FBq+qDDBHdmS5g37sNnzQJer/RuSbIqFbbtaWfChFCEQvAzQzmnX99PxR830tM+wistRkrHDEijY+wb8OX3R/4XTXgwvUofGDbR/+YOjBs3c2TOTex2RVMYlEGhLZD9TaDq7SagqY6JDy/DeNdKAsN9yfM1kzY1Dr3HiScijJFJU7B1DdA+KvHL5jdpzp/DWEAIma5e8l5+ir5OE3+a+wi9KVmMljfgu3U7Ef/3OzLClHzHdIiEhnJW/vRraNxOhu65i/6YJNKrj2NQQ3pDGcE58bQ8/A0Slk5C6LTU33I3wdlxBP/Xd5Bcbox3LGewpZ/YmpOENJ/GOGMCs4OtOCQFLquDo/d9hzn3zcQdHs5b1YJFiZARDAFamBdhp76skw2eRFot10f3kTUuGRmZT8RjtuJsbMfZ1A5C4Lt6/mfSrjYz6RPz2Fyw9+mDaBcuZIcrAGMjSJI3ysW9E8C6+xi6WZPY1aGlwwRrg/rZ+nILoQaJ/q1HWWLowxCZwuLv52M1O2j+3z30nuomxtxL4YyVzO6vwNXQSvOab5KzazdCreZQcBpTW/bT5R9JYv8uymatIqSqjOHQJGbvf5fGhSuo7zBh7N7PaHkDrvkFdPgGsGXpZApqDpGqs2A5Vo7l9n9Ff6iao8YC7jZXcGKfEz+zG9++bh44+Cz2ebPZG5eL8+QhVv3uK7h7Bxj8x3uIkCCq0qYRPCOLnHuXUF8/hOJ/f4XvbctwD5twDo3R3jFGZu0xCmesZGbJAQyLZzDytzfRf/s+Wv/yHr12FcnTMjFkxuB3/1qEEOQCe5thqj+4JW+4q9VpEKCDEIN3483qNw4z7d5ZmPbtwXRwFL629lrfBh9DFlwyMjIfIHk8uAdGcPcM4OrqwzM0AkIgDHrUSTEYls68ogXGV5tRO2zb30VetJ43zQF8MxPMTviXbZAaBKPNPRxqdFLpE0uW0U7S0cNYI1T4oMG3tAhfXw3GP/wLVQ4t2w7BqF1DVegM/u3kb+jX+xMzKxNrhYUXIvKZ/vi/E62xUWu1M6GplLLv/YwbdrxId2QUlQm53HHoJZYcfht3gD86lQWDn52M0oNUrVmFfdokBl7ejGvWdMr8DYS//keeue8xvpvspvvJRqaEKEhamUNYjpa6FjNvxnyDpcXvYzzdwD0d9fjolXT8bD+q3j46IpMwOTxo4iJwm228/VIZcztO4FKrkewOlCGBHDWmMvG5P9F25zoGXtnCwbRkMgrbMA8qGA7OIbL/H/gF6gldVoB2UiaWrQfR3zCXnY1g1HiF1Ds1cFOGN9TVGbJtHZxuaqTqR0XEz84g5oGV1+W6y4JLRuZLjntgGPP2wwiNGhQKlEH+KMOD0eVnowgwfqrwTZ8FAxbYWidxQ3shB6av5ttJsLvJux3JQ5OgbRRObSihadYi5reUUFzcTdSamVRp1Uj/93dcIbFEBSg53KOhcRhuz4a3quErTVsh0J/Uh2/F/t9Pc2LOSlJqD5PaWUt32kQGouI5teRGbkjX4HrXymj2BP6layt7IjKxRebgq1WQ9cbzWKOiKfIE0Jo6hayik9Qm5jE5EnxjYxl9S8/k/e8jra+myCcJW04erb6hmHW+GMQQt+bpOHrMH+O6m1lU9DyW8hqaguKx/PhRNg0GMffkdgbURobSs5h4cg9DGiMDdiPmXa0oPG5cgYGYfIM49cJeuh/4KkPDdoz/+B3Bv/kBmRvewtrWge61P2Gc6N3MY6RtgJ1vVTNleRa+anj/FNyaBbpxCeHqHcR6oBjzziNEpSfScPfXqfTzJ9YKcZoLX6NrhTzHJSPzJUaSJMxbDmJctwTjLUsx3rwYw/wCtJlJKAP9PrdCq9MEOxphTW8h+nn5oFDgr4M+C/RbwO4Bk9lFjDTGnKKtFDsC0d56AzfF2eh9bRejcUm4w0JRLZlNyL7dfLtAomFAor+qjcSKQtxzZnDiH7vI/NoKJnRV0zthMo7wMHRDA6i+cgvTU/RUvH4ES0oasVMSaTjZzrDQcb+hmdv1LYg504jwmDFlZhG0dQutJ5qYPCMWHw0EtdZTtOZuknR27P0jtATFYbG6qT3RTvfbe+k63cvIi++T2F2PfsNmerYX0pcxkaYJ09g35E/SaAfarETWxZgwdrSy9J4pTBxqYPE/vsPsECtxp06iCAvm+flfpX3patL2bubRiteJnZrC2FvbOb2rgqHpM0nP8Qqtsm44GJ7HXGUXno4etjXAbdkfCi1bSQ2WgydwDQyjnZBG4L/cS0G2P3aXV8BdD2TBJSPzT4izveeS9reyF1WizU1HaK/Da/OnpGEIjrXDLVGjCJOJEk00+VFeYTZggShf2FoHqe1VjHYPM7x6JRNmJrBgqIpf/6WKWr84gk8Uoqyto+ZgPd0NfRx55P/ofWMnq155gsHAMPoPlbPy67MJT4/iwJybmb7nbfqMYZy672H0L79ByHvvMKVyH92p2VQeqGfwkYeYN1xN3al+dvhkMlE9jM/8fAYXLUHv74Ow29m9tZ7CDijZW8+sRDUx9RX4DPRCgB+erAzUkpuUnAg2Zy5h7bJYmJRDYkM5jrvW0Xj7vYzFJmAsPM6qo2+SXnWMDr9IYrR2rDuPoAjyw7LjCE1OA43L1pIxP53ZCUoiU0LwtQzTduw0KcvzqA9JJM45yNFFt7C/2St4HG64ORPqpi6gY8MRlm1/AeXoKJLHw9iGvUg2BwqVEmlkjMBv3fXB/mpGDSy7PjF2ZVOhjMw/G7bCClzd/SgCjNjLTqHNPX90A4/ZiqOhDb87V3zGPfz0VPRC64h37sX06kF81y2mowFCDVDUCclBsKsBVErw276T9pBU1K9u5pTHwE6/UEzaIBb3l6FQK8lYlM2YXSJQH0xFlSDi6GEGDYFo4xOIemANO50GXj4pMX/DP3DbHBxZcBsJ82YwseoQAxVWnKMWug7XcONjN1LcBZJeR+C6BYy+sYcBm4PCZYuZqh/Do7ZjeGA1a97fQs+f1jMpwEbwnzbQN2zjUMpM6rKmcU/9HrQGwZ6RCBZYD9O44yihhggs0bH8PWYZ9/jCc6ZQCvIn0V3cScyEGIYaO0gZGWCkrQnj73/Knn2tdObE8PD9mbS8sA3zhFgKjm4iOlaD7ds/5fl2X+ZX/xFTQAipk2PJDAWN0ut4sb4WEgMUpEwMwV49Sv9Pn0TyuNFNycbR2I40Zsb/4Vs/ci3aRr3Bh68HsuCSkfknQfJ4MG/Yhyo6DN813mgWpnd3oQwNRBUV9rH85i0H8Fl55QFSAe8WIyolQqm8KvWdj6IOGHPCylSwldaiyUig0aylcQj8tFDd692csdcMs4er6dIEEOQr8HHbob2PIPMg2QYLfhVHcaYmM7hkCfPj4TvbICxgEFV5OUHmEep14ayv1lA/AuFt9cQ1V1M9YTYB2fHsPOVkusKNob2Vvoxcgv3VvFQhMB8tJ1sTRM2zB1lRs4/R9AyqK/vJM5/E7XYReKqE7dpoVtgPo2yxszdrFlp1GyqjDz/Y9Ev0C6bSOmU2+W5B1AvP0ds9RvMNN+AaGcNPK3iuDH48C9r+to86EUi/MZFliTaGuozsDVtBeK8K/8bTzP7mKpRKwclhHRlb3yEmUoXvTctQ+BlRrz+JRiWQ4mJwe7wbZfpo4PVKWJIEYXo3psFRfNcuxHrwBLg96GZPwl5Si98d9370eo/HrbhelmRZcMlcNyz7isDjQTcz73PlqfZFxDNmYezdXRgWTUMVHf5Buu+NCxh9eRPGdUtQ+Bo+SHfUtaAMD/5YBPZPg6tnAPPmAygD/ZA8ng+faoAqNgLd1AmfOFfm6uhBEeh/wYXG+1u8cy4LEsBjs+OoqKNu8WreKoUHc2FXE5icECTsJI/UkfXOS7hiopjiaeN9eww6q5kIcx+K7ASk0CASlk3iYLud/c0avjO0j57XdzLsctKUO4OayAymn9hBjDaYJSc2UBYRz7BfMCofI7fteYnSRhuBPoHoEyKI17poPFFG+kgrOp0Sj8JDb0Iqh5bex+zyI1i2bUXnr0fqGWSpXktYhIEaXSA+zU0E+SigII2R/W1Ig2NM3vMu5RFZpORG0WvuxNbazeDMOaT4QYwfxCgsqFuq2HvHt5h0+F2KdH4IoxHNnClkVh7laP4MPD2C0VEH6p5eklxt+Dx4L8qQQLZUSqxu2EutNgxLchrrMuHdWq/wumXcCcN6uBxlRDCOijr8vnIjkstF37/+hsDv3M3RdugZ82pnkb7gkbz/rxey4JK5LrgHR/CMjqEryMG8+QBIEvqZeefVDGQujCRJOE81YyuswHfdEhQ+H90fSSiV3rBJ7+zE755VCKUSyeXCdrQM472rr7h9j9mKeetB/O5bjVB9/HFiLz+N6ZXN+Kyeh9Lf+JFj7mETtuPluIdGwenCY7MT8PAtHxx3NndiPV5Oeb+CUB3E+oMJKO0WHEuaxyIgPRherYLqHolbqzcxpPQhdbSNzFumsr3WxR5dIN0uNSseWcGhKhOLtz5P/7TZHAjMJuXIbqZ3tzDa0UWxbxLmO1cwI8JFQkszr+bewD0lb0FnL+aF8/BraCb3jRM0hSUgJSTivHkZYacrKVGGMf+5P5GyJIfuThfuWdNQnx5lcP9JQo5uZOT+exHPPEfx9JXcGm+hsmaINn8/MqK0lPQKCo7vRW0Q6Do76L/3VvyffRt3gh/1s5dSUHKI51xL6GqFl9dK7P7m8wQtX8ANb/0RzdolcLIGT2sTIw4j21xups4KIVLr4MX/3cKdU/zQjESgig7jQCsU1B9HmxBBbI+ZDcFJrBSw7qzQj5IkYTtRjTIyBM/QKGNvbcfZ2IHPvWs4vPUUodMVTJ8ax5ANusdgSx1EGUGjgimRV3wbXTZyrEKZ64Lpre34rJz3wRu2x2bHdrQMV2cfmqwktHkZn1uPtuuN5PHgONWMo6oePB7UiTFo87MvOl6urj6sx8ow3rQY85YDaCdloooMvbJ+uFyMvrL5Y9rcuXisdsyb9qGKCkM3Mw97aS2OU80ojAaEWo27dxBFaCDWgyfQpMShSU/E2dyBMjKMXWGTyI5QkhLkrautz8mO320j0l9JozKYo1F5pEVrmS+1wZsb6PMJJnd6LJqRYcq3V3EgagoP35HMqRNtqEdHMPWZ8J2YQnBHE64dB6mLzuBo/g0km7tIu3UWWiUUvXiQqXTjaOlic9xs7hktwlTTzH+sfZz/LH6KkNsWM9I1TMixI3RkTkK0tKGursUzJZfQkuMczVpAYls1h+bchE9tLbFD7Ry/55skHdmNr2mIyLgAIk4cpdGsYfq6SQStnMXWx9ejHRshIsKXgGADB6vHWKLqpDljCj9WzOHW/uNMPrELXXo8PTeswlV1mmklu4n49XfY98RmxgJC8Avxpbuxn9zsIOp8o1mUrsY0bKU8PIv8P/8aw4o5VBR3ovzKLSAEU6M/vEbmnUexl53C58YFSCNjKHwNjFY1cdQcQK7Ug66yCuMtS9FPmwjAuzVeh46rhRyrUOZzj6u7H4Wf70fMQgqdFsOCqd49lSrqML28Ee2UbLRZ18lt6XOI5HYz9t5u8HjQpCfiu3bhebWc86GKDEWTFIvpvV0ItfrKhZYkMfbuLnyWzbqo0AJQ6LUYb12Gvaqe4T+/hsLPF1V4MO7BEXT52fgsncng715EqFU4GtvxDJtQBBgpcgSRE6EgeVxotXRZOfqn7QzPWki1yo/TNf38p+Ug/YUOfCtO0jB1PuF97QTWVnF6Rzmj0emEBWkZKqqhIzidG20tlIyMUFPdT0BDN3l5mewvuI+8ol0kfm0NQ1boNbmJx4S530S4fZTZvWUc1UbQPGMq01qKMY4Okjnagv6GfNw3ZFD+y62sfHAOh58YZNpQHQ6lg7DaMmJjdKxLcmF9dStP3v9L5gZ6mB00ypC/niz1AEUhMRi/fxulPj6oByBg7XwMr77FlIcXs/mxd1FGJxL/+H00//srfN/0NsreXiavy0ebkciYq4+SoiKaUrJo+NteqrNmEDE1jaCt7xIVryYpO4zE3HTerZYIO7yFWSldUJCNZ2AYi8HIohjB+lrvsoEQA7iGRhl7bxdhT/6Esff3YphfQMf6wxwqWMHaTIFBDR7XEvq+8z+oosKQoiJQK2HE5jXhrrkOO5vIGpfMZ87oq5sx3rLUu+D1AkiShL24CkdtE7qZeWiSYz/DHn4+sRVVovD3RZOW8OnrOFGFdmI6Qn1l76zmHUdQxUZcUkimsxl5/n1UYUFo87NRhXklkqOuhaE/vIR7YAT9jFwCvnknxV0KlJXVpPQ1Yrx9OafbbfS+uYuWucsoPNBMUG87CgFKBdyYYMfx8nrcYxZSHllN59bjNJrVGPx0pChNlOriUAUaURi0DDX14vAIMuMNPJu0kjV1u3g3dw2rQ4YZPF5Dans1bdEpRPa3o6w9TXPaJIoC0ok3dzHt9BEqvvdvTJkSzrRoeLlcIvu3v2RoSj4BXW34S3bGugY5okukcflaUkoPkXtgM1pcON0ShtEhtMmxdHv01GbPIDtSSd7ds9EHGOh8diP7WgTLjf0c2HKa6AUT0K6Yj+ZYEa6eAaocfoSH6ZFUavRqUEaEoB0dxrm/kITffZ+Dzx4kcEoa8+d9GPT25XIwHjrI1KLtBP3nNxkqa6Spz8WMB+fhcMObVbA23orpdy+gXjwLJuVg/fvrDHm0tC5ZwYoJOhRnKfHuoRH6fvx72r7/A0pMPrglmBwJ06LPvcqXz+VqXLLgkvlMcTZ34mzuwDC/4JLySx4PtiOlOFu6UBh0Xk0twBeFvxGFny/KYP9r6sn2eWL0lc0fbFt/PbGdqEKy2r3bY1yE0m6o7vOGXpoUCe7aBjwm8wfmJvC+oAz++h94xixos5Jx9Q4ykDWR5pRcliR550Jrn96GSaFF3HwD+4720FXaTMF9c5mXAL8+DCFvvkFB1UHCrQP4Sg5GhBaPTod56RK00SG4N+8hPVrLZkM2TpOFngVLqHEHsObYm/i3NdOanU9lUAoLbXWY58zGNyIA7de/j1lSMxwVDxYrISnhRCSE8GTMcu7Khv7DFbS2jXJ771FMkTEY506i6vldNIen4J+ZQG14Gmue/k8iI32wKTT4DPRSZA2gZNXd3NG6C7NNIs3ZizohCs+omU7Jh3ptGAm1Jzjum0SG2ozSYWfCxFDsLV20H6giMicGn/xs7BWnUUaGUtQuYVown5mVe3k1ci7f+Uoq2vH3keo+GLZB6EsvMjxgISg3mTbfCLKTfQnNTQSga8BO8z+2oLDb8STFo6urx+Xjw1MJq1g7UUN8AMQHgP+435TNBUf2NmF75k2m/u177GpXc3v21fEsvFzBJS9AlvlMsR48gX7OxR94ZyMUCvSzJ+N390p8Vs9DOykDZaA/njELjuoGxt7few17+/nB2dqFKjb8ugstV+8gztauTxRaHglq+uHOHAgywPpqD0Wbq7BNnPCRfJbdx3D3DuJ33xp0U7Iwe5R0Hq1lUYKEJMG2Cht2u4usWC2bT5qxFlYSvGoO3WPwVDFkO7uZdvowUWM97J65lpKAZJSjowz7haA16hkacZC6IItKpz+Jx/eyYIIP0Z31/FBVQlrXKXokA+8lLsTRP8JA5gQ6NYF49h1FoVazc9JK7HYX6V01mAOC2R4+BacbXi9xEH/yMCsOvs5v5n+bnvJm6qt6MYQGYBROqiPSyeuqIGSwG1NtK6a77+CoOxzX/Xdyy5FX6I7PICwvEU12Co7GDgYPllHXOMpwWhZFNz1AcIAOtdWCdfYMAh5ahyoyFGnFIsq//SNqsqZRlb+QcmU4/XfcQZw0yoFmD/7pMZzs9o7pmMO73i3f00WAaRBLQDAnugV1FV2Ud7kZeHcvo69uRvfiG6S5+kkz2ClYnEbKhAiOxxfwlap3CdO58dFAcYfX+/DdGth0GibOSkS9dDYjv32WRH+P7A4v88+P41QT6tT4T60hCZXKG0cvyJ8zRsax9/fiHh27Km7dn2dsx8rwXbvoencD66GT+Cyf/Yn5ijpharT3bTw5ECIrTmJfPZnCTsGIDSaGQ7JihLF3d6OblYcmNZ4xB1TbK5mi76f2aCMbW7XM3vQiQ1PzeXd3J1OrH6dy5nI0r7xOl38kATNyiH7h74Q4RmkMTyZoZJDYgVaO3PQgqT31aLdsx7loOUVjetSDJhQrl/JO/lqWhFvofHkbJ77yIxy7j7Dk2HuoUuKJnpRAf6sD/xde5njWHPRuG7kF0ehW55C+5zAji5cwxQOK17fQawV3wXQeTTeTODGAka4qqkrasEyaxQ1DpSheeRmzycbpf/sRnv01RE+MY+GqKIYawrG/+jotubm0ZKfj31OJ7Y+/Y/RYA/lbX6W9tpvR/AI2rn6EtP1b2eaWCKwdwPIv36B5ULCytoyQ2xZjf/5tsl5/ksYxFQvWzcZVs4M3BmfT1SRo6bFzZ9Qwg39+BxEXia6onMioKBJGhokNGaHKqsOt05O9cC7asgr87l0FQGttL0GSmxlfnc+urcfIf2gWyYEfva6DVlAunEXry93kbn0Pz0M3o1B89tJLFlwyVxWP2YrQqD82hyJJErbCCoz3XLkL9tkYFk3DsvMovjdd/4f6tcIzZkFoNBedE7zauDywvcH7Od4fEgLAR3KAy/WJmzq6PdA0BFOjvEu6JJsdV2cfgXPzWYZXG6totVP4q+dxr1pHjrUTtyR4/xTkRxsoHorF/c4h1tp70T/+NX53KgC/2s20rbyL5j4nuT42ptpbUL14CGNXG8aeTvCNIMBtRrlwJiGNtVQ98BC2YQvT3vgbA8FRdM5bQpKpE6MGGv6+CXvvMNLoMW6aoMbzj90cvOP3bKiRyNv8FtpgIx6nhzmpWkYq6vExj2CLiiJl3xam35zH/t7TiIkZDCxdzJRtrzE8ZMVUWI1u3iLio/zxvL8DbWc3Pf7BOLftJ7WvAc1dN7L70afwtHex/rYfkjfWQvCBE9iVPvDUGwRGRRK4ai5DM9zoh0Z5VCrDuTYH2y//QvozP0UVKhjtHSEo1IBz425GXt7EaHgMIcnR6JqbsRwpYUZoN+q8DCL8A6nbX48wRNDlk8bcXyyncW8Vlso6Qm9ZSERYEGOtfZTtOcWYNoasUYH+0GEaehxMe2wZ6iBBpl8Nx0sHmDk5+CPXtqrXG6Vk4M4b6fvLX2j+xctM/vm9578RriGy4JK5akiSxNg7OxEG3Qeeb5rsZIRK5Q09NDH9qpu6zni0ecYsn+jd9kXFeqAY/dwpn1l7FqfXNLQk2Tu/0TIMh9tAHC3HHjsJbYNXkMX5fxiI9QwuD7xRBRaHxO6nD6IwWzDU1+G3Zj6pdhdqrQoheUjYswVpcgwWvZ1Duiy2b4OcUBhMzcJy7BRxJYW0LF1O47OFzB7uYX/BcvTBfsx54SnM0XGEt9Rh7R4gwGGmI3MSux/+OWtrt+EpL8eZmkWPTyhRkYLYv/07PpsOcqrFTEKAGp+nn6dLZyT9Z49Q3SwxeORdzJOmk/a3P2GPySFzsIFelxYpOoKCU0f4bdY6vlv1BknTMzHtPUzb22+QvWo+fzBkEOPU07SxELtGx0BMGi2nBrAMqsgcHiPUZeXArJXkD9Zg9zFiK6wmdrSTvv/9GT+N1BHSYsKZloN2+Rz2P7GRseULKFQoGYzyRgGZoWlh8LG/0BscgdvjfVAHHj7ISGMZ7tombEsWMnD/g8wxDjO2YS+hv/4e1S8VItkVFPTXYe88xajKgKf8JJUuM6KuEWNkEJ3bigjUgyoqjBmrc5CCAine20Dnq8U0PPQNVgd6f5/xN8+m8f+2Yp+4Gq3qw99sr8X7l9h8iqHYRCZM+Khg+6y4YsElhEgH3jgrKQn4D0mSfn9WnvnA+0DTeNK7kiQ9fqVty3y+cFTWo5mQim5SptetvbaJsfV7wOPBY7Hh95Ubr0m7+kXTsOw+ju+NC65J/dcTye3GPTKGMsj/M2mv3wJb6z+6D1N6CKQFS5iO9eA3L58xBzQNe7cQsbsBCQxq7z5YAhjss3BP6y58l+UjDHrGDqjp1mo4+PQ+cLkIsI8xpvVhZEo+xuJiKqbkMiMUTnTBaWsA9544gH96DAO1dRgT0inziWN2SxGm92rRLJ1F/qbXsQyYCPA34Bxx8KtbHud/XScI66nltZQZJM9OR3WyDM+8PCrMBqrrJUL1DloKG5GUSm565n62NUBC+XGOxubjm+dDXs1RdCMa9NU1VE5aRX53JZtX3MXDO16lPTyB/mmLGBvQ4HZW4LO/jNlpEkHvPEm30KNFRXBPG8YJWeSujKbs/wrpnzGbacE2CpLC8bt7JSPPvofHP47Od7Zh8B3DUzAB3/FwW9apBSzpLMRnyQy2N8DkCHDuaUI7OYuIlFRq/+1vJCzMIXLDe4zMKsAnU0nlrfdzczwIEYDxzhWMvLkDk8VNzNJ0PDtPIxkM9HdayVuRh9AosQ97UKUFsjVnOXfmfHi97eWnobWbkII0lJEaXq+CpEDIj1STsTSHwvWlzLllEuDVpC12D7pd+3CmhjD9X6+u9eRyuGLBJUnSKSAPQAihBDqA986T9aAkSauutD2ZzyeSJGEvrfnAFCiEQJuZhDYzCcnjQXI4r5ljgdLPF8ntxmO2fixyxBcd+4lqdPnZn0lbzcNwvAPuyAb1OdOQzoY21ONLEnw1MCHM+wdec2DzMCQGwuGDbUxuKMX/Lm8Uj9FXN+N/y1ICNGoy5qfTZ4an32oif/AUHh8fOt06tCpBcSco3E4yj2ynSxXIa3FzWLTledoD0+iZP5ele5+hJCSMyTvXY4mNQ9ldgmQGKSWBqRV7GQwzE75oFhZ7Os+PhfNv+gpKVDm8WKYiXWNg+tH32Zq7kjSDjf94s48JB7aQqTWhiU8jbAAcpdUsrqtgSKMna6SZmMUTMB/Yi6t7ALVLif2vr+C3dBZBSjN12hkEFp3AP1BPQGcrA8tvpqdjhGAfLTXvF6MaHGbi3XPoKqpnODobnnsPz5gZd1YGPqVtqNIjsZefQhnohzItEWdUJJ7KCkwDZixOH/y7Wxk9WUPgt+4kMMDIjg4nQa+9QdAjt1De4cQeFM+aDMUHjhEKnZYj+cvJO/oEY4+uR/Xd2zjd5SD+nlmofVW4egeQzDYcxZVMHJaoLlQR6++9cINqX8yxcZh9LaxJ985JNg55HTKCApOQ2rYxNmTBN9BAU7sZ26u7SFo3nWnTvGHFrtdv7mqbChcBDZIktVzlemU+59iOV1wwJp1QKBDXOBahYcFULHsL8V0175q281njqGvFeNe1j95e1g3tJrgt6/zuzbbiKoy3LDlv2eMdUNMnMba3iLEBJ/EPrUIIgb2qHnVizAdzc1YnPHtglFVbn6VtyUpCT1VinDMBtS9EGlyk7N7KsewJlDnULKvfS3BiMJkdVSz/2y5OBKajDDZwRGQjuWFKajL+5SW8tvprPHxnOofeLePUsREaZ4bzLxNhf/EUejafpD98Et8o3MhAeCwD8xbSPNjOHbueYzAxndJl93GkXZCitbBixx5G1L5smn838+fFEO1vQfX2+0SumIG9rBbD8gVYD+2iKWcquo4+gmrKMbqt+K1bjL6hnLEBJa1xswjZe5w0gwP7jsNE33MjFadGCImPxr7lAM0+MSQunYRK4UIR5M/Qn17FFJ9EalIyCqMPVT/7OzOnpzGw5zCa1HgsewtxtfeQ0T5AX3A00UdP0mLIZFmahEpyASrcwyYan9lMbHUTYXcvpC00AcOWI2j8QwkKMXgXmmvV+K5diCYzCf3hEjYlLSA9C9wSHKqGyUU7cC5d+MF1Twr0/vWYYc+subz35H5W3prN0dfKsK1Yxqxp49FuzFbv/PLahdfilrwoV1tw3QG8doFjM4QQZUAn8ANJkqquctsy1wnJ5cJZ14L+KsS++7QoA/2QbHY8VjsK/T9HwF5nSyfq+Mhr7gK/rxlUCm/U9fPhMZlRGHTnjdJxpA0UNhs3V+/mOXcmN96ShBDeOUdb6SlUt66gYxQGbbBpVwc5G16l6+ZbkLoGKCvpo8R3KhMFBL/9HoPCwczyt5AyklE1dVPvF0O4Z4DW1TfR0SnIrTqEXRIELshHvbua+ugMdDYLLe8dZqLazfcDpqIzQ90gVGsiCW85zI+r9tIflUiZfxL5ERIT60oQdjNPJS5iqFEwyTBG2qZ3cEoCq8HIrdZyAosb6N5zjMCZE9GkxcO6FTT/v9/j9khEWw+jGRlGrJzO6S4nETfMpf+5Lfg11OM+eIDwQC3asACCvnsPtpIanB4f7CerCfz+/RzyJDI9yztu9uoGtFOyaC7vJjmkm9EuG86wcFQ9PSimTkCTlYx54z5UkSH4Z8XTt6+aiolziWxuR+XRMvzUm0hOF/aOPk7nzOWGZ25HoRC4fRIZ/slh8p76GUofHc6mDtydfR/sFiCAmT7DHGoLYNAisTRRonSHh8XRH7+24T5wx2QNh97pofRdG0Uz1/KL6R/ei7YTVegKPhtrwLlcNcElhNAAa4CfnOfwSSBekqQxIcQKYD1w3p+JEOIR4BGAuLi4q9U9mWuI9VCJdxfa64xhfgHWfYX43DDnenflqmA7Xo7vTYuvWf0eCTaehpQgyL5IBCjv2ruPO4ccagVdSzNpzRWoVy4guM2Xwg6wOCX6XtrL8PwF6NsEQZYhOjcdI8rHn2mrJ2K4KZdNx4ZY3PAPQqPt9Lo0+DQ2YFwzF9vq+eQ2nqBpQjr7/WZyc/1OWmw6emxOnAEBaB68Hef+Qwx6NFhD/Zg5fAqpsos3QieTrigjUtHJjnojEwYbCBUmrH0jaMM1zA614l+8E0ZGObj4LlJqilAW5KHdspOs4Qbip8ahKavkVPYEkkxNtBj94eYV7PRLRS+5yRwbQikkFMMCZ98ACpOFpFWL2a5NQzu0hZold7PqyBvoRwYx3H4r5i0HUSdGo87Kw1HrZjgmkbAB77i5h02MPPsuuqkTcCnVqOxWChOmMMlyCMvOI2jzcxBKBUH/9jC2QyWYdx5BfPMByvoM3GjqpCt7Eum+DiSPh93KVJanq1AowNnTT+pzf0f5s69i3XEY/dwpWI+UYrxrBe6BYZytXUguF/o/PU1vaC6JgYLhoR4izA5sh0+iyUlFGfBhIGRXRw/mnUfJe3QFh145TojCjr/uQ49S5+kWlCGBH9mN4LPiampcNwAnJUnqOfeAJEmjZ33eIoT4ixAiRJKk/vPk/Tvwd/BGzriK/ZO5BkgOJ67u/kuOhHEtUYYE4jGZ8djs12SbFNuJKlSRoZ9JBHuPyYzQaa84NNOFsLu88xhz4rxbZpyPfgso8aAcGUMZ+NFMhxpd+B44SFqCL45bVvF0ucCohUEzBO8uZebSdHSpSsxbd1M6oKZ++kKy6oo5npCPtWiYmacPUTF5Lvnvv8h+dyTBahcap50c1QDOiZkcLa3k7qa3Ganvwm4ycF/nScoLFjPh979HO9CPIy2V/i47zv3H2L3kXkonzGWytRXNiWKWCAd1wp/p1cUoU+LYMHUN2W3lxLz9HmUT5lKdEMZ3p49S9fQfmRFkYyg8BIZGCPrxV6FsiK7txRx59KdMjIhgmauDoX/9H5wjo6jnF+Csa0ao1WgLshnpHyXwZ48z7FSQqTERXJCOrbgKz+Ao/g+sRRUdTsahSmpisxjrgrk+g5jeKsJjc2CYPxXlsnkMT4KxXe8SdLIIjdKBMisZ/fSJCJUKy44juPsGMf3k+7SN6ZkTLeHs9qcxOJGJGd5doDM14KuRMO8uxLxhLxH/+SiqiGDMe44z9OfX8Vk8HdOrW1CGBqJOisHnhrnY/AKR7OF0p8YytGUHC/51AQwMYTtejmfEhCQE7t5BlKFB+N27GqFUUj99IetO7cZaGAceCXtlHe7eAdRx1yE0PFdXcN3JBcyEQogIoEeSJEkIMRVvxI6Bq9i2zHXCsvsYhgVTr3c3PkA/rwDr/mJ8ls26qvW6uvpwdfR67foHTiC0arTZKahT4j7YyvzT4B42Yd526EMBNR6CTbLa8blG83UjdthwClanefdXOh+l3d6dhoPrTzHin4G91puuUUJ/Uz9hxw4xMG8WDaGh6Nu9XoXLkqGnZYiTDQN0ZE5i9L93UplagC0kgBCri7yRYdQVlYTv2kWL1kDkA2k0iBzS332fgVtupSMlB9vRXRQpo2kLTiL//RfYP/NOFlbuojI4jpDSYqqzJ3IsPZnwwU4SdU10TJtLUncdVpOd3oREthbcjo9wcf+WJ2n/zrdRHi5iafsxIn1c1MRP4NDMG5nSWsaBZjORIf74jnYTHBtIfVAqx4vHCDteRMnqu/ja6gisb2yi/69v4DFZ0M/IxX7wBO7hMSSdhtF39qHyuIlTSNgM0SRVF+JWuAj5xbe85kWnG2d7N6rCE1jCJqAvrUYxIQD96vmYN+3HsHg6ZQOQEQKV+hgmSSdRhoeinZiGZfcxAn/4IGO7jnNs5hp83XpuTAdXVz9blSHUDniXHCQFQLqti9EXC/GMWQj64QOoIoJx9Q9h2XIQt/X/s/fXYZKc570+flczD/QwM9PuzsziLDOKWWaS7RzbsfNz4sQ5J/kmOTEex3ZMsmyLV6yVlplhaHeYmWe6p5mhfn+0vJIilmVLWs99XXvtTFV1dV1d7/RT7/M+z+fjAZXidTJhogjH4hez5ep+GvJTMUoDMG8lZLJGUsFSKaLTjbK8ANHrxfn8cXwhKPIYyEhQ4DnTCHIZsoxk9Ddv+NBaUD6QwCUIghbYBHzxNdu+BCCK4q+A24AHBEEIAh7gLvGjLJK4wLsi7HQTdnmQJX44vRxvhizRiE8iwfHcsTfsE71+5JnJqF55on0veM42obt5w/VCA9EfwNfeh2PvIZTlBSjL32KB6G0I2Rw4XzyB4e7tCErFe379+2HaCSeH4PaSN/Zg/ZELYxAOR1S/7U2D6O/eji8UkXBqOtkHAyNY9+wkSi0jHIJ9PaCWg9Udpuz0WWyrd9B2Zox0ZTRX/dEkeqC0p4mREStiYI4L675M6aXD+LuDZFUWkxA6wIxER1rzRaarV5By+Az5p88ykphLugFmDAno3TYaylaS2dND+qJ48oeuYohRofH6uKzPQbO4hMyOa2SaRokZGyB/XTHnE7NZ3vhjNKZZLq7chSuzCG9YQsKWpSgfe4qcyR5O3fMZ5pKyqPzVj5EkZZCRacCzsgLzZ74TafxGQFlZiDwlAffVHsZuvhU2raMmT4kqxcjxv/0Do0vXIh3tpErrwnXwHK6XToNMStjmJGSxk+Lqoz+rkvaSbSwK+EAqQaJSMmSFxfYBogU/EoeDQL8X/c0bUFaXMvbjJ2naeidrS6KI14LZDa3HB7FkFmD3wbZEJ8K5y/i1amQp8UiT4wnZXQwePEigvZe8r9yOPC0R53NHUS161X/k1DAsz5CgFQqp+sG/oijLIzA4jjQ+BuWiIiQxhjesqZ7tg+U6G9oYFRKVkrDXh+m7P0MWY0C9tuZDkSFbENld4B3x940QnJhFGmOIiNvGRARuXftOoV5X8waDwI8ygaEJPFdakUbpUK+uflelvCGTBW9D+1uundn3HkK7ZeUb0mlve067E+dzx9Dfvf0v6v78VEckaL2ZSo8oRnqzYtQRc8D5ERPd5/uZWrwMpRSKXOP0n+9ly1cjFWiTDjg6EOnl2pkPA0+f5aS6kGl9HIvOHeCRwl2UJQjIZ6f56oXfIr9rBwqFjPYuC/Ebl9B1vJ3Ynk5G5bEUKRy4p62cl2dS03AYo9vCdFEFZe0X+O3GL2PZtJm7nvhP6jfcysoL+1CPjjK8ZCV2ewB10EeWdYS5rEI0I0PIbTZ8MbHE9bSD348m5Met0CAYo4mLU+NFhmnUwsUvfZuZrAK+ceE3xG+txds9zLHeINUnn0NTXUKgexhZbiqhZbWMtk8ilBaQ3XKJ5J9/5/rn9Z9/d4hP6kYhGORy3R5u3pyGIIDdB5P//jtmU7MZqlzOroGTWJp6mAmpcG3bTO3aXM5enEJ/4hTlyQJSmRRJtB7VttVceaYJ1eQEi/7XLnoCOjrnIEYFVZcPE7NtBQd/X0+iXsKSO2vxdQ7gvdSCEGekUZ9NrN2ENi+dq4ZsVqZD0twowZEpNBuWMmqD7pkQq6av4njmKFJjFMZ//OLbpqNtPjg3AjsLXt0WnDHjb+9HlpVCaHYe9fLKP3lcLqjDL/CB4rl4LaLcvaiYsM1B2GInZHUQdrqQJcV/IIP2wyBksuA+14SgUKDdXve2T42OZ4+i3b76LaWORJ8f+5MHr68HvON7f0hBa8YVkexZn/3GfeIrhRoFxkj6ShThyI8PU/OZdRhjlASnTbQ+V8/clm2AQFCMpA1HrTBkA2//GJnWCUarllHbfYFnxDyGNYnssFxjrbWDoYJKNgeHOFCyhZwLx3Hs3M60NYjs//sxsrwMkl0mrumzWH7iaXyiFH2cFjEY5lTGUqJ1ClLtkySEnISmTdiS0riSVMGlqk0s8g5TeuA5jJYpNEIQVCqQCBhts4hWOyG1Fq8hiiT7NEOphaSZRjGFFCCTkry2ivGOCaIdZmLVINpdjEn05H/nfvwPPY0gCCju2MFpQzF1x/cSdfsmgiOTKJeUoshN53yLjZanLnDL2Dm0t21ibsVqOk90ESotIqHhEnHjg6T/w6dQv6LU5W3pxv7ofiSbV9N/oQ/pyBiS27aTZx1Ff9smpmxhen/4FBk7ahk2ZiJ56QiJ6xeRFbQQGJvC9fxx5IXZUJDDvmEFt8aZCU3M4vnMvZxxx7HBdI1orQT10grCYqR4xuSGVb1n0ZblcfLCNHVMEDZZMLzSYuGtb7tecfhmvNAdSQFrXqM25tx/Bs2aaiR67Qc2NheMJBd4A6LP/55TUaIo4j58HqkxGu3mFZGN8TFv/6KPEdK4GPQ3b8TfPYTnbBOat6iKDNmdCAr52+rzCUoF2k3LcR04+7ZfAhApunA+dwz9XX/eoCWGQnjONCJo1EgMWqRROppmdawp1RApin6VYDgi8bQsLSLjBNB9eYi0bCPGGCUhqwP74Ys0VG8nyiswbHtlkdrqQ9Hdy526CeatPi6u2E1V0ELnpB/ZkgQ+23WSHlUyfeEo5rpn+GRyHWv3X2Bi8XIMFsjta6M/LZPSlnqu5C4j0TLHlD4RBAnDhkTmVNGsaDqCIhzEV1zIr+75Ll+aPoZ44AJjS4tID1owNDZSaB1CoVMRDAmEpWEMGilheQz+YBh5TTkF6ysYeOw4OXMjDGYUYZ2xsyjGj+TAUbLijcz6BCaNKai+cSc09zD7g0eIr8pGu62OdmUy1T/5EbKlJTj3nSL6y3fiPtdMvyGdy787zZKZDjq+8U3Ujc34j7egC4VQP/wwyt5eer/1LboGIw8BggDqDjc6IQbpC5ewr1nDiRV3st7dx5w2l6ZLUmL7eihMiGOibZzS+EkUSjsK+yRBmwPR6cbwpTvRblhK2OpA/ZsG5qKTGN92FwFR4DbpAKGgC/XSyNquRIDVmRH5ruOsIvz8CVZszkO4Nobh/l3IXvlblmg1+PtHUeS9sYJ71gVa+euDFrwir/YBBq33w0LgusERRRHzfzyI8f98+V3nosVQCOdzx1BWFf1JpoUfBxRF2a+kQmfetKzXc7IezYal73geWWoiksFxfO39KMvy3vSYsMOF49mj6O/c9mfvNfN3DSFo1MhSEwjbnfhGppF1OAn12HFF6dBsWYkgkTDngp/WgxTY+splh91eJs+2s+7vdhJ2e3G+cJxzi7cjs7hRn7tMTvcQgbh4EvRR7L6lgH31Umb62vFrexjt62Nw+Saqjz9DRVKI9LlpugdtzCXlsHJRkHi3lG5dDPM2cF8cprbjAk9v/ByZw50UXz6GV6Yg3ShndMxCChJGH/gqaYdeQjMywsqXf8+cUcVgYh6fuvwoVpTMeiQ4jAkE5u3EKwUyokBRnsvxdhelsQGyti3B4/ThN8YhCThIvnQWY0Ii45v2sPQ//xfWnz6OpqkXXWEitgefIMuoomfjZlITQBJrQPa93xO9sRbV8gpEl4fA4DhT9X10tDxLzsQgZZ9dj3ZxHGJRHfb//TM0N61HiC6AFVkU+AYINTYgz03HPjyD5exVRu75BOKcGanTyRerwsgfaaJDlsC9sSMkLc3G2xxAqlcgS0tAtq4Ky3/8FsMXbyc0No0sLRHnM0cQNGoSty3nQljDulhIdc7gvTqA7tY3Nohr5LC7WEIoazXOvQfR7Vn/Ovkw9fpaHI++jDw96Q0Pt2dHIgU8ryUwOoU8PemDHq7vmYXAdYMTmjYhS47H39qLsvKdPbbDXh/Op4+g2bT8T7Z3/7ig3V6H/dGXMdyz43UK7GGXBzEcftdPl5q6Jdj3HkKWlvi6fpiwy4P75BVEnx/9HVvfUV0dIDhtwt85QNjpRgwEr2+X6DTvqmLS3zUQKSZ5pQilYxJiM8Gmh6nuKTzff4mOqjX0EoPJE1nT+uxLsCgJqi6dor1yHYqhEBVnD+NbspipZ05hjFUxMubi9q9sRuwb4mC/wA9O+bF1z3DLd+4g6me/pM+YSfVzP8JbU81j+evIvXAUXTzEJ0eR9IffMlVUScZgOyXWQYRrx3DEJrJROkHR4Cl8fg9hg44JvxynVMpUyWKixydp+/v/w5EGK9/7zRfpMeaQnxKHfGIMuVxPSmYqs9Ik0pwOHGXlhAJzjHZPE2dykvzzb6PSK+j+5z+QHHYyPR9A0EURnxoF21fR39RD7LyNtKd/SOcv9lOyZzmyuGjSjzQy1Wkl6koLwcQEZJnJBDoG0N+3k16zQIOshKxrF3EPBInbuhRBgKDHRXjOTH/jGD6rE2dZOaJdDjINivppVDKIX1LKyqQAqvwEvO0OBv/ue7jSM1hWFkbwebH99llUi0uQZaci0Wux/r9HkUTpcB84S3ByFt3OtehuihQIFdtB74B00Y7rVD36e3e87sFUFMHihRknTM/7kL1wkOjbNpFv0PHaZLYgCGh3rsH58mn0t22+vn3MDglarhtT/hFfUyeabe9sa/PnZiFw3eD4e0fQ3bQe98kr7xi4xHAYx5MH0d266Yb3t3otglSKbucanC+det0fr/t0w3vuT9PdtB7n00fQ37cT0e2NBCx/AM2GZe+6eEMURdzHLqHdXodEr31dMHWfbXzL1M7114fDWDwiB7pl1ysHGyZhQzaITkgqTOaqegc1l87gcMXxxfuqSNTBV2vg6cfaaVBnsVTvJf7xZ7kW1nBhcprZmvXcNnwScccKzsgTaY/LJk0cJv/hp4jfs5rMlssMDs6S6+nj4ZWfoCjsouTky2CxYtm+FdeMFeGem2mT51LUeYWZ0/UE9MmEv/Q5fPuOIQ6N0l+7AZlBR3B0CsU9u7DlVzPggUkT3BNs48SyPVQ2nsQkleCSxzCTlMVycx9xghSiNMT45jkmprJo4iLRMiX9P9hLYooeb/ViBJ+JkSmBpL99APHHPyL6M1/GrjciFqZAYysSkxnrE42IM2ZUaYlM5VYgcZqI/eJtBI4eRrN+KQMWgc45kOg1eCbNxGxdDoEgzmOX8HcOMPLZLxIn+MjuuIauJhZBKgOpBEGShOP5E2hW1xC2Wpk9P8JA8ygpZbkU37UORBHP2UYSfvT/Q1DIsT95kODYNLHf+TwSvRb7Iy8hy0p53cw/RQ/1gz5MDSew7dxJ67gEsyfSVA6RZHCsGhLlAYrPHiL2ixsZDus4Pgj+UKSitDgu4nAsjY1ClpqIr63venXshVG4reSN4zLs8/9F12XfioXAdYMTmjEhXb0EqTGakMmCNO6t16m8l1pQr67+qwpaf0QaF4M8IxlvcyeqxSWIPj+iy/2eVdklKiXqtTXYfvU00kQjmvVL31O1IYC/cwBFaS5SY/Qb9qnrlkRSO1kpb1nSb20f4Zo2iztKI2sdNm/ki2p9Nlydhl82Ql6MnMdiNrInqg/7EweY2bwemz1A0sQAdo8aWYeJEUk0P83ejTss4d9nTjGbV8IfTInkh6B7xI3s6jW6d32Ougv7MDSdQtDo6NfHsaLtFEPZpch0UeSP9dIv1VE4fZXGgh3kNVxhYj5IbpSe+cxCNM8eZNXp57BXLWL0pttZ9ZN/Zf/6T9CtrSbDCfJwEJXFhvLMRcaTypEW1uKITcSSV8DdL/6UYJQBr93NhF6LxmahUj7NmCqeNIPIiU//HeunGikc6+RXoTIKF8kpk1pwJMQSmjGjT9Bi6p3Ed+4nxKlVOIwJGAw6CAQwdrYznltM1f4DOBvaGRSj6POP4Q6AXgHzjiBLM9Q4njmKZuMygnYXYwnZVIycQ33P9teNG//QBKLTja9jgD6HjLmUYtb9aCtKpQx/7zC+a93o792JIJFE1C0CQWTJcYgeH4JBh3bTcsIuz+vucdhiI+vkSTo2biZBKqcoKhKoXlstKgaDOJ44hPaWtUhj9BQQKb6ByNpX5xxcnQFEiEupJO/kAZKyU+n3aciNiciAvZZA30ikR+0jwELg+itAEATUKxe9rfWHGAoRGJ5EvXLRX/jqPjqoastxPH0YeVYq3qZO1Kvfn4yVPCOZqAfufF/9LaIo4mvuQn/fG40UnH6onxBYvWEp7uOX39SJOBiGK8f72PCpVQTa+1CU5XFlQmBJMvx3A1g88LfL4L8b4VvLYWlaPmFnKvZnj+A508jg4lWU3bSKi4e6GYjJJiRIuGn6Egc0qfS7MtiWC13TAT4/dpj5nDhWdD6PSuPBHR3LxeqteHLzWDt8CYVtFJ9EhnfnVra98Eu8n78f6enDWFKzWdHxPJ6keFbNXEVobKVr5+0MFdew6uf/wemyjRzNXcMSDdRF2XH84gniBrtpL1uJPb8ceVQeZe1NxOz9CfEFRjQrqzjeF2TJtTNYYhLpXbGVCWM6cvsIa6avknj+NH0r1pD/zEnKNhThfOIA0oRYECF0tQtFTCzhu2/GmB1P/8sNeNavxBaTgGRknEarBn9ZNjJFKnPqGJJvriU0PIaku4/qNAOqlDiUm5YRdnnodSlZkRxC7Hj9w463ox/7g88ie+B+DgRSWZwES+Je2dfUQWjGjO72LQTHpvFcvIY0NgrDfTshHMb++AH0t256nVKLKIp4L14jOG2i6is73rLoSgyHcew9jGbLyjd9ANLIoTol8g9gzgWdtevo+fUpJjbt4N7yN57T19r7oQjqvhkLgesG5rXmihKdBtHrQwwG3/RJ3Xup5WNb2v5BotuzHvuTBxFUSmRJce/7PO+3KfOPsy3Xy6cRlAo0m1cgCAKiGFG7yI2Fy6FEqvydBOcs16vDILKu8WKXyDKNi+C+I0jKC5h4aD9tgRweiiuhOF4gJwZ+dAlWpMPStMjrQlYHwd4RZKtq6BuDeD/cFzPJMyWLuWfkKgW1WvaqCqjvh+EpN2sPPYEQCKCds+KNj8bfNMJAUQ2VxhD6k08gOF3IBIGcHUsInn0J69QcGf/0HaILShlraUUdp2JMpiOxpYum6i2YJAkkHj7CrKjkUN1drE0HgwLGnztD/PQMeoOShuotfLPlUaQKKeHhUXo+9UV81lEMFztIQ0rc7jqihibQuscJS6ScqtzIxv/3T7TGxtHeZCFr1xYykmzYVYqIZFNJLhK9GrVMTpckGkMghLOgkNCoCV1LN6aCUhJwkfLTn+HOSKcYF+1uC4GMdLZ/ehUyzavpMmdjJxO5pVQM9yBZVHR9u+vEZdzHLzN01/1MhFO5uejVpm/P5VZEvx95bjqOJw4iS0tEf+umV3uqJBL0t2/B8dQhDPfuRFDICdmduPadQrm4GP3bPGCKoojj6SNo1tW8a3GAeC2sKdXgC+YRsjQi+spf5+gghkIQDr/nxv0/Fwt9XDcw3qtdSPTa6+sh/v5RwvM2VLWvf5wSQyEcew9juHfHh3GZHzmCEzMglf5Jgev9IIoijsf2g1pJoHMQSbSesNNN1Bdv54Q1muK4SLn6mRGIkwVIP3kY/X07rwfJI/2gabmG6sJlxj/5acIKJY2T4O8e5G/kHcSUZ/GLQBmtc8L1L9Gw10fcU08zFJXGmeJ1VGmdlL30BB1x+WQUJ5Evd9KRV43b5qalcYpdz/4XYiBI87pbiAk4UZbkkPz73xH9lbsZHrSw17gcXWI0/5LQw8x//JYor4OAyUowJYmQzcVYTBo5KWrcJ68wmZrLkruWMf/4IVxmJxM7b8KqjcXmA81APxkDbUQtysdUVE7si/tQGlTYc/IpvbsOaVcfL5yaoW76KqEv3k9erITA8AQdrTNk+83Ib9/KyL//gQN5G5AoZGzcU0T283sJuTwglSHIpciM0UhT4pn3CDT3u9nw/z5D6H//iNhvfZpnZo2Ix84yG51EWV8DyuER+r/2Te6pVr6hYOHKzw6R8altaA8cQn/nVgBsf9jHeNc0Hbd9gsV5GvJjXz3ee62bwMAYotuDojAbZXXpW8qGhczWiGBvYRbBoQm0u9a+bXHPH13IVdVlyLNS3tc49LX0RFKWPv/1bWG7C9XySpTFOe/rnO/EQh/XAtcJDE+g27n2+u/y3HQcr/hmvRbPxWuoV1T9ZS/uI8yHoXYN4L3ciq9vBIlSQex3I+ppvqtdDPzTb4hPSSPtb28G1KzJhJd65Ojzi1A0tNOTUc6pYVCNjrLm7AnS/v0rlGuV/LwB7F74wedzcPhz+PGTQ+SMHuSeGiMnbJXsrFbhfewwLXIFKz63BsU0rEzT0nUlkTt008w99AJz23egHTpOaMLCntEe+rXJaBJjSNWFWfPZrXTf9m3MX/sKcScPUK8o4xPyeuRuge5zE2QMD6JdXIjxD//Kucfq6bVLWX/oYWxtLvQBDyXeKUIvHMFq9TJZuQyHzojTG6ZitgfZeDfBomy0Sgldw2bKU7RYLR7ii9Pong3jVGeyTLxKU24tmyeHkZTUINWqMeWuJPuJX+D4+n8g/u4XJHQ6CLq9qB98hCm1iszqUnytvchz0oj69M1MOMD0v39F3PbNOH+1l6j8TNp+fQi/S8vWODvtDh+KrGTarRL2XHoO5bJ7Xn/PrC7scg2Jch8upRwxEKL7fz/ElNpI2re+wJ0Jrw9I/u4h/N1DEAigv3v7O85gpMZoJMYo3McuoSjKxnXo3Ks7/9gg9lpEEeWi4vcdtACUlYXvqgL5w2QhcN3IBIKvk3MRBAFZUhzBqbnrpe5iKERwdBrNm9hWLPCXQQwGcR29iPvYJeQlOUR/8qbrsyhPcQltf1PCtukrmP7PL1BWFiJRK1kbFrk8DpLhFkbih1lkULI4R0F4ZTH6KBXfuxCRZErURbQEr0xAflU2OxLH0BikbB+/RNPjXQgeL6X//gUcEgnquRkGHztKmjpM45QC5de+hndgjBhlmIvyBALOfqr2LEERH0VIFDj/zQeJqigldrCXp9QVLP70RuRqqHX00fHNA7g2bSC+II7xPxzE1jjF8muXsEsl6EU//rR0FBuX0dQ0hS9LgeXzn8Nh8xJ/5TxiTgZuQxT+6Ggkba1Umvuxa5UkbFyCsr+dqOF+pNMzRG1bzr6+GLxM4m/uZGjjLooefgKJTIb85k14f/0k6dU1RJ05TUppKnZNFOfiylgc7EYWDNHhN+B47gDlty5F9Jm4FIzCF59Pzvh5Sgw+4r98B2qzirOP13NzgofhET/G7iGURa9KjzQfaKdsSynui9ewekXaP/dDYvasYcOty99wnwPDk3gaOxCDQaLu2fGu0m6iKBI22zD+4xc+uAF3A7AQuG5QxGAQ3kSDTLWiCtehc+hvifg8eS5c/asuyPgwCDvdBMamCY5NE7Y7I1VkKfHIs1LRb199vfw9GIaD/XBbMShLlqJaUoLr0Hkkeg2aNTWsFBT8n5eWU/TSkxgKjdg6nRjS47k6IdIyI7A7H8YckaAVFkEuiJzvdKNO1yEdsjHhVhAuLCG49zyt0yLS1ATsoo7ewipKnSOYekZJOXaQnvhcVowPo1NLyHRNcRkNqiMn0aWlMp6eR3RPF+KtnyHk97Os5wwdRzvJuX0V1nOttMk1RPX3kNA7glOjJ1niwRcdx9DK9Ri7B+kz5lCu87Co8WUOzMeQdksdkh/9ipicFKJ7mxiValGtqyXsCxKdk0RQr0A3No13fAr12hoqJ67S0jnPxk8uZfalc1RMjyOJjebFHV8is/E8Vceex7+ogkkU5OTEo7cNcbVsFeGCHBJfPk7tbYvxnG0iPG9n5dc+gSocoKW9kfRcI06ZmhE7fOnrtYgPDSKx2mg53E51Zgq+q124x2ZRvHQBWVs0s10jDC1dw6r/+1mUqW/sfwxOzeE+XY8YDhP1yprVu8HX0I6y+sMxa/wo8/79GBb4SBMYnkSe+cZ0gUSthFAI0R9ADAYJjs38SWmFBd4doihif+IA1l8+hftMAwRDqJaWo79tM/q7tuGtb0NRUfC6NOXBPtiY/WoTqDRKj+GubYSLCzn285M88V/nmR+cQWbQYUtMo0mWyqlgCoe/t587R05i/s2z+PpH8QZBLgXbyByyoSGuaHMRh8bYHOcgY8dS+hfV4czJo4h5lsvmqDy4F6/NTYU4S/fffptU6wT+sEDR9kVYkjLQPr8PxfQUsx6ImZvE5Q5S/OLj1Jx8jhFTCNWtm3FfaWdKn4B7YII+nwZLbAr9u29nvG4DM2ojGS1XuBRVyJq+M2Ssr+CguoTsaJG4b3+XsFRGtHMenceBuHsTM4poyv7xbkImK/72AXztAygqixi/+9vk9FwlYbCba+eGkKbEE563cWTxToxzE6xN8hH/z18iCj/ztgCKyiKkM7Ns+NQyqsRZCnMNeC9cQ3/bZtTLKlCODOPdd5yJTdtRKaQcaHJycxFEqwWiP38rmrFhQuXF9PzkWbzNnfQNO0msycf2+c8wtP1m1v/DzW8atEImC66D5xBDIQx3v3Ul4JuNGX/P8J9tXenjzEJxxg2K69A51KsWv6nqQ2B0isDIJIRF5FkpbxrgFvjgEP0BHE8fIWR1oFpURGBsGolei3pZBVJjNN6GdlzHLhH7D5+7niJsnorMkqpfc2tGbPCTyxE19iIj2LtHqdj/BNQtwzE6S17LRfaXbkPpsFE43Ipg0KHxOTFmxFH6t7fg+cPz2FespPfJM0wlZTNdXEXFoaeZEPSUZalRDA8jS4ojvSaHsMnCyWklmo5OjKN9CFodvjv20NM2i1kdxdK20yRMjzCy/SYGiCY6UYd+fg5nSIraaUc1MUby7AjenVuZvjbE2O7bqXnpYcRpE96vfpazylzyj++jYrAJW0YOoaICJFca8YsCGW4TGo+Djk9+CafKwNInf4lMAGV1CXj9KJaUELLYmTx2Fe36WkIGPVefqac6Q8qALI7h9du43d6MND6GQP8YgdFJrNn5+C+1EK+XIE2MRVAqUFUWoV5fiyAIhMNhWr/yE7y37cGUnovf7mZZ/0VS7nrVgdp1pgHbQy8wGpVCvMfCrCqW6MJUHANTFFQkIY3Rv9oB/MfvVUEAAUJmG4a7t78n/yrPpRakxqgbXnYNFoozFniFsMP1llJF8ozkiCGcVPKW4rILfDD8UZ9QvaYGf9cg6pWLUBMR7/VebiE0a8F9uoH4n3z7etCaccG4PeKHBRH5nf09cGoEPlkB+UZomwhSOdVM8k//lgylj6bGGZo10ahmnRgMCn5z/7+CRMJqvYUV7ScZvu+f8ApS+j2JlM0Ok7I4h6N2G9du+xTlR59jOKRl5aIc5LMz2B7dz7RUR8VAL16NlsmcYqrTZQwdPYk0oZA1/iFsJcXoUlRY5pxUuocoTcniytA0fm0c1c1HUUVrmC8vZvpiJ9Nli1l24SU0TitHV+zAnFxDWASHUo8iKxVZQhze01dIcM4QDoF/aRV9xYuxTTtY+dLPiP2Hz6HbvRbXy6dRFGYhL8jC8ejLDH36c1T91/dQ/OxfmO8XGR5oZnLHKm6euYyolINMSshiQ7t5JTqlgtM+DcnhSdTVpSjyM64344vBIG0PHYdli5k1ezkdhPsrNBgmpYQs9usN5No1NZi/+3My0oJ0G/MwxGmZ3HkTNZcPor9r2/Viide2QgTnLLj2n0Z/59b3FLTEcJhA/yjq5bs+kHF4o7GQKrwBEd+s2uh/oCjPf98Ntn/tREwk+3G8cBzH04ex7z2E+2wjYaf7dccFp004XziB/s5t+DsHUK96dS1RatBFZrpiGOO//Q3SV77U/CE4NgBbc0U65+DJdniyDQYt8NAu2FEQUT3IvHCCjsrVNJsUfPOynu6GMQ4kL8MQ8uDxBClR2PiivIOVJTqQy2n59x8irF3BuuAgyfdvYSQlj623lJB56RSekhLKW89iefwQw00jdMfnITrcOKPjkGWloZYLnI8q5mL1NgqMEnpkcaSEHTy6+tNkdzfjik/k6ZkYZquqWX71GH5RIDBjZjIqmczyVJaPNpLW18JcaSXzW7bi9sNY9wwru85wIbGSsRkfEreT0JyVcF42BiFAeHE5m2QTRO1ajX7POtyHziHPy0Cen4nzxROoVi5CPThI3H9+jeEfP83a6UbUg4Os6ruAJBREnpGEc+9htJtXIC6pYLptlNhrTRxIqOFwTCUvmWKYckTWG4d/e4CZymqqPrmWQNcgu/JCDNmgtWgZ7hNXIvdcFLH8Yi/SjGQ8rgChDauQBQPUXDmEJFqPIAgIEsnrgpa3vg3P2UYM9+16z2o03kstqFYs9FW+FQupwhuQ4IwZf9fge9bZW+DtcZ9uIDRrRpDLkRdkosjPfLWQYmIGb1MnYbcXWWoCAY0OycBQRLFbEHA8fQTDXdsACHt8uA6cQZYUh2rloutfdjYv/KoJUnTgefkkolrNcOVyEnXwQIUfR2M3B89MoZqZJiYvmYmiSpz6WHwSObOPHsTjFxlYtJK6XDmfevGHqGrK6Tvdxfjnv4hifh4/EsqHmpnqnibTO8uUqEW3ahHNh9pImhrCGRVHoteCJDkeddgPGWn4zjZwbsf9FKYokDy9H51lFn2CAalGyWUxiUXdl5DmZ+GetRLttaMsziE0Nk33yk0kXTxLVEEqsyE1LkHGfFiJJz4JiydEZZqS2JlR5rZtR/2Lh4hvbcb6nW9SYhrAEZIyPzBDWroOzepqCASRpSQgz07F+uBzSA1a/KKE4ZYxErLjsDb1EOezY3KFkfp9+PVRhPR6Jr/2NcI6HQlnT2KwzqGVC5xbfwefqIBAGK40zeE6dpHBus08UKdmwALPH5vkq0mTaNZU02OCwafPkrE4E/uhCwQQCBsMxI/2Ykw3Ev+FW3C9dIqQ3Yn+lk3X+/5EfwDnS6eQZ6eiWvLeCyv+GvsqF4wkF8Bzvhl5TtrrpGIW+NMIma0RF+Q3kVn6n4x0zXClcRZVTRmb8wRo6UCi0yDPTsPXOYC/vR/tzjVIo/X4Q3CoH0LhSHqwLAHi1CK9Dx1Gl5+K51ovok7LuEdOa2IRSaoQSeMDDKXkE5ydR26eR2edI3O8h5ZbPslQdAb5V05wa8fL6OdNWFV6rIZ4YpP0hAQJlw35xO9cRcEffk3QF8A3ZSJuZhS/Ssvk8tV0JBSS1HSZysFG5Eo5JkM8Ce55zizaRFaqjvEly9n18Pd5WlPJ6quHMUSpGNYlk61woc9OJtDWx3xMApeMpSjFAHmNZwkvqcSTnoknJpaE4V5M4xYKpntxaKOQy6VMh1WMli1lqbmD8NQcNlFOpi6EdGIKZU0puj3rCQ5N4LvWjW7POtQrF9Pzk+eQKeX0JeVTJs5i/a/HUWxZReySfKR9Q4TmLIguF7K0ZHxdAyiLc9CsX8qgNpmgUk2RdwpP1xAv5W1ArZRyeSKihr4zH9JPHIyoWCjkeB0+hp45Q1zQgfHTu3E8cQgxFESeHI92Wx32vYfQ37YJ1/6zSGINKPIzcR29iG7nmjeVWno3uM81IU9NRJ6T9r5e/3FkIXAtgOOZI+hu2/y+ZYcWeCPOfafQbFj6jusUwTA81QH3lEXs248OQsHpAxRtX4T1h79HvaIKwz07rh/7dAdsz48c2zXqIckySf3+dlItE7TH5mJQC7h0MfiW1jA86WZ1x0k8u3fwRLuA3R8p0thy/BFUSbEsiQ8xOu4g+cJpWvVZWEsr2LAkGvWVeiz7zuIsKyfntlV0nO8nOGVGb5pmUmVEHfRAdRWzPikme4DK3nrwB+isXo9dqiLHNU1hxyXM27ah7O4lOD6NOuBBH/JhTkwjxmPHt3k90ssN2L0ifk8ApVREnZOKOSqe0qvn6EkuRJqXiZCTjjA6gau5m4zpQeYTUpGNjJFz5xrU6fHYK6oYPdCAvq0FWUcXodJi4uOUyGKjkaUmILrchO0uegM6ErfU0nlllMzD+0j9wh6kdjshkxVBIUeek4ZEr8XfM4RyRRXec81o19UScnk40+FiTZmehyQVqOVQmgCLmGPofDcpqiASREJTc2i31SGJ1uM504iqtozg1ByCUoG/ox9BoyY4bSI0NRfprdNqCJosiP4Ahnu2vysn7Dfjr3G2BQvFGQu8wkLQ+uAQg0FEr+9dLa4fH4T1WZElxigV3BxrYsBm5cLvz1O4fSOCwxI5pxhxHV6fFSb4wlFaxwXkOjVXDClodEp6PEmU6HzErywjNuDg8vnjrB3pJGpxAX/bHsLilfGFJdDfayLBMYfkE9txPfkYGecvoV69GM/Sm7miLcDc3cj6YRvhwmK0phlcP3oInTKKoZQCQmE5ifYpjNFKrB0taDIyiLHa8BrjiNlQS8ipZy6rlMzm41gENcqDx9DJwnQl5BMbdDMvlSEPBxE31aGZGMFlc6IJiYiZWUizUphZsYqYp57GavfhK9BjSDQS8AVQj44Rsllw/vYnXPl/B7m9MAG5JEx43o7pOz8hrySVMG50//5lTI8eZHzMh27pCrJLE5HEROHvGcbR4+bUoJqbFW7Sv3Yb+ps24GvrIzA2hW776uv3Q7t5Bf7+UdTLK19Vg4iHv+sJsUs6xEr3APQFCSfHUbB9ERKVkpDNieOpQwTGphD7RpDnpCFNisN9/DL6+3YSMltRFGThrW/DcNc2BIWcsMuD3Okm0DeC46nDqOuWvGfDRVEUcZ+sX+irfBcsBK4bjLDLg6B+Z6PCBd493sbOSCn2OzDpiNhKJOlEPFfaUC0qxvKjh0nfsZqEsVkaCmqRHTmJ/ZqPxnklRXEwfr6LM4FCErZkI4rQNgl5Jy6hqChClqYmdOAk3aMmtDIF5tQE/sO/lLXX9vPV79QxpzVi+90jRG+vQv3M7xkfmCXxf3+Fax1W1ti7STYq8NQfxzk8gyY/naGKxXSJsUjcbgpH2piPTyHHMcCc2YNaLzLoklAbmsPhEzkQTGK6oJDEi2fQtV5Fdetmws8fxmN3UeBsQ5CCt7iE5H/5G0J3PcBw5VIs1XXEOMyoQ36y50fJ+PG/EXB6OZ9aTs2lUzjnp1EhYo+OR9i4mp82yvhf8+0kf/1TBOcszDx7Cv+/fJv4ihhm/+Y/CHQNEn/nJrQXrjFpjGefvIgNMhumo01cLN7Fuooo4h7rQ7n7bhwvHEcQJGjfxP3A19yJ7paN2H1wbBCSRRdf7ztExvpyFMVr3tAMLIuPwfCJ3XhO1l9XQ/d1DaIoykYQBJQVhXivtBJ2e68/zEgVcqQxBuTpSYjBIJ5zzXjONqGuW4w8I/ktx0zIYsfX3kdoag4AeXbaQl/lu2AhVXiD4WvpQVAqULxGlmaBPw37EwciunJvM4sNi5EKwLtKwXfqCsHpObxXu1GU5iKRStHdthmJRkVwYoaL58dJ3LAEhRSOff8AbXXb0SkE5tygNs2y+cxequa60G2vo33HXVwb8rD4sV/SFY5BVpyLYUkhdV2n6eo0k2cfRxDDzCRnIVlcxvS0G/VtW3B2DJH+vf/El5yMO7+Axlk5YbmMU+WbuHXgOIagG21TE/q5SaLddizGJDRhH06VnpjMOGwJqZwWMqjsuoglIRXR66O05TxIBXxqLTF2E/7oWFQeJzZVFFqbGbVSgt8XQpuTjKq2jLDby9jxVuYrF9FevZ41Bx5m6oEvEzc2wFxzP6GxKWqSwigyk/Acu4QpNpn8L+/BV9+KGAwT9Zmb8TW0gVqFoJDjPNVIX+ccZ7feS9LcGFsa96HdVod6RRWqqqI3bewNe3w4D52jefFGLF7Ykguh/cfRblz+js7WzpdPo15RhdQYHbEYuefVMWD5yaOolla8raOCGAziOX+V4MQsgvJNlDLCYSTRepRl+UiT4/+qsyQLqcK/cgJDE2i3133Yl3HDEJyYQZaSgCAIBGfMEAwiqJQISgWCSnFdb+7UEKzOBHHWRHB2PtLcnZNKcHSaqPt3XVf0vhhMRGduptME420TyNNSWJQi0D0XaTD+fPsJaitiELsTsQ9Oc+HKDOlj3by05bPE6QS29Z5EduR5OqOTyZzoRYzRoairZWhKjm54HvPOnZSFvWh/+yDujEysmzcT9/uH8ez4LJcMBfzr5Z+jMqgYaRoh3mdBa5llNiWL527/Jp96+v+iI8BQbAbluiCrzh/FIlUh9fkobD6DW6lGY1ATX5CGryuEfnYat85A3+qtpJuGMPR0Y9hYijYlFk9jByPKOA58/p+pme8mfbQXS+ViDI/vpWvNVuJWRGN86mnEzFy8F5sY+9bfYTDPYNt7iJDNCZlpOB89QrB6EeGjZwg4PBy6+QFS41uIlgRIDdkwZxUQk52Gqrb8Tb/0Q/M2hp86Q3PxCpbGRJT1Q/M2PFLpOwYtAM3G5bhePoWqtgJZeuLr3kNVW45qcfHbvl6QyRYqe/9MfGCBSxCEYcABhIDg/4yeQuSu/xewHXADnxJFsfmDev8FIoiBwLvWQVvgnfFcakG3ay0hsxX30QvIczMi7sheH6IvQNjrwxmXhC9zEWm6MPbnzwMC+ru2Yv3V0whqJXNmLxNKGLTCrAtypQZmJuwk9rSSes8muq0wZIVfbxcZf7mPsWEvg7fdR2OnjfIDLzKRVUhUSgy3l4rsn68gfdVm0vc/j5wwEq2Wzn1NqBMSsXzyfjbJp+j4xoOoTRZUW8vIPfQi55dupKbEwGdO/ppr8z40XhFLQjIpnaOY49OY1Sewvf55xrNLyZgfxeh38pCrlNwEP8agi5zxbnxKNWxYTTDsx2K1k1KWzXjSBnwnL1Ez245aKWFGKse0eTvx1mHOLt7OH3rVrJq+isfmQE2Ay1495YKSgqMvohgdJU4tMn/UzMyttxF0hwk5goQENUr7HN5VqxHWLENptTC/uBbbr59m00sPYr73Hm7JU+J86CKmb36OU5NeVj26n6g7Nl+3lA+ZrZiP1tNqU6Jev4nbCtTX2xrdxy+hfY1jwtshUSuRxEbj3H+a6C/d8bp9C24KHy4f9IxrnSiKprfYtw3If+XfUuCXr/y/wAeEGArBW/j6LPDeCXt9ESUEpQLXCyfQ3b7l+pfjHxFFuPTjw2ysceE63IxDbWA2IYPZ46NEDTqYfuAB0s+fIFahoS2cQqIWPBXlpF+4RFWJln/tlDHhgF9th64nz1GiCzAqMdJ+cZgErRTd/Bw5OSnMq0UePjBJXV8TtgEVijMX6IlLRbw2RUgiRaVWkvr//TMj49NES8BTWsLstSGcDgl1NX7Ujz9IWCqlxDVH44SKRb5psFkZLCrBmBlPzlgHlppl+AOxBEMSVp3bR0ClIilejcw0w3h+OaWj3YjlRQS8AmNZZbws5PLVz0YTeuJF7M4gqTtWYTp9lh+X76apV8PnL/yOqOI0/ElahPqr7PbbmYlPJ2pskjl9DFK9nLOf/3vuX22gqX2eKsk0flc8YmYM0Usz8J46RFfeEoLD3eTOD+PSaam6cBB3gxzVoiKKC2OITYMjyrWsf/Qg0euW4G3uosutZqpqNZtKVGhfkz0MTpuQGHRv62f1P9GsrUaWmvCRMVBcIMJf8m7sAR4RI4tqlwVBiBYEIVkUxam/4DXc0ASGJt5zJdMCb433UgvqFVUEhieRJcW9LmiJYmT29PzRCdK9AVof+BFBrRaNQsC4WiDl1Hni//PrSKMFvHkbOPPjg8iqVayoieX0sIGqsXb+OfUBpHK4swQOdAdZfvII4ylpnPQmkJCiJ2BxEKMVuBqfTXb9RXZ7BmmJjqfiymkcHh/qoA8hNwW5y4HH72PO5md4590UdlzBMefCFRVLsWMc++/7EIUQCGALSUmVevBbndSXrqPy6zchGZ3gVHIuJWNtmKITSL5wlqAxEVvVYrzHDzO48242VOp5NGkN2e1XyN33IOd1FXxjexjVnAzznm24/TLCejWjpgBZTz3GYvc83q8/gOiwoq1vQLN5FdK0WPIvXaZ+KoMi2yjXvv1T1hZF8UhXgE8M1qPZspyQyYKyLA+/3cPxRdspGmkjJjEWc1I65Suy0G1eQchkQXilKCJRC7tr9OxT7aSotYXunLWszFeyNOqN99N9uuG6K8K7RZDJFkRuP4J8kIFLBI4KgiACvxZF8Tf/Y38qMPaa38df2bYQuD4AAqNTeC+3XndgXeBPQxRFgpNzqNfWYH/0Zfy37KRjEqbsIpLZOaTdfZimHawZaiT909vw+UdBrUS3Zz2+K60oP3/L9QbjZ7sEhM1bWPP8b2lzrSc6O4tml54Yr43FZRpqksIc/ekphsxhRuRylmd6OJSQTa28h6Pln2X9wUeYWLKC7sNX8Wui8Ha0MVNQRczoGDZBDYh4NHqSv//3LP3atzHFpDCzbQ9JHjOuq3aIE7mgTiLePIMv1khuZz3z6mhOb/80ypRkKguy2HD5EjPnRoixdTOSkE6sbY70hvMEY6Khf5hD+koWNzyGfGyc9jU7WTHbyoQvjcp7d7KvDfJ+/H0CRTnsi15KfnUMtngDqUdPEWObRVi3HG3IQ8yZ08wdukJFWgLZv/8umblRPNgkkn7qKDFfXYfnbBPS+Bg8q1fR9LMDrP5sJra0RQyd6mDVLTUEugYArmsM/hGNHO6oktOdXs1dcW+udhYYnUKWELuQRr9B+CAD1ypRFCcEQUgAjgmC0C2K4tn3ehJBEL4AfAEgIyPjA7y8GxdfRz/+3mH09+74q65M+qDwBqGzfgyzMoPgC10QXYTOJJB+8ST5Uj+u2ATOl5aze+U8XlMKo08dQ19djOzEOUJLSjDctxNBJsMfgmc6IVoWJLfxJEO6ZAw/+zUquUC7Np0vHPgJ4yfjmCXE4r5BnNoo8ryz9A7Fkhc3xvnCOlAomCyuJPXKGfQjg8TqNFy+5TNobRY6Ukso9U6iG+onyjSG9JNfxamPYt6YRNmxFxCsNvwSGb5N6yienYZ5DzNOO6LFjiwhia9cehDxsgSZeRrvrAm9P8h8Vh7+kJymnfeyrPs8TZpMDBtr2d5yiHBODE2JNZRLXBRuXM74wWM81zqNWoDoaAX7bNFUdp8gXRNCZcxCj4vWqjrKhgbR2GYYtgcxff97bJSOoyrNQxqC3JaLDGSXYpFqUbi9TKCnexQ2fW4tbY+fxrFlE8tcg6gqdhEcmSDsdL9pP51UAqVvdBS5judcM/o7t/wZR80Cf0n+LOXwgiD8H8ApiuIPX7Pt18BpURSffOX3HmDt26UKF8rh3xnPxWuEXR60m97ouLrAe0d8pax9Wf0h0u5Zj/fZoxju34X7dAOy1AR6DJkMWiPSQK4n9nNaW0Dp3t/hiYph5P5Po27vwLZ2Hb1m8IehWj5P/IUztJbX4Tp4lthoJXEnjpG/u4be833477oFx4kr5MyPMK6IxRyQ0Xnf58lPViLOmSmwDiMeOkXWqUOEdXqc0XFYsvKQz5tJjpXj8YYZVcRQMteHyxXkRM0OKuf78dbWEGpoYeCmuxAGhth19kksXtAOD0YqIg16NLFaTGElsxId8aowsf/1D5yzx/DweQfb2w5SovchiiLVKhtPlu6hXGah+IXHUS+vQL24BEtGDnt/eBp/UiJTKbmM2uAmTzvaPRvYmCMyfde3CDo9tFRvRLdzDfq2Vmo+sw79KxnX0wd6yJc78VQv4dSvT1MYL0FZmoc0M4UBC9SZ21C3tqG/eQOy1ESCEzMEhiZQr1r8nu6pv3+U0Ix5obH3I8yHUg4vCIIWkIii6Hjl583Av/6Pw14CvioIwl4iRRm2hfWt948oiriPXkQSrV8IWh8gB/thZYwLY7yKwIWrqNfWEDJZmJ+y8WNfDasz4aYi8PWNYJp1kd39ArGrq5DGx1BWl4DL3UejdYZgOBFNRwdtI5OcLdnFmqefoWyslZHYDJQx8Zwcl6FKy0F+4DQOfTwdwWi8MXFkTfURv+8hpEMj6IQAQnwcUrOJUHQML33zeyS3N1PknSTaEwQ/DEVnEOcwYddGE5D4qfWOMbJ+M87zV1k+0IrhWTeJPiv95hAaMcTs6i3oMhIoTZMjddjpuDTF9KatEHTwh+4YOkywtUzN6iELEz2zRN++mZOLllI5O0Lig8/ivecWDmcsR331GkNHWtDt3kTlL3/CSXUhN09dIPYzO8lQemj9xm/ROPz0/vjnrJhoZuyh3zP7mc9cD1pzg3NIhsdI/cpGvFe7iE6JxuCZplObTJUc7iwFKMcXp7lurilNScB9rhn1e7yn3kst6O/b+UEOkwU+ZD6oErRE4LwgCC1APXBAFMXDgiB8SRCEL71yzEFgEOgHHgS+/AG9918lrpdPI8tIRr204sO+lI8tYRH29YDNF/m9YTKy2B978RyqmnLC81ZkaYmY9p3jF1GruLXhWfwvHuPMvz3L0Fe+T3eXCakxmqnyJUzZwvTMhGgrWkbbM5dJOHOceVeY+uQKVl98kYL+q1z+xP9CYTRw5ct/zzxqzLfdhq+ggK5FdRQJFmqbjlFoFFGNjiDXKPHceQtR0gA6t4Pxm26l/NTLBJdXMzdool+XyuGCdcRNjhA91E9gzoIhN4W8WJG8+lOUzPRw+Iv/xHRJFY6RGeIFH6aSClJVQcw2PxdL13Ll0hRCXhYZDRd4NmE5TdOgk4ZZ/dQv6IvP4+Dnv8vhlGWk+83of/cIF5Mq+INhOWXxMJxfxVh6IUWn93MutZo7Zy6w/qsbKZnswPzUUaJ6ujA++K/cUiohfXEW3kVVGK5cYexUK2GXh9anLlH7uXUEp00EhiZYsq0MryjlzjKBktek/JQludd/FgQBQSlH9Pnf1f0Vg0HcJ6+gKMlZSKHfYCwoZ3wMCYzPEOgdRrN+oZvgT+FQPxQYoX4CsqPB7IGN/n7CNieh2XnUa2twNnXz07EEFittxCbqSDaN0ddrxp6SQbx5io6lG+lxqVjkHCZK9PCYWMyO83uxukNIgiEKMzWorkbs4QfPdKFLjUUSE4Xb6sHTP4qg0VDYeIr5qASyYyX8fPH95FhGSK7MIGH/S4SmzGSYh5lctpqC5bn0NI+T657kRN0dGJ97jri+Th666Vssd/WTbx5AKkhQzkyTdFMdHpWOs0f6SJgYRPaNz6IZG2HKJaF6ZwWT//0MLemVKLxeLpSvJ/3iKUo+sxHZD39B0s2ruZhZS1Y0yNwuDD/4KcqsJHyfvpcmswKpBEascE85XP71MeI8Vu64vxx/Wx/XYvIxnj9NokFC3He+AEDfbw9h2biR6kw5+/f1kdjTiuL2bVSmSHHsPYThvp14zjYhy0xBkZv+tvfMPzBKeN6OqqbsLY8JTs3hudwCgSDKRcUo8jM/uEGzwJ+FBeWMvwI8ZxvR37Gw0Pyn0D4L0SrIjYE4NXzvAtyW68N3uQv1puWELHZC/iBPnrdRe081UY88RvbFMZRVRVSa+zlt9mD2eIiRXOHvYxy0ROVw9fwwy1STnI0to8ZzjYzdtQRPXcSXmcUwRny6WJpX3kyPXU5SgpUVI3spOXsQAZGE5Wl8s+h+Vg5cJmlROlnnjjPnDpFpm2KgejVFKwqY/f2LpLntNN/+KZJ+8QvUXhcnt3yCuyfPcjRzBZq5WYbyKzEoR7kav5OMtnrWzvZTv/telM9dJDolBn3IQ+OPX+TIjgeIS9JTNNrKGmGcmg0JzPztt7my+5M8q6gm1gr3Fgfxffe/6c3LpqF8DTkhBYuSoWsWUg0w5xTxeYLc+/3b8Dd1cH7RZnIunSDeqEJVkQ+Af3iSHmkcOzLlCALs3JPPxbF8atNFHI8fQHfzBvxdQyDyjkELQJ6TjrPxyBsCV8hix9fWS3B8BllyPNqtdUjUyrc4ywIfdxYC18eMwOhURIJooSHyfTPvgV4z3FIcsRZ5uRf+eQ10PXyOhpKlVB88h/qenTz5L4co2VVL+LePkK/3orlnB7JEI/0xWQRbJrm27S529p9EpZTgtrqobTnJ2ZwVrPYPIF1cQn7TWSx+F+l3rGNqXxtb5NPM/Nc/YHOHsSakopmdwqHUYSmrJOp8G2v8x1harEc92EZvWE+K1EpbxUossigUjx7HX1TJsKij/A9/QK2WM2ZMYonWjr1qDTUz89h3bib3wmVmdm3j71/4HmOWIFq9gqX7H2FOa6RPF82l7DXMLMknTSlS5xqDsUmKGk/hMsbgyMnjMzOnaFXOE6uEi79qoCt3EUF5DNboZGzTUJ0MJg98aQn85+PjfGNLKhKJwGFDGRVDzcS655Hlp6NcFJFD6tzfTOGd266XqEsEWJUBriMXUS2rIGxz4u8fQX/zu+uvEgQBpBICw5P4+0cJz1sj542JQlmSEzGeXOCGZyFV+DHD/vgB9HduWQhcb4E/FAlEZfFQGPfG/cEw7G2HO0pBJoHnuiIagzFTowTGppnuneFa5VpsZ5oo90wwk5TFMqUZhUaO/tZNdD14hBeVxSxS2qjwT2LJyOXSxQnSzh7D5RfJ8JmZKa8mpyyR5+fjiHZa6cgo44tXHmFWpuf5dZ9id6kU7w9+Tap5gnN1t5B8/iQ6l424jFisMi0a0yzGsAuf3oDL6ce5aAm/K9rD8ksvsenUU7ji4nEHBZQqOdHyELqybGKXFNBxohutxUSCQcq1iSBx5ilCUQb61u3g6oodRGsk1MW6uXbgGgk+K2nFyaS6ZklJUnNpFMx1q0m3T2E614ZJG0NSooby+QGce3bRMitwZSxEbE8nRYYAB2IXsajhGBl3raNhTs4S7yiZ/S3IE4woHDbkt21DcbmBNpeGnbe/3gXY195PyGxFWZaH69D599zGEZw2ERgcR1GY9b7NGhf4aLFgJHkDExiaIDA2tfBU+RZ4g5FAtC0P+udh2AobcyD2NWVoL/XA0jTQySOFGXXmNuJj5HgbOpAmxCKLi6Hn8hD+gTHGP/N5VsW60HR2EvQGuBJbjG9kEsPUOLFqyEzT8my/nElbmFZNBuuFcYqunqPw5loOaksJNrYxuGw9GScO4rG5MeviGKpZTVxLExUznQxpEphJzqOm/iC5YQsSmYS4gAMNAaaiUlGNDONdVkt9Yhml+/ci9fk4W7ODwvkBsh3TlH1yHVE7VuHvGsT2uxfoN4NxtA+roMKlMRAWJJyq2UlWkhqj14LFK6DUqwkvKqd+WkLu5ZMk7qmjWxKHPyKsQf0k7J5vpLbtNKG8HBxbN5MdI5DS1oTE7eJKXCkmDzjPNrE70cnJLfexqvccivFxQhot4d2b8B46j98fwpubR87KfKJfo7AUsthxHbmAbvc6HE8dwnDvzoWm4AUWAteNjP2x/ejv3va+3VVvZDyBSNDaXQiGV5Y2vEE4MQRSAdZnQ+vMH/2y4NQwbJttRKNX4r7Ugr+jH9HtIZCTzfiki9nPfo7cnCgyzxxhIKjlYlQRGS1XuKbPYvGVwySVpPF8KId4xxzn1blkxwjEBexkZerRiX6cB8+SU5XOw8oqbn7mJ/gNUfTo09FMjDObnEVC2MVYWh7ZA20snmjl+c99lx2P/YiYGBXDfhUxw314YuNQOeyEQmGUOWkk5iUwMBskdWqQsW9+izUbs/GPTGL95V5843PMzzo5n7KE8axiCoMmkm2TjNfUEcrK4JNroxiwCLQ8cppi5xihiRkaE8uQSCXUpsKR/jBRClidKyP20gVUNSVIdFoEqQRBpURdtwRpjAFRhJ/XQ+VAI9JjJ0mdGkK/ehGCUhGRxTLG4G/tIeqzt7xhNiSGw9gfeQn97ZtxPHsM/a2b3pU55wI3PguB6wbFPzBGaGruPTdf/jXg9MML3XBzEejeaMnEnAtODkf2pekj9iGbgoOEJ6bxTJoIt3ShrChA/Ymbeem3V+jw6Ukd7cG1bi3yhmbwevFUVDDRMsKaM88ynpZPvyGNqYRM3Bo9i0evkjIzwkheBXIhTFH9SfQ+J2FRJEoWxpyaxYgyHldYysbO40gcDoYySzDp4zF6rbTsuY8dD/47Jp0RjWkGHUHm6+qIu9rIgD6FZJ8FXZQalzdEjGue2eQs1KkJ6Hq7kc/N4TNE05y/lGfXfZqwP8BXX/oBiWVptN36SW4qFvAE4eJwiKKOy9Tvb2Oyogb1qiXEaeDsMLj8Yeo8g1T0XaGgrxlFRQF4fYRszojHlUTAEwCLF7wBsDpDJA91ErusmKgVlQSHJ5BE6RER8TZ2kPjL7yKRv3EW5dx/BmVlId5L11CvrkaW9Ca53AX+KlkIXDco9kdfjqwFLKi/vw67L5Lyu7U4oln3VogiHOsNEm0zUWLux33yCv7oaPp75xGLCyn68i4eOmpG2dqGf+M6CuQOjL/8NZertxJrnSXq+FFyp/swJ2Xyu13fQJ2XzuJEkZzv/1+M8gCHtnyKq8o0yq+dZs+xP6C5awfuoxcIma0ENRqmFDEUTHThU+vQB90EfAGkUgmBrExks3NIovT4zTZ0tnmcUbEE5Er8UjnJPgsyrQK7BzwZmcTmJTNjDZDe14IiLRFz7XJmE7O42uckxmmm3NRLYNbCzBe/wJZ1qUi7egmOTTPpgBMnR4m/YyPFtRm0mWXEaaF5UqT0yd8SCkN6mp6iVXkoinOZCSnpdSmwBCJrqQlaKDTCmWODqBqbWZqrIP7TuwEIzphxPHmQsFRCaGgC43c+/wY9QV/XIKHZeQSZFIlBh7I8/882Jhb4+LEQuG5A/H0jhOYsCx5A/4P5SSunWhzs2JSO6m1qVXxBOPl4AwVYSM6KxdvUge6eHVz6h4eJUUGrIZuX01ZR2X2J8c07UM/NYmhrobinnijbHCaFnnSPCavGyDPL76LYNsxMYRllfQ3Em6a4tP0eSs4eIs06hm9qHu/GdQRPXGC4ajnrTz+NzxtA77ZjSslCmxKLVRuDdHQcdVkehvrLaIw6Jt1y4mZHUdy+DX97P+MzXkSVCllNBeHUJKItsyR2tWDrHsefl4Ppb77MRWkGOgWcGYHMsW6qPGNEp8fgmzCxWDZPekUayvJ8hrVJtDxTz3kxifLeBqrz1JQ/sJ2GkQD6H/6U2coaBo3ZBEcnUK5fjl4JSVooint1fTA4b+f4b89jMqawI92LtjCD4MgkwYlZpPGxqJdVYH/8ALK0REJz8xju3n7dkThkd+J6+Qzq9bX4rrShu2n9X2B0LPBxYiFw3YAszLZeT9jpZnr/RVrMclaW6WF8Cs2GpW+aejK74eSxISqC01hrl+J5fD+9tWuI/cl/Y07PwbJ6DR5jPPEPPoRLo0dlnUcR8BIX9iDGRpPY0oAiHMChMnCmdD2F4jyx0gAuq4ectsvU1+3BmZLG6jVpKH74C4TkRE7IsqkMzaC+epXR6FSyZgaJV4bwhiXYFFoUiIyVLCJ+boKMb97FxWeaqK4/Rux//C86Lw4TOldPvHmKcE4mMyojrrAAyQmIHh/m5AzKl2cxMThPhtxDyzR4vEHSipPxZ2Sgb7mKcvdGhn5/iJh7ttFnBvOcC/35C+xapCVnTSGHu4JkNpzD2T5A8JYdhOVK4oZ7yf78Do4MCNxcBOqAB8ezR5GolXjb+umZ8jObVcT6YgX+802oVy9BVVOOPC3x+mfteOYIIYcL7bY6PCfr0d+zPbL90ZfR7lmP88UTGO7dsVARu8AbWAhcNxjea90QDKGqLn3ng29wRJ8f98krWOa9XMxdzi21OmQSEP0BXCeu4LN7CKxZgV2uZdIBHSaYG7OweugyLcu3Ijt8Ek9OLuXn9mNJySBaLcF/9BxmTSzxUh/OsIxAYhISiYByYgypyURYoSTLNExbahmjX/wKCQYpsvEJSn/2fWyLq5FkpRPd3Y531sLZpbvIw4qisRmp10soOpqigasEjEZCGi0KhYQxp5TfLb2Pz13dy2RUEqkTg2T45pjduQvqr6EYGkGmkCLJScerUDMmjSG2MBldnIG88S6cyekcXLQTN3J6zJF1u5AIpbFBVl95mdIv70KQy6j/7UmeTVpJYqyS9NOH6SldRt3QZea3bIFzDciee4nSFAUJ22pRlOahKM5BkEhwB+CFFj/bWg6gLssl0DvMkZSlTOoS+UoNBK5cQwwE31DZKooijqcOo926Ctfh86hqSgn0jYIgoMjPxHutG3XdEmTxMW9xdxf4a2YhcN1ABCdn8Vxufc/mdx93xECQsMNF2O6K/O9wETJbCbs8OGqXct4Ty23FMOOKpMnUrzzAq31ujFcuElCqMNcuRyMJsejkC3ToMhgetLJ9ZRwj9YNcnJaiC3mJb22mOa+GumvH6cksQ/LF+7Bd6SC56yo2uR6VJEhWZzMNBctJy49neXKI/pAe3/NHiFeLyOpqEC82YQ7JMYwMolHL6dWn4ok2kj3ZQ5x5Em9CMp6EREIra7noM5I02k9C42VGa+uovHYK9eQE5sQM1B4nPlGCLiWGjHs3octN4Tn9IuLU0DQFOe1XaFGng0ZNRvdVemrXkxcLF8cior9L206RWFdGmxBPnxm8Q5NEWWax6GPRz8+SLXdTGBXEc7aRkRk/+dsXc0mRSZXGSfrOV6XDwm4vE//yG/q0KSzfWcoLsiLmPAJfKXTg3n8GZXk+ysrCN9yz1yq3+1p7CTtcIAiIPj+SaD2iP7Cgq7nAW7IQuG4Qwk43zuePR1KEfyXl76F5K/Pf+x1iIIh22yqkUfqI1bpBi8SgYzykoX4ionghESJeVzcXRV7bZYK+eQiFYbZtmIIj+4jzWZlfUsszUdWUyW3oBvrA5yPNY0Le0ERXUQ2pLhP1JXWsS/JjaR8m0NFHQKbEm5vD6q7TXEsopuHWz/BPu6PoOdKC9eljJM+Nkry8iJBSwdiFbpRj40zsuhlLYjpRRw6TPNaHVhJEYncxm5gGpUUMzoukzw4R8AaJ1wkEzFZUajldaWVk9LcS0uvJzI4mkJXJVEjNVXkK2foQGXUlaLISOferkwRv2k6aHgKnLlO5LJ1D/lRSNCEWT3dwdcxPR241BiWkGaAqUcT5+H5sE/NY7QEyeq8xv3IVXm0Ui1dmELOslLAIR/deo8AokLU8D9fJK7hPXEH9iZtoTSjmxR7I1It8NtxCaHwa7c61b2l779x/Bs3amuvl7c4XT6JaWo6gVOA+fgn9HQsGpwu8NQuB6wZADAZxPHEQ3a2bkGjfq4nDxw8xHMb53DHc55qJ+fr9SNTKiM36XduuKyocHYjoC+4ujBQMzLsjJe5GTSRYxWsiFtyt+xqpvngQt0LNsCqRKwmlrNZYqF4UB6EwHc9dIrv5AsPZZYzGZjC6eAXbo+exH7+Ma2SGKBUU5xrQ9PcxZhNxVlYRzM0mEB2DfN8hUmeHMKZEIYk2MDNmhUAQW3QCKbZJfGMz+FQaYovT8HcOMp1dhGx6llAwhEIhwZqYRkCmoHBtEfLTFxnQJpEasDKVlo8sPhqZWkF2XTEnlblsm7iMwmHDNzTB+NkOtJX5mKqqGTJmsXWRnpZ//B1zKZlE6eRc1WVTsjqP6hSBVEPkM51sGWX6n35GrMfGXGklBZXJSGOjUa1c9Lp0Xdjrp/4f/wBONwIi9poaSE8lKVGDXiUhtuEyysXFKEvz3vr+BYM4njqC4d4dr24LhbA/8hJIJBju2na9UGOBBd6MhcB1A+B49ijqVYv/Kvpc/AOj2B58DnlBFlGf3H19dhkYmsB7rQvdTRt4sl1gwgHb86BxEsYd0DkHdRkRGxKFVETV04v06X2kumaxBKXIZRJM9iAZljF8+igGK5Yjb+0geaKfztIVXFu1g9Q0PZ/0XePpmFoG919hw8A5lJtWkuiZZ/LAZbwFBbSWrCRqbpK8KydJCDjwx8ejXVnJcUMJZYeewSFVk5YZjfdCM36JjGTRSdDjx61Qg8vLZFIWgkJJ17qdVMnmqYkP4nn6IJ7EZGbGrVhuuZn1pRr8HX14167hv6eTWXPtKIVbyokuzuClRhdVHRewFxQQvNCEuq+P/oxSDCVZ+GfmUW1fw5JkSNRFPs+w18f8c6foPXqN3EQFQ6p4Uoe6Sfv9v7yhKMLb1IG/awj1xmW4TtSjKs1BGm2IpGedbkSfH9WyCiSqtxerdR27hKI453WFGgCheRuiP/BXMY4X+NNYCFwfc9ynG5DGxaAse+sn3I87YjiMv2MAz5VWgkMT6D+1B+WbWE/4Ogc4cnqSwcV1JOsgWg1ViZGZVZcJijEz8JsD+PpGUE1MoPG5cGoMmFeuxjxspkppwVdcyLhTRvyh/RjmJjm/dDfHd3+OtNlh6kKjdCcVknr0IGmTvcQl6vH1DGMPy1DL4UTtHmpdg1RMtoPNwXRKNlK/n7MFq1nXehSHKCezKo3pcx1Yk9LwlZWhuHoN1fQUfSlFeHUGiqzDFO2pQXbyHBPL1+Kbt6NvacEcVmBIiibsdJNVkYL8nj28NKah5toJxlet56wtih4z3D54jLzbVpKZosGohs6pIL4nX0Z2xw78ZxqwZOayqS4JiAQi77VeWpsnyRXsxH7pdi6cGWNFuQFBKrnuABwYn8Hx5EEElQJZohFRJLJ29Rrvq3dL2O3F9fJp9HcupAIXeP8s2Jp8jPF19EM4fEMGLVEUCQyO43ulSlJelAUSgdjvfvEtn+gf9uYi13rJ7ahnucGBUhIGoH5cJD1sY3zfaQR/iIQYNfK1VUzmliJZWUvT/zuCa+MeZpQh0kc6SW89j9Jpp7tuK2J5BbuP/4EolcB8ahYrf/sDzGElI5VL8PW0EhUUUSugZcvtbGk/S6IqhCkuFWtdOdK5Oc5XbGTbb/8Ni1tETIxnpHGYsdwq+oprqDjzEgqfi/Mb76Sir4GoZAllBcn4z1xkLD4L47GjqAhyLb6YWJcZSTiEPb+Aw/MyRs/7WWtuQnLnTlQBOXFBWKaaJzCt5NkRDZpJyI+FWLWMDZ9ez6mHjjGzYQsbm/bzUng50Q31pOfFYj7TSeauVejx0q9OJl03jGb1Epz7TmJ/7hj+pk6QStHtWYe8IAtfYzvB2Xm8V1pR5GW8Z91A97FLaDav+FOHxwILvCcWZlwfEcRgEMfewxhuMItx0R/AfbaJ0Nw88pw0VIuKERRyXEcuIC3KRZmRxP8UBhdF+HlDpHE4KxrWnXgCeUIsquoS5lpG6DnWivFqA/NyPd4li5jSJTIfkjNQtYqYgwfI1wcxukxMhTXIlHKiujs5eMsDCOEwi869hEwUEZLjKW88xaAiDmm0lsLBVtxqHd5lSymTzDNyuQ/jN+5nYsbLTP8MqoAX/WAfMdZZxHkbCrWCqZxiLmfWoJufpaLjAoJCTlflKopcE1iTM1gpmyY4PIlpzo1OJ0eTYqRNmohap2Llf9yHIJEw/ff/RW/HHDGZcWTvXsrZ5MU4ZBqyY6D47EH0t2xEUCowueHCWOSzKZVY6HzmAumzw1ii4ilRu1AW5dD5zAVaPvllysfaKNhSQfu+JlatyyLQNUDQ5sR7qp7Y734JRX4G3ovXCIxMoV69BHl6EiGzFffJK+hvf/c+byGLHc/ZJnR71n2wg2aBvzoWUoUfUzznm5GlJSHPSvmwL+UDIezy4D7dgOhyR/p3kl/1Yw/OzuO61MILWeuIUkZSfzIhUg2XFQUPt4AnCLWpsFo2jb+tF1l8DKIAR393hcypPuqXbccZEBgsqmbpqeeYd4WJH+wCjYaxFWuZLlmE2uche+/DnNh8PwGXj6+e+BkKhZSsvGhmT15lWtARMESROT+KVAIBvZ4ovxtrTDyK2ko6pkUKZ3sw+JxMS/XYVq8h9pFHMReW0lW2isSrl1E57WhCPlRaJerEaHTJ0TinrCT3tSHPz+LyjnvJritB9uhT9FqlKDRKdn5rMxKtGl/nICcazCzNU9Hc72YotYDa0avYHX6WlEXSiIqcdILTJsI2B2IohOtaH2f9CdSnVLJBP8+GlUn4ugaYaR9j9O77qZNOY957hEl5LOrJMbJvX40oQmhuHs36Whx7DyPRqlGtqHqDcaP3aheiL4B62bsrW3c8ffhtKw0XWODdshC4PoaIYsQN9uM+2xJFkdDsPN5LLYjhMJo11W9UCBdFHI+9zNnF21iispKUE4cgCATDMGaDJ9ogLMLOQqiK8WPfewjD/bsYaZ2k919/h1UdTcPamynrbeBs6hJuvbSXTm06OT3NmBPS8C9ZREfFSnJlboof/w3tG29mpHOGT595EGHTakIhEfPJRkR/kGRlEJXTRiAUxpmQiipWy1xuCdNSPUndLWQ3nCWsUGJZtJgExxzugQkGKlfQVlZHTlcDVoWeNNcsSf3teLKy0OhVyN0uYotSOb/5HlqERMqMITQHDuNevJgVMS4yc2KQJxkJO91c+/0p1HduJ1Yj0PqDpzGnZmOwm1mZHCJsd6AozEaWHB9RXY/W4znTiCc+kX84J+Or2j7ExhbCnX3Ii3OZyS6iNkNGYHgC1fIKAv2j6G7djOvl0yiLc673XonB4NsqVzheOI56edU7FlQExmfwdw+h3bjsTx02CyywELg+jngbO5DoNCiKsj/sS3nXiIEgvvY+guMziD7/9e1SYzSqmrK3tKvwXLjKqCQaq8VD0WwvirwM1KsWExbh6Q7w9o6wxtxGXJwKf/cw9pJi2q7NITl1Hl+0kbm0HPTmWVpUaSweuMJkYhaF/dd4cuNnMVfWkOyZo+LaGeL6OzlSuYPY2XHWD1/El5YGA8M4pSpErYYSgx9XbDzp3VcJh6EtKhdLVBxajZSSlvPoxkfw6Q2E3D4UQR9+uZr5pHTkWSkoHTa6Y3NIt06gnxyDFdV4AiJDpbUYolU4O4dRJcSgUktJ6Gknd8cS9LkpSPRaRI+X4Ow8s08epT+5ALlKgYBIscrNmXAKrqW13FchcH4MzJ5XPzchECD2xRfplcaRJveyemk8qsUlIJfT96uXScqKRXC78Ta0IwaCEV1LUUS7YzXSKP27v6/BIPbH9mO4Z8fbrnfZH9uP/q6tC/JNC3wgLASujxmRGch+DPfv+rAv5V0RsjvxnGlEdHtRLipClp6MRP325dKvfa3lwHkuBxNYlx7meHI1sUcOU353HS9OaPHNWtg1fgFZeQFjT51izuTBZExF09pKfIycC0Wr8bqDaIYGSLJMEhuwMxuVxNGbv8T94U6i3VYaMquJ2bsXtyBHOTdHtNSPVVSQMdHHhcVbSA870LutYHNQMNKKSxfNierdLE0Hw1A/ujPnUIQDCEnxjOqSkRijsc3aCeqjYOlikhou4HL7iXPPI/H4oLYKs6BGV5xNgSFAy+UxejbfREZeLNUX9hNVtwhBKSfsdBOcmoOwyNyEjSeFQsrXFbE2MyLb9FynyPL6g7h37+Dxdri3PFKMIYbD+K5143jmKM6CQn4Rs4bv7dbwWtlKMRhEDIZwHbtEaGwa/X07kRp07/seB+cseM40oL9t85vu93UOIDrdqGrL3/d7LLDAa1kIXB8zfG19iKEQqqqiD/tS3pbA2DTeS9cQ1CrUq5e8p6f4P2J/8iCX59VUV8UxmV+OLwSZMjfPff8Y89GJLG05zunEKvwSBVGCn8k1m4g5e5p1082cSKlhOqhkbfsJ7EkpZM8McbV4GXkT3cQbZIgyGXMhBYVtl1D4vXRGZSJIpMh9brLMo0hWLMERm4Ch/jLCvJWQTI4pKgGnXEOSxINyZhrFvBkxPg5PYgo2V4jQimpcQQm9O24l79JJcp55DCcKFITwyJSIu7fgWVaLLieFFXkKjv7flxlZt4V7U6yETl1Ed9OGiPliIIj9sZcxvXgGmzuES6Ii7e4NxOki0afDKic2RoVhYphmaQqKrBRSjXLSZocxz3u5ps3E0jHC3Kat7C6MqLb/T0IeL/Pf/Rkx3/zU69YT3y/epg4IhV8XnERRhEAQx95D6O/fdb05fIEF/lQWAtfHDPtj+yOyTh/BL4Gw0423vo3gjBlZagLqZZXvy2a9ZRpGL3YjnLmEfHMdKzfl8XJXkLyuJs5fs6CbGMM4M0pPZjkrnX0key1MxSQzPBckxTLOcHwuzcllVAomzqVXs/H444yU11LhHifh7z9Nn12GZXCGsscfxDdv5+UlN3Hn6T8Qa55C6fNiTs1ESIgjurcTMRhCUCpQBXyIGjX+sIDEbickkRK1djEupZ4WQw6l64uw/PppvInJ6O3zSNo6ac+rZjwtH4dCh3rXOkrN/eRPdlOyoZhTp0aZjklljzCIIj4azboaBJkMT2Mnw799mVldAvqVFShHx/Aboki9ehkhKZ4rG24jOUlDvNRHQ4+b1R0nidqyjGOdfto1aaDTsrLzFBW7l2BINLzlZ2z+t1+j2bT8A9UDdJ9uIGS2vrpBIiBIpSgXl7yh2XiBBf4U/uKBSxCEdOARIJFIgdhvRFH8r/9xzFpgHzD0yqbnRVH813c6940euPz9oxGfreWVH/alXEf0+fE2dxEYnkCi06CqKXvfygeiCMcGIcptJflXv+La7ntILEnn8ItdxA33Me5VEKURqLL0YxC9ZNTk0BiKp0mfy/CQlfsuP4ZbouRc7goWj1zFJcjRuu24k1KYTcpiNjEDR1hGwVgnSTPDKFRyGgqWsWPfr4m2zqIxqFAHvIhuHx5PAAkiSoUEEBH0OjwBEY1Rz1xsCsbaYgbHnTij46iujOP8hUnmDXHETw2T1XsN247tDEtjUMglLL9/JefGoM8M6dog2mMnyLlymsLKRDQ1ZSgrChB1erp+8hwOd5C4OzaRcOoEwbQULhSt5pZlBqYccPHkADXnXsIZlmLOLaYiAaYvdTKZmI1EoyJvYzlJqXo8J+vf0sNKFEVch87jb+0h9u8/9yfc+QUW+PD4MAJXMpAsimKzIAh6oAm4SRTFztccsxb4liiK76ls7kYPXPbHD6C/e9tHxmcrODmL69glNGtqkGUmv+dZoBgK4TpwFtEfgJQkjkiyyfPOEH75GHtzNpOfpiGxtYFRRRyusTkq7liGbHSMRI8ZpddNl09P/0yAflUit/Qe5XxWLVMSAyuaj2DJyCF+YoCcyX6Gl68lresqEq8XiRhGHvQjen2oHDakoSBSCYjxRpSEcSEn6PETJQ8TTEvF5/AQpZNhDUiJqilmyKOkojoF39Jqzs+p0R45irm+G29cPNKsVFLam5i6427sPshxThKTnYTH4iBX6iBBHuBYhw9/Sze6f3yAHYt1iG4PjrZ+rv3uJBm31JEg8+F8/jihL97PSSGLW4sjslUOP2QYoH0OUtsbcRsTcKZlUKR0ktHegG7HatzHL+M+00js330aacwbZ1vBaRPuoxcJO93obt+yYBmywMeWv7hyhiiKU8DUKz87BEHoAlKBzrd94V85gbFpZKkJH5mgFXa6cR+/jOG+ne9LjT7s9r5SQFBAp0NFfZuPlR1PYPP5mDd5uCd0EPOQnkFTGEFqI62uEvnQMEl97RwzVlGfUsmOh/6DxOREkqRjPFd7O7VnX0SriqZhx73ETY+wuPU0V8vrWHH5OEGZgr4lK4np70VvmSPWbcet0qAKBzDFJWEwTYPHgyspjcDuTZhKikk9+BJaKcxY/Miq8pjpn6X4czsZbR9j4h//gEyh4nLpKvI3JrC06SSBxh7Sv3ob/rEOxgbnCbo8WBPjqVqRgzJax+O9EhLa9pP4b18lpFazvw925qtpaTFRUFeI1jpDwOHC9rWv0B6OZUcWvNgNajlYvTDtjBRmsGwJS069jHFdCoJMh6PRB6EwysUlhMxWXC+fRrWsAkVBFhApxnAfvYgogjQ5DkVczELQWuCvig90jUsQhCzgLFAmiqL9NdvXAs8B48AkkdlXxzud70aecdn3HkJ/6yYE+YdfTiwGg5HZ322b36BGL4oi/s4BZMnxSGOjXrcv5PEx5VfQ329FOHIaf5wRMQx2dRRLO04zoE0h2DeENuQnlJyIdGKK8YQstFo5sd2thIbGmY5KQR4fhcsvoA75COv1mMuq0DY0MJRRQnp2DDlPPYp2bprhpFySYuS4/WGCEhmJ/Z1INQokDgdeuZpQUiKq2jLkDVcJB8M47r6N7OEu/J39EcmpSROWlAx68pYg+gNklybi6B1nqKQa36JFTLlgvaWN1CP7cRkTyQqakWelcrJkA5bYZOQ+N2v7z9O+bDO+gIju4GHW3V+LPNHIM52RKsCxY00k11+k6tu3E5o20TIr0JFSil4Jsy6w+0CvgOwYqEmBBG3kswxOzOBr60O7dRWB8RkC/aOEZsxo96xDUCrwnG8mNG1CUZKLr6kTzablBAbHQSp91w3DCyzwUeVD0yoUBEFHJDh9/bVB6xWagUxRFJ2CIGwHXgTy3+I8XwC+AJCRkfFBXd5HiuCMGWmM4SMRtACcL5xAu2Xl64KWGArhvdJGYHAcRWkunnPNCAoZmo3L8SDj4vF+pG1dxAad5EwMo5QJ2CVpmGZcpA0PMFBQitXkIt4yj2ZdDRdI4WTaTnacfQKZZZZxQYJQUUs4HOZoQhXLh+sZ3bOHmBf3kfDyC8Q7TBS2XkAR8mOOSaa3bBnZrikkvcPowgIejQ5XRgZypwtpWgaWFSsonurC0dFFwOFBkZ9J8nPP4HF4QCGLaEDmpRGKSiZ+US5TBeX8aFqPPgtqYn0I5y7xueh5nH1jCE4XWcuMsPtu2p65QPssuIFCo4Zur5bFEhPN53pYuaMYRZIRiNitHPvDZbRnzjH+d18nPWjnyNFpzOs3sjUtsh5mcsNdpVBg5A0yV7LURLxNnQTnLMjTEnEfPo8sK/W6jqOmbgkhm4NAzzD6+3fha+pEDIXQvCKcu8ACf018IDMuQRDkwH7giCiKP34Xxw8D1aIomt7uuBt1xuV4+jDa3eve0S7iL4H7xBWkScbrfkthrw/P2SZCJguqpRWvkwWyjZloe/oS3uQUygLTGArTCHQPE/YHmNPFMX2pC+nwKKNFi3DNuyltPIHPEE3AF0TpcuCXyBhJKcAoeDHlFDHukZFgGqfANYFhdhLB5UEuBlH53HiVaiQpiXRHZ5E0MYDGaSMkU6D0OJHGGpDoNARNNqSIzEQnES0PI1VI0VrNKGJ1hCfnEDQqpMYYAgg4MnPoKqxlOLuUeJ2UiRErcrMZ/7QJrcvO7WUCk50ThExW2r7x90hVKqYccHZEZEfbQTJ2LyMh08iMJYDw379n6Z21pG+suv6ZOV88QX/DCHvXf5asGNDsO8Dyv91FdpyMR1oia1pfWhJJE74Voj+A46lD6O/bRchkRRqjf9MGX19bH8GpObQL4rYL3CB8GMUZAvAwMC+K4tff4pgkYEYURVEQhFrgWSIzsLd98xsxcIVMFrz1bWi3r/7QriHs8uA524ivZxjR4UJRnENo3kZgaAIxEES1tALdtlXX1S9cfjg1HJFiqksPE/jHHxCYmI3MGkRwJqYwOmzFFZLQXLCC3Ll+cjsbENKS6FMkMO+TMpGaR7FzjNSRLoZis7Coo0l2zxEV8mCcGGIwvYhS9zjhGTN+UYLTEIvo8ZJgmwGJgNLvRR7044iJQyqTorHME9aqcan0hAUJcp8HjdcJgkBYo2Vo7WbsSgPzaTmoVFLiZsaYG7cSdnux+0R8CjU+rR65To3o8eLww3x6DtnRIsqlVXhDMGqFyiT4yqIgzicPotm1jiNNNvLPHma8rBqvRkep1Ep4eIJhl4yz6TUsXRxP0uH9lN+3BqfGwEPNEePL+ypA9i6WM33t/YSd7rdM//l7h/H3DKPbtfaDGg4LLPCh82EErlXAOaANCL+y+TtABoAoir8SBOGrwANAEPAAfyuK4sV3OveNGLgczx5Fu2PNu1ab+CAJzlnwnG0EQUCWmoCvpQdZUhxhtxd5ehLKRcVINCqCM2a8De2EHS6kmSm8qChmR4GA7NQ57E8cQhKtQxIbhXpZJcfkecz89kWUUgjExZFXnIBj3ol8eJR6dRadxUsxih62PvFjXFFGeqvXkjDWjyCV4MjMwT3vYNWZ54m3TOGXKrBpo3EYjGRM9qHxOgkplHhVKkS5CueWjQTMdqKuNXN+z6fpTy1m6+AZas1duK71YveGcC5bgVopwaiTELcoj1BiAg/bMzjqjANBQCODjChYkRxknaeP8fp+jshyCZWVUHzuIIqbN4Nchj8IZ0dheVok4MiDfgwv7ydlrJ/ErdX4m7tQ3baV9kAU8rk5XA4feVuqMJw8xaWYQnQ5KZwZgbJ42JH/xtTg22F/8iC63euup25FUSQwMIbvWjcSjQrNtrqPZN/fAgu8XxYakD/ChCx2POeb/+JPy4HBcTyXWyJddlIJgYExCIdR1S1BXVP2hoKMPyKKIqf295Bdfwat2448OxWUCuSJRpoKVnDov09R1n2FnJJEBnbcSig2hvHnzlJz6QAXNt3NcNFiFOcvc//Bn3O5ZguJGpGo2SnmSsuRujyoenso6G1C5o9U0fkUShSIyH0eglIZg+XLSDZPYHBbCO7aimp4mFD/CJI7dzLXO41SDOLvHUYzN0M4IY7UJ79P18vNJBUl48vL45EWuDIJRhWsTAdnAJTTU+xwduD0BLHmFDJgzOLeSgGjx8qZF9qwrqxDr4TLE3BrEZQmvNIr1dyD/WovyZ/YhqCQ42vtRQwGkcbFYDnXwsUlW9liakbQqlFUFvPr5ojp5fL0N/1o35aQ3Yn76EXUK6rw1rcT9vpQ5KShrCp6Xw3gCyzwUWchcH2EcTx//A1FEH9OgnPzOJ48SNjnR5ZgJDRvQxpjQL1qMYq8Ny98CXt9BHqG8feNMGMPY1JGUb0hH4lOw/yvn6HVIqd32IXSZiErTsai7RU8OqwhqqkRhc+LTK2gsWoDEz4FaV1XWdlxEnNJBfESH1OCFos2ltzuRrQWE9EOMwaXFacuGrlGieBwEkKgL6cSrU5O6mAXSq0C446V2FoHscw4mC2uxGWMJycqjHHfPuQ6FaqacoS/e4DO/c2MyWN5ikKGbRCnjhRN6MI+nBeuoTCbyCxNRr+0lJBCyUs9UJYAuTGwpPk4ms0r2TemRi4Bux/uKRPxt/bia+2FskJaY/OpThVQeFx4LrXgfPk0uls2ciRuMevFUeQzM4RXL2d/H1SnRKoM3y++lh5CVgeq6tK/2HhZYIEPi4XA9RElZHfiOdXwrk33QlYHgeEJgiNTSKJ0qFdUvePTtiiKhKbm8F3rwX2+GcIi6s3LCc9ZItbtqxa/rQqG51ILjmeOoCjIIpSTyUWzmvUpfubGrYw8cwabzU8wM4NgbCxpc0PYYxMZNYdxGGIZLaumovMCdpMLlComFdGsv/QCkthoBkprGZIbyRloIWluFKN1hiT7DIIoYjEm4VVqsBni0KqkWPIKKJnrI1orRVFRQHtMLtPPncEYduNds4q0kiTiH34Yf9cw4qZVBOITmdiyg6unB+j1qGhLKSNFD58oh55uM+rmq7gDIrKlVWxbGc+gFSbs4PTDA1VBggdOMhtUMNxrovr/dysqmcB3TojcLfSSM92LprKA9rgCRsacLHEOMdgyjsGopWp7BX6nh86mCXzZ2VQMNjK6dgvtcwK7CkCzMDFaYIF3zULg+gghiiKeM40oSnLxXriKZtPyt7T7CJkseC5ei6hOAJIoPfLs1Ig7rcmC58JVBLUKzdoaJHotYYeLwMgUwYkZwg7X9fcTrQ5EEXR71hGat+FrbEe7Y80berBee43BaROOx17G19aHZv1SBJWCywMBDH4nE7M+AkPjZJmH0RemMztqxZSQxnhyHsaeNvqT8khThYgZ6UPq8pAwM0I4HCbWOQ8SAY9Sh4QwklAIJUGkgQCCVMBuMHKpsA6tAtI0QaJGhzHVraFivA2JGMY2Moc9IOBMTCHj7vX8xF/GytPPUdR8lpm4NJ6+/Vvo7POYUGMwT5OkA1NpFUZliITRfopM/QRiYmnPXkRxmgqtMiJBZZAFmZnxsDHBifv5YygKs/FcbiXg8zNqCuEKChQVx+FNT6fLImXi/9/enQdJct2Hnf++zDqy7urqo/o+p4/pOXpOzAwwGAADgBiCAEhQvCWalLyWFZIdWq+9smnHxu467LUcG+GV9oi1tRR3ZYYkkgIFkCCJawbnnJi75+r7vq+qrruyKvPtH9UEB8CAADkgBt14n4iO7sx6Vf1+PYn44b18+XuLOTojsKsrgKezGRpqODoiODoKHRF4+PLPCGtFjt/1KFUhB/vrfyOXkqJsaCpxfUzY6SypH76EcfcO8lcGyJ25QuRf/sO3le6RUmJeGSR/dRC9PIznnp3vmdigNApLv3iSzPPHweXE2VyLFgmhuV2lBprA3dOFq6OJ9E9ew1FTiftAD4Vrw9jJNLJQLCXGYhGZLyBNE2nZmANjONsasRCMxeBq3MlcWhIdvsHmvrN4XTDvq2AWHylPkPrEDBXLsyxEmzAySVypBKHMKhaCOX8lTfEpdJeDnOHDsmycxQK6y0FS9yA3tTAS3cTFyCa+evy7kMuzWtNA8sAB2s68Ssq0yeUl5x/6PKl7DpKZWKD+5eepmR7CsExefuAruPf10DxxA2N8jOrVOax8gWl8NGYXKav0kd+ymbNmBH05RocnS00A6gIwlYCi0Nna4sXsG8XZXItrWzvZV84S+t3PcXYaBhaKVMxN4LYL5BoauL/by1gcrpd2JEHXYEslbCqDc7MwPp2mKHQe6jao9P2GLypF2aBU4noPUkqK47PkLl4Hy373Mi9N4D307h173/U5tk3udC/OtgYc0fJbtinOLpJ+8SSB33oYze8l9czLuPdvJ//mldIWJvu2Y14ZxFpZxb21HdfWTe+7SsyKJcgcO03uzSt4H9xf2pQwly8tuFiLDylL90bmlwn/06+hR0Kknj6Ge3sHemUZOJ0Ip6NURuhkL2l/kMWRJY5W7iD69z/kVM+DZBwG9xz/Ea0Lg/isPAlfGd7ECktlNbgcGiKbpWjZxL1h2iev48unAUlBc2BpDlyWSdHlpmh4SFdGMQsSp9uJJ5tGODW0ZAotnyOYXiXt9pGKVOKjiObQSGpupjp6OP6Zb9By7Syb3zxKmZ0lGSpn+9YK+jr3sHjiGmJ+gWJ5BeVBnfRcHEdHM/sPt2JLyJ+8xPjYKkZbHbu/cR96wMdsqpR4oj7YFoXchevITA7PwV3kzl9DeD2sNrZyeqp0T2wpA5YstX8/BauUzDS1yE9Rfm0qcb1DcW6J3JtXSku+G2tw79p8ywd/pVkg9dPX0CNhPId23zKRFMZnyLx6Fs+BHgqj01ixBN779rxt/6P85X7MoQn8nzuM0HXsdJb088fxHt5H7tw1ClNzFEencbbU4e7pxNXViuY13t0fKbHjSYoTs5h9o6Xj1ST+zz14y4Rpp7Okf/oajvpqnJtbSX7vOQp9owR/70nc2zsQQmCOTDL1kzc5vezmzUyQbVdOkPUGaEpMMbbrILV9l4mOXMdwCjAMrlZ1kBEuTm85TPPkDSxbcveN16hcnka3LTRdQ3cIVh0BZkNV6JZNuRfi3dsYOPJbZJ89Rvf0NdzSwiMsjOlJCkUbpy44/sAXqSdNbbkTx9wc2uWraKaJKFr4rRyOtgbCd23mlWyUNA6i4/2UV/nRD+5hdTkDi8ucd9XTdaCFHVO9hM0kmfomXgt28lC7g8j0KLlz1/E/+eDb/r6FsRnyl/vxPXF/abR7uR/P1x7je9cEX936wZ61UhTlw6US103MgTGKM4ullVm/ZArune/Jne4tbXm+Nvqyc3kyzx1HBHx4D9/1VmFcaRbIvH6e4twSelWE3MlLOFvrCXzxETSPG9sssPIf/p9SMqmPYuzuRq8oFUOVlkVhdBrz+jAyl0e4XMhCAWs1ieYxwOVELwvhbKxGForkL/UT+MLDiJ9PC66RZoHkM8fIne7F3dOJ5jXIXexD8xk425vI9w6QPNlLIp4jqRmk3V6C8UVCuSQObMhksRHYhoGWz5P3Bxmp7eR6eSs7xy5g+QOIYpHIwiRlyRVybi8JfxhfeZCix+Bk7W5CU6PULoxTZSfwFfJYHg9iYQnNthF+D3mHm3wBEpXVVNUEmArWEHMHOd2wiwPHf0QkvcJSUzvz3TvZoy9SrKlm5dIw19p2QVsLj914gfnObRjjYzS782S2bWPY8nNo9BSm28Pc5h08txyiYEN3JXgcUB+Aei2N4/ljeA/uwtlajxVLkP7Ja7j3bCF/7hqube24ezp5dkCwv/4XdQMVRfloqcT1azCHJ8i9efWt2oHC68bsG8PZVIujpY78ueu4NrdgJ9LIbO7tb7ZsspduIFfTODub8d63B7N/jPzF65h9o/iOHCT0j7/0vlOBqR+9jDQL6NFy7GSmNA0IFEansFNZfI/ei7u7Dc3vJXdlkOTf/hQ7k8OKJfAc2oPvvj3YyTT53gGMg7uZGosx+ewp4gtpXul5mJCh0Xz6FRoHewlm4rhyWbKhMjwUQUrsWILFQCXStvEUc5TnE1gOF8VCEVfBRLh0UoEIaacHt9/DWKQeS+i4E3HK5yapsDK43RqXW3fTPNWHWVeH57EHOD1poS+uYLa1cPD886Tcfk4XKnCZOXZoy+TcXi6GN+FzSKzOTVyt6sLpEOyKSqyXT3Do2jFCd/cQqQvjvm8vL45p1J0/RXuFwPvQARalwdERONQEDWu3D3NFmE7C5CqsZCWBkyfR83mCk2PoVWU4Nm/C2LsVn0swlSjNGu+t/dAvK0VRPiCVuH4FZv8ouTev4mytx9i/HaHrpSm5ZBprYaVU9HRyDndPJ86Gahz10beN3PI3RsifvYrv0XvRK8rIXxsi/p+/j3t7J9I0CXzxCPbKKrmzV/A/+dAtR31SStLPvoqzreGteoE/P5956SSaz4txoIfi2Azp186S+uFLFBIZbJ+ffMEib/gQ+TxCSnJSY9UdJIuOsRoj6/ETSS3jT6+StnV0p46RTmJ5vcTbNmEsLtPfvovrWiUPn38Wr5Un4NFxrcYwCxJsC1c+i4Yk4w+TMfz4HUXMUARtNcGy8FC1MAEeg2S0lpG6Ltpj41yt38bKpk42H3+OtuQMISuLtZpmKtrMYjCKpypIT4uPgakMWUvgum8fy1t6qA4IHBpUeuHoC0P0PP8UhS2bCUqT4JEDDJwZo8efoebxAyTdAY5PlO4vPdz67ik+WSxiDk5g9o2AWcTK5DCjUayD+8haGpkCZAul9++q+U1dYYqifBAqcb2Pn2/Tkb9wA2dHE8Zd295zNGQlUuQv9qH5PGh+L8LvoTA2Q2F4EvPaEFo4iKuj6a2FHnpZED1aTvy//B2h33sSY0cXUNrrKvX0UYy7d75VtNZOZSguxUj98CX0SAgt8It5Ks3vJX9lAMfurcQtJ7E3b8Crp9ASCbItLYiyEJ69WwlvbiCkWcxdn2J4OE7G1nAsrxCZGmW2qYvAyCAjrkoi8Tmql2cIZOIs1DQx0biZ2UAU/9I8rmyKbRNXmK5rwyjmqZkfYybaRHhlgcrYPCsVdUzcdZCay2/i9rpIGEGc09P4MwkMM8OxXY+SbWgh6zY43PsSWanhzyWJrs5TbGwg09HBvCdC392PcHeXjxeG4X/oXCT1ylmujGdp2VROdcTJ2VmN4VWNRzo0YucHMGcW2fEf/xGa38dCymb2xfN07Gtm1RRcurSIf2WRLn8Ow8G7F9qsXdOu9iZcm1tVtQlF+ZhTiesW7Gwe88oAhbFp7GweO5HG95lDOFvqbpm0pJTkTl6iOD2P59AerHiS7OvnsOaX0Wur0CvCeO/bix4OvO19uTevUFxYwfvovWRfPoO1sIJeV4XR04UwXCS+8zTFhRVcXS1oQR9m7wDufdsxdm2mMLXA+OlBFhMW7nMXKJSV4ZmYwO0AA4vg4b04/Qbu7jZcO7rIHjvN3Gyai2kfo7UdpEIVNI9exfXGSSL9V5ESVt1+AnaOaHIRo5gnEYhwo3E7IpslYGWZc4RomR+ivJjCU8iyqvvIePxULM2Q8AZ5/vBv48znqZodJVbfQji3Svu510n4wgxXt7FSWcdDZ39CeWIRpwOWK+qRXgNfZxODD3yaXVNXmKhpo+1QFxrwt6/H+OP8GUYJEs/Y7O3049jexQ+u2OSX47QNXCLZN4Hl8VC+pREKBRACpwZBp2Qhp+GrDrN9exXe+ooPfN9SUZSPN5W4bmIOjJG/1Icw3Li2bEJ4DTIvnsT/xAOlhRvjMwA42xpwbdmEZrixlmKkn3sD9+4taAEfuTO9oGt47tmFo+rdNXyklJhXh8hd6kPzexBOJzKTBU1DWjbW7ALF+RWwbfyP349WFiJ34RoyX8Bz/17yE/MMXltgtrqZpi011Jx4DSMaBlvibK0nd+F6qT5gTSVy93b6n79E6umXGYi2syANamaHaR3sxZtcAVtSdLgZa99OKL5I49wItpTMROp5s/MeymPzVCUXmKttZd5dxoHBk1TYKRYoJYCozFDu15jzV/J6836Wow00Lo5jdbTh/vHPqJkdI+P107Y4QsYXIhUoI9pdh6shyt9V7GfX6Dnu/XQHKzMJroxlOd51H0GfTnRxEtE3RDDiYXTLXvYOnMJDET0WIzm+QFTL0Vht4NnZhaOzBTMYIlAZeGv1Z7ZQ2oSxPlia2lMUZWNRiesmUsq3RlSFkSmyZ3oJfPFTb9vjSNo2hZEp8lcGyV/qw87mcLXUI3wenC11GHu33nr5fN4ke/wi5sAo0iyiR8txdbXgam96+4aMa32QZoHs8QuYw1OY14YoOhzMJSSJhiaaGoOU5ZPYqTR6OIj3of2YI9Ms/vQEVxu3M5g1cMfjBGcmSBh+ksJgz7mXcFt5CoaXdKiMXFFD87gp5kzqpocwzAxDzdvpb+mh4+opnG4HZlMTjqVlmkauEHJJrFweO5MDlwtqqlitbeRqZRffbf804dUlnjj3NO50korlafK+ICGXjSk1nj7yjzgweoaDZWlGUjpLje3cN3+RISPK3IVRZCRMlauAodnM53RW/RFWQxWUixxdl99ABoJkO9pJdnTRenc73e23ruqhKMong0pct5C/3E9hbAbfE/ffcmqwGEsQ/9//Gj0cwFEfLa3sW1ktrezTtFKi8xqlihOZHHYmR3Fqbu0eVzOeAz3vWQjV7B/F7BtDFgoUUjkWTt9gsbwOZyJOQ0sEr9+FeW2IWKiS60YtA+07aXv5Z2iWxckHvsDuC0dxttQhz1zEMTOHP7GCO5vC0h0UnG5ybi/paA2rmoE3vkI0u4SvkOPV3Y/SdOM8TbPDpDq7iDsDlE2NEEmv4LQLOJNJBLDS0k7iK1/m1bo9XF1xUHH5PLtH3+TA2Bm0TJaCP0Cqopq0y0ti335ObDtMT+/r9Fx8hYTtpGL/ZmqdeU4bzRjj49y92Uvw0E5mw7WcmNbZ51yi//gg7dYy4fEhQt98Eldn823/myqKsnGoxHWTo6Mw/sJFmjwm7Z/dR30Q7KlZsm9cwFpcoTi/jMzkKIxP47lvL46mGsQ7/xxSIosW2DbC60G4naVt4Le042x+7zXUxcUYmRdOIDY1MVjdweRkkvCpk1Tf1U50dozlHbuZeuYEiYEpFurbKHcU2DzeizE1Rd4f5M3qrQSv9uKLLeLPpinqOvFwFeFUHOF2koxUkXZ7iWEwF6rG7xZs0pL4rlzGzmSJrC6R8Iexa2vQY3G8qTiaJnAUTNKGn4nKZp7f93mu122hbm6UlpkBDl18jtbFUQAWK2qZ7bmLWW8FldkYlWEX151V7Jy7TkUuxsof/QF799dx/tgA5qtnCI4NU++XuLtbmVy2sOIJNMAMBOnYFkWzbfyffeB9K5MoivLJoxLXTQafOsHAWIqR8hasqTlqB3rRkQzVd5FrasaojrDt8utYTxwhHPHgcVJaDOEA92oM/Y0zOIWNK+zHf88OHJFgqXTU5ByO6opbrlaTZoH486cYi9kM99xDVjipj01SfuMKCU+Q/oST8UANPSd+Rr61Da2hBv3KdXwXLpDwlxHzljGn+Xno7LM47CJOu0AsUk3eF6ByZQbT4SYnnCyWVRNy2pRnVvDOz+KMx9AtC1vTiAcrmG/pRIZCMLeA23BQkVomqzn57oGvc7p1P5ouKBd57p29yLajf0/b2FVcdoFL9z9OxbZWfDOTLMynadocJbH/bn74xjI7xRLOVILCl5/E7XPTf22Bvceewjs/h29THbFHHkFaNmVVfkbdlRxqFjQEf1GO6ucPbiuKotxMJa6b3Hh9gOWfnsRl5nDLAskvPMnRQj2zaQgmlmg6f5Ib9x5Bc7vQBAjAmUnRdvkUtuFmesc+bKcbLZ2i4dJpaoavYft8LG/ZgTu+gi0F05t7MCuqcGrgGejHOzBA37YDZMIVGLJAw/gNPIvz5FZS2JpGmV5gU995VusaIZUhEF+i0NLEyKFHOBYLEx24wpePfRukJJxaYbmmmazbS0EK6uZGcNhFdK+HXFWUpaxAFovUzI6iOXXGNm3jVOd9VCQXaIhPE50eJuyUFLMml2q20N/cQ7iQpHF1hlpnHmkWcPf340+vMlO7iYmvf5OWpTGWdB9nK7rJhCOYozPULk6w9WAr0dE+DjzWTW45Re9Tp6m6fA68Bm3/+ht4dm2maMPRkdICisPNaiGFoigfjEpcN3n9WpobPzmPkU2z0tBKaGyIPjNIS08t9aN9TDzwCIsFBwsJm0IiQ/O1syBtbnQfIOPyIk2TmpkR6udHsb1eFru2U56PUzfRD5pATyQJTYwhUynmwjX0t+1Eczpomh0kmE0iTRMjuUoovUKxvh7DNglPjDK04x4WHAE8yRhDlZvIxdKEY/M8duFH+DIJTLeHgsfLjx/7xxTzRR49+ldUxObpbdnJ//mZPyEZjHDP9Zf59Lkf0zbVx0IwytXWnYw2dmOhEc0usWXoHIaZZc4fZT5SS8hMISwLI+KnprmcyPANCn0jOHTBzJZd7D/UzMlgB88525lMadQFoLsCfmuzZODV69T+8Ad4l5dI5W0SWYm7rpLgtlY2/dPPo0dCTKzCGxPwYAtU+z/Ei0BRlA1PJa6bTH7nJ2TPXMEV8mOZRUZXbPLxFBXpFbKhMEV0UqEIVjiMy1NaWefNZ/BMT+IdHMAVjyMdDvJuDzlLkLU0slInrbtZqKhnomULuUiEiJ2jaWYIT3yJ2Ug9ObeH5oHL+FcXsR1O8m4vC54IKeHgmXt/h6+88h0aFsexiha+TIpgPkEwFSceKsdpW6x4I8xV1lO+PEPdyhQpfxn9O++ldX6I6NI0iaxF1tbw5tKcb9/P+T0PI3xeWmYG6R44Q+30MKPVm+jtOkA6UMaS7iNTV89D5hBNl04RGBxgvqyGltlBhrr2cmLPES55m9g5cJpGc4WacidNjgye8THovYG/kMaqrKBw7wGGm7bwuc9uwuN3lwr3WvDiMPhccH/Tu58FVhRFeT8qcd1k+KnjLJ++TqahkaLUODOn0dkSYKzoxeHQ0GyLypVZgkvz6GYO4glEMknR4WJ60xaGt+wnFqlGF+AQkqAwqVmepHrgKs5CHtf0DFnTIu4vZ7qhg8qZUSITQ6QdHnION7bTxUq4CsPMsWvgFImdu4ieO4WdybIQqsGKhDGwcS0v40onCCZXmK9uJldeTtP8KEZqFduS4HJSdHuYdQSwsyYCiUezSdY1MVPRgJ3KkHZ4aJ7qZ7msmu898vskHD5qlia4//oxNk9dQ7MtdGkjnQ5wOtEKRX760NdZ7OxBZDM82HeM4L4tyIlZqsf6ENOz6GOTsGsr9v33UNi5ndPzDj7TDhXe0g6/C2k4NwOfaiudUxRF+XX8qonL8f5N1i+tYOJYjeObnGHcW423bSsTcUmPbxFnIoPIZLCSGexcloIUFFqbyW/ZAt0d1Ht0ulxgFHI4+4cwx6ZYzcGgu5pBPULD4CWubrmX5VAlPedf4e6f/hWW14e5aRONq/O4Kp1ksEhNX8VKpIl5y+h65nukPX4ubX8AkHSO9uIxU/iyGUyfn+uf/gLl8XmqBvvJp3LEQjWc2PMofcFGaicHaVsYYmb3FjqunyGQinHFU89IuJOgkWTLyAX+7t7fZqqymQcvvED3Uj+aQyNn+Dlx8Em6e4+T8AbJ6y56ew4Rb24nb0HNxBifWr1CMOql6ukf4N3SyuLsAjFfOfI///fUVrip9sFzQ/BER2n7r6kEpAvgd8JXt6pRlqIoH60NPeI69r88g/7sC0gE2VwRfyGD5fawWllDzhsg7wtgOV248lmMXIqsP8x8TTM5t5eKqREqJobxJmJcb9rKdHkTjbND9AyfBQQakvDqIpplIYCCw4kEjHyGvMtDxhtAQ6LppcK2NjDQ0oNTFmgfvVJ6nkzXcRRNioaHXGUl+vwiNgKhaYxt3sWoswJLaIRTK5geH1O+KtonrnG5Yx+Xeu6nc/gi5YvT1K1MsBKuZsfgGbz5DCON3bg0iXC7CAoTd0UQPVqB3tqA3LcLv6GRNyU3vv8GrclpKufGyYcjSKcT19wsjm9+kV2H25ESppIwvAI7ayD07uewFUVRbpuaKrzJqf/jZ2SffRVWkxQtCwMbI5PClU0jkBR1BwXDS84bQOoavtUVtEIBywZb0xBSknd7CaVXcBWLSCGxDA+Wy03e62cp2kB4eoy820cotYJTSAoeH9I0ccdiSNvGFhpFrw8RjRCMLyOyOXKhMGlbxx1bQegaRbeBns1R9Hmx8kUuN2xnuqIRy+fHbWaRUrBsBInEF0hG6whoFu0Lg+iRABULUxSFg6Ao4Otpp/reLYifvYzMmwhdQ7hdOOuqkK1NxBZTDC1YxJIFgvPTdCwO4w0ZOFsbEEi8R+6lsLWboKGGUIqifHTUVOFN2pNTpOITzKcFXoeNcLvR68oIhmooFooUZ5ax0xmsRAYrb2I5nAhdx2XlcBpuRDhEviixdB9FNDK+ACkjSM5l4MqmqevvxSGLsLpCLFTBVF07zlyGurlRnIaB060jikU88WW0lTlsIbA1BznpxGOZ2F4fKeEkpRlc67yLJV8lsxUNHBw+QdXKLKuFcgoeLzFvGR1Lw0TLnBjLvdgSNMNF08XX8VWXEdrThXA6cFR6sM5fQtvShnA4yDtc5KqiLL50msW5cYxEnE5HDiObQtpFjCfuRSAxdmzGtbZL8q3rfyiKonx8fCiJSwhxBPhzQAe+LaX803e87gb+K7AbWAa+LKUc+zB+9y+z4q9grH4L044QZS6L6mIS38oisVie1VAFK+1N+OZmMdIJzGCIvCWQmQzDjVspX5omujiJBC53HWaioYugmWTr0DnahntxFvLYEpbLqohX1lI/M8SeC0dBCCy3gak7KSZySCmZbdzMUMduPB6drvFeguNjFCydVcvJRGUzcV+EmvgcW6eukxr1kvf6qLFzeJYGyTg9hFMxPLZJ0efDUR7C43Nhuf3M/em/p5jMUPbssxSjUUxfAN12IlccpPfvIzg9TtnfvYh7Vzd7miO4OprIHDuNnUijV4Rxdbfh7ul8300uFUVRPk5ue6pQCKEDA8DDwBRwFviqlPL6TW3+ENgupfwDIcRXgCellF9+v8++3anCp77fx8UfXaQ1NkZlepmUy8eivxJ/Kkbz7BDeXIKsJ0DO8OGyCvhzSYTTgdMqkmpuJb3vLspcNpUnXkdMzyLzJgXDR6KuHvfSIp5CHpFMIleTSF0nXl2HnsqgFQsUNQeW0EAIDMvEchtkHG5GypoYiTRxpWkHS/4KGpcm2Dt2lmIoRCZSgS5tKpamcWfTWLqDLUPnkE31aPt24PIZuA0nbq8LQ7MR5y4hZ+bxPX4/jvIwwuUAXceOJbCTaYozi5T9d/8A4XKS/N5z5M5dw3NwF94H7lKllxRF+dj4yO9xCSEOAP+TlPKRteNvAUgp/8NNbV5Ya3NKCOEA5oBK+T6//HYT17/8nR9wb+9LeD0OcDrwiiIeLNxOgbetDn9HA1p5GHs5jq1paI216IYTmUiTefEE5rVhLNMily+SdRpIIXDYNrlQiJzTgyOfo79miAAAD4hJREFUw7Ykx3c/gm9hjrqlCUy3gT+TYLGyHpdDQ7MsRgK1xIWHH+94HOFwELEzbJm6ys7xC7ixGI+2YRg6rUELf8DNwpYdlA3303P2KI3//KsITcPO5pDpLHYmi3l9hHxvP1o4SMWf/jM0l5PC2AzZ189h3LUNvbqc1DMv47lnJ5lXz2JeG8Zz+C78j9+6yLCiKMqddCfucdUBkzcdTwH73quNlLIohFgFyoGld36YEOL3gd8HaGxsvK2OrbR28iI6GV8Id3MdbTvrqGsI4tJBFApos3N4Tp7G2zuKnkxRsC5hFiwKNhiZNKu+OrKRACIYwCuKuJKrZISL4NwU8w4faX8lY01b2LQ0gW24KBhe3th1hInu3Ww//zLD4Ub6arpAQvfCIF+cO8GcM8yesy+wUNXI8QOPUdNeTWfUQcRb2rg3ls6x82/+kuamAOV/8S2Err8Vj53Oknr2VbSKMnyP34+xbzvJ7z+PvRTD2d1K4OuPU5xbYuXf/Rf0uioSf/scnn3bCH3js2oXYEVRNowPY8T1BeCIlPK/WTv+OrBPSvlPbmpzda3N1Nrx8FqbdyWum93uiGvp1Yt8989O4Dcz+FJx/NlVfMIiVMzgLOQp6E6S/jAZfwgdcFsmgdgium1R1DQ000RYNslghEub9nKqcS+bpm6w0tBCt7nA6IHD1Bx7Di2TpX5lkpEHHsUxMUF9fy9DtZ1MB6qZDUapis/jySZpnu6nKhNj9k/+BZua/ET94NbBlrCUAS700n30Gar+8EsYO7reFkvu4g3y14awUxmE04keCWItx8n3DuKoimAtxXA01mBeG8bd04mxazPG/u1v23tMURTl4+hOjLimgYabjuvXzt2qzdTaVGGI0iKN3yiPLLC/0uTyrI5eW4U/rZOIZYkZIXIuDy5d4MPEL01SgTLysQJL9Z2shKOMNHaz6I2QsBy4c2ke6H+Nf/bSn5GPRknNXiSJi9bXn2euuRNTc/Dc7ifwjy9RP73Acx2HcRfyVKaW6VgcYrxzJ0FPiAfu8dK5tRJ7sRdf191I28ZajFGYnifxX5/Fml/C98QDFIYnMYcmSluuDE+WpgW9HuxkBtfWNtybW0GCI1pB6D99rrSCcHCc1N/8jODXHsVzeJ+aElQUZcP6MEZcDkqLMx6klKDOAl+TUl67qc0fAdtuWpzxeSnll97vs293xPXq5QQXT0+Rffk0/lwKMxCkwZVnXAtjpJME0zGybi9zVY14CnkmNm1nvHYTyYKGpkHIJelITtJ1/AVC81PM1rQQnR3HdjgIpmKs+sJ4zAyz5Q0kpU5W83Bl1/3U1/mpr/USCLjQkLT/7Ck228sEDu/Fs7+H/LUhCgPj4ChNA2beOI/m92L0dL5VhsJaXqUwNo2ruw1XVwuFoQn8Tz6IcLuQ6SzSttGDfuxcnvRzb6BHwngO7VYJS1GUdecjH3Gt3bP6J8ALlJbDf0dKeU0I8W+Bc1LKHwN/CXxXCDEErABfud3f+0FU/uhpDj33BtOBauZ1P96VLGkB7f4kc81dnKx5iFRFNcsYFGyo8ECjnicyOkjzxZOUL09TvjxH2hMg4fbjSa4ysmUvhqFzbvc95PrHcY+NETVjtIYF2w538qmhS4y5N5MKhKnKx9l2/SR1X9iNu7uNwuQcqWeOIQw3zu420s++SnFsGr2mguA/eAJnawP2cpz00dM4W+sx9m3D7BtFZnIEvvLpt5KS8JcKA+avDZG/cAPfZw6hR0IfxZ9UURTljtvQlTNOvDHJ3KkblDltXpj34dvWjq85ykwKynJxjLl5rOk5srEMFZklKiZH8KdiGA6B7XJimRYIQaarE/v+gwi/j8uLGn0pF6JY5Eh+gCP7ylmdWGAg6SCbLlK8axcd41epO/kKDpeOe992dN8vHuu1Uxly569jx5M46qsQDgfu7R1Iy8ZajiMLRTSfB3QNY1sHrp4OsGzseBIrlsCOJUrfV1ZxtjVgHOhRoyxFUdY1VfLpJrlLN0hen2Apr/PamGQ6IRGZLJ3FRTzLCwQTKxgUcbt00lLDzpQeKjbdXtAg39TM4v6D5K8Poy8vM1zRhjPg4Z7kAHW9b5IWLhZrm6G7g6agJLI4jfvKNaQA78MH8OzsxhwcRzh0ZN7EHBxHFm30qjLsuWUK0/M4ohWABE1DC3jRK8pwttQDIAtFpGkidB09EkIrC6KXBdHCQYTHrRKWoigbgir5dJNT6TATo1NUZZa5a3WOhYsjNLpyrOYkDq8bUSxgZ3JYtsTtdmNXV+FxCOygn8n6Ts5VdjNulbPNM0+7a479oy+gmzky1bUUOttp2buJ7tlprPnL2ANJtPIy3I8dAqAwtUDm6Glce7rJnbhE/sYQejiAXhZCZnI4muso+5PfQw8H7vBfSVEUZX3Z0Ilr93QvHVdeQloSC8FMWzOvEaTaaZJ3exDVFWyqchKwTZZSNn2Ucc1Th5FJU5ZfpWfwLPeZKcS2DnIPfZZsuUHL9XN4VxYx7t6Jo7oCLXwYO5kh9dPXcG9qxM7mkdkcwjAozi6QfuoljP3bCP/hl9HLw2qUpCiKcps2dOIadFXRt+leVsNVOF06Hb0ncaZNIjuaSBt+JoK1/LjggolptNVVPI4MLY1J/HuaqawPUxuAphCUpWNkX3sTkXThfexuNP8vdk2UUpL+/vOEvvYZ9ZCvoijKR2BDJ66m1jB186NoQyewaqqZ/ObnGR51M7WyhGd8GrM4hF5eTsX+zdy/K0x7RMLsAoX+fqzLMQBkwSJfHsL32P1onndvSJV9+U08B3eqpKUoivIR2dCJKzM0zcylKeYbukkXDTzXZjgY9HI+FObA4w9ysFlH025+h4D6KM766FtnpJTvOb1XnFvCTmdwtd1eaSpFURTlg9vQiSt2z0FmNh+kpQzayiC4NmD6zK/wGe+VtKRtk37hBMHf/lU+TVEURbldGzpx9URLX78JmZdO4X3gLlULUFEU5SOmvX8T5Z2K0/NgWTgba+50VxRFUT5x1HDhV5S71Id5Y4TAlx65011RFEX5RFIjrg/ITmdJfO85KBQJfvXRt+2TpSiKonx01IjrA8hd6sO8Poz/iQfe9gyXoiiK8tFTieuXKC7GyBw7jautgeDX1OpBRVGUjwOVuN5BFovkzl2nMDKJXhnB//j9pWrtiqIoyseCSlxritPzZM9cgWIR9+4tGPu2qbqCiqIoH0Of6MRlpzJkT13GWlzBUVuF79F70Yx3l3VSFEVRPj4+cYlLWhb53gHM/jE0r4GxvwdHVeROd0tRFEX5gD4xias4t0TudC8yb+La3kHgy0fUVKCiKMo6tKETlzQL5M5epTAxi6OqHO/DB9RCC0VRlHVuQyeuwuQcjroonnt23umuKIqiKB+SDZ24XG0Nd7oLiqIoyodMlXxSFEVR1hWVuBRFUZR15bamCoUQ/yvwOGACw8DvSinjt2g3BiQBCyhKKffczu9VFEVRPrlud8T1ErBVSrkdGAC+9UvaPiCl3KGSlqIoinI7bitxSSlflFIW1w5PA/W33yVFURRFeW8f5j2u3wOee4/XJPCiEOK8EOL3f9mHCCF+XwhxTghxbnFx8UPsnqIoirIRvO89LiHEUaD6Fi/9Gynlj9ba/BugCPz1e3zMQSnltBCiCnhJCNEnpXz9Vg2llH8B/AXAnj175AeIQVEURfkEed/EJaV86Je9LoT4JvAY8KCU8paJRko5vfZ9QQjxNHAXcMvEpSiKoii/zG1NFQohjgB/Ajwhpcy8RxufECLw85+BTwFXb+f3KoqiKJ9c4j0GSR/szUIMAW5gee3UaSnlHwghaoFvSykfFUK0Ak+vve4A/kZK+e8/4OcvAuO/dgdLKoCl2/yMj6ONGheo2NarjRrbRo0LPj6xNUkpKz9o49tKXOuBEOLcRlyCv1HjAhXberVRY9uoccH6jU1VzlAURVHWFZW4FEVRlHXlk5C4/uJOd+A3ZKPGBSq29WqjxrZR44J1GtuGv8elKIqibCyfhBGXoiiKsoFs2MQlhDgihOgXQgwJIf7Vne7P7RBCfEcIsSCEuHrTuYgQ4iUhxODa97I72cdflxCiQQjxihDiuhDimhDij9fOr+v4hBCGEOJNIcTltbj+57XzLUKIM2vX5feFEK473ddflxBCF0JcFEL8ZO14Q8QmhBgTQlwRQlwSQpxbO7eur0cAIURYCPGUEKJPCHFDCHFgvca1IROXEEIH/i/g00A38FUhRPed7dVt+f+AI+8496+AY1LKduDY2vF6VAT+uZSyG9gP/NHav9V6jy8PHJZS9gA7gCNCiP3AfwT+NynlJiAG/MM718Xb9sfAjZuON1Js79zNYr1fjwB/DjwvpewCeij9263PuKSUG+4LOAC8cNPxt4Bv3el+3WZMzcDVm477gZq1n2uA/jvdxw8pzh8BD2+k+AAvcAHYR+lhT8fa+bddp+vpi9JOEMeAw8BPALGBYhsDKt5xbl1fj0AIGGVtXcN6j2tDjriAOmDypuOptXMbSVRKObv28xwQvZOd+TAIIZqBncAZNkB8a1Npl4AFSnvXDQNx+YutgNbzdflnlMq92WvH5Wyc2G61m8V6vx5bgEXg/12b3v32Wgm+dRnXRk1cnyiy9L9L63p5qBDCD/wQ+G+llImbX1uv8UkpLSnlDkqjk7uArjvbow+HEOIxYEFKef5O9+U35KCUchelWw1/JIQ4dPOL6/R6dAC7gP9bSrkTSPOOacH1FNdGTVzTQMNNx/Vr5zaSeSFEDcDa94U73J9fmxDCSSlp/bWU8u/XTm+Y+KSUceAVStNnYSHEz3dlWK/X5T3AE0KIMeB7lKYL/5yNERvypt0sKNVZvYv1fz1OAVNSyjNrx09RSmTrMq6NmrjOAu1rq5xcwFeAH9/hPn3Yfgx8Y+3nb1C6N7TuCCEE8JfADSnlf7rppXUdnxCiUggRXvvZQ+m+3Q1KCewLa83WXVwAUspvSSnrpZTNlP7bellK+dtsgNh+yW4W6/p6lFLOAZNCiM61Uw8C11mncW3YB5CFEI9SmofXge/ID1iR/uNICPG3wP2UKjnPA/8j8AzwA6CRUgX9L0kpV+5QF39tQoiDwBvAFX5xv+RfU7rPtW7jE0JsB/6K0vWnAT+QUv7btd0SvgdEgIvA70gp83eup7dHCHE/8C+klI9thNjeazcLIUQ56/h6BBBC7AC+DbiAEeB3Wbs2WWdxbdjEpSiKomxMG3WqUFEURdmgVOJSFEVR1hWVuBRFUZR1RSUuRVEUZV1RiUtRFEVZV1TiUhRFUdYVlbgURVGUdUUlLkVRFGVd+f8BTHr4OBxrzc0AAAAASUVORK5CYII=\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAn4AAAHeCAYAAAAFJAYTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddXhUV/rA8e+4xF2QJETRBIK7W9FSKNAWKXUXavvb6u5Wtr4VKBVaqFGktLi7RwgSSEIEAsR9Mj5zf3+kTJsmaNFyPs+Tp50r5565czN5OfIemSRJEoIgCIIgCMLfnvxaV0AQBEEQBEG4OkTgJwiCIAiCcJMQgZ8gCIIgCMJNQgR+giAIgiAINwkR+AmCIAiCINwklNe6AoIgCIIgXB0OhwObzXatqyFcISqVCoVCcc5jROAnCIIgCH9zkiRRWFhIZWXlta6KcIV5e3sTHByMTCZrdL8I/ARBEAThb+5M0BcYGIherz9rUCDcuCRJwmg0UlxcDEBISEijx4nATxAEQRD+xhwOhyvo8/Pzu9bVEa4gnU4HQHFxMYGBgY12+4rJHYIgCILwN3ZmTJ9er7/GNRGuhjOf89nGcorATxAEQRBuAqJ79+Zwvs9ZBH6CIAiCIAg3CRH4CYIgCIJw05o2bRpjxoy51tW4akTgJwiCIAjCdWnatGnIZDJkMhkqlYqIiAieffZZzGbzta7aDUvM6hUEQRAE4bo1dOhQ5s6di81mIzk5malTpyKTyXjrrbeuddVuSKLFTxAEQRCE65ZGoyE4OJhmzZoxZswYBg4cyLp16wBwOp288cYbREREoNPpiI+PZ9GiRa5zHQ4HM2bMcO2PjY3lww8/vFZv5bogWvwEQRAEQbghHDp0iJ07dxIWFgbAG2+8wbfffsvs2bOJjo5m69at3HnnnQQEBNCnTx+cTidNmzZl4cKF+Pn5sXPnTu677z5CQkKYMGHCNX4314Zo8RP+Nl555ZVrkq5g/vz5xMXFoVKp8Pb2vurXvxh9+/alTZs2V+Va06ZNw93d/apc668KDw9n2rRp17oaLtfqWb4Z/dV73bdvX/r27et6nZeXh0wm4+uvv/7rlTuPr7/+GplMRl5enmtbeHg4I0aMuOLXBqipqSEpKYmamporep3ly5fj7u6OVqulbdu2FBcX88wzz2CxWHj99df56quvGDJkCC1atGDatGnceeedfPbZZ0Dd2rWvvvoqHTt2JCIigjvuuIPp06fz008/XdE6X89Ei98NZPPmzfTr16/Rfbt27aJr1671tu3cuZNnn32WlJQUPD09mTBhAq+//voN88f4RnD06FGmTZvG0KFDef7552+6BKlGo5H//ve/Df74CcLldPr0aebMmcOYMWNISEi41tW5Ij799FP0ev1V/wfIvDSoOM88idpaAzKZDL3erd52q1VNaWkI/iY1avWFXc9HC1PiL66O/fr1Y9asWdTW1vL++++jVCoZN24chw8fxmg0MmjQoD/Vy0r79u1drz/55BO++uorTpw4gclkwmq1/m2fowshAr8b0GOPPUanTp3qbYuKiqr3ev/+/QwYMICWLVvy3nvvcfLkSd555x2ysrJYtWrV1azu39rmzZtxOp18+OGHDT6Dm4HRaOTVV18FEIGfcMWcPn2aV199lfDw8Ov+D3ZYWBgmkwmVSnVR53366af4+/tfVOB31113MXHiRDQazUXW8ncXEoQdPnwcpVJJbGxsve2SpEaSQn6bdXvJVTgvNzc31/frV199RXx8PF9++aWr92LFihU0adKk3jln7smPP/7IzJkzeffdd+nWrRseHh68/fbb7Nmz58pV+DonAr8bUK9evbjtttvOecw//vEPfHx82Lx5M56enkBdF8C9997L2rVrGTx48NWo6t/emcWwz9fFK0kSZrPZtY6icPHEPfz7MJvNqNVq5PK/32gjmUyGVqu9oteora3Fzc0NhULR6FqsV8uZNCtXk1wu5x//+AdPPfUUmZmZaDQaTpw4QZ8+fRo9fseOHXTv3p2HHnrItS07O/tqVfe69Pf7rbtJ1NTUYLfbG91XXV3NunXruPPOO11BH8CUKVNwd3e/oLENH330Ea1bt0av1+Pj40PHjh35/vvvXfuPHz/OQw89RGxsLDqdDj8/P8aPH19vrAn8PgZl+/btPPbYYwQEBODt7c3999+P1WqlsrKSKVOm4OPjg4+PD88++yySJLnOPzNe5p133uH9998nLCwMnU5Hnz59OHTo0AXdq2+//ZbExER0Oh2+vr5MnDiR/Pz8esdkZWUxbtw4goOD0Wq1NG3alIkTJ1JVVXXWcsPDw3n55ZcBCAgIQCaT8corr7j2jRgxgjVr1tCxY0d0Op1rzElOTg7jx4/H19cXvV5P165dWbFiRb2yN2/ejEwm46effuLVV1+lSZMmeHh4cNttt1FVVYXFYuGJJ54gMDAQd3d3pk+fjsViuaD7AZCcnEz37t3R6XREREQwe/bsBscUFxczY8YMgoKC0Gq1xMfH880337j25+XlERAQAMCrr77q+iNw5h6ccerUKcaMGYO7uzsBAQHMnDkTh8Nx3jqe6x5WVlbyxBNP0KxZMzQaDVFRUbz11ls4nc56Zbzzzjt0794dPz8/dDodiYmJ9Wb8XawLLU8mk/HII4+wdOlS2rRpg0ajoXXr1qxevbrBsdu3b6dTp05otVoiIyNd7/FCffLJJ7Ro0QKdTkfnzp3Ztm1bo13vFouFl19+maioKDQaDc2aNePZZ59t8NxcTN1PnTrF3XffTVBQkOu4r776qt4xZ57lH3/8kX/+8580adIEvV5PdXU15eXlzJw5k7Zt2+Lu7o6npyfDhg0jLS2t3vlnejimT5/ues7+OIZuz549DB06FC8vL/R6PX369GHHjh2X/V7PmTOHyMjIevf6zxob41dYWMj06dNp2rQpGo2GkJAQRo8e7fq+DA8P5/Dhw2zZssX1/s58fme+Q7ds2cJDDz1EYGAgTZs2rbfvz9+7AGvXriUhIQGtVsstt9yC0Wist//06dMkJSU1OK+0tJSkpCTXc3HgwAFMJpNrPF9SUhIZGRnA2cf4lZeXk56eTnJyMvv37ycnJwer1VrvmNzcXFJSUrBarRw7doyUlBT2799Pfn5+vb8BUNd1m56eTkpKCikpKbRq1QqZTMZnn33GzJkzefLJJ/nmm2/Izs4mJSWFjz76yPVdFR0dTVJSEmvWrCEzM5MXX3yRffv2NXjfNxPR4ncDmj59OgaDAYVCQa9evXj77bfp2LGja//Bgwex2+31tgGo1WoSEhJITU09Z/mff/45jz32GLfddhuPP/44ZrOZAwcOsGfPHiZPngzAvn372LlzJxMnTqRp06bk5eUxa9Ys+vbtS3p6eoOxbo8++ijBwcG8+uqr7N69mzlz5uDt7c3OnTtp3rw5r7/+OitXruTtt9+mTZs2TJkypd758+bNo6amhocffhiz2cyHH35I//79OXjwIEFBQWd9L//5z3948cUXmTBhAvfccw8lJSV89NFH9O7dm9TUVLy9vbFarQwZMgSLxeKq56lTp1i+fDmVlZV4eXk1WvYHH3zAvHnz+Pnnn5k1axbu7u60a9fOtT8jI4NJkyZx//33c++99xIbG0tRURHdu3fHaDTy2GOP4efnxzfffMOoUaNYtGgRY8eOrXeNN954A51Ox/PPP8+xY8f46KOPUKlUyOVyKioqeOWVV9i9ezdff/01ERERvPTSS+f8bAEqKioYPnw4EyZMYNKkSfz00088+OCDqNVq7r77bgBMJhN9+/bl2LFjPPLII0RERLBw4UKmTZtGZWUljz/+OAEBAcyaNYsHH3yQsWPHcuuttwLUuwcOh4MhQ4bQpUsX3nnnHdavX8+7775LZGQkDz744Hnr2tg9NBqN9OnTh1OnTnH//ffTvHlzdu7cyQsvvEBBQQEffPCB6/wPP/yQUaNGcccdd2C1Wvnxxx8ZP348y5cv55Zbbjnv9f/sYsrbvn07S5Ys4aGHHsLDw4P//e9/jBs3jhMnTuDn5wfU/a4OHjyYgIAAXnnlFex2Oy+//PI5n+k/mjVrFo888gi9evXiySefJC8vjzFjxuDj4+MKDqAu5cWoUaPYvn079913Hy1btuTgwYO8//77ZGZmsnTp0ouue1FREV27dnUFigEBAaxatYoZM2ZQXV3NE088Ua/Mf/3rX6jVambOnInFYkGtVpOens7SpUsZP348ERERFBUV8dlnn9GnTx/S09MJDQ2lZcuWvPbaa7z00kvcd9999OrVC4Du3bsDsHHjRoYNG0ZiYiIvv/wycrmcuXPn0r9/f7Zt20bnzp0vy73+8ssvuf/+++nevTtPPPEEOTk5jBo1Cl9fX5o1a3bOc8+MR3v00UcJDw+nuLiYdevWceLECcLDw/nggw949NFHcXd35//+7/8AGtTroYceIiAggJdeeona2tpzXi8rK4vbb7+dBx54gKlTp7Jq1SpKSkowGAwX3RrZrFkz8vPzkcvlhISEAJyzG7u0tJS8vDzc3Nxo2rQpNpuN4uJiDAYDrVq1Qqn8PeyQJImsrCzXsTU1NRQVFaHRaAgMDATAZrNhNBpRKBSuZ9psNnPHHXfw3//+l9zcXAICAnjjjTfIycnB29ubDh068I9//AOA+++/n9TUVG6//XZkMhmTJk3ioYceurmHPEnCDWPHjh3SuHHjpC+//FL65ZdfpDfeeEPy8/OTtFqtlJKS4jpu4cKFEiBt3bq1QRnjx4+XgoODz3md0aNHS61btz7nMUajscG2Xbt2SYA0b94817a5c+dKgDRkyBDJ6XS6tnfr1k2SyWTSAw884Npmt9ulpk2bSn369HFty83NlQBJp9NJJ0+edG3fs2ePBEhPPvmka9vLL78s/fGRzsvLkxQKhfSf//ynXj0PHjwoKZVK1/bU1FQJkBYuXHjO99yYM9csKSmptz0sLEwCpNWrV9fb/sQTT0iAtG3bNte2mpoaKSIiQgoPD5ccDockSZK0adMmCZDatGkjWa1W17GTJk2SZDKZNGzYsHrlduvWTQoLCztvffv06SMB0rvvvuvaZrFYpISEBCkwMNB1rQ8++EACpG+//dZ1nNVqlbp16ya5u7tL1dXVkiRJUklJiQRIL7/8coNrTZ06VQKk1157rd729u3bS4mJieet69nu4b/+9S/Jzc1NyszMrLf9+eeflxQKhXTixAnXtj8/p1arVWrTpo3Uv3//BteaOnXqeet0oeUBklqtlo4dO+balpaWJgHSRx995No2ZswYSavVSsePH3dtS09PlxQKhXS+r2eLxSL5+flJnTp1kmw2m2v7119/LQH1fo/mz58vyeXyes+dJEnS7NmzJUDasWPHRdd9xowZUkhIiFRaWlqvzIkTJ0peXl6ue3XmWW7RokWD+2c2m13P/Bm5ubmSRqOp99zs27dPAqS5c+fWO9bpdErR0dENvl+MRqMUEREhDRo0yLXtr9xrq9UqBQYGSgkJCZLFYnFtnzNnToN7feY760xdKyoqJEB6++23z3mN1q1b1yvnjDPfoT179pTsdnuj+3Jzc13bzvzeLF682LWtpKREWrt2rXTw4EHXtlOnTkn79u1rcL2SkhJp3759ktlsdm07dOiQdPTo0QbHVldXS/v27XN9HzgcDik1NVU6dOhQvc+1oqJC2rdvX73v8JycHGnfvn3SqVOn6pV5+PBh6fDhw67Xx48fl1JSUup9vsK5mUwmKT09XTKZTI3uF129N5Du3buzaNEi7r77bkaNGsXzzz/P7t27kclkvPDCC67jTCYTQKMDfrVarWv/2Xh7e3Py5MlzNof/cZyVzWajrKyMqKgovL29SUlJaXD8jBkz6o0F6dKlC5IkMWPGDNc2hUJBx44dycnJaXD+mDFj6g3e7dy5M126dGHlypVnreOSJUtwOp1MmDCB0tJS109wcDDR0dFs2rQJwNWit2bNmgbdIX9FREQEQ4YMqbdt5cqVdO7cmZ49e7q2ubu7c99995GXl0d6enq946dMmVLvX9dn7tuZlrk/bs/Pzz9r9/8fKZVK7r//ftdrtVrN/fffT3FxMcnJya56BgcHM2nSJNdxKpWKxx57DIPBwJYtWy7gDtR54IEH6r3u1atXo59xYxq7hwsXLqRXr174+PjU+1wHDhyIw+Fg69atrmP/+JxWVFRQVVVFr169Gn1GL8TFlDdw4EAiIyNdr9u1a4enp6frvTscDtasWcOYMWNo3ry567iWLVs2eM+NSUpKoqysjHvvvbdeK8odd9yBj49PvWMXLlxIy5YtiYuLq3fP+vfvD+D6XbjQukuSxOLFixk5ciSSJNUrc8iQIVRVVTW4J1OnTm0wPlOj0bjG+TkcDsrKynB3dyc2NvaCPqP9+/eTlZXF5MmTKSsrc9WhtraWAQMGsHXrVpxO52W518XFxTzwwAOo/zB9ddq0aWftEThDp9OhVqvZvHkzFRUV573W2dx7770XPJ4vNDS0Xu+Bu7s7bm5umM1mbDbbJdfhfIxGI3a7nYCAgHrjN729vdFqtY0OnTkzXOSPdf3j8AOlUonD4aC6uvqK1ftmI7p6b3BRUVGMHj2aJUuW4HA4UCgUri/XxsZ8Xcjg+Oeee47169fTuXNnoqKiGDx4MJMnT6ZHjx6uY0wmE2+88QZz587l1KlT9cZkNPbL/ccvW/g92PpzF4mXl1ejX47R0dENtsXExJxzvGJWVhaSJDV6LvzeXREREcFTTz3Fe++9x3fffUevXr0YNWoUd95553m/1M8lIiKiwbbjx4/TpUuXBttbtmzp2v/HPHsXc9+cTidVVVWurrizCQ0Nxc2tflqGmJgYoG58UteuXTl+/DjR0dENBt//sZ4XQqvVNvhi9/HxueA/gI3dw6ysLA4cONCg3DPOTLiBuvxf//73v9m/f3+934dLHZB+MeX9+bOD+u+9pKQEk8nU6PMZGxt7zn/UwO+fwZ9nkyuVSsLDw+tty8rK4siRIxd0zy607pWVlcyZM4c5c+ZcUJmNfZZnZsR/+umn5Obm1hv7eb7nGOreF9QFlWdzZkzs5bjXfz5fpVLRokWLc56r0Wh46623ePrppwkKCqJr166MGDGCKVOmEBwcfM5z/6ix+3c2UVFRDZ7JM993FovlomccX6gz4/ga607WarUYDIZ6286sv/tHZwK9MwICAigvLycrKwuVSoWXlxc+Pj5/6bv5ZicCv7+BZs2aYbVaqa2txdPT0zUOo6CgoMGxBQUFhIaGnrO8li1bkpGRwfLly1m9ejWLFy/m008/5aWXXnKl7nj00UeZO3cuTzzxBN26dcPLywuZTMbEiRMbDLAHzvov1ca2S38a2HupnE4nMpmMVatWNXqdP+YzfPfdd5k2bRq//PILa9eu5bHHHuONN95g9+7d9cZKXYzLMfv0Yu4bXL57d7n81RmHjd1Dp9PJoEGDePbZZxs950wQu23bNkaNGkXv3r359NNPCQkJQaVSMXfu3HoTlS7UxZZ3PX1GTqeTtm3b8t577zW6/8//kDhf3c/8jt95551nDbr+ONYTGv8sX3/9dV588UXuvvtu/vWvf+Hr64tcLueJJ55o9Hvkz84c8/bbb581zcufW5CuhSeeeIKRI0eydOlS1qxZw4svvsgbb7zBxo0b6+WbO5erNZv9aj6fF/IPMJVKRatWraiurqaqqoqqqipKS0vx8/O7qGBY+J0I/P4GcnJy0Gq1rkCmTZs2KJVKkpKS6i1JY7Va2b9//wUtU+Pm5sbtt9/O7bffjtVq5dZbb+U///kPL7zwAlqtlkWLFjF16lTeffdd1zlms5nKysrL/v7g93/Z/1FmZmaDlo0/ioyMRJIkIiIiXMHAubRt25a2bdvyz3/+k507d9KjRw9mz57Nv//9779S9XrCwsJcM+L+6OjRo679V9rp06dd6SDOyMzMBHDdz7CwMA4cOIDT6azX6vfnel6L1SUiIyMxGAwMHDjwnMctXrwYrVbLmjVr6g17mDt37iVd93KXFxAQgE6na/TZbuwZ+bMzn8GxY8fqJXa32+3k5eXVC7wiIyNJS0tjwIABl+UzCwgIwMPDA4fDcd7P4VwWLVpEv379+PLLL+ttr6ysxN/f3/X6bHU+0x3t6el5znpcrnudlZXl6h6HumEuubm5xMefPxleZGQkTz/9NE8//TRZWVkkJCTw7rvv8u233wKX93fp2LFjSJJUr0ybzYZarXY9u2eCe7vdXm+owJ9n316MM93gZrO5XkaJM9vUF5rl+U/kcjne3t54e3sjSRInTpygpKSEkJCQK5465+9IjPG7gZSUlDTYlpaWxq+//srgwYNdf6C9vLwYOHAg3377bb1p9vPnz8dgMDB+/PhzXqesrKzea7VaTatWrZAkyTU+RKFQNPiX4UcffXRBaTouxdKlSzl16pTr9d69e9mzZw/Dhg076zm33norCoWCV199tUFdJUlyvc/q6uoGY+Patm2LXC6/7C0Fw4cPZ+/evezatcu1rba2ljlz5hAeHk6rVq0u6/UaY7fb66WxsFqtfPbZZwQEBJCYmOiqZ2FhIQsWLKh33kcffYS7u7srZ9aZ2dtXKuBvzIQJE9i1axdr1qxpsK+ystL1WSoUCmQyWb1nMi8vr8EM1gt1JcobMmQIS5cu5cSJE67tR44cafS9/VnHjh3x8/Pj888/r/f8fvfddw260idMmMCpU6f4/PPPG5RjMpnOO0u0sbqPGzeOxYsXN5pWqbHvqrOV8+ffzYULF9b7XQdc/0j583OWmJhIZGQk77zzToNuxD/W43Lc64CAAGbPnl0vMPr666/P++wbjUbM5vpLY0RGRuLh4VHv+8XNze2y/R6dPn2an3/+2fXaYDBQW1uLVqt1da2eCQD/eN/OjLP8M7lcfkHjh/V6PUqlkpKSknottlVVVZjN5kvqnv3zdWUymav1848t0GdW5BDOT7T43UBuv/12dDod3bt3JzAwkPT0dObMmYNer+fNN9+sd+x//vMfunfvTp8+fbjvvvs4efIk7777LoMHD2bo0KHnvM7gwYMJDg6mR48eBAUFceTIET7++GNuueUWPDw8ABgxYgTz58/Hy8uLVq1asWvXLtavX39B43IuRVRUFD179uTBBx/EYrHwwQcf4Ofnd9buPqj7cv33v//NCy+84Epz4eHhQW5uLj///DP33XcfM2fOZOPGjTzyyCOMHz+emJgY7HY78+fPd/1xu5yef/55fvjhB4YNG8Zjjz2Gr68v33zzDbm5uSxevPiqJLQNDQ3lrbfeIi8vj5iYGBYsWMD+/fuZM2eO64/Cfffdx2effca0adNITk4mPDycRYsWsWPHDj744APXc6DT6WjVqhULFiwgJiYGX19f2rRpc0XXA37mmWf49ddfGTFiBNOmTSMxMZHa2loOHjzIokWLyMvLw9/fn1tuuYX33nuPoUOHMnnyZIqLi/nkk0+IioriwIEDF33dy10e1OU/XL16Nb169eKhhx5yBdetW7c+b5lqtZpXXnmFRx99lP79+zNhwgTy8vL4+uuviYyMrNfac9ddd/HTTz/xwAMPsGnTJnr06IHD4eDo0aP89NNPrlyJF+PNN99k06ZNdOnShXvvvZdWrVpRXl5OSkoK69evp7y8/LxljBgxgtdee43p06fTvXt3Dh48yHfffddg3FxkZCTe3t7Mnj0bDw8P3Nzc6NKlCxEREXzxxRcMGzaM1q1bM336dJo0acKpU6fYtGkTnp6eLFu27C/fa5VKxb///W/uv/9++vfvz+23305ubi5z58497xi/zMxMBgwYwIQJE1zpTH7++WeKioqYOHGi67jExERmzZrFv//9b6KioggMDKzXungxYmJimDFjBvv27SMoKIiVK1fy1FNP1UsR4+npiVqtJi8vzzXWsLS0FKVS2SCA0uv1lJSUcPr0abRaLUqlskGLHtQFiGfSe2VkZODr6+tK56JWqy84dc4f5eXlYbfb8fT0RKVSYbVaKS4uRq/Xu1r7bDYbhw8fFt2/F+oKzigWLrMPP/xQ6ty5s+Tr6ysplUopJCREuvPOO6WsrKxGj9+2bZvUvXt3SavVSgEBAdLDDz/smnZ/Lp999pnUu3dvyc/PT9JoNFJkZKT0zDPPSFVVVa5jKioqpOnTp0v+/v6Su7u7NGTIEOno0aMN0mKcSTfw57QBZ0uDMnXqVMnNzc31+kxqhLffflt69913pWbNmkkajUbq1auXlJaW1miZf7Z48WKpZ8+ekpubm+Tm5ibFxcVJDz/8sJSRkSFJUl1agbvvvluKjIyUtFqt5OvrK/Xr109av379ee/VudK53HLLLY2ek52dLd12222St7e3pNVqpc6dO0vLly+vd8yZFBh/TjFzsffzz/r06SO1bt1aSkpKkrp16yZptVopLCxM+vjjjxscW1RU5PqM1Wq11LZt2wbpNCRJknbu3CklJiZKarW6XmqXP3+Wf67r+ZzrHtbU1EgvvPCCFBUVJanVasnf31/q3r279M4779RLf/Pll19K0dHRkkajkeLi4qS5c+c2ev0LTedyoeUB0sMPP9zoe/rzdbZs2eK6fy1atJBmz559wfdIkiTpf//7nxQWFiZpNBqpc+fO0o4dO6TExERp6NCh9Y6zWq3SW2+9JbVu3VrSaDSSj4+PlJiYKL366qv1frcvpu5FRUXSww8/LDVr1kxSqVRScHCwNGDAAGnOnDmuY872LEtSXTqXp59+WgoJCZF0Op3Uo0cPadeuXVKfPn0apDb55ZdfpFatWklKpbJBapfU1FTp1ltvdX1nhYWFSRMmTJA2bNhQr4y/eq8//fRTKSIiQtJoNFLHjh2lrVu3Nqjrn9O5lJaWSg8//LAUFxcnubm5SV5eXlKXLl2kn376qV7ZhYWF0i233CJ5eHjUSxFztt/5P+77czqXW265RVqzZo3Url07SaPRSP369ZOSkpIapPcwGAxSenq6lJSUJKWlpUmFhYWNpnOxWq1SZmamlJycLO3bt8+V2uXP6VzOKCsrkw4fPiwlJSVJqampUnZ2dr00OJJU972bnJzc4D39Oc1MeXm5lJGRIaWmprrqmZeXV+/33Gw2S/v27ZNycnIalHczOl86F5kkXWejwQXhD/Ly8oiIiODtt99m5syZ17o6gnDdczqdBAQEcOuttzbatSvcfMxmM7m5uURERIgxcTeB833eYoyfIAjCDcpsNjcYIzdv3jzKy8sbLNkmCMLfx5kVXy6FCPwEQRBuULt376ZDhw68/vrrfPbZZ9x///3cc889tGnT5ryTuAThRlJYWMjjjz9OVFQUWq2WoKAgevTowaxZsy5r4v0r6a8Ea5eTmNwhCIJwgwoPD6dZs2b873//o7y8HF9fX6ZMmcKbb755yakzBOF6k5OTQ48ePfD29ub111+nbdu2aDQaDh48yJw5c2jSpAmjRo26JnWTJAmHw1EvJc71TrT4Cde18PBwJEkS4/sEoRHh4eH8+uuvFBYWYrVaKSws5KuvvnItcC8IfwcPPfRQvdy0LVu2pEWLFowePZoVK1YwcuRIoC7dzz333ENAQACenp7079+ftLQ0VzmvvPIKCQkJzJ8/n/DwcLy8vJg4cWK9tGdOp5M33niDiIgIdDod8fHxLFq0yLV/8+bNroUBEhMT0Wg0bN++nezsbEaPHk1QUBDu7u506tSJ9evXu87r27cvx48f58knn0Qmk9Wbdb99+3Z69eqFTqejWbNmPPbYY/VSLBUXFzNy5Eh0Oh0RERF89913f+l+isBPEARBEITrUllZGWvXruXhhx9usNTkGWeCqPHjx1NcXMyqVatITk6mQ4cODBgwoF5qoezsbJYuXcry5ctZvnw5W7ZsqZcO7Y033mDevHnMnj2bw4cP8+STT3LnnXc2WJ/8+eef58033+TIkSO0a9cOg8HA8OHD2bBhA6mpqQwdOpSRI0e68kYuWbKEpk2b8tprr1FQUOBaWSs7O5uhQ4cybtw4Dhw4wIIFC9i+fTuPPPKI61rTpk0jPz+fTZs2sWjRIj799NMGSyJelKs3wVgQBEEQhKvtfOk9rme7d++WAGnJkiX1tvv5+blSdD377LPStm3bJE9Pz3ppaCRJkiIjI6XPPvtMkqS6VFJ6vb5e+plnnnlG6tKliyRJdWlh9Hq9tHPnznplzJgxQ5o0aZIkSb+nJ1q6dOl56966dWvpo48+cr0OCwuT3n///QZl33ffffW2bdu2TZLL5ZLJZJIyMjIkQNq7d69r/5EjRySgQVlnnO/zvnE6pS+Q0+nk9OnTeHh4XJPlpARBEAThemK1WnE6nTgcjiu2utKVcqa+Z+p/xq5du3A6ndx1112YzWZSU1MxGAwNFhEwmUxkZWXhcDhwOp2Eh4ej1+tdZQUFBVFcXIzD4SAjIwOj0cigQYPqlWG1WklISKh3/9q3b1+vPgaDgVdffZVVq1ZRUFCA3W7HZDKRl5dX77g/v4+0tDQOHDhQr/tWkiScTifHjh0jKysLpVJJfHw8VqsVlUpFXFwc3t7el3xP/3aB3+nTpxssOC4IgiAIN6uwsDBmz56NyWRqsE+zfDuy6qs3K1by1GMZ0fOCjzcajchkMjZt2tToWuYOh4Pi4mKcTif+/v7Mnj27wTEeHh6kpqZSWFiIzWYjNTXVte/UqVOuwPHM8oPvvvtug3GyKpWK1NRUjh07BtR10f6xu/WNN95gz549PP744zRr1gyNRsNzzz3HqVOnXNezWq2cPHmy3vVLS0sZO3Yst99+e4N619TUkJOTgyRJ7N+/H7lcTrt27f7yxK2/XeB3Zimp/Pz8RpeUEQRBEISbidVqpaioiPDw8IYJfdu3vzaVuggDBw5k6dKlvP766w3G+bm7uxMYGMiwYcOYNWsW7dq1Izw8vNFygoOD0ev1tP/De966dStqtZr27dsTFRXFAw88gFarPess4aqqKgDatWtXr9UtMzOTe++9l6eeegqoawEsLi4mMDDQdT13d3dCQkLqXb9bt24UFRWd9XpeXl7MnDkTu92OWq1GLpeTkZHxl9Z1/tsFfme6dz09PUXgJwiCINz0zGYzJSUlKBQKFArFta7ORZs1axY9evSgS5cuvPLKK7Rr1w65XM6+ffvIyMigY8eODBkyhG7dujFu3Dj++9//EhMTw+nTp1mxYgVjx46lY8eOrrXQ/3gP/rjN29ubmTNn8vTTTwPQs2dPqqqq2LFjB56enkydOtV17p/vZXR0NEuXLmX06NHIZDJefPFFnE4nMpnMdVx4eDjbt29n8uTJaDQa/P39ef755+natSuPP/4499xzD25ubqSnp7Nu3To+/vhjWrVqxdChQ3nkkUd47LHHcDqdPP300+h0uku+n2JWryAIgiAI163IyEhSU1MZOHAgL7zwAvHx8XTs2JGPPvqImTNn8q9//QuZTMbKlSvp3bs306dPJyYmhokTJ3L8+HGCgoIu+Fr/+te/ePHFF3njjTdo2bIlQ4cOZcWKFURERJzzvPfeew8fHx+6d+/OyJEjGTJkCB06dKh3zGuvvUZeXh6RkZEEBAQAdS2HW7ZsITMzk169etG+fXteeuklQkNDXefNnTuXkJAQ7r//fsaPH8999933l1I2/e3W6q2ursbLy4uqqirR4icIgiDc9MRavTc+h8NBamoq7du3P2+rrVirVxAEQRAEQQBE4CcIgiAIgnDTEIGfIAiCIAjCTUIEfoIgCIIgCDcJEfgJgiAIgiDcJETgJwiCIAg3gb9ZEg/hLM73OYvATxAEQRD+xlQqFVC3/Jnw93fmcz7zuf/Z327lDkEQBEEQfndmVYoza8vq9XrXKlfCjcHhcAB1OfrOlsdPkiSMRiPFxcV4e3uf9TgR+AmCIAjC31xwcDCAK/gTbixOp5PS0lLy8vJcy8ydjbe3t+vzboxYuUMQBEEQbhIOhwObzXatqyFcJIPBQMeOHUlKSsLd3f2sx6lUqvOu7CFa/ARBEAThJqFQKM4bGAjXH6vVyvHjx1Gr1X952T0xuUMQBEEQBOEmIQI/QRAEQRCEm4QI/ARBEARBEC6SJEmYU49g3LzvWlfloogxfoIgCIIgCBfBlnOSmiXrUEU0xe2W3te6OhdFBH6CIAiCIAgXwFFaQc2idciUCrxmjEPhc+NlDxGBnyAIgiAIwjk4a00Ylm3GUVCC+22DUTU7e568650I/ARBEARBEBoh2e0YN+7BknwEtxF90Ewcdq2r9JeJwE8QBEEQBKERhmWbkWs1+Dx3N7LzrJhxo/h7vAtBEARBEITLyFFtwHYsH/3Qnn+boA9E4CcIgiAIgtBA7dKNuI/uh0wmu9ZVuaxE4CcIgiAIgvAH9qIyHJU1qGPCr3VVLjsR+AmCIAiCIPyBYck6PMYNvNbVuCJE4CcIgiAIgvAbW+4pZGo1yiZB17oqV8QlB35bt25l5MiRhIaGIpPJWLp0ab3906ZNQyaT1fsZOnToecv95JNPCA8PR6vV0qVLF/bu3XupVRQEQRAEQbgohiXrcR874FpX44q55MCvtraW+Ph4Pvnkk7MeM3ToUAoKClw/P/zwwznLXLBgAU899RQvv/wyKSkpxMfHM2TIEIqLiy+1moIgCIIgCBfEcjALZdMgFL5e17oqV8wl5/EbNmwYw4adO5GhRqMhOPjCs1u/99573HvvvUyfPh2A2bNns2LFCr766iuef/75S62qIAiCIAjCOUmSRO3yLXg/dse1rsoVdUXH+G3evJnAwEBiY2N58MEHKSsrO+uxVquV5ORkBg78fTClXC5n4MCB7Nq166znWSwWqqur6/0IgiAIgiBcDPOuNDTxMcjddNe6KlfUFQv8hg4dyrx589iwYQNvvfUWW7ZsYdiwYTgcjkaPLy0txeFwEBRUfzBlUFAQhYWFZ73OG2+8gZeXl+unWbNml/V9CIIgCILw9ybZ7Rg37UU/qPu1rsoVd8WWbJs4caLr/9u2bUu7du2IjIxk8+bNDBhw+QZNvvDCCzz11FOu19XV1SL4EwRBEAThghk37EHfpyMy1d9/Jdurls6lRYsW+Pv7c+zYsUb3+/v7o1AoKCoqqre9qKjonOMENRoNnp6e9X4EQRAEQRAuhNNoxpJyBG33hAb7TDtSMW7cc/UrdQVdtcDv5MmTlJWVERIS0uh+tVpNYmIiGzZscG1zOp1s2LCBbt26Xa1qCoIgCIJwE6lduRW34b0arMfrqKjGtDUZ0+4DSJJ0jWp3+V1y4GcwGNi/fz/79+8HIDc3l/3793PixAkMBgPPPPMMu3fvJi8vjw0bNjB69GiioqIYMmSIq4wBAwbw8ccfu14/9dRTfP7553zzzTccOXKEBx98kNraWtcsX0EQBEEQhMvFUVmDLe8U6nYxDfZVf7ccj8nDUceGYz2Scw1qd2Vccmd2UlIS/fr1c70+M85u6tSpzJo1iwMHDvDNN99QWVlJaGgogwcP5l//+hcajcZ1TnZ2NqWlpa7Xt99+OyUlJbz00ksUFhaSkJDA6tWrG0z4EARBEARB+KtqV23DfVQ/ZDJZve3mfYdQBvqiCgtF7q6n5qfVaFpFXqNaXl4y6e/Ufknd5A4vLy+qqqrEeD9BEARBEBolORxU/HcuPs/PqBf4OWtNVHwwH99npiNTqwCo+GA+XjNuRe7hdk3qejljG7FWryAIgiAINx3roWOo20Y3aO2r+XEVHrcNcgV9ALo+HTFuTb7aVbwiROAnCIIgCMJNx7Q9BV2P9vW2WQ4fA4UcdWxEve2adjFY0jIaTPKQJAmn2XLF63o5icBPEARBEISbitNgRLLaUPj83m0qWawYlm7EY8LQBsfLFArUcRFY07PrbbceyaF2+ZYrXt/LSQR+giAIgiDcVEy7D6DtllBvW83idbgN74Vcr230HH2fjpi2JNXbZly7E/3AGyvlnAj8BEEQBEG4qVhS0tF2aOl6bcs5ibOqBm37lmc9R+HnjWSz46g2AGA/VYTcTYfC2+OK1/dyEoGfIAiCIAg3DXtBCQp/H9fkDclup2bBajzuGHHec3W9EzH9NsmjdtV23Ib3uqJ1vRJE4CcIgiAIwk3DtDUZXe9E1+vaFdvQ9eqAwtP9vOdq4mOxHMjEXlmN02BE2eTGyzMsAj9BEARBEG4KktOJLeckqshmQF3rny0nH+2fZveejUwuR90ygpqvf0E/qG5s342WDlkEfoIgCIIg3BSsh7NRt4p05e4zLNuCx6ThDXL5nYuuazymbcmoW0ViTj5M7bLNV6ayV8glL9kmCIIgCIJwJdlyTuKoqkEymnEazb/911T3X5MZuVaD54xbLzhwM21PdqVrkaw2nJXVKIP9L6pOlkPHUIaFYt5/FPPmJLwfv+Oi39e1JAI/QRAEQRCuO9ZjJzAs3YAmoSVyvRZloC8yvRa5XotMr0Om02BYuhFbdj7qqObnLc9Za0IyWlD4eQNg2X8UTULcRdVJcjox7z2I2+DuVP3vOwI/+Scy5Y0VSt1YtRUEQRAE4aZQu3o7XtPGoPD3Oesx+j4dqV21/YICP/Peg2i7tnO9Nu0+gOe00RdVJ0tyOuqYMMxJh1E0C0am01zU+dcDMcZPEARBEITrir2wFJlScc6gD0AZEoCjtALJajtvmebkw2g7tq4rv7gc69EcKj+Yj2nPgQuuV+3anViP5OJ191i07eOwHs4+/0nXGRH4CYIgCIJwXaldtQ23YReWI0+b2Bpzcvo5j7EXliL3dMeWe4rK2Qso/9dnaDu0wufZuzGu331B6+1a0rOxZh7HY+JQlKGB6Pt0wrh+F5aDWRdUz+uFCPwEQRAEQbhuOKoNOMurUYWFXtDx2q7tMO9OO+t+p8FI5Sc/YMvMw3IwE/dxg1GFh+JxxwjkWg3uI/pgWLrxnNeQJIny1+fgOWUk6phwABS+Xlgzj2MvLrvg93Y9EGP8BEEQBEG4bhjX7UI/uPsFHy930yHTqnGUV6Hw9XJtlySJmgWrsReW4qisIeD9Z5ErlTgqqpFp1ch/G5+niY/FuHkf9sLSs87wrfpiMYpAX9wGdHVtc5osyNQqHFU1l/hOrw3R4icIgiAIwnVBslixZR5H3Sbqos7T9eiAaXtKvW2mLUnINCrcBnVD378z8t9m35p3p6HtFl/3//uPYj9VhMftQ6n5cVWjZdeu24XlQCY+j91Vf/vKrcj9vMAhEjgLgiAIgiBcNOOWJHS9Ey8qoTKAuk0U1sPZrlU0bHmnsaQcwX3MAEzbUtD1/H2JNsuBTDTtYpAkidrlW6j+bgWKID+UTQIxpx5xHSdJEjVL1mPamYoqLBRV+O9dz7ZTRVR/8wtug7rjOX7wX3zXV5cI/ARBEARBuOYkpxPzvkNou7S94HNs+YU4jWZkcjnKiCbYsvNx1pqo/m45XveOQ7JYcdYaUQbUzQ62F5ai8PdBplRi2pGKzE2HTK/FsGgdbiP6ULtyW11iZ7OFqjmLMG1NwpqRh75/F9c1LUdyKH3+fbwfuwPdH9LD3CjEGD9BEARBEK45895DaBNbXXBC5NqVW7Fm54NTQrJYkbnpqP5uOXJ3PR63D0Xu4YZxaxLaTm2AupU6an5cheRwUv7frzDvO4Tn9LHItGqqv1iMNSMXe3E5RQ+8BjJwG9YbubcnCqMFR0UVkiRhXLUNa0Ye6laR6P8w3u9GIgI/QRAEQRCuKUmSMG3ei/cTd13QsYaFa0Aux/uRychkMiRJwp5fSNHD/0bZLATDkvUoQgKw7DmArm8nyt/9GpBhPXwM35ceAKeEMtgf9+F1KWPULZpRu3ob2n6dqXjjc7wfnoTtSA7Omlo87x9P7YotWA9koY4LRx7sj75dzEV3R18vRFevIAiCIAjXlPVIDqoWTZFrz70ShiRJVH/zC3IPNzxuG+wKvmQyGZLZgjqyOZ53jcDnmekogvyQ+XigbNEMn8fuxGPsAPQDu6IOb4Jx7Y56M4eVTQKx7M/AsiuNwE/+iT33FEjg/9ojqIL9cJZVggy0PdrjOF2MpvXFTT65nojATxAEQRCEa8q4dif6wT3OeYxkt1M1ewGqFk0bJHd2VBuoWbgW33/eh2VX3Uoc1pR05Dod5p2pVLw/j7L/zAGnE3PSYZwmiyt1i724nIp3v8Zz2mgkixVVSEBdmTW12E8WUv3dCvxeeRicEjXfrcB9ZN/LfwOuItHVKwiCIAjCNWM7WYTcww2Ft8dZj5EsVipnLUDXs4Nr2TXXPqeT6s8X43nXSJT+Psi0aozbklFGh2HPPYnPI5ORnE7KX5uNuk001d8uR65WUf7fr1D4eOIoq8RzxjiUAXWTPgy/bkLu44njSDY1P63BZ+Z05DoN+iHdKfvXbLwfvP1K35IrSgR+giAIgiBcM7Urt+E+ss9Z9ztrTVR+8gNut/RutIvVsGQ9mk6tUTUPAUDbI4HKD7/F9/l7MZutQF1XsiYhDnVcBMoAH3yevRsAR2klCh8P14QSbcfWGHek4igqRabVIsnlmPcdRN+7I+akw2jat8KafQJ1ZPPLfRuuGtHVKwiCIAjCNeEor0IyW1D+1r3aYH9VDZX/+w73cYMaDfosaRk4K2vQ9+7o2iY5nEhOJ7ZTRaij6wI00479aLsn1OX061WXJ1Amk7la+f5I17FVXaCY2BJLyhGqv/4Fe0k5zioD3g+Mx7B4vStfoCRJSFbb5bodV4UI/ARBEARBuCZqV2/HbWjjY/skq43Kj3+oWx83slmD/U6DEcPyLXhOGfX7OZKEad0u9P26YN6dhiqyWV1evspqFAE+dXkCO7c5Z51Muw8gd9Mj12uRagxYs45zevzTOCuqMW3ci0ytwrhuJ5IkYT2aS+XsBX/tJlxlIvATBEEQBOGqcxrN2E8WoY4Jb3S/aXcauh7tUTYJanR/zcI1eNw2CJla5dpmSctAHROO26BuWA5kovDzxpKWgSYhFktKOpr42HPmCXQYjJh3HcD3pQeomrMITcc2BHzwHDgcaHt2QN06ElVsOFVfLKH8ra8oeeYdHGVVf+k+XG0i8BMEQRAE4aozbU1C16fjWfebd6Wh+21N3T+zZudjNTv4wRzh2iZJEsY1dWla5D5eSFYbktWGaXca2q7xGDftRd+30znrVDX7JzQ9EjCt3437rQNxVtVgST6Ctncitsw87KeKcR/WC58n7kIdF4Hc0x3PqaMv7QZcIyLwEwRBEAThqpIkCXNyeoMZumfYck6iDA1EplE3PNfppOanNSxtNZTsCjD9NsTuTGuf3E2HLScfbcfWmHakgs2Bs7QSZaAfcjfdWetkLyzFtD0FVZNAVNFhqEIC6lK+OOy4D+iKum00tuOnqV27E223eGrX7USu06CKaHJZ7snVIgI/QRAEQRAumi3vNI5qwyWda0nLQNM2GplC0eh+44bd6Ac2viSaccMe0kLbEh3mTudQOFlTv7UPwHYsH7ehvaj5eQPazm2oXbMd/ZCz5wmUHA4qPvoOdWw4ksmC26BuWI+dwPvRyXhOHYO2azzmPQfwnDoaR3E5tSu31dVdrrjhVvAQgZ8gCIIgCBfFabJQPf9XKj/+AcvBrIs+37RxL/p+nRsv22DEWV3b6ExfR2UNJ7amkx/fmb7h0MwT8qvAsv+oq7UPwJZb12JoP1mIzF2PZHegDPI7a31qf92Ms8aITK3C884RdbN1jWYU3p7I9dq6Hzc9jtJKPO64BXtBCfbjp5H7e2MvKLno938ticBPEARBEISLYliyHvcxA/B9eirmfYeo/n4Fkt1+QefaC0qQueuRe7g1ut+0LRld78RG952et4otHYYxuW1dK1szL8ivqt/aJ0kSzppaqr5YhN//3Ufl7J9wO0drny3nJJaMXGzZ+fg+NwOZQoGzrBKFv0+943R9OmLamoRMJkPh7YFMo0LVPARzWsYFve/rhQj8BEEQBEG4YLaTRTgrq+u6ajVqvO4eizqqORXvfI29sPS85xvX7cLtD+vk/pEkSZhTj6Lp0LLBvpqD2ewu03L70FBUv/UQ+2jBeeAo6tgIV2uf/UQBlkPH8LxrJJLNjqOkHNlZgkzJaqNq/jLseafxvGskcnc9ANbM46ii6ydpVsdFYD2ai9PpxJKejbZbAuqwEGS/5fS7UYjATxAEQRCECyJJEjU/rsRj4rB627Wd2+J1zziq5y/DtD3lrOc7jWbsRWWowkMb3W89nI26ZUSDsX9Om53dczYQf88gfOrNz5AI3rcD3aC6QNJpslD+zte4j+mHafO+umXeeiVS+8umBteyF5ZS/uF8sNqQe3vgPuL31UNsx040SDMjk8nQtI6i9pdNyN30aNrG4Da8N/qB3c76fq9HIvATBEEQBOGCmHfuRx3XAoWfd4N9Cn8ffJ6agr24nMrZP+E0mhscY9qWjK5X4924AMZNe9D3bTj2b/tX29D17UxMs/qzci37jyKLjqBUpqtbz/eT71GGBmDLPIEqqjnqqObUdQpLrrF4TrOFmgWrqflpDaqwJmjax6FqHuJq7YO6oFDRyJhAXa8OGBavA7kcbUIsADLVjbX6rQj8BEEQBEE4L2etCeOWpLOutAEgUyjwuHUg+n6dqHh/HvbTxa59v6dwadXouY6ySpDJUPh41tt+8EgFtmMn6D62fk4/yenEuHoH7oO7k19qp/LjH9B2aot5537cJw9H0zYaha8XyOXo+neh5uf1mHbup+Ldb1DHReD96GRsx06AQlEvGHWaLcjUqkZn68o0apwmM1KNAYW/D5LTiaOi+kJu33VDBH6CIAiCIJxX3YSO/udc+eIMdWwE3g9NpHr+MiSnEwDrgUw0rSLPer5x0170/bvU21ZcCzlfrqTb48ORy+sHYsb1u9F2bkuonwrzFwtQRjbFuDUZbZd2aKLDsBzIRN02Gl3P9ph3pWHakow1Mw/fZ+9GEx+LLSOP0uDmlKfloG4d6SrXlnMSVYumjdbRtCsNbbd4HDVGDL9sovz1zzFtTT7v/bieiMBPEARBEIRzsh0/jbPWiKZV5PkP/o3CxxNtxzYY1+0C6gI7Xf/GU7hIdjvWzOOoW7ZwbTPbYdmidDol+KFvWj+1i6OiGsv+o2j7dsRzwSIqnGrsJwrxGDcQdVzdah6WAxmofgsAa5dtxvfVh5BqTa6uWePmfeyXB7NDHYbZ8Xs4ZMs83ugycpIkYVy/C9uxE1gPZKAKD0UdH4vCz+uC78n1QAR+giAIgiCclSRJ1CxYjcftw85/8J/o+nfGciATS3o2Mr0Ohad7o8eZk9PRJrZyda9KEny310yvzO2EjO/f4Pia71fgMXEYNd/8AuUVWNw88H50Mvb8QtTRYXVdsNW11HzzC9ou7XAfOxC5SgUqFbYTBThrTUi1JrQZGXS6rSPzDtRdExq2+ElWG8YtSZT+3/+Qak0og/zxmDQcp9OJLTMPbY/2F31friUR+AmCIAiCcFambcl14+X+NPbuQshkMjwnD6f8v1/hNujss19N25LR9ezger02W6LN+p8Jnzq4wbJt5tQjyP28kamUGLcko+/XmdIhQzHbZdiy81FFNsOWcxLsdjTxsWhatkDfty4Hn/vofhh+3YRp536c7VqitZmJjvElzAs25NYFuZLdjkytAsCamUf523NBBuroMLwevB1FgA+6vp2o/GA+bkN6YFy786Lvy7V0yYHf1q1bGTlyJKGhochkMpYuXeraZ7PZeO6552jbti1ubm6EhoYyZcoUTp8+fc4yX3nlFWQyWb2fuLi4S62iIAiCIAh/gdNgxLxjvys58qWQ+3oD4CitaHS//VQRCh8vVx6+IyVgXL2d1h2bNuhydZot1K7chsfYAdQsWY86LgJ9r0SaetQt3eY0GJG76zEnp+Mor0b/W7CpbBKEo6QChY8ncp0W48Y9FJWY0fWqCzaHREJWOeQcLam3Yohx/W68H5qIrnsCjqJSnJU1aNrF4CgsA7kM0479qC+i+/t6cMmBX21tLfHx8XzyyScN9hmNRlJSUnjxxRdJSUlhyZIlZGRkMGrUqPOW27p1awoKClw/27dvv9QqCoIgCILwF9QsWov7rQPOuqbuhTBtS8Zrxq3Urt2Js9bUYL9xwx70A+rW5S01wpaNefRy5uM2rGeDYw0/b8B9RB8khxNLyhFX7r1mnnDylAG5R11KFuPanXiMH1xvIok2sTXmpMOoE+JwFJRSnZJJaLe6xiWZDKbGw7ZNx7FHhAF1QaRktaHw8cSScgRN+5ZY0jJQR4dj3LAbtyE9sOXko2oWfMn35lq45OQzw4YNY9iwxvv7vby8WLduXb1tH3/8MZ07d+bEiRM0b9680fMAlEolwcE31k0UBEEQhL8bW85JJKsddWzEJZchSRLmpMP4Pjsdhb8PNYvW4jV1tGu/02zBXliKKjwUqwO+3Wlg3JG1+D47pUE6FduJApzlVWjiY6ldswO5TuuajdvMC45tyUcVFYbtVBHOqhq0XdrVO1/btR1Vs39C7uuJskkApQ43Ovj9HgbpVdDHfpyFpoHcJ4FpRyq17dtzshia7tyP591jsaRlULt2B+63DsS4fhcyvfaS7821ctXG+FVVVSGTyfD29j7ncVlZWYSGhtKiRQvuuOMOTpw4cc7jLRYL1dXV9X4EQRAEQbh0kiRRs3gdHhOH/qVyrAez0LRsgUypRNOyBVhtWLOOu/abd+5H1z0BSYJv9zsZunMxwXePRP6ngKpuxZBVeEy+pW527cY9aDq1QiavC2N8tCBlH0cd2Yyq2T/hPm5Qg8BR7qYDtRLb8dPIfb1RVNeg/lNDppe5mrgYb5ZngSX1CGv0cWw7UM1JowJHYSkydz2SxVa3Vq9SgcLbA8Oam2SM38Uwm80899xzTJo0CU/Psw8O7dKlC19//TWrV69m1qxZ5Obm0qtXL2pqas56zhtvvIGXl5frp1mzZlfiLQiCIAjCTcN6OBtVeOhZZ+FeCMlup3bDbnQDfs/N5zFpODUL1yLZ7XWtgXsOou3clg25EL1jLWH926JqHtKgLOOGPWgS4uq6XdMyAND1/D3pskwG+uJCJKUCW3Y+brf0brROci+PuvLsMhShgViz8137nAYjMr2W3s2h5thpTrgHY5UpuaM2mUORiaSvT69b03fycAy/bMJ9dH8Uvt5YUg5f8j26Fq544Gez2ZgwYQKSJDFr1qxzHjts2DDGjx9Pu3btGDJkCCtXrqSyspKffvrprOe88MILVFVVuX7y8/PPeqwgCIIgCOdXu3obbsN6/aUyqr9dgSX5MHIPN9c2ubsefZ+O1K7Yhi0jD2VEEzKrVRTvPkI7d1O9mb1nOCprsCQdQj+wbhygcdNeZFpN/ZQrkoSbUqJwwQZUUc3PGrA6isqw5Z2mpE173Eb1o3bZZtc+W3Y+6ugwZDIYUriPb9w60jFYwpGezfhx0Zg27OZ4l944q+oao5ShgdgrqvCcdMtfuk9X2xUN/M4EfcePH2fdunXnbO1rjLe3NzExMRw7duysx2g0Gjw9Pev9CIIgCIJwaSxHclA2Caq3du3FsheWYli4Bv2g7g3SnWi7J2DLzqfmpzVY+/Rk9d5yBubuwPPOEY2WVfPDCjwmDkcml2MvKEEymtC0ja7Xles4XYyHVkZVrQNNu5jG61RQgtzHE5xOTti0RER4ofD1crX6WTPz6nIAWm04i8vwahFMxpLdSG46HLn5NFcaOdgikYPfbMJ9TH8kqw1sdlThoZd8n66FKxb4nQn6srKyWL9+PX5+DRc7Ph+DwUB2djYhIQ2bfQVBEARBuPyMK7edtav0Qlgz86j4YD7KmDDsJwqonvcr1d8uw15UhiRJyGQytD3bY87IY16mhrH7luB73zhX7rw/sqRlIPf0cAVXtWt3ItNo0P0pabIl8zheJacpc/NFk9B4Gjjjpr2oIpqg690R2d5Ugt1BP6gbpk17gbrJI8rmwZiT08kNa80ww2E6Lv+Wvek17HnmC7Z1HYX1ZDG7CpQ8m+bPvB8zWUo0b9xgyUcueVavwWCo1xKXm5vL/v378fX1JSQkhNtuu42UlBSWL1+Ow+GgsLAQAF9fX9TqumSMAwYMYOzYsTzyyCMAzJw5k5EjRxIWFsbp06d5+eWXUSgUTJo06a+8R0EQBEEQLoA16ziKQN9LHttn2n0A88663HbWg5n4/uNe7CWVlL82C0eVAanWhNxDj/VoHiVtE+j306cE3jMEZYBPg7KcJguG5VvweXpq3etaE46iUnBK9XLtAZjW78KnT3uO7Smo1wW85ThEeEMzNzv24wU4qgy4jR2E8s1fwGxBGRKAo7IaR00tSBIypRLTzv0c8ohn8MLPCfj4GZqdKqbsnzvodlsMhl824P5/g/mxGFqmpHOizyD6trykW3XNXHKLX1JSEu3bt6d9+7qo+6mnnqJ9+/a89NJLnDp1il9//ZWTJ0+SkJBASEiI62fnzt+bfLOzsyktLXW9PnnyJJMmTSI2NpYJEybg5+fH7t27CQgIaHB9QRAEQRAur9oVW3H7LTfexZAkCcOyzViPZOP12GSsaRloOrdFplCgCvbD7//uRTKa8XlqCto+nZB76Mk3awh01tbLtffH8qq+WITHbYOQazVAXT5ARWhQg65cyWrDciATj0FdsGj1rpm+VRZIKYCVWbDxlyMo4lqAyUKhzhd5x3jMew4AoOventplW1A1D8FRWkHpiTLabf4VnwdvR9UkCNO2FPT9u1Dx8Q8gk6EK9ueOVg4MpTVsqPQmwvuib9c1dcktfn379kU6s7BdI86174y8vLx6r3/88cdLrY4gCIIgCH+BNTsfha/XRS/NJjkcVH/zK4pAXzynjcGWnY+z1oi+b2fXMcomQeh6JVLzwyqchlp0D04m95Q/I+JvoeLDb5G761FFNHEdX/vLJtRxLVw5BCVJwpxyBLmXG9rh9SedGFZsRRUdhj0jD1NsDEZbXU6+DTkwLApi/SBteSqr9E3o3bUDuZUQ0KsN5p++Q9+3E9qOran6eineD95OxSc/UlBqIaZbLMqIUCo+mI+2U2twODEnHcZRWoFktWHLOo5/23B8tHWtin3DL/m2X3VirV5BEARBEOpa+y5ibJ9pRypV836l9Nn3UDYPwX1EH2QyGcZNe1EEBzTovtV1bYejyoDteAGH8Cc+CGQaNd4PTaT6+xXYC+t6AC1pGdiLy12zeAEs+4+ijg1HMltR+Hq5tjsNRkzbU9D17YQlLQPP9jGcrAazHfIq64I+Z1kFYYFqeppzWaSMY/UxiAhQo/DxxF5QgkytQrI7MG5LpiYtC4e3N26BXpjW7MTn8Ttxllai8PdB1SwYjzEDqPp8Eeakw+wLas2DHWmQC/B6JwI/QRAEQbjJ2fJOI3fXo/Dzdm07V8+d02yhds0OrOnZaOLjsGXnU/7fr6j437cYlm5A0za6btbrn8iUcmQqJelpRXT4bZEuubse7wcmUPXlEqyZx6ldtQ2vaaORkLE+B4oMYNq4F7mXO9rE1vXKMyxZjzqqOerIZjhrTYSGuHOyGjbnQe+wuvx+xq3JqMJC8IpryqPdlNRY4LtDYOuciHHTPpxOJ84qA+akw5TalITpbKiaBuH10ESKJD2pKUXsWprGquh+LJTFsk0XxdGfdrDNGsQ/NkJ5w1Xormsi8BMEQRCEm5xh+WbcRvZ1vZYsVkpf+AB7QUmjx9f+uglHWSU+T9yF55SReN8/Ht9n70bTLQGZWg2SRMUH8zFu2OM6x15SgWR3oH98GqErfsHNYXbtU/h54zFpGMVPvYXHXSORqVUcr4Rj5bBhVxGby91J33oMa3ybuvpJEoalG0GpQLJYQZJQRTajmVddS9+BYkgMAcnpxHokB9uJQnR9O2NxQKsAGNcSvjaEceq7NRSMexK5jyeWg8cwyjWEP38Xuh7tkclk/Lq9lFBvOXER7gzq6MOwKOjSJRh7UBBtD27l0c7ccETgJwiCIAg3MdvJImRqdb2u2ZqfN6Dv15mahWsbHO+oqKbqm1/w/b/7UAb719tnWrcLt2E98bhtMD7PTMecko6jrBKA2uWbcR/Rl4MWD9zHD6bqqyWuVkVJkjCu2o73w5Mx/LgayW4nvbRu7NwtJ3cyeFICeo2cH3N0/G+3k9QPl2GTKXCfOAzJasN6JAdNfCw+WjhYDN2a1LX2Zf26F6ebW113rY8nxyshzBtCPeBBzVEqouI4FJWI+Wge1hoT3qP7oPCvuw9HSkC3N4nsvBp+Du/Huhw4XgXmfYeZlzCeMT5lHF29n/wbbKVYEfgJgiAIwk2sdtkm3Ef1db22nSjAUVqB25AeKHy9sBzJce1zmi2UvvABnlNHowoNrFeOZLdj2X8U99H9AZDJZHjecQvV3y7HXl6Fs8qAMiyEQ0cqiHOWYcs9Selz7yNJErXLt6CKao77sJ7o+nWi6vNFHCuTiFCb6gLHgmJaDGzL/e3sTNq7AHvTJswP6cPcTVWcVnljyTnlSuNyqga6NIHiYiOlH35Pyv5izFYHlZ/8QNW7c4n6fi5lb31J1etzaN0hlKjMVAo0Pmya/Cj+y5dR/f0KHAYTaw+bCTiWTtc+4Tw02JveYVBggJUbTmJo0oz1nUdRunk/t2Ssv+Kf0eV0ybN6BUEQBEG4sdkLSkAmd7XcSZJEzY+r8Lr3NgDcx/Sn8uPvUcdFgN1B5cffI9Nr8bh1IJJU16p2hmnvQRS+niiD6hZssB47gbPWiMxNR+nMd1CGhVLy1lz8jT7oh4SjeuFeKj/6nsKp/4e2azu8H7wdAG2HVpgrjYStW4693At9n04YN+/D695xVHw4H89B3ekcH0sni5XitelkeQWQnWGl8qAcLw1E+tSNu0t671fadokGnY7v+03jjvZKUjJgSjzYF69C37MDiiZBaBevw7NdLCd79uew1kInHx+Ovv4tcVVOAuQmPMcNRCYDNxVkHy0hPNKftwbKWHBYgaxlLFVV5Vf3Q/uLROAnCIIgCDcpwy/1W/uM63ejad/SldJF7q5HkxCHaVsylgOZKAJ80bSLpdCkZFYyjI2D9r9N0jD8tAb3cYORHA6M63ZR9dUSZGo1MpUSa24+/v95jGRnAN520IXVneP14O0UTX8R6+FjSEYzMjcdAPmtOxKQtbEu4Ht4Eo7qWkpf+ABVbATGDXuoXbsTu0JJxtqDWNu2JmFQO9Qt4OUtoJTBpz+fZlheFn5T+iAZjNzXWckXqXUBobq4CEtJOZ4Th1GzaB2KQD+yvJrzj34q8gI7sX7xHvb3msqoN58kMDoAR2kl1So35qTAhNpDbIhtg1IBxypgpiWdiBlTruZH9peJrl5BEARBuAnZC0vB6UT5W5eto7wKS+qRemlUAHQDulD50feo20RjP1VMdqtO/HgY2gTCj4fA4QSHyYwlPRt7QTFlr3xKzY+rCHh7JiHfvokmPgZ1i2ZUffUzqQW4ZvNKVhs1c5cS9PW/UQb7U/r8+zgq6gbMHT5loZmPDJlOS9lLH2M9kIHbiL54jBuE96OT8X1mOps7jyQkMoAWx9JIqnHn0ySI8YMHEyWab1iJzceHdVsK2BHWGbsT7moLp6sl0j9biefkW5Bsdsy70zDq3LF37YReBa0SQoiwlRO0ezulPsEEvPcMJfNW8MUeK1PaQc2hXCI7hbMwHaJrTyMLDmB5XsOl5q5nIvATBEEQhJtQ7fIt9WbyVn+3HI9Jw5H9of9WkiQMP61F368zlpQjHPaLIrlUxdR2cKoanOWVbP5sM8X3v4rC0x1th9bItBr8356JKiwUZ60JyWon4ON/Ys49RfNv5+Iuq0vzYtqdhrZ7AqogP3xfuAdFkB+lL35E9fcr8PxyHgFh/nhOGYmzrJKgL1/D/ZbeqJoF4zQYyZ39C0Fz5+JeVYpHbDP6Ht1EiLWCKB9Y8GUyklZDq6Ht6O1vpEmEL0sz4L090L8ohZyASNaUe2PalYbM24NTNdCrZ91awFYHHGrelt5rv0X16HSe3+XG10F9mHhsLQHWKvKc7ni7K6mxQq/je1jXtAu9w67qx/aXicBPEARBEG4y1mMnKKpyoGpW1/xmTj6MMtDP9fqM2l83I3fXo713Ajkr9mFr14Zp8bA+F245uo6nCleztDYEk8WJ32uPULtyK553jXTNEDau3Yl+cHcUei0nn3+eJlUFlL/5JdbsfCz7DqPt3LbuQhJoE1thPXaCih/WIPf1xLwjlZrvV+I+uj9Kfx8c1Qaq5y+jau5StqjCiW7hifvYATirDNROuo1WKxcRTTmDSlLxktlYU6Bju3crfLQwoz109DQSlZXKqY7dWZ0lse6HVAwnSyns0IXg35Ym3pQHre1FWJxyDumb0rs5BHaO5XihmaKFG6lt1Zr1uaCxW9DVVNLcXoHzs/lX62O7LETgJwiCIAg3EafZwsl5q/kqZgRbjoPTZKF21Xbcxw5wHSNZrBh+3YSz1ohlUF/mL8zBfXB3Oh7ZTo0VSgtrCCgvIOSxiYzpF0hugZmKVbtwG9EHVfMQ13WsR3PRxMcCkFqpo9k9I1G2aILh5w1Ys/JwGk1Ufr6Igqn/h2H5FvT9OlGjdafp8Qx0vRNRx4Sj6dKW6h9XUfX5IrRd2pI+4nbaHN1L4D1jqPpiCYrQALZk2mhx91DKn/4v7Ya3QR/iwyDjUTqOaMvGXPhwDxgXraHdjIE83E3Ja01zUbtrOJpeilv/LhwpgdQCWHPEiuGHVWR06k8fdRFx/tCjGWT1v4Wj8zZQEByOUgaBe3dyMu0Eka+9gsHN+xp8ipdOTO4QBEEQhBuMo9qAs7oWVdOgiz635oeV7GnTjw4t9BwphaCVa4kY3R9USiyHj2HanoqzxoC2U1sKe/RlyX4Zk07soMnDt1H93Qo27yhkwLFk3EbULe+WsOVX9nt6sMyzHXfFRLquY9pWt5SaTCbDbAerE3x7tKUi+QDKID9kaiVF976CrkcC/v95FFVYKDK5nKU7TQxbNx/D6u1Y0zLQdmqD24g+eE4cRq1VonLmXHr1bUHNjytxVtVg69WNgNUbSB3Yl5iYUCw/rkQzYTpVKTuIDtEzJQQqDx9noU3ik4pw4rMhYdVWEgqPYnz6drJ1GrIr4Kd06LZ6EXkxCTQb3xfr5j2UjRkFQKhe4lRQBMrPvyXAV0P0zvVUtG5H6dQp+Iwefnk+1KtEtPgJgiAIwg1EcjqpmrOQmvnLsBzMuqhzLWkZ1DoUnGwazaES8CvKJ/tICUUp2VS8+SW2zOO43zoQ35nTSW3RkbU5Mh70OYFXE1/kHm7IRg1C/fNyAgwlqGPCAahdtI6Q8X2xJ7RjTgrYnWeudRRtYisA0oog/rcY1fPOEVT/tBqn2UrwnJfxfXoa6oimyORyJAmqZDqCnpyM7UQBuu4J6If2xJZ1gooPv+XwpJeIcZSh9NAjKRSUTpvGvvRqgv01eC5YSNQzE5GplTTdtJaiuHZ198vhoHrROjwnDmVmNwgylJC98RDHdMFoeiTirobUQoiyl9LvdBJNHp3AwD6hxNpLGNTMRj9VAeHffkUgRroqi1Hk5HG6fRcS2vjS7aHBtAy4LB/rVSMCP0EQBEG4gRiWrEfbpR0+T0/FuGkvpp37L+g8p8GIYfkWtnYYiq8O7izfTcy7b+EboGe5rhXyJ2fgPnYAygAfth2HnAp4sCPY123DbXgvANZX+RDpKEfVohkAlZ8vQuampeWUgViddTN25+4He2klcnc9MnXdjNd6s3nNFhTeHmCzu1bJOCO/Gpp6glyjRhXkh0ynxXGyEHVsOFWt21LWrj0xnzyDZkA3UrLNZMb34Pj+fHZZfIlS1GA9dAzPaWPRnzxOxelKTHsOYFi6keLWCYSF6pEhETJrFq2n9KVZbBCzl54ic99xwrwk7khawLHo9vRo5wWAOjackufex7hhD3k1CrJHjkcT5EuIqQyzJKfyluFITglLRt5f/kyvJhH4CYIgCMINwnIwC0d5NfpeicjUKrwfnoj1SA61a3ac99zq+b+iHDeMIpsaDmfgtX0bkS/cRdrwibTp3JyvUmVYHbD9BORUwnj5MSzJh5DrtSj8vDFY4XSBEd9QbyxpGRhWbcOwdAPeT9yFTAYjYyC3Elr6w7pFB9F2iQdwdfN6aOrqYdy4B0WgHyjkOMqr6tUxvQRaB1I35nB0f7wfmIDb8N7IPN1J/zWZrk+NwuqAL7cbCQ7QovHS01JdQ2j+MQ4FtaRi6SZsOfkobVb8VizHtHM/1fOXUVJYQ7ibDcOS9Uh2O6rgAHSJLUlc9SOWNdvwfPt/GE8UUzJoKJLDQdp329i3OYdUg54vI4ZRmFdGYFoy77S7k0q5Fm/JwkvZoXz7z19J3Z1/mT/lK0uM8RMEQRCEG4CjohrDr5vwnTnNtU2mUOB591gMi9dRs2A17hOG1EvHcoZp534UgX5sVTQj3s+ONG8jyiAP3AZ0ZpoDPt4HXULh1S11K19MCa+hevZmrJl5KJsEYlixlc2qCAatX4aECafZiuFQDiiUuA/qDkALH1ibDdF+kJebxfre3RhBXTdvwm/dvHXLs23F9+UHUXi6UTX3Z5AkZEoligBfKqr86B6vxbz/KL4j+wDgNJpJ//Bn1NMmotMpmZUEg8zZ+Ib7cfSLH3HLzCChbxxVeUcoTMuhPCSc6GfvZrE5hrhd3xL40T8oWVuM7J9vYagyYJ02kT1r0shU+NOx6jRuZcXoFRLFNSYql23hp+/LqYjvgOXWe4hc/TNBH32Mxc+PyrsmE7ZwC15eajROI48veIWqKisnut5N1wZ3/PolAj9BEARBuM5JTidVXy7Ga9poZBp1vX0ymQyP2wZTu3Yn1V8sxvPuscgUChxOUMjBUVaJaWsS7k9OIz0Zuh3ejm+EH7qOEchkMrRKuDsBXttat+qFpwaMyzej69cJ5DLULVtQtW43PhnLcTOXont2OunhCRS/+DExERGu7lyA0bGwanc5Y6I8yXEqSS6o6+a947esLcaNe5D7eqJp2QIA36en1b0/mx1TUQXWzaWYf16OMsiPqg++BY0Kq83Jro5DmdzKk0/3wYgY8P73SrKOlOKr0hBYko86eii+Hu54J0STk3qSr+xx6PUKbA9OBx1YfKyU+IZSkZ+NefFOfCw1tD6wleOdu1NbZqBZbjoE+nPbsXXI5DK0lRa0BhM/hzWn/MAJ+szoSeaKVYz75St87hqJLKEl9g++IvTr95C56a/GI3DZiMBPEARBEK5zZ8b1KZucfRav2+DumHYfoPLjH6ieNIFZB9V0CZXotuwX/KeMYlehku7uVVQezCHa24RMG4t570Ekm50TxQ4STjrQSHY8D9ZwInkfvj9vQNs9HlVYKHvveYq4lK14tPBjdQ4E7/iFuAAZWX2GEfOHOoR4gH96GlUd4xnXEt7dBTpVXTev5HRS9fki/F99uEHdZSoluZoAQiOcqGua4f3gRCQJ8l/4iGy5D50mNufzJDu31h5C88j3FBwrYsW0F5lRsZ2Ayd0xrt6OJiEO9z6JNN21n9CACpbO309KcQ6lHv448gpZ3q07zXq1wFtlQ/3NXFQyOXkJCfQ9upmSIzVU9xtNs6eH4zRbqP5hJaZtKRiyVXQ9ug3Z8RhKT1bSaXBXmr8whapPf0Tx4G3kL1hFiiKYOx/rcQU+9StDBH6CIAiCcB2zHMzCWVGNx22Dz3usrms75B56Ut/8lsefuo2qZVvYVyTHbeF+KvJKSDyRgt0zFMlYhbPGiEyj4li1kiMGDRO6KzlaqaR65VHyItui8pUR+uRtGG1wYpeFNlk5zInsx9gRMqJ947EbLSxLUzPsT3VoX5XNEnrzqAL8dHWrYQDUrtmBwscTdWxEo3VPL4GO29fhPn04NgcsXpSB0R5Kk4rTZM/6hXGGHGRl5WzoOhZDSCH/N84P2y9ueNw+lOqvf0HTLqYuxYvBiHnuYmQtB/JdVHfu/PVDWge5E3r0V/xefYii+1+jUO+Baua9xM5bjKqJO3sfep7m6UnYC0tRBvvjPX0su7/fQ8Dmn8gZczvO3SdoVZBDk8+fwLwjFUVoANndBlL02T+Iuve2v/gJX10i8BMEQRCE65SjohrDss34Pj31ws+JiSK7j574r3/AOzuf/tPHssUSxCFFDWZ3D6LDPfDpGIomPpa0Qth5Eu4dDko5xFdUs2e9kkBzFavb34m2EpILIDhpN2mxXXi0swzdbz27Sr0GjRJMNlzb7KeLcW/iR5CngiMldaldKs1gKq3GuHZnvSXi/qzqQDa+TX0we/syb30FfTcuJc3iSYFDRZ+8JIrDYtg38j680g8x5JYoSD2Euk0UtSu3YS8opmr1DkwmO4d7DadF6k6G6zVEnzRTfMcddG6pwD53IUX3vIy1vIbqhG6kBSdye0IyZVYFLQa2paZtc/I+XUTkP6dSe+wUu1Yf5dhzbzE5Qc6R+9+kidKEXKuhZsd+do+eSuWhUnxaxdLE68aaJ3tj1VYQBEEQbhKucX1TRzUY13cuu/IhoZ0fMpWSgHdm4tY7kSx9MA9Ubie7+0BKU7L5WRbN7pOwPR/uaV8X9AHULttMm8RQTjWPoXu0li9TYN4+K61KMhl3RxtXgHdGjB9klv3+2rT7ALpuCQyPhmVZYJdgSCSkzVqJws8bXbeERutcZZYI37UJ08C+/PhNGv3ff5EjMj/8ik+i89CS6fDE0bcHPVu60azgGOFdIjFuS6bslU8pXbCW8qYtOBjdGadOx+BBYUSO60Hw0M4UJHZHdjwf28fzcJaUI/d0x1pQhn9eFj2+eheOZpPSohM9w2QMSPRmV2wPKj7+no0fr2dj34mMaqPEtGY75UHNKOzYjcLH32Rlh5GoVAqGZqxnc/thZEc1/p6uVyLwEwRBEISrrHbVNqo+X4Rh+RbMyYexny5Gsttd+w0rtlLy1Nuo42PPOa7vzyQJcnZmEP7dN7iN6ofCz5vsCog5tBtlt/aEWcvo2SOExKZyTtfAvR1Apag711FRjb2kHFtGLsNmdOP7g3CkDPqd3OtagePPWvrDkdIz15awZeShig3HXQ0d7afpYsylbVkWJ61a7BotCh9P17kZpfBVal26l+wVyWjlEgdmzqLj2oUcad4Ov+PH8PBQo1LJKXn0EVrsXM/aDCst3S3YK2qoXL+PbXc8QlZAFJ5PTKGLuoyoD57E/MsGHGWVlL06C73NSLXaHUdxGZqObbCeKCRz3GS23vMMzdydmGwSLXasw/TfOUhffE9IRhoFizZzvNTOXcXbKF+whgM784h2s5KpCiA1qDWJuUn0UxWQXqFECg6kT9ilPQPXiujqFQRBEISryFFVg/VoLp7Tx2IvKMFRUILxaB6O4jIkhwNHSQWO4nJULZpiTUrHKFeg69cJmfzcbTXOmlqOzVlBtKTHb+ZU5Nq6xHmbDtQwtDSTo8Om027JMnTjOhPtC9G+9c+vXbYZZUggVl8/vj6sZEzSYhY370ur0+n83PseRpfVtfD9UagHFBjq/t+eX4iyaRCS0Yxh4x5iZ/+EPNCXCpudVqNGkmZtxplRikYbrEgxMHT/cna9dgTfI4cwtktEahJMWbWKnA49iVQraTF1CP32bGFheQ1bmnZieNKv5OhDsD7wAfq4VgztFYxC8kYX5UfRvw9i8PFEHROGPb8Qz8m3EO3vg/n7bZhiYtEaTdTK1FTcdhtxvkpYUk5yyx7EdmqO2lmB7UQhEXu3stk7lo4nUihX1VIZ2JRS7yC6ObLIP16E2wtTaXZgA6de+5xF/R7lsY7QSDx8XRMtfoIgCIJwFRlX70A/tCcKbw80LVug798FzztuwefJKXjfPwG5hxt+rzyE7wv34PPsdCRJouKtr7AcyWm0PEmSMO1IpfLjH9gT0412D49AplFjLyylqBaabFhD0KTBHDjtILS2GFWz4AZlOMqrsJdUcPrwSea6JTKkKAmFSs5DW2YTNrQDNuTMP1DXSvdHMhm4qaHaLGFYugHb8QIqZy3AtDUZj2ljkKmUyLRqPL9fQK5XE6osdef9eAhG7lmEX34O6ugwPrjtn2wY/zBuOPhu+EO0Mp6kR5yekIUL0HZuS9z2NczTJFC24yBuOVmE1RbS+vExOL74AVvuKaq+XII8yA/fF+7B+/4JeNw+DPvpYqwLV9Kp6CC7vGOpPFlBblwHSmolEn6eh9zfh58TxpAb0wG3gd1w1hhZev+rfHv78yg//y8FuHGkdRfaFh1BOpxJuNKI7X9zqT2YTcGxYnqkbyZMXkup8XI/IVeWCPwEQRAE4Spx1pqwHT+NOq7hzFZJkqj66mc87xyJKiwUmVKJTKHAbWBXvB+djGXfISo/+QF7SYXrHHtJBZUfzMdRWon8sWkYQprhqwNrejbl/5lDyufriAtW4GjeFLfMDNwT4xqtV9WvW0ixeJER24lHWpnwPJjG7q634OUw0kTn4Mku0NwL3tsNB4r+8H4MRtof3MaJV7/AvO8QnvfciiY+Fn3fTthPFGA/VYxuWE/kWg0Dk1ewZfEBkk7YiVjzC7rkFA57hFFYbkPeoS2t1y7hvRZjeKarjI6nD+DML6DQw5+dH64kP7OYB1d/SJPsw7gtX4VUWkHlR9+j8PPC97kZqJoGouuWgLOmtu5etowmzS+GqhobvoO60L8qnVO7MzgdFsOQ795DhhODbyDxh7dTlpRBxSc/sLnDMA7KA2gXCP896s22uL64JyURWnEa+QuPUHDvfRyZdDdJhXI8np5ByPEMDj7/GWk/7bq8D8kVJrp6BUEQBOEqMW7cg35Al0bHyxnX7kQdE4YqPLTBPrm7Hs8po7CfKqJm3q8ow0ORu+mwHMrC844RKEMCWJEFPeuW0CXn1z2k6OPwW7KKPWNHk7IZ2uzbz9djR8M+0CqhmScMjoQTx6vYn1pGlLuDlneNpmr2AiqHDSN+fyqeU0dh2XcYz3Yx3NvBhzaB8J/t8EBLEx32rMZZXUvTxPZs8x3M0Np0sNqwHMxEEx+LaVsyPo/fiXnPAfT9O6OXy3EczsL49Tzi7x5A4RaQb9lJsFrJkO/e53SXXjw9OoB9P2whfncGpUFNaZq6Fr/QpjSVO1HKoUrvRYC/F+6Th+MzYxwATrMFp9GMJj6OmqMn2O7TmqOlEoPd5QRoqrFt34eqbQw6ixG/zZvxnd4HKT2do2EdCTh0kpaLPiNDq0G1IZfuPqG0a6Fn50nwUEP77H0ENvFirX8MicGwcmUuxuBQqhLaY4xtR9YHPxLm59vg87qeicBPEARBEK4CyWrDcjALtxF9GuyzHT+N9UgO3o/fec4ylE2C8H5qCpaUdJy1JnyenoZMJsMp1eXBGx4FhZV2SnYdIU4poZ/Qh+4+Btz2rSQxzEFIP3fsTrA4YOGa06xNLkW1Yh3tldVodd6UPvsu9qJyTuaaiQrzQt9rJJqWkVR/sxSfp6fRramMiFMZrHpzCydHD+KOqRH4yEH+2gq0t7emcs5C5HodktGMrk9H9IO6UfnR9wR/9yY1C9cSsG8n2aMnEPDhp+Tpg1DEt2CNPooxyUtJTP2Vk0d30nHjeo6360JUbjryGbdTWmCljfkklqO5OLwCsMU1xXGyGIfRwrxMDSMq0lHGt2arogk169KIGxNC913LsDdvSpJHC5STBxL5+aeUhTQn4O1/8F25ljuO5bHJpw3905aw4e7n2OHfmtDCXLrkp3C0pIaspq3JDG+Lh4eKE35e1MxdwoKhIwnZtJllQ8dDEpSaFGi6TWCwuZABV+qhuQJEV68gCIIg/IHTZMGanX/ZyzVtS0HXq0OD1j6n2UL1t8vrllq7gJkCMpkMbWJr9L07uo4/VAytA+oCuhWL0mmiMFEc0JTA6mK044eiOnkKdU0NSE5UClBlZeO+Zh0lJytp5e9EF+qP93MzkHt7EvjZy+wbO43wh8YiUypRBvujiY+jdtlmqr5cgj4zk8mzprHAGsGdS8Fql/AtO03RnCVIZite08egjg1HHdEUR0EJmk6tKX/zC7KL7VSNGYV930HMWj2+JaepkWuIzj/CjlffI6dTH5oe2INnh2halmRRFhrOoexaWvraUQT4Itns+BaeIAtvPMYNIm/+WqossOzHg8yVt6FpTCC98/fSZOXP1EZHc/iHbcT5SYQu+I6Tx0rYOOkxesXpuCu4lI3mAFrPn8PphC7QoQ0x/jK0rVow9NXbqLjrDvy8VIzfvYCuJ5IZMTmegK4tGTD/PWJivBnRwR2rA8K9oGsLFffe2uyyPytXkgj8BEEQBOEPjKu3U/3lEiRJumxlSk4npt1p6LrFN9hX890K3Ef3R+HpfsnlbzsBPZrBD+uL6PLtx2S7hdDu1anousWT9fNu/DwU6Pt2omrWAhwV1WTN34D7AxPRVVWgDgvFY+IwTKu24T6iD7kmDS186s9WVQT5UfXFYlQxYXjeNRKFTk2/cNAp4dl3DxOYcZAKq5yA/z6FMiQA27ETKMObYFi6AUdlDdXHi/nJpxOlzSLpW5SCscZMdmgs1oIyAr2V3P75S8SXHkXTNBB7fiEKDx0dPnua7vf0xbF+Jx53jkAZ4IvK34vgPdtJNnqw+bARVfpR2uhqabZvO75Pv4D1aA7lp6s5/tmvtJ4xBN8QTzyLC7Dp3Oi3dxk/f7AJx9a9SFYb7uUlJA+8jbFxda2lE1rB8iyQadQY27fHc1h3jnYbSPnq3ZgPZ+PrraGFoYCK/cfIKKu738Mif0+Hc6MQgZ8gCIIg/MZptmA9moumU2usB7MuW7nmvYfQdmyNTFk3wmpZJpSbwLQrDZmbDk2bqEsuu9QIemMNB9//Gb/16zFLSnrOepQmwXp0fTtRujGZkCAd+n6d0XZPoOi+V9nRth99A014VxVjMNqRu+lxVtagiY9l32noFArGrUnYK2qo+upnrPuPEjj7Jcw79yM5HGSUQUs/if/5pDJ65RckK0PZMPERV8oZa/ZJajftoeqLJZR4BrFGHcOde35g2NqvQaNGXm3A01KDekgvWgYr8Zk+BnvOSaxHj6NqGozHuMHUfL0Uxc69aBLiqJqzEGVkMyoTO2NU6ch5bwEBhhJGzPkXzcrzaZ23ny2THqaqW3fyjCo6/PAKfj3bYvh5I8ej48kbNoYh706hS49mHP9+A7G//Mjx2ARkOcdJLZDw19WtMKKS141/rLVB7/LDFOv9WayIpV8LOdbTJQSP78/JtSncUbCFwhqJ+IYTpK97YoyfIAiCIPzGuKFu8oUqLoLqr35G0y7mL5cpSRKmTXvxfvIuoG4Zs8N789HsLqRD9l58/3lfg3MsaRnIdBrUMeHnLNtptnDgq23E5OezLHogwz2OEqaJQ+9Rl8PPKcmo8vJHV1td97qmlqIOnel4YAvmanea+ihJ79AXtx9W4vXg7UgSnKqBUJWFkvnLcFb9gN+rD6Np2QIATZd2VH2+mBMWP9pU5eJo2Ywodytbpt/P8nQlg5pbaZGVinHdTuwKFWlj7iCp9xj6hUPw5x9i9/Jmf44dYv1YH9uP+N3prJ44nvCNa2haZkIT3xpVE3+cZgsyNy3KJkGYwyOomfkmx4aMpmlpAWEdmuO1dBW24GCsFTUoDEYi35nJga+SOVwMwye0R+2wUjj9/5CsNk6cqGbQPW2QKZUExgRzsrqao8NuZXVQHx6zHiLvjXU83DuCH/YlEhXtQ0IQpBdLaPKOo6qoAIscz+AKVj/yHMd/3oBSrcRN7iB25Q/Y44bjOF2Cpm30X35OrhYR+AmCIAgCv02+SMvA7YV7kMlkyPQ67AUlKEMC/lK51oNZqOLCXQmVjx430j9lBca8QmS9o6j88FsA5G46lKFBKJsGUrt8K5LDgdc945B7uYNE3bIc/PZfScJyMIva7amk+PVgd+IAXuojI+Cpj/F67REkCWqscLpaIshZizI0EOOOVIy7D7IlfjR37F+MaeNe/HslsvfISTSd26Lw9uBYObTwAdP6XRRaVAT271GXluVEAdYjOUhWG9asExQnDGfoP6ZR8f483Pp15h89QffflWzcUkDFoDjUrdpj3X+U2n59wQHBSgsV1XZ2HS6haWUpywfchU+vThQ4O/Kc9hCV1bkYqyqoah9PyvDbKNH5oKqpJmjPdkL3rMFfp6F1ymaUft5g0lEZEUNRSBinJt1JSPJuYvvfT2TnjpSOHE3uvB/x//RHFP5eHH3hJZo+9x/C4gKRJIm0Fz7H4e3N0f6jiZC7YWgRSn47iYrcbNruXkvbdDPHIhMY5KnEkFvA5jH30Dz7AKXpqyiRNGxLHEt0WQ4hyespbdKEwin/wPef9/+l5+NqE4GfIAiCIADGzfvQ9/l9woTb4O4Y1+3Cc8qov1Ru7bqdeN8/wfX61Pw1tAzz4WTPrhzrkkCXJnXbnQYj9lNFWI/lY805iaZNJCXPvIu+XydkCgWcWblDJgMZqJqHsmLYDFanK5gzEJpbSym22lBHNiO1EBalQ4eqbDontsB9WCcK7niOgj4D6bNrGV7ThlJx/DSqsGCab8+kdOo03IB9p6FHqJ3y538mO64nJ/eX0Gr9LoKfmIjX/ROQ6zTknDAQ/cEPVH9XgHHdLrS9O2JcsoYBt3bnvYJh7Dl0lDtzSnDGtmaHwZthUXBq3mq+CBrFjF3voPFy40hkBxKSDxC5cw3pBUdQWUwUd+hMafPWlOh86u7/qXwCC48T1b4Jdn8FlqR03Ib1wjrjDg4ftNHiuWcItp+mpKSEtZMeY/SuhYR89THVGflk3HknVouD5fuMPNXEHclu5+SPmzAVllPcpSd7qt14vT/8kgG3t5axSB7FAm0UQweZ2DgnmT6rZ5HVpgut5BUMNR3i4w53ELvsV8JHj6dc34L84NH0/uo9alq1wc/h+EvPx9UmAj9BEAThpifZ7Zj3HcL3+RmubarwUGoWleGsNSF30114WU4nVV8sQdkkEGVIAMpAP+TuegAs2fkYTXbcVTW0HxnPnBRcgZ/cXY86NgJHWRVe08eg79sJ054D2LJO4HnniPrXkODnozBvH7w3CFr4QtnrP+M2sh+SBBtyoWsTKFuWQsgLA6n8+AeccgUVaTkkdgrBsHQD+lF9qZm3jLBXnmH3KRnNveq6eT2XreSE1pPoEA0tXruf71aepveWjcT16ADA/lp34gcnUv7kKyiaBoPFitQqlv2apmhUMvrK8tBmZ/N177u4062AykO1HDjppN2wIOI22sn18SLQU8ktafuo0EiYvLwJ8PQmwHQaTc1h1CdqUMe1oCpzA0dlaip9g3DLPI6iaSBOg5HdGUZqZG4EPnUXJV8sJLZbHJ00eSweMo3EhV9h7NCVssxC9E4rD9Uswbd/B6oXrmPfURPxpjK+ajscUzV8vR+i/eD9vdAvrG6Sxvv7tcQV5WNuH0/JiHG0PpSCblB3UnLi6GI7Sdb2JCYMaoJz4TI83n2e7YcNhFPy1x6+q0xM7hAEQRBueqbtqeh6tK9rWfsDfe+OmLYmXVRZxnW7UEU0QRkaQPnbX2HLL8S0+wBOi5UT36xB36M9yvBQdCoZfjrIr/79XEmSMG1PAbkMe3E5ui7t6uq3+0C9a+zIh535dQmbE0Lqgk3z7jQ8JgxmfxG09IdoNzPexzIoevA1jBv3UNipO6FaO54ThuD79DRUAb4ofDxp4qcmvxqyyqFVZS6W/Ucok7TETB+IuxqmjwzlqGdz9v64F4DcSvDJyUSmVOA+uDtZdz/A8lQjfZZ+Qb+iFPrm7ORUk0g6J6/n8Dcb8f3gE2T9u/Ng6UZkOg3bPGJ5wJpEWakR36KTlMe24bjTHZ/3nsfnuRnI3PUUP/YG1pR0WsiqycypRT+oO9p2sWgHdaN2/q/4aCRaj0okbGJ/Dnfsz0/xY+iQm0zGk8/gqbDTojKfNXc+hfsT01C3ieLo91uoVWpJ8o7huElNlC+EeUPXIDs9m0iszoaRMRC0cQ0VngGke7fAP7YJ+skj+VbRhk6hsCqqP3Fp24n6dQHBj0/miNMHc0gIxoAba4aHCPwEQRCEm4qjtAL7qd/XHZMcDkw7UtH1bN/gWE1iK8wpR5Cczgsru6IaS1oG+oFdUQb6ou/XBZ/H7sBZU0vJE29RfKqKuKN70HRsDUBf72r2rUqnZvE6Kt6bR/mbX2BJPYJksVHz/QpqflqN25j+mLYmYT9d7LrOskww2eGeukY4jJv2ogpvAgoFO/cW0e3QJmzjH6SZtYwK70D8lnzE1kGTif6/KdT+uhHj5n3Urt+N96OTMW7cTYwvrD5oouWudZgqjdT07YvWs66VUq2A0Q/3wpqazrfrSgi1V2Jatwtnq2i2Vrrj+HkNEx/uStRL0wmxVWHckUrL5noq/UPw9tOzsvftjMtcS/WsH9kf2ZH0nkNpsngByoJCrCHBhNSWUNmhE6t+yaDizS+xHs7G447hNFk5m9NPPYVdrmBfejUZ3QaydkcJR+T+VO9I46FVMhaEDSAxzoM79/5A56dHMW1sGNmd+2Isq+F+SxKv71Mz73+72a8KJTR1D2tbDkQlr2vZbFl9At1Hc2i9djFPdZHIW52M0ehA6a6jNLoVuZVQZqobXvlQJ1BmZaH30lHh1BAdouJoGfSqzuDg9uxLexCvERH4CYIgCDcNSZKo/uZXquctw3LoGADm3QfQdmrjSrXyRzKFAk18LJb9Ry+o/JrvV+AxcRgymQzDim2439ILubseXbd4lM1DSB11B+r9BzEsXEv523NxX70WS0kVtInD+5FJuA3tiee0MbgN6obPE3ehigmn8sNv0bSLoWruUpxmC8mn6wKXGLWBoBNZGFZspfw/c5AkOPL6d8Qf24dt/U5Ox8YTNrwz2wdPZlOJG32bSzgLSzBu2oc19xS+z05H3SoSe+4pOgfYUCxaRkDfeCoKa2g1vlu99yWXy+j+3Fi8l/6CaslyJJmcbI+mdPvHbSR2bUrN/+YjWawEVBdRFNicoPxMuu38FZ+sI0y5vyMbw7tTYZQwVZm4q3QHpTYlTfMzCB2ciNvBg+grSpB7ebBlwEQkixWvKaNJKYC1ORChtXC6BvYFtiZnTw41A/rRMz+JJ2OriZFXUjPrR7zuHouySRAbc8GWd4rSuLZsX5PFw4tfR7ptBMObGAkPUFEbFU2gXmKmYSt9srcR9dI0Tng3I+7HubQvPsLezsPQZx+jWZdo9pyClVl1axQ79+6nz6kklg2+m/TugzDMX4a7Co4v34fFw+uvP5hXkRjjJwiCINw0LCnpKCOa4D6iD5Wf/YSjqgbTliR8Z0476zn6vp2o+mwh2g6tkCQJ45odmHbux/eZ6cg93H4vOy0DuY8nquYhOMoqkSwWlE2CAKj5aQ2VgwYRZpChH9QNx4TRnK6pa7UzFsKccogGPJYd5nSfARiS61bhkKQ4ZP2iCNqzE/+MGixPfsPbXe5lgimNAQdSsXeOQqbTIPf1IuC9Z5ifpGRG2TaqAjrjvu0Qmib+NGkbxsF16XSt3IEzIZag2S9ROWsByGR1q4B0i6fs65+xe3piWLaV9P4jmezXsF1I4eeNKbQJbZI3caLUSsWt/QlyA7onoPD3puixN5AOnsDbbEE7eSzvKEfQx5pNx7lfk5CaR35cAl8PfojPcr4ks8JEU5MB7dffs23C3Yx5agCf7oOElQs50m8IPVQqvjsI5qP5ZFbIMLeMQ5LJUQX68ob3AWxtgjD836v0iY9jXvdx5Bv82LUJTpeYecrLjI+xBEeYjuW6RMJPHEWjU3Ekz4inwkHvDd/SckALvMdOZvERGR37R2J9ezmZrXoQoLSg0Sj4PlNFhQkia08zLHMDW0r1lE2ciH+1gvXOKDQ1ebivWIn/wRQSHht4mZ/SK0u0+AmCIAg3Bclqo3b1Dtxv6Y1MrcL74UmY1u9CstpAdfZ2ELm7HrmPJ+bDx6h8fz7I5XjNuJXKz35CsttdZRuWb8H91rogoHbVdtyG9QLAmp2PZHeQpAuj7enDaBJb80UqFBnAYoe2QVBUC619HbSQVTO8sw9T2sGjneDxLvBYdyW3P9mb3h/eh666nJmr32F80U6iX7izbpbrgUx0XeNJK1MSX5WNbf0OyheuJbCyEFV0OL5fzMWrtADvJ6bgNrw3Ch9PdD3aY1y3i/9n766j4zqvxe9/h0EzYmZmtGTLzMwYJ3bixGFsuG3aNE3TtIE2nDScOHYMcUwxs2W2JdliZmaNZjQahvcP9Tr1bfve2/6Kt/NZSyvWOWfOHGnOWdp5nmfvDSBNiER/8CxhU1Poqu7EY/FU/lTnOI3BgXdvB+4OExK7jUq/WK4eLKflze3ovzuDQSRDIACTfwDfjF2Lu5+KdXdl8LuoJXgVXydq7hjEDhtVfeCwOZB4qZDERRC4ZhaVvbCeKrRCBRWqCHZXQVmPkzuM1/HtaWWFpAXhe18wyd6Gfn8esvR4hMvnk+eTjsHDhzPNoJLC6+JL+Ax2IvH3RrBsLpa1K/A+cYwCtygcuhHGHvwaw5yZhC2fhMkmoLPbgN/ePch//TS1A5Bx7Ft6oxLJDYZgNcw/uYXLeS0MGqG4SoupoYOuqzUMDBrJ2PEZdcog8gt7//iX9S/MNeLn4uLi4vIfYeTweZSzxyOQSUc3CAQgkSBLiGR4ywHUty++0XnivxMH+zL40kf4v/nsjVE85Yxx6L7aj/vdK9DvO4XbgskI5TJs3f3Y+zVIY8NxOp3odx3H48G1tFXC7NYmiiZNY4wSZkV/f/6OYbDVNROQFonqzyQQH+hWo33mR2S+/WssFV2MHL+IODyIkeOXkGWn0PHy52T21yAen4b1egsKTxG669VcW3ArCeFKrmtg7OiyPRRTs9H85ktkuenovj5I89gpjN2/mSNR41gZIcJqh7euwhO5o2v8AKr3FeCTFYO9pxanxM4D5bvoCYoif/witK39aDoqmG+tpCspk67Nh7n78flU9IpIuJ4HCVGU9Ap4tvFdKqKyGNe5H893n0GdncR4bwdbD3cSfnoPSVnjaN+zg5ImI3d7QKS2kn4vL74OnUnxah8iEyFq3+dU+8dzfDiOtWXfMfHCF+zq82ZGoBnD+bMEfPgzbL2DHM9r59ZZQkRz0rBuPojWKSFyZhoD9bXoPzxHbbuJyRIRZycv5nyJJx1j5vGDyp8y7BVFmQ2yDc14tjfREZ+Bf2cjD1ZcIX7NZHb16YkydOA/YwwRFhmhE/99ijfD/8OI37lz51iyZAnBwcEIBAL27dt3036n08kLL7xAUFAQCoWC2bNnU1f3P7e/+eCDD4iMjEQul5Obm0t+fv5fe4kuLi4uLi4A2AeGsNS3IB+XdmObuaQGWVI06lVzkESHMvTB9tHRvz/gMJnRfr4H+4AOeXYKQg/1jX3y7BREfl4Mbz2IrXvgxlTw8NZDCGRSHCNGjHkFyLKS6EBFlF2D0MuD8x0ipkbcfH3TIqDuVAXysSl/8voLOsBkA4NNQGB8IEJPNfq9p+h75rcIxCLa77ob1dhk3JdNR5t3DX3OWNzXLqDE7sPsNCXTI+Fs6/c1oAUCAerVcxl88XcMxCejTgjH2dxBefIEVNLR5BG7A5qHfv970BvQXi4j3FeMtkeH6O611J4sxz5jEqsnqMkpPk3O5HDkUgHbpt9Nl1cwuje/wtCjYWrhEc7OvZ2C7Dn4yh0kHN2NXCrku15Pur84gPn9LaRt/wxzWirRORFc8krC6OuPd38npj4tXosmc97ky6QwAQ1DAq4Igqgv7eLxCUJ8713Bt9mrWKLqoaO2F4GPF8M7jlB1tISE0ovYf/Uu3QcuIbWakQucDF2rxhoViXLDcvKWbKRoxQZSxgQjEMCcSDvSlFiitB2MXC3Db+c3mN29ED1+DwM/fpaY957GeSGfKb4mPs5ah3xCBoH3LKV+/1+W9f3P9lcHfiMjI2RkZPDBBx/8yf2vv/467777Lh999BFXr17Fzc2NefPmYTKZ/uw5v/nmG5566il+/vOfc/36dTIyMpg3bx69vf9ew6guLi4uLv9ahnceRX3L/BvFmZ1OJ4bjl1DOnQiAYvIYlDNy0by9BYfeAIxO0Wre+Ap5bjru6xehnDMew5mbByOUi6ehP3gWWfroqI/pahmi8ECsLZ1oP/0W45USlLPHU9gJmV3l1EWmkh34/Sjaf/GSO5H19TLs7f9H196mg0vtMDMSaO2g26GgXuhNq08E0sRoRCGBtL27k8Syy5iKq+kPjiS0pQp9RSPOq9dJ9HEiE0OaPxR2/sGJhULsg1rKnb7ENJWhd4hIT/bmSB30GWBFItQPjh46uOskPdNmYb9aglaooLbXTnB2DPnnmhk8U0SZfwLxB76hJWM82UFgS4xjolrL2J88iaOnH2NoOJKGBoxJSbjpBpF7uTGt6BgnZ63jpCMUrUDBqauDfPPhVRgaQmXQoZyQQWVoCoVXOpgcPlpGZsAAjTHpzNGUIMLJex8Uk3XsWzQrluOeGcfF1fcj+sHdXMhegH9bPdZeDYdu+QH501bQ8+D9xHfXUuAWxb5WBWODoHdk9LPoNUBkTyPGmFj2j19NwIU8IlqqyLl3FisypGj0NtzOXcRt6QxC/WXMvLCb4ugcosfHMv7e6f+Pd+c/1l8d+C1YsICXX36ZFStW/NE+p9PJ22+/zfPPP8+yZctIT09n8+bNdHZ2/tHI4B968803ue+++9i4cSPJycl89NFHKJVKvvjii7/2Ml1cXFxc/sNZqpsQuimRhAd9v62yAUlk8E2FmWWpsajXzkfzztcMf3uMkSPn8Xr8dmSpsaP7MxMxl9Xh/INODaYL11HftghTfjmWulYMp64gjQ1HMSUbS1MHksgQBEIhjUPg0VjPWUUsU/7baB+AramDsLQQzreOBqZ2B7Rq4Xg9/PQ0GK3w0nkIyz+HRDtEzI/WQ0QorcVtNK6/i+DhHmxl1VjrWqCkktBfPMilubcRre/C8vuM5OkRcK7Rhrm6Cf2+0+i+2ofvK0/g8/kX+I6JoTkxm4U9+eyogNtSIdJzdMTP2t5DR7eBuDAlfTUd1KROIL2/FqEQcs7v5+yOQrLjFPR36xHOmoyfScPUY9vwfHQdQj8vJLHhxB/Yyexv3qf7ciVqnQafH96F/5JJLDm1manlp5j+s1WcXbQB67rVRA20EDk5gW0Rs+gUutPkUGO5VoHWDNMiQRwZiqa0kUM/3EqIoR/fH96FwdsPYUsbJwjnvgNgGTbS1zPCTyc/wUCHloGoeLKW52BFyGPGK+wrtXD9SAWeO3fx5odlJPlA2YkKNLFJpAaJ8XOMoBJYkcSF091nJHv/VmQZCahXzsbz4bUk2Xop3nEZm8WGm/RPLIj8F/Z3Se5oamqiu7ub2bO/z3Tx8PAgNzeXy5cv/8nXWCwWrl27dtNrhEIhs2fP/rOvcXFxcXFx+f/jdDjQ7z11I+niv4wc/T754g9JIoLxfGDN6H8fue1Gxw0YnR6VZydjKqwARqc/jReKUC2agseDtzDwwnso50zA2tiOQClHlhyDra2bjsZBwpzDtNkUZIeJ/2i0D8BUWE7U9BRqBmDQCO9chSvtcKoZnsyFH06ERDTkDtcTvSgHr8EeQi+exjvYg+p3dqMa1iDLTcc6pMca4I/RIaJL6UvwQ6sYeOE9dN8eY+TtL5l8dCvVl5uQJkbh/dx9NBmlKN3lDF2uQDN3Lq2FTcR52HA4Ybh7iHadk7KPj7I5aj4dx6/R2jCEU6HAdDaf2gYt8kPHaOszUbXrCgMWMbV4kX38W6oWraFZ48DW2k33b17BmhCHt9xJ4KVziBRSGDZij4uhX+WL6ZF72WyOZbJ8APdNW9BMn8GSWcE8e/odPCV2jqTOR33pEm+MG6a028mc/mKaitq46pXEuh/PJiVIzDhdHWnTYrl3jAC10Mqd5XuwR0UgdEJoRz15yjicwyN0h8aif/NLlhz4COuglmtZM1E2NZFw/giJ9n6WT/djsLqdwN4W9OvXYjh8js7XN+O9cgbysakAWEpqCbhjIQ1hSZT85DMMeQV/uxv2H+DvktzR3d0NQEBAwE3bAwICbuz77/r7+7Hb7X/yNdXVf75+ktlsxmw23/hep9P92WNdXFxcXP6zGE7nIx+XdlMAZ6ltRhzge1Mplj8k8vVC5Ov1J/cppmQz9LvtSGPD0e87jfqWeQjEYhwDWqTpCRjPFuKw2xHYHXg+tg6n0UzNG/tIH5fIpcBkNoT/6eu0NrajWjOPUDO8fB6emQAXWmF5ImQFQVUf5Fw8gMjPG/m0HAZ//gECsRiLXcKEkjwG0zIZ6rMjSclE/shdXN99ioXmE5i95SCXYa1vw/vZu8hAzPsFkJ4ANgfUf3EM+f0bsT3/Co0zPKl3T8G7poJf6eNYufM9YuOyMYSE4O6nxvfrSyjNWvzO7KVjzVqilk6g82ktE5qv0eYWgJtVh/jgcU7fsY7BDivtT71BVGAAl4Y9SBnoRWIYwSqRYZg3naP+OYi/uIR7gDsndNEENVYRe/U0eZJIxpw8xrFrCox6NTUSd/xNQ+RPWoro8e0kBSo45RXC4emP8pykCEdPOMIgP4wXi5CtWcjZOri19ghHOxXUpCxgna0cyUAdysIdlNbIuCiKJjlRTILMwLHIdKYkq5BlL0V59AiSllbseiMBO3fg8PGmLGc2yekCLly28kCO+43PynixiMi7ViC9JKO2/AyR3RqUf/pj/Zf0b1/O5ZVXXsHDw+PGV1hY2D/7klxcXFxc/gU49AbMBeUoZoy9sc3pdKLffwa3JdP+qnMKlXIEUind9/0cc3kdkpgwnA4HwzuP4vnwWhTTsjEcPMu2gMl0OFWIA3yoi87AtvsoMZMTkPyJ0T5bVx+iAB/K+wS0aSFINVrexeGECaGjx1yp0BJcfBXvZ+7CWtOMY3gEYaAPjivX8JiaRVyUClVfF1/FLULmo6Zx3nJiXroPryfuwPeXj2Hv12C8VIxcDKl+kN8BH22qot6iRGIyUTdmMtEXTvLwY9ncob/GhJKTZD29kgkXD1CBHzHvvoGgrQNZfARBExKYOy0E8ZfbMa+/BaFUTLDYiBIr6SOt5Hz2FtOu7kfZ30uVQY5izwEEhaUM2iT0BoZzrcuJT38nlXNWsNc3l9l7P+aW018gCPBFmBxP+ksbCVA6OT1jLXPumcTE+osskLQRpeugXerNnuiZqGJDaMkrY/e9H/PzD+uo6rDwYb074soahEYDqsE+MnuqiTl/nCqfWAZvv4PvJq+je8I0LCNmdqUtYezJb7FZbSyOhyhfESWTF1H17Ad41NcgmTqWIbGSTTUKyizuyH8/TGbXDoNAgFasJOHEfvJip7A5fO5fdS/9s/xdAr/AwNG+dT09PTdt7+npubHvv/P19UUkEv1FrwF47rnn0Gq1N77a2tr+H6/excXFxeX/guHdJ1CtnHVTiRZLaS2SiGBEf5Cd+5cSSEQIVUqU08cy9P42Rg6fRzYmGZGHGmt9KwZfPxLMPeyogIZBcKYnotUYyRX0/MnzmQrKKQtO5Uo7PDEe/NzgUN3oaB/AiAUiPv8Y7w1LEKrd0O8/g61vEN2+PHRrVqFIi8WmGcbdpCcqwY9XLkDsHwxYShOjEHp7YL5ehbmqkemRcHlfCRM/fBX39Gic+cUU5MxDbTeQXzJIhUaEo6KGnb+7RB8KLKcuERDli4fQSoDCgVgowNIzyME5dyEuKMYj3Af/llo6JF4kPrsW8bI5jB+owss4hKdJT3LeAVq9QqiJSidAYuPSuke5VDbEhM3v8ODmFxB5qDjyxK+oXrKWsPFxbNndyJDKm8nZvmjEbsyuPcNMdw3al3+GvnMAZ3s3WdZO/ARGhLcuZfmFbWyTZ3D6VBOp2z/H7VoR/ZHxfDP+Vmpjs7iQMgOHRMyIZbQfcWByKFNC7EgnZuN37AhnW6DuWhtec3Op6LThGDEiCg+hXQcX2kYDJY0R2nVwblcxO1VZ5H16Hq8wLzLmpxHy79W44+8T+EVFRREYGMipU6dubNPpdFy9epUJEyb8yddIpVKys7Nveo3D4eDUqVN/9jUAMpkMd3f3m75cXFxcXP6zWVu7cBqMSBOibmxzOp2MHD6P28Kpf/V5nTYbjiE9/h//HNWymSgmZjL00U4kCZFYGloxni+iYM39JBzZzS1l+3nzshP3mmrEd6zA8O1RHCbzTedzOOH6uSb6QqK4N2s0wzQrEERXr2P+chfar76j6qefEFJXjqWuha47nmP4m6OYi2vpSkgj8xcbUMzMxVJSQ9f625l4+lumBFrpGoGdFaPTuQKBAMXYVGRjkhnZdxrr4TOM3fExHy16iqTTB7Gdu0pG6VkS0JD4xiuElRXi29eBv66Xbx/+FQKHndAzx3CzmxGolPi++iQlqVOI/24HEdcv4iZ2Yk+MwxgZSa/Cm9aELOTRIYhFAs5nz8cgkDCm+jLO2EgGIuO4f6wQ9YIpxPY34hfqyYSlaTSOSDjbDF1aO5Mq86gZO4OZLZfhk68Jvn8FKi8l8xPFFE5fwW3XdyHbc4DN0+/FUlqLs3eApcXf8UHlJ3htXEapXwI7M1fQZFbw67i1dHsHo5bC1HBYmwwD8cmIyqpYe2c6U2PEpF07jVHtQYqli8D+Ng5lLaHgUjuSIydxEzvRW+CRw/BinhNdYQ3eUhseQ710TZxORT/I/8Qo7r+yvzrw0+v1FBcXU1xcDIwmdBQXF9Pa2opAIOCJJ57g5ZdfZv/+/ZSVlbFhwwaCg4NZvnz5jXPMmjWL999//8b3Tz31FJ9++ilfffUVVVVVPPTQQ4yMjLBx48a/+gd0cXFxcfnP4nQ6R8u3rF1w03bTlVKk6fE3ZfL+pUyFlciykxEp5KPfF1Tg++vHGd6yH80bX2ENDQGFDPWc8SiuF5FzZg+9lysYszQT1YpZDG858P25bPDFGS1qXxUrUsUIBKOBpX7TXvxG+hFOGw+A6Hge6uRIpDFhOK02nEYzQxvvoOe5H6GQidB9tBOPh29Fe7GUptgMFhbs5fY0J1Ge8O5V6DeMlqsxXikGiZj+H72JJiSKHF0D0ekhmM12ltySQvwrD6KcM4EGn0hCBHoCqkpY+vnLpNVexWYw47liBuqVs7EKxfT+/AOCK68Tsn4Oylm5aI0OxrywnrpPDyM6dQ653YQ1OZHZ86JRCWzUBSeQfnQ3mRunMS4ElhQfoNUzFP/3nqOvdZDEHV8SOdBMXGU+iphQYt58naHtR/AO98FS38LI8Uv8PM/JLdky5vjrcfP3oEIdgVU3giY0igCJFenSmXR2jVCdNYUEPwFeUjvTCg6yRN7J3VlQOTCarbxyQRiaqlacgM+6+cguXmFArEL/4XYGVT6Il87h+swV3DlZzZiD2/AUmPFRwmxJJypPBf4l19CsWE6gWoDNDlF/ejnov6y/OrmjsLCQGTNm3Pj+qaeeAuDOO+9k06ZN/PCHP2RkZIT777+foaEhJk+ezNGjR5HL5Tde09DQQH9//43v165dS19fHy+88ALd3d1kZmZy9OjRP0r4cHFxcXFx+XNMV0qRxkUi8v5+Ds5ps2E4fRXvZzeOjtoZTIjcVX/5uS8W4fHgLQCYK+oRusmRpcRgOHoBSWI0nadLGRcQgmrBZAb9AxBd15HacJ38rrVMTo7BXFqLKb8MQ2IS284NMbfxHAH+YvSHzmFr78Fw/BJGeRCT4/048bkGv752yMgk/rYpjJwrxNCtoXn2YrbFrSLWDsbiauz9GmRrF1HmriHr7Hd4zYhGv/8MY6bmEJHuzo68QaY2X8Zz1wkE3h5IH7odY42RcVMjyL9SQ/uzP2Tc8f10+XhTXDZAQlslUoEDUaQ/guF+rA4rejd3AvpHM4ev3/4SSLyJmJGG+51LaX/lKyyBwfh1tyBatYDAF1/DorJRPmUFY69cpMEKnWERpOtbUHR2Yan3wq2ynN4Ft1BU3M870ukkrBzLw9d2I925j1ZlAPkLbqdsxngeTzUg1OoofnEzMz/+NePjFKhuX4j7nmpus1XSq/CmMGwMk4qOY7w2jHxoCMWtc+jVwYSSk8RNjGFk/yHe8r2TRs1oSzmdRUBkpCevH9DgH+ZFnDCAkOYqoruqyU+bQVqKN8JeOCHLpSTRn8cLt3Bs5nqmV13AaulB+sMHSXAT3yi5o/w364EmcDr/q473/w06nQ4PDw+0Wq1r2tfFxcXlP4x9UIv242/xeuYuBH/Qf9dw6ioIBaNt1rYewt47gNNiReTvjTQpBmly9E2BoNPpvFHs+b/Y+jTod5/A88FbcFptDL7+BV5PbkC/9xSSqBBE4UGc/OUe0rsqCPzyZU72umH/yWuMX5TMteIBYp5cQ3SEiu5ff0Fxq4W0CZFIrl5DvX4RDpMZ49kChm9ZRYEkjFXefRiuVXJx5zXkmgFqpi8muCSfYF03QW88za5+X9L8we+Z54j98TrKgtPYVQVP+zYhybuAVaunP78Gd5EdkZc7bVY5wzFxpGqbqH7qORq+PUfmns20qwOYeeBXdLdp6Hj2baLlZsQ2CwKVEntnH7XuYYj6BxAZDfja9SgiAtntm8uGFxci3nMQxYxxnN1XTqzCjKdVjzM7g12bi5l5eS/XFq4j/vh3dIbHM5icSvq5Q0Q9tgrD6atI4yIYsorY2+VOomgIkcVMcsVlRLMmUWJ2p19jRuhwsjBbhR4p+8vMLC4/itTPAwti2rxDCeltYUfAZPwsWqoWrGbq1+9yIWM26rkTabtax3xtKVN+sopvdtejKC6la9lKPORwuB7G91WQrdRSF5PJ4vy96Bs7cSst4bv7f44jIY4f5MIzx2HIBOvd23D77iBRDaWEffkS4iC/G/dEfsdov+U/VZvxb+lvGdv822f1uri4uLi4wGiwpv1y72jG7h/EbE6LFeOVEhRTs7Frh7F39+P15Aa8fng3bgun4jSZGd56kMHXPmfwrc30PvU6mjc2/dH5jWcLUE4fzRAeOXQO5azxmEtqwOlEMTGT2ivNuE3MQDE5G+1HO7mc30NIsBtiiYgJP15B9W930lffy2WCSBBq8c2IQujnhbW5A+PZQlSr5lLSYmG8vo6+Fz6g/2fvI+jspfThZ4gKVxHS1Uh3RAJHtb6k+MG4xqsMOmVY0tM42QRx3hCQGYUiN52BilaqVm/gSuZsOj2DiYzyJKw0n4oeJ4XVw7R1mzkWMRl37QBFj7zF2aMNRBp6oKMLcVggIj8vrMvmMZIQj9pqQKhWYjFYqO8wkbZqHMraamQZCQxfKWdQ5kHgmCiEPh4M7ThKqE2DecRC8pXjOAL8KF95B+k1+Vh9vOn/cj+WqkbM1Y2UBScTzDCq1mbCGsqRpsRSnD2LWJUVtUqCUmDltD2EV8OXkzYrCWVyFAHv/5QLP/oVUXOzMEVEML39KnWJOaRHKtCqvJg4XEdTi45xZXnkjVuChwymzo2lbESBpLiUSE+YEALzFscirq7Du7aSQ3p/1O0tdI6bTOCFPBaGGMnvgJWJIBNDgzoUSV8fwxLlH5UASvIFgd3297up/w5cgZ+Li4uLy/8JIwfPIvL3Rvv5HkwFFd9vP34J5ezxCEQiRg6cxW3xaHKHQCBAHOCDcvpY1OsXI/LzBqsdxfSxmAorMBWU3TiH027HUtuCJCESW58Ga0MbomA/jBevo75tdC1hy7UmEp0DYLFgGDYz/os38TcP4dDoUIgh8/n1HHtpD/E9tQT/4BZGzuRja+7AWt+KYlIm1q5+jI0d2PYco7+ghk6ZN0H6Xu62l5K+72vCJ8QxZXkqdb02Ctod1L/6NR4/fZAd5VDeC6uTRq9VmhBJP0omVZ9j0YJIHA/dSeGgjI7nf4pPWiQ+v32LrMtH8BeZuPCT19AgI/PAVrSdGtqVfrS26ujxD6ezrB11WTkemFC3t9L2yGPUhiWTWHEF3eb9SJKiaazsJc3SiaW6CeO5awwX1+JbVoRBqcbTX01R2lTsCXF4OM1Ev/9D+lDgFIsw17QQtGUzUW1V7J+6DvfsBBwOJ16ffUlv2hgCn97A4MMPUjQkZ+72d0itvIzXT+9n4NAldCIF/VNn8NLERxnOzOKBDAfkXaJxwkwGfYOZuutjsp9YSlKIlE+uw55qaJg8h9TafCa5aYj0gklxMnIEPaS2leFfXUqZyJ9vU5bSOG0+ATu/obTTRrDayZ3950n/3W8RyyTsWfQQQ+9tu9HSD0BaXUNGSd7f/+b+G3IFfi4uLi4u/1ZM+WWj9dT+gKWhDePlEmw9A3j/6B7MxVUAOEaMWMrqkI9Lwz40jL1n4OZMX6sN/cGzaD/aiWLKGLx/uBH10hn4vPw4A7/8GPvAEADm0lpkGQkIBAL0u46hXDqd4S0H8Lh/DQKRiP6mfnxLCrEfOoPbillcCMxEnhyDraoBw8UiBl/7nDKdFD9/Jb1ST6zNnZiulOB+z0rc71yGEyHXrrQjKKlEl1+Fm7871sgIfBKDGfjpO5hrWzBdKsJ4rYppJ7ezcesvUY1NoVAQyKFaEArA8/dL6PuPXWUgJxe39FgU41KJPfgt2U8s5fSwDy9G3oJTpaItM5fzidNxenowbXY0Ps11RL/+KLG50XhnxiDd8i1OgxG/iiKEJiO6yBh2Bk9l0obJnDP74nBzY+Dlj7FdL8f9xElGatto7TSg8QlCJhPBhGzevfc3bE9ZiqG6mRGlmq4vDyHXDaG3i7l2zxM4RGJkKbHcduh3tPZbKXn8x0je/QVd5e2E7NhKZfUQJ+RxyGQiBFIJ1uomSqq19PUa2FwK6aFSmlfeSvWBAgyFFcROT8ZdBsGM8EGzN3Oj4bFx4C6DadFi9mQtp+GDfaT62Cn6/BTFGhkDoVEE97YQmBLGZWEw9ogwPlKOY/r53fR9dYCIjlokQghkhJ7QGC7nLrwR/DmMZob2n6V17B93gPlX9m+2JNHFxcXF5T+Z025n5Mh5OHEZ9S3zkMZF4DCZ0fx2E5LYMLx+cDsCiRjHsGH02ENncVs8DYFAwMiBvBuFm51O52iSxckrKGeMw+vZjTfW9DltNsQ+HigXTmXglx/j87MHMJ6/jvtdyzDXNOEwmjHsz0N963wcGh3DXx+kptWEb2YcCn8lTomUrrpeZgSJcVtyP8ZTVxjKu07ImQ14R/nTFptG2yubUA5rGfziO9qjkiiPyaZhXA6PBB1BHpbI8KVSxDIJyuRobF5qEIsRB/rQbBARWXcdcUYk4RuXcE+ojZONYnDCaxdhjK8Vv115xE/Kxtahpf+rQxSEjKFJF8hDOfDiWTGBchtXBUpyLa14tFjRfbwV94fWMvzVfvze+TG+3u50lVfAxfMQ4ItTo6ExOI5Qu5a0Zem4HX+Tqx6xJFw/j3jKWIYHesjLXcOC6EJqawaRHCzD/ek7OX5JRqgaZhaeQujuhvDQSSwpSXwy+0Gm7PuCa+nT0Mt9GOfZQZ00mMNlUkK9xWjTFnFmsI+44wd42qbl4/Hr2BPhQ0zBFcIbe5kSeIZPoxdxWwqEeoj4Th2Eb18N5882MaV/kKAHVzK95Bz7Q2aRWHeNzPAgqpTBzBzny7aWFJY+9SqKxdOZ+vpd9G/8KZc9AxkcM50gGwSrQTw2DuNH++nQifFWmwiMDaXRHsF9Nfs4Pm454vELSXxjG1qlB3v9Z5Kul5HyT3wm/lKuwM/FxcXF5d+GpboJWWYSytnj0X6+B2tzJyOHzyEOC8DryQ03ijVL4yMxXavE2taNas087Bod9t5BpPGRWNu6Gd5+GGlCFN7PbgSBAGtNM+aqBqyN7eB0IvJQ4xjUYq5pomP1kwiEAsxltVhqm3GbPwVpRgIjB84idFehWDmHmgNtpNhrkSZEcu39w8TrjNjbG7G5yRmJjadRryL+/FHcFk4hqqSWar9QKpZsRJSawCRtDUs7Cii7cAqJcABL7yA6gYyg3hbkD83FeCYfU345bkum07OziInP341cJsZ0tZTGr/uY32InO1pGX7svgZfOcVkehnDGSpK+285gfieRv1zKQj843QSBXY0EdTeR2a/F696VKF/8JcPh/nilJ+DU6DAcPIuxqol2uR+iID3+gSrqQ+LwN2nxyD+CYNWteAaoid+3j5Pjl+OZM4HQ2hLW9l3CY+Usah76HQKBgMcrg5kfA119RqIunkJgtaCcmEH8608TWyhA5rQxPVFOBXLGz1yMU6NkyfEdnJq2lntzJRR1+9EZegcTYywMXZARIICWzAmUeMdy+5e/YNHtXhT3TKCo3UbwcA+ZS9OR7dzCByt+zI8SpSRe2syaiAFKd13Du8edAaOatjGTCKwrJyZCRciEEARWM+b2HsTjMzngiGJeLJQ0Gbi36Bvi71+I6Jn3GPGMZ+jENV67722W9l9H+c1u9s5bSZo6nZRD35D71Biyq85C4l/XCeafwRX4ubi4uLj82zDll+E2fzJCNwUej95Gz90/w947SPCrT97UoUM2JonBVz/D66k7R0f7DubhtmwGhnOFmK6WoZiag62jB827XyMQCJHEhCFLjkG1eDowui7QoTcgCfbDbjSDw4EsJxVZahySmDAwW3C/cxkiL3eKuiFJcw4HwzgQMFTWSPbzt6IezsA6dzr787Tc0rkd4ZyJ9G4+zJBIgdTbm4zWEmqjQwhdNZWS7ikEfPQptupBHCMGzFYZYbOysNa3IfLxRJoQSd8rn9P/0I9xyw7nVBM0xMRzQQo5c8DPvYP+V7YhtFqISvAh5vAmdE4B8YluePk66bhSi27rJaYrRJTkzkM4MEDugZ20xUTQv3o5Iz/5BGV4ANLThfRL3PHKisU3Mo2u7y4hePVhNHIvol78GfqjIWhPXKHfM4C6xLEsH2gjN90LS7UGh0iEW10dLQHRLFV0MXGomeK9V3AOaPD6wTokMeHYFUr8Tn7HyVnrWTFYTOZICZL77mfPCSHEybgz/2tEWesQCmTMjxNwulPGY+PgfAtYHOAf5Ydpw1rUTY3E9dYjslmxzs6m7FI5KfGexMgMfHJdxvo584l6ZzNuubm8q57I9JYrLH37OS4tvJ33osfw0OYtGDoHaB83g+sZMxgwQmPDEPPydrI9ZwFTX9uMZcxkovwl+Nj03B09QnX4WExGERuLvqGlRUf1vY+w5OM38Xz/p/+kp+Gv41rj5+Li4uLyb8HpdGLvGUAc5IfT4UD7/nYEUgm+v36coTc3Y+sZ+P5gqQRbVz+SmDDsGh22ngEMZ/IxXizCabZg7x1AlpGA1xN34PX0naiWz0SaEIVAIsZcWoPTYMTz4Vvxe/cnOHV6fH/9OCPfnUa9fjGqxdNQrZiFyGu0rMalNoi0DGC32mh4ZQs1d96Psq0VwZhUPi8RstS9lxafCL665UcMh0cS6Cdn7PuPMeE39zJTU8HpX3xL+Y4L+Dn0iP29MUrd0L79Kt7P3oV+13HMFfUI3JTU/einRJ47yVsXbHjIICcY1FJI6q+n+tMjVM1cwjd+EyiZfwulazYycNedmFu76f/lR5zPa0O4cS3+/gr6B0xkV1ygR+dAlBhD4Feb8FcJaBR7cy0oHU+7AT9/N646AsFk4viAGoe3F+fGLaL6yfeo14n4ZPVzjCk4Tv7pOs6dbOCIfw6bnt2LxQ74+xLwzrtc7RIS11uLJT4Oh96I29yJFFxsRWKzMm9BNBecQbiF+PDTPAHBKoidFEd59kz0721BMGJAJYVJYbAwFvI7QSGGWG9YcfdYpgZZUcybhPj8ZUQlFXiPiaNr3Xruq9mPze7EGRhAUekAuzvUeLQ1Iygq59erf4GuYxDl3gPsECVhKyhh94IHscTGEKjpYvLpnSz4xSp+0nUAq0qFz5MbGPuzWwl/ZCUL9eWoZDD1ljG0tesRjIxg6BygdNwcrPWt/5wH4q/kGvFzcXFxcfm3YK1tRhIfidNqQ/PRN9jqWvB96VHEgb6IQwLQfbYbt4VTkGUkMLLvFMrpOdi7+hjecxJLZSNCdxVu8yainD3hphp//525tBa3BVMQyKTY6uuRj0lCv/MYXk/cgfbjnXg+etuNXr8DBlBYDIiNBiwNbVwKzWZ9lgT7KR2ftvsjspqp2XkWn2fu5KlwGDrsAXJ/dJv3I02JJWz9bISVnXQ88BKdHiqC5mbTaXUHdzVHz7aTPGKC5k6ML/6Iz04J2ThlEhsr9uI9eTVvX7QxteQM45xNuK2bRPKeY2yWRaE6cAjT8ACmIS1dbn60esfhDAom8KNvCS/JI0HsiXb2bOwnz2O3NNIblkTVQ0+yJkWAPyPUrvkxDV8cReobyJB/COM2v4d+wgRC9H3YxBLcU6JYritmnLQbzcVSbE4Bnnl5eLY24hSJESkiMQ8OMnG4GJ2ml4bZ64mYEo5TKKR/+3E6lq1m2oE9VBvVHFUEk3rpGMIV8xkTBF8bIqicupAZR7+mTnYrWYnufFAI4R7QqIHnJoPITY44wIfIuhJqN96G4ev9tC+7lSOtEiaqYkmvvUqhMo240FgyivPoEHviePgOvPplLLplCZ2lbUieeJ7A3ET8yq6jCPRixvU8eu9Zh+XAYewdvbQ+/AJdAwIivASEpidjev9rQhdOpL+mE7NYxcwN2SS+s4nfbHydoxY57/yjHoK/AVfg5+Li4uLyb8GUX4ZiRi66TfsQ2B2o185HHOgLgNjPC6+n70S3eT+ma5UgFKGcOZ7Bt7/GeOoKno/cinrtgv+xXduNUcXfn9d4rhD12gX0/+QdvJ7diCwlhqH3t+P52DpE7irOtcIkczPmslr6ZJ5op01FfOgELyffjsII9zaeIPiB6ciiZJgKypHnpuM0mhHHhSMAhj76hv6zFXhnxYLeQMOJEgYj4rFcqmDM7i/IT5+OuraGinIHs6KEzM9OQPdNI6V3vMSUIT1Of19s01MoOlqO5XIj8ZFChDoYm+KBXdtFnXaYhLxD+I2Lp3pYitYp5dSs9Xj1d+KZMxd7Vy/fjt3AXdLRxJa+Lj29br4Mj40hsaaQ7x56hfVHfgft+TQ0D2MPCaYnLJ6JxSdQ5KaibeyiXhGCyE1Bl9yHVC8rZQtWM7bmIo5hA5JAH2zd/SgmL6d8x0X6I2OZd3EXB0InYctNosUCU0pPMHTmHGWzp7I6Cd7WBbPs7pV4vLudk/OWMS9BhqDNSEmzAWmpEYPZhF0goP13e2j1CCf8kXVMabjEcPQ0pBETue/4Jg5dG8JzVg5zpwbyeYsXk+JFWKWjI4ZhgVbOefpyLHoKS09spiYygyML72B9/im0JYVI338Jc5WQicHw/BlwImFGs4Izl4cYc/wwZ2asxXbmDIKJS5mRt5PWVbcBkr/nrf835ZrqdXFxcXH5l+d0OrG19yIK8sVS34pALkU+KeumYwRSCR73rkKaEIliUhaatzZjOHIe/49+hsfdK/9XPXrtPQOIfh/0OfQGnGbraPB36wJMF4sQhwTgfucydJ/uxu6A+kHwOnoUp8XKmcAcbms4wYcZtxLoJ+fnkW142Q3I0uNxOp2jGcRzJ46uNTx0DmtDG7aeAYQNzciqapBW1yJTyZHV1ROdd4ShkAjcrCZkq+ZhunAdVXcHPS9+SMemI9j6NAwFhNKr8qOwQovHiBbtwgVMeP0e4u6YRdvZChxOMIsk2G5ZxoHxa/Aw6znywPPEWXpR+6jwjfBh2cYc/P2UdOvh0SPw+gfl1Fvc0Ajd+N0dL5HdVozI252rgemE3DKL4WEzmYe/wefJOyi2eOIQimh75DHkuiEGvQKR5mbSGpdB6I/uHB1ldVMwNGkKw706uo7k49dcS/3c5TSFJvHsROgagcIxswm3apAXXkMpgbkxsKPPl2NT1hJ1/QJN357Hv66c+bIuTlWascqVnOqWYzNZMIaFMxgei7a0gTTFMCM2AW63LyXl7CGaQhLoVPgS5i2itAfSA8DmgO5ffY7u7g285zOLkoeepW/YQU7ZWaSnzvHuwmf5rHK0Lo7DCdFeo4WxZ61MZeP1b4iYk8n9ccPMDjRw789mY5k8gScsV/72N/zfkWvEz8XFxcXlX561oQ1xVAjDXx/E1t2PzwsP/VFLtf9i6+hl+M2vUK2Zh2PEiCwp5n/9PubSGmTp8QAYL1xHHB6EQ6dHvWYug699gWJiJpLQAISeaspKe8lt78Ba14JGqCRc38mF1bfRb1Tx5jgb2t8ew+ORW4HR/sGyzASEchlOJxi8fbFdKGL4fBHaVSvwKbqI6rH1lHx5iZSJYZi7+sEhIbi9DGF7Cb6V7UgCvDidOIGa9T9mSrSIXfUSXlrpjZ/SQc8zb+AYHkH4+Ta8B4eomruAcznTWVR8AJ/+ToyHv6XJqSZ3/1eYLCCKCCEn2Jve9KmElcHRBkjzh8WNZxicORub3kjGiT0U5oxHE5hFdkcJ5b7jUHp74v3EEpo/2Y97xyDeK2cScuEMeqUHodYhdrvNJMMf9IfPYa1pRBwZQmBuPHVP/5phdRD6uzdwoVvKi9PgcB2MDYJTzQJE45dyd/E3bN+rQJSZTFUfjAg9qZi/mrJe+OV08JLDxv3QrDcw9/yXnFh6D6u8erl6pYKGzDlMOHqChGUr2V5kY2xiFL7X8vllXy4R3mJqBmBcMBQdLCV+2MkeVQaOEThFGON9vVibv53zd/wARZgfUZ4wMwq8FYwm7vjC78748FRXJcMv3kfPa5soWbOalmZImxaHR1jc/+Pd/Y/lCvxcXFxcXP7l6Q+exd49gNBdidezG//s6J25oh79obP4vvUjjKevolo1G3NxNYrJY/5X72OpbMTjwVtG6/xdrwIBeD26DoFYjGLyGAx5hbjNHo9i+li63j1IVrgYq91Ol16I6NnVHBn04pMFdgz7z6CYMgaRuwqHzU73kas0rN9IbQHIa2rwbLWTMmyleckaEnUtSCKDOT+gxFNowa27E8WIAYdGhzQ5Gr3WiMPPG5GfFyF9raTueZfmARuPaTuwbPegTaPDIJAQlSHAatOBUMS4oTrifrYT9wA1+uomJDJvgpRubAufR1nmDOYFjeB+bAfuzZsYqxUxN9Cd1noZwZpOcky19EdEM3ShgytFRUj8ZQwjIvLkYdoe+wEdCX4MfXqU8DlZ7PMfy/zuY5z38iaoa4CRqCh6jlxh4KOPkU7MQiAQEPmLF7ni9CNv0T1kIiDcA7aVw9QI6B0ZTdrQmgW8l7iapL3bmJkjxxQSTUEHDJtBLIQoL6gbBLvBRM7lr2n1DsO2eC6duz5BZq5lm8dUHM1XGC7oxKeqFMPaVWSdy8N9TwFpz6ziW0kYGb52FHu2cWbhBpI8bYwZbIQzJUTaBlHOGc8cYy2XDRlcNQlYnwYFnaNTw9e7nMwrO4p1TAYRhRe4HhpHm80dSTc8M+FveZf/Y7imel1cXFxc/mU5Roxov9yH6WIx3s/dg1AmQxYf+SePtQ8No993GmlCJCI3BQ6tHtXCqZiKq/9X72UbGMJS20L/ezuo+8Fb2Lr7kWcm3ujPqpichelqKU6LFY1FiF/pddwnpmOsaCR/wW1cO1PPr+u/Rvf2ZqqHpVwKzuKjQtj6eTE1EWn4e0q4M1rPwpObGSceoF3oQUPOFMSNzYw89zRHvTOJivVGvWoOna+/giY2AXFIAGW33ov56w+p849l7LafUffrV4lJD8ErORxFbhqKKWMoXX0XAZFeIJEgy0rEUVSO1FNFt0VC4bRlyLo7aZu1gNNpc4nxF7G2ZC9Hxq/i7bEb8XpsHXXZ0wjvb0ITFIGlqR3Dh9swKlXMkPXQ0T1Cf0M3QakhhMutVPxyMwErp3NOEcOShjxaVAGkRLnRL1GTWp/PuDdfoiZ3Ft/Nu5euXhOOilo+mfYAwWoBTUOjgdyGdGjVwt4aCHADJxDsKWbGS2sRHztDf3Unt6TAtS7IDoIj9XCyysyrzds5KozGMG4sPkoB5pQkomQm1gZqaJo2j9TLR1kma2PK9HDi7phJWHst5z86w7AF/C/kESw1k9hcwvyTW5gtbMMjMZyoACk+P7oHt4RwHuo4RnH36PXkd46O+N06mM+42fGc80tHt+cUAYsmUPv7EcQ/M+j8L80V+Lm4uLi4/MtxOp0YLxaheXsL4rBAVCtmIlQqcOJEIP3jhfROhwPt57tRTB+LND4S/f4zqJZMR6hS4jSZcVptf/J97Do9hlNX0bzxFZq3tyCJCaVy6a0UC/wYaOjGUt+KvV8DgEAkQjkzl56dp6h4/F3cp2TSc/8vGBK7YWvpQh3gwQthq3k6fiNf+E2jUy/gjhQbC/uvMfe2LGJKLqJ99BcoZk+gLyyG4nFzGLv7S2R3LGN7o4I5xYdR+qqRjUvF9uRL9HsFcnH53ZS6R5PXISZ3RRbDZ6/T2zyAR0sjgVt+jVClxCpVYI8IQ6aU4//+T5CEB2ISyWnzDMHt/rWEtNcjFQsZqWnBTwlvcZa98lT0Xn5sXg52oZiDrRLCHTrcMuMoih9Pj9IH38wYvCQOIhimziOMjsJGus8Uc2DmBn4bvBC3/AIMSnfMLT14a7px72on5eP3MN9+C+2+EYQXX6E7axx1XpFkF53C6YRBE3jIYHfVaBmcCSHw2myYHQ3VAxDoI2Nkw20knT2IWtvPC1NBJoITNTZuubID59xp+Gm6Oeufic5gJydRRf+QhbG95TQ43An2lyORSxAIBFy60Mn57LlEFZ5l4QsPIf3t+7TJfElzM9I7AiG6bpYHD5P27GjbPbfZ47EDizsv8UH+6PWtDtSiqqrAe+EExptbaRZ6kBQoZsgEw5a/4wPwd+QK/FxcXFxc/qXYuvrQvPkV9oEhvH90N47BIeS5aVjqW5HGhP/J1+j3nBztx9vRgyQyZDSjNDoUAGlSDJaqxhvHOozmG8GebvN+BG4KPB65FVlSNOrVc6loNjJJV0PFhgepHDcb3Vf7Gd57EqfFiiEhga6XP0PR2UZDcRvDWiOvzX+WqtnLSJiVwjvL5WxdCe/kauk3wNDpQoQqJZq3NuMYGkY2JYfvomfSUdHBkFWEqK+PXVGzSWsvJb6tElGQHz33/4LrkxcjfewuGofFXOmAjRngPSOLioNFzM3bgXJ6DiKRCHGgHz0NvYwpOIliUiZDb3yFfvdJLt/3DCnPrcPt3Y/Ii5qA7ZF7SD9/iBeHjmFp66ImIYckXzjZCI1DcP/AOcxePvS0DjF0soC+iVOQXbxMbZsBsVpJckMR3Tonv41ZyYhYzsIEEbN+spxvlJn4azqxn7yIl66XpsgUBnVW1t2Ris+CCbQN2nl+wfPMvXaA7lYNKX5wezrMiYZBIzw6bnTUrFs/2mtYa4KLg0oG1t6CadMeyut0JHnamH32G4wTxnPM4E+8r4Demi5WnfwCS2c/ypYmDPtOkRtop1MvwK7Rcb3VSkdeCfPrz+IMCSJIaMTs5k6nVUahdwLXFq/D+/H11E+YTY9QdePeKB8/j6WCJnq/3M/86wdw/9VvsPUOonl7C749rRg9ffjm8jBTI+B619/1Mfi7cQV+Li4uLi7/Mpx2O9pPd+G+YRmqpTNAJMLa0I4kOgxLRT3S1Ng/eo25rA6HRodySjaW+jZG8vJRLZtxY798TBKmoioA7INaNG9tRqCQ4fHIrXg9ug7F+HSEchnWlk66C+tI3bcVWWQwq9Yk0qgI4MLiDZQIA8m743WurHyeEaWa7oAojA4xZrGczEmR7FwNyxPBQw6Gc4Vonn+HtR1n6Xj1K0SJ0Xg9fSfDNe1sS1hEduUFsm6dwPi9X3B4zl3EddchOXsRZWszlvI6Sl96laqMyRissKsKsgJhZyU0D4uxyRUohwZQTMvBabNhrm5i2GBHVV2BpaEdxazxiGeMRydzQ7xjLyNzZxPdWMYvVTMozJ5DzI5NFPaKiFJZeTIXzraAc3iEGFMP7g4jcg8Ffvp+Qo/u55wyDtHMiaT01RKudiC0mlmhL2HYPJrwIHVYmbHnY4ZMYBsewSqS0jBtPukvb0Tb3Efl1rN8HrcUh1qNcd0aFu58hwczbbjL4DeX4Z6s0WBv2AwDRrgzHfbVjF6TZ5AnFyavYPrpb0g9tIOE+Rn8eiiBpW0XMHT0Maa3kp71t/NuwFy0D96HuV/HhoPvEtLfSn7qNKo2nSSxpgDRjIkMPfskQqeT6/c8zsJNj+PjqyDnwFa2v7CfrSd6+fa8BuOF62g/243Hp1/SaxSQ1FbOcEAQlsgIfF59AsfQMIpJWYTbh4h+902WXNmFJP8aA20anE7nP+z5+FtwBX4uLi4uLv8yLLUtyNITEPt5AWBr70EcGoBAIMDa1IEkKuSm4+0aHfoDebhvWIq1owdzaQ2qeZORRATfOEYc5Ie9sw9rWzdDH36Dx90rUEzMRCiX3TjGmF+G6VollQYloSFu2O68heONAvpG4PS1IVpPFlMZmoqH04QsMpiEe+YS3FFPu2cIjziKb/zxHzlxGWt9G0JfLwT7jiJ5fCPHwybSsu8yx/yyWJEiwru7lfpL9TjUakxKN3w/+5zA6jIUY5Lo+MlPONsp5W5zEf3fneN2bT6rDKUMF9fy6vtlRDk0WFs6EUWGoNt8AN3pfGxjx+CxYSlCDxXG01cpy5zOlMKjqNfMZUfkLAK7mzHLFMhkInRhkXiF+TJ+3yYuXu5kagTMbjpPYWQOrbX9qIqLcVeKsEjk+Ha1oCluoGP2QnQCObbsLFLe/Q2P1e3ljaM6Tn50DpEIwssKMNhF8PZL+Mwdx9Yf7uGd0yNUrLqDiYlK7kiDraEzsIqkmL/+jkN1oJSMJncAnGqCWVGjCRxmGzQOwsFaeHiRD95KsDZ1MBgVR/zV01h2HuLijFtQrlvCR5UKqvthnyIZZZgfcj8PfMK9aTtSSPyBnXh7SCmbs5Kkc4cZDI4gRTqMWiWhInoMM964G0NODvcPXUR1Oo8vSgScTZnF8UV3079uHZWPPk1G+QWKdAr054pRTMpCtXgaR1c9TGx2OAXZs5HJRNRuy8Nw0lXOxcXFxcXF5a9iLqxAMS3nxvem/DLkuWk4TGYEYhECkejGPqfdjvbz3XjctQyn3cHgLz9CtWousrQ/UV5DImbodzvwfu5eRO7fT+3ZdXqGtx/B2tqJx0O30l3lwOEbQ4fGl3gPG8nXzxNX18rViQuJLThLwuR4dmavIvbod8iNVvyCxDR8doiQ5k4Ejc1gNCNNjmYwvxqRxcJIZianzvfTfa4Fw70buPT1GdwtHvjUlVMZn86tn/6chsA40h5byEBjNz3vbme6xUpDUjKquBDCMZHkaUam76XrwAmOBMQzSe+EFz7EWVRO+ZPPkzU3CZW7k/7n3kYcEUxtVR9r3KE7IBL7nutcjhjLTxp2URAcRv+wnZDWelpXrsZvx27m/2w1uqFOJDJ3JI2NmIUCtFExGO0W3NNieD1rI9nFpwjyyyChtAKbWoX8VB5Lq6vxq69EptHQ7e1PVWASVa1uzPn2Kw5GzGTxklg0ptFuG4/lQlW/gCPjV5JTdohzXYX87OnRz9hkg3oNLEsY/Tw0Rug3wj2RQ0h+9iHFDg8GPGNx2/g8yjkL2RQ6G0V0MBoTlPfBykTYXyvmpzNS0O86zK4fv01mZhPyJ/YgCYxCn1eIfKCPy7NWsKizAO3IJGRiMYVdEJIcTNqyFfiPwN4qKDdCm3Z0CtfXS0Z7QDSZ/TUUfmdixhsbKeuFEG8J4YFyqsQifKZl8m13JhPm/F0ehb8bV+Dn4uLi4vIP5bTa/mTLtNEizT2IwwJvbLPWtqBaORtLaS3SxKibjtfvPoFiQiZCDzVD721F5OuFask0LrePjhxNixhdP2a8Uopdo0WeNZqha66oR+jjgaWsHvO1ClRr5mE4eoFqVRiikqPIl08i/POPUNTXEpIchjgiCOmpHfj2tFGHEF99HmUDQpL9AvAZl0q+PILAvh4EUeNpmzEPWXkF0d2XcbeO0PP4q6zWDdAal05WwRF8zp5E5O9Fd7+eWUVfUTtlPm24E7P1EHU501nw7GR+dl1Nbgh0GmB6GsikUPhhAUvunUWipw/Hj7qTdmovzqnTqQtNYuloy2CESjmDzX2klB1ENDGS4ld3omwzMDJ7NtIvX6H4oTdY78jj0pAbthPnacidge6xLyj3iWPCta+xqT0oi8oks+I8IfMnEXLXXH7lIST/6gBFyzcgMLeR9t02ikIymHV6Oza1msrJ8wlpq0PjE0TcmcMcmncr0xLdOdcKE0NhfgzoLZAVBOfMCeyruMQcinHvDoToUC60wtTfL9vcVw35rTbWlx8k/ZtjVM1dQEhjJeJGDREbZuCVX8qFtNW8Pxmq+qB9GFq0MCYIKh3J+AzuZpX/EK1vb8MwLgdTYhz+TdXUjsiYuSYDeYGJ0qNlJI7J4koHPJk7+r4BbqPrDSv74HcL4WA9+MudjHT2snfR3aiFNkoLRTRoYEkcVIamEtJQzjbvCdQNwoFaWBL/d3lU/i5cU70uLi4uLv8wlvpW+n/6Dk67/Y/22Tt7EYf63yjMbOsc7dQhEAgwl9cjTfl+fZ+5pAaH3ogsI4Ghd7fitmQ6Ii8PhHIZ+R1gdcDbV6FlzwUsFfV4PbWBkeOXGXzlM2qPl3J97Yvsr4Hv5t7NQVMIpY0Grr25jxCpmaE9p4kLUZC1+af4Pv8AexPnkRokJHrBGL5b8iCvxK6h/Na7CVfYiPMXcYtfP16tjcx9ej6PjxcQev4kXpH+fPvgy8QONBH/2HJif7iOnrIWbA4HA0MWfEqKMCjVdA1auZYyjZOKBJqmz+flEjXF3TBiGR35cpfB5dIhUttK8Vw0CemeA8yRdNLtG0Z5t4OwK+cwl1RjOHUVSXI0PW1agqtL0ETGcV0WRkRHHfe1H2d37i3cM1KAalIWtTUaItROBtsHSajOJzdCRJIfKKQCMhsKqMmehkImwi85lLJuBzKjgTumqJk5Lxab0UJCTSEloRkMqP0YTzcGhxBfdwnt6zfQ6HTnYht4yqBNB1Mi4EwzTI+AOdECjiXMROTjiW77YSxDeq53jxaO/izfit/l89zx+c8Yo21g1zNvoLfAwPIVCOdPx2Z3YqtpZmWonuMNcK4VViRAvA+sToKrF9pQRgQx8sv3UcucJD2wiJ3py3DrbMc6KZdQTyHKqdlo865R2efklmQQ/T4CatWOdvRwAn5uowkmCxRdzJ0WhMPXh8njAqjsg1A1BKrAbUwi3vXVLIwDuwPOt/4DHpy/IVfg5+Li4uLyDzNy8CzycWkYz1//o32mwkqkafHo951Gt+MImne34dCNoN9/BsOZq5gr6hk5cRn9/jPothxAOX8Smne3or5tIU6TGWlaLBY7yLQaZgaZuLXqIFX5rRS1WNB8dRBxoDeKe9dQPSQm/t6FLE4UsjJZiEd9DY6icvzNQ5yLn4KHrxuHZ93BB80+vFioZH59Ht5Lp9Hda+SsKJJwXynuVRWII4Kw6UYQjYwQmRXBtgIjOpMTvyuXuLr+YW5fGIRJLENaV0+OuZX4xhLMCYkMxCUj91KRcelz+sdNZPbVPaQb23gkxYjeCg/ngNEGOhOcqHfSv+kAieumMfTWFhxGM5acLIzTp3I5ZAx6lSfN1X0MvvY5mo92IW1sRBETzJXNFxANDuAntTMocqMgZSpJPTXs6PIk2tTL/uR5TL5+jLZVtzJWrsHeO8igRIXAamXVI5Oo9ovh8yIBrcWtpIwL4+D+Ojpf20T4M+tJEGnpCY5iADmVlQP0qf2oXHQL02NETA0fDagaNKOjaAJG/62WQs0gKBMjGdKYKUqbStUb35KsNLDv3Tym7v4Ej0sXuTB5BTHvPo2ooxM3gZ18rwRy1k+isqwPXWwsE2svUFHaQ5z36BrBukFo0jiZbGhgMDgca3kdoRkRqMalovZyo10vZNL40RFkPVK6/MIJ72kg3GP0nus3wI4KuDsLFBLI7xgt2my8UoJ6YjpRnqC1gFo2Goib7CCWSxAq5UQJh/GUOzFqDP+4B+hvwBX4ubi4uLj8Q1ibOhB6qlEtn4nxYhFOi/Wm/ebyOkZOXEIU6DvaaUMAbgsmIw4LRBzsjzgkAHu/hpGjFxBFBtP70C9xGIyMHL/E8PYjCNUqGr84StbxXXQseBBOX2RSji8+6+bx5fg76Jw4g4Jj1SRZe/FcvwDLlWKuf3AI9883Iff1pOCOR3hAewn1fWsIVAtxk8AkQz31g7Dl1CCvC8eR4gsh7rC88yJ6i4DT3ukUdgnoTs4krKqYTc/tx6JUsejqbsL3fIP6+YdoLO2k/9nf4hsXyNncxYQe2of72nkUdgmpjMxgZPo03LITOfeLbwg7foiGJi1ZQfDqLGg4eh2r1c7I/jwUkzIR5WZSdaSEiXQSJDAQc/kU+t9+jlaqosU7nMIHnuHjJU+j69Ux8eweIlKDqO0yk2Fq40LcJPSHzlPnF4O5uAZHWAhznl+B6fw1jHozfn3tlD70NKLCYmbPDMe05zhjju9EfPEqG3VX6Vh9K4nzM/CM8CN9oA61eYSBhBTq3MO40ilgUwlsyIA35oxOtbfr4IMCyAoY7dShksKkMGibMgvz1TI2K7JxfryVyQFWfNzF7J91F4bccWgHjUytOM32hAXMlXSyu9JJgLcEdaAXyjGJpJz6DrtWj5sUpkdCwGAHaaFi+loGkXkoEeLEoVTivFSANTOdpr2XALjcDlWpExhffxkYzSb+ohjuzhy91sVxo9eZFejE1tSBJDqURXGj9fzivSEnCLaWQYcO+hJSGfjuHKtObWKN439XIPxfhSvwc3FxcXG5wa7T4xgx/l3OrT+Yh2JiJkPvb0fk58XI4fM39lnaujAXVaFeNhPF+HQEEjHSmDCkcRE4LVbEoQEYDp9D6KbA+/kHcGr1BHz0An4v/wDV6jnYegcY+mA72kul+J4/iyQ0AKfdgaWqkejWKh6ilEa9iL6dp/COD0Kbd53SumG8WhvwcJj5aupGHqvZTcfiZRT1iykp6kY9omF6xSnCbp+HqraakYQk2nTQ3mcm8MoFgpZM4NbHp5CV6oUsyBfT6SuEXDiFQe3JSZ9Mqs7Vcd7kh6WtF02XDk12Dsn7d2BSqumaMpNPi+CHEyGyv5nu3Km8mLoRTWIKd9UeIPHwLgZL68nYv430dH++mnoXPXW9lBT3kjAmhKIWC7HddRz0zaF/7HiO5q6g2TeCaeZGAnfvJsoTPMzDNBY206AKJu7sES44g4nWtmGLi2Ldla1M+sF8REIwN3ZgMlmxhIUhs1tp+e4K3350lbTCU2RWXCLKX0KI3ELSib2UvvAlhvoOghsrCLAMobcKcI8PJdUPrHZ48BDcfxCiPEene7+tgJoB8FGCj2J09O+KM5CyPuhx86E7JIaWJg1lq+/ilCWQmZFQ88khotbNYvbV/VTtOEfIF59i7Bog6Ye30nytBd8JqSi27mR/uRWrHXJrL2Ns6MAcFMxIYAgOq40DRUaSWsqonb6AC402hruH2FEBt09SI1Er0bf08Ol1uC0VfJVQ1Q9TwkdH9aRt7YijQhAIBAgEIBfDA9lwayqsT4PhAT1pPdV4nT1DwhOrmLr2f9cO8F+FwPnvVoDmf6DT6fDw8ECr1eLu7v7PvhwXFxeXfxum/DIMp64iUClwGs0IPd2RxkcgjQtHFPz92rv/P9amDgynrqBaPReRp/r77a1dGI5fwqE3IMtOxmmzM/TWZtzmT0aWnczQZ7tRLZmOeulo/T39wbOIwwIRCAQMvPIpbnMnorplPtbfT/d63LMScaAvjhEjQx9sx9rSid/rT7Npbyvzj20i8NMXR6eJ957C1t6FYnou5+osRG3+nCtp08l3j0M3LpcFO9/FrbebhrTxiEIDsfVpSDZ3ofcPIrzkCip/D9oHbchsZtojkxm2Cph+fCsShZTyR5+hMXksyYZ2zOeuIc87i3dnC3n3/4hp3cU0hcYTfuIwcpGTFrkfYqEASUsLQoeDNx5+lwC1kNWpQoxvfsGXUzYyLnS0j+29WdBwrAjDU79Cec8aYp5ai84MZ9e9ghtWvDXdFC5ax5w7xnLk2wo8DENEdTfwcsZGFEoxw2ZYfG4rnWOnMCNejOPVDzHI3ciRDnAwbRFJ+acJHOzAb8kkzHWtDB88i0mqYCAzh16xmki1A4duhIQJUQgUCtSLp2KpbcJc20pB8QCeFSVITQZKI8eAxYJIKWfhzBDISOagOJ5N1VJ8laN9eDuHR9fBucthepiTaNsA6oZarNfKmTBYxeVFtzNlzRjOtwn4ogjW2CrxbG7Aph3GmpZKXXQ6WUVnkJZXMHWMF/uiZrCm9hinxdHQ2c2aX6+ge81TNAbGEvTDu2h8ZTOxhi4OhkxEmp3CokXRXDzbSu/ZEsqmLeHTJWBq6+HEJ+eJf2Q5QaePIQ705WPFOFYkCThaD1GnDjFzbRZmtTtnXt5DqoeF0LsWIg71x3DiMtdO1qJaNYf2fReZ88RcxGbTH5UZ+lv7W8Y2rqxeFxcXl/9wTrud4R1HwWHH69m7EIhH/zTYB7VY6lownCnA1tWHQChEHBGEcu7Em0qi/KGRYxeRj0lC++kupEnRuM2fhEAsZuTgWSSx4YycvIytux/fV55A7OPJyLELDL23DbtGByYL9kEtIm8PDHkFiHw9kcaGI0uOxm3JdHSf70Ec4o/3D+8GsQjjlVIMp64g8lTj9fjtmCRy4vZ/g8f9yxDKZQjlMjwfWIOluonuHSfoyZhHVloUJbG5PHVPEjUNOqQ2HZ6GPtQt16kTZiKenss5j0X0lzYT62OkbO5q5h36krY1t6E/c5W1B95GuGQmJVMXMUdfxYTWk7yqnsnsM28zbINAtYIxQ/XsDprMxr4LmNVytL3DjChl2HsHifFU4Gca4uHLX9Bjk2E5YSOgppwFxq9QScHuhEv+AiwXCilbcCd3BgoZeO0ztEYn4XUldGaOY8+ynyKODOFipw1lQSFDai92pcwlK0zMxVZYpmjHz0cBGeFMGwMrlv6IV5u3IQmJJPOTz3AfGUI9JxeBhxrNqQIcAgl2iYyIlip8ZG4UzVjOoqRB5FlJCNVKJJHBSCKDkc2eRPCvt3KuN5ms7go0bl6EWdp5Z96z6ML0+BVX4V23nQcdIgZjE7nomUyYtwB1Yz3R1XUE5Q/iFuLLNa84Qh64k+qTx8hO96FhSEBpD0z2HiHp2DnSIuScSc3l/vVJOJ1w9Uwzgz95gC01WgLzTnJFoSR2pIxG/yh6Hv01Fi8vusdPZXKECkeoitbjQ3i5DyBKjSbMHVYsCOftbaeIkxlwOpVsHwggx67B5+NPkS6cwlBTL8nluyhSLmNOtIT6r7ro6kuk7LWD9C1aQnO4CtUnX2Br68bj3pVMeP1uXr0oIN6/Hu0L7+L11IZ/5OP6/8wV+Lm4uLj8B7NrdKM9bieNQTEh46Z9Im8PFLnpKHLTgdEyLKb8Moyn81Etn/lH53LoDTiNJuTj0pCNTcV0uYTB175APjYVp8PByLGLSCKDEQf4YK1rAYUM44UifH/zDCP7zyBNjGL4myPY2rpx6g14vfgwDqMJ46UidF/sxX39otF1fgND6L4+iDgsEO8f3o3mrc1IYsOpOVGJzGrmE3kOwqswPnS064U0MYoTs9Yza+cHVIj8mVN7Bt1Lpwm7eBW1eZgBhwyBlzu3zfBGrGzlYGMPXg3X+XraHbh3tJGZ5oVp105mXDzCqQ1Pcy15MnLgpEc08rPnyWj/EHeLHofKG5NJzLlBFQF1F9gcEs0MtYkkJQToOhiQ2JEbDWjNTg6OW4VnuA+3S+qpL40jbM5UyvtG11+ZGi7TrA7DzcsTXVMxUpGQ/uNXkI5Jg7QUFs4I4YWzMENYRODYANq0Dg55RJDmADeJk/jLp/g6fQWfZ8FHhRDm1BHoIaTj0BXMvn7Ur1hJmKYNx7u7cLNYsMrkBH7wU5xVtRQV9hPVVEEHAVi+uUj5xocZvjr6+Sq0GhKPXGLvHb8l5PjbuGUlkn2pknmOZnJSYxk7dwLvXJ3AHbEGyk/XYDq+l2RfKPaIoWHsNMbO8KGgC641gY8YrqbNQLB7P1cWRFI34GRj/m7iVRauJsxiyszRWoyWzj46RB6sSJRQMeRL3y23ounoQJJ3iOj8s3QrZfQGRDJ9bRbm8moGo2Kxi0vQOiXcnTh63QdqgWkTMJy6wofeM0nVNBJgHEQcFY18bCoNgeDtU4/tk68IXJuFTaDjwKZr9K+4k4VhFro+3Yc4Jx63eZMwXixCHB5MTmkdQ03tCCJDkSXH/J2ezr8PV+Dn4uLi8h/KXNWIfs9JPO5ahjgk4H88XiARI44IYuTYxT8Z+BnOFt4oviwQCFBMzEQ2Jom+J1/D1j+EyFONx90rcej0aD/fg0AuxfdXjzO8/RCKCZnI0uORpcczcuQ8okBf7INaBl/7HGlCFJ6Prbsx1az9fM/3QeDQMEK1EoMVGr84jC1rAveOFSMWwtUOeD8fxCIQyNwQhwdRXidh9onNVKZNJnX+FIbOXWfXvS/yVO23SMMCMCChtspE1uql+JgVZJw8itPaRUJFM9cf/TGiMRk8FzVa5PfDa/D67RkEPrWVYpUfQZpOHG5KJncX0zFjLhQ10OOQUS+IIqf8EI7gEDrsClofe5ZnKnczMvNODr1dzb3zfbE0FdByppepoi6Mxy8T7+1BS8Z4di5dT5DQSPLxAiwtXUxaIaXE7iTB3Y720DViggQcyNyAlxMutsHHUXUcLwvAN8idovxOAneewVOmxOuZVRyttOJuMzJS34GtuQKdXzBOoDMgipPuOThaupk0WELhxEVMvHsqfgq4xQ3cJIDVSv+Lm9mWOIEVQTp6Y5PJrjiPZeFsZlflcSXEC5Pdh54R+LJOiS04iyVrpRxukxE6NhaFEU41w+pEONEA4e7w6Dg1O88LGerSElBShaqmlp4f30OXWxSRnqP3VO2ha3jPGENxz2jNvrsyIGdSCMYV91O4+RzyV9+hYVEm3VeHCMorp8wtDEXWVOJKLzNiWU+XHrpHQD0mgTHl5zi1Q8DK8B58fv04g299zYlqC0dapcyIiEW/UE7v089QEzeWlphw7q8/itvXVxCOm0LJxAVMjwRxRBD9z76JwzuWzJ/fy5nf7CKgZJgZGeo/eh7+VbkCPxcXF5f/ME6nk5HD57G1duL19J03tS77nxjP5GMurcXS3IE08vt1TU6nE3NJDW4LJt90vGNoGHFoAPZBHeJAP/T7TiMOC8BcVEXgllcQKhUMvvY56jXzbpzHdL0SSVQo9kEtkthw3O9Y8n1tvz4NQk/3G4GqqaiaprAkir8sIsBpZ9GDkxD//seZHjlaxPmlcxDfXUvx/mLk/hFUxWYTrW2j60Q7NcFJuIX5sddrDfF7T7Fl3Dpis6Ff7eTuwuOYa/PpsYo5eMePkcTGk+MGH10DhRh+M8uJ4QdvcjR9Pu5WA4EHC+gLzSbdz0GOtYbv0mIIuXAKlXGY9sBYpO5uxA820V5WRd+FYiTf5THH6aCgNhapQkpkrw7rYAddPmFEjE1iwmA9L9boUV87jMEhJu2jZxC0tGP+0aus9nYnPNqdk74ZlOoUxHpDlIeT7r3nMM5ew5QDX6KvrOD0xp9iUHvR894ZKvyyiLP2Mb/6BK3BMUQVX8buhInvP0qH3skcSTNufkJqQ33ICRotfv1ftFsOUOX0oihxIm+5tbIjNop2ixbP0FjMablEf/ktL3fdya8XyEgPgHevOvC9dJ7YYXfqE2O5KwNePDv6u/vJZDjeOFomJXzeWKL3Hie9KI+wt59luykSpWk0GzhNpmOkuJWcdfN49Cg8OxFyft+JTyGBDC8rv/vhKySHyhlbcoSG8/nMiAnhF5N/QELpRQ5eHsDq7UO8N3QNWJDrdUylkfdS7mZ+n4A632zGHDpIxpAU8ZZawpsbqMYbvcnJvKFSzo5fxpKBbiJGuvmiE8b2lKE/eZXun/6YnmtdWH6xiYHgGNT55ZAx4S98Cv95XFm9Li4uLv9BHCNGht7fhkAswuPBtX9R0Adg6+rH477VaF757KZyLJayOmTJMQiEN/9Z0R/Mw1LfituCyfj+8lFkGQnY+zS437sKS2UjTocDSUQQxrwCYLTOn7WuFWlCJJ6P3AZmCyKv7xezm/LLUOSmAaM12M4eqaU9MJo5rZcRhgcj9ve66f1LLzQy5+hXJJ4/wvHc5YiVcqpSxzNidWJ1d0chdHDn6U+5RV/EkEOMb1MNiQoDGfu24Hv2ND4KAb9Z+GMGI+MZMMA3FaPZqwMm0Lz5FUa1JwOLl5JYW4hBoSJksB2H0YQowJusA9vQOcTkq6IZmxtEQmcVpp4hAhsqOXnfT6ictwpLXCyClATMCfEULriV4imL6Xr6aZDLqRN4c+dHP8a9qQ5taAQDVe0YCysoixpDmK8YTxn0JqUxYgWdGR4XFFNs9SR5/w48BnsJGh+PXupGeH8LQf1tLKKR5HgPyh96ilDhCIO+QRi9fDnpCCXb0Iy7bgCf5+4nsvgy/Ybv8z4NeQWYlGquN1l4dF0UzcVtpIwNw3nnGl7pi6FR6EXgbXO4vXgXbhInBV0wpb+cytBUEtRW7kgw83XZ6Ghdm2406SPADcaHwFGdL9NPf8P1dQ9QoIwkwgN+mmNibf1RvHbtRrVuMW9cETAj8vugD0bXpdZdbmDinDgavMK5Hj0W96QwBuQe3F3wNdHZ4YQfPcD9Y+B6cR8Bm75izJNLiVDbud7h5FjBEEsNFXhevUKwwIAuOJxdy3+A8PaVLN72DIGLJuC5aQtnHcFcGlQStGUz2w+2s3PGXbxQ6cdl/3QsG9eR4+giQt/1Fz1D/2yuET8XFxeX/xBOm42hd7eiWj0HaVzEX/x6x4gRgVyG2+zxGPMK0H6xB48HbkEgEGDIK8B9w9Kbjrf3azAVVCCUSHC/fTEAstRYZKmx2HV6dF99h9DdDVlWEo6hYUbOFqL9bDf+v3seSUgAtq4+REF+N53TUlGPVTNM0TdX6BUoSe+pxj3PStegAa8p2aPXqJRjKatDf/QiVZZAku9eyZmXdiH2kWH29mbJia/oCItHK1LQGJmKs70Vz7Z6wiRmoi6cJKAiCGlCFPrqZgpffhNtazCN7RDhCRn+4OsGY6vO49NYwbmfvIb6mwOIe3rwiAlEGuiNrU9Dy3u7cDgFCDwgzdBGa/oSVKcKkdlMuOUkkVRbwCRbPRZ9H0fGP82hQS/erfqEerEb82pPUjBuBvIDh5G7K0mtvEq/PYXdLUo8599LeYeQx+fbaOq3ceWCgCQfGBwwYNy5ldzpE9ipS2NpkpaGa9XM7vySsf3V9NskDKVmYDcYSeyqxto7hHJYy7XpywgWQ2zeEYQZicgSo1DlJNF+ohi/5VmjWdrXq/g0eC7KQCPZwQIOtQwR6OdJWZcAP7fRlmtnWqJYtqCX85+doG3qHNaX5rNr+gbujZIirKlicmQmOytgcjjoraPlU+7z62T6lk8oSpyAVSLjdJ2N57nK0Ikq1AunIFo+j8PXBSgkcHv6zffiwJVKaoKTuC1CgM0JzZsLiFHJKVl2H2O8LfjVXWbkuQ8oCvfB53ovsc/fwheDniSGdfJa01Zqe21cWbMQkX8KBquA952ZvD2wl4xbJiOWQ8r8NERXCtD0NhMjHoHECH4Rt4BIIDMQXp4J4Abj1+Mwmf/iZ+mfyRX4ubi4uPyHGDl4DsW0nL8q6AOw1DYjTYhEIJMiDgtEFOyPfu8plNNyEIiEN5VvAdDtPIq1qYPATS/fyBT+LyJ3FdgdGC+XoJyaAzIJnYsfIeDjnyP5/TSuuaIB2R+0abP3axi0iigrHSbwkduZ1VyEpdKBpaIefZ+JEMsguq++wzFiRBIdStmitZhtSn5zTkOcUM7KCztQDPbStvpWJjw8l5q7f0lKmIFfZKxnaNhGdnMBd9krsLd0YewdYOSR+3ixIRinc7TzhJsUbkuDtPYyBg/uYNfDv0Tc3MqUg5sQZyThtXICpsIKKv2TiGzYhU9GHA3DAtzC/GjdchRfkROv2+bj3V5GzwMPUHzKA3FCLtNLjxPToae9qZWujfcTuCaY/O19TOvuIbS/FXtYAM75M5jRWcB7p9V0BMVSd7yEsl1XmYOCGJkBx7UyTJOyORE0nonHPkCREICXrovw2kYGMjJpW7WWaqMbiX7gTTfHYschnB6A0T+YEIEeY2EFwbveBCBkYS4NL32BfXoMum2H+G7a7cgOXmDiinTKO6wIpWLyuwTcPwaMVvi6HGK9IHRRLqHl32E6tJ+qsBjmeg5Bt42RE9dIfiaFAYOEpQlQWqlh8qnjHDxmoVsZSPet65j47cdEiD04kJnDw0/dg9YmYFMRTAp10mMQoJLefC+W7L3G+CfXYnVAYYeDSXWFHLrvGTwVAtLCZaiTpiP7+hhfdHojv3UBomExayWN2LorGWjVcGDpc+h7hYR4BVDVD09m2Yk6NIj49/+j4TCa8VeLODrrQeShNjp/u5WIHIj0AM/Am6/lLx01/2dzTfW6uLi4/AewdfdjbWpH/t8yd/8SlqpGpEnRAMgyExD7eOIcMTD08U4UM8bddKxdo0O/6wQe965EEuz/J88nH5eG6WIxooggBl/8EPmkMQiU8j94vwakiVHAaO/aIzvLaBm0M+OOXCZEiLA1dyCUiFFMzqJx8hxC7lqE58O34vXMXYiXzmFLg5JDtRBZcJ555cfwdBPinDaR1oXLsTR30D13EQU+ibxe+xU7k2qJ6ajlhTt+y5GIyVx0i+en4mlYHKMB33OTYYKpBcfvvmLora/we/NHyM1G0l/9OcJgf+pSxmLp01Ickk6AcRB1Wgyl9z6Gl0FDsU8C4z74AQOe/ng9dz8O7TC+l87hOHIGWVMTUVfzSK0txJqeRlDeKXY/sZWM0vP4Se0Mz5mNPSSYvppOFBuWk9JeybPVOzj15WV+l7OBwHBPBhWemCIiyNN6sOitn5Ad4MDW3oVFZ8CuUGD19aXLIKJZHcyC6cGcOtZIz6x5yAYGmDcvkr43tqBaNvNGABPkJaY5IYv+n7zL5fGL8PZR4NXVSu7UCA6f7mA4MIT7xoDDObrWTiEe/XwAqqYvJPHySRqc7vjv34dQqcDW0knZy1/z5KVPMD/8PNEfvE27KoChQQNDmdlMOr0Tq9FC9rMrMeWM4RfnBfzyHNwe0I/xF+8yJ9x2031TX9GD3d2d6BAFu6tgUcdlCPBj2D8YnRn83aBvBA7EzmBsexGrI4ykHf6Gi/vL2T9nI20p2YwzNCP+fWu5VH+Y7/j+3obvlxTMT5JQa1CQO9afBwK7+K4Wxv19S/b93bkCPxcXF5f/45xOJ7qtB1HfvuR/VYT5z7G19yAOGx3ukI9JxnS9EtXa+Zjyy0F483kHXv0Mka/XjaQNGO0Kott+mIFff8rwt8dwCgTY9SNofrsJoZcav1ceR7//DCNn8nE4HDiNZgRuSq60w+8KIaGrlliVjTqvSCw2J7bmDuwaHcOtfVjG5dxIRuh67xs+fWgraLSk5Z/glrxNeCaFUTJ9KcMTJhDkLuTc4VoqAhOYtzSesA3zGHhrCxnJXvx8sZqBOXORZyZQrRXjLXPyYXQdc45+SUpHOd1OJUP3bqR9yE7GKz9H5x2A7vGHSa26wi5rFJJZE/G+dAFt/zDqx58j8v4l6BYvQnElH1VOCocbBGiCw+l9dzspbWVISyvQW0Cs0TDDU4swMZpT0ZPpNAjRSdwwxcainzkdlaeCsg+PcCRrMQapgtjGUh7/+nm2m6O4rlMg8VDhKYOyzGkEffwCbpp+DCMWrk5ZTtDKaZiulvJ43W62nR5EadRTNqIkSqinrXEQaUMjbwQv5IMC2FsNhZ0g6umle9jJiLc/9aVdxKYGcrRegLC5jSVzw5CIoGcEpCII84BgdzjdBG5NjQxNnkLats9Q3b4ExaQszCsWUdUHLXohQxtup3LdvURZB/Aa0bDQW8PA+vX0rllLx+kSZkdDuxYCVbD9jTwE6UkITn7f4cXhhOKdV8hdN57KPrDZnaj2fEfVsluRi0ElgSN1cN8BEMyeQlxLOdr3tqCZOJnh5UsxiBUo505iRddFwj1gYtjoz6G7VIo89/v5ZFNBGbKcFBJ9YW0KuE/OxHqlBDfJaFD578w11evi4uLyf5zx/DVkSdGI/bz+x2MNeQVIE6MQB/retN1hMiOQSm4EjkKVEqfZgrmgAo+7VzKy9zSije6IAnxGu3Qcu0joyc8QCATYNTpGDp7F3juI26KpqG9dgLW+Fe0nu7DWtWLr6CXgkxcRuClwDAwxcuwiltJahv0Cee8cSETgaxqiuW0Yy+QJCE0CthztZmH3AJ53LKb1Qj1x4aMlXU7ntdHUIKMvK4X7330CsUJKV/Z4whdNIKqgipNZ92LrBnV1D4vu8yewv53hXSc4/8xLLJe2Idj8JWs8/bmU38QCtZFp8n4SvCJwe3AtWeUdlB8r51yPlNTnX0EXGYt3iDf912qQyj2oi8lg0dcfY7faaBT5EDE+Es+yEqbo7fR+tQ+/rBiKPt5Js0BIemII7mvmoJwyme+e+xbVw1NZPFiESR2Ce2E72dUXqLnvMZSnz9K1YQNxujM0V3axIe9TPGtK+O7OZ9AIFSzf+xEhA2389o5XmFJ2CtxVHLzzHeLVbjQ98wLiT7Zx8YSUqtnLCXK2MHnL2xxUprDx2jYm/3Ahui/3YZuXQqWvgnuzoEsP7aeKadY4KU9eTPy5QvrrtDTNSudgEbwib0cWnQ1Ajx6ahuDx3NF2bA8edPJQ/hk0I3bMP/4BJ944iiYlDfXV6yQ6R8h97xmEQgHXWq30nB/gzcVP8YPJUtr7YMOSGC49fYad4TMQCITkKjSkxpq5PG8OZz/6Cr2nhrRkL7r6zUTbBpFEBHOwAO43XaPCpmD+vCh+VwCH6kAtG22/NsPbjFQmpFvmhsfBo0xNDyV4QgL2yHCO7XLjPv8uTliCuCvFxvU8LXPu8wHA1jOA0F190xSuJCaM1k+Pccv9Tk40Coj1/hs9nP8ErsDPxcXF5f8wu06P6WIxXj/c+D8e6zRbMJ4rxHS1FK+n77xpXZ61rhVp/M1rA2Upsej3nsLnF4+gmJTJ0Ec7EU8dS899L+KYNZm6QQHSb79DqNUhXTANWVw4RhHYbSCJicBhMCIKC0CelYQ5v4yh327CKRUxcL0Ry4CWYbUX08e3EBCswr2zDUdvGypBOMKCw3SeL6epVU/MkUu0eCVAQS1vD8sYc2YfJkUIS/Z9yPWZyxk30kh4RQXVjSlETM2hYkBIrKGTrFAx+ReaCW48g+rh9ZjrFPiOTaItLJyuDc8z7BXL+MFKgtbOQL16Ag69AbdTZ9DET2HCZ2/RL5EzuHQpKed2Y/TyoSB2Ao9e/py2gkYs3hH497TgG5UAEjFBZ0/SN6Bl38w7SQ8R07ftGIjFSCKCURw+zpWsOcxKCKPySBk2JTTEpFNk6MM3IoDUKAV1djFNEcnE7zuE1Grm+Ji5tCn8mJy3C2GAH1cyx/OD717FaIFN8x7mlooSLo6ZhfBMPgHJ4RgLSulPXkzWrHAKdgWR3F7KxHGBmC4VI1Ap8Jw7nkAJtGpBW9fB4PkSsh6+HYkQbG9+hsomoFA1jygntHcbOVyhRCAYbcfmoxjN0LXYIbW7io7iZhofexJJYAD6NBEL3Ps5fOf9TLm+HQxGUCmJL8ijaOwkRFIph+tG+/p+WiTA6RPHHFMdc5YlUPDWebxvm8qKGAG2Hy+ia+th8v3WU7SziIy4LPZfhDVJTpp/lkdRXC5vHwODFULdYU0ShHs4afvNd/jfv5HJtn48b52HtbEd/fUarr5/mhQPG55bt/HAT+7H3tJJfW4sW0phYyYYL1xHMeXm/rsCgYBGr3Dmm1tok0TSrht9rzYdFHfDkvi/xdP6j+EK/FxcXFz+Dxvedgj1rQsQiET/47HGyyUoZ+aCUIh+32nUq+fe2GepakCWnXLT8eKwQOyDWoRuCnBT4H7bQqrXP49IKsM6bELwzK/ouHsjwzOSsdjB1jAaIFjsoC6+TlxpO82LVqJqb6UnKg5vHwPG7kEGAtwJiAEPOaiwYRsxYe7sQzUpC7fF0xHYbKi2HWJ4zS2c2XOJk6meZBs7WVB5DnlLM7ooMR8vfJw3H4zg+u2/xPvuNYR+dZD9uuVETjYw74t3iA5XE739HY6tuJeQfjnjgpyUlPTTsPsCwb5eDMi8GVy0mjnDFxk5dQVzRT0iPy8m/O51qkU+uId5EbP5UzqXL8ChNZNUmk93Qx2tY6eSdPYQ7kunoRqXiNPhoEeoxlpQz7TtHxA8PY3+lhLKQpMYe+IKPQofVMlhtAzB3sCFzD7yFalRyVQGxBN6vpoeUSQBB/ahdFh4bclz/Gz3C8RqWohqOk9/chKR5ScJEvehDw6hwjOGjcfex982wjWnG9cl0cxqucyElipu1+WzZac3Y3q6GPvKXXiGe1J2sJhgix1ZWhzjdPCrw3ruvXaI+S/djkUu4rPrMCIMJFvcRXKAgKneBjzrlcwcD04nvJcPKxOhdgC2XTMzdctHHFr1KDNTAlgUD4Op8Wwti6dbA4dUqSi3V+CVHkVYTTem2bOJ64ezLdCuA5UMVi/PwbT1OyxTglFoh+j2DSMcEAf54RnuA6UVbHSW0zb5Lqyt8PmeFpI7HIgWZPLubPj1BUj0hW+r4MGhC0xdEI/H3Bw0v92E02DCERnOloFwFsyCaImegV98iO6zXZgKKwhIjiG1spHLh92IqrqGu9oN5/AIQi93EAoZ7NYhFzoxfLKD2QkJVG0dRulvp6oTsmcmQvy/Tx0/V+Dn4uLi8n+UuaQGoVr1v2og73Q6MV0pweupOxFIJQx9tHM0izc+EgBraxeqVXNuPn9RNSI/b5wOBwKhEH3XANKWNpTpsQS8dDcCkYjQ7YeRKYdQzh5/o8afpa6FwSv7ITWEnGdnMfThN9jKj6BZsIDDu8pZHtyL4MQ5PH77QwbOFNE3Zizm65WUD7kh+uVehFIxEoOCvC4VEVOmEKPRsczRwIBUzG8eeZ/w8yfJGGml7LwVH6eBMqOKqgUbGDKKuPeNJ7GHBKGMD8PvkbWM+/oUHZ9/iiDMh66EVCYZm/jx1PswCSU8WX0cQbgS7eb92Jo7EAf4YZw4ns4GM4MeHsRU1rI7bBprtv4Ad4eJnqBwJjRc4dtpa7n/hZXIpaNr5nwLthD1yFouR+XidnAzDYnZ7B53C8oPn+Py7LHIJbCjAhQSGR2TZ/Lo5S94ecnzpGx7j+AQNcbFC/htfxRhxYUcmbCaxKYSpp/ehdPXm2PzNhAptzDGzYGqrh5RhD+WwSFmJ0goloVRmBxFB27M+WozK/RDXJq6gkqfWEqGBew2JzNxxEz7ORESp41V+d8S/OBSpO5KpMC1LkgTCUhVGNhjdBI00IYtJgyAXsPoiN/2CohTWcnZ+hHWnCyeuSeeb6ugWz86EmhzwKuzQGpNpufd7WhPVlM7bQFXOkbL0EwKG11nd2sqdOtVVLoJOP/BCTSZk/niFCyMG+3ywZiZ+L/0NrKZKWyrFpPkC+n91/GNkpC+PIjXL41OU08Oh2XqboKqm/C4+w4EAgFuC6YwdOAcW6LnMz8G4n0AVHg+tBbjhevIs1PwfnYj3nYHe7+pwF32/7F3luFxndfavodHo9GImZnJsmRbZmaIKY4dO3aY2mBTbppSmoahEHDiJI45ZjtmRlmymJlZGo1Gw/D9mMSJa6fJgZ6vPWfu69Ilae9379ljvdt7zVrveh7w8vfGqh7G0lGP3W6jQqsienQ4gqMVBCybxOEmd/zDxXR0waz/fL/U/xecgZ8TJ06c/C/EZjCiPXgWz2fu+V7jzbXNSKJCEEglAKjWLmTwjU/xfHItAokYgVB0U9bQbjRh6ehBnpOCqaoR4/UKOl7bijgsGI+HVmBp7UIxMQvPH21AfyqPwVc24bZqLlhtDO88irV3AN/Xn8M2rMPS2oXaN4iKLedI9HTFd0o6Blcxpn0n8PT3RvLHlxD6eJD6+DSq+wX0/+Q1XFRS1p96n7akTPoTJvFKt4oZY4yEVRTgE+jGlFgRXb94DdfVc2g9V07x0vt4sPEwvW4+RFr0aK+WYdRbEEzM4b3YNcy11DGv/RJdrUPU+suZlq4kLjoa/e6jGIsqEchkmFs7qcOPaKGZ0/ZwerJms/wvP8e3pxVJSACiyalcMHrRpIzh0cOOUuDyRAhpKUN479Mcr/THTRlOlMLCuvObOJo0A8WFfC7Kk/jtdDntGlCPRNB3xMK8q7txZ4RjdzxLWb+I3mErPzZdp2vNXQh/n8eAzsbZsEksUfTisXcfHXfdjbuuBHl2Gk02V0o7rGwo2ITAVUF1WAhhV7sp9k1gQOZB2c83kh+RxQpZE8cixmLVwrTLBymKHs3+hgCWy+BcvYWUyydZ2nGBDi9/Zp7djtaoRhofif56Je9el5ESJOeuyX40vLqDdqGV1juWEegGd6fCR0UwJwbC3EEuBsQyxAMDBKXEMhjhw4OhDtHpOB/Hv5FA4BjrtTSNgt9tY9Uf7sCjHGK9HH1DfyuWMtci5dhFI6oVEO1uR3juMl6/f4jfnhNwoRX+Og+i3Cx07TnI52NWEotjPao9Ppor759nTq6OeB/FjTksjQljeOth5DkOQXChSMjU4Wq2p0zHPdKL0C91w+12KLgMz4wD3Ug31s4+Zsd584cL8Lup//V79X8aZ+DnxIkTJ//m2AxGzPWtmGtbsNttuN0xA+3uEygXTP7eGmO6E1dwu3POjd+Fri64LZ+J5tMDuEzKQhITetN4/ZUS5GPTsRtN9L/wV+xuSgxiGbHP3IVy0VQGX/0YaWw44gAfFNPHIBuVyOCbmzHXt+IyNQdJaACSID8GX/+EwXV30/3zdxAnJZOS6I5tRI/bilkYi6oQhwRgLK9DpzVz6dkPUA31IdObMQZGUW2TU98jxG20mMyy6xxSTMSjtQTbwys5ZTKRavmI8v0F9CWPIaymhJFzV+lctQb3d1+h6b6HMI8YadhTwzJREVqBhGvDBi6JEllVuIvks80YIxSYKxuQpsYjFAro7R5B0dXBhSd+SeaJvYwRtaEtKcQY6I/XX1/ggiWUvlc+YcIPJ2OXQIMaXOxmOrr1HO/2J0phZkBtpH/1BkR19ZwZCCfIM5K7678gaMkddAxDUF8rZ5OmkXF8F9tz11DXJSLRG1LrCskcHcSlT7YyFBpCW3YmGX2duHYN4f3bxyndUcS4+xdiO3SSN8LuZNw0GYYpmRwu1bPh+g50QglJhg5mB7WwR+ZHcm81TcYR3IKMpB7dgmKgiykKE9d35lNsgWyplpjeelzc5ZRlz2Bh3SlsVitCPy8OXdPg1WdggaCfvifKKYsZjy1EwdzRjkjJy8WRefvzNXhxumO+2I0mbMMjiPy8KOhwlPG9XGBF0s3z0N7WhdJbSXGXnQWxAt4vBH8FeBg1SIN8meeiIyGul8P7qvH09WSHNQaBAGZHQ7QXaLYew2PWWILdVBR0QLwPfFgkYPbyCQRduwDhs256PfcHliNUKW/cRwwPc/dUL/5aAA9lgbsMmtQQ4eEIQF3GpTO86zjCyHhMVhAWlTKi1eA6e/z3vl//f+MM/Jw4ceLk3wxrvxpTTTPm2mYsXX0IpBIk0aFIEyPRn7+O/kIhtqFhZOnx3+98gxrsVisib4+btkvjIzGW1KA9cAa3OxxPcEvvIAKZBN3pPESeKgTuSux6Ax1+4bi596CcNxGBSIRq/RI0H+3B89n1CEQiLK1dCJUuuD+0nOEdx1DMGY/6/V30TZzCiTYpsxUWBC2NBP3mWQZf/xTlwimOTspf/pWB6k60ySlEu5ipDk6jOTmbkLYakldPpq7dleWH/oxdaKGn9DRbZzzAZxkC9j63G0tqIpI2NbOqjqM41c5IQgKZ57ZiltlRHdpBSdQoWDSXlQs90Q7q2XLPXzBE+ZJTk48wxJ/eiiY85FJEchk2mYwToiQmX/+C+b/9Ie6YMJqMyOIjuBwzjo/aQrk/w06dSEPfqctMtTUzHD6RPx/tY55XMOEeMKqziqspCcglcFEZjZcZ9G7RLPcr4Tc76jDFxBCVV82Quy/y0VPRjh5NkBKwWphwZT/VaTHsmbian1Zv4+DsucwLU3Pm5QMo+kbQeAegCvWhSuvIngm27SHK0M8dZW24ugvQe/niE+VFvk8SbqlhjNG24jbQTVVrHsL+di5Mv5NCqw/po8TES0eI3PsRbZ6htIcnkFx1FeRCbMA292ya4+D+VSA6dhSz0g15URHD92/gG8k0gtzAaoODNSAAPI+dRpe7CNcjhewfNZ60AIj2hHfywPqlO5zQZCTmVBODSWPI31ZLaG48F5qhaQie6MnjZOgY4uK96Xl5D15l5RQ8/EOutjuygh3D8OH2BrzLdDRHp2DXOMrnAHOiIT4nhsGXz2PT6hAqv75QrYcPagOY9GC5VI4lOhV9H8R5wi9OwawoON/quNYtpWC0ehJcM8K2QzomXj/KFqGUlAdm8++U+HMGfk6cOHHyb4Kls5ehDz5HHOiLJC4cxaxcRP7eN2nzCdzd6H3s9/h/+JvvfV7d6TxHU8dtUN4xnY5lT+F+z2LsFgtDf9mGuaMb+4gB9z88gf5sPsLwIIy1LURkxzsaPQCxnxcu4zPR7juNJDQA/aUiPH94N5rNB7HrDAxvPYx28UIOi2JYXbOVKrkPYeFutDapudClYKFVzInLvQReq8NXaCH1jxuof/4D1GoDSQmexA9qMAT54q/RIXJXYlDr0A3q+GPx+2w9KmJ00WkME8ZCfBRt7Z0oQxSE2wwM1KnRi5UMu/oR42oi8sM/0l8QzoWyYUQCMZm1eSS+8ji2wnIqeoeQ/Pg+/Arz+GXIUmLPHKb1kcfx+3AjkmE9pVOWotHZ8LJbCN25hfJP+hD0DdCR488PhJncV36cSYXnKB4/h2WJoD5RwkX/xQxVQYa/Q6/OYIG3QueRu+Nj9vmGEtlWg668CJmLhLmXdlEalkZ83inavULpm7uWJ20NBCgiCPcU0LXlKBl3juHCJxeR/mA9XXllVJrcmH10E20TplM3NMCI0I8CrzByByoI6q0hcPdO7JGJhEeIEIiEjCrPpzshjZgrp0gUwkiBGPFgH52eUkIFfTQolGgMMmoLalA21JGQOQFLWBqBmkFGegYonL0S1YuvMSbEdNO8OVoHz08GkRAEPX1Y7L3Ils7m4NVyVocPsyrXDaHAkUUTCRzB6sjhqwjX5iBNjCLopX1US+LJCoK7Ei0YXm/kp+9Mx18p4PR2NSHeAkrDw3k/C042QIZKj/umE3j+4h6ELo4GooZBGBsCYiEcqhUwcfp4anddom78DJrVjjFfZR4lQvC7Wop+7SqkQLgnTBNAeS/IRI6uXanI8dWT5M99B14n68fL+as+nrH/OSOc/2/8UwWcIyIiEAgEt3w99thjtx2/adOmW8bK5fLbjnXixImT/2to95zE/f5luN+/DMWk0YgDfG4RZDYVViKOCMKu/37+oXaLBXN1E9Kk6G8dI8tMQLP5AMO7T4BIiLmhHa9fP4qpphmkEroHLHjJ7cizb+76lY/PxHC5CO3hc6g2LEZ34jLWATUCsQhtYjINZypYr7mKsbmLnqwxyFyk5H10jt7MbLb9/ghJl48R6mLB75l1dD3xEp+MW4dbRiwhzz+PpbWb5q2nSDn6OcbyOkp7QR8VzUlBKGHNVbR6h3K1R4KytwtFZzv13uHoxHIk7q74Jwcz5qGZ1HlHEvrWjznVq4C+AYIUVtym5XD+00vojpwn5c9PUfPJCf4QsoS8DiFRhj6aW7UEJAbTkT2e2DBXFmhKmJ3ry3L3bmrSxuO+bjFeo+MJC3Ej5NGlCPR6/Jrr+OWL19laZKXNrqR1CHZUwIABQlXQ2KZDK3PlgQ9+imKgF4WLmLz7nuJ80hSyTu0hpKkCS0IcMquJ0JpiXMZlMElXR51ejtuFS3QsWExzr5na13YQblUjT49HeuYCgs4ejoxayAJtGZbli6jxi4UNq/BMjcTzybsxDI5w4uGf07T+Ada+u54Vf16P/0PLGJa6cjF1OgVpUylKn0zn3fdwbf4ahhOTGfhgNxlv/IGuR3+HVS5Ds/80Q08/jn2rwyoPoE8HFrsj6+ensCPZcxj3NfPYUgb9iaks1ZchFzuEkw/Xgh2wm8wYiqqR56QiUinxVwq4WqHB3xX86ioQZSSR3ymgtqQTgURMp0nKulj9jVKs14EDuK2YhdBFhskKfyuA+zK/lKgxOPT9ftYVx1BlM7FyPfdmwpNjHfIti+NhlscgMSEuTIyXMSYYMgPgjgTIDoLpkY734u1iR3zhKo0l7YQlBSJPj2dyOJxu+g/cxP8C/FMzfteuXcNqtd74vaysjJkzZ7JixYpvPUalUlFdXX3j9/+KyrwTJ06c/G/B3NIJYvENL9Hbob9wHVNtCx4/XMPIoXO437f0O89ruFaObHTyt/5fa25oQ5Yej91gQn/iKuLwQGSpseiPXXKseg8JpEegIFk6hCwtDktnLyPHLmEb1GCqbsLar8aoNdDz0jZM5XXUxmcx0O9N/+q1PBE/xODa52jzDSfzZ7mc3ioko+I4SnUfA7kTiRybTG/+NUzVTRQGpxGo62P0SBPSH67BagP9ZxfwKsunJXssv0nbgNHdiwkhdqLL8/DvasF3sIuRkDCOzb2Hdd5dVPhNYeqhTYiVcoxCMfKmJjaVpNPqEc9q0yV0ATEkLUxj8NNDHF32KEGvH+Lz1IUUd8l52eM6luQo5l48SaerN5EqESE5YdhHBeE6ezz6PeeI3H2Zt+SrMNocWazDeUPMkCrZMfchMnZ8iEyj5r5IDR83qRjrpSOstpiM+ir2t7swGB9HhL6GJpMCD1cxkV/sxTSk5cqYeZitdrID3QjY9imG4TpcJ2YiPHoGnSwM4ahgVLWNBOzZjUrTTWNiGlMzfdnqN54WnYQM0SBGoZh6q5K1s5M4U29lpqaNsrcOUBAxganjAojw+PKPLYIZlScx3Z/F0Vd24+Kt4mfWFo6UhhPb00TxvJWIq+rw8zTReL2CwfIRXLu7mTYjAmncHIbe34XHE3dzrF7A7C/dz4zXyrCFBvFesxc5wVCXk4DlzGbMM8Zy/r0zROu7eX/yIu7Ul+AycZRD8NsGF8JzSCi/RuT46Zg2F7DqN6vYWmsl/6+HkLn7kbF4HIojxxlcvoiwuhJEXiqkcRE3gr7pkZDs62iwEQvh/lFwrF5Ap348idcuIl864+b74DbafQDTv3wfNq2OoU170fgF0X7v/Yw+/Sm2ET2ZAS4M6L/zNvuX4p+a8fP19SUgIODG18GDB4mOjmby5MnfeoxAILjpGH9//3/mJTpx4sTJvwXavadQfrnO7u+xG00Mvb8LS1cfHo/fhTQ8CNvwCNZBzXeeV3/x9g+8rzBVNSKJCsXc0IbNYsGCAOnkbEYOn8N14WQaT5bikxqOALCN6On92+cUeMdzwRbAxaBR7Hz6Fa4/8BTW0mq0Tz1Gf5uaQ3Pu5USTgKu7CrBGhNIblUDdXw+Q4qJlpEdD0oJR9MUm0vfGZoQKFwrn38VH6SuZsHcjF2yBnD/VzB5jGH1CV3Tr7mLzmDXMKPyCbSO7WHB1J/5t9UilIvxj/LgUNprpOT609Vuo9Iigu3MY208ep2JvPtfjx+H+3gc8VLOPnt8+T+nMO5D9+mVwdWEwv5oC/2TGjwskw9+O6VIhopoGRO5KJG1tBP/qfprqBzgijOL5M/AHcS5u1ZXIPZXEesHcWOg9coXrnrFE+0kI9HfB7eFVyDdvZ/3hN5n4yq+IzDvHhSYbcpuZJYpOhs1ClHU17A8aS8OYqVQtWU1IQxk1SWMoCk6jJmM88olZ9P70DcytnSQUnKXlzR0EHD6AfO4kjk1bjfcjK8n3S0ZotyFq6+ApUx4Xw8cwIRSMifH4N1Vz2jsV6/mrrN2QdiPos+kMDH9+HO2+U4gqa8kMk3J0xeO8Er+SmZMDGTVQQ1JtAbkVZxBcLaRhwTLE2PEPUKKUOjpkZRkJ9G0/Tq8OIj0dzRKDRy6zyX8yS+KhXwdjIyQgEFDw7HskhCvIXD2e6Uc/IW/zBUbSHbooxxogeHQ0CQMNnL/USXqSB2I3F5JKL3ItII0g4QhR87OxaXVUnq4koTIP5dIZmKzw7jeCPnCUcVVf9jfNigaP0fFcO9eEZcTw9f1jt2OsbPjWrLepponBNzfjOmcCX0RPYX6sAHl2Cob8cgQC8Fbc9rB/Wf7HvHpNJhObN2/m3nvv/YdZPK1WS3h4OKGhoSxevJjy8vL/qUt04sSJk39JTPWtCN0Ut7Vcs7R3M/DKJuRj03FbPuuGVp7r3AmMHLnwD89rbupA7Of9Dzt/zbUtmEqqUUwbgzkoiMajRVz+ogrJ3UsYeHML7RYZiStyEfl6MvDWFs4qYwna9BEZZeeZOzOUH462MosmXFOj6Xl7G+WTF7B9nQuLfNXo9p/iL1GLKJ+9DFF8JKqde7Dr9Gxs9qA3r5ruEwVsXf0TXih04yFdHoroYEZVXGDG07OZ4tpPlL6b99Pu5Kw4nN41ayg0qQja9BFefkqGklIwe3rh6yqAgyfIfGwe67vPUuyXyGNHxTwffAdhu7Yxy09Lp8SD5OwQouuKka1ZROmwnOzrJ1DJ4GytmV+EtNJjFKGqr6FgUEZHRCLrjrly5mIXn/QF0KYBk9VOZ2AUc+vP0D0Cx+thdncBsXeMo6ReS5yfiNGSAYSllXh2ttLsHcb78YvZMmU92g1reTFwAZeEwdT6xbDMd5Cf3+FJRlsJVaEpjNjFVPVBetN1Oq5UIwryY8grgAGZOx9OeYDTP3wBF4sBr/FpbCmF1w4PknvwUx5uO0L70WuYoyIYHwobqxVohw2MabhK1OKxWEurMZbUoP7rNtR/2Yb+XAG+rz+HYnwGqqnZ9OoFhPvLKFBGYYuPoWzlvQROSiVqVCh3+GuIlumIlOqxGx3r+xRTsqlqGmGGphxrv5rO17ewP3oad2VKiPCAmgEIa6lEXdtBj08wmcvHII0OJXBKGvFKE19susalFjt1AzBoEJAyKYaHS7bhPWccuvY+rp9v5I4JXjT5RtGqAeWdc7G/v5XQNdMxIeLdApga8XXQdzsmRwhQzcnl4HuXsNi+muPNSGPDbolN7BYLw7tPoDt5Bc8n19LtF4ZcDL6uIMtKxphfjlWjRXc2/x/eZ/9q/I81d+zduxe1Ws369eu/dUx8fDwffvghaWlpDA0N8corr5Cbm0t5eTkhISG3PcZoNGI0fr2WRaP57k+4Tpw4cfLvxMj+06juWXzLdt35AgxXS/B4dBUiT9VN+6TxkWj3ncY2or/RcHHL8Scu4zp/0re+rt1qxdI7gMDNle62Ic76ZTBtipDAVYs5cL6T0ZUd+E+LxlJchd1koUIWwKiDW/AYk4TP84+iv3idvp+8wfDACGf80lFpu3l+ngIXCay69BkXo6IoDkknsB+CNTbkCl9U9SUsff05LCHBtAbF0DAk5Kd+9aRtPYDb6vnYLVb6fvIGI6WNHBi7io66XmYOtfLglUPIysrpTEhi+KF7CS2+xnvhc3jkzF/xmD+RT+rkLD55kf4pq2gbspNaVYAlOIhOwRAqiZG2A1eIEOv4PHA67upNpGz8KXnb68nZ/wl1Q834m810623UjJ+ATWdgor+RmYlSHlklQCCAJ99pw5SVSayhnZ/G93PF4IW0u4szHkm80nGEthNXqD9TgebJH7DPFolweJiwyus8bz6PWBjMi02B2MRSZm8Yi6yvB0VnC2Hl+ZwYv4EMpQHhqfN4n9nPYHAYHpMTOeGaysj0NI5WyIjusRNQ0MdQti/K0gYWXDmFad0yprq2cPiVPh7Q5fHK5THoLJDrMogiIxX52HT6f/s3h3fyyjmYm9oxN3UgKxcbmQABAABJREFUCQmgcfcF9sdM49Vc2FkB9o5u9mj8mSVoQezvjXVgCLdp2bhNHU3vMy9jKK7GJScVgwXyJizk7h0v03bUhX2T13DXNF98XaGh28yos19gCIS9K59i1ZWtCARfCofnlRH6xrMsOHKZrS9uJ/DBO+gwyAifn42mvxNxiD8HfvgxtuULSWs7S8qy8ezdcp3s1usMZ6RjOZ/P6ffPMMMDAtxFaHw8Efl5IfLzQuznhSjA5yYNyqzZibicv8QHl8Zz7zgZ+nMFuC64uRJpae/+UspoNG5floUPFcCyBMd+oYsMgVyKdu8p5KP+TpPmX5z/scBv48aNzJ07l6CgoG8dM27cOMaN+9r2JDc3l8TERN59911++9vf3vaYF198kRdeeOG//XqdOHHi5L8b/dUSRD6eSKNDv3vwl5hqmhD5eCLycr+xzW40MfTxfkReKofThvD2xRvFzHHojl9GuWTaLftsWh02zcg/XDNoqmnG0tZFz/330fLaNubPikfiHojMpmVmxzWuG2RkhHkxvGUfRavuJ+KN11CoZIi9PUAkxGX8KLrOl3EiMpOwc8cJSo/A9uq7DM/MxVpWjXHqnUy8tJfRdBN66TRqi5gP5/6AJyq2I9CrKYjPIvbQTiILDyOYmoWxpZuuq1WMVHTQJ/FilrGWsZ9fwl1iQRXkQu/jDzNskaFu7scQFId1cJimqDTs5xtxzW2krFHP+RnRzL1+hPBIF5LW3E/rqh+hGJeBcsd+jO/+jpKPK5jnp+KFkxZagrOYlBLH0O//QEJXNZUpuYRXF2JauRh9fRMFinC2nYfLbRBW20pfchinPNKY+NkRfHLHMyxyIWPj2xjqrmB49llqInPYUQ73Z8IHhW4sXjeZMskk2kpbGXPtM5JH2lAQgGTCKLru+xVergGs3PEyAboBxHodna5+1D36HJ+JffFyAYUYJodD0GA7ublBXLh8iYCWZg7MXEu0Xcb5nfvJv+dx5NdPEO1znmUpXkgjg9GdycfS0oVicjYipStCN1dGjlzE69n11PZYqKkd4r6HvVBK4dHRsP3dVgZ8g6n/+BhBv12NrKER3ckrKBdNxWVqDtrdJ3DJSeVSrZEZVw+hi42joqyXteOVeLo6gqi2Vw8QvXQ8F/wSyVWAokaJtV+Nub4VWUYCQpmUorTJpLo20vOXj5EuWIggOxDPx1dzbss1anyieTZHTO8zBUi7+5maksaP3NbxzmIp2yph0h2Q6A92swVrvxprzwCWngFMFQ1Yu/qw22yI/b2RRIUgiQwmaslYBFWXeLRnAvOLNMQv8iHeBkJsjHxxAXNtM+4PrbzxgapVw41s31fIspIZ+mAXqrULv/f9/K+AwG632//ZL9Lc3ExUVBS7d+9m8eJbP7X+I1asWIFYLGbr1q233X+7jF9oaChDQ0OoVKrbHuPEiRMn/9NYBzUMvbsDu82Ox0MrbtHM+zYGXvkI9wdXIPpSZNbS3s3Qpn0oF01Flhr7D4+12+0MvPgBXs+uRyCVYCiqQvfFBQRuCuwjemSjU3CdfnsZF4DeH71CS1wGDRYls5Ok6LcfQhTkhyGvjJ4lSzCU1UNJBcIAX6z+/oRXFuC6YAry9DhsBiPdhU3sVaUTmX+ewSlTWTJYgP7oBUzldQgyU/ggYTHzvNRQVYd60MiJuMmkrRiHrbSaMc8/zfFZ65hcc46ehx9iqzSV+LZyUnZ9gkxgpXf6LDwbapAYdHjaDYSHq6ip6KVownz8iwu4FjeO2b0FpP5mAx76IUofe40qiT89iWlY1cN4joplVGc52uuVyErKaPcIoMUvirG6RnomTKFBK0E52EtIRx2BbbVUpIwn+Tf30v3My7R5h5LhaebY2DsolAQxqId1Vz8j9PHlmMRSBB9sxuPsWcRDQ5yYsooMbwt7x90JONafNQzC/mqYHeMI3HZXwrxDH6A1mBkb64rh+CX0fUPoTXbMLgq6rTIiexshJYHTtmB0Kk8yQsQQ4EtldCbGi0WkaupJi/dgnz6YeGM3pqERajRi3J5az/JEO0f/sI+xRSeRZSVh61fj8+qPEIhEDL72MSIvD6SJUdRFppF/opolrp14LplyYx6ceGEX7SIVIRGeFERlsyHdjvBNx7wCaF/6JN6vPMupN48TvXoqh8Wx3Ovegv34OUcj0PVKNqctZelYFQdr4OEsMBZXYW7rxlxai8dTaxmwydhcAg+Mgo3ntEw69zmVoclE5cbR/LsPycwOwqOnHXFIAB6PrWJTsYBwFXxRD+vSIPU72gHsdjvW7n7MDW2YG9sxd3TTd/I6jQmjsKQlc9B/DOL+fmZe289IYhJ9WTm4ywV4WUeI662nqN3CJH8zKrEFu9GE3WLFUFCB8XoF3r98GJfcjO+6lf9LaDQa3N3d/1tim/+RjN9HH32En58f8+fP/w8dZ7VaKS0tZd68ed86RiaTIZN9P2V6J06cOPn/xfCWQ7itno9AJmXovV14PLX2O101jOV1iEMCbgR9pvoWtJ+fuG1p93YIBAIUk0ejPXAGS3sPIj8vPJ5ai91kpu+nbyL0dGfgWikif29kKbFIk6JvlIUNZXW01/VRvSKXWcc3IxiIwqY3ITKasBuN2P/6KdGzR9PvIaeue4SYhyYjaCoDmxV5TioND/+B877pZF3YitHDkzG2elR3TMN1ajYjZ/P5tMePlee3Ezwjk0IvH/rUA7QkjSYdGNdxnS6fYNLyT1DjGYhB5U+AwErwwX24uooJkVixeivoqBCRotfgHeaJ/LG7aXriYxJ2fILJTcXovn48QxTIjp/mtM0PaVs/Y4TN1Km7ackYi7ZPS+GlJhK85bT8/ne0bj+D73APn6YvYyRrMqFuMC3WgnnlQ8g9XBmOjaPitx8TJbJw0isa3/OfkxKfjiBQRqCfjRx9A94n92Ab1mEPc6Ozf4DiF16kv6ibjSI/Yl0cGaP5sbCxEGZEwWPZDjkTy7COeFMXvUMm2jwm4x/iT2XaJHTLlhDmK6HtZx9AdjqalSuoOVKFT0sdTRoBEbZuEr54jcC6MpQTMun2SKRNFcm6VeMwbP8CbZWBmTVHuaLPJLnoPOIQfyQRQQjT4jDXtiBLjEK5eCq9T79Mx9TXKWiFFdpS3GZ+nSHWGMHa0cv8KA37k2YSr4CPigWsyMjE5cJ1XKbmIA70ofFnf6Xv6V9RJ1HyaBbIxWHoB5KwDmroWncPUWoB27/MdgoEIEuJQfPJflwmjEIgk7H1msOv90wzTExRkj5tLcpPTzLw2M9Qjc4h8q6pjBw6h+uiqRR2CVBKYVoUTIpwdO1+n3tBHOCDOMAHl9wMRkxw3iubpeImPNZkMe3sVbpbKji6YDF+YV5siAJTezfqj/ZRGJZJ2ZCMUVEueHmLEcgkIBJhyCtFMX0sogDv776AfyH+6YGfzWbjo48+4p577kEsvvnl1q1bR3BwMC+++CIAv/nNbxg7diwxMTGo1Wpefvllmpubuf/++//Zl+nEiRMn/zQMRVUIvdyRhAUCDlHkofd24fGD1d/a7Ga32xk5dA6PR1cBjtLs8LYjeD659lvX7N1yDosFS+8Amo/34vfOz5GEOZbaGCobcJ07AeWCyTcyIcbSWoY27sZuMCLLzaRgZx6yqCiWGKswKBWMnLiKNCECY2EV/UuXonj7fbTXyulR+JI1OZT6t7diWr6CsNYq6j49wVBBNVFBOuTTx2FYvIigGAE2s5mBlzZSiTcibz+SD76G4UoR/j/7GwXhE3l4jAi7Wk3TnktYlZ6EBLrw0fSnSNm8j5kDVUQNt2JJiGN77BLGfraZyIk51PskM2pZLMce+wuqjl40cQlEJftjOXIWUbOYK+/2IdDrCJcZGRSqKL3rPh65O5YrBT1YD2yn84kf8ZeOQOLjNKQc28jyWCuiDLDYoO6dz0nWqvn4sZfRhEahNJ6hP3cyacUF+MQGYvngQxaGeNKnB6lgBEuEH0K5jPZjBQgmjiFodCyDBy5wJDeX4U7IDHSc90o7PD8JirqgutPMurzPsNa3EPGn5yj6zWasmQnIV62isQ9CPttJenYQu/ShXGjw5s0fjOc358czojES3FnHPeFDMDGZppxcmo4VcodbG5qtKi63CYh7cjmG997D7c3d+D8wD8/1ixk5dhFjWR22YR2yxCh0xy/TtngZfTtPcv8TM9CcHELs53VjDh2vMJLW14D7z59mQ6iAHRUQ5QH77BnMOfQh/sXVuEwfS9e7x9HaJTwy2iFyDOAywdEtfqXQUZbODQGPr6R5RSIs7T3IMhO5+KUzhpcLVPTCvBgQCIQE+MgoWrKCyQ9MRiSzYB0YYsTdi1PX4Ykcx2m+T9B3O3ZUwJQlKfDeVYbe2Yo0PoKIn6/nYYGAsh7Y+nkdo4rPkPrjuzje6MaGSLjaAUd0MD8EgpoqccnNQBIdiqm4BmnU91++8f+bf3rgd+LECVpaWrj33ntv2dfS0oLwG2tTBgcHeeCBB+jq6sLT05OsrCwuXbpEUtK/18JJJ06cOPkKu9HEyKFzeD5zz41t0oRILD39DG/7AtVdX1c07HY72O0IhEJMJTVIYsIQKhXY7XaGNu7GbdXc7x30GSvq0e45iWJKNh6Pr8HS2n0j8NOfuYb7A8uBmzMhrjPHYdSbuXrfy3jK7IRkR6PZcghJaACW5nYkYQGUjZ+Ny4c7cB+XQZkkgPTBOqznrhHt68neiDG0nyvH7chRWlffy/jik+wat4gfuPSh2XoN/bkCOt0DuDBpNTPTXRHJwdTRT6lvPG6p0ehe/RCPxlpcpXZ0CRG0mCFZ10qicJCAlho6klIJiIkgtrGMA+NWsLTgANZnZ/PjKh9yu9WE2Ybx7a1k8HI7Lu+8wDmDDy2Xa1m87TVGNEZ6srOYtOl19g7dhejQcXpnLsKmDERngWw/C0mTYqm+WohiQiZ53VKm79rC9cnzCM+OYsgAq0157D/tiq+vCOnYDArHzye0oQzzqHQyg00opmTT0jBIcYOYkQ1349nch3+YJ+MjxQwYHHImPz0FzWqHI0RPcSOjTh4nSDWCYGwa+g92YAoN5dT4VfwoTMDQFxdoUfgz2dzERa85RHnC0Qb443TYuqcDVW0+OXekII4OY+dQPPVT4pmQMkjlT94mMzMEt48/oiV5FLJWDQwNYzeZcZ01HmRSht7eimxUIuVaF1rmTmXuxZ0Yvjh3k6TJsBEspy7jEel/40PLyiQ4XAc+bmKK0iYTHKai0TUQT69i1nSeQyqaefM8tED7MPi4wKpv9GiaKhuQZ6cweKqAq4lBPDlW4Mj2hTkygvqrJbTX92OdsQRfVzAW1yNNiuHjErgzGSQi/tOUdoOLGGK8BZjvnINAJkEc/HWtOKoyn+C+aiofXMcblVL8lRDl5fhSfykG7fPJFRJ+dCdhAXLE/v9eGb9/upzLrFmzsNvtxMXF3bLvzJkzbNq06cbvr7/+Os3NzRiNRrq6ujh06BCZmZn/7Et04sSJk38aw3tO4jp/0i1lXcWk0QDozn0tBaE7eZXhHUcd2b7D528Yv48cPo80MQppdCh2ux2bwYh1YAhLezemmiaMxdXoLxUxcuIKw58fp/fnb6I7dgmPR+7EZXwmiinZ6M8VYNUZ6Lr3F9jM5pv8Sr/CYoNtexvxGZ9C+JwstF+cx1RchU1vQDFvIpqmHqQHjqD18OWzxU+SHu2KJDyIYaGczvp+vJvqEFlM1K+9lxkjNbTZXVl6+D0MX5xDlhbHsLc/x+ZuoANXMgIcQXHJgULaM8cybv1ELDoDyroa+pXeDI4Zh9dAF+t0Bbhr+gl+4REConyoOlhAnVZMlqmd04/8DO07m0l/7XdE9TUS89RKvJ+5hwG5ByWHSqjthRzpACZfP04nTOGwbzbaqGji33ubhP56BENadp/tJcbFyKKRMvx/+yhJASIaX9pM1i+eRRnoiXr1XRgHhll54ROGW/uQy8ScT5nG6eyF3PvwKFoCojEdPM1OQxivX4Gjfz6NcvE0lieBa3ERoxZnEODmsAZbnwG9I5AVCIGnjxFbex3vp+7GrtVhqW9F6OFG/tL1+Ov7qf/8IgH9behHj+Jku4R5qTIa1RDiBsleFqIunuDKzJXUFLShSImmuAviPe3UbDxC1h82EPGLe/H8yf10XKwi9o+PoJiaw+Bbm7GN6HGdnI0sLY7SFz5hcMZMViWDau1CNJsPIk11PKuNFthXbiHz+gncvvHhRCBwlKvD3UEbF88hXSDbyiBn+SjM10uxDg3fNKeudzre8+pUx7FfMfLFedwfWMYFvSfL8j/HZrFQ0OlwyjDVtaC/WMSBUYtYGO84SH+1lMLgVKI9HW4n/1l0Zse6wDu+7M6VRIXcCPrsdjvDO45gae/G4weryY2W8vRYuOsbhjQeclipbCc12ZszvQreLRTSLbtVZulfGadXrxMnTpz8kzC3dGLrVyPPSLjtfrdVc1G/vQWRrxfShEiM18oAMFwqQprsWG9nqmnC0tyB+yN3Ym7uYOiDzxH5eyN0kSNwdXF8V8gRKuRgt6E7l4/LuAwAhy+uyaGxZmpup3PpE9i0euwmC8ayOmQpMTeuxW6HT4ohPP8i1WtXY33tj7i09uJ6/3LMVwqw9AwyWFCPMDiUwsd+RotRwSX3eKL3nEU2OpUgsZHw0qPYXPWkHv2EweAwTImJ+KaocLtjBp2bj7InbDJVQ2JyguCtPJjRep3eYSuixGjannqZ4NpyXvvJh0wtPkraYCvpbzzCwO/excNVRKNnGPXDrdSPmUVIRgS9Whvi7fvwVAiIqbhOybQlNGbNQ/DmRt6d8yw56hpmbXmdiMFm8u58kKn6VjTJ4yjY0U1IloxIhYnyiAwmFZ5n/PkiCsdkkG4SI9ywkqRFj+Gi11Kxch0RrVX4VJbg6iWkyKLi2uw78VSJ8fnzX9hcMYH8uHE8dGwn1RYLd/r2USQ1MW9WEOU9dlQtTUQ9OR3NKZgUBsfqHZkv74tn8Sw4SHtIChOunER9tQRJWADtUi+mHvqYboU3VZGR6OYvQ3PyGpKUVLICHKXhgi5Y2HaZ2rhRZPlaOXNejqlDTJ8evMqLmZjti1u8o+xYv+0sopw0XAM8IcATgasL6rc+w+2hlZzOnEOwdzULMr/8ACAVI4oIpm7nBS5OX4lxSMe4kpN4BHkgTYi8Ze5OCIOibujVOcqtLxlSWSXqRLv1FGEPLkb0ZVppa7nDEs3jG+6rproWRH7e7OtwQz57IgGacsp/8ylj7roTe48O7c5jVCy9m0yJCIXEsWRB06PhusmTH3y7s+D3YkcFLIkH2d9FP3aTmaEPPkcaH4niGw1Pt8ssjhy9gN8dM7jHH3pGHBZ1QW7/tev6n8QZ+Dlx4sTJPwG73c7wti9ulFRvh0AgwP3B5ahf/xTrxFFIokMc4rwf78PvrZ9i1WgZ3nkMz6fWIRAI0J24gsfDK28qS32FbUTP4Buf4vPrx25p/LBqdQz8cSOm6ia8f/0ohivF6M5dw9LRg2LmOAQCAXurwaYeRtnUQOS2D7BcL6Zr3T3UjJlH6vZTSLtr6AmPp1QZjfDsFWZKR4gqvISPWYOwtBBhVAiG2maQSREG+HFy/DLuXRCA7pWNDI8d4MiFboyrZ5MggnR/iPCwc+a3xxmRKPDdswd38wjFq+4no6OUwbBoNsVl432ln2XnS6l67S16Nx5GYjYy6pdT0by7nTb3cK5lz8bv+CaGPPwIVAkoPFaEweLBs/YCtEFujFLL6SuREXvlFO2TJxN5eA+qrFCKo+9Au3Mbq3PkvO8+h7nX+whID6Lo159gKq8jVmDFKpcTmHeJ/upa7Bip7Ozj4Ni7meKlI6P5Oqb10yku7GJpwQeEpwQiPneIw4ddmfvkTIwWyN+Rx2z5CHUfHWVWiwbd4AiT5dBU28+o2qsUpo4jsbmcwdP7sWl1iKeN41pENutXx1LZK+DtazDYBKtaqwlaeTdlaoe0yrsnNRSU1iJdsoEnyOf1hETuPwh3hWqIKMjH9cH1AJhbu6gv6ST711NvzANpdCiKdYs5+/w2Qu5ZQmxRIyMXC2mJSafsbB0SSRhhdS3MafwDyqQI5FkJ6Aw9tywvsNthaxnEeTsCoxWJUN0XQMtlMyN2LUeP9mH08sFDDmq9o5P5mwztP8v+zEXEKBxdzXZ7MhebPFi+/SMGLWYUT6znUr2cp8c6xutL6riojGVNKgj/Cy6upT0gEzmu+6b7Q6Nl6G87cJ0zAVnardXJm8b2q7FbrDfKu36ujq9/J/7HnDucOHHi5P8S+lN5yDISvrP7ViiX4f7gcgZf/wT52HTHU9VqBakEzQe7Ud29AKFCjm1Ej3Vw6LZBn91uZ+j9XbjdOeeW17NpdQy9swXXOeORJUQhCfVHmhyNy4RR2Ef0aD7ay7k6M/2tg0Rt/Av+zTWIr+TjmhbD6NkJzHnlpwR4y+jNGI1SM0BkXxOzBorx6GjBVllLp38ERqMVm0iEdVCDLC6CqsQcplWewnbiAsRHs+/pzcjumI2bDExWOF5v5+hLBwlrqqQhbhR6kZQ6eSAnFXFM7i1hJVU8PdrMyn1vo164kI59F7D1D1I+aionPs1jlzSZvSETWZW3jdEuQ6gDQvAZ6GTOhy/h7y3jfKudtK0fYLLaubxoHeEeIP7LJqJ9RZwLzUYsgjmPTqFl5xnWdJymNGcGobOyiPIVE95ShUEsp8kzhEGrBGVsKIWL1lAbncmaad7MqTiKcvM2fN/5Mwlt5WgNNgar21FWVpBUdI7LVSOc/espxl45hHJCBud908h4dB6d69aTP2oGYT2N1P34p4zJ8CFo/TwQClE9sIIatZDp53Yw+ONX8cq7TGuPkUDTIKFBrlzokrAuDcaFwPTiI3wQMYepEQK6LlfRFx2Ph8yO1769tM5eSJtOjM5kJ++dL9Asmo+34utIyWiBD9p98X58FRFf7GOT7yRO7irBvO4pst59lSlRQpKfugNFqC9ud85FmhaHyP3mVJbVBh8WOcqt6f4gETosy3LDBIzKCaQuOo2F1Sd4PNvxt34k++YSb19VB+cGleSmuDM53LGtoBOik/0Ri4VgtXO+cJCZUV83bhQeKiF8Wio+/wVrNL0Zvqj7usT7FZauPtRvb8HtrnnfGfQBjBy9iOus3P/8hfwL4Mz4OXHixMl/M1b1MIZrpXg+d2tT222RiJGlxKDdeQybwYjrnAkMvbMVWUY8kvAvGzIuXsdlQtZtD9fuOoYsIwFpTNhN221a3Y2H2sihs3g8tQ7t3pNIE6MYfPUT5BnxdFd3Iv/4SeIUSkJaqvB+9Tm0e05gbe9G/fYWh/SLWYxeJOTovPt5+ukcAt1AdzafXmsf+aJgNHetxP39jzBFp2JyC8drxw6Ufq50n7Iy0Kkh026F4yKujyiZEqMkv1KDf3UJebnz8W+tQ+PrR0VIKs8VfUxYkCvySH+Mb7+L0lNK3rK7SH7kcVymj8M705vWzflIw6NYteUnuGnVVEQmUps1jv1iD6YY3BgRyVg6WMCJxQ+ScnAbk7lAx9U6ShauxbWnm4h9u5CtnMebbQGMGVITKupHMSWHqmffwXyxAK/4UC7PWInYz5vov7zJYFsPWT4VNOiGSKnNx9o3gPevHkGWnULJ7/cy3n+I4lFL0Z26QrZ7C4rXXqJL4oFndiC9Qxb81HVI5RISSgcQ7j5E77gJeB89ztD4ZGyf7EcaF07T3RsYMssJc+tGd6mI1l0neaxtHx4yO0OBQdzlfwDLgArrkBalXsOwRUR7iwZ9nx0fbxfmVVwlY3wkLw77U3wexrQXkpMeTe54jxtzQWd2+NjOiYZYbxU/Sbub8MunyVmUjs/PFtL/i7exdvdj7R5AtW4Rmk17cZmWg+QbYuMmq0OKZnSQYy3epmKY/Y3Sa8jkFOZfreJSk4yAC60IlKGMDf56f5Maij44Q+7jc/lKN9xuhzNNdtYXfo7roqnoIyKx/HYH8Uo1BGTS1GthZGCEmalfv5f/DDsqYHGcQ1LnmwxvPYzHwyu/l66mTW/E0taNdPV/TJruXw1nxs+JEydOvicjRy5gLK/ju3Tvh7cewm3VvG911Ph79Geu4bZ0Ji4TR6GYmoM4LBDD5WJcpjo0K+x2O4b8CuSjb1U4MOSVYtMZUEzJvmm7VaNl8K3PcFszH4FUDCIRsrhwXGePRxzkizwtlsEJEymKyGIgJoFgkQ7FlGzMJTXozxVgae8BoQCzvz/X3aPRunvjNyObQDeHO4JmyyE8p2cTkRqIxGwi0DhA9PrZuB45Rs1Pf8nVwHT+OO+n9CZm4OLvTl6/lJgIFYNtg5g7+3Hr7cKrvoaQhWOZ5DXCY+qzYLeza+5DGIND0J24zNm593D8eDPRblZcJ2fT9uFBaifOYXnrWYRuCrI//RlaV3c6Ro1lbetpij1jmHh6J8P1nagvFlOjd6HuTCWm1k6Cjx3E0N5LkrGToF/8koDPNqNsbmT4ajmS515g6HQ+AqGQHq2dqPzzmPedwJiaTIJoCO3h8wSnhuDx8ApEvl60Z+Wytc2dquXrkBgNhJbnI48O4cTqJzBYwN9NQE23lYuKWEaNCcY+okdy9gJ94TGoRtTExXvS88VVrGuWQUggZ7rlLIwDcbA/qhWz+WjVL7hy71PIBTaCgpV4JwQhiQhCc64QXUQkKXX5aHccxjZxDFHmPmaNVHAsYgIZ/lDYpCe5roDElbk3Mm1DRvhLvmNtW6KvI3gTKFyY/OQ8CkMzsHb24Tp/El4/vhcsVjSb9iGQiNHuOIokytGKa7DAX/Mda/uyg6B7xLEt5BsJZklMGPLWViY+PJ3Sj06S5mu/cQ35HXD0Qje50VICI75uhijvhdH5x3GJC0OekcDeJhmpv1iDpbGVti0nOH6wmlHTYvkW1aPvRXmvIzMZ73PzdnNLJ0KV8nuLqevP5OEyefR//kL+RXAGfk6cOHHyPbBpdRhLajDVNDHw4gcM7zqGpavvlnHGkhqEbkokEd9uT/lN7HY7xrI6pGlxyHNSkSZHoztxBfmETKzd/QCYqhqRxoQi+DstVHNLJ7pz+ajW3JyBsA4No35nK6q1C5GEBaI9cBbloilYbNAfFY9sVBL60DC6fvY21uBAQtbORmoxYSqvR7PzKHadDllqHLIpYzhsC8fDPELmuslEejqeviOHzyNyd8ParyZlQSbKjz7F9ssnaXlzO74BKoTBgQyuWc3Ky1vxmTeWt5f8FHl/H0NmIY9F3INVKqPL1QfJvCmM3vMx9jOX8bt2haQfLmZehJmitw5QHZHKnqtD/KR2J6W5cyl59zC1kWkkll+mr7wFw7w5/PakBalRz9Ltr9HVPIDZz4/yhDF0JaaTYO9HFRtMc9ZEPr/3VzQvXIZcYGdjwmJ2rn6O0I4apG1tuDY3otSq0bsoMdgEnMheyPuB09HNmILgyQfoCo+nJSqVpNmpdDz+Ime9ksnrECAUQLNGgNXHhxa7ktHmdpK+2Ene9OW4Cq1YY6Mo6xPgJbGhK6lhMCCM+qnzcHWVIBYJSfvgJ9QeKeSLyEksT/y6rGm0OJomxoQJ0U7IJfHpFVjaehh86zNKA5IozJhG7dT51C9ZRUNALNJd+7k0fglaM/wwRceLrbs4JIxi3ydFGIur6dM5Mn2rUyDSE042QHE3PDsOkvygshcMV0uQj0lDIBajmD4Grx/fizQ9HmN5PQKFDK0J/nIN5sVCqp/jOg/XwsK/M44RCASI/L05XGZkxqQgGi7VcKoRDtZAZR+saDmD15KbPXGL9uST7GHGdeY4WoYc29xdhOxNW8RVtQszDnyIpL8P/cVCTPWt2Eb03+u++gqDxSG/sjTx1n0jh8//Q6/qb2K3WjFcr0SenfIfev1/RZylXidOnDj5HhiLq5GPTUMxaTT2JXbM1U1o95/BNjSMPDvF8eAUCdEeOHOTZt93YSqpQZYSg0AgwG6zMfTB56juWQxmC7pjl1CtW4T+dB5uK2ffdJxNq0Pz6QGHCPQ3AkJTfatj+wPLEAf7o2kfoLnfwpVOPwytIGlqZmbFCa4qo4gOCSC+5CJeB2sQxUcgy0pi6K3PUK2ahyAhhprXdiCNGU1OnIJ8/zgyA8DSO4ipthmByhW7wYh222GCpmfy+aluJvp6UBkQSeSZI5zJXUxioJienl76E6YyLceXn9hieODUXxlTeY7X7v0TrhoNYSYZgWn++K6dg6WtB+Pbu9H06/n9tKf42dX3aBkaYn9ABBvq96JzcWN02VnKJi1AUVXH3ObDKJfOoCu/jqEHHuXJgv2cWTyHfXVmnjj4I3bP2EBISw1ZAR50tZgJ9Rbxuwtv0KYX4dtQgVbhTsmsZShamvAa6OaFNa9i8vZhajgovcEohS/skSQainlNm0iyIJik1koiw21s9x7LmvARzmwxkBDjjtEmok+m5I7ywxRMWsxIRx+LmvdwqcYdL4kU91VTSX53O5ZgbzyeuBuBi4yQ7iYOxf+QgF6I8ABsVt7f1swdpRVI8zsY98wCRJ4qXOdMoLmwhQKTB0uPb8SQncXx9iAS9r2LRSBkTflehkwCao+aiVR3oMmOokwrY3hLBbaAXu59aAI+Ckfm60QjLIwDf4cZDF4yC0NdQ7h/Q4tOIBbjOjUHaWw4fXvOsTljOXcmQ9iXdtFdWof0T/Btlq+WhaQS3VhG9NpJeLy+md+p4vCS2vhxQBPaHodlmuFaGdY+Nb0dGoItLvg8vwK73VGO9VfAx8WwKA5C/eIYGkpEMWk0lq4+jMXV6I5cwKY3gB2EKiUIBLg/sOxbhdB3VjgC1L8v8Vr71djNZsQBPrc97u8xXCtHPjr5e2fx/5VxBn5OnDhx8j0wFlfj9mVmTSAQIE2IRJoQic1gxJhfztBftmFVa1AumfadVmzfRHfmGqoNSxw/H72IfFQSkhBHA8fwrmOYu3qxmy2IfL4uj9mtVobe3Ynqrnk37NzAse6u/zd/xZYYR/Pvt9JvEaNqbUY2JpVFl3cgampB5+vHHo0vuSN1DIVFMMrWg3TpdKzd/Wi3H0GenYJ4wXQ+EKSSKjzK5IpT+L7wCj0lrXifu8rwiB6XCaMY+eIcQn9v9BcL6V15J8pTdVTIghirNJLnEsXyQ+9iX7GQw4c7eUGSz6fiZPRDVsaLe+jxCuRHUyQMPvEun2csIidMhlqYSq3ITnJPHm1ewfxu+Dh+JVdoDk/kydrd9M2ZRmZVFQaJCwFDXcTHutNVOkTf6UJCdH0M//lvdGs7ScirJbOmmnq/GMJaqlGY9NQHxKBKUzFrhooj7TKiXvwdHd6TSMoJY/BKC+L2Zk6OmofB04fXZ0G6m55P9jTTnZmAj5sIL6GJKY0XUa6ZznX/eIrPXsK78VMi5iWg6W+nW2LjyuxFLKg8ht+LTzJh51GGjCOI86+giYzmmmsUs7dtpljhTvGy9eQqJKiL6+iSe/LiVDv5J8o58kk5cRItrYNhuOZkMW7pfORyAXY75P/5C/YkLGD9wmA+L55ITFU+U7e/gjk2gtTXHsXPXYCv3U7tEx/h+8enqKtXsTwWnjuewuTiY1j3HaNzwUz2VgnwdLl5Xd6Y4QYa/aK4eXWoA7WnP5ebrayZ2UOQu9+N7QdrYMFtbKLbNXDdLYLV6nNoRVM56pXOwxc3ohdIuF5RT8LcDPjy3hF5e7C5TsXqNCEWG3xUDNV9MGcMJPuCuboRzZ6TjvV3nqobJecb94DdjqGoioHn/4xdb8DzibtvuZ7KXsf3RN9br3Xkiwu4zp14m3d9K3a7Hf25fDx+uOZ7jf9Xxxn4OXHixMl3YLfZsGl1t3Q4gqMr12XCKFwmjMKmMzj09L4n1r5BBFIxIpUSq0aLsaTmpoYQl0lZDL27C7clDu9UY2s3nScL8ZJYkY9Nu/EwtNuhvrAV7W830R2aiiklmyi5gVEjfeiaihFcvYRNIUfg4UZPVS/GhRs4kbiQZ9NG6Lv7x8iyk7D2DCCQSrF4ebFVH4W2sRFFciTi/A46n3mN8NyZKO+dijjAh+E9J7H2DGAzmjCNSqdy1xWE82YiLy3lqq8P6cNNCGsb2FowgcjV07m+8VOKgicyu3I39rIKdGEJiO9/hmsz1jJB186W4PVENw+TvfcThDYjc1T9dJfrsLkpie1toNfsj9BXh1tTA9ejRpP02EpaNUaO6mLIHKihPHYyOef3YZo8nk6FN3WucZTPW8mMIx/TMWRmmzmapW5QqIaqOg1pQwOcv/fHJObvwdTWjVauIjHGHXmMYx3bk51nSe0a4MTZEvwsGlInRSNuLsHnselkCwSQMR5zeyyla54nwm6k9MnHUW05gPsf7kQW4I70uQ24XivHMjkV64ELeFtG6K7uhLQgwnZuob4pjI5tJwhWidD/bSuZqbHU3z2LJwo9UOth0xyH4LPRAp8faCRQIME/MZhzzdA4ImH93WMoqyhl48T72K76UuD4XAEhY+M4r1GhkjkCqa3LBRzJmM07m88TWbQflyWLuDNZcJMkSkB1CSdCp3JzARY6hmFzKWx4eDqSL05A7Kob2wUCCPy7W8FkhS1l8OAoMYN5SvaeVrNiTTbBqmysfYPUvHeYT1OXcl8muEihVQNKGVT1w9lm6ByGt+aAiwT0l4sxXC35Tj9r/fHL+L3/Ar0//APyMam4jE2/sc9ggQO18MOcW4+zjeixdPXe0gz1bZhrmxGHBf6HPtD9K/Pvn7N04sSJk38y5roWJN/jIfEfCfoAdKeuopjqEIvV7jqOctnMm0pWslGJGK6WIEmNxW6z0fT+QcrO1HDtfBM9KZkUdsFHRfDOsSHUr27CLcCD6S/fw/w1GYS52zFfLMBjwx2EHHiH4M/fYOil39K67E5Ueg1+rgL0O47gtm4Bw1uPYB3SonF150CrC/oL1xm39S+E5l9AtXo+fTEJBOn6qDYoOFIHDceKMfQPY3Hz4EixHreVs1miLkQ8dQxnw8fhqh/memQWOXs/xvbOJvZ4jSbzyhEyy86hdvMmd5QvtgA/Fud9jtuImg0n3yXskw8xjxhpi0nhSEAOEc2VhE1KRuDlTteseViul2PXDBM9Jor3SsRU/nk/2Q1XidK0E65uJ/SXD7AnaS62inoizIOsvLCZZq9wxtFJmKsFlRxevQzCXfu56p9C3IVj7CaGYYU7ngmhWOqaeTLNSLxgkOYrtTzvNpP6mHRyu4uRDqkRyKQ3/W26cUHa04N86Uz8/vgnch+cxrsN7miMjlKpLDUWU3k9/Z0aUkZaqX/rLfbMvI8zE5dRdqaGEHUHkW88hecz9+A6K5dSsweuEpBLHJIp/Tp4+4qVzPwTnEydzaDe0aDwxBjwaaknYWIs02JEXGx1lP0Nl4pIWTmOEw0OQeFID1BIYFE82KdPJF8STOCuHQQpLDfeg91iwT44hEeIN53fMNxoHHQEcXelgE+EDwKpBHNLJwAHa2+f7dtR7ugYblLDaY9k1ljKb5SCtYfOEX3nJJbEO5pMekYc87ZpyGELlxUAyxNBLraj3X8aU1UDHj9Y/Q8DLcPVUqQJkUj8vfF47C6GPtyNuanjxv7dlY7r/PsSL4Du5BUU08d+67n/npGjF3GdOe57j/9Xxxn4OXHixMl3YCysQpYe/996TrvFgrm+DUl8BObmDuwm8y0ZCFNZHfLRyZiKq9EdvUi9bySevkq0Yhf+tKub3ZUwJdDEXec/JsxPite4ZHRfnGfo/c8RBfggG5WE6q55CMRibHZHBqQlOpV5QyVI6+owCCXIx6Rjqm6kpa6Pjtpesk2thFcWoDRqaXX155DGh56iJgaae7E/8BxR77yB/PIVLo+axd5rw8jCAohJCaCsZoiDw/5E1BbzYYcvL6Ws5Y/rXsLQq2ZUXzWZzdeRazUEifQYKxsY9AlEJrAiq6rE1NxB+nAjIZXXSTu5j9y8Q/THJ5Nn9eUvK36O15Ev6JZ7sn/iXZS0mfEqzMdHZGJCth/KpzdgFon52WAyyy5swa23g2RvK5Z+DfOOb0LnouQv+kMUtNsJlZtY1HkVWVggJrsQr7wrRClMSMxGstU1XFv/B1LeeQWtixsP9p9n0oU9lISmM5JfgbmpA6vaER1Zh3XUP/s2rsmRFPcICXpkGe6nz7DKVMrfCmBQZ0P9wefYXF2xa0YwL1vIhQ4xIe11TD2xBRdvNwJffQbxl+X7wi6Hf6/RAj+fAKeb4IVzkFN7hbN+6dhdFbwwBTQmR2eq/mIhwdMzeGAUXGmH3q1HUS6bwdUuEUIhjA916NaZrA6h5ZnREDJ7NFWhKex+ZitW/ZduLhUNSBOjyAmCa1/GTCXdDlcVm83RFPHKZSjKmoZ69ylaNY7g01950zSloNPRnNKpdfy8ZmUcoqoax7+Vehhb/xDS6FDCPWBGBDx40BEg/ij3aweQ3GArmk17QSBAtX4JAtG3m/HazRZ0J6/gOmcCAC4Ts5CEBaHe+DlW9TDVfY41iEm3KfHazV+61nyLm85NY202jMXVCCSS7935+++As9TrxIkTJ9+Buakd5d81V/xXMeRXIMtyyLMM7zyK+71LbxmjP5uP+8MrGfrgc2wuCq6pI5g1NZcZo4OY8MF2Buato/TlXQQNmQiL9sXWO4hiwihk6fEMbz9yU8fipVYIdgNxjxqv6AByz+/h4obHGfXq27SqAvGoqcN35ULO9irozxxNxo5K+M2P6PKJY1/MeB5qPULI2kn0/+kjLBYrhsYOEnygYtZ0tn+YT78kiWlnd9CrNhNqHkZZvJ+BgFCKpy0m/fAOVOo+enxDCdQPcz16JqEV1/kkfQX+wz20zphH8LULWFNEyEc0ZF78AvuSOQxFJjG97AySES3WtCSmPL2Yqld3MntRCn3lcGzYTvvbZ9GnpvCbfb+ms7oL9YJl7IyfhsJFzBqpEYNFgmDzbh6Ia8DQ2IlOJiLjzD7qbCpcg0NoSJvI/LvSKX11F/3NalRJ3uwat56HsiDw0y20NgxQI/PHx2bA9tALeP/8QSo+PI5SZEWv9EYSE07EsrHYLWOR7D7BqsZCrnygJsXTQqNRTv78dYy9VsZU1QDjgmzU3zmb9je2YewNw3r0IiMCGVeapOiMMiZ5ysiwSBHYZdgHhnApLUO05kEeiXXo8ElFILGYsGlHEPs6gsY75K0UtVkRSyJo64ffTIF3rkGmv6NsHeUJ1f2wJhXCJiTz4TYZmx77jDteXoXoagmui6YS4wUflziCvopeeDDLIRYtEzsCqPwOT84PKsnf2cLaRTd/OBnQw8lG8FWA2QYbMkAgkCCQy7BqtIwcPofrgkl0j8C+anARw7QIUBuhph/qBmBemBHNOzuQj03HZWzad947I8cuoZiag0AqARxrblV3L2R4xxH6/7aT/ePW8niu5LbH6i8W4jIh81ubQew2G6bqJgxXS7B29SGJj8DtrrnfeU3/TjgDPydOnDj5B1j71Qg9VN/6oPjPor9QgMejd2G8VoY0JhyRl/tN+y29gwhkEsSBvugqGtkz616yWy8xdtZ4BGIx4gWTsfzuFXJ9PWlLyaJ2y0Gkzz3E6LQ4bFod5pbOG8GqwQKX2iCyvoScvCMYvVzxcJPQfKSAyNPXEfgH4ernzmmdO53jxvJYw34E/ioCcsIZPlbMQx3l2K+XsX1Aht53DEIfC+kFF+h/+ffMiBHR+ufzePq5cTB5DimFZ9meezeD/XrS9a34FRUwpuYSYv0Ig7njUZcW43viOH9b9iz+UzJYmqLj+JMfMGgWMdnagl2ro/MHjyLavpehq80ovN0ouO8Jlp/9lGK5J5GjIxjYewYfFzvdw5Aw2Idrt4WyViOfr/4loZkRvJgFvz+spdcriMyfreLs7iISP3kXpcGIwTeUYoK4OGUZY6suEHfXZP7cANPuWkz2z3/BGXksKhmca7bzCMOIA+QMP/wYuo5OWn/6B7TLn8HmE0DQpBQKB6RMGRfIyJELmL+UGhEXVZKg1pHnPwXaWlgZ2YD18llMwfFIlKHEfL6F7thItvb6sjJVyKEiAxiNhJm0uHUZMeUZGajWs7jyAuboSLBZCVGJudrukFIxXK9Enun4wGC32/E4eozzo+7Evwcez3Gsv4v3BlepI9APUjoyf+Eejnl1310xnAmU8fkPPiYnyE6X1pvSBmgbcpSHP7nja2kZcPw8NgSa752K5PU9XMy8h9NNMCXC8TofXHcEfGn+kBX49XHyUUnoLxYy0tLD/pT5aKu+9MgVObKQP5vg6N41D2gI3LcD16XTkcbf6gn899i0OkylNXj++L6btksighD5eHJV78WCa3uQT1oB3HzP2u129JeK8Hpuw83bbxPsuc4cd1uXnP8NOEu9Tpw4cfIPMBZVI8v47y3zWtq7EXm6IxCLGDl+Gdd5t3YX6k9dRTFtDB1HrnE5YSLpddeIjFDdkG6xDY9g0+qwa3X4nzxGypN3MBgaxSfv5lO76xKus3JvBKuHa2HqUDmKsjKCnr0bbX4FtQWtjDuyhdKcmYSY1dT7RePW38O0DHeMX5xnUOHBud/s4GCJiUu5ixhau4bs0rN4LpqIj0GDT4gHC3PcqXvxU+QmA5eXbCDCW4TU35Mmq5IxWb6c9s8kqrcBjcyVyrhsAvrbqPcMRYaFFYsjWZoAf8wTM9KnIafsLF1mKbL5U1Dt3U+DayAKbzeGrUJC0WINDkTb3k9wuDvR1y9g7B4kyNVG4K8eQHy9lG2zHyJmdAR+rvDqFUjXNjEYEkmMF4TNyKA2OAF3gwZZWxunV/+AqPoSeidPo7IPfpANGcFCRFFhBJfmE6ywkKFrRWg0ErtkDFc6BGTKh/Azqrk0fjEd4ybTczyfWF8htspaxIG+qO5ZjDQxCtcZ4/C+ZwFZLYX4uthwsZpQ/vk36JOS8Hj4TsTBfuT84k6GYuJ5oSeWPS6pVMaORj02F+usqexJmENWnBueT63lavIkZucfwG63U9ztCK4cmnupAOhO5ZHnk0RWvBK91ZGdA5gaAXkdMDcGTjY55Fu+yZQpoUwf401NwzA9WjsZ/pAWAKkBNwd93+Rwr4qsDB/WujSwJhVqB+Bnp6B9GNam3Rz0GSxQ5R1F1Ru7ORQzjXGh8HAWBCjhaD3MjnG8zrqAPpZd3oZq/eLvFfQBaPecQHnH9Nt+EOuePgt5XS2hSQGMHDx7y37j9QpkaXE37iFzUweaT/Yz+KcPMVU14DpzHF4/uR+3O2b8rw36wJnxc+LEiZN/iLGsFvf7l/23nlN36iqK6WMdArIzx90oWX2F3WLBVN9Kf85YSg+fZsYf11O+7gWCFzquw1TXgv58AeLwQIz55UiiQlDNHU/Ouzsx2YW0bb/ExviXWKyxIzXqMZ4rx16WR/ikUVT99F16VSH4NdUwGBKBn1FNj1mMzsuX4KYapD/7NcMiGcNPP0FkmAr/Yj1r677AgJhTMxchr68nmw7yRs9CveGPmGWeaJ96FBMyxlWd43H3+UyLgMttsKj2ODH1RWhkSrqffBLTn/+KUQCldz/E6Fdeo3bcdCYcP0uovg0zQpQdbVQcl2I1iInK8MSzrpqdcXNYtHs7G5f/gHEfbKK+txfZ+Amkp/hySJVGzE9f4f0J9/Hkhlhy5INoCms4+kUd2hEzJePnsbMcTFY7ST21WKx2TmbMJW33p8i9lIydFnojE6bee4rK1ffh9/5GHu47w9kaIxaDGcXoRCb+aRdN585y5ZGfMa2vCNdHVtP+exNxv3d0YFsHhhh4dROGa+WIo4IRq5S0BUYTvGwy3uNiKOwCPy4y/PlxXOdNwkslwGJzNF+sS4PyL9ekXW0HeXsb9U29VI6ZzpBcgFdLH8NHLqHzHI/SNMKQUIhQqcCs1nL1QCkuP7yXuVFwrR2O1DuCPKUUfFzgbwWwNMFRIv4mVo0WD5OWmctT+KC0hQLfcF6aBm/nOzrE/z6malI7zhmwbArqv+3A89lIZkcLqO6HB0eBu8wh5VLa4ygrSy1Gsg9/jm9qGPfO80f4pceu1uRw/Ij2dGTSNR/uxuuxVd/pZ/0Vls5erEPa2waJRgvsbXfl/pmJCFwkWOpbMRSUI89K/sZ9l4fb2oWMfHEeY1E14hB/XKZkIwkLvOV8/5txBn5OnDhx8i3YzRbsZjNCV5f/vnMaTVg6+xC6KzHVNeO6eOotYwx5ZfTGJlP7zkEmPT0fmUyITS5H0NOHtV/N8PYjCGQSBBYr8nHp2AY1DG89jMjLHduJywT6SJn30nNURKbS22sgy9BEk8gLy8f7cJGL8PISYxJLCS4vQGOTMugfTFjBReRhfmisQsLWzMQzTUXx8UrGHTuPy30z+VwfibWmkSm/fQblgom49nVTbVYxLlHFAbs3C5TdnOyQYRzlwYAeplmaWH31Mxo9fbicMp3Bo0X4qEK5c7Qcl6ESBvMK8GgbpDggmUGlmGB7F2JPFfKhQT647w+svraVHdMeIrPiEn3DFia//ycM3t6I4yNonzOfvnOX8Lj8MSfCxjBb1Ib6pXwOuXtQ4ReHesI87t/4CxbNW8FvroPP+bOkDg9TMWYm+oERChfdySvzRIi/rK5b2rsxWWxcFwbwxB8fpvv+X5Hl4cXFjClMfOYVwi02fnXfyzyYK0e1uxB5STHBYj19v3sXU2EVVu0IQhc57o/eiSwuAkFwAFeKpDzzZeNo/QCMnpiGfttuVOsWATAm2LGmLj0A7AJHc0eIi5kNnYcxPbKGl0oEZAaA6+Kp1Ly2k5S4avRXBpDnpGK3WCh/ZSeC5fOYEuVI0Y0Ogj9fg26to/ki1B2O1X+dwZOKHIGmixiCvjiNLWsqHQIl4y6dwXVGOO/kg6fcIbMSdvOqAw7VOjp8hS6uSCKDMZXU8Jk1nkgP2Fvt6EAOcnOUoqf46dH9bSvKe6Zh7R1Eu+uYwzZQJOJEA8yIcgSe6r/twP2B5d876AOHrqXbituvtd1TBXOjwdNnLAMvfYjH46sYencXIl8vxEG+aA+cwVTZgHbHUVzGZ+D5o/W3OOH8X+H/5rt24sSJk++BqbIBaWL0dw/8nhR3ge5sMf4p6di3H0W1fPZNJas2DdjsMHSkkPKgZGbnBuIe6U9BYS+uaXEYi6oxldchiYtAd/Qi7g+uQH8mD1F8JCNHL2L39KDE7suEDx7Fq64e87v7kBq0XJVFoJapmDcmik6vYCgoYcxnL1A35zHkBi2NvlHI73sA3+pSoirzcZk2BvXfdtA2pGTsLzbQZJGgOVjH+GNbkQd602BwoUttJVRlp1wrY3YUFL1yjqPRU/BxgUChjkXbX8OscEUt9qU4cwq5R7biG+DK3tAJCAqPUf/zjcz/8/NE9dQT0VnHrrt/QpNOwoKqYzxUvYejbnFUeMcx1nYOW2gIej8fEq+coE9gZPRID92FDagVntw1P5S4SXGIQwOwaUYYfu8a95ZdwCwXIdq4hfAJG5h2fht6mSuv5NzPlIZL+LXU0mIbTRSOTFHb1lN8HjMTgR3+VqUgPjAB/2tXiHNzp9bNH8lTa/GpE6McUWPVjNB17y9xGZeBYs54FDPGYcwrxePRO28EEtc6IMP/68xZ+zDM6G7HpHIFmw2tVcT1Lnh0NHxW6pAdOVYPz7YfQ7VgIrsHXLk3E4q6oEEtoHDaHYw9+Cl60whev3gIzcf7KY7N4c6xX9sCCgRwZzJsLYf7MxwZ178tcMiZ2O2OdX56C4x0DWIwq9GlROAvBN/rWjzc9ASnufDmVdhVAU9/Q7mkftARECqlDueP6siJWP7yGdcWxrEqRcAUX/B1dYy1aXWo39mCcsVspNGhEBcBwODrnyBfv5z6QTcWhugZensb7usXI/bz+t73jrGyAZGH6rZOG7UDjveW6g8gQrl4KiP7zuD+0ArUr3+CQCrF3N6N1/OPIA3/flaK/5txBn5OnDhx8i0Yi6pQTB/zXz6PxebQORMKIK28lKqxUzB3iGnoCsJ90BHstQ1DuDt0VHfjN+DCbGExHvc7FqF3ni0he3oalo1bEbrIMVzah/fvf4ilsR1xaADGinrMTR10yn2w/OA+PiwWsF4/xOC1KsQenggU/oyfFMIXXRJi/7oFWXgQl574Kz52Oxa5ApnKBTd/FVzsx0VoY+TgWVxXzqGmL5Q5bvDZu6Uk5Z0kSN9P3k9+w+eaQN6eMEL52l/j2tjAxd2J6AbMxM7ypX7ATsr7byK22yj1iOKybypZbcVE+IhwWzqZxs2XEa9ajqmsEztCoptK6X/2KX64bhTHnvyQpNd/iPzSFUoMo1l5/XNiumoYv+93qE0CqlbU4FdVSneHFxWh6YS8/hxJiQ7ZG/X+rXQP2/FLHo0qNABjTCj6C4Wknv0FgvYOTv7yDZbEivnBPePZ+/hHvPSOjCmew7g3N+JbXUpkg4FxoeCqckE0LZrrHkpCw9y5kLMYzfEGfjhwkdp3W/AQDyNLi8P/vecxN7aj3XsSzx/efVP26HIb3Jfh+NlqA4l2GHt3H65zJmC4Vs4mYRp3JjmCMTsOOZeY/gYCRHqK/RKRDEFmAMR6OTpz7XYJS1fPpO/HrzFy8CxalSe22ORbNOp8XR1l1Jcvw8qkrzXsBAJHh65MDILjJ/FbPR3Jl0Yc+txM9JeK8J05jl9OggcOwucVDv2/piF48wqEuDv8fmO9ICHMhcLoaF72r8At4usyqlWjdfhDr5mP5BvBlcv4TCQRQRT+cSvT5k5m6C8XUa6c/R9aQ2e32xnZdwqPx1ffss9kdWT7Hs/+epssOQb9mWvYhx2BsrW7H+3Bc86g70uczR1OnDhxchvsdjuWjh5EQX7fPfgf0KeDN69Cgg8sdWklON6PsWVnmPnkTFL9HNkgsxWSfBySHZOarjDTe5hTo+aytVLMkAGEtY0oWpqx640M7zqGOD6SkUPn0Gw9zPCB05iKqnB76Vn6pB7MVPQybe/7FP/4PaoTsvHUqUmvvkJQST6RR/Yiz0xiUOmFsqEWvVhOl28YbjYDfZ8eItxdgO/rP8bruQ00e4QS6wVnP72K39WLJHpYuJo6lXfaA0n1g3OXOhEsmc25qXeS8O5bJLWVMXrbu9y58dekqusZEisos3pRnTwWVUUZLq5S3q5QED3UiseevSh6u7D5eSNxkVJvUvDaq/n0pozijRIXPjTGsuTkx6ilSkLuW4BWbUDz41fRK9w4lLOUqrhs6u9aj/GLM5x/+gMGK1tRrVnA6ZlrGD0+FN3xy+jGjuGtO35CWHUxYncljZ0mnssFpYuI5b+7g1C0JCd6M8ZNQ8XjT5P7xw0EPrwMaVQwlpZOElL8ybP4EPrR+6RfPYphcIShxYvQx8ahunshIwfOot15DI9H77ppjWbvCLhKHF214Pj7JhacxnXRFFwmZlG4p4CMAEeX7alGmBIOu4uMLK8/gXbxAq60OdblgSPLluLnmEOm0lpcZuYyvOMI1xImkBty61wDh+dv/aCj1Pv3WDp6sFssNwVm8tHJGAvKsduhX+8IHE80wt174KmjoJLBs+PgBzkOq7czTTBhfS7mE5ex2+2AY42j+u0tuN+z6KZzf4UtwJ/zc+/G/733ESoVt9ivfReGS0XI0uIQKhW37NtbBbOjHGXsb6JcOYfhbUdAKGTk8HmU87+fPdv/BZyBnxMnTpzcBmtXH6JA3/+SjEtRF2wqdnQ9jgp0dOpqJa5cVcXybo0Sdxn8dLydh0fbWZFg5ZEkPfHd1XglhrB+UQjZwfC3g90YDBa05wsxXCvDZWIWArsNob8XppomrK3dyHJSKSodIGxeFh3Ln0Sw+zAGi5305iLas8fjec8iKqr6cQ9wJ+rJ5dQvWoFNJMEqESPxVBFXX4TZYGZv+BSO6APo1Iu53mkn+sopas5Wkh4mpWDYjc8iZ/HbjEHSq68ifvsDevNq8G6sQWozM7BgPupBA4k9Ndh7B/gkeg6SIB8imspxGxqgtmmYu/e/hTE+nouz78LNV0WoppOLj/6UmM+3EF9TQGtSJsaaJvxOHafqjrsJGOnjUJeSfT/eRkmVmpKQVFrvvZ8BmTuTy04wEhDEqcX386x4GutOqygp6CDvV5vZGzeL3Wo/0s4dRCYVkx+XyxRdDT2vb8bS1Yc80Jvxa8exq0mOWuWLR0sjLu99xPC2wwj9vZEmRyPu6GCooYtQNwFj1k1g18JHSRlq5GLmLBRTsjGW1uD+2Kpb3FrOt8CEUEdQVzsA9fVq/I1qpLHhlGrkDLt5MsbegdECzUMOHbsJeYfxWzGNz+rkrE8H0TeezDYbKCR2es+XYFNr8HhmPbbP9hDrZb9lvhktjvV4P8qFzytvnY/Du0+gvGPGjd8H9HCxU8wleyDv7W7jbDOMDXYIQK9KhhCVI4BVGxzjzzY7dAHD/WTIMuIxXC7G0juI+q/bcb9v6bdm8S4125h6eS+eP74XSWw46jc3Y9Pqvtc9ZDeZ0Z25hmJm7i37vmoWSQ+49TixryfiqGB0J65g0+r+V3fp/kdxlnqdOHHi5DYYi6qQfx91f7sdm3rY0XHYO4g0IRJ8vdldYcc+oOZhaRv2Q+0M1LfReaYYnclO3NRRjLlYD8DwV1pjQgF2swVsNtyWOh7OsV6QW3EWt84auvK7ECWlEPzhr+ld9QzDu46hunshPr99HJ3RTu22Gnye/yMtUm8MuVkEz8jkfGg2bcMCDCHQWWVlrr2Bsj9t564oV7oFduonziGgvBCT2c7EkTq+SH2E2OYy8q4LaL1UR4fRSIzSzgm3RMyqXp5o2oek2YbaPwZRdCxvjb2Ph1qPcD74Hkz5LUQMD9Lt6kNhdDYLL++kMmEM8y59gSzMn6rVa0i9uJvCGQvpbDTx1KmPkYxJZeG9Yyi9epap9jZCPQZQ157g1B1rCRLriW2vpEHlSliMFy62Pha8sYR79gl49C/3Eut9899h+4fXkVeU0nbnGqamuiIwGdHv+phz8+5GNWc8WcVnOJU+m3nbj9Ai86FT4s+YDz/hZEwKCQtj6JyzGg+RFcPmnUhFAup6bSTkhHMpdikZY8XcVdrO9XYLpnHhdNsh8JcP3zIXrDYo7oYrbdChhewgsG85TfDsqZi64EQDPHZPLroj5yjOnkG8u5K2y/VEeInYZolhcZzDo/eb1A3CI6511FV0EPqrX1FvVuIRo0Z34AzKv2sM+qrBIdUPrnc6As/YL5fRmRvasMhcKLL5UFnisE3zcnG4W+SuykZ06gLuKSHY7Y5Ssd0O78x1BIebiuHOJMd7+8r7VjF9LAMvbYTTebg/uOKGoPTfY7Ha6d+4h9HzknAZ5dAflESFMPjmZtxWzXWsBfwHjBy5gGLGWASSW8OVS60O549vQzl/Mr3PvYr7Qyv/4Wv8X8MZ+Dlx4sTJbTBVNNzi52np6nM8QLv6sHb2YtM7UiEidzdEQX7Y9QY6Pz1IW9MQydGBBOfEIIwORZKThlDlSkuThckPT0WRebMuoN1qxdo7iObTA8jHZWDp7kfs54W5qw/3bTvwSgpDsng8bXIfDv1iF8l17ciC/WFMJhfaRGwpg6RrJehsIny1vZz2jmTmglHUNAgo7YFD14Z5UtZLjdqFuMFmhjdVI8lMx0fbR1hWBO1H+3APDWHs7o0IgtxJa2zGrX4ArQnyc2YxZudnlGRO5WzuDIKjfejccYq6iImESQ0M1nSQHzeJZ/I+QO3phyYwhHFlZzDYhXSJ3PCJTaZn+kzSh5oQZqXSr4NfV20mONAF1/Xzee+KnnClN71tfQS+8w6Djz2DqHaYpUffImjFGPLbfUjPP0jGez9iS5mAaE9uCvrsFgsDnx5G0CFh4Wtr6TUIOVoHrr94mRBXd5rmLuHnM2SYhVG4VVVycubd2MuqmTVcyfEpkzmQsoS0sdBTWov+w08ZUbrTPno88vVZLE2T0FEK9QN2PL84Sug9K2gdcsimbMi4eb50DMN71x2l3QQf+N1U+PDUIB52HWk5ofzxoiNbdmpQRfLZ6wyXaBAbDWT3t1P02HOEqBx2bN9k2AhyLNje2ojowdWc6lPSPgzzV03AumMvurP5KCaPBhzlXZ35qwYHWBgLL11yOHbU9IP8/VN0LV5CjAlmRYOf4puyLX4MDA5hMxgRymWYrRDj6egI9nN12Kz9/DS8PuvrYwRSCW4rZiPy87pFfPzG38Zup/QvhwhMDkE5/ut/MGl0KJ5P3M3QB5+jc5E7TmqxOD74WKzY+TqbKZBIcF045TbndgSiT/8Dy12BVILP75/4b+3K/9+AM/Bz4sSJk7/DpjeCSHjT2i2bwcjQuztQTB+LLD0O8ezxNz1Q7DYbtb/8gGMLH2HpaAWqwusYS2qw+Xoi8HSj78BFRP7RiF2k6E7nOTKEPQPYLRYEAiECNwXmhlZkoxLRXyxEn1/OyNFL2KRSbF29mM1mAr3c8RrW0+kbTFlwKoH7S/G8VE9uVR9jy04jiwmlZOE9RDZV0fhSO9Z+KQnh8aiuXEI/WIdHahI2jwhEVwoZMNiI9dTQ55GA7GePIvhoC7K+SmozchiJykJdlw9uUnLO7mfAwx+bny9laimDOy4xqfoqoqensqHrNNtFQh499BrnM2aRWnMNsUHPxeSpeKi7WVR+mLC5o9mbGkfM3g85ljKLMXs30aoeoNMllFMtMcSfP4J0Ri59+zV4N9cRIdAw9sKHiMJ9OKP1ZLSuEX1qKm0iD840wx++keSyDmoYen8XtQnZBE1PRSgCX5GR8Vs/ZKTkEmeWP4xJLKN2ABJnjMPz4jb2XWjhpfviGHzlIk2TVpGoH0H60geEdTbj/tgqFJOybvKJnRsDRzZfZ1FqLKPi3WirhnMtjkyYl4tD/uRQrcORwm53uF88ng1SMSysO8WLKdNor4Fnx9mJqC6kfXs+eSlTKOyXMEbRTVFkIoH795Cb4YlFOf6mrtWSbjs5Z/Yg8vIgZ3E6r191lHP9lALs9yxC+/lxhjbuxmX1Qj6vlPDQKEcAWNIF2ytgQOdY+7bepR7PHB9UU75dOsVlbDqGKyVIJmajtzjcOL6iaQjS/aGsFyaHf71dmvCPRZd1J69SZVCw7M5bG6SESgUeT9yNbWAIJGIEIpEjqycRf6/lFbUDEO11c1n8djiDvltxBn5OnDhx8neYymqRpsTetM1wuRjFjHG4jM+87TEDBbXke8by4AJ/h2Bu8GRc50/CVFLDwG//hvZaOWHZIoa2diOLjUASG4brvEkI3ZUIBAI0nx3C84m1iCOCGP7sEBhNjKSn41JcimL2FCxt3TT7hGMuqcYr0ps57sPUVvWhKR0ip6ce19w0Rox25HYLcX01nJVkIdKrGXv5r3h0tVERN5qc+ADaPzyIW3Q4PsXXcZu5jqIWI3PWR9O7U0Rd8hS89h3CnF/G1THL6IlPZVrBIUz33c2owTqWHXsXc2snnXYFGwSVNL62nRBXXz6653nu/PQPKIxajkz9IQHlRVzNmsWcIwWU94sI6mqgv7EH94aDqH39yRpspGj2MsQaDX5DnQgu9yJaswT7uXOoHniC+hmzEZ3KRzfeD9yVHPMdRcdRR9PE9gqH7I1fVxOZecfQLl3MkRF/lpjg8pkmPH7/Err6DrQe/phdlTw4CvI74a2rAuInL2HZwU8pO5CFUe7HzJIjcOQMVaMnk7DlByikt0YR7jY9weUFdC25l2gcAsmlPfB2niMbppDAsgRHZ/bTx+C12Y5uWktXH0MaE6ETA+mq6sB8+CiXQ6Mpm3cvLcMiFu16C3FaPC+5TuXheaATt2PfexK72YLr7PFI4yLo3nOW8Uobiqk5CIUC4r3hdNOXIssiEcoVc2i7VM2ZpzZhmrGETSJfgt0c+oC/nOgI3t4rsCM8fxblk6v+4ZyX56Qy+ManXA7PZl6Mw+LPZodGNfTqHJm1v+RDqMqRufwuzOphyo5X4H7fhlsEpL9CIBAg8vb47pPdhnPNsOS7V2I4uQ0C+1dtOf9L0Gg0uLu7MzQ0hEr1/YUhnThx4uQr1O/twm3pdEQ+Xz/hBl7aiOdT625x2fiKYz/9jNiHFxEZ7nbTdrvNRuddz3FiwjImb3kbxYQsbCM67DoDNr0Bu8kMNjvW7n5EYYGO735eWDt70Td1oVwzH7f0WGo3n0LW24PKXYYgJ5MyjYywwktoLSAVCemdPI3S9EmIIkKIOrqP+rIuwjQd+NlHqBwzneRQOdotBwlVQbtVTmhzFQa7kO0/fYcJF/dxOXkKkw9tYmRQx+X48Sx5YSkd7+4hfnkukvAgtPnlmK+V8lnULCblHSIp7xSF0mAK1j7K5H0b0SGhMjyNhMKzbF7+I37ctg/3ETXmK4UITBaqghKQT8pGkRqN+5UrvJS5nlXFe5ipGsBtymhE3u4MvvkZ+uIq2pUB+KRH4q2wYx8x8ud5T9Onh2fGOuzAyvdeI7y9hvycOQzXtePWWEdWVxkeBdfo9Q+jwDsBf1cwhYbQqAigJmUcCglIRCAtLmX5py/SGJ9JoKudkjvWkUcAoSqHJ61M7BA5lovBRQKKvQcxJSZwSBzDI1mO7V8Fec9PdtiQ2e3wzDGHf+2ieMfvTW9sZ5PHWFwLi4l20XM2cy4bJqtID4C3rsID/l38qd6HSB8xUyIcAZ3WBBPd1EQWXsTU0M4lgzeTw2wol0xH7OfFm1cdJVq10VGGHTY6MovDnWruq9iNy6TRfCROY0akY+2e3Q5/+7Acj/5uVv1o2i2OHH/PwKb9bHHP5tElgRyscWQvv6iHx0Y7mjx0ZnjnmsN+TSX79vM0DkLRn3biMz+XCeODv/N1/6NoTbCpyOFN/H+F/87Yxpnxc+LEiZNvYLfbsQ2obwr6zI3tiIP8vjXoK6/VIBFyS9AHoNl6GLOPD36aHuSjknBdOBlJaCDi4K/PN/ThHiSpMfT/4m2HC4SbAkFwAP0yLzzsVpp/9xFKgR3P+eNpTR9DicGN6S1XaFy6lKDPPiV815tUrv8tWxIWsODAHqSafjL6q7k4eh6avmaUy+fQceIMFRMWs9HuzdMHXqR15Wo8t+3gjo9+z77VzxIx0Ix3VSnnc1YwaaianodfQOLjxVmzP5F5DYTk5SNNjmfZ0U+IXpBN/bEjqKckk3p6HwKxmGaJL9FRKuyN7oS2VuNTVkhXrwGT0he/nhai/vRDvhjyIvXcIfznT2bKYB9BtaW43DMZm1aH/koRFi8vzt3xAOP3byL8p79Ef/oaLWdKGVV6jk8DJuJjHeHge+eZXHcJWUIEsc3HOS6LIL23EFtFNXmrHqanqh0fBUTJR0h/dBynN+cR3HgRfXYWC1vOIy45TsWs6Vjl3vTeOZ9xvhJGGsBqd6zNk4jAYHYIAuuauzH1DNM1NYauRvig0JF1NFodpd3XLkOwChoGHR3cIgHsr4HQ8muklPfgYTpG0vrpeKRFsdTkyFYaLI55sVcbgFTmCBYjPBxrBrUmuNDiwbGI+fjEmAlS2jHs/4w6ixfX8h3C0FMjHGXOp8ZAoBu8cRUem+2BfPY6Dr1yjJmqBhIy5gMSsNuY0XyZT8atJb8DsoP/8dwvjslmYkEeQsFicoLht+ccZeuvpGkUEoeDx0dFju1/X2YdMcHuKpA2NjIlWoL/hO94wf8kl1oh9x/3hDj5Bzgzfk6cOHHyDcyN7eivlqBaNffGtqEP9+A6O/e2khAmK+x87TQLpwehGnVz04axvI6h93ZS6x+L99ZtDCSmYBZJkUsEyF3EyL1VuLqIELe0IkmIYnjnEQQCITarBeuAlpGgYBqnzSO+9DLhz6yi1C+BkU17yUrzQjgtl/I7f0HcgtHY3JQc7VXiV3odU3IyNTUDjLJ3oUlOpVgZwZpYI/kXmpGvWYT4vqfQBwZj8AugXW1j1sWdtIydgqy/jyaBCtuMyQTPGU3kj35E0LwcTK3djLT1MeLpS2u/CatMjqogn7qgBHzU3ZSkTiBa3UL3vIWEFl3ls8iZPLPndxjVI+x97HeMUVeT2FpOfY8Z4/pVzK0+xd5p65j+8av0u/kiclMQF+/FYN8IX0hjmdt2Ge+xSRgr6xnad5aymCwC+tuxDQzRHpNCqr6NsOcfQBIRRO+JfLqefQ3fjEjMSxdw9VgtJaOns6rtNDGP38GlX29BuHYZSbs+Qdc5QFH6FBTaIcb8eCluUujSOuRUTjc5umddJbAyGWK8IFhpR/36x7jfsxiRjyeDethSBo99KRTcOwKvXwFvhUPmJFU0yB3GCjwrStAVVnNm1hoaE0ZxT6YIvcUR8LVpYGcFhLgBAhjUQ06wo7dBIHBk6MDxXW2A4JoSPK0j+MwdR5sGsgIh0ffrdYUhbuCjcNi1vVsAE8MgrqMS3ZGLqDYswVzXgm1Ez0aP8QwaHDIvf693d2OuWuCtPLj3/Id4PXk3RYNS3s6Dv82/1ev3SpvjGlYkfX29l9ocAdniGCs+Gz/C4werb6u791/FbodXLsNTY7+2o/u/gDPj58SJEyf/JP5exsVmMGLtV2MzmBg5ehFxkB+SmDCE/4+98w5vq7Db9q29LNvy3nvGsZM4e++9JwkhBAibMkoLLW9poaWFUqBllVH2SEjIJJPsvRPHjle895Yty9r7++OwUkLooO/XvtV9Xbkg8tHRiXUcPfmN51EJva6dlT4GGarRDppw1Xk8XT30vPAB5oYuinvDGSWREZsdTdjtC+m1i+i+XEtvWR19p89j1uoIeXcXHo0ahdWCWCHDoQxg40PPcZOsBl3TRZzJyeh/8xHD75tBYE485+/6E0FTRxB010JK5jyM+5k/4a0oo+/0Zbrm3cierm7mnN3ObfdGcfitowxeOJjmlXcT2dVEoUrHOW0w03UdhE8fjuLgESpCU2n89W+Ysvl1rgRqiFk4m9qichSXinFK5VDTjkyiwaXSIPO4iHIb6Q4KYeaJDdSkDaTfu69xOTKLCe3bMIqUHFhyG9aoaCQX9/L6pDXceeo9Yj9+iWMLbiHz/GGU7a0kRQayNWoYqngtV06fYOEMG0pdMs0VnQRdqac2IoV+g+O4VBdKalct8Uf3Ef7zmxBJJXTc9Wu6C6qR338LInsvxyvsdNy4ipyd20i5YzIfVqvIz0km6slfov7JLchbOxm27xDWpx9n45ZqQlQQOzSNnAjBEuS50yATCZW/U01gP1uMmgRUPTrSEDzt7G5BYFldoJT4kLS0EdhQxp0dDUTEBlEZk0GoUUzA008wLV5Hcee3PebOt0CHFV6eDp9Vwh35374HfT6hqmgrKiZozUJEGkHsZH2x9xEfCJEaYf5ueiq8VSB47+VFApHZyOKjML67FZ/NTsj/3MkCK2wogS3lcFPete/7w/VC9VHlzEV/opgDqsGsGSQ8Pv2vUgtHxEFdCVxohegA2FQOmaGCGHPsPYVo9MB/ieiDL5Y6dP9dou+Hxi/8/Pjx4+cbOKsb0cybgNdiw1FchWnjXnx2hxAB5XTh0fdiPXQWn92BWReORC8hLjsKkUiEz+PBWdVI35YD9Gw5jD44Eqk2lNEF+4mYMIDg5VOxfLydmEduITZ9OO5xGfRpHDjrWnCEqnCabPSmp6N22dkVNYLxB9ZjqK/l7OTZWB/ewKkJi/lzURjztx8lubmbvJcfoLu6lfbkbFo3HcEUFEd0/yDCTh0no7MKVXsdva+uQzJoPE2vbMTndLH3poeYWrSX7MqdHJq+ir3hM1l2uIQwRy/GsyW0NRgYcuJ3iOZMIfD0KcQqOUFRobSJA+gWBZNUdhGnSo3T7SM9WgbNYgJ69JydvYLPQ/J56NQbOKUyxkU6qNq9iVqPBklTC82qUExNV8h89ilULgeeuHACUuOITY7i/KubWPjUQpwfbMY8Ygg72n0MV0bB4yupUygI8kH1B58RsDSJkA930ffeNiQJ0Zy//cfMarvI9oyp5I1JprtAT6zcyeZzJkYW7SR5Ui6Spx/CduQ84nAditw0lDt3sMBoxmJx05QVww69mh4bdJqF6llBvZP7o1uxtp2h9/ZbOdYm2Lf02ITZP7sbHkrtJnzzJrbZ4qhM6cfxfpO4Y4iYqUd2ol09CuUgHUfqv70EUdMDXTah8vbCGSGj91qIRCC3W7HhQxygprJbqEJ+c1ZuQabQjn6nEIbHwsBvCExJmA7dwzfj6TYikklJCIJglZDMUWMQhNM3sbuFjd2fjACvLo9jj61j+W8GE6cVrnNswrcrhUv7wZ/PCzORq/KEDWePoQ9HcRW6R279R3/8vpfjjTA/8/uP8/Pd+IWfHz9+/HyB40odrvoWDH/6EJFMhiI3Hbxe5NmpeK02VMPzsBeUg8+HavpoPm0OYVbxeziMclqX/QSPy41BrKZdFYr2lpvJoJcqvRdlVwsRL/6M060S4oeZkG07hHbJNCyfHRKsYxxO5OkJ9JS24OvupUWkJOr20SQV7MF6poXAinJMt68kUxpAurMD9Rvb2b/qIS5dlKL95AyFI5ex+OM/kPKjRRy+YKBq9mJSdr+HSmZlR/wYhheco9Mppv2BR1h64D067G6iFowjzW1BuW8bBXkTGVhyjHnb34ABOShafMg+3oB5cD6exHhMFY20hYeCx8fZrLF8OPF2RotaWfXqo7TeeTfqdRuJPnOcG5SXiGyrQLJkJp9aw1jtOkbloDEMaT2Pdt8+Iix67GNH0qrRcWnRLXir6klf/2eiZSJ6Xl1H+LxxVG8+S9TIyVy+YOXOdAUvHbaSuWcLg5uLkU4fx/GRTzBPVEdVk43+jWXsnLqS4ZkazjRD8vF9tCFlcEcZmU+sRqwUqrKuijosn58g7PcPo//pc6gmDydyxAC0B7YzeHgezqoa+krrKbliALeb42olPcNGItLLGBoDN/YXBI7FCb/bZ6fixc08PWwZ2igdrWa4KRUSG8uQ40U5LBcQBNaSfl/fW3a3IKLuyIf8KNhQCmV6odr4TUHn0RtwXqnDfrYY1VihHHiyCWb9lUj0+uDjYhgSLaTC/DUiqRRp5NeGh/MyYGO5YPL80PCrK2YH62BSknAdR9sVBEcHE23qQBQYycw02FMNi7OvPr9UDA/+lUuLacMetDfM+KfSbq6H2SlUW8P+NcXE/xr8ws+PHz9+AFdTO70vf4xq5AACbpiJq7wW08Z9+OwOtEumfuWvphyWi9to4vIrOxl16ixYevHdu5ILC26h1qFijLSDIZdOErRqNN1P/BnKutDdOIPdtRJKOqHZOYA7q9ej3leEvKQWmceNPD2B3rW7ackZiTs8HM35CxhOF2Po7iU8NgRJgJoBKWqsDjd5r7zDkTFjUaTFc6HSymiTnVZvAJbIGFrW7efY7Ie4q+88Md5OTkmiyb14EL0LlJnJjP3jr7iYPojinz6MRuwm9zePE+OVIPO4ccbHkVJ1CdGpNnwxURhHDMV38gIdaAgckI2x1oLC3MeG+XehUUoYdeEw1Zn5pG7dxJ9m3MfSvW8x0thId0Ia+91xrIlsptviJqyuAo3bTnhSKEeXPMbovkqKhs0nNwLqiCWpLYai/Il0fLiZ8D9so92nQW3Yj2TwUEqe/YR+Db1EuYzEvfAwnnY9ykOl9DkbqNBkwLKbSAkQEaEBefkVJGcLGPDzm8icOfCr99Xn8eCqaUIxbAD6x14k7PlHsOw4jP6RF/DoDdhOFKAckkPwmDxcQyNJSNByqk/LsjPrCek3AZtbxIU2YXnD5/Ey7+xG9o6bzs0jdSzIhNWfQaSjl4atpzg0czWduwVRlqqDwC+WIrw+eOcShKmEluzZFhiXAGJ8HDjTzWhHHc4rtXj7LEhCg5FnJRN4y3wkocHY3dDnEFq7X+L1CQsWgyKF+b6/hcgA0MohJkBIEJmRJjxuc0G5HmalCTOPJV1w99JhWA+dI3DVXHLCBWHYa4dg5Xef31FWg1ijumZW7w/F6Wb/UscPgX+5w48fP36A3tfX4+npQ5GfjbOwAnlWMu62LgLmT/xqqcNV34rtRAGW+naOyJIYFWyhvFsMXXqSU0NImjEI84bP0T20CuO7W+i70kRru5WaNXdzVhTD+ETBW62h1cqsZx/BEhxCZ94Qog/swavv5c2FPyMpVsW4TW8QZDGiGZhB4Yyl5CpMlJ2pZ/jgcCreP0D3k4/TZpOi336MroAwwnrauG1yMDXPfIQ3NZGs4tO0yYNpkIWR2F5FffYQYvvauDxlIW0yHZriywytOUtZQi7jiw9QnjEYbaCSfg1FiLq6UUwaQaU3GFFXN/LWFpo1kZjtXqLdvRh0UQzIDsa88wjW1FQ6PSpU7a1EWLqwSxQY+w8go6kMh81JU2oO6U49Gq8D9ayxmGUaLq87gTk2HoVOS1ptIV1xKWiqqmiJSCSmrZbu5Awi3CYq1LG0elQMqzxF7Pj+hKVFokiKoV4bw2FnFN0eGZmhcEMOrPvkCtGvvU7Xj+9n2aK0q95X865jiBRyHAVl+NweNDPHYD9dhHbFLMRqJT3Pv4/uoVWINSr0VmErNUwFwQcP0BSRhDk5jQGRQlXN/ekuunRR/N4zmMdGQ7cd9lxxk7zuQ/oWziMyOYzsMKFyVtYlpHgszBIEi04pGCEvzxESPu7sPYmr6AoXPGFEDEwmb0wykqBvb4WfahKi4MZ+YZzs9QlWJv3ChVm7v4ceG6wrFuYYV+YKlbMdlZAQCDkR8NJZYbs4RAU9v38b3U9uQSSTUt8rLLCsHnDt8/rcbnqeffer7+O/gv/WpY4v8S93+PHjx88PiLOqAXtBOZ72blSjB6H7yWp8Hi+9f173leiz7DuFq6oBzZzxvJYUjdcLkZveJuc3a4gPFuNqaEX/y1eRJUZjPX4R6/7TNAwaTXN3B4XSGB4fDRFqHyKRiHVNvajcDkJ8JhRpoSjfquXjRQ/z6I8H0fveVnzd3ejig7joC6UjJgV5sBhdq5nO37zJ0Xt+Q3eLFJnIx6DWCo6MHMFtVadYeyiBFGUw+Qc/pzVAx/YpKxlqqcPiMRNo7aVx6FhavRrUdbXEKlysm3wHNxes53LaUIZqrdDRTGloOklSNb0nStD1S6cycyAlEQPIPbufFIWX3TNuJVVfS8ul84TeOJ9MDbzeO5AHPnkSqctFW1p/jPEZ9N11G8F/+BM5PbWospNwN3fg7exBL3IjGdSPqPMXMdmikEeEkKV1IcmLJc5iZo92GJNHhNPplpMRFYH+0/ME/vbHlMekUtcrvFdJgVBUKcS2Lc1wU/vufszHWukeP4EpM68Wfe52Pc7yWkQqhWCK3C+Vnuffx2e1IdEJH57axVMxrd9D0JpFhKkFYTUiDlrHD2bW8f2ELxPOaT12AZtXxObgwfy6P7xdKPjVrWk4QPnQfO6YEXbV9mu/cGFr9/nTQntyUrKwlbuhFJZm+3C/Xk7oY7cz2SsYIwd64a92KADBfPqOLzzDfT74oEhY8vh7RR8Igi5cIyyHbCgVRF5ltzBruLVCqEKGfKHblENzsZ8vQTVqIEnBsL8W2kyChcxfY9l7CvX4Id8SfZc7hNnAtJC//1r/mmqDf6njh8Iv/Pz48fNfi7ulA8vnJ7HsOoY4JJCIN36J7AuhZzt0FtXIgfh8PkzrdiNSyGhaspx1JSKcHvh5RC3K8UkEBAptLtOGz/E5HCiG5ND72noMIhXNx8sIXjKV56cJr9e3/nMkYTpGrt1BYVASvvgEMn75e6zp6SSunkFmgJ3qzz6jY+JoGmqbhIB7m5jiLpjbYUQtVzL0xE4GT0nn9EU9YreF35e+Q2yog6qjWwjvaKJLHYI+MpEpJzaiFbmpGzaOsHOnORKUzeDmInrFasp1SYw/vxNXWwfOuaOxFh1F1dJOstzAiYlL8OTm4GlsJd7YQmx3H2qvC70kkJTLpxmaIkfeWU/AyT6OBmVxb8Fv6IhO5tjUG1nUcx46mgn52S9IDpUQcsdi7KcuochOxTcol43yEcSFK7jljgmU3/scuueeRLJhG9obZ2H+/BQlKcvwhQstR9veE8TlZZIyMpX8L1qMbq+wyaqWwWhNL21/2MJ78nwWxXdwasKUr0QLCH6MfWt3ItaokGenoBggbAQo8rMwbz0k5CFHhiLPSsZ+rhhHcRWK3HQmJwtVtiW5Onr2WfHaHbgb27AVXGH98BtZmQMxgcKc3GdbKsjVWHFOGkhR+7d98josEB8E05LhsUNCmzIqAGJ6WrAlCwdLxXD7IMEY+dYBgjD7km4rqL8wkvb54MPLkB7yz7U7Z6UJfoRJQUL6yKw0QVQZbLDoG0kYypED6P3zJyiH9kckkzIvU5gPvHvw1efzdPfiLK9B95Nbrnq86gubHLdXSP34Z8f+jjUIc4p+/nn82tmPHz//Vfh8PuxFFXS+8CFdO07SGxSOzeWj52c/pUIWSZtJ8Oazny9BMSAT4+sbBLPledP4vEZEgAx+O9GHeMdePJ09GJ59F0dRBT6ni7CnH6LzzxspjM3jbMwAYuQOhnUWYzt5iT6rmwunmjn+4m709d2ELJiArKoGucuO9e41jFD30vrAsxg0OjJ/dxcqsZforibYtItRn7xG/PqPKR4/lwmzMmg8U0XUuZNYUlNJVdiQ5mYQpm+lLSmLwuVraA2JI7irjSZFGOoz51A47Sza/ReUuBloqmXpibWkddfSE5/CpN5SotVewqYNI/FXt6PwuvlcnIxhzHj+kL2CFjScuOvnHBx/A0E9HXQfvkR3XAoFoii0ZSWUZg7jswmrmLtyILqlUxlqriUtXcfZ517j0/CRWJJSCHv2x2zuDmP4zo9Y2XQA19Gz6H52G00PPY9q0nAsWw/ROGEas9JgXw0MttQT0t5IypKxnGj84n3zehEZ+0jqbWZ08wX0b27ik5x5REcFII6Nol/q1aUo29EL+BxOJOE61BOGfvW4q6SG0F/ciemT3fR9tANPn5mAZdMxbz+M1+YgTQeNRsHXTjksF8u+05g27Wf/2KWMShQT+0WX7UxJH0OKjxJ40xyGxwpze9+kwwI7q+CWAYLX35gEQbTWGaDsQBmK/K83P9QyuG0gvFckVAe/5MuZNt8XixzJwcJ5/hm0CmHbOD5IEJGpOkHQreh/9XFitRLVyAH0vryWnuffQ7VjD1G15dQ0Wa86zrThc7Q3zLxqoaPNBJ9VwF2DhepkQfs/d82WL5Y6vimK/fzj+Gf8/Pjx83+CNpOPi9U2XL1mPEYz7j4zPpMFqdmMzGpGajIhs1rwSqSYk1PoGz6c4I5Wkj58h66770TlshNw7DgWiQJ3Tx/y9nZEPh+m3DzEuZk0iwIxI2OSqZzo9npcDa2EPnkv0vgoDH/8ANweynrENKfnEb/uI+z9c4iJVJLw0HJOf3oBz/7jWCUK0hpKsfik1PcbTNaR3YjCdDRpItDhILChljZdLCqFGFVHGyKvh8Jx8wmtLsMtkTLo3lkEKER8eqCTkXVnEItFhLTVY2/owKrU4Jo0nnKnljhbF3uDculfc57O0Djym4tI8BhwdvTgkSn4fNwNxLiNyDKT0X3+OWSlUX7H/bSYRejPlDGh+gSHdLlI8ZIe7GXO3WN56gSMdzXQ//Xn6Ro4FEtxDYEWAx1xaST/4QGmpoLXbKXjrl8jiQol4oVHaX1tEyfTRnPME43eBu/P8WB5+UNclY1IokKpOV5BFBbkseEUhPdDolHRYJUSVnCOvLn5lDXZ8Xk8jEsAkViMOFCDWBdIr1rHes1AIoOkjNjyLkemrmDFMDXaL2LEPL0mun/1KvKcVILuXPqVKPH0GDFt2EPwPUJurbOiDvOOI8iS45AlxeIoriTolgWcbha2cMcGm+i4+TEan3qKFpnuqy1dj9HM7l9uZNojs1DECxXiNy4K83vBSkGo/PmC0KLVqeBQHZxohBW5kBgElx57m1ML15AVLiImAGK+MGKu7xWSP+4bKqSAPH9asFhZXyocMyHph/lZsbuFCuNPRsD7RTAqDjLDvvt4n9eLu7Gd3rJ6zhxtYFSIDWloEOLQIHwWO4E3zfnq2F67MMN4Vz4EKYXXeuUc/GQkiP/Bqt+BWuH7+rcusvxfxD/j58ePn/96PL0mXLVNuGqacDa0c7HRS2aiGqUuALlOiyJeg1IXgiQoAXGgBtupQpylNYgUcoLuGYPjQinWukJE4/uRHWDAWVZD0LO3I5JKMbz4ES6dm8AbZuBQqjG0m7hyoZPRYXYypvTHWaFBPWM07pZOel/9BE9vH/W33EFvcT2+Tj3RySHU6HtoHjOT3ZdlTJ4xkrraNvL66nBfNtAn1pBxaBdWhYoWTSxhbhOOyDBs3T0cfujXNHS7GVF+nKiKy4icdsShQZxKH8/JpgjClV6GquopW7iSyfWnaTX04dE4acsbit0kIcFQjcjhZLLpDMENVxhacZrOuFT0Rg8nn3mT3YYwHtz1HBZNIPGHPqcjNZ0jC+8hSyribDOkD+jH1tRMJlUeZdLJLRybeiP/s89FRpSM7OJS4n52M61vHqRXE4wkOQ6TLIhRTRfwpQymb/0evBYrEmkU5r0nUPtcpOdGs/44zIqycOjnm8icnEvyT26h4w/v0/jrJ2l/62OkSjk7Bs1DHaDgR+WfIv7pIuJHZiD2BvDyRSmpgyD3i9CURqMwn3ZTHny+uZSQvGSciq9FH0Dva+sRBwUQtGbRVZUo2+kiVKMGffV7eWYyuowknJcrsew+jkdvwFFUwdDcTP542sfAz7bjzOvP5XYfd0wThK1l1zHarrQhmTP5K9EHMDpOqNBNTRG89Zb1E0QfCFuxmaFChc3d2UNWejC5Q0Q0GYUFkJIu0H9RSOuxwc8PwtRkSAkWYt4iA3440QdC3nBehNDy1SmvL/pAEN2ypBjCk2LwpIyiRQc50l5cNU3I877uv9rdwjlvzhNE35evNTBKqIiO/AfmEn0+KOwQWut+fhj8ws+PHz//Ebjb9djPl+Cqa8HncCLWapClxqMcnMPpvMkEu+2kJoqvmRjgNVtxllaj+9kanFfq6Lz3tygG90MaHQ4OF67aZoLuWoZIJMJeWo310Fmi3vstktBglEB7FMTHweAscHcZMK3bjTw7BWV+NiKNEsftN9H68i5coaEMKztOc1wGvspSdJFapqbAR4UeJpaUYqwqpTJ7BNl5UWi378Dqgb2JAzidMIQ/vXUn+kcfJrKvg5TuVqqzB2EMiWD2rrcpGzCWG380ineLROxucZPefZw2u5Tm3WfxScQ0x2bw/uJHWbPtOXRtjbQpQpB392LKzqF41gKCP9mAaOZoalqdLM11EqtvJMQMvrHDqZmxgqpmKZW9wveqyQiT4nyMKKqj7PFfM6a3AdGG93DGxCApPMzhggwcPhmREWq25S/gxngj5k/WYlq3E1d1E4F3LEYarkP/8xcx/frnvF8ED0c3kXVsD94757DXGcPRty4RpEwmrrGSwiGjyByRzA1rNyIbnU9YQiiBc4YBwrLDmHj46DL8YargjbftCtwzBI7U+RhRcYort9781QwggKutC8fFMqI+fBqR9OqPOGdxJZoZo696TCQSoRiQiTwvA+uR83Q99iK6B29iZJmRuvgMzkWEsqjzDJatclzVjWhmjeNY2nQWZF9dvuoXLhg9d1kEgZMULDzeZRGu+5fjhN87CspQDMlBKRME11+LLrMTNpXB5U6I0AiVvsnJf8cPyt/IhCQhgm5B1vceehUz0oRqYf8RwShDg7963O0VKn0LMoU5xm8yPlHIFB4a8/cvZ1QbBAHsX+r44fB/K/348fNvj8/lxvjOFmQpcQTdsYSQR24l+O5laKaOpC8qjpp6E9lbPsb4zhauNb1i/uwwmnkTAXBcLEM1bjDujm7M248g1qoJXD0fT5+Zvg2fY3xtA8EPrkLyjQ+1o40wVqnH8Oo6jK+vRzk0h5D/uQNPrwlRZgqn/7CdpmlzUORl0txs4YI8Fn3/QZjf3kj5hmPMe+0JkorO0BkSQ8vQMbTvPE2lLBKHV4Rt+hSePfY8hpw83Gs/Q3vwEDkHtqKMDeNudRVKsRdNfhaflImwuODH8mIq5JEM//g1LDExSFvaaNWE8fMXbif9zCGqQ5JxaDQEBKtJStORcvIAg0McaAsLmbz1L0S++SaGkChKI9LZbw7BqA7mkVHg9EJWqCAEZPuPkDU7nwZZKPvtUUiCtMy+sh9Rdw9hxk6CZD5CsTP6zA6GOZsIefwuoQ0bFEDg8pko+qdjGzWM3lfXM3nPe2RdOIzuwZsIz4phRZKZeV3nmX/rYDKbSsmeO5jthjCKx80h8f13UE0cdtV7t6QftJqFbNodFXDvEEEEdB0rJH5sNgUGxVUGxsY/f4Jm7vhvbZi6GtuEWU3JXwXPfoFIJEIzcRhhv70f27GLZCQF8KInj+k04NywE1lyLLpHbsWVkYbdI/qWibBEDBaXIICGfqMl+ep5WJQtVL4AHCVVKHKu3j7+JgFywTYlPlCIQ5ua8p2H/lPIJcL84d8rqJRSGBgp5PV+ic8HHxYJIv1aG7xyiZAucrLp77/OYw1CcoifHw6/8PPjx8+/PZZdx1BPHoEiJ+2rjFwQPnC2Htcz+9RGgu6+AVliDPZThVc9193RjUdvQJ6RSO/rG5BEhKAcPQhnSTWetk6spy7Ref8zVM57mPqn38dRWo1i4NeZUEaLm9Ajh5Fs2Y126XQkEaFoV8zGVd2E9ch5tp23Un/jajr0NkR/fBOlx8nwPesJFTvoCYki/MO1mFq7KYzJ5TerXyC+4BSRQRKCFD4Mc+ey/P2nCBPZqItKQ6/W0SdR0S3TsvR3P8JeXI3D4SH82GHm+mp4dVAnKW++wvhD6wipqyT0wnm6gyOJ8ZkwieTsmHILn427EY8uhJIVa9g9ehn94hVoHrqFxohkIqzdaBvqaB8zgYzZQ7ghw8PtgfWU6QUbk6oeCGhuJMnZxRWnlsFbPsC56zABIhflXT4+W/NL3NnpaLQKLi69jaIbbiNiyUSkcZF42vUoBmVhO3qBqo8PcXLKclAp6O9oQ+T24LU58DldGN/4lMBVc7HsPIZm7gRSQ8VUG2Bk3VlCH7+Tvg8+w93Z89X3XySCmanwpzPCAoRKBsdq3AytOY9nzHDkkq9FlaujG0dZDdrlM791D9lOFKAafY1g3L9Cmd8PaWIMJdUmFh7+CFlsBIG3LkSkVCASiTjTDCOvsVVb3AGBiquFVFUPNBhh/hfdUK/ZikguRySXffsE30AkguX9Ydq1/F3+DZiQBCeahCUoEJZDkoKvnSDyJV8aV3/5nL+Fpj7heP9Sxw+Lv9Xrx4+ff2vcXQZcNU1o5k/81tcuFnaRs3crsY/fgCQkCM2ccRieex95/7SvzHDNm/cTsHgKXY+8gP3sZSTREXi6uvH2WZGlxWM5W0qbLhZtaCjWpBSUPjsda35F0D3LkCXFUvHWATInDCF49gR8Fhs+iw3w0fHwH/h05I1czhzBsA2fMKz4FKnxKkQhKfSeKkbT04y9wUS7Opi82ksof/EYn5o+pafkNJI//pLwv7zHmXYjYbX1dCTEY4mMZV/EYB5r3cH6cQtwd5bT120mIElDVns9npPHaLx3Mz6PCIVShSUiEkV3NxaZih6RmhCxj+64FB4/8AJ1unjEpUXkVpewMW8ixUdEBI2axmR9IfJhOTi629k79WYUHicxf/qQI0NXkBkoZXXLIdyHy+iVB7B/XQFylwibSMp5RSwR0xIZ1VOKNjKI07OWoZCJGRMgzKR1HbqE06ugeundyD/dhthoZYRhD/U3LCW7+jQBy6ZjfHcLnnY98lULOdUpJ7LeQPTiNM7WQHpnNRqFCPXoQSiykul941N0D69GrFJwoRWMDhgUBZvLIUoLliMXGTVtIGe/iFT7EuPLawlYOOWrqLYv8Xm9uBvakK6Y9Tfdc6XjZ2MtaWDonbM4WC9iZVY05q0HkWelXHPerLkPDtTB/UOFCt+X4mZDiSCGFF9W+4oqrvpHxX8qUrHQvj1SDzIJ+BB8Cr/vOWMTBCPov6WK2WmBT0qECq+fH5Z/acXvySefRCQSXfUrK+v6AwUbN24kKysLpVJJbm4uu3fv/ldeoh8/fv4X8NocGF5ei2nTPpxX6vC53X/zc02f7Ea7Yua38j/76tppf2sb+b9YgSQkCBDySbUrZtH30Q4AnNWNIJVg2nIA+5nLRK5/DmlkKJLgQLQ3zKDqgZ+w91cvEaNyozN2EaNvptUmIfCOJRiefZeuR/7IxcGT6Td7ECKRCNupS4ijQql76I+8038xFapobnvtESKsPQSmxtCui+PtfgspSxtETb2ZHat/TrpDT1D/FAb01cK+I8T/5i46P9qDwSkiTN+MND6KFrGWXqMDx5DBNFbpGdF6ic9HLaFaHEJLbDrupnZM24/SGRjBkfFLMenCsclUlMb1Z/29z5Chr+HyjCWsOPEBx1NG8eaqp/CEhtB4+11k3zcPe04ON+lPcv7Hj9NU0ED3ggVIpGJsMiUl42dz46YX+NmuZ4i7fI5iRQwOmQqbSMr2/rNoWHgDU6xXmOOpJrhfIpszZ9JlF7OzUlhM2FkJ9rfW47rjZvKDbIwYHMbgECcHRi1k8uJcAuZOwH7sAtKYSFyhYRz76AzBO3bSPHE6G0vhnbMOhhUd4uW4mRxvgEZxMPK5kzCt382pJrjUDncOhmExcLAejIcuMMlWiWZcPsUdkBsh3A+udj2O0moCb5jxrXvIWVaLPDvlb8qQvdQO5RYVcxZlERckwugAe3AIXpOZmnYHSUFXV/X6HMKs3JpBgggaEAlF7UIiRnQAjPyGt5/9UjnKQdnfftH/QIbFCN+rWsPV/n/XY3isEH1n/54ff6Nd2Da+fZDQ+vbzw/Ivr/jl5ORw4MCBr19Q+t0veerUKVasWMEzzzzDnDlzWLduHQsWLKCgoID+/ft/5/P8+PHz741p/W7UE4chUilwFldh3nkEkUQiGOvmpiOJibjmh7K98AqSiBCkMRFXPe6qb6Xg5d2kP3ojilChD2TZfxpPux7tytlIo8KwnbmMecdhfBYbrrYuXCoVtTf9kuhb52CPiODjiWvoHyliVc0ZLOFBqFdMx1nbTEBBI+0vfkL4ytm01xkYemATJlcWsvlTaXh7FxeC0+kmmaDmBkYe3UppzlAG1pxHKhdjUGgZ/skbnB2/kOgRcu448Box0RqUowYKvnIiEDkdJJ49SltUInKrAWuQhnh9HXaPCG17C4FGPUeJwFdZw0pLAYXh+dTHZZJcX0ZjVCr9+hqQWi2YkLNtzf9we8lmugcMJri1EX1MMh9NuoMHnZdBraQ7IZXDx/T8pHA9qeOyyAgx4Jg3mKQPn0cxIJOKi804LpUT7LMjl0VwcfhMkiJ0HNflsLU9CKUMkkvPEdlaS/dND7HDEoPPBWk6Qdxo5BBeWYzC5yGwrBBlwUnkY/LZM+82xoZLhRZsXgbGDz7DFhXDholrWBXeQVBbEwNGhlPcAfa1e5n/0ER6HApMTijqgFZHGvKyYqptNQybmMruKrjQ5GZOwW4kKjmu+1fS7RQTqOCrtIzel9eiXTr9mm1U28lLBCya8r33aYUeTjbC3UO+NhyekChUqSYOy6X48xLGLvvawdjlETJ4b+wvtHlBEDd/PCNs4trdX28j+1xufHYHYu3/jb6lSCT49AXI/3ZzZrFIqAwerIPZ6dc+xuqCty7BqlyuMuT288PxLxd+UqmUqKiov+nYl156iRkzZvDII48A8NRTT7F//35effVV3njjjX/lZfrx4+dfhKOkGkQiFLnC3/TyNGFS22t34LxSh/XQOdwtHYiDA1FPGfHV131OF5Zdx9D9ZPVV53PWNNL44T6alt3IhCRhwt5rsuAoKEM1ehA9v38baVw07Xc8AR4P8pw0HGW1tMelY5wyjo4dF3GGhXFDUBuhMjGdb2xANHkMZw1qPB1SsscN4HxhF4lmOWcdkQSnJqA/VYNsw1M4DWYSWi+SbezBo1SiClAQXHEaictFd/Zg+kwuNBIx+rETmNR0hMiGCnwhQfjMVkRxUXRLtXTsuIgqSEdjbBat2khijW003L4U3Ycfk2zvwJWexvQTn9IYlcbLmbNI7awnMm8AoppiVBIvqrISDvafisbnJMzRi/pCAd05+cjUMkr7Anm64HWSWypQjRrA0U/XM6W0iq5+iegkGiJkMtRTRoJaSc972+lSxdJ31z1cGDCUsD27qVQlMWRMDmfPC9ukOQ2FrLm8HvULP+VX1TFEqOG3k4SW66wEB7eaCun8w0v0Tp3CpawJtKrCuNQqtDofDhFmMB2F5fSGRVNda+Guu10EaSMhIxKvD7bvrGVsiBtTajqjTEIs2Nx0QUjYls/ix7s/4JA1joJmF9OPbWR/5GAUw/PYfgCmpMCwL6pp7o5unGU1hP3mvm/dfz6HE6/JjDRcd937tNEIu6uFjeFvVvRyI4Rt3TEDclFs/YTIWwXh92WSxuRkiPuGrZpCCmV6YSnlaP3X84fOshrk2f+mQ3v/IMHK7z/mr8mPEoTxxCTBuPqbOD3wVoFQQbxWNJyfH4Z/ufCrqqoiJiYGpVLJyJEjeeaZZ0hIuPaKzunTp3n44Yevemz69Ols27btO8/vcDhwOBxf/b6vr+8HuW4/fvz883htDsyfHfqWeAMQKxUoB2ahHCj0idxdBqx7T2LeehD1pGE4G1rRTB2JWKnA53Rh3nUUkUSCvbyO7eNXcufgr8sBfZv34/NBzx/exed2427vQaILRDN7LD6JhA6TD9HiWVjbrAQaDATEhuL93Us0FlRgiI6n1hVJckYkm0Py2KoMJEtajra1FFV6KtO6LmAdkYRh326kKgUil5vWmGQCZCIOZ45mnLMe46038WZzJKFFFxnYZkJTXIz21OegUWM1OTiWN4PQw4eIzI5B99b7dC1eRmO7jGk9hexKm4hj51FyLJ0EmXroWriQKkkoiXWl3FC0hQ51CGknLuKUKXDZXeydcTsjKk6wdsRKbtzyBn0u0OnbCOjtxjB/Bb27t+FZMp12uRp3QSHDRsQhzUuloMmJq9rCgNJTSPr6ODF8DjFYKcwaglwkZWfeLBYc/ICS+EjMthBml+5hVKCJK5lD+XNBDGMTYUYKVBW30/vRGQZ4ergSG4Q8IZG2W25D5ISyWsgOE/zrProMa3e1MKLgDJ9PWcXE6Ar2v3KQlslCK9Ztc6Led4C6h1bRUidU7pqNUK4XNlpzkhVsqxpPytYNTFPZCXxoHjpZFJ0WmJYCP90Pd+YLCRTOP31I4IpZ37JvAbBfLEP5jZSMa9FhEUyS7x3ytVD7EpEIhkTDuioFo6PUuLsMSMN17K6GhCDIi7z6+M+rYU6aMJ825xtVLfvFMjSzxl73Ov4bEImEpZXPa65uEXu8QvV0SoqQLOLnX8e/VPgNHz6c999/n8zMTNra2vj1r3/N2LFjKSkpQav9tpxvb28nMvLqn6LIyEja27877+WZZ57h17/+9Q9+7X78+PnnMa3fTcDCyd8atr8W3t4+AhZOFp635QDmTfsIfvAmnFUNmDbuQ54Wj2nzAeoy8xmVIP5q9sd+rhjrrqNIwkMJffbHWLYcBLkUV3UT3l4zvRv3Y49KQPbeJwzCRvDyGawbfiOKzTvIcaqJXzOHlMuXOXm0k/unOim+0ktYXyft5y8zOqwUe0oMNX/eRrRcgjkgiBPJIwnPimHLwKksL9nOe8o5lFVFoukzkHrlAkk3Tmb0Oy9iMrm4qAmnLSEdRUs3eTEaQlpqsEeFU9fQy6R0HYYOMSMK9lEnCqYmIo3ciSG0VnQwoL6IaLWHbrONjKY6OnIGcHn8fAIvnqdfXREWo507D71BpL4J52038rErndCKUsZueB/psjn8wZpJv+ICUhIjMM0YRbcVvI4+Anbt5aQolLVj78cZpCOurZoxn6zl7MwViGQKzkxcTNq6dSyWyVGMzqer0837mqHMinQwoLYYy7bL6DUh1OcNJ2lkFIFP/ALf3StJDBa85+7MFyK65BIQ9xg4s3cXH068iXtHShkRm4Px9WI0ES3IkmPZ/Ou9uBeN54aRXwv4vC/m46K08OoFWKC1oK0tQfPIrcjio8jzwp/OwoocWNpP8Hj76EAXYy7WkvrEA9e8r+znigm6ffF33ndGO3zwPfNkI+Pg9QuwatZgOg9eZHv6FKI0Qs7tN2k1CVu8d+XDr458LQp9Ph+ejm6kUd/jlPxfQv9wOFgrzEcGKr6ung6Jhpzw/99X93+ff6nwmznz65X6vLw8hg8fTmJiIp9++ilr1qz5QV7jscceu6pK2NfXR3z8P5Fg7cePnx8ER0k1iMUo+n1/e8t2qhD7+RJ8DifS+CjcrZ2Ev/RzTJ/spveVdQQsmYqjtAbRS7+m4kAT8zZ8gGPOeBzltZi3HUT301sxrd+Ddddx1NNG0fPix9hCQvGs3Y3X6aYtNIHe+XMJ37MLthWTvPtXhLfWcfgXzzEiM5qOuMF42zoJsNYwbkgY73b042j8fG6+9Cm7rYncbdhJn0ROgywUxfghHNUkkVV6jkPdWtJvyKTufA9TDq5FGxtCwIuvQU8HV1bfR4S5G9mUmQz45SMcy59CkN6OTxJHRdIAMj/6I161jj2jlhDRUodz1lR+L09ikWg/YZd20+5UU5c/GktoBGN6yzGZmkjqraLFJsMQl4CnuYZjYxdR328u0196Aq/by5HpqyiMGEZYSSFNRqhdfiPTI0VE6xtJOXkO7dQMdncEcOvlTZwaMJnGlDSUqWKG71rHsCdv5JIhGLvER5q0F3tMAJZTjfw4x0vq0V7iJuaiWnwTfymRE26HSd2FmNRilGNyeOsS3JZiIryqHMemCixWOzWFLdjuvp03xqvZUw1vXoQbls6h761PkMyZTHOngzU/v3rDNU0Hfz4PjQYPqyt3o1JJ0bz5BL0vrxWsfGQyojXwSSnMzwSvF5p/+hGy5XN5q1BCik6oJn3ZQvT0mUFybVNv+Hqe7Oa868+TtZuFyuI6RwoJJw4xd5aP2MCrB9u8PiFRZPUAkErgNxO/bhm7apuRJsde48z/nYhEMCsddlcJtjWby4Uou6H+b9H/Cv+rdi7BwcFkZGRQXV19za9HRUXR0dFx1WMdHR3XnRFUKBQoFN9fTfDjx8//Htdr8f41jqIK7BdKqVq4Ap9EQsjBAygKK3FcuoJ2xUx0P1tD3/vbcLfrOf7+SWauGYc6aQTdv/sL0uhwVGMH0/P6p3QERdJV5yBw77tInXbkNe14xXKskycQqA7gQI+ChOwRKMfkM/nhNTSNnkTg0aOU73VSkpiHaEAOGxwjcXS40bU0EF13msrWTu6p+AS9NoT6hP6MC+qjYuunpJjNBPV2oQ+LxXZiN09216FSy/CUuhEFBxI4bhDTSvYjDgsm9+MXsClFDNWX0y7R0qEIYuYHz9EcFMmmGx4h++ReymcsQud28NSuX0J1PZVx/VAFqTFExDJPVMelkCRCWxoQWy0oXAoSDE0kpgTSNnciGX9+CafTQ8/MWRjzhjGmuQpPUylHJy+nv1hEwfZLLLSVoJ08jHMHq/l8wBQSBzpYdmozbm0bz8pGMX2qmJLfrmVX+mQe7B/GYU8005/+HXF5cahWLmV9TxT3DRUyaO1uUIk89L27Bd/Ucex69QjL7HUEhakRDeqHctV8jj21idApoxlechRzsY+pw/PoTM3mneoAJuYMwfHCRswr7ifgG39128wOtu9vYUR9MymtlUjmDqW13wBa9NCXPhbvH/bRPHU2YpGQhpGmA3eHHlFnPQdyf8z9g4U28WsXhHbz5GTwnC5CNXLgNe+7L+fJFmd9O2nim/h8sLYEZGJoNIlYPj4RZWs9BF7tX3KwTrCc+VJAfnNO0FFQhnJIzvf+LPw3kREKe2uESrFM8v12MH5+OES+a9nc/4swm80kJCTw5JNP8sAD3y7L33DDDVitVnbs2PHVY6NGjSIvL+9vXu74IYOM/fjx8zU+n+9vssMAML63FeXwvO+t9jkr6+neeoStY1eSHCEjWOxE8djv6I1JoGHwaEJLChH5vAR1tVFz7wMklV0k87P1KIflEXjfcppW/g/GPic2JASGqAkwGRC73HTfthrD659Sl5nPnmmreXDTU/jEEtz33Yr4rY/pS04nNiGIRJmN53STCK0oJb2pjLhAwSB4uyeJOlU0GYXHSKkrRi724Y6I4P2ZP8ItlbFi75uUJg8i5/JxBjVcIiAujDZpINrh/ZGeL+DssJlk7NyEMSuH4KYaRColisZGOlShqF1WTDIN6294hP6XjnIgdyp3FW4gtqseBxKkQVrO3XQfA15+hrTcKCqMMgqX3cbwnnJkr7yDSt+JMlBJ/ejJNPsCGO1u4nRwJvvGLCNPaiB402aOzlpNZLCUeRX7UPucvKoYwbCTn1G4aDXZUTLGJsLayz4yLxxhiKybxunz2Lf+ErftfoWC5MEopo1loeECqlEDcV6uZPfwRYzLDaCm24etqongzz4j8tI5Lk5exNh5uUT0j8dZVo3xVDEXqm0kLRpDznhhwM1rtmI7cxlHYTlotZyLHchBSyg3JZlJ62vGVdtMX5eZSz0yMgfFkjggjo/64pBoVMRoIfaLX5IPP/1q+ef1C8JWqOqp51COGkTlyCkUtAmWKiLgYhscrocJ298l79c3o1BcXePwfBEvNi7x+q3FdrMwq1fRDS9Nh/eKYFGIHvn+IwTfueSr47osgqXLA8OuveHa8/u30f1szd/88/PfQqMRLrTCwqy/fTP4v5UfUtv8S4XfT3/6U+bOnUtiYiKtra088cQTFBYWUlZWRnh4ODfffDOxsbE888wzgGDnMn78eH7/+98ze/Zs1q9fz9NPP/132bn4hZ8fPz88ps37cVyuRDNtFMpRA6/7AeYorsJeUEbQ6vnXPaersY3at3axY9IqluUriPUY6XrkeRT5/Qi+d/lXr9G3aT+mc6U4tUFopR4CFkymfvNxHJv3YY2NJcJpRNzShgcRtshoCufciOfEBbSdzXieeoyF48KobLXjnb+G0NxERGIxgW8+RUGbiGNbL6M5cxbvyMHkujoYYKnniK4/rZWd9C89RZ08lJzGYkK8VjrD4+hzS4nqa8eqDUJhs6J1WVAmRFIenoFLp0NtM9MxbhLZezcRePIUDqUGh0SKISyWXrmWZHcPRYTTGZ3EuFg3F6ptDGwvwxMTTbihDdvMabj2HEbkdBKWm0RLZSenhsxgQEsJvrZOyiPSCfPZ6Fd2BnNYJOmJaroq2jj0xIs0edTk7/iYjikzGDcgkMZXN3NYkUrI+EEsPvwBjQuW8EmrjnuHwp4qcHnBYbSSf2IHkaePYxg+gjBDBzkZQQSGaRAHqAmYMx6P2UrTyxu56AnD2mtFmRhN//LTFOVPZPqkeKQFRcJWdWYG6xW5LB4RSELQtd9zd5cB+8lLHD7SxNhR0SjT4ihSxnO6T8stA6+/Jeo1WzG8vJaQR27F4pPy7uedzH3rKZK2v4xIJOJUkyDQbhkgiAhnaxfl60+wd8RChsYI5sEGG5xqhspuocI0+DuSJuxuwYdPb4U4LYRphDm/6h4oaIMpe94j+Ec3CotHPnjlPNyQA5HXcGpxd/Zg2X74unOGfvx8H/8xwm/58uUcO3aM7u5uwsPDGTNmDL/73e9ITRWqABMmTCApKYn333//q+ds3LiRxx9/nPr6etLT0/nDH/7ArFl/m9s6+IWfHz8/NM4rdViPnCdozSKs+0/hKK4iYMEk5Jnf7s14bQ4Mf/wA3U9WX3ehw9HezYU/bKZ2yUqWDtOgwI3+V68iUioI/eXdX4k+n9NFz3PvEXTXUnre3kqDQ4H1UgXBNiNBKpC6XPgAp1jGhdAsxLX1xLXV4lCq0cUEE7l4AspIHVtatcwo2o2vQ48sIwmT1c0VUSh6K4yLdtNX0UiNLwi9NAC1pQ+TR8qe0Uu458CfcdsclE5ZSELZBSL1LYhF4JApCGmtp3fIcFoDo4jCTLa0D09BCRj6EHs9OKOjcdY0YVdpqYvLQOe2gMtNrwPSo+R0NBtpCE1kUogJg0RNRWA8mu4uMPQS6jKjDwglvKUW1WP3Yt+6F3ewji3SLJZW7ePSrOXEt1UTuf9z1o9bTa65gckaPcEZsVhnTcf53iaOZE+gIjQV9dpPyZg9GEVOGieaoN7gI7ipjjHNF4iR2fkscBAhyWHMWf8imffOQ9M/hc5fvIIxIQVrcRXSllbE+PBYHXgDNOgy49A3dBM/ZSCBQ7JRDc+lRaZjfakQqfbXGbZ/TWkXXNHD3AxhJk4tgwWZQtbt92E/X4Krrhntshk0P/gHjsYOYfYDk74SjEfqhQWLFf3BvGkvigGZOBOT+KgYDtVBZpjguZcddu0Kk88HZ1rgeCPMTBOsXP50Bu4b+rVf4J/OwK3WCyjkYlRj8jneAFY3TP+O4rZlz3EkUWH/Z4yb/fz/4T9G+P3/wC/8/Pj54fBabBhe/Ajdj29GrBY+Xb1mK+atB/AYzWiXTLtqU/FvafF2tZm48NQnhN27jKH9gwGwHDyDZdshQn5x51XnM207hPNyJRaJnPrT1YR7zMQ9ugpvYyuO0mok/TPo+t3b9CFn/7gbGB3pxKILI/3TD7E9/hMKNEnUXeliXPFB0hqKkWYmUxiTi0ejQddQgzZMS9aoFAyvrMXebuBUaA4RVy5jU2iQa5Uk11ymL0BHbXAC0eZOlB4nErkEqdtJxJQhtIi0uKUy0kKg7VAhvcERdCek0t1lw6XV0u/ycepvXE2qTkTQp5vovtJCjKWTuqhUyqOySLZ1YhHJaBswnJC0KNJ2bybc1UfJ4ElE791FwphMjMV1tIybwoUWH6nmVqITg3EvnE3IBx9Rd7mF3qVLyJiYjfNP79KTmIJW30HX3PkMHRhGz64TnC0xcjkohZDuNsaI24iW2ukIj6e+31D29uoIkMOtWQ6S171Pt12M9HwBYhFIU+IJGd6PiLH9kUaE8HaVGtGp86QXnKD/6GQifrQcgOJOYUPz9vzvT1lwewUPt+U58GkZTE2GAX+bzetX9L62HnleOobnP0Dy0Ut8VCziniFfv/a+GjA5fOSsfZvji27H5hExJFpI1DjbAhfahLixYTFXi79GI2wpF8Th1BRhRq/RCCca4cbcr4872wLWPgd5G99DdMcK3q8P4qERgjnxteh5/j10D9z0vfm8fvxcD7/wuw5+4efHzw+Dz+fD+PoG1NNGfWWq/E3cbV2YNu5DogtElpZA38c7wOMl+IGVyLOSEYm/XcI5X21D/6ePyf/xfCLThDQOn9NF54O/Rz11JNov0hW8dgeWHUcwfrgdx2MPcfZEE2ObzqOJDUUxciCdv3yV9oxcZJdLMWqCOXLv4+QWHSUjLwbLyUtkhovwdhuxh4RSqhcTK7WxPW0SJRFZLMyCoVozZ7cVMa7qOM5zlzEMHs67yoGs3P4yF3MnMCPRg3P7AQr6j6E3Ppl+R3bhkcjoHjwUkdNJ3C/W0OZV4a1uYEzVSZp3nKJ3+nRa80fgkivo/9xv6R01mg55MKrqagJ6OtAVFeJVKmmLTMTpgZTmK4hXLSL6roV07T6N782P0WYnsmPlI3jLqljVuB/LrmOUz1yMpK2DHqsP75RxdGX1J3vbegwBITT0G8y9jvO4GlsJvn0JjYExnN9ehKyhiciWasI6mgmYOZrd9hiKZdHMmBhNg03BlW643A5un5BMkVd8jKhACbGfb0cTqiX0f+5AFve1tZbXBz87IGTSvjzUQGiIAnGAmhONQgXv1oFfV8Sux+fVcKVbqKytyvv+6uC1cPcYaZ5wC8oRA5BnJGK0Q6EtkAlDQlBG6pBEhHD2ihnqm8lZM+1b27pOj1AZLOoQFkDSQ+CzCnB4YHH21e3mjy/DhKSrDZq/FK8/TuriyNNbybp9BvEDr+1N6zVbMb67Fd0DK//+P6gfP9/AL/yug1/4+fHzw2A9dgFvTx8BCyZd/7gTBRh+/w6K/GyCbl2Ao6gSZ0UdkjAdymG5yPul4PRJ2HDJScqnHzP4rqmo0r62XDJt2of18Dk0cyfgtTtwldbgqmtGpFTQHZdKsTya0UX7UcRFYDC7ce88RO0tt5MR4Ma4+QBnbn8YY2oGizJ9FKw7yfjPP0I5IAvFwCzOHqiENTdiW78T15qVdJ2tILC4iMY+Ec7ISBI6atk0cjlRtVdYs/U5igaMY2pwL+VGGZFlhQQFSDEGhNAWFMW+OWtY8fKjaO+8gfYeJ931epzaIGz1rRyet4bEuADyg+1EPPscSdMH0VHThXnwYOJeeB5nSxdF6UOINnYgU8nxmqzUjpzEHFclPrcHr92OOD2Zt8bfgbSrm1UfPYUnKoISr47ko/sonL2ci5mj8DlczG8/TcDcCRS8f4RVb92Gr7CEhj+u5+y05SRVXiJ9RAphWTGYN+1H9/BqxCoFL5+FXgf02IQ2a6sJAuXwh6kwNNJN7+Mv4SivJeiOpWjnTfjWe/xpKZxsEubnWkzCYsWOSrC4hOrd9w3m+3zC1uvrF+GRkUKs2T86zN/72nqcVQ2E//FRRCIRPp+PyhoTJwt7WBreAz09ePW9BCycjCTsu52ArS549Tyca4FbBwj2It+8Jodb2BD+8YhvP3dXFXSYIVxkY8yRTSjys1GPG/Kt42wnCvB5vdf8mh8/fw9+4Xcd/MLPj59/Hne7nr6Pd6B7ePU1K3df4nM4MfzxQ7RrFiFRK6/yS3O367GfK6arqI4CcwD9Ha0k3j4HZf+vXW+9VjttKx9FFKDG2GXB3Wela+QYOgcMRr19D6dGzmFi2SHkZiPmuESiOhuJcPaiWzqNSwcqcPSYsP7657guXyF02zYSu2oJHN6f8N89yOZyOF5sYuWZtcQG+DjtDGNcjIvSyQsoO1mL/NxFNoxYTpRVz/KPnmb9kocZE+Ek9Y2XiWuvRRkVgk2qRG+BlINvUXjjr9ApwCZVsDdtAvVhyYy7vB/ZHSvIS1GjMvVifWcTSRIz1Sm5aLftwGd3YGnvReFyYNMGkThlIKdcYZQ6A5nfc4EgpQjFlQpEQPv//BzT2RIyzh7G22em1+bDpVQRKPdxXp1EYb8x3Do+kICsBD7bUkFubjju9FScr36Icsl0Mv70HCE/vgn1+CH0vPAh9oVzKJdGUNkNh+sgSAVZobD1ihCH9dosiNBA7182Ynz/MyJf/Z9rJlycaIT3i+CZSUKF7oUzEK6G6ADBM+/7KOkUUhp6bLA6T2il/qM4K+vp+unzRG/8I2LF1X3lwnahDXtH/ne3XUFY3CjugKONQnzYsFg4UAt1vcJcX1bY139ugDHXKOaZnYIVzP3DQIIX88Z9+JwutDfOQiT5uvRpeHUdQavn/5/J5/Xz/w+/8LsOfuHnx88/h8/txvD8+wTdvvi6FROfz4fxzU9RjRUWB779dThUD+UdXhYfeh+FQgwuN+pJw1EM7Q9uDx0/+h2Owis4Zk2jwSoj6755SA6foHnnGXrHjmNkghjzxr14LTYCb12As7gaWXoCNZtPYu004rp5GbbCCtQSLxEXTqOymJDMnMiO2XdwtlPGi9NBWXSZK4+9gTwsiKbgOPYF5hBh0qNfsIBskYGkp5/i6OzV9G8pJe3AToJyEhDXNuBMTKSjTk/cwER6KloJNnSw6cHfcz4wjaX73maAvJeE5x+m06ukubSFlnd30hYRT/qxfYjlUtS93ZikakJsvZg0wfTLCMTXoae624crNpbAjDgoLifO2YM5NgHR6YsoJD7w+jAGh3F6xb0kzBvJ2feOomxpZsakeOKWjuO1U24SP34Pw123M/XcZyQlaHFX1hGwci61b2ynwaWkNXsQyqH96RcGajmcbxGqc6+cE8TObyZCfKCPvvV7MDz3HtEbnkeRnvit9/BgLdQbBaHz4HDhseIOoS06JOb691FDL3xWKWzF9o8QFibWDPo7bsS/wmux0b7ml+h+eguqIdd2eTjVBDUGuCn36updtxUKO4S2tEQkWLgMibl6JtHsFFrRX1Y0d1bC3deIcPsubKcKsZ8uIuiupYgD1PhcbgwvfUTIT2/9x//Qfvx8wQ+pbf5XDZz9+PHz74958wFUE4ZdV/QBWD8/gTQ++pqiz+wUslpTgn3cVLwFxfhBqEYPEmb3DpzBeP8zeK1WHOdLUT95H8WH6pn05GKkSim7bMGkG3uIO7uf3q1dqOeMQzNxGD6rHV9SNPWXW7A3dRBoNyM/fZwmrwplRytB/VNoRIu9zkjWX15hys9Xo1OFUnWyHKM6GEOnhyqXmPy+Ilruf4A8Zy9DX/kD+8bPIaniEgPai4j+3Z2Y1u3CHByC7XIlileeoay+h/TTz9M8ax4Bba3cc3InGfYOlFEx+E6dR6YKxvXaZ+Qn6jA7G+iSS4isK6ckdyyRbXX0xiYS2dlI37lmGjIHcfz2Vfwo107t2kOoTH20jhoJhWXExAmlpm6Jhsspg5F7XZjsPspi+jHLY0Bp0PPCq5cJxUHS1EGsDKnD2FiJVZrEhbmrKasXM0aupV9XFUO73UiLG5FnJLLZnsgFcwCH6oUIsYdHQpzESs+z67BfKCP4vhXXFH27qoR2aJoO5FLwmiy4mjtIae5AkZsOXLt012WBbRXCzN+qXGFm7uVzQqLFP4rP68Xw3HsoctK+U/QBjIoXrnnrFcFMubAD6gxCdvDAKCFS7ruEXIAclvQDo0NIlEgI/ttFH4Bq1ECk0eEYXvqYwNXz8Xb3/k2pNX78/G/jr/j58ePnKxyl1dhOFl5lTnvN48pqsB05T9A9N3zL069CL1R6lmT5CNu1A2lMJOrJQrnIVd+KadNepAnRGF54H0mYjisZw0gcnEBwdwdnm32kSEwkLxwlvE55LZYdR1GNzcdR24K5voPuDhMBxm5CRvZnw7wfkbpvB/FjszB9soeQX9yFKi6Mumc/IsHSyclB0wj/8ENqY7PZNXwhac4uJrQXMPUnMzH+9nV2KbNQW/qYpNHjCg6hc+kSdA//AqvBjCskFK82gOaASMJFNtoTskgOl5IhNaG9cSb2ogpaHnsVd68Z7eLJdFR3IS0owqkNpDE8GZWxhwy5hT63mNiJeTjLa9mx9EE86anc2nqI7g92cCIsD1d7FwN9nRhSMygcM4dDrmhGxMGd5jMc3V3JG5kLeaDsU3ZMW81DRWvprOli3Et3UXfrE5SOnUX7hCmMiXQSu+VTVMNyUY0aiM/rpauqg/PHG7h4sgGN3cLYHDVpgxPZ2hHI3JrDYLMj1qoJeXj1Va1Ij9nGniNtaPQd5Ps6OFFoYFgsyALVSGMjkUaGYtl/Gt1DqxBrvt6cMDlgZxX02oU4tZgv4tjPt0KnRaii/aP0rd2F/XwxIb+4C2n49f9BArD/ixzYgVGQEvy/aw7sMZowvrkRr8VG8D03+PN5/fwg+Fu918Ev/Pz4+cfwmq0YXvr4ez34PN299L65keAfr0Lk9eG12vFZbPgCNOzWB9Ftg5W54Nq6F5FaRcDscTgqG9A/+jzKYbkE3bkUw4sf4jhfRvtvn0Tyy2cJlbg4N2gyA2+fSsimTQTfvxLD8++jnjQcT7eRuqOleHYcpKDfKDQKCWPSZPS4ZVzqkeNNTSJ93zYabr+btsRM9taAxOtl6NndTD26AblMxKXpywjoaGW8tYpmrwr1lQo6E9KwaYJQBKqQ2Gz0PnAvoc88B+1dWLWBWGPiia4sJsxjZv/jf2JevorALZ8RsHI2lpOFVL+5E3taOnEdNVjOluJGTG9GNvLGJs6PnMXo2tN4jSaql6wkRmLngCSFuKoipNHh5O/4GLtCw+XkfKpCk4mV2amYNA+H3c2yJCvjMhV8WiXnw32drKnaQYjSS8bqadR/egRVfR2i3j5cs6Yy4N5ZaG19GN/4lICFk5FnJVNjEGbWnB4hFaGpD342CqZEmHFVN1JyugFRdR39VozBWVaLdvFUnOW1OK/U4WrTU9SnRJkUTf6QKAy6SD436Lh10NXKyVnTiGXHUYIfWInTK2Z/LVT1COIuI/Qbx3ngxbPw0PC/bev3WliPX8R+ughZUgzaZTP+sZP8L+Nzu7GduIR6wtD/35fi5/8IfuF3HfzCz4+fvx+fz0fvq+vQzBqHPDUer92Bp70bd7seT7sed7ser9mCz+vFfrIQRX42kgANIrUCkVqFVaai6EQ9gY+uYVSSBPOOI/icLjTTR2P8aAeWHUcRKWQoBmbiKK3B09GN8skHqf7dB2SsnsrG0GFMffcPKHoNhD5xD15DH5KIULoPF3Bc149+mz+kPG8srsBAMqJllLW6abvSTkpHFbIBWehwsHfO7RS0Cbmr9+R7iXz9ddoLagmpKKUjJpmw+FCkYYE4jl5AYjRg1obikUgJk7mQBmro8CpRdOvpCQihNyGVjsXL6H9iF+FHDxM2fyzSABV4vfQdu0SjIpSwlAgkp85jtnkwhUVRnT8G7+UrGKPiGdNdQnBdNa/c/geyEtQkXDjOuX5jGfHBKyS2VKDPy2frsh8zO8GJ5O21fBacz23qWiIVLio9QZj77DR0uhDhIyXQS1zpBaRNzVzOG8egcA8BTQ1EvvY4SCT0ffAZAbcs5DLhnGgSKm0ZIbD5ChxvgLsHw4ovfOicrV3o/7yBIyG5jL28H1ViFNLoMOSZyUgyU1jbEUZ6qIixX3R+t1yB/uFXi7kvMR++QHm5nsMDZzA5WZgd/OvK2vYKiA387oSM78NZ3Yh5+2G8VjshP7kFscqfy+7nvxP/jJ8fP35+MHwezxcWGY34th4AQKSQI40MRRIVhjwnVWjValSY3t1K6K/vQzkw66vnF7QJSxwrAs+ga7iEpcqFx2hCGhNB99Nv4bM7UI3LJ/CGmXQ+9HtkOal4O7ppefRFohdOYnO/2UzYt474J++i+6nX6ft4Jz6Pl/JVd2IyKBh+8iN64uI4lzSEIXXnuWKUcsUXTG5HDRKJmD2SVHSBcrJtrQxqLGZmnAPrk+coFEeiFknpHTMOV2s3zsY2mhQhqJPT6VYHo6sqIzwrngarmE/n38/Ki5+i625GUtXIiQlz0TfbmHL5AiEzhmE/UYA7OgzDggV0n6wh5spl7FUyDKmZHBy/DFe7nomff4g7KBhFXTtaYwd7H3sOR2AK4RteoV2uJdt0iFCpE8fKpawPGEHmmSMEv7qHyMmDGRwTwDOyxUzqp2JWmmCVcrRBmFcbYa0lRO7GGR7D8AwtgXIfQY8/geGFD/Ao5BQvv53ChgDyowRvvEttsLkcSvUQogQfwoxdUFkJSXu24dSFIgrxcSUoiSm/FvJjXR54p/DrLVcQlnNqemBh5l/dLz5hdm6/YggT7Nu5V3SZgOi8b91XPTZhMWRuxj92X3p6jJg2fI48NwNJoMYv+vz4+YHwV/z8+PkvxtXQivHdrbg7ewj73QNIdIHfmcNrPXgWT58J7ULBZNnpEfzdpGLB+Fbqc9Pxo6eRJQrlHUlECJ6ObgJXzcO0cS+ummZEWjWWHYfRL12GyGCkJzqBqLorxGVGIg5Qo5owhOZfvUFZZBYp9SUEiV14UlP4xJuBpq2ZiumLGPTpW+g6W9i/4gEGBjnI37eRo3f+jLxnnkSdGkOiswe9U0JJ+lCGfb6OZ+f+grxkFWPWv4ayswPN4GwuWrV0d9sRSyUkLZ9Axp7NOC+U4rS5OLL8XhIaK8i8cIS+oFDCw9XI1HLMbQY8TW24Q8NoGzycgKZGTHoz5sAQjCjoyx9CbuMlQmUeaoPiqRw/kxXHP6RC72NfzlRm7XmX4OhgzquTiR6QQLarkzOh/fgsaCCxWvB4IUUHIWrYVCaYJT/oOMt8UQ3m5Yspf/kzJt0xFp/TRVednrLdhfT1OWmdu5A2bSSdFrB7IEghbNQGKWFRFoyO96LYugvn8QuopoxEu3QaFW/u4g3RAOZNT2B0PPylAMYlXJ2iUd4ltG/nfUP41RhgR4VwndNSQYEbw4sfoV02A1nC1WW9twoEe5S4f+CvYZ/TheGPHxCwdDrmT/ei+9lt17UV8uPn/zr+ip8fP37+KXxuN+bPDuNu6UQklxH6izuRhgR95/GOy5U4SqsJvv9GQDAAXlss5JPmfRHwYDlcgLupHZFGRcCM0TjLatE9vBrrwbP0tRhob7EQ3lIMMjm+w6dxjh2JPCmGqM5yHMWV+BRyztmD8STlMVhlwlPZh9fpoep0NZIELbsmrWLxsXVktFyh4dnf8/qMKFztXXTtNXHb5Y14hkbTMX8BhWsPclkczuyN79EUGEFymITlXaeo9EqQjh/FoWob3QoVw6RdyBoacL3biq28CEd6Ksdzp5JqaGJoTgDWZh3SedPZp81i3Pb3MJjc2CMTCbb0Enz+PLXRGSgCpUQ21xCnkNBtCKH/+Aw8x84SFCeh+9ABrvR4sKgCWLLhBcpTB7Fp3m2smazDUFBJ0ek+YpcN5ON46LLCy2ehsUqPrbacHJeEBZUFjMxW0zB2FNu2NJOpCeHdNwupNMmQ2a2cGXELwR4Lszd9ArMWkpoQjlYO5XrhfbF7YF6sBeMfP8DR3o3u/pUo8jLwWmxEWPT84p4EHvgcPikWrFr6R179np9qFpY0ANrNwqauzmJgZV8ZGqcCeWIuIqWCoDuX0vvKOnQP3vSVj2Nlt5DB+4+IPq/VjvHdLWhmjcN+qhDN/Il+0efHzw+Iv+Lnx89/Gc6aJkwbPkc9aTjuzm4kGvVXW7fXwnrkPJbdx5ClJhCweConLcEUtgv2HF/GW1kOn6P3xY8IeeYhep99B/X4oQTeMh8fIkpXPoGnrAptRjwtHhWX04cxsuE8DXfdy+ht7+C8Uo80L5Ozviiyay8ROiAFx+UKHGW1XEkdSLdDikcsQRsTQv/CI7z+oxf5023RiERgfGcLtkvl2A6eQbtqHvazl6m2qbBanVRPnE17u4XpiR4CP/iYprh0LsUPpCMxk4B+Saw4vY6Awks42rph/nTWxk0kX2tliqyV3re3YJs9FUtNG8bKFl4ZejPLCrcRa+vmyA33skeSyo17XiPcrKc8fwLOoCCmF+/FZTAhDlBRkjGc/hcOondIcMkVGMeMpW7MVKoNEGAzMe/kelpXrUai/NpITtzZRcDGrVzOGE7ayX3IslPpjU7AanPhsLqIElmJOXcC69IFpM0aQkYohKrB3dOH8fX1qG5ZxA5DGE4vjImHkvNNjNj2PuKQIEJ+shrJF8LevO0Q0sRo3DnZPHZQmAmUSeCGHIj/4q9MuxvevAir83zsO9aGvLiMfFs9gZHBKIf0w2uxYT9XjFijRjVuMCKVEsuWAwQ/dBM+sYQ/nYG7Bn9/du838bncWA+cxn7pCgFzJyAJC8a0aR+6+/1xZ378+Jc7roNf+Pnxc218ThemLQfwGowE3jQXd2c3lj0nCL5vxTXbuz6fD/PGvTirGhHrtEimjuP0Xw6hiQlh5O0TkWiF6o6jvJbuJ18j7Hf3YztyAY/ZiiwtAV9WOgW/XUd4eTHBcSH4HE66UNEQlUaVLpHFqmZUVZWEvvlrXi9XM6ftNEEXL+DpNWHecwKjQoPCYKAgYwQhU4cT3dVAgSmA0fdOJfriKeyXynEUVeDpsyAKCebywAkY2oxcCs1isbOU9g4rDo+IsJ5WOjVhtMSmkt5dx+E1j/LYh48icrkQRYXSUtHFY7e9yi2WCyQd3IXH50PV0YFVGUDR4EmcSRzCI3VbaelysGnpw3CxmKV7/0JN3kiOjVuCtruDmzqPEyRx4Sm6Ql1cJpWqaPIvHsCWmkZzfCbn8iaiVUD/MC9DP/uItqkz0AdH4nAL32tlVwcpm9axP3oIEdXlXBozh7whMcjFX2fhVvbAsO0fkv3oDd/aum5qMFL87Aai713CoP4h7PjLKfof3kHIjBFoV8zEXduC/VI57sY2pDEReJbM4a1LIm7IEdrKy3OEvNrEYJiR7OHUkQZaT5QSqO8gY0A08eP6IUtP/Fblzd1lwHbiIs7yOrw+H82iIA5OvIFxCV/PCn7vfen1YjtRgO14AeoJQ1GOGohIJMLw0sdob5jht0Px4we/8LsufuHnx8/V+Hw+nGU1mLcdQjNjNMrBOXitdgx/+vCq9txVz3G5Mb69CcRifE4X3TfcwJZqKYuzIa69FsuOI8gzkpCmJ2BauwtkMpwXSpBlJCLWajDuPE5beDxR3S3ItSq8JgueSWOo6PCgEHnIyAqlZ8M+6l99Bb0qhHxrA0HvfYg1NBwOnsDVa8ag1lEZmU6mqJfyqCzyCg7hUAegiAqlNyMbR1AwUWdP0BsZS6tZQlbdJQpickk2tXOq/wQySk4TGyxB1NwKkeGEOU3YdDoCLhcjHpxL8MKJlL+2g4rIdMIN7cjEYFdqyCk8hkuno77fEJR9vWSKe3HrQtgWOJCA6koiu1voefA+pA472lOnmDI0FLPTx5WtZyEyjM6wOKJPHKa+3xDk2SmYJ07kVDNkSwzc2bQXTb9k5MlxuBpacdU146prwVF4hcYREzgTkMpBVQb3T9QwPklo2zb3CS3TK3qY13YaSVAAivxsfDYHHouN41fsNLXbmRnYhXfbXkRR4VRdaKT/DaPAasNrsiJPT0QxKBtpYjRdzb1sPapnTriRIIeJjhYT9fVGBoW5aTBCk0lEkSqeWUtyGJQfifh6+WcIfnlH6320XKpn1M4PCQ+WEThrDNL4aKQJUYgDA8Dlxt3ahbu1E3drJ16jGVGwFq+hD2dJFcrheQTMm/iVL6CjpBpHUQWBK2f/C34i/Pj5z8Mv/K6DX/j58SPgtdqxnSjAcbEMWWocmtnjEWtUX0WtqScOQ56Z/O3nma30vr4BWVIMrpZOTk1bQatdyk25wtwWCGLSsvsYva9vQKRS4LPYCf7prbhqGilvc9MWnsCIlgI8Zy7h7TMT/OcnWXe8l4jOJkY7GnCXViFNjKG6y0t9SAJZHZU0Dx9HaE052jNn2DduGa2B0fSOHMXIvesZfWA9tal5lN77E2aNCiVEBfqHnkF0sQiTWEWAHDpRYw0LJy1Ry4HIfCbs+oCS/qPJPHeIrrBYYrsacKs1BItdPLfkCZZ9+DTlqYPozclj/OENGEKi6GeopS08gVMP/pKIvZ+jVMu4EpJMzsYP8SoUVERnkjYiFdf5y/TGJnH73Cis63bQbXTTOn8RQ2Og+MGXsSk0XBw/n/L88YTYevlp32HsPWbKWt0MihUTlBKJLCkWVEos2w/TvGQ5e/WBdFoEIfWXueD1wfOnhfixdSXw4DDw9vbR/sZm1EoJboWKMwYl0eEq+icqEWuUuJo6aP14L7YBeWTNGIBiYBaSYO0X7xmc315Ex6EChkzOIDhSi0QXiDg4kI8btIxPk5GsExZDDtTCmvzr31/NfXCgDswOmJAkxKDh9WB4/n3EgRpc9a24G9rw2h2IJBKkseHI0hJR9E9DpFJi3n4YiS4QeVYy3j4Lnh4jPqdTuAf7LIT8bM1VBtF+/Pw34xd+18Ev/Pz8t+Oqb8V68Aye3j5Uo/NRDumHSPr1Hpf1+EU8esNX27nfxN2ux/jOFpRD+2MsrGbTqBUMipcxJuFqjzZ3Swe9b2/G53ThKKwAsQhvaAgXE4cQlxBImqcbwzNvIU2MQRoVSkl4JtpzZ4kJU6AZmoOrtpkeVTB6iRqf002nBZLOH8MtkVOVnEtPYDgfjbuVRfqz5B7eTlVMJuKwECJaa6kISyOv4RL9L5+gKyKOmvnLWJc4lYzy89xy7mOMYiUh9VXYA3XIrRYKB4xHEhNBY3wmyuISwlrrGF52nAOTVhDyo+VM2fBnaO3A1mum06UgYtF49A3d9I4YRZkolNznf0v34GHUeINYQiWdQVEc8sUyqvoUKfo65CGB1DhVpOmgmQBMVxrxZaRQnjEMR3cfOTIj424fh/3AaZoTs/hIMZC0EIixdKHdvJVDE5Zzoi8QxxcbucNiIUQlCMAQFRjssCATrnRDg1HIvi3uFDZu8yKEOcsQuZukk4fQ9vVwZNgc1owL+EqkAzQafZx69zhJrm4G/3geMtnVbsp9DmEL98fDYX2p4LuXeY0Oq88nvPaRBghWwJSUrxM6vsRrseG8Uos0JgJJuA6RVIrP58PT0Y27uQN3Uzs+lwv1jDFIAgOueQ/7fL7v3C734+e/Eb/wuw5+4efnvxGf04XtTBH2M5eRRoejnjwcaUzEt45zt3XR9/FOdD9Z/a15LWdlPaZN+1FPGk7d5wXsnriSG/Nl3/pgd3d00/Hbt7CrNFj6HFikSqwOHw2jJzGtv4oQj5nux19GpJDjbO3EuGgxHQcvkBWjIP7Nx+l7ezM948Zz/lA1/T78C3K5mA5ZMEYUROibsMlVbJhwCyM8Lci7ujiWMZaQKcMYWnyEjN56Ag4fpa+1h77AEKxvv8wFi5b6M9U8duldet1SqtscGGISSG2tRNWjp+HWOwjf9zlGj4QgfRs2hYYYiZ2YsTlUVfYQ0T+ByIQgev+8AVlyLEZ5AG1okLe0IO8z4oiOpnTKAkJFDpL1ddSXtDKkqwwkUiqzhmAeOJCE/ESywyUcuONlVOFBOPVGdKZuvMsXcCh/FuEbNqAZlYduVC4hKjh3sYvYnVupXrCcak8g+dEwIBJKuuC+oeBww1PHBEHW3AdaBSTrhKqa0S5kz87PFKx0rNXN6NfuwTJiOPqsPFQyGBIjvFdWF3xW7iXosx0M7hdIxKIJ3ymoDtTCySbIj4Y56VcLfbsbTjXBxTboFw7jE/++xQ0/fvz8c/iF33XwCz8//9fx+Xx42rpw1QrzYe4OPYhEKIf0RzUiD5Hi2p/IPrcbw3PvE3TnEiShwVd9zVFUgfXgWeSTRnDx4xO0rFjJnP4KDHboMEOHRfjlau5kwKvPYhwyFGn/DCIqiglobiT8p6vRZiXgqmtB/8SrKPL7YTt7GUd5HRWaODJyI4leNQPje1txIuWiOhG3OoBBrhYsOw5jlGspyBiB3Okg1NBGUmctxpQM9udOJWrRRKRiGFN8APfe4zi6DIS3NbD7/t9gDI0md+N7ZDo7KM0aRmOPj7oFS/lZ8UcUH67iXNYo8jqukNRYjiEqDklwIGkNxTTPW4zug48Jyk3mwIzVjH3x10jsNtyaAGo1MQTbemkKTyC1pwGzNgRlZwdKYw9OsQxdsAL1wAzk/dOweSR8ViViuLgTR1EF3S4pxSmDaRoxjhVpDqLffgu53YJvxUI+TZ3GyHgRZSVdpO3ZSveNK/i0ScuaQdBsgjA1ZElNBDlNvN4WRYtFTHQAPD4WVDJwewUbHYCEIGEO07z1AJ7uXrQr51xVPfP54HQznK51MvfURuLH9UM1etB33lNOD7x9SchZ/t2kr0Vdjw0O1UGjEUYnCJVAqd9ZxY+f/3X8wu86+IWfn/9r+Fxu7OdLcNU0CSIPkEaGIUuJQ5YciyQ6/G9qi/Wt3YU8IxHl0P5XPe41W2l5YS0XBk/BsPkQNUtXEhGqRC6BSA1EaCAyAHRXSrH/9lVCfnYbqjH5dD/9Fu6GVjQzxqBdMg37hVJMW/YjjQjBXliB/XIlTb4AgrPjyfzVLZjW78F6rpTjmePoGD6axSfWUdPuwNzYhdJmwR0TRZDdxPm04UTpZAScPEVPSDTWhASmavR4Dp7kSlwOvQotwc31RMhc9Prk7M2ZggYPLrmCrsnTMTph1Us/wRAWw6Df3Ix39Y+xabSow4KIkLuFbWS5jG6xGmNNK5GlRXTHJSEPDuB8SDYpAW52Z0xihfkS9eUdBFqNaJtq6RuYT9qQRBKXjBOEs9fLpiPdDDqzh1JpJFfO1nNg1GIGD4niNxNB5nZieOlj3M3tyPqlUVZp5IQimcnuWrIfXcH6Zi0pwfDnC3BPYjf1nx4j0mfGGx6Grb4dpVLC+ImJqPsnI0uKRST5uj3rrGnCtH4PmqkjUQ7Lver9bOqDLeWQrTCTv2c92nkTUOSkfed9YXYKbd5pqUKe7plmGBUPB+uEOcPJyZCm+3Ycmx8/fv738Au/6+AXfn7+r+DzeLAdL8B28hKqkQOQZyX/zSLvr7FfKsdReIWgWxd+9ZjTIxjzdv15A3t8iWQ3ljDwVytJiVFd9SHvau7AvGU/thOXCHv6AeRpiVgOnMGy/xSumiY000bh7TUhCQ/BWdeMSK3E9NEODF45dbfewYgT2/GJRPS4ZVSHpyIVecltL6NBFUlTl4v9C+7gpu0vE99SSX1SDsqsJNztPVwmjOzOSqISdIjPFbBtyi2MqDlDXEURvQnJVGvjsPikeDUBXMkbRWdmHk0mWNh6kqzDOwgdkILyyDHwQcCssYR4LIgVctSTh9P99NuINEochZX0qbSgVnEpPBvn4rlUd3qY2nwGb1EZIUMy0Vy+zOXRM8lVmokVW5HGRqAYnkvHxRrKC1pRL5/F6xub0NpNlAwcxwNDYXKMA/G769DMHAsJcZz85cd4F81mRJCFYlkkH9ZpCVVDQl8bISeO0miEthHj+OXKGPZWQ1En3NXfSURHI84rtbjqWxCJxMhS4/FabHj7TASumnfVRrbNJZgsm5ywWNcF67YSuHo+sri/cmb+Bp0WeL9IsHNJ+MK/e3O5UF2cnCxUIf348fP/H7/wuw5+4efnPx2fz4f9bDHWg2dQjRiAavzgq5Yz/h68dgfmTfsxGyw0zV1Ei01Gm1n4YJeKIfPScUQ796Mdl0/+XdOQar/+pHdWN2LZfRyRQo6rpR3NpBF4LTZcFXWY95/GqzcQ+pv7UA3LQxIaTPczb+OqbxE82ewe6vsNIT9GhCQsmPbyNtanz6AzM5efb3gcS3ULncpgnAkJOO1uIvo6kIYG0+lVEoWVUwGpRIzJJXxIOsof/ZzS9MFEGDsJSIslqfwispkT2HveQGhnM7awSOpuvYMSdzArUm0E/M9TqOuqUTvtGJNSSUoKJPyWuUgiQjFv3IfjcgXaW+ajf/RPeK02jEFhmJAhCdJSEtOf0OxYwo4fJcFlQCb28nn/GQSPyCFjeDKtXjVj605jen8rl+xBdA4ewQ5NDiNObGfLlNv48RgpLZ02gj/+hEuDJuJJScbqgqXxZobtWYfujsVccIVy6mg9EaeO0+DRIJs+jnZNGCv6wdEmqO+Fpf2EObqr7gunC2dNE3i8KPp/XcHz+eBsCxxrFGbz0gz1mDcfIOjuZUh03/13YJ0BNpbD7YOEJRI/fvz8++IXftfBL/z8/Kfi8/mEWbs9J1AMyEA9ZSQiuez7n/gdOIoqMO88inTGeN50ZDImXvCDiwoAcZce04Y9tB+4ROVPfs78CRFfXYPzciWW/aeQRoWjnjkG88Z9+BxOJLpAFIOycVbUYd5xBNWEIajHDsFRWk3fB5/hbu3C63bjMDspyxnBkLxQgjNiMVe3sqvcRfPCZYx993nEXT10KIJJMjRSm5xLTtUFJPFRbIifSFaUlHeyFhCl8ZJjqGPcey+gCpDTrAhFZjYTY+4gYvUcHu/KIv3KeQbOGUDChRNUXdHT7+EluJ55DXFVDT61ioqf/g/B9dXUTp7DVFsl6mPHcFY3gUiIoPN6fDQPHk2ws48uTRgbhy1j/vnNyJtbkKtkOJKTOTh2KaXicG7N89Fe2oTi2ElaCaBi1FTkSjlU1TDhwCdIAlQEzpvA7Llp9L2zmYAl02gNS+CVc4KoajSCp6eXiYc3YPbJGTAwgpOZY7ljUiAVesEWRSIGmRhcHnhk1N/WWm3pg03lkB4CU8LN2HccxttnJvC2RYhViu983uUOOFwPd+Rz1favHz9+/j3xC7/r4Bd+fv4TcZTXYtl5FFlaPJrpYxCrlX/X870W21eeZx6jCdO6XYiDtAQsmsIHVxSMSRDEgae7F/O2Q/jsDsy9NnYnj+PGJAv2o+dwXq7Eo+9FFheBYlA2kohQHKXViDUqvN1Ggn92G56mdvRPvoanXU/AkqmIZFIsu47jtdiQRIbSa/NSHxxPoreP2CEp+BLjeUQ2EXlXF3e9/CD4YN2gRSQ7u8lI1BC4/wB9SWm8u/gRMi2tbPSkopSLyYuEO4o/JbG1isaf/BTDPU+S4OmlzSpmW8Y01FIf9z07h8QwKZs3V5Lz2h8RVVRjkqpBo0T06L3kDInDtHYXdrGU9ovVeEw2NB4HAWI3Rp+chp8+Su7ETMrveBqdxUBj3nC8IjEDxHoi5R6ci+fws5IIngm4RMWBEjp10WyNGE53QBjtZtApQdXbw9zindhX3cBNvjJUpWUEzJ1AsTKW441w20BhIxeETd3qehOJoRIOdKiJ1cLQWHj5nHCczwe/OSYIxR+PAOV1irx2t5C00WuHJSkO5EeO46pqJGD+xGt6M351n/jgcB3U9sItA4SoNj9+/Pz74xd+18Ev/Pz8J+Eor8W6+ziS6DAhueAaKRrXw+fzYdlxBPv5EhQDMhGHBuM4exnt8lnIkmK42Ca09BbGmbF8dhhPdy+aGWOwniuheMs5EoeloA1WYfpkD8qReajGDgarHU+fGUdRBe52PV6jCUl0OD6TFXFYMI6CcmSZKXgaW/GaLUiiQgm+bwX1dUYsWw4idTlInDIAli/gpbZYyvUw/sxnDD6wmWdXPMXSir1Euo302kBiNCJNjOWdcbdRY1EQoxU2WD8KPYv59XV4b1rC2df2Ic5MoWHMZLL/+AxmXTiXpi8jR19J6qHdqIw9KEICqQ6Ko1kbxZTOS0Q8fBO2I+exHr8IIrGQVIKSkuGT8TZ3EDsqi5E/X8xvPrcy+dB6wlpqsbf3EPGb+wn8ZD3S+VPZd6ITrdXIldRB1Cf1p8MpZXgsKCRwqB4qu+Hm8+uoHz0FU2gEvxgD2WGwq0rYhr0x99obsI1G2F0Nd+ULixhHGwST5jcuwoxUcHlhXw3cM0SoAl79fsO5VuE5s5PdJF0+i/1iGZqZY1AMzPrO+U+fDwrahQ3d/CiYlOxf1vDj5z8Jv/C7Dn7h5+ffnS8j1Cx7jiONjUQze9x3GtkCuFs7sV8oRTNn/FXee16zFePbm5H3T0eek0rPb15HJJMS9uzDiFVK+hzw5nkvd5lO4yooRZ6egKtdj7u5g7YLNYiGDiBW48V2rhhZZjISrQZJsJbAVXNxt+uxHjqLauxgHMVV+OwOjO9swd2uRxYfhbxfKqpxQ/C069E9uJKKvZex/eJ58PkIG5yO63e/YH21guON4Daa+d2f72Ddsp+S3l7Fgt4CThHDzshhjND0sU+ZQb+y04T/aAVKpZi8lmICnniGrsHD2RIylOzmUloWL2Pmbx/CHBWLUiUjtuQCFqRUxPWndtYizsQMxOYR8V7vRtRuG2KxiN63NyNSqXDIlayfeBuGUWNI3LGJ8ZYq3lr8M4r0Em5r3s+0lrM83+9GftSwE8WxE8gH9+dkzCAupg6jgEiywwQPPYNN8MdbWwyJgWArqya8upxTI+YyIxXig6BMD8NjhcWIawkrtxdePCu0WIMU8O4lmJUOJZ1CNW5aqnDc+RbhXDfnfX2eVhNsLIM0nY9xbZdwnriAevwQIdtWfG2PFZ9P8AbcVwNZYcJ1Xa+S6MePn39P/MLvOviFn59/V3w+H87SGiyfH0cWH41m1ljEWs11j7d+fgJneR3yfim4WzoJvG0hIpEIZ3Ujpg2fE3jjbCThOgwvfkTQXcvwtOux7DpK4J3LWHusl3GF+4gYlY15x1Hw+ZBGh9FT10XpwHHMuaE/hpfXYnX4KLGoMfbrT+jlAnQVpfhEYorufZicD/6CLSQMqcOOxusgrKyIoNXzCZg/CfNnhwi65wbq/rgB29rtWAN1eKMi0OcOor7DicPuwY2YIVdO4JIrKeo/hrnVhziRNJx3By5n8dkNrBuyDJdKwx8DLxLYWEd1XR8ZZw6ydfxKLgydyY2ln9EamcjIta8RaDaAz4dKLoYRg2jVxWEPDGZLziykMin9wiD/4XuQed0EuczIeg0UL7mFfYPnkqZxMrSpgNTj+9i66ud0KnX4KuuY/P7z/HH5k9yeYWPspT34AoN4JXMxTTYZZV0wLgH6RcC5FghXC61RnRJ2XPHwTOm7XJh7E8UmFXIJhKqE1I1gpbBha3ULBso21xf/dQv/PzFJaPH22mFdMcxOh51VcO+Qq8Xinuqv7VS2V0CPHRYFtCHdsgtFfjbqScMRyb5bxVV2C+dICILpqf5ZPj9+/pPxC7/r4Bd+fv7d8Pl8OIursHx+AllSDJpZ4763petu19P30Q6U+f1QTRqGSCTCevAsrpYOJBEhuCobCFqzCLFGRd+H21EOy0WeJcx2OasbqXr4JUzZ/Rj2+HKMH23HXdtC2NMP0nu5lv07K5nz8BSsr35MfaeLK/2GM1neTvQts/EYTXT/9i+4Wzqwny9BEh+FJECNLD0B/b7z+EYMITYvDq/JinLUANpe3ECdwYtHq0UkgpboNHZNuRmdCsq7IKm6iIc2/oYzU29gZOEBNuUvYmvmdPobaxnUXcm2fjNJCITZGVC8vYBZ57diyc3jw4FLsTd1MPb0Z0yoO0W4vhVnSAiu6ZPImDeElsvNVJd24hBLsXaZ2DlwNvdt+h3hDdW41BqsUiVlOSNoD44FhZwsdyeRFj2OW2/kgDydngPneFR8kQJ3CPVWOV6Pl76ZM2nyqumyQJsZskJhbgZkhAqWJ3GBQkzZ2hL4jegMZrcI7eThtJqFSuDGMhgYKZgfq2Vf/1JJv/ivDMTfEHZbyiEtBD6vgXsGfz0L+CVeLzx9AvRWITc3vaEE69HzBN+17Lr/YKgzCEIyQgMz0yDwu3c8/Pjx8x+CX/hdB7/w8/PvhEdvoO+jHUjjo9DMGPO9gs/n82E9eBZHYTmBN89HGhHy1de8Fhtdj/4RSbCW0N/ej0gkwtXUjmXnEYLvWY7P58N29AI9x4rYlzyG+e2nkYbrsJ0qJPR/7kQcpuPQYx+R8uiNKD7cyDlvOLFxgaT3NiCLj8RRXIll93EkUWHIs1Nwt3bh6dAj1mpwVjYiTYimPCqT1KJTyKNC6ezzcDRmMPktRTi8EiQKKY4JY9kzdCEV3aA3e3j/jVU0xaXTEp5IQ3As+7OnEx0It5x4j0+HLgOtBpkYWtssrDj4NhIRbF1wH7OzJIQ98RSpZefROi24Rgzl3H0/Z9VEHS4v/PwzC4H11WRfOU/CheOENdXhE4lwKtUcvfNRptpreH3MGmxuEYP15SRWXIJbbmBPDYTu2kVGnILO81UEiFzYliygKTyZk81CVS4rFBKD4e7BQkzZ9kqhmhemhuONMCvGyg2n1rFh8m0Eq8VMS4ForRCl9m6hYMOSH339+8LhhlfPC+3ecYmCuPwmbSbBaiUpCOoNPmZUHiLS00fgqrnfae3T0gc7KkEjF2xddH6LFj9+/s/gF37XwS/8/Pw74PP5sB06h/1iKYGr5iKNDv/e53i6e+l7/zPkOamop4++alDfVd9K39qdaJdOw1FchUguQzNnPL1//JCAFTMxbzuM41IZisE5fJI0hUVjdIT26em469fIkmOIeOkxLj63FUNOLiFnT1MWm8P4lgto1BIsu48jz8vArTeiuH0ZolFDsLz6MR6zFc/U8XjXbsNU20Zx5ggSyi/SZ/OgtfaxfeLNrDr6LjafGJ9YglQiojBlMDVDJ3BUk8HvdjxBSEUZL9zzEnOLd7Fx4q0YfTLSumrR1VzhwMBZjIiDWJueqA3rCBa76L39VpL7WlC98xEhly8hk4q4tOBm6pfdxD26ekynL/PZUT12hYpxWSpMm/YRVVeOJD6astgcCAokqE9PkSway/Dh3B3WjKOgjKLBkzjeLCGorBhvWjKyqmqSO+t4bskTqNVS3B6hJRqsFJY3rnQL//1/7N13mFxneffx75kzvc/23vtKWvXeJUtylXsDF0wxYEKAFBKSN4F0AgQIhGIwNrbBXbLcZMuW1XvXrna1vfed3mdOef8YY3BsCwjGtPO5rr1WmrKz1hyvfnqe575vqwEGg7CyNDPhojYL7rj4HEdyZ5HdUsWhYfjzJT9/DyUFHmsDhzEzS/fdCij2DECPLzMZ5Zr6n9+ekOCFLpiOZXr5ZYspZn6wjZfVcjbcvextc5MhM0rv+c5MIchVtZD77ouBGo3mD5QW/C5BC36a3zVp0kvo0ecxzarFumn5L520oSSShLe/hjw0gfPOa94WEmV/iMB3Hsfz6Q+gc9hQVZXwT19CGp9CmvAS7+jnXNFsAkuWEpiJUhSaoFkMYD1zBr1BxNRSR2QmwnFHDY5IgCyXgcLXd6FzO9Blu1FvuYb+7+xAdTtJFxSgD4cw9vbTs2g1UiSBZWqCEVch6w5tJ2S0M1NajZSSmDfVTlQRGSuoZLB5EXvczWzu3oswMYlO1DH34hH+32ceYnOolX5DNueK5yAIsOnlhzi09kYaKu3c7hiGh58mGFdoK51NwWgf1YkJbHsPoGR78Hz4er5cdQMbz76MVa/wQ/cK6qQZ7p14lYkn92CIR3F/9AYOxLIRZBnr3HpCP3iKRG4+NasbSB4+w+i6Tbjses4MpljPMIHpMPbGCvZNmHipYAkuM/zlskzoOzmeqcptzM6cw9PrMufshkKwtQ6qk5MEtr3OQwtvY3Eh6MXMeLP/bd8gXJiCD83NbPH+IlWFf9ybWZH78yWZ7V9Vzbz2noFMZe+c/Mw/BIL3P43tmrUkqmv4/im4d2FmlRAylcPPd2VWKa+p4x1DoUaj+ePwXmYbrb5Lo3mPqKpKbNdhkq1dOD94NfqCnEs/PpUmfvQc/v/6MdLoNOblc0F8e2O16At7cdyyBZ3DhhJPkjx1gfToJOGfvoShtoydn/4nrlme2St84gK01IL30RcZmL+GiNlBvKAQ57Ht1A0+jqM8D4vdhPUjN+C8eTNdX3uCkZ/sJW9JIyeuuYtIJM2ir34R3fL5LJOnELtOE+0dZ55ex/jsBeT+zT08PlnK/J2Pkeg+xf23/QNzAr1sy1nCio49MDJK9UQPucFJ/PXNfCn6Gt6DZxmvXknT0AQmnYI9y4Y7x8bCwZNEn36CoDdGd+NCEtMRnGMDWC6eI1RQQtnVS9iZv5BPHP8xj+SvIeqPcf2Z7cwP9+M9dIaUK5uiu6+kNWLhpbKVLF5cSN9UmmsXn0ERdAy9fIJzl91AyeAguukxlnhMTM+Zg3T1LMa+/jB7F97M1vrM6tqXD2Xm1ObboSkHvnIZqMB9L2VC2H0LwaIkCTzwAq0br2dlPhwZgfsWvfN7u6Ycih3wrROZytwCe2arPrrzAH0TKfSpMu66tQwBFxMRgSfbM9u6n1mS+T5+Vrzj+vD16AtyMAEfnJOZqXvXnEzDZ188cwbxZ6PWNBqN5lehBT+N5j3ws2IM07xGPH9x9yVX+aTxaWKvHUUan8bQXI0qyZTsf4jYS/uZvPdLmBfNwnHLZozVZciBMNKEFyUcJfCdx1HjSYwLm5CnfdivXkOPPofZ8VHybNl881imEbC1q5NcSwzZ5Me0pIyjjx+gQB/GPasMncmIvjAH0/xGjv7b04jnO6h2irySdzlX54Wx7NlGusKOaIyROHKOGW8c87L5PH7H37Ks1szXLggsGj/H3F3P8B9Xfp5rD23n0Kx1bD21DbeQonXd1fQNDnLzvoeJSgJHTgc4tOCDeF351FgSXLfnB+jMZjZ+9W70kQjdtiLa1lzBlVIv7tgQ3v4uQs4sZppbiCgemrtPclxXTN3p/RhynDQefY1ALMl0YTXMbuSHpZt4WS3ni2sgnlKpObqfWtXPxJl+TLLAUFc7eyqW4K/ZhCDqyDdDZJ+fDUYLH15hwR+HIiesKstUzRp0me3WfzsIrVPQnJuptpXHJvE/tAPrtRs5EfRwsz2zTWu8RAPkmiz42Hx46JTExuFjFPV3YLtiFT8Ku7m7YBjlxdc53BUghJEbl5SQ3VyGPl1M7HAbyVPteD5751saeRc5Mlu5j13IFG3UZr37a2s0Gs270bZ6NZrfUPzgaRLH23DeeTVijudt9yc7+pAmZhBtFuIHTqNz2rBuWIqhqoTpL3wDY1M1rg9eDYAcCBP41k+Q/SHUZAppdBJDXQW2tYswL21BMBnx//ejSCNTJP71C7x4UeGmo49zrHQBuUsbWGgNEfjuEwhWM+nBMQb02YhSmvy2U7g/cQvhJ15BNhg5NmsttTseR7baeGrNncwq1FH+7FNk9XcjXLYC9dxFuvJrsPm9PPrxf2NOgchAEOYLM8z+9L38z/V/S2WZnaXf/TLnKuYzJzqEr7iSdjGHrUee5Lu3f5GC/otccexpvnHbv7BMHmbV2Gmy0mE6syrR+/10kkVB3MuySA/JQJRwSgW7nR7Rg9ss4HfnkXC60dstrLy4l+jZHgYLqzAAus98mOy187n/NNzgGMd97gwjHeMUDHXRr7o4fdd9eBUjV7/+CEe23M7rPgeVrkxz5Msu7GK4qIZrr65iTv7Pw9vhYXitL7PVuncQJqKZKl3l+Fkaes9Q/ekb6FOdRJKZ+1aXQ+klfsQoikr0yHnCrx3jQMlCAnPmkVAEBgPwgTmwux82V8Fsd5J03wjpniHS/aPoi3Kx33DZu/bm02g0f3q0M36XoAU/zfsp1TNE9KUDuP/s9ndc5VOSKSbv/RJqPImY5yH7Hz6JPjcTDtNDY0x/5ssUPvk1BL0eacpH8tQFkhd6SQ+OocYTqKk0xvpKDNWloKrED5xG0Ovgpqv52lQ5xTkmPD0dLHj0u+RtXgzJJKqskmrrJlJbx4WaBawaPoGxqoSEYOCUrYId/hw+9KP/h8UAO//iPyibHqTgxCGyDh8g5vCgplL0LlxLUWQSzyduoctZxslxCM5EuftfP8a+pVsJrlzF5u/9E22V81hwbh/P3vpZ2mxlfPa5f8ca8HJo2VZqO0+Q0Blomu6m4r8+S+jweXaOWsgKTZMfmcYeC2Fd0EhoeIaE0UJZoZn4y4fw6ax05tTgjgdwyXHqnSmS0RRnXNXU5OkZv+8+mtwSe548i7mvD7mogKPF8yi3pVn4/a/juf+f2R1wsqUGSsKTbPvXFzF96k6e6NSzsVyh4EcPUPfPHyHXJmA1ZBoam0RIypnmzC5Tpsq2d1piw+mXSAoiRxduYc+QyFAo06D51DgsK3n368Ix0EvRwT1MV9RxpmEZ0ykD8XSmD+C8gkwz5S01l14x1Gg0mp/Rgt8laMFP836RQxEC3/opns/cgc5mQVVV5JkA0uAY6cExUu19RF8+gL68COtly5D6Rki2dmOa14DObs3M5q0tx1hfAZKMmOPBNL8RU3MNgtFA4LtPkO4ZwnbTJuTxGUL7TzNRWEnwwiDt2dXMn+6gJDiOIcuBd9CLqb8fnSShc9oQGms5Nv8yNtaKxO5/ghMbb2JqNMieNTeztWsXjU88SMyVTdSdTXlTAYaDRzGUFWCpLSXUOUKX4kYuyOOHK+5BVqE0McM19/8TM648xjduoXzvK8RjaUS9jlRtDcmeYRzpGHMvHmaqsgH0IrGUyqHL72BlsUTZV7/CVGEF+nCYXDWO6ao1nOmKUjLajefatbgDU4Qf2k5CNDNeXMVI3VyONq3imiNPYRocpLu0mdklBjoKG6mc7OV80sl401wqFlUxFddhOHOW5c88wPDdH+E5VwtLimFJUaYlyupAO3Mm2jmx4Qas3d2Eu0fJvn7dz5sqSzAZyUy4qPNk2qGcaQ9w+dFn6Jq1FG9dM0WOzGSOMmdm5JpelynqsBjAqs98tuhBGp1k+JFdDBqyODdvHXl5VhYUZrZ9RV3muWZ9pseeRqPR/Kq04HcJWvDTvJdUVSXdOYA840dNpVHTEmoyhZJIEX1xX2Y+rs2KEo6ACmKOG31ZIdLwBOmeIQSLmey/+9ibX0+a9hN64Bn0VSVEt79OwU+//M4rhZEYge89geczdzD47w8zeqyb7o98goW7nkLKzyOYFFh662KExmoO/Nn3yPVPoEgyqUkvtolx+psX0Vjrwtzby4CjiGRrN0pdNZUuFd3RkyiqwIWGRZROD5BSdESy8wjqbaixBM8svZnVg8cwDg3RWjmfssAI9qlx8iPTnJ21EtlgYumJl5kuqmTiss0UPP8slSNdGKUEB1ZeT8/qLSx77DtMF1bS3H+KwuEeoq4szPEorgX1pPJyGT3Rh6eukMJkgOSpC6ixBMyq52+u+EeqmwuYq5tm6Tf+mcnJKMM5FRAM4WtuwbRkDp1FjVw3x8CRETg1qjJ7+6PMHTzNC3f8JZOeYm5uzhRXfP0onJvM/Fp5eQ+b6vX0nx1h419eiT0v87MhJcEjrZmRaeXuTHWts7eL3IP7WfzX1+IoziEhZaZwxN+YwvHIebisOtO6JSH9fCqH7eRJ3B0XSN94DS1NHirc2jxcjUbz3tCC3yVowU/zXkkPjhF5+lXE4jwM5UUIJgOCwYBgMhB99QiG4jzMK+YjmAzoHDYEQcgUeTz8HKb5jSTOXsR117Vvbu3+jJJMMbr5XqyblpH1Nx8h3TdC4vBZ0IsIRgOC0UDiRBuGyhJ6BTfBXUdpuHwO0sPbkQ0G9m79MLffM5fkkfMce+wIRVMDuKQE9tuvIN09xPNNm5n14++iRmKIo2MYIhHkynJs/mlSOgNRh5tDi65gmcmLvroU47cfoKOokXz/OH151SScbsp7znO2YgE14534jQ5EvY6EzsC5uqVsOPE8FinBdE4xc/pPkRTNWHIcEAzBk9/n1a/vImdyiLmRPrIC0wyX1ZMOx/FEfcgWG4Z4FEEvkh2eRknLqFXlDF95LTvPRCmszqU8NErjYz8i6vQwnlvOyfxm5n7uOs77DQwEMlW3rw9AZMTLp1/6KumiIvj8J9lYI5JtyWzD3n8agkm4/yro80PrpErt809ikpLw8Tvp8WVuPz8JcwsyRRNVHhXphd3sOx1g2V9uJdv19hlnY+HMOcA7W35+mypJhB59AZ3VjP3GTdrZPI1G857T2rloNL9Fsi9I+OlXAXDevRUx2/2W+xOnLqCzWbBfu+HN21RFIfLSAdIX+3F+6DrSXQOY6ivfFvoAEkfPY6gqwbxwFhN3fgHz0jnYr9+IAKgpCTkQQtp7koPGSopefpEFt61Ab7cQW7uIJxbfzNWHnmTyuh/R2rCU8gIb9vYpsr/1BaSRSToFN/UdJ6iYV4r38VdJRxNMrFpLvyWfwrE+Sgc7EAIhbGqK9imV+du/y6g9l7rOUwznlBEyOcgL+/BaPJQERultWUbLyHksg30cXX0tW/c/SvlUP7GcPOaMjfHKgs0YlrTQ2Hkco1lEvuouNikyxus3E5wu4cRn/ooFi4v44tNeVj77I+onOjGmk6j5uQSr5xKtqGRCMZN86TyzjAplR5+hYGoQli/CXFTIcPN6Khc1c86XWVW7byE8dFal6MRhbj/5BMoHbmDlR9Zg/oWfZDNxKHdlVtu+cQwGA7C4WOD42uvxKHEqUpkxZnodfH1zptUKQGz/KfwxmL7hRrLfoUWKqmb6860p/4VrxR8i+IOnsa5bjHnRrP/jFafRaDTvH23FT6N5g5JIEn1hH+nBMRw3bMJQUQRA6mI/sj8EgoAcCBHbeQDH7VciGPSZdKEoxF47innJHCxrFqJG4/j/+ydk/fWH3jZeS5Ukxu/4W+zXbSDdPYR58Wwi215DTSSxLJ+LZf0Spl4/zStiFctm2qm+YgGmWTUEH3yWVq+IfWqcio0tPGqaw8ZXH8X4k6cJfeRuxsYj5F1spWvrrWxd6ubiP/6IPlwUSgGcgRlS0STmeBRZkvG589BZzejjcTqLGsgLTzOSW47RYqTQN4ptoI/J/HJG62ZR1HqSqrFuws4sTKkE2SToq55D2u7gkYarufr8i8wdayXs8NCXXwPJFCU2GV08Qcm/3sfBCQPR/36Y5ed3I+qgr7CWjpr51NtSdJvyMHb3YEvHMYhQGJokN+pl4OOfpODwAQZN2ZgUiYVbmvixaS6S2YJ3LETpUz9lYWqYpi9/gsJZP6+wiKfh0Va4OAPDoUzxRSCRGX82vxBumwUmfabXoVkP1zdkii3gjS34B7ezfcPdrK3UgZA59zcZzXwkpczjCu1wc3Pm16nOfsJPv4rrQ9eiL8r7rV+fGo3mT5e21XsJWvDT/LpURSG+9wTxI+ewXbEK09wGBEFADkUI/+QFdE4Hhooi1GSK0JOvYL9yNTqHDRQF1MzzzfMb31wZDD64HcuyuRgbKt/2WoEfP0f40edw33sz9mvWoSZT+L72Yywr5pI8383UwTYio15y5lZgVmX0xfmk2nuJSyptizay7ppGnmxTWXn6ZdR9R+lpXoxrsB+Ld4rWinkYq0oyff9MRuYefQWnx8xIXI/O6UDcuolg1yjOiJ/I0CRx9EiqwHDNbOzpOHOYRjxykjQCg0W1lE4P4gnNMFVUhaDKWE0iUb0Fo5xkVHCQ7x1jpqyaV6+8h15XGUXjfazvP8R8VwLFoCf0ymE8EyPodBBubMKVbWOyaxLSaWRZxh4LkzBaUPQGzCYdtrJ8dB+5DWn/CR6edyPOfDfhqATn21k2dpbRuIhxZITmcgvL//ND4HDQH4AL09Dvz4Q9ix6yrZmWLDVZ8PVj8LmlMBKCh85CpxdqszMrftEURNMQSarM3/5jji2/Cl1eDguLMj368t/o1Zdve+v0DVVVib16mFTXIK4P34DOYnpfrlONRvOnSwt+l6AFP82lKIkk8qQXaXwGeXwaaWIa2RfCvGQO1nWLEEQRVVVJHD1PbM9xnLdejqGqBFVVCX7/KSwr52OaVfOuXz/VO0Rs93HcH7vxLberikL0pQN4v/gdcr/9BaxLM4fEAt97AvPG5Qy5S3mxB7L37KYpOoz52EniJSXECooJVVajtHaycEkRB/tlKk8dQvYH0dksxAQjZrPIGbGAWUPnEUJhrPEwiqBDMBl5Yf2d6GrLubV1B+1jEp3VLTSER/BcbMXo82IrzsEYjxEPxAgKJozJBDGbg6zQDDGTlUPNazFYzRS4RXLOnIRojPNVCzDHIhxcfg0jpfWI8RiFST+3v/Q/eNQE8YRE/kgvOgFmymsZvfEWmoYv0DUt07FiM6nuQVbsehyLkMaY7aai0gXhCN51G0iOTvOT+TdTW2TiyEhme7XMlTm3N/vcPgrDk1R/6nomk3qSElR6oCkXRkOZ4LeqFHYPwIfnZUavSUpmlW5XL/T6ocYDVmMmyP0s1OUfO4jHridry9Jfev2oyRTBB59FX5SL7eq1v3Qcn0aj0bwX/mCC37//+7+zbds2Ll68iMViYfny5Xz5y1+mvr7+XZ/z0EMP8aEPfegtt5lMJhKJxK/0mlrw0/yMKkkkW7tJnr2I7A0AIBiN6AuyEQty0Bfmoi/MRWe3vvkc2R8i9Ojz6EvysV+9FkGvR02lie09gRpPYt+67t1fT5bx/+eDuD91W2ZF8A2p3mEiT+9CmvYjmAzk/uufA5A400Gio5+Ha66gbRIWZCVZ9sT3MI6OIjTVom+sQTlzAWVknNG/+At2vDbG+ld/gifmJ2qwUjTSDWYTU5ZsnHKcgDMHVyyAZXoCSzzKjCefmNGGqEj0e8pYOHCSlN1BwOomf3IIyeFEh0pAZ6W7oBaDIqHLdlGQCqCLJ3D2dWNJJ0ibTKgpCVWFgdIGcsMzGBqrKS+zs3/CSJ4dKvtbCc1EEUIRVEFg7ObbeHL+jVRfPMnyFx9h0JrLM1d8gss6XmPlnqex1RRx8Lb7aJqTj+8Hz1D/2Rspi0+zXamiMkdkMACHhjPn706NQ3rSz9azzxL/2N30BgT+bFEm9AEcHIIeP6wtg3/Yl5m2Ierg1FhmwsW5KZj/RgGH0/TWSltpbIrw4y/j/uwdvzTESVM+Qg9sw3bVGkyza3+9i1Gj0Wh+A38wxR379u3jvvvuY9GiRUiSxBe+8AU2bdpEe3s7Ntu7N7JyOp10dna++XvtX9WaX5WqqqQu9pM4fBZ5xo9pdi22y1ch5mdf8jpSVTUzgePQGRy3X4mhrDBzuyQx/fmvI42MY1o0G+n7MxjKCtCXFqIvK0B02t/8GrHXjmJe1vJm6EsPjBHZ8To6pw37LZcz81dfIfdbfwdk5rZGXjzA/cs+xNAULC9IU/HAD5DOnKPtzo8xvXApgiLTcKyPrto11HzsC9yTDqBaTHjrmqnra0W3YTEPLLuLRWd2E163goVnX2dyx0G81Y1MFZTjaj2LPp1CpyosHj1LyJWNMRGjaLSXRHYOpliUsMFCxGIj4MiiLDBKnjGFcdUyLhzpx5VXQkWZE2nbTvob5+OcW0+oc5Tuq25kocHPzgtjNHYewTk+QlxVwOGk67pbuXjFTRwflFjadoSlQyeZcOVRaBP5y5/+Hbn+CQr/6k48H7uBknCKfX/7CHV/fgN7Q24kxU22BdJyZnauSZ9pppxrUflS5AVK//kq9IUCgUSmaveuOTAQhC4v3NQEn9qZObe3tR6e68oUeAwG4S+WZsadve09lyRCDz+H62M3/dKfMclznURfOoDzIze8Y8GORqPR/KF4X7d6p6enycvLY9++faxevfodH/PQQw/xmc98hkAg8H96DW3F70+PqqpIg+PED51BGhrHUF+BZVkL+sLcX+n5sjdA6JHnMVSXYrtiFYL483EKkZ0HiGzbTd43Po/OZkEOhpGGJpCGx0kPTaCEo6ATELPdJFu7yfnXTyNPeons2IPObsV+zVrEHA++//wRCAJZf5VZzfb/cDv/rbYwbslm7fgpFvo60Z89T87ffRT7yvn44nDyuy/TFbdQ8eI2co0y4YQK+TmUhcbRx6McXH09Rd1tRCuqmD92nsneacZLaylvPUZIMNFWvYim9ASqL0DM6sCpJnEP9iGa9AQsLiI6M0PlTXxnxYf5+z3foNEQYPIzn2HnnjGWHHiWuM6EIisc33wztx5/AjWRJOHKQohEsUUC2KfGCVucTJXVcLpuKfXhEXKdImNhKEt6MQe8jAZVdLnZuCaGKTIkqLz/7zBWl2W2zr/zOOrKJXzDX8WsXDg/lanAPT4GCwrgM0th/xB8RD2PbnIKx3Ub33xfAgn40r5MU+VNVfC1I3BlLdzYBDNR+KvXMrff2PTzAo7/Lbz9NfR52VhWzLvktRXdsQdpyofr7q0Ixre3eNFoNJrftj+YFb//LRgMApCVdenp4pFIhPLychRFYf78+fzbv/0bzc3N7/jYZDJJMpl88/ehUOi9+4Y1v7eUeJJUZz+pC71IQ+PoywqxLJ+L/vYrfq0V4uSFHqLP7cV599a3BUU1lSb0wDayv3QfOpsFANHlQJzteMtWnyrLeP/1fvTF+Uzc9XegE7BtWIp5QROC3Up6eILEqQvkf+8fAZg+3snzu8fIt0fYWmmg4aoFxF4fRV02h+i8+Tx9DrjQRW5rF7M6OnDl2klYbMxpyid2oZdAQuZswTyKjuxHcNiYdeQVvME4SkqhcnwfPpuH/qo5zIsNYE7HkVwW4lYHo0ET43ULSAgGqqd6CWXn0dh+hAda9+DLL2UoK5v01x7g1guniZls+KvqMDdX8/n2x5AqsuksbCA06sMY7GHYVID/6itZbZxigVWhot7Bt0130+4TuXzgAJ62MaShaYotIgFzLk1fu4/8RXVApqii5yd7GDRU0BOrwmnKNEWOpEEAbm6Czy3LrOp9rDGO9J2jZP31PW95b/oDmW3gk2OZat4lxZmQNxSEf9ib+RpXXGI3Nt03gjw2/ZaWPG+7xqJxgg9sy8xS3rpO23nQaDR/FN63FT9FUbjmmmsIBAIcPHjwXR935MgRuru7mTNnDsFgkK9+9avs37+fCxcuUFLy9uGYX/ziF/nSl770ttu1Fb8/LkoiSbprkGR7JugJZhPG+gqMjVXoSwv+T38pR3cdJvz0LpRQFPOCRnROB4KoA1GHIIpEXz+GmkjiuH4jYl42+tIC9MV56MwmVFVFjcZRgmEiu48S33sCy9K52LeuQ8z1IA2Ok7rYR+JUO/HDZ3HcsgXbxqV0PXOEoUd20XPNzdx532KcWVbiJ9sY/4fvsfuLX0fQG6hLTCD+7b9iCPpx2vREMFI8tww1GIbZTZztCuMvrSQhCXjOniQZSVA/cB5ZpyeQU0hc1WMvySLfP8HxxpUMrbqMrNA0C3/6PYZN2ViVFN3mfGpGLxK3u3BbBSjM51gyi2v3/4TBvEp6/vU/SFoc3Hn2SZI1Vew7ME5gPEB2Okx3VQsr7QE4eprk4oW83LgRFSg4dYyaE3vInhlHUlRaZ6+i+6qbWNNsJduSmXAxHIL8ngtUjHQyctV1tE0LTMcyxRnRNHgsUJ8NExH49GKoeGUH5sWzMTVWvfm+nZuAxy+AywzryuHLh+Frm8Abg9f7QS9mKnnfjZpK4/vKg3g+/YG3nMX8RdLoJMGHduC4eTPG2vJ3fIxGo9G8X/5gijt+0Sc+8Ql27tzJwYMH3zHAvZt0Ok1jYyO33XYb//zP//y2+99pxa+0tFQLfn8EVEUhvucEidPtCHoRQ10FpqZq9OWFv9F0hNTwBP7//BHS0DhiST7ue64n8vxeHLdsQV+UC7JCenya6c9+mfwffInkiTaS57uQRqdQgmFUSc5UCIg65PEZJEEHsxrIu2o5hoZKSKZQU2mk0Skiz+1BX1OGPOnluL2GC30RGrcuYuuNDSipFL5/uZ/x0/2c//CnWLWunK5vPUvhjx9EFcCyYgGTw35KSxzEbC5cN2zktXNhsscHSY3NMJHUUxiZpqinlemSKqbK66k6sY9kUyPijJfh8nrmeiRi/SPIvSPoUEhYHfRVzQGrlWPNa/CMDiAG/Sw+v5fKiR76ShrxZRfiDk7hkJMIAhhCQaIGC6rewHBBFQ3DbahpiRNVi4kZbZSFRimZ7MeUToDFwlhlA21zVqFbuZipKMgquM2ZwgqHd5Ksl17itc13kG3XE01DbRZMxWB9Raagoz+QaaWSOzFEedtxcj9+I8tKMs8/NATfOwU3NMJVdbCjE7ItsGcgs6VbaIdN1VB6if/1Qz95EWNTFeZ5je94f/zoeeIHT+H66I2Irnc4HKjRaDTvsz+44PepT32KHTt2sH//fior397b7Je56aab0Ov1PPbYY7/0sdoZvz8OyY4+os++jnnxbCxrF77l3N3/hSrLJE+1E917nFR7HzqXHfPcBpwfuArBaMjMxv3O49iuXI2xqZqpe7+E7doN6MxGpLFpbFeuRnBYkQbGiO8/RXpglFT/KDOVdfQYcnGTRO4ZxBHyYl02l8JSF/KufejsNsKFJTzrzyavr5OWYj2zvvEp4ntPEPjGw4wsXImvoISahx/IBFFBJW5341pYz3THKK7mCg6tuhZ/TjEv+9184qG/Q0mmeWHx9VxzYhvmRASTTkAfixBzujles4xcOcqstfU0Bnpp29tNwBeleHKAmM3N8OKV9K7YxEV3BcHecW4e20/FwddwGVVGamdz9pYP0zDdQ9U3v4Y5EuL4nPW80LSZQU8pLcPnuOrks3TXL0S8bSvr+w8S/eFTpCNx5LISLCvnMzIYJDYdIqfQgep0EKyuY7S4lmHFhhCNsfG1R9F/9DZ6JAf5Nri6Dn5wJtNWZTAIHjMsLobZ2RKD//IQ7dfeziGfldEw2AyZMWxf3pgp3BgIwKt9cHcL/MehTOuWMlemlculrqvEoTO4PnLD2+6TA2GiL+4HRcFx2+Vva76t0Wg0vyt/MMFPVVX+7M/+jO3bt7N3715qa3/9FgiyLNPc3MwVV1zBf/3Xf/3Sx2vB7w+b7A0QfuLlTGHE9Rvf0mrl/0KVZaLP7yN5oQdjczWJExdQo3Ect16OZemctz42mSLwvSdRJZlURy/u+24jcfQc9hs3kzh0hlRHH4aqEnTZLhLH2ziy4hpCWfnc0pxpHwLgnY7R8z/bEXe8THf1PJ5ffwcj1lycOomPH/wB1VVubLteQ2c10XfDbViefBZLRwfTjlza56ymrPssOquJuKLHhMTFqrmYVIlkJEn9UCvGRIx/uuFf+NzOr5AqLaah7RhIEk9v/gh5vR3MCfXjduhJKDoC/gSTZg8tQ+dIebIYza/A4p0hbjQTMDux6HVYrXqSl1+Gr3sMS083ZTODZE+N0D9nCYeymph98SixrDyKpSBuKYqtJIdUXEI6f5GEwUTX5usovmU9Z0MWao7vYSqkcO8/bMBqyLTGSbZ2kWjrIeaPEQkmeHXxNXTaSqj2gD8OpycyVbwryzIrd/U5sKESIi/uR3TZsaycD0BKzkzkqMsCsyET8r5+FD6+AJ7vyvTyq3CDKIDjXfopK5EY/m8+iuezd6KzmjNFQUPjJE93kOoeROewYlk+D1PLu7eb0mg0mt+FP5jg98lPfpKf/vSn7Nix4y29+1wuFxZL5rD8nXfeSXFxMf/+7/8OwD/90z+xdOlSampqCAQCfOUrX+HZZ5/l1KlTNDU1/dLX1ILfHyY1lSby4j6kvlHsN2/GUFrwnnzd8NO70LkcmFrq8X7pO+jsVjx/cTf6vHcuMEqPTDD+gb/BOKcONRDG2FiFmOPBsqwFQ305kWdeIx1N8vSsq6gvMLC2IjPOq9cPHTMQ33OchscfpO+K60nPBFAQmFq7gS19+7AN9OOfjjCjs6EbGiFruA85rTC0aCWN//xhQh//fxQVOem+66PY9+7jwpYbeH3KivXkKTb3H8Bz4RxnapawYvgEiYZ6XIcOETI7eK1lC1ee3IEzFWWqvIZYEqLoyQtNkhP14S8o5WJuLZZ4GIuSRh+LUOAbI2m20lcxi3xiFI/3EUWPEk/RWTMPIRYnOziJiIo9GcXTUELE7mJiMo5JhOJP38LszbOIpqFjUmbs4ZcJKnpu/8Im9KKAqmYqbweDmYKL4aCKEomxpN5GKAlnJyHPmjnz9zcr3pigkcq0XZGm/YQe3oHnc3e969nNZy9CsSMTuDu9mXFslyKNTxP80XYct2xBjcZJnG5HnvCiLyvENL8RY135b7yqrNFoNL8tfzDB791+aD/44IPcfffdAKxdu5aKigoeeughAD772c+ybds2JiYm8Hg8LFiwgH/5l39h3rxL7N/8Ai34/W6oaQklGkeNxlCi8Td+HUeJJRAsJnR2688/HDYEqxlBp0NVVZInLxDddRjbZcswLZr1nlVPJjv6iO87iXXtImb+8dtY1y/B9eHr3z4/V1WRx6dJnG4n8L2nUNNpUFVsW1bi/swd6HQ60pM+pn+4nYk5i3hQP4dKd6aIQFUh78JZ8gptVB7fh3Suk5fv+TyzmnOYjEJhYJxZu54h2dqNmOVEVWA0ItDp11Ey0kV43nxUg4Hs86exz6sn9A+fZ/yVkwyMx+n3qyyP9BItKcNw8DAGj4Mch4HUlJeK88dQRBFffilF04OksnOYWL2B9NFzEI1QEJhAj8JgTQtem4c8Q5qoMws5nuJE4WxK1rdQufNZ6va8CJEoQ/lVzDhzSLqycEyPI7gcxHPz0JsMxEQz5o4OLNlOlty+hNKty98MSUokRuD7T5GY28Jow1wGgzAWyfy5uMxQ5oRyN+RbYdtFCCRheQm05MN3T2Wqbwvsb30vAt98FMetl6MvyHnH93UoCC/1wI2N8ONz8OdLQH+JI5/J1m4i215FLMhBCUYwNtdgnt/4K7f70Wg0mt+1P5jg97ugBb/3lppMIftDyN4Aij+E7Asi+0Mo/hCq9MbkekFAEEV0NguCzYLOZn3jswWd1YwST2ZCYDiKEomhRGKo0TgqKkgyhrpy7FeueU97pCmRGL7/+jH6/GyiOw+S9TcfedvWLoA0MUPwgW0YSguQpn1Ikz6QZbL/9iMkz3WR6BrkdVsd1hOn6N5yHe1qFjc3yDTpQ7jDXtIdvUSf30tsbIbhgmr8n76PTfMcvNILVr3KitbdBH7wNPLoJGKOm37FSWomSH7ci2dxI6LbSfdwlPymEs588JNs3z3G1m3fpr+wFv26ZVxM2ti8/bvUDnfgdebgifhQAacUI2ZxYAl4kUUD+lSStMFIxGjDmo5jtJlJzGuh35LHhGxBUBXKPDpGQgLFXecoFOMoOj3BBQvp3nQthm9+H4Og4l+0BM+SJtrIRhYN5NggmIA5njRnh9OMKlZWlGbO2CUHxrDteIHhTVdhry6i9I2QV2h/axBTVPjBaVhaDC1vLOTuHchs326qzvx+MgJPtMOKqfPUpadxXPfObVYkBb5xDO6Zmwl9H5gNee/eC57wc3uIvbgfQ1UJ9us2YKyr+PUvJo1Go/kd04LfJWjB7/9OVVXkSS+prgFSFwdQAiEEkxExy4Uuy4XocSJmu9B5Mr8WDL+fh98VRWHmL76KPDmDviAXz99/DH2W622Pkya9BH/wNO5P3ILsDxF55lXiR8+T9/XPY6goQpVlDv7lA2R3dZCeP4fxgMy8EhGjXocux42Y5WJsxyGeyV9GWfc5nC4z0x/5ML5hHxVdZ5l/5nUMoSCp0WnS5aUcpoDGzpPkGGVcW1ZgXj6X59VKCsf7MJw+T7dXxe2b4OJVtzA0axHG/Qe59qX7cUX8jJfXY42ESFnMVPa0oqogoICgI2qycrRqGflJP0XTQ0Td2exr2UQsN59p0c7sKjtFbh08/QL2vl7iDfWU+0cImh20ZtfhmBojcNkm6tbU409kKmt9cfjovMwW9opSiEtwYgzCycxZu8q+8ywYPI3n3ptYWGd713N1qgqPtWWKLlaWZW6bicGj5zMrdYIAh4fh+CjcVh1n4D8e4eQN93Bzi57cdwh0L3RBzhtbxBUuWFT8zq8rh6N4//5bKKEI7s/dhWlWjdaHT6PR/MHSgt8laMHvV6cqCtLAWCbodQ2iJpKI+dkY6yow1pUjZrt/198i0Z0HSY9MYr96zbtu/f2iVGc/vi8/gOwP4bzzGuzXbnjHv/ClKV8m9H38ZgSbBf/XfkzyQi+uO6/GtmUlUjBMx9/eTzSaRq4oJyEameuIoUx5UfxBJF+Yid5pZixuSlc0UrKokviBM0wfaCXmcOH0TyHLoAsGGSyuoz27imW9x7Dkuhi87Eokg4loIE5WRysmVWJf9mzMSoo5g2fpy6+hIDRJ8WAnoiIRKChh2uyhzD9CyUAHCaMZg5wmabBwomoRo8U1FAfHcYZmEGsreW3ljejTacxjo1SNXcQ9PACpFCfrl9NZ0sx1F1/m9JpraPLIOB99jPKabFxWHUfdDRy2VqHk5vDVTQI/bYNVpXBiPBPcLq8BQVUIP/EKg16JF2dfQbFbJJwCFVhYlNnCNf/Cvwde6MqEuyvfqOtSVfjv43DrrMw4tp+ekciPTLPWOEnq6DmsV6wiUFbFExcybV42Vf28cGY0lBnFtqIkM3/3jrcv4GbOir6wj9BDz+K4aTOOD16lBT6NRvMHTwt+l6AFv19OVVUSx1qJ7T6KsbYMY30lhpqyN6dT/D6Q/SH833yEdP8o5vmNxI+cQ01JmOc2YLtyNaaWurec1ZMmvUSe3oU0NkXiXCe53/g8ppp3brz7Zui79yZ0bgfBHz1LengcNRon96t/SeTZ15m6fxsX5q3BNbcOoauHWeZIZvWztIBWJYtDewdY3XWAErtK+mIfSDIJsxVJ0GFPx1FdDpLjPnrmr6RHctAy2UGwsBThbz+FLdvBeFil58e7CMl6DpQtZl73Ua45vo2OkiYKgpM0DZ5Hr0oEnNnEjVZcapLc8QGSOgN6VWbGkcuezR/ELKUw+LxUKQHixcUczZtN9WQv5WkvfeY8ipfUUbKwkuPTJo4938q6iZMkbr+Z/lMD5J85jnL3LQRFK2On+8gNTuH0T9EizDCUtpCsqqIzp4ry2hzy7AKWZIzSHU+TmD2L5ML56ARon4aJKCwvhpSSWc2zGWFxUaZqdzwKtzZnwp80McOxI2OIE5Pkhaa4MJqiPk+koCYXQ0k++vKiN4t6VBUOj8CREbihAYqdmcB4YwM81ZFZLTT+Qi2GNDpJ4sQF4kfPoYQieP7qHoyV77IcqNFoNH9gtOB3CVrwe3eqqpI83U5s1xGMc+qwbVyKYDL+rr+tt1BVlfjrx4m+chCMBrL/7l50lsw+ohyOEt15gPjuY8hTPsTifMwr56FGYqiRGFhNxF85TN7//D/0Bdlvfr103wiG0gIEowFp2k/w/qew33gZsd3HiL1yCFSQgyGMVaWk+kaQ0zJD2WVkWQVmcopoSE4gjU6R8oWJ+OMogFNNYm6uRkkk0a9fxnlTCcLxs5S2nwR/EFlSmMopRpJVXIkwitnM+QXrOVXSQqS0gqwTRymcGSE/MIY74kcngFhdRvbFNlwzY4iShIQO2WBE1IE+EUcAVEEgYndxdOnVnM+qISfqoyU5ghqNkS4oRGmoIV5Ty+5oDguKwKyD8+MS1UdeRxeNErv2Kth9EHfET8nHrqEnIBJ79TDFk/3YUzFW/fPtPDFsxy3FKJvqY3m8D2ViBtViJh2IIN54BXJpCSk5c0YvrWQmZuzuz0zmyLFmAtlUNNNnb0s1LC2BkrNH8Z3uZp+rkaK6fLqMeXxwkQn7L7n8wkl4qh3OjaS5y9xHR0+QNcUybn0aaWyadNcA0sQMOpcDQ0keqgruj9+C6NYaL2s0mj8eWvC7BC34vZ2qqqTOdxF9+SDG+kqsm1e8GaZ+n0ijk4R+8iK6HDeyN4D73puJvXwQJRzDvLAZY3P1m6t8ya4BQj/aRvJsZ6aiOBJDzHZhu3odlsWzEHOz0GW5CD24HYBU9yDp3mFSA2PorGYERUUsL8Q0ux41mSR55iKCyQAOG8caV5O3oAb/4VZmPf84ksNBX0EtPc5S1LIiWrpPEHHn4OjvYbyhhem0nqrzx8hOh+kvrsPd340nMI3Xls1QSR3p2c3U48c8OIDS2UvWxDB6KU3cbGUmqwh7Koor4kdMJBBUhZTeSFpvxGTWoyigD4fQqQpJo5mpshp+fPVnqe5vJScdIlFWRt5gN7GbrmPhvFwK7Jmih+EgtE2DyzfJuuPP092wANeSWSzZt51kYSHX3beSwckUp/5rO53mIkqvX8Vy/STt332Brhvu4Ib5Jhp+YWddicRAEC65Ktzvh2c7waKHLh+ZXn0xlZKDuxGTSTrXXoHTJDC/MNOr75ftwCrhKIkzHXTv62TMl2assIq5rgQVMwPI3gCGsgKMLQ0YKorRGfUgiog5bq3xskaj+aOjBb9L0ILfWyXbe4m+uB9DZTG2y1f9Xm3n/oyaShN5bg/S2DSmRbNIHDqNdcNSoi8dwHbFKvRFeSROXiDZ1o0SCKHGkxiba7BdvpJU5yCxVw6C1UTiaCs6iwmdw4aqKKQ7BxE9DsTcLHDakQfG0Lnt6OxWlFgCNRJDsJqJD4yTzCtg6rrredHaTOn5YzTt2oYrFWXnbZ/hcPECVDJnzhqHWym5eAZDKsXxy27Ceeggm3b8AJ3NQlQREVNJkhWV/HDFh1jQtp/1y3IxjgwzeXYA09AQukQcn8WDN6uArHSE/LE+9KqCJOhImK2kDSbs4QAiKpLRhJhI4LVnoXM6iBgsdJc047N7OFCzipbYINfq+qn9zI1sHzCRkKDLCyfHoNqtsqDjALbBAUwf2IrDoqN421MML1zBjTfVc+DYJN3f3sH0+o14WqqodmdW1lZJg2zp2UPeZz/4fyreGQ3BV49kVv6uqFGpeOV5JowuTs9aw2g4c7av7O11Nm+SJmZInGon1dGLYDYRa2hku76ej9bHiDy4HVNjFebFsxCL8rSzexqN5k+GFvwuQQt+GaneIaLP7kEszMF+9dp3HUb/u5bqHiT85CvYNi1HsJqJPLcHMcuFYDLiuHlLph1MOEpsz3GSrd0Y6yvAbCZx4CSpzgEEqxnT3AZSF/tw3XkNmIwkj7cRfVC/mGIAAGJTSURBVPkg9lu2YFm9kNjLBwn9aBuG2nKs6xZjntuAsakab+cYJ7+6jWRSZmz+MsyDA1inJ8kdG0BesYhdm+/CahZJSpkK1wIlTPSL3yRUWMrhrCbmvvBTHIP9KB4PfbZ8rIkonpoi/qXuVub0n+Xa1udgeBRJ0NNX0oBPb2csu4SLs5YjSGlueun7RErLqOw5h4wOQyqBNR7BYjOCyUAslMDvysOeYyc2GeBE8ypeWH4rYZOdOzt2oMvN4kDTOvwJgbW5Mc63zqD6AiyTRqg+vR9LthNTeQFJWWA4YeTMoo3kVOUTOHSeorZTGO66gfa0k8sq4cgoXNuQqeBNnuskduAU7k/e+mvNRPbF4Yen4d6FYFQlTv3H0wwW1rD2toUU/5L/FX+22ivmejAvaMbYWEVK0PPfx+FD4kV0ew5mZuf+HhQcaTQazftNC36X8Kce/NLDE0S270bntGHfuh7R8/v7Z5AeGif8+E7c991Gun+U4I+eQXQ7cdyyBUNdBenuQWJ7TqBGYljWL8Y0twF5fJrwk68g5mdnGkWHIyRPd2BeNhfRYXvzHKDodpIeHkfxhTLTOr74cSxzm5AnvYRHfZx45hTGc+dxOwwUrGshnldI36FOjMEgrk/czKuOZlaWwhMXoNINBkFh9pf/CZ9qQrCaqR7uwBD0E6+vpy3poMIike8f40TcQcVUP1ZBJphdwLTRxWhOGV5XHkJ+Diey6rnq6NMsO/Yiol5EF4uBAILRiKQ3YCrKIRmKE1L07J1/BRXmJDPeOD0ty2kvnYMxFODetqcJrV7Fqax6pmIwu/c01rNn8BeWUxiaoDQ6iePOa8iaVYnFouflXvBYIJ2UsD3/EklV5Nq/2cxjHXpsxsyEjQ/MBrf55+9N/PBZUh19OO+57m0ra6oKCSkzNzeYyHwOJDLTOO5qgVwxSfA7T2BZs4BoYzPbLoLVAFvreduZPlVViR84ReJ4K64P3/CW6/XHZ1UWn91NkRTEeec172mfR41Go/lDogW/S/hTDX7SpJfI9t0A2K/f+K4jyX5fyP4Qge88jufTHyDZ2o3/G49gv3Y9ti0rSRw7T/JsJ4aqEixrFqIvyEEJRwlvew01Gsd+4yaS5zpJtnWjxpNYVs9HQCB28DSxvScwz6kDvR4xxwOyTHpoDCUYJRqXGMaBNwaFpiR1q+vJ/uiNTD7yEqdPTKCKIjO33kKv4MFlhhOjcPssWJydYORD/wCBEPaFjehPnkZJSQQdWXgjCrEN6zEODdHtU1neeYipwnJ06RQJo5W+ghqaRy7gCs6gyDKOeBBzKoFssxPFgCmdYDyvnKGGeTS2HcaUiDFZP5sRewELFhfxLeNihi25CMCCcC+f872O5yPX8cR0DrlGCfeLL+FTjIzMW4rpmefYdFkFJdevRNDp8Mfhy4dgPKxiDvi47PgOjpQvIm/VbMJJODeZCWOX12SaIOsECKd+HuYSrx4iORNicP3lBJOZ5sk/Y9Zn2rG4zeB843OxAxxSjMD/PIbtmnWYGqvefHyXN9OKZW4+rK3INHhWEklCD+1AzM/CvnX9W1YXD3fFUR5+hvnrarGsX6xt62o0mj9pWvC7hD+14Cf7Q0Se3Y0SjmG/fiOGkvzf9bf0SymJJIGvP4Ljg1cSfekAkZcO4P7ErUj9w6jJNJaV8zHNrUf2Bgk9+gKxvcdRQxH0FcWIbkemmrNvBENVCcY5dehMxkwhwKl2rBuWYl48C3lihnTPMKb1Szme18yJcYFQTGberqdY5OvE7DAj5mYRP9FGa1YNfdmV7Jm3hYocPVuqMqtXV1Qr+B5+HtP2F3G7DJgEGdUXQpYVpnJLGaloxF5VhPPRx7D4polYnSQrKjBNTmJERozH0ZkMdBQ34VXNzBk8iz0e5lz1QgonB8iLedl5+2fJHuxhyf5nSVrszMxfxGvNlzF/cyP/c1pPJAWFpjT3Ro6wST+KfPv1/LjTxGpXCP83HyXuyWLan8YtxVh/XRMmk57UdIDznUFO98axG2F+IRQWOXi6bA0f2JBL6yR8/Th8dgmIAoxHYDoGspoJcS7Tz0Od67XXsNqMZG9d/Zb2Ke9E9gYIfO9JnB+4CkNF0dvuV1U4OJTZVr7KPEb+zhewX7/xLQERYKRzkrb/3sGaT2/GUv/OLXk0Go3mT4kW/C7hTyX4qYpC5NnXkfpHsV27HmN16e/6W/qVqLKM/6sPgdFI/PBZBLMRY20Zplm1WNcsRMx2I41NEXluL9LYFLHXj2O7ajW2K1ajJpJEnnmN+PHzOG+/Cp3JkBknFwyj+EN4PnsnieOtpPtGsF62jFRjAz86K+AywUwgyZZ9j1OQCmBdsxDzshb8336cH+St40TciaUoly+ugxoPPHIiRe7uV8h+4QUcZgFjfz8oKkpuDslYirgqEndlEZRECqaHMKaSnKlfjK+inpbOo+hyPBxoXMOMLYeqY3upHOvEkYygCgIJqx2DDpxmiOYUkH/mBOhFzJXFPPfBv+CH4SrsBoiGk9SPX+Q26QJrCpLYFzcz2LiAfQdGWNF5AP/OI1yoWUBgwSKuzQuQPzGAXF/H6aSLnV4Xvbi5ao6ZPLvAUCgT8K6shVf6YCQE19fD/Ldns7e/X6pK+JHn0ZcXYV2z8Oe3pyXkGT/ylA9p0os85SM9MIrrIzdcstG2qqoEXjlK2/4eOjZdx9aFdvJ/YVZv6Egr+396guV/dxNZBVpLFo1GowEt+F3S+xX8khd6SHX0Ydu84n0vnFAiMYL3P41pUTPWVQve19f+TaQnvXj/8X9QojFQVMS8LJw3XIZ5+VwEnY700DjR5/eCKKKvKsH/b/djWbUQsTAHQYD0wDhKLE7Of34O0WJGVuClpy8g7DtCZVM+6uQM08tWMl5cw/kpgaPDChvOv0JT7ynKfMNY9RCtrCJkdmK92MHxkvlMO3OZW2Zg80I3gsnA3kePkXXoIBY5iSGVQEym8OYWoZckrIEZIjY3F6vmYg74aek9SdpgoKe4Hnsyhi0V49E1d5MSDNSMddGTU4FdirO+fTfFM8N0lM3GY5CpG7uIPhgkVFCMcfUiUvn5PN9yNUf70qz0dyCd7wBAbm5AbWliRYGE7eVXic2EkGMJ+v1w8sZ7uHaumcLXdhKTdHSs2Ex/UODcRGYbdW0lNGRDqQtshsws3COjsKYMzkzCx3+Ny0ZVVYLfeyITfuMJAIQ3ttLFvCz0+dmIeVmIuZ5LtlKR/SECP3wafX4O5mUteL0J9l1M4FESLM6Ko5ua4ci0kbJ7rqC5UGvJotFoND+jBb9LeL+Cn6qqpNp6iO06jC7Hje3yVe/Lubr00Dihh5/DcdsVv/erfEokRuJMB2KWi/jeEyQv9CB7gwgmA7bLV2G/bgM6sylTgfzCfnQOG7YrV5PqG2bm8/+FZeMyjNWlJE9cQFVk5IkZ8r73j8wYXZwZkej71jO0nHodi5JmNLsUZ5EbMZkg5ItjCgcoDU0QTSmoKRlBr2M6uxg1x4NrepzzTSsYdBYzL9xHydkjWAI+jLEogqKgFwV0Akg6kYQ7C53RSCyeZvc1H2G4poVNP/gXKka7CFvsdJTMJmm2UjI5gF6R8ER8jGSVcKJqIesu7KEgOM5YbhkRk4OW4XPIoo6g1cO3b/l/bL6wm4msIiawkjszitFm4kRuE/3FDTSXmzEIKrazZyi6cJpdcy5nfucRAo4sepesI8cgsWb/00yVVHOxfjGRNERS0JwHDTnwsxNxDmOmtUquDeYXwHdOwt0tmWKPX4eqKCjhKDqn/Vc+b6eqKtLIJIlTF4i9egRpZBLzshYMZYXorBYEqxmd1Ux/ysL+KTOuHBvmkjyua/j1vjeNRqP5Y6cFv0v4XWz1pvtHie48AALYLl/9jueb/i+URBKd+eeNluPHzhPffwrXvTchOu2XeObvljQ+TfTlg6R6hlD8YWRfAEN5EfEDp7FuXo7nU7cjZrtRVZXQg88i6EVsV67ONFx+7EX8//4Aumw3Yo4bQ0MVEb2Z5Iv7iNjdqKk0icIiTD29eNyZJsehsip8MzFCSZWEyUyBfwJXfy9JnZ5YUQm2inx6ymdhGh0l+9wpYnY3Jr+P7Mg0elkmjQ5ZpyNicxGrqCJruI+IaKa3pIGcmXH8jizOz17N+pPPU9LfjqCo+O0eTlUvoWKmHwGVgeJ6OqrmEhQs3PPad6ia6CFutBJ055DlnyRhtTNeUIkuniCUnY/bN8mZmiX0FNRiqSpm3FVAT1hPjgU2V0OdJUb2jh104eGMkEfFhVP0L15Nor4eJRxhxWtP0jpnFRNltZj1mWrZa2phQTFkWzKrfv/b/sHMtI2NVW+/772iptKkOvpInOlAGptCjSVQInFsV67Cum7xu64Iygq0TcHs/EyRiUaj0Wh+Tgt+l/C7POMnTcwQ3XkAxRfCumUFxqbqX7saUfYGiB89T6qtOzMpwe3EfsNG4ruPoiZSOG6/4vdyMsGbK6C7j6IqKmoiiT4/G8v6xXi/+F0SJ9qw37SJ7L/6EDq7FYDBZw5wdEgmWliCu7Od4r2v4urrIlhZy8U7P0akuoaax36Mxz+JddUCsqdHCR5tRe4ZQlddSkTSMbjhSgquWoF9cpSzD7xKefspUqKB7g/dy0c+uRBjOknsQi+n9vah7DlMwOImNTZNTiKANRHGGo9iMomEjXYs0SC6WIyI2cGEp4icqJeE3Ykoy9iDM3gtHmyJMO2lswnmF1Ez2E5fXjVjDXOJC3oWH36eBf0nkXUGAg4P9lgYSyrGZFYRZyoWoIgijnSMsepZdK3czHBMj0HMVMt641DlgSo3zJu5SP4LOxg1ZuFP6wk1NDHnytnUlVrJCU1jeGI7WXdf8+Zc219FKAn3n4bPLX17sFJVFSQZVZJBUUBVURUFFPUXfq9COp1pfB1LZD7Hf/5ZjSeQvUFUWcZYV4EST5LuHsSyYh6WlfN+L69ZjUaj+UOhBb9L+H0o7pBDEWKvHCLdP4pgMWEoL8JQWYKhsvjN0POWxwfDmRYm57rQOWxYlrW8OZ4s2d6L9++/hWlRM1l/cffvXS8zNZUmfugM8SPnMFQWoyaSSGPT6AtziR89hzQ0hi7bQ/73//HNHnzGTSvYPypif/pZqguMGEWVVEc/6ZEJ3B+5Ac8nbyM9NM7M33+LVM8gos2K4HbQa85HGRolXllFSd8FbAUezrlrSHT0I6ugtxjpmL2SyXUbifePsaj3GJaAl3goTsNgK8PVszlWOIfFHQeoG25Hl06h6HTIOhFnLERaNDBaWo/OpKdi4AJJQU9MNHGhqImA1c1lba8yXN6EQ4nTU1DLS/OupkbysmHnj6kYakenKIyX19NVOYeSjjNcrJrLk2vvZm7XUTace5mKiV7aW1bhb5zFjDsfW0U++txsTk0INORCsHOYG/f8GHM4yIn5G9hfvIgNi7K5bXam2jbZ0Uf02ddxffzmS/ZnlBSYiqhM9U0T6Bol3jtKcmyaOXngtvDmrDSBX0iAehFBL2ZSoaADnS7zMJ0OdAKCKIIoZsbdvfGhs5oRzKbMtq3FhGC1kDh6jsSJNiyr5mNZPlcLfBqNRvMe0ILfJfw+BL9fpMQSpAfHkAZGSQ+MokTiqKk0gsmImkohjUyis1kwzm3ANLsO0WlHMBsRrGaUQJjQI89jv2Uzqi9E9NUjWNcvwbx0zu9FX7PkuU4iL+zDOKuGdM8QiSPnQC+ixpOIOW6Mc+pJnmzDtmUl+rJCRLeT3rEY3v96mIKeDhxrFmBeNZ905wCJsxexb12PoayQ5PkuUl0DSMMTqNkeRgurORJ30zR8AU+xm7gkMOlP0yO4KZ/sY8+iK2kYvkBX9TyqfIPU9Z1FUnUMOQpIiwaKIpP8aOMnqe49y4f3/gBHIkzCaCFi95AwmvF4J4haHPg9uRROD+MJe5lx5NBWMgdHPMis0TZM6SR9+TXoRRjMKsOciNI4fhFLPAIqnK5ZwovX3svyU69Q2XWW/dd8iJxcG562MzScO4Qz4iOybj2BcBrfTIy4LIAAqqJicVvJSYWpiY1xes1WOlpW0ekT+Nd10JCb+bOOHzpD4uQFXPfehM5sQlUzvfbGwzAZkvEOzJDsG8EyNorZ78VmBEtRDq66ErIbivCU56ITf/UpHL8qVVWRBseJHz2XqaZevQDzspZMUNRoNBrNe0ILfpfw+xT81LSENDKZCX5D40hTXgB0Lgc6iwlMRgzlRZlttngSJZF8Y9ssiRJPIghgv2kzojvT1kJNpYm+cohURx/2Gy77nRV3qJJE+OlXSV7sI3HoLIo3gGAzI9ismJpqMC9ozMzXPdaKvrYUJIXw6U7GT/ZgCgZwiDKm+Q2kh6eIiGYErw+hIBfDvCaMOW44fpL0yXZmymo5VDQXcyjAnLEL+Eqq8Oos5HW0MpZdQv1wGxYkRKMeV2k2JqeFvpq5bMtfhlJazHXNItHP/ye6RBzX8ACe6TF0eh2jjgIkgwFLNExaNOCzZxE32dArMo54iK6SRooiUxTNjOCI+InYXCTNVhK5+Zy2V9Aw2UVx0suUZCRutNB27R1cZx3B/NBjjBo8jNS14ClwEJ/0U2JM0VRixHfZZfwoVk2PD0YDCg34ME1OMFeZoCg0Sdpkxn/55YzKVgIJ+OKanxdgRF7YT1+vn97LrmYqoqAbn8I6PkrO9AieWAC7ScBRlounvhhzdQlifvZv9R8GqqoiDY2TOHaedO8I+vJCzEvmYKgq+b34B4lGo9H8sdGC3yX8LoOfmkyR6hok1d5LenAMQa9HX5qPobwos+L1Hv2FLAfDRJ5+FTUt4bjxssyEinf6flLpzBberzFv9ZeRZvz4//NB0sPjpDr6MDZWY1k1D9umFRiqS9/874sfPU/wR9sw1FcwGDMy3T9D7dbF6AeH6Klq4XxOPZ6ZERZ/7Z/AYSdeX0/C7iIcTuLY/TrHWzbgU82oThtFMyMcWX4NY/ZcLnv5Yaavuppg7zh1vn6We6KUyCE6ElaGplJYSnJY0WjDFphh6Im9OHQyMzEFq3eaUHY+x2qXcdm5V0inFcK5BcRlHW6HASk/F2FolJ7SJryiDcf0JBXTvfgcObhSEXw6K6Uzg7hjASRRjymVIOLJwV1fhhiPMjMZ4fkrP8rF+sV8bLaE+Ph2PBW55IwNsKdhLbr6asYjcHYisxW7pBg+swSOj8HRETDpM9W3Q0G4Zy7k299oo/LkKxxpDeG06ylIBnBYRSzl+RgqijMNrXM971vYSg9PZMJe9xD60gLMS+e85T3XaDQazW+HFvwu4f0MfmpaItUzlAl6/SMIooihrgJTUzWqqCPd0YdpVg1iUd5v5S/H9NA4kW2vIWa7sW1d92alr6qqxPedJH7oDAgC9itXY2qpf/M+eSaAPDaFNDGDYDFhnt/0jmcP3/zvVFXSfSOEn3qF+IFTmccqKpYNS7Eua3nza8MbofSpXUSe34ty5020vd5Jcs4sEksWkT5wghzilN28ltILJ/D/zdcxz2/E/Rd348sv5eCz56n86lc4uuxqBK8XjwVSTjcXV17OGTmLtSdfRF9VytPJcj63+5uoJQUcWnQlwxEdi/ydzAn2YUwlMExOoQyNYpSSiLEYBjnNdFYR2YFJnLEgU+58As4cEgYL/fVziRjtLDvzGt3ZFfTnVFHt7aNutJ1pZx6OdIyozoxFSRPKyUefiGFKp3j9uo9RPdKBc3iA9tLZWDcsxuMwcJ17mtHvbye9ZgXCviP4LtuEubaMHl/maN1gED48N7NNe2QErqqF6izwx+HxC3BNHVR6Mu1Tpr7yEBdPj5G9rJmGW1f91lfy3vG9T6WJHzlH4ug59IW5mbBXW66FPY1Go3kfacHvEt7PBs6xlw9hqC7FOKsGQ0URgl6PqqrEdh4g1TWIZeV8Uu29SGNTiHnZmFrqMDbXvKVFy/+VqijI49OZ6sm+EWKvH0VfUoC+rIDE3lMYmioxL5qFPB0guvMA0tA4xvpKdG47YrYbfVFeZgZuJEbi5AVQVUwLmjIh0GZBlWXSfSMkz3SQ6hpEnvEj5mWRHp5A57Dh+uDVJI6dx/XRGwFQwlEiL+xDHpsm7I9yNmxlSLETueZKmsostEQGcOzZi7GxktiR8yQPnML2sZvxXn4Fu3ph/Fgn13z7/3F88WYq+y6gNNbRO28FibYeTIkYs9wpAr4Ez5Sv4W+Pfp/cWy8jdOUVHPv2TmZPXCRrXg3TMYX4ywfxJQScIR9u7xgmOY1sMmOMR4maHQSd2aSNJkzRMJIgYpQlUgYj50pbMCDTMnyedFYWg1llJPUmjtYsY3nfUcQsFzkXW4nMno3X7IGJKc7NXoVtVjUVHgFfAkp6L1DddpSuFZuo27cT+wev5rBayLKSzEi07Rfhq5dl5uICxNLwXCcEkpni2dXlMCsPku29TH3xu7Q7yqn4mzuor3X/xtfLr0uamCG2+xjSyATmZS1Ylrb83hUWaTQazZ8KLfhdwu9yq1ee8RN86FnM85reNlhempjJFC2092ZaXtRXYmysQsxyoXPafulheNkXJN0zRKpvBGlkAgB9YR46mwUEATkcJbpzP/K4F8uKeZgXNyPo9ehcDvRFeWDUE31+H4JOwH7jJkSPE1WSkKd8KJEY6Wkf8b0nSR5vRfaH0Fkt6Ity0bmdpEcnMVSXkDrVjnFuA9aNywj/eAfGljpISSR7hoiE0ww0L2J8LExWRyunb/kw5vnNmPVgG+qn6YHvMFVRx1hhFRUHX+NcyxoOzd+CPwHZo/38w0Ofo7+4Douc5NSSyxFGJ5AcDqQ1KzAKMqVP/pSQyc6KiXOQ7SFlsRLtn8CciDCeV07OxBCqCmcr5mNQJDaceQmdTgCLCTWRYrSgismSGpRgGGskyHhuGaON8/BmFyIEQ1w1cZRc3zhny+cRGA1wsHkdMbuLO04+jiURw5SK4SsqZyauo3fhKpauq+aWWZlq22hSYftXXiXgjXGgZiU3ntlO5+brKKnN5fKaTLjbOwjfvTLz+P9tLAzBuErVdC/R5/cSvtDP60u3suGDiylzvZdX6KWpikLyXCfxvScQrBasG5f+3jcJ12g0mj8FWvC7hN9F8FNVlcShM8QPncF519ZLziqFN5rcdg6Q6uxHCYRRQlFURX7zfp3FjM7lQLCakUYmUZMpRLcDQ3Uphpoy9CX5b57bU1Npoi8dINUziOPmLehLC0gcPkts7wksK+djWfHWlhqJc50E738aNZVGX5YPaRnZGwBZQSzOw9RUjaG6FGl4gtTF/sw5QYuR6FOvoi/OQ8zNIj48SSKvkGlPPl0RI231S9B5nMx99Rlqe85gvWkLlmSMRFIhEE6j7+ik65a78JDEffAAg84inp91BU0dx1h/8GkKBy4SsziIWx1MOfPw1TWx6r7LyGoq4/5TKnXf+SZJ0cDsxaXYjhxlqKyB+5nD/I6DVE32YJKSDJXUE7E4md15lPLxHhSdDkFVidmcXCxsRNLpqZ7oYtpdwEBZE5LHg19vI2yyU6cPkTXYS0BnoaNsNqVblxEZnuGqF75PsnOQMVMW/UV1HJt3GXVLKvnr5ZBty7S5Ozqs0vn1p7noqqBmcTlzdj/L0Q03kV/mYc9AZmu3xAFfWgvF73A5qmmJxKl24vtPIhbm4u8c5eWmy7h+a/WbK4O/TT+ryk2cbid1sR/TnDqsaxddcutfo9FoNO8vLfhdwvsd/JRwlOCPd2Aozsd2zdrfuI2FqqqZqt5gGCWWQF+Ul6kA/l8kb4DIjteR+kexXbUG8+LZb1lhVCWJ2N6TJE5dAElGMBlQ0xKix4WxpR7SEv6vPIj18hU4P3AVosvx5nPlUIT4yXZ8R9qIne1CGpui7drbaZuzBsPQMCUXz3Jg1fUEvVEWdh1BTqXJHezBpZc4dcXthLLzcZrAY1Kp3PksXbnVGHt6UWIJavvPYRNkCsOTWP0zkJYYzi4lbXcyPXchq6tEnDPjRLtGmOmbQoxGCWfnEZ/VjNQ3ykNLPkD2+BBXnXuegpgXk8VIoKae1qiVOb0nsMVCmFJJwjYn2xfewImqRdy1/yFiRis/vuwTRLJyybKAUYTmLJkbTz6NfLoNX2klvpommgQv0ye7KTl9hLgEz628ndHFK5i25dCYA7IKIyGYikJMgrU9ByixqRQtrsHx/IvYP34rU6KDHh/EJVhWkhmb1jYNSQlqsmB2HhQGx0nsP4E0No1pfhPG5hqGvvcsu1q2cPuVpbjMv9FldOlrLC2Rau8lcbodedKLvrwI84ImDDVl72khkEaj0WjeG1rwu4T3K/gFEhA904n59X04b7sCQ1XJr/V8VZZJtfehs1vRZTkvOQNVVVXksanMKmHXAOn+UdL9o5miCkHA1FKPbdNyBMNbm+VGdx0meaYDMS87M+t2JoDodmJePIv4wTNYVswlMTKD73wvE3MWEj54DsPZ84iRMIoiYI5HEAWVU3d/CvPaJZQbE+T98IfsW3I1iQt9FB/dT8iTg0sn4ZbjKLE4/WWNjOaWM11QTs5QLw1j7VT2tqKPhnEHpokVFdOTW0NR53kSBiNxk40Di6/i3NIt1CbGWdy6l8LedgJxhYTBgpiIY4+HKZ4aQAXiFgfWigI8QoJIXOHVujXY+vtZ0nUYfSqJXpWYdBfytS1/SVbMx/XHn+Gp1R/kfP0yLAaBSg8UJX18PHoU64sv02crxOU0Ml5YxYCzGN/5XpafeoWx4lqeuOGzNFVaMelgcTGMReDoaOY83vpKmO3vJfTSAaas2RR1nGH6uhsYdRVhz3fhT+m4bRZvWbVLRRL0v3qO6UNtTNpyiC9eRO3cIhoNYca++Rj7lm3ljs35WH8LR+mUSIzk2YskznSgJlMYG6sxz29EX5j73r+YRqPRaN5TWvC7hPcr+B3Y1cOZV9o5teRynA4D1R6YVwh1WVBgf3M4wjuSpv2EHtyOsa4ic87OF0IJR+CNt0IwGhE9TgS7BWlwPLMtW5iLmJ9NsrUbfa4H+3Ub0DlsmW3mN7Z2bZctw7RoFoIgkOrsJ/b6MVwfv4WULDAShm4vjA34yH3wIZRQBFJp3N5xHCE/BilJOj8fZdVSXMkwVosB/fJ5jBzvQTrTTjSdKUbocZVhllMYpBSTGzahlBaj1+uYvf9Fpu+6i7A/hq9tgNoTe2g4vgd9Ms5kSTV+WxY7F19Hy3gbaw89y0xRGQg6xuvnUBefwDM+hOqwczK/mTMRGx7/FJLBQPPERexSgpTeQFvNQrKjPuxBL9aQj0F3KWXTA+SGJtEpKoKqYJRStJbOwUKaPAuk/vxjBFw5+BQTWW3n8B/vQJFkjH4f5xuWsjTWx+61txDrGGD16Zdpig7RVjKbVzfdydpaAw05mRm3spJpt7K1VqHCP0Ti4GnCz7yKmp3FsDWX3gWraTZFGOrzUiKHqHIpIIDO6UCfl4U06UWJRLEsm4t5YTOC0UAgAW0TCtH/fhjvlsu5cUM+xveo77ESS5Dq7Cd1oRdpdBLBZsE8twHT3AZtG1ej0Wj+wGjB7xLer+DX3e3ntX6I2D3oBAgnYSYG0XTm7JfDCHpdJizk28BtBqcZPG3nMR07Tvym6yA3G1GXeZxeAFEHogB6KYkYCCHE4xhL89Ghory0B2F6Buv1mzCW5qP2DhJ7+hVSeiMJCeJJhfTQOIlAjIG6uVg7L3Jo3U1MufNJOty4LQKFdpjXtp+i4Die8SGM0Qiu27ZgWLWQ8ZCK/79/inLwBAc338aIMZvKc0c5vv4GYjYns8/sYf7eZ0mUlnG2eSWTLYsx9PWx9OyrFPV2cK5lNX1ZFXgMErO6TlLYdpIJRx57llxDjXeA8vgkDe3H0EUi+LMLSJuseHOKEEWBnmUbaTfk4zl9koLJAWS9AZ3RgCsR4nRhM6XeEbat/iCi3cqasy+zee9PscXDoNNhScYQZYmg1YUtEaW9bDZ1+jBZQooxWw6RcIpc/ziSTs9EeT3WxgrkeIojrnqypkd4uW4D60+9iN1jobi3jY7CRsR7buETS0QODMIjrbAhN8KWeCfWrk7UWAJDRRH+Fw8yYsrm4vqrWb6hhrEItE7BrbMy7zdkVmqVYAR52pcpssnLett1FNm+G53HiXXtot/oenxbD0mzCWNDBcamGvTFv512QhqNRqN5f2jB7xLer+AX2nuS8DceRl+YQ6KklGBKIBBXiaQgrQh44wIxVUTQCaRVASSF0sEOJIOR8fJ6dAY9uOwMLFhJymIHIbOFKAiZhT9FzVRZll84QUnneToWrGG8rA6DKjH7xOvYIwHa112NZLaQVkAvqLikGJXRMRb88JuYywtwrVuAKZlA8QcASI9OkugYIGF3Eiws5cINd+E3OrCdOUPuuZN0NS9lOL+KNU/8D/pYhFfu+mucOQ7EdJJlD34D1WplwuTGlWXBXJKH3DfMuMnDmDmbnGSQ2o7jJGQo6buAySDiNqlY68pIjU6RGBxn1F1MW/lcbMkoqqrSN3sJMdHM7LaDSNEkUZONWG0NanEhxrOtlPqHqZVmmHAX0WUrRheNsOjs67gjXuJGK6ZUAkEUaa9soWSsF5NBQNdQTZe5AF0sTo1/AEdxNm0Vcylb34LhocfwDUyhi8cx2sx0lM7CGo9waPYGLj++neGWxVR+7BrSqsBjbZBOSXzYfwjHQB+BWS34q+oIxhXKv/zvDBXXMnb7HWQ5DShq5tzehspLr/T+b8mOPuL7TuC692YEQcgExUAYecaPPO3PfJ4JoPhDqKry8yf+rxcREEAvYqgpw9Rcg76sQDurp9FoNH9EtOB3Ce9X8Ov0ws4ulfxTR8np7UC56WryqnIpcoDbBJ98SaXSKmWC3MgY1a+/SNe81YwUVBGXQJYUHL5pWs7uxZtTxGD1LBypGM6wD0/Ehzviw5WOoi6cg2vdIoo8Iu7gNDy+A5YvIja3hWgqs5KYbwfzG8f7ws/tJWUwkawoJ7rtNaIxCb/Bhtw3THbrGVJ5+ehys4g73MTGvYiT00y58jlbtRBfVgHzR86h1FRiqSmh9OBuWgubqD2wi2FzDrvW3UpFeIJ1rz9O/tQgE5WNqHYrBpeDyZpmkkmJNU/8D4LRgKITmSmtwt3XjRgMMVJWh6++mbqLJwlGZcLZ+ZRN9GLxzRAVTcgmM1JpMcM55QQMdurSU8TSAqNpE0OOQjyhaTafegFrOk7aYCSRX8CMaOdA3Urm9Z4iP+FjKr+cgsgUufo09uoihpeu4UHzPNxjw2x47oe0NS3FajWQcHowd3ZSPDWAzmnDqqY5MmcD8hUbGQmBPwFLkwNsbH0V3fIFGJfPQ1KgY8dJ3C+9SOGq2TR89obf6PpRwlH8//0TXJ+8jcTe46Q6+0EvIrqdiDkexBx35nOuB53HqQU5jUaj+ROmBb9LeL+C34lReLEb7EawhP1U7nqeSFEpI4tX0R3S0+2F6xtUKs8dIWuoF+ED1+HOteMyg907RXTbq8TPdqIU5KPEEqRHJolW1zC5cCleRy6TliymZBOTEfDFVGrajlIy1MnuZdeScrqx6sFh4s0KVRVwjwxQfvYIF666FXfUR9Ou7Qh6kdy2s1inxglZXYy7iwjrLSiiHsliYbxxLixZQFVgmOJnnmDCmUtcMCIAMUFP85n9mGNh/K48VL0eayLMcGkDWEzUDLZh882gCmCU0hiTcdJmK6IokLY7SKk6BhyFDOeUMlnWwNWvPMCEs4CelhVkTwzinhimp6CWkcuuQpDSFBw9yKyB0+R6R1HQETQ76c+rwhULMGvoPNZ0nLa6xYy5CvAEpijyjZKwOckJTaK4PaTz8xhefzmtDUs5OiFSNNbLrdGzNJqjvDpnE8uef5jx8Sgpg4n0TVtJzZnFVEAiPe1Hzc+l3A1TEzFWnXmZKg8oW7fQm7LS2z6J44Wd1JTbyRWTuP/s9t9o61RVVfxffxh9QS7p/hFsW1Zgmt+kbcdqNBqN5h1pwe8S3q/gNxGBY8MKU74UiWiKZCRBXvtZCjrO8XLxMtwFTlpGW/EXldEzdyXi1DSlx/ZR2nYCZIX+xgWMNsyjtK8NZzJMeNYcHHqZ4o6zsGwBtpXzcFt1JL1h5MeeJVpWwcyylQQSAhNRmIxkpkH44pnZryVijE2vPUrHjXeQO9xD0cnDhO1ZGLq6UMNRhj2lvLLmVoy5HspdkGsHYyqJrfU85ScPYIkEOX/bxwgWFDMdkLAcPMys7hPoYzH23vVZ5lgjbHjhR7hNKoZgACkcQxocQ5eThbE4l3jnIEmDiajeTFS08MLi6+l2lHDnq99HsZjIG+zhp+vvoa98FpsOPIk1FsbnzOGlhdcy4cjDHfYxf/AUd+z7EYZ0ikl3PlGLgxI1QvbkMIZUkkBeERFFj06VcSlJzMvn4lFipE62Er7xBh5ruZ6LERMJCVaWQZkzszLbPRBmzetP0O0uY350AHdtMXnBSYZWbaTfU0ZzHszNU9nxyBmaO48z09RCKimRPTlErpAku9hN/mXziTy1C8/n7nrH9jq/KlVVCXz7pyTPdeK87Uosaxf+xi2ANBqNRvPHTQt+l/C+Bb9vPUbssZcwleZjKMpDl+Mh6XbTnbQytreV4jI3gwuW4xjsJ+fMCcRkAn99M8NrN6GUFGMQQSdAPA2RUBLLufN4Olrxm50kBAOuyREC9izssTCHV1xDPL8AhyEz+eFnH5Y32n4k0yrZDz9Mu7uKWa0HyfFPoKgqg7lVzB46h+J2EliwmPHKRo65ahlSHSTlzDnCupF25nceYbK8jrLBDoaseRRMDmAWFJIFhYTWrWP2hUOUv/oCqeIi1KZ6rHoF6+Hj9N/9YZ7zLGDRT/6HnMlhQhYXIVlkV+NG5vceZ/OZF5lyF5DUm/jupk9jj4eYM3AaSzrJaH4lflcudlLM93WREg20HHqRwZxyTs/byCLvRYr7LmCcnsIR9SOqCnGDlZjZhkVKEHFmEbc6cIQDdNXNJ5WVjdsMXr0De3EW0yYPZ2UP0ZjEmrO7cGRbKS9zcKx5NQM+FVcixGXd+3BIMbpza3AcOECO24C7uYLcxhJcjWUYasoy4+vSEv6vP4zzjqvfsf2J7AsimE3orJduvpfqHiT4wDbkQIjcr38e0fJbbNan0Wg0mj8aWvC7hPcr+A1Op2jd2Yrx4BHMU5NIdgd6s5Hp8RC5YpKCZAA5EkNorsd4x1asc+uxGnUYdJkmwNF0pqGvUcx8GHRgSCdJ7z5M9NndyN4AYn42OqsZy/qlpNevZCZlYCYGQ0Ho9kG4Z4yszjbqzh/GoEjo0ymiFgeTpTXYIwEaLhyls24B+xrWYImEqJ7qodY/SK4UxuU2Y3MaCcVVWvObmJ4MU9l1jrLgKDZRxlRTii7Pg9TRj2DQY7z3ds4Zikg+8izeqMI3Nn+G7MAU9+76FqUzQ5ysXkyBf5xRTzGzxtvJivkYzavkYlYVPfk1rO7cj6zXYzfr6J63EpdvmvPFs7CMDDOcU8rfPPb3TLoKCHpyKfSOYosEMCWjGGQJRdTTdfe9rL9tPsn+MaJxmeC6dcT2HCcUTtPVsoJzE+CNqdQbI9hCfrwjPlpGz7Ow5zhmUcFfUcuIp5QCj57KXAMDcQNn/QbsiShF00PMu2c9noV1P5+IkpaQhidI94+SON2OdeNSzPMa33INqKpKbNdhkq1doKjoywqxbV6B6HnrdSd7A4QeewnBZEQaniDrr+/RWqpoNBqN5lemBb9LeL+CX98jrxJ+5HlEUUBEQYxFkYNRgkkBR44Nn2pGUGQkvQFklaTRTNCdS8iVkznEL4AeFdVsQicK2GJhMBlIFJegOBxYpyfIO3MCQVERQkH0wSBBVy5hdzay3U522IteTuPNLcHZ3Yk9EiBhd5DQmXDMjKPoRCSLBaPNjEdIYtWD6LCgutz4kpDsHMAY8COLeiSDCYeaxF6eD+tXEBUMpJ/bjdg/SFonMuouRpQlZJ0Ovz2LyZxSGsbaKZvoRxFF9s9aT9HUMP78YrLUJGdK5yD7Qjy29FaWDp9hec9hhjdsobj9DIOih0Wte/A6c6nx9pMTmsLm9zKeW4oBFXfYS1oFczKO7HJiaa6h+D8/g7m2FDkQJvjdJ3D85YdoH0ww863HOLD1HnoCOnJsUOwAqx76Ayr3tG1Dd/QkXU2L8a7fyLImB6XOzMzc9mlYWgxLSt44H5lKkx4eJ90/htQ/guwPIej16EvzMVSWYKgsRsx2v+X9l6Z8hB7egamlAevGpW/2Toy+fAid1YztqjXoC3ORJmYIPrAN54euJbLtNWybV2CsLf+tXZcajUaj+eOjBb9LeL+Cn3//GXwvHCJlsZJUIJ2UaR9NY09GcEX9COEoFpsBMZlAH48jGQzINjuqKJI2mIlm5aLIMsZgAFWS0SfiGBJxJL0BVRUwJqOEsgsQLCYcgoQzGcEW8iKFYsRSMJFViKQ3UdN9CkXQkbDYMcppjKqMYDSQrqwgPnsWSbOV2KiXxOgk+qkZDKkktkSEkNVFV3EjBYExbIko41kl+J3ZLOg6gj0awpyOE7C4iJltlHqHEVUFdAKCCjpFQkVgNLeUgaIGGmd6cCWCBIx2TpbOpXqgHdlmx22QMBXk0JlbhbHtIimzhcKElzwxibm7G1WSCVtchF1Z5IenCZts6GNRLFYjWTdtwO/IYd/SraQNJppyIfsnP+XcvLWcMxRR/eIznK9ZBGVFXGkZZ1FOiuGpFF3nRpn10lPETTaELWupzjcgx5JcHE0SCyepdCoU2IFfqKMQRBF9aQGGymIMFcWZKtpLTFGJ7T5G8kwHzjuvQZ+f/bbHSKOTRF48gDzlRfaHyP7CR0mcuYgaT2C/eu1v5XrUaDQazR8vLfhdwvtW1fvlp5Ceeomk3UUkt4Bwdh4XE3aqckT8/gRlljQGOYUxEUNMpzGGglgnxzD7ZtAnk6AqqIBqMJFwOUlbHYjJFGIqQcJqQ5eWMYaDqCqkBZG43giKiivqR68oGOQUqiCQcHlQHA4M8RiS0YQvp4iYaCKuN+ET7UQVHROuAjqLm1Bsdq64sJPAshWUlDrJ8o4zY3IR7xuj4sVnyPJOkNbpMYoQLS3DGfIRLy9HjEWxDA+hi0YzYc2RxUR1E/mpIAavF10oxJQzD0s8QlZoBtlqRc3OZsSaQ1dWFTXT3RT7xrBHA+ijEXRSmojZwUxRBYXeUcJWF153HqUj3Uysvwy/aqKjah7eOfPJt2cmjohnW3FOjbF37mbq/AOsnjhD1sZFNB/YiVxZxsFByL1wmoqxHkY3bCYybz6qyYhsNGGyGllVZ6Kq0PQbFVLI3gChh3ZgbKrCumXlJatw08MTBO9/Cn15EYoviKDX4/7sHVrlrkaj0Wh+bVrwu4T3K/idfXQ/vkdeREil0CfixFMKst5I1OEmWVSMWQ/GZAxjIo7NO4U5HESnKMhmE5LFhphKYAr4EZIpTNEwgiwjGwwIOh2CoiDIMpIoAgI6VUGU0qiCgGIykTJZiTjd6JIp9PEYBlkikFPIxQVrmLZmYZ4Yp6rrLNUTXehVGdVoQBX1qEDUnUNap0eUJPTxKIZ4jLRoQDAZscdDqDoRVdQRmT0LYXYT5ld2E05BTGcCAUKllbStuZpkMMqSAzuouXiSqNVB3Gihv7CWbCHJpLuAEXMOpdEp5p3bg0lKkrZYIRpDJyv484tJubNIp2R8rjwqhtpJ2Jx0XHEjBcPdtK67lh59DpE0uEyQo8ZY9fKjKH92D3XZAvLXf8iYswCfL87hORvJO3ucpYEualwyWX95N86S7Pds9BlkVvni+06SON76rgUevyg9MEb4sZdw3XcrotOOEksg6EUE429hCK9Go9Fo/uhpwe8S3rdZvSdnaN15nqzgDK6pURK9w5RGpzCGgpjSCVRBQFUz/fVkg4GU0YJktoKqYI5GSJvNSEYz9sAMOkVGUFWEVAqdIpMwWojbHOgMBtI2O6ZwkIvzV9NdPRfX9AQVfecpGbiIMRElYbAQNtnQyRLWZAxFEFBFPUa9QLipCcVoJqvtHMZ4FL0iIQBR0YTP4iZlspCrRHGGvOgUBaWogLTNzpQjD6V/GHPAy3BRLbkxHxeuvIkLC9aTnA5w5c4fUXnxFKqq8uANn+dYUQtZUpRP/fQfsISDmJEwp+LYYyFSegMDhbVUDbYjqAqSTiTiyMIkJdHLaaZsueyftR5FbyBusnJ43iacNj2F9kxzaocJlr/+NNKyhVBdQf6ObdhbW0nefiMmOYX52AlcZbmYkjFcH73xN2q18k7kGT+hR1/AUF2K7crVv7SRcqp3iMiTu3D/2e1aAYdGo9Fo3hNa8LuE9yv4jX13G5EfPIkiK0QUEUkGYyqOQRRQZAVFUhB0oAo6JIORlN6IORFDEXRM5ZSgU2T0cpqLVXMJGh2YoyHG8yvID01SMdlHtn8CvZLG5Z8m7MrGkE5iSiYQZAkVUHQipnSCpNODyWUBRSWdTGOPhRHNRtTcLJQVSzBOTGAoyKIjbGaXvgqvYGHN4FHmXjyKPhJGSaWRDUaSFjsxs42A0Y4oS2C3Yasvw5ljZ+ea24gNTVHVehRDwE9uZxt+o4MZezbF3mGMyQSumJ+EzclARRPVAxdwhH2kRAMjueU0954BUUfU4Wb/3M0oSYnqkXYmXIVUxiepmBnAumo+ntoiEN4YQWYyoDMZQRAQjAYct2wh+NAOYnuOkf0vf0b4py+BrCBYTJjqK7Fdu/493UZVFeWNit1unB+4En1R3i99TqprgMizr+O+7zZ0Nst79r1oNBqN5k+bFvwu4f0KfuEdrxN5+lUwGeg+PYrNAO5CF1JDLScK53DF+iIcLjPp4Qkiz+0leeoC+uJc0OmI9Y0TtrqIJ2UsiQi28gIsokrS5SFitpEYmiLtDaJE4yTNVrLGB7FGw8g2K3qDHlNZPmIsRrKslHDPGMm4hCUdQyeKBKtr8eeWMFZSg35qinZDPkOWPIpj06wIdVEZGsHS30ccPYZ4jJjDzfHKxQxmlZEvh1jefQibHtKSwoi9gIBowahKDJY0MOXI5Yrdj3Bi3gaMMzNsOP0ikt6AZDTjiPjBkClmiZhsWFJxFEGHQU4TtruZyCphqLQeo6BQ5e2n97rbuHVtFjz7Co6bNxN5eheWdYuxLJmDqqqoyVTmI55EnvET2bEHecYPioI0Mol5xXxsG5dibKpC0Ovf0/c2PTRO+LGXMC+chWX94l8pUCYv9BDdeQD3fbe/56uOGo1Go/nTpgW/S3i/gp/vm48S/N6TJGSV0dxyhFuuwb6wmbaIidsXmjEKCvETrcRfO4qxqZqgzc30U68jjU5iiYYwxqOoBmMm3CgqqqogKCopo4mkMbOCl3A4EfQGQkXFRMxOrNMT+IwuzL5pptyF7F14JUt7jzKv6wg6QSDsycGgKliSUYRkEjktkyUkyTap6HSQTisEEyo9/7+9O4+Oqzrzvf89NU+qKs2zZEke5AHLsyxssMEGm3lqAgGC03CTkJgkhHTfkHuTQLrfjiGh+6UhBEjnbUh3ICamscEkBsxgA8bGeJAnWbIlS5YsqTSrqlTzsN8/FJQ4NmCBbVno+axVa7nO2aV6/FStVb91ztn7FE1F8/nZPuVCYpMnccklxSw4up2uLft5L2cm/fXtONpaSOtowTHQT1RnIK3PQ0rIjyd7HAFlxJyM0JY/nq233s35zz9BfkMNxkQUo0lPatjHofEz6bGnM65uF3GTmXhODplfvoTglj3EbrqOpVPMeH/1e9zfuhl9qhMVj+Nf8zoqFMF525VoJiNKKXzPvETkQD3xjh7iR47hXnkzKV9adtrDHgwu6zKw7k3inb04b73yhPX4Pk54+z5C7+7CffeX0cym016XEEKIsW1UBb/HH3+cX/ziF3g8HioqKnjssceYN2/ex45fs2YNP/7xj2lqamLChAk89NBDXH755af8fmcr+O1ft4Omt/axy2umqL+diS370EciGE0GlAYoiOkMeJzZmH39WMMBghY7AauTblcWxkQUSyTEobxy6nMn4cko4Lz2GiqadpNIdWG2GEkJ+2krmsShSXPotaViG+jn79Y9RkF/K3abAS0URmc2YZ47DZ3diu9oF90NHvzOdHKD3dj6utFnuOnNzKfxSD/H7Fkk9Qbm1n+AZVweOROyGNhzGH9HP33GFFpyx9NYPJleUwqF3S0kbVZSTUmqNr9IIhjGr7OgNI30YB8Jo5HGtGIyfZ24Qz602eeReuvl7NjUwOMlVxN0uLjz9V8ypfcI2bNKSeblcGxjNWklWYz79vUEX3kH54qrTziFGt5Ti/+5P2EoKSD4xlb06W50bgeJZg+ZT/wEgztlaGysuZ3AHzeTHAgNLc+iGQzoUuzoHDY0hw2dw4Y+w42pvOQTZ/RGahoYWPsm9mULsMyZekrfgXhHD/7VGzDkZuC4dolM3hBCCHFGjJrg9/zzz3P77bfz5JNPUllZySOPPMKaNWuoq6sjK+vEa6bef/99LrzwQlatWsWVV17Jc889x0MPPcSuXbuYNm3aKb3n2Qp+z75yjKO/f4vsYw2kh/pIGE3ELVZM4SDGSIhjaYXEjWZmHNlBwmKls6AUvysdn81NR34ZA2XjceamMtEZJ79mJ7Y9e/BWzKCvYhZRDMQSEIwpTEcamXroQ/IHOrB1d6CZjZgmjkMzGojWt2BZNIemmQt5x2Mk89B+JtftwOD3EkgaeXHqFfj3H2HZ9rWY9GC3GjAnojTMW0SLzkXRji2QSODPyCaWlo4KRxnXeggHUZyxAPFYkqg/SKs7j3fnXoHzwpnMf/LnDISThFxpzO/cjzXLScNDD/NhtxHdf/6e90urmJCquP31J0gzxkn7h6/SMGEWdU+8zHy7F2sySmDdWxgnFJH6j3+P5byJqFicaE0Doe37iNU1kgyECG2pxjS1DPsVi1DhCJbpE7HMOw+AaH0zgT++gy7Fjv3yCzDkZAx9LioaIxkIkfQHSAZCKH+AeFsX0YNHMBRkYTl/BsbSwqHTtyqRwPdfLwOQ8qXlp3RtnorGGFi/iXhzOylfvvy49xdCCCFOt1ET/CorK5k7dy6//OUvAUgmkxQWFvLtb3+b++6774TxN910E4FAgFdeeWVo2/z585kxYwZPPvnkKb3n2Qp+zbUdPPBcB96Cccwsd+INK4yb3mXO1g0YDRp5Xc1oej3Hxk8lkFdAVtSHy5jEXl6MvrwMncmA1uaBPTXoqmahq5wxuORHvw+tqwfV0g7V+zH7vGgWI0mlEe4dIJSZQ7fNTU9AEe7oRXX1UnL0ACaVoKuwjMNpJfRF9VgCXuY3bMNuhMTEMgac6USa2+nHjKu/m7RgH7p0NwGLnZA3hMPfhyMaQFmsBI0WevV2Mvxd6MwmGgsnoyJRxjUfJORKJX/ZHOjqYfecpYTzC9gbS8X24ks4rRrL031k97WjmzmZXr2D3ZXLMe+uZlGimdQVV+H99QuYp09EBUN4/+tliMXBoEdnNmEozMFQkENk/2Hcd3+Z6IEGYg0toBSu795GrOYIgdfew5Cdgf3yC9CnuYb1mcWa2wm9X028sRXjhCIs588guOE9TJNLsZ4/45T+RnhXDYEN72FftgDz7CmyLp8QQogzblQEv2g0is1m44UXXuDaa68d2r5ixQr6+/t56aWXTnhNUVER9957L/fcc8/Qtvvvv59169axZ8+ek75PJBIhEokMPff5fBQWFp7x4Petf97PxHf/REHCizPkI9PfhclhxTilDC3dDQ47Jj2kmEAzDa6Th9FAorOHeHM7oGGaWIwu1Umix0skliQUg9BAhFh7NwOaifaMIjpd2VjajpFXt5eOknIyu4+R6u9Gl+pC8w/Q60hn4+QltOpdzKl7n3ED7RQnvRgTMQ7PuoAGdxGRvgHKDn7IhNZa7PEQAbuL3oQRS2gAZzSAXZdAp9PoxUpQZ6R16mwyIl6q3aU4ujtJD/YyoaMea34GvXlFtGhu2somEw7GiXT0csGuDeTpw6ipExmIwIDdhXXAS/hrKyhxxMl47x3c99zGwAsb0aW5sF9SNdTHaEMLmtGAoTCHaE0DgfWbcN11E/o/n9KNNrQQO9pG5MP9g0uqLFuALsX+uT47pRSxQ030PfYsKhrHceUiDDkZ6NJd6FNd6NwpJyzbEvd0D57WLcjGcdViuZZPCCHEWXM6g9/pv0L+z7q7u0kkEmRnZx+3PTs7m9ra2pO+xuPxnHS8x+P52PdZtWoVP/3pTz9/wcNUp6XRWb4IezjA8thhyr59OaUzC9E0SCoIxyEYA18UBv7qEfBHCPf6iXf0YK2pwR9xEgtZSD9Sh6OjnYTRiDEWxRHxM8HTytQUM+aeXhhXSG5LNc32bLpiRlx1R/E500npPMJX9u748y3VdCggaLYTRI+76zUuiEcwRwLYB7xEjGa8VicxQ5LUfCep4ydyMK2M520VdLlzmevwM+PZpyjZ/T59znRMljw8580mv7eONpuRaCRO2KvjaHEBNVmTKavfw8y2fYQXLeTQ//oKExr3UtpQgykaxpCej677AMkjA6Tc9SWCb34AGseFPgBTWSEAwc07iOytI/V7tw+Fqkh9M/0PP40+Kx33d27FkJF6Wj47TdOINbVhu3AOjuuXEq1rItHVR6y1g2Svj2S/H6WSg2MNBjSLCRWOknLzZXJaVwghxKh2xoLf2fLDH/6Qe++9d+j5R0f8zrQnlsV48afbacwq4/+edye8ZyB1B0xIg6lWP+WeQ7g8x7CH/OjjMRIKDAkw6/SYe3zYOlrRaRplPR0YEzF0OtB0GgmHk0BhIZ2FlbTpUih/8VlMgQH0TV1oSpFl6aS5dCobp12MNRZkID2bmM7ApIMfEDTaSAA2vWJ8Sw2ung50OrBHQ3jmVBGbO4vSXAu6UIiaei/buzRsXW2Uzygi4omTumsjr069hCq7G0fIR3aGmUQsiHl/DfXTKnln4fU47QYmHtrBN/77ATItSXKurCLt27cQO9JCoK4a+zUXE9y4Bdcd1w/1KvzhfmJNbbi+dsMJfVRKMbDmNVQsgfvuW9A0jYRvAP/qDYS378d20TxMk0vwPb0OQ34W9uULh32K92+F3q8mfqwD5x3XoWka5qnjP3asisZI+gPo0lxyWlcIIcSod8aCX0ZGBnq9no6OjuO2d3R0kJOTc9LX5OTkDGs8gNlsxmw+++umZTUfpgoPxQePclX9m4QNZsJRhTEZB5uVSG4uPePy6Cgsx2A1kREP4NxbTUZ1NZrPTzIYJqEgCfTbnHRn5NOXkk5Y6Uhp6CJ192tMG+gl5EolXDaBnrQcdmRNw9zbw8T2Wi7etR4LCczhEKHUNLrHT8aUl4HucAN5u7ahMxshxYTW5yWWlkpO3E/swG72H0zlkJZOHAPJeAJfXGPu7x5nScTPsZKpTNjyB5pmno8+txzr3v3MrX+XfTeuoOTqKgqiCeKvvs2MnoOkL5qAfcl8Yg0t9P2//0Wy30/qP/493l+txvWNLw31KVrXOLjUyXduOS44xbv6iOw+SGT3Qcyzp2JfOh+VSBB4cxvhbXtJhiO4v3EjlrmDk3oss6YQPdSE73evoFnMOK64AEN+9t9+LJ8qsu8w4Z0HcH/r5lMKcprJiD7dPez3EUIIIc5FZ3xyx7x583jssceAwckdRUVF3H333R87uSMYDLJ+/fqhbeeffz7Tp08/5yZ3hHbV4F29gU1bOzHFI7gdJlxWaO2KotAYcLjQG/Rk9LSS0d2GORIkqdPjd2VwZNJMInYn1oCXrvR84glFttdDVr8HcyREVG9EHwpBIsGxtEIas0qJ2x3kuvXkuPSEwgm83QMk+vyEbA7c/d0UNB2gqL0BTSkSJhMJsxn0BuK5OcTzcvG1dKP6fPTZXCT0RowGDaPVhEEl2XvhlVi6Oqls3Eb+eQXsKl/Iu3125qVHuGRuGh9qORw54OHi7evJcg1OxHDdcR26FDvJYJien/4KzWZB0+uxzJ2GfdkCYPC6ON/T63B/7ytoRgOxw82Eq2uJN7WiT3djnjUZ09Tx6CxmorWNDKx9E8P4QqK1jbhuvxpjcd5Jex9v7WDgj++iwhHsly3ENKH4lD6zaEMLA2vfIPU7t8nSK0IIIUaNUTG5AwaXc1mxYgVPPfUU8+bN45FHHuEPf/gDtbW1ZGdnc/vtt5Ofn8+qVauAweVcFi1axIMPPsgVV1zB6tWr+dnPfnZOLucSqWkgeqCe/+gp4IDXhCEeZV5WgkW5Uf74bieu6l2UHavFFAnRmDuBLRVLGHBnMtHXTOXOjZhjYeJmC3FNh9edSUfRRNoLxuPw9pB1+ADbJl9IffE0Jqse5iQ9aO0eosc6iUST6O1mjNnpGLUkziOHKT1ag93bg3X2FFK/eg2Wqgq8MT3rauGZPXDMBwYdTLQEuaV1E45jzRyeOItweibp3k7mvbGGrJJ0Gi6/gf9uT+XC6o0U59tpXLiELmVh6eF3KPAcgXgc68JZWBfPRdM0kv4A3l+/gP2qRegyUul54Ffos9JwLF+IcdI4+h99FuuCmURrGkgOBAdn0s6YjGFc3tDRtnhXHwP/8zo6mxVDaT7h9/fguutL6J2OT/0MEr1eAhveI97SjrG0APOMcoxlhSddry/e3oXvmZdwf+dWuZ2aEEKIUWXUBD+AX/7yl0MLOM+YMYNHH32UyspKABYvXsy4ceN45plnhsavWbOGH/3oR0MLOP/85z8/Jxdwrv3gKAfWbSfhC9De0EOu30NKwItNi6PZrXSm5bNt3Bza5y+irO8o6bs/ZMLh3WS1N9KZnk/T+AoiBQXgD+BsO4qzv4d0Xye9aTm0lU7FZQVdNEqrJYNgVg75xW5KjQHUkWaizR7sVgOu/k4sNTXETRaqb7qTDwpmcbhXozMAgfjgmsbpNrhlCuSkQGP/YABsaAkyo+Y9podayLUkOLBgOU8fdVNau4PKgQbs08eTMT6b7B3bsCSj6HMySHT34frqtRhyM1HRGIHXthCtacBx/VJME4rpf/z32K9ajCEng4GX38b/3J8wV1VgzE7Hfs1F6F0pJHwDxBtbiR5pIX60HRWNoXM7sV1aRXTvIRIdPTi/ei2acXhXICiliB05RmRPHbH6ZnR2K+aKSZimT0TvdJDo89H/q9VDdwkRQgghRpNRFfzOtrMV/GrWfkD/f7+C3mJmuz6P9rIp7M8qJ2h3cm2sjhkHt6A7UEvQH6XDnkEyxUEUPc9dfAcJewpZDnCZodgQJHfzGwT6gzRPnU1hqJuCjkYykkHSdBFsPd30dfjwxTQiDiea1YzF78N55DBBi4PN51/LG3OuJJLQEUuCpkGJG748DcpSYUc7BKLgMMHGI2AzwNw8CMahsyvE4V5FwGDjG7Phiomg1xTRvYcIvrNjcHqyTsOQnY7j+qWg1xN+v5rgpg+xLZmPpfI8NE0jXF1LtLYR582XoZTC+/Ra4k1tJLr7MJ03EWJxVDiCLsWGsbQQY2kBxuK8wduyRWN4/3MtxuJcbMsXnpYJFAnfANG9h4jsPURyIIgKR3B9/UaZkSuEEGJUkuD3Cc5W8Dt4NMQbR3UkTSYih5vpf+19ZrfuxTTgp92ZQ8uUWbTOriIr2878d1+mw+wmfunFzMjXse0YfFDrJ+NQDTPa9mG9+mIqLigl3QJHumO0bthBctc+jqYV4S0upUAfIq35CImObgJKT3pfB1suv40aeyE6wGkGuwmWlMB1k6DZD+tqoScIvjDs6YRoAlKtYDWAXgOXBTJscNE4uGriYGD8W4l+P0mvH2NxHpED9QTWb8JcMQnb0qqho3IqEqX34WdI+4evoplNDGx4l+DGrdiXzsd++YXEPd3onA50Nstxf1upwYA58MpmHFcuwlwx6Yx8TiqZRIUicnpXCCHEqCXB7xOcreC3aX0t3U+uwdXfRSwlhe3Z04lWzaXL5GRJ2gDb6gJc4+5CbdrGgZIKugxOzP192Pu7KbInyc130F1Qyv84Z9IyoMcdDzKvfiul3Y30z5yNz5VG6pYtRIMR+gtLsVRMIvNoHf0HjvJixTWkOEzk2hU93ijG+OAjEogSCkYxRKKYklEsiQhGvUZJpoFpuQZyUg1kuvRkOA0YTAY0owHNaEQz6sFo/PNzw+DEB6MBTdOIt3bgX/M6+sw0HNdchM5hO77fqzdgnlyKuWIS4d0H6fv33+H8ylXYl8z/2N7FmtoY+J+NGAqzsV+5+IRQKIQQQoi/kOD3Cc5W8OvfsIXIa1sw5Kah6fUc7dfY2G7mqrIkH9T6meOtJ9jWQ7IgH6PbQdBopSthxqtM9MUG7wqRYQVHMoLzWBNRf4hmdwGYDJT2HMVkt0BZMX6dhbb+BOkNtQStDnzjykCnw5/Q4bToyEkzEtebOBI2E9BMuF1GjBYTZTkmJuSZKHSCLhFHxRIQj6Nig4/BfydQsRgqnoBYDBX98/5oDGJxkskEmsGA89YrMWSlHff/jzW3M/DSWxiy0km5aTnRxla6//Fh3N+9DdsFs0/as0RPP/4XNoKmkXLDUlkmRQghhDgFEvw+wdkKfslYjERHD/GmNmKNx4i29bBmZ4iSeC+BhI7d2VMwF2ShlMKkKez6JA5DEl0ySSyeZCCUJNHVi4EkHZULKUrVmPjhJo7pXPypdBFeq4tQHMw93Vy/ay09iy7CUzCeQBwWFPzl+r1qDyQUFLlgfCqcXwiTM05+6vZvqXgcFYmdcBpUJZOEtuwm9M5O9Olukv0+jOOLsMyZhtJpBF5+G53FjP2aizFkphLv7qPjzp+Q+v0V2C6cc2KvAiECr2wmfqwDx99d8rHLtAghhBDiRBL8PsHZCn4tHx5h/5s1HLLn4w0mGbd7C6a+PrrMbsonuNllK8JYkI3RpAe9nqjS0Ol1uGx6Ctwa+S49FpeVjhYv/S+/jQcHNZMrCRhtJGIJHMTID3Qyr7WatmuvpzbuJt0G8SQc9UJnAPQ6GJ8GF4+DOXlgGcZk2FhzO/5n/4hmt6LCEYzj8jBOm0Cyp5/QOzuxzJ2G7aJ5aEYDSinCW3bjfWYdSW8A6+I52C6uxDRpHEl/EM+K/4P727eccHo30d1H6IN9RPbUDV7HN33i6f0QhBBCiDFgVNyr94uuQ+9A197JJXWvY0t34L5qIebK87huax75szRubN3LwTd2k6qP4TZDinnw3r09vjh9Hh9t7e24Oo6RdLrQnzedzAwHs5t2kp9mYGK2AZPZwF6Tnd+V3k60x4DFAHU9g/cAnpKpWFncw6RgG2mVk4e1GLFKJgn86V1iDS24Vt6M3ukgmUgQeHkT/Q/9J1iMgwsiJ5MkOntAp2PgpbcByPh/voM+O514i4fwzgMM/M9Gwh/uw3XnDdiXzEclk8QaWojsPkjsyDF0aW4ssyZjv+9ONJ3uTH0UQgghhDhFcsTvMwq+X03sQAOOm5ejT7EDgzNVf/ubvaj3trP46im0TZvFQU+C/oNHsTY24ertwK5LkB/uwVWQRvy6yzhY04v50GHKNC+5ZZnsTh3Ps6ES+tv7yQ92YZhUCnYb+T4PSxONFPYchVAYfVYahux0Igfqcd99CzrLp9+2Lt7Vh++367DMnop18VwAItW1BF/dgmnaeGyXVKGzmEmGI8TqmgaXQwlFcFx5IYa8rKG/k/ANEHz1PWKNrVgWzkTTG4hU1w7OAC4rxDxzMsbSArm3rRBCCHEayKneT3C2gt9fC0YVTW/tw/un92iO26jrhsneJjLifnSpTuxTSsiYPg57VwddHQH2zr+Udks6k/qaKMiyUN2hqN7Xg/vIYSZ11JEX7UPvsOHLzCWrvYk0UwLHrEnYLz0f87QJx82sjdY2MvDy27hXfnnoWr1kIET04BGiNQ0kuvpQySTx5nZiLR7MM8rRpQy+XkPDWFaI7dLzT2lmbbyzl8Cf3iHR2YuxOJeYpwctmcRcMQnzjHL0aa4z02AhhBBiDJPg9wnOVvBraguybd1e7G+9Tdah/Si9nlhqGobifHakTqBt4nR+/KUsgm09NK/dQnDnQXz5RdiSEWwtzYR8YTy2dFQkht5kwJpixuByoHfaSbdppDkNOC0a5mnjMRTmkGjrJLLzIJiNWKsqME+fNLSWXqT2CN6nXsA8YxKJjh40ixnT5FLM08aj2Sz4f/cKhrws7FcvPuntzD5NrLmdgVc2k+joQWc1g06HeeZkrPOno/vz0U4hhBBCnBkS/D7B2Qp+ex7/I/Fn12KYOYW0v7+G/Jkl6PSD17E19MH/fi3BhX01FO3eSv+0CsKaHlv1XlosGeyYUEk8K4sSZ5JZ+XrKM2Fi2uDt1T6SDEdI9nqJNbURa2wl3t4JCrQUGyoQItHrw5CXiQpFIJlEs1uIHmgYvOWZQU8yECLpCxCpriXl5sswlRWe8H+Ie7qJ1Tej4onBJV4SicGlXeKJwRm/8QTR2kYSnT3o3E4sc6dirZohd8AQQgghziIJfp/grC3nMhAEw+DRs0RbF/H2LgaOduBp6KbbG2OHR4cjL41wLInW0cWh4vPwTZ9BZamZhcUwIW3wNmrDoZQi0dHz5zB4jFhDC5iN6KwWdDYrKhojvG0PjuuWoM/OQGe3YppYjGY+/o2iDS0ENryLZjBgnlGOZtSjGQyg16MiEWKNbcSOtJD0+jEU52O7eB7GskK5Zk8IIYQYARL8PsHZCn6RA/V4X3mXnoSJVlsm7QYnkWgCcziItaeLZCBCQzIF/YXzuGhRIdNzNcxnYQ513NON9/97cfDetJmpQ9uVUkT3HSaw8X0MmWnYLrsAQ2bqXyZy7K8nfsyDLsWOaUoZpqnjj3u9EEIIIUaGLOdyDqhPutgVyyel2wOxY7hSbDgn5FN4filFUxdiTBmZe8MacjJwff1GvL9eg+vO69FnphL+YB+hzTswThqH+xtfQrOaiVTX4n/uj5BMYiovwXrBLAyFOXJUTwghhPgCkyN+n1Hd4X4ON/opm5pDWY4R0/DnTJxRiT4f/b9aDYBlzlSsi+aQ7PcT3PQh8cZWzDMmYV04SyZnCCGEEOc4OdX7CUZiOZdzlYoO3oc3svsgoa3V6F0pWC+S6/WEEEKI0URO9YqTUtEYcU838bYu4u2dxFs7UcEwlnnn4V55y+BSLEIIIYQYsyT4jVJKKSK7aojWHCHR2YtSSTSDAUNuJvq8TMzTJ2K/dMHQos5CCCGEEBL8RqFkMIzvv15Cn5mG7ZIq9Jmpn2lhZiGEEEKMLRL8RplofTP+518l5cZLMU0cN9LlCCGEEGIUkeA3SqhkksD6zcSPeUi95ytyClcIIYQQwybBbxRI9HrxPr0Wy+ypuK6+WWbkCiGEEOIzkeB3jgvvPEDw9a04V1yNIS9rpMsRQgghxCgmwe8cFW/rJLDhPTSTkdR/+CqaUT4qIYQQQnw+kibOIcmBIKGte4jsPog+MxXr4rmYygpHuiwhhBBCfEFI8BthKh4nsvcQoferIZ7AUlVB6vdulyN8QgghhDjtJF2MkIRvgMDLm4gf82CumITzK1ehd6WMdFlCCCGE+AKT4HeWKaUIbd5BeOseHDdeimn8lSNdkhBCCCHGCAl+Z1G8tQPfc3/CPH0iqT+4A02nG+mShBBCCDGGSPA7C1Q0xsC6N4l39OC64zr06e6RLkkIIYQQY5AEvzMssvcQA+s3YV++gJQvLR/pcoQQQggxhknwO0MS3X34X3gdncNO6vdXoLOYR7okIYQQQoxxEvxOs7inm8D6TSTDERxXX4SxOG+kSxJCCCGEACT4nTaxFg+BVzaBpmG/cjHGguyRLkkIIYQQ4jgS/D6naEMLgT++g85uxXHdUgw5GSNdkhBCCCHESUnw+4zirR34n38NfYYb561XyExdIYQQQpzzJPh9RprNivPO6+RuG0IIIYQYNST4fUb6VOdIlyCEEEIIMSxy6wghhBBCiDFCgp8QQgghxBghwU8IIYQQYow4I8GvqamJO++8k5KSEqxWK2VlZdx///1Eo9FPfN3ixYvRNO24x1133XUmShRCCCGEGHPOyOSO2tpakskkTz31FOPHj2f//v187WtfIxAI8PDDD3/ia7/2ta/xT//0T0PPbTbbmShRCCGEEGLMOSPBb/ny5SxfvnzoeWlpKXV1dTzxxBOfGvxsNhs5OTlnoiwhhBBCiDHtrF3j5/V6SUtL+9Rxzz77LBkZGUybNo0f/vCHBIPBs1CdEEIIIcQX31lZx6++vp7HHnvsU4/23XLLLRQXF5OXl8fevXv5wQ9+QF1dHS+++OLHviYSiRCJRIae+3y+01a3EEIIIcQXiaaUUqc6+L777uOhhx76xDEHDx6kvLx86HlrayuLFi1i8eLF/OY3vxlWcW+99RZLliyhvr6esrKyk4554IEH+OlPf3rCdq/Xi9MpiywLIYQQYnTz+Xy4XK7Tkm2GFfy6urro6en5xDGlpaWYTCYA2traWLx4MfPnz+eZZ55BpxvemeVAIIDD4eDVV19l2bJlJx1zsiN+hYWFEvyEEEII8YVwOoPfsE71ZmZmkpmZeUpjW1tbueiii5g9ezZPP/30sEMfQHV1NQC5ubkfO8ZsNmM2m4f9t4UQQgghxpozMrmjtbWVxYsXU1RUxMMPP0xXVxcejwePx3PcmPLycrZv3w5AQ0MD//zP/8zOnTtpamri5Zdf5vbbb+fCCy9k+vTpZ6JMIYQQQogx5YxM7ti4cSP19fXU19dTUFBw3L6PzizHYjHq6uqGZu2aTCbeeOMNHnnkEQKBAIWFhdxwww386Ec/OhMlCiGEEEKMOcO6xm80OJ3nwYUQQgghRtrpzDZyr14hhBBCiDHirKzjdzZ9dABT1vMTQgghxBfBR5nmdJyk/cIFP7/fD0BhYeEIVyKEEEIIcfr4/X5cLtfn+htfuGv8kskkbW1tpKSkoGnaGXufj9YLbGlpkWsJPwfp4+khffz8pIenh/Tx9JA+nh5flD4qpfD7/eTl5X2m5fH+2hfuiJ9OpzthJvGZ5HQ6R/WX6VwhfTw9pI+fn/Tw9JA+nh7Sx9Pji9DHz3uk7yMyuUMIIYQQYoyQ4CeEEEIIMUZI8PuMzGYz999/v9wu7nOSPp4e0sfPT3p4ekgfTw/p4+khfTzRF25yhxBCCCGEODk54ieEEEIIMUZI8BNCCCGEGCMk+AkhhBBCjBES/IQQQgghxggJfp/B448/zrhx47BYLFRWVrJ9+/aRLumc984773DVVVeRl5eHpmmsW7fuuP1KKX7yk5+Qm5uL1Wpl6dKlHD58eGSKPUetWrWKuXPnkpKSQlZWFtdeey11dXXHjQmHw6xcuZL09HQcDgc33HADHR0dI1TxuemJJ55g+vTpQwu6VlVVsWHDhqH90sPhe/DBB9E0jXvuuWdom/Tx0z3wwANomnbco7y8fGi/9PDUtba2ctttt5Geno7VauW8885jx44dQ/vlN+YvJPgN0/PPP8+9997L/fffz65du6ioqGDZsmV0dnaOdGnntEAgQEVFBY8//vhJ9//85z/n0Ucf5cknn+SDDz7AbrezbNkywuHwWa703LV582ZWrlzJtm3b2LhxI7FYjEsvvZRAIDA05nvf+x7r169nzZo1bN68mba2Nq6//voRrPrcU1BQwIMPPsjOnTvZsWMHF198Mddccw0HDhwApIfD9eGHH/LUU08xffr047ZLH0/N1KlTaW9vH3q89957Q/ukh6emr6+PBQsWYDQa2bBhAzU1Nfzrv/4rqampQ2PkN+avKDEs8+bNUytXrhx6nkgkVF5enlq1atUIVjW6AGrt2rVDz5PJpMrJyVG/+MUvhrb19/crs9msfv/7349AhaNDZ2enAtTmzZuVUoM9MxqNas2aNUNjDh48qAC1devWkSpzVEhNTVW/+c1vpIfD5Pf71YQJE9TGjRvVokWL1He/+12llHwXT9X999+vKioqTrpPenjqfvCDH6iFCxd+7H75jTmeHPEbhmg0ys6dO1m6dOnQNp1Ox9KlS9m6desIVja6NTY24vF4juury+WisrJS+voJvF4vAGlpaQDs3LmTWCx2XB/Ly8spKiqSPn6MRCLB6tWrCQQCVFVVSQ+HaeXKlVxxxRXH9Qvkuzgchw8fJi8vj9LSUm699Vaam5sB6eFwvPzyy8yZM4cbb7yRrKwsZs6cyX/8x38M7ZffmONJ8BuG7u5uEokE2dnZx23Pzs7G4/GMUFWj30e9k76eumQyyT333MOCBQuYNm0aMNhHk8mE2+0+bqz08UT79u3D4XBgNpu56667WLt2LVOmTJEeDsPq1avZtWsXq1atOmGf9PHUVFZW8swzz/Dqq6/yxBNP0NjYyAUXXIDf75ceDsORI0d44oknmDBhAq+99hrf/OY3+c53vsNvf/tbQH5j/pZhpAsQQgzfypUr2b9//3HXA4lTN2nSJKqrq/F6vbzwwgusWLGCzZs3j3RZo0ZLSwvf/e532bhxIxaLZaTLGbUuu+yyoX9Pnz6dyspKiouL+cMf/oDVah3BykaXZDLJnDlz+NnPfgbAzJkz2b9/P08++SQrVqwY4erOPXLEbxgyMjLQ6/UnzKrq6OggJydnhKoa/T7qnfT11Nx999288sorvP322xQUFAxtz8nJIRqN0t/ff9x46eOJTCYT48ePZ/bs2axatYqKigr+/d//XXp4inbu3ElnZyezZs3CYDBgMBjYvHkzjz76KAaDgezsbOnjZ+B2u5k4cSL19fXyXRyG3NxcpkyZcty2yZMnD502l9+Y40nwGwaTycTs2bN58803h7Ylk0nefPNNqqqqRrCy0a2kpIScnJzj+urz+fjggw+kr39FKcXdd9/N2rVreeuttygpKTlu/+zZszEajcf1sa6ujubmZunjp0gmk0QiEenhKVqyZAn79u2jurp66DFnzhxuvfXWoX9LH4dvYGCAhoYGcnNz5bs4DAsWLDhhaatDhw5RXFwMyG/MCUZ6dslos3r1amU2m9Uzzzyjampq1Ne//nXldruVx+MZ6dLOaX6/X+3evVvt3r1bAerf/u3f1O7du9XRo0eVUko9+OCDyu12q5deeknt3btXXXPNNaqkpESFQqERrvzc8c1vflO5XC61adMm1d7ePvQIBoNDY+666y5VVFSk3nrrLbVjxw5VVVWlqqqqRrDqc899992nNm/erBobG9XevXvVfffdpzRNU6+//rpSSnr4Wf31rF6lpI+n4vvf/77atGmTamxsVFu2bFFLly5VGRkZqrOzUyklPTxV27dvVwaDQf3Lv/yLOnz4sHr22WeVzWZTv/vd74bGyG/MX0jw+wwee+wxVVRUpEwmk5o3b57atm3bSJd0znv77bcVcMJjxYoVSqnB6fY//vGPVXZ2tjKbzWrJkiWqrq5uZIs+x5ysf4B6+umnh8aEQiH1rW99S6Wmpiqbzaauu+461d7ePnJFn4PuuOMOVVxcrEwmk8rMzFRLliwZCn1KSQ8/q78NftLHT3fTTTep3NxcZTKZVH5+vrrppptUfX390H7p4albv369mjZtmjKbzaq8vFz9+te/Pm6//Mb8haaUUiNzrFEIIYQQQpxNco2fEEIIIcQYIcFPCCGEEGKMkOAnhBBCCDFGSPATQgghhBgjJPgJIYQQQowREvyEEEIIIcYICX5CCCGEEGOEBD8hhBBCiDFCgp8QQgghxBghwU8IIYQQYoyQ4CeEEEIIMUZI8BNCCCGEGCP+fwrw+JmBgXDDAAAAAElFTkSuQmCC\n", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], @@ -711,9 +649,9 @@ ], "metadata": { "kernelspec": { - "display_name": "jax0227", + "display_name": "py38", "language": "python", - "name": "jax0227" + "name": "py38" }, "language_info": { "codemirror_mode": { @@ -725,7 +663,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.8.16" } }, "nbformat": 4, diff --git a/examples/nonlinear_heat_pde.ipynb b/examples/nonlinear_heat_pde.ipynb new file mode 100644 index 00000000..c61b92ee --- /dev/null +++ b/examples/nonlinear_heat_pde.ipynb @@ -0,0 +1,417 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "03617dd0-ce8d-4c5f-a9e5-edb8395c21b2", + "metadata": {}, + "source": [ + "# Nonlinear heat PDE\n", + "\n", + "Diffrax can also be used to solve some PDEs.\n", + "\n", + "(Specifically, the scope of Diffrax is \"any numerical method which iterates over timesteps\". This means that e.g. semidiscretised evolution equations are in-scope, but e.g. finite volume methods for elliptic equations are out-of-scope.)\n", + "\n", + "---\n", + "\n", + "In this example, we solve the nonlinear heat equation\n", + "\n", + "$$ \\frac{\\partial y}{\\partial t}(t, x) = (1 - y(t, x)) \\Delta y(t, x) \\qquad\\text{in}\\qquad t \\in [0, 40], x \\in [-1, 1]$$\n", + "\n", + "subject to the initial condition\n", + "$$ y(0, x) = x^2, $$\n", + "\n", + "and Dirichlet boundary conditions\n", + "$$ y(t, -1) = 1, $$\n", + "$$ y(t, 1) = 1. $$\n", + "\n", + "---\n", + "\n", + "We spatially discretise $x \\in [-1, 1]$ into points $-1 = x_0 < x_1 < \\cdots < x_{n-1} = 1$, with equal spacing $\\delta x = x_{i+1} - x_i$. The solution is then discretised into $y(t, x_i) \\approx y_i(t)$, and the Laplacian discretised into $\\Delta y(t,x_i) \\approx \\frac{y_{i+1}(t) - 2y_{i}(t) + y_{i-1}(t)}{2 \\delta x}$.\n", + "\n", + "In doing so we reduce to a system of ODEs\n", + "\n", + "$$ \\frac{\\mathrm{d}y_i}{\\mathrm{d}t}(t) = (1 - y_i(t)) \\frac{y_{i+1}(t) - 2y_{i}(t) + y_{i-1}(t)}{2 \\delta x} \\qquad\\text{for}\\qquad i \\in \\{1, ..., n-2\\},$$\n", + "\n", + "subject to the initial condition\n", + "$$ y_i(0) = {x_i}^2, $$\n", + "\n", + "for which the Dirichlet boundary conditions become\n", + "$$ \\frac{\\mathrm{d}y_0}{\\mathrm{d}t}(t) = 0, $$\n", + "$$ \\frac{\\mathrm{d}y_{n-1}}{\\mathrm{d}t}(t) = 0. $$\n", + "\n", + "---\n", + "\n", + "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/nonlinear_heat_pde.ipynb).\n", + "\n", + "\n", + "!!! danger \"Advanced example\"\n", + "\n", + " This is an advanced example, as it involves defining a custom solver." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "0a89f429-bab4-4a0f-800c-a0c8e1c7bf9b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from typing import Callable\n", + "\n", + "import diffrax\n", + "import equinox as eqx # https://github.com/patrick-kidger/equinox\n", + "import jax\n", + "import jax.lax as lax\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "from jaxtyping import Array, Float # https://github.com/google/jaxtyping\n", + "\n", + "\n", + "jax.config.update(\"jax_enable_x64\", True)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "16da14af-420a-4d25-aa06-515a9baa50c2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Represents the interval [x0, x_final] discretised into n equally-spaced points.\n", + "class SpatialDiscretisation(eqx.Module):\n", + " x0: float = eqx.static_field()\n", + " x_final: float = eqx.static_field()\n", + " vals: Float[Array, \"n\"]\n", + "\n", + " @classmethod\n", + " def discretise_fn(cls, x0: float, x_final: float, n: int, fn: Callable):\n", + " if n < 2:\n", + " raise ValueError(\"Must discretise [x0, x_final] into at least two points\")\n", + " vals = jax.vmap(fn)(jnp.linspace(x0, x_final, n))\n", + " return cls(x0, x_final, vals)\n", + "\n", + " @property\n", + " def δx(self):\n", + " return (self.x_final - self.x0) / (len(self.vals) - 1)\n", + "\n", + " def binop(self, other, fn):\n", + " if isinstance(other, SpatialDiscretisation):\n", + " if self.x0 != other.x0 or self.x_final != other.x_final:\n", + " raise ValueError(\"Mismatched spatial discretisations\")\n", + " other = other.vals\n", + " return SpatialDiscretisation(self.x0, self.x_final, fn(self.vals, other))\n", + "\n", + " def __add__(self, other):\n", + " return self.binop(other, lambda x, y: x + y)\n", + "\n", + " def __mul__(self, other):\n", + " return self.binop(other, lambda x, y: x * y)\n", + "\n", + " def __radd__(self, other):\n", + " return self.binop(other, lambda x, y: y + x)\n", + "\n", + " def __rmul__(self, other):\n", + " return self.binop(other, lambda x, y: y * x)\n", + "\n", + " def __sub__(self, other):\n", + " return self.binop(other, lambda x, y: x - y)\n", + "\n", + " def __rsub__(self, other):\n", + " return self.binop(other, lambda x, y: y - x)\n", + "\n", + "\n", + "def laplacian(y: SpatialDiscretisation) -> SpatialDiscretisation:\n", + " y_next = jnp.roll(y.vals, shift=1)\n", + " y_prev = jnp.roll(y.vals, shift=-1)\n", + " Δy = (y_next - 2 * y.vals + y_prev) / (2 * y.δx)\n", + " # Dirichlet boundary condition\n", + " Δy = Δy.at[0].set(0)\n", + " Δy = Δy.at[-1].set(0)\n", + " return SpatialDiscretisation(y.x0, y.x_final, Δy)" + ] + }, + { + "cell_type": "markdown", + "id": "7482e079-5ed1-4bc7-85a5-dcee3717ce7f", + "metadata": {}, + "source": [ + "First let's try solving this semidiscretisation directly, as a system of ODEs." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d304a9d0-58c7-4d29-91e6-10bc65406b73", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Problem\n", + "def vector_field(t, y, args):\n", + " return (1 - y) * laplacian(y)\n", + "\n", + "\n", + "term = diffrax.ODETerm(vector_field)\n", + "ic = lambda x: x**2\n", + "\n", + "# Spatial discretisation\n", + "x0 = -1\n", + "x_final = 1\n", + "n = 50\n", + "y0 = SpatialDiscretisation.discretise_fn(x0, x_final, n, ic)\n", + "\n", + "# Temporal discretisation\n", + "t0 = 0\n", + "t_final = 20\n", + "δt = 0.0001\n", + "saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t_final, 50))\n", + "\n", + "# Tolerances\n", + "rtol = 1e-10\n", + "atol = 1e-10\n", + "stepsize_controller = diffrax.PIDController(\n", + " pcoeff=0.3, icoeff=0.4, rtol=rtol, atol=atol\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d1ad0404-5a13-4506-bdab-cdcfaf5be609", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "solver = diffrax.Tsit5()\n", + "sol = diffrax.diffeqsolve(\n", + " term,\n", + " solver,\n", + " t0,\n", + " t_final,\n", + " δt,\n", + " y0,\n", + " saveat=saveat,\n", + " stepsize_controller=stepsize_controller,\n", + " max_steps=None,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "28185196-75f2-4465-ad59-ff45ec8c4d01", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcEAAAGiCAYAAACf230cAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABKOklEQVR4nO3df3QU5b0/8PfsJtmgklBLSAhGfviDHwUSDBqDWOGYGtFDjbVeoFR+FOFISQ+YWiFWCUJrWq0UvUbS0sbgrVzQe4R6hRMvRoH6JWIJzal4BAEJQWVX0CYhEfJjd75/pBlcmc+QZzPZZTPvl2fUPDvzPLOzM/vsPPOZz2i6rusgIiJyIFekV4CIiChS2AkSEZFjsRMkIiLHYidIRESOxU6QiIgci50gERE5FjtBIiJyLHaCRETkWOwEiYjIsdgJEhGRY7ETJCKiiNu1axemTp2K1NRUaJqGLVu2XHCZHTt24LrrroPH48HVV1+N8vJy5XbZCRIRUcQ1NzcjPT0dJSUlXZr/6NGjuPPOOzF58mTU1NRgyZIluP/++/HGG28otasxgTYREV1MNE3D5s2bkZeXJ86zdOlSbN26Ffv37zfKpk+fjvr6elRUVHS5rZjurCgREfUeZ8+eRWtrqy116boOTdOCyjweDzwejy31V1VVIScnJ6gsNzcXS5YsUaqHnSAREeHs2bMYMvQy+Lx+W+q77LLL0NTUFFRWVFSEFStW2FK/1+tFcnJyUFlycjIaGxtx5swZ9OnTp0v1sBMkIiK0trbC5/Xjg0ND0Dehe+EipxsD+M41tTh+/DgSEhKMcrvOAu3ETpCIiAx9E1xI6GYn2CkhISGoE7RTSkoKfD5fUJnP50NCQkKXzwIBdoJERPQ1WgDQAtqFZ7xAHT0tOzsb27ZtCyrbvn07srOzlerhLRJERHSOrtkzKWpqakJNTQ1qamoAdNwCUVNTg7q6OgBAYWEhZs2aZcz/wAMP4OOPP8bDDz+MAwcO4Pnnn8fLL7+MBx98UKlddoJERBRxe/fuxbhx4zBu3DgAQEFBAcaNG4fly5cDAE6cOGF0iAAwdOhQbN26Fdu3b0d6ejqefvpp/OlPf0Jubq5Su7xPkIiI0NjYiMTERHzyydVISHB3sy4/rrjiMBoaGnrsmqBdeE2QiIgMHdcEu19HtOBwKBERORbPBImI6JzAv6fu1hEl2AkSEZFB0zum7tYRLTgcSkREjsUzQSIiMmi6DYExUXQmyE6QiIjOCegdU3friBIcDiUiIsfimSARERmcFhjDTpCIiM5x2C0SHA4lIiLH4pkgEREZtIAOrZuBLd1dPpzYCRIR0TkcDiUiInIGngkSEZGB0aFERORcHA4lIiJyBp4JEhGRwWkP1WUnSERE5+gA9G5e1Iuia4IcDiUiIsfimSARERn4KCUiInIuRocSERE5A88EiYjIwJvliYjIuTgcSkRE5Ay2doLFxcW4/vrr0bdvXwwYMAB5eXk4ePBg0Dxnz57FokWL8O1vfxuXXXYZ7rnnHvh8Pst6dV3H8uXLMXDgQPTp0wc5OTk4dOiQnatORETAuTPB7k5RwtZOcOfOnVi0aBHeffddbN++HW1tbbjtttvQ3NxszPPggw/if//3f/HKK69g586d+Oyzz/CDH/zAst4nn3wSzz77LEpLS7Fnzx5ceumlyM3NxdmzZ+1cfSIix+u4Jqh1c4r0u+g6Tde7mxpAdvLkSQwYMAA7d+7Ed7/7XTQ0NCApKQkbNmzAD3/4QwDAgQMHMHLkSFRVVeHGG288rw5d15Gamoqf//zneOihhwAADQ0NSE5ORnl5OaZPn95Tq09E5BiNjY1ITEzEl7uHIeEyd/fqavLj8gkfo6GhAQkJCTatYc/o0cCYhoYGAMDll18OAKiurkZbWxtycnKMeUaMGIErr7xS7ASPHj0Kr9cbtExiYiKysrJQVVUldoItLS1oaWkx/g4EAvjyyy/x7W9/G5qm2fL+iIgiSdd1nD59GqmpqXC5bBrYc1hgTI91goFAAEuWLMFNN92E0aNHAwC8Xi/i4uLQr1+/oHmTk5Ph9XpN6+ksT05O7vIyQMf1yccff7wb74CIKDocP34cV1xxhT2VsRO0x6JFi7B//3688847PdWEpcLCQhQUFBh/NzQ04Morr8T+I2no2zf4F5PnS/NT/9h/yfW7GoThgoY402L9tMe03N9kXg4A7c3xQrl5G+1n1crbzsbKbbeYv9beZr7LtLcKbbfJv079Ul1+820rzR8ImLfh98ttBwLmowG6VK73/OiBJlxI0Vzm5S6hHADcbvNvIZfLvNwd225aHuP2K80PADGx5m3ExLUK8wtte9rENmLjzV+LiRfakMovlcrleAP3ZS2m5Vpf83IkmrcRSDTftgDQ9i3z8pbLg5c5fTqA0VcdR9++fcW6yFqPdIL5+fl4/fXXsWvXrqBfJykpKWhtbUV9fX3Q2aDP50NKSoppXZ3lPp8PAwcODFomIyNDXAePxwOP5/wOpm9fFxISvtEJtgmdYJv8JeNqFzpBoS5d+mL3yx9Bu/Bau1/ooHTz8japPGDecVnV1S7sMm260AlCvrbgF+pq18yXadcVO0Ghno5lensnaP4FK3WCMTHm2yomRuoE5W0bE2u+TKzwmysmzvzzi4mTt3ms8Nsx1mO+TWLipXJhe/SRj0v3JeadtnaJsE0uNS8PXCZ/fm19zd97S4Kwj9h5iUdH958CEUWBMbZGh+q6jvz8fGzevBlvvfUWhg4dGvR6ZmYmYmNjUVlZaZQdPHgQdXV1yM7ONq1z6NChSElJCVqmsbERe/bsEZchIqLQaAHNlila2NoJLlq0CH/5y1+wYcMG9O3bF16vF16vF2fOnAHQEdAyb948FBQU4O2330Z1dTXmzp2L7OzsoKCYESNGYPPmzQA6fuEsWbIEv/rVr/Daa6/h/fffx6xZs5Camoq8vDw7V5+IiBzG1uHQtWvXAgAmTZoUVP7CCy9gzpw5AIDf//73cLlcuOeee9DS0oLc3Fw8//zzQfMfPHjQiCwFgIcffhjNzc1YsGAB6uvrMXHiRFRUVCA+3vyaGRERhchhw6G2doJdueUwPj4eJSUlKCkp6XI9mqZh5cqVWLlyZbfXMVLCca1JrEu13KIusVyIBrN6f+J1OcW2pWuCVruj6uehWm5FuvYnkratRT3SNpHalretebnLat9R3BeU91uL1+w6nqzqkfad6BkAvABdA7o7nBmGa+h2Ye5QIiJyLD5FgoiIzuF9gkRE5FgOuybI4VAiInIsngkSEdE5ARsCY6LoPkF2gqGSxrztjIpSjSgVIzfV18m2KMkQIlDlaE/7IjftajsUUl1S5KaU5cX6/am1Lb1vqW1bI4tD2LbiPq3ahp1f1mIb9jURFrrW/e8xRocSERFd/HgmSEREBi3QMXW3jmjBTpCIiM5x2DVBDocSEZFj8UyQiIjOcdh9guwEiYjoHIcNh7ITDBc7Q7QVf2Wphstbvaac3Nri6e5yMm7Vts3rD6ntHk7MHIpQtq2mmT/Y1q5tbtW2y93zycm7kKv/Gwsozm91XEZR+D9dGDtBIiI6x2H3CbITJCKicxyWQJvRoURE5Fg8EyQionM4HEpERE6l61pI+Ya/WUe0YCd4EVPdkVSjQC0j7KS6bEpuHcoydiXWtrOuUA52KVG2RDm5NexLlG3r5yq0Le5rFpsppH1aoR5yDnaCRER0DodDiYjIsRgdSkRE5Aw8EyQionM4HEpERI7F3KFkCGVcW/rwpV9GVr+YxAhNtfyPYvVWEX7K+TvVo/XsqkvKYxlKBKMYNWrjQa3DvC7NZf4GVSM3AXmbKOcUDcfnqrivWVE+NqT3YXlc2pQHOIqum/Vm7ASJiOgcDocSEZFjOWw4lNGhRETkWLZ3grt27cLUqVORmpoKTdOwZcuWoNc1TTOdnnrqKbHOFStWnDf/iBEj7F51IiLSbZqihO2dYHNzM9LT01FSUmL6+okTJ4KmsrIyaJqGe+65x7Le73znO0HLvfPOO3avOhGR4+kBzZYpFCUlJRgyZAji4+ORlZWF9957z3L+NWvWYPjw4ejTpw/S0tLw4IMP4uzZs0pt2n5NcMqUKZgyZYr4ekpKStDff/3rXzF58mQMGzbMst6YmJjzlo0omy78Wu0sckSbPeWWbYvLSAuEkL9TiGC0K6+nVI9lGzZFxVoRc4cK2zYQwm9VqQ3VbSVG3gpPj+94UW3fUY0atWrDrmPAzmjgaAoSiaRNmzahoKAApaWlyMrKwpo1a5Cbm4uDBw9iwIAB582/YcMGLFu2DGVlZZgwYQI++ugjzJkzB5qmYfXq1V1uN6LXBH0+H7Zu3Yp58+ZdcN5Dhw4hNTUVw4YNw8yZM1FXV2c5f0tLCxobG4MmIiK6gM7o0O5OilavXo358+dj7ty5GDVqFEpLS3HJJZegrKzMdP7du3fjpptuwo9+9CMMGTIEt912G2bMmHHBs8dvimgnuH79evTt2xc/+MEPLOfLyspCeXk5KioqsHbtWhw9ehQ333wzTp8+LS5TXFyMxMREY0pLS7N79YmIep/O6NDuTsB5JyItLS2mTba2tqK6uho5OTlGmcvlQk5ODqqqqkyXmTBhAqqrq41O7+OPP8a2bdtwxx13KL3diHaCZWVlmDlzJuLj4y3nmzJlCu69916MHTsWubm52LZtG+rr6/Hyyy+LyxQWFqKhocGYjh8/bvfqExGRhbS0tKCTkeLiYtP5Tp06Bb/fj+Tk5KDy5ORkeL1e02V+9KMfYeXKlZg4cSJiY2Nx1VVXYdKkSXjkkUeU1jFi9wn+7W9/w8GDB7Fp0yblZfv164drr70Whw8fFufxeDzweDzdWUUiIufRYcPN8h3/OX78OBISEoxiO7+Td+zYgSeeeALPP/88srKycPjwYSxevBirVq3CY4891uV6ItYJ/vnPf0ZmZibS09OVl21qasKRI0dw33339cCaERE5mG7DzfL/7kQTEhKCOkFJ//794Xa74fP5gsp9Pp8YEPnYY4/hvvvuw/333w8AGDNmDJqbm7FgwQL88pe/hMvVtYFO2zvBpqamoDO0o0ePoqamBpdffjmuvPJKAB3jxK+88gqefvpp0zpuvfVW3H333cjPzwcAPPTQQ5g6dSoGDx6Mzz77DEVFRXC73ZgxY4Y9K23nPS2hRLqpNmHX089tfOp7WJ4sb1M0KSB/HqHUpUqqS4rodAlho7pmEXkrvA/NL+UnVYsmtfNzjeT+GQrlqO2QGrGvqmgQFxeHzMxMVFZWIi8vDwAQCARQWVlp9APf9NVXX53X0bndbgCAbpXc9hts7wT37t2LyZMnG38XFBQAAGbPno3y8nIAwMaNG6HrutiJHTlyBKdOnTL+/uSTTzBjxgx88cUXSEpKwsSJE/Huu+8iKSnJ7tUnInI0XbdOkN7VOlQVFBRg9uzZGD9+PG644QasWbMGzc3NmDt3LgBg1qxZGDRokHFdcerUqVi9ejXGjRtnDIc+9thjmDp1qtEZdoXtneCkSZMu2AsvWLAACxYsEF+vra0N+nvjxo12rBoREV1IhBJoT5s2DSdPnsTy5cvh9XqRkZGBiooKI1imrq4u6Mzv0UcfhaZpePTRR/Hpp58iKSkJU6dOxa9//WuldplAm4iILgr5+fni8OeOHTuC/o6JiUFRURGKioq61SY7QSIiOsdhT5FgJ0hERAZd17odSGRnIFJP46OUiIjIsXgmGGlWMUR23aagmBQ6lDZUb1+wbsOedbJM3q14K0QgDAm0pS0l3u4g3NYAhLKt1OqxTE4uJd2OsfH2GrsSnYvlYtO9H4dDiYjIsSIUHRopHA4lIiLH4pkgEREZnBYYw06QiIjOCUB8wLNSHVGCw6FERORYPBO0YhUhJvzS0aVfQDYmIpYXUIuAszcJslo9gBxhKEYX2hSxarWMX4waFatSJr4PISG222W+U1m9PykCVTm6VzGxdsdrUrmNSa9DSLrd021Lx74mfidYrFgkOSwwhp0gEREZ9IDW7afe2PnUnJ7G4VAiInIsngkSEdE5HA4lIiKnctotEhwOJSIix+KZ4MVM+jUl5U1UzadodfG6h/OWhlKXXdGkgHoUaCj5VyVS5CaESExpXcV6IG8TMW+pEIEajs9VjrYMoQ1pGak8is5Ywka3IXdoFG1XdoJERHSOw64JcjiUiIgci2eCRERk0PXuJ4ewM7lET2MnSERE5zjseYIcDiUiIsfimWCoVKPNpPyBIfxiUs/fKZSHFH2n+ER2yyfLm5erPvVdNQ+oZdshRJqqkurShJVyuc13Hqv3p5w7VNzmfqX5rdpQ3Xes90/VcsXPz+q4VMwPHE1nRYDz7hNkJ0hEROcwOpSIiMgZeCZIREQGpz1Fgp0gERGdo8OG4VBb1iQsOBxKRESOZXsnuGvXLkydOhWpqanQNA1btmwJen3OnDnQNC1ouv322y9Yb0lJCYYMGYL4+HhkZWXhvffes3vVe1bnxWaFqXNY4puTuIzU9L+jvVSmgN98ktYpEHCZTtZtuEwn1fmltgMBTZ4U25Ymf8ClPKl/FtL7tnh/0jaxaZtbrq+0L0j7jrSvhbDfqh5/yseY1dRLhLLdlT6Li4ztnWBzczPS09NRUlIiznP77bfjxIkTxvTf//3flnVu2rQJBQUFKCoqwr59+5Ceno7c3Fx8/vnndq8+EZGzdd4s390pSth+TXDKlCmYMmWK5TwejwcpKSldrnP16tWYP38+5s6dCwAoLS3F1q1bUVZWhmXLlnVrfYmIyLkick1wx44dGDBgAIYPH46FCxfiiy++EOdtbW1FdXU1cnJyjDKXy4WcnBxUVVWJy7W0tKCxsTFoIiIia525Q7s7RYuwd4K33347XnzxRVRWVuK3v/0tdu7ciSlTpsDvN89McerUKfj9fiQnJweVJycnw+v1iu0UFxcjMTHRmNLS0mx9H0REvZHTrgmG/RaJ6dOnG/8/ZswYjB07FldddRV27NiBW2+91bZ2CgsLUVBQYPzd2NjIjpCIiIJE/BaJYcOGoX///jh8+LDp6/3794fb7YbP5wsq9/l8ltcVPR4PEhISgiYiIroABsaE1yeffIIvvvgCAwcONH09Li4OmZmZqKysRF5eHgAgEAigsrIS+fn5PbpuWoQ/SPVE2WqJiK2GLFSTVauWW72mmkA7oPi+reqSklLbObzjF+qSkl67XeYZm63eX8Bl/prLpm0e8X1HcZ+2LbG2jSL9/SKxYzgzmoZDbT8TbGpqQk1NDWpqagAAR48eRU1NDerq6tDU1IRf/OIXePfdd1FbW4vKykrcdddduPrqq5Gbm2vUceutt+K5554z/i4oKMC6deuwfv16fPjhh1i4cCGam5uNaFEiIqJQ2H4muHfvXkyePNn4u/O63OzZs7F27Vr885//xPr161FfX4/U1FTcdtttWLVqFTwej7HMkSNHcOrUKePvadOm4eTJk1i+fDm8Xi8yMjJQUVFxXrAMERF1lx03/0fPmaDtneCkSZOgW8THvvHGGxeso7a29ryy/Pz8Hh/+JCJyOg6HEhEROUTEA2OIiOgiYkd050Ua9GOGnWCopNN9KWothAhN6XEkYlSgGCWpHvEoR24K5X7zNqT5O15TjAIV5hfXSZgfkKNAVdfJTlJ0aEh1BczrCmjm78+lm0egqm5zQH1fCC06VPhsVSOFQ4gaFetSPC4vVnZkfGHGGCIioijAM0EiIjI4LTCGnSAREZ1jx/MRo6gT5HAoERE5Fs8EiYjonIBmmZKvq3VEC3aCADTzwLjQqA4DWOW3VIw2k6LyLtr8j4rvQ8oRqvq+gVAiU83rCeXahxQF6hLyfUrr6rKIJhW3oUuIGpW2obt37DvKUaOhfInbOARo63eSIqddE+RwKBERORbPBImI6ByHBcawEyQiIgOHQ4mIiByCZ4JERGTQAx1Td+uIFuwEwyWECDixKmkZ1eg7i9yaqvlGlaP1EEqEplrbUn5Qy2X89kUwSqToUOkRZG63eT1W709qQ9yGipGpVp+rHmPPfmi5f9p0DCjXb9FGr+Gwa4IcDiUiIsfimSARERmcFhjDTpCIiAxO6wQ5HEpERI7FM0EiIjrHYYEx7AStWIX5Sq+F8gR5iV25FsUn1Fs0rfi09lCeyK4a1SlFJKo+JR6Qo0DF92fnQS3UJecClT4oi/cn1KX5zcv9Qk5RdwjRlsr7iHTMWB1/ik+EV40mtSLWJeYhFSq6SG8j0HWLnKoKdUQLDocSEZFj8UyQiIgMTguMYSdIRETn6BCHm5XqiBIcDiUiIsfimSARERk4HEpERI7FTpDCy2LsXD1JsDB7CLdtyMmO1dq2uk1BuuVBOUm30IZ0G4TlMnbe4qJIbkLahnKMvaYJ29ZlvoxqAvSA2yq5tVCuuG0t90/F/VA5sXYUXdOi7rH9muCuXbswdepUpKamQtM0bNmyxXitra0NS5cuxZgxY3DppZciNTUVs2bNwmeffWZZ54oVK6BpWtA0YsQIu1ediMjx9IBmyxSKkpISDBkyBPHx8cjKysJ7771nOX99fT0WLVqEgQMHwuPx4Nprr8W2bduU2rS9E2xubkZ6ejpKSkrOe+2rr77Cvn378Nhjj2Hfvn149dVXcfDgQXz/+9+/YL3f+c53cOLECWN655137F51IiLqzBjT3UnRpk2bUFBQgKKiIuzbtw/p6enIzc3F559/bjp/a2srvve976G2thb/8z//g4MHD2LdunUYNGiQUru2D4dOmTIFU6ZMMX0tMTER27dvDyp77rnncMMNN6Curg5XXnmlWG9MTAxSUlK6vB4tLS1oaWkx/m5sbOzyskRE1H3f/N71eDzweDym865evRrz58/H3LlzAQClpaXYunUrysrKsGzZsvPmLysrw5dffondu3cjNjYWADBkyBDldYz4LRINDQ3QNA39+vWznO/QoUNITU3FsGHDMHPmTNTV1VnOX1xcjMTERGNKS0uzca2JiHqnzsCY7k4AkJaWFvQ9XFxcbNpma2srqqurkZOTY5S5XC7k5OSgqqrKdJnXXnsN2dnZWLRoEZKTkzF69Gg88cQT8Pv9Su83ooExZ8+exdKlSzFjxgwkJCSI82VlZaG8vBzDhw/HiRMn8Pjjj+Pmm2/G/v370bdvX9NlCgsLUVBQYPzd2NjIjpCI6ALsjA49fvx40He7dBZ46tQp+P1+JCcnB5UnJyfjwIEDpst8/PHHeOuttzBz5kxs27YNhw8fxk9/+lO0tbWhqKioy+sasU6wra0N//Ef/wFd17F27VrLeb8+vDp27FhkZWVh8ODBePnllzFv3jzTZaxOu22hGlUmJgm2M0LTvui7gGoCbSmK0Co6VKxLLbG2FAVq9f5Uo0Cltu3kFiI3/UIQqKZZvD9hm7hc5u/DJTQSEOYPJepXTqxtY/SyXceMVWCH9JpiZKoTJCQkWJ7gdEcgEMCAAQPwxz/+EW63G5mZmfj000/x1FNPXfydYGcHeOzYMbz11lvKG6lfv3649tprcfjw4R5aQyIiZ9L17j8FQnX5/v37w+12w+fzBZX7fD4xFmTgwIGIjY2F2+02ykaOHAmv14vW1lbExcV1qe2wXxPs7AAPHTqEN998E9/+9reV62hqasKRI0cwcODAHlhDIiLnsvOaYFfFxcUhMzMTlZWVRlkgEEBlZSWys7NNl7nppptw+PBhBALnRjE++ugjDBw4sMsdINADnWBTUxNqampQU1MDADh69ChqampQV1eHtrY2/PCHP8TevXvx0ksvwe/3w+v1Gj13p1tvvRXPPfec8fdDDz2EnTt3ora2Frt378bdd98Nt9uNGTNm2L36REQUAQUFBVi3bh3Wr1+PDz/8EAsXLkRzc7MRLTpr1iwUFhYa8y9cuBBffvklFi9ejI8++ghbt27FE088gUWLFim1a/tw6N69ezF58mTj787glNmzZ2PFihV47bXXAAAZGRlBy7399tuYNGkSAODIkSM4deqU8donn3yCGTNm4IsvvkBSUhImTpyId999F0lJSXavPhGRswU0+bqnSh2Kpk2bhpMnT2L58uXwer3IyMhARUWFESxTV1cXdE07LS0Nb7zxBh588EGMHTsWgwYNwuLFi7F06VKldm3vBCdNmgTdYkDY6rVOtbW1QX9v3Lixu6tFRERdEMncofn5+cjPzzd9bceOHeeVZWdn49133w2prU7MHWrFzvyBinlAAcgRpYq5GeWoPKv8j2GIQFXMBSqVhxLRKbYhRgua1xPKwa5p5h+gX7g64ZLmt2haakN12waEN25r5KbifmD1mnJOUTHS02Lj2hXtyfykFwV2gkREZOBTJIiIyLGc1glGPG0aERFRpPBMkIiIvqb7Z4JA9JwJshMkIqJzQnwU0nl1RAl2gqEKIReo6fw2RtmpPz1b/cnryuVCHsmOZezJEaqak9KqLjFSUdy2YhMiKeenS6hM+piscocGhCfL+/3mbUg5Rd1SxLHl52rPvmO1f9p1DIQS1Syuk7CttO7ec0c9ip0gEREZ9IB8S5BKHdGCnSARERkYHUpEROQQPBMkIiKD084E2QkSEZGBnaATBaQQP/s+SNWo0Y6FejYKNLTIVPP5Q8lPqv4EciHiUTHKFLDKN6oaNSo2IXIpBuu6hbdh9f6k3KGaEDUaEN6I+BlJKwWrHK/m84cUoWlXJLSNUaC2CmXHopCwEyQiIkPHk+W7eyZo08qEATtBIiI6x2E3yzM6lIiIHItngkREZGBgDBEROZbTOkEOhxIRkWMpnQlOmjQJGRkZWLNmTQ+tThSRop/EUGzF+aEeUi4nsQ4hCbIUGi+UqyaeBqxC6dVueVB93x112XWLRAgh9uLtC8pVicQE2kIbMUKyR3HbhnB7jeo+Zb1/qh0Dqrf8WAZ22HXsX6SYO5SIiByLw6GCOXPmYOfOnXjmmWegaRo0TUNtbW0PrhoREVHP6vKZ4DPPPIOPPvoIo0ePxsqVKwEASUlJPbZiREQUfk47E+xyJ5iYmIi4uDhccsklSElJ6cl1IiKiCHFaJ8joUCIiciwGxlixMem1GJ1m0Yb0mriMcoJisWnlpMZyMmyL6FApClSM/FObX4r0BIB2aX3FyFv7ftkGhMTsLpf5B6KLH5QcgqcJ20QTQlClbeh2q28P1QhiO/dP8RhQPJYsj0u7knFb7lORCyl12pmgUicYFxcHv9/fU+tCREQR5rROUGk4dMiQIdizZw9qa2tx6tQpBALn/xLdtWsXpk6ditTUVGiahi1btgS9rus6li9fjoEDB6JPnz7IycnBoUOHLth2SUkJhgwZgvj4eGRlZeG9995TWXUiIqLzKHWCDz30ENxuN0aNGoWkpCTU1dWdN09zczPS09NRUlJiWseTTz6JZ599FqWlpdizZw8uvfRS5Obm4uzZs2K7mzZtQkFBAYqKirBv3z6kp6cjNzcXn3/+ucrqExHRBXSeCXZ3ihZKw6HXXnstqqqqLOeZMmUKpkyZYvqarutYs2YNHn30Udx1110AgBdffBHJycnYsmULpk+fbrrc6tWrMX/+fMydOxcAUFpaiq1bt6KsrAzLli1TeQtERGRF10KLh/hmHVEirNGhR48ehdfrRU5OjlGWmJiIrKwssXNtbW1FdXV10DIulws5OTmWHXJLSwsaGxuDJiIioq8La3So1+sFACQnJweVJycnG69906lTp+D3+02XOXDggNhWcXExHn/88W6usQXVX0qKOSk7XlSLQhMj3RRzjVq9plxulf9RsS4pd6hcrp47VMxbGpZ4MGGdhLchRXp2vGa+TTQhoDSSn6tqORBCtLXqsWR1XNp07F+sGBjTSxQWFqKhocGYjh8/HulVIiKii0xYzwQ7M834fD4MHDjQKPf5fMjIyDBdpn///nC73fD5fEHlPp/PMnONx+OBx+Pp/koTETkIzwR70NChQ5GSkoLKykqjrLGxEXv27EF2drbpMnFxccjMzAxaJhAIoLKyUlyGiIhCo+v2TNHC9jPBpqYmHD582Pj76NGjqKmpweWXX44rr7wSS5Yswa9+9Stcc801GDp0KB577DGkpqYiLy/PWObWW2/F3Xffjfz8fABAQUEBZs+ejfHjx+OGG27AmjVr0NzcbESLEhERhcL2TnDv3r2YPHmy8XdBQQEAYPbs2SgvL8fDDz+M5uZmLFiwAPX19Zg4cSIqKioQHx9vLHPkyBGcOnXK+HvatGk4efIkli9fDq/Xi4yMDFRUVJwXLENERN1kx31+UTQcansnOGnSJItchx0RbStXrjQex2TG7DmF+fn5xpnhxcwq56D5AhYvKeaxVI2ys9rRVSNKVZ/6Dqjn/JSfJq72JPpQ2pZ2aSkPqBWX8KEHhDbcwtuQIkABwCWEgUrb0C/kCBXzuEorBfV9QTWnqFVdcrliGyEM5ykf+xcpXhMkIiJyCD5FgoiIDE47E2QnSEREBqd1ghwOJSIix+KZIBERGfSA1u0gn2gKEmInaEV+cLdyLkLVfJ9AKE/iVmvbOneo2jJizk0b85NKkZtSeXu7eu5Qv+JTzkO5KVgXcn6Kayvth+1yGy6pDaHcrjyugHpkseq+1tGIPceGcg5Si7aVy62+XyKo42b37g6H2rQyYcDhUCIiciyeCRIRkcFpgTHsBImIyOC0TpDDoURE5Fg8EyQiIoPTzgTZCRIRkYGdoANpYQjnVb19weo1KXxbDE33C6HpQrnVa1LbcnJr9UTLqsmt24V1leYH5FshAkLYekD6LMQWZNL+5pJfEFZKbkPaJprQhrjNhUTZlrfXCHXJ+46N+6fqLQ8hHJfh+IIPx3cSdWAnSEREBp4JEhGRYzmtE2R0KBERORbPBImIyKDrNuQOjaIzQXaCRERkcNpwKDtBC5YRWtIvJSliTzGiE5Cj7OSoUcWIOau2FRMqi4mWpfcAOXGyHO0ptWFevxQB2tGGlEDbfP5wBOtpwvq6pcbdcl0u4Y24XObbsF2YP0YIl7X+XO3Zd6z2TzkKVLjCI0ZOqx8bdh37jAC9OLATJCIiQ8dTJLpfR7RgJ0hERIaAron3xarUES0YHUpERI7FM0EiIjIwMIaIiJzLhk7QMh3kRYadoN0UIzFDiYATo+mkvJdSzkaL/I9S22KEpmLkHwC0t6vl/JTehzy/2LQYBSpHhwpthBAAIKUC1RRjUF1W709YX7ewjJiXVfiMYoScoh11qe0jqvt5x2tSdKg0v32R0yHlAaaLFjtBIiIycDiUiIgcy2mdIKNDiYjIscLeCQ4ZMgSapp03LVq0yHT+8vLy8+aNj48P81oTETmDHtBsmUJRUlKCIUOGID4+HllZWXjvvfe6tNzGjRuhaRry8vKU2wz7cOjf//53+L+W52r//v343ve+h3vvvVdcJiEhAQcPHjT+1rToOdUmIoomkRoO3bRpEwoKClBaWoqsrCysWbMGubm5OHjwIAYMGCAuV1tbi4ceegg333xzSOsa9k4wKSkp6O/f/OY3uOqqq3DLLbeIy2iahpSUlJ5eNTXKT7AW6rH6xaQYuaYaASdF2FnVpRppavl0d8UowrZ2IYJRaKPN8snyUrkQgSrWpE5q2y1EdIo7Twi/trV283Ipp6hfCH+1+lzt2nes9k+7jgExotPyuBSKVZ9eT0FWr16N+fPnY+7cuQCA0tJSbN26FWVlZVi2bJnpMn6/HzNnzsTjjz+Ov/3tb6ivr1duN6LXBFtbW/GXv/wFP/nJTyzP7pqamjB48GCkpaXhrrvuwgcffHDBultaWtDY2Bg0ERGRtc4zwe5OAM77Dm5paTFts7W1FdXV1cjJyTHKXC4XcnJyUFVVJa7rypUrMWDAAMybNy/k9xvRTnDLli2or6/HnDlzxHmGDx+OsrIy/PWvf8Vf/vIXBAIBTJgwAZ988oll3cXFxUhMTDSmtLQ0m9eeiKj3sbMTTEtLC/oeLi4uNm3z1KlT8Pv9SE5ODipPTk6G1+s1Xeadd97Bn//8Z6xbt65b7zeit0j8+c9/xpQpU5CamirOk52djezsbOPvCRMmYOTIkfjDH/6AVatWicsVFhaioKDA+LuxsZEdIRFRGB0/fhwJCQnG3x6Px5Z6T58+jfvuuw/r1q1D//79u1VXxDrBY8eO4c0338Srr76qtFxsbCzGjRuHw4cPW87n8Xhs2+BERE4R0Lv/FIjOS8kJCQlBnaCkf//+cLvd8Pl8QeU+n880HuTIkSOora3F1KlTz7X57xRRMTExOHjwIK666qourWvEhkNfeOEFDBgwAHfeeafScn6/H++//z4GDhzYQ2tGRORcdg6HdlVcXBwyMzNRWVlplAUCAVRWVgaNBHYaMWIE3n//fdTU1BjT97//fUyePBk1NTVKo34RORMMBAJ44YUXMHv2bMTEBK/CrFmzMGjQIGPseOXKlbjxxhtx9dVXo76+Hk899RSOHTuG+++/Pwwral9UlxQ5FtLTsxWfhh3Kk7sDwtPdpXLpSeOh5H8Uc4RKkZtiudg02hWXkdJ0hvLsUMUYUDFvqVXrLuH9uaT3p/hZWH2u4r6guE9Z7p+quUAVjxnL49Ku7wUbv196g4KCAsyePRvjx4/HDTfcgDVr1qC5udmIFv163xAfH4/Ro0cHLd+vXz8AOK/8QiLSCb755puoq6vDT37yk/Neq6urCwrX/te//oX58+fD6/XiW9/6FjIzM7F7926MGjUqnKtMROQIkbpPcNq0aTh58iSWL18Or9eLjIwMVFRUGMEy3+wb7KLpuh7Kj9mo09jYiMTERBz7fDASEoI3ZJ8TbtNlYj81LwcAnLzEtLj95KWm5a1fXGZa3tJgPj8AnG00b+NMg1De3Eet/Cs58464zBnz66xnzsaZlp9tMS/veM38N1hLm/l2P9sm3D8o/Mpvs3jKwsV4Jigd3m5hgRhNbj1WqCzWbb5MfKz5O/TEmt8hGe8RbjgEEO9pNS3vEy+U9zEPm+9z6RmxjT6XnFVaRixP/Mq0PD7BvBwAPInNpuVx324yLY9JMp8fSXIbbYPMt/uZgcHljY0BDB5wDA0NDV269mal8zty/bCncInL/Pjvqq8CZzD741/Ysl49jblDiYjIsfgUCSIiMjjtKRLsBImIyBDQNRtukYieTpDDoURE5Fg8EwTkqAcLuhgpoZhEN4RQbDGsW0pErBiaDlglNVYrl8LlATkJs5QoOyBscymYRQp+sVqmTZhfzH8utiCTtrq4tmL8i/z+XEK8m1tYYWmbxwhROVafq137juX+Kd1WoXrrRChJrxWPZem7wvJcKZQdyyYcDiUiIsdyWifI4VAiInIsngkSEZHBaWeC7ASJiMig2xAdGk2dIIdDiYjIsXgmaCWECDHVhLwBiyg75Ug3xQTalkmQpWhPobxdiNaTyq1ekxJit0nRpGLCbbFpMQrUPFkV4A8pQZpaG27reMHzWb0/YZuIKduEj8nWz1Vxn7JOvm7PsSFHVIeQ2N7GyPBI0vWOqbt1RAt2gkREZNADmsWTS7peR7TgcCgRETkWzwSJiMjA6FAiInIs5g4lIiJyCJ4J2k01F6FV7lDpAa+KuRZVo0Y7XrMn0lTKDwoAfiFMsl2IzJNzhJqXSxGgACA9ErZdCLkMR+7QgNh2CL+qhRUWnqkrb3NhZa0+V/Xcoer7pxxtrda2GMUYSu7QKAoGscLoUCIiciynXRPkcCgRETkWzwSJiMjgtMAYdoJERGRw2jVBDocSEZFj8UwwVIpRoKpPiQcsnqwtRnvalztUaru9Xe0J8lJeSABokyJKFZ8UL+XilCJAAaBNCJ+U6pIjN9WJ+TuFV3SxFYsnywvl0jZ0Cz/dpc8oNiC/c2lfkPYd1YhOq9ekY0A8ZhQjqoEQnkYfZVGjTguMYSdIREQGp10T5HAoERE5Fs8EiYjIoOuAHkoWiG/UES3YCRIRkUHXbXiUEodDiYiILn5hPxNcsWIFHn/88aCy4cOH48CBA+Iyr7zyCh577DHU1tbimmuuwW9/+1vccccdtq2TJp36Ww0JKP7SUX4atUUb0lOv5RyhapFxHcvY9GT5dqsck+blUgRjm2KOUCkCFJCjQKVlpAhNO3OHyltKekV+f5qwjCbmFDWfX4oCtfpc22PseYK89f6peAxIT4oP4bhUPssRo0blRcTvpDAI6BoC3TwTZGDMBXznO9/BiRMnjOmdd94R5929ezdmzJiBefPm4R//+Afy8vKQl5eH/fv3h3GNiYgcQj93w3yoU0j3DkVIRDrBmJgYpKSkGFP//v3FeZ955hncfvvt+MUvfoGRI0di1apVuO666/Dcc8+FcY2JiKg3ikgneOjQIaSmpmLYsGGYOXMm6urqxHmrqqqQk5MTVJabm4uqqirLNlpaWtDY2Bg0ERGRtYB+7l7B0KdIv4uuC3snmJWVhfLyclRUVGDt2rU4evQobr75Zpw+fdp0fq/Xi+Tk5KCy5ORkeL1ey3aKi4uRmJhoTGlpaba9ByKi3qq7Q6F25B4Np7B3glOmTMG9996LsWPHIjc3F9u2bUN9fT1efvllW9spLCxEQ0ODMR0/ftzW+omIKPpF/D7Bfv364dprr8Xhw4dNX09JSYHP5wsq8/l8SElJsazX4/HA4/HYtp5ERE7gtPsEI94JNjU14ciRI7jvvvtMX8/OzkZlZSWWLFlilG3fvh3Z2dlhWkM1cnJdaQH1UGzVhNihJChu97uFZaRE2WrlANAmvCbdCiElxJZua7D1Fgnh3gJbb5FQHEKyals1Sbe0zWOEzyjW4nOVPnNp35H3NfUE76rHhvhlbXnrklAcZYmyJQE9tP36m3VEi7APhz700EPYuXMnamtrsXv3btx9991wu92YMWMGAGDWrFkoLCw05l+8eDEqKirw9NNP48CBA1ixYgX27t2L/Pz8cK86ERH1MmE/E/zkk08wY8YMfPHFF0hKSsLEiRPx7rvvIikpCQBQV1cHl+tc3zxhwgRs2LABjz76KB555BFcc8012LJlC0aPHh3uVSci6vV0vfu3+UVTYEzYO8GNGzdavr5jx47zyu69917ce++9PbRGRETUiRljiIiIHCLigTFERHTx4HAonWN1Sq8YBSpFjomJfaEe0aYaMSclNAYAv7Be7e2KCbQtIuakRNl+YRtKEZ1S1Gi7xaEsLiNkLg4lx7pE2uouKTxUF/YDyzbMt610wEvbXPqMLD9XMZm6sB8K+5rV/qkaCa0caW1xXCpHgIvzX5xDhk7rBDkcSkREjsUzQSIiMjgtMIadIBERGex4ElIUjYZyOJSIiJyLZ4JERGRwWto0doKhkgL5FHMRWiWa1YUINV01d6gUHSpE61kuI5VL0aRWOSZ7OEdom0UyTmkZv1AuBwSqH+2acL3FJVQVECJWrZIca0J4nnmWTsAt1CV+Rlafq2K0p537pxhRLZVLx5hl7lDFZaKoQwA69qtuJ9Du5vLhxOFQIiJyLJ4JEhGRQbdhODSa7hNkJ0hERAZGhxIRETkEzwSJiMjA6FAnsjMxpBA1J0WOSVFrVsvY9VRt6yfLSzlCpaeDq0UXAvLTzJWjQEN4srwUOSpFh/q7/bVwYW5N+JykBSy2rSa8KEeHmosR9sE4i7ZV9xFpX7N8srzqvq6YUzSU41IMIZZY7VI9v7uJOBxKRETkEDwTJCIiA4dDiYjIsTgcSkRE5BA8EyQiIkMANgyH2rEiYcJOMETy06WFaDMpOs0iokyOgJOekq0WTRrKk7ulZaRckm1WTyAXytuE8lZhkEV6Grx17lDzZcToUKGuQAgDP9JT3wO69FR7YRsK0aQdbZjX1So8pV7KHRor1G/5uQqvKecOtdo/xYhSxWNDzDVqldNXWC8x0tS8/GLNrsnhUCIiIofgmSARERk4HEpERI6lo/sJsDkcSkREFAV4JkhERAYOhxIRkWM5LTqUnaAVq4S4Uji0ankoCbQVkwS3+83TI7e3WyTQFl5rb5cSZQvJka0SaEu3PEgJsYXbFFqEcuk2iI66hNsqhGWkt2HnLRLCpkWsdNXCqmnh9gkpsXascBFIul2lXUoiDYt9QXiD8r5mleDdfJ9WTS5v53EpJta2K+G2A5SUlOCpp56C1+tFeno6/vM//xM33HCD6bzr1q3Diy++iP379wMAMjMz8cQTT4jzS8J+TbC4uBjXX389+vbtiwEDBiAvLw8HDx60XKa8vByapgVN8fHxYVpjIiLn0HFuSDTUKZQzwU2bNqGgoABFRUXYt28f0tPTkZubi88//9x0/h07dmDGjBl4++23UVVVhbS0NNx222349NNPldoNeye4c+dOLFq0CO+++y62b9+OtrY23HbbbWhubrZcLiEhASdOnDCmY8eOhWmNiYico7sdYKjXFFevXo358+dj7ty5GDVqFEpLS3HJJZegrKzMdP6XXnoJP/3pT5GRkYERI0bgT3/6EwKBACorK5XaDftwaEVFRdDf5eXlGDBgAKqrq/Hd735XXE7TNKSkpHS5nZaWFrS0tBh/NzY2qq8sERGF7Jvfux6PBx6P57z5WltbUV1djcLCQqPM5XIhJycHVVVVXWrrq6++QltbGy6//HKldYz4LRINDQ0AcMEVb2pqwuDBg5GWloa77roLH3zwgeX8xcXFSExMNKa0tDTb1pmIqLfSbZoAIC0tLeh7uLi42LTNU6dOwe/3Izk5Oag8OTkZXq+3S+u9dOlSpKamIicnR+HdRjgwJhAIYMmSJbjpppswevRocb7hw4ejrKwMY8eORUNDA373u99hwoQJ+OCDD3DFFVeYLlNYWIiCggLj78bGRnaEREQXYOctEsePH0dCQoJRbnYWaIff/OY32LhxI3bs2KEcLxLRTnDRokXYv38/3nnnHcv5srOzkZ2dbfw9YcIEjBw5En/4wx+watUq02Wk024lFhFwyhFfYkSZRROKCYfVE27LAwF+YZl2v5Qo27weKbrQ6rUWKTpULBeSRQsRoFbLtEuJtcOQQNst7SNCPVa7p7SQlChb2uYxwvyWn6uw2aV9R9rXrPZP1X1aTpQtRY2KTdt27Ft/gL1DQkJCUCco6d+/P9xuN3w+X1C5z+e74GWw3/3ud/jNb36DN998E2PHjlVex4gNh+bn5+P111/H22+/LZ7NSWJjYzFu3DgcPny4h9aOiMiZdJv+UREXF4fMzMygoJbOIJevnwB905NPPolVq1ahoqIC48ePD+n9hr0T1HUd+fn52Lx5M9566y0MHTpUuQ6/34/3338fAwcO7IE1JCJyrkhFhxYUFGDdunVYv349PvzwQyxcuBDNzc2YO3cuAGDWrFlBgTO//e1v8dhjj6GsrAxDhgyB1+uF1+tFU1OTUrthHw5dtGgRNmzYgL/+9a/o27evcdEzMTERffr0AdDxZgcNGmRcRF25ciVuvPFGXH311aivr8dTTz2FY8eO4f777w/36hMRUQ+YNm0aTp48ieXLl8Pr9SIjIwMVFRVGsExdXR1crnPnbWvXrkVrayt++MMfBtVTVFSEFStWdLndsHeCa9euBQBMmjQpqPyFF17AnDlzAJz/Zv/1r39h/vz58Hq9+Na3voXMzEzs3r0bo0aNCtdqExE5QiTTpuXn5yM/P9/0tR07dgT9XVtbG2IrwcLeCepdeEbHN9/s73//e/z+97/voTUiIqJOTKDtQELgnyVdiAST8wqa1yPlOux4TYiAU8yPKOcUtcrNKCwjRY1K+SLFFoBWKXeoENUpRXuq5gEFgBbNb962sE7+MBzWbiHfp1/6XW2x32pC4KEUmSp9EbQKn6tHqAew2BfEiGP1/dOuY0D1GAPkY1nOQ6oeBRrKdxKFhp0gEREZdOjQu9kLd2XE72LBTpCIiAxOGw6NeNo0IiKiSOGZIBERGZx2JshOkIiIvkY944tZHdGCnaAVq58zUiSYEG2mS9FmFvkDVfMg+qUnyEtP4bbIzShF5rUJkW6twj4v5aQErHKBKuYIFcqlCFDrusyXkQL87Mwd6tKlp9qbf34WAZrid5DUdpsYNWpekdXnGifs09K+I+1r1vun+TaRjgHVY8nquJSOZenYl3ONik1QGLETJCIiA4dDiYjIsUJJgG1WR7RgdCgRETkWzwSJiMjA4VAiInIsXev+8351418XP3aCdhNzh6o/PVv5yfKKUXbt7XLbbe1qOUKlJ41LuTgBOUdoi1huHrkplguRnoCch1SMDhXeh63RoULCT+lXdcAqNZXwJSZ9t0lPnI8VyqX9AJAjTcV9R9rXLPZPu44B1WMMkI9lJzwpvjdiJ0hERIaO4dDuncZxOJSIiKKS064JMjqUiIgci2eCRERkcNp9guwEiYjIwOFQIiIih+CZICD/bLEKeRaWkcKnpVBsXcrMDKtwbykRsZBAWwg1lxIXA0C70EabMMqhmgwbAFqEp1eLty9It04ItzWctUigLd0K0SosI0XL+UMY9pFuR5BunZBuhQhoQmJtQE6gLdyGESMlIRd+J8dZ3J4h7gvCsSHta5b7p7hPSwm0hW0rJcO2OC7FY1k1UXYI3y/hEIBuQ3Qoh0OJiCgKOe1meQ6HEhGRY/FMkIiIDBwOJSIiB3PWk+U5HEpERI7FM0ErFhFi4pVjMTpULToNUI9oU40abQshOrRd+IHXLtQjRXQCcoSmFO3ZKoTMSVGgZ8W1kqNAz0rrZBFpahePLkQ2ClG0oSTQduvmn7lbeN9SAu1WIcoUANqFNsR9R4pEtto/bYoCVZ0fsIgcVU2sbfX9EkFOu0+QnSARERmcdk0wYsOhJSUlGDJkCOLj45GVlYX33nvPcv5XXnkFI0aMQHx8PMaMGYNt27aFaU2JiKi3ikgnuGnTJhQUFKCoqAj79u1Deno6cnNz8fnnn5vOv3v3bsyYMQPz5s3DP/7xD+Tl5SEvLw/79+8P85oTEfVuuk1TtIhIJ7h69WrMnz8fc+fOxahRo1BaWopLLrkEZWVlpvM/88wzuP322/GLX/wCI0eOxKpVq3DdddfhueeeC/OaExH1bgFNt2WKFmG/Jtja2orq6moUFhYaZS6XCzk5OaiqqjJdpqqqCgUFBUFlubm52LJli9hOS0sLWlpajL8bGhoAAKdPn3/J1t9kfoHa/ZUcDOE/Yx50cfas+TPWv2oxL29qaxXbaGprMS1vbo81b8NvXn4mYB5EcNYisOKsEKTRIgVKiKmy5G3YJgSutAvlfiF4wy881z5g8WR56TWpXLeoyy4BKTAGQrkwPwD4hWX8urDNhbradPOvCKkcAFqFtqV9R9oPPRb7zpnAWdPyOL95eUy7eblLOMY0i+MSwrHsF479VuG7wvL7pck8tKS5Mbi88/tMtwqSIkth7wRPnToFv9+P5OTkoPLk5GQcOHDAdBmv12s6v9frFdspLi7G448/fl756KuOh7DW1COk49ahx3NzpFcgUqS+4LTFMtJrJ7q5LlHqiy++QGJioi11OS0wptdGhxYWFgadPdbX12Pw4MGoq6uzbWfp7RobG5GWlobjx48jISEh0qsTNbjd1HGbhaahoQFXXnklLr/8ctvqtOOaXvR0gRHoBPv37w+32w2fzxdU7vP5kJKSYrpMSkqK0vwA4PF44PF4zitPTEzkQaYoISGB2ywE3G7quM1C43Ix70mowr7l4uLikJmZicrKSqMsEAigsrIS2dnZpstkZ2cHzQ8A27dvF+cnIqLQdA6HdneKFhEZDi0oKMDs2bMxfvx43HDDDVizZg2am5sxd+5cAMCsWbMwaNAgFBcXAwAWL16MW265BU8//TTuvPNObNy4EXv37sUf//jHSKw+EVGvxWuCYTBt2jScPHkSy5cvh9frRUZGBioqKozgl7q6uqDT+wkTJmDDhg149NFH8cgjj+Caa67Bli1bMHr06C636fF4UFRUZDpESua4zULD7aaO2yw03G7dp+mMrSUicrzGxkYkJiZivPsZxGh9ulVXu34Ge/2L0dDQcNFf4+210aFERKROt+FRSt1/FFP4MKSIiIgci2eCRERk0G0IjImmM0F2gkREZAhoOrRu5v6MpuhQDocSEZFj9dpO8Ne//jUmTJiASy65BP369evSMrquY/ny5Rg4cCD69OmDnJwcHDp0qGdX9CLz5ZdfYubMmUhISEC/fv0wb948NDU1WS4zadIkaJoWND3wwANhWuPI4PMw1alss/Ly8vP2qfj4+DCubeTt2rULU6dORWpqKjRNs3xgQKcdO3bguuuug8fjwdVXX43y8nLldgM2TdGi13aCra2tuPfee7Fw4cIuL/Pkk0/i2WefRWlpKfbs2YNLL70Uubm5OHvWPAN9bzRz5kx88MEH2L59O15//XXs2rULCxYsuOBy8+fPx4kTJ4zpySefDMPaRgafh6lOdZsBHSnUvr5PHTt2LIxrHHnNzc1IT09HSUlJl+Y/evQo7rzzTkyePBk1NTVYsmQJ7r//frzxxhtK7TotY0yvv0+wvLwcS5YsQX19veV8uq4jNTUVP//5z/HQQw8B6EhOm5ycjPLyckyfPj0MaxtZH374IUaNGoW///3vGD9+PACgoqICd9xxBz755BOkpqaaLjdp0iRkZGRgzZo1YVzbyMnKysL1119vPM8yEAggLS0NP/vZz7Bs2bLz5p82bRqam5vx+uuvG2U33ngjMjIyUFpaGrb1jiTVbdbV49YpNE3D5s2bkZeXJ86zdOlSbN26NejH1fTp01FfX4+KiooLttF5n+DomKfh7uZ9gn79DPa3/zwq7hPstWeCqo4ePQqv14ucnByjLDExEVlZWeJzDnubqqoq9OvXz+gAASAnJwculwt79uyxXPall15C//79MXr0aBQWFuKrr77q6dWNiM7nYX59P+nK8zC/Pj/Q8TxMp+xXoWwzAGhqasLgwYORlpaGu+66Cx988EE4Vjdq2bWf6Tb9Ey0YHfpvnc8mVH1uYW/i9XoxYMCAoLKYmBhcfvnlltvgRz/6EQYPHozU1FT885//xNKlS3Hw4EG8+uqrPb3KYReu52H2JqFss+HDh6OsrAxjx45FQ0MDfve732HChAn44IMPcMUVV4RjtaOOtJ81NjbizJkz6NOna2d3AejQHJQ7NKrOBJctW3bexfJvTtJB5WQ9vd0WLFiA3NxcjBkzBjNnzsSLL76IzZs348iRIza+C3KS7OxszJo1CxkZGbjlllvw6quvIikpCX/4wx8ivWrUy0TVmeDPf/5zzJkzx3KeYcOGhVR357MJfT4fBg4caJT7fD5kZGSEVOfFoqvbLSUl5bxAhfb2dnz55ZeWz278pqysLADA4cOHcdVVVymv78UsXM/D7E1C2WbfFBsbi3HjxuHw4cM9sYq9grSfJSQkdPksEHDemWBUdYJJSUlISkrqkbqHDh2KlJQUVFZWGp1eY2Mj9uzZoxRhejHq6nbLzs5GfX09qqurkZmZCQB46623EAgEjI6tK2pqagAg6MdEb/H152F2Bil0Pg8zPz/fdJnO52EuWbLEKHPS8zBD2Wbf5Pf78f777+OOO+7owTWNbtnZ2efdehPKfua0TjCqhkNV1NXVoaamBnV1dfD7/aipqUFNTU3QPW8jRozA5s2bAXREXy1ZsgS/+tWv8Nprr+H999/HrFmzkJqaahmR1ZuMHDkSt99+O+bPn4/33nsP/+///T/k5+dj+vTpRmTop59+ihEjRhj3eB05cgSrVq1CdXU1amtr8dprr2HWrFn47ne/i7Fjx0by7fSYgoICrFu3DuvXr8eHH36IhQsXnvc8zMLCQmP+xYsXo6KiAk8//TQOHDiAFStWYO/evV3uAHoD1W22cuVK/N///R8+/vhj7Nu3Dz/+8Y9x7Ngx3H///ZF6C2HX1NRkfG8BHcF7nd9pAFBYWIhZs2YZ8z/wwAP4+OOP8fDDD+PAgQN4/vnn8fLLL+PBBx+MxOpHjag6E1SxfPlyrF+/3vh73LhxAIC3334bkyZNAgAcPHgQDQ0NxjwPP/wwmpubsWDBAtTX12PixImoqKhw1E26L730EvLz83HrrbfC5XLhnnvuwbPPPmu83tbWhoMHDxrRn3FxcXjzzTeNByOnpaXhnnvuwaOPPhqpt9DjIvE8zGinus3+9a9/Yf78+fB6vfjWt76FzMxM7N69G6NGjYrUWwi7vXv3YvLkycbfBQUFAIDZs2ejvLwcJ06cMDpEoGM0a+vWrXjwwQfxzDPP4IorrsCf/vQn5ObmKrUbAGw4E4wevf4+QSIiurDO+wSHxf4WLq17P/wD+ll83LaU9wkSERFdzHrtcCgREanrCGpxTmAMO0EiIjI4rRPkcCgRETkWzwSJiMjgtyH3ZzSdCbITJCIiA4dDiYiIHIJngkREZHDamSA7QSIiMvi1AHStezlfAlGUM4bDoURE5FjsBIlscPLkSaSkpOCJJ54wynbv3o24uDhUVlZGcM2I1Pih2zJFC3aCRDZISkpCWVmZ8YSI06dP47777jOSkRNFi4ANHWCo1wRLSkowZMgQxMfHIysry3hajeSVV17BiBEjEB8fjzFjxpz3KKmuYCdIZJM77rgD8+fPx8yZM/HAAw/g0ksvRXFxcaRXiygqbNq0CQUFBSgqKsK+ffuQnp6O3Nzc8x703Wn37t2YMWMG5s2bh3/84x/Iy8tDXl4e9u/fr9QunyJBZKMzZ85g9OjROH78OKqrqzFmzJhIrxJRl3Q+ReIyTxG0bj5FQtfPoqnlcaWnSGRlZeH666/Hc889B6DjwctpaWn42c9+hmXLlp03/7Rp09Dc3IzXX3/dKLvxxhuRkZGB0tLSLq8rzwSJbHTkyBF89tlnCAQCqK2tjfTqECnT0QJdP9u9CS0AOjrWr08tLS2mbba2tqK6uho5OTlGmcvlQk5ODqqqqkyXqaqqCpofAHJzc8X5JbxFgsgmra2t+PGPf4xp06Zh+PDhuP/++/H+++9jwIABkV41oguKi4tDSkoKvN7f2FLfZZddhrS0tKCyoqIirFix4rx5T506Bb/fbzxkuVNycjIOHDhgWr/X6zWd3+v1Kq0nO0Eim/zyl79EQ0MDnn32WVx22WXYtm0bfvKTnwQN1xBdrOLj43H06FG0trbaUp+u69A0LajM4/HYUred2AkS2WDHjh1Ys2YN3n77beMayH/9138hPT0da9euxcKFCyO8hkQXFh8fj/j47l0PDEX//v3hdrvh8/mCyn0+H1JSUkyXSUlJUZpfwmuCRDaYNGkS2traMHHiRKNsyJAhaGhoYAdIdAFxcXHIzMwMuqc2EAigsrIS2dnZpstkZ2efdw/u9u3bxfklPBMkIqKIKygowOzZszF+/HjccMMNWLNmDZqbmzF37lwAwKxZszBo0CDjtqPFixfjlltuwdNPP40777wTGzduxN69e/HHP/5RqV12gkREFHHTpk3DyZMnsXz5cni9XmRkZKCiosIIfqmrq4PLdW7wcsKECdiwYQMeffRRPPLII7jmmmuwZcsWjB49Wqld3idIRESOxWuCRETkWOwEiYjIsdgJEhGRY7ETJCIix2InSEREjsVOkIiIHIudIBERORY7QSIicix2gkRE5FjsBImIyLHYCRIRkWP9fwTknDm9uQKpAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(5, 5))\n", + "plt.imshow(\n", + " sol.ys.vals,\n", + " origin=\"lower\",\n", + " extent=(x0, x_final, t0, t_final),\n", + " aspect=(x_final - x0) / (t_final - t0),\n", + " cmap=\"plasma\",\n", + ")\n", + "plt.xlabel(\"x\")\n", + "plt.ylabel(\"t\", rotation=0)\n", + "plt.clim(0, 1)\n", + "plt.colorbar()" + ] + }, + { + "cell_type": "markdown", + "id": "26ba8fec-3ca9-4612-b6f9-83962333d96d", + "metadata": {}, + "source": [ + "That worked!\n", + "\n", + "However, for more complicated PDEs then we may wish to define a custom solver. So as an example, here's how to solve the same PDE using the famous [Crank–Nicolson](https://en.wikipedia.org/wiki/Crank%E2%80%93Nicolson_method) scheme.\n", + "\n", + "(See the page on [abstract solvers](https://docs.kidger.site/diffrax/api/solvers/abstract_solvers/) for more details about how to define a custom solver.)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "059fed69-c042-4fec-bf36-60e365c98de8", + "metadata": {}, + "outputs": [], + "source": [ + "class CrankNicolson(diffrax.AbstractSolver):\n", + " rtol: float\n", + " atol: float\n", + "\n", + " term_structure = diffrax.ODETerm\n", + " interpolation_cls = diffrax.ThirdOrderHermitePolynomialInterpolation\n", + "\n", + " def order(self, terms):\n", + " return 2\n", + "\n", + " def init(self, terms, t0, t1, y0, args):\n", + " f0 = terms.vf(t0, y0, args)\n", + " solver_state = f0\n", + " return solver_state\n", + "\n", + " def step(self, terms, t0, t1, y0, args, solver_state, made_jump):\n", + " del made_jump\n", + " δt = t1 - t0\n", + " f0 = solver_state\n", + "\n", + " def keep_iterating(val):\n", + " _, not_converged = val\n", + " return not_converged\n", + "\n", + " def fixed_point_iteration(val):\n", + " y1, _ = val\n", + " new_y1 = y0 + 0.5 * δt * (f0 + terms.vf(t1, y1, args))\n", + " diff = jnp.abs((new_y1 - y1).vals)\n", + " max_y1 = jnp.maximum(jnp.abs(y1.vals), jnp.abs(new_y1.vals))\n", + " scale = self.atol + self.rtol * max_y1\n", + " not_converged = jnp.any(diff > scale)\n", + " return new_y1, not_converged\n", + "\n", + " euler_y1 = y0 + δt * f0\n", + " y1, _ = lax.while_loop(keep_iterating, fixed_point_iteration, (euler_y1, False))\n", + " f1 = terms.vf(t1, y1, args)\n", + "\n", + " y_error = y1 - euler_y1\n", + " dense_info = dict(y0=y0, y1=y1, f0=f0, f1=f1)\n", + "\n", + " solver_state = f1\n", + " result = diffrax.RESULTS.successful\n", + " return y1, y_error, dense_info, solver_state, result\n", + "\n", + " def func(self, terms, t0, y0, args):\n", + " return terms.vf(t0, y0, args)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "da4511b8-f112-4839-94f5-dfc7728da8ea", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "solver = CrankNicolson(rtol=rtol, atol=atol)\n", + "sol = diffrax.diffeqsolve(\n", + " term,\n", + " solver,\n", + " t0,\n", + " t_final,\n", + " δt,\n", + " y0,\n", + " saveat=saveat,\n", + " stepsize_controller=stepsize_controller,\n", + " max_steps=None,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6667e3c7-5b45-4740-9caf-3e0aa4b1d7a9", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcEAAAGiCAYAAACf230cAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABKRklEQVR4nO3dfXQU5b0H8O/sJtmAklBLSAhGXnzhpUCCQWIQKxxTI3qosdYLlBqgvBwp6QG3VIhFgtCaVitFr5G0tDF4Kxf0HqFe4cSLUaBeApTQnIJHEJAQVHYFbRISSTbZnfsHl4kr8xsym8kuy3w/nlHz7Mwzz87O7LPzzG9+o6iqqoKIiMiGHJFuABERUaSwEyQiIttiJ0hERLbFTpCIiGyLnSAREdkWO0EiIrItdoJERGRb7ASJiMi22AkSEZFtsRMkIiLbYidIREQRt2vXLkyePBmpqalQFAVbtmy57DI7duzArbfeCpfLhZtuugnl5eWm18tOkIiIIq65uRnp6ekoKSnp1PwnTpzA/fffj4kTJ6KmpgaLFi3CnDlz8Pbbb5tar8IE2kREdCVRFAWbN29GXl6eOM+SJUuwdetWHDp0SCubOnUq6uvrUVFR0el1xXSloUREdPVoaWmBz+ezpC5VVaEoSlCZy+WCy+WypP6qqirk5OQEleXm5mLRokWm6mEnSEREaGlpwcBB18Lr8VtS37XXXoumpqagsqKiIqxYscKS+j0eD5KTk4PKkpOT0djYiPPnz6NHjx6dqoedIBERwefzwevx44OjA9EroWvhIucaA/jOzbU4deoUEhIStHKrzgKtxE6QiIg0vRIcSOhiJ3hRQkJCUCdopZSUFHi93qAyr9eLhISETp8FAuwEiYjoa5QAoASUy894mTq6W3Z2NrZt2xZUtn37dmRnZ5uqh7dIEBFRB1WxZjKpqakJNTU1qKmpAXDhFoiamhrU1dUBAAoLC5Gfn6/N/+ijj+Ljjz/G448/jsOHD+Oll17Ca6+9hscee8zUetkJEhFRxO3fvx+jR4/G6NGjAQButxujR4/G8uXLAQCnT5/WOkQAGDRoELZu3Yrt27cjPT0dzz33HP70pz8hNzfX1Hp5nyAREaGxsRGJiYn45JObkJDg7GJdflx//TE0NDR02zVBq/CaIBERaS5cE+x6HdGCw6FERGRbPBMkIqIOgf+fulpHlGAnSEREGkW9MHW1jmjB4VAiIrItngkSEZFGUS0IjImiM0F2gkRE1CGgXpi6WkeU4HAoERHZFs8EiYhIY7fAGHaCRETUwWa3SHA4lIiIbItngkREpFECKpQuBrZ0dflwYidIREQdOBxKRERkDzwTJCIiDaNDiYjIvjgcSkREZA88EyQiIo3dHqrLTpCIiDqoANQuXtSLomuCHA4lIiLb4pkgERFp+CglIiKyL0aHEhER2QPPBImISMOb5YmIyL44HEpERGQPlnaCxcXFuO2229CrVy/07dsXeXl5OHLkSNA8LS0tWLBgAb797W/j2muvxUMPPQSv12tYr6qqWL58Ofr164cePXogJycHR48etbLpREQEdJwJdnWKEpZ2gjt37sSCBQuwZ88ebN++HW1tbbjnnnvQ3NyszfPYY4/hv//7v/H6669j586d+Oyzz/CDH/zAsN5nnnkGL7zwAkpLS7F3715cc801yM3NRUtLi5XNJyKyvQvXBJUuTpF+F52nqGpXUwPIzpw5g759+2Lnzp347ne/i4aGBiQlJWHDhg344Q9/CAA4fPgwhg0bhqqqKtx+++2X1KGqKlJTU/Hzn/8cixcvBgA0NDQgOTkZ5eXlmDp1anc1n4jINhobG5GYmIgvdw9GwrXOrtXV5Md14z5GQ0MDEhISLGph9+jWwJiGhgYAwHXXXQcAqK6uRltbG3JycrR5hg4dihtuuEHsBE+cOAGPxxO0TGJiIrKyslBVVSV2gq2trWhtbdX+DgQC+PLLL/Htb38biqJY8v6IiCJJVVWcO3cOqampcDgsGtizWWBMt3WCgUAAixYtwh133IERI0YAADweD+Li4tC7d++geZOTk+HxeHTruVienJzc6WWAC9cnn3rqqS68AyKi6HDq1Clcf/311lTGTtAaCxYswKFDh/D+++931yoMFRYWwu12a383NDTghhtuwKHjaejVK/gXU1yj/maIaZBHip2N+uVKk7BJm+N0i1WhHAAC5/Xr8n/l0p+/NVZ//lahHp/88fulutqEutr0h08C7fKwSru0jF8q1/+lqwb0ywNCudEyqnDwdt9Fgw7SAIUivA3FIX/TOITXpGUcTqncr1seE6tfDgCOGP3XnMIyzth2/XJXm7gOZ5y0jH65Q6jL2bNVt9zRQ78eAFCu8em/IJSr1+rX5TcYJWxP1N8ZfAnBdZ07F8CIG0+hV69ecmVkqFs6wYKCArz11lvYtWtX0K+TlJQU+Hw+1NfXB50Ner1epKSk6NZ1sdzr9aJfv35By2RkZIhtcLlccLku7Sx69XIgIeEbnaCq/6Ub45e/ZJwB/Z1UEeqCUK5K8wMICB+PX9UvDyjC/IrQoQnzA0A7hGUc0jqENjkMOkHFZCfYbq6zs7YT7P4hdEWIJrCyExTLY8LQCUodV6z+to3R/613YRmX/jJOYRlHvP62dfbQb6ujp/yrR+kp7NPXCMf4tfp1+YVyAGjvpf+h+xKEHzdWXuJR0fWnQERRYIyl0aGqqqKgoACbN2/Gu+++i0GDBgW9npmZidjYWFRWVmplR44cQV1dHbKzs3XrHDRoEFJSUoKWaWxsxN69e8VliIgoNEpAsWSKFpZ2ggsWLMBf/vIXbNiwAb169YLH44HH48H58+cBXAhomT17NtxuN9577z1UV1dj1qxZyM7ODgqKGTp0KDZv3gzgwi+cRYsW4Ve/+hXefPNNHDx4EPn5+UhNTUVeXp6VzSciIpuxdDh07dq1AIAJEyYElb/88suYOXMmAOD3v/89HA4HHnroIbS2tiI3NxcvvfRS0PxHjhzRIksB4PHHH0dzczPmzZuH+vp6jB8/HhUVFYiPj7ey+UREZLPhUEs7wc7cchgfH4+SkhKUlJR0uh5FUbBy5UqsXLmyy22kDuG4zhWOdVD4cd+5iqkK0NXhzCj67Jg7lIiIbItPkSAiog68T5CIiGzLZtcEORxKRES2xTNBIiLqELAgMCaK7hNkJ3glMxlhJUXTqVbukFZGfUnttajcqK1mM8NENFIxIGSSMcgSEo5taJqFdUn7tOnPKYqiGMNGVbq+XaJou3I4lIiIbItngkREpFECF6au1hEt2AkSEVEHm10T5HAoERHZFs8EiYiog83uE2QnSEREHWw2HMpOMNJC+cVk0a8so3ByMQTdZGi6UU51OSzf3Pyh3NZg1a0Qodw6IT081/y6jV40eyuEtA7rPler9qmQmD1mouhMhrqGnSAREXWw2X2C7ASJiKiDzRJoMzqUiIhsi2eCRETUgcOhRERkV6qqdDnfcERz7ZrETtBq0s4Twk5lekc0GeFnXJe5qMBQIvxMRxGaTPIsJck2XHcYEmhLdUlRo2I9QmLtC3WZfH/SNneE8LmKn5+0QAjHhskoV7GeUL7sLTzGKfLYCRIRUQcOhxIRkW0xOpSIiMgeeCZIREQdOBxKRES2xdyhdJFhsJ5FH3Io0WlmI/zEiDmDdZuOkjTbphDWEdHcoWH4ZWs6vaXh+5NesGjbGu23Vn1OhusQii3KQ2r0/kzvCUJdimHyVwoXdoJERNSBw6FERGRbNhsOZXQoERHZluWd4K5duzB58mSkpqZCURRs2bIl6HVFUXSnZ599VqxzxYoVl8w/dOhQq5tORESqRVOUsLwTbG5uRnp6OkpKSnRfP336dNBUVlYGRVHw0EMPGdb7ne98J2i5999/3+qmExHZnhpQLJlCUVJSgoEDByI+Ph5ZWVnYt2+f4fxr1qzBkCFD0KNHD6SlpeGxxx5DS0uLqXVafk1w0qRJmDRpkvh6SkpK0N9//etfMXHiRAwePNiw3piYmEuWvSKZjS40uoBsNmrO9BO95d9ApqNAxXUb/M4Sfi0GpGWEdQT8QrnR+xNeCyka0iKKQ9ggQpiycVIO/VedQk5RaVs5hASsxvuO0DIL9x05EtqaYyak4zKKgkGuRJs2bYLb7UZpaSmysrKwZs0a5Obm4siRI+jbt+8l82/YsAFLly5FWVkZxo0bh48++ggzZ86EoihYvXp1p9cb0WuCXq8XW7duxezZsy8779GjR5GamorBgwdj+vTpqKurM5y/tbUVjY2NQRMREV3GxejQrk4mrV69GnPnzsWsWbMwfPhwlJaWomfPnigrK9Odf/fu3bjjjjvwox/9CAMHDsQ999yDadOmXfbs8Zsi2gmuX78evXr1wg9+8APD+bKyslBeXo6KigqsXbsWJ06cwJ133olz586JyxQXFyMxMVGb0tLSrG4+EdHV52J0aFcn4JITkdbWVt1V+nw+VFdXIycnRytzOBzIyclBVVWV7jLjxo1DdXW11ul9/PHH2LZtG+677z5TbzeinWBZWRmmT5+O+Ph4w/kmTZqEhx9+GKNGjUJubi62bduG+vp6vPbaa+IyhYWFaGho0KZTp05Z3XwiIjKQlpYWdDJSXFysO9/Zs2fh9/uRnJwcVJ6cnAyPx6O7zI9+9COsXLkS48ePR2xsLG688UZMmDABTzzxhKk2Ruw+wb/97W84cuQINm3aZHrZ3r1745ZbbsGxY8fEeVwuF1wuV1eaSERkPyosuFn+wn9OnTqFhIQErdjK7+QdO3bg6aefxksvvYSsrCwcO3YMCxcuxKpVq/Dkk092up6IdYJ//vOfkZmZifT0dNPLNjU14fjx43jkkUe6oWVERDamWnCz/P93ogkJCUGdoKRPnz5wOp3wer1B5V6vVwyIfPLJJ/HII49gzpw5AICRI0eiubkZ8+bNwy9/+Us4HJ0b6LS8E2xqago6Qztx4gRqampw3XXX4YYbbgBwYZz49ddfx3PPPadbx913340HH3wQBQUFAIDFixdj8uTJGDBgAD777DMUFRXB6XRi2rRp1jT6Cn32lVURmmafEh/KOgIh5GwM+M1FaJqNAg0l+lVqk1yPqdkBAEKAJlTh/Tmcwg5q8P6kXVrx6zfYIUWNCttDjGSFvC84Irh/mp4/0q7Q76TuEhcXh8zMTFRWViIvLw8AEAgEUFlZqfUD3/TVV19d0tE5nU4AgGriwLS8E9y/fz8mTpyo/e12uwEAM2bMQHl5OQBg48aNUFVV7MSOHz+Os2fPan9/8sknmDZtGr744gskJSVh/Pjx2LNnD5KSkqxuPhGRralqaD/uvlmHWW63GzNmzMCYMWMwduxYrFmzBs3NzZg1axYAID8/H/3799euK06ePBmrV6/G6NGjteHQJ598EpMnT9Y6w86wvBOcMGHCZXvhefPmYd68eeLrtbW1QX9v3LjRiqYREdHlRCiB9pQpU3DmzBksX74cHo8HGRkZqKio0IJl6urqgs78li1bBkVRsGzZMnz66adISkrC5MmT8etf/9rUeplAm4iIrggFBQXi8OeOHTuC/o6JiUFRURGKioq6tE52gkRE1MFmT5FgJ0hERBpVVbocMHTFBhzp4KOUiIjItngmGCophFnKHSzNb/SLyextB6YTFBvdviAsI93WICbpltdhNnGyeCtECLc7iOsWYrrk+c3/4lWEhNhSuXibgjA/ADikuoRtqEj7rUNKoG3wuUr7gvA+pH3NEcK+Y/pWCLOJtSEfy9I2jLrbHTgcSkREthWh6NBI4XAoERHZFs8EiYhIY7fAGHaCRETUIYCuX8eMouugHA4lIiLb4pmgESt/zYQQoSleXDYZiSlHdIYSPWkuotMwAlWKArUoUbZh8m6TkbeRjA6Vyo2S5IsRtlKbhMTairASR0A+OOTPQ4g0NblPAeajlM1GgRrtt4pVQ31X6tmSzQJj2AkSEZFGDSjGP847WUe04HAoERHZFs8EiYioA4dDiYjIrux2iwSHQ4mIyLZ4JhgiMW2jpTkKuzeC0ejXmpy/01zkplRuvA5zUaBSbk2prUbrMGqvfptCiA51mHvstkPI32kcXihEdUo1SRGoAXM5SAFAEZeRcoSaj+617BgI4bg0eywbpHi9MqkW5A6NojNBdoJERNTBZtcEORxKRES2xTNBIiLSqKr8SDEzdUQLdoJERNTBZs8T5HAoERHZFs8ErWb2F5BRgJ8YgSoUCzk3xYvUoUSmSk9rF9pkFEUoPr3e5JPizUasGr1m9DR6PSHlDg3LWJG0YwlPlhciUKXPSMopCgBO1a9bbjrfp9G2laI9xWNAqkcoDyWvZxSd/Rix232C7ASJiKgDo0OJiIjsgWeCRESksdtTJNgJEhFRBxUWDIda0pKw4HAoERHZluWd4K5duzB58mSkpqZCURRs2bIl6PWZM2dCUZSg6d57771svSUlJRg4cCDi4+ORlZWFffv2Wd30S6kG08V7ab45iXUppidVmvwO3UmqJ+A3P6kBh/4ktCngd+hOYj0BBwLCJM+vCJN+Pf52pzhJ7RXr8jt1J2l+o8l0XUJbDd+fuH79bWjVZ6QatFfcn8V6zO+34rEkHDNSm0I5XkXSd4XR90sEiZ+TySlaWN4JNjc3Iz09HSUlJeI89957L06fPq1N//mf/2lY56ZNm+B2u1FUVIQDBw4gPT0dubm5+Pzzz61uPhGRvUmdttkpSlh+TXDSpEmYNGmS4TwulwspKSmdrnP16tWYO3cuZs2aBQAoLS3F1q1bUVZWhqVLl3apvUREZF8RuSa4Y8cO9O3bF0OGDMH8+fPxxRdfiPP6fD5UV1cjJydHK3M4HMjJyUFVVZW4XGtrKxobG4MmIiIydjF3aFenaBH2TvDee+/FK6+8gsrKSvz2t7/Fzp07MWnSJPj9+lkmzp49C7/fj+Tk5KDy5ORkeDwecT3FxcVITEzUprS0NEvfBxHR1chu1wTDfovE1KlTtf8fOXIkRo0ahRtvvBE7duzA3Xffbdl6CgsL4Xa7tb8bGxvZERIRUZCI3yIxePBg9OnTB8eOHdN9vU+fPnA6nfB6vUHlXq/X8Lqiy+VCQkJC0ERERJfBwJjw+uSTT/DFF1+gX79+uq/HxcUhMzMTlZWVyMvLAwAEAgFUVlaioKDAkjYoAQsHsMWEvELCX6OdRVpGSh5sMlG20ZCFNKYvJp4WylWDRMSqySTWYmJtk20Kpa5wXONQFHNfHIoiN0rap6VlpG0uza8G9C9fACHsI1JC+FASaJvd1y08Lq28tcHS7yST7JZA2/IzwaamJtTU1KCmpgYAcOLECdTU1KCurg5NTU34xS9+gT179qC2thaVlZV44IEHcNNNNyE3N1er4+6778aLL76o/e12u7Fu3TqsX78eH374IebPn4/m5mYtWpSIiCgUlp8J7t+/HxMnTtT+vnhdbsaMGVi7di3++c9/Yv369aivr0dqairuuecerFq1Ci6XS1vm+PHjOHv2rPb3lClTcObMGSxfvhwejwcZGRmoqKi4JFiGiIi6yoKnSCB6zgQt7wQnTJgA1WDs6O23375sHbW1tZeUFRQUWDb8SURE+jgcSkREZBMRD4whIqIriBXRnYwOtQGzp/vSCHEIEXBWRcZJ0ZmAHC0ovQ+z0ZbGy5iMAjUZTWpYVyjRghZRHNJOot9Wh0MOvTUb7SmVO0L4XB1SSLAUcSx9fiHsO2YjpE2XA+ajQKNoaBCwJuMLM8YQERFFAZ4JEhGRxm6BMewEiYiow+Wej9jZOqIEh0OJiMi2eCZIREQdAkrXA8EYHXp1UAzyXoos/PDlqE6TeRDN1gOIwxlmI/mMogilusxGgYaSO1SMQA0h0tQqihhSJ+2I5gdyFL/ZnKLC/A6DyE2Tn58zxvz+aVUuUEs/V5PHfkjfL2Fgt2uCHA4lIiLb4pkgERF1sFlgDDtBIiLScDiUiIjIJngmSEREGjVg/DDsztYRLdgJhklI0WkWPVleFfOAms+tKUeaChUZrsPc+zAbBSrVD5iPAjWdq9KA0RPh9Yh5XMWoUUAa5JFzhOrXJW1Dp8H7FlOHms1DarQOiz4/s8cYIB/L0TMAeBk2uybI4VAiIrItngkSEZHGboEx7ASJiEhjt06Qw6FERGRbPBMkIqIONguMYScYKrPRkCaj0wyZfEq9lJLSMALOZJSd2XyfhnWZjvzTrz+kdZuMig2FVJcYuSk8Qd7o/SmK39S6pbpCWXfA2f2Rt2KaVfH4E6vSF0JeXfkYN7nuCFPVy+QV7mQd0YLDoUREZFs8EyQiIo3dAmPYCRIRUQcV5oeP9eqIEhwOJSIi2+KZIBERaTgcSkREtsVOkK4YZncks7dCGIYxS3WZTawdhtswQlm3VXWFJYG20Fbp9gXA/DaU6rL0cxUTZUv1iKsIbZ82UQ/Zh+XXBHft2oXJkycjNTUViqJgy5Yt2mttbW1YsmQJRo4ciWuuuQapqanIz8/HZ599ZljnihUroChK0DR06FCrm05EZHtqQLFkCkVJSQkGDhyI+Ph4ZGVlYd++fYbz19fXY8GCBejXrx9cLhduueUWbNu2zdQ6Le8Em5ubkZ6ejpKSkkte++qrr3DgwAE8+eSTOHDgAN544w0cOXIE3//+9y9b73e+8x2cPn1am95//32rm05ERBczxnR1MmnTpk1wu90oKirCgQMHkJ6ejtzcXHz++ee68/t8Pnzve99DbW0t/uu//gtHjhzBunXr0L9/f1PrtXw4dNKkSZg0aZLua4mJidi+fXtQ2YsvvoixY8eirq4ON9xwg1hvTEwMUlJSOt2O1tZWtLa2an83NjZ2elkiIuq6b37vulwuuFwu3XlXr16NuXPnYtasWQCA0tJSbN26FWVlZVi6dOkl85eVleHLL7/E7t27ERsbCwAYOHCg6TZG/BaJhoYGKIqC3r17G8539OhRpKamYvDgwZg+fTrq6uoM5y8uLkZiYqI2paWlWdhqIqKr08XAmK5OAJCWlhb0PVxcXKy7Tp/Ph+rqauTk5GhlDocDOTk5qKqq0l3mzTffRHZ2NhYsWIDk5GSMGDECTz/9NPx+/ZSBkogGxrS0tGDJkiWYNm0aEhISxPmysrJQXl6OIUOG4PTp03jqqadw55134tChQ+jVq5fuMoWFhXC73drfjY2N7AiJiC7DyujQU6dOBX23S2eBZ8+ehd/vR3JyclB5cnIyDh8+rLvMxx9/jHfffRfTp0/Htm3bcOzYMfz0pz9FW1sbioqKOt3WiHWCbW1t+Ld/+zeoqoq1a9cazvv14dVRo0YhKysLAwYMwGuvvYbZs2frLmN02t1poSS+lS4Im026CxhEaErRd2ajSQ0i/CyqK7QIP2sSZVsZmSpHPIbwZWHR+Ith9KuwTcwm1rbycxXnl/bnEL6ITR8bZhNxG71mdl+IssTaoUhISDA8wemKQCCAvn374o9//COcTicyMzPx6aef4tlnn73yO8GLHeDJkyfx7rvvmt5IvXv3xi233IJjx451UwuJiOxJVbv+FAizy/fp0wdOpxNerzeo3Ov1irEg/fr1Q2xsLJxOp1Y2bNgweDwe+Hw+xMXFdWrdYb8meLEDPHr0KN555x18+9vfNl1HU1MTjh8/jn79+nVDC4mI7MvKa4KdFRcXh8zMTFRWVmplgUAAlZWVyM7O1l3mjjvuwLFjxxAIdJxSf/TRR+jXr1+nO0CgGzrBpqYm1NTUoKamBgBw4sQJ1NTUoK6uDm1tbfjhD3+I/fv349VXX4Xf74fH49F67ovuvvtuvPjii9rfixcvxs6dO1FbW4vdu3fjwQcfhNPpxLRp06xuPhERRYDb7ca6deuwfv16fPjhh5g/fz6am5u1aNH8/HwUFhZq88+fPx9ffvklFi5ciI8++ghbt27F008/jQULFphar+XDofv378fEiRO1vy8Gp8yYMQMrVqzAm2++CQDIyMgIWu69997DhAkTAADHjx/H2bNntdc++eQTTJs2DV988QWSkpIwfvx47NmzB0lJSVY3n4jI3gJKaA/7/mYdJk2ZMgVnzpzB8uXL4fF4kJGRgYqKCi1Ypq6uDg5Hx3lbWloa3n77bTz22GMYNWoU+vfvj4ULF2LJkiWm1mt5JzhhwgSoBgPCRq9dVFtbG/T3xo0bu9osIiLqhEjmDi0oKEBBQYHuazt27LikLDs7G3v27AlpXRcxd2ioIplz0Oy6LYxsFHfuMOQONcvS6FALIxjFqECTFyeszHtpac5Uiz5Xw/3TomMgLJif9IrGTpCIiDR8igQREdmW3TrBiKdNIyIiihSeCRIR0dd0/UwQiJ4zQXaCRETUIcRHIV1SR5RgJxhhoT58Ur+yMEQLWpjv0Lr8pBa+byujQKV1iLkn9YsVp/kcVlZtKyv3T2nfsfT6kVWRxVYel3RFYydIREQaNdD1H7tW/ljubuwEiYhIw+hQIiIim+CZIBERaex2JshOkIiINOwE6Yph1wi1SOYOtXIdEkUxF+0ZlqjYKPrSspJdjzHqwE6QiIg0F54s39UzQYsaEwbsBImIqIPNbpZndCgREdkWzwSJiEjDwBgiIrItu3WCHA4lIiLbMnUmOGHCBGRkZGDNmjXd1BzqDuKvsnAk3L5KkltH8nYEBRYm0JZuCXBaU//lXjO3kjCsgy7B3KFERGRbHA4VzJw5Ezt37sTzzz8PRVGgKApqa2u7sWlERETdq9Nngs8//zw++ugjjBgxAitXrgQAJCUldVvDiIgo/Ox2JtjpTjAxMRFxcXHo2bMnUlJSurNNREQUIXbrBBkdSkREtsXAGAAhBN8Z1GUyKs9o3SYjFc1GBBr9WgsE9H8fSeXWrsOa9y3VDwABv7n3FwjDL1uHycTaRqQk3fK2MrcNDaNDTdZl5TpMR/eK5eKqDY5lC/eRCObetNuZoKlOMC4uDn6/v7vaQkREEWa3TtDUcOjAgQOxd+9e1NbW4uzZswgELr0ZZNeuXZg8eTJSU1OhKAq2bNkS9Lqqqli+fDn69euHHj16ICcnB0ePHr3suktKSjBw4EDEx8cjKysL+/btM9N0IiKiS5jqBBcvXgyn04nhw4cjKSkJdXV1l8zT3NyM9PR0lJSU6NbxzDPP4IUXXkBpaSn27t2La665Brm5uWhpaRHXu2nTJrjdbhQVFeHAgQNIT09Hbm4uPv/8czPNJyKiy7h4JtjVKVqYGg695ZZbUFVVZTjPpEmTMGnSJN3XVFXFmjVrsGzZMjzwwAMAgFdeeQXJycnYsmULpk6dqrvc6tWrMXfuXMyaNQsAUFpaiq1bt6KsrAxLly418xaIiMiIqgBdfdhwFHWCYY0OPXHiBDweD3JycrSyxMREZGVliZ2rz+dDdXV10DIOhwM5OTmGHXJraysaGxuDJiIioq8La3Sox+MBACQnJweVJycna69909mzZ+H3+3WXOXz4sLiu4uJiPPXUU11rsFGElpAbT8yZF0JOSjHSTYhsVKVoyBCiJ81HYoYQoSlFCwrvz+/XT3AplUv1GK3bL0YqilVZRlUi9+tZiiYNOIRyw8+1+6N75Shs4fOTjplQcsWK69afXRG/E+RVRBIDY64ShYWFaGho0KZTp05FuklERHSFCeuZ4MVMM16vF/369dPKvV4vMjIydJfp06cPnE4nvF5vULnX6zXMXONyueByubreaCIiG+GZYDcaNGgQUlJSUFlZqZU1NjZi7969yM7O1l0mLi4OmZmZQcsEAgFUVlaKyxARUWhU1ZopWlh+JtjU1IRjx45pf584cQI1NTW47rrrcMMNN2DRokX41a9+hZtvvhmDBg3Ck08+idTUVOTl5WnL3H333XjwwQdRUFAAAHC73ZgxYwbGjBmDsWPHYs2aNWhubtaiRYmIiEJheSe4f/9+TJw4Ufvb7XYDAGbMmIHy8nI8/vjjaG5uxrx581BfX4/x48ejoqIC8fHx2jLHjx/H2bNntb+nTJmCM2fOYPny5fB4PMjIyEBFRcUlwTJERNRFVtznF0XDoZZ3ghMmTIBqcC6sKApWrlypPY5Jj95zCgsKCrQzwyuC2fyBUoSY0f04ZqPp/OYiN1VhfgAItEuRm+YiNP3t8iPL/SajQM1Gk7YL9VyoS9qG5vNYWkXM9ym8DYdBm6QDW1qHtM0VITrUYfC5Ohz6O7vTr18eaNdPxajGGOWdNXcMmM4danRcmowA7/I9d2HGa4JEREQ2wadIEBGRxm5nguwEiYhIY7dOkMOhRERkWzwTJCIijRpQ5AcHm6gjWrATNKCE8kGajRo1evK6mCPUXN7EgBDJF0peT78UNSrNbxChKUWOSsu0t+nvrlIUqNRWwPwT5EPKMSmQIjSlcofQVikK06hdZtsrtkmIGgUAZ0A/2lP6PJwx5vYpQN6nxdyhJqOBDUP8LYoCDen7JQwu3Oze1eFQixoTBhwOJSIi2+KZIBERaewWGMNOkIiINHbrBDkcSkREtsUzQSIi0tjtTJCdIBERadgJ2pAiR5rLTN7yoEq3O0gh2pATXEvLSOHecuJpgwTTJhNlS7cvGN2mYNWtEO1tQlsNQuzlJN3680f0FgnhbTid+rcihNouPWKb2uWDpt2h//lJt3RI+5rh/mkymbp8/Jk7xgxfC+E2KElI30kUEnaCRESk4ZkgERHZlt06QUaHEhGRbfFMkIiINKpqQe7QKDoTZCdIREQauw2HshM0YhShJb0mJbeWogsNfnGJCbHFaDopgba5ZNiAQcSlkLi4XSyXdzEpCrRNqsts1GgICbT9Zj+/MESHOoVk1X4hshEAYmL0d1DxfcTq12O2rQCgCO11Os1F90qJtQF5nxaPAWEfMZtwGzD4zKVlxO8KcRUURuwEiYhIc+EpEl2vI1qwEyQiIk1AVcRHipmpI1owOpSIiGyLZ4JERKRhYAwREdmXBZ1gKKniIoWdoBGji7smcxHK8xtFwEkRmlIUqLkITb9B5KaY11NYRozcFMoBwGdymTZh3W1CdGG7wbaV8ooGhM8vHNc4HFKeTqcUNSqHFwaEJKhqbPfmFDV6zenUb5MzRj8HqtH+2d6uv0xMe7tQl7ljxui4NHvsK2KuUXkVFD7sBImISMPhUCIisi27dYKMDiUiItsKeyc4cOBAKIpyybRgwQLd+cvLyy+ZNz4+PsytJiKyBzWgWDKFoqSkBAMHDkR8fDyysrKwb9++Ti23ceNGKIqCvLw80+sM+3Do3//+d/j9HRe1Dx06hO9973t4+OGHxWUSEhJw5MgR7W9FiZ5TbSKiaBKp4dBNmzbB7XajtLQUWVlZWLNmDXJzc3HkyBH07dtXXK62thaLFy/GnXfeGVJbw94JJiUlBf39m9/8BjfeeCPuuusucRlFUZCSktJ9jQoIYVoGTyYXI7vEJ8sLUYdCdBog5y+U8h2KT0sXc4fK65ajPc1FjfrahKSUButo9ZmMDhWjYo1yh+pvWyl3qBQdGkp6KOk3nBQd6vRLT5w3iH4VcodK70MIJhUZ/Q4Vc4e26Ud0Op36n2uMEDUKGOzTFuUINTwupQhw8cnyQkVGZ0tmP5CrwOrVqzF37lzMmjULAFBaWoqtW7eirKwMS5cu1V3G7/dj+vTpeOqpp/C3v/0N9fX1ptcb0WuCPp8Pf/nLX/CTn/zE8OyuqakJAwYMQFpaGh544AF88MEHl627tbUVjY2NQRMRERm7eCbY1QnAJd/Bra2tuuv0+Xyorq5GTk6OVuZwOJCTk4OqqiqxrStXrkTfvn0xe/bskN9vRDvBLVu2oL6+HjNnzhTnGTJkCMrKyvDXv/4Vf/nLXxAIBDBu3Dh88sknhnUXFxcjMTFRm9LS0ixuPRHR1cfKTjAtLS3oe7i4uFh3nWfPnoXf70dycnJQeXJyMjwej+4y77//Pv785z9j3bp1XXq/Eb1F4s9//jMmTZqE1NRUcZ7s7GxkZ2drf48bNw7Dhg3DH/7wB6xatUpcrrCwEG63W/u7sbGRHSERURidOnUKCQkJ2t8ul8uSes+dO4dHHnkE69atQ58+fbpUV8Q6wZMnT+Kdd97BG2+8YWq52NhYjB49GseOHTOcz+VyWbbBiYjsIqB2PUPSxTCLhISEoE5Q0qdPHzidTni93qByr9erGw9y/Phx1NbWYvLkyR3r/P/rqDExMThy5AhuvPHGTrU1YsOhL7/8Mvr27Yv777/f1HJ+vx8HDx5Ev379uqllRET2ZeVwaGfFxcUhMzMTlZWVWlkgEEBlZWXQSOBFQ4cOxcGDB1FTU6NN3//+9zFx4kTU1NSYGvWLyJlgIBDAyy+/jBkzZiAmJrgJ+fn56N+/vzZ2vHLlStx+++246aabUF9fj2effRYnT57EnDlzItH0DlLkpskcoUb300gRanK5uYg56WnwF14TokNN5w6V1yFFgUo5RaUcoVJ0aJtRXlYpOlQISIxodKiwCR1CFCYgvz9VCDo0+z6Mcoc6Hfqfn9MpRIe26zeqXcgDCgAxJiOhzR5Lhk+WN5sHOMR75uzG7XZjxowZGDNmDMaOHYs1a9agublZixb9et8QHx+PESNGBC3fu3dvALik/HIi0gm+8847qKurw09+8pNLXqurqwsK/f7Xv/6FuXPnwuPx4Fvf+hYyMzOxe/duDB8+PJxNJiKyhUjdJzhlyhScOXMGy5cvh8fjQUZGBioqKrRgmW/2DVZRVDWU37LRp7GxEYmJiTj5+QAkJARvyHiP/q/BWK+8wZXP9bPW+L+4Rre8/V89dct9DfrlANDaqP9aa1MP3fKWJv02tXxlrhwAWs7rv9baEieU619/bW2V7xNsEV7jmWCwUM4EY4UnNsQK997FxuqXx8Xqn43Fu9rEdbuE11zx+uHxrnif/jp6tIjriO+p/5pYfq1+ueva8/rlCV+J645L1H8t5lv65c5vN+uWq33l99eWrP/5taQEf06NjQEM6HsSDQ0Nnbr2ZuTid+T6wc+ip0P/O6azvgqcx4yPf2FJu7obc4cSEZFt8SkSRESksdtTJNgJEhGRJqAqFtwiET2dIIdDiYjItngmCECM9jbKYSv90hGS64oJfI2CN4SAD7+YEFv/4/SLty/IH790a0ObTwhmEW93kANjpAAYsS7hfYuBMe3yr9F2kwm0xV0khMAYh9AsqbVSAu0Yo1skYqRE4Przmx2+MgrSk26fcDj0Dyjp1on2GHn/lPZp8RgweeuE0XEpHcvSsS9+Vxh8vxjcgdLtOBxKRES2ZbdOkMOhRERkWzwTJCIijd3OBNkJEhGRRrUgOjSaOkEOhxIRkW3xTNCAYYSWFD0mpdcS5pei0wCD1F4mI+CkRNlSMmyj19pMlktJsgE5CrRViExt9em/P5+wbaUIUAAQAi7hFz4/aX4rOYXmOoV1G72/diEMVEoLZ/aXu5T6DTCIAhUSZceIydqFxkLep81GTkvHmNFxKUaOihHj+vNHMgLUiKqGlg7wm3VEC3aCRESkUQMKVPGGnc7XES04HEpERLbFM0EiItIwOpSIiGyLuUOJiIhsgmeCgJzDz+DiriosI0WOqSZzigJyhFpAioDzC1GgYu5QOQJOyhHaJuX1FMvldZiNAm0V3rdP+JzajKInpehQYX5pFwklCE5qlZQK1CksIaQHBQDEijlCrfndK+UHBQCnQ39fcApRozFO/X0tJkb/gb6AvE9Lx4B0zIi5Qw2OS/FYlo59YedRjIJHjPIWdzNGhxIRkW3Z7Zogh0OJiMi2eCZIREQauwXGsBMkIiKN3a4JcjiUiIhsi2eCRoxO6U3mCJWizaSoNQAICHVJEXDi07alqFGD3KHy09qlJ8ibywMKmI8CbRGi8nzCZ+Ez+DUqxR22C/Ge4pPl5VWIpE9c2ttipOhQg/fnl3JiisuY+z2sKPLn6hSSncbESLlD9T+NWIP906pjQDrGjI5L6Vg2nVP0Ch0ytFtgDDtBIiLS2O2aIIdDiYjItngmSEREGlWVb/A3U0e0YCdIREQaVbXgUUocDiUiIrryhf1McMWKFXjqqaeCyoYMGYLDhw+Ly7z++ut48sknUVtbi5tvvhm//e1vcd9991nWJkXMHWqwkJT3z+yT5S3NHWryyfJCJJ3Ra21CtKdPKJeiTAHzUaCtwjZvEepvM8jsKUWH+oVl5Nyh5sd9FOFXtrQnSG01Onj90n5o+he6fquc8scKX5v+NomN0V8oNkbKbWuwf0pPljcZ7SkeY4a5Q01GgUrfFQbfL+J3UhgEVAWBLp4JMjDmMr7zne/g9OnT2vT++++L8+7evRvTpk3D7Nmz8Y9//AN5eXnIy8vDoUOHwthiIiKbUDtumA91CimzfIREpBOMiYlBSkqKNvXp00ec9/nnn8e9996LX/ziFxg2bBhWrVqFW2+9FS+++GIYW0xERFejiHSCR48eRWpqKgYPHozp06ejrq5OnLeqqgo5OTlBZbm5uaiqqjJcR2trKxobG4MmIiIyFlA77hUMfYr0u+i8sHeCWVlZKC8vR0VFBdauXYsTJ07gzjvvxLlz53Tn93g8SE5ODipLTk6Gx+MxXE9xcTESExO1KS0tzbL3QER0terqUKgVuUfDKeyd4KRJk/Dwww9j1KhRyM3NxbZt21BfX4/XXnvN0vUUFhaioaFBm06dOmVp/UREFP0ifp9g7969ccstt+DYsWO6r6ekpMDr9QaVeb1epKSkGNbrcrngcrksaycRkR3Y7T7BiHeCTU1NOH78OB555BHd17Ozs1FZWYlFixZpZdu3b0d2dnb3N87og5QSZQvh/aoQJi2GW0MO0xaTB0vl0i0SwvyAnCi71WSi7JY2+f35hG1i9laIFiEUTUqGDci3T/gVoTwM4W5O4YvHKeyHsQZfVH7xFXNh/A5hdofB5+pw6O8LMcKtELGx+q012j+lfdrssSEdY0bHpXwsCxvL7C0VERZQQ0sM/806okXYh0MXL16MnTt3ora2Frt378aDDz4Ip9OJadOmAQDy8/NRWFiozb9w4UJUVFTgueeew+HDh7FixQrs378fBQUF4W46ERFdZcJ+JvjJJ59g2rRp+OKLL5CUlITx48djz549SEpKAgDU1dXB4ejom8eNG4cNGzZg2bJleOKJJ3DzzTdjy5YtGDFiRLibTkR01VPVrt/mF02BMWHvBDdu3Gj4+o4dOy4pe/jhh/Hwww93U4uIiOgiZowhIiKyiYgHxhAR0ZWDw6HUOcKHLIUGywm05WEDKemvKkWNmoyYaxciOgE5UXa70Ka2Nv330SZFzAFosSgK1CeVG2QhFqNDTUaNWkmKAnUq+uV+gyGrONXsII9+XYrwGTnl8FNxX2iPlfYd8/un2Uho6ZgRE2sbHJfSsSzeFhBFHQJgv06Qw6FERGRbPBMkIiKN3QJj2AkSEZHGiichRdFoKIdDiYjIvngmSEREGrulTWMnCMifuMGeIOUPlPIEijkKhXKjZUznTQwhd2i78D7apOhQoa1SflAA8AkHihjtKZS3CFGgPoMPsE3MEaq/jLiLhBA16hCul0h5Op3CgI3f4LqLuNmFqFFpLxTzlhp8rrHCviDtO9K+ZmnuUGF+K49LMZ+wsK0Mgpe73gt1gQoLEmh3cflw4nAoERHZFs8EiYhIo1owHBpN9wmyEyQiIg2jQ4mIiGyCZ4JERKRhdCh1MIiAk54KbTZ3aChPlpfyGvqF6DuzT9sG5ByhYtSokCPUZ3A0mY0ClXKBSlGgrUa5Q4XXxNyhVg7wiFGgQp5OYd0Bo4Ecobnik+KFqFFpDzGKDo0T9gUxCtTkfmv0mnQMSMdMSE+WN507VCg3+n6JIA6HEhER2QTPBImISMPhUCIisi0OhxIREdkEzwSJiEgTgAXDoVY0JEzYCYbKbMSXML/hE6zF6FBz5Waj8oyWaWsXniAv7PVt4hrk16SnvrcK5VIeUCkC9MI69F+TysWPO4SBH4cQBdouVBUrDdgYBBdK61CEVB5SBGqMUI/h5yrtC8K+Y+X+afbYkI8x85HhZo/9KxWHQ4mIiGyCZ4JERKThcCgREdmWiq4nwOZwKBERURTgmSAREWk4HEpERLZlt+hQdoIAhAj7kEKbxQTawjrEpLsAAkIiYrNh4AEhnNwvzH/hNSGcXSoX3odfXAPQbvKWh3bhlgeztzsAcjLudmEZSxNoC5yKsO9ICxg0SbpFQkqg3SbMHyvsuH6D+zOkfUHad6R9zWj/lPZp88eGuUT4F16Tyk1+XxjML34nXeVKSkrw7LPPwuPxID09Hf/+7/+OsWPH6s67bt06vPLKKzh06BAAIDMzE08//bQ4vyTs1wSLi4tx2223oVevXujbty/y8vJw5MgRw2XKy8uhKErQFB8fH6YWExHZh4qOIdFQp1D68E2bNsHtdqOoqAgHDhxAeno6cnNz8fnnn+vOv2PHDkybNg3vvfceqqqqkJaWhnvuuQeffvqpqfWGvRPcuXMnFixYgD179mD79u1oa2vDPffcg+bmZsPlEhIScPr0aW06efJkmFpMRGQfXe0AQ72muHr1asydOxezZs3C8OHDUVpaip49e6KsrEx3/ldffRU//elPkZGRgaFDh+JPf/oTAoEAKisrTa037MOhFRUVQX+Xl5ejb9++qK6uxne/+11xOUVRkJKS0un1tLa2orW1Vfu7sbHRfGOJiChk3/zedblccLlcl8zn8/lQXV2NwsJCrczhcCAnJwdVVVWdWtdXX32FtrY2XHfddabaGPFbJBoaGgDgsg1vamrCgAEDkJaWhgceeAAffPCB4fzFxcVITEzUprS0NMvaTER0tVItmgAgLS0t6Hu4uLhYd51nz56F3+9HcnJyUHlycjI8Hk+n2r1kyRKkpqYiJyfHxLuNcGBMIBDAokWLcMcdd2DEiBHifEOGDEFZWRlGjRqFhoYG/O53v8O4cePwwQcf4Prrr9ddprCwEG63W/u7sbGRHSER0WVYeYvEqVOnkJCQoJXrnQVa4Te/+Q02btyIHTt2mI4XiWgnuGDBAhw6dAjvv/++4XzZ2dnIzs7W/h43bhyGDRuGP/zhD1i1apXuMtJpt1VUMVmuufnFemAQaSrVJUammk/eLUXmiZF8wvuWkmEDcsSlVC5Hk5qPDjUbBSrNb2UCbXlgRn/dRrGIDiHS1ClG8Zr7LAw/V2kdJqNADZPLm9zXTR8zhontrTn27SAhISGoE5T06dMHTqcTXq83qNzr9V72Mtjvfvc7/OY3v8E777yDUaNGmW5jxIZDCwoK8NZbb+G9994Tz+YksbGxGD16NI4dO9ZNrSMisifVon/MiIuLQ2ZmZlBQy8Ugl6+fAH3TM888g1WrVqGiogJjxowJ6f2GvRNUVRUFBQXYvHkz3n33XQwaNMh0HX6/HwcPHkS/fv26oYVERPYVqehQt9uNdevWYf369fjwww8xf/58NDc3Y9asWQCA/Pz8oMCZ3/72t3jyySdRVlaGgQMHwuPxwOPxoKmpydR6wz4cumDBAmzYsAF//etf0atXL+2iZ2JiInr06AHgwpvt37+/dhF15cqVuP3223HTTTehvr4ezz77LE6ePIk5c+aEu/lERNQNpkyZgjNnzmD58uXweDzIyMhARUWFFixTV1cHh6PjvG3t2rXw+Xz44Q9/GFRPUVERVqxY0en1hr0TXLt2LQBgwoQJQeUvv/wyZs6cCeDSN/uvf/0Lc+fOhcfjwbe+9S1kZmZi9+7dGD58eLiaTURkC5FMm1ZQUICCggLd13bs2BH0d21tbYhrCRb2TlDtxDM6vvlmf//73+P3v/99N7WIiIguYgJt6mDlJynlCTTMUWhNFKgUfRdSfkRhfmlTGf3kaRfKzUYqiuUGCRjNRoG2C3WFFh0qUKWtqL+Ew+j9iTk/zUbk6jN612b3hVBycZrdp82WG+YNDiGnsK5o6imuYuwEiYhIo0KF2sUM3p0Z8btSsBMkIiKN3YZDI542jYiIKFJ4JkhERBq7nQmyEyQioq8xn/FFr45owU4wVCHkHNSd39Kn15urKxBCZGrAoqhRI9ImDJiM0DSK3DS7TCjrMEvaVpa+P3EbCoRVGH2u4r5gMgrUaP8U123RsRHScSnmFLVv7tBowE6QiIg0HA4lIiLbCiUBtl4d0YLRoUREZFs8EyQiIg2HQ4mIyLZUpeuxPKr2rysfO0GrWZW7EIAqPnFbKPebnV9ed1u7uSfI+4V6pKfBA0C7ySfCS+VSjlCjJ8v7hHWLuUPD8GT5gCJ8TlJFBquOEepqE/KTStuqXXhCfbtRXk/h/Un7jrSvGe2fVh0D0jFmnFeXUaBXE3aCRESkuTAc2rXTOA6HEhFRVLLbNUFGhxIRkW3xTJCIiDR2u0+QnSAREWk4HEpERGQTPBME5J8tRiHPwjLirRBiiLZBGLgY7i2ElAvraG8XyoX6AaBdWIdPeH+twvBHm8GwSKtwa4N0+4IUxn8e7brlLYp04wbgE27q8AnLSNFy/hCGfZzCLQTSrRNxqlO/TYp+OQDx9glFOOSlbd4q/E6OM3hyuLQvuIR9R9rXDPdPYZ8Wb52QjhlhHUbHpenbKiz8fgmHAFQLokM5HEpERFHIbjfLcziUiIhsi2eCRESk4XAoERHZmL2eLM/hUCIisi2eCRoxiBATrxyL0aHmotMA8xFtUjRdu18/irAthOjQduEHnn58phx1CMgRmq1SuVCXFAXaIrZKjgJtEdctR5paxSVGgepv9IBBhKYQaAqn8LvXoeq/v1ihIp+QWBsA2lVhP5T2HWFfM9w/hX1aOgZMR1obrFuMHDWbWNvo+yWC7HafIDtBIiLS2O2aYMSGQ0tKSjBw4EDEx8cjKysL+/btM5z/9ddfx9ChQxEfH4+RI0di27ZtYWopERFdrSLSCW7atAlutxtFRUU4cOAA0tPTkZubi88//1x3/t27d2PatGmYPXs2/vGPfyAvLw95eXk4dOhQmFtORHR1Uy2aokVEOsHVq1dj7ty5mDVrFoYPH47S0lL07NkTZWVluvM///zzuPfee/GLX/wCw4YNw6pVq3DrrbfixRdfDHPLiYiubgFFtWSKFmG/Jujz+VBdXY3CwkKtzOFwICcnB1VVVbrLVFVVwe12B5Xl5uZiy5Yt4npaW1vR2tqq/d3Q0AAAOHfu0ku2/ib9C9TOr+RgCP95IVVXS5tu+Vet+uVNbT5xHU1trbrlze2x+uvw65efD+gHEbQYBFa0CEEarVKghJQ2TQi4AIA2IXClXSj3C0Erfuhv24D4vHv5NalcNajLKgEpMAZCuTA/APiFZfyqsM2FutpU/a8IqRwAfMK6pX1H2g9dBvvO+UCLbnmcX788pl2/3CEcY4rBcQnhWPYLx75P+K4w/H5p0g8taW4MLr/4faYaBUmRobB3gmfPnoXf70dycnJQeXJyMg4fPqy7jMfj0Z3f4/GI6ykuLsZTTz11SfmIG0+F0GrqFtJxa9PjuTnSDYgUqS84Z7CM9NrpLrYlSn3xxRdITEy0pC67BcZctdGhhYWFQWeP9fX1GDBgAOrq6izbWa52jY2NSEtLw6lTp5CQkBDp5kQNbjfzuM1C09DQgBtuuAHXXXedZXVacU0verrACHSCffr0gdPphNfrDSr3er1ISUnRXSYlJcXU/ADgcrngcrkuKU9MTORBZlJCQgK3WQi43czjNguNw8G8J6EK+5aLi4tDZmYmKisrtbJAIIDKykpkZ2frLpOdnR00PwBs375dnJ+IiEJzcTi0q1O0iMhwqNvtxowZMzBmzBiMHTsWa9asQXNzM2bNmgUAyM/PR//+/VFcXAwAWLhwIe666y4899xzuP/++7Fx40bs378ff/zjHyPRfCKiqxavCYbBlClTcObMGSxfvhwejwcZGRmoqKjQgl/q6uqCTu/HjRuHDRs2YNmyZXjiiSdw8803Y8uWLRgxYkSn1+lyuVBUVKQ7REr6uM1Cw+1mHrdZaLjduk5RGVtLRGR7jY2NSExMxBjn84hRenSprnb1PPb7F6KhoeGKv8Z71UaHEhGReaoFj1Lq+qOYwochRUREZFs8EyQiIo1qQWBMNJ0JshMkIiJNQFGhdDH3ZzRFh3I4lIiIbOuq7QR//etfY9y4cejZsyd69+7dqWVUVcXy5cvRr18/9OjRAzk5OTh69Gj3NvQK8+WXX2L69OlISEhA7969MXv2bDQ1NRkuM2HCBCiKEjQ9+uijYWpxZPB5mOaZ2Wbl5eWX7FPx8fFhbG3k7dq1C5MnT0ZqaioURTF8YMBFO3bswK233gqXy4WbbroJ5eXlptcbsGiKFldtJ+jz+fDwww9j/vz5nV7mmWeewQsvvIDS0lLs3bsX11xzDXJzc9HSop+B/mo0ffp0fPDBB9i+fTveeust7Nq1C/PmzbvscnPnzsXp06e16ZlnnglDayODz8M0z+w2Ay6kUPv6PnXy5MkwtjjympubkZ6ejpKSkk7Nf+LECdx///2YOHEiampqsGjRIsyZMwdvv/22qfXaLWPMVX+fYHl5ORYtWoT6+nrD+VRVRWpqKn7+859j8eLFAC4kp01OTkZ5eTmmTp0ahtZG1ocffojhw4fj73//O8aMGQMAqKiowH333YdPPvkEqampustNmDABGRkZWLNmTRhbGzlZWVm47bbbtOdZBgIBpKWl4Wc/+xmWLl16yfxTpkxBc3Mz3nrrLa3s9ttvR0ZGBkpLS8PW7kgyu806e9zahaIo2Lx5M/Ly8sR5lixZgq1btwb9uJo6dSrq6+tRUVFx2XVcvE9wRMxzcHbxPkG/eh6H2n8eFfcJXrVngmadOHECHo8HOTk5WlliYiKysrLE5xxebaqqqtC7d2+tAwSAnJwcOBwO7N2713DZV199FX369MGIESNQWFiIr776qrubGxEXn4f59f2kM8/D/Pr8wIXnYdplvwplmwFAU1MTBgwYgLS0NDzwwAP44IMPwtHcqGXVfqZa9E+0YHTo/7v4bEKzzy28mng8HvTt2zeoLCYmBtddd53hNvjRj36EAQMGIDU1Ff/85z+xZMkSHDlyBG+88UZ3NznswvU8zKtJKNtsyJAhKCsrw6hRo9DQ0IDf/e53GDduHD744ANcf/314Wh21JH2s8bGRpw/fx49enTu7C4AFYqNcodG1Zng0qVLL7lY/s1JOqjsrLu327x585Cbm4uRI0di+vTpeOWVV7B582YcP37cwndBdpKdnY38/HxkZGTgrrvuwhtvvIGkpCT84Q9/iHTT6CoTVWeCP//5zzFz5kzDeQYPHhxS3RefTej1etGvXz+t3Ov1IiMjI6Q6rxSd3W4pKSmXBCq0t7fjyy+/NHx24zdlZWUBAI4dO4Ybb7zRdHuvZOF6HubVJJRt9k2xsbEYPXo0jh071h1NvCpI+1lCQkKnzwIB+50JRlUnmJSUhKSkpG6pe9CgQUhJSUFlZaXW6TU2NmLv3r2mIkyvRJ3dbtnZ2aivr0d1dTUyMzMBAO+++y4CgYDWsXVGTU0NAAT9mLhafP15mBeDFC4+D7OgoEB3mYvPw1y0aJFWZqfnYYayzb7J7/fj4MGDuO+++7qxpdEtOzv7kltvQtnP7NYJRtVwqBl1dXWoqalBXV0d/H4/ampqUFNTE3TP29ChQ7F582YAF6KvFi1ahF/96ld48803cfDgQeTn5yM1NdUwIutqMmzYMNx7772YO3cu9u3bh//93/9FQUEBpk6dqkWGfvrppxg6dKh2j9fx48exatUqVFdXo7a2Fm+++Sby8/Px3e9+F6NGjYrk2+k2brcb69atw/r16/Hhhx9i/vz5lzwPs7CwUJt/4cKFqKiowHPPPYfDhw9jxYoV2L9/f6c7gKuB2W22cuVK/M///A8+/vhjHDhwAD/+8Y9x8uRJzJkzJ1JvIeyampq07y3gQvDexe80ACgsLER+fr42/6OPPoqPP/4Yjz/+OA4fPoyXXnoJr732Gh577LFIND9qRNWZoBnLly/H+vXrtb9Hjx4NAHjvvfcwYcIEAMCRI0fQ0NCgzfP444+jubkZ8+bNQ319PcaPH4+Kigpb3aT76quvoqCgAHfffTccDgceeughvPDCC9rrbW1tOHLkiBb9GRcXh3feeUd7MHJaWhoeeughLFu2LFJvodtF4nmY0c7sNvvXv/6FuXPnwuPx4Fvf+hYyMzOxe/duDB8+PFJvIez279+PiRMnan+73W4AwIwZM1BeXo7Tp09rHSJwYTRr69ateOyxx/D888/j+uuvx5/+9Cfk5uaaWm8AsOBMMHpc9fcJEhHR5V28T3Bw7G/hULr2wz+gtuDjtiW8T5CIiOhKdtUOhxIRkXkXglrsExjDTpCIiDR26wQ5HEpERLbFM0EiItL4Lcj9GU1nguwEiYhIw+FQIiIim+CZIBERaex2JshOkIiINH4lAFXpWs6XQBTljOFwKBER2RY7QSILnDlzBikpKXj66ae1st27dyMuLg6VlZURbBmROX6olkzRgp0gkQWSkpJQVlamPSHi3LlzeOSRR7Rk5ETRImBBBxjqNcGSkhIMHDgQ8fHxyMrK0p5WI3n99dcxdOhQxMfHY+TIkZc8Sqoz2AkSWeS+++7D3LlzMX36dDz66KO45pprUFxcHOlmEUWFTZs2we12o6ioCAcOHEB6ejpyc3MvedD3Rbt378a0adMwe/Zs/OMf/0BeXh7y8vJw6NAhU+vlUySILHT+/HmMGDECp06dQnV1NUaOHBnpJhF1ysWnSFzrKoLSxadIqGoLmlqfMvUUiaysLNx222148cUXAVx48HJaWhp+9rOfYenSpZfMP2XKFDQ3N+Ott97Sym6//XZkZGSgtLS0023lmSCRhY4fP47PPvsMgUAAtbW1kW4OkWkqWqGqLV2b0ArgQsf69am1tVV3nT6fD9XV1cjJydHKHA4HcnJyUFVVpbtMVVVV0PwAkJubK84v4S0SRBbx+Xz48Y9/jClTpmDIkCGYM2cODh48iL59+0a6aUSXFRcXh5SUFHg8v7GkvmuvvRZpaWlBZUVFRVixYsUl8549exZ+v197yPJFycnJOHz4sG79Ho9Hd36Px2OqnewEiSzyy1/+Eg0NDXjhhRdw7bXXYtu2bfjJT34SNFxDdKWKj4/HiRMn4PP5LKlPVVUoihJU5nK5LKnbSuwEiSywY8cOrFmzBu+99552DeQ//uM/kJ6ejrVr12L+/PkRbiHR5cXHxyM+vmvXA0PRp08fOJ1OeL3eoHKv14uUlBTdZVJSUkzNL+E1QSILTJgwAW1tbRg/frxWNnDgQDQ0NLADJLqMuLg4ZGZmBt1TGwgEUFlZiezsbN1lsrOzL7kHd/v27eL8Ep4JEhFRxLndbsyYMQNjxozB2LFjsWbNGjQ3N2PWrFkAgPz8fPTv31+77WjhwoW466678Nxzz+H+++/Hxo0bsX//fvzxj380tV52gkREFHFTpkzBmTNnsHz5cng8HmRkZKCiokILfqmrq4PD0TF4OW7cOGzYsAHLli3DE088gZtvvhlbtmzBiBEjTK2X9wkSEZFt8ZogERHZFjtBIiKyLXaCRERkW+wEiYjIttgJEhGRbbETJCIi22InSEREtsVOkIiIbIudIBER2RY7QSIisi12gkREZFv/B6dynoyri+8jAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(5, 5))\n", + "plt.imshow(\n", + " sol.ys.vals,\n", + " origin=\"lower\",\n", + " extent=(x0, x_final, t0, t_final),\n", + " aspect=(x_final - x0) / (t_final - t0),\n", + " cmap=\"plasma\",\n", + ")\n", + "plt.xlabel(\"x\")\n", + "plt.ylabel(\"t\", rotation=0)\n", + "plt.clim(0, 1)\n", + "plt.colorbar()" + ] + }, + { + "cell_type": "markdown", + "id": "b4b4ced9-0602-4354-a1b9-277ddf70245c", + "metadata": {}, + "source": [ + "Some final notes.\n", + "\n", + "1. We wrote down the general Crank–Nicolson method, which uses a fixed point iteration to solve the implicit problem. If you know something about the structure of your problem (e.g. that it is linear) then it is often possible to more specialised solvers, which run faster. (E.g. linear solvers.)\n", + "\n", + "2. To keep this example brief, we didn't worry about doing a von Neumann stability analysis." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py38", + "language": "python", + "name": "py38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mkdocs.yml b/mkdocs.yml index 18cf5380..16b89b2d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -97,16 +97,19 @@ nav: - 'usage/manual-stepping.md' - 'usage/extending.md' - Examples: - - Basic examples: 'basic-examples.md' - - Neural ODE: 'examples/neural_ode.ipynb' - - Neural CDE: 'examples/neural_cde.ipynb' - - Neural SDE: 'examples/neural_sde.ipynb' - - Latent ODE: 'examples/latent_ode.ipynb' - - Continuous Normalising Flow: 'examples/continuous_normalising_flow.ipynb' - - Symbolic Regression: 'examples/symbolic_regression.ipynb' + - Basic ODE/SDE/CDE examples: 'other_examples/basic-examples.md' + - Coupled ODEs: 'examples/coupled_odes.ipynb' - Stiff ODE: 'examples/stiff_ode.ipynb' - - Steady State: 'examples/steady_state.ipynb' - - Kalman Filter: 'examples/kalman_filter.ipynb' + - Neural differential equations: + - Neural ODE: 'examples/neural_ode.ipynb' + - Neural CDE: 'examples/neural_cde.ipynb' + - Neural SDE: 'examples/neural_sde.ipynb' + - Latent ODE: 'examples/latent_ode.ipynb' + - Continuous normalising flow: 'examples/continuous_normalising_flow.ipynb' + - Symbolic regression: 'examples/symbolic_regression.ipynb' + - Steady state: 'examples/steady_state.ipynb' + - Kalman filter: 'examples/kalman_filter.ipynb' + - Nonlinear heat PDE: 'examples/nonlinear_heat_pde.ipynb' - Basic API: - 'api/type_terminology.md' - 'api/diffeqsolve.md' From 73f5714e2f0a4b9331b255dba670ab748d86336a Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 20 Feb 2023 14:08:14 -0800 Subject: [PATCH 18/19] doc tweaks --- README.md | 2 +- diffrax/adjoint.py | 2 - diffrax/autocitation.py | 2 + docs/api/adjoints.md | 2 +- docs/api/{citation.md => autocitation.md} | 4 + docs/api/solvers/sde_solvers.md | 4 +- docs/citation.md | 5 + docs/further_details/citation.md | 5 - docs/further_details/faq.md | 2 +- docs/index.md | 16 +-- docs/other_examples/basic-examples.md | 2 +- docs/requirements.txt | 1 + docs/usage/how-to-choose-a-solver.md | 4 +- docs/usage/manual-stepping.md | 4 - examples/coupled_odes.ipynb | 13 +- examples/neural_ode.ipynb | 4 +- examples/nonlinear_heat_pde.ipynb | 73 ++++------- examples/symbolic_regression.ipynb | 148 +++++++++++----------- mkdocs.yml | 4 +- 19 files changed, 130 insertions(+), 167 deletions(-) rename docs/api/{citation.md => autocitation.md} (64%) create mode 100644 docs/citation.md delete mode 100644 docs/further_details/citation.md diff --git a/README.md b/README.md index cdb73773..48fcb2ca 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ _From a technical point of view, the internal structure of the library is pretty pip install diffrax ``` -Requires Python >=3.8 and JAX >=0.4.3. +Requires Python 3.8+, JAX 0.4.3+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.0+. ## Documentation diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index fb4d2c36..b0000454 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -327,8 +327,6 @@ class DirectAdjoint(AbstractAdjoint): So unless you need forward-mode autodifferentiation then [`diffrax.RecursiveCheckpointAdjoint`][] should be preferred. - - This is not reverse-mode autodifferentiable if `diffeqsolve(..., max_steps=None)`. """ def loop( diff --git a/diffrax/autocitation.py b/diffrax/autocitation.py index 251ab0be..829903cb 100644 --- a/diffrax/autocitation.py +++ b/diffrax/autocitation.py @@ -44,8 +44,10 @@ def citation(*args, **kwargs): ```python from diffrax import citation, Dopri5, PIDController + citation(solver=Dopri5(), stepsize_controller=PIDController(pcoeff=0.4, rtol=1e-3, atol=1e-6)) + # % --- AUTOGENERATED REFERENCES PRODUCED USING `diffrax.citation(...)` --- # % The following references were found for the numerical techniques being used. # % This does not cover e.g. any modelling techniques being used. diff --git a/docs/api/adjoints.md b/docs/api/adjoints.md index a5870b8d..65d8713b 100644 --- a/docs/api/adjoints.md +++ b/docs/api/adjoints.md @@ -14,7 +14,7 @@ There are multiple ways to backpropagate through a differential equation (to com Alternatively we may compute $\frac{\mathrm{d}y(t_1)}{\mathrm{d}y_0}$ analytically. In doing so we obtain a backwards-in-time ODE that we must numerically solve to obtain the desired gradients. This is known as "optimise then discretise", and corresponds to [`diffrax.BacksolveAdjoint`][] below. -??? abstract "`diffrax.AbstractSolver`" +??? abstract "`diffrax.AbstractAdjoint`" ::: diffrax.AbstractAdjoint selection: diff --git a/docs/api/citation.md b/docs/api/autocitation.md similarity index 64% rename from docs/api/citation.md rename to docs/api/autocitation.md index f68aa9ce..006af82c 100644 --- a/docs/api/citation.md +++ b/docs/api/autocitation.md @@ -2,4 +2,8 @@ Diffrax can autogenerate BibTeX citations for all the numerical methods you use. +!!! warning + + This is an experimental feature that may change. + ::: diffrax.citation diff --git a/docs/api/solvers/sde_solvers.md b/docs/api/solvers/sde_solvers.md index 39b38039..1a3db677 100644 --- a/docs/api/solvers/sde_solvers.md +++ b/docs/api/solvers/sde_solvers.md @@ -14,7 +14,7 @@ See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochast diffeqsolve(terms, solver=Euler(), ...) ``` - Some solvers are SDE-specific. For these, such as for example [`diffrax.StratonovichMilstein`][], then `terms` should be a 2-tuple `(AbstractTerm, AbstractTerm)`, representing the drift and diffusion separately. + Some solvers are SDE-specific. For these, such as for example [`diffrax.StratonovichMilstein`][], then `terms` should be a 2-tuple `(ODETerm, AbstractTerm)`, representing the drift and diffusion separately. For those SDE-specific solvers then this is documented below, and the term structure is available programmatically under `.term_structure`. @@ -60,7 +60,7 @@ These are reversible in the same way as when applied to ODEs. [See here.](./ode_ !!! info "Term structure" - For these SDE-specific solvers, the terms (given by the value of `terms` to [`diffrax.diffeqsolve`][]) must be a 2-tuple `(AbstractTerm, AbstractTerm)`, representing the drift and diffusion respectively. Typically that means `(ODETerm(...), ControlTerm(..., ...))`. + For these SDE-specific solvers, the terms (given by the value of `terms` to [`diffrax.diffeqsolve`][]) must be a 2-tuple `(ODETerm, AbstractTerm)`, representing the drift and diffusion respectively. Typically that means `(ODETerm(...), ControlTerm(..., ...))`. ::: diffrax.EulerHeun selection: diff --git a/docs/citation.md b/docs/citation.md new file mode 100644 index 00000000..baf78e0c --- /dev/null +++ b/docs/citation.md @@ -0,0 +1,5 @@ +# Citation + +--8<-- "further_details/.citation.md" + +In addition, see the [Autocitation](./api/autocitation.md) page for how to get Diffrax to autogenerate a list of BibTeX citations for the numerical methods that you use. diff --git a/docs/further_details/citation.md b/docs/further_details/citation.md deleted file mode 100644 index 3841153d..00000000 --- a/docs/further_details/citation.md +++ /dev/null @@ -1,5 +0,0 @@ -# Citation - ---8<-- "further_details/.citation.md" - -In addition, see the [Create citations](../api/citation.md) page for how to get Diffrax to autogenerate a list of BibTeX citations for the numerical methods you are using. diff --git a/docs/further_details/faq.md b/docs/further_details/faq.md index 9bf3643d..146502dd 100644 --- a/docs/further_details/faq.md +++ b/docs/further_details/faq.md @@ -5,7 +5,7 @@ - Use `scan_stages=True`, e.g. `Tsit5(scan_stages=True)`. This is supported for all Runge--Kutta methods. This will substantially reduce compile time at the expense of a slightly slower run time. - Set `dt0=`, e.g. `diffeqsolve(..., dt0=0.01)`. In contrast `dt0=None` will determine the initial step size automatically, but will increase compilation time. - Prefer `SaveAt(t0=True, t1=True)` over `SaveAt(ts=[t0, t1])`, if possible. -- It's an internal (subject-to-change) API, but you can also try adding `equinox.internal.noinline` to your vector field (s). eg. `ODETerm(noinline(...))`. This stages the vector field out into a separate compilation graph. This can greatly decrease compilation time whilst greatly increasing runtime. +- It's an internal (subject-to-change) API, but you can also try adding `equinox.internal.noinline` to your vector field (s), e.g. `ODETerm(noinline(...))`. This stages the vector field out into a separate compilation graph. This can greatly decrease compilation time whilst greatly increasing runtime. ### The solve is taking loads of steps / I'm getting NaN gradients / other weird behaviour. diff --git a/docs/index.md b/docs/index.md index 5d3ca6db..52dd17af 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,7 +20,7 @@ _From a technical point of view, the internal structure of the library is pretty pip install diffrax ``` -Requires Python >=3.7 and JAX >=0.3.4. +Requires Python 3.8+, JAX 0.4.3+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.0+. ## Quick example @@ -43,16 +43,6 @@ Here, `Dopri5` refers to the Dormand--Prince 5(4) numerical differential equatio --8<-- "further_details/.citation.md" -## Getting started +## Next steps -If this page has caught your interest, then have a look at the [Getting Started](./usage/getting-started.md) page. - -!!! help - - Both Diffrax and its documentation are very new! If: - - - anything is unclear; - - you have any suggestions; - - you need any more features; - - then please open an issue or pull request on [GitHub](https://github.com/patrick-kidger/diffrax). +Have a look at the [Getting Started](./usage/getting-started.md) page. diff --git a/docs/other_examples/basic-examples.md b/docs/other_examples/basic-examples.md index 5611479a..cef5a5cb 100644 --- a/docs/other_examples/basic-examples.md +++ b/docs/other_examples/basic-examples.md @@ -1,5 +1,5 @@ # Basic examples -If you're just getting started then you can find basic examples on the [Getting started](./usage/getting-started.md) page. +If you're just getting started then you can find basic examples on the [Getting started](../usage/getting-started.md) page. The API page for [`diffrax.diffeqsolve`][] is also a useful reference for the possible solver configuration options. diff --git a/docs/requirements.txt b/docs/requirements.txt index b1aadcb2..beb6233c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -9,6 +9,7 @@ mkdocs_include_exclude_files==0.0.1 # Allow for customising which files get inc jinja2==3.0.3 # Older version. After 3.1.0 seems to be incompatible with current versions of mkdocstrings. nbconvert==6.5.0 # | Older verson to avoid error nbformat==5.4.0 # | +pygments==2.14.0 # Install latest version of our dependencies jax[cpu] diff --git a/docs/usage/how-to-choose-a-solver.md b/docs/usage/how-to-choose-a-solver.md index ef086b6f..73aed4ce 100644 --- a/docs/usage/how-to-choose-a-solver.md +++ b/docs/usage/how-to-choose-a-solver.md @@ -16,7 +16,7 @@ For non-stiff problems then [`diffrax.Tsit5`][] is a good general-purpose solver For a long time the recommend default solver for many problems was [`diffrax.Dopri5`][]. This is the default solver used in [`torchdiffeq`](https://github.com/rtqichen/torchdiffeq/), and is the solver used in MATLAB's `ode45`. However `Tsit5` is now reckoned on being slightly more efficient overall. (Try both if you wish.) -If you need accurate solutions at high tolerances then try [`diffrax.Dopri8`][]. +If you need accurate solutions at tight tolerances then try [`diffrax.Dopri8`][]. If you are solving a neural differential equation, and training via discretise-then-optimise (corresponding to `diffeqsolve(..., adjoint=RecursiveCheckpointAdjoint())`, which is the default), then accurate solutions are often not needed and a low-order solver will be most efficient. For example something like [`diffrax.Heun`][]. @@ -40,7 +40,7 @@ See also the [Stiff ODE example](../examples/stiff_ode.ipynb). SDE solvers are relatively specialised depending on the type of problem. Each solver will converge to either the Itô solution or the Stratonovich solution. In addition some solvers require "commutative noise". -??? info "Commutative noise" +!!! info "Commutative noise" Consider the SDE diff --git a/docs/usage/manual-stepping.md b/docs/usage/manual-stepping.md index 4b494098..93b8f12d 100644 --- a/docs/usage/manual-stepping.md +++ b/docs/usage/manual-stepping.md @@ -1,9 +1,5 @@ # Interactively step through a solve -!!! warning - - This API should now be relatively stable, but in principle may still be subject to change. - Sometimes you might want to do perform a differential equation solve just one step at a time (or a few steps at a time), and perhaps do some other computations in between. A common example is when solving a differential equation in real time, and wanting to continually produce some output. One option is to repeatedly call `diffrax.diffeqsolve`. However if that seems inelegant/inefficient to you, then it is possible to use the solvers (and step size controllers, etc.) yourself directly. diff --git a/examples/coupled_odes.ipynb b/examples/coupled_odes.ipynb index a1a6b71e..e166c14c 100644 --- a/examples/coupled_odes.ipynb +++ b/examples/coupled_odes.ipynb @@ -60,16 +60,6 @@ "tags": [] }, "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGfCAYAAAD/BbCUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAACXdElEQVR4nO2dd3yb1b3/39ree8d29l5kQAh7BEJombmlpbSMS8ulDTO0pem9pfsXoLctHZQuCu0tFEpbKJtCgEAgCdk7zk6ceMV2vK39/P44eiQ5cRLLlvQMnffr5ZccSZG+0vFzzud8z3dYFEVRkEgkEolEIkkSVq0NkEgkEolEklpI8SGRSCQSiSSpSPEhkUgkEokkqUjxIZFIJBKJJKlI8SGRSCQSiSSpSPEhkUgkEokkqUjxIZFIJBKJJKlI8SGRSCQSiSSpSPEhkUgkEokkqUjxIZFIJBKJJKnYh/KfH374YZYsWcK9997LY489BoDb7eaBBx7gueeew+PxMH/+fH79619TWlo6oNcMBoPU1dWRnZ2NxWIZinkSiUQikUiShKIodHZ2UlFRgdV6at/GoMXHmjVr+O1vf8u0adP63H///ffz2muv8cILL5Cbm8tdd93F9ddfz0cffTSg162rq6OqqmqwZkkkEolEItGQ2tpaKisrT/mcQYmPrq4ubrrpJn7/+9/zwx/+MHx/e3s7Tz75JM8++yyXXHIJAE899RQTJ05k1apVnH322ad97ezs7LDxOTk5gzFPIpFIJBJJkuno6KCqqiq8jp+KQYmPRYsW8alPfYp58+b1ER/r1q3D5/Mxb9688H0TJkygurqalStX9is+PB4PHo8n/O/Ozk4AcnJypPiQSCQSicRgDCRkImbx8dxzz7F+/XrWrFlzwmMNDQ04nU7y8vL63F9aWkpDQ0O/r7d06VK+973vxWqGRCKRSCQSgxJTtkttbS333nsvzzzzDGlpaXExYMmSJbS3t4d/amtr4/K6EolEIpFI9ElM4mPdunU0NTUxc+ZM7HY7drud5cuX84tf/AK73U5paSler5e2trY+/6+xsZGysrJ+X9PlcoWPWORRi0QikUgk5iemY5dLL72ULVu29LnvtttuY8KECTz44INUVVXhcDhYtmwZCxcuBKCmpoZDhw4xd+7c+FktkUgkEonEsMQkPrKzs5kyZUqf+zIzMyksLAzff/vtt7N48WIKCgrIycnh7rvvZu7cuQPKdJFIJBKJRGJ+hlRkrD9+9rOfYbVaWbhwYZ8iYxKJRCKRSCQAFkVRFK2NiKajo4Pc3Fza29tl/IdEIpFIJAYhlvVb9naRSCQSiUSSVKT4kEgkEolEklSk+JBIJBKJRJJUpPiQSCQSiUSSVKT4kEgkEolEklTinmorSQ1W7m3hoz3NXDCumLNGFmhtjiRGFEXhtS317Gro5OozKhhTcvoulBJ94fUHeWFdLa1dXj57VhUl2fFpeSFJHu09Pp5bcwib1cLn51ST4UydJVmm2kpi5uVNddzz1w0AWCzwxE0zuWJKucZWSWLhf9+q4Vfv7QEg3WHjn189h4nl8nozCoqicMf/rePt7Y0AVOSm8eo951OQ6dTYMslAcfsCXP2rFexq7AJg9vB8/nrH2Thsxj2QkKm2koRxrNvLd/61FRATnqLAQ//ahtsX0NgyyUDZeqSdX78vhEd5bhq9vgDf+dc2dLYPkZyClzfV8fb2Rhw2C4WZTura3fz07RqtzZLEwM+X7WZXYxc5aXYynDbWHjzGc2tSp7GqFB+SmPjDin0c6/ExoSybZQ9cxLC8dJo6Pby1rUFr0yQD5LF3dhNU4FPTynnxq+fitFn55EArOxs6tTZNMgCCQYWf/HsXAPdcMpZffn4GAH9fd5huj19L0yQD5Fi3lydX7Afgfz8znW/MHw/Anz8+oKFVyUWKD8mA8QWCPL/mMAD3XDqWdKeN62cOA+CVTfVamiYZIHVtvby7U7jqF182jrLcNC4aXwzAa5vlGBqBFXuaOdTaQ3aandvPH8ncUYVUF2Tg9gV5d2eT1uZJBsDf1x3G6w8yuSKHyyaVct3MShw2C7ubutjdmBqbACk+JAPm3Z1NNHd5KM52cdmkUgA+Pa0CgOW7muhw+7Q0TzIA/r7uMEEFzh5VwOjiLEB4QABe2yLFhxF4fq1wzS+cWUmG047FYgmP4etyDA2BOoZfOHs4FouF3HQH548NbQJSZAyl+JAMGPVo5erpFeGgqPFl2QwvzMAXUFh38JiW5kkGgDqG18+oDN936cRSbFYL+5u7qWvr1co0yQBw+wK8H/JuXDdjWPj++ZPLAPhoTzPBoIzd0TN7mrrY09SFwxYRjQDzJ4sN3Ud7mrUyLalI8SEZEP5AxKWrej1UZg8XqbZrD7Qm3S7JwDnS1su2ug6sFrh0Ykn4/iyXnUmhTJe1UkDqmpV7W+j2BijLSWPqsNzw/ZMrckh32Ohw+9lztEtDCyWnQ81Qmju6iJw0R/j+M0eIeXTT4XY8fvMH8EvxIRkQG2rbaOvxkZfhYPbw/D6PnTlC/HvtAblw6RlVPM4ank9hlqvPY7NCY7pOCkhdsywUrzNvUglWqyV8v8Nm5YyqPADpgdQ5aszV8Zu4kUWZFGQ68fqDbKvr0MK0pCLFh2RArNzbAsB5Y4qwH5eHPjskPjYdbsMXCCbdNsnAWBUawwtCZ8vRqGMoPR/6ZuUpxlAVkHIToF96vH421rYBcMHYoj6PWSwWZlarmwDzj6EUH5IBsWqfmPTmjCo84bHRxVlkuey4fUH2N3cn2zTJAFAUJTyGZ48+cQynV+YBsKuxUwpIndLU6Wbv0W4sFvqtKqx6PrbVtSfZMslAWX+wDV9AoSI3jeqCjBMen1GdB6TGGErxITktHn8g7MqdO+rESc9isTC2VGRO1MhaEbpkT1MXLd1e0hxWplXmnvB4ZX46mU4bvoDCASkgdcnqfeJIbGJZDnkZJ1YyHV8mSuTvO9otBaROCW8ARhVisVhOeHx8qRjDmkbzx+1I8SE5LZtq2/H4gxRlOcPpmcczITTx7UqRHHWjoU56s4bn47LbTnjcYrEwrkyd+OQY6pHohas/huUJAekNBDnYIgWkHll5mjFUBeTepi78JheQUnxITsvqqCOX/tQ6wLiQYpdVMvXJqtCu+eyR/U96ENl17ZJjqEsi4qP/Ro5Wq4Wx6s65wfw7Z6PR6w2wKRTvMbefo08QAjIjJCAPtPQk0brkI8WH5LSsOySOXM4acfLutWF3oVy4dMn60BjOPsUYjiuVng+90t7jY+9R4c04cyDXoRxD3bG1rh1/UKE0x0Vlfnq/z4kWkGb3IkvxITkliqKw5bAIfpoeCmjrD/WCqT3WkxI56kaiqdNNfbsbiwWm9hPvoaLG7expkrtmvbHliLgGqwsyyD9F51p1DPfKMdQdqtdjWmXeST3IAONKxBjuNnnchxQfklNypK2Xlm4vdqslHNfRH0VZTjKcNhQFDh+TVTL1hCoex4Sykk7GiMJMAGqP9coqmTpj0+E24NTiEWB4aAwPtsqYD72xOXQdTht26jEcUZQaYyjFh+SUqAvX+LJs0hwnBiqqWCyWcOrYIZOfVRoNddI73cJVnpuG3WrB6w/S0OFOhmmSARL2Pp5WfIhr8GBzD4oiBaSeUL1X007hQQZSZh6V4kNySjapaj1UB+JUqDvnAzLSXldsDu2ap59mDO02K1WhiU+Oob5Qx/B016G6cHV6/BzrkY0e9UJ7ry9cA+l0no+wgGyV4kOSwmw50gbQb22I4wlfNCZX7EZCUZTwjut0ng9InV2XkTja6aEuFLMz5TQLV5rDRllOGoBMt9URqueqqiD9lDE7AMMLxCbuaKeHHq8/4bZphRQfkpMSDCqRc8oBiQ9x0RwyuWI3EnXtbpq7RMyO2jzuVKTKrstIqBuA0aeJ2VGpDo2hvA71w6YBeq4AcjMc5KaLhnNmHkMpPiQn5WBrD51uPy67NZyGeSoing+549ILm0MR9uNKTx2zoyI9H/pjU+3ANwAAwwukB1JvDDRmRyUVvMhSfEhOinrOPKkiB4ft9H8qw/JE7npdm1sGu+mEzUfUNOmBTXqV+WLSO9wmM5b0QjhQ8TRHLirqGB6RWWe6IXz0OSxvQM+vSoExlOJDclK214u2zpMrTu+uByjLFWfNvb4AbTLYTRfsCI3hpIqBLVwVeWIM66X40A3qGE4eoPhQx7CuXY6hHmjv9XEkdD1NGuBcWh6aS+tNPIZSfEhOilqtdELZwC6YNIeNoiwRTCUnPn2gjuHEU9RoiaYi5L062uXB6zd3bwkj0N7jo75dpD2Pj3EM1f8n0Rb1GhyWlx6O5Tgd5VFeZLMixYfkpOysDy1c5QOb9ADKc81/0RiF6IVr3AAXrsJMJ067FUWBRlnrQ3N2Ngivx7C8dHLSBrZwVYQXrl55/KkD1DEcqHgEGJYC3ispPiT90tbjDReaGkiwqUrYbW/ii8YoDGbhslgsYZdvnTx60Ry1R8upqgsfjzp+Pd4AHb3mTdU0CjsbBjOGIe+ViTdxUnxI+kW9YCrz08ke4MIFkYvmiFy4NEcdw1g8VwAVqvdKCkjN2RHyPsaya05z2CjIlMefemFnfeyej/LQJq6p040vYM7jz5jExxNPPMG0adPIyckhJyeHuXPn8sYbb4Qfv+iii7BYLH1+7rzzzrgbLUk8NYNQ6xAdsGhexW4UVPERy6QHkYlPHp1pT03IezVhADVaogkHncpNgKYEgwq7Qg3iJsYwhkWZLhw2C0ETH3/GJD4qKyt5+OGHWbduHWvXruWSSy7hmmuuYdu2beHnfPnLX6a+vj788+ijj8bdaEniUV32Aw02VQm7C+WOS3MGO4YVcgx1QTCoxBwwrBKOvZJBp5pypK2XLo8fp83KyFDDuIFgtVqi5lJzjuHpy+VFcdVVV/X5949+9COeeOIJVq1axeTJkwHIyMigrKwsfhZKNGGwu+bSHNVd6Im7TZKBEwwq7Bqk96o0xwVAU4ccQy050tZLtzeA02YNdzodKOoYHjXprtkoqPPo6JKsAdVKiqYsJ41DrT3S83E8gUCA5557ju7ububOnRu+/5lnnqGoqIgpU6awZMkSenrMW6HNrETvuGJduIqzQ5OeFB+acvhYZOGKZccFUJwtBaQeUOt7jBnEwlUix1AXqPEesXquwPxzaUyeD4AtW7Ywd+5c3G43WVlZvPjii0yaNAmAz3/+8wwfPpyKigo2b97Mgw8+SE1NDf/85z9P+noejwePJ/LldnR0DOJjSOLJ4WO99Axy4SoJXTA93gBdHv+AelFI4o965DKmJAt7jAuX2Sc9ozDYDQDIMdQLg/UgQ2QMzSogY14Zxo8fz8aNG2lvb+fvf/87t9xyC8uXL2fSpEnccccd4edNnTqV8vJyLr30Uvbu3cvo0aP7fb2lS5fyve99b/CfQBJ3dgxh4cp02cl02uj2BmjqcJNVnJUIEyWnIZzeF2OmC0QE5NFOD4qiYLFY4mqbZGDEYwzNunAZhZ2DDBgG8wvImI9dnE4nY8aMYdasWSxdupTp06fz85//vN/nzpkzB4A9e/ac9PWWLFlCe3t7+Ke2tjZWkyRxpmYIkx6Y/6IxAvHYNXsDQVknQkMixankwmVE3L4A+5tFk83BHLuUmHwMh+wTDwaDfY5Notm4cSMA5eXlJ/3/LpcLl8s1VDMkcWQoCxeI8+YDLT1y16UhQ1m40hw2ctLsdLj9NHW6yc0YeJ0XSXzw+CML12CuQzXmo7nLQzCoYLVK71Wy2dPURVCBvAxHWAzGgjx2iWLJkiUsWLCA6upqOjs7efbZZ3n//fd566232Lt3L88++yxXXnklhYWFbN68mfvvv58LLriAadOmJcp+SQLY0yTy0seWDNLzkWNuxa53fIFguBX32JLBHXuV5KTR4e7iaKeHsTFUuJXEh/3N3QQVyE6zh3fAsVCY5cRiAX9Q4ViPl8IsucFLNnuPqvNo1qCOLs3uvYpJfDQ1NXHzzTdTX19Pbm4u06ZN46233uKyyy6jtraWd955h8cee4zu7m6qqqpYuHAh//M//5Mo2yUJwB8IhndcYwa5cBVnmVux652DLd34gwqZTlu41HasFGe52NPUJcdQI/Y2Ra7BwSxcDpuVggwnLd1emjo9UnxowN7QJm6w86jqvWrp9uAPBGOOv9M7MYmPJ5988qSPVVVVsXz58iEbJNGWw8d68QaCuOxWhoUaVMVKiVonotOc+el6Z09o4Ro9yIULImNo1l2X3lG9j6OHELBdnO2ipdvL0U4PE09+8i1JEHuODm0MCzKdWC0QVKC120tJzuA2EnrFXFJKMmTUSW9Ucdagz4lVxS4XLm3YO8RJD6KzJaSA1AJ14RrsrhnMHzOgd8ICcpBjaLNaKDKxF1mKD0kf4jnpSfGhDXuG6O4FuXBpTdhlP0TPB8jrUAv8gSAHmkXclRzD/pHiQ9KHeEx6Zk8R0zvx8XxI75VWBIMK+5qHtmuG6Cqn0nuVbGpDx9dpjsEfX4O5PZBSfEj6ED6nLImtsmk06gXT0u01bTtovaIoSlSg29DHUHo+ks+Rtl7cviBOm5Wq/MEvXGbeNesd9RocVTT442sw9xhK8SEJoyhKXFz2+RlO7KELrrnLfBeNnqlvd9PtDWC3WhheOHjxET52MWlTKz2jbgBGFmUOKcNBCkjtiMfxNUjxIUkRjnZ56HT7sVpgxBAWLqvVQn6mE4CWLm+8zJMMAPXIZXhhRszNyKJRUzM73H7pvUoye5uG7n0EUesDoEVuAJJOPLKVIFK2oNmE86gUH5Iw6gVTVZBBmsM2pNcqDImP1m7zXTR6Jl6TXl66A9VbfEyOYVLZE4e4K4DCTLFwyWsw+eyNk+cjvInrNp+AlOJDEiYewaYqBVJ8aEI8js0g5L3KUCc+OYbJJBwwPMQxVK/Btl4fgaAyZLskAyNex9dgbgEpxYckzN6jQ6tsGk1Bply4tCBeOy6QAlIr4uW9yg/15FEUONYjxzBZHO2MOr4uyhjSa5n5GpTiQxImXpMeRB+7mM9dqGfC1U3j6L2SAjJ5tHR5ONbjw2IZ+hjabVbyQgLEjIuXXlGDTasLMnDZh3h8HYrbOdbjI2gy75UUH5IwQ63IF02Bid2FeqW9xxfOLorHGKoTX6sMWEwaqvdxWF466c6hLVwQJSBNGLCoV4ba0yUa9egzEFRo7/UN+fX0hBQfEgC6PH4aQmmVcYn5yJKTXrJRd1zluWlkuWJq29QvZnb56pV4eh9BBn5rQTzH0Gm3kp0mrmWzeSCl+JAAEbVelOUiN+SqHQpy0ks+e+O8cKneK7NNenomnjE7EC0gpfcqWajeq3h4H8G8c6kUHxIgOktiaLUFVOSuOfnEq7CRilknPT0TrywJFSkgk0/8x9Cc16EUHxIg/juuQhmsmHT2xjFmB6JrDMgxTBbxPnYpyBReTFmrJTl0un3h4+v4jaEUHxITE+9JT1242nt9skJmkgj35SmOj/dKej6SS683wJG2XkB6PozKvtCRS3G2i9z0oR9fg3mPzqT4kADxd9nnZzixqBUyZY2BhOP2BahtDbXwjrO7V+6ak4PqfSzIdIa/+6EiBWRyiVd12mjMKiCl+JDgCwQ51BLfhctmtZCXLmsMJIsDLd0EFchJs4f7QQwVdeE61uM1XY0BPbI3zp4rMK/LXq/Eoyv48ZhVQErxIeFgSzf+oEKm00ZZTlrcXjc88cl024QTXaPFYhl8C+9o1KOzoCJKdEsSSzzrQ6jIQnHJJZ4tKlTMKiCl+JAkZOGCSF8COfElnkS4ex02KzmhGgNmO2/WI5GYnfiNYbhCZrcXRZHeq0QTOb7OjttrmrVmkhQfkkhPlzhOemBexa5H4l1bQKUwdIRjtolPj8SzwrCKeg36gwodvf64va7kRLz+IAdDx9fy2OX0SPEhScikBxHFLgNOE08iPB8gBWSy8AeCHGgOxV3FcQxddlu42q0Z27LriUOt3QSCClkue1yPr9US660m815J8SGJe5qtihpw2tYj4wUSSSCosC/O2Uoq4TGUMR8JpfZYL95AkDSHlWF56XF9bbW5nBzDxBKZRzPjenytbgC8gSC9vkDcXldrpPhIcRRFiSowFj9XIUQUe5v0fCSUurZePP4gTpuVyvx4L1zSe5UM1IVrVFEWVmv8Fi6IEh9yDBNKojZxGU4bDpv4mzhmoo2cFB8pTn27mx5vALvVwvDC+IoPddIz0wWjR9RJb2RRJnZbfC/p/NAYtssxTCjxrjAcTWQTIMcwkSQq7spisUQ2ASY6/pTiI8VRF67hhRk44rxw5UnPR1KIdy+JaCICUo5hIknkGKqVNuUmILEkcgzDmwATHZ1J8ZHi7E1Aep9KvjxrTgqJKE6lEjl2kWOYSBLlsgd5/JkMgkEloXNpXrr5jj+l+EhxkrJrNpGrUI8kKlsJIguXPHZJHH3jrhK4CZBjmDAaOqKPrzPi/vp5JhxDKT5SnMSKD7Fwdbj9+GVzuYSgKErc+/JEI49dEs/RTg+dbj9WC4woiv/ClSuDhhOOOo+OKMqM+/E1mDNoWIqPFCccJJUQV2Gkq2OHWxY4SgSt3V7aenxYLCJTIt7INM3Eoy5c1QUZuOy2uL++GeMF9EZ0mm0iMGPQsBQfKUx7j4/mLlF4KBEue7vNSnaowJHcdSUGddIblpdOujP+C1d00LCZChzpiUQeuYD0XiWDxI+h+WKvpPhIYfYc7QSgPDctXAUx3uRlms9dqCcSeeQCkV2zL6DQ7TVPgSM9kciYHYhauLrNs3DpjUQeX4M8dpGYjL1NiTtyUTGju1BPJHoM0x02nHYxTZhp4tMTiTz6hKigYXnskjASmekC5swclOIjhUn0rhlkjYFEk+gxtFgsskx+gkn4rjk0fl0eP16/DPyON209XppDjRcTJT5yUz3V9oknnmDatGnk5OSQk5PD3LlzeeONN8KPu91uFi1aRGFhIVlZWSxcuJDGxsa4Gy2JD4l294KsMZBo9iZ44QLpvUoknW4fDR1uIHELV066A7XViPR+xB/V61GRm0Zmgo6v8zPNV2k4JvFRWVnJww8/zLp161i7di2XXHIJ11xzDdu2bQPg/vvv55VXXuGFF15g+fLl1NXVcf311yfEcMnQSWRxKhVZYyBx9Hj9HGnrBRJ7dCYDFhPHvtCRS3G2K+wljDc2qyX82nITEH+Suonr9Zkm8DsmmXbVVVf1+fePfvQjnnjiCVatWkVlZSVPPvkkzz77LJdccgkATz31FBMnTmTVqlWcffbZ8bNaMmTcvgC1raEW3ok8dpE1BhKGunAVZDrDnS8TgRmD3fRC+MglgeIRxNFLW49PHn8mgETH7EDk+DoQVOhw+xMmVJPJoGM+AoEAzz33HN3d3cydO5d169bh8/mYN29e+DkTJkygurqalStXnvR1PB4PHR0dfX4kiWd/czdBBXLS7BRnuRL2PmYMlNILyVq45LFL4khG3BXIPkuJJNExOwBpDhvpDpFKb5ajl5jFx5YtW8jKysLlcnHnnXfy4osvMmnSJBoaGnA6neTl5fV5fmlpKQ0NDSd9vaVLl5Kbmxv+qaqqivlDSGInfORSkoXFEt8W3tHImI/EERnDxB2bAeTK7sQJY2+Ci1OpmLE8t15IZF+eaMx2/Bmz+Bg/fjwbN25k9erVfOUrX+GWW25h+/btgzZgyZIltLe3h39qa2sH/VqSgZOsXXN44ZI1BuJOsiY9KSATR8TzkZ3Q94nEDMgxjCduX4DaY4k/voYo75VJvMgxh+Y6nU7GjBkDwKxZs1izZg0///nP+exnP4vX66Wtra2P96OxsZGysrKTvp7L5cLlSpzbX9I/yXAVgqwxkEiSN4by6CwReP1BDrYka+GS3qtEsL+5G0URMRlFWYmLu4Lo4H1zCMgh1/kIBoN4PB5mzZqFw+Fg2bJl4cdqamo4dOgQc+fOHerbSOJMshauvHRzuQr1gj8Q5ECLCHRL9BiascaAHjjY0k0gqJDlslOak9gNmNqS3SwLl16InkcTeXwN5usSHpPnY8mSJSxYsIDq6mo6Ozt59tlnef/993nrrbfIzc3l9ttvZ/HixRQUFJCTk8Pdd9/N3LlzZaaLzggEFfY1J2fhUj0fPd4AHn8gIY2zUpGDrT34AgrpDhsVuekJfa9wYzK5a44r0SmaiV648jNlzEciSNbxNaT4sUtTUxM333wz9fX15ObmMm3aNN566y0uu+wyAH72s59htVpZuHAhHo+H+fPn8+tf/zohhksGz+FjPXj9QZx2K5X58W/hHU12mh2rBYKKWLxKcqT4iAeRhSsTqzXRC5f0fCSCZC5cudIDmRCSla0EmK7ScEzi48knnzzl42lpaTz++OM8/vjjQzJKkljUSW9UUSa2BC9c1lCBo2OhGgMlOWkJfb9UIak7rvRIS/ZgUEm42EkVkrlwyXTpxJCMCsMqZgv8lr1dUpBkVOSLJl8WGos7yZz01IyloAKdbn/C3y9V2JOkNFuQ4iMRJPP4GsyX8i7FRwqSzF0zRAKlZMZL/NiT4C6a0bjsNjKc4rhMCsj4EAwq4TotSXHZm6xGhB6obRXH1y67lWF5iY27Aun5kJiAZE56EAmUkgGL8UFRlKR6PkCmTMebI229uH1BnDYr1QWJjbuCiPjw+IO4fYGEv18qED6+Ls5KylGk2TZxUnykGIqiJC3NVkWm28aXhg433d4ANquF4YWJd9mDDFiMN6rnakRRBnZb4qfhLJc9HN9llsVLa5IZswPmq7cjxUeKcbTLQ4fbj9UCI4uStHCZ7KLRGlU8Di/MwGlPziVstl2X1iTbc2WxWOQmIM4k+/harbejBn4bHSk+Ugz1gqkqyCDNkZy010iBI7lwxYNkT3oge4PEGy3GMFeOYVxJtgdZ9T4qCnS4jT+GUnykGHs1XLjaZV+JuJDsSQ8iuy65cMWHZGecgfnqRGiJFnFXTruVzFDgtxnGUIqPFEOLhUvumuOLpmMoBeSQURQl6fECEBX4LcdwyDR1euj0iOPrEUWJDxhWMVOVUyk+UoxwimZSd81SfMSTZGcrQVShMTmGQ6a120tbjw+LJTmp0irS8xE/9objrjKT2jIiMpcaX0BK8ZFiaLNrlmma8aKtx0tzl5h4krpwyaDhuKFeg5X56UmLuwIZ+B1PkllnJxozBX5L8ZFCdLh9NHZ4AG12zWZQ61qjLlwVuWlkumLqjjAkcmVX1LgRPnJJ9sIl43bihhabODDXEbYUHymE6iosyXaRk+ZI2vuqF0y3N4DXH0za+5oRLQIVQXo+4onWC5eM+Rg6Wo2hmQK/pfhIIbS6YLLTHKgdw83gLtQSzRcuE0x6WqP1GJph4dIazcfQBAJSio8UQosIewBbqLMtyF3XUNEi2BSi+kr0+lAU4xc40pK9TVrFC5hn16wlHW4fTZ3i+DoZTQGjyTfRJkCKjxRib5PowJjsSQ9kpH280CpeQBWPgaBCl0d2th0s3R4/de1uQINdc7p5ghW1RPV6lOa4yE7i8TVExe2YYAyl+EghtNo1A+TKXdeQcfsCHD7WCyQ/5iPNYSPNIaYLOYaDR70Gi7KcYU9EspCdbeODVkcuEMlYMsMYSvGRInj8AQ62CM+HFheN7CsxdPYe7UJRxCJSmJnchQtktkQ82KPRkQtExq/HG8Djl51tB4sWVaJVzFRvR4qPFOFAcw9BBbJddkqyXUl/fzPlp2tFdD8QiyXxLbyPx0zBblqh5a45O80uA7/jgJZjKCucSgxHdIqmJguXjPkYMsnuJXE8slLt0NFy4bJGB37LMRw0WlSJVolkLHkN39lWio8UQctJD6JiPuSuedBola2kImt9DB3NxzBdjuFQcPsC1Lb2ABrFfITGL6hAl9fYgd9SfKQIu5o6ARhXqvGkJ3dcg2ZXo1i4xpZma/L+asxAu4zbGRRuX4CDLWLhGqfRGMrA76Gxp6mLYCjuqjgr+cfX0YHfRvdeSfGRIuxqEOJDs4VLxnwMCY8/wP5mETCsmYCURaqGxL6j3QSCCjlp2sRdgWx1MFR2hzdx2ZocX4N5Ar+l+EgBvP5geOEar7H4MPoFoxX7m8XClZ1mpywnTRMbZGOyoaEuXOPLNFy45CZgSKjeR602AGCewG8pPlKAAy3d+IMK2S475bkaLVzpMuZjKEQmPbnjMiq7GrX1PoI8/hwqqgdZq2MzME/gtxQfKUBN6IIZU6pNpgtIz8dQiUx62u+4ZIn8wVHTEBKQGgWbggz8Hiq7mrQXH/kmSbeV4iMF2B3acWl15AKRHVen248/IDvbxoq6a9Zy0pO75qERjhco03LhkmM4WHq8fmpbRYVhTa/DcH8XYwtIKT5SAK2zJCDiKgTocBs7RUwLdjdFjl20wkwFjpJNrzfAoVZtM11AeiCHwu7GSGn8Ag0qDKtESqwbewyl+EgBIrtm7dy9dpuV7DQ7ICPtY8XtC3CgRc100cPC5ZWdbWNkT5MojV+Y6aRIgxRNlTwZezVo9OB9BPPEXknxYXKiFy4tj10gurGVsS+aZKMuXPkZDoqytNtxqePnCyj0eGVvkFiIBJtqtwGAqIwleQ3GjB68j2Ce2CspPkzO3qOiKE5uuoNijWoLqISLVBn8okk2aqzAWA0zXQDSHTactlBnW3n0EhP62TXL8uqDpaZBHwLSLLFXUnyYHPWccrzGCxfI8+bBomZJaO25slgsUTtnKSBjQTfiIxS30+nx45OB3zGhh8B9ME+9HSk+TI5e3L1gnvz0ZLNbBzE7KnLnPDii67RoSU4o7gqgw+CLVzLpdPuoa3cD2gbug4z5kBgEvey4QDYmGyw1OihOpSLHMHa6PH6OtKkpmtoKyD6B33IMB4wqHktzXH0y97QgOubDyIHfMYmPpUuXcuaZZ5KdnU1JSQnXXnstNTU1fZ5z0UUXYbFY+vzceeedcTVaMnD0suMC2ZhsMHR7/Bw+pn1tAZVck+y6konquSrJdoWPPbREHn/Gzm4dbuKMHvgdk/hYvnw5ixYtYtWqVbz99tv4fD4uv/xyuru7+zzvy1/+MvX19eGfRx99NK5GSwZGj9dP7TG1toAOXPZy1xwze5rU2gIuTWsLqJilr0Qy2a2jDQDIwO/BoKdNnFkCv+2nf0qEN998s8+/n376aUpKSli3bh0XXHBB+P6MjAzKysriY6Fk0ETXFijUsLaAioz5iJ0aHcV7gIz5GAw1Ooq7Aun5GAx6qJWkYrFYyMtw0NTpoa3Hy7C8dK1NGhRDivlob28HoKCgoM/9zzzzDEVFRUyZMoUlS5bQ09Nz0tfweDx0dHT0+ZHEBz2pdZAVMgeDnty9IBeuwbBLJ1kSKuHrUI7hgNFT7ByY4zqMyfMRTTAY5L777uPcc89lypQp4fs///nPM3z4cCoqKti8eTMPPvggNTU1/POf/+z3dZYuXcr3vve9wZohOQV6ypIA8/QkSCZ6E5CyMVns7NZBe4NownUi5CZgQLT1eGnq9AB6GkPjC8hBi49FixaxdetWVqxY0ef+O+64I/z71KlTKS8v59JLL2Xv3r2MHj36hNdZsmQJixcvDv+7o6ODqqqqwZoliSLsstewkVU0ctKLnfCuuUwfAlI2JouN9l4fDR1qiqY+xjBP1mqJCXUDMCwvnSzXoJfMuJJrgtirQX2Td911F6+++ioffPABlZWVp3zunDlzANizZ0+/4sPlcuFyaR+PYEb0FuiWG04R8xEMKlit2hY90zsdbh/1odoCY0r0MYZm2HElE9X7WJGbRk6atimaKjL2Kjb0FO+hYoYqpzHFfCiKwl133cWLL77Iu+++y8iRI0/7fzZu3AhAeXn5oAyUDI5Oty9SW0BnC5eiQKfsbHtaVPFYlpOmeW0BFZntEht66Ch9PDL2Kjb0FncF0bU+jDuGMXk+Fi1axLPPPsu//vUvsrOzaWhoACA3N5f09HT27t3Ls88+y5VXXklhYSGbN2/m/vvv54ILLmDatGkJ+QCS/lHVemmOK+xx0Bqn3Uqm00a3N8CxHq9u7NIrOxtE8LVejs1A7ppjRR3DCToaw0jGkhSQA2FHg3r0qaMxDAcNG3cMY/J8PPHEE7S3t3PRRRdRXl4e/nn++ecBcDqdvPPOO1x++eVMmDCBBx54gIULF/LKK68kxHjJydlRLy6YieU5GlvSF7nrGjg7w2Oop0lPLFwefxC3z7gFjpLFjnohPvR0Hcp6OwNHURRdjqEZNgExeT5OV8q1qqqK5cuXD8kgSXzQ4wUD4qI50tZraMWeLNQxnKSjMcxy2bFZLQSCCm09PspybVqbpFsURQkLyAk6FJBGXriSRV27m063H7vVwuhiHcV8mEBAyt4uJmVnyFWoJ3cvmOOsMhkoihI1hvoRHxaLJSprSQrIU3GkrZdOjx+HzcKoIv0sXGqJ/A63j0DQuL1BksHO0AZgTEkWTrt+lstIqwrjzqP6+TYlcSMYVMIXjd48H3LXNTAOH+uly+PHabMyqjhTa3P6kCvHcECoR59jSrJ1tXCpLnsR+C3H8FTo1YNshsBv/VwRkrhx+Fgv3d6AWLiKdLZwyVTNAbEjasflsOnrMjVDml8yCG8AdOZ9VAO/QY7h6dihcw+ykcdPX7OaJC7sCEXYjy3Nwq63hcsEij0Z7NBhrICKGjQsG5OdGvU61PMYGjlmIBno1/Mhxs/jD9Jr0M62+lqZJHFBrxcMyMZkA0VN0dRTsKmK9HwMjJ06zTgDWeV0IPR6AxxoFh3b9SYgM5027KEijUbdyEnxYULCEfY6cxWCOaK0k4Eeg01VcuUYnpZeb4D9LaGFS4djaAa3faLZ3dRJMNQVvFgHXcGjUTvbgnHHUIoPE7JDx7vmSMyHMdV6Mujx+jkQWrj0VONDJV92RT0tNY2dKAoUZbkoztbXwgXRZfLldXgyoj3IFov+WkEYvdaHFB8mo9vj52BLD6Cvinwq0vNxemoaxMJVnO2iUGc7LpAu+4EQyTbT3zUI0ns1EHbo2IMMxo+9kuLDZKju+hKdL1wy5uPk6H3SM/qOKxmou2a9jqGM2zk94THUoQcZjD+GUnyYDDVQUY9BbhDlsu/1nbZibqqi52BTkJkSA0FN0dTrdSiL/Z2a6CJ/0nuVGKT4MBkRta7TCyak1gNBhU6P7GzbH3osyR2NbEx2akRZddXzoVPxIWM+TklDh5v2Xh92q4UxJfqpThtNnsFrJknxYTLUhUuvu+Y0h400h/izk0cvJ6IoSjhgWO+7ZqPuuBJNXbubjlA/EL0uXEbfNScadRM3ujgLl12f/Ysi3itjCkgpPkyEXvuBHI/RFXsiOdLWS6dbf/1AolHHr8cbwOM3ZoGjRKLXfiDRyHo7p0bPRf5U8mWqrUQv6LkfSDSyyunJUSe90cX6Xbiy0+yomYcyZuBE9B5sCjJu53QYYROXGxrDYwY9OtPn7CYZFNt13A8kGpktcXLUhUuvx2YAVqslPIZy53wiO3Rc2VQlOl06KDvbnsD2unZAv8GmILNdJDpi2xFxwUwZpt9JD2TMwKnYGhrDycNyNbbk1IQnPjmGJ7C1Tr0O9TuGqngMKtDllYHf0XR7/OwLlVWfXKHfMTR6xpIUHyZia53YNet50oNIzIDMljiRbeoYVuhbQObKKqf90t7rCxf5m6zjMUxz2Eh3hDrbdssxjGZHfQeKAmU5abqsTqti9Ng5u9YGDJZAIIDPZ8wvPVEcbetkWLaNySXpuN1urc0BwOFwYLP1jRY3ek+CRNHa7eVIWy8Ak3S8cEG0y1cKyGi2h8RjZX56OK5Cr+RlOOhtD9DW66WaDK3N0Q1bDeJBVjOWen0B3L4AaQ59ZuWcDMOJD0VRaGhooK2tTWtTdEUgqHDPWXlYgDRPC/v3t2ptUpi8vDzKysrC/RFkml//bAu560cWZZKd5tDYmlOTb3CXb6JQx3CKjt31KrnpDurb3XITcByqB1nPRy4A2S47Vos4Ouvo9UnxkWhU4VFSUkJGRoYuG/5oQZfbRyCzF6fNxkidZLooikJPTw9NTU0AlJeXA8Z3FyaKrUfUSU/fOy6IZEsYNdI+URhl1wwy9upkRMZQ3+JDDfw+1uOjrddHSU6a1ibFhKHERyAQCAuPwsJCrc3RFe1esNgDZGU4SUvTzx9heno6AE1NTZSUlGCz2QxfHCdRGCFQUUVmLPVPeNdsgDGUsVcn4vYF2N3UBRhFQDqF+DDgdWiogFM1xiMjQ55PHk+vVxR70qPrTR0vdfyMniKWKMLZSjp394LcNfdHj9fP3qOhhctIYyivwzA1DZ0EggqFmU7KDOBJyDVw7JWhxIeKPGo5EbdPiI90p/7Ex/HjJQscnUiH28cBA2RJqMjuxCeiZkmU5rh0nSWhImOvTkT1Pk4elmuIdcbImwBDig9JX/yBIN5AEIB0h/6HNLrAkexsK9gWivcYlpdOfqa+syQgKm5HHp2FUWN2jOD1ABl71R+RMdT/BgAiXcKNuAnQ/0olOS29Ia+Hy27FZtX/kKriwxdQ6PHK3iAQlSVhgHNmiNo1G3DSSxRGKRCnImOvTmSbgeKuIHLsYsTAb/2vVCbh1ltvxWKxYLFYcDqdjBkzhu9///v4/UOvLqiKDz3Ge/RHusOGM1T+3YjuwkSw1UDxHiAbk/XHVoMUiFORsVd98QWC4a7ghrkO5bGLZCBcccUV1NfXs3v3bh544AG++93v8uMf//iE53m9salYNdhUj/Ee/WGxWKJ2zsZT7InAKNVpVdS4nU6PH1/oyC+VcfsC7G4MLVwGGUMZ89GX3Y1deANBstPsVBWka23OgDDyJkCKjyTicrkoKytj+PDhfOUrX2HevHm8/PLL3HrrrVx77bX86Ec/oqKigvHjxwNQW1vLDTfcQF5eHgUFBVxzzTUcOHAAgA8++ACHw0FDQ0Mk2NRh47777uP888/X6iMOGCNfNPEmOktiskGOXXLSIln6HXLxYldjJ/6gQkGmk/Jc/WdJQCReQG4ABFujCsQZIdgUooP3jTeGhqrz0R+KooSPHZJNusM2pD/S9PR0WlpaAFi2bBk5OTm8/fbbgEhLnT9/PnPnzuXDDz/Ebrfzwx/+kCuuuILNmzdzwQUXMGrUKP705z+z4PN3AGAnyDPPPMOjjz469A+XYIzsLow3apZESbaLkmxjLFx2m5XsNDudbj9tvT4Ks/Sf3ZFIogvEGWfhihy7KIpiGLsThVEac0Zj5Ngrw4uPXl+ASQ+9pcl7b//+fDKcsX+FiqKwbNky3nrrLe6++26OHj1KZmYmf/jDH3A6hZL9y1/+QjAY5A9/+EN4UnjqqafIy8vj/fff5/LLL+f222/nj089zYLP34HTZuWN11/D7XZzww03xPVzJoJcGWkfJhxhbxB3vUpehkOIDzmGhioQp6Jmu/iDCt3eAFkuwy8HQ8JoR59g7LgdeeySRF599VWysrJIS0tjwYIFfPazn+W73/0uAFOnTg0LD4BNmzaxZ88esrOzycrKIisri4KCAtxuN3v37gVEEOu+vXvYvH4NaQ4bTz/9NDfccAOZmfoor34qIp4P47kL480WNUvCIIGKKuE0PzmGkUwXA41hmsOK0x4K/E7xoxd/IBhuCqj3ni7R5IWvQeOJD8NL3XSHje3fn6/Ze8fCxRdfzBNPPIHT6aSiogK7PfL1Hy8Yurq6mDVrFs8888wJr1NcXAxASUkJl85fwEt/e4YzJo3jjTfe4P3334/9g2iAjPmIsKm2DYDplXma2hEr4TS/FG/J7vYF2FEvFi4jjaHFYiEv3UFTp4e2Hh+V+VpbpB27m7ro9QXIdtkZVaT/zZuKOo92hQK/HTbj+BMMLz4sFsugjj60IDMzkzFjxgzouTNnzuT555+npKSEnJyT76au+9zNfO2r/8mE0SMYPXo05557brzMTSiytLOgy+NnTyjYdFqVcXZcICvVquyo78AXECW5K/ONkSWhkpchxIcRd87xRN0ATK3MxWo1TuxLTnqk+3V7r48iA8VeGUcmpRg33XQTRUVFXHPNNXz44Yfs37+f999/n3vuuYfDhw8DwlV41vkXk5mVzY8fXsptt92msdUDJ1d2RQVgy+F2FEVUNjVKsKlKnoH7SsSTsOeqKs9wQZtq3EeqX4ebDrcBYgyNhM1qCWeeGW0jF5P4WLp0KWeeeSbZ2dmUlJRw7bXXUlNT0+c5brebRYsWUVhYSFZWFgsXLqSxsTGuRqcCGRkZfPDBB1RXV3P99dczceJEbr/9dtxud9gT0uMLYLVauf6zNxEIBLj55ps1tnrgFIZKiLd2y0kPYLrBvB5AuAx8qo/h5sMi3sNIRy4qqgfymMEWrnizsda4Y2jU6zCm84rly5ezaNEizjzzTPx+P9/61re4/PLL2b59ezhm4f777+e1117jhRdeIDc3l7vuuovrr7+ejz76KCEfwCg8/fTTMT9WVlbGn/70p5P+P7W4WEtTA1deeSXl5eVDMTGpFBj0gok3Ro33ACkgVTYaWECqKdItXR6NLdGOXm+AXaECcWcYzPMB4jo82NJDa7exxjAm8fHmm2/2+ffTTz9NSUkJ69at44ILLqC9vZ0nn3ySZ599lksuuQQQ6aETJ05k1apVnH322fGzXEJDcysbNm7iX//4Gy+//LLW5sREUZZYuJpTeNKDvi57o6EKyJYUFh/tvT72He0GpIA0Ktvq2gkEFUpzXJQZpEBcNAWZIQFpsDEcUqRme7twVRUUFACwbt06fD4f8+bNCz9nwoQJVFdXs3Llyn7Fh8fjweOJLEAdHR1DMSllUBSFL3/hs2zZsI7bv3wHl112mdYmxYR6wXS4jRelHS+aOtzUtbuxWIxVW0ClMCQgU3nXvCV05FJdkGGIbsTHExaQXcZauOLJxtAGYJoBxSNECUiDjeGgxUcwGOS+++7j3HPPZcqUKQA0NDTgdDrJy8vr89zS0lIaGhr6fZ2lS5fyve99b7BmpCy+QJA//O0VLBYLk8uNU1tAJS/dgdUCQQWOdXspyTHejmOobAotXGNLsgxZ4KkwJCBTedds1EBFlbCANJjLPp6o16ERj1wgegyNdR0Oeru5aNEitm7dynPPPTckA5YsWUJ7e3v4p7a2dkivlyqorejTHFZDpYapWK2W8K6r2WCKPV5sVhcug+641PE71uPDn6LN5TaGY3aM57kCKSDB2HFXYNzjz0Ftt+666y5effVVPvjgAyorK8P3l5WV4fV6aWtr6+P9aGxspKysrN/XcrlcuFzGyU3WC2qwaYbDeDtmlYJMJ81d3pSd+DYaON4DID/DgcUCiiIESHF2al3HiqKEx9Dwu+YU3QC0dns51NoDiBofRqTIoEHDMXk+FEXhrrvu4sUXX+Tdd99l5MiRfR6fNWsWDoeDZcuWhe+rqanh0KFDzJ07Nz4WS4CI5yPdGVuVVT0RUezGumjigaIo4R2XURcuu80arvWRigKyocPN0U4PNqvFUCW5oynMjNT5CAQVja1JPuqx2ajizHDFXqNh1MzBmLbNixYt4tlnn+Vf//oX2dnZ4TiO3Nxc0tPTyc3N5fbbb2fx4sUUFBSQk5PD3Xffzdy5c2WmSxyJ7uSbYWDxoab5Ge2iiQcHWnrocPtx2q2ML8vW2pxBU5jl4liPL7TrMu7nGAyqeBxfmm3YTYAaJBtURLG4VOtOHN4AGPTIBSLeK6MdX8ckPp544gkALrrooj73P/XUU9x6660A/OxnP8NqtbJw4UI8Hg/z58/n17/+dVyMlQjcviBBRcFmseCyGzdLpDCFI+031h4DRCMyI2f6GPW8OR5sCB+bGdPrAeCwWclNd9De66O1O/XEh9GPPiESt3Osx0swqBgmBjAm8aEop3fLpaWl8fjjj/P4448P2ijJqenx+gFx5GK0cs7RpPLCte6gEB+zqo3dzSsiIFPv6Gx9aAxnmmAM23t9NHd5GVuqtTXJIxhUwmM4a7hxx1CdRwNBhQ63L9xzSe8Yd8uVwqjxHpn9pGfeeuutXHvttUm2aHBEjl1Sb+Fad7ANMPakBxGXb6odnXn9wXCKphxDY7L3aBcdbj/pDhsTDHz06bRbyQ71dzHS0YsUH0ni1ltvxWKxYLFYcDqdjBkzhu9///v4/f6YX6s75PmIR7zHgQMHsFgsbNy4ccivFSupeuzS6fZR0yCK6c00+MJl1OqKQ2VbXTtef5D8DAcjDdSCvT8iAYuptQlQvY9nVOVhN/DRJ0QyXowkII39jRuMK664gvr6enbv3s0DDzzAd7/7XX784x+f8Dyv9+R/QL5AEK9f1FTQW7CpzxdbcyqjRmkPlU217QQVqMxPp9TgxdVSVUCui3LXG/noEyIC0ki75nigjuHM4XnaGhIHCgx4/CnFRxJxuVyUlZUxfPhwvvKVrzBv3jxefvnl8FHJj370IyoqKhg/fjwAtbW13HDDDeTl5VFQUMA111zDjt17AUhz2EBRWLx4MXl5eRQWFvKNb3zjhLicN998k/POOy/8nE9/+tPs3bs3/LiaLj1jxgwsFks4mDgYDPL973+fyspKXC4XZ5xxRp/ePqrH5Pnnn+fCCy8kLS2NZ555Jqbvo8iglfmGyjoTnDOrpKrLfv0hdeEy/hgWpegYrjtkouvQgPFzxhcfigLebm1+BhCAeyrS09PDXo5ly5ZRU1PD22+/zauvvorP52P+/PlkZ2fz4Ycf8tFHH5GVlcX1V38an9dLhtPGT37yE55++mn++Mc/smLFClpbW3nxxRf7vEd3dzeLFy9m7dq1LFu2DKvVynXXXUcwKLwnn3zyCQDvvPMO9fX1/POf/wTg5z//OT/5yU/43//9XzZv3sz8+fO5+uqr2b17d5/X/+Y3v8m9997Ljh07mD9/fkyfX91xtff68KVQhUwzTXrhKrUp5LJXFMU0AcOQmvV2Wru94YaAM6qMP4ZGLBZn3PKYKr4e+H8V2rz3t+rAGft5r6IoLFu2jLfeeou7776bo0ePkpmZyR/+8AecTvFH9Je//IVgMMgf/vCHsFv3qaeeIjcvjzUrV/DZaz/NY489xpIlS7j++usB+M1vfsNbb73V570WLlzY599//OMfKS4uZvv27UyZMoXi4mIACgsL+1Sh/d///V8efPBBPve5zwHwyCOP8N577/HYY4/1yWS67777wu8fK6nY3yUYVNhgkiwJSM3y3Efaemns8GC3WgzbjCyawnCFzNQZww2hDcDo4kxDNgQ8HiPG7Rjf82EgXn31VbKyskhLS2PBggV89rOf5bvf/S4AU6dODQsPgE2bNrFnzx6ys7PJysoiKyuLgoICPG43hw/ux9fbRX19PXPmzAn/H7vdzuzZs/u85+7du7nxxhsZNWoUOTk5jBgxAoBDhw6d1M6Ojg7q6uo499xz+9x/7rnnsmPHjj73Hf9+sRDd38VI7sKhsLupi06PnwynsSPsVdQdV1sK9XdRvR6TK3IMW1wsmsIUjL0y09EnRDYBzQYaQ+N7PhwZwgOh1XvHwMUXX8wTTzyB0+mkoqICuz3y9Wdm9vWgdHV1MWvWrD5xFD1eP4daeiguLsY5wOJiV111FcOHD+f3v/89FRUVBINBpkyZcsqg1lg43u5YUfu7pMquy0wR9gD5Gc5wf5fWHi8l2eb3XoXre5hk4Uq1DQCYUHyocTsGmkeNLz4slkEdfWhBZmYmY8aMGdBzZ86cyfPPP09JSQk5OTkAHO30QE4vOWkO8vIyKS8vZ/Xq1VxwwQUA+P1+1q1bx8yZMwFoaWmhpqaG3//+95x//vkArFixos/7qN6WQCAQvi8nJ4eKigo++ugjLrzwwvD9H330EWedddYgP33/pNp5s9kmPZvVQn6Gk9Zu0SAwFcSHmWJ2ILJwqf1dbAapkDlYfIFguKeLacbQgMefxt96mZSbbrqJoqIirrnmGj788EP279/Psnff5eGHHqStuR6Ae++9l4cffpiXXnqJnTt38tWvfpW2trbwa+Tn51NYWMjvfvc79uzZw7vvvsvixYv7vE9JSQnp6em8+eabNDY20t4uCid9/etf55FHHuH555+npqaGb37zm2zcuJF77703rp8z1fq7mClLQqUghdJtuz1+dtR3AuZZuPJDFTFFd2Lzj+GO+g7cviC56Q5GFWVpbU5cMOImTooPnZKRkcEHH3xAdXU1119/PRMnTuQb934Vr8dDaWEBAA888ABf/OIXueWWW5g7dy7Z2dlcd9114dewWq0899xzrFu3jilTpnD//fefUFfEbrfzi1/8gt/+9rdUVFRwzTXXAHDPPfewePFiHnjgAaZOncqbb77Jyy+/zNixY+P6OYvUbAkD5acPlqOdHvY3d2OxwEwTRNirFKbQGG6sbSMQVKjITaM8N11rc+KCw2YlL0N0dE0FAbnmgBrwnWeYPiinIzpd2ijdiY1/7GIQnn766ZgfKysr409/+hMAbl+AXY2dWC0WSovEMYzdbuexxx7jscceO+lrz5s3j+3bt/e57/haIF/60pf40pe+1Oc+q9XKd77zHb7zne/0+7ojRowYUK+f01GcLTwfTR3mX7g+2d8KiC6ouRnGbN/dH+oYHu00/xiu3tcCwFkjCzS2JL4UZ7lo6/FxtNNj6C7LAyEyhoUaWxI/CjJF7FVQEQJEvSb1jPR8GIRuT6SZnNXgFRWjUWMEmlJg4fpkv5j0zh5lnkkPUmsMV4cE5ByzjWFOaBPQ6dbYksQSDCqsOaCOoXkEpN1mDXsgjTKGUnwYhO5TNJMzMsXhSS91Fi6z7ZrDC1eHMSa9weL2BdgQasE+x2xjmCICcndTF8d6fKQ7bEwdlqu1OXGl2GBjKMWHAVAUJez5yDJBXYFoSsIue3MvXMe6vexsEIGKphMf2akhIDfVtuH1BynKchm+mdzxlKTI8efqkPdx1vB8HCZIdY8mPJcaZAzN9e2bFG8giC8QxGKxkOE0l+dD3XG1dHtNXaRKdfWOKckKd6A0C6mya/5kf8Rdb/RmcscTjr0y+SYgfGxmsg0ARG8CjDGGUnwYgG6POHJJd9hME52tUpjpxGa1oCjm7qpp6kkvJzUCTtUxPNuUY2h+AakoCqv3mfPoEwh3yDbKGBpSfMQjy8JIhI9cXMY8cjnVeFmtlnCamFEU+2BQ3b1mnPTUHVd7rw+3L3CaZxsTXyAYLhBnpiwJlZIUyFja39xNc5cHp93K9Ko8rc2JO5HYK2OMoaHEh8Mh0hN7eno0tiS5dHuF+DBqsKk6Xur4HU/YbW+QiyZWOtw+ttd1AObLdAHITXeEy/2bdfHacqSdXl+A/AwHY0vMUZgqmkjMh5k3AMLrcUZVHmkOY27kToXRjl0MtZrZbDby8vJoamoCRCEus529Ho/XH8TjdmPBgjXox+02zs5SURR6enpoamoiLy8Pm63/C97sAYvrDhwjqMCIwoywa9RMWCwWirNcHGnrpanTQ1VBbD2PjEC0u95sR58QOXbp9gbo9vgNu9E5FWp9DzMem4Hxsl0M9xemtn1XBYjZ6fH6ae324bRbONRrzIUrLy8vPG79YfYaA2ZNsY2mJEeID7NmLX2y33yFqaLJctnJcNro8QZo6vQw0mTiQ1GUqOvQnGMYvYlTFEX3G3PD/YVZLBbKy8spKSnB5/NpbU7C+d+3anhjaxM3zK7iv6aN1NqcmHE4HCf1eKgYTbHHysrQjmuOSSc9MLf3yhcIhktymzFgWKUk28WBlh6aOtymSyU+1NpDfbsbh83CzOF5WpuTENSMJa8/SEevX/dVlA0nPlRsNttpFzWjoygKb+xo4UhngOkjiklLM6bn43SYucZAe4+PLaEOmueOKdLWmARi5ridzYfb6PL4yc9wMKk8R2tzEkZJdpoQHyYUkCv2NAMwozrfdOUKVNIcNnLTHbT3+mjqdOtefBgq4DTVONTaw5G2Xhw2i7ld9iYuNLZyXwtBBUYXZ1KWa07xCMYLdouFFbuF5+qc0UWmjPdQMXO14Y9C4uM8E28AwFgeSCk+dIyq1meaVa23H4aOOlPXGDD1pKco0LIXelqj4nbMO4am9FwFA9C0AzxdphWQgaDCx3uFgDTlGPrc0LgdAj5Dxc+ZcEUzD6ZeuN5/GN5fCsDoOQ8Aszja6SEYVEy1u/xor0kXrmAA/vEl2PZPsLmYctZSoMx0xy7dHj/rD4l4D9Ndh+52+PO1ULce0guYOvGngNMw5bkHyva6Dtp6fGS57EyvNFc/F1r2wp+vgfZaKBjNuIKlfIQxjj+l50On9FHrY0026R34KCw8ALJX/4QLrJvwBxWO9ZinymldWy/7jnZjtcDZo00WbLruKSE8AAIeJn6yhGpLo+k8H5/sb8UfVKgqSKe60GQpxO98VwgPgN5Wrtj+TdJxm24MVQ/y2aMKsZusnwv/WiSEB0DrXm5tXAoohhhDk42EeVDVerbLzjSTdV/k41+I25k3w5w7AVjsfAkwl9te9VxNr8ojJ03fwV8xEQzAR6ExvOJhGHkB1oCH/7K9Sku3x1Q9ekzrfexqgg3PiN9vfB5yq8lwN/Aftg8M4bKPhY/3qmNosg3AoVVwaCXYnHDbG2BPY3jnBs6y7DTEPCrFh04Jq/XRJlPrrftg15uABc69D867H2xOzqCGGZbdhrhoBoppF65db0HbQUgvgJm3wIXfBGCh7QPylA5aus3jvVph1niP9X+CgAeGzYZx8+HcewC43fYGRzt6NTYufrh9gXBDQNON4erfitvpn4Ph58D0GwH4kv11Q1SqNdGqZi5Mu3DVvCluR54PhaMhuwwmXw/AtbYVNLSbY+JTFIUVe0wa5Fbzmrid9llwZoiJr2wqaRYfC2xrqG/X/8Q3EI52etjZ0AnAXLOVxa95Q9zOvBksFjjj8yjObEZYGxnp3mGaHj3rDx7D4w9Sku1ijJnK4vu9sOcd8fvMW8Tt2V8B4BLrBrrbj2pk2MCR4kOHuH0BPjlgUrW++9/idtwVkfumCPFxhW0NdcfM0bdnV2MXzV0e0h02ZlTnaW1O/FAU2P22+H3c5eLWYgkLyCusn1DfZg4BqbrrJ5XnUJjl0tiaONLVBEfWid/HzRe3zkwYvwAIXYcmGcMVUZs4vVf8jIlDK8HTAZnFUDFT3Fc8Hm/RJOyWIJM7P9Z9A1YpPnTImgOteP1BSnNcjC42UaVBbzcc/Ej8PvbyyP2jLsJjy6TU0oajfp02tsWZD3eLnceZIwtw2U1UDK9hM3Q1giMThp8buX/i1QDMtW6nudkcrQ8+3K0euZjM67Fnmbgtny48jyEsE68CzCUgI2Nosk3cntAGYOzlYI0s49ZJ4jqcx2rdH39K8aFD3q8RC9eF44rNpdaPrIeAF3KGQeGYyP12F0dLzgOgrGWVRsbFl+gxNBWHVovb4eeAPcobUDSG5rThOCwBXIc/0sa2OBIMKuExvGh8icbWxJna0DU28sK+94+5FD92qq1Haa/blXy74kxTp5stR9oBuMCs1+HIC/rcbZ94JQDnWLdR19qRbKtiImbx8cEHH3DVVVdRUVGBxWLhpZde6vP4rbfeisVi6fNzxRVX9P9ikn55r0bsHC8226R3eI24rTxTuOqj8FQL8TG22/iej26Pn9WhRmQXjzfZpBc9hsdxtPhsAEqajS8gt9d30NzlIdNpY/aIfK3NiS+H14rbqrP63u/M5FDmZABchz5MslHx54NdwusxdVhuuO+JKfB7oH6T+P3467B0Kh2WHDItHnr2rUm+bTEQs/jo7u5m+vTpPP744yd9zhVXXEF9fX34569//euQjEwlDrZ0s+9oN3arxXz1PdRz5srZJzzkGnsJABP9O1E8Xcm0Ku58tKcZX0BheGGG6Rp0RcTHiWPYW3k+AKO71ifTooTw3k6xATh3TJG5js08ndC0Xfw+7MQxbCwUArLYBAIysokz2QagYYvIVMoohIJRfR+zWtmTeQYADp0LyJgrnC5YsIAFCxac8jkul+uULdQlJ0d19c4ekW+u2hCKcspdc1H1BI4ohQyztNC5dxXZk+Yl2cD48V5oDC8eX2KuY7PuFji2X/w+bNYJD7vGXEDwYwtVgVoR1JhlXM9deOGaYNzP0C91G0AJQm4V5JSf8HBv5blw6LcM79wgrlmD/v36A0E+2BU6NjPbGJ7CgwzQUDAHuj6g4OjqJBsWGwmJ+Xj//fcpKSlh/PjxfOUrX6GlpeWkz/V4PHR0dPT5SWVMe+TS1SQCFbFA2bQTHk5z2tlmHQ9Azz7j7roUReH90BheZLYdV+MWcVswGtLzTni4tLSM3cowAHyHPkmiYfGltdvLhto2wIRjWL9Z3Fac0e/DacNn41Vs5AbbRC0Xg7L+UBudbtGJeHplntbmxJeGreK2/Ix+H3aXi81dWfcOURBQp8RdfFxxxRX8+c9/ZtmyZTzyyCMsX76cBQsWEAj0/yUsXbqU3Nzc8E9VVVW8TTIMvd4AK0Ml1U234zq6Q9wWjBS1IfrhQPoUAKxH1ibLqrhT09hJfbubNIeVs81WG6IpNIYlE/t9uDDTyWbGAtBrYAH54e6jKApMKMumPDdda3Pii3rkUjql34fLCnLZrowAQKnVd8zAqVA3cReOK8Zmol5RADSGxEfp5H4fTh82mW7FRXqwB47WJNGw2Ii7+Pjc5z7H1VdfzdSpU7n22mt59dVXWbNmDe+//36/z1+yZAnt7e3hn9ra2nibZBhW7WvB4w8yLC+dsWYqiANRC9ekkz6lJU94RLKbNwqXrwF5b6dw9Z4zuog0h4liBSCycJ1kDC0WCwfSxGPKYeMKSDXew3QbAIgsXCcZw4q8dDYERSaa96BxvVemHcNgAI7uFL+fRHxU5GexKTha/OOwfgVkwlNtR40aRVFREXv27On3cZfLRU5OTp+fVOW9KHe9qWIFICI+iiec9Cn+kil4FDtpvmOR2AKDYdogN4Cm0KRXcvIxbMoVAjLj6CZdu3xPRiCosHxXJGbHVAT8kZ3wSRauNIeNXQ7h2Qoa1PNR397LzoZOLBY4f6zJrsPWfeB3gyMD8kf0+5TyvDQ2KEJABnW8CUi4+Dh8+DAtLS2Ul58Y3CSJoCgK7+40abwHRNT6SVz2AKUFOWwLuXzR8UVzMtp7faw7KNqvm642hKJEjeHJvVfBorF0Kuk4Aj2R5xuIjbVtHOvxkZ1mZ6aZKtPCcQvXyJM+rSlXHMm4mreKtE6DoXofz6jKoyDTqbE1cUb1XBVPAGv/ntWiTBdbGAeAX8exVzGLj66uLjZu3MjGjRsB2L9/Pxs3buTQoUN0dXXx9a9/nVWrVnHgwAGWLVvGNddcw5gxY5g/f368bTcVOxs6OXysF6fdyjlmq6ioKFG75pOLj6r8DDYERcxAOC3XQLy3s4lAUGFsSRZVBSZrv95xRJRzttpFwOlJqCzIZlMwlP6nY5fvyXh7eyMgYgVM1dARoDlUOKxobJ+qmMfjKBjJUSUHa9AXCVA1EP/e3gDApWY7cgFo3i1uT+FBtlotNOYIAeloqRHp1Tok5qtr7dq1zJgxgxkzZgCwePFiZsyYwUMPPYTNZmPz5s1cffXVjBs3jttvv51Zs2bx4Ycf4nKZqMhLAnhrm7hgLhhbRIYz5gxofdNRB552sNj6VjY9jqqCDLYGR4h/GHDSU8dw/mQTppmrx2aFY8F+8t1kdUEGW5WQ+GjYkgTD4oeiKPzbzGPYEjr6Lhx7yqdVF2WyNRjyjDQY6zrsdPv4ONTQ0dRjWHTyeRQgu2gYdUoBFhRo3JYEw2In5lXuoosuOmXDmrfeemtIBqUqb20TO67LzXjBqJkuhWP6luQ+jurCjPCxi9KwGUsweModmp5w+wLhGi2mnPTCAcMn33EBDC/MYHlwuPiHwcTHnqYu9jV347RZzZdiC1Hi49QLV1VBBtuV4VzMJsOJj/drjuINBBlVlGmuLrYqAxzD6oJ0tu8fToWtVWzkqs9OgnGxYYyZ3eTUtvawo74DqwXmTSzV2pz406yq9VPvuHLSHLSmVeNRHFi8XdB2IPG2xYkVu5vp9QWoyE1jyjATBk23hNy9ReNO+bTqggy2KUJ8KA1bDRV0qnquzhlTSLaZCvyptOwVt6dduDLYrnogDSYg1TG8bHKp+YL2FSUG8RHZyOlVQErxoQPUC+askQXmC5ACEegGUHjyWAGVYYU57FRCtV4MNPGpY3j55DLzTXoAraHso1PEewAUZ7uosw2jV3Fi8XVH/p8B+Hco3sOUniuIWrhOPYbDowVk4zaRJWMAPH6Tex97WsAtGuWdUFb9OKoLMtke9kBK8SE5Cf/eZvJJT02bPc0FAyGXr3rRGCTuwx8I8s4O9djMhJ4riBIfJ8+SAFHro7Igi51KtbijYVOCDYsPdW29bD7cjsWs3kd3B3SLbLrTiY+KvHRqKaVLScPid0dEi875eG8LXR4/JdkuzjBbVVOIjENuFThOXfyuOnR0Bogj04AvwcbFjhQfGtPc5WHNwVbApPEeEPF8DEB89HUXGsPzsebAMY71+MjPcHDWiAKtzYk/vl7oOCx+H+gYGizuQw00nT0831wdUFVaQ0cumSWQlnvKpzrtVspyM9kRFpDGGsPLJ5diNVtVUxiw5wpE/FytUkKHkg4Bry4rnUrxoTHvbG9EUUTb52F5JivlDOLM/1ioR8QpaguoVEd7PnTqLjwe9cjl0oml5kvPhMj4uXJEJ83TUF2YES7RbRTv1Vtm9z4OMN5DZXihsa7DQFAJp0mbdwwHFu8BkOWyU5jpYoei302ACWdKY/FmOLXPhK5egPbDEPSBzQU5w0779OqCDHYq1QSxQGc9dB1NgpGDJxg0eXomRI7N8kcMqMtpXwGpv0nveFq6PHxyIOR9nGTSMYxh1wyq236E+IcBxMfaA600d3nJTrMzZ6TJ6iSpxCA+4LgjbB1eh1J8aEhrt5cVu5sBuHKqSSvAqkcu+SMGlDZbVZBBD2kcUEKLgM4nvnWHjlHX7ibbZef8sUVam5MYYjg2A1VAVhHAKuIMOhsSaNzQeX1LPYGgwrTKXKoLTVYcTmWAMTsqVccfnem819LLm+oAuGJyGU67SZe1YwfE7QA8yKD/jBeTjpIxeH1LPf6gwpRhOYwqNmFOOkQtXAO7YCry0nHarWwPhs6bdVogR+XljWLSu3xymfkayanEKD5GFGXixsUBJSSo1ZLQOuWVTfUAXDWtQmNLEkjbIXGbN3xATx9ZlMlupVIIyJ4W4YXUKb5AkDe2CoF71fRUGMPqAT19RGEGO3QsIKX40BBVrV9t5gsmhkwXAJvVwqiiTGqCoXRbtbiVDvEHgry+RUzKV59h4jFsjW0MqwsysFstbFfHUMcCsq6tl08OtGKxwKenm9T7CNAe6hY+wIVrdHEWHpzsJ/R33bg9QYYNnY/2NNPa7aUoy8k5o0165OLuiKTZ5lUN6L+MLslij1IhBKS7TXcCUooPjahv72VN6Jz502becakL1wBdhSAmvl1qrY8m/S5cH+1toaXbS2Gmk3PNOulBzN4rh83K8MIMdhjAe/XqZrEBOHNEAeW5Jgz4BlGno0N8TnIHtnANL8zAaoHtAVVA6td7pW7irpxabs6Ab4iIx/R8cGUP6L+oAvIQIVHdpC8BadKR0j+vbqpHUeCsEQVUmDHLRaVNzXQZMeD/Mro4M1Jo7GiNbqtkqkcupp70gkERNAwDdtmDmPjCtT50LD7CRy5m9j52HAElADYnZA0ssD3NYaOqIIOdOheQbl8gXCfJ1GOoHrkMUDwCjCrOBGBboFLcoTPvlUlnTP3zSmjHdZWZXb0Abaq7d+AXzeiSLA4pJXhwiRbgOqySKSa9FDhn7moQ2UoWG2QP/G91dElWZOE6WgN+b4IMHDz7m7vZcqQdm9XClVNMmuUCkV1zbmVMvZJGF2dFan3oVHy8X9NEl8dPRW4as6rztTYncbTFdmwGkOG0MywvnV06PcKW4kMD9jd3s/lwaNIza5YLiFbO7jbxewyKfXRxFgpW9hBS7DpzF4JoYNXp8VOem8bs4Skw6eUMA9vA+1COLs6ijkK6LZlCvKjt3HWE6rk6d0wRhVkmLCymoo5hDNcghDyQqoBs1qeAVI9cPj29wpyFxVTaYws2VRlVnEmNTo+wpfjQgBc3HAFSaNJLzwfXwLN5RhaF3IV+/YqPFzeIo4irTD/pxe65ArFwgYVd6HPnrChKeAxNHfANUVkSsY5hFvUU0G3JgqBfCBAd0d7j450domR86oxhbOJjTEkWNUpoHtXZEbYUH0kmGFT4xzox6S2cefqiW4ZmEOeUAJkuOxW5aZGLRmcLV0uXh2WhSW/hzEqNrUkwgxzD0aF25pt96hjqK2Bx7cFjHGjpIcNpY4GZj1wgatc88JgdINSSXr8C8uXNdXj9QSaUZTO5woSdpKMZtPcqi0NKKR6L/o6wpfhIMh/vbeFIWy/ZaXbzVsRUiTG9L5rRJVnUqOfNOjurfHHDEfyholTjywYWeW5YBun5yElzUJLt0m3Q6d/WiM/1qanlZLoGfpxkSIawcAFs0qmAfGGt+Fz/MavSnJ2koxmC9yqIlX0W/R29SPGRZF5YJy6Ya86oMG9RKpVB7poBJpbnUBMMTXqte0VzMx2gKAp/D3muPjM79s9lOAa5cAFMKM/RZbZEt8fPa6H6LKkxhoNbuPIznZTm6FNA1jR0svlwO3arhetmmNyD7O2BHlEJO9aNnLo52urVX70WKT6SSHuvjzdDlfg+MysFJr1B7poBJpXncJQ8Oiw5oAR1E7C49UgHOxs6cdqtXG3m+iwqg1y4ACaWZ0eC3boaoLs5joYNnte31NPjDTCiMIMzR5g4WBhEqnSHiDEbjAdyok4FpOr1uGRCibnj5iCS6u7MhrS8mP5rQaaT8ty0SOkCHcXPSfGRRF7ZVIfHH2RcaRbTKk/d1toUDGHXPLE8B7BEvB86Ueyq52r+5DJyMxwaW5NgFCVKQMYWLwBCQPaQRr1NLbOuj8XrhSjPlend9V2NoqW6xQbZsYvlSeU57FIqRaPHrkZdNHr0BYK8tFEIqhtSynNVPaDGjscjxlCKj5RFURT++on4I/rMrBSY9GBIno9RxZk47dZIgRwdXDQ9Xn84U+kzs0weaArQ0wq+HvH7ADoSH48aBLjVr5/A4T1NnXyyvxWrBa43e8A3RBaunIqYUqVVJlUIAdmgCkgdxAz8e1sjzV1eirJcXDi+WGtzEk/74L2PIMZwp1rro3Wfbo6wpfhIEhtq29hW14HTbuU/UmHh8rnFTgkGtWt22KyMK83SlWL/18Y6Ot1+hhdmcN4Yk3awjUad9LJKwZEW838fUZiJy25lW7hEt/YL119Wic906cRS85ZTj2YIQd8gds2gLwH5f6sOAHDjWVU4zFpZOJoheJChnyPso/pImU6BkdMHf1kpyoxfNa2C/EynxtYkAfWc2ZEp6nwMgknlUYpd42MXRVH4v9AYfmHOcHPX9lAZ4qRnt1mZUJYd1eNF22yJbo8/nOb+xbNjF8SGZAhB3wDDCzPJcNrY5teHgNzd2MmqfcJzdeNZgxNUhmOQNT5UJlWII+wdwZCnTwcbOZDiIym0dHl4dbOIrr95bqpMeqGeLnlVgzqnBBH3sVut9dFZB73H4mRc7Kw/1Mb2+g5cdiufmZ0CnisY0rGZysTynEi2xNGdosmZRry08QidHj8jUsVzBUMeQ5vVwoSy7KiMF20F5F9WiXnlskml5u6JFc0Qx7AqP4Msl50dAf14kUGKj6Twt7WH8QaCTKvMZXpVntbmJIch7poBplXm0UkG9YQWCg3rffzfygOAqKSYl5ECniuI2xgeUkpwh4sc7YuTcbHRx3N1dop4rmDIu2aAqcNyIz1emrQTkF0eP/9YLzyqXzx7hCY2aMIQx9BqtTC5IiroVCfB+1J8JBh/IBhW619IFVcvxGXXPGVYDk6bNdLWWyPF3tTp5vUtIkX6i6niuYIhxwsAzByeh4KVmvDxmTY750/2t7KzoZM0hzU10txV4iAgZw7Pp1Ypppc0CHhE3R0N+Of6w3R5/IwqyuSc0YWa2JB0/F7oFHMPuUO5DvMjR9jS85EavL61gSNtvRRmOs3ffyCaOEx6LruNycNEqh+gmWL/08cH8AaCzKzOY1plniY2aMIQ4wUAxpZkk+Wyax50+rsPhMfluhmV5k+RVumTKj2Ehas6HwVrVPxV8gVkIKjwhw9FafBbzhmROp6rjsOAAvZ0yBz8UeHM6vzIPNpZLzLZNEaKjwSiKAq/+0DsEm6eO8L8FU2jicOkB+Ki2alhS+gujz/srr/jgtFJf39NicMY2qwWzqjK07RK5u7GTpbtbMJigS+fPzLp768ZPS2RVOncwccpVeanU5TlYkdQOwH55tYGDrX2kJfhSJ2YK+hb5G8I5RlmVufRRQaHFe2PsFWk+EggH+9tYeuRDtIc1tRy10NcPB+gKvaovgSKMkTDYuP5NbV0uP2MLMrkskmlSX1vTfF0RgJ8h3B0BmLi07JK5u8/FF6PyyaWMqp44N2VDY+6cGWVgX3wVUAtFgszq/MicR9JHsM+m7izh5PhNHkvnmjiNI8WZrkYUZgROf7UwdGLFB8J5LchV+8Ns6soSIX0WpWAP6qk89AumhnVeexVKvArVnC3C5dhkvAFgvxxhXD1fun8kdhSxdULkUkvLQ9cQ2ueN2N4fqS8c/shMY5JoqnDzUsb6gD4rwtHJe19dUEc4q5URMyANuJj9f5WNh1ux2W3cvM5I5L63poTJw8yiI1cjY7qJknxkSC2Hmnng11HsVrgS+el2KTXWQ9KAKwOsesaAhV56ZQW5LBfUUt0J++ieXljXTheZ+HMFHL1QlwXrlnD8+myZHFECQUJJnEM/7BiP95AkFnD85k1vCBp76sL4pDponLmiILIwtVeC71tQ37NgfLr94XXY+GsSorM3sfleIbQW+l4zhxZoKt2FVJ8JIjH3hGN0K6aXkF1YYbG1iSZcKDiMLAO/U/svDFFUYo9ObsuXyDIL97dDcDt549MrXgdiBrDoS9cOWkOpldFH70kJ2CxqdPNn0Mp0osuTrF4HYibyx5gemUuSlpuVMxAchavtQda+WDXUWxWC/91QYpt4iBqDId+HYp5VLyO0rQ96UfYxyPFRwLYWNvGOzuasFrg3kvHam1O8mmP36QHcN6Y4ohiT1Kg1Ivrj3CwpYfCTCe3zB2RlPfUFXH0fACcP6YocvSSJLf9b97fh9sX5IyqPC4eX5KU99QVcRxDu83K3FGFURkvyRnDn4U2cZ+ZVcnwwsykvKeuiKP3qqogA3/+aPyKFYunI3I0rhFSfCSAn70tLpjrZlSmVoCbSlv8zikBzhldSA1i0vPVJ37X7PVHvB53XjiaTFcKBbipxHHXDHDe2OKw50NJwsLV2OHmL6tFltLiy8alRiPH4wkvXPEJdj9/bFFSK52u3tfCR3tacNgsLLp4TMLfT3fEMXZOZc7YcvZpcITdH1J8xJm1B1pZHnIT3nNpCl4wEGlIFqeFKz/TiVI8GQBrc03CKyy+sK6Ww8d6KcpypVZhuGji7PmYUZ3HQbtIcw02boNgMC6vezJ+9e4evP4gs4fnc/7YFCmlfjwJFJCBBG8CFEXhJ6FN3A2zq6gqSLGja4hr7JzK+WOLdBN0GrP4+OCDD7jqqquoqKjAYrHw0ksv9XlcURQeeughysvLSU9PZ968eezevTte9uqaYFDhB6+JY4GUdRNC3D0fABMnTaVLScMW9EJL4v6eOt2+sOdq0cWjSXemWKyHSpwXLofNSsXoqXgUOzZfd6T3TwLY09TJs58IAbz48hT1erjbwRPKKoqTgBxRmEFH7jgAlMbtCRWQb21r5JP9rTjt1tT0ekBU3FVlXGLnAM4ZU8TukPjoPLQpLq85WGL+RN3d3UyfPp3HH3+838cfffRRfvGLX/Cb3/yG1atXk5mZyfz583G73UM2Vu+8tPEIm2rbyHTaWHz5OK3N0Y4475oBrpxWEXb5ums3xu11j+dX7+6hucvLqKJMbpqTol4Pvwe6QiWd4yggr5hWGW4UqCTQbf+DV3cQCCrMm1jKOaNT3OuRUQjO+GyCLBYLk6fOxKM4sAd6oO1AXF73eDz+AP/vdbGJu+P8UanTQO54EjCP5qQ5sJZNAcBdp22H4pjFx4IFC/jhD3/Iddddd8JjiqLw2GOP8T//8z9cc801TJs2jT//+c/U1dWd4CHRgvZjrezetjYhr93t8fPImzsBWHTJGEqy0xLyPronGIz7rhlgfGk2h10iY+Hwzk/i9rrRHGju5o8fiboe//PpiTjtKXoq2S7azuPIEItXnLhkYgm7EGKmZd+GuL1uNO/VNLF811EcNgv//amJCXkPQxCH0vj9sWBaJbsU0Zrdc3hzXF9b5amPDnCotYeSbBdfuSgFs5RUEuBBBhg5aTb1SgF7vfmaZrzEdXbdv38/DQ0NzJs3L3xfbm4uc+bMYeXKlf3+H4/HQ0dHR5+fRLBl5Vvk/nwkWX//PEoCvvBfv7+Hxg4PVQXp/Oe5KVTC+Xi6j4rmUxYr5AyL28taLBbSK6cDiZn0FEXhh69txxdQOH9sUWpmR6hER9jH8cgiJ82Br3ASAMcSID48/gA/eFWcY996zghGFqXosSckZNcMosPtIYdIeT20I/6bgMYON796dw8A37hiQmoGe6uoR5NxSLONZu6Zs7ko8Gt+X/FDvAGTiI+GBuGqLS3tW4a6tLQ0/NjxLF26lNzc3PBPVVViOk6OmDgTgHKlkfW74nvevKO+g98uF9VM//vKialXEyIadeHKLgd7fKu6jpt+DgAlPbtp7vLE9bVf39LAOzuasFstPPTpSakZJ6CSoF0zQOWE2QC4WsXRSDx5/N097DvaTVGWk7tTMcU9mjhnuqhYLBZcw6YC0H5gY1xfW1EUvv3SVro8fqZX5XH9jPhtXgxJHKubRlOSncbGhy7nyVvP1NS7q7lfecmSJbS3t4d/amtrE/I+2XnFHHOIiOGPP3ovbq/rDwR58B+b8QcV5k8uZf7k+EQlG5b2+OWlH8/ISbMJYqHY0s6rH8cvWOpYt5fvvCxiEL568RjGlg6tnLjhiWNtgeOZNed8ACqDDazYHr9NwI76jnAlzO9fM4WctBTpXHsyEiggp8wQm4Cinj3UtvbE7XXf2NrAv7c3YrdaeGTh1NTpXHsy4ljd9Hj0EEgfV/FRViYW3sbGxj73NzY2hh87HpfLRU5OTp+fRGGtmAZA+/71tHZ74/Kav/1gH5sPt5OdZucH10xJ7R0zJHThwplJV6bYyW1et4JgHHbOiqLw0MvbaO7yMrYkKzUrYR5PgnZcAGl5ZXTZC7BaFD786MO4vKbHH+BrL2wKbwAWTEnxDQAk7NgFoGyc8F6NsDTy95U74/KaTZ1uHvpXaANw0WgmlCVuHTAE0bFziZhLdUBcxcfIkSMpKytj2bJl4fs6OjpYvXo1c+fOjedbDYrcEeLoZbxygKdCgYVDYe2BVn4aSsv8zlWTKclJ0SDTaBIQbBpNRpWI+yju2sUbW/s/youF59fU8sqmOmxWC4/+xzRcdu13BJqTwB0XgLVcRNt3HdrErsbOIb/e0td3sq2ug/wMh9wAqCRy4coswp1WDMC6NR/T6fYN6eWCQYXFz2+iucvLhLJsFl2Soqm10XQ1QNAHVjtkV2htTUKIWXx0dXWxceNGNm7cCIgg040bN3Lo0CEsFgv33XcfP/zhD3n55ZfZsmULN998MxUVFVx77bVxNn0QlAvPx2TrQZ766ABtPYP3frR2e7nnrxsIBBWuPaOChTNT/HxSJZGeD8BeIc6bJ1oP8tg7u4YUN7CzoYPvviLSzb52+XhmVOfHxUbDk6B4AZWMSnEdjrcc4ufvDK1my1vbGnj64wMA/OSG6XIDAODtgZ5m8XuCNgGu0HVY5dvHUx8dGNJrPf7eHlbsaSbdYeNXn58hNwAQuQZzhoHNnEG3MYuPtWvXMmPGDGbMmAHA4sWLmTFjBg899BAA3/jGN7j77ru54447OPPMM+nq6uLNN98kLU0Hk0KZuGDGWo/g8bjDZ8Sx4vYFuOPPa6lrdzOyKJMfXjdV7rZUEiw+KBVjOMVWy+6mLl7aMLj+BE2dbm5/ei1uX5ALxhWnZtOq/vB7RWVFSOAYCs/HJOtBXttSz5bD7YN6ma1H2rnvuY0A3H7eSC6ZUHrq/5AqqNegKwfS8xLyFpZQrYgJlkP8/sN9tAwyAPz1LfXhSqbfu2YyY0pSPN5KJdHzqA6IWXxcdNFFKIpyws/TTz8NiGjo73//+zQ0NOB2u3nnnXcYN04nBbdyqyAtDwd+xloO8+SK/Ww9EtvEFwgqfP3vm1l78BjZaXZ+f/MsslI5HSwaRUlovAAQFpCjqMOFlx+9viPmia/L4+fLf1rLkbZeRhVl8ovPnSGD21Q6joASBHsaZBYn5j1KRan8qfbDgMKD/9iMLxBbtczDx3q4/U9r6PUFOG9MEd9cMCEBhhqUBHuugLCAnJlWR6fbz/deib1U97qDrdz//EZApEbfMDsxXhpDoqbZJnIMNUbzbJekYrGEF68bq9sJBBW+9sImerwD6xXiCwS5//mNvLKpDrvVwm+/MEsq9Wh6WsAXin7PrUzMe+RUQHo+VgJcVnSM1m4v33pxy4CDT9t7fHzxydVsOtxOfoaDP956JnkZ8U0JNjTRHYkT5c0rHg8WGxnBLsand7C9voNfLhv48cuB5m4++9tVNHZ4GFuSxa+/MBOHLbWmslOiLlz5iRQfQkBOstZitSi8vKmO1zbXD/i/r9rXws1PfoLHH+TSCSV8+9OTEmWpMTmmig/p+TAPZeK8eWFFK4WZTnY2dHLPXzfiP83Oq73Hxx1/XsvLIeHxixtncM6YFC3dfDLUHVdWGdhdiXkPiyW863pwph+71cJb2xr58b9rTvtfa1t7+OzvVrLhUBt5GQ7+9J9nMSKVC1H1R4KDTQHxt1EkvKHfPUvc9Yt39/DihsOn/a/rDh7jht+uFF6r4kz+7/Y5Mq32eNqSsHAVjQOrHZu3g6+dLTp3L/7bRtYfOnba//rq5jpufeoTur0Bzh1TyC8/PwOb9Dz2Rb0OEykgNSYFxYfwfGS0bud3N8/Cabfyzo5Gbnt6zUkDUN+vaeLTv/qQ92qO4rRb+d3Ns7hyankyrTYGyTqnDI1hlXsPDy8UYvKJ9/fy3y9uweMPnPD0YFDh+TWH+NQvPmRnQydFWU6eu+NsplXmJdZOI5KsMQztnOdm1fPl80VF4Af+tonff7CvXy+W2xfg5+/s5rO/XUlTp4cJZdk8f8dcynJ1EEumN44lwWVvd0LReAD+a1wPl0woweMP8sU/rObNk2ShHev28s1/bOauZzfg9gW5ZEIJT95yJhlOeWx9AikQ85F6ox5auGjYwqzqfH79+Znc/dcNfLi7mQt//D43zanmrJEFOO1Wdjd28a+NR1h/qA2AqoJ0fv35WUytzNXOfj2TrAum/AxxW7+R/7iykrYeEfvxzOpDvF9zlC+cPZxplbkEggpbjrTzj/WH2Xe0GxCt3R///MzUbVZ1OpJVW6B0Mmz9OzRuY8n1i+l0+3luTS0/en0HL244wufOqmJsSTa9Pj9rDhzj7+sOc7RTxPZ8amo5j/zHNBlrdTKSKSCbtmE7up1f3ngfX/rTWlbua+HOv6zj/LFFXHPGMCrz02nr8bJiTzMvbaijyyOOuO+8cDRfu3wcdnlcdiLBQKS/khQfJqJ4PNic4OmAYweYN2kkf//KXO5/fiO7Grv49ft7T8iCsVst3HbuCO65dCzZ0sV7chJY2KgPFSLTivrNEPDzpfNHMbIokyX/3MKRtt5wg79osl127p03llvOGSHjA05FuDJmohcucXRG4zasVgtLr5/KpIocHn2zhu31HTz0rxM7bpblpPHfn5rIp6eVy+yyU5GMmA8Q4mML0LiNTJedP99+Fj9+q4Y/rtjPh7ub+XB38wn/ZWJ5Dt+7ejJnjSxIrG1GpjO6xod5PeypJz5sDnHR1G0QPwUjmVyRyxv3XsDrW+p5e3sjOxs68AcVqgsymDOykIWzhqVul9pYSNaOq3AMOLPB2wnNNVA6mUsnlvLBN4p4Yd1hltcc5UCL8HSMLcnivNAuTO6UB0CSj11o3g0+NxZHGjfPHcFV0yp49pNDfLy3mfp2N06blYnlOVw6sYTLJ5WlbqfhgeLugN5Q3EXCxzAiIAEcNivfunIiX5gznL+uOcQn+1tp6/GS4bQzZVgun5pazrljCqVwPB3hDUAlWM1b8yQ1Z+OKmUJ4HFkHU64HwGa1cNX0Cq6abs5qckkhXN00wZOe1QoVZ8CBD8U4hhayNIeNL549nC+ebd4grYQS8ItUW0i89yqnAtLywN0mBGS5qFybn+lk0cVjWHSxrHI5KNSFK70AXAnOxFMFZIsQkDjEBq26MIMHr5Cpz4MmGQHDOiA1txHDZonbuvi39U5ZFCW5QVIVZ4jbI+sT/16pQmcdKAGwOkTGUiKJSnunPn5NAlOeZGZJZJcJkaMEw94PSRxIRp0WHZCi4kP0eKFuowjukQyd3mPiGAQSv2uGSNyHFJDxoy0qZseahKlBHcMj6xL/XqlCMnfNFktkLpVjGD9SoMAYpKr4KBoHzizwdcPR09eHkAwA9YLJLAFHEjJJ1IWrcasoCS4ZOsdCzRaT5e6tFN1ROSwXrriRjDTbaIaFxvDI2uS8XyqQAmm2kKriw2qLpGvWSbd9XGjdJ24LktQjJX+kiBkIeKEp9tLOkn5oDYmPZI2hunA1bQNvd3Le0+wke+EKC0gpPuKGFB8mJ+wulOIjLiRbfFgsUUcvcgzjQrLHMHeYSCVUguIIVDJ0wmm2I5Lzfmr8XOte6GlNznuamRSp8QFSfMizynjRekDcFoxM3nvKuI/4oh675CdxDNXFS16HQ0dR4NgB8Xuyjl0yCiJiVW4Chk77YQj6RS2q7AQHfWtMCouP0KTXuE2kiUmGRrJ3zRAlIKX4iAtajGGljBmIG11N4O0CizW5PUGGydiduNEaKnCZP9LUNT4glcVHbhVkFIlKco1btbbG+KgLV1J3zVExA57O5L2vGelpBXe7+D1ZLnuQC1c8UReu3MrENXbsDykg40dLaAwLR2trRxJIXfEh08Tih7cbukLNpJJ57JJTLs5FlSAcXpO89zUjarBpdjk4M5L3vhVnABboOCzKSksGT9hzleSFS/UiH14rjn4kg0cL76NGpK74AKgM9fM+tEpbO4yOes6clifOgJNJ1dni9tDq5L6v2TiW5EwXFVc2lEwUv8tNwNDQatdcNlXEKPS2Rv6OJINDio8UoVpduFZJxT4UwimaSfR6qFTPEbe1UkAOCS2OzVTCO2fpvRoS6rFLsj0fdlekWq08Phsa8tglRRg2S3QO7KyLdGSVxI6Wal31fBxeK3qTSAaHlgKyKiQgpQdyaLSErkMtFq6wF3ll8t/bLAT8ES9ysgWkBqS2+HBmhBtayYlvCGgpPkomgitHRPk3yf4SgyY8hhqIj+HniNsj62Tm2WBRFG2vQ3UMpfgYPO21IgHC5oKcYVpbk3BSW3wAVM8Vt/KiGTxa1IdQsdqg8kzxu4z7GDxaxXyo75lZIqrVyloRg6OzQbSLsFi16Qmiio+m7bLY2GAJH5uNTE5vJY0x/yc8HdFxH5LBoXWQlDqGMu5jcHi6oKtR/K6FgLRYIovXwY+S//5mQF248qrB7kz++2cWiZ5ZIOfSwRI++jT/kQtI8RE5b27aITqzSmLD1xvphqqV+AjHDEjPx6Bo2SNuMwohPU8bG8LiQ3ogB4VWabbRhI9ePtbOBiMTDjY1f6YLSPEBWSWhC1aRin0wtOwBFEjLFd+lFlTOBotN1IpQu3pKBk7zLnFbNF47G9SFq3a1DBweDHrIkqhWBaQUH4NCq2wljZDiA2Dk+eJ2/wfa2mFEohcui0UbG5yZkXRNOYaxc7RG3BaP086GkkngyhWBw41btLPDqOhh4VIFZP0mcZQniQ1VQKZAjQ+Q4kMw8kJxu2+5tnYYkaMh8aHlwgUwKjSG++UYxkxzSHxo6fmw2iKxO/s/1M4Oo6Jeh0VjtLMhrwpyq0VjNBl/FRs+dyTou1jD6zCJSPEBMPICcdu0DbqOamuL0dDDwgURAbn/A1kwLlaad4tbvQjIfe9raobh8Hsjno/iidraMio0l8oxjI2WPaJNRFouZJVqbU1SkOIDRKR26RTx+wHpto+J8MKlsfioPBPsaSJrQz1GkJyegD/i7i3SWnxcLG4PfizrfcRC617hbXDlQE6FtraoY7j3fU3NMBxHd4rb4gnaHV8nGSk+VOTRS+wEAxHxofXC5UiLctvLMRwwx/aLwkaODMip1NaWkomQVQb+XhF4KhkYTTvEbbGGcVcqoy4St41boKtJU1MMRbT4SBGk+FCRMQOx03YQAh7hccir1toaKSAHg+olKhqrfWEjiyWyeO17T1NTDEU4YFgHsQKZRVA2Tfwuj14GjhQfKczwc0Sfl2MHIjnzklOjTnqFY0XAoNaoAvLAhxDwaWuLUVAnPa09VyqjVbe9FB8D5qjq+dA43kNFFZByDAdOU+g6LJHiI/VwZUdKre/6t7a2GIXGreK2dJK2dqiUz4CMIvB0yHL5AyU8hpO1tUNFXbjqN0F3i6amGIaw50MnC5cqIPe9J4O/B4LfE9nw6mUMk4AUH9GMmy9ud72prR1GoSFUj0EN1tUaqxXGXi5+3/WWtrYYhQZVfEzV1g6V7LLQ35MCe5dpbY3+8XsjFWr1smuungv2dOisj4hbyclp2QNKQNS5yS7X2pqkEXfx8d3vfheLxdLnZ8IEnVwUp2NsSHwc/EgWyRkI6sJVphPxATBOio8B4+2JpGjqagyvELc1r2trhxForglluuTqpxOqIz3i/ah5Q1tbjIC6iSuZqH3AcBJJiOdj8uTJ1NfXh39WrFiRiLeJP0VjRWOtgFcGS50Ob3fEVagXzwfA6EtE7E7L7kgKqaR/mnaI2gIZRfqqLTD+SnG7+x2xs5ecnPpN4rZ8mr4WrvELxK0UkKenfrO4LZ+urR1JJiHiw263U1ZWFv4pKipKxNvEH4tFHr0MlMbtgCJaoWvV06U/0nIjsTu7ZezOKVHLmJdN0dfCVTFDiCFvJxw0yMZFK9SFS80w0QvjrgAsULcBOuq1tkbfNKjiQ2djmGASIj52795NRUUFo0aN4qabbuLQoUMnfa7H46Gjo6PPj6ZExwwEA9raomeiFy69Id32AyMc76GzMbRao8ZQuu1PiV4XrqwS0fAR5EbuVChKZAz1JiATTNzFx5w5c3j66ad58803eeKJJ9i/fz/nn38+nZ2d/T5/6dKl5Obmhn+qqqribVJsjDhf7J67m2TGxKnQ68IFMPHT4vbAClku/1SowYBlOgk2jUY9eql5Q2ZMnIxgMBIvoMeFSx69nJ62g+BuB6sjpTJdIAHiY8GCBXzmM59h2rRpzJ8/n9dff522tjb+9re/9fv8JUuW0N7eHv6pra2Nt0mxYXfChKvE79te1NYWPVO3Xtzq8Zwyf4Rw3StB2PmK1tbok4A/Kl5Ah2M46kJwZEJ7LRxZp7U1+qR1n+gCbE/TT52WaMZ/StzufQ96j2lri15Rj81KJoq1J4VIeKptXl4e48aNY8+ePf0+7nK5yMnJ6fOjOZOvFbfbX5ZHL/3hc0c8H6prVW9MulbcSgHZP0d3gK8HnNn6XLgc6TAh5P3Y8ndtbdErDSHxWDoZbHZtbemPkglQMlmU798hNwH9otdjsySQcPHR1dXF3r17KS83UP7yyAshLU8evZyMhi1iQskogrzhWlvTP6qAlEcv/aN6E4bN0Ed12v6YslDcbntRbgL643BoDMvP0NSMUzLlenG79R/a2qFXDq8RtxUztLVDA+IuPr72ta+xfPlyDhw4wMcff8x1112HzWbjxhtvjPdbJQ67EyaE4ga2/lNbW/TIkbXitnK2vrIkook+etnxL62t0R+HQ2M4TKeeK4DRl4r4q64G0elW0he1+V7VHG3tOBWqgNz/AXQ2amuL3ggGIgJSz2OYIOIuPg4fPsyNN97I+PHjueGGGygsLGTVqlUUFxfH+60Sy9TQRbP1H7K99/EYYeECmPIf4nbjX7W1Q4+ong+9HpuB2ARMvFr8vlUevfTB547E7FSdpa0tp6JgJAybJTYB21/S2hp90bRDpJM7s6BEJy0qkkjcxcdzzz1HXV0dHo+Hw4cP89xzzzF69Oh4v03iGXmhqBjobpPR2scT9nzM0taO0zHtBrDYhL1q4yYJeDojbdiH6XwMp4YE5NYXRUVWiaB+ozj6zCwRXj49o24CNslNQB9Uz1XlbP0efSYQ2dvlZFhtcMbnxe8b/qKtLXqis0F0/sUCFTO1tubUZJVE6kVslGMYpvYTQIHcatFLRc+MuADyqsHTDjte1toa/RA+cjlLv0efKtNuEKmkdRsi2R2S0HVISh65gBQfp0YVH3vfhfYj2tqiFw6EKk6WTYX0PE1NGRAzbhK3m56HgE9bW/SCOoYjztPWjoFgtcKMm8Xv6/6krS16wkgLV2ZRpPbOejmGYcKeDx0fmyUQKT5ORcEoGH4uoEjvh8r+D8TtyAu0tWOgjL0cMotF5pKstCg48KG4NYL4ACEgLVY49DEc3aW1NdoTDEQEpNpKQO/MvEXcbn5BHp8BtNXCsf3iWLjqTK2t0QQpPk7HrNvE7donZZMriFq4ztfWjoFic8CML4jfV/9WW1v0gKcTjoQKxI00yBjmVEQ6Tq97WlNTdEH9RhGL5so1TormyAtFWr6nHbbJDMJw49Jhs0RGVwoixcfpmHQNZJdDV6MsWNV+RFRVtFhhuEF2XABnflnsMA58KM+cD60CJSAWgrxqra0ZOGfeLm7X/1mUo05l9r4nbkeer8/iYv1htcLs/xS/r3xclszfFxrDURdpaoaWSPFxOuzOyMS36tepfdHsXy5uy88wllrPHRYpOpbq3g91x2UUr4fKmHmi94W3U8Z+qGNotIVr1q0irbRpO+xdprU12hEMwr7QXGq0MYwjUnwMhFm3gc0l3J2pXOxI7TA69jJt7RgMc74ibrf8TWTspCrqGI4x2BhaLDD3LvH76t+kbvCwtycSqDjqYm1tiZX0PJgZCh7++FeamqIpjVuhp1n0LqpMzXgPkOJjYGQWRTJflj+irS1a4feIrB+IpK8aiaozoepsCHjho59rbY02NO+G1r0i7XH0JVpbEzvTbhB1LTqOwJYXtLZGG/a+K/6G86qh0ID1k+bcKY5A970HdRu1tkYb1A3AyAtSrplcNFJ8DJTzF4tJe//y1PR+HPhQdNDMLtd3L4lTcdGD4nbtH1PT+6FOeiPOgzQdNHCMFbsL5n5V/L78kdT0fqgN2iZcpf/6Hv2RPzxScv29H2lri1aonbbV9OMURYqPgZJXHcmaeP9hbW3RAnXhGjdfBI8ZkVEXi7oIfndqej/UMRx/pbZ2DIWz7hDej2MHUi/93e+FXaExnHiVtrYMhYu+Kbwfu/8Nh1ZrbU1yOXZANOa02GDcAq2t0RSDriIaEe392POO1tYkj4APtr0kfp9gYLVusYiJD2DNk6FKrSlC++FIh+bxBp70nJlw/gPi9+WPplbfpQMfikyfzBJ993M5HYWjI8X/3v1BagXx73hV3I44FzILtbVFY6T4iIW8arHzAnjzW6nj9t37ngiQyigyfnT2qIvFZwh44N/f1tqa5LHlBUCB4edBXpXW1gyNWbdCTiV01sHHv9DamuSxJdRcb8KnjN8L5IJvgM0pBNXO17S2Jnls+Zu4VRsmpjBSfMTKhd+AjEJorhGxA6nA5ufE7dT/EEW7jIzFAvOXilolO16OVGw1M4oiysuDCNo0Oo40uPz74vcPfwLHDmprTzJwd0S6wqrB70YmrwrOuVv8/uaS1Kh6Wr9ZdCK2OSNxLymMFB+xkp4HF/+3+P3dH0JHnabmJJzeY7Az1NV32me1tSVelE6C2aHaLa89YH7Xfd0GOLpDpItPukZra+LD5OtFlV2/WyxeZmfbi+DrgaJx5knPPP8B4cFqPwQrfqq1NYln4zPidvyVkFGgrS06QIqPwTDrVlEW19MBr9xn7jPL9f8H/l4omWycUs4D4eJvQVYpNO+C5SYPIP7kd+J20tXGaAY4ECwWuPLHYLVDzWuRIwkzoiiRhmwzvmDMLJf+cGbCFUvF7yt+JrwCZsXbA5tD3scZX9TWFp0gxcdgsNrgml8L99nut2DTX7W2KDEE/PDJ78XvZ99pnkkPxM7j0z8Tv3/0czi8Tlt7EkVXE2z9h/hdLbRmFkomwgVfF7+/tti8XshDK+HIOuG5mn6j1tbEl4lXifiHoB/+eYd5vZAbnxFe5LzhMNpgxeEShBQfg6VkQiRz4rWvwdEabe1JBDteFi7RjEKY+hmtrYk/Ez4lPpcShL/fKiYHs/HJ70VRqsozoXKW1tbEn/MfEB45dzu8eKfo+Go21LTwMz4PWSXa2hJvLBaxCcgsgaM74e2HtLYo/gT8sDJU0XXuXcYPFo4TUnwMhXPvE1XqfN3wt5vB2621RfEjGIjUMznzy+BI19aeRHHl/4rdSNuh0OIV1Nqi+NHdIvoRQaQ0udmwOeC634EjQ6TAL/u+1hbFl/rNsOtNwBIJ0DQbmUVwTWhx/uS3keBos7D1HyKtP70gkmIskeJjSFhtsPBJyCoTqv0fXzbPzmvz30RGT3p+pKqkGUnPgxv+LFzau94UdQfMwkc/E1Vpy6aZO7WveFxk8froMfPEfygK/Pt/xO9TFhqznPpAGTc/coT2yj1wZL229sQLnzsyp8xdJOJcJIAUH0MnqySyeNW8Bm98w/gBqO6OyA7yvPuN1cF2MFScAVc9Jn5f8VNY8wctrYkPR3dFOvhe+pBxq9IOlCkL4Zx7xO8v3hlpO29kdr8tvDk2J1yaAjVpLloCYy8XGUzPfAaa92ht0dBZ/QS010LOMDjbxJu4QWDyGSlJVM+Bhb8HLGLheuc7xhYg7/5AFHAqGBUpqmZ2zvg8XPQt8ftrX4ONz2prz1BQFHj1fhHrMeYy0Y4+FZj3XZFKHPTBczfBoVVaWzR43B0iiBZgzn9B/ghNzUkKqie5fLooavh/1xq7CnHLXng/1Ij0km+DM0Nbe3SGFB/xYtI1IvUPRIDYm0uMGT+w991Ihsunf2beWI/+uPAbcOaXAAVe+qpxi8it/g0cXCHiID71E3NlKZ0Kqw2u/72oYOvrhj9fC7v+rbVVg+Pf/y12zHnD4cJvam1N8kjLgZv+AYVjxOd/cj40btPaqtgJ+MUc4u8Vf4/TP6e1RbpDio94ctaX4VOhYjmrn4C/fRE8XdraFAvtR0TcCgrMus34pdRjxWIRAahn/RcQ8h78+3+MFcdTuyYSJ3DZ90UX0VTC7oLP/TXkvu+F5240XgO69f8H6/8sfr/mcXBlaWtPsskqhltehZJJ0NUATy2A/R9qbVVsvP0Q1K4CZxZc/cvU2QDEgBQf8ebM28Xuy+aCna/CH+dD826trTo9vW3wzH8Id2fZVLjC5IW3TobFAgsegQsfFP/++Jfw189Bd7O2dg2Elr3C1qAfJl0b8uKkIM4M+NyzIo066Id/LYKX7zFGDYl9yyPHLRf/N4w8X1t7tCKnHG57XXShdrfDn68WhciM4E1e+0dY9bj4/donRE8wyQlI8ZEIpt0At74mctcbt8JvzofVv9PvhdPdAn+5Hpq2i6qfn/2L6J+RqlgsogLqf/wR7Gmi9fevz46UmdcjR2vgz9cI8Vg+XWR/pPJuS03BvehbgEVUCP3DpaJYl17Z/6EQjwGvKL51/te0tkhb0vPhiy+JwmpKEN75LjyzUN+9fNb/GV4NiccLviGqCkv6xaIo+oqM7OjoIDc3l/b2dnJycrQ2Z2h01MFLX4F974t/V54lPAp6KvbUsldElrfuhbQ8IZrKpmhtlX6o3ywqLx7dIf498Sq47AdQMFJbu6I5+DH89UZwt4mz8tveFK5riWDvu/CPL0FPi2goeNZ/hRpE6qi/xoa/iFYNQR+MvkQcHaXyBiAatbz8698Q3agdGcIzOedO/XxHwaAI1Fd71Jx1Byx4NOU2ALGs31J8JJpgMJIB4wt1bpx8HZy3GMqnaWeXosCG/4M3vimC83Kr4Qt/h+Lx2tmkV3xueO+HsPJxsQOzOUVMzDl3a9uePuCD5Y+Izq5KEIbNhs//DTILtbNJr3Q3iyBwtaW5K0ek5p71ZW373fQegze/BZtC2VWTroXrfpNagd4DpXkPvHKvCKYGkb564Tdg2ue0FSHHDsDLd0c6ZJ+3WGS3mD29vR+k+NAjHXWw7AeRSQbEDmfmzaLLod2VPFsOrQoFRK0W/x5+rkhxyylPng1GpHE7vPUt2BeqIWG1i8Vixhdg5IXJm2yCQRFP9M53hccKRMfhTz8m0/lOx55l8PZ3oHGL+LcjQ3x3Z94OpVOSt1P1e4SLfvmj0N0EWMRCeuE3U3LRGjDBoOil9d6PoOOIuC+jUDT7nHVrcuMrettE2fSVj4uNpT1d1AtK4cwWKT70TMMWETi17UWxWwVx3DHxKlHlb9RF4MqO//v2tkHNGyIY6vAn4j5HhuhPI/sNDBxFEcdoK34mCkCp5FaJMRx7uRBzdmf837uzAba9JDxpLaEg5owiuPJRUWRLMjCCQdj2T/jwp9AUlcZZNE6IyfELRFVYmz3+792yVzQZ2/CMyOQAKBwrslqq58T//cyKzy3mspWPQ8fhyP2VZ4qyB2MvF+MZbzGpKFC/UQjHzS+At1PcP/xckdVi5iq0A0CKDyPQuk+c8278qyjopWJ1iGyTytnCjV4yQRT7ikWQKIpw59ZvgsNr4dDHIpgt6Iu8xxk3imA86e0YPHUbxdHVlhdERL6KI1M0O6ucBcNmQdF4USQqFtewogixUbdBiMUDH8HhNUDocnXliOJT596bGLGaCiiKiJf55LdCmAe8kcec2TB8rojTKp0ExRPEGMYi0oNBcZ03bBJdk/e8I1oWqGSXwwVfgxk3J0aspgIBP9S8Dmt+H0rHjVrOMouFKKicLTogF0+EnIrYBEnAJ4K5GzYLj/GeZX3FTvEEkZU08aqUi+/oDyk+jEQwIM4Kd70Ju96CY/v7f15miRAK6QUiUM6RISZCi02ICncHeDrE2faxg+BpP/E1iifA5OuFezK7NKEfK6XwuUVGzO63REnsrsZ+nmQRZ9RZJWL80guEGLFYxRgGPJEx7GoS58hqjFA0lWeJFNIzbpSiI56426HmTdHJ+cCHfcWkitUB2WXiJ7NEHHE50kVafdAnjlJ8vWL8Ouugo16MazQWq/BuzrwldNwqRUfc6GyAHa+In9rVokz78djTQ2NYLmKjHBkio83mFGPl94oYuM5G6KwXP0H/ca+RJjpiz7wFRpwvj8mikOLDyBw7ILwVR9aJ5kote0T65GDIqxaLVeWZIr6keFxcTZX0QzAodreH18KRtcI70rpPiIpYsViF67jyTPEz9jKxc5MklmBApMgfWCGynZq2i93v8UJiINjToHSyOMYZdaEQHun5cTdZchx+j5g/D66Ahq2i8WfzblAGUTDQlSO80eVniHl0xLkyIPgkSPFhNtzt0Lpf7Kh6W6GnVeyKlaCYKK02cYGk5YiJLX+EKMssgw/1gaKIMTu2H7qPipTPnlaxmAWDoQwau4j9ScsVnpH8kSKORO6M9UEwIILGOxvEbrj7qPBy+HvFQmdzih97mkhzzq4QnsqcysTEjkhix+8VQarqGPa0CO+Ir1ccudlc4nqzpwvPsDqG2RXSuzFApPiQSCQSiUSSVGJZvxMm5x5//HFGjBhBWloac+bM4ZNPPknUW0kkEolEIjEQCREfzz//PIsXL+Y73/kO69evZ/r06cyfP5+mpqZEvJ1EIpFIJBIDkRDx8dOf/pQvf/nL3HbbbUyaNInf/OY3ZGRk8Mc/GrRFuUQikUgkkrgRd/Hh9XpZt24d8+bNi7yJ1cq8efNYuXLlCc/3eDx0dHT0+ZFIJBKJRGJe4i4+mpubCQQClJb2rSNRWlpKQ0PDCc9funQpubm54Z+qKg17ZUgkEolEIkk4mucPLVmyhPb29vBPbW2t1iZJJBKJRCJJIHFPQC8qKsJms9HY2LfKY2NjI2VlZSc83+Vy4XIlsamaRCKRSCQSTYm758PpdDJr1iyWLVsWvi8YDLJs2TLmzp0b77eTSCQSiURiMBJSem/x4sXccsstzJ49m7POOovHHnuM7u5ubrvttkS8nUQikUgkEgOREPHx2c9+lqNHj/LQQw/R0NDAGWecwZtvvnlCEKpEIpFIJJLUQ5ZXl0gkEolEMmR0UV5dIpFIJBKJpD+k+JBIJBKJRJJUdNfrWT0FkpVOJRKJRCIxDuq6PZBoDt2Jj87OTgBZ6VQikUgkEgPS2dlJbm7uKZ+ju4DTYDBIXV0d2dnZWCyWuL52R0cHVVVV1NbWpmQwa6p/fpDfQap/fpDfgfz8qf35IXHfgaIodHZ2UlFRgdV66qgO3Xk+rFYrlZWVCX2PnJyclP2jA/n5QX4Hqf75QX4H8vOn9ueHxHwHp/N4qMiAU4lEIpFIJElFig+JRCKRSCRJJaXEh8vl4jvf+U7KNrJL9c8P8jtI9c8P8juQnz+1Pz/o4zvQXcCpRCKRSCQSc5NSng+JRCKRSCTaI8WHRCKRSCSSpCLFh0QikUgkkqQixYdEIpFIJJKkkjLi4/HHH2fEiBGkpaUxZ84cPvnkE61NShhLly7lzDPPJDs7m5KSEq699lpqamr6PMftdrNo0SIKCwvJyspi4cKFNDY2amRxYnn44YexWCzcd9994fvM/vmPHDnCF77wBQoLC0lPT2fq1KmsXbs2/LiiKDz00EOUl5eTnp7OvHnz2L17t4YWx5dAIMC3v/1tRo4cSXp6OqNHj+YHP/hBn54TZvoOPvjgA6666ioqKiqwWCy89NJLfR4fyGdtbW3lpptuIicnh7y8PG6//Xa6urqS+CmGxqm+A5/Px4MPPsjUqVPJzMykoqKCm2++mbq6uj6vYeTv4HR/A9HceeedWCwWHnvssT73J/Pzp4T4eP7551m8eDHf+c53WL9+PdOnT2f+/Pk0NTVpbVpCWL58OYsWLWLVqlW8/fbb+Hw+Lr/8crq7u8PPuf/++3nllVd44YUXWL58OXV1dVx//fUaWp0Y1qxZw29/+1umTZvW534zf/5jx45x7rnn4nA4eOONN9i+fTs/+clPyM/PDz/n0Ucf5Re/+AW/+c1vWL16NZmZmcyfPx+3262h5fHjkUce4YknnuBXv/oVO3bs4JFHHuHRRx/ll7/8Zfg5ZvoOuru7mT59Oo8//ni/jw/ks950001s27aNt99+m1dffZUPPviAO+64I1kfYcic6jvo6elh/fr1fPvb32b9+vX885//pKamhquvvrrP84z8HZzub0DlxRdfZNWqVVRUVJzwWFI/v5ICnHXWWcqiRYvC/w4EAkpFRYWydOlSDa1KHk1NTQqgLF++XFEURWlra1McDofywgsvhJ+zY8cOBVBWrlyplZlxp7OzUxk7dqzy9ttvKxdeeKFy7733Kopi/s//4IMPKuedd95JHw8Gg0pZWZny4x//OHxfW1ub4nK5lL/+9a/JMDHhfOpTn1L+8z//s899119/vXLTTTcpimLu7wBQXnzxxfC/B/JZt2/frgDKmjVrws954403FIvFohw5ciRptseL47+D/vjkk08UQDl48KCiKOb6Dk72+Q8fPqwMGzZM2bp1qzJ8+HDlZz/7WfixZH9+03s+vF4v69atY968eeH7rFYr8+bNY+XKlRpaljza29sBKCgoAGDdunX4fL4+38mECROorq421XeyaNEiPvWpT/X5nGD+z//yyy8ze/ZsPvOZz1BSUsKMGTP4/e9/H358//79NDQ09Pn8ubm5zJkzxxSfH+Ccc85h2bJl7Nq1C4BNmzaxYsUKFixYAKTGd6AykM+6cuVK8vLymD17dvg58+bNw2q1snr16qTbnAza29uxWCzk5eUB5v8OgsEgX/ziF/n617/O5MmTT3g82Z9fd43l4k1zczOBQIDS0tI+95eWlrJz506NrEoewWCQ++67j3PPPZcpU6YA0NDQgNPpDF90KqWlpTQ0NGhgZfx57rnnWL9+PWvWrDnhMbN//n379vHEE0+wePFivvWtb7FmzRruuecenE4nt9xyS/gz9ndNmOHzA3zzm9+ko6ODCRMmYLPZCAQC/OhHP+Kmm24CSInvQGUgn7WhoYGSkpI+j9vtdgoKCkz3fYCI+XrwwQe58cYbw43VzP4dPPLII9jtdu65555+H0/25ze9+Eh1Fi1axNatW1mxYoXWpiSN2tpa7r33Xt5++23S0tK0NifpBINBZs+ezf/7f/8PgBkzZrB161Z+85vfcMstt2hsXXL429/+xjPPPMOzzz7L5MmT2bhxI/fddx8VFRUp8x1I+sfn83HDDTegKApPPPGE1uYkhXXr1vHzn/+c9evXY7FYtDYHSIGA06KiImw22wmZDI2NjZSVlWlkVXK46667ePXVV3nvvfeorKwM319WVobX66Wtra3P883ynaxbt46mpiZmzpyJ3W7HbrezfPlyfvGLX2C32yktLTX15y8vL2fSpEl97ps4cSKHDh0CCH9GM18TX//61/nmN7/J5z73OaZOncoXv/hF7r//fpYuXQqkxnegMpDPWlZWdkIAvt/vp7W11VTfhyo8Dh48yNtvv92nnbyZv4MPP/yQpqYmqqurw3PiwYMHeeCBBxgxYgSQ/M9vevHhdDqZNWsWy5YtC98XDAZZtmwZc+fO1dCyxKEoCnfddRcvvvgi7777LiNHjuzz+KxZs3A4HH2+k5qaGg4dOmSK7+TSSy9ly5YtbNy4Mfwze/ZsbrrppvDvZv7855577gmp1bt27WL48OEAjBw5krKysj6fv6Ojg9WrV5vi84PIbrBa+05vNpuNYDAIpMZ3oDKQzzp37lza2tpYt25d+DnvvvsuwWCQOXPmJN3mRKAKj927d/POO+9QWFjY53Ezfwdf/OIX2bx5c585saKigq9//eu89dZbgAafP+4hrDrkueeeU1wul/L0008r27dvV+644w4lLy9PaWho0Nq0hPCVr3xFyc3NVd5//32lvr4+/NPT0xN+zp133qlUV1cr7777rrJ27Vpl7ty5yty5czW0OrFEZ7soirk//yeffKLY7XblRz/6kbJ7927lmWeeUTIyMpS//OUv4ec8/PDDSl5envKvf/1L2bx5s3LNNdcoI0eOVHp7ezW0PH7ccsstyrBhw5RXX31V2b9/v/LPf/5TKSoqUr7xjW+En2Om76Czs1PZsGGDsmHDBgVQfvrTnyobNmwIZ3IM5LNeccUVyowZM5TVq1crK1asUMaOHavceOONWn2kmDnVd+D1epWrr75aqaysVDZu3NhnXvR4POHXMPJ3cLq/geM5PttFUZL7+VNCfCiKovzyl79UqqurFafTqZx11lnKqlWrtDYpYQD9/jz11FPh5/T29ipf/epXlfz8fCUjI0O57rrrlPr6eu2MTjDHiw+zf/5XXnlFmTJliuJyuZQJEyYov/vd7/o8HgwGlW9/+9tKaWmp4nK5lEsvvVSpqanRyNr409HRodx7771KdXW1kpaWpowaNUr57//+7z4LjZm+g/fee6/fa/6WW25RFGVgn7WlpUW58cYblaysLCUnJ0e57bbblM7OTg0+zeA41Xewf//+k86L7733Xvg1jPwdnO5v4Hj6Ex/J/PwWRYkq+SeRSCQSiUSSYEwf8yGRSCQSiURfSPEhkUgkEokkqUjxIZFIJBKJJKlI8SGRSCQSiSSpSPEhkUgkEokkqUjxIZFIJBKJJKlI8SGRSCQSiSSpSPEhkUgkEokkqUjxIZFIJBKJJKlI8SGRSCQSiSSpSPEhkUgkEokkqUjxIZFIJBKJJKn8fz9Z9IrPVEeNAAAAAElFTkSuQmCC\n", @@ -84,7 +74,8 @@ "source": [ "plt.plot(sol.ts, sol.ys[0], label=\"Prey\")\n", "plt.plot(sol.ts, sol.ys[1], label=\"Predator\")\n", - "plt.legend()" + "plt.legend()\n", + "plt.show()" ] } ], diff --git a/examples/neural_ode.ipynb b/examples/neural_ode.ipynb index 356b5aca..49ef3c5e 100644 --- a/examples/neural_ode.ipynb +++ b/examples/neural_ode.ipynb @@ -309,6 +309,7 @@ "source": [ "Some notes on speed:\n", "The hyperparameters for the above example haven't really been optimised. Try experimenting with them to see how much faster you can make this example run. There's lots of things you can try tweaking:\n", + "\n", "- The size of the neural network.\n", "- The numerical solver.\n", "- The step size controller, including both its step size and its tolerances.\n", @@ -317,8 +318,9 @@ "- ... etc.!\n", "\n", "Some notes on being Markov:\n", + "\n", "- This example has assumed that the problem is Markov. Essentially, that the data `ys` is a complete observation of the system, and that we're not missing any channels. Note how the result of our model is evolving in data space. This is unlike e.g. an RNN, which has hidden state, and a linear map from hidden state to data.\n", - "- If we wanted we could generalise this to the non-Markov case: inside `NeuralODE`, project the initial condition into some high-dimensional latent space, do the ODE solve there, then take a linear map to get the output. See `latent_ode.ipynb` for an example doing this as part of a generative model; also see [Augmented Neural ODEs](https://arxiv.org/abs/1904.01681) for a short paper on it." + "- If we wanted we could generalise this to the non-Markov case: inside `NeuralODE`, project the initial condition into some high-dimensional latent space, do the ODE solve there, then take a linear map to get the output. See the [Latent ODE example](../latent_ode) for an example doing this as part of a generative model; also see [Augmented Neural ODEs](https://arxiv.org/abs/1904.01681) for a short paper on it." ] } ], diff --git a/examples/nonlinear_heat_pde.ipynb b/examples/nonlinear_heat_pde.ipynb index c61b92ee..c7fc00cd 100644 --- a/examples/nonlinear_heat_pde.ipynb +++ b/examples/nonlinear_heat_pde.ipynb @@ -18,26 +18,28 @@ "$$ \\frac{\\partial y}{\\partial t}(t, x) = (1 - y(t, x)) \\Delta y(t, x) \\qquad\\text{in}\\qquad t \\in [0, 40], x \\in [-1, 1]$$\n", "\n", "subject to the initial condition\n", + "\n", "$$ y(0, x) = x^2, $$\n", "\n", "and Dirichlet boundary conditions\n", - "$$ y(t, -1) = 1, $$\n", - "$$ y(t, 1) = 1. $$\n", + "\n", + "$$ y(t, -1) = 1,\\qquad y(t, 1) = 1. $$\n", "\n", "---\n", "\n", - "We spatially discretise $x \\in [-1, 1]$ into points $-1 = x_0 < x_1 < \\cdots < x_{n-1} = 1$, with equal spacing $\\delta x = x_{i+1} - x_i$. The solution is then discretised into $y(t, x_i) \\approx y_i(t)$, and the Laplacian discretised into $\\Delta y(t,x_i) \\approx \\frac{y_{i+1}(t) - 2y_{i}(t) + y_{i-1}(t)}{2 \\delta x}$.\n", + "We spatially discretise $x \\in [-1, 1]$ into points $-1 = x_0 < x_1 < \\cdots < x_{n-1} = 1$, with equal spacing $\\delta x = x_{i+1} - x_i$. The solution is then discretised into $y(t, x_i) \\approx y_i(t)$, and the Laplacian discretised into $\\Delta y(t,x_i) \\approx \\frac{y_{i+1}(t) - 2y_{i}(t) + y_{i-1}(t)}{\\delta x^2}$.\n", "\n", "In doing so we reduce to a system of ODEs\n", "\n", - "$$ \\frac{\\mathrm{d}y_i}{\\mathrm{d}t}(t) = (1 - y_i(t)) \\frac{y_{i+1}(t) - 2y_{i}(t) + y_{i-1}(t)}{2 \\delta x} \\qquad\\text{for}\\qquad i \\in \\{1, ..., n-2\\},$$\n", + "$$ \\frac{\\mathrm{d}y_i}{\\mathrm{d}t}(t) = (1 - y_i(t)) \\frac{y_{i+1}(t) - 2y_{i}(t) + y_{i-1}(t)}{\\delta x^2} \\qquad\\text{for}\\qquad i \\in \\{1, ..., n-2\\},$$\n", "\n", "subject to the initial condition\n", + "\n", "$$ y_i(0) = {x_i}^2, $$\n", "\n", "for which the Dirichlet boundary conditions become\n", - "$$ \\frac{\\mathrm{d}y_0}{\\mathrm{d}t}(t) = 0, $$\n", - "$$ \\frac{\\mathrm{d}y_{n-1}}{\\mathrm{d}t}(t) = 0. $$\n", + "\n", + "$$ \\frac{\\mathrm{d}y_0}{\\mathrm{d}t}(t) = 0,\\qquad \\frac{\\mathrm{d}y_{n-1}}{\\mathrm{d}t}(t) = 0. $$\n", "\n", "---\n", "\n", @@ -127,7 +129,7 @@ "def laplacian(y: SpatialDiscretisation) -> SpatialDiscretisation:\n", " y_next = jnp.roll(y.vals, shift=1)\n", " y_prev = jnp.roll(y.vals, shift=-1)\n", - " Δy = (y_next - 2 * y.vals + y_prev) / (2 * y.δx)\n", + " Δy = (y_next - 2 * y.vals + y_prev) / (y.δx**2)\n", " # Dirichlet boundary condition\n", " Δy = Δy.at[0].set(0)\n", " Δy = Δy.at[-1].set(0)\n", @@ -167,7 +169,7 @@ "\n", "# Temporal discretisation\n", "t0 = 0\n", - "t_final = 20\n", + "t_final = 1\n", "δt = 0.0001\n", "saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t_final, 50))\n", "\n", @@ -175,7 +177,7 @@ "rtol = 1e-10\n", "atol = 1e-10\n", "stepsize_controller = diffrax.PIDController(\n", - " pcoeff=0.3, icoeff=0.4, rtol=rtol, atol=atol\n", + " pcoeff=0.3, icoeff=0.4, rtol=rtol, atol=atol, dtmax=0.001\n", ")" ] }, @@ -212,17 +214,7 @@ "outputs": [ { "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcEAAAGiCAYAAACf230cAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABKOklEQVR4nO3df3QU5b0/8PfsJtmgklBLSAhGfviDHwUSDBqDWOGYGtFDjbVeoFR+FOFISQ+YWiFWCUJrWq0UvUbS0sbgrVzQe4R6hRMvRoH6JWIJzal4BAEJQWVX0CYhEfJjd75/pBlcmc+QZzPZZTPvl2fUPDvzPLOzM/vsPPOZz2i6rusgIiJyIFekV4CIiChS2AkSEZFjsRMkIiLHYidIRESOxU6QiIgci50gERE5FjtBIiJyLHaCRETkWOwEiYjIsdgJEhGRY7ETJCKiiNu1axemTp2K1NRUaJqGLVu2XHCZHTt24LrrroPH48HVV1+N8vJy5XbZCRIRUcQ1NzcjPT0dJSUlXZr/6NGjuPPOOzF58mTU1NRgyZIluP/++/HGG28otasxgTYREV1MNE3D5s2bkZeXJ86zdOlSbN26Ffv37zfKpk+fjvr6elRUVHS5rZjurCgREfUeZ8+eRWtrqy116boOTdOCyjweDzwejy31V1VVIScnJ6gsNzcXS5YsUaqHnSAREeHs2bMYMvQy+Lx+W+q77LLL0NTUFFRWVFSEFStW2FK/1+tFcnJyUFlycjIaGxtx5swZ9OnTp0v1sBMkIiK0trbC5/Xjg0ND0Dehe+EipxsD+M41tTh+/DgSEhKMcrvOAu3ETpCIiAx9E1xI6GYn2CkhISGoE7RTSkoKfD5fUJnP50NCQkKXzwIBdoJERPQ1WgDQAtqFZ7xAHT0tOzsb27ZtCyrbvn07srOzlerhLRJERHSOrtkzKWpqakJNTQ1qamoAdNwCUVNTg7q6OgBAYWEhZs2aZcz/wAMP4OOPP8bDDz+MAwcO4Pnnn8fLL7+MBx98UKlddoJERBRxe/fuxbhx4zBu3DgAQEFBAcaNG4fly5cDAE6cOGF0iAAwdOhQbN26Fdu3b0d6ejqefvpp/OlPf0Jubq5Su7xPkIiI0NjYiMTERHzyydVISHB3sy4/rrjiMBoaGnrsmqBdeE2QiIgMHdcEu19HtOBwKBERORbPBImI6JzAv6fu1hEl2AkSEZFB0zum7tYRLTgcSkREjsUzQSIiMmi6DYExUXQmyE6QiIjOCegdU3friBIcDiUiIsfimSARERmcFhjDTpCIiM5x2C0SHA4lIiLH4pkgEREZtIAOrZuBLd1dPpzYCRIR0TkcDiUiInIGngkSEZGB0aFERORcHA4lIiJyBp4JEhGRwWkP1WUnSERE5+gA9G5e1Iuia4IcDiUiIsfimSARERn4KCUiInIuRocSERE5A88EiYjIwJvliYjIuTgcSkRE5Ay2doLFxcW4/vrr0bdvXwwYMAB5eXk4ePBg0Dxnz57FokWL8O1vfxuXXXYZ7rnnHvh8Pst6dV3H8uXLMXDgQPTp0wc5OTk4dOiQnatORETAuTPB7k5RwtZOcOfOnVi0aBHeffddbN++HW1tbbjtttvQ3NxszPPggw/if//3f/HKK69g586d+Oyzz/CDH/zAst4nn3wSzz77LEpLS7Fnzx5ceumlyM3NxdmzZ+1cfSIix+u4Jqh1c4r0u+g6Tde7mxpAdvLkSQwYMAA7d+7Ed7/7XTQ0NCApKQkbNmzAD3/4QwDAgQMHMHLkSFRVVeHGG288rw5d15Gamoqf//zneOihhwAADQ0NSE5ORnl5OaZPn95Tq09E5BiNjY1ITEzEl7uHIeEyd/fqavLj8gkfo6GhAQkJCTatYc/o0cCYhoYGAMDll18OAKiurkZbWxtycnKMeUaMGIErr7xS7ASPHj0Kr9cbtExiYiKysrJQVVUldoItLS1oaWkx/g4EAvjyyy/x7W9/G5qm2fL+iIgiSdd1nD59GqmpqXC5bBrYc1hgTI91goFAAEuWLMFNN92E0aNHAwC8Xi/i4uLQr1+/oHmTk5Ph9XpN6+ksT05O7vIyQMf1yccff7wb74CIKDocP34cV1xxhT2VsRO0x6JFi7B//3688847PdWEpcLCQhQUFBh/NzQ04Morr8T+I2no2zf4F5PnS/NT/9h/yfW7GoThgoY402L9tMe03N9kXg4A7c3xQrl5G+1n1crbzsbKbbeYv9beZr7LtLcKbbfJv079Ul1+820rzR8ImLfh98ttBwLmowG6VK73/OiBJlxI0Vzm5S6hHADcbvNvIZfLvNwd225aHuP2K80PADGx5m3ExLUK8wtte9rENmLjzV+LiRfakMovlcrleAP3ZS2m5Vpf83IkmrcRSDTftgDQ9i3z8pbLg5c5fTqA0VcdR9++fcW6yFqPdIL5+fl4/fXXsWvXrqBfJykpKWhtbUV9fX3Q2aDP50NKSoppXZ3lPp8PAwcODFomIyNDXAePxwOP5/wOpm9fFxISvtEJtgmdYJv8JeNqFzpBoS5d+mL3yx9Bu/Bau1/ooHTz8japPGDecVnV1S7sMm260AlCvrbgF+pq18yXadcVO0Ghno5lensnaP4FK3WCMTHm2yomRuoE5W0bE2u+TKzwmysmzvzzi4mTt3ms8Nsx1mO+TWLipXJhe/SRj0v3JeadtnaJsE0uNS8PXCZ/fm19zd97S4Kwj9h5iUdH958CEUWBMbZGh+q6jvz8fGzevBlvvfUWhg4dGvR6ZmYmYmNjUVlZaZQdPHgQdXV1yM7ONq1z6NChSElJCVqmsbERe/bsEZchIqLQaAHNlila2NoJLlq0CH/5y1+wYcMG9O3bF16vF16vF2fOnAHQEdAyb948FBQU4O2330Z1dTXmzp2L7OzsoKCYESNGYPPmzQA6fuEsWbIEv/rVr/Daa6/h/fffx6xZs5Camoq8vDw7V5+IiBzG1uHQtWvXAgAmTZoUVP7CCy9gzpw5AIDf//73cLlcuOeee9DS0oLc3Fw8//zzQfMfPHjQiCwFgIcffhjNzc1YsGAB6uvrMXHiRFRUVCA+3vyaGRERhchhw6G2doJdueUwPj4eJSUlKCkp6XI9mqZh5cqVWLlyZbfXMVLCca1JrEu13KIusVyIBrN6f+J1OcW2pWuCVruj6uehWm5FuvYnkratRT3SNpHalretebnLat9R3BeU91uL1+w6nqzqkfad6BkAvABdA7o7nBmGa+h2Ye5QIiJyLD5FgoiIzuF9gkRE5FgOuybI4VAiInIsngkSEdE5ARsCY6LoPkF2gqGSxrztjIpSjSgVIzfV18m2KMkQIlDlaE/7IjftajsUUl1S5KaU5cX6/am1Lb1vqW1bI4tD2LbiPq3ahp1f1mIb9jURFrrW/e8xRocSERFd/HgmSEREBi3QMXW3jmjBTpCIiM5x2DVBDocSEZFj8UyQiIjOcdh9guwEiYjoHIcNh7ITDBc7Q7QVf2Wphstbvaac3Nri6e5yMm7Vts3rD6ntHk7MHIpQtq2mmT/Y1q5tbtW2y93zycm7kKv/Gwsozm91XEZR+D9dGDtBIiI6x2H3CbITJCKicxyWQJvRoURE5Fg8EyQionM4HEpERE6l61pI+Ya/WUe0YCd4EVPdkVSjQC0j7KS6bEpuHcoydiXWtrOuUA52KVG2RDm5NexLlG3r5yq0Le5rFpsppH1aoR5yDnaCRER0DodDiYjIsRgdSkRE5Aw8EyQionM4HEpERI7F3KFkCGVcW/rwpV9GVr+YxAhNtfyPYvVWEX7K+TvVo/XsqkvKYxlKBKMYNWrjQa3DvC7NZf4GVSM3AXmbKOcUDcfnqrivWVE+NqT3YXlc2pQHOIqum/Vm7ASJiOgcDocSEZFjOWw4lNGhRETkWLZ3grt27cLUqVORmpoKTdOwZcuWoNc1TTOdnnrqKbHOFStWnDf/iBEj7F51IiLSbZqihO2dYHNzM9LT01FSUmL6+okTJ4KmsrIyaJqGe+65x7Le73znO0HLvfPOO3avOhGR4+kBzZYpFCUlJRgyZAji4+ORlZWF9957z3L+NWvWYPjw4ejTpw/S0tLw4IMP4uzZs0pt2n5NcMqUKZgyZYr4ekpKStDff/3rXzF58mQMGzbMst6YmJjzlo0omy78Wu0sckSbPeWWbYvLSAuEkL9TiGC0K6+nVI9lGzZFxVoRc4cK2zYQwm9VqQ3VbSVG3gpPj+94UW3fUY0atWrDrmPAzmjgaAoSiaRNmzahoKAApaWlyMrKwpo1a5Cbm4uDBw9iwIAB582/YcMGLFu2DGVlZZgwYQI++ugjzJkzB5qmYfXq1V1uN6LXBH0+H7Zu3Yp58+ZdcN5Dhw4hNTUVw4YNw8yZM1FXV2c5f0tLCxobG4MmIiK6gM7o0O5OilavXo358+dj7ty5GDVqFEpLS3HJJZegrKzMdP7du3fjpptuwo9+9CMMGTIEt912G2bMmHHBs8dvimgnuH79evTt2xc/+MEPLOfLyspCeXk5KioqsHbtWhw9ehQ333wzTp8+LS5TXFyMxMREY0pLS7N79YmIep/O6NDuTsB5JyItLS2mTba2tqK6uho5OTlGmcvlQk5ODqqqqkyXmTBhAqqrq41O7+OPP8a2bdtwxx13KL3diHaCZWVlmDlzJuLj4y3nmzJlCu69916MHTsWubm52LZtG+rr6/Hyyy+LyxQWFqKhocGYjh8/bvfqExGRhbS0tKCTkeLiYtP5Tp06Bb/fj+Tk5KDy5ORkeL1e02V+9KMfYeXKlZg4cSJiY2Nx1VVXYdKkSXjkkUeU1jFi9wn+7W9/w8GDB7Fp0yblZfv164drr70Whw8fFufxeDzweDzdWUUiIufRYcPN8h3/OX78OBISEoxiO7+Td+zYgSeeeALPP/88srKycPjwYSxevBirVq3CY4891uV6ItYJ/vnPf0ZmZibS09OVl21qasKRI0dw33339cCaERE5mG7DzfL/7kQTEhKCOkFJ//794Xa74fP5gsp9Pp8YEPnYY4/hvvvuw/333w8AGDNmDJqbm7FgwQL88pe/hMvVtYFO2zvBpqamoDO0o0ePoqamBpdffjmuvPJKAB3jxK+88gqefvpp0zpuvfVW3H333cjPzwcAPPTQQ5g6dSoGDx6Mzz77DEVFRXC73ZgxY4Y9K23nPS2hRLqpNmHX089tfOp7WJ4sb1M0KSB/HqHUpUqqS4rodAlho7pmEXkrvA/NL+UnVYsmtfNzjeT+GQrlqO2QGrGvqmgQFxeHzMxMVFZWIi8vDwAQCARQWVlp9APf9NVXX53X0bndbgCAbpXc9hts7wT37t2LyZMnG38XFBQAAGbPno3y8nIAwMaNG6HrutiJHTlyBKdOnTL+/uSTTzBjxgx88cUXSEpKwsSJE/Huu+8iKSnJ7tUnInI0XbdOkN7VOlQVFBRg9uzZGD9+PG644QasWbMGzc3NmDt3LgBg1qxZGDRokHFdcerUqVi9ejXGjRtnDIc+9thjmDp1qtEZdoXtneCkSZMu2AsvWLAACxYsEF+vra0N+nvjxo12rBoREV1IhBJoT5s2DSdPnsTy5cvh9XqRkZGBiooKI1imrq4u6Mzv0UcfhaZpePTRR/Hpp58iKSkJU6dOxa9//WuldplAm4iILgr5+fni8OeOHTuC/o6JiUFRURGKioq61SY7QSIiOsdhT5FgJ0hERAZd17odSGRnIFJP46OUiIjIsXgmGGlWMUR23aagmBQ6lDZUb1+wbsOedbJM3q14K0QgDAm0pS0l3u4g3NYAhLKt1OqxTE4uJd2OsfH2GrsSnYvlYtO9H4dDiYjIsSIUHRopHA4lIiLH4pkgEREZnBYYw06QiIjOCUB8wLNSHVGCw6FERORYPBO0YhUhJvzS0aVfQDYmIpYXUIuAszcJslo9gBxhKEYX2hSxarWMX4waFatSJr4PISG222W+U1m9PykCVTm6VzGxdsdrUrmNSa9DSLrd021Lx74mfidYrFgkOSwwhp0gEREZ9IDW7afe2PnUnJ7G4VAiInIsngkSEdE5HA4lIiKnctotEhwOJSIix+KZ4MVM+jUl5U1UzadodfG6h/OWhlKXXdGkgHoUaCj5VyVS5CaESExpXcV6IG8TMW+pEIEajs9VjrYMoQ1pGak8is5Ywka3IXdoFG1XdoJERHSOw64JcjiUiIgci2eCRERk0PXuJ4ewM7lET2MnSERE5zjseYIcDiUiIsfimWCoVKPNpPyBIfxiUs/fKZSHFH2n+ER2yyfLm5erPvVdNQ+oZdshRJqqkurShJVyuc13Hqv3p5w7VNzmfqX5rdpQ3Xes90/VcsXPz+q4VMwPHE1nRYDz7hNkJ0hEROcwOpSIiMgZeCZIREQGpz1Fgp0gERGdo8OG4VBb1iQsOBxKRESOZXsnuGvXLkydOhWpqanQNA1btmwJen3OnDnQNC1ouv322y9Yb0lJCYYMGYL4+HhkZWXhvffes3vVe1bnxWaFqXNY4puTuIzU9L+jvVSmgN98ktYpEHCZTtZtuEwn1fmltgMBTZ4U25Ymf8ClPKl/FtL7tnh/0jaxaZtbrq+0L0j7jrSvhbDfqh5/yseY1dRLhLLdlT6Li4ztnWBzczPS09NRUlIiznP77bfjxIkTxvTf//3flnVu2rQJBQUFKCoqwr59+5Ceno7c3Fx8/vnndq8+EZGzdd4s390pSth+TXDKlCmYMmWK5TwejwcpKSldrnP16tWYP38+5s6dCwAoLS3F1q1bUVZWhmXLlnVrfYmIyLkick1wx44dGDBgAIYPH46FCxfiiy++EOdtbW1FdXU1cnJyjDKXy4WcnBxUVVWJy7W0tKCxsTFoIiIia525Q7s7RYuwd4K33347XnzxRVRWVuK3v/0tdu7ciSlTpsDvN89McerUKfj9fiQnJweVJycnw+v1iu0UFxcjMTHRmNLS0mx9H0REvZHTrgmG/RaJ6dOnG/8/ZswYjB07FldddRV27NiBW2+91bZ2CgsLUVBQYPzd2NjIjpCIiIJE/BaJYcOGoX///jh8+LDp6/3794fb7YbP5wsq9/l8ltcVPR4PEhISgiYiIroABsaE1yeffIIvvvgCAwcONH09Li4OmZmZqKysRF5eHgAgEAigsrIS+fn5PbpuWoQ/SPVE2WqJiK2GLFSTVauWW72mmkA7oPi+reqSklLbObzjF+qSkl67XeYZm63eX8Bl/prLpm0e8X1HcZ+2LbG2jSL9/SKxYzgzmoZDbT8TbGpqQk1NDWpqagAAR48eRU1NDerq6tDU1IRf/OIXePfdd1FbW4vKykrcdddduPrqq5Gbm2vUceutt+K5554z/i4oKMC6deuwfv16fPjhh1i4cCGam5uNaFEiIqJQ2H4muHfvXkyePNn4u/O63OzZs7F27Vr885//xPr161FfX4/U1FTcdtttWLVqFTwej7HMkSNHcOrUKePvadOm4eTJk1i+fDm8Xi8yMjJQUVFxXrAMERF1lx03/0fPmaDtneCkSZOgW8THvvHGGxeso7a29ryy/Pz8Hh/+JCJyOg6HEhEROUTEA2OIiOgiYkd050Ua9GOGnWCopNN9KWothAhN6XEkYlSgGCWpHvEoR24K5X7zNqT5O15TjAIV5hfXSZgfkKNAVdfJTlJ0aEh1BczrCmjm78+lm0egqm5zQH1fCC06VPhsVSOFQ4gaFetSPC4vVnZkfGHGGCIioijAM0EiIjI4LTCGnSAREZ1jx/MRo6gT5HAoERE5Fs8EiYjonIBmmZKvq3VEC3aCADTzwLjQqA4DWOW3VIw2k6LyLtr8j4rvQ8oRqvq+gVAiU83rCeXahxQF6hLyfUrr6rKIJhW3oUuIGpW2obt37DvKUaOhfInbOARo63eSIqddE+RwKBERORbPBImI6ByHBcawEyQiIgOHQ4mIiByCZ4JERGTQAx1Td+uIFuwEwyWECDixKmkZ1eg7i9yaqvlGlaP1EEqEplrbUn5Qy2X89kUwSqToUOkRZG63eT1W709qQ9yGipGpVp+rHmPPfmi5f9p0DCjXb9FGr+Gwa4IcDiUiIsfimSARERmcFhjDTpCIiAxO6wQ5HEpERI7FM0EiIjrHYYEx7AStWIX5Sq+F8gR5iV25FsUn1Fs0rfi09lCeyK4a1SlFJKo+JR6Qo0DF92fnQS3UJecClT4oi/cn1KX5zcv9Qk5RdwjRlsr7iHTMWB1/ik+EV40mtSLWJeYhFSq6SG8j0HWLnKoKdUQLDocSEZFj8UyQiIgMTguMYSdIRETn6BCHm5XqiBIcDiUiIsfimSARERk4HEpERI7FTpDCy2LsXD1JsDB7CLdtyMmO1dq2uk1BuuVBOUm30IZ0G4TlMnbe4qJIbkLahnKMvaYJ29ZlvoxqAvSA2yq5tVCuuG0t90/F/VA5sXYUXdOi7rH9muCuXbswdepUpKamQtM0bNmyxXitra0NS5cuxZgxY3DppZciNTUVs2bNwmeffWZZ54oVK6BpWtA0YsQIu1ediMjx9IBmyxSKkpISDBkyBPHx8cjKysJ7771nOX99fT0WLVqEgQMHwuPx4Nprr8W2bduU2rS9E2xubkZ6ejpKSkrOe+2rr77Cvn378Nhjj2Hfvn149dVXcfDgQXz/+9+/YL3f+c53cOLECWN655137F51IiLqzBjT3UnRpk2bUFBQgKKiIuzbtw/p6enIzc3F559/bjp/a2srvve976G2thb/8z//g4MHD2LdunUYNGiQUru2D4dOmTIFU6ZMMX0tMTER27dvDyp77rnncMMNN6Curg5XXnmlWG9MTAxSUlK6vB4tLS1oaWkx/m5sbOzyskRE1H3f/N71eDzweDym865evRrz58/H3LlzAQClpaXYunUrysrKsGzZsvPmLysrw5dffondu3cjNjYWADBkyBDldYz4LRINDQ3QNA39+vWznO/QoUNITU3FsGHDMHPmTNTV1VnOX1xcjMTERGNKS0uzca2JiHqnzsCY7k4AkJaWFvQ9XFxcbNpma2srqqurkZOTY5S5XC7k5OSgqqrKdJnXXnsN2dnZWLRoEZKTkzF69Gg88cQT8Pv9Su83ooExZ8+exdKlSzFjxgwkJCSI82VlZaG8vBzDhw/HiRMn8Pjjj+Pmm2/G/v370bdvX9NlCgsLUVBQYPzd2NjIjpCI6ALsjA49fvx40He7dBZ46tQp+P1+JCcnB5UnJyfjwIEDpst8/PHHeOuttzBz5kxs27YNhw8fxk9/+lO0tbWhqKioy+sasU6wra0N//Ef/wFd17F27VrLeb8+vDp27FhkZWVh8ODBePnllzFv3jzTZaxOu22hGlUmJgm2M0LTvui7gGoCbSmK0Co6VKxLLbG2FAVq9f5Uo0Cltu3kFiI3/UIQqKZZvD9hm7hc5u/DJTQSEOYPJepXTqxtY/SyXceMVWCH9JpiZKoTJCQkWJ7gdEcgEMCAAQPwxz/+EW63G5mZmfj000/x1FNPXfydYGcHeOzYMbz11lvKG6lfv3649tprcfjw4R5aQyIiZ9L17j8FQnX5/v37w+12w+fzBZX7fD4xFmTgwIGIjY2F2+02ykaOHAmv14vW1lbExcV1qe2wXxPs7AAPHTqEN998E9/+9reV62hqasKRI0cwcODAHlhDIiLnsvOaYFfFxcUhMzMTlZWVRlkgEEBlZSWys7NNl7nppptw+PBhBALnRjE++ugjDBw4sMsdINADnWBTUxNqampQU1MDADh69ChqampQV1eHtrY2/PCHP8TevXvx0ksvwe/3w+v1Gj13p1tvvRXPPfec8fdDDz2EnTt3ora2Frt378bdd98Nt9uNGTNm2L36REQUAQUFBVi3bh3Wr1+PDz/8EAsXLkRzc7MRLTpr1iwUFhYa8y9cuBBffvklFi9ejI8++ghbt27FE088gUWLFim1a/tw6N69ezF58mTj787glNmzZ2PFihV47bXXAAAZGRlBy7399tuYNGkSAODIkSM4deqU8donn3yCGTNm4IsvvkBSUhImTpyId999F0lJSXavPhGRswU0+bqnSh2Kpk2bhpMnT2L58uXwer3IyMhARUWFESxTV1cXdE07LS0Nb7zxBh588EGMHTsWgwYNwuLFi7F06VKldm3vBCdNmgTdYkDY6rVOtbW1QX9v3Lixu6tFRERdEMncofn5+cjPzzd9bceOHeeVZWdn49133w2prU7MHWrFzvyBinlAAcgRpYq5GeWoPKv8j2GIQFXMBSqVhxLRKbYhRgua1xPKwa5p5h+gX7g64ZLmt2haakN12waEN25r5KbifmD1mnJOUTHS02Lj2hXtyfykFwV2gkREZOBTJIiIyLGc1glGPG0aERFRpPBMkIiIvqb7Z4JA9JwJshMkIqJzQnwU0nl1RAl2gqEKIReo6fw2RtmpPz1b/cnryuVCHsmOZezJEaqak9KqLjFSUdy2YhMiKeenS6hM+piscocGhCfL+/3mbUg5Rd1SxLHl52rPvmO1f9p1DIQS1Syuk7CttO7ec0c9ip0gEREZ9IB8S5BKHdGCnSARERkYHUpEROQQPBMkIiKD084E2QkSEZGBnaATBaQQP/s+SNWo0Y6FejYKNLTIVPP5Q8lPqv4EciHiUTHKFLDKN6oaNSo2IXIpBuu6hbdh9f6k3KGaEDUaEN6I+BlJKwWrHK/m84cUoWlXJLSNUaC2CmXHopCwEyQiIkPHk+W7eyZo08qEATtBIiI6x2E3yzM6lIiIHItngkREZGBgDBEROZbTOkEOhxIRkWMpnQlOmjQJGRkZWLNmTQ+tThSRop/EUGzF+aEeUi4nsQ4hCbIUGi+UqyaeBqxC6dVueVB93x112XWLRAgh9uLtC8pVicQE2kIbMUKyR3HbhnB7jeo+Zb1/qh0Dqrf8WAZ22HXsX6SYO5SIiByLw6GCOXPmYOfOnXjmmWegaRo0TUNtbW0PrhoREVHP6vKZ4DPPPIOPPvoIo0ePxsqVKwEASUlJPbZiREQUfk47E+xyJ5iYmIi4uDhccsklSElJ6cl1IiKiCHFaJ8joUCIiciwGxlixMem1GJ1m0Yb0mriMcoJisWnlpMZyMmyL6FApClSM/FObX4r0BIB2aX3FyFv7ftkGhMTsLpf5B6KLH5QcgqcJ20QTQlClbeh2q28P1QhiO/dP8RhQPJYsj0u7knFb7lORCyl12pmgUicYFxcHv9/fU+tCREQR5rROUGk4dMiQIdizZw9qa2tx6tQpBALn/xLdtWsXpk6ditTUVGiahi1btgS9rus6li9fjoEDB6JPnz7IycnBoUOHLth2SUkJhgwZgvj4eGRlZeG9995TWXUiIqLzKHWCDz30ENxuN0aNGoWkpCTU1dWdN09zczPS09NRUlJiWseTTz6JZ599FqWlpdizZw8uvfRS5Obm4uzZs2K7mzZtQkFBAYqKirBv3z6kp6cjNzcXn3/+ucrqExHRBXSeCXZ3ihZKw6HXXnstqqqqLOeZMmUKpkyZYvqarutYs2YNHn30Udx1110AgBdffBHJycnYsmULpk+fbrrc6tWrMX/+fMydOxcAUFpaiq1bt6KsrAzLli1TeQtERGRF10KLh/hmHVEirNGhR48ehdfrRU5OjlGWmJiIrKwssXNtbW1FdXV10DIulws5OTmWHXJLSwsaGxuDJiIioq8La3So1+sFACQnJweVJycnG69906lTp+D3+02XOXDggNhWcXExHn/88W6usQXVX0qKOSk7XlSLQhMj3RRzjVq9plxulf9RsS4pd6hcrp47VMxbGpZ4MGGdhLchRXp2vGa+TTQhoDSSn6tqORBCtLXqsWR1XNp07F+sGBjTSxQWFqKhocGYjh8/HulVIiKii0xYzwQ7M834fD4MHDjQKPf5fMjIyDBdpn///nC73fD5fEHlPp/PMnONx+OBx+Pp/koTETkIzwR70NChQ5GSkoLKykqjrLGxEXv27EF2drbpMnFxccjMzAxaJhAIoLKyUlyGiIhCo+v2TNHC9jPBpqYmHD582Pj76NGjqKmpweWXX44rr7wSS5Yswa9+9Stcc801GDp0KB577DGkpqYiLy/PWObWW2/F3Xffjfz8fABAQUEBZs+ejfHjx+OGG27AmjVr0NzcbESLEhERhcL2TnDv3r2YPHmy8XdBQQEAYPbs2SgvL8fDDz+M5uZmLFiwAPX19Zg4cSIqKioQHx9vLHPkyBGcOnXK+HvatGk4efIkli9fDq/Xi4yMDFRUVJwXLENERN1kx31+UTQcansnOGnSJItchx0RbStXrjQex2TG7DmF+fn5xpnhxcwq56D5AhYvKeaxVI2ys9rRVSNKVZ/6Dqjn/JSfJq72JPpQ2pZ2aSkPqBWX8KEHhDbcwtuQIkABwCWEgUrb0C/kCBXzuEorBfV9QTWnqFVdcrliGyEM5ykf+xcpXhMkIiJyCD5FgoiIDE47E2QnSEREBqd1ghwOJSIix+KZIBERGfSA1u0gn2gKEmInaEV+cLdyLkLVfJ9AKE/iVmvbOneo2jJizk0b85NKkZtSeXu7eu5Qv+JTzkO5KVgXcn6Kayvth+1yGy6pDaHcrjyugHpkseq+1tGIPceGcg5Si7aVy62+XyKo42b37g6H2rQyYcDhUCIiciyeCRIRkcFpgTHsBImIyOC0TpDDoURE5Fg8EyQiIoPTzgTZCRIRkYGdoANpYQjnVb19weo1KXxbDE33C6HpQrnVa1LbcnJr9UTLqsmt24V1leYH5FshAkLYekD6LMQWZNL+5pJfEFZKbkPaJprQhrjNhUTZlrfXCHXJ+46N+6fqLQ8hHJfh+IIPx3cSdWAnSEREBp4JEhGRYzmtE2R0KBERORbPBImIyKDrNuQOjaIzQXaCRERkcNpwKDtBC5YRWtIvJSliTzGiE5Cj7OSoUcWIOau2FRMqi4mWpfcAOXGyHO0ptWFevxQB2tGGlEDbfP5wBOtpwvq6pcbdcl0u4Y24XObbsF2YP0YIl7X+XO3Zd6z2TzkKVLjCI0ZOqx8bdh37jAC9OLATJCIiQ8dTJLpfR7RgJ0hERIaAron3xarUES0YHUpERI7FM0EiIjIwMIaIiJzLhk7QMh3kRYadoN0UIzFDiYATo+mkvJdSzkaL/I9S22KEpmLkHwC0t6vl/JTehzy/2LQYBSpHhwpthBAAIKUC1RRjUF1W709YX7ewjJiXVfiMYoScoh11qe0jqvt5x2tSdKg0v32R0yHlAaaLFjtBIiIycDiUiIgcy2mdIKNDiYjIscLeCQ4ZMgSapp03LVq0yHT+8vLy8+aNj48P81oTETmDHtBsmUJRUlKCIUOGID4+HllZWXjvvfe6tNzGjRuhaRry8vKU2wz7cOjf//53+L+W52r//v343ve+h3vvvVdcJiEhAQcPHjT+1rToOdUmIoomkRoO3bRpEwoKClBaWoqsrCysWbMGubm5OHjwIAYMGCAuV1tbi4ceegg333xzSOsa9k4wKSkp6O/f/OY3uOqqq3DLLbeIy2iahpSUlJ5eNTXKT7AW6rH6xaQYuaYaASdF2FnVpRppavl0d8UowrZ2IYJRaKPN8snyUrkQgSrWpE5q2y1EdIo7Twi/trV283Ipp6hfCH+1+lzt2nes9k+7jgExotPyuBSKVZ9eT0FWr16N+fPnY+7cuQCA0tJSbN26FWVlZVi2bJnpMn6/HzNnzsTjjz+Ov/3tb6ivr1duN6LXBFtbW/GXv/wFP/nJTyzP7pqamjB48GCkpaXhrrvuwgcffHDBultaWtDY2Bg0ERGRtc4zwe5OAM77Dm5paTFts7W1FdXV1cjJyTHKXC4XcnJyUFVVJa7rypUrMWDAAMybNy/k9xvRTnDLli2or6/HnDlzxHmGDx+OsrIy/PWvf8Vf/vIXBAIBTJgwAZ988oll3cXFxUhMTDSmtLQ0m9eeiKj3sbMTTEtLC/oeLi4uNm3z1KlT8Pv9SE5ODipPTk6G1+s1Xeadd97Bn//8Z6xbt65b7zeit0j8+c9/xpQpU5CamirOk52djezsbOPvCRMmYOTIkfjDH/6AVatWicsVFhaioKDA+LuxsZEdIRFRGB0/fhwJCQnG3x6Px5Z6T58+jfvuuw/r1q1D//79u1VXxDrBY8eO4c0338Srr76qtFxsbCzGjRuHw4cPW87n8Xhs2+BERE4R0Lv/FIjOS8kJCQlBnaCkf//+cLvd8Pl8QeU+n880HuTIkSOora3F1KlTz7X57xRRMTExOHjwIK666qourWvEhkNfeOEFDBgwAHfeeafScn6/H++//z4GDhzYQ2tGRORcdg6HdlVcXBwyMzNRWVlplAUCAVRWVgaNBHYaMWIE3n//fdTU1BjT97//fUyePBk1NTVKo34RORMMBAJ44YUXMHv2bMTEBK/CrFmzMGjQIGPseOXKlbjxxhtx9dVXo76+Hk899RSOHTuG+++/Pwwral9UlxQ5FtLTsxWfhh3Kk7sDwtPdpXLpSeOh5H8Uc4RKkZtiudg02hWXkdJ0hvLsUMUYUDFvqVXrLuH9uaT3p/hZWH2u4r6guE9Z7p+quUAVjxnL49Ku7wUbv196g4KCAsyePRvjx4/HDTfcgDVr1qC5udmIFv163xAfH4/Ro0cHLd+vXz8AOK/8QiLSCb755puoq6vDT37yk/Neq6urCwrX/te//oX58+fD6/XiW9/6FjIzM7F7926MGjUqnKtMROQIkbpPcNq0aTh58iSWL18Or9eLjIwMVFRUGMEy3+wb7KLpuh7Kj9mo09jYiMTERBz7fDASEoI3ZJ8TbtNlYj81LwcAnLzEtLj95KWm5a1fXGZa3tJgPj8AnG00b+NMg1De3Eet/Cs58464zBnz66xnzsaZlp9tMS/veM38N1hLm/l2P9sm3D8o/Mpvs3jKwsV4Jigd3m5hgRhNbj1WqCzWbb5MfKz5O/TEmt8hGe8RbjgEEO9pNS3vEy+U9zEPm+9z6RmxjT6XnFVaRixP/Mq0PD7BvBwAPInNpuVx324yLY9JMp8fSXIbbYPMt/uZgcHljY0BDB5wDA0NDV269mal8zty/bCncInL/Pjvqq8CZzD741/Ysl49jblDiYjIsfgUCSIiMjjtKRLsBImIyBDQNRtukYieTpDDoURE5Fg8EwTkqAcLuhgpoZhEN4RQbDGsW0pErBiaDlglNVYrl8LlATkJs5QoOyBscymYRQp+sVqmTZhfzH8utiCTtrq4tmL8i/z+XEK8m1tYYWmbxwhROVafq137juX+Kd1WoXrrRChJrxWPZem7wvJcKZQdyyYcDiUiIsdyWifI4VAiInIsngkSEZHBaWeC7ASJiMig2xAdGk2dIIdDiYjIsXgmaCWECDHVhLwBiyg75Ug3xQTalkmQpWhPobxdiNaTyq1ekxJit0nRpGLCbbFpMQrUPFkV4A8pQZpaG27reMHzWb0/YZuIKduEj8nWz1Vxn7JOvm7PsSFHVIeQ2N7GyPBI0vWOqbt1RAt2gkREZNADmsWTS7peR7TgcCgRETkWzwSJiMjA6FAiInIs5g4lIiJyCJ4J2k01F6FV7lDpAa+KuRZVo0Y7XrMn0lTKDwoAfiFMsl2IzJNzhJqXSxGgACA9ErZdCLkMR+7QgNh2CL+qhRUWnqkrb3NhZa0+V/Xcoer7pxxtrda2GMUYSu7QKAoGscLoUCIiciynXRPkcCgRETkWzwSJiMjgtMAYdoJERGRw2jVBDocSEZFj8UwwVIpRoKpPiQcsnqwtRnvalztUaru9Xe0J8lJeSABokyJKFZ8UL+XilCJAAaBNCJ+U6pIjN9WJ+TuFV3SxFYsnywvl0jZ0Cz/dpc8oNiC/c2lfkPYd1YhOq9ekY0A8ZhQjqoEQnkYfZVGjTguMYSdIREQGp10T5HAoERE5Fs8EiYjIoOuAHkoWiG/UES3YCRIRkUHXbXiUEodDiYiILn5hPxNcsWIFHn/88aCy4cOH48CBA+Iyr7zyCh577DHU1tbimmuuwW9/+1vccccdtq2TJp36Ww0JKP7SUX4atUUb0lOv5RyhapFxHcvY9GT5dqsck+blUgRjm2KOUCkCFJCjQKVlpAhNO3OHyltKekV+f5qwjCbmFDWfX4oCtfpc22PseYK89f6peAxIT4oP4bhUPssRo0blRcTvpDAI6BoC3TwTZGDMBXznO9/BiRMnjOmdd94R5929ezdmzJiBefPm4R//+Afy8vKQl5eH/fv3h3GNiYgcQj93w3yoU0j3DkVIRDrBmJgYpKSkGFP//v3FeZ955hncfvvt+MUvfoGRI0di1apVuO666/Dcc8+FcY2JiKg3ikgneOjQIaSmpmLYsGGYOXMm6urqxHmrqqqQk5MTVJabm4uqqirLNlpaWtDY2Bg0ERGRtYB+7l7B0KdIv4uuC3snmJWVhfLyclRUVGDt2rU4evQobr75Zpw+fdp0fq/Xi+Tk5KCy5ORkeL1ey3aKi4uRmJhoTGlpaba9ByKi3qq7Q6F25B4Np7B3glOmTMG9996LsWPHIjc3F9u2bUN9fT1efvllW9spLCxEQ0ODMR0/ftzW+omIKPpF/D7Bfv364dprr8Xhw4dNX09JSYHP5wsq8/l8SElJsazX4/HA4/HYtp5ERE7gtPsEI94JNjU14ciRI7jvvvtMX8/OzkZlZSWWLFlilG3fvh3Z2dlhWkM1cnJdaQH1UGzVhNihJChu97uFZaRE2WrlANAmvCbdCiElxJZua7D1Fgnh3gJbb5FQHEKyals1Sbe0zWOEzyjW4nOVPnNp35H3NfUE76rHhvhlbXnrklAcZYmyJQE9tP36m3VEi7APhz700EPYuXMnamtrsXv3btx9991wu92YMWMGAGDWrFkoLCw05l+8eDEqKirw9NNP48CBA1ixYgX27t2L/Pz8cK86ERH1MmE/E/zkk08wY8YMfPHFF0hKSsLEiRPx7rvvIikpCQBQV1cHl+tc3zxhwgRs2LABjz76KB555BFcc8012LJlC0aPHh3uVSci6vV0vfu3+UVTYEzYO8GNGzdavr5jx47zyu69917ce++9PbRGRETUiRljiIiIHCLigTFERHTx4HAonWN1Sq8YBSpFjomJfaEe0aYaMSclNAYAv7Be7e2KCbQtIuakRNl+YRtKEZ1S1Gi7xaEsLiNkLg4lx7pE2uouKTxUF/YDyzbMt610wEvbXPqMLD9XMZm6sB8K+5rV/qkaCa0caW1xXCpHgIvzX5xDhk7rBDkcSkREjsUzQSIiMjgtMIadIBERGex4ElIUjYZyOJSIiJyLZ4JERGRwWto0doKhkgL5FHMRWiWa1YUINV01d6gUHSpE61kuI5VL0aRWOSZ7OEdom0UyTmkZv1AuBwSqH+2acL3FJVQVECJWrZIca0J4nnmWTsAt1CV+Rlafq2K0p537pxhRLZVLx5hl7lDFZaKoQwA69qtuJ9Du5vLhxOFQIiJyLJ4JEhGRQbdhODSa7hNkJ0hERAZGhxIRETkEzwSJiMjA6FAnsjMxpBA1J0WOSVFrVsvY9VRt6yfLSzlCpaeDq0UXAvLTzJWjQEN4srwUOSpFh/q7/bVwYW5N+JykBSy2rSa8KEeHmosR9sE4i7ZV9xFpX7N8srzqvq6YUzSU41IMIZZY7VI9v7uJOBxKRETkEDwTJCIiA4dDiYjIsTgcSkRE5BA8EyQiIkMANgyH2rEiYcJOMETy06WFaDMpOs0iokyOgJOekq0WTRrKk7ulZaRckm1WTyAXytuE8lZhkEV6Grx17lDzZcToUKGuQAgDP9JT3wO69FR7YRsK0aQdbZjX1So8pV7KHRor1G/5uQqvKecOtdo/xYhSxWNDzDVqldNXWC8x0tS8/GLNrsnhUCIiIofgmSARERk4HEpERI6lo/sJsDkcSkREFAV4JkhERAYOhxIRkWM5LTqUnaAVq4S4Uji0ankoCbQVkwS3+83TI7e3WyTQFl5rb5cSZQvJka0SaEu3PEgJsYXbFFqEcuk2iI66hNsqhGWkt2HnLRLCpkWsdNXCqmnh9gkpsXascBFIul2lXUoiDYt9QXiD8r5mleDdfJ9WTS5v53EpJta2K+G2A5SUlOCpp56C1+tFeno6/vM//xM33HCD6bzr1q3Diy++iP379wMAMjMz8cQTT4jzS8J+TbC4uBjXX389+vbtiwEDBiAvLw8HDx60XKa8vByapgVN8fHxYVpjIiLn0HFuSDTUKZQzwU2bNqGgoABFRUXYt28f0tPTkZubi88//9x0/h07dmDGjBl4++23UVVVhbS0NNx222349NNPldoNeye4c+dOLFq0CO+++y62b9+OtrY23HbbbWhubrZcLiEhASdOnDCmY8eOhWmNiYico7sdYKjXFFevXo358+dj7ty5GDVqFEpLS3HJJZegrKzMdP6XXnoJP/3pT5GRkYERI0bgT3/6EwKBACorK5XaDftwaEVFRdDf5eXlGDBgAKqrq/Hd735XXE7TNKSkpHS5nZaWFrS0tBh/NzY2qq8sERGF7Jvfux6PBx6P57z5WltbUV1djcLCQqPM5XIhJycHVVVVXWrrq6++QltbGy6//HKldYz4LRINDQ0AcMEVb2pqwuDBg5GWloa77roLH3zwgeX8xcXFSExMNKa0tDTb1pmIqLfSbZoAIC0tLeh7uLi42LTNU6dOwe/3Izk5Oag8OTkZXq+3S+u9dOlSpKamIicnR+HdRjgwJhAIYMmSJbjpppswevRocb7hw4ejrKwMY8eORUNDA373u99hwoQJ+OCDD3DFFVeYLlNYWIiCggLj78bGRnaEREQXYOctEsePH0dCQoJRbnYWaIff/OY32LhxI3bs2KEcLxLRTnDRokXYv38/3nnnHcv5srOzkZ2dbfw9YcIEjBw5En/4wx+watUq02Wk024lFhFwyhFfYkSZRROKCYfVE27LAwF+YZl2v5Qo27weKbrQ6rUWKTpULBeSRQsRoFbLtEuJtcOQQNst7SNCPVa7p7SQlChb2uYxwvyWn6uw2aV9R9rXrPZP1X1aTpQtRY2KTdt27Ft/gL1DQkJCUCco6d+/P9xuN3w+X1C5z+e74GWw3/3ud/jNb36DN998E2PHjlVex4gNh+bn5+P111/H22+/LZ7NSWJjYzFu3DgcPny4h9aOiMiZdJv+UREXF4fMzMygoJbOIJevnwB905NPPolVq1ahoqIC48ePD+n9hr0T1HUd+fn52Lx5M9566y0MHTpUuQ6/34/3338fAwcO7IE1JCJyrkhFhxYUFGDdunVYv349PvzwQyxcuBDNzc2YO3cuAGDWrFlBgTO//e1v8dhjj6GsrAxDhgyB1+uF1+tFU1OTUrthHw5dtGgRNmzYgL/+9a/o27evcdEzMTERffr0AdDxZgcNGmRcRF25ciVuvPFGXH311aivr8dTTz2FY8eO4f777w/36hMRUQ+YNm0aTp48ieXLl8Pr9SIjIwMVFRVGsExdXR1crnPnbWvXrkVrayt++MMfBtVTVFSEFStWdLndsHeCa9euBQBMmjQpqPyFF17AnDlzAJz/Zv/1r39h/vz58Hq9+Na3voXMzEzs3r0bo0aNCtdqExE5QiTTpuXn5yM/P9/0tR07dgT9XVtbG2IrwcLeCepdeEbHN9/s73//e/z+97/voTUiIqJOTKDtQELgnyVdiAST8wqa1yPlOux4TYiAU8yPKOcUtcrNKCwjRY1K+SLFFoBWKXeoENUpRXuq5gEFgBbNb962sE7+MBzWbiHfp1/6XW2x32pC4KEUmSp9EbQKn6tHqAew2BfEiGP1/dOuY0D1GAPkY1nOQ6oeBRrKdxKFhp0gEREZdOjQu9kLd2XE72LBTpCIiAxOGw6NeNo0IiKiSOGZIBERGZx2JshOkIiIvkY944tZHdGCnaAVq58zUiSYEG2mS9FmFvkDVfMg+qUnyEtP4bbIzShF5rUJkW6twj4v5aQErHKBKuYIFcqlCFDrusyXkQL87Mwd6tKlp9qbf34WAZrid5DUdpsYNWpekdXnGifs09K+I+1r1vun+TaRjgHVY8nquJSOZenYl3ONik1QGLETJCIiA4dDiYjIsUJJgG1WR7RgdCgRETkWzwSJiMjA4VAiInIsXev+8351418XP3aCdhNzh6o/PVv5yfKKUXbt7XLbbe1qOUKlJ41LuTgBOUdoi1huHrkplguRnoCch1SMDhXeh63RoULCT+lXdcAqNZXwJSZ9t0lPnI8VyqX9AJAjTcV9R9rXLPZPu44B1WMMkI9lJzwpvjdiJ0hERIaO4dDuncZxOJSIiKKS064JMjqUiIgci2eCRERkcNp9guwEiYjIwOFQIiIih+CZICD/bLEKeRaWkcKnpVBsXcrMDKtwbykRsZBAWwg1lxIXA0C70EabMMqhmgwbAFqEp1eLty9It04ItzWctUigLd0K0SosI0XL+UMY9pFuR5BunZBuhQhoQmJtQE6gLdyGESMlIRd+J8dZ3J4h7gvCsSHta5b7p7hPSwm0hW0rJcO2OC7FY1k1UXYI3y/hEIBuQ3Qoh0OJiCgKOe1meQ6HEhGRY/FMkIiIDBwOJSIiB3PWk+U5HEpERI7FM0ErFhFi4pVjMTpULToNUI9oU40abQshOrRd+IHXLtQjRXQCcoSmFO3ZKoTMSVGgZ8W1kqNAz0rrZBFpahePLkQ2ClG0oSTQduvmn7lbeN9SAu1WIcoUANqFNsR9R4pEtto/bYoCVZ0fsIgcVU2sbfX9EkFOu0+QnSARERmcdk0wYsOhJSUlGDJkCOLj45GVlYX33nvPcv5XXnkFI0aMQHx8PMaMGYNt27aFaU2JiKi3ikgnuGnTJhQUFKCoqAj79u1Deno6cnNz8fnnn5vOv3v3bsyYMQPz5s3DP/7xD+Tl5SEvLw/79+8P85oTEfVuuk1TtIhIJ7h69WrMnz8fc+fOxahRo1BaWopLLrkEZWVlpvM/88wzuP322/GLX/wCI0eOxKpVq3DdddfhueeeC/OaExH1bgFNt2WKFmG/Jtja2orq6moUFhYaZS6XCzk5OaiqqjJdpqqqCgUFBUFlubm52LJli9hOS0sLWlpajL8bGhoAAKdPn3/J1t9kfoHa/ZUcDOE/Yx50cfas+TPWv2oxL29qaxXbaGprMS1vbo81b8NvXn4mYB5EcNYisOKsEKTRIgVKiKmy5G3YJgSutAvlfiF4wy881z5g8WR56TWpXLeoyy4BKTAGQrkwPwD4hWX8urDNhbradPOvCKkcAFqFtqV9R9oPPRb7zpnAWdPyOL95eUy7eblLOMY0i+MSwrHsF479VuG7wvL7pck8tKS5Mbi88/tMtwqSIkth7wRPnToFv9+P5OTkoPLk5GQcOHDAdBmv12s6v9frFdspLi7G448/fl756KuOh7DW1COk49ahx3NzpFcgUqS+4LTFMtJrJ7q5LlHqiy++QGJioi11OS0wptdGhxYWFgadPdbX12Pw4MGoq6uzbWfp7RobG5GWlobjx48jISEh0qsTNbjd1HGbhaahoQFXXnklLr/8ctvqtOOaXvR0gRHoBPv37w+32w2fzxdU7vP5kJKSYrpMSkqK0vwA4PF44PF4zitPTEzkQaYoISGB2ywE3G7quM1C43Ix70mowr7l4uLikJmZicrKSqMsEAigsrIS2dnZpstkZ2cHzQ8A27dvF+cnIqLQdA6HdneKFhEZDi0oKMDs2bMxfvx43HDDDVizZg2am5sxd+5cAMCsWbMwaNAgFBcXAwAWL16MW265BU8//TTuvPNObNy4EXv37sUf//jHSKw+EVGvxWuCYTBt2jScPHkSy5cvh9frRUZGBioqKozgl7q6uqDT+wkTJmDDhg149NFH8cgjj+Caa67Bli1bMHr06C636fF4UFRUZDpESua4zULD7aaO2yw03G7dp+mMrSUicrzGxkYkJiZivPsZxGh9ulVXu34Ge/2L0dDQcNFf4+210aFERKROt+FRSt1/FFP4MKSIiIgci2eCRERk0G0IjImmM0F2gkREZAhoOrRu5v6MpuhQDocSEZFj9dpO8Ne//jUmTJiASy65BP369evSMrquY/ny5Rg4cCD69OmDnJwcHDp0qGdX9CLz5ZdfYubMmUhISEC/fv0wb948NDU1WS4zadIkaJoWND3wwANhWuPI4PMw1alss/Ly8vP2qfj4+DCubeTt2rULU6dORWpqKjRNs3xgQKcdO3bguuuug8fjwdVXX43y8nLldgM2TdGi13aCra2tuPfee7Fw4cIuL/Pkk0/i2WefRWlpKfbs2YNLL70Uubm5OHvWPAN9bzRz5kx88MEH2L59O15//XXs2rULCxYsuOBy8+fPx4kTJ4zpySefDMPaRgafh6lOdZsBHSnUvr5PHTt2LIxrHHnNzc1IT09HSUlJl+Y/evQo7rzzTkyePBk1NTVYsmQJ7r//frzxxhtK7TotY0yvv0+wvLwcS5YsQX19veV8uq4jNTUVP//5z/HQQw8B6EhOm5ycjPLyckyfPj0MaxtZH374IUaNGoW///3vGD9+PACgoqICd9xxBz755BOkpqaaLjdp0iRkZGRgzZo1YVzbyMnKysL1119vPM8yEAggLS0NP/vZz7Bs2bLz5p82bRqam5vx+uuvG2U33ngjMjIyUFpaGrb1jiTVbdbV49YpNE3D5s2bkZeXJ86zdOlSbN26NejH1fTp01FfX4+KiooLttF5n+DomKfh7uZ9gn79DPa3/zwq7hPstWeCqo4ePQqv14ucnByjLDExEVlZWeJzDnubqqoq9OvXz+gAASAnJwculwt79uyxXPall15C//79MXr0aBQWFuKrr77q6dWNiM7nYX59P+nK8zC/Pj/Q8TxMp+xXoWwzAGhqasLgwYORlpaGu+66Cx988EE4Vjdq2bWf6Tb9Ey0YHfpvnc8mVH1uYW/i9XoxYMCAoLKYmBhcfvnlltvgRz/6EQYPHozU1FT885//xNKlS3Hw4EG8+uqrPb3KYReu52H2JqFss+HDh6OsrAxjx45FQ0MDfve732HChAn44IMPcMUVV4RjtaOOtJ81NjbizJkz6NOna2d3AejQHJQ7NKrOBJctW3bexfJvTtJB5WQ9vd0WLFiA3NxcjBkzBjNnzsSLL76IzZs348iRIza+C3KS7OxszJo1CxkZGbjlllvw6quvIikpCX/4wx8ivWrUy0TVmeDPf/5zzJkzx3KeYcOGhVR357MJfT4fBg4caJT7fD5kZGSEVOfFoqvbLSUl5bxAhfb2dnz55ZeWz278pqysLADA4cOHcdVVVymv78UsXM/D7E1C2WbfFBsbi3HjxuHw4cM9sYq9grSfJSQkdPksEHDemWBUdYJJSUlISkrqkbqHDh2KlJQUVFZWGp1eY2Mj9uzZoxRhejHq6nbLzs5GfX09qqurkZmZCQB46623EAgEjI6tK2pqagAg6MdEb/H152F2Bil0Pg8zPz/fdJnO52EuWbLEKHPS8zBD2Wbf5Pf78f777+OOO+7owTWNbtnZ2efdehPKfua0TjCqhkNV1NXVoaamBnV1dfD7/aipqUFNTU3QPW8jRozA5s2bAXREXy1ZsgS/+tWv8Nprr+H999/HrFmzkJqaahmR1ZuMHDkSt99+O+bPn4/33nsP/+///T/k5+dj+vTpRmTop59+ihEjRhj3eB05cgSrVq1CdXU1amtr8dprr2HWrFn47ne/i7Fjx0by7fSYgoICrFu3DuvXr8eHH36IhQsXnvc8zMLCQmP+xYsXo6KiAk8//TQOHDiAFStWYO/evV3uAHoD1W22cuVK/N///R8+/vhj7Nu3Dz/+8Y9x7Ngx3H///ZF6C2HX1NRkfG8BHcF7nd9pAFBYWIhZs2YZ8z/wwAP4+OOP8fDDD+PAgQN4/vnn8fLLL+PBBx+MxOpHjag6E1SxfPlyrF+/3vh73LhxAIC3334bkyZNAgAcPHgQDQ0NxjwPP/wwmpubsWDBAtTX12PixImoqKhw1E26L730EvLz83HrrbfC5XLhnnvuwbPPPmu83tbWhoMHDxrRn3FxcXjzzTeNByOnpaXhnnvuwaOPPhqpt9DjIvE8zGinus3+9a9/Yf78+fB6vfjWt76FzMxM7N69G6NGjYrUWwi7vXv3YvLkycbfBQUFAIDZs2ejvLwcJ06cMDpEoGM0a+vWrXjwwQfxzDPP4IorrsCf/vQn5ObmKrUbAGw4E4wevf4+QSIiurDO+wSHxf4WLq17P/wD+ll83LaU9wkSERFdzHrtcCgREanrCGpxTmAMO0EiIjI4rRPkcCgRETkWzwSJiMjgtyH3ZzSdCbITJCIiA4dDiYiIHIJngkREZHDamSA7QSIiMvi1AHStezlfAlGUM4bDoURE5FjsBIlscPLkSaSkpOCJJ54wynbv3o24uDhUVlZGcM2I1Pih2zJFC3aCRDZISkpCWVmZ8YSI06dP47777jOSkRNFi4ANHWCo1wRLSkowZMgQxMfHIysry3hajeSVV17BiBEjEB8fjzFjxpz3KKmuYCdIZJM77rgD8+fPx8yZM/HAAw/g0ksvRXFxcaRXiygqbNq0CQUFBSgqKsK+ffuQnp6O3Nzc8x703Wn37t2YMWMG5s2bh3/84x/Iy8tDXl4e9u/fr9QunyJBZKMzZ85g9OjROH78OKqrqzFmzJhIrxJRl3Q+ReIyTxG0bj5FQtfPoqnlcaWnSGRlZeH666/Hc889B6DjwctpaWn42c9+hmXLlp03/7Rp09Dc3IzXX3/dKLvxxhuRkZGB0tLSLq8rzwSJbHTkyBF89tlnCAQCqK2tjfTqECnT0QJdP9u9CS0AOjrWr08tLS2mbba2tqK6uho5OTlGmcvlQk5ODqqqqkyXqaqqCpofAHJzc8X5JbxFgsgmra2t+PGPf4xp06Zh+PDhuP/++/H+++9jwIABkV41oguKi4tDSkoKvN7f2FLfZZddhrS0tKCyoqIirFix4rx5T506Bb/fbzxkuVNycjIOHDhgWr/X6zWd3+v1Kq0nO0Eim/zyl79EQ0MDnn32WVx22WXYtm0bfvKTnwQN1xBdrOLj43H06FG0trbaUp+u69A0LajM4/HYUred2AkS2WDHjh1Ys2YN3n77beMayH/9138hPT0da9euxcKFCyO8hkQXFh8fj/j47l0PDEX//v3hdrvh8/mCyn0+H1JSUkyXSUlJUZpfwmuCRDaYNGkS2traMHHiRKNsyJAhaGhoYAdIdAFxcXHIzMwMuqc2EAigsrIS2dnZpstkZ2efdw/u9u3bxfklPBMkIqKIKygowOzZszF+/HjccMMNWLNmDZqbmzF37lwAwKxZszBo0CDjtqPFixfjlltuwdNPP40777wTGzduxN69e/HHP/5RqV12gkREFHHTpk3DyZMnsXz5cni9XmRkZKCiosIIfqmrq4PLdW7wcsKECdiwYQMeffRRPPLII7jmmmuwZcsWjB49Wqld3idIRESOxWuCRETkWOwEiYjIsdgJEhGRY7ETJCIix2InSEREjsVOkIiIHIudIBERORY7QSIicix2gkRE5FjsBImIyLHYCRIRkWP9fwTknDm9uQKpAAAAAElFTkSuQmCC\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbgAAAGiCAYAAACVh9NOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABESUlEQVR4nO3df3RU5Z0/8PedSTIh0oS6QCIYjb+VggTDkg3agqdZ08phyx/dpcACZgXXH9kDZKuAAhE5JdYfbGyLZsVNaffIgu3xx57C4qGpWdclhSWQU7WCi4ChrBOhfkkgQCaZ+3z/SBkdcz8PeSY3mcy97xdn/uCZ+2tm7nOf3Od+ns9jKaUUiIiIPCaQ7AMgIiIaCGzgiIjIk9jAERGRJ7GBIyIiT2IDR0REnsQGjoiIPIkNHBEReRIbOCIi8iQ2cERE5Els4IiIyJPYwBER0YB6++23MXPmTIwZMwaWZeH111+/5DoNDQ247bbbEAqFcP3112Pz5s3G+2UDR0REA6qjowMTJ07Exo0b+7T80aNHMWPGDNx5551obm7G0qVLsWjRIrz55ptG+7WYbJmIiAaLZVl47bXXMGvWLHGZ5cuXY/v27XjvvfdiZd/73vdw+vRp7Ny5s8/7SuvPgRIRUWq4cOECIpGIa9tTSsGyrLiyUCiEUCjU7203NjaitLQ0rqysrAxLly412g4bOCIij7tw4QKuuSYP4XCba9scPnw4zp49G1dWVVWFxx9/vN/bDofDyM3NjSvLzc1Fe3s7zp8/j2HDhvVpO2zgiIg8LhKJIBxuw5GP/wnZ2X1rHHTa28/j2quX4fjx48jOzo6Vu3H35iY2cEREPpGdPcyVBu7z7WXHNXBuycvLQ2tra1xZa2srsrOz+3z3BrCBIyLyDaW6oVS3K9sZSCUlJdixY0dc2a5du1BSUmK0HQ4TICLyCaWirr1MnD17Fs3NzWhubgbQMwygubkZLS0tAICVK1diwYIFseXvv/9+HDlyBI888ggOHjyI559/Hq+88gqWLVtmtF82cERENKD27duHSZMmYdKkSQCAyspKTJo0CWvWrAEAfPLJJ7HGDgCuueYabN++Hbt27cLEiRPx7LPP4qWXXkJZWZnRfjkOjojI49rb25GTk4PwqWddCzLJG/mPaGtrG5BncG7hMzgiIp9IlWdwbmEXJREReRLv4IiIfKInQMSNOzizIJNkYQNHROQTyu6Gsl1o4FzYxmBgFyUREXkS7+CIiPxCdfe83NhOCmADR0TkE4yiJCIi8gDewRER+YXdDdhd7mwnBbCBIyLyiZ4uyqAr20kF7KIkIiJP4h0cEZFf2N2A3f87OHZREhHR0OKzBo5dlERE5Em8gyMi8o2oS4O0mYuSiIiGEMvuhmX3v+POYhclERFR8vAOjojIL+xuwIU7uFQJMmEDR0TkFz5r4NhFSUREnsQ7OCIin7BUNyzlQpBJiqTqYgNHROQXtg3YLoT423b/tzEI2EVJRESexDs4IiKf6BkHZ7mynVTABo6IyC/sqEtRlKmRyYRdlERE5EmuN3Bvv/02Zs6ciTFjxsCyLLz++uuXXKehoQG33XYbQqEQrr/+emzevNntwyIiIrvbvVcKcL2B6+jowMSJE7Fx48Y+LX/06FHMmDEDd955J5qbm7F06VIsWrQIb775ptuHRkTka5Ydde2VClx/Bvftb38b3/72t/u8fG1tLa655ho8++yzAIBbbrkF77zzDv7pn/4JZWVlbh8eERH5RNKDTBobG1FaWhpXVlZWhqVLl4rrdHZ2orOzM/Z/27bx2Wef4c/+7M9gWf2PECIiSjalFM6cOYMxY8YgEHCps025FGSifHoHZyocDiM3NzeuLDc3F+3t7Th//jyGDRvWa53q6mqsXbt2sA6RiChpjh8/jiuvvNKVbVm27Ur3opUiA72T3sAlYuXKlaisrIz9v62tDVdddRWOtTyH7Oz4BvFC50nHbUTP/UHcfuCc8zqBc390LA+ePe28/Pmz4j6scx3Ob1w471x+7oJzecT5Ya8SFgcAFXH+C05FnE8H1eVcbncHxX3YwjqqW9h31HlbKiosr0k3pJTzXbySxv+4MC7okgLKsdiSyi3n8p73nC8uVlAqd76gWWnOywfS5QCCQJqwLWEdK0Mqly+QVqbwRoZwucoSVsjs/ccxAKisy8R928OGO5ZHh49wXj7rz4TyUeI+glnOjVVmKH6d9vbzKLhqCb7yla+I2yK9pDdweXl5aG1tjStrbW1Fdna2490bAIRCIYRCoV7l2dnDkJ2dFVeW0el88kfTeq9/USCY7lwecP66gnC+OAcCcgNgWcIFOiBdbIXyoHAx11yzlbAPsVzYh90lNzK20KWiglIDZ1iu6WYxb+AGYbRMQGh8EmnghMO1hNPNCgr7EGp/IF3+PgLpwudId/5urQyhPCSfoGIDJ62TJZQPE86dLM0fZsJ70eHCH3mXOV8r7Mvk60tQaJAzQ1mO5a4+drGj7vxB59cgE1MlJSXYsWNHXNmuXbtQUlKSpCMiIvKmnghINzKZpEYD5/qfrmfPnkVzczOam5sB9AwDaG5uRktLC4Ce7sUFCxbElr///vtx5MgRPPLIIzh48CCef/55vPLKK1i2bJnbh0ZERD7i+h3cvn37cOedd8b+f/FZ2cKFC7F582Z88sknscYOAK655hps374dy5Ytw3PPPYcrr7wSL730EocIEBG5jV2U/TN9+nQoJT8/cMpSMn36dBw4cMDtQ0k+3UkgRSGZRidJiycS5ORmcIb0nhAcIu1DetYmPWfTbUt61qbbllssad/SD6XrW5GeGQrP7SwpxNwW6qnmd5W+W8v0O9Sdn6bnrmldSpGL80BgFyUREZEHJD3IhIiIBgm7KImIyIssW7kySNuSureHGHZREhGRJ/EOjojIL+xoYgFoTttJAZ5v4JRKYN4i4ccTI4cG48c2jq40j3AUow9Noys17xnvI4F9m0ZLarflFinw0TS6Urct088nZMjQfR+m+xAj9rTfuRTdOQj5Dweh7id0TXKLcqmBS5Fky+yiJCIiT/L8HRwREfWwlG0+ZlHYTipgA0dE5Bc+ewbHLkoiIvIk3sEREfmFbbs00JtdlERENJSwgSPX6E4C4wSxUui02SEBmslCpZB1Mexe08PtVlLlBCYptcVZwAdhOIBARYUweiFBsu7ZgTSEQIkzoZolYdbOli587+JwB2l53YkrJhA3HD7gVkJzSlls4IiIfMKybVgutO9upPsaDGzgiIj8wrZdiqJMjQaOUZRERORJvIMjIvILn93BsYEjIvILNnA+YbuZhDmBH9ulaEkpY442YtA0EXICCXPdSqosRUtKkZK6bYnRo4bb0ZGiIkUB5x9Q9/mkd1RAiJZ0K0EyYHzuSNvSn5/C53ArulJDrMummTsSub6Q6/zbwBER+Y2Kyn8QGG2Hd3BERDSE+G2YAKMoiYjIk3gHR0TkFwwyISIiT2IDR65FQLkZXSkub1gOiJGMYoSjadSldh0hZ2F30LncNOpStw/Dz5cIBSGaUIhwFHNqCtGVgOazS9+hFNkplCfyu4qTaIrRseIuzC/AxnXGxYszoyWHNDZwRER+YSt3Gng3IjEHARs4IiK/sJVLXZSp0cAxipKIiDyJd3BERH7h2oSnvIMjIqKhxLbdexnauHEjCgoKkJmZieLiYuzdu1e7fE1NDW666SYMGzYM+fn5WLZsGS5cuGC0T97BucE0Tx1gnupGGc7o7WaEozTDs3ZGb+d9mM62LS6fwIzekkGZ6VvKGSpEXeo+gfSTB4JmeS2DUqSmbkZvcaZ2sxnZ9fkuxZ3L6zjv3Gx5ILG6TJe0bds2VFZWora2FsXFxaipqUFZWRkOHTqE0aNH91p+y5YtWLFiBerq6jB16lR8+OGHuOeee2BZFjZs2NDn/fIOjojIL2zl3svAhg0bsHjxYpSXl2PcuHGora1FVlYW6urqHJffvXs3br/9dsydOxcFBQW46667MGfOnEve9X0ZGzgiIr9QtnsvAO3t7XGvzs7OXruMRCJoampCaWlprCwQCKC0tBSNjY2Ohzl16lQ0NTXFGrQjR45gx44duPvuu40+Lhs4IiJKSH5+PnJycmKv6urqXsucOnUK0WgUubm5ceW5ubkIh8OO2507dy6eeOIJ3HHHHUhPT8d1112H6dOn49FHHzU6Pj6DIyLyC+XSOLg/PQ89fvw4srOzY8WhUMiFjQMNDQ1Yv349nn/+eRQXF+Pw4cNYsmQJ1q1bh9WrV/d5O2zgiIj8wuWB3tnZ2XENnJORI0ciGAyitbU1rry1tRV5eXmO66xevRrz58/HokWLAAATJkxAR0cH7rvvPjz22GMIBPrW+cgGzoBlGGGlXd6lGb3lGZY1J4D0nuHM3Xa3JtpO3IdZtJ20nURm9DYd/yNGj2pYwmRbUrQkhByVCc3oLeSWtCyzyEe7Ww4gCAZNZ303j4KVZ9WW6obZjN66eil9ctO6T5/LyMhAUVER6uvrMWvWLACAbduor69HRUWF4zrnzp3r1YgFgz25VpVBNC0bOCIiv0hSqq7KykosXLgQkydPxpQpU1BTU4OOjg6Ul5cDABYsWICxY8fGnuHNnDkTGzZswKRJk2JdlKtXr8bMmTNjDV1fsIEjIvKJLwRA9ns7JmbPno2TJ09izZo1CIfDKCwsxM6dO2OBJy0tLXF3bKtWrYJlWVi1ahVOnDiBUaNGYebMmfjBD35gtF82cERENOAqKirELsmGhoa4/6elpaGqqgpVVVX92icbOCIiv/DZbAJs4IiI/MKGSw2cC9sYBBzoTUREnsQ7OAfGIcFiWHMiyV6lbTkXJxISL4fkC+VRIWpJk6TYPKmy8z6k5bUJkqXPISUKdjHZslLOn8OSQviF31ubx9rwO7QMhyKISZghnwsq6FxnEkm2LH0+y3A4gEi3vOG2Um74gM/u4NjAERH5hYI82M90OymAXZRERORJvIMjIvIJZVtit7HZdlw4mEHABo6IyC989gyOXZRERORJ/r2Ds7sTWMfFP1sME8qKXQJSuS4yUIxMdCcRcs9xmSVJliMDDRNDA7BtIdLPONmyeVeOFC0pJVu2As7lAcjRebbwd2kg6HwyiNGS0rFqflclRX1KYZ/Sd6iNgpX27VzuWnSljnGkZgLXl8GgLOOk487b6f8mBoN/GzgiIp/x2zO4Aemi3LhxIwoKCpCZmYni4uLYtOOSmpoa3HTTTRg2bBjy8/OxbNkyXLhwYSAOjYiIfML1O7ht27ahsrIStbW1KC4uRk1NDcrKynDo0CGMHj261/JbtmzBihUrUFdXh6lTp+LDDz/EPffcA8uysGHDBrcPj4jIv2yXuij9ege3YcMGLF68GOXl5Rg3bhxqa2uRlZWFuro6x+V3796N22+/HXPnzkVBQQHuuusuzJkz55J3fUREZEhZ7r1SgKsNXCQSQVNTE0pLSz/fQSCA0tJSNDY2Oq4zdepUNDU1xRq0I0eOYMeOHbj77rvF/XR2dqK9vT3uRURE9EWudlGeOnUK0Wg0NondRbm5uTh48KDjOnPnzsWpU6dwxx13QCmF7u5u3H///Xj00UfF/VRXV2Pt2rVuHnq/SPkE3cx5J3YJGOaVBADVLUQZdhtGUWr+ihNzSAoReuLxCuVSpKRuW2Kkppt/jQrbCggRi9LvakP+fFKEpbKEz21JkatSRGQCeUylc0eYfVmly1GGYp5K026xBOqlWJc9gkEmg6yhoQHr16/H888/j/379+PVV1/F9u3bsW7dOnGdlStXoq2tLfY6fvz4IB4xEVGKsgPuvVKAq3dwI0eORDAYRGtra1x5a2sr8vLyHNdZvXo15s+fj0WLFgEAJkyYgI6ODtx333147LHH4qYxvygUCiEUCrl56ERE5DGuNsMZGRkoKipCfX19rMy2bdTX16OkpMRxnXPnzvVqxIJ/6tZQKkVGExIRpYKLUZRuvFKA68MEKisrsXDhQkyePBlTpkxBTU0NOjo6UF5eDgBYsGABxo4di+rqagDAzJkzsWHDBkyaNAnFxcU4fPgwVq9ejZkzZ8YaOiIi6j+lLFfmP0yVew/XG7jZs2fj5MmTWLNmDcLhMAoLC7Fz585Y4ElLS0vcHduqVatgWRZWrVqFEydOYNSoUZg5cyZ+8IMfuH1oRETkIwOSqquiogIVFRWO7zU0NMQfQFoaqqqqUFVVNRCHApVITjjlvI5rM30DYs5JOVrSuVjMG6jNJyjlBzTLJyhFYwLmOSeVsHwieSWjYqSmWY+8m7kopTPHEmbPDmpCBqUIS+nTKcMZvS1L03MizNwtnTtydKzu/BSOS6p/Yp1xL0eluG/hWqGT0DXJLXbApYHeqXELx1yUREQ+oWzzpOPO20mNBi41Yj2JiIgM8Q6OiMgvXJsux6dRlERENDS5F0WZGg0cuyiJiMiTeAfnBtPoSh1pgIlhrkYpKrFnF2YRb3ZUiGTU5iw0ndHb7JikSEndvsUIThf/GpW2JUVXStGEujNKirCUck7aUed9i7OPaxINSueCJcwmLv6umvNTzkUpRf+6GPDgZl0eitxKs5UiuSjZwBER+YR7yZbZRUlERJQ0vIMjIvIJvwWZsIEjIvILnz2DYxclERF5Eu/giIh8wm9BJv5t4BJIkmocQqxNtiy8JyVVlvq8pdB3XTeENLTAsFwK+de+Z7gtaTiAFK4OaIYDSJ9vEJ4nSCH5lpAI2XLxmIKW80klJlt2mGQ49l7Q7LuVzzXN+Wk4nMMyTaqcQLJl47qfyPVlEPjtGRy7KImIyJP8ewdHROQ3PgsyYQNHROQTfnsGxy5KIiLyJN7BERH5hN+CTNjAORCnp5dIUVmapLXilO/SKuIuDBPTatZRUlJlqVyX8FioALZtlrg5kUTPckJnKTpv4DsyLCGS0RISBQeE5MWA/J0EhEhN6TsPCBGc2t9VPEeEZMu2EE2YwPkpfIWaOiPUMW29NHu4ZHytSDbl0jO41JjQm12URETkTbyDIyLyCb8FmbCBIyLyCaXceX7m5hR8A4ldlERE5Em8gyMi8guXuih1QUJDCRs4E1LEVCI574T3xAAvMVpSiBjURakZRiYmkotS2oeYc9LwmLR5MIWoQduwUibSlSPlnATMIhl1pAhL6TuUWJbwPYmfAQhIuSilc0eMwNXkapSi/MS8ls6LWy7WS+NclEOUUgFXooZVivRRsouSiIg8iXdwRER+YVvudC+yi5KIiIYSv2UyYRclERF5Eu/giIh8ggO9fcLSRnENQsSUYS5KKT+gPEu1JsJRWqdbinw0zwcpRj+KOSqlmbuFciFSsuc9w0g/w3IdceZuMU+ktCWziEgdK+C8k0ACn1vMg9ktRAWnOdezhM5P4RwR80FKdcxNwr6115ckYhQlERGRB/j2Do6IyG/YRUlERJ7EKEoiIiIP4B0cEZFP+O0Ojg2cASm3nRzFpct5Z1pumOuvO4FZmU2jDDW5D8XoR8MclVK0ZFSXB9Pwc0iRnQkRtiXNti1FV7oZpWaac1JaHgCUkAfTOBdlAuennI9V2JBYLtdLqS6LeS1TjFIuPYNLkQaOXZRERORJvIMjIvIJv42DYwNHROQTfhsmwC5KIiLyJN7BERH5BKMoiYjIk9jAeY0ahKSn4jT3mtBi6SGt1LctnVDi8uZh9HJov1k5AEQHeDiAlJwZkMP+3UyqbEoZDhOQhhUkwoo6bysqDAeQjqlnW2bngpTQWXd+ulcHpDqmG74zCMMBBuOaRAD4DI6IyDeU/XmgSf9e5vveuHEjCgoKkJmZieLiYuzdu1e7/OnTp/HQQw/hiiuuQCgUwo033ogdO3YY7dP7d3BERAQgeV2U27ZtQ2VlJWpra1FcXIyamhqUlZXh0KFDGD16dK/lI5EI/vIv/xKjR4/GL3/5S4wdOxYff/wxRowYYbRfNnBERJSQ9vb2uP+HQiGEQqFey23YsAGLFy9GeXk5AKC2thbbt29HXV0dVqxY0Wv5uro6fPbZZ9i9ezfS09MBAAUFBcbHxy5KIiKfuDjQ240XAOTn5yMnJyf2qq6u7rXPSCSCpqYmlJaWxsoCgQBKS0vR2NjoeJz//u//jpKSEjz00EPIzc3F+PHjsX79ekSjZpNR8w6OiMgnbGW5knv14jaOHz+O7OzsWLnT3dupU6cQjUaRm5sbV56bm4uDBw86bv/IkSP4zW9+g3nz5mHHjh04fPgwHnzwQXR1daGqqqrPx+nfBk5KkAxd8mTDpMq2JhJOWEWM9OsWktYKEWS2LpmtuC1hHSnRs6aiSFGOcrnwOYTlo5ooSjla0mzfiTyrEKMiA1IUpZC8OIEoSmnfluX8uS3h/NR+bum7En4P6VzTnZ8BaR/SeSscr/T59PVSqstSEmbDa4XHZGdnxzVwbrFtG6NHj8aLL76IYDCIoqIinDhxAk8//TQbOCIicuBSqi5xaIaDkSNHIhgMorW1Na68tbUVeXl5jutcccUVSE9PRzD4+R81t9xyC8LhMCKRCDIyMvq07wF5BpeMcFAiItK7GEXpxquvMjIyUFRUhPr6+liZbduor69HSUmJ4zq33347Dh8+DPsLd9Qffvghrrjiij43bsAANHAXw0Grqqqwf/9+TJw4EWVlZfj0008dl78YDnrs2DH88pe/xKFDh7Bp0yaMHTvW7UMjIqIkqKysxKZNm/Czn/0MH3zwAR544AF0dHTEoioXLFiAlStXxpZ/4IEH8Nlnn2HJkiX48MMPsX37dqxfvx4PPfSQ0X5d76JMVjgoERHpJWsc3OzZs3Hy5EmsWbMG4XAYhYWF2LlzZyzwpKWlBYHA5/db+fn5ePPNN7Fs2TLceuutGDt2LJYsWYLly5cb7dfVBu5iOOgXW2KTcNA33ngDo0aNwty5c7F8+fK4/tcv6uzsRGdnZ+z/Xx6LQUREvSUzF2VFRQUqKioc32toaOhVVlJSgt/+9rfG+/kiVxu4wQoHra6uxtq1a9089L4xzVOnzXknlHdLUWpCuZDzUczbhwRyUYp5IoV9a7Yl5qg0jJaUjqnnPbOITHE7CVRiKYekdCoEAs77kKIudXQ5JE2Wjwbk7zYQdP4gVlSIBhVzUWrOT+EcEYMjhDoj1rFE8k0ORo5Kcl3SB3p/MRy0qKgIs2fPxmOPPYba2lpxnZUrV6KtrS32On78+CAeMRFRarJVwLVXKnD1Dm6wwkGldDBERCRTyqUZvVNkuhxXm+FkhoMSERF9kev3mckKByUiIr1kjINLJteHCSQrHJSIiPQ4o7cLkhEO6irTPHIJ5KKUJgwU80FK+RWliMEEIhyV6Yzeulm1hShHaZ1uMbrSbKZvQI68FHNUwr3KKv3kFoToSuGYggnMKCnNEy1Vcvk3kvctztwdMDx3NOenlItSrgPC7y3mj0wgF6W4vD9yTqYq5qIkIvIJt2cTGOrYwBER+YTfuihTYzADERGRId7BERH5hN/u4NjAERH5BJ/B+YRlSzFnunWkaEnDckCTJ89wNmPTvH2QI9ikGa/F/JG6SEYpF6VhbknT7QAQ0whJlXIw/hoVZ/oWoivF8yMBUWmm76g0+7jmdxXyVAbTnKMJxXNNc36K57RQB8TJNxPJRSm8J9Z9ge76Yp5llBLl2waOiMhvlHLnDzqVIq00GzgiIp/w2zM4RlESEZEn8Q6OiMgnlEtBJqlyB8cGjojIJ9hFSURE5AFGd3DTp09HYWEhampqBuhwhgZLSqAqlUuJcbVJXYVyw6TKYrkmma2coNksqbIU2q9fRxqKYDYcoFuzb+mvS9NhAol05QSkkHyhXAnltuXeX8iWJfwW0r61v6vziSsmYXbx/BSHFki/kzhMQFMvxbosJW5OrWTLfruDYxclEZFP+G2gd5+7KO+55x7853/+J5577jlYlgXLsnDs2LEBPDQiIqLE9fkO7rnnnsOHH36I8ePH44knngAAjBo1asAOjIiI3MUuSkFOTg4yMjKQlZWFvLy8gTwmIiIaAOyiJCIi8gDPB5ko5Zz0VPv3h/G09VLklbyKEqLOxKTKQgJhKRpNlwhZ2reYCFlaXpdsWYyWFJIqG0ZL6pItS+9J3SpudrdIP7kURSmVBwPuZVu2hO9cTACt+V2lRMzSdx6UInO156dhHRATkTsvb2mjm12q+xrSNWkwKFhQ+qtfn7eTCowauIyMDESjqRUWS0REPfz2DM6oi7KgoAB79uzBsWPHcOrUKXFMDBERUbIZNXDf//73EQwGMW7cOIwaNQotLS0DdVxEROSyi0EmbrxSgVEX5Y033ojGxsaBOhYiIhpA7KIkIiLyAM9HUYoSySEnRksK5bppb01zS3abRaPZmlx/Uq5BaR1xeU0ko7ROt7AP02hJXRSlmAdTiPwalFyUECIWkzg1ciAg5aKUn627du7ozk8xwlgol3JUivleE4iiNI6uHJrBeDZcGgeXIlGUvIMjIiJP8u8dHBGRz/jtGRwbOCIin7BhudK9yC5KIiKiJOIdHBGRX7jURSlOMjvEsIFzIM/obRhhpQu8MpzpWJwBWcq7KEWWQRd5aZZzUoqIBIBuIT+gLeQTFHNRSrkrhe3otiUFz7k5aDUqdN3IM307byeo6QIyvUBJ+5Z+I8uST9yAkCPTNOek7vwUz2njumEYXQkY1/FUm9GbswkQERF5AO/giIh8glGURETkSTb0PbQm20kF7KIkIiJP4h0cEZFPsIvSL2zNrLpCZJQcXSmF58m7kGYhNs1RKUap6WbbliITo86ngxjJqImilCIZu6RclMY5KuUKJkVYSpFfg1FZxdmzhXJdukQVcOd4xfyYuuhYMYrS+dyRJkhO05yfxpGXYv5WqY5pIh+FuizXfak8ebN269jKnQhI3aToQwm7KImIyJP8ewdHROQzChaUC2m23NjGYGADR0TkExzoTURE5AG8gyMi8omeIBN3tpMK2MAREfkEn8F5jZvhuoaJWJVu10Iftm2YUFZa3hbC63veM0uqLC2vTbYshvcL2zIcDtCl+XzSMIHBSLYsMU22bEtvAACkMHcpebKQbNkSlo/Kf54Hg877Nj13tOenS3UgKA0L0dRLyzSheiKG6BACL/J+A0dERAD8F2TCBo6IyCeU0icRMNlOKmAUJREReRLv4IiIfELBgs0gEyIi8homW/YJ7VTzYgJVKcLKvWTLqluIOpOi1ITybinRLICo8J4YLdntfJrYQrQiYB4t2SV8DikislvYvm7f0l+uyYyiDAgJjzVfrchWziecJURLSscUDMg7l86FNCm6UjjXdOdn0PBcl+qMnGxZ3LWmLkt13zAxOw0q3zZwRER+47coygEJMtm4cSMKCgqQmZmJ4uJi7N27t0/rbd26FZZlYdasWQNxWEREvqZcfKUC1xu4bdu2obKyElVVVdi/fz8mTpyIsrIyfPrpp9r1jh07hu9///v4+te/7vYhERGRD7newG3YsAGLFy9GeXk5xo0bh9raWmRlZaGurk5cJxqNYt68eVi7di2uvfZatw+JiIjweRelG69U4GoDF4lE0NTUhNLS0s93EAigtLQUjY2N4npPPPEERo8ejXvvvbdP++ns7ER7e3vci4iI9GwXX6nA1SCTU6dOIRqNIjc3N648NzcXBw8edFznnXfewb/8y7+gubm5z/uprq7G2rVr+3OoWsb56MRy+a8cJb0n5VEU8/NJEWeaKErhPSm3pBSV2KXLRSnuwzAXpfCXojYXpfDdSlGUUshzIpVY+otRygcpRVEmMlZJCfkrA0JuyYBh7koACFjO34r0ewcN80cC8jkt1QGxzkh1TFMvTeu4eK2gISGpmUzOnDmD+fPnY9OmTRg5cmSf11u5ciXa2tpir+PHjw/gURIRecPFcXBuvFKBq3dwI0eORDAYRGtra1x5a2sr8vLyei3/0Ucf4dixY5g5c2aszP7TX0RpaWk4dOgQrrvuul7rhUIhhEIhNw+diMjzOEygHzIyMlBUVIT6+vpYmW3bqK+vR0lJSa/lb775Zrz77rtobm6Ovf7qr/4Kd955J5qbm5Gfn+/m4RERkY+43kVZWVmJTZs24Wc/+xk++OADPPDAA+jo6EB5eTkAYMGCBVi5ciUAIDMzE+PHj497jRgxAl/5ylcwfvx4ZGRkuH14RES+lcxxcMkYH+16JpPZs2fj5MmTWLNmDcLhMAoLC7Fz585Y4ElLSwsCmlRAREQ0MJLVRXlxfHRtbS2Ki4tRU1ODsrIyHDp0CKNHjxbX6+/46AFJ1VVRUYGKigrH9xoaGrTrbt682f0DcqKdblsgRUx1C/kEo5ooSim6zDQXpRBNKC0PaKIohTyDUoSjbkZvKcpRLhciNQ2jK3XvyTkqB570awQDznu3pZyIkC8u6QEhL6I0c7fwFWpn9BaON2iYo1J3fpqe01KdkSOP5XPHEuqy8YzeiVxfPOyL46MBoLa2Ftu3b0ddXR1WrFjhuM4Xx0f/13/9F06fPm28X95KERH5hNvj4L48Hrmzs7PXPgdrfLQTNnBERD7h9jCB/Px85OTkxF7V1dW99qkbHx0Ohx2P8+L46E2bNvXr83I2ASIiSsjx48eRnZ0d+78bw7cSHR/thA0cEZFPKLjzzPnik8rs7Oy4Bs7JYI2PdsIuSiIin1BwqYvSIJVcMsdHe/8Ozu5KYB1hNl5hxmQoKfJK/vtBnIXYMHefNDNyNCr/tFI0oTgLt7RvTT5IacZtqVyKlpSX10SJGuaWFHNRJhBOLc2SLeV3lCJag5p8kOlCJKPpMcnlutm2nfedJkRwiuea5vyUzuk0MSpSKBfqmDYXpVjHzWb01krkmpTiKisrsXDhQkyePBlTpkxBTU1Nr/HRY8eORXV1dWx89BeNGDECAHqVX4r3GzgiIgIA2Krn5cZ2TCRrfDQbOCIin3BrNu5EtpGM8dF8BkdERJ7EOzgiIp/w22wCbOCIiHzCrdm4U2WaV3ZREhGRJ/n2Ds7ShfcaTlsv/TmjhHBnAFDKMKmyWC4NE5D/dnErqXKXZh8RYZ1OITxcGg7QKYWZ65ItG67j6l+jwj7EZMtCqL6tGQpga4ZIODOr5pYmhCAoDCHoDpol8I6mycmIpXPatG5IdUxXL8Wwf8Nrgvb6kkRuzcbtyxm9iYho6GIXJRERkQfwDo6IyCeUkhMvmW4nFbCBIyLyCRsWbIM8krrtpAJ2URIRkSfxDs6BZRgxJc1OrzRRhsbJlqVoNCEiTIp81L0nRkuKiZDlfZgmT5aiJd1MtixGUQ5Cd0tA+INXiqIMaqLUTJMtS6SkyrpEz13ClyUlQu4OmJ+f0jkt1QHTZMv6eulcLl0TxGvFEJWsXJTJwgaOiMgvXHoG50pCy0HALkoiIvIk3sEREfmE34JM2MAREfmE34YJsIuSiIg8yft3cFJYlDYXpWk+OmnXmlyUUg7JLuefRMzDJ0QrRoWcj4Ac/RgxzBMp5ZvUvSdFS0rLdwnRhNIxAUC38NelFEU5GH+NWmIUpXN5mrQC5DRJbk1hEtDlohQiOKVzJy3oXJd056d0Tkt1QKozcnSlLheleFBCuXSt0FxfpGvSIPBbqi7vN3BERATAf8ME2EVJRESexDs4IiKfUHBnCFuK3MCxgSMi8oueLkoXhgmkSAvHLkoiIvIk397BJTSjd7ewTrfwF5HmLyUpH54UKSbOwi3lotREism5KM0iHDs1uShNoyUjYs5J5+9Qiq4EdLkonZeX/qJNJLpSCn4MSDknhRWimnyQtjRTdcCdP6stIUclAASjznUjaDmXi+ea7vwUz2nnOpAuzegt5ZzU3cFIdVmq+yk3o7e/xsH5toEjIvIbvw0TYBclERF5Eu/giIh8gl2URETkSeyiJCIi8gDewTkR884535crKQpPEykm588T8u0Zztytm21byhsozcLdaZijsmcfZtGSnUK0pGl0JSBHSyZ3Rm/nfYszemtyUUaFaElb/Lva7O/YNE0Ep3QupBnmqMyw5XyM0jkt1QGxzkh1TJcjVjhHLOkkSbEZvZVLqbrYRUlEREOK3zKZsIuSiIg8iXdwREQ+4bfZBNjAERH5hN+GCbCLkoiIPIl3cEREPuG3cXDeb+CkcGRNMlQxUaqYhFkIPxfCnQFAiaHQzj+JLYTwSwloI0I5YJ5UWQrVv6D5fPI6ZsMEuhNIttwl/ExSt0q3C9OHXIoUem8JwwHSNX0r0ucQkzAb9icFNL+rJXyOdGGd9IBzXdKdnxmGdUCqM1Id09VLMdmyaVJlbTJ3eYjEQPPbMzh2URIRkSd5/w6OiIgA+G8cHBs4IiKfYBclERGRB/AOjojIJ/w2Ds63DZx2SnkxWlJItixFcQnJXgE5EawUFdklLN8lLC8lrAXkBLhiFKVQ3qlJtixFRYrlQnRll1CRdMmWxShKYfnuQYh5TgsISXyF5aWE0QAgfFVIF6NBhQhAIYLTsuTfNShFS1rO+0gPCMmWo3L9k85pqQ5IdUZMtqypl2Ii5u4u53LT6Mok89swAXZREhGRJw1IA7dx40YUFBQgMzMTxcXF2Lt3r7jspk2b8PWvfx1f/epX8dWvfhWlpaXa5YmIKDE2Pg806dcr2R+kj1xv4LZt24bKykpUVVVh//79mDhxIsrKyvDpp586Lt/Q0IA5c+bgrbfeQmNjI/Lz83HXXXfhxIkTbh8aEZGvKRdfqcD1Bm7Dhg1YvHgxysvLMW7cONTW1iIrKwt1dXWOy7/88st48MEHUVhYiJtvvhkvvfQSbNtGfX2924dGREQ+4moDF4lE0NTUhNLS0s93EAigtLQUjY2NfdrGuXPn0NXVhcsvv1xcprOzE+3t7XEvIiLSU4l0Rzq8fBlFeerUKUSjUeTm5saV5+bm4uDBg33axvLlyzFmzJi4RvLLqqursXbt2n4dq3aq+W4hV5wQ2qaEsDa7W/77wRbWiQpRXFJUZLeQn0+KfNS9F5G2ZZhXUveeFC15QQg6k3JOSpGSuvekyEQxt6O8C5H0jUSEjQWFwEdtFKXwnUSFPJFu/h0r5dRMF86d9KjzB9ednyEp76qYv1XKOWleL6W6LP4g0rVCd31JIqVcymSSIg3ckIqifPLJJ7F161a89tpryMzMFJdbuXIl2traYq/jx48P4lESEVEqcPUObuTIkQgGg2htbY0rb21tRV5ennbdZ555Bk8++SR+/etf49Zbb9UuGwqFEAqF+n28RER+wnFw/ZCRkYGioqK4AJGLASMlJSXiek899RTWrVuHnTt3YvLkyW4eEhER/UnPMzTlwivZn6RvXM9kUllZiYULF2Ly5MmYMmUKampq0NHRgfLycgDAggULMHbsWFRXVwMAfvjDH2LNmjXYsmULCgoKEA6HAQDDhw/H8OHD3T48IiLyCdcbuNmzZ+PkyZNYs2YNwuEwCgsLsXPnzljgSUtLCwKBz28cX3jhBUQiEXz3u9+N205VVRUef/xxtw+PiMi3OF2OCyoqKlBRUeH4XkNDQ9z/jx07NhCH8Dlh9tyEclFKOQ6FKC5pRuGewxLy5wl5Ik1n7pbyTQJApxClJs3QLZVL0ZWAebTkBTFHpfPyuvyRXUKIl9StootYdIsULSmkqBRnMgeALuFrD0k7ES9HzhuSjgmQf7+gEF2ZFnD+oTI056d0Tos5J4VtSXVMVy/FXJTS+ZZILspkz+jt0nZSwZCKoiQiInKLb2cTICLyG/Wnf25sJxWwgSMi8gl2URIREXkA7+CIiHzCbwO9/dvA6aKclPPPp4TgJykqy9bN6C3lnBTKu8RZuM1m5wbMoyUvRJ3D6s5rcvqdE9YRoyiF8ojQF6LPRem8TlSKrhS2oxJIuCfNki19U0Fh+XRN30q6GA3qvC1biK6Uv0J55wHh2UvQcj53pCjKUMB8xnmpDkh1RsxRqZvRW6gDUt23hGuF9vqSREq59AwuRZJRsouSiIg8yb93cEREPsMuSiIi8iR2URIREXkA7+CIiHxCwZ3uxdS4f2MDR0TkG7ZSsF1onuwU6aL0fgMnxffqwnilaei7hXBrIexYClPu2YUQCm2cVFkK7dckWxbXkRIkmyVO1r13XvhqLwgZjzulYQKaCtYlJMCNChVbqqyJPKuwIAwTEIYDBIXl05X89CBdSMQs/ExQwvABie65RdByflfK85wuDAe4EJDPz1DU+SSR6oBUZ8TkzJp6KQ4hEOq+eK3QDkNKXrLlZNq4cSOefvpphMNhTJw4ET/+8Y8xZcoUx2U3bdqEn//853jvvfcAAEVFRVi/fr24vITP4IiIfEK5+M/Etm3bUFlZiaqqKuzfvx8TJ05EWVkZPv30U8flGxoaMGfOHLz11ltobGxEfn4+7rrrLpw4ccJov2zgiIh8wnbxZWLDhg1YvHgxysvLMW7cONTW1iIrKwt1dXWOy7/88st48MEHUVhYiJtvvhkvvfQSbNtGfX290X7ZwBERUULa29vjXp2dnb2WiUQiaGpqQmlpaawsEAigtLQUjY2NfdrPuXPn0NXVhcsvv9zo+NjAERH5hA3l2gsA8vPzkZOTE3tVV1f32uepU6cQjUaRm5sbV56bm4twONyn416+fDnGjBkT10j2hfeDTIiICID7UZTHjx9HdnZ2rDwUCvV721/25JNPYuvWrWhoaEBmZqbRup5v4CxhenhLiLQDIE5Dr4QEwkpIOhzVJHXtFt7r6kp3LJcS0HYKkWJS4mTdexeE6DwpcbJUDgAdhtGSF6LO33mn8FtElByl1iU8IeiG8zpShXc1ilKIZEyD82+RLoVEAsgQEhtHhchLW+iosYVjhVgOCMGgCAaEaFDhXMsQkjADQKYYMWxWZ6Q6pquXUl2W6r50HdFdX6RrUirKzs6Oa+CcjBw5EsFgEK2trXHlra2tyMvL0677zDPP4Mknn8Svf/1r3HrrrcbHxy5KIiKfSEYUZUZGBoqKiuICRC4GjJSUlIjrPfXUU1i3bh127tyJyZMnJ/R5PX8HR0REPb74/Ky/2zFRWVmJhQsXYvLkyZgyZQpqamrQ0dGB8vJyAMCCBQswduzY2DO8H/7wh1izZg22bNmCgoKC2LO64cOHY/jw4X3eLxs4IiIaULNnz8bJkyexZs0ahMNhFBYWYufOnbHAk5aWFgQCn3covvDCC4hEIvjud78bt52qqio8/vjjfd4vGzgiIp9I1h0cAFRUVKCiosLxvYaGhrj/Hzt2LIGj6o0NHBGRTySShUTaTirwbwOnzUXp/J4ScthJ+etsIRoMkPPkSZFiEWHfnbaUV1IXRSlERQoRZOeF5aW8kj3bcq4A56PO3+154feICJGPnegS9x2xnN+LwvmAu62Bj2pLU86/a1CoghnKOTIQAELCe13Cbx5VwnkrlAc0UZRiTk0hR2W65XweZAg5KgH5nJbqgFRnpDqmq5dSXZbqvnSt0F5faND4t4EjIvIZ5VIXJe/giIhoSLEtG5bV/xnhbFdmlRt4HAdHRESexDs4IiKfsKFgJSmKMhnYwBER+cTFVMlubCcVeL+Bk6KZdLkopSjKrgznTUkzDWty3pnO3N0pRIqdF3NRyr3P54T3pGjJDiFg8awQKQkAHUK05DkhD9859J5mAwA6Lan8grjviBVxLI8KkZdRJUdkuiUYcI58DMK5PEM5n2sAEFLOCWdDyjnRbdR2LpdrgBzhKOWpDArRlekB53MtIyCfO9I5nZVmVmekOqarl1JdVl1SFKXzuaa9vjDCctB4v4EjIiIAPX/UuNNFmRrYwBER+QSjKImIiDyAd3BERD5hw4blwt1XqtzBsYEjIvIJNnA+YXVrIueE6EDTXJRRTc47aRbizm7n8guGM3dLkZIA0CHknJRm4ZaiJc90yzkcO4TIxA4h+vGcddaxvNM671geUc7lANClnPfRbTtHZEohz0oza7jEEmbbtoSnAWkB5wjHdDhHSgJAZ2CYY3lIOZd3K+f5s6K28z5sTR5M6ZIhzegt5ajURVFmSjPOSxHGQp2R6piuXprnonT+HNrrCw0a3zZwRER+w3FwRETkSYyiJCIi8gDewRER+YSC7crdF7soiYhoSFGIQrnQcaeEiYiHGnZREhGRJ3n+Ds4Skvtqk6FGhWECQghxd8Q5HLlbCFMG5NDmTiGE+ZyQbFkaDiANBQCAc8JX0tHl/LmlxMnSUAAAaBfC/juE8vNodyzvtJ2X77LPifuWhgNEbefEuPYgJFsOWEKy5YBzUmVp+AAApAeyHMsjlvNwgG7L+QePwnl5CMMKACAgJOROE5IRpwlJmDOEJMwAkBl0fm+YUAcuMxw+oKuXUl2W6r50rdBdX8Rr0iDo6Z70T5CJ5xs4IiLq0TOPmxsNXGrMB8cuSiIi8iTewRER+URPkIlzt7HpdlIBGzgiIp/w2zM4dlESEZEnef8OTpgeXpcMVUmBl0IklS1EkHUKEVkA0ClEcl0QIsXOC5FiHUIS2I5uuRtCTp7s/F2dEaIPT1vOkY8A0B74f47l52zn8ogQLdnZfcaxPGrLyZZtJSRVFqMlB6O7RUjCLERXBiw5ijIoJFvuSnOOLI0EnMu7Al91LLd1EcZ2tmNxoNs5GjQYcP7c6UJyZgDIDDqvkxWUEo4LUZRSQnNNvZTqspyE2Xk72mTLwjVpMDAXJREReZKNKODCMzg7RZ7BDUgX5caNG1FQUIDMzEwUFxdj79692uV/8Ytf4Oabb0ZmZiYmTJiAHTt2DMRhERGRj7jewG3btg2VlZWoqqrC/v37MXHiRJSVleHTTz91XH737t2YM2cO7r33Xhw4cACzZs3CrFmz8N5777l9aEREvnaxi9KNVypwvYHbsGEDFi9ejPLycowbNw61tbXIyspCXV2d4/LPPfccvvWtb+Hhhx/GLbfcgnXr1uG2227DT37yE7cPjYjI12wVde2VClx9BheJRNDU1ISVK1fGygKBAEpLS9HY2Oi4TmNjIyorK+PKysrK8Prrr4v76ezsRGfn54EEbW1tAID29t6BB+fPOD/sjXTIf4Go885BGJ2dzuucjTj/2Gd1M14LD6HPRZ0DOs4Lsxx32s7HGrHlfnYhIxe6hJO2WzkfUxTyg3RbePouVQxp9mylpNm25UwK8num5W4S0r8Jx6o0xyR/J87fofSdS79RVJO6TDoX5HPK+byNaG4ApHP6fNSszkh1LFNTL0NCXVZC3Q8J1wpLc325IFyTutrjg4EuXs905zrpudrAnTp1CtFoFLm5uXHlubm5OHjwoOM64XDYcflwOCzup7q6GmvXru1VXnDVkgSO2kSrYTnRRdIFz/liZys516YUhNdl+Ee1czyry+SP4XF/SOC91xxL//jHPyInJ6ffRwQwijIlrFy5Mu6u7/Tp07j66qvR0tLi2ongB+3t7cjPz8fx48eRne0c/k3x+J0lht+buba2Nlx11VW4/PLLXdtmTwPX/+5FXzZwI0eORDAYRGtr/B1Na2sr8vLyHNfJy8szWh4AQqEQQqHe44RycnJYeRKQnZ3N780Qv7PE8HszF9DMvEB6rn5zGRkZKCoqQn19fazMtm3U19ejpKTEcZ2SkpK45QFg165d4vJERJQYpWzYLryk58BDjetdlJWVlVi4cCEmT56MKVOmoKamBh0dHSgvLwcALFiwAGPHjkV1dTUAYMmSJZg2bRqeffZZzJgxA1u3bsW+ffvw4osvun1oRES+1tO16EayZZ82cLNnz8bJkyexZs0ahMNhFBYWYufOnbFAkpaWlrhb7qlTp2LLli1YtWoVHn30Udxwww14/fXXMX78+D7vMxQKoaqqyrHbkmT83szxO0sMvzdz/M76z1KMQSUi8rT29nbk5OQgJ3McLMt56IYJpaJou/B7tLW1DelnqikZRUlEROZs2LB81EXJ8BwiIvIk3sEREflET/SjC3dwfo2iJCKiocmNQd5ubmegsYuSiIg8KWUbuB/84AeYOnUqsrKyMGLEiD6to5TCmjVrcMUVV2DYsGEoLS3F//7v/w7sgQ4hn332GebNm4fs7GyMGDEC9957L86edZ5J+6Lp06fDsqy41/333z9IR5wcnM/QnMl3tnnz5l7nVGZm5iAe7dDw9ttvY+bMmRgzZgwsy9ImmL+ooaEBt912G0KhEK6//nps3rzZaJ9KKag/DdTu3ys1gu9TtoGLRCL467/+azzwwAN9Xuepp57Cj370I9TW1mLPnj247LLLUFZWhgsXLgzgkQ4d8+bNw/vvv49du3bhV7/6Fd5++23cd999l1xv8eLF+OSTT2Kvp556ahCONjk4n6E50+8M6EnZ9cVz6uOPPx7EIx4aOjo6MHHiRGzcuLFPyx89ehQzZszAnXfeiebmZixduhSLFi3Cm2++2ed9+m0+OKgU99Of/lTl5ORccjnbtlVeXp56+umnY2WnT59WoVBI/du//dsAHuHQ8Pvf/14BUP/zP/8TK/uP//gPZVmWOnHihLjetGnT1JIlSwbhCIeGKVOmqIceeij2/2g0qsaMGaOqq6sdl/+bv/kbNWPGjLiy4uJi9fd///cDepxDiel31tc66ycA1GuvvaZd5pFHHlFf+9rX4spmz56tysrKLrn9trY2BUANyyhQWaFr+/0allGgAKi2trb+fOwBl7J3cKaOHj2KcDiM0tLSWFlOTg6Ki4vFueq8pLGxESNGjMDkyZNjZaWlpQgEAtizZ4923ZdffhkjR47E+PHjsXLlSpw75805UC7OZ/jFc6Qv8xl+cXmgZz5DP5xTQGLfGQCcPXsWV199NfLz8/Gd73wH77///mAcbkpz41xTKuraKxX4Jory4vxypnPPeUU4HMbo0aPjytLS0nD55ZdrP//cuXNx9dVXY8yYMfjd736H5cuX49ChQ3j11VcH+pAH3WDNZ+gliXxnN910E+rq6nDrrbeira0NzzzzDKZOnYr3338fV1555WAcdkqSzrX29nacP38ew4YNu+Q23ArvT5VhAkPqDm7FihW9Hj5/+SVVGr8a6O/svvvuQ1lZGSZMmIB58+bh5z//OV577TV89NFHLn4K8pOSkhIsWLAAhYWFmDZtGl599VWMGjUK//zP/5zsQyOPGVJ3cP/4j/+Ie+65R7vMtddem9C2L84v19raiiuuuCJW3traisLCwoS2ORT09TvLy8vr9dC/u7sbn332mXbuvS8rLi4GABw+fBjXXXed8fEOZYM1n6GXJPKdfVl6ejomTZqEw4cPD8QheoZ0rmVnZ/fp7g1wL8VWqgSZDKkGbtSoURg1atSAbPuaa65BXl4e6uvrYw1ae3s79uzZYxSJOdT09TsrKSnB6dOn0dTUhKKiIgDAb37zG9i2HWu0+qK5uRkA4v5I8Iovzmc4a9YsAJ/PZ1hRUeG4zsX5DJcuXRor89N8hol8Z18WjUbx7rvv4u677x7AI019JSUlvYagmJ5rfuuiTNkoyo8//lgdOHBArV27Vg0fPlwdOHBAHThwQJ05cya2zE033aReffXV2P+ffPJJNWLECPXGG2+o3/3ud+o73/mOuuaaa9T58+eT8REG3be+9S01adIktWfPHvXOO++oG264Qc2ZMyf2/h/+8Ad10003qT179iillDp8+LB64okn1L59+9TRo0fVG2+8oa699lr1jW98I1kfYcBt3bpVhUIhtXnzZvX73/9e3XfffWrEiBEqHA4rpZSaP3++WrFiRWz5//7v/1ZpaWnqmWeeUR988IGqqqpS6enp6t13303WRxh0pt/Z2rVr1Ztvvqk++ugj1dTUpL73ve+pzMxM9f777yfrIyTFmTNnYtctAGrDhg3qwIED6uOPP1ZKKbVixQo1f/782PJHjhxRWVlZ6uGHH1YffPCB2rhxowoGg2rnzp2X3NfFKMr0YK7KSLui36/0YG5KRFGmbAO3cOFCBaDX66233ootA0D99Kc/jf3ftm21evVqlZubq0KhkPrmN7+pDh06NPgHnyR//OMf1Zw5c9Tw4cNVdna2Ki8vj/uD4OjRo3HfYUtLi/rGN76hLr/8chUKhdT111+vHn744SF/UvfXj3/8Y3XVVVepjIwMNWXKFPXb3/429t60adPUwoUL45Z/5ZVX1I033qgyMjLU1772NbV9+/ZBPuLkM/nOli5dGls2NzdX3X333Wr//v1JOOrkeuuttxyvYRe/q4ULF6pp06b1WqewsFBlZGSoa6+9Nu76pnOxgUsLjlLpabn9fqUFR6VEA8f54IiIPO7ifHDBwOWwrP7HFiplI2p/NuTngxtSUZRERERuGVJBJkRENJAU4EoEZGp0/LGBIyLyCffmg0uNBo5dlERE5Em8gyMi8omeAdou3MGxi5KIiIYWdxq4VHkGxy5KIiLyJN7BERH5hUtBJkiRIBM2cEREPuG3Z3DsoiQiIk9iA0d0CSdPnkReXh7Wr18fK9u9ezcyMjJQX1+fxCMjMmW7+Br62MARXcKoUaNQV1eHxx9/HPv27cOZM2cwf/58VFRU4Jvf/GayD4/IgOp5ftbfVwJdlBs3bkRBQQEyMzNRXFyMvXv3apf/xS9+gZtvvhmZmZmYMGFCr6mC+oINHFEf3H333Vi8eDHmzZuH+++/H5dddhmqq6uTfVhEKWHbtm2orKxEVVUV9u/fj4kTJ6KsrKzXJMwX7d69G3PmzMG9996LAwcOYNasWZg1axbee+89o/1yNgGiPjp//jzGjx+P48ePo6mpCRMmTEj2IRH1ycXZBIAg3BsHF+3zbALFxcX48z//c/zkJz8B0DMpbn5+Pv7hH/4BK1as6LX87Nmz0dHRgV/96lexsr/4i79AYWEhamtr+3yUvIMj6qOPPvoI//d//wfbtnHs2LFkHw5RghynoTN89Whvb497dXZ29tpbJBJBU1MTSktLY2WBQAClpaVobGx0PMLGxsa45QGgrKxMXF7CBo6oDyKRCP72b/8Ws2fPxrp167Bo0SKxe4VoqMnIyEBeXh6AqGuv4cOHIz8/Hzk5ObGXU7f9qVOnEI1GkZubG1eem5uLcDjseLzhcNhoeQnHwRH1wWOPPYa2tjb86Ec/wvDhw7Fjxw783d/9XVwXCtFQlZmZiaNHjyISibi2TaUULCu+uzMUCrm2fTewgSO6hIaGBtTU1OCtt96KPW/413/9V0ycOBEvvPACHnjggSQfIdGlZWZmIjMzc9D3O3LkSASDQbS2tsaVt7a2/umusre8vDyj5SXsoiS6hOnTp6Orqwt33HFHrKygoABtbW1s3IguISMjA0VFRXFjRm3bRn19PUpKShzXKSkp6TXGdNeuXeLyEt7BERHRgKqsrMTChQsxefJkTJkyBTU1Nejo6EB5eTkAYMGCBRg7dmzsGd6SJUswbdo0PPvss5gxYwa2bt2Kffv24cUXXzTaLxs4IiIaULNnz8bJkyexZs0ahMNhFBYWYufOnbFAkpaWFgQCn3coTp06FVu2bMGqVavw6KOP4oYbbsDrr7+O8ePHG+2X4+CIiMiT+AyOiIg8iQ0cERF5Ehs4IiLyJDZwRETkSWzgiIjIk9jAERGRJ7GBIyIiT2IDR0REnsQGjoiIPIkNHBEReRIbOCIi8qT/D4VCuJc0yamLAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] @@ -238,12 +230,13 @@ " origin=\"lower\",\n", " extent=(x0, x_final, t0, t_final),\n", " aspect=(x_final - x0) / (t_final - t0),\n", - " cmap=\"plasma\",\n", + " cmap=\"inferno\",\n", ")\n", "plt.xlabel(\"x\")\n", "plt.ylabel(\"t\", rotation=0)\n", "plt.clim(0, 1)\n", - "plt.colorbar()" + "plt.colorbar()\n", + "plt.show()" ] }, { @@ -262,7 +255,9 @@ "cell_type": "code", "execution_count": 6, "id": "059fed69-c042-4fec-bf36-60e365c98de8", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "class CrankNicolson(diffrax.AbstractSolver):\n", @@ -270,20 +265,18 @@ " atol: float\n", "\n", " term_structure = diffrax.ODETerm\n", - " interpolation_cls = diffrax.ThirdOrderHermitePolynomialInterpolation\n", + " interpolation_cls = diffrax.LocalLinearInterpolation\n", "\n", " def order(self, terms):\n", " return 2\n", "\n", " def init(self, terms, t0, t1, y0, args):\n", - " f0 = terms.vf(t0, y0, args)\n", - " solver_state = f0\n", - " return solver_state\n", + " return None\n", "\n", " def step(self, terms, t0, t1, y0, args, solver_state, made_jump):\n", - " del made_jump\n", + " del solver_state, made_jump\n", " δt = t1 - t0\n", - " f0 = solver_state\n", + " f0 = terms.vf(t0, y0, args)\n", "\n", " def keep_iterating(val):\n", " _, not_converged = val\n", @@ -300,12 +293,11 @@ "\n", " euler_y1 = y0 + δt * f0\n", " y1, _ = lax.while_loop(keep_iterating, fixed_point_iteration, (euler_y1, False))\n", - " f1 = terms.vf(t1, y1, args)\n", "\n", " y_error = y1 - euler_y1\n", - " dense_info = dict(y0=y0, y1=y1, f0=f0, f1=f1)\n", + " dense_info = dict(y0=y0, y1=y1)\n", "\n", - " solver_state = f1\n", + " solver_state = None\n", " result = diffrax.RESULTS.successful\n", " return y1, y_error, dense_info, solver_state, result\n", "\n", @@ -346,17 +338,7 @@ "outputs": [ { "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcEAAAGiCAYAAACf230cAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABKRklEQVR4nO3dfXQU5b0H8O/sJtmAklBLSAhGXnzhpUCCQWIQKxxTI3qosdYLlBqgvBwp6QG3VIhFgtCaVitFr5G0tDF4Kxf0HqFe4cSLUaBeApTQnIJHEJAQVHYFbRISSTbZnfsHl4kr8xsym8kuy3w/nlHz7Mwzz87O7LPzzG9+o6iqqoKIiMiGHJFuABERUaSwEyQiIttiJ0hERLbFTpCIiGyLnSAREdkWO0EiIrItdoJERGRb7ASJiMi22AkSEZFtsRMkIiLbYidIREQRt2vXLkyePBmpqalQFAVbtmy57DI7duzArbfeCpfLhZtuugnl5eWm18tOkIiIIq65uRnp6ekoKSnp1PwnTpzA/fffj4kTJ6KmpgaLFi3CnDlz8Pbbb5tar8IE2kREdCVRFAWbN29GXl6eOM+SJUuwdetWHDp0SCubOnUq6uvrUVFR0el1xXSloUREdPVoaWmBz+ezpC5VVaEoSlCZy+WCy+WypP6qqirk5OQEleXm5mLRokWm6mEnSEREaGlpwcBB18Lr8VtS37XXXoumpqagsqKiIqxYscKS+j0eD5KTk4PKkpOT0djYiPPnz6NHjx6dqoedIBERwefzwevx44OjA9EroWvhIucaA/jOzbU4deoUEhIStHKrzgKtxE6QiIg0vRIcSOhiJ3hRQkJCUCdopZSUFHi93qAyr9eLhISETp8FAuwEiYjoa5QAoASUy894mTq6W3Z2NrZt2xZUtn37dmRnZ5uqh7dIEBFRB1WxZjKpqakJNTU1qKmpAXDhFoiamhrU1dUBAAoLC5Gfn6/N/+ijj+Ljjz/G448/jsOHD+Oll17Ca6+9hscee8zUetkJEhFRxO3fvx+jR4/G6NGjAQButxujR4/G8uXLAQCnT5/WOkQAGDRoELZu3Yrt27cjPT0dzz33HP70pz8hNzfX1Hp5nyAREaGxsRGJiYn45JObkJDg7GJdflx//TE0NDR02zVBq/CaIBERaS5cE+x6HdGCw6FERGRbPBMkIqIOgf+fulpHlGAnSEREGkW9MHW1jmjB4VAiIrItngkSEZFGUS0IjImiM0F2gkRE1CGgXpi6WkeU4HAoERHZFs8EiYhIY7fAGHaCRETUwWa3SHA4lIiIbItngkREpFECKpQuBrZ0dflwYidIREQdOBxKRERkDzwTJCIiDaNDiYjIvjgcSkREZA88EyQiIo3dHqrLTpCIiDqoANQuXtSLomuCHA4lIiLb4pkgERFp+CglIiKyL0aHEhER2QPPBImISMOb5YmIyL44HEpERGQPlnaCxcXFuO2229CrVy/07dsXeXl5OHLkSNA8LS0tWLBgAb797W/j2muvxUMPPQSv12tYr6qqWL58Ofr164cePXogJycHR48etbLpREQEdJwJdnWKEpZ2gjt37sSCBQuwZ88ebN++HW1tbbjnnnvQ3NyszfPYY4/hv//7v/H6669j586d+Oyzz/CDH/zAsN5nnnkGL7zwAkpLS7F3715cc801yM3NRUtLi5XNJyKyvQvXBJUuTpF+F52nqGpXUwPIzpw5g759+2Lnzp347ne/i4aGBiQlJWHDhg344Q9/CAA4fPgwhg0bhqqqKtx+++2X1KGqKlJTU/Hzn/8cixcvBgA0NDQgOTkZ5eXlmDp1anc1n4jINhobG5GYmIgvdw9GwrXOrtXV5Md14z5GQ0MDEhISLGph9+jWwJiGhgYAwHXXXQcAqK6uRltbG3JycrR5hg4dihtuuEHsBE+cOAGPxxO0TGJiIrKyslBVVSV2gq2trWhtbdX+DgQC+PLLL/Htb38biqJY8v6IiCJJVVWcO3cOqampcDgsGtizWWBMt3WCgUAAixYtwh133IERI0YAADweD+Li4tC7d++geZOTk+HxeHTruVienJzc6WWAC9cnn3rqqS68AyKi6HDq1Clcf/311lTGTtAaCxYswKFDh/D+++931yoMFRYWwu12a383NDTghhtuwKHjaejVK/gXU1yj/maIaZBHip2N+uVKk7BJm+N0i1WhHAAC5/Xr8n/l0p+/NVZ//lahHp/88fulutqEutr0h08C7fKwSru0jF8q1/+lqwb0ywNCudEyqnDwdt9Fgw7SAIUivA3FIX/TOITXpGUcTqncr1seE6tfDgCOGP3XnMIyzth2/XJXm7gOZ5y0jH65Q6jL2bNVt9zRQ78eAFCu8em/IJSr1+rX5TcYJWxP1N8ZfAnBdZ07F8CIG0+hV69ecmVkqFs6wYKCArz11lvYtWtX0K+TlJQU+Hw+1NfXB50Ner1epKSk6NZ1sdzr9aJfv35By2RkZIhtcLlccLku7Sx69XIgIeEbnaCq/6Ub45e/ZJwB/Z1UEeqCUK5K8wMICB+PX9UvDyjC/IrQoQnzA0A7hGUc0jqENjkMOkHFZCfYbq6zs7YT7P4hdEWIJrCyExTLY8LQCUodV6z+to3R/613YRmX/jJOYRlHvP62dfbQb6ujp/yrR+kp7NPXCMf4tfp1+YVyAGjvpf+h+xKEHzdWXuJR0fWnQERRYIyl0aGqqqKgoACbN2/Gu+++i0GDBgW9npmZidjYWFRWVmplR44cQV1dHbKzs3XrHDRoEFJSUoKWaWxsxN69e8VliIgoNEpAsWSKFpZ2ggsWLMBf/vIXbNiwAb169YLH44HH48H58+cBXAhomT17NtxuN9577z1UV1dj1qxZyM7ODgqKGTp0KDZv3gzgwi+cRYsW4Ve/+hXefPNNHDx4EPn5+UhNTUVeXp6VzSciIpuxdDh07dq1AIAJEyYElb/88suYOXMmAOD3v/89HA4HHnroIbS2tiI3NxcvvfRS0PxHjhzRIksB4PHHH0dzczPmzZuH+vp6jB8/HhUVFYiPj7ey+UREZLPhUEs7wc7cchgfH4+SkhKUlJR0uh5FUbBy5UqsXLmyy22kDuG4zhWOdVD4cd+5iqkK0NXhzCj67Jg7lIiIbItPkSAiog68T5CIiGzLZtcEORxKRES2xTNBIiLqELAgMCaK7hNkJ3glMxlhJUXTqVbukFZGfUnttajcqK1mM8NENFIxIGSSMcgSEo5taJqFdUn7tOnPKYqiGMNGVbq+XaJou3I4lIiIbItngkREpFECF6au1hEt2AkSEVEHm10T5HAoERHZFs8EiYiog83uE2QnSEREHWw2HMpOMNJC+cVk0a8so3ByMQTdZGi6UU51OSzf3Pyh3NZg1a0Qodw6IT081/y6jV40eyuEtA7rPler9qmQmD1mouhMhrqGnSAREXWw2X2C7ASJiKiDzRJoMzqUiIhsi2eCRETUgcOhRERkV6qqdDnfcERz7ZrETtBq0s4Twk5lekc0GeFnXJe5qMBQIvxMRxGaTPIsJck2XHcYEmhLdUlRo2I9QmLtC3WZfH/SNneE8LmKn5+0QAjHhskoV7GeUL7sLTzGKfLYCRIRUQcOhxIRkW0xOpSIiMgeeCZIREQdOBxKRES2xdyhdJFhsJ5FH3Io0WlmI/zEiDmDdZuOkjTbphDWEdHcoWH4ZWs6vaXh+5NesGjbGu23Vn1OhusQii3KQ2r0/kzvCUJdimHyVwoXdoJERNSBw6FERGRbNhsOZXQoERHZluWd4K5duzB58mSkpqZCURRs2bIl6HVFUXSnZ599VqxzxYoVl8w/dOhQq5tORESqRVOUsLwTbG5uRnp6OkpKSnRfP336dNBUVlYGRVHw0EMPGdb7ne98J2i5999/3+qmExHZnhpQLJlCUVJSgoEDByI+Ph5ZWVnYt2+f4fxr1qzBkCFD0KNHD6SlpeGxxx5DS0uLqXVafk1w0qRJmDRpkvh6SkpK0N9//etfMXHiRAwePNiw3piYmEuWvSKZjS40uoBsNmrO9BO95d9ApqNAxXUb/M4Sfi0GpGWEdQT8QrnR+xNeCyka0iKKQ9ggQpiycVIO/VedQk5RaVs5hASsxvuO0DIL9x05EtqaYyak4zKKgkGuRJs2bYLb7UZpaSmysrKwZs0a5Obm4siRI+jbt+8l82/YsAFLly5FWVkZxo0bh48++ggzZ86EoihYvXp1p9cb0WuCXq8XW7duxezZsy8779GjR5GamorBgwdj+vTpqKurM5y/tbUVjY2NQRMREV3GxejQrk4mrV69GnPnzsWsWbMwfPhwlJaWomfPnigrK9Odf/fu3bjjjjvwox/9CAMHDsQ999yDadOmXfbs8Zsi2gmuX78evXr1wg9+8APD+bKyslBeXo6KigqsXbsWJ06cwJ133olz586JyxQXFyMxMVGb0tLSrG4+EdHV52J0aFcn4JITkdbWVt1V+nw+VFdXIycnRytzOBzIyclBVVWV7jLjxo1DdXW11ul9/PHH2LZtG+677z5TbzeinWBZWRmmT5+O+Ph4w/kmTZqEhx9+GKNGjUJubi62bduG+vp6vPbaa+IyhYWFaGho0KZTp05Z3XwiIjKQlpYWdDJSXFysO9/Zs2fh9/uRnJwcVJ6cnAyPx6O7zI9+9COsXLkS48ePR2xsLG688UZMmDABTzzxhKk2Ruw+wb/97W84cuQINm3aZHrZ3r1745ZbbsGxY8fEeVwuF1wuV1eaSERkPyosuFn+wn9OnTqFhIQErdjK7+QdO3bg6aefxksvvYSsrCwcO3YMCxcuxKpVq/Dkk092up6IdYJ//vOfkZmZifT0dNPLNjU14fjx43jkkUe6oWVERDamWnCz/P93ogkJCUGdoKRPnz5wOp3wer1B5V6vVwyIfPLJJ/HII49gzpw5AICRI0eiubkZ8+bNwy9/+Us4HJ0b6LS8E2xqago6Qztx4gRqampw3XXX4YYbbgBwYZz49ddfx3PPPadbx913340HH3wQBQUFAIDFixdj8uTJGDBgAD777DMUFRXB6XRi2rRp1jT6Cn32lVURmmafEh/KOgIh5GwM+M1FaJqNAg0l+lVqk1yPqdkBAEKAJlTh/Tmcwg5q8P6kXVrx6zfYIUWNCttDjGSFvC84Irh/mp4/0q7Q76TuEhcXh8zMTFRWViIvLw8AEAgEUFlZqfUD3/TVV19d0tE5nU4AgGriwLS8E9y/fz8mTpyo/e12uwEAM2bMQHl5OQBg48aNUFVV7MSOHz+Os2fPan9/8sknmDZtGr744gskJSVh/Pjx2LNnD5KSkqxuPhGRralqaD/uvlmHWW63GzNmzMCYMWMwduxYrFmzBs3NzZg1axYAID8/H/3799euK06ePBmrV6/G6NGjteHQJ598EpMnT9Y6w86wvBOcMGHCZXvhefPmYd68eeLrtbW1QX9v3LjRiqYREdHlRCiB9pQpU3DmzBksX74cHo8HGRkZqKio0IJl6urqgs78li1bBkVRsGzZMnz66adISkrC5MmT8etf/9rUeplAm4iIrggFBQXi8OeOHTuC/o6JiUFRURGKioq6tE52gkRE1MFmT5FgJ0hERBpVVbocMHTFBhzp4KOUiIjItngmGCophFnKHSzNb/SLyextB6YTFBvdviAsI93WICbpltdhNnGyeCtECLc7iOsWYrrk+c3/4lWEhNhSuXibgjA/ADikuoRtqEj7rUNKoG3wuUr7gvA+pH3NEcK+Y/pWCLOJtSEfy9I2jLrbHTgcSkREthWh6NBI4XAoERHZFs8EiYhIY7fAGHaCRETUIYCuX8eMouugHA4lIiLb4pmgESt/zYQQoSleXDYZiSlHdIYSPWkuotMwAlWKArUoUbZh8m6TkbeRjA6Vyo2S5IsRtlKbhMTairASR0A+OOTPQ4g0NblPAeajlM1GgRrtt4pVQ31X6tmSzQJj2AkSEZFGDSjGP847WUe04HAoERHZFs8EiYioA4dDiYjIrux2iwSHQ4mIyLZ4JhgiMW2jpTkKuzeC0ejXmpy/01zkplRuvA5zUaBSbk2prUbrMGqvfptCiA51mHvstkPI32kcXihEdUo1SRGoAXM5SAFAEZeRcoSaj+617BgI4bg0eywbpHi9MqkW5A6NojNBdoJERNTBZtcEORxKRES2xTNBIiLSqKr8SDEzdUQLdoJERNTBZs8T5HAoERHZFs8ErWb2F5BRgJ8YgSoUCzk3xYvUoUSmSk9rF9pkFEUoPr3e5JPizUasGr1m9DR6PSHlDg3LWJG0YwlPlhciUKXPSMopCgBO1a9bbjrfp9G2laI9xWNAqkcoDyWvZxSd/Rix232C7ASJiKgDo0OJiIjsgWeCRESksdtTJNgJEhFRBxUWDIda0pKw4HAoERHZluWd4K5duzB58mSkpqZCURRs2bIl6PWZM2dCUZSg6d57771svSUlJRg4cCDi4+ORlZWFffv2Wd30S6kG08V7ab45iXUppidVmvwO3UmqJ+A3P6kBh/4ktCngd+hOYj0BBwLCJM+vCJN+Pf52pzhJ7RXr8jt1J2l+o8l0XUJbDd+fuH79bWjVZ6QatFfcn8V6zO+34rEkHDNSm0I5XkXSd4XR90sEiZ+TySlaWN4JNjc3Iz09HSUlJeI89957L06fPq1N//mf/2lY56ZNm+B2u1FUVIQDBw4gPT0dubm5+Pzzz61uPhGRvUmdttkpSlh+TXDSpEmYNGmS4TwulwspKSmdrnP16tWYO3cuZs2aBQAoLS3F1q1bUVZWhqVLl3apvUREZF8RuSa4Y8cO9O3bF0OGDMH8+fPxxRdfiPP6fD5UV1cjJydHK3M4HMjJyUFVVZW4XGtrKxobG4MmIiIydjF3aFenaBH2TvDee+/FK6+8gsrKSvz2t7/Fzp07MWnSJPj9+lkmzp49C7/fj+Tk5KDy5ORkeDwecT3FxcVITEzUprS0NEvfBxHR1chu1wTDfovE1KlTtf8fOXIkRo0ahRtvvBE7duzA3Xffbdl6CgsL4Xa7tb8bGxvZERIRUZCI3yIxePBg9OnTB8eOHdN9vU+fPnA6nfB6vUHlXq/X8Lqiy+VCQkJC0ERERJfBwJjw+uSTT/DFF1+gX79+uq/HxcUhMzMTlZWVyMvLAwAEAgFUVlaioKDAkjYoAQsHsMWEvELCX6OdRVpGSh5sMlG20ZCFNKYvJp4WylWDRMSqySTWYmJtk20Kpa5wXONQFHNfHIoiN0rap6VlpG0uza8G9C9fACHsI1JC+FASaJvd1y08Lq28tcHS7yST7JZA2/IzwaamJtTU1KCmpgYAcOLECdTU1KCurg5NTU34xS9+gT179qC2thaVlZV44IEHcNNNNyE3N1er4+6778aLL76o/e12u7Fu3TqsX78eH374IebPn4/m5mYtWpSIiCgUlp8J7t+/HxMnTtT+vnhdbsaMGVi7di3++c9/Yv369aivr0dqairuuecerFq1Ci6XS1vm+PHjOHv2rPb3lClTcObMGSxfvhwejwcZGRmoqKi4JFiGiIi6yoKnSCB6zgQt7wQnTJgA1WDs6O23375sHbW1tZeUFRQUWDb8SURE+jgcSkREZBMRD4whIqIriBXRnYwOtQGzp/vSCHEIEXBWRcZJ0ZmAHC0ovQ+z0ZbGy5iMAjUZTWpYVyjRghZRHNJOot9Wh0MOvTUb7SmVO0L4XB1SSLAUcSx9fiHsO2YjpE2XA+ajQKNoaBCwJuMLM8YQERFFAZ4JEhGRxm6BMewEiYiow+Wej9jZOqIEh0OJiMi2eCZIREQdAkrXA8EYHXp1UAzyXoos/PDlqE6TeRDN1gOIwxlmI/mMogilusxGgYaSO1SMQA0h0tQqihhSJ+2I5gdyFL/ZnKLC/A6DyE2Tn58zxvz+aVUuUEs/V5PHfkjfL2Fgt2uCHA4lIiLb4pkgERF1sFlgDDtBIiLScDiUiIjIJngmSEREGjVg/DDsztYRLdgJhklI0WkWPVleFfOAms+tKUeaChUZrsPc+zAbBSrVD5iPAjWdq9KA0RPh9Yh5XMWoUUAa5JFzhOrXJW1Dp8H7FlOHms1DarQOiz4/s8cYIB/L0TMAeBk2uybI4VAiIrItngkSEZHGboEx7ASJiEhjt06Qw6FERGRbPBMkIqIONguMYScYKrPRkCaj0wyZfEq9lJLSMALOZJSd2XyfhnWZjvzTrz+kdZuMig2FVJcYuSk8Qd7o/SmK39S6pbpCWXfA2f2Rt2KaVfH4E6vSF0JeXfkYN7nuCFPVy+QV7mQd0YLDoUREZFs8EyQiIo3dAmPYCRIRUQcV5oeP9eqIEhwOJSIi2+KZIBERaTgcSkREtsVOkK4YZncks7dCGIYxS3WZTawdhtswQlm3VXWFJYG20Fbp9gXA/DaU6rL0cxUTZUv1iKsIbZ82UQ/Zh+XXBHft2oXJkycjNTUViqJgy5Yt2mttbW1YsmQJRo4ciWuuuQapqanIz8/HZ599ZljnihUroChK0DR06FCrm05EZHtqQLFkCkVJSQkGDhyI+Ph4ZGVlYd++fYbz19fXY8GCBejXrx9cLhduueUWbNu2zdQ6Le8Em5ubkZ6ejpKSkkte++qrr3DgwAE8+eSTOHDgAN544w0cOXIE3//+9y9b73e+8x2cPn1am95//32rm05ERBczxnR1MmnTpk1wu90oKirCgQMHkJ6ejtzcXHz++ee68/t8Pnzve99DbW0t/uu//gtHjhzBunXr0L9/f1PrtXw4dNKkSZg0aZLua4mJidi+fXtQ2YsvvoixY8eirq4ON9xwg1hvTEwMUlJSOt2O1tZWtLa2an83NjZ2elkiIuq6b37vulwuuFwu3XlXr16NuXPnYtasWQCA0tJSbN26FWVlZVi6dOkl85eVleHLL7/E7t27ERsbCwAYOHCg6TZG/BaJhoYGKIqC3r17G8539OhRpKamYvDgwZg+fTrq6uoM5y8uLkZiYqI2paWlWdhqIqKr08XAmK5OAJCWlhb0PVxcXKy7Tp/Ph+rqauTk5GhlDocDOTk5qKqq0l3mzTffRHZ2NhYsWIDk5GSMGDECTz/9NPx+/ZSBkogGxrS0tGDJkiWYNm0aEhISxPmysrJQXl6OIUOG4PTp03jqqadw55134tChQ+jVq5fuMoWFhXC73drfjY2N7AiJiC7DyujQU6dOBX23S2eBZ8+ehd/vR3JyclB5cnIyDh8+rLvMxx9/jHfffRfTp0/Htm3bcOzYMfz0pz9FW1sbioqKOt3WiHWCbW1t+Ld/+zeoqoq1a9cazvv14dVRo0YhKysLAwYMwGuvvYbZs2frLmN02t1poSS+lS4Im026CxhEaErRd2ajSQ0i/CyqK7QIP2sSZVsZmSpHPIbwZWHR+Ith9KuwTcwm1rbycxXnl/bnEL6ITR8bZhNxG71mdl+IssTaoUhISDA8wemKQCCAvn374o9//COcTicyMzPx6aef4tlnn73yO8GLHeDJkyfx7rvvmt5IvXv3xi233IJjx451UwuJiOxJVbv+FAizy/fp0wdOpxNerzeo3Ov1irEg/fr1Q2xsLJxOp1Y2bNgweDwe+Hw+xMXFdWrdYb8meLEDPHr0KN555x18+9vfNl1HU1MTjh8/jn79+nVDC4mI7MvKa4KdFRcXh8zMTFRWVmplgUAAlZWVyM7O1l3mjjvuwLFjxxAIdJxSf/TRR+jXr1+nO0CgGzrBpqYm1NTUoKamBgBw4sQJ1NTUoK6uDm1tbfjhD3+I/fv349VXX4Xf74fH49F67ovuvvtuvPjii9rfixcvxs6dO1FbW4vdu3fjwQcfhNPpxLRp06xuPhERRYDb7ca6deuwfv16fPjhh5g/fz6am5u1aNH8/HwUFhZq88+fPx9ffvklFi5ciI8++ghbt27F008/jQULFphar+XDofv378fEiRO1vy8Gp8yYMQMrVqzAm2++CQDIyMgIWu69997DhAkTAADHjx/H2bNntdc++eQTTJs2DV988QWSkpIwfvx47NmzB0lJSVY3n4jI3gJKaA/7/mYdJk2ZMgVnzpzB8uXL4fF4kJGRgYqKCi1Ypq6uDg5Hx3lbWloa3n77bTz22GMYNWoU+vfvj4ULF2LJkiWm1mt5JzhhwgSoBgPCRq9dVFtbG/T3xo0bu9osIiLqhEjmDi0oKEBBQYHuazt27LikLDs7G3v27AlpXRcxd2ioIplz0Oy6LYxsFHfuMOQONcvS6FALIxjFqECTFyeszHtpac5Uiz5Xw/3TomMgLJif9IrGTpCIiDR8igQREdmW3TrBiKdNIyIiihSeCRIR0dd0/UwQiJ4zQXaCRETUIcRHIV1SR5RgJxhhoT58Ur+yMEQLWpjv0Lr8pBa+byujQKV1iLkn9YsVp/kcVlZtKyv3T2nfsfT6kVWRxVYel3RFYydIREQaNdD1H7tW/ljubuwEiYhIw+hQIiIim+CZIBERaex2JshOkIiINOwE6Yph1wi1SOYOtXIdEkUxF+0ZlqjYKPrSspJdjzHqwE6QiIg0F54s39UzQYsaEwbsBImIqIPNbpZndCgREdkWzwSJiEjDwBgiIrItu3WCHA4lIiLbMnUmOGHCBGRkZGDNmjXd1BzqDuKvsnAk3L5KkltH8nYEBRYm0JZuCXBaU//lXjO3kjCsgy7B3KFERGRbHA4VzJw5Ezt37sTzzz8PRVGgKApqa2u7sWlERETdq9Nngs8//zw++ugjjBgxAitXrgQAJCUldVvDiIgo/Ox2JtjpTjAxMRFxcXHo2bMnUlJSurNNREQUIXbrBBkdSkREtsXAGAAhBN8Z1GUyKs9o3SYjFc1GBBr9WgsE9H8fSeXWrsOa9y3VDwABv7n3FwjDL1uHycTaRqQk3fK2MrcNDaNDTdZl5TpMR/eK5eKqDY5lC/eRCObetNuZoKlOMC4uDn6/v7vaQkREEWa3TtDUcOjAgQOxd+9e1NbW4uzZswgELr0ZZNeuXZg8eTJSU1OhKAq2bNkS9Lqqqli+fDn69euHHj16ICcnB0ePHr3suktKSjBw4EDEx8cjKysL+/btM9N0IiKiS5jqBBcvXgyn04nhw4cjKSkJdXV1l8zT3NyM9PR0lJSU6NbxzDPP4IUXXkBpaSn27t2La665Brm5uWhpaRHXu2nTJrjdbhQVFeHAgQNIT09Hbm4uPv/8czPNJyKiy7h4JtjVKVqYGg695ZZbUFVVZTjPpEmTMGnSJN3XVFXFmjVrsGzZMjzwwAMAgFdeeQXJycnYsmULpk6dqrvc6tWrMXfuXMyaNQsAUFpaiq1bt6KsrAxLly418xaIiMiIqgBdfdhwFHWCYY0OPXHiBDweD3JycrSyxMREZGVliZ2rz+dDdXV10DIOhwM5OTmGHXJraysaGxuDJiIioq8La3Sox+MBACQnJweVJycna69909mzZ+H3+3WXOXz4sLiu4uJiPPXUU11rsFGElpAbT8yZF0JOSjHSTYhsVKVoyBCiJ81HYoYQoSlFCwrvz+/XT3AplUv1GK3bL0YqilVZRlUi9+tZiiYNOIRyw8+1+6N75Shs4fOTjplQcsWK69afXRG/E+RVRBIDY64ShYWFaGho0KZTp05FuklERHSFCeuZ4MVMM16vF/369dPKvV4vMjIydJfp06cPnE4nvF5vULnX6zXMXONyueByubreaCIiG+GZYDcaNGgQUlJSUFlZqZU1NjZi7969yM7O1l0mLi4OmZmZQcsEAgFUVlaKyxARUWhU1ZopWlh+JtjU1IRjx45pf584cQI1NTW47rrrcMMNN2DRokX41a9+hZtvvhmDBg3Ck08+idTUVOTl5WnL3H333XjwwQdRUFAAAHC73ZgxYwbGjBmDsWPHYs2aNWhubtaiRYmIiEJheSe4f/9+TJw4Ufvb7XYDAGbMmIHy8nI8/vjjaG5uxrx581BfX4/x48ejoqIC8fHx2jLHjx/H2bNntb+nTJmCM2fOYPny5fB4PMjIyEBFRcUlwTJERNRFVtznF0XDoZZ3ghMmTIBqcC6sKApWrlypPY5Jj95zCgsKCrQzwyuC2fyBUoSY0f04ZqPp/OYiN1VhfgAItEuRm+YiNP3t8iPL/SajQM1Gk7YL9VyoS9qG5vNYWkXM9ym8DYdBm6QDW1qHtM0VITrUYfC5Ohz6O7vTr18eaNdPxajGGOWdNXcMmM4danRcmowA7/I9d2HGa4JEREQ2wadIEBGRxm5nguwEiYhIY7dOkMOhRERkWzwTJCIijRpQ5AcHm6gjWrATNKCE8kGajRo1evK6mCPUXN7EgBDJF0peT78UNSrNbxChKUWOSsu0t+nvrlIUqNRWwPwT5EPKMSmQIjSlcofQVikK06hdZtsrtkmIGgUAZ0A/2lP6PJwx5vYpQN6nxdyhJqOBDUP8LYoCDen7JQwu3Oze1eFQixoTBhwOJSIi2+KZIBERaewWGMNOkIiINHbrBDkcSkREtsUzQSIi0tjtTJCdIBERadgJ2pAiR5rLTN7yoEq3O0gh2pATXEvLSOHecuJpgwTTJhNlS7cvGN2mYNWtEO1tQlsNQuzlJN3680f0FgnhbTid+rcihNouPWKb2uWDpt2h//lJt3RI+5rh/mkymbp8/Jk7xgxfC+E2KElI30kUEnaCRESk4ZkgERHZlt06QUaHEhGRbfFMkIiINKpqQe7QKDoTZCdIREQauw2HshM0YhShJb0mJbeWogsNfnGJCbHFaDopgba5ZNiAQcSlkLi4XSyXdzEpCrRNqsts1GgICbT9Zj+/MESHOoVk1X4hshEAYmL0d1DxfcTq12O2rQCgCO11Os1F90qJtQF5nxaPAWEfMZtwGzD4zKVlxO8KcRUURuwEiYhIc+EpEl2vI1qwEyQiIk1AVcRHipmpI1owOpSIiGyLZ4JERKRhYAwREdmXBZ1gKKniIoWdoBGji7smcxHK8xtFwEkRmlIUqLkITb9B5KaY11NYRozcFMoBwGdymTZh3W1CdGG7wbaV8ooGhM8vHNc4HFKeTqcUNSqHFwaEJKhqbPfmFDV6zenUb5MzRj8HqtH+2d6uv0xMe7tQl7ljxui4NHvsK2KuUXkVFD7sBImISMPhUCIisi27dYKMDiUiItsKeyc4cOBAKIpyybRgwQLd+cvLyy+ZNz4+PsytJiKyBzWgWDKFoqSkBAMHDkR8fDyysrKwb9++Ti23ceNGKIqCvLw80+sM+3Do3//+d/j9HRe1Dx06hO9973t4+OGHxWUSEhJw5MgR7W9FiZ5TbSKiaBKp4dBNmzbB7XajtLQUWVlZWLNmDXJzc3HkyBH07dtXXK62thaLFy/GnXfeGVJbw94JJiUlBf39m9/8BjfeeCPuuusucRlFUZCSktJ9jQoIYVoGTyYXI7vEJ8sLUYdCdBog5y+U8h2KT0sXc4fK65ajPc1FjfrahKSUButo9ZmMDhWjYo1yh+pvWyl3qBQdGkp6KOk3nBQd6vRLT5w3iH4VcodK70MIJhUZ/Q4Vc4e26Ud0Op36n2uMEDUKGOzTFuUINTwupQhw8cnyQkVGZ0tmP5CrwOrVqzF37lzMmjULAFBaWoqtW7eirKwMS5cu1V3G7/dj+vTpeOqpp/C3v/0N9fX1ptcb0WuCPp8Pf/nLX/CTn/zE8OyuqakJAwYMQFpaGh544AF88MEHl627tbUVjY2NQRMRERm7eCbY1QnAJd/Bra2tuuv0+Xyorq5GTk6OVuZwOJCTk4OqqiqxrStXrkTfvn0xe/bskN9vRDvBLVu2oL6+HjNnzhTnGTJkCMrKyvDXv/4Vf/nLXxAIBDBu3Dh88sknhnUXFxcjMTFRm9LS0ixuPRHR1cfKTjAtLS3oe7i4uFh3nWfPnoXf70dycnJQeXJyMjwej+4y77//Pv785z9j3bp1XXq/Eb1F4s9//jMmTZqE1NRUcZ7s7GxkZ2drf48bNw7Dhg3DH/7wB6xatUpcrrCwEG63W/u7sbGRHSERURidOnUKCQkJ2t8ul8uSes+dO4dHHnkE69atQ58+fbpUV8Q6wZMnT+Kdd97BG2+8YWq52NhYjB49GseOHTOcz+VyWbbBiYjsIqB2PUPSxTCLhISEoE5Q0qdPHzidTni93qByr9erGw9y/Phx1NbWYvLkyR3r/P/rqDExMThy5AhuvPHGTrU1YsOhL7/8Mvr27Yv777/f1HJ+vx8HDx5Ev379uqllRET2ZeVwaGfFxcUhMzMTlZWVWlkgEEBlZWXQSOBFQ4cOxcGDB1FTU6NN3//+9zFx4kTU1NSYGvWLyJlgIBDAyy+/jBkzZiAmJrgJ+fn56N+/vzZ2vHLlStx+++246aabUF9fj2effRYnT57EnDlzItH0DlLkpskcoUb300gRanK5uYg56WnwF14TokNN5w6V1yFFgUo5RaUcoVJ0aJtRXlYpOlQISIxodKiwCR1CFCYgvz9VCDo0+z6Mcoc6Hfqfn9MpRIe26zeqXcgDCgAxJiOhzR5Lhk+WN5sHOMR75uzG7XZjxowZGDNmDMaOHYs1a9agublZixb9et8QHx+PESNGBC3fu3dvALik/HIi0gm+8847qKurw09+8pNLXqurqwsK/f7Xv/6FuXPnwuPx4Fvf+hYyMzOxe/duDB8+PJxNJiKyhUjdJzhlyhScOXMGy5cvh8fjQUZGBioqKrRgmW/2DVZRVDWU37LRp7GxEYmJiTj5+QAkJARvyHiP/q/BWK+8wZXP9bPW+L+4Rre8/V89dct9DfrlANDaqP9aa1MP3fKWJv02tXxlrhwAWs7rv9baEieU619/bW2V7xNsEV7jmWCwUM4EY4UnNsQK997FxuqXx8Xqn43Fu9rEdbuE11zx+uHxrnif/jp6tIjriO+p/5pYfq1+ueva8/rlCV+J645L1H8t5lv65c5vN+uWq33l99eWrP/5taQEf06NjQEM6HsSDQ0Nnbr2ZuTid+T6wc+ip0P/O6azvgqcx4yPf2FJu7obc4cSEZFt8SkSRESksdtTJNgJEhGRJqAqFtwiET2dIIdDiYjItngmCECM9jbKYSv90hGS64oJfI2CN4SAD7+YEFv/4/SLty/IH790a0ObTwhmEW93kANjpAAYsS7hfYuBMe3yr9F2kwm0xV0khMAYh9AsqbVSAu0Yo1skYqRE4Przmx2+MgrSk26fcDj0Dyjp1on2GHn/lPZp8RgweeuE0XEpHcvSsS9+Vxh8vxjcgdLtOBxKRES2ZbdOkMOhRERkWzwTJCIijd3OBNkJEhGRRrUgOjSaOkEOhxIRkW3xTNCAYYSWFD0mpdcS5pei0wCD1F4mI+CkRNlSMmyj19pMlktJsgE5CrRViExt9em/P5+wbaUIUAAQAi7hFz4/aX4rOYXmOoV1G72/diEMVEoLZ/aXu5T6DTCIAhUSZceIydqFxkLep81GTkvHmNFxKUaOihHj+vNHMgLUiKqGlg7wm3VEC3aCRESkUQMKVPGGnc7XES04HEpERLbFM0EiItIwOpSIiGyLuUOJiIhsgmeCgJzDz+DiriosI0WOqSZzigJyhFpAioDzC1GgYu5QOQJOyhHaJuX1FMvldZiNAm0V3rdP+JzajKInpehQYX5pFwklCE5qlZQK1CksIaQHBQDEijlCrfndK+UHBQCnQ39fcApRozFO/X0tJkb/gb6AvE9Lx4B0zIi5Qw2OS/FYlo59YedRjIJHjPIWdzNGhxIRkW3Z7Zogh0OJiMi2eCZIREQauwXGsBMkIiKN3a4JcjiUiIhsi2eCRoxO6U3mCJWizaSoNQAICHVJEXDi07alqFGD3KHy09qlJ8ibywMKmI8CbRGi8nzCZ+Ez+DUqxR22C/Ge4pPl5VWIpE9c2ttipOhQg/fnl3JiisuY+z2sKPLn6hSSncbESLlD9T+NWIP906pjQDrGjI5L6Vg2nVP0Ch0ytFtgDDtBIiLS2O2aIIdDiYjItngmSEREGlWVb/A3U0e0YCdIREQaVbXgUUocDiUiIrryhf1McMWKFXjqqaeCyoYMGYLDhw+Ly7z++ut48sknUVtbi5tvvhm//e1vcd9991nWJkXMHWqwkJT3z+yT5S3NHWryyfJCJJ3Ra21CtKdPKJeiTAHzUaCtwjZvEepvM8jsKUWH+oVl5Nyh5sd9FOFXtrQnSG01Onj90n5o+he6fquc8scKX5v+NomN0V8oNkbKbWuwf0pPljcZ7SkeY4a5Q01GgUrfFQbfL+J3UhgEVAWBLp4JMjDmMr7zne/g9OnT2vT++++L8+7evRvTpk3D7Nmz8Y9//AN5eXnIy8vDoUOHwthiIiKbUDtumA91CimzfIREpBOMiYlBSkqKNvXp00ec9/nnn8e9996LX/ziFxg2bBhWrVqFW2+9FS+++GIYW0xERFejiHSCR48eRWpqKgYPHozp06ejrq5OnLeqqgo5OTlBZbm5uaiqqjJcR2trKxobG4MmIiIyFlA77hUMfYr0u+i8sHeCWVlZKC8vR0VFBdauXYsTJ07gzjvvxLlz53Tn93g8SE5ODipLTk6Gx+MxXE9xcTESExO1KS0tzbL3QER0terqUKgVuUfDKeyd4KRJk/Dwww9j1KhRyM3NxbZt21BfX4/XXnvN0vUUFhaioaFBm06dOmVp/UREFP0ifp9g7969ccstt+DYsWO6r6ekpMDr9QaVeb1epKSkGNbrcrngcrksaycRkR3Y7T7BiHeCTU1NOH78OB555BHd17Ozs1FZWYlFixZpZdu3b0d2dnb3N87og5QSZQvh/aoQJi2GW0MO0xaTB0vl0i0SwvyAnCi71WSi7JY2+f35hG1i9laIFiEUTUqGDci3T/gVoTwM4W5O4YvHKeyHsQZfVH7xFXNh/A5hdofB5+pw6O8LMcKtELGx+q012j+lfdrssSEdY0bHpXwsCxvL7C0VERZQQ0sM/806okXYh0MXL16MnTt3ora2Frt378aDDz4Ip9OJadOmAQDy8/NRWFiozb9w4UJUVFTgueeew+HDh7FixQrs378fBQUF4W46ERFdZcJ+JvjJJ59g2rRp+OKLL5CUlITx48djz549SEpKAgDU1dXB4ejom8eNG4cNGzZg2bJleOKJJ3DzzTdjy5YtGDFiRLibTkR01VPVrt/mF02BMWHvBDdu3Gj4+o4dOy4pe/jhh/Hwww93U4uIiOgiZowhIiKyiYgHxhAR0ZWDw6HUOcKHLIUGywm05WEDKemvKkWNmoyYaxciOgE5UXa70Ka2Nv330SZFzAFosSgK1CeVG2QhFqNDTUaNWkmKAnUq+uV+gyGrONXsII9+XYrwGTnl8FNxX2iPlfYd8/un2Uho6ZgRE2sbHJfSsSzeFhBFHQJgv06Qw6FERGRbPBMkIiKN3QJj2AkSEZHGiichRdFoKIdDiYjIvngmSEREGrulTWMnCMifuMGeIOUPlPIEijkKhXKjZUznTQwhd2i78D7apOhQoa1SflAA8AkHihjtKZS3CFGgPoMPsE3MEaq/jLiLhBA16hCul0h5Op3CgI3f4LqLuNmFqFFpLxTzlhp8rrHCviDtO9K+ZmnuUGF+K49LMZ+wsK0Mgpe73gt1gQoLEmh3cflw4nAoERHZFs8EiYhIo1owHBpN9wmyEyQiIg2jQ4mIiGyCZ4JERKRhdCh1MIiAk54KbTZ3aChPlpfyGvqF6DuzT9sG5ByhYtSokCPUZ3A0mY0ClXKBSlGgrUa5Q4XXxNyhVg7wiFGgQp5OYd0Bo4Ecobnik+KFqFFpDzGKDo0T9gUxCtTkfmv0mnQMSMdMSE+WN507VCg3+n6JIA6HEhER2QTPBImISMPhUCIisi0OhxIREdkEzwSJiEgTgAXDoVY0JEzYCYbKbMSXML/hE6zF6FBz5Waj8oyWaWsXniAv7PVt4hrk16SnvrcK5VIeUCkC9MI69F+TysWPO4SBH4cQBdouVBUrDdgYBBdK61CEVB5SBGqMUI/h5yrtC8K+Y+X+afbYkI8x85HhZo/9KxWHQ4mIiGyCZ4JERKThcCgREdmWiq4nwOZwKBERURTgmSAREWk4HEpERLZlt+hQdoIAhAj7kEKbxQTawjrEpLsAAkIiYrNh4AEhnNwvzH/hNSGcXSoX3odfXAPQbvKWh3bhlgeztzsAcjLudmEZSxNoC5yKsO9ICxg0SbpFQkqg3SbMHyvsuH6D+zOkfUHad6R9zWj/lPZp88eGuUT4F16Tyk1+XxjML34nXeVKSkrw7LPPwuPxID09Hf/+7/+OsWPH6s67bt06vPLKKzh06BAAIDMzE08//bQ4vyTs1wSLi4tx2223oVevXujbty/y8vJw5MgRw2XKy8uhKErQFB8fH6YWExHZh4qOIdFQp1D68E2bNsHtdqOoqAgHDhxAeno6cnNz8fnnn+vOv2PHDkybNg3vvfceqqqqkJaWhnvuuQeffvqpqfWGvRPcuXMnFixYgD179mD79u1oa2vDPffcg+bmZsPlEhIScPr0aW06efJkmFpMRGQfXe0AQ72muHr1asydOxezZs3C8OHDUVpaip49e6KsrEx3/ldffRU//elPkZGRgaFDh+JPf/oTAoEAKisrTa037MOhFRUVQX+Xl5ejb9++qK6uxne/+11xOUVRkJKS0un1tLa2orW1Vfu7sbHRfGOJiChk3/zedblccLlcl8zn8/lQXV2NwsJCrczhcCAnJwdVVVWdWtdXX32FtrY2XHfddabaGPFbJBoaGgDgsg1vamrCgAEDkJaWhgceeAAffPCB4fzFxcVITEzUprS0NMvaTER0tVItmgAgLS0t6Hu4uLhYd51nz56F3+9HcnJyUHlycjI8Hk+n2r1kyRKkpqYiJyfHxLuNcGBMIBDAokWLcMcdd2DEiBHifEOGDEFZWRlGjRqFhoYG/O53v8O4cePwwQcf4Prrr9ddprCwEG63W/u7sbGRHSER0WVYeYvEqVOnkJCQoJXrnQVa4Te/+Q02btyIHTt2mI4XiWgnuGDBAhw6dAjvv/++4XzZ2dnIzs7W/h43bhyGDRuGP/zhD1i1apXuMtJpt1VUMVmuufnFemAQaSrVJUammk/eLUXmiZF8wvuWkmEDcsSlVC5Hk5qPDjUbBSrNb2UCbXlgRn/dRrGIDiHS1ClG8Zr7LAw/V2kdJqNADZPLm9zXTR8zhontrTn27SAhISGoE5T06dMHTqcTXq83qNzr9V72Mtjvfvc7/OY3v8E777yDUaNGmW5jxIZDCwoK8NZbb+G9994Tz+YksbGxGD16NI4dO9ZNrSMisifVon/MiIuLQ2ZmZlBQy8Ugl6+fAH3TM888g1WrVqGiogJjxowJ6f2GvRNUVRUFBQXYvHkz3n33XQwaNMh0HX6/HwcPHkS/fv26oYVERPYVqehQt9uNdevWYf369fjwww8xf/58NDc3Y9asWQCA/Pz8oMCZ3/72t3jyySdRVlaGgQMHwuPxwOPxoKmpydR6wz4cumDBAmzYsAF//etf0atXL+2iZ2JiInr06AHgwpvt37+/dhF15cqVuP3223HTTTehvr4ezz77LE6ePIk5c+aEu/lERNQNpkyZgjNnzmD58uXweDzIyMhARUWFFixTV1cHh6PjvG3t2rXw+Xz44Q9/GFRPUVERVqxY0en1hr0TXLt2LQBgwoQJQeUvv/wyZs6cCeDSN/uvf/0Lc+fOhcfjwbe+9S1kZmZi9+7dGD58eLiaTURkC5FMm1ZQUICCggLd13bs2BH0d21tbYhrCRb2TlDtxDM6vvlmf//73+P3v/99N7WIiIguYgJt6mDlJynlCTTMUWhNFKgUfRdSfkRhfmlTGf3kaRfKzUYqiuUGCRjNRoG2C3WFFh0qUKWtqL+Ew+j9iTk/zUbk6jN612b3hVBycZrdp82WG+YNDiGnsK5o6imuYuwEiYhIo0KF2sUM3p0Z8btSsBMkIiKN3YZDI542jYiIKFJ4JkhERBq7nQmyEyQioq8xn/FFr45owU4wVCHkHNSd39Kn15urKxBCZGrAoqhRI9ImDJiM0DSK3DS7TCjrMEvaVpa+P3EbCoRVGH2u4r5gMgrUaP8U123RsRHScSnmFLVv7tBowE6QiIg0HA4lIiLbCiUBtl4d0YLRoUREZFs8EyQiIg2HQ4mIyLZUpeuxPKr2rysfO0GrWZW7EIAqPnFbKPebnV9ed1u7uSfI+4V6pKfBA0C7ySfCS+VSjlCjJ8v7hHWLuUPD8GT5gCJ8TlJFBquOEepqE/KTStuqXXhCfbtRXk/h/Un7jrSvGe2fVh0D0jFmnFeXUaBXE3aCRESkuTAc2rXTOA6HEhFRVLLbNUFGhxIRkW3xTJCIiDR2u0+QnSAREWk4HEpERGQTPBME5J8tRiHPwjLirRBiiLZBGLgY7i2ElAvraG8XyoX6AaBdWIdPeH+twvBHm8GwSKtwa4N0+4IUxn8e7brlLYp04wbgE27q8AnLSNFy/hCGfZzCLQTSrRNxqlO/TYp+OQDx9glFOOSlbd4q/E6OM3hyuLQvuIR9R9rXDPdPYZ8Wb52QjhlhHUbHpenbKiz8fgmHAFQLokM5HEpERFHIbjfLcziUiIhsi2eCRESk4XAoERHZmL2eLM/hUCIisi2eCRoxiBATrxyL0aHmotMA8xFtUjRdu18/irAthOjQduEHnn58phx1CMgRmq1SuVCXFAXaIrZKjgJtEdctR5paxSVGgepv9IBBhKYQaAqn8LvXoeq/v1ihIp+QWBsA2lVhP5T2HWFfM9w/hX1aOgZMR1obrFuMHDWbWNvo+yWC7HafIDtBIiLS2O2aYMSGQ0tKSjBw4EDEx8cjKysL+/btM5z/9ddfx9ChQxEfH4+RI0di27ZtYWopERFdrSLSCW7atAlutxtFRUU4cOAA0tPTkZubi88//1x3/t27d2PatGmYPXs2/vGPfyAvLw95eXk4dOhQmFtORHR1Uy2aokVEOsHVq1dj7ty5mDVrFoYPH47S0lL07NkTZWVluvM///zzuPfee/GLX/wCw4YNw6pVq3DrrbfixRdfDHPLiYiubgFFtWSKFmG/Jujz+VBdXY3CwkKtzOFwICcnB1VVVbrLVFVVwe12B5Xl5uZiy5Yt4npaW1vR2tqq/d3Q0AAAOHfu0ku2/ib9C9TOr+RgCP95IVVXS5tu+Vet+uVNbT5xHU1trbrlze2x+uvw65efD+gHEbQYBFa0CEEarVKghJQ2TQi4AIA2IXClXSj3C0Erfuhv24D4vHv5NalcNajLKgEpMAZCuTA/APiFZfyqsM2FutpU/a8IqRwAfMK6pX1H2g9dBvvO+UCLbnmcX788pl2/3CEcY4rBcQnhWPYLx75P+K4w/H5p0g8taW4MLr/4faYaBUmRobB3gmfPnoXf70dycnJQeXJyMg4fPqy7jMfj0Z3f4/GI6ykuLsZTTz11SfmIG0+F0GrqFtJxa9PjuTnSDYgUqS84Z7CM9NrpLrYlSn3xxRdITEy0pC67BcZctdGhhYWFQWeP9fX1GDBgAOrq6izbWa52jY2NSEtLw6lTp5CQkBDp5kQNbjfzuM1C09DQgBtuuAHXXXedZXVacU0verrACHSCffr0gdPphNfrDSr3er1ISUnRXSYlJcXU/ADgcrngcrkuKU9MTORBZlJCQgK3WQi43czjNguNw8G8J6EK+5aLi4tDZmYmKisrtbJAIIDKykpkZ2frLpOdnR00PwBs375dnJ+IiEJzcTi0q1O0iMhwqNvtxowZMzBmzBiMHTsWa9asQXNzM2bNmgUAyM/PR//+/VFcXAwAWLhwIe666y4899xzuP/++7Fx40bs378ff/zjHyPRfCKiqxavCYbBlClTcObMGSxfvhwejwcZGRmoqKjQgl/q6uqCTu/HjRuHDRs2YNmyZXjiiSdw8803Y8uWLRgxYkSn1+lyuVBUVKQ7REr6uM1Cw+1mHrdZaLjduk5RGVtLRGR7jY2NSExMxBjn84hRenSprnb1PPb7F6KhoeGKv8Z71UaHEhGReaoFj1Lq+qOYwochRUREZFs8EyQiIo1qQWBMNJ0JshMkIiJNQFGhdDH3ZzRFh3I4lIiIbOuq7QR//etfY9y4cejZsyd69+7dqWVUVcXy5cvRr18/9OjRAzk5OTh69Gj3NvQK8+WXX2L69OlISEhA7969MXv2bDQ1NRkuM2HCBCiKEjQ9+uijYWpxZPB5mOaZ2Wbl5eWX7FPx8fFhbG3k7dq1C5MnT0ZqaioURTF8YMBFO3bswK233gqXy4WbbroJ5eXlptcbsGiKFldtJ+jz+fDwww9j/vz5nV7mmWeewQsvvIDS0lLs3bsX11xzDXJzc9HSop+B/mo0ffp0fPDBB9i+fTveeust7Nq1C/PmzbvscnPnzsXp06e16ZlnnglDayODz8M0z+w2Ay6kUPv6PnXy5MkwtjjympubkZ6ejpKSkk7Nf+LECdx///2YOHEiampqsGjRIsyZMwdvv/22qfXaLWPMVX+fYHl5ORYtWoT6+nrD+VRVRWpqKn7+859j8eLFAC4kp01OTkZ5eTmmTp0ahtZG1ocffojhw4fj73//O8aMGQMAqKiowH333YdPPvkEqampustNmDABGRkZWLNmTRhbGzlZWVm47bbbtOdZBgIBpKWl4Wc/+xmWLl16yfxTpkxBc3Mz3nrrLa3s9ttvR0ZGBkpLS8PW7kgyu806e9zahaIo2Lx5M/Ly8sR5lixZgq1btwb9uJo6dSrq6+tRUVFx2XVcvE9wRMxzcHbxPkG/eh6H2n8eFfcJXrVngmadOHECHo8HOTk5WlliYiKysrLE5xxebaqqqtC7d2+tAwSAnJwcOBwO7N2713DZV199FX369MGIESNQWFiIr776qrubGxEXn4f59f2kM8/D/Pr8wIXnYdplvwplmwFAU1MTBgwYgLS0NDzwwAP44IMPwtHcqGXVfqZa9E+0YHTo/7v4bEKzzy28mng8HvTt2zeoLCYmBtddd53hNvjRj36EAQMGIDU1Ff/85z+xZMkSHDlyBG+88UZ3NznswvU8zKtJKNtsyJAhKCsrw6hRo9DQ0IDf/e53GDduHD744ANcf/314Wh21JH2s8bGRpw/fx49enTu7C4AFYqNcodG1Zng0qVLL7lY/s1JOqjsrLu327x585Cbm4uRI0di+vTpeOWVV7B582YcP37cwndBdpKdnY38/HxkZGTgrrvuwhtvvIGkpCT84Q9/iHTT6CoTVWeCP//5zzFz5kzDeQYPHhxS3RefTej1etGvXz+t3Ov1IiMjI6Q6rxSd3W4pKSmXBCq0t7fjyy+/NHx24zdlZWUBAI4dO4Ybb7zRdHuvZOF6HubVJJRt9k2xsbEYPXo0jh071h1NvCpI+1lCQkKnzwIB+50JRlUnmJSUhKSkpG6pe9CgQUhJSUFlZaXW6TU2NmLv3r2mIkyvRJ3dbtnZ2aivr0d1dTUyMzMBAO+++y4CgYDWsXVGTU0NAAT9mLhafP15mBeDFC4+D7OgoEB3mYvPw1y0aJFWZqfnYYayzb7J7/fj4MGDuO+++7qxpdEtOzv7kltvQtnP7NYJRtVwqBl1dXWoqalBXV0d/H4/ampqUFNTE3TP29ChQ7F582YAF6KvFi1ahF/96ld48803cfDgQeTn5yM1NdUwIutqMmzYMNx7772YO3cu9u3bh//93/9FQUEBpk6dqkWGfvrppxg6dKh2j9fx48exatUqVFdXo7a2Fm+++Sby8/Px3e9+F6NGjYrk2+k2brcb69atw/r16/Hhhx9i/vz5lzwPs7CwUJt/4cKFqKiowHPPPYfDhw9jxYoV2L9/f6c7gKuB2W22cuVK/M///A8+/vhjHDhwAD/+8Y9x8uRJzJkzJ1JvIeyampq07y3gQvDexe80ACgsLER+fr42/6OPPoqPP/4Yjz/+OA4fPoyXXnoJr732Gh577LFIND9qRNWZoBnLly/H+vXrtb9Hjx4NAHjvvfcwYcIEAMCRI0fQ0NCgzfP444+jubkZ8+bNQ319PcaPH4+Kigpb3aT76quvoqCgAHfffTccDgceeughvPDCC9rrbW1tOHLkiBb9GRcXh3feeUd7MHJaWhoeeughLFu2LFJvodtF4nmY0c7sNvvXv/6FuXPnwuPx4Fvf+hYyMzOxe/duDB8+PFJvIez279+PiRMnan+73W4AwIwZM1BeXo7Tp09rHSJwYTRr69ateOyxx/D888/j+uuvx5/+9Cfk5uaaWm8AsOBMMHpc9fcJEhHR5V28T3Bw7G/hULr2wz+gtuDjtiW8T5CIiOhKdtUOhxIRkXkXglrsExjDTpCIiDR26wQ5HEpERLbFM0EiItL4Lcj9GU1nguwEiYhIw+FQIiIim+CZIBERaex2JshOkIiINH4lAFXpWs6XQBTljOFwKBER2RY7QSILnDlzBikpKXj66ae1st27dyMuLg6VlZURbBmROX6olkzRgp0gkQWSkpJQVlamPSHi3LlzeOSRR7Rk5ETRImBBBxjqNcGSkhIMHDgQ8fHxyMrK0p5WI3n99dcxdOhQxMfHY+TIkZc8Sqoz2AkSWeS+++7D3LlzMX36dDz66KO45pprUFxcHOlmEUWFTZs2we12o6ioCAcOHEB6ejpyc3MvedD3Rbt378a0adMwe/Zs/OMf/0BeXh7y8vJw6NAhU+vlUySILHT+/HmMGDECp06dQnV1NUaOHBnpJhF1ysWnSFzrKoLSxadIqGoLmlqfMvUUiaysLNx222148cUXAVx48HJaWhp+9rOfYenSpZfMP2XKFDQ3N+Ott97Sym6//XZkZGSgtLS0023lmSCRhY4fP47PPvsMgUAAtbW1kW4OkWkqWqGqLV2b0ArgQsf69am1tVV3nT6fD9XV1cjJydHKHA4HcnJyUFVVpbtMVVVV0PwAkJubK84v4S0SRBbx+Xz48Y9/jClTpmDIkCGYM2cODh48iL59+0a6aUSXFRcXh5SUFHg8v7GkvmuvvRZpaWlBZUVFRVixYsUl8549exZ+v197yPJFycnJOHz4sG79Ho9Hd36Px2OqnewEiSzyy1/+Eg0NDXjhhRdw7bXXYtu2bfjJT34SNFxDdKWKj4/HiRMn4PP5LKlPVVUoihJU5nK5LKnbSuwEiSywY8cOrFmzBu+99552DeQ//uM/kJ6ejrVr12L+/PkRbiHR5cXHxyM+vmvXA0PRp08fOJ1OeL3eoHKv14uUlBTdZVJSUkzNL+E1QSILTJgwAW1tbRg/frxWNnDgQDQ0NLADJLqMuLg4ZGZmBt1TGwgEUFlZiezsbN1lsrOzL7kHd/v27eL8Ep4JEhFRxLndbsyYMQNjxozB2LFjsWbNGjQ3N2PWrFkAgPz8fPTv31+77WjhwoW466678Nxzz+H+++/Hxo0bsX//fvzxj380tV52gkREFHFTpkzBmTNnsHz5cng8HmRkZKCiokILfqmrq4PD0TF4OW7cOGzYsAHLli3DE088gZtvvhlbtmzBiBEjTK2X9wkSEZFt8ZogERHZFjtBIiKyLXaCRERkW+wEiYjIttgJEhGRbbETJCIi22InSEREtsVOkIiIbIudIBER2RY7QSIisi12gkREZFv/B6dynoyri+8jAAAAAElFTkSuQmCC\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbgAAAGiCAYAAACVh9NOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABENElEQVR4nO3df3RU5Z0/8PedSTIh0oS6QCIYjb+VggTDkg3agqdZ08phyx/dpcACZgXXH9kDZKuAAhE9JdYfbGyLZsVNaffIgu3xx57C4qGpWdclhSWQU7WCi4ChrBOhfkkgQCaZ+3z/SBkZcz8PeSY3mcy97xdn/uCZ+ysz97nP3Od+ns9jKaUUiIiIPCaQ7AMgIiIaCGzgiIjIk9jAERGRJ7GBIyIiT2IDR0REnsQGjoiIPIkNHBEReRIbOCIi8iQ2cERE5Els4IiIyJPYwBER0YB65513MHPmTIwZMwaWZeGNN9645DoNDQ247bbbEAqFcP3112PTpk3G+2UDR0REA6qjowMTJ07Ehg0b+rT8kSNHMGPGDNx5551obm7G0qVLsWjRIrz11ltG+7WYbJmIiAaLZVl4/fXXMWvWLHGZ5cuXY9u2bXj//fdjZd/73vdw6tQp7Nixo8/7SuvPgRIRUWo4f/48IpGIa9tTSsGyrLiyUCiEUCjU7203NjaitLQ0rqysrAxLly412g4bOCIijzt//jyuuSYP4XCba9scPnw4zpw5E1dWVVWFxx9/vN/bDofDyM3NjSvLzc1Fe3s7zp07h2HDhvVpO2zgiIg8LhKJIBxuw+FP/gnZ2X1rHHTa28/h2quX4dixY8jOzo6Vu3H35iY2cEREPpGdPcyVBu6L7WXHNXBuycvLQ2tra1xZa2srsrOz+3z3BrCBIyLyDaW6oVS3K9sZSCUlJdi+fXtc2c6dO1FSUmK0HQ4TICLyCaWirr1MnDlzBs3NzWhubgbQMwygubkZLS0tAICVK1diwYIFseXvv/9+HD58GI888ggOHDiAF154Aa+++iqWLVtmtF82cERENKD27t2LSZMmYdKkSQCAyspKTJo0CWvWrAEAfPrpp7HGDgCuueYabNu2DTt37sTEiRPx3HPP4eWXX0ZZWZnRfjkOjojI49rb25GTk4PwyedcCzLJG/mPaGtrG5BncG7hMzgiIp9IlWdwbmEXJREReRLv4IiIfKInQMSNOzizIJNkYQNHROQTyu6Gsl1o4FzYxmBgFyUREXkS7+CIiPxCdfe83NhOCmADR0TkE4yiJCIi8gDewRER+YXdDdhd7mwnBbCBIyLyiZ4uyqAr20kF7KIkIiJP4h0cEZFf2N2A3f87OHZREhHR0OKzBo5dlERE5Em8gyMi8o2oS4O0mYuSiIiGEMvuhmX3v+POYhclERFR8vAOjojIL+xuwIU7uFQJMmEDR0TkFz5r4NhFSUREnsQ7OCIin7BUNyzlQpBJiqTqYgNHROQXtg3YLoT423b/tzEI2EVJRESexDs4IiKf6BkHZ7mynVTABo6IyC/sqEtRlKmRyYRdlERE5EmuN3DvvPMOZs6ciTFjxsCyLLzxxhuXXKehoQG33XYbQqEQrr/+emzatMntwyIiIrvbvVcKcL2B6+jowMSJE7Fhw4Y+LX/kyBHMmDEDd955J5qbm7F06VIsWrQIb731ltuHRkTka5Ydde2VClx/Bvftb38b3/72t/u8fG1tLa655ho899xzAIBbbrkF7777Lv7pn/4JZWVlbh8eERH5RNKDTBobG1FaWhpXVlZWhqVLl4rrdHZ2orOzM/Z/27bx+eef48/+7M9gWf2PECIiSjalFE6fPo0xY8YgEHCps025FGSifHoHZyocDiM3NzeuLDc3F+3t7Th37hyGDRvWa53q6mqsXbt2sA6RiChpjh07hiuvvNKVbVm27Ur3opUiA72T3sAlYuXKlaisrIz9v62tDVdddRWOtjyP7Oz4BvF85wnHbUTP/kHcfuCs8zqBs390LA+eOeW8/Lkz4j6ssx3Ob5w/51x+9rxzecT5Ya8SFgcAFXH+BacizqeD6nIut7uD4j5sYR3VLew76rwtJYzZUba8b6WkdQzv7hMZLxRQRotbwvKWJW/HCjhfoMRtBYXl05wvUoF0OYAgkCZsS1jHypDK5QuklSm8kSFcrrKEFTJ7/zgGAJV1mbhve9hwx/Lo8BHOy2f9mVA+StxHMMu5scoMxa/T3n4OBVctwVe+8hVxW6SX9AYuLy8Pra2tcWWtra3Izs52vHsDgFAohFAo1Ks8O3sYsrOz4soyOp1P/mha7/UvCATTncsDzh9XEM4X20BAvghbltBNEJAuqkJ5ULiYa67NStiHWC7sw+6SuzpsoUtFBaUGTiiXGjhhecDNBi6BrpyA2S/bhBq4oLCO2MAJ5ULtD6TLf3cg3fnvs9KdP1srQygPyd+F2MBJ62QJ5cOEcypL88NMeC86XPiRd5nztcK+TL6+BIUGOTOU5Vju6mMXO5rYDzen7aSApDdwJSUl2L59e1zZzp07UVJSkqQjIiLypp4ISDcymaRGA+f6MIEzZ86gubkZzc3NAHqGATQ3N6OlpQVAT/figgULYsvff//9OHz4MB555BEcOHAAL7zwAl599VUsW7bM7UMjIiIfcf0Obu/evbjzzjtj/7/wrGzhwoXYtGkTPv3001hjBwDXXHMNtm3bhmXLluH555/HlVdeiZdffplDBIiI3MYuyv6ZPn06lJKfHzhlKZk+fTr279/v9qEkn+4kkKKQTKOTpMUTCXIyfXalqyjSe8JcVPKzNiH4RPOQUT5eYd+6B5amhOOVnqkp6YvS9a0I+wCk4BNhY7ZQTzXfq/TZWqafoe78ND13TetSilycBwK7KImIiDwg6UEmREQ0SNhFSUREXmTZypVB2pbUvT3EsIuSiIg8iXdwRER+YUcTC0Bz2k4K8HwDp1QC8xYJX54YOTQYX7ZxdKV5hKMYfZhAZhA5xZbhPhLJSmIYLWmc4SQRQl+JJR2r7iokbcv07xMyZOg+D9N9iBF72s9ciu4chPyHg1D3E7omuUW51MClSLJldlESEZEnef4OjoiIeljKNh+zKGwnFbCBIyLyC589g2MXJREReRLv4IiI/MK2XRrozS5KIiIaStjAkWt0J4FxglgpdNrskABASRN5SiHrYti9pofbNKmyuG+zxMkAYGsmQ3Xch5vJliXC96SEiWx1f4E0hEBJk+gKiZ6lJMy671X6nsThDtLyuhNXTCBuOHzArYTmlLLYwBER+YRl27BcaN/dSPc1GNjAERH5hW27FEWZGg0coyiJiMiTeAdHROQXPruDYwNHROQXbOB8wnYzCXMCX7ZL0ZJSxhxtZKBpIuQEEua6lVRZipbURUrKSZUHPrrSEiIWpdmzrIDzF6j7+6R3VECIlnQrQTJgfO5I29Kfn8Lf4VZ0pYZYl00zdyRyfSHX+beBIyLyGxWVfxAYbYd3cERENIT4bZgAoyiJiMiTeAdHROQXDDIhIiJPYgNHrkVADcZJIEaWadaRItgMo+3ESEntOkLOwu6gc7lp1KVuH4Z/XyKk3JKWEOEo5tQUoisBzd8ufYZCZKeUozKR71WcRFOMjhV34c4FWMfNesloySGNDRwRkV/Yyp0G3o1IzEHABo6IyC9s5VIXZWo0cIyiJCIiT+IdHBGRX7g24Snv4IiIaCixbfdehjZs2ICCggJkZmaiuLgYe/bs0S5fU1ODm266CcOGDUN+fj6WLVuG8+fPG+2Td3BuMM1TB8ipbqQTRxnO6O1mhKM0w7N2Rm/nfUg5FqXIQHF5TV5J02jJZM7oLf7E1M1YLm0qaJbXMihFaupm9BZnajebkV2f71LcubC8YaJWnUTqMl3S1q1bUVlZidraWhQXF6OmpgZlZWU4ePAgRo8e3Wv5zZs3Y8WKFairq8PUqVPx0Ucf4Z577oFlWVi/fn2f98s7OCIiv7CVey8D69evx+LFi1FeXo5x48ahtrYWWVlZqKurc1x+165duP322zF37lwUFBTgrrvuwpw5cy551/dlbOCIiPxC2e69ALS3t8e9Ojs7e+0yEomgqakJpaWlsbJAIIDS0lI0NjY6HubUqVPR1NQUa9AOHz6M7du34+677zb6c9nAERFRQvLz85GTkxN7VVdX91rm5MmTiEajyM3NjSvPzc1FOBx23O7cuXPxxBNP4I477kB6ejquu+46TJ8+HY8++qjR8fEZHBGRXyiXxsH96XnosWPHkJ2dHSsOhUIubBxoaGjAunXr8MILL6C4uBiHDh3CkiVL8OSTT2L16tV93g4bOCIiv3B5oHd2dnZcA+dk5MiRCAaDaG1tjStvbW1FXl6e4zqrV6/G/PnzsWjRIgDAhAkT0NHRgfvuuw+PPfYYAoG+dT6ygTNgGUZYaZd3aUZveYZlzQkgvWc4c7fdrYm2E/dhFm0nbSeRGb1Nx/+I0aMaljDZlpSjUgmRgdLM4IBmRm9hHcsyi3y0u+V9B4Oms76bR8HKs2pLdcMsulJXL8WZ1xldmbCMjAwUFRWhvr4es2bNAgDYto36+npUVFQ4rnP27NlejVgw2JNrVaozTtjAERH5RZJSdVVWVmLhwoWYPHkypkyZgpqaGnR0dKC8vBwAsGDBAowdOzb2DG/mzJlYv349Jk2aFOuiXL16NWbOnBlr6PqCDRwRkU9cFADZ7+2YmD17Nk6cOIE1a9YgHA6jsLAQO3bsiAWetLS0xN2xrVq1CpZlYdWqVTh+/DhGjRqFmTNn4gc/+IHRftnAERHRgKuoqBC7JBsaGuL+n5aWhqqqKlRVVfVrn2zgiIj8wmezCbCBIyLyCxsuNXAubGMQcKA3ERF5Eu/g3CCGNSeS7FXalnNxIiHxcki+UB4VopY0SYpNkyq7Vd6zc2kds2NKhFLOn5UU9i+FxGvzWLv0WRknYYZ8Lqigcxh9IsmWpeO1DIcDiHTLuzHb9VDmszs4NnBERH6hIA/2M91OCmAXJREReRLv4IiIfELZlthtbLYdFw5mELCBIyLyC589g2MXJREReZJ/7+DsbvEtMbGqmxFWhgllxS4BqVwXGShG27mTCLnnuMySJNtCdJ6YVFnTzWLbQqSfcbJl864cKVpSSrZsBZzLA5CT+9ou/S4NSseq+V6VFPUphX1Kn6E2Clbat3O5a9GVOqaJmzXXl6RSlnHSceft9H8Tg8G/DRwRkc/47RncgHRRbtiwAQUFBcjMzERxcXFs2nFJTU0NbrrpJgwbNgz5+flYtmwZzp8/PxCHRkREPuH6HdzWrVtRWVmJ2tpaFBcXo6amBmVlZTh48CBGjx7da/nNmzdjxYoVqKurw9SpU/HRRx/hnnvugWVZWL9+vduHR0TkX7ZLXZR+vYNbv349Fi9ejPLycowbNw61tbXIyspCXV2d4/K7du3C7bffjrlz56KgoAB33XUX5syZc8m7PiIiMqQs914pwNUGLhKJoKmpCaWlpV/sIBBAaWkpGhsbHdeZOnUqmpqaYg3a4cOHsX37dtx9993ifjo7O9He3h73IiIiupirXZQnT55ENBqNTWJ3QW5uLg4cOOC4zty5c3Hy5EnccccdUEqhu7sb999/Px599FFxP9XV1Vi7dq2bh94vUj5BV3PeiVGUZnklAUB1C7klTXNUan7FiXkRhQg98XiFcilSUrctMSejm79GhW0FhIhF6Xu1If99UoSlsoS/25JycEoRkeZ5TMVzRzjXVLouilnah7iKsLx5vRTrskcwyGSQNTQ0YN26dXjhhRewb98+vPbaa9i2bRuefPJJcZ2VK1eira0t9jp27NggHjERUYqyA+69UoCrd3AjR45EMBhEa2trXHlrayvy8vIc11m9ejXmz5+PRYsWAQAmTJiAjo4O3HfffXjsscfipjG/IBQKIRQKuXnoRETkMa42wxkZGSgqKkJ9fX2szLZt1NfXo6SkxHGds2fP9mrEgsGebg2lUmQ0IRFRKrgQRenGKwW4PkygsrISCxcuxOTJkzFlyhTU1NSgo6MD5eXlAIAFCxZg7NixqK6uBgDMnDkT69evx6RJk1BcXIxDhw5h9erVmDlzZqyhIyKi/lPKcmX+w1S593C9gZs9ezZOnDiBNWvWIBwOo7CwEDt27IgFnrS0tMTdsa1atQqWZWHVqlU4fvw4Ro0ahZkzZ+IHP/iB24dGREQ+MiCpuioqKlBRUeH4XkNDQ/wBpKWhqqoKVVVVA3EoUInkhFPO64h55yTaKEqzqDo5P5+Uq1GXT9CdmbvFaEwkMKO3lKMygbySUTFS06xH3s1clNKZYwmzZwc1IYNShKX016mA8zFJ35FlaXpOpJm7DWf61p+fwnGJeR+FDbmYo1Lct3Ct0EnomuQWO+DSQO/UuIVjLkoiIp9QtnnSceftpEYDlxqxnkRERIZ4B0dE5BeuTZfj0yhKIiIamtyLokyNBo5dlERE5Em8g3ODaXSljjTAxDQHoC4XpeHM3dJs2/qchWYzepsekxQpqdu3GMHp4q9RaVtSdKUUTag7o6QISynnpB113rc4+7gm0aB0LlhBs5m+deennItSmh3cxYAHN+vyUORWmq0UyUXJBo6IyCfcS7bMLkoiIqKk4R0cEZFP+C3IhA0cEZFf+OwZHLsoiYjIk3gHR0TkE34LMvFvA5dAklTjEGJtsmXhPSmpstTnLSYv1iTMNRxaIA8fkDsAxPcMtyUNB5DC1QHNcADp7xuE5wlSSL4lJEK2XDymoOV8UonJlh0mGY69FzQclmKa2BsQ66Y4BMM0qXICyZaN634i15dB4LdncOyiJCIiT/LvHRwRkd/4LMiEDRwRkU/47RkcuyiJiMiTeAdHROQTfgsyYQPnQJyeXiJFZWmS1opTvkuriLswTEyrWUeMbJOSF+sSHgsVwLbNEjcnkuhZTuhsVimlRME6lhCxKEYAComCA0LyYkD+TAJCpKb0mQeECE7d9yqeC9L3ZAvRhAmcn8JHq6kzQh3T1kuzh0vG14pkUy49g0uNCb3ZRUlERN7EOzgiIp/wW5AJGzgiIp9Qyp3nZ25OwTeQ2EVJRESexDs4IiK/cKmLUhckNJSwgTMhRUwlkvNOeE8M8BKjJaUIR02UmmFkohSVqMtFKe1DzDnp0jHp1pH2YbodHcsyWycYdD6ndH+fFGFp+vdZlvAdCdGYuuOy0swic8XoSkCO8hNzpTovbrlYL41zUQ5RSgUSig7uvZ3U6KNkFyUREXkS7+CIiPzCttzpXmQXJRERDSV+y2TCLkoiIvIk3sEREfkEB3r7hKWN4hqEiCnDXJRSfkB5lmpNlKHpjN7SDNmabgox+lHMUWkWqSnlV+x5z72/w5SYc1KITJQiH6U8kYmwAs4nVcDF79X8nErk/BQiOMXo5kGI9BP2rb2+JBGjKImIiDzAt3dwRER+wy5KIiLyJEZREhEReQDv4IiIfMJvd3Bs4AxIue3kKC5dzjvTcsNcf93meSLl/JFmy/esI0VFmuWolKIlownkojSN7HSTNNu2FF3pZpSaac5JaXkAsKNCRKZQrqTyBM5POR+rsCGxXK6XUl0W81qmGKVcegaXIg0cuyiJiMiTeAdHROQTfhsHxwaOiMgn/DZMgF2URETkSbyDIyLyCUZREhGRJ7GB8xo1CElPxWnuNaHF0kNaqW9bOqHE5c3D6OXQfiHRs+Ykl5IIm25LGg4gJWcG5LD/wUi2LFGGwwSkYQU6luV8vkmfeVQYDhAIyuet6TkiJXTWnZ/u1QGpjumG7wzCcIDBuCYRAD6DIyLyDWV/EWjSv5f5vjds2ICCggJkZmaiuLgYe/bs0S5/6tQpPPTQQ7jiiisQCoVw4403Yvv27Ub79P4dHBERAUheF+XWrVtRWVmJ2tpaFBcXo6amBmVlZTh48CBGjx7da/lIJIK//Mu/xOjRo/HLX/4SY8eOxSeffIIRI0YY7ZcNHBERJaS9vT3u/6FQCKFQqNdy69evx+LFi1FeXg4AqK2txbZt21BXV4cVK1b0Wr6urg6ff/45du3ahfT0dABAQUGB8fGxi5KIyCcuDPR24wUA+fn5yMnJib2qq6t77TMSiaCpqQmlpaWxskAggNLSUjQ2Njoe57//+7+jpKQEDz30EHJzczF+/HisW7cO0ajZZNS8gyMi8glbWa7kXr2wjWPHjiE7OztW7nT3dvLkSUSjUeTm5saV5+bm4sCBA47bP3z4MH7zm99g3rx52L59Ow4dOoQHH3wQXV1dqKqq6vNx+reBkxIkQ5c82TCpsq2JhBNWESP9pGTLQrmtS2bbLSRPliITpX1oEh5LUY5SuRh1KS2viaKUoyWlYxqEZMsBKYpSSEacQBSl9BlKLOH81H2vgYB0HgrRscK5pjs/A4bnuvR9S3+fvl5KdVlKwmx4rfCY7OzsuAbOLbZtY/To0XjppZcQDAZRVFSE48eP45lnnmEDR0REDlxK1SUOzXAwcuRIBINBtLa2xpW3trYiLy/PcZ0rrrgC6enpCAa/+IF0yy23IBwOIxKJICMjo0/7HpBncMkIByUiIr0LUZRuvPoqIyMDRUVFqK+vj5XZto36+nqUlJQ4rnP77bfj0KFDsC+6o/7oo49wxRVX9LlxAwaggbsQDlpVVYV9+/Zh4sSJKCsrw2effea4/IVw0KNHj+KXv/wlDh48iI0bN2Ls2LFuHxoRESVBZWUlNm7ciJ/97Gf48MMP8cADD6CjoyMWVblgwQKsXLkytvwDDzyAzz//HEuWLMFHH32Ebdu2Yd26dXjooYeM9ut6F2WywkGJiEgvWePgZs+ejRMnTmDNmjUIh8MoLCzEjh07YoEnLS0tCAS+uN/Kz8/HW2+9hWXLluHWW2/F2LFjsWTJEixfvtxov642cBfCQS9uiU3CQd98802MGjUKc+fOxfLly+P6Xy/W2dmJzs7O2P+/PBaDiIh6S2YuyoqKClRUVDi+19DQ0KuspKQEv/3tb433czFXG7jBCgetrq7G2rVr3Tz0vjHNU6fNeSeUdwuRYlLkoxQ5pzkBjXNR2s770EXtifkPXYqWlI6p5z13clEmEk4t5ZCUJoi0LOd9SFGXOlJeS9PlowH5s5XyVFpRIRpUzEWpOT+Fc0SqA1KdEetYIvkmByNHJbku6QO9Lw4HLSoqwuzZs/HYY4+htrZWXGflypVoa2uLvY4dOzaIR0xElJpsFXDtlQpcvYMbrHBQKR0MERHJlHJpRu8UmS7H1WY4meGgREREF3P9PjNZ4aBERKSXjHFwyeT6MIFkhYMSEZEeZ/R2QTLCQV1lmkcugVyU0oSBYj5IwxyVdgIRjspwtmbtrNpClKNb0ZLSTN+AZnZwzfE6bgfmlViKlpQEA1L0YQK5KA2jKOXvyHxG74t/tPZled35KeWilOuAcD6L+SMTyEUpLu+PnJOpirkoiYh8wu3ZBIY6NnBERD7hty7K1BjMQEREZIh3cEREPuG3Ozg2cEREPsFncD5h2d0JrCNFSxqWA5o8eYaRYlLePk22AimCTZrxWswfqYlklKIcpUjGaNT5VDTdDgAxjZBUKQfj16iU91H6ngLCTN96UnV2PtetqPMxSRGRgJynMpjmHE0onmua81M+p81mnE8oF6Xwnlj3Bbrri3l8LCXKtw0cEZHfKOXOD7oERrEkBRs4IiKf8NszOEZREhGRJ/EOjojIJ5RLQSapcgfHBo6IyCfYRUlEROQBRndw06dPR2FhIWpqagbocIYGS0qgKpVLmZO1SV2FcimUvVtKKCssr0lmKydoNkuqLCXr1a8jDRMwTMKsmVFYXsdsmEAiXTkBYTiAOExASpCs++lpOILAsoTPQ9i3/nt13rmYhNnF81OqA1KdkYcJaOqlWJelxM2plWzZb3dw7KIkIvIJvw307nMX5T333IP//M//xPPPPw/LsmBZFo4ePTqAh0ZERJS4Pt/BPf/88/joo48wfvx4PPHEEwCAUaNGDdiBERGRu9hFKcjJyUFGRgaysrKQl5c3kMdEREQDgF2UREREHuD5IBOlhESzupWMp62XIq/kVZQQdaa6zRIhS9FoukTI0r7F6ENpeV2yZTFa0vl4peW7hYg+XbJl6T2pW8XN7hbpK5eiKKVy3S/kYMDseC3hM5f2HdB8r1IiZukzD0qRudrz07AOSHVG2IeljW52qe5rSNekwaBgQemvfn3eTiowauAyMjIQjaZWWCwREfXw2zM4oy7KgoIC7N69G0ePHsXJkyfFMTFERETJZtTAff/730cwGMS4ceMwatQotLS0DNRxERGRyy4EmbjxSgVGXZQ33ngjGhsbB+pYiIhoALGLkoiIyAM8H0UpSiSHnBgtKZTrpr2V8u0Z5uETIwM1UYZSrkFbisiUlk9gH93CPkyjJXVRlGIeTCHya1ByUUKIWEzi1MiBgJSLUn627ta5ozs/xXPaMB+rnO81gShK4+jKoRmMZ8OlcXApEkXJOzgiIvIk/97BERH5jN+ewbGBIyLyCRuWK92L7KIkIiJKIt7BERH5hUtdlOIks0MMGzgH8ozehhFWusArw5mOxXIhD58u1584c7dhzkkpIhIAuqXjEvIJirkopdyVCczoLQXPuTloNSp03cgzfTtvJ6jpAjK9QEn7lr4jy5JP3EDA+T3TnJPaXJRibkmzcqmO6eulWR1PtRm9OZsAERGRB/AOjojIJxhFSUREnmRD30Nrsp1UwC5KIiLyJN7BERH5BLso/cI2n1VXjq6UwvPkbUmRYqY5KqVoNGk2Y0Azq3bU+XQQIxk1UZRSJGOXlIvSMFqyW/f3CetIkV+DUVnF2bOFcm2KSuljNwzoE/Nj6qJjxShK53NHmiA5LYEoXznnpJS7Uqpjmg9KqMvG0ZIJXF8Gg63ciYDUTYo+lLCLkoiIPMm/d3BERD6jYEG5kGbLjW0MBjZwREQ+wYHeREREHsA7OCIin+gJMnFnO6mADRwRkU/wGZzXJBKu61KyZaXbtdCHbRsmlJWWt2051NsWQvilpMrS8tpky8L+xaTK4vLOn1OX5u+Twv6j0mc+CM8TpOEAQaHclrIwa6iA8zrSEIVuS0i2HJV/ngeDznXD9NzRnp8u1YGgNCxEUy8t44TqCSRbHqJDCLzI+w0cEREB8F+QCRs4IiKfUOoSSQQMtpMKGEVJRESexDs4IiKfULBgM8iEiIi8hsmWfUKbPNU4ijKBZMuGiWOlBLRSebeUaBZAVHhPjq4UlheSGgMDHy3ZLWy/57ikKEqzJMxukqIobcv5JJGW13M+by0hWlKM7AyYJ+q2pehK4VzTnZ9Bw3PdNEG5djIzsS6bRVEaJ2emAeHbBo6IyG/8FkU5IEEmGzZsQEFBATIzM1FcXIw9e/b0ab0tW7bAsizMmjVrIA6LiMjXlIuvVOB6A7d161ZUVlaiqqoK+/btw8SJE1FWVobPPvtMu97Ro0fx/e9/H1//+tfdPiQiIvIh1xu49evXY/HixSgvL8e4ceNQW1uLrKws1NXVietEo1HMmzcPa9euxbXXXuv2IREREb7oonTjlQpcbeAikQiamppQWlr6xQ4CAZSWlqKxsVFc74knnsDo0aNx77339mk/nZ2daG9vj3sREZGe7eIrFbgaZHLy5ElEo1Hk5ubGlefm5uLAgQOO67z77rv4l3/5FzQ3N/d5P9XV1Vi7dm1/DlXLPB+dVC7/ylFCFJmUb08ulyLONFGUwntSbkmpvEuXi1LclvPxmkZL6nJRShGZ0vgfKeQ5kUos/WKU8kFGhWMKBgb+KUcAQi5KTQRnQIj6DEr5IA3zRwLyOW1cN6RITU29NK3j4rWChoSkZjI5ffo05s+fj40bN2LkyJF9Xm/lypVoa2uLvY4dOzaAR0lE5A0XxsG58UoFrt7BjRw5EsFgEK2trXHlra2tyMvL67X8xx9/jKNHj2LmzJmxMvtPv4jS0tJw8OBBXHfddb3WC4VCCIVCbh46EZHncZhAP2RkZKCoqAj19fWxMtu2UV9fj5KSkl7L33zzzXjvvffQ3Nwce/3VX/0V7rzzTjQ3NyM/P9/NwyMiIh9xvYuysrISGzduxM9+9jN8+OGHeOCBB9DR0YHy8nIAwIIFC7By5UoAQGZmJsaPHx/3GjFiBL7yla9g/PjxyMjIcPvwiIh8K5nj4JIxPtr1TCazZ8/GiRMnsGbNGoTDYRQWFmLHjh2xwJOWlhYENKmAiIhoYCSri/LC+Oja2loUFxejpqYGZWVlOHjwIEaPHi2u19/x0QOSqquiogIVFRWO7zU0NGjX3bRpk/sH5EQ73bZAipjqdv49o6KaKEopuqzbLA+fNDOymLcPchSlGF0pRThqIuGkKEepPCJFakq5KzUVTHpPjJYchChKKe+jFLFoSzkRoTleZZb/UJo0XDujd8D5U0kzPKd056fpOS3VGTnqUj53LKEui3Vfksj1xcMuHh8NALW1tdi2bRvq6uqwYsUKx3UuHh/9X//1Xzh16pTxfnkrRUTkE26Pg/vyeOTOzs5e+xys8dFO2MAREfmE28ME8vPzkZOTE3tVV1f32qdufHQ4HHY8zgvjozdu3Nivv5ezCRARUUKOHTuG7Ozs2P/dGL6V6PhoJ2zgiIh8QsGdNFsXnlRmZ2fHNXBOBmt8tBN2URIR+YSCS12UQoo5J8kcH+39Ozi7K4F1hGg0Jfz2UVLkleb3gzQLsWHuPmlm5GhU/mrF2bZNc1Rq8kFKOSRNc0tKy+uiKLvFGb3NoivdJEVLBoVypckHqXvPiRTBKZfrZtt2rgPdAed10sRZ4uXzUzqnpUhNqc6IOSd1uSjFOm42o7dWItekFFdZWYmFCxdi8uTJmDJlCmpqanqNjx47diyqq6tj46MvNmLECADoVX4p3m/giIgIAGCrnpcb2zGRrPHRbOCIiHzCrdm4E9lGMsZH8xkcERF5Eu/giIh8wm+zCbCBIyLyCbdm406VaV7ZRUlERJ7k2zs4SxfeazhtvfRzRgnhzoAc9i8mlBXLpWEC8m+X7m7nr900qXKXZh9S8uROITxcHj4glOuSLRsOLRiMX6PSJyUNE4gKSY0BIN1wmIBpNbc0IQRBYQhBd1AaxuK872ianIx4oOuGrl6KYf+G1wTt9SWJ3JqN25czehMR0dDFLkoiIiIP4B0cEZFPKCUnXjLdTipgA0dE5BM2LNgGeSR120kF7KIkIiJP4h2cA8swYkqanV5pogxVt/N7YrJlKRpNiAiTIh9174nRkmIiZHkfpsmTO42TM8ufrRQtKUZRDkJ3S0D4wStFUQY1UWpKiLA0ffAvJVWWjgkAuoQPS0qELCVh1p2f0jkt1QGpzsh1TFcvncula4J4rRiikpWLMlnYwBER+YVLz+BcSWg5CNhFSUREnsQ7OCIin/BbkAkbOCIin/DbMAF2URIRkSd5/w5OCovS5qI0zUcnbUcTreVWvj0hWjEq5HwEzHNOSpGMUr5J3XtStKS0vJRzUjomAOgWfl1KUZSD8WvUEqMoncvTpBWgOd1cyg8Y0OWiFCI4xXNHyFGpOz+lc9q0bkh1TFcv5Q9XqvvStUJzfZGuSYPAb6m6vN/AERERAP8NE2AXJREReRLv4IiIfELBnSFsKXIDxwaOiMgverooXRgmkCItHLsoiYjIk3x7B5fQjN7dwjrdQnSeLued8F60S5jxWpqFW8pFqZm1WIp4kyIZ5YhIeR+m0ZIRMeek82erndFbzEXpvLz0izaR6Eop+DEg5ZwUVohq8kHaSogaDLjzs9oSclQCQDDqXDeClnO5mKNSd36K57SQj1WoM1Id09VLqS6LdT/lZvT21zg43zZwRER+47dhAuyiJCIiT+IdHBGRT7CLkoiIPIldlERERB7AOzgnYt455/tyJUXhaSLFTCO/TGfu1s22HRHyAEqzcHcKy+vyQZpGS3YK0ZLS8lJEJCBHXkrfk5S70k1SbklLiJZM1/z0jArbssXf1Wa/Y9M0EZzSuZAm5KiUzrUMW87HKJ3TUh0wjUjW1UvpHLGkgV8pNqO3cilVF7soiYhoSPFbJhN2URIRkSfxDo6IyCf8NpsAGzgiIp/w2zABdlESEZEn8Q6OiMgn/DYOzvsNnBSOrEmGKiZKFZMwC2HbQrgzACgxFNr5K7GFEH4pAW2XZt/dQli1aWj/ec0+5HXMhgl0J5RsWSiXRn9A+P4S6IYJSLl6hbizoLCClDAaANKFMH4xCbNhf1JA872KwxqEddIDznVJe34a1gGpzkh1TFcvxWTLpkmVtcnc5SESA81vz+DYRUlERJ7k/Ts4IiIC4L9xcGzgiIh8gl2UREREHsA7OCIin/DbODjfNnDaKeXFaEkh2bIUxSUkewXkRLBSBFm3kCC2S4qiFMoBOQGuGEUplHdqki1LUZFyUmWzcimhcs97zuVSnewehJjnNCFa0hL2rU22LLyXYXhMlpgAWt55UIqWFP6Q9ICQbLlbrn9dQbM6INUZMdmypl6KiZi7u5zLTaMrk8xvwwTYRUlERJ40IA3chg0bUFBQgMzMTBQXF2PPnj3ishs3bsTXv/51fPWrX8VXv/pVlJaWapcnIqLE2Pgi0KRfr2T/IX3kegO3detWVFZWoqqqCvv27cPEiRNRVlaGzz77zHH5hoYGzJkzB2+//TYaGxuRn5+Pu+66C8ePH3f70IiIfE25+EoFrjdw69evx+LFi1FeXo5x48ahtrYWWVlZqKurc1z+lVdewYMPPojCwkLcfPPNePnll2HbNurr690+NCIi8hFXG7hIJIKmpiaUlpZ+sYNAAKWlpWhsbOzTNs6ePYuuri5cfvnl4jKdnZ1ob2+PexERkZ5KpDvS4eXLKMqTJ08iGo0iNzc3rjw3NxcHDhzo0zaWL1+OMWPGxDWSX1ZdXY21a9f261i1U813C7nihCSHSghrs7vl3w+2sE7UMFqyW8jPJ0U+6t6LSNsyzCupe69TzEXpvB05ilLctfielKNSqqyJPGeQPpGIsLGgEAwqHWvPe2a5M5Umr6WpNCkXpXDupAsJQHXnZ0jYllQHpDoj1TFdvZTqsviFSNcK3fUliZRyKZNJijRwQyqK8qmnnsKWLVvw+uuvIzMzU1xu5cqVaGtri72OHTs2iEdJRESpwNU7uJEjRyIYDKK1tTWuvLW1FXl5edp1n332WTz11FP49a9/jVtvvVW7bCgUQigU6vfxEhH5CcfB9UNGRgaKioriAkQuBIyUlJSI6z399NN48sknsWPHDkyePNnNQyIioj/peYamXHgl+y/pG9czmVRWVmLhwoWYPHkypkyZgpqaGnR0dKC8vBwAsGDBAowdOxbV1dUAgB/+8IdYs2YNNm/ejIKCAoTDYQDA8OHDMXz4cLcPj4iIfML1Bm727Nk4ceIE1qxZg3A4jMLCQuzYsSMWeNLS0oJA4IsbxxdffBGRSATf/e5347ZTVVWFxx9/3O3DIyLyLU6X44KKigpUVFQ4vtfQ0BD3/6NHjw7EIXxBmD03oVyUUo5DIYpLmlG457CEiDAhT6SUby8ilQvbAYBOIUpNmqFbKpeiKwE5WvK88LGfl3JXCsvr8kcKKUMRFUK/dBGLbpGiJYNCPkjd39clBfoJO5Fn9HbekDQrOSB/f0EhujIt4PyHZGjOT+mcFnNOCtuS6piuXoq5KKXvI5FclMme0dul7aSCIRVFSURE5BbfziZAROQ36k//3NhOKmADR0TkE+yiJCIi8gDewRER+YTfBnr7t4HTRTkp569PCcFPUlSWrZvRW5ydWMjDJ87CbTY7N2AeLXk+6hxWd06T0++ssI4YRSmUR4S+EF0uStNoye5BSKyXJkRLStGHUnQlAKSLUaJCjkohulL+COXvNSA8ewlazueOFEUZCuhypQqz1IsRxs7bEnNU6mb0FuqAVPct4Vqhvb4kkVIuPYNLkWSU7KIkIiJP8u8dHBGRz7CLkoiIPIldlERERB7AOzgiIp9QcKd7MTXu39jAERH5hq0UbBeaJztFuii938BJ8b26MF5pGvpuIdxaCDuWwpR7diGEQhsnVZZC+zXJlsV1pATJZomTde+dEz7a80IMf6c0TEBTwbqEBLhRoWJLlTWRZxUWnM+RgDRMQFg+PSA/PUgXElMLXxOUMHxAontuEbSc35WSSacLwwHOB+TzMxR1PkmkOiDVGTE5s6ZeikMIhLovXiu0w5CSl2w5mTZs2IBnnnkG4XAYEydOxI9//GNMmTLFcdmNGzfi5z//Od5//30AQFFREdatWycuL+EzOCIin1Au/jOxdetWVFZWoqqqCvv27cPEiRNRVlaGzz77zHH5hoYGzJkzB2+//TYaGxuRn5+Pu+66C8ePHzfaLxs4IiKfsF18mVi/fj0WL16M8vJyjBs3DrW1tcjKykJdXZ3j8q+88goefPBBFBYW4uabb8bLL78M27ZRX19vtF82cERElJD29va4V2dnZ69lIpEImpqaUFpaGisLBAIoLS1FY2Njn/Zz9uxZdHV14fLLLzc6PjZwREQ+YUO59gKA/Px85OTkxF7V1dW99nny5ElEo1Hk5ubGlefm5iIcDvfpuJcvX44xY8bENZJ94f0gEyIiAuB+FOWxY8eQnZ0dKw+FQv3e9pc99dRT2LJlCxoaGpCZmWm0rm8bOEuItAMgTkOvhATCSkg6HNUkde0W3uvqSncslxLQdgqRYlLiZN1754XoPClxslQOyNGSZ4VoyfNR58+8U/guIkqOUusSnhB0w3kdqcK7GkUpRDKmwfm7SJdCIgFkCImNo8p5HVvsqJG+P/l7lXJABwNCNKhwrmUISZgBIFOMGDarM1Id09VLqS5LdV+6jmivLx6SnZ0d18A5GTlyJILBIFpbW+PKW1tbkZeXp1332WefxVNPPYVf//rXuPXWW42Pj12UREQ+kYwoyoyMDBQVFcUFiFwIGCkpKRHXe/rpp/Hkk09ix44dmDx5ckJ/r2/v4IiI/Obi52f93Y6JyspKLFy4EJMnT8aUKVNQU1ODjo4OlJeXAwAWLFiAsWPHxp7h/fCHP8SaNWuwefNmFBQUxJ7VDR8+HMOHD+/zftnAERHRgJo9ezZOnDiBNWvWIBwOo7CwEDt27IgFnrS0tCBwUXKDF198EZFIBN/97nfjtlNVVYXHH3+8z/tlA0dE5BPJuoMDgIqKClRUVDi+19DQEPf/o0ePJnBUvbGBIyLyiUSykEjbSQWeb+AsO5FclM7vKSGHnZS/zhaiwQA5T54UKRYR9t1pS3kldVGUQlSkEEF2TlheipQEgDPdzhXgXNT5sz0nfB8RIfKxE13iviOW83tROB9wtzXwuQHTlPP3GhSqYIZyjgwEgJDwXpfwnUeVcN4K5booSjGnppCjMt1yPg8yhByVgHxOS3VAqjNSHdPVS6kuS3Vfulbori/iNYlc5/kGjoiIeiiXuih5B0dEREOKbdmwrP6P0bNdmVVu4HEcHBEReRLv4IiIfMKGgpWkKMpkYANHROQTF1Ilu7GdVOD9Bk6KZtLlipOiKLsynDclzTSsyXlnOnN3pxApdk7MRSn3Pp8V3pOiJTuEgEUpUhIAOoRoybNCBNk5RBzLz1vnHcs7hXIAiFjO24oKkZdRJUdkuiUYcI58DMK5PEM5n2sAEFLOCWczhfKoLZy34h7kCEcpwjIoRFdKM5NnBORzRzqns9LM6oxUx3T1UqrLqkuKonQ+17TXF10EN7nK+w0cEREB6PlR404XZWpgA0dE5BOMoiQiIvIA3sEREfmEDRuWC3dfqXIHxwaOiMgn2MD5hNWtiZwTogNNc1FGNTnvpFmIO7udy88bztwtRUoCQIeQc7JDSJEnRUue7pZz6nUIkYkdQvTjOeusY/l5q8OxPKLOifvuUs776LY7HculkGelmTVcYgmzbVvC04C0QMixPB3OEZEA0BkY5lgeUZc5lnepLMfyqO28D1uTB1O6ZEgzeks5KnVRlJnSjPNShLFQZ6Q6pquX5rkonf8O7fWFBo1vGzgiIr/hODgiIvIkRlESERF5AO/giIh8QsF25e6LXZRERDSkKEShXOi4U8JExEMNuyiJiMiTPH8HJ04Pr0uGGhWGCQghxN0R53DkbiFMGZBDmzuFEOazQrJlaTiANBQAAM4KH0lHl/PfLSVOloYCAEC7dcZ5HaH8HNodyztt5+W7bOdhBYA8HCBqOyfGtQch2XLAEpItB5wTIUvDBwAgPeAc9h+xnIdOdFnZjuVRDHfegRLKAQSEhNxpQjLiNCEJc4aQhBkAMoPO7w0T6sBlhsMHdPVSqstS3ZeuFbrri3hNGgQ93ZP+CTLxfANHREQ9euZxc6OBS4354NhFSUREnsQ7OCIin+gJMnHuNjbdTipgA0dE5BN+ewbHLkoiIvIk79/BCdPD65KhKinwUoiksoUIsk4hIgsAOoVIrvNCpNg5IVKsQ0gC29Etd0PIyZOdP6vTQvThKcs58hEA2gP/z7H8rO1cHhGiJTu7TzuWR2052bKthKTKYrTkYHS3CEmYhejKgCVHUQaFZMtdac6RpZGAc3lX4KuO5bYuwth2jsgMdDtHgwYDzn93upCcGQAyg87rZAWlhONCFKWU0FxTL6W6LCdhdt6ONtmycE0aDMxFSUREnmQjCrjwDM5OkWdwA9JFuWHDBhQUFCAzMxPFxcXYs2ePdvlf/OIXuPnmm5GZmYkJEyZg+/btA3FYRETkI643cFu3bkVlZSWqqqqwb98+TJw4EWVlZfjss88cl9+1axfmzJmDe++9F/v378esWbMwa9YsvP/++24fGhGRr13oonTjlQpcb+DWr1+PxYsXo7y8HOPGjUNtbS2ysrJQV1fnuPzzzz+Pb33rW3j44Ydxyy234Mknn8Rtt92Gn/zkJ24fGhGRr9kq6torFbj6DC4SiaCpqQkrV66MlQUCAZSWlqKxsdFxncbGRlRWVsaVlZWV4Y033hD309nZic7OLwIJ2traAADt7b0DD86ddn7YG+mQf4Goc85BGJ2dzuuciTh/2Wd0M14LD6HPRp0DOs4Jsxx32s7HGrHlfnYhIxe6hJO2WzkfUxTyg3RbePouVQxp9mylpNm25UwK8num5W4S0r8Jx6o0xyR/Js6fofSZS99RVJO6TDoX5HPK+byNaG4ApHP6XNSszkh1LFNTL0NCXVZC3Q8J1wpLc305L1yTutrjg4EuXM905zrpudrAnTx5EtFoFLm5uXHlubm5OHDggOM64XDYcflwOCzup7q6GmvXru1VXnDVkgSO2kSrYTnRBdIFz/liZys516YUhNdl+KPaOZ7VZfKf4XF/SOC91x1L//jHPyInJ6ffRwQwijIlrFy5Mu6u79SpU7j66qvR0tLi2ongB+3t7cjPz8exY8eQne0c/k3x+Jklhp+buba2Nlx11VW4/PLLXdtmTwPX/+5FXzZwI0eORDAYRGtr/B1Na2sr8vLyHNfJy8szWh4AQqEQQqHe44RycnJYeRKQnZ3Nz80QP7PE8HMzF9DMvEB6rn5yGRkZKCoqQn19fazMtm3U19ejpKTEcZ2SkpK45QFg586d4vJERJQYpWzYLryk58BDjetdlJWVlVi4cCEmT56MKVOmoKamBh0dHSgvLwcALFiwAGPHjkV1dTUAYMmSJZg2bRqee+45zJgxA1u2bMHevXvx0ksvuX1oRES+1tO16EayZZ82cLNnz8aJEyewZs0ahMNhFBYWYseOHbFAkpaWlrhb7qlTp2Lz5s1YtWoVHn30Udxwww144403MH78+D7vMxQKoaqqyrHbkmT83MzxM0sMPzdz/Mz6z1KMQSUi8rT29nbk5OQgJ3McLMt56IYJpaJoO/97tLW1DelnqikZRUlEROZs2LB81EXJ8BwiIvIk3sEREflET/SjC3dwfo2iJCKiocmNQd5ubmegsYuSiIg8KWUbuB/84AeYOnUqsrKyMGLEiD6to5TCmjVrcMUVV2DYsGEoLS3F//7v/w7sgQ4hn3/+OebNm4fs7GyMGDEC9957L86ccZ5J+4Lp06fDsqy41/333z9IR5wcnM/QnMlntmnTpl7nVGZm5iAe7dDwzjvvYObMmRgzZgwsy9ImmL+goaEBt912G0KhEK6//nps2rTJaJ9KKag/DdTu3ys1gu9TtoGLRCL467/+azzwwAN9Xufpp5/Gj370I9TW1mL37t247LLLUFZWhvPnzw/gkQ4d8+bNwwcffICdO3fiV7/6Fd555x3cd999l1xv8eLF+PTTT2Ovp59+ehCONjk4n6E5088M6EnZdfE59cknnwziEQ8NHR0dmDhxIjZs2NCn5Y8cOYIZM2bgzjvvRHNzM5YuXYpFixbhrbfe6vM+/TYfHFSK++lPf6pycnIuuZxt2yovL08988wzsbJTp06pUCik/u3f/m0Aj3Bo+P3vf68AqP/5n/+Jlf3Hf/yHsixLHT9+XFxv2rRpasmSJYNwhEPDlClT1EMPPRT7fzQaVWPGjFHV1dWOy//N3/yNmjFjRlxZcXGx+vu///sBPc6hxPQz62ud9RMA6vXXX9cu88gjj6ivfe1rcWWzZ89WZWVll9x+W1ubAqCGZRSorNC1/X4NyyhQAFRbW1t//uwBl7J3cKaOHDmCcDiM0tLSWFlOTg6Ki4vFueq8pLGxESNGjMDkyZNjZaWlpQgEAti9e7d23VdeeQUjR47E+PHjsXLlSpw96805UC7MZ3jxOdKX+QwvXh7omc/QD+cUkNhnBgBnzpzB1Vdfjfz8fHznO9/BBx98MBiHm9LcONeUirr2SgW+iaK8ML+c6dxzXhEOhzF69Oi4srS0NFx++eXav3/u3Lm4+uqrMWbMGPzud7/D8uXLcfDgQbz22msDfciDbrDmM/SSRD6zm266CXV1dbj11lvR1taGZ599FlOnTsUHH3yAK6+8cjAOOyVJ51p7ezvOnTuHYcOGXXIbboX3p8owgSF1B7dixYpeD5+//JIqjV8N9Gd23333oaysDBMmTMC8efPw85//HK+//jo+/vhjF/8K8pOSkhIsWLAAhYWFmDZtGl577TWMGjUK//zP/5zsQyOPGVJ3cP/4j/+Ie+65R7vMtddem9C2L8wv19raiiuuuCJW3traisLCwoS2ORT09TPLy8vr9dC/u7sbn3/+uXbuvS8rLi4GABw6dAjXXXed8fEOZYM1n6GXJPKZfVl6ejomTZqEQ4cODcQheoZ0rmVnZ/fp7g1wL8VWqgSZDKkGbtSoURg1atSAbPuaa65BXl4e6uvrYw1ae3s7du/ebRSJOdT09TMrKSnBqVOn0NTUhKKiIgDAb37zG9i2HWu0+qK5uRkA4n4keMXF8xnOmjULwBfzGVZUVDiuc2E+w6VLl8bK/DSfYSKf2ZdFo1G89957uPvuuwfwSFNfSUlJryEopuea37ooUzaK8pNPPlH79+9Xa9euVcOHD1f79+9X+/fvV6dPn44tc9NNN6nXXnst9v+nnnpKjRgxQr355pvqd7/7nfrOd76jrrnmGnXu3Llk/AmD7lvf+paaNGmS2r17t3r33XfVDTfcoObMmRN7/w9/+IO66aab1O7du5VSSh06dEg98cQTau/everIkSPqzTffVNdee636xje+kaw/YcBt2bJFhUIhtWnTJvX73/9e3XfffWrEiBEqHA4rpZSaP3++WrFiRWz5//7v/1ZpaWnq2WefVR9++KGqqqpS6enp6r333kvWnzDoTD+ztWvXqrfeekt9/PHHqqmpSX3ve99TmZmZ6oMPPkjWn5AUp0+fjl23AKj169er/fv3q08++UQppdSKFSvU/PnzY8sfPnxYZWVlqYcfflh9+OGHasOGDSoYDKodO3Zccl8XoijTg7kqI+2Kfr/Sg7kpEUWZsg3cwoULFYBer7fffju2DAD105/+NPZ/27bV6tWrVW5urgqFQuqb3/ymOnjw4OAffJL88Y9/VHPmzFHDhw9X2dnZqry8PO4HwZEjR+I+w5aWFvWNb3xDXX755SoUCqnrr79ePfzww0P+pO6vH//4x+qqq65SGRkZasqUKeq3v/1t7L1p06aphQsXxi3/6quvqhtvvFFlZGSor33ta2rbtm2DfMTJZ/KZLV26NLZsbm6uuvvuu9W+ffuScNTJ9fbbbztewy58VgsXLlTTpk3rtU5hYaHKyMhQ1157bdz1TedCA5cWHKXS03L7/UoLjkqJBo7zwRERedyF+eCCgcthWf2PLVTKRtT+fMjPBzekoiiJiIjcMqSCTIiIaCApwJUIyNTo+GMDR0TkE+7NB5caDRy7KImIyJN4B0dE5BM9A7RduINjFyUREQ0t7jRwqfIMjl2URETkSbyDIyLyC5eCTJAiQSZs4IiIfMJvz+DYRUlERJ7EBo7oEk6cOIG8vDysW7cuVrZr1y5kZGSgvr4+iUdGZMp28TX0sYEjuoRRo0ahrq4Ojz/+OPbu3YvTp09j/vz5qKiowDe/+c1kHx6RAdXz/Ky/rwS6KDds2ICCggJkZmaiuLgYe/bs0S7/i1/8AjfffDMyMzMxYcKEXlMF9QUbOKI+uPvuu7F48WLMmzcP999/Py677DJUV1cn+7CIUsLWrVtRWVmJqqoq7Nu3DxMnTkRZWVmvSZgv2LVrF+bMmYN7770X+/fvx6xZszBr1iy8//77RvvlbAJEfXTu3DmMHz8ex44dQ1NTEyZMmJDsQyLqkwuzCQBBuDcOLtrn2QSKi4vx53/+5/jJT34CoGdS3Pz8fPzDP/wDVqxY0Wv52bNno6OjA7/61a9iZX/xF3+BwsJC1NbW9vkoeQdH1Ecff/wx/u///g+2bePo0aPJPhyiBDlOQ2f46tHe3h736uzs7LW3SCSCpqYmlJaWxsoCgQBKS0vR2NjoeISNjY1xywNAWVmZuLyEDRxRH0QiEfzt3/4tZs+ejSeffBKLFi0Su1eIhpqMjAzk5eUBiLr2Gj58OPLz85GTkxN7OXXbnzx5EtFoFLm5uXHlubm5CIfDjscbDoeNlpdwHBxRHzz22GNoa2vDj370IwwfPhzbt2/H3/3d38V1oRANVZmZmThy5AgikYhr21RKwbLiuztDoZBr23cDGziiS2hoaEBNTQ3efvvt2POGf/3Xf8XEiRPx4osv4oEHHkjyERJdWmZmJjIzMwd9vyNHjkQwGERra2tceWtr65/uKnvLy8szWl7CLkqiS5g+fTq6urpwxx13xMoKCgrQ1tbGxo3oEjIyMlBUVBQ3ZtS2bdTX16OkpMRxnZKSkl5jTHfu3CkuL+EdHBERDajKykosXLgQkydPxpQpU1BTU4OOjg6Ul5cDABYsWICxY8fGnuEtWbIE06ZNw3PPPYcZM2Zgy5Yt2Lt3L1566SWj/bKBIyKiATV79mycOHECa9asQTgcRmFhIXbs2BELJGlpaUEg8EWH4tSpU7F582asWrUKjz76KG644Qa88cYbGD9+vNF+OQ6OiIg8ic/giIjIk9jAERGRJ7GBIyIiT2IDR0REnsQGjoiIPIkNHBEReRIbOCIi8iQ2cERE5Els4IiIyJPYwBERkSexgSMiIk/6/5sex0hmMCxEAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] @@ -372,12 +354,13 @@ " origin=\"lower\",\n", " extent=(x0, x_final, t0, t_final),\n", " aspect=(x_final - x0) / (t_final - t0),\n", - " cmap=\"plasma\",\n", + " cmap=\"inferno\",\n", ")\n", "plt.xlabel(\"x\")\n", "plt.ylabel(\"t\", rotation=0)\n", "plt.clim(0, 1)\n", - "plt.colorbar()" + "plt.colorbar()\n", + "plt.show()" ] }, { diff --git a/examples/symbolic_regression.ipynb b/examples/symbolic_regression.ipynb index b5121ff3..8572424a 100644 --- a/examples/symbolic_regression.ipynb +++ b/examples/symbolic_regression.ipynb @@ -75,13 +75,13 @@ "import optax # https://github.com/deepmind/optax\n", "import pysr # https://github.com/MilesCranmer/PySR\n", "import sympy\n", + "import sympy2jax # https://github.com/google/sympy2jax\n", "\n", "\n", "# Note that PySR, which we use for symbolic regression, uses Julia as a backend.\n", "# You'll need to install a recent version of Julia if you don't have one.\n", "# (And can get funny errors if you have a too-old version of Julia already.)\n", "# You may also need to restart Python after running `pysr.install()` the first time.\n", - "pysr.silence_julia_warning()\n", "pysr.install(quiet=True)" ] }, @@ -90,7 +90,7 @@ "id": "4d26c41f-7682-4ad0-aa33-77e22b2768f8", "metadata": {}, "source": [ - "Now for a bunch of helpers. We'll use these in a moment; skip over them for now." + "Now two helpers. We'll use these in a moment; skip over them for now." ] }, { @@ -100,51 +100,23 @@ "metadata": {}, "outputs": [], "source": [ - "def quantise(expr, quantise_to):\n", - " if isinstance(expr, sympy.Float):\n", - " return expr.func(round(float(expr) / quantise_to) * quantise_to)\n", - " elif isinstance(expr, sympy.Symbol):\n", - " return expr\n", - " else:\n", - " return expr.func(*[quantise(arg, quantise_to) for arg in expr.args])\n", - "\n", - "\n", - "class SymbolicFn(eqx.Module):\n", - " fn: callable\n", - " parameters: jnp.ndarray\n", - "\n", - " def __call__(self, x):\n", - " # Dummy batch/unbatching. PySR assumes its JAX'd symbolic functions act on\n", - " # tensors with a single batch dimension.\n", - " return jnp.squeeze(self.fn(x[None], self.parameters))\n", - "\n", - "\n", "class Stack(eqx.Module):\n", " modules: List[eqx.Module]\n", "\n", " def __call__(self, x):\n", - " return jnp.stack([module(x) for module in self.modules], axis=-1)\n", - "\n", - "\n", - "def expr_size(expr):\n", - " return sum(expr_size(v) for v in expr.args) + 1\n", + " assert x.shape[-1] == 2\n", + " x0 = x[..., 0]\n", + " x1 = x[..., 1]\n", + " return jnp.stack([module(x0=x0, x1=x1) for module in self.modules], axis=-1)\n", "\n", "\n", - "def _replace_parameters(expr, parameters, i_ref):\n", + "def quantise(expr, quantise_to):\n", " if isinstance(expr, sympy.Float):\n", - " i_ref[0] += 1\n", - " return expr.func(parameters[i_ref[0]])\n", + " return expr.func(round(float(expr) / quantise_to) * quantise_to)\n", " elif isinstance(expr, sympy.Symbol):\n", " return expr\n", " else:\n", - " return expr.func(\n", - " *[_replace_parameters(arg, parameters, i_ref) for arg in expr.args]\n", - " )\n", - "\n", - "\n", - "def replace_parameters(expr, parameters):\n", - " i_ref = [-1] # Distinctly sketchy approach to making this conversion.\n", - " return _replace_parameters(expr, parameters, i_ref)" + " return expr.func(*[quantise(arg, quantise_to) for arg in expr.args])" ] }, { @@ -205,22 +177,17 @@ " niterations=symbolic_migration_steps,\n", " ncyclesperiteration=symbolic_mutation_steps,\n", " populations=symbolic_num_populations,\n", - " npop=symbolic_population_size,\n", + " population_size=symbolic_population_size,\n", " optimizer_iterations=symbolic_descent_steps,\n", " optimizer_nrestarts=1,\n", " procs=1,\n", - " verbosity=0,\n", + " model_selection=\"score\",\n", + " progress=False,\n", " tempdir=tempdir,\n", " temp_equation_file=True,\n", - " output_jax_format=True,\n", " )\n", " symbolic_regressor.fit(in_, out)\n", - " best_equations = symbolic_regressor.get_best()\n", - " expressions = [b.sympy_format for b in best_equations]\n", - " symbolic_fns = [\n", - " SymbolicFn(b.jax_format[\"callable\"], b.jax_format[\"parameters\"])\n", - " for b in best_equations\n", - " ]\n", + " best_expressions = [b.sympy_format for b in symbolic_regressor.get_best()]\n", "\n", " #\n", " # Now the constants in this expression have been optimised for regressing across\n", @@ -231,14 +198,10 @@ " # and apply gradient descent.\n", " #\n", "\n", - " print(\"Optimising symbolic expression.\")\n", + " print(\"\\nOptimising symbolic expression.\")\n", "\n", - " symbolic_fn = Stack(symbolic_fns)\n", - " flat, treedef = jax.tree_util.tree_flatten(\n", - " model, is_leaf=lambda x: x is model.func.mlp # noqa: F821\n", - " )\n", - " flat = [symbolic_fn if f is model.func.mlp else f for f in flat] # noqa: F821\n", - " symbolic_model = jax.tree_util.tree_unflatten(treedef, flat)\n", + " symbolic_fn = Stack([sympy2jax.SymbolicModule(expr) for expr in best_expressions])\n", + " symbolic_model = eqx.tree_at(lambda m: m.func.mlp, model, symbolic_fn) # noqa: F821\n", "\n", " @eqx.filter_grad\n", " def grad_loss(symbolic_model):\n", @@ -264,8 +227,8 @@ " #\n", "\n", " trained_expressions = []\n", - " for module, expression in zip(symbolic_model.func.mlp.modules, expressions):\n", - " expression = replace_parameters(expression, module.parameters.tolist())\n", + " for symbolic_module in symbolic_model.func.mlp.modules:\n", + " expression = symbolic_module.sympy()\n", " expression = quantise(expression, quantise_to)\n", " trained_expressions.append(expression)\n", "\n", @@ -276,37 +239,37 @@ "cell_type": "code", "execution_count": 4, "id": "042fd565-825a-40fb-a4da-25e3e0da106a", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training neural differential equation.\n", - "Step: 0, Loss: 0.1665748506784439, Computation time: 24.18653130531311\n", - "Step: 100, Loss: 0.011155527085065842, Computation time: 0.09058809280395508\n", - "Step: 200, Loss: 0.006481727119535208, Computation time: 0.0928184986114502\n", - "Step: 300, Loss: 0.001382559770718217, Computation time: 0.09850335121154785\n", - "Step: 400, Loss: 0.001073717838153243, Computation time: 0.09830045700073242\n", - "Step: 499, Loss: 0.0007992316968739033, Computation time: 0.09975647926330566\n", - "Step: 0, Loss: 0.02832634374499321, Computation time: 24.61294913291931\n", - "Step: 100, Loss: 0.005440382286906242, Computation time: 0.40324854850769043\n", - "Step: 200, Loss: 0.004360489547252655, Computation time: 0.43680524826049805\n", - "Step: 300, Loss: 0.001799552352167666, Computation time: 0.4346010684967041\n", - "Step: 400, Loss: 0.0017023109830915928, Computation time: 0.437793493270874\n", - "Step: 499, Loss: 0.0011540694395080209, Computation time: 0.42920470237731934\n" + "Step: 0, Loss: 0.16657482087612152, Computation time: 11.210124731063843\n", + "Step: 100, Loss: 0.01115578692406416, Computation time: 0.002620220184326172\n", + "Step: 200, Loss: 0.006481764372438192, Computation time: 0.0026247501373291016\n", + "Step: 300, Loss: 0.0013819701271131635, Computation time: 0.003179311752319336\n", + "Step: 400, Loss: 0.0010746140033006668, Computation time: 0.0031697750091552734\n", + "Step: 499, Loss: 0.0007994902553036809, Computation time: 0.0031609535217285156\n", + "Step: 0, Loss: 0.028307927772402763, Computation time: 11.210363626480103\n", + "Step: 100, Loss: 0.005411561578512192, Computation time: 0.020294666290283203\n", + "Step: 200, Loss: 0.004366496577858925, Computation time: 0.022084712982177734\n", + "Step: 300, Loss: 0.0018046485492959619, Computation time: 0.022309064865112305\n", + "Step: 400, Loss: 0.001767474808730185, Computation time: 0.021766185760498047\n", + "Step: 499, Loss: 0.0011962582357227802, Computation time: 0.022264480590820312\n" ] }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACHHUlEQVR4nOydd3hUZdqH73dKMpMOCSGh9xpC71VAlA4qCoKi2Nu66zY/3bVtU3d13XXtDWyggNIEQXoNECBAQg09hBJC+vSZ9/vjDZEqJeXMTM59XblIzpw558mQc57zPuX3CCklOjo6Ojo6/oZBawN0dHR0dHQuh+6gdHR0dHT8Et1B6ejo6Oj4JbqD0tHR0dHxS3QHpaOjo6Pjl5i0NuBKxMXFyUaNGmltho6Ojo5OJbNly5YzUspaF2/3WwfVqFEjUlNTtTZDR0dHR6eSEUIcudx2PcSno6Ojo+OX6A5KR0dHR8cv0R2Ujo6Ojo5fojsoHR0dHR2/RHdQOjo6Ojp+ie6gdHR0dHT8Et1B6ejo6Oj4JbqD0tHR0dHxS4LaQXl9+qwrnRtDH5Omo6M9fqskUV7sxU52tp9AftMWhPdoR6sh7ajZtRUiNERr03T8iLN2WHIAFh+AQ/lQ5IQiF3h80KwmJMVDUi3o11D9XJ1xeyHtFGw6DhuPw45T4PaBQH3VjoA+DaBvfeheDyL0S02nnAStgyrKs3G6Uzei03eQsG0d+e9BdkQ01sfupsmTt2EIs2htoo6GbD4O/94IKVngldAw3EOHREGExUhEqLrh7s2F1Udg9m71nr4N4P72MKARGIM69nAhNjd8vRM+2gYni9W2FrFwS1MIN4MEfBIO5qn9PkuDECNMage/6gY1rFparxPICH8d+d6lSxdZEVp8Xh+kZeSRsXgnlu/m0Xn/RuwxNan123uJf2A0why0PlrnMuTa4B/rYMuKTEbtX0oX2xESzhzDcOw4+CSmhFiMdeIxN0gkfGhfwob04pQnlNm74Yud6gbdKBpeHqAcVTDj8MCHW+HTbZDngP41S7g39gTtIuxEeu1It4eQlo0wNUhECFH2ni0nYM4emLUbIszwRFe4rwNY9EtN5woIIbZIKbtcsj3YHdT55Nrg80+30/Czj2h/dDvOTh1o+fVfMMbGVOh5dPyT79I9/PjBOm7dMIv2R9LAbMLctD4hTRtgbloPjEa82afxnMjBtfsQ3pyzGCLDCR/Rn+gHbsOQ1JLFB+DNFDiQB7e1ghf6BecKYecp+PUS8Ow5yKTcDfQ7koIpbSd4vJfsa4iLwdKpDWGDexA5YRgGSygA+3Lh1XWw7BA0rwmfjoIG0VX9m+gEArqDOo8dJyXfvrqEiTNeR9aKpek3/yC0TdNKOZeO9kgJn806RP0XX6JJzkFknQTiHr6NyLuHY6wRdfn3eL3Y122jeOYSiuevRNocRE4aQexzD+GJqcH/NsO7qRAdCq8NhpubVO3vVFl4fPDOZljyXQYPrvmE5P2bAQhp24ywQd0Jbd8SQ0QYIsyKMAicGZk4t+zCkZqB+8AxjLVjiXlqIlH3jCwLo688DE/9CGYDfDQSOidq+Avq+CW6g7qIQie88vYuxr3zHDFeG3U/fJGIW3tX2vl0tMHrk0z/8zw6ffpfvGHhNPjXb4ga1Q9hNF77MQqLyfvXVAo+moUh3ErNZx8kaspY9uQa+N1SyDgNz/aBRzpBaaQrIClwwgvvZ9J9+gd0P5CCqBlNjScnEDnuFkwJcb/4XikljnXbOPuvqTjWbcNYqybxbz9H2KDugFpxTpkLJ4rh9cEwplVV/EY6gYLuoC6DywsvfpNDn1efo8Xp/dT99g3C+nWu1HPqVB3OQjvLJ/yDFptWcKJDV7p/+SfMtW+8FM+17zBnnvsP9lWphA/rS/w7f8JlCeO3S2DBfhjfFv56E5iv3ff5DdkFPr76zQxG/vARIjKc+F9NIHrKWAwRYdd9LPv6NM489xauXQep+eyDxPx6EsJgIM8Oj/ygKgD/PQRua10Jv4hOQKI7qCvgk/DaoiK6/u4J6tlyaLL4XUJaNq708+pULj6Hk7XDniUhfSsH7n+YW/8xAWEof+mdlJKCD2eR++I7mJvVJ2Ha3zE1qc+bKfD2JuhVDz4eCeEBVGK9L+M0ux/4G0kHtmIf2J/W7/0eY83yJYt8Ngc5z7xO8eyfCBvah/j/PY8xKgKnB+6bC5uz4cux0KNeBf0SOgHNlRxUNSqWvTwGAX8cGsmSP75OoQjh0Lg/4DmVq7VZOuVAuj1suesl6u5MJfXJZxn62sQKcU4AQghiHhlH4sw38ObkcXzIwzhWbeZ3PeGNm9XqYMp8Vc0WCGQsTKNk6P00O7Yb91+epe2Mv5TbOQEYwizEv/dnYv/2NLYlG8ge8yt8RSWEmuD9EdAwBh5eoEJ/OjpXoto7KFBO6qW7E/jkkdfwnMnn8Phn8ZXYtTZL5waQXi+7p/yNmuvXsvDu33DXn4ZWynnC+nam3k8fYapfm5OTnsW2bCN3tIE3hsDGLHXzdfq5kzowYxWGB35LcUQMkQs/odWjw8vKxSsCIQQxD99Bwpev4tp9kJP3PY90uYkOhc9GgcmgVlO5tgo7pU4V4sk+zZnn/oN0uirtHLqDKiXMDC883or/jH8JmbGX7Bff19oknRvgyB//R+iPS5k17BHuff22Sm2oNTdIpM73/8XcoiEnJz+HbflGxraCVwfBqiPw5CKlvuCPHH3ve3y/+jOHE5tTZ8G7NGhfv9LOFT64B/H/eRb76i2cfurvSJ+PBtHwySg4VQxPLFKhdp3AwZtXSPadv6Vw+kLcR7Ir7Ty6gzqPulHwyO97M6fbHTinfYd93TatTdK5Dop+WIN32izm9RjH7f+eRHRo5Z/TWCOKOrPfwty8ISfvVU5qfJJq5F1yEJ5d5n+6fsdf/xz3C2+S2rwn9b9/i6ZNKr85KfLOW6n550cp/m4puS+9C0DHBPU5bciCL3ZUugk6FYSvxM6JiX/EfSib9c+9iq9Jo0o7l+6gLqJzIoT+7mGO16jLkSde1UN9AYLn5BmO/+o19iW0oNHfHqtS3TzlpP5d5qQcm9O5rz38urtSU/hwa9XZcjVyP5iF458fsSx5CE2//htJDatO8ivmqbuJfugOCt77hqKZiwFV+di/IfxjLRwtqDJTdG4Q6fZw6oEXcG7ZxfxHX+RlZ0eyiyrvfLqDugyP9LEwfeKzmI9nc/KVD7U2R+cqSJ+PY4/8Fa/dyZInXmRMO3OV22CsGU2dWW9iTKzFycnP4T56gqe7w9Bm6ua7/FCVm3QJRd/+SP6f/sPaln2p8/b/0a1h1WoPCSGI/cuTWLq148yzb+HJPo0QKiRqMsDvf9JDff6MlJKcZ17HtiyFvOd+x7+j+vF4F2hSo/LOqTuoyxBqgocf6cCcrrdh/3QW9vVpWpuk8wvkvzMD3/otfHDrr3jm7gaaNcsaY2NI/Po1pMvNyUnPQnEJbw6B1rXgVz/C/rPa2AVQsmgNp556la2NOnHqlRcZ0UYbYTxhNBL/v+eRHi+nn34VKSV1IuHP/SDlOHyuh/r8luJZSyiasYjIZ+7jN9EjaRwDj11SGF6x6A7qCnRMANfTj5Adk8jRp/+J9Ph5SVY1xbXnEGf+/hGrW/Un+YkR1L+8clGVEdK8IbU/+QuufUc49dBLWIWHj0dAqBEenKcUTKoax5YMTjz4EnsTWrD4t//gtwOqIDn3C5gb1yX25cexr9xM4WdzALizDQxoCK+uhaxCTc3TuQzuoyc488d/Y+mezLQ+93G0AP42sPIFgHUH9Qs8PTCMmaOfwnT4KPlfL9LaHJ2LkFJy8rn/YjNbWTL5d9zf0T90hsL6dyHu1V9jW5ZC7l8/oG4UvD8cjhXC88urtmjCc/IM2fc+z+nwWN55+HXeGBPmF6NCoiaPxnpTN3JffhfXgWMIAX8fqEJ8b2zQ2jqd85FeL6ef+BtSSmx//xMfpBkZ2wp6V17hZxl+8Kfqv4SZYcSjfcio25aT//gMn12Dx1+dK2Jbsh73mlSm9p3Cn0bH+MWN9xzR940h6v4xFLwzg5Il6+laF37dA+btU4UTVYHP4eTkfc/jLLDx0l3/4I07Y4j2kzFoQgji//Mswmwi53f/QkpJ3SiY0gG+3wMZOVpbqHOO/P9Nx5Gynbh//IY/708kzAx/6ls15/ajS9o/ubmpYOUdjxByJofcj2ZrbY5OKdLl5tSf/8fRuIYwYQxta2lt0aXEvvIkIW2bcfqpv+M5kcMTXaBHXXhhpRruV5lIKTnzhzdxbtnF30Y8x4Tbm9Lazz4jU2Itaj77II61W7H9pJZNj3WFaIsK9elojzM9k7Ovfkz4qJvY3uMWNmTBMz0h7volGm8I3UFdBSFg3H0d2di0O7n//hJvQSXWVOpcMwWfzEYeyuKDm5/k6T7+OQnPYAml9scvIx0uTj36Cgbp5a1b1LTZJxdVrtJE4affUzR9IV/3m4zh1gHc177yzlUeoiaPxtykHrmvvIf0eIgOhSe7wuqjsPao1tZVb6SUnHn+PxiiI4j75+94Y6OgTgRMaFt1NlSIgxJC3CqE2CuEyBRCPPsL+90uhJBCiEqu/ahYetaDtLsfxlxcxOn/TNfanGqP90weuf+cxuam3WkyugeNYrS26MqENGtArdefwbE+jbw3ppEYCf8crEJYb6RUzjmd6ZmceeF/pLfpwfc3T+FfN/vvGBBhNlHzT4/g3nuYohk/AnBvMtSLhL+v1cvOtaRk4Roc69Oo+ccHWF0QxbaT8FQ3VeVcVZTbQQkhjMA7wFCgDTBBCNHmMvtFAk8DG8t7Ti24Z3wLlrUdTOGHM/GcPKO1OdWas/+ciq/Ezse3PMGvumttzdWJvOtWIu68lbx/TcWxOZ0hTWFCEny0FdJOVuy5fCV2Tj38EvbwKP58y3O8PsRAfHjFnqOiCR/Rn9CuSZx99WN8JXZCTfDbXsqJz92rtXXVE+l0kfvSu5hbNSZy0kjeTIH6UTDukjt75VIRK6huQKaU8qCU0gXMAEZfZr+/AK8Bjgo4Z5XTLh6OTH4A4XJx4t1ZWptTbfGcPEPBF/NZ1H4YAwY3JiFCa4uujVqv/hpTnVqcfvpVfA4nz/WB+HD4/dKKDfWd+fPbuDKP8sLwP3FrtxoMDoBJv0II4l56HO+pXAre/xaAMS2hdZwaYaKvoqqegk++w3P4OHEvP8FPR03sPA1Pd6/6WWcV4aDqAsfO+zmrdFsZQohOQH0p5Q+/dCAhxMNCiFQhRGpOjv+V8Twwph7rWval+It5ugSSRuS//w3S42VOv4k8HkCBYkNkOLX+/Ufc+4+Q9/pnRIWqsup9uWrEekVQPG8FRV/M56eBd3O0TRee71Mxx60KLN3aET68P3lvf4X3TB4GAY92VuM4/EGFozrhPZNH3hvTCBvUA8tN3fl3CjSOgbEaTEGu9CIJIYQBeBP47dX2lVJ+KKXsIqXsUquWn5UcoSQ9jo69k5DiInKn/6i1OdUOb14hBVPnsrLNQAYPqEtNq9YWXR9hN3UjcuJw8t+ZjmPrLgY1httawTupsKucz2Oek2fIeeZ1Clq25p/dHuQvN+E3JeXXSs3nH0KW2Cn49HsAhjeHupHwwRaNDatmnH39M3wldmJfeYIlB2D3GbV6MmlQUlcRpzwOnN+yVa902zkigSRgpRDiMNADmBdohRLnGHZnO/YktuLUOzORPp/W5lQrCj6ZDSV2ZvSeyORkra25MWJfeRJjQhynf/UPpNPFC/0gJlSF+jw3+OckpSTnd//C53TzzM1/ZnALE0ObVazdVUFI84aE3dqHgk++w2dzYDbCAx1hUzZsq+Bcnc7l8WSfpvDL+UTdM4KQFo34NA3qRcGoFtrYUxEOajPQXAjRWAgRAowH5p17UUpZIKWMk1I2klI2AlKAUVLKyp/nXgl0qSvYPOROrFnHKFlaSWVYOpfgK7aR/8EsNrXsTavezairsaTRjWKMiqDWG7/Hvfcwef/+nBpWeOUmSD994yMnimf/hG3xOn4Y8SC58fX5y4AKNblKiXl8PL6zBRR9o5Rb7moLUSHwob6KqhLy3/sGfJKYpyaSkaMmRE9uj2ZN8OU+rZTSAzwJLAZ2A99KKTOEEK8IIUaV9/j+hhDQecpN5ETW4vBb32ptTrWh8Mv5yPxCPu85iQc6am1N+Qgf3IOIO24m7+2vcR04xrBm0K+BkvjJKbm+Y3lOn+XMc/+hJCmJf7cYx7N9oHaAFI5cDkuPZEI7t6HgvW+RXi8RITApGX48AEfytbYuuPGeLaDw8/lE3D4Yc4NEPksDqwnuquLKvfOpEL8opVwopWwhpWwqpfxb6bYXpJTzLrPvgEBdPZ1jaCsTy3vfhnXzFpwZmVqbE/RIp4v8d2awu2lHQrsk0SlRa4vKT+xLT2AIDeHMs/8GJC8PAIcH/rHu+o5z5o9v4rM5+PPNz9Iu0ciEpEowtgoRQhDz+Hjch7IoWaTkJO5rD0YBH+vzQyuVgo9nI212ajw1kVwbzNsLt7fWNpepK0ncACYDJDwwCrvZwuH/zNTanKCneMEqvCfPMLXbxIBfPZ3DVDuWGs8+iH3lZkrmr6RJDXi4E8zeDZuPX/XtgKraK1mwiu3j7ictvCEvDwCDnzbkXg/hw/thalSH/HdnAGpFOKYVfLsL8gOyScX/8RXbKPhoFmFD+xDSqjHT08Hphfs6aGuX7qBukDt6RrEqeQjyh2X4iq4zLqNzXRROnUturbpkt+sakMn/KxE9ZQwhSc0586e38RXbeLIb1ImAP6+8esGEr6iEM8/9B9mmOX9sNJ4726gRMcGAMBqJefQunJvTsW9UibkpHdQK8/s92toWrBR+MQ9ffhE1np6E2wtf7FRh5+ZVOJn6cugO6gaJCgXvbcMwu5ycnr1Ca3OCFteeQzhStjMreRT3djRoUupaWQiTiVqvP4P3RA55b0wlzAwv9FdlvVcrmDj7+qd4T5/lo1G/xRJi4o+9q8bmqiJy/FAMNaLKGnfb1IL2tWF6etWOK6kOSKeL/He/wdKnE5bObfnxAJws1n71BLqDKheDxrThaGwDsqfps6Iqi8Jpc/GazCzrMJQ7q1CksqqwdE0i8u7h5L//La7Mo9zaFPrUh7c2QsEVwlnOjEwKPppNwZhRzDC1rVJ16arCEG4lcvxQSn5ciydHSb+Pbwt7c2H7KY2NCzKK5yzHe/IMNZ6eBMDUNGgUDTc10tQsQHdQ5aJtvCCt51Ci03fgPpiltTlBh6/ETuG3i1nfdgBdkmoEXGPutVLz+YcRllByX34PIeD5vso5vX0ZhQnp83Hm929giI7gTx0epnlNuCdAe8KuRtTE4eDxUjxzMQAjW6iqshkZGhsWZBRMm4u5eUOs/btwMA9ST8CEdv6Rz9QdVDlJmHQLXmHgwCf6KqqiKZ6zHFlYzKz2o7lTw1LXysYUX5MaT0/C9uNa7Gu30qYW3NEGpm2HowUX7ls0YxGOzensvv8xMtxRPNdHmw7/qiCkZWNCu7Sl8KsFSCmJDIURLVR1WYlLa+uCA2d6Js7N6UTdOwohBLN2K8ekhazR5QjSP+2qY2jvWmxt2hXX7B+RXq/W5gQVhdPmklOnESdbJtO/kdbWVC7Rj9yJqV5tzvz5f0ivl9/3VKXVr55Xdu4tKCL3lfcwdW3H89FD6VXPP8IwlUnU3cNx7zuCM1Utm8a3hRI3LNivsWFBQuHncxGWECLvuhWvD77brYojavuJAr7uoMpJjAVyhgwlPPc0Bav0Ro2Kwrl9L85tu/mm3WhuayOCdpVwDoM1lJp/fhRX+n6Kvl1M7Qh4pDP8sB9Ss9U+eW9Mw3e2kEUTfsNZl4Hn+vrvnKeKImLsIESYlcKvFgDQORGa1dTDfBWBr9hG0beLiRg9EGONKDZkwYlitXr3F4L8sq8aOkzsQ3FoBAc/Wai1KUFD4Rfz8YaGsrjdLVU+g0YrIsYOIrRTa87+7UN8JXYe6axGcvx1DbgOHKPg49kYxw3jzYLmjG2lRsAEO4aIMCJG30Tx98vxFdsQQq2itp5QSvA6N07xd0uRJXaiJqvpSLN2K1mpm/1oRIvuoCqAXs1C2dhxEGErVuEtLNbanIBHOl0Uz1nG5qT+tGgUqXkvRlUhhCD2lSfxnsol/70ZhJnhtz2UUOruP7yLCDHzcf+HAPhdT42NrUIiJw5H2uwUz1kOKAV4s0FfRZUHKSUFU+cQ0rYpoV3aUuSERZmqEMVShRNzr4buoCoAowHMtw/F7HaRNWuV1uYEPLZlKfgKivm+xZCgLo64HNbuyYQP70f+OzPw5uZzRxsYemYrEavX4npwEl+eiOW+Dkphurpg6dYOc7MGFH6txsnFhsHAxjB/H3j1gQI3hDNtD66d+4maPBohBAszVSP07X52vekOqoIYOKYNp6Jqc2K27qDKS9Gsn7BH1yCjWWdGaCTzryU1/+9BZImdvLe/wii9PLX0bU5GJ/B/9e8iIoSAGtRYEQghiJw4HOfmdFyZRwEY0RxOl8DmbI2NC1AKp85FhFmJvGMIoCS2msRAJz9TI9EdVAXRMEawq/MAamzZrIf5yoG3sJiSJetZ0WYgA5uZiArV2qKqJ6RlYyLG3ULhJ9+R/963WDIzmTvyUbbmhzKlgyrMqW5E3n4zAMVzVZhvcBPVEzV/n5ZWBSY+m4PiucuJGDMQQ2Q4RwvUWI3bW/tf0Y3uoCqQiJEDMHk9HPv+OiWpdcooWbAKnC4WtLqZ4c21tkY7av5xCtLrI++fnxLauQ0ZXQcCEFkNHTaAKbEWlu7JlMxVsmJhZhjUWOVNbnTQY3WlZPFaZImdyHFq9XTOyftL79P56A6qAuk9sg05kbU4MXul1qYELMWzf6Kwdl2ONGgT9D0+v4S5QSKhHVohbQ7O3HY76TmCelHw4Vawu7W2ThsiRt+Ea/dBXHsPASqhn2uHDcc0NizAKJ71E8Y68Vh6dQBgYaYSGvbHIaC6g6pA6scY2N25PzU2b8JXbNPanIDDc/IM9jVbWdzmZgY3FVjNWlukHd6CInUjNhg48P0GaoXBqwNV3mXaDU7eDXTCRw4AISguXUUNaAQRITBfb9q9Zry5+diWbyTytkEIg4GjBWqas79OCdAdVAUTPmIAZo+Lw3PWa21KwFH8/VKQkvmtbmaYn14wVUX+f79CFtkoGTGUDqlL+WP8Ifo2VF3+H24BWzVcRZkS4rD0bE/x3OVIKbGYYEgTFeZz6SIu10TxvBXg8RJxuwrvLSydt6o7qGpC79FJ5EbUJFsvN79uimb9RE6TVpxNaFCtw3uek2co+GgWEbcN5tVej2EPsdJ7/lQAnu6uwlpXG8cRrESMHoh73xFce1SYb0QLKHTCmqMaGxYgFM9cgrlVY0LaNgVg0X7V8N0gWmPDroDuoCqYujFG9nbqT83NKfhK7FqbEzC4DhzFtWMfP7QazKDGVOvwXt6/piI9Xg5OfpB1xdHk3n479nkrcO09RJc6ahzHh1uqZy4qfER/MBgoKa3m69sAokP1ar5rwX0kG8fmdCJvvxkhBFmFkHYKv45W6A6qEggfMYBQl4ND8zdqbUrAULJgNQCLmg6o1tV77kPHKfxyAVH3juKNrDokREDX5+9ChFnIe2MaAL/uDmfs8OVOjY3VAFN8Tay9OyileykJMcKtzeCng6rRVOfKFM9eCkBEacn+j34e3gPdQVUKPccmkxcWQ/ZMfdLutVKycDU5TVpREle7Wof38t6YiggxcXD8vWzOhkc7Q1jtGKIfuI3iOctx7T1E17rQqx58UF1XUaMH4j5wDFfGAQCGN4diF6zTq/muiJSSotlLsPRoj7m+6sZdmAlt4qBxDY2N+wV0B1UJJMaY2N++NzGbNiHd+mPd1fCcyMG5dTc/Ne3LoMb+pQVWlbgOHKVo5hKi7hvDvw/EUisMJiSp12IeH4+wWsh783MAft0DcmzwdbqGBmtExPB+YDSWNe32qKuq+ZYe1NgwP8aVnol73xEi7lCrp5PFsOUEDPXzaIXuoCqJ0EE9CXMUk7WyGsZhrpOShWsAWNK0n1+HGyqbvH9NRVhCOHzH3WzIgse6/OysjbGlq6jvl+Had5judaFnPXgvtfqFtoxxNbD0bE/Jj2sBCDVB/4aw9BD4pMbG+SnF81eCwUDEiP6AqnwE/84/ge6gKo3ksV1wG0wcnLNBa1P8npKFqyms24AT8Y3o31Bra7TBte8wxd8tI/qB2/jPgZrEWeHupAv3iXn8LrWKeusLAJ7qplZRs3ZpYLDGhN/SC/eeQ7iPKDG+wU1Uj9jOUxob5qeULFyNpVcHjLExgMo/Na+pZmv5M7qDqiSa1g9nf5NkQtamaG2KX+PNK8S+Lo21rfrRo54K1VRH8v41FWENJWvcBFYfhYc6XVrJaIyrQdS9Iyn+bhnuI9n0qqcUAN7fUv3kfsKH9AagZLHqNxzYSE0g/umQhkb5Ka7Mo7j3HlahUaDAoUR2b2mqsWHXgO6gKgkhwN67J/HZhyg8eFJrc/wW25L14PUyv2E/BjXW2hptcO4+SPGc5UQ/eDvvHoghKhQmtrv8vjGPjweDIP+dGQihlM2PFVa/Mmtzk3qYmzfEtkTpXsZYoGsdPQ91OUp+UBWy4cP6ArDqCHilGlni7+gOqhJpOLoXABkz9TDflSj+YRWOuHj2Jraqtg4q/81pqox8wngWH4B7k68sCmtKrEXkXbdS9PUPeE7lMrgJtIhVuajqln8Jv6UX9vVp+IpKABXm231GOWydnyn5YTWhnVpjqqNGMC8/DDWt0KG2tnZdC7qDqkQ69qzPyRp1sC3THdTl8JXYsa/YxLakPjSPFX7bzV6ZuDKPUjx3BdEP3MYHB6MJNcGUDr/8npgn70a6PRR8MBODUMUUe3NhWTULb4UN6Q1uD7YVmwAYXPqAo6+ifsZz/BTObbsJH6bCe14frDwMAxqqQav+TgCYGLiEmAQnuvQkMX0rHptTa3P8DtuKTUiHi+/r9yu7uVQ38t/6AmEJwTnpLr7fA3e1VRNjf4mQpvWJGDWAgs++x1tQxKgWasLuO5tBVqNVlKVrWww1osryUI1rqKS/7qB+5lyFbPgI5aC2nYQ8R2CE90B3UJVOzC09CXU72bVgm9am+B0li9bijYxka/321TK85z6STdGsn4i6dxQfH6uBT6riiGsh5ul7kMU2Cj/5HpNBNfRuOwkbsirXZn9CmEyEDeqObdkGpFepxd7cGFKOK30+HSj+YTXmlo0IadoAgOWHVDFJvwCpltUdVCXTZXQH7GYLJ3/Qw3znI30+bMtTyGzXg6gwEx0Ttbao6sl/+2swGhAPTGB6OoxuCfWvcSZPaFIzwgb3IP/Db/HZHIxrA7XClLpEdSJ8SG98uQU4UjMAlYfy+FQhQHXHm5uPY8N2Iob3L9u2/LAqJokOkMGXuoOqZGJiQjncqjPRKRuQ1Sn+chWc2/fiO5PPj3W6c1MjMFWzv0RP9mkKpy8k6u5hfHm6Fja3WgVdDzFPTcSXW0DRN4uwmGBye1h5BPaeqRyb/RHroO5gMmIrDfN1TIBYa/XLx12Okh/Xgc9HeGl5eXaRKiK5KYCiFdXstqAN4qaexJ09wbGth7U2xW+wLUtBCsHK+t2qZXgv/3/TwefD+vhEpm1XfTwt467vGJae7Qnt1JqC975Fer1MagdWk5q6W10wRkVg7dmektJyc6NBKZyvOVL9qhovpmThakz1Ewhpp/SMlpc67YGNtLPpetEdVBXQckx3APYvqGbxl1/AtjSFvGatKImsETDx8IrCeyaPwi/nE3nHEOYVJ5Jrh4evc/UEIIQg5okJuA9lUbJoLTWsqshi7l6ltVZdCBvSG/few2WqEv0aKrX3XTkaG6YhPrsT+5othN/SGyEEoMJ79aOUgkSgoDuoKqBpUgKnYuviWZuqtSl+gTc3H+fW3aQ07UmnhMCJh1cUBR/NRjpcRD01kY+2qoFxPere2LHCh/fD1KgO+e9MR0rJAx1VE+ZnaRVqsl8TNkg9ANpXqeurr6oHYHU1zkM51qch7U7Cbu6pfvYotfeBjZWIQKCgO6gqQAjI7dCFOru34XZVM2XPy2BbsQmkZEFi92q3evIV2yj49DvCh/ZhrakhB/NV5d6N3jSE0UjMo3fhTM3AsWknDaKVAOhXO6GomlSymZs1wFgnvqwfKj5cjZFYVY2n7NqWbkBYQ7H06gBASpZyUoEU3gPdQVUZUQM6E+a0sWv5Hq1N0RzbshQ8MTHsrdOq7Gm3ulD45Xx8+UXEPDWRD7dC3cjyK0pHThiGoWY0+e9MB+CRzlDkgukZFWBwACCEIGxAV+xrtpSVm/drCFuyocSlsXEaIKWkZOkGrH07Y7Co8MTqoxBqhB71NDbuOtEdVBWRNLIjPgTHf6zeYT7p9WJbvokDSd2IshhoF6+1RVWHdLnJf+9bLL06sKdeWzYeV6oRZmP5jmsIsxB9/xhsP67DlXmU5NoqZPjpNnB7K8R0vydsQFd8BcU409QDYN8G4PaplUN1w33gGJ7D2YQN7lG2bd1R6FIn8Gat6Q6qioitG8Px+s0xbarehRLOtD34zhawpF5PetcPDLmViqLou6V4s08TU5p7igqB8UlXf9+1EPXA7YgQMwXvfwuosOGJ4p/n/gQ71n6dQQhsKzcDqtfHaqqeYT7bUtVzGTZIOajTJbAnl4CMVlSj24P22Lt0psGBdAry7Fqbohm2pSlIg4FldboG5AVzo0ifj/y3vyKkbVPOdu3OwkyY0K7ixouYatUg4o6bKfr2R7xnCxjYGBrHwCfbqof8kTE2htDkFthXqghFqEmFs6pjoYRtaQrmlo0wN1Dd72uPqe196mto1A2iO6gqpPbNXTD7POz4YYfWpmiGbdlGClu1oTAsmj7VyEHZlqzHve8IMU9N5PMdAgFMTq7Yc0Q/cifS7qTw83kYBNzfAdJOqdHe1QHrgK44UtPL1M37N4RD+XC0QFu7qhJfsQ37+jTCS6v3QIX3aligbQCG03UHVYW0uTUZt9FM7vLqmYfynsnDuW0325p3p1E01Uq9PP+dGZjq1UYMvYkZ6TC0GdS9RlmjayW0dROsA7pS8Ml3SJebO1pDVCh8mlax5/FXwgZ0BY8X+zqle3luhb6mGoX57KtTwe0hbLByUFKq3793fTAEUHn5OXQHVYVYIi1kN0siakv1zEPZ1yiJg3nx3ehbjcrLHdt240jZTvQj45i930ShCx7oWDnnin54HN6TZyiet4LwEDU2flEmZFWDGUmWrkmIMAu2FSoP1bSGqpKsTrp8JUtTEBFhWLqpiZf7z8KpEgI2WqE7qCpG9uxMw+z9HD6Qp7UpVY5tdSq+iAh21GpZrfJPBe/OwBAVQcTEkXyWpvTiOlWSOG7YoO6YmzWg4P1vkVIyuT0IYOr2yjmfPyFCQ7D26oB9lXJQolS1e/2x6lHNKKXEtjSFsAFdEWZVrre2dPUYqNdbhTgoIcStQoi9QohMIcSzl3n9GSHELiHEDiHEMiFENXp+vpBGw7oAsPuH6jV+Q0qJfVUqJ5I6IYxGegZYP8aN4j56guJ5K4m6dySrzoRxKP/qAwnLgzAYiH5kHM7te3Fs3EmdSBjWHGakQ3E16Amy9u+K+8Ax3EdV4q1vA9UTtvO0xoZVAa5dB/CeyCkL74EqkGgUreaFBSLldlBCCCPwDjAUaANMEEK0uWi3bUAXKWUyMAt4vbznDVQa92mJLTS87CmvuuA5nI3n2Ek21O9ChwSVG6kOFHw4EwyC6Ifu4NNtkBCh8k+VSeS4WzDERJaVnD/QUd2kZ+6q3PP6A2E3dQUou77OSUhVhzlZ9tIS+7CB3QC1akzJCtzwHlTMCqobkCmlPCildAEzgNHn7yClXCGltJX+mAJUk+fnSzGYTZxu24FaGWnVSm3ZtloVhsyP6xyw4YbrxVtQROGXC4gYO4gDIfGsPaYq98rbmHs1DOFWou4dRcmiNbiPnqBjggorfr49+BW+zS0aYUyIw7ZK5Xljw6BVrArzBTu2lZsxt2iIKbEWoAZYlrgDN7wHFeOg6gLn//dnlW67Eg8AiyrgvAFLSPf21MnNYveu6jO4x74qFXd8PMdq1g/IfowbofCL+cgSOzGPjeez7aqLf0IFNeZejaj7x4IQFH72PQD3tYeD+cHfFySEwNqnI47128rmr/WsD6knwBnEMpg+hxNHynbC+nct27b2qKrc6xnA11uVFkkIISYBXYB/XuH1h4UQqUKI1Jyc4NXKbzykAwAHllSDzDWqSdW+ditH2nbGYhYk19baospHuj0UfDgLa99O2Js35/s9MKYl1LBWzfnN9WoTPrQPhV8uwGdzMKy5mrhbHVTOrb074c3Jw71feeOe9ZRQ6vZTGhtWiTg2pyMdLqwDupRtW3sM2tcO7GkBFeGgjgPn++h6pdsuQAgxGHgeGCWlvKzOspTyQyllFylll1q1alWAaf5J3R7NsYdacW5I09qUKsG1cz++vELW1u1C50TV5R/sFC9YifdEDtGP3sU3GeoGeV/7qrUh+qE78OUXUTx7CSFGmNROTdw9FOQFpNbeqobfvlYVIvWoqyoZgznMZ1+5GUxGrL3U725zK4fcK8CTKRXhoDYDzYUQjYUQIcB4YN75OwghOgIfoJxTNain+WWEycSZlu2IzdiO16e1NZXP+fmn6lK9V/DBTMxN6hE6sAef74DudaF1FT9zWXq2J6RtUwo+no2UkrvbgdkA04JcyMTUqA6muvFlDbvRpSoKwVwoYV+ViqVzWwwRYYBScvf4oHuAX2/ldlBSSg/wJLAY2A18K6XMEEK8IoQYVbrbP4EIYKYQIk0IMe8Kh6s2mLu3p+HpQ+zan6+1KZWOffUWnE2akBcRG3By/zeCIzUD55ZdRD88jhVHDWQVwuQqXj2BysdEP3gHrl0HcaxPIz4chjdX1XzBXHIuhMDSuyP28/NQ9WDrSbWSDTa8Zwtw7th3QXgv5TgYBXSppH67qqJCclBSyoVSyhZSyqZSyr+VbntBSjmv9PvBUsraUsoOpV+jfvmIwc+5PFTmkuB+nD2XvD3QujNWk4qJBzsFH3yLISqCyLtuZWoaJEbALU21sSXi9psx1Iii4KPZgNLnK3bBrN3a2FNVWHt1xHcmH/few4ByUC5vcOoS2tdsBSkvKJBIyYLk2hBeQWLEWqErSWhEYs9WuEwh2NenaW1KpXIuebsisQtd6kBIJZdYa40n+zTF81cROWkEB1xhrD2mcj8mja40gzWUqEkjVcn5sZN0SIAOtWFaWnCXnFv7lOahSsN83eqoFUUw5qFsqzZjiAwntGMr9XNp/ikYwum6g9IIERrCmZZJxGWk4QniPJR9VSqYjCys0aGsaTKYKfjkO5CS6AduY9p2NcW0qkrLr0TUlLEAFE6dA6hw48F8pXIdrJgaJGKqVxv7WqX/GBmqVhTBloc6p9Bi6dMRYVLVR1tPqGGNwRBO1x2Uhpi6tafxiUwyMou0NqXSsK9Pw96qNfbQsIDux7gWfDYHhZ/PI3xYXxy1E/luN4xooZpFtcRcrzbht/am8KsF+BxOhjWHmlb4PIijy0IIrL07Yt+QhvSpJ8Ce9dTKIpjGwHsOZ+M5euKC8N6GLLVa7Bzg+SfQHZSmNLy5AwYk+5fu1NqUSsFXbMO5bTf7mnckzAzJATiP5noomrkYX34R0Q+PY/Ye1cVf0TOfbpSoKbfhyy2gZN4KLCYY3xaWHoLjQaxybundEV9uAa49hwDloDw+2JytsWEViK1U0sna/+cCiY3H1WqxooZhaonuoDQkoXcbPEYTtvXB2bDr2JwOHi8r4zvQtU7lS/xoiZSSgo9nE5LUnNDuyXyxQxWEtE/Q2jKFtV9npXL+qVKWmKimMfBVcD4bAT/3QzlK81Bd6qgy+2AK89lXpWKqG4+5qQpP2N2QdpKgCafrDkpDDGEWcpu1JjYjLSjHAdjXbQOTkcXRSUFzwVwJx7ptuPccIvqh20k5Lsg8C/f4yeoJVMgr6v6xOLfswrl9L/WiYFBjmJERvBJA5gaJmBokYl+XBqBW8bVhU5CsoKTPh33dNqx9OyOEmka4JYjyT6A7KM0xdm9Ps+y97Dxsu/rOAYZjfRq2Vq1xhAR//qng49kYakYTMXYwX+yAGAuMbKG1VRcSOf5WRJhVFXIA9yZDrh1+yNTYsErE2quD6ocqzUN1qwM7TqmVRqDj2nUQX14h1j6dyraV9T/V0dCwCkR3UBrTYGAHTD4vmcuDaxaCr8SOY9tu9jfrQJgZkoJXuQr3sZOULFpL1KQRnPaGsvgA3NlGicP6E8aoCCLHDaH4+6V4zxbQpwE0jlEq58GKpXdHfHmFuEr7obrWVXmobSe1tasiOFehaOn983jmlCxoFx8c+SfQHZTmxPdti08ISlKCKxlwLv+0unZHOicGd/7pnGJ41P1jmZ6u+ov8Kbx3PlFTxiIdLoqmL8QglJ3bTgbvQD9rDyXh4UhRXrhLHaXLFwyFEvZ12zA1qoO5nup+L8s/BUl4D3QHpTnGqAjy6jUmKmNnUDVOnss//RidRNcgCTdcDp/dSeGXCwgf2hcSa/N1OgxoBA2itbbs8oS2aYqlR3sKps5B+nzc0QasJvgySEvOTQ0TMSbE4UhRv2B0KLSOU5VugYz0enGsTysrBAEl5eT2Kd3HYEF3UH6Ar2M7mh/LIDMneColHOu24WjVCntIGN2C6IK5mOLZP+HLKyT6odtZchBOl/jv6ukcUVPG4DmcjX3FZqJDYVRLmLsXCi87YyCwEUJg6Z6MPWVHmS5f17qlzawBfLm50jPxFRZfkH9KzVarw2DJP4HuoPyC2v3aEe6ykb72oNamVAjn8k+HWnTAbFDSOsGIlJKCT74jpHUTLL068MUOqBcFAxpqbdkvEzG8P4a4GApKlSXuaQd2D8wOUn0+a49kvNmn8WSpgVDd66rfNyOAR86dyz+d76A2HVerw6gAnv90MbqD8gPq9ldNKWfXp2tsScVwLv+0NrEjSfFgNWttUeXg2LQTV/p+oh+8nQN5gg1ZMDEJjH5+VYkQM1F3D8e2ZD2e46doV1v1bH25E2QQhZnPYbkoD3Uu5BzIYT77um2Ym9bHlBAHqMKPrSeDa/UEuoPyC8wNEymOiSVke3AUStjXbQOjkQWR7YI6vFf42RwMkeFE3H4zX+9UTaB3ttXaqmsjavJokJLCL+YDStA282xg37SvREjrxhiiIsryUPHhqnoxUAslpMeDfcP2C1ZPu3OUSGywXW+6g/IDhBDY27Wj6aGdZAWB9IxjfRrutq0oMoXRLcie6M7hOX2W4nkriBw/FGeIlZm7YWgziNNYd+9aMTdIJGxQdwq/XIB0exjZQoWGvgyOZ6QLEEYjlm5J2Df+XAnStY4KiQViYZJzxz5kse2CAolzzceBPv/pYnQH5SfE9G5HQsFJtm4P4MA4SjDVsW03R1t0AIIv5HCOoi8XgNtD1JSxzN+nCgzOyQcFClH3j8F7KpeSRWuwmmFca/gxE3JKtLas4rF0T8a99zDeswWAWmkUOGF/rsaG3QDnRtmf3/+0OVvlPxMjtbKqctAdlJ/QYIC6u51cFdiPsI4tGeD2kFKnAy1ilaJCsCE9HgqmzcXavwshzRrw1U5oVjPwynvDBvXAVD+hbAzHxGRVpvxNhrZ2VQZleajSVdS5/6tAlD2yr92KuWUjTPE1AZU3TM0mKKMVuoPyE6zJzXGFWGBrgDuolB0gBPMjkoLyggEoWbweb/ZpoqeMZedpSDulcjilcmgBgzAaibpnJPY1W3FlHqVpDehVD75OB2+QzSizdGyFCA3BXpqHqh8FtcNVmC+QkG4Pjo07sfb6efV0OB9ybARlv6HuoPwEYTZR1Ko19TJ3ctautTU3jmNDGr6WzThtiAjKCwaUcoSpbjxhQ3rx1U7V6Hpba62tujEiJ44As4nCaXMBmJQMx4tg1RGNDatgRGgIoR1bl62ghFBhvk3ZgVW56Ezbg7TZyyYGw8/FHsEYTtcdlB8R1q0dzU5mkpoZmMKx0u3BsWUXJ1qpTtVgqygCcGUexb4qlajJoynympi7V4nCRgdo74kpvibhw/pR9M2P+OxObm4CcVa1igo2LN3b4dy+F1+JegLsWgdOFkNWAM0Lta9PA5QI7jk2Z0MNCzSvqY1NlYnuoPyIegPaYZReDq8OzI5J5469SJuDLXXbUzcS6gRZwhZUaTlmE5GTRjJnjyrtnRRgxREXEzV5FL68QkrmryDECHe1hWWHIDuAbtzXgqVHe/B4cWxVwsznVhypAZSHcmzYjrlFQ4xxNcq2bc4u1RgMsBDztaA7KD8isocSjnVtDsw81Lk+kwVR7YMyvOezOSj6ZhERIwdgjKvB1+nQtpaaMRTIWPt0wty0PoXT5gEwPkmFvYKtWMLSLQmEKPs7bRWrVL8DxUFJrxf7xh1Ye3Yo25ZTAofygzO8B7qD8iuM0ZEU1W9M3J50HAE4RM6+YTs0qkemsWZQOqji75fhKygm6r4xbDsJu8+o0vJAf3IVQhB17ygcm3bi3H2QBtHQr6EaZugJomIJY1QEIW2a4tikHgCNBuiYoHT5AgFXeiay2IblovAeBGcFH+gOyu8QHdvS6ngG208E1p1B+nw4Nu4gt40q5w3GJ7rCqXMwt2qMpUcyX++EcDOMbqm1VRVD5F23IkJDyoolJrZT+ZnlhzQ2rIKxdEvCkZqB9Cql2C6JsCcXigJAKNe+IQ0Aa8/2ZdtSsyHUCEnxGhlVyegOys9I6NuWSEcxezYd1dqU68K19zC+/CLSG7YnMgRaxGptUcXi2LYbZ9oeou8bQ6FLMH+/ck7BMhjOGBtD+Mj+FH+7GF+JnUGNVRn2V4EZbb4ilm7tkMU2XLuV5+2cqNQkAmGAoWPDdkyN6mJK/Hn6Z+oJtQoMCdJ5a7qD8jNieyUBkJ8SWAkAxwYlxLm0Zns6JYIhwMNeF1P42RxEmJXIO2/hu93g8MDdAV4ccTFRk8fgKyqheM5yTAZVLLHqCBwLAvmtc1i6qf80R2met2Pp36q/56Gkz6f0985bPdndSpG9cxBGK86hOyg/w9y0Po7wSEJ3pgdUf4YjZTuG2nGsJ5HOQaYH5s0vonjOMiLHDUFEhPN1OiTHq9HawYSlezvMLRuVhfnGt1X5tRlBVHJuqp+AsXZsWR4qIgRaxcEWP89DufYexpdXiOU8B7X9lMoRBtv1dj66g/IzhMGAI6ktzY5mcCBPa2uuDSkl9pQdFCcnI4UIugumaMYipN1J1H1jSD0B+3IDT3fvWlDFEqNxbtuNc8c+6pbOtvp2V2AP9zsfIQSWbu3USJhSuiSqEJ8/F4Sci1Cc3/90zql2StDAoCpCd1B+SHT3tjTMOcy2fcVam3JNeI6ewHsih/1NO2AQ0CGILhgpJYXT5hLaNYnQpGZMT1dP3SNbaG1Z5RB55y0ISwiFX6oxHHe3U1OClx/W1q6KxNItCc+RE3hOngFUQU+JG/ac0diwX8C+Pg1jnXhMDX5++ttyAprWgBpWDQ2rZHQH5Yck9k3CgCR73S6tTbkm7KVPd2vik2kdFzyFA6BG17szjxJ932gKHLBgH4xpCeFB9DuejzEmkvBRAymauQRfsY2bGkFCRHAVS5TloUrDfOdW/Kl+GuaTUuLYkIa1V3tEaU+DTyoHFYzVsuejOyg/xNq5NVIIfFsDo1DCsXEHhugIlojGQRfeK5g6F0NMJOEjb+K7PeD0Bl9xxMVETx6FLLb9XCzRBlYHUbFEaLsWCGtoWZivbqRywlv8tFDCfTAL7+mzFzToHsyDfEdw559Ad1B+iSEynOKGjUnITA8I4VjHxp24O7Sj2GsIqic6z+mzlCxcTeT4oQhLKNPT1Wj0trWu/t5AJrRrEuZWjSn8XClL3FU6JThYlCWE2URoh1ZlKyghVB7KXwslHKX6e+cXSJyzVXdQOppg6pxEm+O72HrcjzO3gPdMHu79RzjWTAnEBtMFUzR9oRpKOHk0W0/C3lyYkKS1VZWPEILoe0ddWCzRSDkofy4kuB4s3drh3LEPn80BqFDZ8SI44Yf6g/aU7Rhr1cDcrEHZttRsNWutSY1feGMQoDsoPyWxb1sinMXsS/XvuQfnwiSbE5NJiFDhkmBA+nwUfj4PS59OhDRrUKYcMSpIiyMuJmLchcUSE0uLJZYFibKEpVsSeLw40/YA/p2HcmzYjqV7cln+CdQKqlNC8PUbXozuoPyUyO7qUb3Izxt2HZt2QoiZHy0t6ZwY+Lp057Cv2Izn6AmiJ6viiPlBXhxxMZcrlggmZQlL1wsLJdrUUnO9/K1h13P8FJ5jJ8smAgPk2eFAXvAXSIDuoPwWc9P6OCOiiNiVjsuPe1DsG3diSG7FIUdoUIX3CqbNwVirBuHD+vL9XlUcMSHIiyMupqxY4vtlZcoSq49AVhAUSxhrRGFu0bDMQZkM0N4PhWPPTQA+X0Fia6ksUzBdb1dCd1B+ihACd7s2tDqWQUaO1tZcHp/diTNtDzmlAwq7BMkF48k+jW3xeiInDAOzmek7lRhnsClHXI2yYonSMN+5Yolv/XtRf81Yuibh2JyO9KnEWudE2HVGSQj5C46U7Rgiwwlp27Rs25ZsMApVsBPs6A7Kj6nZM4lGZw6zY78fZm4B57bd4Paws147LCYVJgkGCr9cAFISdc8otp1UatcTq0FxxMUIIYi6ZxTOrbtx7txPvSg1huObXcFRLGHp1g5ffhHuA8cA6Jygfq8dpzU27DzsKTsI7ZqEMP6sBrvlhKoktZo1NKyK0B2UHxPbWz2ynlzvnw27jo0q/LAsJon2tcEcBIrK0uOh8MsFWAd0xdyoDtPTIcwMo4JkrMb1EjluiCqW+EKtoiYkqTEcqw5ra1dF8HPDrir06VgaAfCXfijv2QLcew5h7ZFcts3thbRT1SP/BLqD8mssHVXDrkjzVwe1E1OLRmyyRQeNHphtaQreEzlETR5NkVMVR4xqEVzqGNeDsUYU4SMHUDxrCb4SO4MbQ60w+DoIBGTNTetjqBFVpmxe0wpNYn7O8WjNuQfA8wskdp9RSvqdgiScfjV0B+XHGCLDKWnUmPoH08n2syif9PlwbE6nuF0ybl/wSP4XTpuLsXYs4UN6MXcv2D3Vo/fpl4i6Z5QawzF3OWYj3NlGafOdDAypyCsihMDSpe0FwrGdSxt2/WGSgD1lB4SYCe3YqmxbdWnQPYfuoPyc0M5taZ3tfw27rj2H8BUWk9lYhUk6BsEKyn3sJLZlG4maOAJhNjE9A9rEVY9k9C9h6ZGMuXlDlZtDFUv4ZHAoS1i6JuHedwRvnipN7JQIZ+1wpEBjw1AFEpaOrTFYQsu2bT0BiRFQJ0j6Da+G7qD8nNq91YTd/Vv8a8LuufDD2vhkGkVDXJjGBlUARaU34MhJI9h5CtJPw/ik4OntulHUGI6RODen49x1gIYx0Ke+qubz+tdz03VTlodKVd723MpEa9kjX7EN5459F8gbQWmDbjVZPYHuoPye8G6qUKJks389rjo27sSYEMcKd2JQXDDS7aHwqwWEDeqOuX4C09PBYoIxra7+3upA5J23Qoi5TJ9vQhJkFcEa/3puum5CO7QCo7EszNc8FiJDtC+UcGzZBR4v1vPyTyeLlRxTdQnvge6g/B5zswa4wiOI3p2Bw6O1NT/j2LgDX8d25NiDY0BhyZL1eE/lEjV5NCUumLsPRjSH6NCrv7c6YKwZTcS5YgmbgyFNIdYK0wO8WMIQbiU0qRnOVPWLGIQKV2tdKOFI2Q4Gg5JkKqVsQGEQXG/XSoU4KCHErUKIvUKITCHEs5d5PVQI8U3p6xuFEI0q4rzVAWEw4GnXhpZZGew4pbU1Cs/xU3iyTpHVXIVHguGCKZw2F2NiLcIG92D+Pih2qfCezs9E3TMSX0ExJfNXEmKE21vD0kNKoy+QsXRNwrFlN9KjngA7J6rhhUVO7Wyyp+wgpG1TDJHhZdu2noBQY/Cr6Z9PuR2UEMIIvAMMBdoAE4QQbS7a7QEgT0rZDPg38Fp5z1udqNEzicanD7E90z/uBPZSeZgtdZIJN0PLWI0NKifuI9nYV24matIIhMnEjAxoXjN4lDEqCkuvDpib1CsL841PUo2ts/yzC+KaCe2WhLTZce06CKgHLgls1+iBULrcOLdkXBDeA7WCahcPIUHQb3itVMQKqhuQKaU8KKV0ATOA0RftMxqYVvr9LGCQENU99Xzt1OzRBgOSUxt2a20KAI6UnYhwK8tCmtIhAYwBHigu/HIBCEHUxOHszoFtJ1WORf8LvRBVLDEKx6aduPYeomkN6FEXZmSoqr5ApUw4tjQP1SEBBNoVSjh37EPanVjOa9B1eCAjJ3jaOa6Viri11AWOnfdzVum2y+4jpfQABcAlz91CiIeFEKlCiNScHD8VoNOA0M5qQSq2Z/hFf4Zj4w7MndqSkWcK+PyTdHso+voHwgb3wFS3NtMz1BPqbXpxxGWJvOtWMJvKlCXGJ6mS7PXHrvJGP8ZUNx5jYq0yBxUVCi1itXNQ9pTtwIUNuumnweUlaBrirxW/evaVUn4opewipexSq1Y1CrReBWN0JLYGDWl0OJ2jGvdneAuLce06wNk2yfhk4OefShavw3v6LFGTR2F3w/d7YGgzqGHV2jL/xBhXg/Bh/Sj6djE+h5OhzVQhyQz/KjK9Lq7UsLvthDYrQ8eG7Zib1scUX7Ns27mijUC/3q6XinBQx4H65/1cr3TbZfcRQpiAaCC3As5dbQjtkkTr47vYckLbJZRzcwZISUaD0gKJAH+iK5w2F2OdeMIG9WBhJhQ6deWIqxF170h8eYWU/LAai0kVSyw+oBpcAxVLtyQ8R0/gOXkGUA6q0AWZZ6vWDunz4di084LwHqgCifpREB9+hTcGKRXhoDYDzYUQjYUQIcB4YN5F+8wDJpd+fwewXEp/CFYFDvG92hBtL2T/Fm1jKY5NO8FoZGV0G5rVhGiLpuaUC/fh84ojjEamp0PjGJVX0bky1j6dMDWqc0FPlMsLs/0jRXpDWLqqp5JzqyitGnZdew7hyy+6ILwnZfVr0D1HuR1UaU7pSWAxsBv4VkqZIYR4RQgxqnS3T4BYIUQm8AxwSSl6ZXB832kKzvhH5Vt5sZb2Q9hStS2Zsm/cQUhSMzbmhwV8/qnwy/lgMBA1cTj7z8LmbF054loQBgNRk0biWJ+GK/MoLWLVzXN6un9o2N0Ioe1aICwhZQMMG8Uo8diqbth1nBtQeJ6DOl6kSvkD/Xq7ESokByWlXCilbCGlbCql/FvptheklPNKv3dIKcdJKZtJKbtJKQ9WxHl/iaO7TlDSZxwp/11Y2aeqEkJaNMRtDaPGnnSKXdrYIF1unFt34WyfTL4jsMN7qjhiIWFDemGqE8/0dDAb4I7WWlsWGESOHwomY1mxxN1Jagz5Zj8ZVXG9iBAzoR1al62ghFB/31XdsOtI2Y4xsRamhj97o+rYoHsOvyqSqEgatEnkeP2WWL9fgC+Qa2BLEUYj3uS2tM7KIE2jLnfnTlX+erCJyj8F8kyakkVr8OacJeqekTg88N1uuLlJcGgKVgWm2rGE39qHohmLkE4XI5pDVEhgj+GwdE3CuWMfPrvq0O2cqJxuVeXWpJTYU3Zg7ZHM+V04W0+A1QSt46rGDn8iaB0UgO+OEdQ9eZC0JQHeSVhKjZ5taXz6IGkHbJqc/1z4Y318O2Is0KSGJmZUCIWfz8NUrzZhg7qz+ADkOdQqQOfaibp3FL6zBRQvXI3VDKNbwcL9kO/Q2rIbw9K9Hbg9alI0P/ccbauiPJTnyAm8J3IuyD+BWkG1TwBTUN+tL09Q/8rdHhmMPcTK8U/ma21KhRDTMwmj9JGToo3DdWzcialRHdY44uiUoHTLAhH3wSzsq1KJLC2O+DpdVUj1bqC1ZYGFtX8XTA0SKTovzOf0qlL9QKSsUKL0QSw5XjmFqiqU+Ln/6ecKPpsbduVUX1WToHZQkTXDONprII3XLScvV5tVR0US2rkNUghMaelV3p8hpcS+cQfGLslkng3shG3hl/PBaCTq7uEczIOULFWJFqgOVytUscQI7Gu24jpwjDa1oENtFeYLxGIJY81ozM0bljkoq1np3m2popC6I2U7hphIQlo1Ltu2/RR4ZWBfb+UhqB0UQMOHR2J129n44TKtTSk3xuhIHI0a0/RIepX3Z7gPHsN3Jp/sFir/FKiSK9LlpnD6QsKG9MSUWIsZGeopedzF6pE610TkhGFgNFL0lZqlNT4J9uVqP0/pRrF0S8KxOR3pU4OuOiXC9pPg9lb+uR0pO7B0T0YYfr4tp5YWnVTHAgmoBg6qzeA2ZCc2xjRrfkA+1V2MtVsSbbMy2FLFE3YdG9VT5dY6yRhF4E6ZLVm4Bt+ZfKLuHY3TAzN3waDG1a8BsqIwJcQRfmtvCqcvRLrcjGoB4ebAHcNh6dYOX34R7kw16KpzItg9St28MvGcPov7wLFLGnS3nFDCxTEB3G9YHoLeQQkhcI0dQeOju9mxOlNrc8pNrd5tiXAWc3DLkSo9ryNlB4YaUawyNKBNLQgzV+npK4zCz+diqp9A2E1d+emgqtDSiyPKR9S9o/Cdyadk0VrCQ2BMS1iwHwo0HFdxo1i6lwrHlj6QnQutpVbyivDchOrz+598UlXwVdfwHlQDBwXQ5bFbcBvNHPr4B61NKTfnBpg5N1ftI6ojZQeh3ZLZdtoQsBeM68Ax7Gu2EjVpZFlxRL1I6KsXR5QL64CumOonUPj5XAAmtFPq23MCsFjC3KQ+hriYsjxUnUhIjFCOojJxbNiOsIYSmtyibNuBPOXkA/V6qwiqhYOKTojmSLd+NFq5mML8AHysOw9zk/q4oqKJ35deZf0ZnlO5uA9lUdAuGbsncC+Ywi/mgdFI5N3DOJIP647BXUmBPy5Ea84pS9hXb8F9MIt28Wpu0dc7A69YQgihBhiWOihQf++V7aDsG7Zj6dIWEfJzaOKcikWg5nsrgmpzadZ9YCSRjiLWfbpKa1PKhRAC2TGJtlnpVdafcU5+ZVdDFR8PRAclnS6Kpi8i/NbemBLimJ4ORgF36sURFULk3apYovDLn0vO9+Sq2VqBhqVbO9wHs/Dk5AGqQCGrCE4WV875vAVFuDIysfTqcMH21BNQwwJNYirnvIFAtXFQbUZ0JCeuLuLbeQH3VHcxcb3a0iD3KDv3Vs3sDUeKCj+sDm9JQoQKewQaxQtW4TtbQNTk0bi8PxdHJERobVlwcEmxREtVLBGIyhLWbioP5dx8YR6qsioTHRt3gJRYe3a4YPu5/FN11oasNg7KYDRgGzOC5ge2syPlqNbmlIvI7ioPdTalaobw2FN2ENq5Lak5poC9YAqnzcXUqA7W/l346SCcsetjNSqasmKJhWuICIHRLWH+vsArlghJbgEhZuylYb42tSDUWHnCsfb1aRBiJrTTz8v5s3aVgwrEaEVFUm0cFECnx4fiMRg58GFgK0uEdmyNz2AkdGd6pfdnnBtQ6O2UTFZRYF4wrr2HcGzYTtQ9oxAGA1/thLqR0L+h1pYFF9YBXTE1SCwrlrg7KTCLJQyWUCztW+LYpJZ/IUY1Br6yKvkcG7Zj6dgagzW0bNu58H11zj9BNXNQMfVjOda5Nw2X/0hBkVtrc24YQ5gFZ4vmtDyazq6cyj2Xc3MG+HwcaKrKXwPRQRV+Ph/MJiInDONwviqOGK8XR1Q4FytLtKsNSQFaLGHpkYxz+158DrX861JHjV23VfBtw1dsw7l93yX5py0nVAN5cnzFni/QqHaXaML9I4mx5bNu6lqtTSkXkT2SaJW9my1Znko9jz1lOxiNbIhtQ6hRhTsCCZ/dSdE3i4gY3g9TrRplxRF36cURlULkhGGlYzjUMMNALZaw9EgGlxvnViUc2yVRSQ5V9CQBx6ad4PVivYyDaltLyS1VZ6qdg2oztitna9TG901gK0vE9krC6nZwJKVym48dKTsITW5BSl4YHRNUuCOQKJ67HF9BMVH3jSkrjhjcBGrrxRGVgikhjvBbfh7DMbqlaur+eufV3+tPWLolgxA4NigB186JIKj4eVf2DeoB0NKlbdk2txfSTlVfgdjzqXYOymAyUjJqOK33bmZ7VY/LrEDOSaLI1B2V5mil04Vz226MXZNJPx2Y858Kp83F3KwBll4dWHIAcvXiiEonavIofLkFFP+w+udiiQBTljDGRBLSpmmZwni0BVrGVnwln2N9GqEdWmKI+HkQWUaOyt1VV/2986l2DgqgwxPD8AoD+98L3GIJU2ItHImJND6wgyOVVG3u2LYH6XSR1SIZr4SuAeagnBmZOFMziJo8CiEEX+1UyhH9dOWISsXavwumRnUonKaKJSaWFkt8t1tjw64Ta8/2ODalI90qjN6ljnJQ3gqSwfTZnTi27cba88L5T+dWaV3rVsx5KgvpcuMtKKrUc1RLBxXTuDbHO/ag0fKF5BdVbg6nMgnp1p52R7ez+XjlLKEcpU+PmxPaIQi8J7rCqXMQlhAi7xrKgTxYn6UrR1QFwmAg6t5RONan4dp3mHa1lbjwVwFWLGHp0R5ps+PcuQ9QDqrYBXtzK+b4zi0Z4PZguaj/aXM2NIyG2n4uYFy8YBVHkm/DtfdQpZ2j2l6qtR8YRc3is6yeuk5rU26YWv3aUcOWz74txyrl+I6UHZibN2RdSQ1ax0FU6NXf4y/4im0UzVxCxOiBGGtE8fVOVRV1V9urv1en/ERNGAZmE4XTVLHExHaw/2zF53Aqk3Nh9HN5qHMh7or6Hezr00CICxTMpVTHD4RoReEn32GqHYe5eeX1a1RbB9X6tu7kxcTDN3MD6qnufMJKQwPO0pVORSK9Xhyb0wntnszWk/4fbriYopmLkSV2ou4fg8MDs3bDkCb+/1QaLBjjahAxcgBF3yzCZ3MwsgVEhqhVVKBgqh2LuWl9VciACg8nRPw8o6m82DdsJySpGcaonyt2DuSpJl1/d1DO9Ewcm3YSdf/oC+ZXVTTV1kEJkwn7mBG03buZ1E0B9Fh3HuZmDXBFx5CwZwcVPTDYlZ6Jr7CYM0kdsLn9/4I5HyklhVPnENKuOaGd2rAwE/Id6ilep+qImjwaX0ExxXOXE2aG21rDwkyqTOS4IrD0bI8jZTvS50MIVVlXEQ5KOl04U9MvKS8PlPxTwaffIayhRE4YXqnnqbYOCqDjk8PxCgOZHwRmsYQQAjon0+7YjgrvcrdvSANgW/0OQGCVvDo27cS16yDR949RxRE7oHEM9KqvtWXVC0vP9phbNLygWOJcqX+gYO3ZHl9BMa7dBwEV5ssuhuOF5TuuY+tupMN12fxTnNW/BWK9BUUUz/6JiNsGY4ypXGHOau2gIhrGk92pJ82WLySnIDCLJWr1S6ZOfjbp6RU78tO+Lg1To7qsdcdTLwoSA0ggtnDqHAyR4USMHcyeM0qi5u52YAhADcFARghB1OQxOLfswrljHy3j1Er8651qGF8gcM6BXJyHKu8DoX3tVhACa++OF2zfdFydw5/1LoumL0LaHERPua3Sz1WtHRRAnYdGUbPkLGs+C0xliYjeKg9VtK7i8lDS68WxIQ1r7w6kBkjC9hzeM3kUz1tJxLhbMESE8dVOJfQ5rrXWllVPIu+8BWENpWDqHEA9KBwugHUBotdsrp+AqV5t7KUjZ1rHKZX28hZK2NdsITS5xQUrkJPFcKzQv6836fNR+Nn3hHZNumC4YmVR7R1U8zHdya8Rj/HbeQHzVHc+oUnN8FisRKXvwF5BOmGujAMqd9CxIzk2/75gLqZw+kJwuYm+fwwlLvhuDwxvDjWsWltWPTHGRBIxdjDFs5fiLSxmWDM14+iLACqWsPRsj2NDGlJKTAbomFA+ZXOfzYEjNQNr304XbD/n9Lr5cf7JvioV98EsoqeMrZLzVXsHJYxGnLePJGn/ZtavzdLanOtGmEy4k9uSdGQ7aacq5pj29WkA7GjYAQgcByW9XgqnzcXSoz0hrRozb5/qW7lbL47QlOgpY5E2O8Xf/IjFpEr9lx6EE5Xb41lhWHu0x3v6LO6D6v7QtQ7sPgMFjhs7nmPTTnB7sPbpfMH2TceVLJQ/610WfPY9hrgYIkYOqJLzVXsHBdDhiRF4DUaOfDBXa1NuiNi+yTQ+fZCteyvmirev34apUR3WuWsTHQrNalbIYSsd27KNeI6cUDdECZ9vhzZxgVXgEYyEtm9JaKfWFHw2ByklE9upHNT0ABlmaClt53CUPrh1rweSGw/z2ddsAZMRS/cLn5xSs6FTgurX80fcR7KxLV5P1KSRiNCQKjmnn34UVUtYvThOdO9D69WLOJoTQIJhpcT0aY8ByZm15Y+bSJ8Px4btWHt1JDVbJWwDpbig8NPvMcbXJHx4P7acgF1n4J5k/044Vxei7h+Le/8RHOu20SAaBjSC6RlU+jyzisDcrAHG+JrY120DVIgv1Agpx2/sePZ127B0anOB/l6BU63K/Dm8V/DZ9yAE0fePqbJz6g6qlMaPjSHaXsD6D1Zobcp1E9qpDV6jidBtO3CV84J37TqIL78Id5cOHMwPnPCe+9BxbMs3EnXvKESImS92qMbQ0S21tkwHIGLMQAw1oij45DtAPTicLoElBzU27BoQQmDt2xn7mi1IKbGY1ADDlBvICHgLi3Fu24O1z4X5py3ZalXmr4LMvhI7RV8uIHxEf0x1qm5Ile6gSql/a2dya9cn6vu5OAOs4twQZsHVphVtD29jRznzUOeeEnc0VOWvPeuV17qqoWDqHDAaiJo8mjM21RB6RxsIr5pIhM5VMFhCiZo4nJJFa/GcyGFAQ6XM8MUOrS27Nqx9O6s81N7DAHSvq1THC68z4OJI2Q4+3yUFEqnZKrTXKaGCDK5gimf/hK+gmOgHb6/S8+oOqhQhBMa7R9PyaDorfqzcGUuVQY2bOtMyew+pe4vLdRzH+m2YGiayxl2biBA1EdXf8dkcFH39A+HD+mFKiOObDNUQOkkvjvArou4dDT4fhV/Mx2hQyh4bspRGn79j7acKGmyrtwDKQfnk9atK2NdsRYSGENrlQlHIDcehXbx/DiiUUlLw8WxCkppfkjerbHQHdR7Jjw7FZQoh95M5Wpty3dS8qRNG6eP06hvvh5I+H/bS/FPKcRXe89eE7fkUf78MX34R0VPG4vUpvbde9QKnuKO6YG5cl7CB3Sn8fB7S7eHOtmoA5pcBsIoy10/A1Kgu9tWpgBpgaDZcfx7KvmYrlm5JGCw/Ky+XuGDHKf+NVjjWbcO1+yDRD92u1GuqkAC4/VQd5ppR5N40kPYbl7D7UInW5lwXoV3a4jWHEJa69YbzUK7dB/HlFeLq3JEDef57wZyPlJKCT7/D3Koxll4dWHEYjhepHIeO/xE1ZSzeU7mULFhFXBgMa6aEfItdWlt2dcL6dcaxPg3p8WA1qxEiG68jD+XNzceVkXlJefnmbPD41EOVP1Lw8WwMNaOJGDu4ys+tO6iLaPnUGKxuO5vfW6y1KdeFwRKKMzmJdoe23nAeyr62NP/UqAMQGA7KuWUXrh37iJ4yFiEEn+9QiuVDmmptmc7lCBvUHVOjOhR8PBuA+zoo5xQIwwytfTvjKyrBmbYXUOXmO09fu3M9l9+9OP+0IUutxvyxQMJ97CQli9YSdc9IDNYL5+14fGr1V5noDuoi4nq0IadxS+rP/448e2BJS9Qc2JlmpzLZkpF/Q++3r07F1KguazwJRIb4d8PgOQo+moUhKoLIcbdwIA9WHVG5jUAITVZHhNFI9AO34di0E+eOfXSoDcnxMG2H/w8ztPZRhUP20jxUz7rgldc+Bt6+Zgsi3Epoh1YXbN+QpaoC/TH/VPDJ7CuWli8+AD0+gX0VNMDxcuiX8UUIIajx4G00OHOEn77eqrU510XsTerJLHfVtut+r3R7sK/bRtiALmzIUv0Y/n6T95zIoXjeCiInDscQEca07SqncXeS1pbp/BKRE4YhwiwUfDwbIdQqKvMsrKucuZsVhjGuBiFtm2FboxxU59Ic7bWUm0spsS3fhLVPJ4TZVLa90KlWYf4YrfAV2yj6YgERI/tjqlv7ktc/S4MYKzStUXk2+PktSBua3zOIkvBofF9+j9entTXXTmiHVrgtViK2XH8eypGagSyx4+zelUP50MMPL5iLKZw6F7w+oqfcRpETZu2CEc2hlj6U0K8xRkcSeectFH+3FO+ZPIY3h5pWmFbxczcrHGv/Ljg3p+OzOwkzq8q7aymUcB88hufoCcIGdr9g++bjqhrQHx1U0fSF+AqLiX70zkte23la5c4mJ4OxEr2I7qAug8EaimPsCDplrGHVhgoSuKsChNmEs1P7G8pD2VduBoOBrQ3VKswfL5jz8TmcFHw+l7BbemNuVIdZu6HEDfe119oynWsh+sHbkU4XhV8uwGKCCUmw9JBS8/ZnrH07I50upaeHuk52nALbVYSa7cs3AVzioNZnKVWKTn4mxyW9XvI/nEVo1yQsndte8vq0NKUbOO7SlyoU3UFdgeSnRyOAwwGmz1drYGca5B5lW1rOdb3PtmozoZ1as64gkqhQpWHnzxR/vwzfmXyiH74Dn1RP3x0ToL2fNjrqXEhIy8ZY+3VW+nweD5PagcD/G3etPZLBZCwrN+9eVxULXE3d3LZ8I+Ym9TA3urASYkMWdEwEi+kKb9QI25L1eA4fJ+aRS1dPZ2wwdx/c1gqiQy/z5gpEd1BXwNIokZzuvWi/ah77sgNHn+9cHipv1bXnz7z5RTi37SGsfxdSstRFV5nL9vIipaTgo1mYWzXG2qcTq4/AoXx99RRoRD90B97s05QsXEOdSFV5OSOdChsbUxkYIsKwdG6LfY26vrqU5qHW/0IeyudwqvzuRaunfAfsyvHP8vL897/FVD+B8OF9L3lterpqhJ9cBdebH9+GtKfJU7cTYytg7YeBo88XktQMZ3gkUVu3XHMeyr52K/h82Lp25XAB9PBjwUoAR8oOXDv3E/PQHQgh+CwNaoXBsOZaW6ZzPYTd3BNTw0QKPpwFqAeMAifM2auxYVfB2r8LzrQ9eHPziQhRK/c1vzCA0ZGyA2l3Yr3IQaVkKf09fwunO3fsw7E+jegHb0eYLlzaub1qldu3AbSIrXxbdAf1C9S+uTN5iQ1I+D5wSs6FwYC7ayeSD25l24lrs9m+ajMi3Mqm2iqg7O8FEgUfzsQQE0nEHUM4lAcrS0vLQ4xaW6ZzPQijkegHb8excQeOtD10r6taGz7Z5t8l52E39wQpsS3fCKibdfppOGu//P625RsRoSFYe3W4YPuGLMqEZ/2J/A++RYRbiZw04pLXFh+AUyVwfxVFK3QH9QsIIYh44HZaZu9m0awMrc25ZhIGd6J24Sm2bLi2ul3bylSsvTuyNttEDYsaa+2vuA9nU7JwDVH3jsIQZuHTNOWYJuq6ewFJ1MQRiIgwCt7/FiHgwY5Km++XViRaE5rcAmOtmtiWrAegX0O1ErpSmbxtxSYsPZIxhF841nlDlpIT86cHK8/xUxR/t5Sou4djjIq45PVP06BhNNzUuGrsKZeDEkLUFEL8JITYX/rvJRXxQogOQogNQogMIcQOIcRd5TlnVdPigVuxh0UgP/s2IGbXANS8pQcAJUs3XnVf95FsPIePYxnQldVH1NOgP+efCj74FowGoh+6g3wHzNwFo1tAvF5aHpAYIsOJumckxXOX4zl+SrUJhMHH19/KV2UIg4GwwT2wrdiE9HhIjoeoUFh95NJ9PcdP4d5z6JL80xkb7M31v/Be/gczQUL0Y5fepredVE3Jk9tX3Yy48t6KngWWSSmbA8tKf74YG3CvlLItcCvwlhAippznrTIMEWG4bhtJ1/RV/LT2pNbmXBPmRnUoqteQ+mkbyLX98r72Vaoa6URyV3Js6mnQX/HmF1H49UIibxuMKSGOr3eC3QMPdNTaMp3yEP3QHeCTFHzyHaEmpaO46oh/q5yH3dwTX0Exjk3pGA3Qu75a9V0cmrRdobx8dekKsW+DqrD22vDmF1H4+Twixg7EXP/SuOMn29SMtbsqubT8fMrroEYD00q/nwaMuXgHKeU+KeX+0u+zgdNAAIjo/EzSM7eDEGS9951fx8bPJ3RgD9ofSWPdvisExkuxrdyMsU48qwzqSvFnB1U4bS7SZif6sbtweWHqdnVjaB1Qf006F2Oun0D4yAEUTpuHr9jGpHaqN+gzP15FhQ3oCmYTtqUbAOjXAE4Uw4G8C/ezLd+IsU485paNLti+6jDEWv1rnE3h1DnIEjsxT9x9yWtZhbBwv+pXi6jCGWvldVC1pZTnlKhOApfqYZyHEKIbEAIcuMLrDwshUoUQqTk519fHU5mE1K9NXr/+dFs3n9TMqyxJ/IT6I3sQ4nVzZNGWK+4jPR7sa7YQ1r8LK48KWscpoVV/RLrcFHw0C+uAroS2bcYP+1Wy9kF99RQUxDx2J77CYoqmLyQ2DMa2gtl7IO+Xn680wxAZjrVne0p+Ug6qT+lK6PzcmfR4sK/eQtjAbheMqfBJtYLq17DqQmVXw+dwUvDhLKw3dSM0qdklr3+Wpv69v0OVmnV1ByWEWCqESL/M1+jz95NSSlSu8ErHSQS+AO6XUl5WQEhK+aGUsouUskutWv71WNzymXFEOorZ8u6PWptyTYT3ao/LYsW8NuWKqz7HpnR8+UUYb+rJlmzo78erp6LvluI9lUvMY3chpcpRNK0BAxppbZlORWDp3BZL1yTyP5yF9HqZ0gEcHjXby18Ju7kn7j2HcB89QYNoaBR9YR7KvmE7vsJiwgb1uOB9O0+pir8BfnS9Fc9cjDfnLDFPXbp6KnLCjAwY3hzqRFatXVd1UFLKwVLKpMt8zQVOlTqecw7o9OWOIYSIAn4AnpdSplTkL1BVxPRMIrdFG1ovnMmRPP8X6BMhZkq6diV5zwb25FzeQ5UsXgchZtKadcPt89/wnpSSgvdmENK6CdaburEpW5X1PtDRf55AdcpP9GN34Tl8nJIfVtMyToXNpm4Hp0dryy5P2M09AbCVrqL6NlS6fOf6D0t+WI2whBB2U7cL3rfyiFLN8JfrTXq95L8zg9D2LbH26XTJ699kqJEiD136UqVT3hDfPGBy6feTgUt0gYQQIcD3wOdSylnlPJ+m1HlyHPXOZrH4kw1am3JN1B7Wg9qFp9my5tAlr0kpsS1ai7VPJ1acCcNqgi5+pgd2DvuKTbh2HST60TsRQvDRVqhhUVIrOsFD+LC+mBvXI//tr5FS8khnyLHB93u0tuzyhDRtgLlxvZ8dVAOlybf1hLq+ShatxXpTt0vKy1ceUcMOa1ovd9Sqp2ThGtwHjhHzxIRLJuZ6fCq8160OJP9iAqdyKK+DehW4WQixHxhc+jNCiC5CiI9L97kT6AfcJ4RIK/3qUM7zakLdOwZQGJdA7enTr1od5w8kDlehhaLFlzpUd+ZR3IeyCL+1N6uPQK/6EOpnemDnyHvrC4x14om8Ywj7cuGng3Bvsn/Oz9G5cYTRSMxTE3Cm7cG+Zgu960PbWvDhVpW38UfCbu6Jfd1WfDYHPeuBUaj8kjNtD97s04QP63fB/vkOSDvpP+F0KSV5b36OuWl9wkcNuOT1RZmQVQQParB6gnI6KCllrpRykJSyeWko8Gzp9lQp5YOl338ppTRLKTuc95VWAbZXOcJsIuKRu0g6up3536Zrbc5VMSXW4mzj5iRu3XCJvlnJ4nUAnO3RiyMF/hNuuBj7xh04Nmwn5vHxiBAzH25R3ff3ddDaMp3KIPLOWzHWjiX/7a8RAh7trCrjlh7U2rLLEzakF9Lhwr5mC1GhpbJHR1R4D6OR8Ft6X7D/6iPK2fpL7tS2ZD2u9P3E/PoehPHCjmEp4b1UaBIDg6uoMfdi/Lgl0z9p+tBw7OFRhHz69VUl9v0B84AetDmazsY9RRdst/24jpB2zVntUut2f0rYnk/+f77EEBtN1KQRnChSOm3j2/pPeESnYhGhIUQ/eif2lZtxbt/LsOZQLwrev3IxqqZYe7bHEBVByYJVAPRvpGYlFSxYg7VXe4w1oi7Yf9URiLGoEJ/WSCnJ+9dUTA0Tibz95kteX3MUMnLgkc7aNe/rDuo6MYRbYeJYuu5ey7xFl2kd9zMaj+mBUXo5uCC1bJv3TB6OzemE39qH1UdQFUgx2tl4JZw792P7aQMxD9+JIdzKx9vU06dW4QadqiF68mgMURHk/fcrTAaVnN9yQg3I8zdEiJnwYX0pWbgG6XQxpAnUO3MEeeDIJeE9n1QOqp+fqLXYl2/CmbaHGk/fc8GU33O8m6raTsZqmOv1g48p8Gj9m9vxms2UvD/d7+WPorq3xRYehWnFmrJy85KlKeDzYRrcm/VZ6oLxR/L/+yUiIoyoB8aS74Cv02FkC6gfdfX36gQuhshwoqaMpWT+SlwHjnFnG1UU84GfrqIixgzCV1iMbcUmWsbCiCOrAVX0cT67clTRhz+E96SU5L0xFVO92kTedeslr287qbQCH+qkbW5ad1A3gDGuBvZRw+m1dQk/bjijtTm/iDAasfXvS4eMdaRnqblWth/XYUysxcboFtjccEtTjY28DK4Dxyiet5Lo+8dijI7k8x2qQurRzlpbplMVRD90ByLUTP7bXxFmVvpvPx2EfblaW3Yp1n6dMdSMpvj7ZQgBA/evYW+d1jjiLpSJWHlY/esPD4T2tVtxbE4n5qmJiJBLq43eS1XDCCckaWDceegO6gZp88fxmHxejr09028rjM7R6O5BhLts7Ph2Iz6HE9uKTYTf2ptFBwRRof45XiP/v18hQkxEP3onDg9MTYObGumyRtUFU3xNoiaNpOibH3Efyea+9mrE+P82a23ZpQiziYgR/Sn5cR2uA8eIzdzN6pb9WHVRBmDJQUiOh1oaq7VIKcn752cYE+KIvHvYJa/vP6vGakxuX7WyRpdDd1A3SGjjOhTeNIA+a+fw09ZCrc35ReIHdaQkIhrjj8twrEtD2uyEDu7NTwfh5ib+JfcP4D6YRdE3PxJ1zyhM8TX5eifk2uHxLlpbplOVxPxqIsJoJO+tL6hhVa0F8/fBwbyrv7eqiRg7CGmzk//frwDY2b4fi88TdMsqhO2nlBqD1tiXb8KxYTs1fn0PBsulM9s/SC2tlPWDCdW6gyoHbf58L+EuG/vf/NavRWSFyURB/wEkZaznxNdLMESGs71RRwqdcKsfhvfO/uszRIiJmKcn4fCoZG2PetDNzyf96lQspsRaRN0zkqIZi3AfyeahTuph6h0/XEVZerbHGF+TkoVrCGndhNbdGrDi0M+qEgsz1b9DL5W5q1Kkz0fu3z7A1DCRqHtGXvL60QL4bo8K7cWGaWDgRegOqhxYk5qS378/fVfNYlla0dXfoCGN7h6I1e3A++Mqwkf0Z9GxUMLM/tf/5Np3mOJZPxH9wG2YascyPV0lln/d/erv1Qk+Yp6epFZR//6cuDCYmKSUJY4WaG3ZhQijEWv/LvjyCwkffRO3NIVCF2zMUq8v3K+ajhvGaGomJfNW4tq5n5p/mHLZ3NPbm8BkgMf8JNerO6hy0vrF+4hwlrDnzZl+vYqqN6g9NksERpeTsNuHsOSgyulY/Ew94uzrnyHCLMQ8eTcOj0rWdq/rf4PddKoGU0IcUfeOomjGj7gPZ/NIZ3UDfTf16u+tckplggzREfSpr66tJQchu0hVxWkd3pNuD2f/8RHmVo2JuEzf05F8mL1bTaeufekwXU3QHVQ5CWvXjLy+/ei7YiYrd/jvKkoYjRAdiQTWRrQgx+Z/4T1nRiYlc5cT/fA4jLExzMhQIzX01VP1JuZXExFmtYqqHaEG5s3aBcf9KPUrfT7s67YhQkOwr9iM1azkjJYcVKsn0D68VzRjIe6DWcQ+//AlqhEAb29Wzv9RP8r16g6qAlCrqGIy3pzlt6sob14hYWdOI4Ct36YQYoSBGsmXXImzr32CISqCmMfHl62eutXRV0/VHbWKGk3RNz/iOnCMx0pvoP60irKv3Yr3+GmsA7piW74Rz8kzDGkCJ4th1m5oHQdNamhnn8/u5Ow/pxLaNYmwi+SXAA7nw3fnVk9+NBNOd1AVQHj75uT16kOf5d+yIr1Ya3MuS/G8FeD1UhgRQ4N1y+nbQPsS0vNxpGZgW7SW6MfvwhgTyTcZ6uJ+untZ5ESnGhPz9CREaAhn//oBdSLVKmpGhgpL+QNF3/yIISqC2D89Ah4vhV/MZ1BjJR67+wwM0zi8V/DeN3hP5KjV02UuqP9t8r/VE4CfZSACl9Yv3cfJIQ+y5tUZ9P/iQb+QMjmf4plLMLdsxOHGXej801x80YWAf0gySCk588L/MMbXJOaRO7G5VbK2ax010r2qcLvdZGVl4XA4qu6kfobFYqFevXqYzf4lFW+Kr0nMUxPIe+1THJvT+VX3JGbthjdS4L+XCiFUKb5iGyULVhF5xxBCWjXGOrA7hdPm0vDX99C8pok9uXBLE+3s82SfJu8/XxA+vD/W3peOoD6cryr37mvvX6sn0B1UhRHesSUFgwZy04pv+GHdGEb1jdPapDLcx07i2LiDms89xO6IXiT/OIuYhQuh/3itTQOgZM5ynJvTqfXWsxgiwvhkk6rce3941a6esrKyiIyMpFGjRpd9ygx2pJTk5uaSlZVF48Z+Fv8FYh69i8LP5pD74jvU+eFdpnQQvJeq1EXaaNjAXTxvBdLmKJMMin5gLCcnPkvJwjV45U0AHC+ClhrdEnJffg98PmJfeeKyr/9zvSrff8TPVk+gh/gqlLZ/e4gQr5tTr031qymgxbN/AiB0zM3MlM3Y3aAdtebOwefVfjKwz+Ek9y/vE5LUnMjxt3LWrpSrb2kKXepUrS0Oh4PY2Nhq6ZwAhBDExsb67QrSEBFGzWcfwLE5nZKFa3i0M0SGqhuslhR98yPmJvUI7ap0gcIG9cDUIJGcj74j8yxYjCoPpQX2Ddsp/m4pMU/ejbnBpRNJt52EBfvh4U7+t3oC3UFVKKFN6+EYN5r+mxbw3cKjWpsDqKfiopmLsXRrxwpPImftcOa224jPPU767E1am0fBBzPxHDtJ7CtPIIxG/rdJae79vpc29lRX53QOf//9IycMw9yiIWdfeY8oo4fHu8Dyw7D5uDb2uPYfwbE+jci7hpZ9dsJoJOq+0fg2ptHw9EGGN1c6ggVV7Pel18uZ/3sLU914Yp6aeOnrEv6xFuKs8LCf9D1djO6gKpg2L96HJyQE35sfUuTU2hqwr9mCe98RIu8ezjcZUCcCxjzVn7zwmuR+/J2mtnlOnyXvrS8Iu7UPYX07c7QAPt8Bd7aB5jU1NU3HTxEmE7EvPIb7YBaF0+ZyX3uID4fX1qNJBW3+O9MRlhCi7r1QlSF8/HBcphAe3TeH+zsqRYn5+6rWtsIvF+DKyCT25ScxhFkueX35Idh4HH7dw78Kps5Hd1AVjKlWDcSDE+iRsYqZX2s/dbfgg5kYa9WgYMhgVh+BcW0gJsrMgZtHUi8thcID2g3ZOfu3D5EOJ7EvPQbAmylgEPCbHpqZpDlGo5EOHTqQlJTEyJEjyc/Pv6HjTJ06lSeffLJijfMTwob0wtq3E2df/Rhz3lme7qZmRf144OrvrUg8J89QNHMJkROGY4y7sIZ8TVEMy9sMpMumH2ljKaFlbNWG+Twncjj7yvtYene87Ch3jw/+sU5Nyx3ftursul50B1UJtPzdXZRE16TWf9/haL52eR7XgaPYlqwn6r4xzDqgRCHHlf4xNn1sFBID6W/N0cQ225otFH39AzGP3UVI0wbsOAVz9sCUDpDgJ13sWmC1WklLSyM9PZ2aNWvyzjvvaG2S3yGEIO61Z/DZHOS+9A7jk6BVLPx1NTiqMPdb8OEs8HiJeeyuS177Kh1W97kNg81O0ZfzuaONyvdknq18u6SU5DzzOtLtJv7NP1w2bDtzl1It/2NvMPuZWPT56FV8lYAhIoyY5x4m/I+vMu/vC3ny9RGa2FHw4WwIMRNx7xi+/QH6Nvh52F/njvHMTOpDq/k/4Hv9AQzWS1WNKwuf3UnOM//E1KguNX4/BZ+EP62AuDB4vGuVmfGLvLxKDZirSNrUghf7X/v+PXv2ZMeOHQAcOHCAJ554gpycHMLCwvjoo49o1aoV8+fP569//Ssul4vY2Fi++uoratf2g3nilUxI84bUeGoieW9OI3L8UF4a0IXxs9VQw6erQHnEV1RC4dQ5hI8cgLnxhSrG2UWw4jA8dktrrLu6kvfWF4xeOZxXRSSzdyunUJkUTV+IbWkKsX97GnOTS7vcC5zwxgbonOifs+DOR19BVRL17xtKXttkes14j9Vp+VV+fm9+EUUzFhJ522DWO2qSXQzjzxs+JgT4JowlvKSQQ18trVLb8v75KZ7Dx4n/9x8wWEP5JkONIni+L0RVnZ/0a7xeL8uWLWPUqFEAPPzww7z99tts2bKFf/3rXzz++OMA9OnTh5SUFLZt28b48eN5/fXXtTS7Son59T2YGtXlzO/foHuckxHNldJ5VhVIIBV+Pg9fUQkxT0645LUZ6SofNr4txL7wGL78Ikwff0X/hirMV5kVvp7s0+T+6W0sPdsT/eBtl93nX+vV+JpXBvh/E7y+gqokhMFAq//9luODprDj+ffpPvfZKh2dXPjlfKTNQfQjd/JNBtS0wuCLWlsGjO/E9v80J+Hf05D33IwIrfxMqXP7XvLf/YbIicOx9ulEnh1eXacEYce0rPTTXzPXs9KpSOx2Ox06dOD48eO0bt2am2++meLiYtavX8+4cePK9nM6VQVOVlYWd911FydOnMDlcvll/1JlYbCGUuv1Zzhx52/Jf/trnn/0fpYegr+tgfeGV955pctN/vvfYu3bCUuHVhe85vHBN7vUlIAG0UC75kSMG0LBhzN5cPZY7j5cu2ycRYXbJSWnf/M60usl/j//hzBcuv7YcQq+2AH3dYCk+EuP4W/oK6hKJCKpCfaJd9J30w/M/nJHlZ1XejwUfjwbS59OnKjXjMUH4I7WXOIgEyIF2+97nIjTJzjxfuVX9Em3h9O/eQ1jzWhiX1JNg6+thyIn/GWA/z/NVQXnclBHjhxBSsk777yDz+cjJiaGtLS0sq/du1XG/amnnuLJJ59k586dfPDBB37bw1RZhN3UjYixg8h76wtijx3gia5q9tLaSuzyKJq5BO/JM8Q8efclr604rCS67j7PAdV89kGklDT78hPaxcP7qVAZLYgFH3yLfflGYv/0yCVhR1DnfG65CqU/EyCFSLqDqmSSX7mPwprx1P7nGxzLrZoMbvH3y/AcP03Mo3eWaWw92Ony+465vwspzXpQ8O9peM9W7pCd3L+8j2vnfuJefwZjTCRpJ1U45P4O2nXZ+ythYWH897//5Y033iAsLIzGjRszc+ZMQD0pb9++HYCCggLq1lU3o2nTpmlmr5bE/u1pjNGRnHroRR5s5aBBtLoR29wVfy5vYTFn//4hoR1bY72p2wWvSamcT0IEDDpvIWuun0D0Q7dT/M2P/KbGAQ4X/DzAsKKwrUol98V3CR/en6gHLh/a+2on7DwNL/QLnFC67qAqGUNEGDX//msanT7I4mem4qvkXg1fiZ3cv35ISLvmnO7ak+/2wKTkK3eJt4iFPfc/jtFm4+TrlXeDK16wioL3viFqym1EjOiPwwN/WAq1wvVxGleiY8eOJCcnM336dL766is++eQT2rdvT9u2bZk7dy4AL730EuPGjaNz587ExVVPL2+qVYP4d/+Ee98Ril/6L68PhiMF8Nq6ij9X3j8+xpuTR9zrz1xSHbfiMKSegCe7XloZV+PpezBEhdPq43doGiN5d3PF9W25j2Rz6qEXMbdoSPz/nrtsaO9UiVLc6FMfRraomPNWCVJKv/zq3LmzDCY2TPq73BfXV876bGulnif3Hx/JzLg+0rZhu3xmsZTN35byVPEvv2fXaSn/ffPrcl/CAOk6cKzCbXIdOCYPNr5FHrv5IelzOKWUUr6ySsoGb0m5/FCFn+6G2bVrl9Ym+AWB+jmcefk9mRnXRxbNWS5fXKn+vtYdrbjjO9L2yMz4fvL079+45DWvT8qhX0nZ+1MpnZ7Lvz//49kyM66PXPXcFxX2t+8ttsmj/SfLg01vla6DWZffxyflPd+re8GBs+U/Z2UApMrL+AF9BVVFdH33afJq16PeK6+wPzO/Us7hPnqC/HemE3HbYE62TC5bPcVfRWOrdS04MnkKToOZky+/X6E2+exOTj7wAhgN1P74ZURoCOuOwcfb4J5kNdVXR6ciqPl/DxLauQ05z7zOM/WyaRQNv18Kxa7yH1v6fOT84Q2MsdHUfP6hS15flAkZOSq3E3KFvqKoKWOJGDOQuh9/xC0nNvPO5vLZ5HM4OfXIy7h2H6L2Ry9fNu8EMDUNVh2BP/fTdibVjaA7qCrCGBlG/U9eItpWQMZDr+HyVHysL/eld8FgIPbFx3h7s7pQHr1Gja0pN8fyda9JuBauovCrHyrEHunxkPOrv+NK30/8O3/C3CCRAif8bonqYH++T4WcRkcHAGE2UfuDF0FKCu7/P97oXsjxQlXVV14Kv5iPc+tuYl9+AmN05AWveXyqr6h5TRj9C5WoQghqvfUsIS0b8czMlzm66yQbsm7MHl+xjZMT/4ht8TriXvsNYRflw86xO0cpRtzcBCa1u7FzaYnuoKqQxO4tyHvyUdqlr2XBC7Mr9Nj2tVspmb+SGr+axFFrPN/vUX+QV1s9naNdPJy6+27SmnYl5/f/wp5SvqpD6fVy+lf/oHjOcmq++BjhQ5T66wsrVDz8rVvA6l8jh3SCAHPDOiR8/nfcB7NI+M3veay1ja/TYd7eGz+ma/8Rzv5FyQZF3DHkkte/3wMH8uC3PbnqHDhDuJXan/2VEOnh798/z0uLndetfuHNLyL7zt9iX5dG/DvPE33fmMvu5/DAUz9CdCi8Nigwq2R1B1XF9Hh+HIc69yHp47dZ/cGKCjmmr6iEnD++ialBIhGPjuePSyHMfO2rp3P8oa+Jv97+MrmxiZy8/3ncx07ekD3S6yXn6VcpnrmEms89RI3SctzP0mDOXtXp3z7hhg6to3NVrH06Ufujl3Cm7WXi+8/Rs5aT3y9VFWzXiyf7NCfGPYMICSH+rWcvKYywueGtFPWAd+s1qjKENK1P7ff+TOPj+/jNvx/j0++uvSbetecQ2WOewrl9L7U/eYXIOy8/rVFKpYay/yy8OQRiw675FH6F7qCqGCEEvb99gaON2xL/wiukf1e+kRfS4+Hkgy/iPpBFrX//gQ93hbIpW/UV1brO+S4t4+CRgZE8M/ZV3A4PJ+/5P3zFtuuzx+ki59evUfTNj9T44xRq/OZeAJYdhFdWK2mVJ/1EzkgneAkf2pf4/zyLc80WXvvhJRIMDh6aDzkl134Mb34R2Xf9Dm9BMYkz/om50aUDyl5YoYYR/qnv9a1Qwm/pTcJXr1Kv5DT9fv0guz9chPyFsj5vYTFn/vRfjg24H8/x0yR+9RoRw/tdcf8PtsDX6fBYZ9U0HLBcrnLCH76CrYrvYk5nF8pl7e6T6YmD5dGVO2/oGD6fT57+3T9lZlwfWfD5XLnjpJRN/ivl4z9I6fPdmF0er5S3fyvlpN+kyMz4fvJo33ulI23PNb3Xvm23PNLnHpkZ10fmvvpx2fadp6Rs9Y6UI76WssR1Y3ZVBYFavVbRBNPnkP/RLJlZq6/c132SHPLnTHnbN1I63Fd/n9fmkFnDH5eZdW6SJatTL7vPrF2qUvCN9TduX+7B03Ju16dkZlwfmTXuGZn3/rfSsWOf9Hk80n0iRxYvXidzX/tEHmo9UmbW6itPP/O69JzJ+8Vjfr9b2fXkQlXBFwhwhSo+zR3Rlb6C3UFJKeXePblydYvxckf9W+XpH1Ou+/15706XmXF95JmX35U2l5Q3TZOy+8dS5tvLZ9fhPOVQXnx5gzyUNEZm1u4vc1/9WPqcl/cu3mKbPPPXD2Rm7f7yULuxsnjJz1dsdqGUXT+SsucnVy931xp/uDEDcuLEiWU/u91uGRcXJ4cPH35dx2nYsKHMycm5oX384XOoSEpWbJKHWo+S++oMlE/e8518ZL7vF51UydIUeaTH3TKzVl9Z9P2yy+6zP1ddI3fOVA915WHRHo/847jP5Y6242RmXB/1Veemn7+v1VdmjXjimh4U1xyRsul/pbxz1rU5Yn/hSg5K1+LTkBYta3Jq6r85/cAfCJv0O0oeuIuGLz98VU086XKT//bXnH3tE8JH9KfG84/wfytUovbr2yD60tlk10XDGBWyeG55D2Lf/Jw75/yXvH9NVZN5uyQR0roJIc0b4Np3BNuqVBybdoLLTeSEYcT+5cmyKqd9uXD/XChxw3fjrr1gozoTHh5Oeno6drsdq9XKTz/9VKYUoXNjhA3oSr2Vn3H6yb/x60VvcmDLHD6bO5jxfxhETHMVtvPZnbgPHOPsPz7CtmQ95ib1SJzxT8IGXtpF7vDAEwvBYoL/3Hr1woircWtLIwvuv4eR++7hzaRT3Hw2Def2vZjqJxKa3ILQds0xRFw9ibQhCx79QZWSfzjiUmmzQCQIfoXApm/PBNbP/Ygffv0uwz/5hsx1W6n/1m8J7dTmsnNcHKkZ5DzzOq7dBwkfPZDoN/+PX/9kYO5eeKIL9K5fMXbdnQRbTsC/dkVyZtzz/H7kAIqmzcGRsp3i2T+V7RfSthnRD91O+NC+WLsnl21ffQQeL72Ip98WeFJGZ57/L870/RV6zNCk5sT97VdX3W/YsGH88MMP3HHHHUyfPp0JEyawZo2qlT579ixTpkzh4MGDhIWF8eGHH5KcnExubi4TJkzg+PHj9OzZ84J8xpdffsl///tfXC4X3bt3591338Vo9OMhQJWAKb4miTP+SdFXP1D8yUKazvmQ3DkfklcvAQqK8BWp5JSICCP2pceJfugORMilZaYFDlUZtycXpo2uuNll/7oZCp3wm/Ta/O2mW5g07pZrfq+UMHU7/GU1NIpRdkUHiJTR1dAdlB/Qq3kosdN+w5v/6MqDs17l+K2PYqqfQPjwflj7dsabm4/ncDauvYcoWbgGY2ItEr58FXf/3kxeoMY2/7EXPNal4mwSQl00MRb4ZBvktujNG1/0JtQE3oIi3PuPYmqQiCn+wtnsPqnUkl9epWSUPh0FdSKvcBKdyzJ+/HheeeUVRowYwY4dO5gyZUqZg3rxxRfp2LEjc+bMYfny5dx7772kpaXx8ssv06dPH1544QV++OEHPvnkEwB2797NN998w7p16zCbzTz++ON89dVX3HvvvVr+ipogDAai7hlJ+3tGsmLdSZb8bzntzuwjqXcM9ZvFYqwdS9jA7phqx172/XvPwEML1LynVwfBgEYVZ5vFpFY9jy+E51eAR8J97a/+PodHzVKbuUv1Ov17CEQGiXMC3UH5DS3j4M8v9+HJVl8Tu34Nww6tpu0n31Pw/rdqB4MBU/3aRD98BxG/f5C1uWH8/Vs4Vghv3wqjKmFUhUHAn/tCfJhq9jtRBE90hf4NI7F0uXBOtNcHP+yHtzer0N5NjeB/QyGi8id4VArXstKpLJKTkzl8+DDTp09n2LBhF7y2du1aZs9WPXQDBw4kNzeXwsJCVq9ezXffKUX64cOHU6OGkgxYtmwZW7ZsoWtXVTppt9uJjw+AOQuVzE29EwhvdDfPLFHX0ODGah7Z5ZQWnB71t/38CvX3/M0dathfRWMxwfvD4clF8OJKWHMEHukMXetcWiFY7FKzpT7bBocL4Olu8Ose6poNJnQH5UckRMCMB6KY3Ws4r20cTt6ZEm517ad2kzjCGyUQH21ixylYMAPyHVArDL4aC90qMUUhBDzaReWP/roG7p8HiRFwe2sl25/vgDyHGm9wIE910//nFiVIWd7YfHVm1KhR/O53v2PlypXk5ube8HGklEyePJl//OMfFWhdcNCtLiy9R/Xn/W8z3PylcjzNa0LzWDAbYOVhWHtM9Tt1TlRzpq4kvFwRhBjhnaHwbqqya9ws6Jig2jOkVCurnBKYswcKXeq1lwYEr2SY7qD8DJMB7mqrhvfNyAjn210dWFkEZ0t1uywmGNIExrSCfg0uVU2uLG5rDSNawNKDMCNDTS49l+WICoWmNeDdYTC0WfA9xWnBlClTiImJoV27dqxcubJse9++ffnqq6/485//zMqVK4mLiyMqKop+/frx9ddf86c//YlFixaRl5cHwKBBgxg9ejS/+c1viI+P5+zZsxQVFdGwYSA3x1QcFpMKjd/RGt7fAttOKtWJwlL9vjoRcFsrGNhI9RNVxfVmNqpm9oc7qdDdR9vUUM9zmAwwrJkaU9OpElZy/oTuoPyUUBNMbq++QMWaT5dArBXCNQqbhRhhWHP1ddaunuiiLeqC0alY6tWrx69+dWmY8aWXXmLKlCkkJycTFhZWNgPqxRdfZMKECbRt25ZevXrRoEEDANq0acNf//pXhgwZgs/nw2w288477+gO6iJqhSsxVVB/16dtUOKCxjHaSQRZzXBveyWqbHOriITJAEYRmLJFN4I4v9rHn+jSpYtMTU3V2gydasTu3btp3bq11mZojv456FQ1QogtUspLyrz0Z18dHR0dHb9Ed1A6Ojo6On6J7qB0dM7DX0PeVUV1//11/AvdQenolGKxWMjNza22N2kpJbm5uVgs5dTK0tGpIMpVxSeEqAl8AzQCDgN3SinzrrBvFLALmCOlfLI859XRqQzq1atHVlYWOTk5WpuiGRaLhXr16mltho4OUP4y82eBZVLKV4UQz5b+/Mcr7PsXYHU5z6ejU2mYzWYaN26stRk6OjqllDfENxqYVvr9NGDM5XYSQnQGagNLynk+HR0dHZ1qQnkdVG0p5YnS70+inNAFCCEMwBvA7652MCHEw0KIVCFEanUOs+jo6OjoXEOITwixFEi4zEvPn/+DlFIKIS6XXX4cWCilzLrc+IiLjvEh8CGoRt2r2aajo6OjE7yUS0lCCLEXGCClPCGESARWSilbXrTPV0BfwAdEACHAu1LKZ69y7BzgyA0b9zNxwJkKOE4woX8ml6J/Jheifx6Xon8ml1JRn0lDKWWtizeW10H9E8g9r0iippTyD7+w/31Al6qs4hNCpF5OQqM6o38ml6J/Jheifx6Xon8ml1LZn0l5c1CvAjcLIfYDg0t/RgjRRQjxcXmN09HR0dGpvpSrzFxKmQsMusz2VODBy2yfCkwtzzl1dHR0dKoH1UFJ4kOtDfBD9M/kUvTP5EL0z+NS9M/kUir1M/HbcRs6Ojo6OtWb6rCC0tHR0dEJQHQHpaOjo6PjlwStgxJC3CqE2CuEyCwtga/WCCHqCyFWCCF2CSEyhBBPa22TvyCEMAohtgkhFmhtiz8ghIgRQswSQuwRQuwWQvTU2iatEUL8pvS6SRdCTBdCVDvJdyHEp0KI00KI9PO21RRC/CSE2F/6b42KPGdQOighhBF4BxgKtAEmCCHaaGuV5niA30op2wA9gCf0z6SMp4HdWhvhR/wH+FFK2QpoTzX/bIQQdYFfoXo4kwAjMF5bqzRhKnDrRdvOCYY3B5aV/lxhBKWDAroBmVLKg1JKFzADJWxbbZFSnpBSbi39vgh106mrrVXaI4SoBwwH9L49QAgRDfQDPgGQUrqklPmaGuUfmACrEMIEhAHZGttT5UgpVwNnL9p8TYLhN0qwOqi6wLHzfs5CvxmXIYRoBHQENmpsij/wFvAHlBSXDjQGcoDPSsOeHwshwrU2SkuklMeBfwFHgRNAgZRSn8yguKpgeHkIVgelcwWEEBHAbODXUspCre3REiHECOC0lHKL1rb4ESagE/CelLIjUEIFh20CjdK8ymiU864DhAshJmlrlf8hVc9ShfYtBauDOg7UP+/neqXbqjVCCDPKOX0lpfxOa3v8gN7AKCHEYVQYeKAQ4kttTdKcLCBLSnludT0L5bCqM4OBQ1LKHCmlG/gO6KWxTf7CqVKhcEr/PV2RBw9WB7UZaC6EaCyECEElNOdpbJOmCDXr5BNgt5TyTa3t8QeklP8npawnpWyE+htZLqWs1k/GUsqTwDEhxLmpBIOAXRqa5A8cBXoIIcJKr6NBVPPCkfOYB0wu/X4yMLciD17eke9+iZTSI4R4EliMqrj5VEqZobFZWtMbuAfYKYRIK932nJRyoXYm6fgpTwFflT7cHQTu19geTZFSbhRCzAK2oqpht1ENZY+EENOBAUCcECILeBElEP6tEOIB1HikOyv0nLrUkY6Ojo6OPxKsIT4dHR0dnQBHd1A6Ojo6On6J7qB0dHR0dPwS3UHp6Ojo6PgluoPS+f/26lgAAAAAYJC/9TR2lEQAS4ICYElQACwFpueunVT6NF8AAAAASUVORK5CYII=\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAnYAAAHWCAYAAAD6oMSKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAADucElEQVR4nOydd3hc9ZW/3zszmhlJo2bLlrtlufeGG9hUG9OMwbQAIQlJSEJCyvLLJkuy6ZuesEk22WRDQiAEQqjGNGOwMeDee+9NliVZbaTR1Pv74zv3jkaWXDW6U877PH48vnfGOpY8d84953w+R9N1XUcQBEEQBEFIeWxWByAIgiAIgiB0DJLYCYIgCIIgpAmS2AmCIAiCIKQJktgJgiAIgiCkCZLYCYIgCIIgpAmS2AmCIAiCIKQJktgJgiAIgiCkCZLYCYIgCIIgpAkOqwPoaCKRCCdOnCAvLw9N06wORxAEQRAE4ZLQdZ2GhgZ69eqFzXb2mlzaJXYnTpygb9++VochCIIgCILQoRw9epQ+ffqc9Tlpl9jl5eUB6h+fn59vcTSCIAiCIAiXRn19PX379jVznLORdomd0X7Nz8+XxE4QBEEQhLThfEbMRDwhCIIgCIKQJkhiJwiCIAiCkCZIYicIgiAIgpAmSGInCIIgCIKQJkhiJwiCIAiCkCZIYicIgiAIgpAmSGInCIIgCIKQJkhiJwiCIAiCkCZIYicIgiAIgpAmSGInCIIgCIKQJkhiJwiCIAiCkCZIYicIgiAIgpAmSGInCIIgCIKQJkhiJwiCIAiCkCZIYicIgiAIgpAmOKwOIFV5+5vPUjiiPxNuHYeryGN1OIKQUhyogS0VUOeP/mqGej+EIjCiG0zoCaO6gUuuUElLcwjWnYDVx8EXAg3QNPW7TYMBhXDtACjOsThQQcgwNF3XdauD6Ejq6+spKCigrq6O/Pz8hHyNUyfqqRt7CzZ0wjY7VYOGk3f1RMpuvoycy0aiObMS8nUFIZWpbYbX98DLO2HjyXM/32mHkdEk787hKuETrGV/DXxwCD44DKuOq+TubGion9/1ZTCzDAZ16YwoBSH9uJDcRhK7i+DEvko2fu/v5K9fR6/qY3Hn/Hn5FP/HZ+j26VvRHFJuEITlR+EfW+C9gxAIq2N2DSb2VNWcfBcUuCHfCRFUJW9DOVT74v+eWWXwlckwpqTT/wkZz7ZT8LPl8NGR+OPdc2FGP/Vz1HXQUb+HIrCuXL2uJSO6wQ+vgkm9Oy10QUgLJLFLcGJnEIrAmjUn2fH6emwr1zFm/zqKmmoBCA4ZRP9ffY3saWMTGoMgJCveAPzwQ/jX9tixEcVwx3C4taSeopAXR6/uaFln3gDpOhypgw0n4b0D8OZelTQAXN0fvjJFJYZCYjlSB79aCa/tVn922GBKb7iyv/o5DO2q2q/tUd4A7x5QSf2KoxCMqOP3jYL/uEIl9IIgnBtJ7DopsWtJcwgW7w2x+vEF3L7wL+Q3NwDgvn0WJT94GEdP6SMJmcPaE/DoIpUY9K0+wmcjW5jmO0De4YMEdh0kXFGtnmiz4ejdHUf/XmT160nWoL54bp9JVp/4stz+GvjDWpi/C8LRK9YNA+HnM6FQkoMOp8YHv1sDz2yJJWNzh8LXp0G/gov/O3+6PJbod8uB710Ftww+e3IoCIIkdpYkdgaNAfj1W7Xk/+8T3LzxdWzo6NnZ9Py/75B744xOj0cQOpNAGH6zGv64Dgae2M1Dq57msu0ftflczeVE9wfOPGGzkXP95RQ8eBvZV09Cs8XE+4dr4Q/r1JxeKAK9PPC7G2FSrwT9gzKQjSfhC2/CSa/68/S+8B9X6AypPkDzys1EGhqJNDWj+5rRm5qJNAdwDi0le/p4XGOGnHMEZfVxeGyxStZBVf5+MxuKshP8DxOEFEYSOwsTO4NlR+B/n9rNx179DSOPbUO32Sn5/WPk3TXbspgEIZFUNMKnF0BkwzYe+OhppuxfpU5oGu7Lx+EaNQjnsDKcwwfgHDoALTebcEU1oSPlBA+fIHj4BM0rNuH7aIP5d2YN6EP+g3PJf+BWbJ6YvHLrKXjkLThUp+b1Hp0KX5yk1JjCxfPPbfDdpSpBH5gf4addtjFww0c0vvURoUPHz/l6W14u7qljyJ4+gdybrySrf9sZtz+kkv8/rFNfa1hXeHaeKGgFoT0ksUuCxA6gwQ8/fD9E6S9/wQ1b3kbXNLr94lEKPnWbpXEJQkdT3QSf/+sp7vzHz5h0YK06aLPhuWMmRV97AOeQ0vP+uwJ7D1P/t/k0PP82kYZGABz9e9L9t4+RfcV483neAHx7CcyPzn9N7wv/PVsN9AsXRnNIJXT/2g7uQBPf3vUcM1a+TqTytPkczeXEfcV4HD2KseW40bLdaLluNJsN/6Zd+JZvJFLnjf2lWQ4KHrqTov/3Sez5bVtC7aqCj78KlU0wsAj+OQ9KxD1KEM5AErskSexADYH/5MMIzp//jnlrXwagy3e/QNGX77c4MkHoGOqa4Zv/vYMH//wtir3V4LCTd/cNFH3142SV9bnovzfS6MP7yrvU/PczhI6eBE2j4HN30uXbn8eW7QLU++vFHSop8YWgJBeemye2GhfC8XrVet1aEWH21oV8ZdmfcZ9WM5C2fA85s6aRe9MMcq6dElc1bY0eDhPYtg/f8o00vrOc5hWb1N9RXEjXxx4i7/6b0ez2M153sAbufQXKvVBaAP+8A3rlJeSfKggpiyR2SZTYgfrw+dEHOvpv/sLHl/8dgMKvPUCXbz2EJlPDQgrT4Idff/s97n72p7hCARhSRr9nfnxJCV1rIg2NVH3vDzQ88zoAWYP60f3338I9caT5nL2nVXKy7zQUZ6vkbmhxh4WQtuyvgXtegh67NvPl9/6HQSdU+dNR2ouu3/48uTfNuGhfzsb3VlH9nf8huE95pDhHDqL4J18l+/JxZzz3SJ1K7o7VQ598Vbm7WJGGIKQjktglWWIHKrn73gfQ/Mdn+fySPwFQ9P8+SZf/+KzFkQnCxdHkj/CPh57kurefBiByzRUMfPK7Z63qXAqN762i8ms/U4pam40u3/wMhf/2gHlzdNoH978KOyqhyA3/uB1GdU9IKGnBwRp44LkmPv7Sr5m1bREAmieHov/3SQofuhPN5bzkr6EHQ9Q9+So1v3xStWk1jeJf/j8KPjn3jOcer4f7XlFzkz098K87oH/hJYcgCGmBJHZJmNiBSu7+832of+pVvvb24wD0+OcvyZ051eLIBOHCaKpv5r27/ouRGz4AIPjp+xj6k8+12WrrSMI19VR96zd4X3oXgPzPzKP4J181lbO1zfCJ+bC5Qhkf/+M2GNsjoSGlJIdr4ZE/H+crT32LssoDYLOR//FbKPrmZ3B07/g+dri6lqrv/h7vC+8A0OU/Pkvho584o2NR4VWVu/01MLwY5t8DbvF5FwRJ7JI1sQOI6PCtJdDtl49z2/pXoaiA/u//FUdvsdMXUgM9EmHJLf9J6dqPCNqzCHz/3xnzhRs7NYa6J1+l6j/+G3Qdz7yZdP/9t02j43o/fOo1WF8OHic8NVfsUFpypA5+9JM1fPG575Pf3IDWrQs9//YjsqeMSejX1XWdmp//lZpfqwpvwUN30vW/vhxnZwMqubvxObV55L5R8NPrEhqWIKQEF5Lb2M56VuhwbBr81zWw5hOPsLvnUKipo/yh76MHz7F0URCShPXfeILStR8RsDup+d9fd3pSB1Dw6dvp/qfvgsOO95X3OPnAY0SamgFVqfv7bWpDgjegKnhbT531r8sYjtbp/P2r/+Tf//bv5Dc3YBs3gn6L/5LwpA5A0zS6/MdnKf7JVwGoe+IlTj38I/RAMO55JR747Wy1Z/a5bcqUWhCE80cSOwtw2ODxW5w8/rEf4nV5CKzdRvWP/8/qsAThnBz7xzsUPf0PADZ8+ZtMmzf+HK9IHHnzZtLzHz9Hy3bRtHgV5Xc9SrhObXzxOOHpuXB5H2gKwmcWxAx3M5XK2iAf3PUj7n3zf7HrERx330zpG//T6VtxCh66k+7/9z0zKS//+H8Q8fnjnjOjP3x5snr82JKYmbEgCOdGEjuL6J0Pj97Vi1/MeQyAuj88T+PCZRZHJQjt07RmG43//nMAFs38OHd+83qLI4Kc66bQ66X/xlbgoXnNVk7M/QrhepXBZWfB/90Cg7so8+TPvK6SvEzE3xxi5Z3fY9rGdwnZ7GR9/9/o9/tvdohA4mIwk/IcN77311D1jV/Teiroa1NgWjQx/+Kb4MvQn50gXCiS2FnITYOh351X8tLkuwCo+NKPCR4ptzgqQTiT4LEKDt//LRyhICuHTmfm/zyEK0mG2t2TR9Nrwe+xd+tCYPs+Kj79HXO0Id8FT94KXbJh2yn4t3fUnGsmoYdCfHTXDxm5+SMCDif2P/+cfl+aZ7nVUs51U+jxzE/BZqPh+bdNOxsDuw1+d4PaKburGr7/gUWBCsJZCB4px/vmh1aHEYckdhbz3Sth8Z0Ps6P3CPR6LxUP/+iMO1dBsJKIt4lDH/sPnLU17O8+ENuvv8Pg4uS6dLhGDKTnP3+BlpON74N1VP6/X5rvo34F8OebwWmHhfvhVystDrYT0cNh1tz/Ywasep+APYuaX/4Xg+ZOsTosk5wrL6PLtz8HQOVjv6F54864891z1R5ZDXh+O7wq83ZCEhFp9HHyE9+i4lPfpv6ZBVaHY5JcV+cMJDsLfntLFj+/8wf4srLxr9mK9+V3rQ5LEExOfuO/se3ex+ncIhZ+/WfcOzk5F3q6xg6l5InvqwrQP9+i9vG/m+cm9YafRdWVf1gLL+9s++9IJ/RwmB2f/RnFS94jZLOz7ds/5PKPT7M6rDMo/PJ95N58JQSCVDz4n4Sra+POT++n2rIAP/hAWdoIgtXous6pr/yUwPZ92Lt1Iee65LEtk8QuCRjeDT47pwf/mP4AAJXf/yMRb5PFUQkCNC1di+/FhUTQ+M39/8W37+lBMi9Lyb3+cop/+jUATv/sLzS8+I557o7h8KVJ6vF/LIYNaTz1oEciHPrKL3G/sZCwZuftL32fOx+ZbnVYbaJpGt1+9xhZZX0IHT9FxRd+iB4Oxz3nkckwpCvUNMPjqywKVBBaUPubZ2hc8D5kOSj6y49w9EoeN3RJ7JKET46BfbfezYnCXugVVdT87lmrQxIynEhTM+WP/gqA+ZPm8ZnPjqFrchbr4ij49O0UfOljAJz66s/wLd9onvv6NLhhIATC8OWFyvMuHTn1k78SeeFNwpqNv3/iO3z+sauTOiG353soeerHSkyxdC01v/hb3HmHDb5/lXr8zBbYVWVBkIIQpXHRck7/9C8AfPDJf+OGnWNYesjamFoiiV2SoGnwretc/HHWlwCo+cPzBA+fsDgqIZOp+fVTcPQElXnd2PeJh7h2gNURnT9dv/swuXOuhmCIk5/8FsGDxwHlI/mrWdA3X+0l/fYStREmnfC+uhjvb1Ub+k/zvsnD372O7Itb99qpuIaX0e3xbwBQ8/jTNC1eHXf+ir5w4yAlfvneB+n3cxNSg8CeQ1R8/oeg60TuvY0f9byVE17Ic1kdWQxJ7JKICT2heM4M1pdORAsEqP7+/1odkpCh+Hfsp+Z/nwfg9zd8ja9fn2txRBeGZrPR/Q//ieuykUTqvFR87vumEW6eS6kt7Ros2AOvpNFAvn/zbsq//FMAnp92L/O+dRO9k28BT7vk3TGL/E/PA6Dym78+w9/u2zPAZYdVx+CtfVZEKGQy4boGTj7wGLq3Cfe0sfz86q8Q0VUXYGJPq6OLIYldkvEf0zX+cuNXCGs2Gt/4AN+yDVaHJGQYeiRC5aO/RAuFWTZ0Bv3vvJLBHb8+NOHYsl2UPPEDbIV5+DftovrHfzbPTegJ/xaddf7O+3Co1poYO5LQqdMcf+Bb2Px+Vg+civ5vn2d6P6ujunC6fvcL2Ht1J3S4nNo/PBd3rm8+PHyZevzjj8TbTug8dF3n1Bd+SPDAMRx9Sjj+4x/x7tEs7Bp84wqro4tHErsko1cezL6xjNcn3gZA5bd+ix6SdWNC51H/1Gv412+n0ZnDX2/5mpkApSJZfUro/ruoCfj/Pk/juzGvky9eBlN7Q2MQvrJQzd2lKro/wMlPfRu9/BRHuvbjX5/5Hl+fYbc6rIvClptN8Q/USErtb/9xxkjKFyZC7zw43gB/Wm9FhEIm4n3xHZreW4XmdlLy9E/4yY4iAO4dBQOLLA6uFZLYJSFfmAhv3vBp6t15BHceoL6VcacgJIrQySqq/0utt3vymof42HXdKU4BwcTZyL1xBgWfvQOAU1/+MaHySkAZ4P73bChwweYK+O8UVVvquk7lNx7Hv3YbXpeHH9z7U352mwd3khhIXwy5c6/BPX0CenOA6u/+Pu5cdpZqyQL8cR0crbcgQCGjCJ+uoyr6/7Do6w+yOHsIG09CThZ8NXlsIU0ksUtCcp3wxVkFPHXVZwCo/ulfxf5E6BSq/vN/0Bsa2dlrOGuvvp1PW7cKtkPp8r2HcY4aTKS6TpmAR+00euXBz2eq5/xxHSw/amGQF0n93+bT8JxSwP5w3ve5b04/hnfu+tcOR9M0uv30a2C30/jWR2cIKW4aBFP7gD+sWrKCkEiqf/hHItV1ZA0bQO7nP8YvV6jjD41XJtrJhiR2Scq84XDwhrkc7dIHvaYuqVythfTEv3k3ja8tIazZePymf+cbM+wpXfVpic3touSJ76PlZNO8fCM1v3nGPHfjILhvFOjAN95LrX2y/p0HzErCn697mKyrpvDZNEnGncMGUPCQqrRWfes36P6AeU7T4AdXqY0Ub++DnZUWBSmkPb6Vm2l49k0Auv3q6zy/28HBWijOhs9NtDa29pDELkmxafCf1zh4Yeq9AFT/4QVT1ScIieD0r54CYMnI68gbO5hbh1gbT0fjHNSPbr94FICaX/wN3+ot5rn/nAF98pQFSqoY4Eaa/Zz6wg/AH2DVoKm8M+MeHr9eXTvShS7f+DT2bl0IHjhG7Z9eiDs3rBhuHqwe/+86C4IT0h49EKTy35WXZ94DcwiPH8Nvo8Xjr0wBj9PC4M6CJHZJzOTe4LtlNtWeLlBRSYOsGhMShH/zbpoWLiOs2fjHjE/xnzNIakPbiyXvnhvw3D0bIhEqv/ZzIs3KTiPXCT+6Rj3nrxth6ykLgzxPTv/o/wjsOEBNbhG/vOUxfnitRq88q6PqWGx5uXT9/sOA8rYLnYj/wXwxuknkjb3poWwWkova/32e4O5D2IoL6fqdL/Dn9VDlg9ICVeVPViSxS3I+N83Fy5PvBqD6d8+hRyIWRySkI6d/qZz+l4ycSe+x/ZjU2+KAEkjxj7+KvXsXgvuOUNNin+y1A+DWIcoA95vvQSiJ32pNi1dT9+cXAfjFnMcYPaILtw21OKgE4blrNu7Jo9Gbmqn+0Z/izo3sBteUqp/Z/4lCVuhAgodOKJN2oPiHj+DNzueJ6BKbb1wBWUksOpfELsmZ2hsO3TAXryuXyL7DNC1aYXVIQprRvGkXTe8sJ6zZeGbGJ/n8BKsjSiz2wjyKf/ZvANT+z7P4t8ecbr97pVLJbq9UlbtkJFxVw6kv/wSAVy+bx6ah0/iva9KzwgpKSFH8k68C4H1lMcEDx+LOfynqa/fSTjjp7ezohHRE13Uqv/k4enOA7Csn4rnzep7bpuZvh3VV4p1kRhK7JEfT4FMzPLw+YS6gqnaC0JHURKt1i0fNImdIP64ptTaezsAz52pyb74SQmEqv/Zz0yuyW66atwM1a3ekzsIg20DXdU597eeEK09ztHsp/3fdF/nyZOhfaHVkicU1dig5102FSISa38dfAyf1hsm9lA/hE+LnLnQATYtX41uyGpxZFP/iUYIRjb9tUucempD8N1GS2KUA15fB2ll3EbBnEVy7NW7oWxAuheaNO2latIJItFr3uRS4aHUUxT/7N2z5HvybdlH355fM43eNgGl9oDmUfLtk659+TVVXHVn8YO736Ffi4vNJqszraAr/7QEAGp5/2/QiNPhSdNbuuW1Q4+vsyIR0Qtd1an75JAAFn70D58B+vLEXKhqhWw7MSQFRmSR2KYDdBh+7pph3R88GoEaqdkIHYVTr3h11PaG+fZmbpnNabeHoUUzX738RgNM/+wvBg8cBldj+9Fq1k/TDI/DqbiujjBE8VkH199T+6P+75vMcKBnET64FZxLP+nQk2VPG4J42FoIhaqN7jA2u6q/m7ZqC8NRmiwIU0oKmxavxb9iJlu2i8JH70HX4S7QS/Mmx4EoBC6iEJ3Z/+MMfKC0txe12M2XKFNasWXNer3v++efRNI3bbrstsQGmCLcPg3evvZcIGr5FywnsOmh1SEKK07xhB03vriRss/PMjE/y6fGZkyQY5H38FrXhwOen8uu/RI+W5wYUxRzl/+tDqPef5S/pBHRdp+obv0Zv8rFvwBhennwXd42AKWkscmmLoq+pql393xcQrq41j2tarGr3t03gDZz5WkE4Fy2rdfmfvh1HtyJWHVczt24HfHy0xQGeJwlN7P71r3/x6KOP8r3vfY8NGzYwduxYZs+ezalTZ/cSOHToEF//+teZMWNGIsNLKdwOuHFmP5YNuxLgjDkTQbhQYtW6WdT36JPU8v1EoWka3X/972huJ74P19Pwz7fMc5+boHZAVvvg92stDBJofO19mt5dScTh4L9m/zuF2Ta+Nd3amKwg+5rJOMcMQW9qpu6Jl+LO3TAQygqhzg/PbrUmPiG1iavWfek+IDa3eedwKMq2MLgLIKGJ3eOPP85DDz3Egw8+yIgRI/jTn/5ETk4OTz75ZLuvCYfD3H///fzgBz+grKwskeGlHPePgdeuVP/ZGl569wxPJ0E4X/w79tP03irCNjv/mP5J7hsF+S6ro7KGrLI+FH0zur7vR38iXK+klVn2mJDiyY3W+aSFaxuo+tZvAXh++gMcKS7l2zOgS4p8yHQkmqaZVbu6v7xMpKHRPGe3wcNRhewTG8AfsiJCIVVpq1q3vwYWR5tjn0mhjS4JS+wCgQDr169n5syZsS9mszFz5kxWrlzZ7ut++MMf0r17dz7zmc8kKrSUpcAFk28cweZ+Y9HCYepbVBcE4UKof/JVAD4aeiWnivvw6XHWxmM1hZ+/m6zB/YlU1VLz66fN49eUqvmtYMS6naTVP/hfwpWnqenVn6enfpzxPVT1IFPJvflK9bOq81L31Py4c7cNgx4eqGyCd/ZbE5+QmrRVrXsyank0cwCUFVkY3AWSsMSuqqqKcDhMSUlJ3PGSkhJOnjzZ5muWLVvGX//6V5544onz/jp+v5/6+vq4X+nMp8fBwglzAKh+5i0xLBYumHC9l4YXFwHw2sTbmTsUeqbZxoILRctyUPyjLwNQ9+cXCew7oo5rqmpn12DRAVh2pHPj8i3fSMM/3gDgB7O+QdDhTNutIOeLZrNR+JX7Aaj74wtEfLEBSKcdPjZSPX5umxXRCalIW9W60z54cYc6/1CKeXsmjSq2oaGBBx54gCeeeILi4uLzft1Pf/pTCgoKzF99+/ZNYJTW08MD+XOuwuvKxXa8HN8yMW4SLoyG5xeiN/k42G0Am/uP43MpdtFKFDnXTSFn1jQIhan+7u/N40O6wgNj1OMffdh5GykizX4q/98vAVh75Vy29B3DzYPhsl6d8/WTmbw7ZuHoU0K48nTcXCTA3SPUvtyVx+BAjUUBCilFW9W6f2wFfxhGdU89kVLCErvi4mLsdjsVFRVxxysqKujRo8cZz9+/fz+HDh1izpw5OBwOHA4Hf//731mwYAEOh4P9+9uuqz/22GPU1dWZv44ePZqQf08ycecEN0tGqhZ3zT+kHSucP7quU/831YZ9beLtTO2rMez876PSnq4/egQcdpreXUnT4tXm8a9NUaMQu6rh+U6qBNX+9zME9x8lVNyVH07+PE47/McVnfO1kx0ty0HhI+oDuPYP/0QPh81zvfPh6v7q8fPbrYhOSCXaqtb5Q/D3qG3OZ8enXoU8YYmd0+lk4sSJLF682DwWiURYvHgx06ZNO+P5w4YNY+vWrWzatMn8deutt3LNNdewadOmditxLpeL/Pz8uF/pzpTesGn6LQA0vfEB4doGiyMSUgXfR+sJ7juCz5nNe6Ov556RVkeUXDgH9qPgoTsBqPrO/6AH1QR+UTb821T1nF+vUsrLRBLYf5Sa/3kWgCdu/hqN7jweHAf9ChL7dVOJvPtuxlaYR+hIOb6l6+LO3Re1pXhxh4gohLPjW7LmjGrdogNqTrMkF24ebHGAF0FCW7GPPvooTzzxBE8//TQ7d+7k4YcfprGxkQcffBCAT3ziEzz22GMAuN1uRo0aFfersLCQvLw8Ro0ahdPpTGSoKYWmwdQbhrK/exm2YADvy+9aHZKQIhiiiXdG34AjLzfpdx5aQdH/+yS24kKCew9TF/1+gfKwGtQFTvvgd6vP8hd0ANXf+wMEQ5yeNIUX+1xFl+yYT5ugsGW7yLtLmbbXP/N63LlrStWH8mkfvHvAguCElMGwzcn/xK04uimFhDFbd/fI1PT2TGhid8899/CrX/2K7373u4wbN45NmzaxcOFCU1Bx5MgRysvLExlC2nLHCI13xqmqXeXf37Q4GiEVCJ04RePbywBYcNnt3DpU+SMK8dgL8uj6rYcAqPnFk4Sr1KBWlh2+E7U/eWpz4vbINi1ZTdM7y8Fh5/uXfxk0jX+LtoKFePIeUEKyxneWEaqoNo87bJjVaBFRCO0RPHicpiXqLq3gM3cAcNILH0VFUqmqPk+4eOKRRx7h8OHD+P1+Vq9ezZQpU8xzS5cu5amnnmr3tU899RTz589PdIgpSfdcCN48i6DNgbZjD/6te60OSUhy6p9eAJEIW/qP41C3AdKGPQt5992Mc9RgIvVeTv/8r+bxq0thRj8loPjvVR3/dfVgiKrv/A8Au268g215/RnUJdZaFOJxDS/DNWkUhMI0PP923Ll7RoIGLD9qnQehkNzUPTUfdJ3sa6eQNUApJF7eCREdJveC0kJLw7tokkYVK1w4c6cWsnyoKiHURC0RBKEt9EDQbFe9OnEew4thdHeLg0piNLud4h9/BYD6Z94gsD/mc/KNy9Xvr+6C3VUd+3Xrn5pPcM9h6FLAt0d8CoBvT1cVKKFt8j+uOhf1/3gjzv6pT77yIITOE7wIqUOkqZmG51S3q+AztwOg6/BStA175wirIrt05HKRwlxVCiun3gxA/YvvEmm2eKGlkLR43/iAcOVp6vK7smzoDFXNSDGlV2eTffk4cq6/HMJhTv80VrUbUwI3DgIdJaToKMLVtWZ18KPbH6ImK4+pvdW8mNA+nrnXYsvLJXToOL7lG+POtRRRBMJtvFjIWLzzFxOpbcDRryc51yll1PpyOFAL2Y7UFE0YSGKXwjhsMHTOZVTkd8fe0EDj2xZZ4wtJjyGamD/2VuxOB7cNtTigFKHLtx4CTaPxtSX4N+82j/+/acor7Z39sLFtv/UL5vTPnyRS54VhA/mvHqoK9fXLJQE/F7bcbDx3KPunhlYiimtL1dhKlYgohBboum5eE/M/ORfNrhQSL+1U528eDJ4U1mtKYpfi3DPGzjtjbgSg8mkRUQhn4t+xn+bVW4jY7bwx4VauL0udZdZW4xo5yEwaqn/8Z/P44C5wxzD1+BfLL/3r+Hfsp/7p1wB4Zd5XCWp2ru4Pk8SM+LzIf+BWALxvfki4utY8nmVXhsUA/5R2rBDFv2EH/s270VxO8u9XXS9fEF7fo86nchsWJLFLefoVQOUNNwEQWbGe4BFRGQvxNPxrIQBrhlxBdV6xiCYukC7f+Aw47PjeXxPX6vvaVMiywYpjl7ZqTNd1qv/zdxCJoF9/FX+wqW3jXz/T7lNoB9eYIbjGDoVAkIYXFsadM0QUHx1JnJJZSC0MG6Pcuddi71oIwML94A1A3/zU2zTRGkns0oDZV/diQ+kENF2n/oVFVocjJBF6OIz3lfcAeGvkbHp54Ir03rrX4WQN6G1WhKp//Gd0XQfUcP790RmuX65Qg9cXQ9O7K/B9tAHN5eSJWV9EB24YCKNLzvlSoQWG9Un9M2+YPyNQN78z+qnHr+6yIjIhmQhX1+KdvwSAgk/fZh43vOvuHK7GLFIZSezSgNkDYcW46wE49epSa4MRkgrfik2ET1bhy/GwetBU7hoBdnnXXzBFj34CLduFf+025TEX5ZFJkJMFmyqUW/2FoofDnP4v1eIN3H8nz9f1QkPN8AkXRt68mWg5boJ7D9O8emvcubnRmVKj1SZkLvXPvgmBIK6xQ3FNUD3XY/WwIrqN9I4Ub8OCJHZpgdsBXedMJ6zZce7ZR2B/+u/LFc4P70tqK8mSIVcTdDi5S9qwF4WjR7G5auz0T54wd5N2y4VPj1PP+eUKCEfa+QvawfvyuwR2HsBW4OE3Yz8OqCRkSNeOijxzsOXl4rntOgDqn1kQd+76gWqDwN7THW9RI6QOejhM/VPzAbUXVosqk17eqVTul/dRrdhURxK7NGH2xAI2lk4AoG7BBxZHIyQDkWY/ja8vBeDd0dczLU0uWlZR+OX7sRV4COw8YLa3AT43UW2F2Hv6wipCuj/A6Z+r5ePeT97P25V52LXYTlrhwsn/RHQTxYL3CdfFdmjnu+DqqKfd6+LlnrE0vbeK0NGT2ArzzJuAiB5Tw96VBtU6kMQubZjYCzaOuxqAyleWWhqLkBw0vbuSSEMjNUXd2dpvLLcOsTqi1MZemEfhl+8H4PTP/4oeVNvlC1zwWXVPxe/Xqg+K86H+7wsIHSnHXtKVXwxS1cC7RqSu230y4JowgqyhpejNAXN9nsEt0f//r++++HlIIbWpjxr55913M7YcNwBrTyhRjccJN6TJ7mxJ7NIEmwaFt8wgrNlw7dpN8PAJq0MSLMb7smrDLhw2E5vNxuyBFgeUBhQ8dCf2bl0IHS6n4YV3zOOfGgv5TlW1e3vfuf+eiLeJ048/DUDtZx/kw1NunHb4yuRERZ4ZaJqG57ZrAWh8bUncuZkD1NjKoTrYVmlFdIKVhE/X0fTeSgDy773JPL4gak954yA1L5sOSGKXRsy8rIjN/ccBUCvt2IwmXNtA47vqIrZ41Cym9oGuORYHlQbYctwUPnIvADW/eQY9pKp2+S741Dj1nN+vOXdFqPZP/yJSVUtWWR9+3Vv5aN09AnpLq/yS8cxViV3T0rWEa2Pt2FwnXDdAPRYRRebhXfA+hMI4Rw3GOUz9RwhHlMk4wC0pvGmiNZLYpRETesKmaDv21MvvWxuMYCmNry+FQJATPcs4UDKIm9KkxZAM5H9yLrauBYQOHY+btfvMeMjNgh1V8N7B9l8frqqh9vf/BKD+4Yf48IQDhw2+cFmiI88MnIP74xw5EEJhGt/8MO6c8eH9xh5px2Ya3heVFVjenbPMY+vLobJJVdsvTyMbKEns0gibBgU3X0kEDff2nQSPdtCuIyHlaIi2Yd8cNgublj6zI8mALTebwoc/BkDN4383FbKFbvjEGPWc/zlL1a7mN8+gN/pwjR3Kb4quBuD2oSJs6Ug8t6qqnbdVO/baASr5Pt4AG+TymDEED5+gec1W0DQ882aax42xiVlR1XS6IIldmnHdlK5s7ac+XUQdm5mEjlfQHN2QsGTkTKb0hmJpw3YoBZ+Zh60on+D+o6rFE+WzE9Qc1+YK+LCNbRTBoyep+9t8ALxf/jzvHrKhAQ9P6py4M4XcudcA4PtwPeHTsXUTbgfMKlOP35B2bMbgfVlV1rOnj8fRsxugRE5GYndjmt34SmKXZozvAZvGXg3AyZeXWhqLYA0N0fbgvkFjqSjswc1pNDuSLNg8ORR8/i4gWrWLKAO74pzYNorfrT6zalfz+NMQCJI9YwK/d6ps7pYhMLCo00LPCJwD++IcPRjCYRrfjL/BnRNVx76x98J9B4XUQ9d1Gl5WbVjPndebxzdXQLlXVXCNzSTpgiR2aYZNg/w5VwGQvXUboXKRf2Uahhr2tSHXqzasqGETQsFn78CW7yG462DcLNfnJ4LLDuvKYeWx2PODR8ppeP5tAHxf+qxZMfqSzNYlBENE4X0tft54Rj8ldjnVqKwuhPQmsHUvwT2H0VxOcm+5yjxuVOuui6ql0wlJ7NKQ66Z1Y1ufUYCoYzMN/84DBLbvJ+Jw8OHwq5ncW21HEDoee0EeBQ/dAUDNr58295OW5KrF8wC/WxN7fu1v/wGhMNlXXcYfw6PRURYcw7t1cuAZgpHY+T7aQLiqxjzucmBa/4g6Nv1peElV63Kuvxx7vgdQlfS3okbV6daGBUns0pJxJbA5qo4tf2mppbEInYv31cUAbB8+lYbsfG5Ow4tWMlHw+bvRcrMJbN8Xt0P24csgy6YqdmtPqNm6+ufeBCD0xQd5Neqd9Yj41iWMrNJeuMYNg0gE7xvxN7iGWffb+yAk7di0RQ+HTeV63l2xNuz2Sjharyp1V5daFFwCkcQuDdE08My5GoDszVsIVVRbG5DQaTS+/REAC8quQUPUsInGXpRPwWfmAVDzq6fMql2vPLgzup7oT+taVOuunMiftTGEIjC9r5qJFRKHIaJo3Y69vC90yYZqH6yU1dppi2/ZRsIV1dgK88i5Lrar761oG/bq/uljStwSSezSlGsvL2FH7xFouk7t6x+e+wVCyhPYf5TgroNE7HZWD5rGlN7QXdqwCafw4XvQctz4N+/G936s9/q5CaABWzZVUBet1vHIg7ywPfpQqnUJx3OrSuyaV2yKu8F12ODGaDv2rfPYFCKkJt5oG9Yz91o0p8rgdB3ejrZhb0pTYZkkdmnK2BLYMlYNip545SOLoxE6g6aFajfmvkHj8Wbnpe1FK9mwFxeR/4BaPl/7h3+ax8uK4PqBcO+Kf6AFQ7inT+DvWWPxh2FiT5ja26qIM4esfj1xTRwBkcgZZsXXRxO7xQfFrDgdifj8Zgvec0fMlHhPNRyoVb5115ZaE1uikcQuTdE0yL/hcgBcGzcRafRZHJGQaIyl528PmIFGeg4FJysFX7gH7HZ8H67Hv3m3efzh3hXctFEtHm/+wqd4dqs6/sXL1HtUSDweox07f3Hc8al9VBuuolF2x6YjTe8sR/c24ejbA/eU0eZxQw17ZT/Ic1kUXIKRxC6NmXZlf8oLeuIIBfF+tMHqcIQEEqqsUc7qwIrBVzBZ2rCdSlafEjy3KxVm7e+fM4/3+eezZEVCbOw/nv9sHE99QHnWXTvAqkgzD7Mdu2oLoZNV5nG3I+ZftviAFZEJicTYvuOZNxPNFkt13kpTU+KWSGKXxozvqbF56BQAjixYZXE0QiJpemc56DrH+w2lsqDEtHMQOo/CL90HgHfBUoKHTxA6cYr6Z1W17u8zHjSH9B+aoPwmhc7B0bsE16RRoOumuMjgumiCvfgsu32F1CPibTLnXVuuEDtQA7ur1YylsYEkHZHELo2x2yA4QymBwh+sNhV7QvphfGC9WzYDSN/ZkWTGNWoQ2VdPgkiE2j/+i9r/eQ4CQdzTxnJy5Hh0IMcBtw+zOtLMI3f2FQA0vbsy7rjxPtlyCiq8nRyUkDCalqxG9wfIGtAH5/BYBme0Ya/oCwVui4LrBCSxS3MG3jCBgD0Lz6lygnsPWx2OkAAi3iZ8H6wD4MMhMxhQCANkRZUlFH5ZVe0ann2Dun+8ro49+ins0SutzSbVOivIuV7NG/s+Wk+kqdk83i1X+X4CLDlkQWBCQjBudHNumo7WYpj1vWhlNt07GpLYpTlXDstma/9xAJx4U9qx6UjT0rXo/gB13XtxqNsArim1OqLMJXvGRJxjhqA3B6A5gGvcMLYMnMiJBnXeG5BtB1bgHDYAR58S9OYAvmXx88bXRQs60o5ND/RAkKZFqjKbe+MM83iNDzaWq8fp3tGQxC7NKXBD+QTVjq1auNriaIREYNydfjRkBmha2l+0khlN0yj47B3mnwu+cDd/3qAqBoYZ8f+tF3uNzkbTNHJmTQOg6b34duzM6JzdR0egOdTZkQkdjW/FJiL1XuzduuC+bKR5/MMjoAPDukLPPOvi6wwkscsAimarxC5vy2Yi3iaLoxE6Ej0UomnRCkDN1+VkwWTxR7OUcHWt+fjk8QaWHFJGxT+6BnKz1PD2UpmK6HRyZkYTu3dXxs0bDy+GXh6V1K2QLRQpT+Nb0TbsDVeg2e3m8aWH1O/puEKsNZLYZQBTp/flRGEvHKEgtUvF9iSdaF61hUhtA/78Arb3GcX0vmrJuWANeiBI3Z9fMv/c8OcXsEXCzB4Io7vDvaPU8SfkbdjpZE+fgOZ2EjpWQWBnzN9E02L2M9KOTW30SITGqFF7yzZsRIcPojdTmTCqIoldBjC4q8b2Eapqd/h1mbNLJ4y7080jriBis4s/msU0vLSIcHkltu5doTCPgorjTN/9IZ+bqM4/OE6JJ5Yfhd1VZ/2rhA7GluMme/oE4Ex1bEvbE2mTpy7+zbsJl1ei5WaTPWOCeXxrhdoLnOdUW1/SHUnsMgBNA+1K5Wdn+3CV2J6kCbqum3enC/qpu9NMuBtNVvRIRFmcAEVf+hg7r58HwCc3vmh+mPTJjyny/rbZiigzG0Md2zqxu7wvZDug3As7JOFOWcw27HVTsbljayXeP6R+n94PsuxtvDDNkMQuQxh68wQCdieeqgoCuw9ZHY7QAQS27SN09CRht5t1AyYxohh6eKyOKnNpfHsZwX1HsBV4cN53K78acBtBm4MBB7bGrRn79Dj1+ys7lVJP6DxyZqnErnntNsI19eZxt0N96IO0Y1OZxrfUPuDcm6+MO27MtF7dv7MjsgZJ7DKEqYPcbC0dB8Ch16Qdmw4YatiDIycRyHJxjbRhLUPXdWp/9w8A8j89jzeO53Awq5jVY6Nrxv78ovncSb1gVHfwh+G5bZaEm7Fk9SlRhrWRCE3RzQQGZjtW1oulJIF9RwjuOQxZDnJmTjWPVzfBppPqcaZ0NCSxyxDcDqierFRhte9KYpcOGGrYt/pNB9LfmymZaV6+Ef+GnWhuJwWfvdNss2Y9eBcA3lcXm3tKNS1Wtfv7FgiGLQg4gzHVsdH3j4Exn7qpAk41dnZUwqVi3OhmXzEee36sdWHYnIwohpIM6WhIYpdBdL9RzdkVbtsiticpTriqBv8W5XT7Qf8pFLpjPmlC51Pzu2cByLvvFjYEi9hRqW6mbrp9GO5JoyAYov7p18zn3zIYuuXASS8s3G9V1JmJ6We3ZDV6OJZVl+TCmO7qsTGTJaQOjW9H1bA3zYg7bvwsM6VaB5LYZRSXT+/L8aLeOMIhTr23zupwhEug6YN1oOvUlg6ixtOVq/pjrq0SOhf/zgNq4bjNRuEX7+Fvm9TxecOg0A0Fn1NVu/qnXyPS7AeUJc39o9XzntzU+TFnMu5JI7EV5hGpqad53fa4czNlC0VKEjpZhT/6s2xpcxKOwIfGfF2pBYFZhHwUZBA982DvKDV7cOR12UKRyjQtUfNBa8omA9KGtZK6/3sBUAPbp4p68U60AvepsZjH7b26E66swfvqYvN1948Gpx02lMdmgITEozkc5Fyruhet1bHGcP2KoxCKdHZkwsXS+M5y0HVcE0fg6FFsHt9cATXNkO+ECRlgc2IgiV2G4bhGJXZZy1eL7UmKous6vqUqsXun52RsGlyVIWqvZCN06jTel94FoPDhe3hmC4R1uLwPDI1+vmhZDgo+o6xP6v78kvm+654Lc4ao50jVrnNpb73YqO5Q4IKGAGypsCIy4WIwbE5aVusgtm1iRn9wZFC2k0H/VAFg5I1jCdoc5FdXEDhUbnU4wkUQ2L6f8KnThN1utvUdzfgeUJRtdVSZSf1T89H9AVwTR6CPG8U/o529B8fFPy//gTlo2S4C2/bSvDJmYGc87829UOHtlJAFUBU7m43A9v0Ej8UyOLtNedoBLJP1YilBxNuE76P1QPvzdZlic2IgiV2GMa4sm719hgOwf6HsNUpFmpaoNvrBIeMJOpzShrWISLOfur+9CkDhF+5h/m6obYa++THrDAN7UT55d80GiFs5Nrq7sj8JReCZrZ0WesZj71KAe+II4Myq3Yyon92yI50dlXAx+JZtgGAIR2lvnINjGVxVE2w5pR5fVWpNbFYhiV2G4bDB6TFq1crpJZLYpSKG/9aSPmq+7soMuxtNFrwvvUukqhZHnxJybr6Spzap458c27aQpeChOwFlyxA8EquWG1W757aCP5TYmIUYhu2Jb+nauOPToxW7DeXQGOjsqIQLxZg3NuYmDQzRxMhuSvGcSUhil4HkXaUSO8/GDTJnl2JEvE00r94CwLL+k8l3qQuX0Lnouk7tn/4FqIRtdYWDXdVqLdXdI9t+jXPYALKvngSRCHV/fcU8Pnug2hhS7YO39nVG9AJA9lVqga9v2YY425N+BWr1WzACa05YFZ1wPui6TtMS5cuac+3kuHOZaHNiIIldBjLq+pEE7E7y6qpp3C2DJKmEb8VGCIZoKunJsS59mdZHbE6swPf+GoK7D6HlZpP38VtMi5M7hqvh+/YoeOgOABqee5OIT1mfOGxw3yh1/pktCQxaiMM1dii2vFwidV78W/eaxzUtVrWTdmxyEzxwjNDhcshykH3FePN4RI/NSGbafB1IYpeRDO7pYm8/VVbY+7a0Y1MJo+2wdehk0DSu6GtxQBlK7Z+UxUn+/bdQoXl4N7qG6pNjz/66nOum4ujXk0htA975MeuTe0epBG99OWyvTFTUQks0hwN3NBnwfbg+7pzM2aUGvuhYSvbUMdg8Oebx3VVw2gc5WTAuA43bJbHLQDQN6serC1r9hxstjka4EIwL2Ts9VdtBErvOp6UhccFDd/LPbapCMLUPDOl69tdqdjv5n5wLQP2Tr5rHu+fCjYPU439I1a7TyJ4Rbcd+FJ/YGcrYXdWyXiyZMYRk2a3m61YcU79P7gVZ9s6OynoksctQulyt5uwKNm+UObsUIXj4BMEDx9Dtdtb0m0hJLgwssjqqzMM0JL5pBvTtxT+3qeMPjD6/1+ffdxM4s/Bv2kXzxp3m8QfGqN9f3QV1/o6MWGiPnCtVYte8egu6P6aU6JIdm11dIdMqSUmk2Y9vuSpM5FwTP19n/Mwuz9AbX0nsMpQxM4fT7HCR11BD7dZDVocjnAeGGrZq6CiaXLlc3ldVX4XOI1xVYxoSF3zhbhYdgMomtff1+oHn93fYi4vwzL0GiK/aTe6lKn6+ELyys71XCx1J1tBS7N27oPv8Z6wXmy7t2KSmefVW9KZm7CVdcY6IvflCEVh9XD2WxE7IKPp0c7JvgCoR7H1r/TmeLSQDRht2bZm0Ya2i/h9vKEPiccNwTx5tih0+NkqtBztfCh68HQDv/MWET9cBKkk3qn7PbAEppCceTdNi7djWc3YtjIrlZ5F8NL2v2rA5105Ba3GHu+2U2hyS74IRxe29Or2RxC6DaZ6o5uwaP5I5u2RHD4Zoin7wvFEiiZ0V6KEQ9U/NB6Dgs3ewr0Zj5TGwaTFV6/niumwkztGD0ZsDNDz/tnl83nDIzYL9NbE5ISGxGIldU6s5u0m9wWWHci8cqLUgMOGs+Az/unbasJnsGJCh/2wBoPhaNWfXZdsm9IhsvE5mmtdtR/c2ES4sYHePIQwohF55VkeVWTQuXE7o+ClsXQvInXsNz0Y3RVw34MJ/FpqmmVW7ur/NN99/HqdK7kCsTzqL7OicnX/DTiINMaWE2wETo4vjpR2bXIROnCKw8wDYbGRfdVncuZXRG6JpfSwILEmQxC6DGX/dUJqc2Xia6jm5br/V4QhnwZivOzJqErpmk2qdBdT95WUA8h+4lWa7i5d3qOOG6OFC8cybiS3fQ+jQ8bjtB0Y7dtF+KG+4lIiF8yGrbw8cpb0hHMa3clPcObE9SU6a3lfvF9eE4di7FJjHA+GYqXQmXyMlsctguuQ7ODhQGW/tEz+7pMaYr/ugr2o7ZOpQsFX4dx6geflGsNsp+NRcFuyG+gD0L4h9+F8ottxs8j52I6CqdgZDi2FKbwjrmIpbIbEY6tjWc3aGgGLlMTWULyQHhs1J6zbsppPQHILibBjcxYrIkgNJ7DKc4CQ1ZxdYIXN2yUq4tgH/5t0AvNl9EpDZbQYrqI+uAMu9cTr2XiU8E23D3jdazdhdLPkP3gZA06IVBI+eNI8bVcDntkEw3MYLhQ6lPT+7kd3UJpGGAGypsCIyoTV6KITvA1Wxa70ftqXNSSY7Bkhil+H0nKUuaN23byISlA3kyUjzqs2g6/j79aM6r5gR3ZTPltA5hGsbaHjxHUCJJjZXKOWdyw53j7i0v9s5qJ+a8YpEqP/7AvP47IHKQqWyCd47eGlfQzg32dOjN7g7DhCqrDGP222x6vhH0o5NCvwbdxGp82IrzMM1fljcOUNwlOkdDUnsMpyxVw+iwe0hx9/I4ZV7z/0CodPxLVNt8v2D1YdPJs+OWEHD82+hNzXjHF6G+/Jx5maImwd3TIKdHxVRNDz7BnogCCjrFCNplHZs4rEXF+EcqVZ/+Ja1vV5suRgVJwXmtomrJqHZYx5DviBsKFePL8/wjoYkdhlOrtvO0SHjADj4jszZJSO+5ZsAeL9kHCCJXWeiRyLU/VWZCBd89g7q/RoL9qhzH79I0URrcmdfgb2kK+HKGhoXLjOPfyxqofLhYThS1zFfS2if7Hbm7Iz324ZylTwI1mLO110bP1+37gQEI9A7D/oVtPXKzEESOwF9iqoERVbKnF2yEa6pJ7B9HwBLuo/HYVMbCoTOoWnxakKHjmMr8OC5Yxav7AJ/GIZ1hQkdtFxcy3KQd9/NANQ/87p5vF9UmKED/9rezouFDqO9Obv+BdDDo5KGjSfbeqXQWYRP1+HfuAtoY77OaMP2yez5OpDETgD6zlJ+dj12byEckDm7ZKJ55SbQdXz9+lPj6cq4Esh1Wh1V5mBYnOTddzNaTjbPR9ui947u2A+P/PtvBk3Dt3QtwUMnzOP3Rqt2L+wQEUWiyZ42Fhx2QofLCR6O/Qw0LXYztea4RcEJAPg+2gC6jnN4GY4e8WslMn0/bEsksRMYOb2Mhuw8sgM+9i6TObtkwmjD7hkk83WdTWD/UXxLVoOmUfDpeWw8CbuqlWji9qEd+7Wy+vci+2qleK5/9g3z+KwyZd1wqhGWHOrYrynEY/Pk4J6gBhtbt2On9Fa/rznR+lVCZ+JbrsaFsqdPiDte74ctp9RjcQyQxE4AnFk2TgxSpYGjS7daHI3QEt9y1R5f0j2a2F2kZ5pw4Rgq1ZyZU8kq7WWKGG4ZAgXujv96+Z+4FYCG595EjyrUnXa4KyqieE7emgnHnLNr1Y6dHE3s1pcrE1zBGnzL1PUwe0Z8Yrf2OER0KCuEnrKRRxI7QREZr+zuQ+vk0yNZCJ+uM+frlvUYh8sO40osDipDiPj8NPzzLQDyP3UbDX54PSqauHdkYr5m7uwrsHfrQvjUaRoXrTCPGyKKDw7D0frEfG1BYVSCfCs2oeu6eXxwFyhyK/Pbbaesii6zCZ2sIrj3MGga7mnj4s4tlzVicUhiJwDQ40qV2HXbtZVIRD/Hs4XOwLdys/q9fyk1ni6MLQGXw+KgMoTG198nUlOPo08JOddN4bXd4AvBoC5wWYLEK1qWg7x71SaKlp52pYWqBS8iisTjmjACshyEK6oJHWo1Z2e0Y2XOzhKM7oVz9GDshfFluZUyXxeHJHYCACOvHU7Q5qCooZqjO8qtDkcAtcIK2Bf1r5vU28poMou6p14D1F5YzW7nn9GE6t6RiVXc5X98DqBWyAWPxN6H90Wrdv/aLqutEokt24VrnDK99a3eEnfOEFCslsTOEgw/z9bzdbXNsKNKPZ4qFTtAEjshSnaeixP9hgCw7z1pxyYDxqDwhz1UYic2J52Df9s+/Gu3gcNO3v03szW6acJph3nDE/u1swb0Jvuqy0DXaXj2TfP49QOhqyGikE0UCSV7qjIobF61Oe64IaBYdwLCklx3OuZ83RXj446vixZWB3WB4pzOjio5kcROMGkeo9qxjasksbOacHUtgR0HAFjcbRwaMKGntTFlCvV/V9W63JuuxFHS1RRN3DCwc1a55T+gRBT1z72JHoqJKO40RBSyiSKhuM3ELr5iN7wbeJxQH1DqaKHzCB6rIHToONjtypamBWujid0kufE1kcROMCm6QiV2+TsksbMaY76uecAAanOLGN4N8l0WB5UBRLxNNLwQ3Qv74G00BuA1QzQxqnNiyL1xOrbiQsInq2h6b5V5/GNR0cbSQ3BMRBQJwz1JXQeD+48SOnXaPO6wwcTozZXM2XUuRhvWNW4otrzcuHOS2J2JJHaCybDr1QWtd/lBKk40WBxNZmNcyA4OlTZsZ9Lw8rvojT6yBvXDfcV4Xt8D3gCUFnSe4k5zZpF/701AvIiirEjFoAMv7eicWDIRe1E+zuFlADSvib/JnSICCktoNubrWrVhm0OwpUI9lsQuhiR2gklRny5UFvfGhs7OxfLJYSXNKzYBsLxnNLET4UTC0XWd+r/NByD/k3PRNI3nDdHEqM5dU5R3/y0ANL23itDxCvP4PdGq3Ys7lG+XkBjc7czZtVTG6vL97xR0XTcVsa2FE1sq1Kq37rnQN9+K6JITSeyEOOpGqKpdzXJpx1pFuKqGwE41X/dm0TggcRYbQgz/+h0Etu9DczvJu+cGdlWp3aAOG9yRYNFEa5wD++K+YrwSUTy/0Dx+4yDId8KxhtgKJaHjcU9Vc1yt5+zGdFebR6p8cKDWgsAykNDhckLHKiDLgXvy6LhzLduwmb4ftiWS2Alx5ExRg0TuLZLYWYUvWq0LDCqjLqeQ/gVQknv21wiXTl20WueZey32onxeiFbrZg6AbhZ8//Pvi7Zj//kmekTJMN0OuDW6zkw87RKHoYz1b91LxNtkHnc5YHwP9VjasZ2DsQXEPWEEttx49ZLM17WNJHZCHAOvU3dEfQ/toM4bsjiazMTYD3tE5us6jXBNPY2vLQEg/8HbCIThlV3q3D0J2jRxLnJvuRpbXi6hw+Vmax5iIop39isPL6HjcfTqjqNvD4hEaF4bL0M22rHiZ9c5xNqw8fN14QisjyZ20tGIRxI7IY6e40tpzPaQHWxm24f7rA4nI/GtUBeyFb3EmLizaPjX2+j+AM5Rg3FNGMF7B6CmWVVKr+xvTUy2HDeeeTMBqH/2DfP4qO4wohj8YZi/25rYMgFzzq6VUbEhoFgriV3C0XW9XWPiPdXKeiY3C4YXWxFd8iKJnRCHZrNRNUS1Yys+lHZsZxOuqiG4SznQvl44DpCKXaLRdZ36Z14HIP+Tt6JpGi9EtUN3DlczdlaRF23HNr7xAeE6pVTXNLg7WrWTdmziyI7O2flazdlN6Kn+TxxrENuZRBPcf5RwRTWay4nrsvjSudGGndjT2vdoMiLfDuEMHJepdqy2URK7zsa3Wn3PQwMHUO0qoFuO2hUqJI7mNVsJ7jmMluMm745ZnPTCB4fVubtGWBuba/xwnMPL0JsDeF95zzx+21BlWryjUpbSJwqjYudfvx09EDSP52TB6O7q8ZoTbb1S6ChM/7rLRmJzxxt5rpU2bLtIYiecQZ+rVWLXa+9WfEHR9Hcmhm/WiSHqQ+UyUXslnIZ/qDanZ+612PJyeXmnshKZ3AsGFFkbm6ZpZtWu4bm3zONF2TB7oHosVbvEkDW4P7YuBejNAfyb43vek8XPrlPwfRRtw86YcMa5dSKcaBdJ7IQz6D9jOCGbneKGKrZtPGl1OBmFkdit76WSa2nDJpZwvRevIZr4+C3oOqYa9m6LRBOtybvzeshy4N+0C//22Nzr3dFq4vzdyqhV6Fg0TcM9Rb0Pfa3n7KLvS0nsEoeu6+a8cfYV8Ynd8Xo44VUt2HE9rIguuZHETjgDe66byv5DADjyvrRjO4uIz29WBt7KVx8oIpxILN6X30X3+ckaWopr0ijWnoBDdWog+6ZBVkensBcXkTv7CiC+aje9H/TOg3q/UsgKHU92O352l/UCDdhfA5WNFgSWAQR2HSRSVYuW48Y9Id5I0mjDjuqmWuNCPJLYCW0SHq8EFMF1kth1Fv6NOyEYQu/Wlf05PfE4Re2VaEzRxMfnKNFEtFp3yxDIdVoYWCuMTRQNL76D7g8AYNNiM4DSjk0MLZWxhpcgQIEbhnZVjzdIUyMhNEdtTtyTR6M547M3ma87O5LYCW3SfYaqGBXv3Eo4co4nCx2CYatQNXwMaBoTeojaK5H4N+8msHUvOLPIu3s23gC8sVedu9ti0URrcq6ZhL1nNyI19TQuXG4ev2uEqhwtPwpHRaHZ4bhGD0HLcROpbSCw+1DcuQk91e8byjs/rkzAUCNnTxt3xjkxJj47Cf/Y+MMf/kBpaSlut5spU6awZs2adp/7xBNPMGPGDIqKiigqKmLmzJlnfb6QOMqiRsX9Kw6w46DX4mgyA2O+bltfacN2Bka1znPLVdi7FPDGHvCFYGCRslBIJjS7nbx7bgDiPe365MMVfdXjF6Vq1+FoWQ7cE1WW33pvrJHYrZfErsPRdd38fhtzjgZ1zbC7Wj2Wil3bJDSx+9e//sWjjz7K9773PTZs2MDYsWOZPXs2p061rc9funQp9957L++//z4rV66kb9++XH/99Rw/LhOqnY2rZzE1xT2xobN/6Q6rw0l79BYO94uKVPtHhBOJI9Loo+HldwHIf2AOgOldd/eI5FQi59+r1LG+pWsJHa8wjxsij5d3KTWv0LGYe2NXx4+lGIndlgoIhDs7qvQmdLiccEU1ZDlwTYgvn6+LJtJlhVCc0/mxpQIJTewef/xxHnroIR588EFGjBjBn/70J3JycnjyySfbfP6zzz7LF7/4RcaNG8ewYcP4y1/+QiQSYfHixYkMU2iHxhHqE8O7dqfFkaQ/gd2HiNR5ISebtfkDyRK1V0LxvrYE3duEo7Q37svHse+0qrzYNbh9+LlfbwVZZX1wXz4OdJ2GF94xj88eCHlOZZYra646HnPOrlXFrqwQCt1qA8jOSgsCS2OMsRTXuGHYsuP969bJfN05SVhiFwgEWL9+PTNnzox9MZuNmTNnsnLlyvP6O5qamggGg3Tp0qXd5/j9furr6+N+CR2DZ5K6U8reIRW7RGO0Yb3DhhOxORjVXS18FxJDfdS7Lv/jt6DZbLwU/S9+dalaI5as5H3sRgAann8bXVflObcD5igRu/nvEDoO98SRYLMROn6K0IlYt0nTYHz05ksEFB2LL5pEZ7dqw4LM150PCUvsqqqqCIfDlJSUxB0vKSnh5Mnzexd885vfpFevXnHJYWt++tOfUlBQYP7q27fvJcUtxCi9UpUu+h3ZSWWj9HgSiXGHur9MVQfGS7UuYQR2HcS/dhs47OR97EbCEXhllzpn9aaJc+GZczVaTjbBA8fUvyHKHdEq41v7oDFgUXBpii03G+cI5QbdvD4+cxYBRWIw2t5GG9w8HoLN0SmEyTKD3C5Jq7n72c9+xvPPP8+rr76K2+1u93mPPfYYdXV15q+jR492YpTpTdGEwYRtdro01ohRcYIxKnYruqk7VEnsEkf9P5RoInf2FThKurLsKFQ0qrbataXWxnYubJ4cPHOuAqDhXwvN4xN7woBCaArC2/vaebFw0bgviwoo1scrVCYYFTtJ7DqMcFUNwb1qp5970qi4c1tPqXnG4mzoX2BFdKlBwhK74uJi7HY7FRUVcccrKiro0ePsn1q/+tWv+NnPfsaiRYsYM2bMWZ/rcrnIz8+P+yV0DDa3i9P9lUvrieUyZ5coQierCB0uB5uNd/LVXOOEJFNlpgt6IEjDi4sAyLvvZiDWvrx1CLhSoP1ttGO9ry4m4vMDqi14Z7Rq96K8VTsc10T1vvSvjU/sxvVQfoLHGqBCzAM6BOMmN2toKfYu8dmbUV+YKKsWz0rCEjun08nEiRPjhA+GEGLatGntvu4Xv/gFP/rRj1i4cCGXXXZZosITzpPIGHWnGtwkwzuJwmg7hAYPpD4rl245aqOA0PE0vrOcyOk67D2Kybl2ctzWhmRvwxq4Lx+Ho28PIg2NNL79kXl83nDlabfqGBypsy6+dMQ9KZrYbdmNHgiaxz3OmFGx2J50DL7o9TC7VRsWYGP0eywdjbOT0Fbso48+yhNPPMHTTz/Nzp07efjhh2lsbOTBBx8E4BOf+ASPPfaY+fyf//znfOc73+HJJ5+ktLSUkydPcvLkSbxeuRWyim5TVRmgaM9OQmJUnBCM+bryIaoNO6Gn3I0miobn3gQg754b0BwO3tyrVI2Du8Do7hYHd55oNpvpadfw/Nvm8V55as0YwCtStetQssr6YivMQ28O4N8Rv79tgggoOpT2/OsgVrGTxO7sJDSxu+eee/jVr37Fd7/7XcaNG8emTZtYuHChKag4cuQI5eWx25w//vGPBAIB7rzzTnr27Gn++tWvfpXIMIWz0PdKVcYYdGI3u0/KpvFEYLQeNvWOJnZy0UoIofJKmpYow/O8qCec0Ya9M0m969oj726V2Pk+WEeoPOa1YbRjX9opnnYdiaZpSh0L+Ne1mrMTAUWHEWn04d+yBzhTOHHSC+Ve1fpOlZswq0i4eOKRRx7h8OHD+P1+Vq9ezZQpU8xzS5cu5amnnjL/fOjQIXRdP+PX97///USHKbSDa1BffNke3CE/u1YdtDqctCPibcK/TU27v1MYFU7IfF1CaPjXQohEcE8di3NgXw7WKLNTmwa3D7M6ugsja0Bv9cEXiZzhaedxqvVia8TTrkNxGQKKVomdsaVk2ykxKr5U/Bt3QiiMvVd3HH3iHTWMat2wrsm1xzkZSVpVrJAcaDYbDUPUp17tapmz62iaN+yAcBitVwk7HCXY5W40Iei6TsNzbwGQd5+q1r0ctTiZ0S+5vevaw/S0+9dC09MuOwtuGazOvyTt2A7FfZlSaLZWxpYWQpds1dLfLkbFl0RL/zqtVQld2rDnjyR2wjlxTVD9Hfs2+aToaIw2bO1IVa0b3g1ysqyMKD1pXrWF4MFjaLnZeOZcTUSPzaGlimiiNZ6516DluAnuPYx/Q+ym687ov+fNveJp15G4JgwHTSN06AShyhrzeJxRsbRjL4n2/OughXBCOhrnRBI74Zz0uUIldn0O7uC0z+Jg0gzjQranv/jXJRJDNOG57VpsnhxWHoPjDZDvhFllFgd3kdg8OeTeEvW0ayGiuKwnlBYoT7uF+9t7tXCh2PM9ZA3pDxCXSIP42XUEeihk7stuLZwIhmFLdOmHrFo8N5LYCeekyxSV2PWvPMSm/U0WR5M+tLyQfdRN+TWKcKLjiXib8C54H4D8qHfdy9Fq3S1DUnt1m6GO9b66mEhzC0+7aNXuRZme6FDclykBRXOLrR8QE1CI5cnFE9i+H73Rhy3fg3N4/N3W7mq1dSLfCQOLLAowhZDETjgnjh7FNHTtjg2dw8t2WR1O2hDYcQC90YeWl8tixwBAjIkTgXf+EvSmZrIG9cM1aRSNgdh2hjtTtA1rkD19Ao7e3YnUeWl6e5l53BCDrDoGx2V9dodhKGNbrxYbW6JEOOVeKG+wIrLUx7dK2T65J49Cs8WnJsZ8nWEILZwdSeyE8yIwMqoI2yBzdh2Foa4Ljh6BT7dT5JY1OYnA9K677yY0TeOtfapNOaAw9Sukms2GJ2p90vBiTB3bJx+m9gEdmL/bouDSEFe0YuffsBM9HJPA5jphWLF6LH52F0fMv+7MbVMinLgwJLETzouCSaod69m1g7AYFXcIhrru+EClthvfI7W81FKBwN7Dqm1mt5N312wg1oa9Y3h6fL/z7lb/rqYlawidOm0enxet2r2yC3TxtOsQnEP6o3ly0Jt8BHbF2z8ZNwnSjr1wdF0X4UQHIomdcF70maEqdoOP7WRPtcXBpAmG0enG7up7K23Yjqfhn8riJGfmVBw9ijler9qTkHrede3hHNRPKTbDYbyvxlY43jQIXHbYd1otTxcuHc1uxz3x7H52IqC4cIIHjhGuPA3OLFzjhsadq22GA7Xq8biSM18rnIkkdsJ5kT12CBHNRreGSrZtFbOmSyV8uo7gAZVhLPKoaqi0GToWPRRSpsTENk3M363ak1N7q3ZlumBsomh4YWHsmEsZFoOq2gkdw7k2UGyvBL8s6bkgjLWK7nHDsLldcec2RduwAwqhKLuTA0tRJLETzgubJ4f6/mrA/9RKmbO7VIzha62sL7vCBWioAWyh42h6fy3hU6exdS0gd9Y0dD2W4Mwbbm1sHY3ntmvBYSewZU9ci9D4dy7YrSwjhEvHZVTsWgko+hcoo+JAGLbJve8FEWvDynxdRyCJnXDe2MeqTwl9s3goXCr+6Hxd/TB19z+0q6qwCB2HWa2bNwvNmcW2U6ot6bKrNmU6Ye9aSM7MaQBxK8Zm9INuOVDtgw8PWxVdemG0YoN7DxOujUlgNS02Z7dR2rEXhOlfN3n0GecksbtwJLETzpuSy9UFrceBndQ2WxxMimPc7e/pq76nMhTcsYRrG2h8+yMgtnrLqNZdPzA9k2hDRNHw8rumYtNhg1ujI0svSzu2Q7B3LSSrrA8A/ta2J9HkY3NFZ0eVuoRrGwjuVXcdRtJsENElsbsYJLETzhvDqHho+U42HZe+zsWiRyLmB8KKrqpiJxetjsU7fzEEgjhHlOEcPZhQRLUjIX1EE63Jvf5ybAUewidO4Vu+0Tx+R7Qd+94BqPNbFFyaYdietN4bawz3S2J3/hjXwqwBfbAXx7sPH6iBer+qsht2MsK5kcROOG+cwwYQdGWTE/Cxd90Rq8NJWYL7jhBpaETLdrHIrhzWJbHrWMw27MduRNM0PjwMVT7omg1X9rM4uAShuZx4brsOAG+LduyIYtXq94fhrb1WRZdemEbFrQQUY6KJ3eE6qJH1i+dF8zrVhnVNGnnGOUM4MaYEsuydGVVqI4mdcN5odjtNQ4cA4F0rc3YXi/FhEBo5jEbdQb4TBnWxOKg0IrDviFIs2u147rgeiLVhbx2a3h8QnrvUv9f7+gdEGlVmoWkxEcXLonvqENymUfEO9EjM2LPQrfb0Qmy3qXB2jOuh8T1tibRhLw5J7IQLInuC+oRw7twlpqcXidF6ODlQXcjGypqcDqXh+bcByLl2Mo7uXaj3w6L96ty8NG3DGrgnj8ZR2gu9yWfOGALcPlT9H1t7Ao7UWRhgmuAcUYaW4yZS5yW4L757YczZGdUmoX1ajqW4Lxt1xvkNkthdFJLYCRdEjylqErv02G6OyU7Ei8KYy9naK5rYic1Jh6GHw6Yq1BBNvL1PtSEHFsHo7lZGl3g0TTM3bBjtaIASD1zRVz1+VUQUl4zmcOAarboX/o3x31CZszt/ArsPqbGUnGycwwfEnWsKwq4q9VgSuwtDEjvhgvBMUCWPsor9bDkWtDia1CPibSKwU/mMfVCgFGCS2HUcvo82EC6vxFaYR+7sK4AW3nXD0mOF2LnIu1O1Y30frid0sso8brRjX9kpK8Y6Atd4dS1s3hjf3zbm7LZUyPf5XPiN+boJw9EcjrhzW08pVWwPD/TMsyK61EUSO+GCcAzojT/HgzMc4NDag+d+gRCHf9MuiESw9erO2rCSeUli13EYbVjP7dehuZxpuULsXGSV9cE9aRREInhfftc8fsNAyMmCQ3WwSapJl4xrvMqU/ZviK3ajuoNdg8omOCFdjbPSvDY6X9fK5gRiXoCyRuzCkcROuCA0TcM/TLUgmjbutjia1MMYFPaNHElEh5Jc1SYTLp1wvZfGtz4EYm3Y+dH/olP7QO80WiF2LjyGp92Li8xjOVmxFWPSjr103OPUnUJg2z70QKx74XbErDmkHXt2TOHEpDPn64zv3Thpw14wktgJF4wnuqTZvXs34cg5nizEYRgTHx0gbdiOpvG199F9frIG98c1fnj8CrEMqdYZeOZeC1kOAtv34d+x3zx+W9Ss+PU9smLsUnEM6I2tMA/dHyCw80DcubEyZ3dOzmZMDKqVDbHWtnD+SGInXDAlU6Nzdsd3s7/G4mBSCF3XzVVi60uUcEIuWh1Ha++6bZWxFWI3ptkKsXNhL8onZ+ZUALwvxdqx0/tBcTac9sFHYkV5SWiahitatWtu1Y41EjtpebfP2YyJq5swxXnpLnhKBJLYCRdMdrRiV3ZKBBQXQuhIOeHKGshysCRbtbNlfqRjCB48TvPqLaBp5EW93OZHP2tnlkF+Gq4QOxeGiML7yrum11rLFWPSjr10jMTO30pAYbQPt51CuhrtcDZjYsMDsKwwM9+7l4okdsIF4yjtRSDXgzMc5Oi6A+d+gQBA8wZ1h2ofMZi9TepqJRW7jqHhJTVLln3VZTh6diMcgQV71Dmj/Zhp5Fx/Oba8XELHT9G8crN53BCRLDoA3oBFwaUJ7gltCygGd1Ezjd4A0tVoh7MZE0sb9tKQxE64YDRNIzhcfVo2bRIBxfniX6cSu7phap6ktAAK3FZGlB7ouh7zrouKBlYeg1ONahPA1aUWBmchNreL3DlXA7HEF1Rra2ARNIfgnf3tvFg4LwxlbGDnQXPTB4DdBqO6qccyZ3cm5zImNr5nMoN8cUhiJ1wUeVE/u7w9u/GHLA4mRTCMiff1k/m6jsS/dhuhQ8fRcrLJvelKINZmvHkQONN4hdi5MNrSjQuWEmn2A8rLb660YzsER49i7D2KIRLBvzV+Ea+xgUISuzM5mzGxrkvF7lKRxE64KIonq0+GQeW7TXdwoX10f8C88K/sKhsnOhKjWpd7y1XYcrPxBWFhtBKVKd517eG+fBz2Xt2J1Htpem+Vedz4viw/ChWNFgWXJhhGxf5NrebsDGWsrBY7g7MZE5/0Kg9AuwYju1kRXeojiZ1wUbgNAUXFfrYclUGdc+HfuhcCQWxdC/gw0hOQxK4j0P0BvPMXA7E27HsH1WxTnzyY2MvK6KxHs9nIu2MmAN4W7dh+BTCxp3L2f12mKS4J97jonF2r1WJGtWlnFdLVaMXZjImNCueQrpCd1ZlRpQ+S2AkXhaNfTwKePLIiIY5vkA0U58IYFNbGjqSiScOmwUiR8V8yjYtWEKnzYu/Zjezp44GYGnbuMLX4PtPx3DELgMZ3VxKuqTePG1W7VyWxuyTMil2rxK5vPnTJhmBEJXdCjLMZE0sb9tKRxE64KDRNIzxCVe1aK8KEMzHsEE4NUneoQ6KqOeHSaHgxKpq4cxaa3c5pHyxVnqcZq4ZtjWvkIJwjyiAQpPH1pebxmwcr+5Ntp2BPtWXhpTyG5Unw4DHCtbEdYpomfnZtcS5jYhFOXDqS2AkXTcFE9clZuH+32Cacg+YNKrHb2Uu1beRu9NIJV9fS9O5KADx3qTbsm3shFFGzOUO6WhldcuGJetq1VMd2yYar+6vH86Vqd9HYi/JxlPYGzrzJNZKTLZLYmZzNmFjXYx52co28eCSxEy6aostUYje0fBdbT1kcTBITPl1H6NBxAJYXqLt72X946XjnL4FQGOfowbiGlwExledtGS6aaE3evJmgaTSv3EzwaGya32jHvrZLzdsJF4d7fNtGxWbFTgQUJmczJj5cB/V+tS1mqNyYXTSS2AkXjdGCKD11kK1H/RZHk7wYd/GOAX1Y5VWb6OVu9NJpeCG6QuzuGwA4Ugfry0ED5g6xMLAkxNG7BPfl4wDwvhxbMTazDDxOtb5p/QmLgksDXFGj4uZ2ErsDNSphEc5uTGy0YYcXZ7ZN0aUiiZ1w0Tj69iCQl09WJET5etlA0R7GxT48ajh10bvRYXI3ekkE9h3Bv2En2O145inVp9FOvLwvlHgsDC5JyWvRjtV1VZ5zO+CGgeq8tGMvHlc7ytiuOdAnH3SQrgatjIkntp/YyY3vpSGJnXDRaJqGPlK1Y4Nb5FOhPYyL/YkB6uI/ohtkyd3oJWF41+VcMwlH9y7oekwNm+nede2RO+cqNJeT4O5DBLbHVk4YIpM390IwbFFwKY5r9GCw2QifrCJ0Ml4CK352MYL7jypj4myXEvS0YqsIJzoESeyES8KYs+t+cBfVTRYHk4Touq4qS8CWEpVxyN3opaFHIqYnm9GG3VapdnK67LEKlBCPvSCPnJnTAPC+HBNRTOsL3XKgphk+PGJVdKmNLTcb57BS4Mw5O+P9vkUqdjRHq3WuMUPPMCYOR9T7GOQaealIYidcEnkTVbIypHyPrM5pg/CJU4QrT4Pdzoe5avBL7kYvjeZVWwgdPYktL5ecG6YD8Fq0YDyzDPJcFgaX5HiiZsUNryxGj0QAZXkyJzqT+JoU3i8aox3b3KodOzrqV7lVro/4N0QTuzZsTvadhqagsoEaWHTGaeECkMROuCRcY1XFbkDlAbaJgOIMDJuTrOFlbKpVGYfcjV4ahmVH7i1XYct2EY7AgmhCYuxAFdomZ9Y0bHm5hE+connlZvO48X1btB8axbroonC1o4wdFU3sjjVAja+zo0oujIqde0L7/nWju4NdMpNLQr59wiXh6FNCsKAARyTMyfX7z/2CDMO4yDePGI4vpBSIcjd68USa/TS+9j4QWyG25oTad5rvjPmyCW1jc7vIveUqALyvvGceH1sC/QvAF4J3RQd1UbjHRwUUm3eb4hSAfBcMKFSPM1lAEfH5CexQnxGGirglIpzoOCSxEy4JTdMgKqDQt+1GFy+sOIy2zNH+6kI2qrusuboUmt5bRaTei71Xd9O+w2gf3jgYXI72XysoPHeqFWPeBe+jB4KA2pJgiChEHXtxOIeXgTOLSE09oUPx3jFGOzaTjYoDW/dAKIy9Wxccfc7M3oykV0ZVLh1J7IRLpssk1YLodXg3FY0WB5NE6JGI6WG3sZv6HslF69IwRRN3zESz2fCHlJoTZIXY+ZJ9xXjsJV2J1DbQtGS1eXxuVE384WFECHURaM4sXKMGAWf62Y2Ovu8zuWLX3GK+TtPi724D4dg+XblGXjqS2AmXTM549Yk6+OSejL5wtSa47wi6twktx80yZykAY7pbG1MqE66pp9FYIRb1ZPvgsDJ+LcmFKb2tjC510Ox2PLdfB0DDSzGz4oFFqrIU1uGtfVZFl9qYfnab48uexvs+k6+PhjuA0bJuya4qldwVuqFvfmdHln5IYidcMq4xSlI3oPIA24/J5LWBIZxwjh7CjhrVIxwtid1F0/j6UggEcY4ciGuE8jQx2rBzhsjA9YXguUO1Y5veWUbEGyvPGSKK+bvaepVwLlxj1bWwdWI3spv6/XhD5lZDm8+iiDXn67qrsQDh0pBLoXDJOPqUEMzLwxEJU7HpoNXhJA2GMXHjsOH4w2qIul+BxUGlMA0vqjasUa3zBmKD/tKGvTBcY4eSNbAvenOAxjc/NI/PGaJWsq0rh6P11sWXqhguAYEte0w7GVAWPGWF6nEmVu3CVTWEDpcDMfVwS7aIcKJDkcROuGQ0TUMboe5Uw9v2iIAiiqGIPdwvKpzoJnejF0vwSDnNqzaDpqmF9ihrDn9YfWCOkkroBaFpmlm1a2ixO7aHB6b1UY8XiIjignEOLUVzO4k0NBI8eCzuXCbP2RnVuqzB/bHnn7nvz/ieSGLXMUhiJ3QIhRNVYtfjyB4RUAC6P4B/uxpU2tBNJXaj5aJ10RiL67Onj8fRS2Vx81t410nCfOEYO3Z9H6wjdOq0edxox4pZ8YWjORw4Rw0GwL95T9y50Rk8Z2fO17Vhc9Icgj3V6rGMqnQMktgJHULOOJXYiYBC4d+xHwJBbF0KWElPQC5aF4uu66YpsdGGrWqCZdH1V2JKfHE4B/bFNX44RCI0vrbEPH7jYHDaYXc17Ky0MMAUxTVG/Yc8Q0BhVOwy0PLEXCXWxnzdriol2OmaDT3PLOYJF4EkdkKHYAgoyir2s+1EyOJorMe4Q3WOHcbOalVOksTu4ghs3Utwz2E0t9M0131jr/owGFsCA8Tw+aJpqx1b4IJrStXjBXvaeJFwVs4moNCAE151Y5Ip6JGIOZbS1sYJoxAwSoQTHYYkdkKHkDWgD6HsbNwhPye2HLU6HMsxLmReQzjhVM7+woVjVOtyZk8353NekxViHYLntmvBZsO/fgfBg8fN48b3dcFuZGb2AjEEFP7Nu+MEFB4nlEVvQjKpqxE8eIxInRfN7cQZVbO3pGViJ3QMktgJHYJms8EwNVsSkg0UpkHpgb6xjRNyN3rh6OGwOV+Xd5dqwx6pgw3lqvpxyxALg0sDHCVdyb5yIhC/Yuy6AZCbpfabri+3KrrUxBBQ6N6mMwUUGThn5zfasKOHoGWduRrG+F5IR6PjkMRO6DAKJkQFFIczW0ARaWgkuFcNgK0rFuHEpeD7aAPhU6exdSkg55rJQEytOa2PMiYWLg3TrPiVd80dp24HzI4WV6Qde2Gcl4Aig+bsmterm9y25utEOJEYJLETOgxDQDHo5N6MuiNtjX+z6l85+vZgnV/1XuSidXGY3nVzr0FzZgGxREPasB1D7i1XobmcBPccJrB9v3n81uj39809EIq082KhTUwBxaZ4p+cxGWh54jeMidtQxO6uUv+3itzQO6+zI0tfJLETOgzXaHWXOqhiL1vKM/eTwGjDZo0dZu4/lMTuwok0NdP45gdAbMh/V5VSa2bZ4MZBVkaXPtjzPeTMnArEt2On94Uu2VDlgxUyNntBnEtAUe6FygzoakSa/abt09mEE6NlVKVDkcRO6DCcQ0qJZDnx+Bs5tv2E1eFYhqGIbRgqwolLoWnRCvRGH45+PXFPHg3E2rDXlEKB27rY0g3D08776nvmwH+WHW6KJs/iaXdhuMap7Qr+Vhsocp0wsIt6nAlVu8C2fRAMYSsuxNGv5xnnZb4uMUhiJ3QYWpaDyFA1mBPcmrkbKIz2y/4+6uI+Uu5GLwrDgsMzbyaapqHrsTbsrdKG7VByZl2O5skhdKyC5jXbzOPGqraF+9U8lHB+OIf0R8t2ZbyAwvCvc08YgdbGRVAUsYlBEjuhQ8kfr1oQJRm6gSJUWUPoWAVoGmsK1aei3I1eOOGaepoWrwIg707Vht14Uu0vzcmCmQOsjC79sGW7yL3pSiC+HTuxF/TyqL28S2QN9HmjORw4R6pyp39TK6PiDErsDNuntubr/C2FEyIu61AksRM6lNwM30BhzNRkDerHRq+SbEpid+F4F7wPwRDOUYNxDlVZnFGtm1UG2VkWBpem5N0RbccuWIIeVOU5mwZzDE87UcdeEO1toMjUil1rdldDMAKFbugjwokORRI7oUNxRgUUg0/uYcvJzOvF+jcZwomhIpy4BLwvRduw0WQjFIE3RA2bULKvnIituJBIdR2+D9ebx43v95KDUO+3KLgUxDWu7cRuRFRAcdILp9K4qxE+XUfokDK9do0/s2InwonEIYmd0KE4h5eh2+wUNtVxeHcG3JK2wmi71A9uIZwotDamVCN49CTNqzaDppEXHepfdQwqm9Td/Yx+FgeYpmgOB55brwWUp53BiGIYWAT+MCza396rhdaYGyjaEFAMygABhX+jmjXOGtgXe+GZJTkRTiQOSeyEDsXmdhEZVAqAf+vejBJQ6LpuzpTs7a0u6iO7q3aWcP4YM17uy8fh6KWu+oYq86ZBakG9kBiMdmzjmx8S8anynKbFqnaijj1/4gQUB9oRUKSxUXHzxqh/3fhhbZ4X4UTikMRO6HDyDAHFocwSUIRPVhE+dRrsdtYWqJa03I1eOOYKsah3nT8EC5UVlqhhE4xr0igc/XqiN/poWrTCPG4kdsuPZtYC+0shfgNFKwFFBhgVGxU71/gz5+v8IWVODHKNTASS2AkdTs7YmIBiSxrfkbbGsDlxDitlc50yWZOL1oXh376PwM4D4Mwi99arAVh6GOoD0MMDk3tZG1+6o2kanttUO9bboh1bWqjUnGEd3thrUXApiGtM20bFo9JcQNGye+Fuo2K3JyqcKHBB3/zOji79kcRO6HCcozNTGdtszJSMGcaOSnVMErsLw/uyasPmzpqGvUDN5RimxLcMBrtcsRKOseWj8b1VhOsazONG1e51aceeN+acXavVYiOKlYCiojE9BRSh46cIV9aAw25WLVvSsg0rwomORy6TQofjGjUIXdPo1lDJ/n01VofTaRgX79rBQ/GHIU+EExeEHomYVSIzuQjAe1H/NFHDdg6uEQNxDi+DQJDGNz40j98yRCUj68rheL118aUSZmK3dW+7Gyi2V1oRWWIx9sM6RwzElu0647wIJxKLJHZCh2Pz5BDp3xeIKsIyQECh67rZbtnbS0n7RThxYTSv2kLo+ClsebnkzJoGwKIDauPBgEL5EOhMPLdfB6gVYwY9PDClt3r8unjanRdnE1CM6qZ+T8euRqwNe6bNCcA2SewSiiR2QkLwRI2Kux3ek5athtaEjpQTOV0HWQ7W5pUBctG6UAzRRO4tV2Fzq7t8ow07Z4i0bDoTz+1KHev7aAOhimrzuCFeeU0Su/PibAIKY85uWxomds2mcOLMxC4Qhl3Gxgm5RiYESeyEhJATTeyGnNyTlheu1hj+da4RA9lS4wTkonUh6IGg2jZBrA1b44MPj6jzoobtXLJKe+GaOAIiERqjPxdQdjMOG+yohH2nLQwwhTAFFK3m7EanaWKnh8Pmv7Utq5M91Sq5y3dBv4LOji4zkMROSAjmBoryzBBQmBsnxg2TjRMXQdP7a4jUNmAv6Ur29PEAvL1PbZwYUQyDu1gcYAZiVO0aXl1sHivKjhlESzv2/GhpVNySEdFW7PEGOO3r7KgSR3DfEfRGH1pONs6hpWecN4UT3aQKnygksRMSgiuqjO1Ve4K9h70WR5N4mqMVu7pBw2gOgcepLCKE88Now3puvw7NrhyIjXafVOuswTP3GrDZ8K/dRvDwCfP4reqtzYLdZMT87KXS3gaKfBeURitW29Po5rd5g7rJdY0dYr6XW2LO15V0ZlSZhSR2QkKwF+UT7tkDgMYt6W18pUciBAzhRB/VehjZTYQT50vE20TjwmVArA1b4YXV0VnzOUOsiiyzcfQoJvuKcQB4W1Ttrh8ILjscqIVtaajo7GicQ/qjuZ1KQHGw1QaKNDQqNoQTbc3XgShiOwNJ7ISEYczZdT24J63d6oMHjhFpaERzO1mfUwrEFG/CuWl8+yN0n5+ssj5mdeONvaADE3tCHzEwtQyjHdtSHetxwnUD1OMF4ml3TjSHA+fIQcCZ7VjjOpFOCbKxcaItRWwwDLuioypyjUwcktgJCSNnrJqzG3Ryb9oNCLfE3Dgxeghbqh2A7D+8EAxTYs8ds9CiQzct1bCCdeTOuRqyHAR2HCCw66B53GiPv74HItKOPSeuMdF2bJorY3V/AP92tf+vLeHE3tOIx2cnIImdkDBcozMksTOk/eOGmWajktidH+GqGpqWrgXAM09Vhw7XwqYK1cq+5UzTeqETsRfmkXPdVAAaXolV7a4pVR/O5V5Yf6KdFwsm5pzd5lYVu+h14nAd1Pk7O6qOx799HwRD2LoW4OjX84zzxueAjKokFknshIRh3KX2rzrMzmNpcNVqB6NiVz94KE1BcDtgYJHFQaUI3gVLIRzGNXYozkFKbrkg+tl3eR/olmtdbILCMy9mVqxH1RJuB8weqM6Lp925MSxPAlv2mN9DgEJ3bNQgHQQU/g0xY2KtDcnrNrnx7RQksRMShr2kK+EuRdj1MPVb91sdTkLQQyH825Q4ZE9P1XoY0U12mp4vphr2jpnmsddFDZtU5F5/BVqOm9ChE+ZgPMTa5G/tVbY0Qvs4hw1AczmJ1HsJHYovcabTBormcwgntrXYESskDvn4ERKGpmm4o3eqBfv3UNtscUAJILDnMHpTM1puNuuzVcVJhoLPj+CRcprXbAVNw3ObqgrtroLd1ZBlgxsGWRygAIAtN5vcG6YDsXlIgCv6QpdsqPbB8qNWRZcaaFkOnCNUibP1nF06GRX7z7JxIhxRxtYgiV2ikcROSCi5aS6gMB3Wxw5lW5V6O4mM//zwRme2sqePx9FTZcNGW+/qUig4c3e4YBHG/KN3/mL0cBiALLvaRAGijj0fzA0UW9oWUGxPcWVsuN5LcO9hANxtCCf214AvBDlZUFbYycFlGJLYCQnFMCoeXJHuid0wc0ZG7kbPDyOx88xT3nW63qINK2rYpCLnmsnYCvMInzqNb8Um8/jcaLv8nf3QHLImtlTBNdZI7NoWUByoAW+gs6PqOIxKpKNfT+zFZw4ZG/N1I4plVCXRyLdXSCjGXWpZxX52lKffld/YEdswdBj1AXDaZf3V+eDfsZ/AzgPgzCL3lqsApYQ9UgfZDphZZnGAQhxai5+Tt4U69rJe0NMDDQFYesii4FKElpYnLQUUxTnqe6gTa1WmIoZwQubrrEcSOyGhOPr3JJKbizMcoHr7YavD6VD0QND0bNrbS7UehnVVLSrh7BiiidyZU7EX5gHwerRDNatMtWuE5MJoxza+8QF6IAhELWmi1VXZHXt2nMMGQJaDSG0DoSPlcefSwc8uZkx8ZhsWJLHrTBKe2P3hD3+gtLQUt9vNlClTWLNmzVmf/+KLLzJs2DDcbjejR4/mrbfeSnSIQgLRbDaco9QgTs6ePdSnketJYOcBCASxFeaxMasXIPN154MeiZgrqow2bDgCr0c3z4kaNjnJvnwc9pKuRGobaHo/dh2fG03s3jsIjSncSkw0msuJc7gqRbe3gSKVlbFnU8RG9NgMoVwjE09CE7t//etfPProo3zve99jw4YNjB07ltmzZ3PqVNv/e1esWMG9997LZz7zGTZu3Mhtt93GbbfdxrZt2xIZppBgcsfF5uxSudXQmuZNMWPirZXKs0nuRs9N89pthI6eRMvNJuf6ywFYcwJONarF6Ff2szhAoU00ux3P3GuB+HbsqO5qmX1zCBYdsCq61CBmVNyOMjZFr4+hk1WET5wCm80cv2nJ4Vo1P+iywyAZVUk4CU3sHn/8cR566CEefPBBRowYwZ/+9CdycnJ48skn23z+b3/7W2644Qb+/d//neHDh/OjH/2ICRMm8Pvf/z6RYQoJxhBQDDq5J6VbDa2JU8RKm+G8MSwzcm++Clu2kr4aqsobB4HLYVVkwrkwzIobFy4j0ugDQNPiV4wJ7WMqY9tZLbbvNDQFOzuqS8e4FmYN6Y/Nk3PGeSNhHV4MDhkASzgJ+xYHAgHWr1/PzJkx41GbzcbMmTNZuXJlm69ZuXJl3PMBZs+e3e7zAfx+P/X19XG/hOTCOcZI7PayrSJ9nEwN4UTj0GHUNKsL1pCuFgeV5OjBEN4FSwDIi5oSB8LwlhpVFDVskuOaMAJHaS/0pmYaFy03jxtmxR8eJi39KjsKs2LXagNF91zolqNaljurrIru4mk2Nk5MGNHm+a1y49upJCyxq6qqIhwOU1JSEne8pKSEkydPtvmakydPXtDzAX76059SUFBg/urbt++lBy90KM7B/Yi4nOQGmji5Mz0WS0Z8fgK7VN/JEE4M6apWLQnt0/TBOiLVddiKC8m+ciIAy46oZKBbDkzrY3GAwlnRWphJe19ZbB4f0lVVY4IReHufVdElP87hZeCwEzldR+h4rH2haaktoPCb83UinEgGUr4o+thjj1FXV2f+OnpULNCTDc3hwDFMua67du9JyVZDawI79kEojL1bERs1dbWSjRPnxvtq1Lvu1mvRHCoLNnbD3jxY/K1SAUMd27R4FeHaBvO4UW0Vs+L2sbldOIcOANpvx6aagELXdbMV625DOKHrkth1Ngm7jBYXF2O326moqIg7XlFRQY8ePdp8TY8ePS7o+QAul4v8/Py4X0Ly4RkXm7NLBwGFuTpn7FC2iXDivIg0NdP45odArA3rC8Ki6BrhOdKGTQlcw8tU5SkYovGND8zjc6JzdiuPQUWjRcGlAOcSUGxPscQudPA4kdoGpfqNrk1rydF6qPOrNYFDZVSlU0hYYud0Opk4cSKLF8fK9ZFIhMWLFzNt2rQ2XzNt2rS45wO8++677T5fSB2co2OrxVLtjrQt/C0UscbdqMj4z07TohXojT4c/XrimjQKgCWHoDEIffJgYk9r4xPOH8/t0XbsqzF1bN98GN9DGe2+tdeiwFIAI7ELtGN5sud0am3xMGxOnKMHo2WdOYtiXB+HFisDdyHxJLTx8eijj/LEE0/w9NNPs3PnTh5++GEaGxt58MEHAfjEJz7BY489Zj7/q1/9KgsXLuTXv/41u3bt4vvf/z7r1q3jkUceSWSYQidgqMEGn9zD9gr9HM9Ofpqjd9u+ocOobFJGrcOLLQ4qyWl4RZkSe26/Dk1TVU6jbTdniJozElIDz+3RiuuyjYQqqs3jxoqx16Qd2y6x1WLxGyh65UGRG0IR2J1CAgpjvs49rp35umiHRkZVOo+EJnb33HMPv/rVr/jud7/LuHHj2LRpEwsXLjQFEkeOHKG8PObAffnll/Pcc8/x5z//mbFjx/LSSy8xf/58Ro0alcgwhU7AObwM3W6nqKmWo/tS6KrVBhFvE8E9aovG3l7qk2xQF8iWbQntEq5toGnxagA8dyhT4gY/vH9InZ8jpsQpRVZpL1wTR0Akgve1983jNw1WNzkbT6r1cMKZOEcMArudcGUN4fLYXIqmxar+qdTVMFeJTZBVYslCwkeVH3nkEQ4fPozf72f16tVMmTLFPLd06VKeeuqpuOffdddd7N69G7/fz7Zt27jpppsSHaLQCdjcLmwD+wNg37knpVoNrfFv2weRCPae3disqzKd3I2encY3PoBAEOeIMlxR9/1FB8AfhoFFajG4kFoYVbuW7diSXJjaWz1+Qzzt2sSW7cI5VF0Lz9hAkWJGxXowhH+r+je0tXGipXBCRlU6D9GgCZ2GsYGirHwPu1K4aOffFL1DHTdULlrniddsw8Z8Ko123a3Shk1JPHOvAZsN/7rtBA/HbIwMs+IFkti1i2tM2wKKVFPGBnYdRG8OYMv3kFV2plfRSS9U+8CuwTC5ees0JLG7SHwNfnZ8KBPCF4IrKqAYXJHaAgrDmNg9dpgYb54HoZNV+JZtBGJD96d9yr8OZDdsquLoUUz2FeMAzN2/oLaHZNmU0e6e6nZenOG0t4HCuEHcXaWMu5Md079u3FA025nphFF5HNxFPD47E0nsLoJdK/azb/itNH/iUYKBFO4pdjKmgKJ8T4ondkoR2zx8GOVe0IAR0optF+/8xaDruCeNIqt/L0CpJsO6SojLiiwOULho2mrHFrrhStVplKpdO7S0PGkpoOibDwUuZfS8OwWS4mbTmPjs83XS0ehcJLG7CMom9Cdkz6KgsZb1r260OpyUwajYldRXcOBgak5Wh+u9BPcrE2xDOFFWBB6nlVElN8aGAkM0AbEPfPGuS21y51wNWQ4COw7g33nAPN7SrFhPfRF8h+McNRhsNsKnThM+GZtLSbUNFP5zJHbGDfxISew6FUnsLgKn28GJy68GoOrFxWd/smBiy8tF66/mMGzb9+BPwWKn4T3l6NeTzYFCQNqwZyOw/6i6+Nvt5N56DQDlDbDmuDo/Z7CFwQmXjL0wj5zrpgLgfSVWtZtVplpvh+tSZ16sM7HluHEOKwXan7NL9sQu0ugjsOsQAG5RxCYVkthdJD3vUbNCfVd/gC8ddmR1Ejnj1a38gPI9KTl/Y7Yexg6V+brzwGjRZV91GY5uquf6xl5lYjupF/SWRTEpj2de1Kx4/mKzrZjrVMkdiKdde5gCiuhoh8HoFEns/Fv3QjiMvaQrjp5nzqKcalQbSDRE9d7ZSGJ3kYy9ZQw1eV3wNHtZ98Iaq8NJGYw5uyHlu1PyTt4QTrjGDTNX/8j8SNvouo73ZZXYGSvEIGZKfKu0YdOC3OuvQMtxEzp0Av+GHeZx4+f7xl6ISDv2DFxRQ98zKnbRHGlnFQSTWEBhugO0U60zru8Du6hEX+g8JLG7SOwOO6euvBaA2lekHXu+GEPDg0+mpoDCv1ndXQdHDONYdP/5SBFOtElgyx6C+46guZ3k3nQlAAdrYMspZX9ws7Rh0wJbbja5N84AYvOUAFf1h3ynsrwwWu9CDNe4tgUU/Qshz6k8Hvedtii488AwJnaPE+FEsiGJ3SXQ/z7Vgihdv4yGOr/F0aQGrtHqNr53zXH2H2qwOJoLI3y6jtBhtSlld49oS7kQ8l0WBpXENERnrnKuvwKbJweIiSam94OuOVZFJnQ0ce3YsCozuRwwe5A6v0DasWcQt4HiROwu16bFbhaT+ea3eWN0X/Y55uskset8JLG7BIbPHEl1YQk5AR9r/rnS6nBSAntRPvRV297DO/amhFeTgdEyySrrwxZfHiDzde2hh8PmML3RhtX1WGInbdj0IufqydgK8wifOo1vecwpYG705/zWvuRuK1qBLduFc9gA4CwCiiTdQBE+XUfokCrDutrZEbtVEjvLkMTuErDZNGquVe3YpvnSjj1fcqNLsMuO704pAYXfuEMdN0zUXuegedUWwiersOV7TNXkzirVWnLZYfZAiwMUOhTNmaWsT4hXx07rC8XZUNMMHx2xKLgkxhhNad6UWhsoDMFHVlkf7IV5Z5yvakI8Pi1EErtLZNADqgVRtnklpyubLI4mNWg5Z5fsyq+WGPN1rvHD5G70HBgf7rlzrkJzqclpox13TSnkSfs67ciL+hQ2vvEBuj8AgMMGNxuedmJWfAYt5+xaYlxXdlRCONLZUZ2bcxkTG9dH8fi0BknsLpGyK4ZwqrgP7pCfNc8utzqclCBVlbHGXXVw+DCO1qtjUrE7Ez0QxLvgfSBmSqzr8LrRhpUVYmmJe+oY7D2KidR5aVqy2jxutN0X7QefOEPFEVPG7ooTUJQVQW4WNIdgf41V0bWP2b04R2In10drkMTuEtE0De8sVbULvS7t2PPBSOz6nj7K3sONFkdzfoQqqtWAs6axp0TJOfsVqPU/QjxN768hUtuAvaQr2ZePA2B9ORxrUHfv1w2wNj4hMWh2O57b1GhKy3bsxJ7QJw8ag7DkkEXBJSnO4WXgsBOpriN0rMI8btNiLcxku/nVdT2miB3f9nydCCesRRK7DmDYJ1RiN2jbasqPp5bS0wrsxUXQU73jQzv2psRQtTlTMqQ/W71KzikXrbbxvvwuAJ7br0Oz24FYG+76MlkGns4YFdrGd5YT8arRFE2LrY4TdWw8NrdLJXe0345NtnGV8IlThCtPg92Oc3TbKqit0RxVrpHWIIldB9D3sgGc7DmArEiI9X//0OpwUoKc6GxJ6fE9Se3VZGDsRHSPHy7zdWch4m2iceEyIPYhH4rAm9KGzQhcY4eSNaAPus9P4zux0RTj5/7+IagXZ6g4jJnj1hsoklVA0Ryt1jlHlGHLPrNlUd0EJ7zqsXh8WoMkdh2Ef7aq2mlvSTv2fHBFlbGpYlTc3GKmRBK79mlcuAzd5yerrI/5gbXyKFT5oMgN0/taHKCQUDRNi3naRSu3AMOLYVAXZbq7aL9V0SUn7W2gMK4v2yuTa3NHy5vctjAsWsoKRSRlFZLYdRAjP6VmSwbv3sDBZJx2TTKMPYlDyncnrVeTga7r5t10cOQwDtep43I3eiZmG/aOWWiaBsBr0WrdzYMhy25VZEJnYVRqm95fQ/i0erNoWkxEIbtj43G3SOxaCyjcDmgKwoEk+khpXq/WxrkmjGjzvNGGFeGEdUhi10GUjOzLif5DsethNv99qdXhJD2mgKL6CLuP+iyO5uyEjpQTOV0HWQ72FCsDtj75UJRtcWBJRri6lqalawHwzFOmxP4QvLNPnZc2bGbgHNwf5+jBEArjfX2peXxu9Oe/7ChUpoZmqlNwDhsAWQ4iNfWEjpSbxx22mIAiWebs9HA4ti/7HDtiJbGzDknsOhD9JtWCcC1cjJ5EpfNkxFHSFboXY9cj+LftS0qvJgNDAeYaNZittcqUSdqwZ+Jd8D6EwjjHDME5qB8ASw9DfQB6eGBSL4sDFDoNI7Fv2Y4tLYSxJaqt+OY+iwJLQjSXE9cIdcN4xgaKJFPGBnYdRG/yoeVm4xzSv83niCLWeiSx60DGRNuxgw5sYWey3GIlMdnjVNWu//E9SenVZNC86cyNE3LROhPvS+pDPO/OWeYxo+126xBl4SBkBnm3q5vc5lVbCB2P2XgYVTtpx8ZjztltakcZmyTjKuZN7vjhpuK9JTU+ZWsEUrGzEknsOpDCshKODRmDDZ2df19idThJj3tsbM4uWe5I28JvuqzLKrH2CB4pp3nNVtA0PLdFh+cD8N4BdX6utGEzCkfvEtzTxoKu450fuxbeMkStmdpQDkfqrIsv2TDEZP4tba8W234qOQQUzRvUfJ37HG3Y0gLIF+GEZUhi18FkzVEfannvLUmKN2IyY8zZJfNqMT0cxr9ZTf+HRw7nQK06LhW7eAxD2uzp43H0VP2jRfuVCnJgkQhNMhFDRNHQoh1bkgvT+qjHxiYSAVxjjYpd/AaKwV3UbuWGAByutSi4FviNxG5iO8IJo6NR0lkRCW0hiV0HM/aTVxPWbAw8tpONq45ZHU5SY9hhlFYeYuex5DS3Cuw5rGZKcrLZXaDmxnrnQRcRTsRhJHaeebE27Pxo8WHuUKWKFDILz5yrwWEnsHUvgT2HzONG9VbMimM4hw0AZxaROi+hQyfM41l2ZRUD1s/ZRbxNBHYdAs6iiJWORlIgiV0Hk9uzCydGTgRg/7PSjj0b9h7F6F2LsOthfNv3J6WAwmzDjhvKtmo1UyIXrXj82/cR2HkAnFnkzrkKUCaly46o84bNhZBZ2LsUkHPtFCB+xdiNgyDLBruqYXeVVdElF5ozC9fIQcCZAoox0erXFosTO//m3RCJYO/VHUeP4jafIx6fyYEkdgnAEx0cLl66OCXWZVmFpmnmBop+x3ZzsNbaeNrC8K9zjR8md6Pt4H1ZfWjnzpqGvSAPgDf3QliHMd1hQJGV0QlWYqhjG15+z2wxFrjhmlJ1/jVpx5q4otdC/+b4DRRGYre1ovUrOpdzzdfVNsOxevVYrpHWIoldAhjzwJWE7A76Vxxg9fsHrA4nqTHn7Mr3WH5H2hb+6MYJ9zjZONEWeiSC95WYKbHBay3asELmkjv7CrQcN6FDx83qN8Q8DRfsRqyhohim7WdU7FqsFrNybtu/Ptq9aGe+zpiT7lcABSKcsBRJ7BKAsyiP8vFTATj2vKwYOxvGnN2Qk7stvyNtje4P4N+uDLfCo4aZ7u+S2MVQdhansOXlkjNrGgBH62FduVI/zpE2bEZj8+SQe8N0IFbZBZg5AHKy1P+VjSetii65aGl5okdicykDu0C2Axot3kBhVuzaWSUmN77JgyR2CaLLnaod2/ujxTQF5Ja0PYyKXempg2w/FrA4mnj82/dBMIStawG73T3RgZ4eKM6xOrLkwZidyr3lKmxudZv+erTgMLUPlHisikxIFjx3RM2K5y9GD6vZlOwsuL5MnRdPO4VzWClatotIQyPBAzHhncMWU5Vb1dUIlVcSLq8Em828GW+NJHbJgyR2CWLkPVfgz3LT6/RxVry969wvyFAcfUqgsICsSIjGHQcJJZGAomUbdvMpJescIzJ+Ez0QVNsmAE8LU+IF0bkpacMKADlXT8ZWlE/41Gl8yzeax43/H2/uJane91ahORy4Rkf97KLVMQNTQGFRV6M5akzsHD4Am6ftO1tJ7JIHSewShN2TzakplwNQ+YK0Y9tD0zRzA0Xp0V3sO21xQC1oaUxsXFDHyEXLpOn9NURq6rGXdCX7ivGAUjnurFKqx5sGWRygkBRoziw8t14DxLdjZ/SDIjdUNsGKo1ZFl1wY+1eb1++MOz7a4sTOv14lmu3ZnNQ1xwynRThhPZLYJZBeH1MtiAErl1DbJLek7eGKzmwMLd/J5iSas2tuoYg1LqhjpWJn4n1pEaBU4MZ6IUPleHWpUj8KAsScAhrf+IBIs/KszLLDzYPVeWnHKtzRxKml0ARiN5TbK62pbsYUsWf3r+ubD4XyvrccSewSyODbptDk9tCtoZJlr26xOpykxT1eDQ0PO7HTsjvS1kS8TQT3HAYgMHw4h6J3o+Koroh4m2h8ZzkQU8Pqesx0VtqwQkvc08Zi79WdSL2XpvdWmcdvi/4/WbgfmkMWBZdEGBU7/7a96P7YzHFZEeRmqe9RZ3c19HA4ZvvUjtWJ2dGQ62NSIIldAtFcTmqmzwCgvoVBpxCPUbErrTzEriM+i6NR+DcrHwZHnxK2610A6F8gd6MGjW99iO7zkzWwrzlMveGkUjnmZCnVoyAYaDYbefOiO4Rfiq0Ym9hLbXLxBmDxQauiSx4c/Xpi61oAwRD+bXvN4zYtNrvW2Te/gT2H0Rt9aLnZOIeWtvmczRmc2EUafXFr4JIBSewSTOkDqpoxbO1SymvklrQtHD2KoaQYux4htH0v/iT4NjWbGyeGmRctacPGaHgp5l2nRfeFzY9qhK4vU6pHQWiJ547rAWh8dwXhugZAJSxGdXe+aMzQNC3Wjt3Qqh1r0QYKc75u3DBz5KI1mTyqUvmNxzk67X4aW1SirUYSuwTT9/rxNOR3ocBXx/Ln11odTtKSE23HDjq+kz3VFgdDTBHrGj/cTOykDasIVVTj+2AdAHl3qg/rYBjeiBYYbhtmVWRCMuMcOZCsYQMgEKTx9Q/M40Y79v1Dagg/0zEFFK3n7CzaQHGu+brKRjjhVb6VmaaIjTT6aHzjA4L7j2LPz7U6HBNJ7BKM5nDgvU4pwoKvvXuOZ2cuhunl0BO7kkJAYcyUuCcMNy+kmXg32hbeVxdDJILrspFkDegNwEdH4LQPumYrtaMgtEbTNPKi85gNUeENwNBiGNYVghF4a59V0SUP7vHtVOyiSdPOKgh04qpKw3rlXPN1A7uAx9lZUSUHjW99iN7kw1HaG9ekUVaHYyKJXScw9JPqYjZ6yzL2HJNb0rYwLhpDT+yyfLVYuKqG0JFy0DQaBg0170YNk9BMx/uyukExqnUA86OiiVuGKENVQWgLQ2jTvGIToROxN7pR5ZV2rFLhAwT3HyVc22Ae71cA+S7wh+m0rkak0Udgpxp+dLezSiyTR1UaXngHgLx7ZpsjKcmAXII7geLLR1BT3JPsoI91zy63OpykxFin06fmGPsO1FsaS3O0DZs1qB9bfaq8PigD70bbIrD3sKpmOux45qpKdGMAFu1X528TNaxwFrL69sA9ZQzoOg0tBGW3RlfPrToOJxraeXGGYO9SQNaAPkC87Ymmxe+N7Qz8m3dDJIK9ZzccPdu+s81U4USovLLFSMpsi6OJRxK7TkDTNEI3KU+7rLfek6XXbWAvzEMrVRcz+/Zd+ILWxWK2HsYPz+ih4LYwFI05107BXlwEwKID4Asp1fD4HlZGJ6QCxpaSlurY3vkwuZd6bGwuyWRM25N2BBSdNa5yrvk6XY+JOTLtGtnw8rug67injCGrtJfV4cQhiV0nMeLBaDt21yrW787wW9J2yJ2gqnaDT+xie6V1cTSv2w6A+7IRIpxoga7rNLwcNSVusULMaJ/NHaqqCoJwNjy3XgNZDgLb9xHYFfM4Mdqxr0k71rSAam61WqyzLU9iGyfanq871qBmax02GF7cOTElA7qu4zXasHcnV7UOJLHrNPJHDaCq30CyIiG2P7PU6nCSEmO2ZOiJnZ3WamiNHonEVolNHCnCiRb4124jdLgcLTeb3NnTAahqUsIJEDWscH7YuxSQc91UIF5EcdMgtYpuR1XnzZAlK26jYrdxZ5xHmlGx212deENnXddb3OSObPM5RoI5rBjcjsTGk0wEtu0jsPMAmstJbnRdXjIhiV0nknWrascWvPsewU5UNaUKxl3qsPJdlm2gCO4/SqTOi5btoqpPGVUZeDfaHsaHsOeWq7DlKKfmN/ZCWFezPwOLrIxOSCUMdaz3lffQI2pHVlE2XNVfnTfEOJmKc/RgcNgJV9YQOha7GPbOgy7Zaq3YrqrExhA6VkH4ZBU47OYMdGtM4USG2Zw0vLAQgJzrL8demGdxNGciiV0nMuKTynl95MGNLF+X4HdlCuIaPQTdbqe4oYrDe6zpxRp3qK6xw9hSrW5Bh3bNrLvRttADQbzzlwDguSvWejDasFKtEy6EnNlXoHlyCB09SfOabeZxsx2rFr9kLDa3C9fIQUCsHQpq1KGz2rHNa9XPxTV6iHkj15pMFE7ooRDel5XwJ++eGyyOpm0ksetE3KU9OTVsNDZ0Djy32Opwkg5bjhv7kFIAsnfupMHf+TEYwgn3ZSPMdrC0YaFpyWoiNfXYS7qSPX08AIdrYeNJtT1gzhBr4xNSC1u2C8/NVwLgfTnWjp05QO1EPVYP68qtii45aM+o2LgeJdoWyh9N7Nprw0Z02JaB18impesIV57G1rWAnGunWB1Om0hi18nkz1Pt2J5L38MbOMeTMxBDQDHkxC62WVC0a15nDAuPyMi70fYwV4jNm2muFTLaZVf0he7JY7oupAieqA+i97X30QNKBp+dBbMHqvOZ7mnnOsdqsURvoDAqdu52jHf316gdv24HDO6a2FiSCW+0DZt3+0y0rORs5Uhi18kMvf8awjY7Q07sYunSo1aHk3SYc3YnOn/OLtLoI7BDGbK1FE5kemIXrvfS9M4yIGZKrOuqXQbiXSdcHNkzJmDv3oVITT1NS1abx2+PtmPf2Nu5GxaSDVNAsWU3eiimlDC87PacJmG2UJFGH/5tag1IexsVjOvzqG6ZY0oeaWik8e2PAPAkoRrWIEN+HMmDo3sRVeMuA6Di+ffO8ezMwx0d0h1avpMtJzt3yMY04+zVnWPZ3agPgMsOQ7p0ahhJR+MbH6A3B8gaWqqGulEtmP016vtjVFgE4ULQ7HY80Q6G4eAPqgLcLQdqm2HpIYuCSwKyBvVD8+SgNzUT2HXIPF7iURXyiE7CbKH8m3ZBOKyMiXu3rYzIxI0T3gVL1bVwcP92BSXJgCR2FtDjY+piNnD5u5z0ZvCEcBs4RwxEdzrJa/ZSseNYp37t5vVRaf+E4eZFa0Q3yLJ3ahhJhzeqhs27Y5a5NufVaLVuVhnkuayKTEh18qJCnKZFKwjXKX9Pu015IgK8msHtWM1mwx21gPJvjPezM6p2iTIqbmlz0t6qrC0Z2NEw1LB5dyfXCrHWSGJnAaV3XUkgy0Xf00dZ+noGX7naQMtykDVKVYXy9uyithNX6xrqM/dlIzPWTb01wWMV+JZtBGIzUaEILIgmdrcn702rkAI4Rw8ma9gAdH+AxgXvm8fnRf1wFx+EOgtEVMmCMWfX3GrOblx0w8umk4n5uuZ83eS227CBMOyIVgsz5RoZPFJO84pNoGnmtTBZkcTOAmyeHOqmK4NXY6G6EMMQUAw7sbPT5uxamnG6Jo7MyLvRtvC+tEitzbliPFl91afJsiNQ2aT8tAzfMUG4GDRNM6t2DS/E1LEjimFIV7Xw/q29VkVnPeacXasNFMbqvo0JSOzUtfDswond1epnk++C0sKOjyEZMap12dPHk9UnuT8YJLGziAGfUBn/2HXvsaciwRbiKUZsA8WuTtuJGD5xinBFNdjtOEYNMWX8mZzY6bpOw4tnrs15JVpknjNE2tTCpZN35yzQNJpXbSZ4RHmcaBrMi1aDM7kda1TsArsOEfE2mcfHlIAGHK1X2186kuCBY0Sq69BcTlyj2/YxMm98u2fGGkFd12n4V7QN+7EbLY7m3EhiZxHdZ0+myVNAl8Yalr+w3upwkgpDGTv45B42H++cpNewOXGOHMiBZje+kPLTyuRtCv7NuwnuOYzmduKZczWg7A3eUcJh84NXEC4FR6/upjei98VY1W7uUJW8rD6ufO0yEUePYuy9ukMkosRdUfJdMCgq6urodqxpTDx2KJozq83nZJpwonnVFkKHTqh1ijdfZXU450QSO4vQshz4Z6tNFLbXFxERDYVJ1sC+6J5c3CE/1VsOdYoDvSmcmDjSvGiN7q7MdzMVb/QONfemK7HlKaO6hfvUjsqywsy5qAuJx9hm0vDiO+Zu1F55MLWPOp/JnnZGO7R59da44+MS1I71GyMpk9o2JobME040PP82AJ6512LLzbY4mnMjiZ2FDP6U2pc4cdtHrNrrszia5EGz2XCPU7K4Hgd3crQT7tZjwokRbIheKI05lkxED4ZoeDW6Nufu2Nocow17+/DMaMEInYNnztVo2S6C+4/ib7FpwRDnvLIrc1eMuSePBqB5TXxil6g5u+a16uu4J41u87wvCHuq1eNMuLmLNPrwRoU9ybpCrDWS2FlI/pSR1HfvRXbQx9bnllkdTlKRHR0aHn58BxsSvFpIDwTxb1FtDtfEEWyMfr3xPRP7dZOZpiWriVTXYe/WheyrJgJQ3gArop7aYkosdCQ2Tw65N84A4j3tbhqkvBL312DJJppkIHtKNLFbtw09EjGPG4nd5goIR9p65YUTrvcS2HkQaH+V2LZKCOvKa7CHp2O+bjLT+NaH6N4mHKW9cE8dY3U454UkdhaiaRqO25SIomDRooS5iKciRvth5LFtCVF+tcS/Yz96cwBbYR7+Pn3ZHb0bzeSKnTEo7LlzFppDrc15bTfowKRe0K/AwuCEtMRox3rnL0YPqtnaPBdcHzXAfnVne69Mb5wjB6LlZBOp8xLYfcg8PqQrZDvU3Ov+mo75Wv4NO0HXcfTviaOk7T1hW1rM12VC1d5ow+bdcwOaLTVSptSIMo0Z8qBqx07Yt5Yl6zvo3ZkGGHeLpVWH2HWgIaFfy2zDThzJ5goNHeiTn7n7T8O1DTS+sxyItWF1PdaGFdGEkAhyrr4Me7cuRKrr4leMRavDC/YoD8VMQ3M4cE9UHYyW7ViHLdYK7aib33PthwXMDsq4DLjxDR6rwPfRBiB+JCXZkcTOYlyD+nF68HDsepiD/1xidThJg724CK1UTU7///bOO7zJsu3D55OkSdO9B9BC2ZS9N8iQKYIiqKAi7oF7+36u14HzdaKAC1QQRUGWgsgG2cjee5QCpXulGc/3x52kLZYOaJs0uc/jyNH0yZPkStPcue5r/C5l+x7yq7A51tE4YWif6Fwg23nBonU5cuavgAIz+sT6GFo0BGBvitCu0mthaCMXGyjxSBSdjoAbRUNZ0XRsr7pCM/FCrtBQ9EYuV2dX2ULFTseuQ9mOXTsvKFXJ/nlJoY5nfM15wdKxcwPCRouoXb1Vf3Ihx8XGuBH+dtXzJqd2O3XlqgLTFkfErohjV3M+w5WOU6+paNOEPQ3WLwGCfV1hlcQbcI4YW7LOOWLMRys0E6EwauxtlOXYVUbETrXZCrMXl4nYJWdDUrZQC2jj4Y0TQrvOnoatAdp1RZGOnRtQd0w/rBotzc7s5a+V1Tsf1Z1xjLOpyjo768V0zMfE31zfttCx89b6OvOxM+LLQ6MhYKTYcFhsor4OZBpWUrXoWzXGp0k9+4ixlc7jjv+7JUcgywtHjBk6NAdFwXI8Ccu5i87jjszCgYuQU3B1z2E+eAJbZjaKnxF9Yv0Sz9lqj9Y1DQd//dU9n7tj2rwb89HTKH5GAq5zf+26okjHzg3QRYWR0U50Hl78SY4Yc+DYNTZL2sv2JGuVPIdjBqNPw3hOK4Gk5okuvMTIKnk6t8cxacLYqz26mAgA1p0SabBQX7imnguNk3g8iqI4I8WOyDGIWrIGoUJD8ffDrrLOdWiDApzOVtGoXXQAxAaATYVdV5nVcMicGNo1czZMXYo3pWEzHdp111+DJsDPxdZUDOnYuQl1xoru2JYb/uRgipcKNl2Cvkk9bAH++BXkkfLPsSp5DpNTmDjRuWg1jxK1ZN6GarMVjhC7+d9p2Osae+ffRVK9BI4aABoN+Rt3Yj4qoumKAjeJ/gF+9dLu2KpOx+ZvKkfjhJeUqtjyTOT8Jmrea1oaFqRj5zZEj+hJgcFInbTTrPhtj6vNcQsUrRZf+6zEqEO7SM6u/OfIs6u5Gzq2KFy0vDQN6xybU1RTzASL7SPERjZzoXESr0EXG4mxdwegcPA6CLFix4ixkxkuMs6FVLVQcWHjRMn6dSYL7LFHBT19jcxZtApbVg66+Fh8u7Yu9dxVJ2DZMffq2JaOnZugCfAjv6/I46u/LXarfxJX4t/ZXmd3qvLr7FSzxRmxM3Zp5fWNE1k//g5AwPA+zrE5iw6J9FeDUM8vlpa4D4G3iihJ1k+LnaK8sYHQI17cPscLo3YOx8608yC23Hzn8aKO3ZVO57Ccu4j5yClQFOfzXMreFDBZRYdyvZAre56aQuaMRYD4PyxLu+79v+Gu+TBzV6mnVSvSsXMjGtwlOsK67FjGmkNeWCFcAlUpVGzadRA1z4QmNAhzvbrssyvbe2PjhC07l2x7sXrgrUOcxx1pr5vkCDFJNeI/qCeaoAAsp8+Rt+4f53FH1PiXfXjdfG1dXAzamAiwWIuNXWsZBVoFzufA2SvMauSv3wGAPrEB2pDAEs9xlKq0jfHstcB87Az5a7eBohB4y5BSzz2QAjvPg49GlKq4C9KxcyMCe7UjJzyKwPxsds7629XmuAWG9omoikKt9CQOHUit1MfO37gTEDvh3SkarKoYkVOr5HXNo8lesBI1Nw+f+nWcO/bj6bApSUgb3CjTsJJqRGM0EHCDXdPOXsQOMKgBBOjhVCZsTnKVda5BKRJNK5qONfpAU9HndMWb3zy7Y2fs1uay5zgbJzx845tpz1wYr+mIT53S0xS/2P3rvgkikukuSMfOjVA0GvQ3iiaKiD+XkCGDdmiDAqBRAgDqP7sxV2JzbP4Gu2PXuWXhfFgPX7Quh3Nszi2DUezbcUe0rme8d8yElLgXjqL1nIWrsGXnAsKJGSI0s/llr6sscx1VVWeXv367ePxS6sm8oSNWtVqda2HQmKGlnmuxwVy7ruKoxKq2rGJIx87NSBgn0rEdDm9g8UY5YgwgyF5n1+jkbvZfLOPkcqKqKnn2iJ2xS2tn44Q3Onbm40nk/71dpB5Gi/8/m1pYxySbJiSuwNA+EZ8Gcai5+WTPX+E8fpP9S/T3w3jdfG3fznbHbvNuZ+0hFK5bVzKBwnoxnYJ9RwEwXsaxKypM3NqDa21zV2zGevYCmrBg/Af3KPXcVSeEDFS4Ea6pW00GlhPp2LkZhib1yGjcDJ3Nyukf/3K1OW6BU6j41G5nZO1qMR85he1iBoqvHn3Lxl6xG70cDnV1Y+8O6GqLVXvDaTidBYF6GNjAldZJvBVFUZxRu6Lp2I61IC4IsgsKO7a9BUPzhih+vtgysjEfPOE87pA82XWeCmc1HCUpPk3qoY0ILfEcpzBxhGcLE2fNWAhA4E0DUAylv1BHxHhEUzEdxZ2Qjp0bEnWb0BBLXLuYozJoh29H0X7f5OwBdpyunC26Iw1raNuMpAI9F3LFUO2WUZXy8DUG1WYrHCFWQtPEsMbgW7JWqURS5QSOHgiKQv76HZiPnQFE1Gikl2raKT46DHYJqKLp2PqhEKQXHewVzWrk/b0duHy0Dryjvs6akkbOknUABI4tPQ2blgd/2aVVR7lhRkM6dm5I7Oi+WLU6Gicf5M/FR11tjsvxqR+HNTgYvbWA1K0HK+Ux8zeIYmHfzq2ci1ZihPc5MXnr/sFyKhlNUIBTuy6noFDdX6ZhJa5EVyuqRE07x//l2pNwNssVlrkOo73OzqHBCfbZrVeYjnU4dr6lNU7YH7O9B2c0sn75E8wWDG2aYkgsPU0x7yAUWKF5JDRzwylF0rFzQ7ThIeR27wqA6ZfFWL1c005RFKfsSdj+3aTlXf1jOhZFY5fWhfNhPXjRuhxO7bob+qExGgDh1OWaISHEsxdySc3AEUnO+nmJs64sPhg61QKVwgJ2b6GsBoptFShXsWZmU7Bb7OKMXduUeI7JArsdwsQeuh6oqlqoXVdGtA4K07Du1jThQDp2bkrdcSId22Xbn6w/UTVzUmsSgZ1FOrYy9OwsySlYjp8BRcHQsXmhMLEHpxlKwpaVQ87CVUChICwUtvBL7TqJO+A/uCeaQH8sJ8869daguKbdlQrz1kQMHZuDomA5fgbL+UIJqA61xM9NFZCByd+wE1QVn4Q6ztnQl7LngohOhRmhbvDVWO6+mLbtxbz/GIrRQMCN/Us9d3+KqGX00cDwJtVkYAWRjp2bEjKoK/kBQURkX2TTL1tcbY7LKSZUfPbqVnHHTlef2ACzX4DH70YvR/Zvy1HzTPg0quus2zmZIRonFKR2ncQ9KKpplznzd+fxoY1E6cSRtMJUoTegDQpA30xIQBWN2rWLFSnZ05nlT087S1K6la++zlM3eo5onf+wa4TEVim4q3ZdUaRj56Yoeh8014mdQ9AfS8jyck07Q9tmqBotEVkpHN57/qoeq1C/rhV7LoDZJlrW44Iqw9KagyMNG3hroXadI63VPc47hZol7kmhpt1KrJlivEKgQTh3AD972Xht3y5tAMgvMpUjQC9qvqD8UTtn40Q56us8deNry8kje+4yoGztOrMVfnNT7bqiSMfOjal3p9AU67pvNb9vu8JZMR6Cxs8XEsUqbt2y66qEih36db5dCoWJ28V67m60JAoOnxRDvzUaAkcVatfNtteO3OTGi5bE+zB0aI5Pk3pC087+JQxws31e/YKDounHWzD2bAdA7tptxY53cqRjz5T9GLbsXEzbDwDge5n6OvB8YeLs+StQs3PR1atdagMJwGo31q4rinTs3Bjfds3Iia+Lr8XE0ZnLXW2OywnuIlbxRid3s/vClT2GLTu3sFi4cyvnWCJvq6/L/GEBAH79uzhra9afFqOaAvVidJNE4i4oikKQvag964eFzuOdakG9YMgxF3ZyewPGbm1AUTDvP1aszq5jbfGzPOPW8rfsAasVXVwMPnElL4Bns8T8Wa0HCxNnfjcfgKCxQ52Zi8sx256GdUftuqJIx86NURSF8NvEYtZizSIOpLjYIBfjaPNveXInG8uxIy2J/C17wGZDFx+LNjbKmbLoVLuSjKwBqAVmp3Zd0O3DnMd/sqezhjcRo5skEncicNRA8NFh2r4fk31zpiiFKbGfvCgdqw0LRt9CzFbLKxK162iP2B24COn5pT+GU+aktGidPQ3bNAL8PHBNMO05jGnLHtBpi+l4lkRqHvxlVx+7yc3rj6Vj5+ZEjx2ITaMl8cxeFi8+5mpzXIojTN7g3GF2Hsi8osco1K9ryaFU8WH11UErD92NlkTOH2uxpaSjjQ7Hr38XADLyYbE94uFIb0kk7oQ2ItSpteiYEACibECjiCiVNwm6G3u2ByBvzVbnsQg/qB8irm8pI2qXXw5h4q0enobNnC6idf5DeqGLDi/13Dn7RD12yyhIdEPtuqJUmWOXmprK2LFjCQoKIiQkhLvvvpvs7MvXiaWmpvLII4/QpEkTjEYj8fHxPProo2RkZFSViTUCXVQY+T2Epp31l98xWVxskAvRRYdjq18XDSoF67dfkb6fQ7/Ot3MrZx1KuxjQu3FYvbJxpGEDbx2CohOKzHMPgMkKzSK8b/qGpObgTMfOXoItT3SUxQRAb3u9k6NG1Bsw9hB1dnmX1Nk50rGlNVDY8kzk/yPyiqU1TjicQ0/Us7Rl55I1ewkAQeOuL/VcVYVZ9ojwLTVg41tljt3YsWPZs2cPS5cuZeHChaxevZr77rvvsucnJSWRlJTE+++/z+7du5k2bRqLFy/m7rvvrioTawx17xaLWa9/FvPXIS/27ICgXm0BaHJkGwcqODpHNVswbRWfTmPnVs50bmcvSsOaT54lb5WQzwkae53zuCONdXNz72oikdQsjL07oKsTjS0jm5zfVzuPO6LMv+wDi5cIuhu7tgatFsvxJMwnC1WJy9NAYdq2FwrMaGMi0CWUvABmmYReG0AXD1wjs+cuQ83OxSehjtNJvhzbkuFQqsjuXO+m2nVFqRLHbt++fSxevJivvvqKzp0706NHDz799FNmzZpFUlLJ24gWLVrw66+/MmzYMBo0aEDfvn158803WbBgARaLdzszgQO6kBcaRmhuOttn/e1qc1yKv/0D2Ob4PxWuszPtOoiaZ0ITGoSuUd1Cx65OJRvpxmTNWASqirF3B3zqiW+AXedh7wUwaOGGpi42UCIpBUWrJdAuSVE0HdsvQXQqns+BVcddZFw1ownww9BOFHsVjdo56oV3n4e8y4zWLjof9nINA1vOik75usEQ64HSR46micA7hqFoSneFZu0WP69rBEGGqrbs6qkSx279+vWEhITQoUMH57H+/fuj0WjYuHFjuR8nIyODoKAgdLrLD/A0mUxkZmYWu3gaik6H301iEkX80t8543kvsdwYu7cBoP6Fo+zaV7GCGqd+XaeWnMzScC5HpGDbeklHrGqxkGnXrisarXMsWgMbQIivKyyTSMpP4K1DQFHIW7MN8zGxO9MX2ZT85EXpWL8S6uzigiDaX9SDbT9X8v3y128HSp8Pu+G0+OmJGQ3TjgOYtu8HvQ9Btwwu9dwsk5DTgZqRhoUqcuySk5OJiipeqKPT6QgLCyM5uXwS4SkpKbz++uulpm8BJk6cSHBwsPMSFxd3xXa7M3XGi46dzoc3sGCd97bHaiNCsTSsD4Bp/fYKjRLKd+jXdW7pjNa1jhbhdW8gd/lGrGcvoAkLxn+IKELPt8B8IWUlmyYkNQKfOtEYr+kIQObMRc7jjv/fZcfgQo4rLKt+HHp2eWu2odoXQ0UpjNqVlI5VC8xCHYDSGycca2RXD8xoZEyfB0DAdb3RRoSWeu7CQ5BngQahhWPb3J0KOXbPP/88iqKUetm//+onMmdmZjJ06FASExN59dVXSz33hRdeICMjw3k5derUVT+/O6JvVJecFi3QqlYuzlqCzYtmI15KcG9RZ1f/4D8cSy/ffVSrlTy7SruxaxvnouVNMieZ34vUVeDNg1AMekBof2UWQJ0g6OaZeyKJBxJ0m4g4Z836A9VeqtM4XETfLTaYc/VfQzUCQ4fmKAY91nMXMR8+6TzukD0pSc8uf/Nu1DwT2sgwfBrXK/Fxswtgpz3a52kRO1tWDtm//gVA0B2lN01AYUZjdA2qP66QY/fUU0+xb9++Ui/169cnJiaG8+eLj32yWCykpqYSE1N63isrK4tBgwYRGBjI3Llz8fEpXTzHYDAQFBRU7OKp1B4vaku6bfydv096r2cXYN+ltjmxrdx1dqadB7FlZKMJ9MfQponXNU5YklPIXboeKPxSBPjJsWjZJSMkkpqA/6AeaMKDsSankLu8sLxntF3T7ue9VCiaX1PR+Brw7STmaOetKVJnZ3fstp39dzOJ4+9l7NPxsvV1W8+CVRVp3doe9pWa9etS1Nw8fBrVLXPSxP4Ukc7WaWBkDao/rpBjFxkZSdOmTUu96PV6unbtSnp6Olu3Fub9ly9fjs1mo3Pnzpd9/MzMTAYMGIBer2f+/Pn4+sqCn6KE3tgXs8GX+IsnWT1vt6vNcRnGbm1QFYV6KSfYvad8rbGOTlDfHm1JytVxOlOoqXtiG39JZP34O1it+HZqid6+Sz+eDhvOgIJ7zz2USC5F0fsQOFrUHTsi0QDDGoNRB4dTRfG/N2Ds8e86uyYRosg/xywao4qSu2ITAH72dHZJOOrrPK0bVlVVMqeJNGzQHcPKnDThUAvonwCR/lVtXeVRJTV2zZo1Y9CgQdx7771s2rSJdevWMWHCBG655RZq1RJbiTNnztC0aVM2bRL/ZA6nLicnh6+//prMzEySk5NJTk7Gar2KwaAehCbAD2VwHwCCFy4qU1ncU9GGBmFpZFddLzIEuzTyVgvHzq9XBzba0xMtosTgbE9HtdnItHcQBhaZNOEYnN6rLtTywK43iWfjiDznLl2P5azwXgINcF1jcfvMXa6yrHpx1tmt+wfVJsJzGgU62DetRevsLBfSKNh1SNyvdymOnYcqBpi27aVgz2EUg57Am0tvmsi3FKb0b2lRDcZVIlWmYzdjxgyaNm1Kv379GDJkCD169GDq1KnO281mMwcOHCA3NxeAbdu2sXHjRnbt2kXDhg2JjY11Xjy1bu5KiL9LLGa99iznt225LrbGdTjq7OL3/8PpMrqEbXkm8jeJCKexdwfnQuctadi81VuxnDiLJtCfgGHXACI945h7WFM6vSSSougb18O3a2uwWsn8foHz+FgxeZBFh8oeq+UJGNo2RQnww5aW6ZyDDUUaKIrU2eWt2gyAvkUjdFFhJT5errmwvq6Lhzl2jmid//A+aENLzzH/eUT8/8QGQK/46rCu8qgyxy4sLIyZM2eSlZVFRkYG33zzDQEBAc7b69Wrh6qqXHPNNQBcc801qKpa4qVevXpVZWaNw7dLS/Jq18GvII+jPyzzijqSknAIFbc5sa1UhXWA/E27UE0FaGMj8WkYz0YPbuMvicxv5wIQcNMANP5GQHQOns8R2l/967vSOonkygkaPwKAzO8XoJpFE0WbaEiMEJNUft3nQuOqCUWnc3a3ljQ3dktSYb1h7grh2Pn16XTZx9uaJDZ+tQNFjZ2nYE1JI3vuMgCC7xxR5vmOSROjE0Fbw4av1jBzJYqiEHmX6OTpvHZ+hUV6PQXfrq2xaTTEpZ5mz64LpZ7rqK/z69WeC7kKR9NFXVnHGtK6fjVYzpwjZ/E6AILvusF5fIY9TTU60bvGqUk8i4ChvdFGhmJNTiFnifg/VxQYY4/azdjlHU0UhbInhXV2raKF6PjFPDiSJurL8laK0idjn8unYdfbv1M8Llr3/QJUUwGGNk0xdCg9TXEyA9adstcf18CMhnTsaiARYwdj1fnQ9Ox+ljpEyLwMbXAg5saNAMhdW3qdnaO+zti7g7P9v1kEBHtBb07mdwvAZsO3e1v0TRMAOJEOq06IRcvxBSiR1EQUvY9zEkXmtN+cx0c0AT8f4dCUFdH3BJwNFOt3OCOXRcXXNydBwZ4jWM+novj5Yux0+Q++J2Y0VLOFDHsaNviekWU2TTg2vj3ja2bUUjp2NRBteAjqwGsACJk3jxQvLbVz1NnF7t122b+BNTUD004hG27s2d7Z7eUN+nVqgdlZexQ8vjBaN9PeUN2rLsQHu8IyiaTyCLrjejGJYtUWCo4ILbdAAwy3z/Sc4QVNFPrmDdCEBqHm5GH6pzD/3LGIUHGuI1rXva1Tx/JS8syww15f50nCxDm/r8GadB5NRAgBI/qWeq7JIuRyAG5rVQ3GVQHSsauhxN8n0rF9dv3FnM1eIrN+CSG9Rfqh7fF/LjvwOm/tNlBVfJrUQxcT4VWNE9mLVmG9kIo2Otw5aaLYoiWjdRIPwCc+Fr/+XYDC4niAsfZOxj8OQ2qeKyyrPhSNBmN3sdHNLZKO7WwvN1l/GvKW2x27a0qprzsrRpHFBtTMSNXlyPjqVwCC7xh+WafWwe/2/5fYADGDuCYiHbsaim/X1uTVrYvRnMfpGX965SQKY9fW2DRaaqUnsXtHyUMR81aLRc6vVwfS8mC/XfbOGyJ2mV+LpomgO65H8RFz0/4osmj1raGLlkRyKUH2YvisWX9gyzMB0DIaWkZBgRVme8H8WMeYtdxlhYLNHWqJOrvUtHxy7bOy/fpe3rHbUKS+rqZMWSgL065D5G/YATotQXcOL/P8H8SfiVtbCGHimkgNNVuiKArRd4t/0h7r5rHmhPd5dpoAP0xNRb4lt4jqelEcjl3R+rqGYRDhVy0mugzTnsNiNq5WW2xsjiMtVZMXLYnkUvz6dUYXF4MtPYvs35Y5jzukT37c7flNFP7XdgXAtGUPlgtpABh9RJNY6xPbUcxmdHWi8Wlw+dmBnlhf54jW+Q/tjS42stRz910QwtZapWbLQMmlvQYTdusgLHo9Dc4fYcU8L9iSloCjzi5y97/r7Mwnz2I+dhq0WozdCufDdvaCbtjMb38DwH9IT3QxEQAcvCgKybVK4cB0icQTUIpsYIo2UVzfWIiQH0uHv0+7xrbqQlcrCn3LRqCq5P613nm8Zzx0POrohu102caBfIsYnwWeU19nTc0ge85SAILvHVnm+T/YN74DG0B0QOnnujPSsavBaEMC0QwRhaDR83/jbJaLDXIB4X0cdXbb/hW1dETrfNs1QxPo71zYPT0Na83MJmv2nwAE332j87gjWte/PsTU4EVLIimJwLHXgY8O07Z9mHYItQB/veiQBe+YROE/sDsAuUv+dh7rWRfa2x07fa/Ly5xsOyvS1tH+UNdDmqoyf1iIml+AvmUjfEvpBAbILoC59kkTNbVpwoF07Go4cfeJdGzvvcuZs9H7PDvfzq2w6vVEZ55n97pjxW4rKnNyPqdwZmKPGqYiXlGyf1oshlw3qecccp1rLhRrvV02TUg8EF1kqHOySkaRqJ0jHbvkCB6vIODncOxWbMKWL2oNG5nOUS/lBFZFw4HGHS57340eVl+nWixkfjMHKJ/Eydz9YrZug1DoVsMjltKxq+EYOjQnr0EDDJYCzs9YjMXmaouqF42fL5ZOQsOJlX87m0hUm83ZHWbs1YE1QgWBllGeXV+nqioZ9jRs8PgbnIvZ/AOQVSB24t093LGVeC+OJorsOX9hzRAb3cRIoedmtolaO0/G0Kox2pgI1Nw88tdtByB/pZg2sb92IqvSLz8U2iEF1cVDMho5i9dhOXMeTXgwATf2L/VcVS1smhjbsuY7ttKxq+EoikLsPSJq12vDPJYd9fAK4RKIGdYNgJZ7/nZG5Qr2HsWWko7iZ8S3fSIrj4vjveu6xsbqIm/tNsyHTqD4GQkcPdB53JGGHdNSDAiXSDwR3y6t0CfWR83NJ+uHhc7jd9hTaz/sArPVRcZVA4pGg/8AsR46JnHkrhBp2C0JHVl7suT7ZRcIqRPwnPq6jKmzAQi6bRgaX0Op5245KxQTfHVwU7PqsK5qkY6dBxB68wAsvkbqpZxg5S87XG1OtRM4QHSDJZ7Zw/qd6UCRNGzX1th0Pqy2L2jX1HOBgdVIxtRfAAgcPQBNoD8gBnrvPC+U6EcnutI6iaRqURSF4HtHAaIbUrWIKQxDG0GEEZKz4c+jrrSw6vEbKBy73D/XYbNYnGvh5gad2HkO0krQ9Ft3SkQ06wVDQmh1Wls15G/bS/56IXESbJ8nXBqOaN2wxp4xkUg6dh6AJtAf/XARam74+xwOpLjYoGrGp040OfUboFVtpPwpNJwKZU7as/M8pOdDkL5wxI4nUnDkJLn2XXrwfaOcx7/ZLn4ObQRhRhcYJpFUIwE3XYsmPBjL6XPkLFoDgEFXWGv37XbX2VYdGHt2QDEasJw5T/acZdjSs9AEBaC2aIqKcOIuxZHR8JSNb/pnPwIQcOO16GpHl3ruxVwhSgyeU38sHTsPodZDopW75/7VzF5WslivJxNg36VGblpPVoaJvPXbAbHIrTouzukR79nabRlTZoOq4jegG/pGIud8LgcWiolq3NXGdbZJJNWFxtdAsL3WLn3Kz87jY1uKz//mJNh93kXGVQMaowFjL9EkkfXdfEDInHRLECLlay5Jx6oqrDgurvepV01GViHmo6fJWbgKgJAJt5Z5/k97RDdwyyho7SEbfw/+mvMuDIkNMHdsh1a1op01l/R8V1tUvdSy19l1OLKRnXM2oebmo6sdhb55A1aeEOd4cn2dNTWDrFl/ABDy4M3O4zN2ihRL+1hoVfrGVSLxGILG3wB6H0ybd5O/dQ8gdMmGNBS3T/PwihV/+0Y3f7tohQ+4rje97E1Ta04WF2s+cBHOZov6si4eUF+X/sUsscG9tiuGZvVLPddshen2NOydravBuGpCOnYeRJ2HRdRu4NYF/LTFuzw7Q7tm5AUGE5ifjflHsUv1G9SD9HzFOdTakx27zGnzUPNMQq/JPjPSZClsmpDROok3oYsOJ/CGfoA9km3nzjbi5/wDIgXnqfhdKxw7TGYUgx6//l3oXFvU2Z7JgqPphec60rBd6wjnriZjOZ9K1o/2De6EMWWe/8dhUXcZ6Sfq6zwF6dh5EP6DumOKjSU4L5Pj3//pVdInilaLpadoooja/Q8A/oN7svYU2FRoEg6xl+/0r9GopgIyvhZjc0IeusUpcbLgIKTY58IObOBKCyWS6if4/tEAZM9fieWM2N21i4FWUWCywqw9rrSuatHFRKC1T5zxaRSPJsAPow90iBW3F03HelIaNvPrOaimAgztE/HtWnoITlXha/FVwe2tRB2mpyAdOw9C0WqJvE9MGui35heWHPYu6ZM614tdqsFsgsAAjN3aOOvrPDlalzXnL6znU9HGRhIwXEwiUVX4eru4fVxr8NG6zj6JxBUYHNFrq5WMr4VQraIURu2+34lHb34dHcEohV/zvezr4Bp7eUqmCecM7Zru2Nmyc8mwCxKHTBhTpiDxtmQxQk2vLWys8RSkY+dhhN0+FIuvkYQLx1jz8zZXm1OthF7bERviw3yxRWvQ6Ty+vk5VVTIm/wSIWYiKj9h2bkoSkzZ8dXBrC1daKJG4jpAHRNQu87v52HKEzsd1dumTs9liGoUnUnD4JLaUdHH90HFseWIKRU97nd3606JhYM1JsKpi2kJ8DR8jljljEbb0LHzq18F/cI8yz//GHq0b3sTzROulY+dhaIMDMd40CIDEP37x6O6vS9EE+qPq9QCcVoLYlwIXcsGog461XGxcFZG3cjMFe4+i+BkJuv1653HHonVjUwjxAF0mieRK8BvQDV292tgyssn6aTEgUm5j7BGaadtdZ1tVkrNgJQCKQQ/5BeTZp/AkRkK4UYzO+ifZc2ROVLPFucENefhWFG3pKYozmaK+Djyz/lg6dh5IzMM3AdD14DrmLDnjYmuqD/PB42gLxM6UpGSW2UfHdovzrPqJoqR/IRazoLFD0YaIIsJTmYUirOPbuMgwicQNUDQaQu4T62HGlJ9RbSL3eptd+mRTEuy54EoLq4Zsu9yHob1QJM/9829ATJ3pHifOWX2isL6ub71qNrCSyZ63HMvpc2gjQwkoMnHnckzfKSKV3eoIZ9fTkI6dB6JvGI+le2c0qBh/nuPxg68d5Pyx1nk98eQulu/KATw3DWvad5S8FZtAoykmSPzdDtEw0iMOGoe70ECJxA0IvHUImqAAzEdPOx2cotInjgJ6T8F8IomCnQdBoymcnbtoFapZ1Nw56uz+PCIyGn4+NTujoaoq6Z/NBCD43pvKHB+WU1A4M/jutlVtnWuQjp2HUmeC2KUO+GcRszZ6h2eX84dQmc8JCsHHZiFwsxh+XdOLgi9HxuezAPAf0hOfemJlzimAWfZFyxNTDBJJRdEE+BE0TszTTvvoe1S7iNs99i/1eQcgKctV1lU+OYtWA+DbtTUB1/VGGxmKLSWd3BViKo+jzu5gqvjZo4ZnNHJ+X0PBniMo/kahX1gGv+4TTSN1g6FvQjUY6AKkY+eh+PXthCkujgBTDknf/UGe2dUWVS2W5BRM24QYZ37vngB0OryehJCaXxRcEuYTSWT98icgakoc/LIPMgvEzMc+HrpoSSQVJfiB0Si+ekxb9zrrzVrHCEFei82zonaOqQsB1/VG8dERMPJaALJ+WgJATAC0LiJWXpPr61SbjbR3vwYg5P7RznKUy2FTC0fKjW8jUtOeiHTsPBRFoyH2IRG1G7z2J37eZXGxRVVLzmKRhjV0aE78rUKYtPPh9XSM8Uw9g7SPfwCLFeM1HfHt0BwQX1BT7Y3Q49t67qIlkVQUXVQYQbcNAyDtf985jz/YXvz8cTdkeICmuyU5hfzNImTvP7QXAIGjRTNdzuK1WNNFaLJPkfKUa2pwqUrO/JUU7D2KJiiA4CITdy7HiuNCnDlID6MTq9w8lyEdOw8meMwQzMHB1Eo/y65vl2O2utqiqiPnd5GG9R/ck6hercg2BBCWk0bcAc+bHWQ+lewcHxb69J3O4wsOwulM0fV2swcvWhLJlRAy4Vbw0ZG/7h/yNoo5Ur3rQrMI0SX6/U4XG1gJONKwho4t0MWKrgBDy0bomzeAAjPZ85YDEFykU97Pp9rNrBRUq5XU974BIPjBsqN1AFNFsJbRzcFfX5XWuRbp2HkwGj9fwu06TgOXzWDhAc+MXtmycshbK0JV/oN7sCvNh1XNegMQ8dcSV5pWJaR/OgPMFow922Hs3AoQKYYvtojbx7cBYw1drCWSqkJXO5rAm0X0Kv3D7wEhWHy/PWr37XbIr+GJDUc3bIA9WufAEbVzSL7sLtIJ7Oigr2lkz12G+eAJNCGBhNinjJTG5iTYcAZ8NJ7bNOFAOnYeTti9N2Ax+lH/wlHWfre+2PBnTyH3rw1gtuDTMB59o7osOARLW4qFrMXWFZxNMbnYwsrDknSezBmLAAh96k7n8eXHxDDvAD3c4UHDrCWSyiT0kbGg0ZC7bAOmHQcAIVhcO1CM3/t1n4sNvAqsKWnk/70dAP/rehe7LeDG/qDRYNq8G9ORU079OoDfD1WfjZWFarGQ9t63gKgx1gT6l3mfz0UvHTc2g1oeOl7SgXTsPBxtcCCB9pb33kt+YPkxz/PsHPV1/oN7YFNh0UHYFd+KlNAY/Aty2fLdGhdbWHmkfTIDCsz4dm2NsbvYdqoqTLJH625rCcGld/tLJF6LT/06BNwganDTPhJROx8t3NtO3D51K1hraGIj5481YLOhb9kIn7rF9Ut0MRH49ekEwIGvlnAxD/ztUf11pyCjhu19s2b/ifnoaTThwQTfM7LM8/dcgOXHRd2xo67Sk5GOnRcQ9fBorD56mp/ezeIft7vanEpFLTCTu3Q9IOrrtp6FpGzwN2jIGiSEKjXzPCMda0lOIeuHhQCEPjPeeXxTEmw7Cwat56cYJJKrJfTx2wHRPVpwQKiY39xcTGg5ngGLa+iYscyZvwMQMKJfibcH3CzWQ9tvS1BUG9c1FjqXZhssq0HpWNVsIe2DaYCIwGoCyp4H5ojWXdcIEkKr0Dg3QTp2XoAuOhzD6MEAtJ//PZs9aBhF7rIN2LJy0EaFYWifyMKD4viA+pB4zwAAGu3bTNLRiy60snJI/2wmqqkA304tMfZo5zw+yb5ojUqEqLIzEhKJV6NvmuDsGE37+AdANBCME+WqTN5CjStZMe05jGnLHtBpCbxlcInn+A/qiRLoT2BKMi1P7mB4ExjcQNz2++FqNPYqyfrxdywnzqKNCiuXbt3RNFhkTzc/1KGKjXMTpGPnJdR6fAw2jZaORzcz99cDrjan0nDsUgNHD8SGxvkBvq4x1GoVz8l6iWhVK7u++cuFVl49lnMXyZw+DxCdsIoitEx2n4dVJ0SK4b52pT2CRCJxEPrEHQBk//oX5mNipzuuNfjqYOd5WH/aldZVnMzvFgAia6GLCivxHI3RQGbfPgAM37uYLrVhSCNx2+oTkF1QLaZeFaqpgLT/TQcg5LHb0fiVPQj78y2gAv0ToJkHjg8rCenYeQk+9Wqhua4vAA1/+YH9KS42qBKwnLvoTMMG3jqEjWfEiJxgQ6G6ummYaKIwLqrZ6dj0z2eh5hdgaJ+I8ZqOzuOT7e371zWCuiGusU0iqWkYWjfB2Lcz2GzOtF64X6G2mSMKXhOw5eaTPVusb0F3XF/quX+0EOthjz0rUPLzaRIOCSFgsooGLHcnfepsLGfOo42NJOiOYWWefyYT5u4X1x/uWPq5noR07LyIOk/fBkDPfav4ccFJF1tz9WTNXgJWK4aOLdA3rudMww5qCHqtuN5mfF/MGh21Tx/i1OaaWTxjSTpP5rdzAdEJ64jWHU/3vhSDRFJZhD13FwBZPy/BtEfkIu9rL+Qw1p6CjTWkZCV73nJsWTno6tXC2OvynQE5BTBd14qkkFh88vPI+X01ilI4M9fd07GW5BTSPhDRuvD/3FfmTFgQgu0WG3SrA+1iq9pC96EGT4iTVBRDs/pYrumObuU6ombM4NDwF2hUctTe7VFVlSx7Gjbo1iGYrfCH3W8b1qjwvJi4YDa17ErzHWs49O0S4jo+5AJrr46Lb36JmmfCt0tr/Pp3cR7/YovQr+tbz7NSDFarFbPZw2fguRC9Xo9GI/f0vu0S8b++DznzV5D6+hRiZ71HXJBopPhhF3ywHn4aKbTu3JnM70UaNui2YSilvK9Lj0KeVWFzx0EMX/otWT8tJvCmAQxuJLrqVx6HXLP7ChanvjEFNScPQ/tEAkYNKPP8Czliogh4V7QOpGPnddR97nbOrFzHtTuX8O2csbx1T7yrTboiTFv2YD50AsXPl4ARfVlzGlLzxNSFrnHFz9WMGAg71hDy51JU6/0oWq1rjL4CTDsOkP2zEBUN/+/Dzmjd0TSYvVec4ymLlqqqJCcnk56e7mpTPBqNRkNCQgJ6vQdL75eT8P/cR87vq8ldtoHcNVvx69meCR3FZ2vjGSEF0sONl0jT3iOYNu8utWnCwTx7abXfqEGw9FvyVm2h4NAJWjSsS50gMbVm5fHCujt3In/rHqe4csRbj5XqwDr4ZrtIMbeJhu5xZZ7uUUjHzsvw7dAcW++uaFetp+F3X7Lr+tdpGeVqqypO5kwh0hsw7Bo0gf4s2CCOD24Iuks+8x3HdCXpnUCCM1I4uWQbdYfUDE9IVVVSXv4MgICR1+Lbtpnztg83gNUeretQ6zIPUMNwOHVRUVH4+fk5nVhJ5WGz2UhKSuLs2bPEx8d7/d/Yp34dgu4YTuY3c0h97QuMf04lNlDD2JbCMXh/vXAK3PXPlPndfAD8B/VAFx1+2fNS82C1vfqmf+9a+A3qQe7itaR//ANRn/2HIQ1F2vKPw+7n2Kk2GykvfgxA4C2D8W1X9rzECzkwzT5N8uGO7vv+VRXSsfNC4v97P6eu2UDvfSv5asY+Jj7RrOw7uRG27Fyy5y4DIHDMUEwWWOJIwzb+9/nRYXr+6tSPLqt/49T0xTXGsctdvJb8v7ej+OoJ+7/7ncf3XID59nrCZ7q5yLhKxmq1Op268PDLf0FJrp7IyEiSkpKwWCz4+Lhp3q0aCX36TrJ++kNEx+ctJ/CG/jzYAWbuhn+ShbBtvwRXW/lvRNPEnwAEjRte6rm/HxK1Zi2ioGEY5D9xO7mL15L1y1JCn72LwQ1jmboNlh0TY9V83cgzyPp5CaZt+1AC/Iqtg6Xx6WaRVm4TDdfWr2ID3RBZaOGFGBIbiPQk0G7mZDafqVmiTdkLVqLm5OGTUAffrq1ZcxIyTULDreNlold+N4nXG7Z2Nbbs3Ooz9gpRzRYuvvYFAMH3j8anTrTztvf+Fj+vbwyJHlJb56ip8/MrW2xUcnU4UrBWq9XFlrgHushQQh4ZA0DqW1+iFpiJ8hfyJyBq7dxR1y57/gpsmdno6saW2jQBhWnY4U3ET9929u56q5X0T2fQJgZqBUCOuXCT7A7YsnJIfX0yAKFPjSs1KungRDrM2CWuP9fd+6J1IB07ryXupbux6nxod3wbv3292S0XrsvhaJoIHDMERVFYYI9eXdcItJf5j+4+rDmnw+pgKMjnxHT3lz7JnPYb5iOn0EaGEvrYbc7jm8/AiuOgVeDJLpe/f03F21OD1YH8G/+bkAduRhsVhuV4EhnThF7kA+3F2K09F9xzGoUjDVtW08SZTDGdRqF4Y5ljAkfmzN+xnUvh5ubi+Pc7q8riipP2v+lYz6fiU78OIfeNKtd93l8vopO960I3L6utcyAdOy/FJy4G/e0jAOg5ezJrjteMAYkFR06Sv2EHaDQE3jyIfIvo9gIhSnw5ogIUdlwrZgpmT5qBWuC+XZfWjCxS358GQOizdzkHXKsqvGOP1t3c3DtG40gk1YHG30jos0L+JO1/07BmZhNmhLvsI/r+t8G9ZsgWa5q4dUip5zo2vp3rQGxg4XHfbm3w7dwKCsykfz6LW1uIDePmJNh3oQqNLycFR06SPmU2AOGvP4KiL7tsYPf5wjKVZz2kTOVKkI6dFxP33B2YjX40Sj7EX5OW14ioXdaPfwDg17cTuthIftsv0gd1gqBdTOn3bfzAMFL9w/C7cI7Un/6sBmuvjLQPv8OWmoFPk3oE3Xad8/jKE2LRNWjhsc4uNFBS7dx5552MGDHC1WZ4NEFjh+LTMB7bxQzSP5kBwL3tIMgABy/CwkMuNrAImfaoYllNE1CYhh3RpPhxRVGcEzgyp88jwpTOQPuIsR92Vaq5FUa1WDj/yEQwW/Dr1wX/AeXz0t4tUqbSogY2BVYW0rHzYrThIQQ8eCsA/ed/xZL97hvFAvFhd7S8B44ZiqqKzjWAO1qVXUvRP9HA4mvE601+/3tUi6UKrb0yCg4eJ+PLXwEIf+UhFJ2oYraphYvWuNYQE+AqCyWXcuedQjRaURR8fHxISEjg2WefJT8/39WmSSqAotMR/pIozk//fBYF+48RbBDOHYhO9AI3KEs0n0omc8ZCAILuKn1W6oEU2JsiRJcHN/z37ca+nTC0boKam0/6lNncbp+XO2c/ZJkq2/Lyk/bxD5g270YT6E/Ee0+V6z5/nxLjFXUaeLprFRvo5kjHzsup9cho8kPCqJ12hi0fL8TiRumGS8n9awPW5BQ04cH4D+zOulNw4KIQ1LylRdn312kg+p7hpPsF45t0hqw5y6re6AqgWq2cf3QiFJjx69elmBjxokOw9wIE6OFBOWXC7Rg0aBBnz57l6NGjfPjhh0yZMoVXXnnF1WZJKojf4J74DewOZgvnH38b1WrlrjYQYYRj6YUbSVeS9t63UGDGt3tbjD1KHxD99Xbxs18ChJQwVlVRFEIcUbuvfqVTYBYNQkVH6Zz9lWx4Ocn/Zx9p700DIOKdJ/CJKyMVg71MZZ24PqaFHK8oHTsvRxPgR/jT4wAYuHgaP67PdrFFJaOqKmn2urOgMUNR9D58/Y+4bVSimA9bHkZ3MPJb15sBOPv+96hu1BmYMflnTFv3ogn0J/J/zziL3E0WeN8erbuvHYQZXWikpEQMBgMxMTHExcUxYsQI+vfvz9KlSwGhHTdx4kQSEhIwGo20bt2aX375xXlfq9XK3Xff7by9SZMmfPzxx656KV6NoihEvvskSoAfpq17yfhqDgF60V0J8MlGOJfjOvsKDh53Zi3C/+/+UhthzuUUzkm9txT/z39wD3yaJmDLyiHrm7nOqN33O6u/G9iWk8f5B18HqxX/4X0JuKnsCRMg9Pe2nxOb/Ec7VbGRNQDp2EmIHj+MvDpxhOWkkvHmZC64cOG6HLlL1mHacQDFz0jIQ7dwNE3oSwGMb13+xwnxBWXsjWT6BqI7doKcBauqxN6KUnDoBKkTvwJEobCuVmGByJRtcDwDIv3g7rausrB6UVURNXDF5Wq/zHbv3s3ff//tlBWZOHEi3333HZMnT2bPnj088cQT3HbbbaxaJf73bDYbderUYfbs2ezdu5eXX36ZF198kZ9//vlq/4ySK0BXK4rwV8XowdS3pmI+eZabEqFtjKjnfXut62xLnfgV2Gz4De6Bb4fmpZ47fbtIHbePLV3EXNFonJ336VN+5oa4XIw6OJRa/fNyL772uVADiI0k8r2nytXBXWAVnbAA97SFSP8qNrIG4EYyhBJXoeh9qPfJM5y78VEGb57HtG8G8MwjrVxtlhPVZiP17a8BCL53JNqIUL5dIW7rl1Dx7tCxXf35odMo7lz9Deffn47/9deUa0RNVeFIwaqmAox9OhE4prDL7WQGfLZJXH+5l0jFegN5Fmj2uWuee99DFZ+XuXDhQgICArBYLJhMJjQaDZ999hkmk4m33nqLv/76i65dReFP/fr1Wbt2LVOmTKF37974+Pjw2muvOR8rISGB9evX8/PPPzN69OjKfGmSchJ0+zCyf11K/vodXHjqPWJ//oDXeisM/0mkKMe2rP6JL/nb95OzcBUoCmEv3FvquTkF8L29AeK+0rO1AASM6Evae99iPnqagtc/4YZhzzNzt4jadalTCcaXg5yl68n89jcAoj59EW1oULnuN3UrHEkT4yTL81q9ARmxkwAQ0LMt5huHAtB+0rtsOlbgYosKyfl9DQV7DqME+BHy0C1k5BfOSb2SCFaTCDg9/CZy9H6oB46Su2Rd5RpcQTKmzMa0ZQ9KgB9RHz7r3KWqKry8Usw77B5X8lQNiXvQp08ftm/fzsaNGxk3bhzjx49n5MiRHD58mNzcXK699loCAgKcl++++44jRwrF0SZNmkT79u2JjIwkICCAqVOncvLkSRe+Iu9G0WiI/PBZFIOevJWbyf55Ca1jYLQ9SPbyyuqXP0l9cyoAAaMGYmhW+jiFn/YI0faEkPJNXlB0OiI/fA4UhawZi7jj/GpA6PdVR+rZmpLGhcfeBiD4/lH49S5fIfHxdPjEvvF9qRcElrMkx9ORETuJk0bvPMzeZeupe/EEC1+aQbvvxv9r7mp1o9pspL37DQAh941CGxbMj1vtEZ0I6HaFu8lbuwfyW4cbGfv3D1x8fzp+g3q4RLi14PBJUid+CUDEfyegq104YWLJESFG7KOB1/t4l4K6USciZ6567ori7+9Pw4ai7fCbb76hdevWfP3117RoIbp6Fi1aRO3atYvdx2AQ30KzZs3i6aef5oMPPqBr164EBgby3nvvsXHjxqt7IZKrQt8gntBnxpP6xhRS/u8TjH068Wy3MP44JESLZ+0RkbvqIHfNVvJWbgYfHWHP3VXquRYbzvrje9tdXrT9Uozd2hAy4VbSP52J72vvcs1TzVmZG86s3VUrr2TLN5F876tYL6Ti0zSh3GPDVBX+b4XY+PaI+7ecizcjI3YSJ9qQQCLffBSAAX99z+wFx11rEJCzYBUF+46iCfQn+MGbsdhE7QjAXW2u3NnplwB/9x9Nno8v5p0HyF36d2WZXG5Uq5ULj05EzS/AeE1HAoto1uUUwGv28r/720MDLxMjVhSRDnXF5WodaI1Gw4svvsj//d//kZiYiMFg4OTJkzRs2LDYJS5OyOKvW7eObt268dBDD9G2bVsaNmxYLJoncR0hD92CvkUjbOlZpDz/IeFGlSftUhrv/Q3p1aBoo6oqqW9MASDojuvxiY8t9fxFh+B0lkhNjqzgGPCw5+5G37whtosZPD7/bVBVZu6mytQSVIuF8w/8l/y121D8jURPfhmNb/nCbvMOwJqTQtfzzb7etfEtC+nYSYoRPbov6V26oreaCfjvuyRnuU7/RLVaSX1PROuCH7wZbUggiw9DUrZYtK6/ih2aVgM3dg9lXgehA3Xh6Q+wpmZUhtnlJvXNqeRv3o0S4Efkh88Vixh+skm8zjpBMKFjtZolqQRGjRqFVqtlypQpPP300zzxxBNMnz6dI0eOsG3bNj799FOmT58OQKNGjdiyZQtLlizh4MGDvPTSS2zevNnFr0ACoPjoiProOdBqyVmwkozPZ3F7K2gSDmn5Yo5sVZP7xxpM2/ah+BkJfXJcqeeqKkzZKq6Paw2+FYw+KwY90ZNfRjHoCdqwgVt2/UZytug6rWxUVeXC0++Ts2g16H2I/eFtDM1LENsrgYx8eF1ki3mkE9QLqXz7ajLSsZMUQ1EUWn7+FPkGI4knd7Hgv/NdZkv2vBWYDxxHExxA8P1iTqBDl+m2lhVftC7l5kT4uc94TobHYz17QehWVVN/f+b380n/dCYAke8/jU+dwhTswYvwlT2V8t/eYKxgIb/E9eh0OiZMmMC7777LCy+8wEsvvcTEiRNp1qwZgwYNYtGiRSQkJABw//33c+ONN3LzzTfTuXNnLl68yEMPuSgPLfkXhtZNCP/vBAAuvvYFpmXree0acdsPu2Dr2ap7blt2Lhdf+wIQtWe6qLBSz193SqSJfXVCtP1K0DdNIOylBwC4a/Ek4i6e5N2/hexSZZL6+hSyZiwCjYboqa+WqclXlLfXQUoeNAwTGQ1JcRS1ur7JqonMzEyCg4PJyMggKKh8XTWSf7Pvg1/Qv/0x2QZ/MmZPo1fXskUiKxPVauVUjzswHz5J2Av3EvrkHWxJgpGzQa+FdeMhqhLa2v+zHNYvPcQX0+5HZzET8c6TBJeh5n615K7czNlbngGrldBn7yLsmfHO22wq3PKrkBkYUB++HFalprgN+fn5HDt2jISEBHx9S1BSlVQa8m9dcVRV5cKT75L1w0KUAD/qLJ7Mc8cTmLMf4oLgjzGVX7ivqirn7n6ZnAUr0cZGErdmOtrgwFLvc8dvYvrCHa1EXe4VP7fNxtnRT5G3agtHajflgTu+4NneOh6oJCcqfdKPXHxVtL1HfvhcsdGJZbE5CW4SI2SZfRN0ql36+Z5CRXwbGbGTlEjTx2/gQuPmBJhyKLj/RZLP51Xr82f9tBjz4ZNoQoMIvnckVhu8Yq85G9Gkcpw6EGKWyXGN+KKviJBcfPkzTHurrr7JtO8o5+56CaxWAkYNIPTpO4vdPnWrcOp8dfBK7yozQyKRVABFUYh850l8u7RGzc7l7G0v8ErrDOoEwalM0SVb2WR88RM5C1aCj46Yb14v06nbd0E4dRqldEHi8qBoNER9+iKakEAanNnPWz89yzersq5a41RVVTK+net06sJefqBCTl2BFV5cLq7f3Nx7nLqKIh07SYkoWi0tZrxGVkAICWcPseH2iVis1RPcNR87Q8p/PgEg5JExaAL9+W4n7D4PQXp4tnzzoMtFdAA83gXmdhzJ1iZdUU0FnLvvVWy5lV8VbTl3keQxz2LLysG3a2uiLqmr25xUOA/2lV6ivk4ikbgHit6HmG9fRxcfi+X4GXIfepmP+lrQKELbbt6BynuuvHX/cPG/kwGIeOPRMsWIVRVes9ecDW4I8cFXb4MuNpKoL15G8fOl49HNvDvlPr759fgVP571Yjrn7nqJlGf/B0Dww7cQ+sjYCj3GG2tEqUq4EV7sccWmeDzSsZNcloB60QROeQOzVkfLbStY8uz3Vf6cqqmA5HteRs3OxbdTS0IevJlz2YXK4s91r3xl8fGtoVG4whtDXiA3NBzzgeNcfPmzSn0OW24+ybe/gOX0OXzq1yFm2psohkK14dQ8eOQPsKowvAncWo7ZtxKJpHrRRoQS891bKH5G8tZso96kT3i0o9jw/me5iN5dLZbkFM7d+6ozqh80fkSZ95m1B9afFpH+5ypx4+vfvwu1F36ONTaauNTTDH7hfvbPrrjuZ87S9ZzqNU4ILOu0hL1wL+GvVKyOdO5+mL5DXH+vf8mzbyUC6dhJSqXBgNYkP/EEAI2/+4p/ZlatmG/KK5Mo2HkQTVgw0V++iqLT8foayC6ANtEwpgp0o3y0oh4lwz+UV4b+H6qikDl9Hlmzl1TK45tPJHFm2MOY/tmHJiyY2B/fQxtWuKW2qfDkn3A2G+qHwFuydV8icVsMzRsSPfklUBQyv53Lrd+/SecIE1kF8Pjiq5MGUQvMnLvrJawXUtEn1ify/WfK1NdMzoY314jrz3SFuiFX/vwlYWjZiAbLv+RUk9b4F+SiffgFUj/6vlyNZrbsXC48/T7JY57Fej4Vn8Z1qbN4CqFP3lEh3dADKfDCMnF9QkfoVw7RZW9GOnaSMrnmuevZ1f8GNKj4PPtfUnYcq5LnyV6wksyv5wAQ9dl/0NWKYtUJWHBQ1I282Vf8rAq61oHrG8PWeh1Yfu0YAM4/9AYX//sFqvnK28Fy/lzH6X53O53V2O8n4lO/uKry1G1CiNighc+HeM/YMImkpuI/uCcR7z0lZFBmL+Htbx+hbn4KW87CpCtUqlFVlZSXPiN/8240QQHEfPsmGr/Sw1KqKiKFWQVilu34Nlf23GWhjQglYe6HLGo/HI2qkvbmVE60vIFzD/6XzJmLMJ9KFvZYLJh2HCB98k8kj3uRE+1GkTl9HiC6euv89TWG1hXTqcoywQOLhCh9z3h4skulvzyPQ3bFSspFdo6F1f2eoMmR7aRG1abtqsnoIkIq7fHNx5M43fcubFk5hDwyhvCXHyTfAgN+gBMZcHcbeLmKmwmSs6Hvd5BnsjDjwOdE/ypar3w7tyL6y1fRxUaW+7FUi4XUt78m/eMfADC0TyT6q/8WkzUB2JIEo38RKdi3+3lvClZ2alYf8m9deeSu3sK5u1/Glp6FOTycx65/i0N1Epk+HHrWLf/jWJLOc/6xt8V0CSDm+4n4Dyq7iGz+AXhksZhO8/sYaBx+pa+kfLy7Dk5O+Y0Hl32Ob0HxhjpdfCzW1AzU7Nzix+tEE/nJC/j1rHhLraoKp27xEagVAIvGQJjxql5CjUV2xUoqnQB/HXWn/5dzwTGEnT/Dnr4PYjpwvFIeWzUVcO7eV0RTQccWzgHXn28WTl1MAE6196okJkCMzrFpdExo9yiBU15HE+hP/sadnOozntwVm8r1OOZjZ0ga9ZTTqQu+ZyS153/2L6fufA5MKFJXd0vp9dESicTN8OvVgTp/folPk3r4XLzIJz88Qp+dS7h3gcrmpLLvr6oqWb/8yame48hbuRnFV0/kB8+Uy6m7mFuoFPBIp6p36gAe6ggbeo5g+JMLWP/Wx4Q8cQeGji1Aq8Vy8ixqdi6aoAD8ru1K2MsPUPuPycRvmnVFTh2IbMbiI0Li6ouh3uvUVRQZsZNUiLkLjxH1+LPEZCRj9vMn7utX8e9/5bFxW24+F558l+xfl6IJDSJuxTfoakdzJA0GzRDt7V8MgSGNKvFFlILZCoNmwuFUuKEpvNPwNOfveZmC3YdAUQgYNQDfDs0xtGqMPrEhGqMBVVUp2HOEnD/WkLNoNQV7hEy74m8k6sPnCLih37+eJzkbxsyBI2mirm7Brd6dgpVRpOpD/q0rH1tWDucefJ3cJaIG+WhkfVa1HcTNLw6gZfOSPS7rxXQuPPOBkDQBDG2bETXpP+gblS/U98gfMP8gNA0X64deWxmvpGx+2QtPLRXXPxoo1klbVg752/ahDQ1C37wBivbqjVl2FO5dKDa+b/aB265QcNlTqIhvIx07SYWZvjyNoKdfotWpHagaDRGvPEjwgzdXqBgWwLT7MOfufxXzwROgKMT88Db+A7pxJhNumQMnM+CaujBtePU2E6w/LZwumyrG8rzS2UTqS586a0WcaLXoG9fFlpuH5cTZYseNPdoSMfHxEhfpM5lw6xwRjawdCD/eWPkFzzUN6WxUH/JvXTWoNhtp735D2mc/gqkAAKtGi6ZnR6JH90MtsGA+eRbLybOYT57FvP8Ytqwc0GkJffpOQh+7DUVXvnE6S47AfQtFzfG8m6FVdNn3qSxUVYzz+nq7SAFPGw494iv3Oebuh6f+FE7dyGbwwbWyoUw6dtKxq3K+2mgm9//+x9DtCwEIvGUwEe89Va4BzqqqkvnVr1x87QtUUwHa6HCiPv8//Hp14HSmmLxwKlNoMc2+SaRIq5tf94mFRUWMrHmhO+St2ULe6q2Ydh2iYNdBrBfSnOcrvnqMfTrhP7gn/gO7F+t6LcrJDOHUnc4UivU/jhQ/vR3pbFyelStX0qdPH9LS0ggJCSnXferVq8fjjz/O448//q/b5N+6arFmZHHxl+XsnLKYhGO7Sz3Xp0k9oif9X4UaCv44DI8uFtmM+9u7Rs/NpgobFhwUmYafb4Lm5S9BLpVvt8Or9hTzDU2FtIlPNUUj3ZmK+DZXOW1T4q3c09mHL99+lk//14CHln5K1qw/yFm8loARfQkcNRBDxxYlRvAs5y5y4an3nCkLvwHdiPr4ebQRoZyyO3WnM6FuMPw00jVOHYhdoskCLywXQ7WNOniiVwf8enUAhHNqPXcR006hSmrs3g6Nf+kFIMfT4dZfISkbEkJg5o1Qq3QxeUkN4M4772T69Oncf//9TJ48udhtDz/8MJ9//jnjxo1j2rRprjFQUq1ogwOJuns47ccM59EvTtJw9WK6nthCw/pBBNaPQRcfi098LXR1YzG0aFjuKB3AzF3wnxXCsRrUAJ5yUYeoRhFRtJRckeEY9xvMGX11wsiqCh9ugI/tpczj28DLvapOCcGTkY6d5Iq5t73C1Kdu4vmIujy98G2i08+TOW0emdPmoatXm8BRA9CGBVNw6CTmQ8cpOHgC67mLACgGPeGvPkTQ3TeiKIpw6n6B01lQLxhmjYRYFzs9Y1qCySp2jx9tFOKfDwq/DkVR0MVEoIuJKNdjbT0rurvO50CDUBGpi65koWWJ64iLi2PWrFl8+OGHGI3Cwc/Pz2fmzJnEx1dynkpSIwg1wgf3xzM6/D6+Sb8PPx8xNWdc64o7K6oKn2yC/20Qv49pAW/0Aa0L2x8NOph6HYyaDfsvCufu19FX1uBgU8VYtu93it+f6gqPdJTp1ytFdsVKror72sHAOzoydsLPPD3mQ1a0GYTV14jl+BnS3vuWlBc+IvObOeSt2eZ06vQtG1F78RSC7xmJoigcSYOb7U5dQgj8dJPrnToH49vA893F9bfXiR1lTkH5738yQ3S+3vizcOqahItIpHTqPIt27doRFxfHnDlznMfmzJlDfHw8bdu2dR4zmUw8+uijREVF4evrS48ePdi8ubjw2e+//07jxo0xGo306dOH48eP/+v51q5dS8+ePTEajcTFxfHoo4+Sk3OVgzwllU6Uv4jMd6oFuWaxSRw1WzRnlRebKrpfHU7dI52EiLkrnToHQQaYPkLUCh9Nh2E/ijIWazlFmlUVVh6HG34STp2CcFgf7SSduqtB1thJKoX1p+GVlXDgIvgW5HHL2dXccnolIQbQN6qLT+O66BvF49MwHm1wIFYbrDwBP+wU4rwqojt01kgxv9Xd+N/6whRBkAHGtoA721w+VZxpEkKl32wXtTAKMCpR1MOEypb9f3Fp3ZeqqqhVMK+3PCh+vhVqBLrzzjtJT0+nd+/eLFq0iL/++guA/v37c91117Fy5UpCQkKYNm0ajz32GL/88gtfffUVdevW5d1332X+/PkcPnyYsLAwTp06RaNGjXj44Ye577772LJlC0899RTnzp1z1tgdOXKE1q1b88YbbzB06FAuXLjAhAkTaN26Nd9++y0ga+zcDZsq1rq310GOWXSwPtYJ7m0nIl8lkZYnmiR+3Qeb7NIpr/auOhHiq+FQKtw2V3T7AzQME2niwQ1LdtBUFVadEJmQf4S2Mb46eP9aGNa4+uyuScjmCenYuQSLTey6/rceMu1RrbYxojmgdiDUtv/cewF+3C0idA56xcP7A9w3kqWqMHsvfL4FjqWLYzqNWISurS8cuYu5cDFP1J2sPSXmvwJ0j4P/6wmJlVRc7Ilc6mzYcvI4Vm+AS2xJOP5nmfWSRXE4dl9++SVxcXEcOCDqLps2bcqpU6e45557CAkJYdKkSYSGhjJt2jTGjBHTTcxms9MJe+aZZ3jxxReZN28ee/bscT7+888/zzvvvON07O655x60Wi1TpkxxnrN27Vp69+5NTk4Ovr6+0rFzU85kirrdVSfE7z4aaBQOzSOgeRQ0jYAT6fD7YVh7UnSFOs77YIDQu3RX8swwfSd8sQXS7XuyllGiAUJVwaKCxSq+J1aeKO7Q3d4K7m9X+XPAPQnZPCFxCTqN2E0OawzvrIOf94oPr+MDfCnBBhHFGtNS1J25M4oCo5vDTYmw7Bh8uQ02nhFt+XP3l3yfhmEiQte3nkwreAORkZEMHTqUadOmoaoqQ4cOJSKisAbzyJEjmM1munfv7jzm4+NDp06d2LdvHwD79u2jc+fOxR63a9fi6tw7duxg586dzJgxw3lMVVVsNhvHjh2jWbNmVfHyJJVA7SCYPlysGW+thQu5YqO79wLM3vfv8xMjYGgjuL7J1TUmVAdGH3igvaj/+2obfPUP7DovLiUhHbqqQzp2kkonwg/eu1Y0GuxLgTNZYqd6OlNcDzYIJ2loI/HhrkloFBGhu7Y+7DwnUq3H0kTBcISfuIQbIS4Y+iUIZ1dScRQ/XxKO/+my575S7rrrLiZMmADApEmTKsukYmRnZ3P//ffz6KOP/us22ajh/igK3NhMRLJOZxU6dnsvwN4UsT4OaSjWxwQ33/CWRJBBTAoa11po3R1LE+ugjxa0iog+RgcIB1A6dFVDDftaldQk6oeKi6fSKloor0sqH0VRUCqQDnUXBg0aREFBAYqiMHBg8X+OBg0aoNfrWbduHXXrCuFqs9nM5s2bnSnTZs2aMX/+/GL327BhQ7Hf27Vrx969e2nYsGHVvRBJlaMookwlLggGNnC1NZVPuJ/oApZUPzKeIJFIJJWEVqtl37597N27F+0lY5X8/f158MEHeeaZZ1i8eDF79+7l3nvvJTc3l7vvvhuABx54gEOHDvHMM89w4MABZs6c+S/9u+eee46///6bCRMmsH37dg4dOsS8efOckUKJROLdSMdOIpFIKpGgoKDLFje//fbbjBw5kttvv5127dpx+PBhlixZQmioCG3Hx8fz66+/8ttvv9G6dWsmT57MW2+9VewxWrVqxapVqzh48CA9e/akbdu2vPzyy9SqVavKX5tEInF/ZFesRCJxObJTs/qQf2uJpOZREd9GRuwkEolEIpFIPATp2EkkEolEIpF4CNKxk0gkEolEIvEQpGMnkUgkEolE4iFIx04ikUgkEonEQ5COnUQicRtsNpurTfB4PEwIQSKRXEKVTZ5ITU3lkUceYcGCBWg0GkaOHMnHH39MQEBAmfdVVZUhQ4awePFi5s6dy4gRI6rKTIlE4gbo9Xo0Gg1JSUlERkai1+tR5IDdSkdVVS5cuICiKPj4+LjaHIlEUgVUmWM3duxYzp49y9KlSzGbzYwfP5777ruPmTNnlnnfjz76SC7qEokXodFoSEhI4OzZsyQlJbnaHI9GURTq1Knzr8kYEonEM6gSx27fvn0sXryYzZs306FDBwA+/fRThgwZwvvvv1+qQvr27dv54IMP2LJlC7GxsVVhnkQicUP0ej3x8fFYLBasVqurzfFYfHx8pFMnkXgwVeLYrV+/npCQEKdTB9C/f380Gg0bN27khhtuKPF+ubm5jBkzhkmTJhETE1MVpkkkEjfGkSKUaUKJRCK5MqrEsUtOTiYqKqr4E+l0hIWFkZycfNn7PfHEE3Tr1o3hw4eX+7lMJhMmk8n5e2ZmZsUNlkgkEolEIvEAKtQV+/zzz6MoSqmX/fv3X5Eh8+fPZ/ny5Xz00UcVut/EiRMJDg52XuLi4q7o+SUSiUQikUhqOhWK2D311FPceeedpZ5Tv359YmJiOH/+fLHjFouF1NTUy6ZYly9fzpEjRwgJCSl2fOTIkfTs2ZOVK1eWeL8XXniBJ5980vl7ZmamdO4kEolEIpF4JRVy7CIjI4mMjCzzvK5du5Kens7WrVtp3749IBw3m81G586dS7zP888/zz333FPsWMuWLfnwww8ZNmzYZZ/LYDBgMBicvzs0mmRKViKRSCQSiSfg8GnKpUOpVhGDBg1S27Ztq27cuFFdu3at2qhRI/XWW2913n769Gm1SZMm6saNGy/7GIA6d+7cCj3vqVOnVEBe5EVe5EVe5EVe5MWjLqdOnSrTD6oyHbsZM2YwYcIE+vXr5xQo/uSTT5y3m81mDhw4QG5ubqU+b61atTh16hSBgYFVqoXnSPmeOnWKoKCgKnseScWR7437It8b90W+N+6LfG/cl+p6b1RVJSsrq1S5OAeKPTImqSCZmZkEBweTkZEhP2huhnxv3Bf53rgv8r1xX+R7476443sjZ8VKJBKJRCKReAjSsZNIJBKJRCLxEKRjd4UYDAZeeeWVYh25EvdAvjfui3xv3Bf53rgv8r1xX9zxvZE1dhKJRCKRSCQegozYSSQSiUQikXgI0rGTSCQSiUQi8RCkYyeRSCQSiUTiIUjHTiKRSCQSicRDkI7dFTBp0iTq1auHr68vnTt3ZtOmTa42SQJMnDiRjh07EhgYSFRUFCNGjODAgQOuNktyCW+//TaKovD444+72hSJnTNnznDbbbcRHh6O0WikZcuWbNmyxdVmeT1Wq5WXXnqJhIQEjEYjDRo04PXXXy/fvFBJpbJ69WqGDRtGrVq1UBSF3377rdjtqqry8ssvExsbi9FopH///hw6dMgltkrHroL89NNPPPnkk7zyyits27aN1q1bM3DgQM6fP+9q07yeVatW8fDDD7NhwwaWLl2K2WxmwIAB5OTkuNo0iZ3NmzczZcoUWrVq5WpTJHbS0tLo3r07Pj4+/PHHH+zdu5cPPviA0NBQV5vm9bzzzjt88cUXfPbZZ+zbt4933nmHd999l08//dTVpnkdOTk5tG7dmkmTJpV4+7vvvssnn3zC5MmT2bhxI/7+/gwcOJD8/PxqtlTKnVSYzp0707FjRz777DMAbDYbcXFxPPLIIzz//PMutk5SlAsXLhAVFcWqVavo1auXq83xerKzs2nXrh2ff/45b7zxBm3atOGjjz5ytVlez/PPP8+6detYs2aNq02RXMJ1111HdHQ0X3/9tfPYyJEjMRqN/PDDDy60zLtRFIW5c+cyYsQIQETratWqxVNPPcXTTz8NQEZGBtHR0UybNo1bbrmlWu2TEbsKUFBQwNatW+nfv7/zmEajoX///qxfv96FlklKIiMjA4CwsDAXWyIBePjhhxk6dGixz4/E9cyfP58OHTowatQooqKiaNu2LV9++aWrzZIA3bp1Y9myZRw8eBCAHTt2sHbtWgYPHuxiyyRFOXbsGMnJycXWtuDgYDp37uwS30BX7c9Yg0lJScFqtRIdHV3seHR0NPv373eRVZKSsNlsPP7443Tv3p0WLVq42hyvZ9asWWzbto3Nmze72hTJJRw9epQvvviCJ598khdffJHNmzfz6KOPotfrGTdunKvN82qef/55MjMzadq0KVqtFqvVyptvvsnYsWNdbZqkCMnJyQAl+gaO26oT6dhJPJKHH36Y3bt3s3btWleb4vWcOnWKxx57jKVLl+Lr6+tqcySXYLPZ6NChA2+99RYAbdu2Zffu3UyePFk6di7m559/ZsaMGcycOZPmzZuzfft2Hn/8cWrVqiXfG8llkanYChAREYFWq+XcuXPFjp87d46YmBgXWSW5lAkTJrBw4UJWrFhBnTp1XG2O17N161bOnz9Pu3bt0Ol06HQ6Vq1axSeffIJOp8NqtbraRK8mNjaWxMTEYseaNWvGyZMnXWSRxMEzzzzD888/zy233ELLli25/fbbeeKJJ5g4caKrTZMUwfH97y6+gXTsKoBer6d9+/YsW7bMecxms7Fs2TK6du3qQsskIApYJ0yYwNy5c1m+fDkJCQmuNkkC9OvXj127drF9+3bnpUOHDowdO5bt27ej1WpdbaJX071793/JAh08eJC6deu6yCKJg9zcXDSa4l/TWq0Wm83mIoskJZGQkEBMTEwx3yAzM5ONGze6xDeQqdgK8uSTTzJu3Dg6dOhAp06d+Oijj8jJyWH8+PGuNs3refjhh5k5cybz5s0jMDDQWdsQHByM0Wh0sXXeS2Bg4L/qHP39/QkPD5f1j27AE088Qbdu3XjrrbcYPXo0mzZtYurUqUydOtXVpnk9w4YN48033yQ+Pp7mzZvzzz//8L///Y+77rrL1aZ5HdnZ2Rw+fNj5+7Fjx9i+fTthYWHEx8fz+OOP88Ybb9CoUSMSEhJ46aWXqFWrlrNztlpRJRXm008/VePj41W9Xq926tRJ3bBhg6tNkqiqCpR4+fbbb11tmuQSevfurT722GOuNkNiZ8GCBWqLFi1Ug8GgNm3aVJ06daqrTZKoqpqZmak+9thjanx8vOrr66vWr19f/c9//qOaTCZXm+Z1rFixosTvl3Hjxqmqqqo2m0196aWX1OjoaNVgMKj9+vVTDxw44BJbpY6dRCKRSCQSiYcga+wkEolEIpFIPATp2EkkEolEIpF4CNKxk0gkEolEIvEQpGMnkUgkEolE4iFIx04ikUgkEonEQ5COnUQikUgkEomHIB07iUQikUgkEg9BOnYSiUQikUgkHoJ07CQSiUQikUg8BOnYSSQSiUQikXgI0rGTSCQSiUQi8RCkYyeRSCQSiUTiIfw/slTJ6mtv5AcAAAAASUVORK5CYII=\n", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" }, { @@ -320,6 +283,39 @@ "name": "stdout", "output_type": "stream", "text": [ + "Started!\n", + "\n", + "Cycles per second: 5.190e+03\n", + "Head worker occupation: 3.3%\n", + "Progress: 434 / 800 total iterations (54.250%)\n", + "==============================\n", + "Best equations for output 1\n", + "Hall of Fame:\n", + "-----------------------------------------\n", + "Complexity Loss Score Equation\n", + "1 4.883e-02 1.206e+00 x1\n", + "3 2.746e-02 2.877e-01 (x1 + -0.14616892)\n", + "5 6.162e-04 1.899e+00 (x1 / (x1 - -1.0118991))\n", + "7 4.476e-04 1.598e-01 ((x1 / 0.92953163) / (x1 + 1.0533974))\n", + "9 3.997e-04 5.664e-02 (((x1 * 1.0935224) + -0.008988203) / (x1 + 1.0716586))\n", + "13 3.364e-04 4.306e-02 (x1 * ((((x0 * -0.94923264) / 11.808947) - -1.087501) / (x1 + 1.0548282)))\n", + "15 3.062e-04 4.714e-02 (x1 * ((((x0 * (-1.1005011 - x1)) / 13.075972) - -1.0955853) / (x1 + 1.0604433)))\n", + "\n", + "==============================\n", + "Best equations for output 2\n", + "Hall of Fame:\n", + "-----------------------------------------\n", + "Complexity Loss Score Equation\n", + "1 1.588e-01 -1.000e-10 -0.002322703\n", + "3 2.034e-02 1.028e+00 (0.14746223 - x0)\n", + "5 1.413e-03 1.333e+00 (x0 / (-1.046938 - x0))\n", + "7 6.958e-04 3.543e-01 (x0 / ((x0 + 1.1405994) / -1.1647526))\n", + "9 2.163e-04 5.841e-01 (((x0 + -0.026584703) / (x0 + 1.2191753)) * -1.2456053)\n", + "11 2.163e-04 7.749e-06 ((x0 - 0.026545616) / (((x0 / -1.2450436) + -0.9980602) - -0.019172505))\n", + "\n", + "==============================\n", + "Press 'q' and then to stop execution early.\n", + "\n", "Optimising symbolic expression.\n", "Expressions found: [x1/(x1 + 1.0), x0/(-x0 - 1.0)]\n" ] @@ -332,9 +328,9 @@ ], "metadata": { "kernelspec": { - "display_name": "jax0227", + "display_name": "py38", "language": "python", - "name": "jax0227" + "name": "py38" }, "language_info": { "codemirror_mode": { @@ -346,7 +342,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.8.16" } }, "nbformat": 4, diff --git a/mkdocs.yml b/mkdocs.yml index 16b89b2d..efc43a7e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -90,7 +90,7 @@ plugins: nav: - 'index.md' - - 'further_details/citation.md' + - 'citation.md' - Usage: - 'usage/getting-started.md' - 'usage/how-to-choose-a-solver.md' @@ -120,7 +120,7 @@ nav: - 'api/saveat.md' - 'api/stepsize_controller.md' - 'api/solution.md' - - 'api/citation.md' + - 'api/autocitation.md' - Advanced API: - 'api/adjoints.md' - 'api/events.md' From 813fe0f9b623af99bf280f6a46cc2f2d2161df11 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 20 Feb 2023 16:46:42 -0800 Subject: [PATCH 19/19] Make nbqa happy --- examples/kalman_filter.ipynb | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/kalman_filter.ipynb b/examples/kalman_filter.ipynb index 5a1aea20..430dfe38 100644 --- a/examples/kalman_filter.ipynb +++ b/examples/kalman_filter.ipynb @@ -237,7 +237,6 @@ " R: jnp.ndarray\n", "\n", " def __call__(self, ts, ys, us: Optional[jnp.ndarray] = None):\n", - "\n", " A, B, C = self.sys.A, self.sys.B, self.sys.C\n", "\n", " y_t = dfx.LinearInterpolation(ts=ts, ys=ys)\n", @@ -303,7 +302,6 @@ " n_gradient_steps=0,\n", " print_every=10,\n", "):\n", - "\n", " xs, ys = simulate_lti_system(\n", " sys_true, sys_true_x0, ts, std_measurement_noise=sys_true_std_measurement_noise\n", " )\n",