diff --git a/docsrc/index.rst b/docsrc/index.rst index 757acc2011..82600dce98 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -37,6 +37,7 @@ User Guide * :ref:`saving_models` * :ref:`runtime` * :ref:`using_dla` +* :ref:`mixed_precision` .. toctree:: :caption: User Guide @@ -48,6 +49,7 @@ User Guide user_guide/saving_models user_guide/runtime user_guide/using_dla + user_guide/mixed_precision tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage tutorials/_rendered_examples/dynamo/vgg16_ptq tutorials/_rendered_examples/dynamo/engine_caching_example @@ -118,6 +120,8 @@ Tutorials tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2 tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example + tutorials/_rendered_examples/dynamo/torch_export_gpt2 + tutorials/_rendered_examples/dynamo/torch_export_llama2 Python API Documentation ------------------------ diff --git a/docsrc/user_guide/mixed_precision.rst b/docsrc/user_guide/mixed_precision.rst new file mode 100644 index 0000000000..dca0b033e6 --- /dev/null +++ b/docsrc/user_guide/mixed_precision.rst @@ -0,0 +1,74 @@ +.. _mixed_precision: + +Compile Mixed Precision models with Torch-TensorRT +==================================== +.. currentmodule:: torch_tensorrt.dynamo + +.. automodule:: torch_tensorrt.dynamo + :members: + :undoc-members: + :show-inheritance: + +Consider the following Pytorch model which explicitly casts intermediate layer to run in FP16. + +.. code-block:: python + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10,10) + self.linear2 = torch.nn.Linear(10,30).half() + self.linear3 = torch.nn.Linear(30,40) + + def forward(self, x): + x = self.linear1(x) + x = x.to(torch.float16) + x = self.linear2(x) + x = x.to(torch.float32) + x = self.linear3(x) + return x + + +If we compile the above model using Torch-TensorRT, layer profiling logs indicate that all the layers are +run in FP32. This is because TensorRT picks the kernels for layers which result in the best performance. + +.. code-block:: python + + inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()] + mod = MyModule().eval().cuda() + ep = torch.export.export(mod, tuple(inputs)) + with torch_tensorrt.logging.debug(): + trt_gm = torch_tensorrt.dynamo.compile(ep, + inputs=inputs, + debug=True) + + # Debug log info + # Layers: + # Name: __myl_MulSum_myl0_0, LayerType: kgen, Inputs: [ { Name: __mye116_dconst, Dimensions: [10,10], Format/Datatype: Float }, { Name: x, Dimensions: [10,1], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Float }], TacticName: __myl_MulSum_0xfa6c1858aea1b13b03f90165d7149ec6, StreamId: 0, Metadata: + # Name: __myl_AddResMulSum_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye131_dconst, Dimensions: [10,30], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Float }, { Name: linear1/addmm_constant_0 _ linear1/addmm_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,10], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_AddResMulSum_0xb3915d7ebfe48be45b6d49083479e12f, StreamId: 0, Metadata: + # Name: __myl_AddResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye146_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_AddResMulSumAdd_0xcdd0085ad25f5f45ac5fafb72acbffd6, StreamId: 0, Metadata: + + +In order to respect the types specified by the user in the model (eg: in this case, ``linear2`` layer to run in FP16), users can enable +the compilation setting ``use_explicit_typing=True``. Compiling with this option results in the following TensorRT logs + +.. note:: If you enable ``use_explicit_typing=True``, only torch.float32 is supported in the enabled_precisions. + +.. code-block:: python + + inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()] + mod = MyModule().eval().cuda() + ep = torch.export.export(mod, tuple(inputs)) + with torch_tensorrt.logging.debug(): + trt_gm = torch_tensorrt.dynamo.compile(ep, + inputs=inputs, + use_explicit_typing=True + debug=True) + + # Debug log info + # Layers: + # Name: __myl_MulSumAddCas_myl0_0, LayerType: kgen, Inputs: [ { Name: linear1/addmm_constant_0 _ linear1/addmm_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,10], Format/Datatype: Float }, { Name: __mye112_dconst, Dimensions: [10,10], Format/Datatype: Float }, { Name: x, Dimensions: [10,1], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], TacticName: __myl_MulSumAddCas_0xacf8f5dd9be2f3e7bb09cdddeac6c936, StreamId: 0, Metadata: + # Name: __myl_ResMulSumAddCas_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye127_dconst, Dimensions: [10,30], Format/Datatype: Half }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantHalf, Dimensions: [1,30], Format/Datatype: Half }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_ResMulSumAddCas_0x5a3b318b5a1c97b7d5110c0291481337, StreamId: 0, Metadata: + # Name: __myl_ResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye142_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_ResMulSumAdd_0x3fad91127c640fd6db771aa9cde67db0, StreamId: 0, Metadata: + +Now the ``linear2`` layer runs in FP16 as shown in the above logs. \ No newline at end of file diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index ff3563cffe..83655628bc 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -1,15 +1,24 @@ .. _torch_compile: -Dynamo / ``torch.compile`` ----------------------------- +Torch-TensorRT Examples +==================================== -Torch-TensorRT provides a backend for the new ``torch.compile`` API released in PyTorch 2.0. In the following examples we describe -a number of ways you can leverage this backend to accelerate inference. +Please refer to the following examples which demonstrate the usage of different features of Torch-TensorRT. We also provide +examples of Torch-TensorRT compilation of select computer vision and language models. -* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` -* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` +Dependencies +------------------------------------ + +Please install the following external dependencies (assuming you already have correct `torch`, `torch_tensorrt` and `tensorrt` libraries installed (`dependencies `_)) + +.. code-block:: python + + pip install -r requirements.txt + + +Compiler Features +------------------------------------ * :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API -* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` * :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"` * :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines * :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights @@ -17,3 +26,11 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile`` * :ref:`engine_caching_example`: Utilizing engine caching to speed up compilation times * :ref:`engine_caching_bert_example`: Demonstrating engine caching on BERT + +Model Zoo +------------------------------------ +* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` +* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` +* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` +* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`) +* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`) \ No newline at end of file diff --git a/examples/dynamo/requirements.txt b/examples/dynamo/requirements.txt index 6e53935186..59a802918c 100644 --- a/examples/dynamo/requirements.txt +++ b/examples/dynamo/requirements.txt @@ -1,4 +1,4 @@ cupy==13.1.0 -torch>=2.4.0.dev20240503+cu121 -torch-tensorrt>=2.4.0.dev20240503+cu121 triton==2.3.0 +diffusers==0.30.3 +transformers==4.44.2 \ No newline at end of file diff --git a/examples/dynamo/torch_export_gpt2.py b/examples/dynamo/torch_export_gpt2.py index a26305e4a3..f9229e420c 100644 --- a/examples/dynamo/torch_export_gpt2.py +++ b/examples/dynamo/torch_export_gpt2.py @@ -25,12 +25,16 @@ # CPU is used here so that GPU memory is reserved for TRT compilation. with torch.no_grad(): tokenizer = AutoTokenizer.from_pretrained("gpt2") - model = AutoModelForCausalLM.from_pretrained( - "gpt2", - pad_token_id=tokenizer.eos_token_id, - use_cache=False, - attn_implementation="eager", - ).eval() + model = ( + AutoModelForCausalLM.from_pretrained( + "gpt2", + pad_token_id=tokenizer.eos_token_id, + use_cache=False, + attn_implementation="eager", + ) + .eval() + .half() + ) # %% # Tokenize a sample input prompt and get pytorch model outputs @@ -48,6 +52,10 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # Export the GPT2 model into an ExportedProgram which is input of TRT compilation +# To compile the model in FP16, we do the following +# 1) Cast the model to FP16 via model.half() +# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation +# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch) gpt2_ep = export_llm(model, input_ids, max_seq_len=1024) trt_model = torch_tensorrt.dynamo.compile( gpt2_ep, @@ -56,6 +64,8 @@ truncate_double=True, device=DEVICE, disable_tf32=True, + use_explicit_typing=True, + use_fp32_acc=True, ) # Auto-regressive generation loop for greedy decoding using TensorRT model @@ -81,6 +91,10 @@ # %% # The output sentences should look like # ============================= -# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my +# Pytorch model generated text: What is parallel programming ? + +# The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that # ============================= -# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my +# TensorRT model generated text: What is parallel programming ? + +# The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that diff --git a/examples/dynamo/torch_export_llama2.py b/examples/dynamo/torch_export_llama2.py index 195944688b..11a0c93276 100644 --- a/examples/dynamo/torch_export_llama2.py +++ b/examples/dynamo/torch_export_llama2.py @@ -24,9 +24,13 @@ # CPU is used here so that GPU memory is reserved for TRT compilation. llama_path = "meta-llama/Llama-2-7b-chat-hf" with torch.no_grad(): - model = AutoModelForCausalLM.from_pretrained( - llama_path, use_cache=False, attn_implementation="eager" - ).eval() + model = ( + AutoModelForCausalLM.from_pretrained( + llama_path, use_cache=False, attn_implementation="eager" + ) + .eval() + .half() + ) tokenizer = AutoTokenizer.from_pretrained(llama_path) @@ -45,15 +49,20 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # Export the llama2 model into an ExportedProgram which is input of TRT compilation +# To compile the model in FP16, we do the following +# 1) Cast the model to FP16 via model.half() +# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation +# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch) llama2_ep = export_llm(model, input_ids, max_seq_len=64) trt_model = torch_tensorrt.dynamo.compile( llama2_ep, inputs=[input_ids], enabled_precisions={torch.float32}, - min_block_size=1, truncate_double=True, device=DEVICE, disable_tf32=True, + use_explicit_typing=True, + use_fp32_acc=True, ) # Auto-regressive generation loop for greedy decoding using TensorRT model @@ -85,6 +94,6 @@ # %% # The output sentences should look like # ============================= -# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my +# Pytorch model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and # ============================= -# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my +# TensorRT model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 97aa2ec443..fc7d1a0bc8 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -88,6 +88,8 @@ def compile( engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR, engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE, custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, + use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING, + use_fp32_acc: bool = _defaults.USE_FP32_ACC, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -158,6 +160,8 @@ def compile( engine_cache_dir (Optional[str]): Directory to store the cached TRT engines engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. + use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. + use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -197,6 +201,20 @@ def compile( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) + if use_explicit_typing: + if len(enabled_precisions) != 1 or not any( + x in enabled_precisions for x in {torch.float32, dtype.f32} + ): + raise AssertionError( + f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" + ) + + if use_fp32_acc: + logger.debug( + "FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \ + This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation." + ) + # Aliasing inputs to arg_inputs for better understanding if not arg_inputs and not inputs: raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") @@ -232,7 +250,7 @@ def compile( logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module - gm = post_lowering(gm) + gm = post_lowering(gm, use_fp32_acc=use_fp32_acc) logger.debug("Lowered Input graph: " + str(gm.graph)) engine_cache = None @@ -281,6 +299,8 @@ def compile( "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, + "use_explicit_typing": use_explicit_typing, + "use_fp32_acc": use_fp32_acc, } settings = CompilationSettings(**compilation_options) @@ -520,6 +540,8 @@ def convert_exported_program_to_serialized_trt_engine( calibrator: object = None, allow_shape_tensors: bool = False, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, + use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING, + use_fp32_acc: bool = _defaults.USE_FP32_ACC, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -578,6 +600,8 @@ def convert_exported_program_to_serialized_trt_engine( calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. + use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ @@ -651,6 +675,8 @@ def convert_exported_program_to_serialized_trt_engine( "dla_local_dram_size": dla_local_dram_size, "dla_global_dram_size": dla_global_dram_size, "timing_cache_path": timing_cache_path, + "use_explicit_typing": use_explicit_typing, + "use_fp32_acc": use_fp32_acc, } exported_program = pre_export_lowering(exported_program) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 68e446dab5..de99df71e0 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -40,6 +40,8 @@ ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache") ENGINE_CACHE_SIZE = 1073741824 CUSTOM_ENGINE_CACHE = None +USE_EXPLICIT_TYPING = False +USE_FP32_ACC = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index f8886fbd67..98865c683e 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -29,7 +29,9 @@ SPARSE_WEIGHTS, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, + USE_EXPLICIT_TYPING, USE_FAST_PARTITIONER, + USE_FP32_ACC, USE_PYTHON_RUNTIME, VERSION_COMPATIBLE, WORKSPACE_SIZE, @@ -78,6 +80,8 @@ class CompilationSettings: timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage + use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. + use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -112,6 +116,8 @@ class CompilationSettings: lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES + use_explicit_typing: bool = USE_EXPLICIT_TYPING + use_fp32_acc: bool = USE_FP32_ACC _SETTINGS_TO_BE_ENGINE_INVARIANT = ( diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 605d963a50..aa8766fdae 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -100,7 +100,7 @@ def _pretraced_backend( logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) - gm = post_lowering(gm) + gm = post_lowering(gm, use_fp32_acc=settings.use_fp32_acc) logger.debug("Lowered Input graph:\n " + str(gm.graph)) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index aab4d521f8..19d80e70b1 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -80,10 +80,11 @@ def __init__( self.builder = trt.Builder(self.logger) flag = 0 - - # It is deprecated to not use this flag - EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - flag |= EXPLICIT_BATCH + if compilation_settings.use_explicit_typing: + STRONGLY_TYPED = 1 << (int)( + trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED + ) + flag |= STRONGLY_TYPED self.ctx = ConversionContext( self.builder.create_network(flag), compilation_settings diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index f0b65b3a6e..06fade9674 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -3,6 +3,7 @@ import logging from typing import Any, List, Optional, Sequence +import tensorrt as trt import torch from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device @@ -18,8 +19,6 @@ from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs -import tensorrt as trt - logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py index f726a1c500..db257b9c4e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py @@ -37,7 +37,11 @@ def convert_activation( layer.beta = beta set_layer_name(layer, target, name, source_ir) - if input_val.dynamic_range is not None and dyn_range_fn is not None: + if ( + not ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + and input_val.dynamic_range is not None + and dyn_range_fn is not None + ): dyn_range = dyn_range_fn(input_val.dynamic_range) mark_as_int8_layer(layer, dyn_range) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index e9e80593e9..ca605c3189 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -13,13 +13,11 @@ broadcast_to_same_shape, cast_trt_tensor, get_trt_tensor, -) -from torch_tensorrt.fx.converters.converter_utils import ( broadcast, has_dynamic_shape, set_layer_name, ) -from torch_tensorrt.fx.types import TRTElementWiseOp, TRTTensor +from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor def get_python_op_from_trt_elementwise_op( @@ -152,7 +150,7 @@ def convert_binary_elementwise( if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape): lhs_val, rhs_val = broadcast( - ctx.net, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs" + ctx, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs" ) else: lhs_val, rhs_val = broadcast_to_same_shape( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 3f8d9667b3..348c71fd87 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -18,14 +18,11 @@ from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) -from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.fx.converters.converter_utils import broadcast from torch_tensorrt.fx.types import TRTTensor -import tensorrt as trt - def trunc_div( ctx: ConversionContext, @@ -69,11 +66,6 @@ def trunc_div( prod_output, ) - # cast the sign_output back to int32 for trunc div - # This is required for scatter_reduce_.two(reduce='mean' where trunc_div casts it to float32 and TRTInterpreter expects int32) - if (isinstance(sign_output, TRTTensor)) and (sign_output.dtype == trt.float32): - sign_output = cast_trt_tensor(ctx, sign_output, trt.int32, name) - # Convert constant input into ITensor for UnaryOperation if not isinstance(input, trt.tensorrt.ITensor): input = get_trt_tensor(ctx, input, f"{name}_input") diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index a4f0c2bc6c..c900c51b8f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -478,10 +478,6 @@ def sign( name: str, input_val: TRTTensor, ) -> TRTTensor: - if (isinstance(input_val, TRTTensor)) and ( - input_val.dtype == trt.int8 or input_val.dtype == trt.int32 - ): - input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( ctx, target, source_ir, name, trt.UnaryOperation.SIGN, input_val diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 534bc3eac5..4ffbcfdedb 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -296,7 +296,7 @@ class ReduceOperation(Enum): AMAX = ("Amax reduce operation", lambda x, y: torch.max(x, y)) AMIN = ("Amin reduce operation", lambda x, y: torch.min(x, y)) - def __new__(cls, description, func): + def __new__(cls, description: Any, func: Any) -> Any: obj = object.__new__(cls) obj._value_ = auto() obj.description = description @@ -304,8 +304,13 @@ def __new__(cls, description, func): return obj def reduce_operation_with_scatter( - self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor - ): + self, + operation_lhs: Any, + initial_tensor: torch.Tensor, + dim: int, + index_tensor: torch.Tensor, + src_tensor: torch.Tensor, + ) -> Any: scatter_tensor = None if self == ReduceOperation.SUM or self == ReduceOperation.MEAN: scatter_tensor = torch.zeros_like(initial_tensor) @@ -341,7 +346,7 @@ def scatter_reduce_decomposition( scatter_count_tensor = torch.zeros_like(input_tensor) src_shape = list(src_tensor.shape) src_dim = src_shape[dim] - if include_self == False: + if not include_self: raise AssertionError("include_self False for scatter reduce not yet supported") for i in range(0, src_dim): src_slice = torch.select(src_tensor, dim, i) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index b7c65f1880..b6435c0d8c 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -1,8 +1,9 @@ import logging -from typing import Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union import torch +from .accumulate_fp32_matmul import accumulate_fp32_matmul from .constant_folding import constant_fold from .fuse_prims_broadcast import fuse_prims_broadcast from .lower_linear import lower_linear @@ -90,12 +91,16 @@ def _remove_lowering_pass(*, index: int) -> None: return -def post_lowering(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: +def post_lowering(gm: torch.fx.GraphModule, **kwargs: Any) -> torch.fx.GraphModule: """Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule""" logging.debug( f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}" ) - return ATEN_POST_LOWERING_PASSES(gm) + gm = ATEN_POST_LOWERING_PASSES(gm) + if kwargs.get("use_fp32_acc", False): + gm = accumulate_fp32_matmul(gm) + + return gm def pre_export_lowering(ep: torch.export.ExportedProgram) -> torch.fx.GraphModule: diff --git a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py new file mode 100644 index 0000000000..d69249088c --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py @@ -0,0 +1,49 @@ +import logging + +import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def accumulate_fp32_matmul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Replace a matmul layer with fp32 accumulation nodes""" + matmul_targets = [ + torch.ops.aten.mm.default, + torch.ops.aten.bmm.default, + torch.ops.aten.addmm.default, + ] + matmul_nodes = [node for node in gm.graph.nodes if node.target in matmul_targets] + for matmul_node in matmul_nodes: + # Prior to the matmul node, insert a cast to the 32-bit float32 node + node_inputs = matmul_node.all_input_nodes + + for node_input in node_inputs: + with gm.graph.inserting_before(matmul_node): + node_32bit = gm.graph.call_function( + torch.ops.aten._to_copy.default, + args=(node_input,), + kwargs={"dtype": torch.float32}, + ) + + # Replace the input to matmul node with new 32-bit cast node + matmul_node.replace_input_with(node_input, node_32bit) + + # Add a cast back to original precision + with gm.graph.inserting_after(matmul_node): + node_orig_precision = gm.graph.call_function( + torch.ops.aten._to_copy.default, + args=(matmul_node,), + kwargs={"dtype": torch.float16}, + ) + matmul_node.replace_all_uses_with(node_orig_precision, propagate_meta=False) + # This is a hack. replace_all_uses_with isn't working here. It complains node_orig_precision is already being used before created. + node_orig_precision.replace_input_with( + node_orig_precision.all_input_nodes[0], matmul_node + ) + + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after changing matmuls to use FP32 accumulation:\n{gm.graph}") + return gm diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index ad3fc8fa79..86cd5e3699 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -466,9 +466,6 @@ def forward(self, input, weight, bias): ) torch._dynamo.reset() - @unittest.skip( - "This test has threshold failures. This is tracked at https://github.com/pytorch/TensorRT/issues/2715", - ) def test_lower_linear_batch(self): class Linear(torch.nn.Module): def forward(self, input, weight, bias): @@ -575,5 +572,44 @@ def forward(self, input): torch._dynamo.reset() +class TestFP32Accumulation(TestCase): + def test_fp32_acc(self): + class FP32Acc(torch.nn.Module): + def forward(self, input, weight): + out = torch.ops.aten.mm.default(input, weight) + return out + + inputs = [ + torch.rand((3, 4)).cuda(), + torch.rand((4, 5)).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(FP32Acc()) + expected_ops = {torch.ops.aten._to_copy.default, torch.ops.aten.mm.default} + unexpected_ops = {} + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + use_fp32_acc=True, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + torch._dynamo.reset() + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index 9e33aec53a..c0126fad24 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -24,6 +24,7 @@ def fx_dynamo_testing_backend( min_block_size: int = 3, torch_executed_ops: Sequence[str] = set(), use_fast_partitioner: bool = True, + use_fp32_acc: bool = False, ): """Helper Dynamo backend exclusively for testing""" custom_backend = partial( @@ -50,7 +51,7 @@ def fx_dynamo_testing_backend( decompositions=get_decompositions(), ) - gm = post_lowering(gm) + gm = post_lowering(gm, use_fp32_acc=use_fp32_acc) trt_compiled = custom_backend( gm, @@ -153,6 +154,7 @@ def lower_graph_testing( torch_executed_ops: Sequence[str] = set(), testing_partitioning: bool = False, use_fast_partitioner: bool = True, + use_fp32_acc: bool = False, ): """Helper function to assist with graph lowering for testing of Dynamo compile @@ -165,6 +167,7 @@ def lower_graph_testing( torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage testing_partitioning: Whether partitioning is being tested (to analyze only TRT-supported ops) use_fast_partitioner: Whether to use the fast or global partitioner + use_fp32_acc: This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. Returns: If testing_partitioning: List[torch.fx.GraphModule], Set, Set: List of partitioned graph outputs, unexpected ops seen, expected ops unseen @@ -179,6 +182,7 @@ def lower_graph_testing( min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, use_fast_partitioner=use_fast_partitioner, + use_fp32_acc=use_fp32_acc, ) # Invoke compilation diff --git a/uv.lock b/uv.lock index a89106e927..493873c773 100644 --- a/uv.lock +++ b/uv.lock @@ -65,22 +65,10 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/04/b0/46fb0d4e00372f4a86a6f8efa3cb193c9f64863615e39010b1477e010578/black-24.8.0.tar.gz", hash = "sha256:2500945420b6784c38b9ee885af039f5e7471ef284ab03fa35ecdde4688cd83f", size = 644810 } wheels = [ - { url = "https://files.pythonhosted.org/packages/47/6e/74e29edf1fba3887ed7066930a87f698ffdcd52c5dbc263eabb06061672d/black-24.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:09cdeb74d494ec023ded657f7092ba518e8cf78fa8386155e4a03fdcc44679e6", size = 1632092 }, - { url = "https://files.pythonhosted.org/packages/ab/49/575cb6c3faee690b05c9d11ee2e8dba8fbd6d6c134496e644c1feb1b47da/black-24.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:81c6742da39f33b08e791da38410f32e27d632260e599df7245cccee2064afeb", size = 1457529 }, { url = "https://files.pythonhosted.org/packages/7a/b4/d34099e95c437b53d01c4aa37cf93944b233066eb034ccf7897fa4e5f286/black-24.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:707a1ca89221bc8a1a64fb5e15ef39cd755633daa672a9db7498d1c19de66a42", size = 1757443 }, - { url = "https://files.pythonhosted.org/packages/87/a0/6d2e4175ef364b8c4b64f8441ba041ed65c63ea1db2720d61494ac711c15/black-24.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:d6417535d99c37cee4091a2f24eb2b6d5ec42b144d50f1f2e436d9fe1916fe1a", size = 1418012 }, - { url = "https://files.pythonhosted.org/packages/08/a6/0a3aa89de9c283556146dc6dbda20cd63a9c94160a6fbdebaf0918e4a3e1/black-24.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fb6e2c0b86bbd43dee042e48059c9ad7830abd5c94b0bc518c0eeec57c3eddc1", size = 1615080 }, - { url = "https://files.pythonhosted.org/packages/db/94/b803d810e14588bb297e565821a947c108390a079e21dbdcb9ab6956cd7a/black-24.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:837fd281f1908d0076844bc2b801ad2d369c78c45cf800cad7b61686051041af", size = 1438143 }, { url = "https://files.pythonhosted.org/packages/a5/b5/f485e1bbe31f768e2e5210f52ea3f432256201289fd1a3c0afda693776b0/black-24.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:62e8730977f0b77998029da7971fa896ceefa2c4c4933fcd593fa599ecbf97a4", size = 1738774 }, - { url = "https://files.pythonhosted.org/packages/a8/69/a000fc3736f89d1bdc7f4a879f8aaf516fb03613bb51a0154070383d95d9/black-24.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:72901b4913cbac8972ad911dc4098d5753704d1f3c56e44ae8dce99eecb0e3af", size = 1427503 }, - { url = "https://files.pythonhosted.org/packages/a2/a8/05fb14195cfef32b7c8d4585a44b7499c2a4b205e1662c427b941ed87054/black-24.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7c046c1d1eeb7aea9335da62472481d3bbf3fd986e093cffd35f4385c94ae368", size = 1646132 }, - { url = "https://files.pythonhosted.org/packages/41/77/8d9ce42673e5cb9988f6df73c1c5c1d4e9e788053cccd7f5fb14ef100982/black-24.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:649f6d84ccbae73ab767e206772cc2d7a393a001070a4c814a546afd0d423aed", size = 1448665 }, { url = "https://files.pythonhosted.org/packages/cc/94/eff1ddad2ce1d3cc26c162b3693043c6b6b575f538f602f26fe846dfdc75/black-24.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2b59b250fdba5f9a9cd9d0ece6e6d993d91ce877d121d161e4698af3eb9c1018", size = 1762458 }, - { url = "https://files.pythonhosted.org/packages/28/ea/18b8d86a9ca19a6942e4e16759b2fa5fc02bbc0eb33c1b866fcd387640ab/black-24.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:6e55d30d44bed36593c3163b9bc63bf58b3b30e4611e4d88a0c3c239930ed5b2", size = 1436109 }, - { url = "https://files.pythonhosted.org/packages/13/b2/b3f24fdbb46f0e7ef6238e131f13572ee8279b70f237f221dd168a9dba1a/black-24.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eab4dd44ce80dea27dc69db40dab62d4ca96112f87996bca68cd75639aeb2e4c", size = 1631706 }, - { url = "https://files.pythonhosted.org/packages/d9/35/31010981e4a05202a84a3116423970fd1a59d2eda4ac0b3570fbb7029ddc/black-24.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3c4285573d4897a7610054af5a890bde7c65cb466040c5f0c8b732812d7f0e5e", size = 1457429 }, { url = "https://files.pythonhosted.org/packages/27/25/3f706b4f044dd569a20a4835c3b733dedea38d83d2ee0beb8178a6d44945/black-24.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9e84e33b37be070ba135176c123ae52a51f82306def9f7d063ee302ecab2cf47", size = 1756488 }, - { url = "https://files.pythonhosted.org/packages/63/72/79375cd8277cbf1c5670914e6bd4c1b15dea2c8f8e906dc21c448d0535f0/black-24.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:73bbf84ed136e45d451a260c6b73ed674652f90a2b3211d6a35e78054563a9bb", size = 1417721 }, { url = "https://files.pythonhosted.org/packages/27/1e/83fa8a787180e1632c3d831f7e58994d7aaf23a0961320d21e84f922f919/black-24.8.0-py3-none-any.whl", hash = "sha256:972085c618ee94f402da1af548a4f218c754ea7e5dc70acb168bfaca4c2542ed", size = 206504 }, ] @@ -107,42 +95,26 @@ name = "charset-normalizer" version = "3.3.2" source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, - { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-py3-none-any.whl" }, ] @@ -152,14 +124,11 @@ version = "14.0.6" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/0c/92/d57c1b3ea310ae0f48ab51a5aa2c87c4c732c3d79037ad2527f2eed7ca34/clang-format-14.0.6.tar.gz", hash = "sha256:d5c96b500d7f8b5d2db5b75ac035be387512850ad589cdc3019666b861382136", size = 9598 } wheels = [ - { url = "https://files.pythonhosted.org/packages/08/62/71ffc9213f66cab7dd5adc5e933b5f64323272c197fcff2905674016c03d/clang_format-14.0.6-py2.py3-none-macosx_10_9_universal2.whl", hash = "sha256:bd400c47665dd19afc03f98e747f78ed828abab99c6a1b07e137b35c1cd3cc26", size = 1016919 }, { url = "https://files.pythonhosted.org/packages/5f/de/f666633c30a4cc9e987d153db992849bfeea03ad200bf1cfa937039c64ff/clang_format-14.0.6-py2.py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:13f2d6d4a2af004a783c65f0921afa8f0384bffcdaf500b6c2cb542edeb0b4a5", size = 1259649 }, { url = "https://files.pythonhosted.org/packages/ce/27/df41404419d9116e071d0b8a5ba0a0969d9db7587af689ec81ec75c1f18a/clang_format-14.0.6-py2.py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d7c1c5e404c58e55f0170f01b3c5611dce6c119e62b5d1020347e0ad97d5a047", size = 1147591 }, { url = "https://files.pythonhosted.org/packages/23/e4/ea55429601432913e9fe40686c3c09a79338075c830a523fabc71aa49c69/clang_format-14.0.6-py2.py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dbfd60528eb3bb7d7cfe8576faa70845fbf93601f815ef75163d36606e87f388", size = 1205157 }, { url = "https://files.pythonhosted.org/packages/8c/67/e1faf73ea166669e1698f55f3ae366369db57d75eb3b6c04c93620ebac12/clang_format-14.0.6-py2.py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c93580945f75de7e01996f1fb3cf67e4dc424f1c864e237c85614fb99a48c7a4", size = 1949067 }, { url = "https://files.pythonhosted.org/packages/cd/3b/3e20072464e98314eafdc5bc5744454ade6e6f5e525fb29f6b4555173811/clang_format-14.0.6-py2.py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aaf4edecc46a24f0b572b82cf5827e292ad1c137903427627c4d5f671668cc2b", size = 1187836 }, - { url = "https://files.pythonhosted.org/packages/6e/06/302903004246dd62a11965e9f672b975c58ad6966985dbcaa14c6cdb4779/clang_format-14.0.6-py2.py3-none-win32.whl", hash = "sha256:810c649ab97d208cd418c897d50ab6e958eb8d96854527edd80d0dd21a75e914", size = 833512 }, - { url = "https://files.pythonhosted.org/packages/63/7a/1f11404d5097263ad065cf9166dd00be0a8c1040c1ec4f57921ac07591eb/clang_format-14.0.6-py2.py3-none-win_amd64.whl", hash = "sha256:d780c04334bca80f2b60d25bf53c37bd0618520ee295a7888a11f25bde114ac4", size = 1007035 }, ] [[package]] @@ -447,30 +416,18 @@ version = "2.1.5" source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } sdist = { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5.tar.gz" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl" }, - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_universal2.whl" }, - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp311-cp311-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl" }, - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_universal2.whl" }, - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp39-cp39-win_amd64.whl" }, ] [[package]] @@ -513,36 +470,23 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/5c/86/5d7cbc4974fd564550b80fbb8103c05501ea11aa7835edf3351d90095896/mypy-1.11.2.tar.gz", hash = "sha256:7f9993ad3e0ffdc95c2a14b66dee63729f021968bff8ad911867579c65d13a79", size = 3078806 } wheels = [ - { url = "https://files.pythonhosted.org/packages/78/cd/815368cd83c3a31873e5e55b317551500b12f2d1d7549720632f32630333/mypy-1.11.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d42a6dd818ffce7be66cce644f1dff482f1d97c53ca70908dff0b9ddc120b77a", size = 10939401 }, - { url = "https://files.pythonhosted.org/packages/f1/27/e18c93a195d2fad75eb96e1f1cbc431842c332e8eba2e2b77eaf7313c6b7/mypy-1.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:801780c56d1cdb896eacd5619a83e427ce436d86a3bdf9112527f24a66618fef", size = 10111697 }, { url = "https://files.pythonhosted.org/packages/dc/08/cdc1fc6d0d5a67d354741344cc4aa7d53f7128902ebcbe699ddd4f15a61c/mypy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41ea707d036a5307ac674ea172875f40c9d55c5394f888b168033177fce47383", size = 12500508 }, { url = "https://files.pythonhosted.org/packages/64/12/aad3af008c92c2d5d0720ea3b6674ba94a98cdb86888d389acdb5f218c30/mypy-1.11.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e658bd2d20565ea86da7d91331b0eed6d2eee22dc031579e6297f3e12c758c8", size = 13020712 }, - { url = "https://files.pythonhosted.org/packages/03/e6/a7d97cc124a565be5e9b7d5c2a6ebf082379ffba99646e4863ed5bbcb3c3/mypy-1.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:478db5f5036817fe45adb7332d927daa62417159d49783041338921dcf646fc7", size = 9567319 }, - { url = "https://files.pythonhosted.org/packages/e2/aa/cc56fb53ebe14c64f1fe91d32d838d6f4db948b9494e200d2f61b820b85d/mypy-1.11.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75746e06d5fa1e91bfd5432448d00d34593b52e7e91a187d981d08d1f33d4385", size = 10859630 }, - { url = "https://files.pythonhosted.org/packages/04/c8/b19a760fab491c22c51975cf74e3d253b8c8ce2be7afaa2490fbf95a8c59/mypy-1.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a976775ab2256aadc6add633d44f100a2517d2388906ec4f13231fafbb0eccca", size = 10037973 }, { url = "https://files.pythonhosted.org/packages/88/57/7e7e39f2619c8f74a22efb9a4c4eff32b09d3798335625a124436d121d89/mypy-1.11.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cd953f221ac1379050a8a646585a29574488974f79d8082cedef62744f0a0104", size = 12416659 }, { url = "https://files.pythonhosted.org/packages/fc/a6/37f7544666b63a27e46c48f49caeee388bf3ce95f9c570eb5cfba5234405/mypy-1.11.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:57555a7715c0a34421013144a33d280e73c08df70f3a18a552938587ce9274f4", size = 12897010 }, - { url = "https://files.pythonhosted.org/packages/84/8b/459a513badc4d34acb31c736a0101c22d2bd0697b969796ad93294165cfb/mypy-1.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:36383a4fcbad95f2657642a07ba22ff797de26277158f1cc7bd234821468b1b6", size = 9562873 }, - { url = "https://files.pythonhosted.org/packages/35/3a/ed7b12ecc3f6db2f664ccf85cb2e004d3e90bec928e9d7be6aa2f16b7cdf/mypy-1.11.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e8960dbbbf36906c5c0b7f4fbf2f0c7ffb20f4898e6a879fcf56a41a08b0d318", size = 10990335 }, - { url = "https://files.pythonhosted.org/packages/04/e4/1a9051e2ef10296d206519f1df13d2cc896aea39e8683302f89bf5792a59/mypy-1.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06d26c277962f3fb50e13044674aa10553981ae514288cb7d0a738f495550b36", size = 10007119 }, { url = "https://files.pythonhosted.org/packages/f3/3c/350a9da895f8a7e87ade0028b962be0252d152e0c2fbaafa6f0658b4d0d4/mypy-1.11.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e7184632d89d677973a14d00ae4d03214c8bc301ceefcdaf5c474866814c987", size = 12506856 }, { url = "https://files.pythonhosted.org/packages/b6/49/ee5adf6a49ff13f4202d949544d3d08abb0ea1f3e7f2a6d5b4c10ba0360a/mypy-1.11.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a66169b92452f72117e2da3a576087025449018afc2d8e9bfe5ffab865709ca", size = 12952066 }, - { url = "https://files.pythonhosted.org/packages/27/c0/b19d709a42b24004d720db37446a42abadf844d5c46a2c442e2a074d70d9/mypy-1.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:969ea3ef09617aff826885a22ece0ddef69d95852cdad2f60c8bb06bf1f71f70", size = 9664000 }, - { url = "https://files.pythonhosted.org/packages/16/64/bb5ed751487e2bea0dfaa6f640a7e3bb88083648f522e766d5ef4a76f578/mypy-1.11.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:801ca29f43d5acce85f8e999b1e431fb479cb02d0e11deb7d2abb56bdaf24fd6", size = 10937294 }, - { url = "https://files.pythonhosted.org/packages/a9/a3/67a0069abed93c3bf3b0bebb8857e2979a02828a4a3fd82f107f8f1143e8/mypy-1.11.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:af8d155170fcf87a2afb55b35dc1a0ac21df4431e7d96717621962e4b9192e70", size = 10107707 }, { url = "https://files.pythonhosted.org/packages/2f/4d/0379daf4258b454b1f9ed589a9dabd072c17f97496daea7b72fdacf7c248/mypy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7821776e5c4286b6a13138cc935e2e9b6fde05e081bdebf5cdb2bb97c9df81d", size = 12498367 }, { url = "https://files.pythonhosted.org/packages/3b/dc/3976a988c280b3571b8eb6928882dc4b723a403b21735a6d8ae6ed20e82b/mypy-1.11.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:539c570477a96a4e6fb718b8d5c3e0c0eba1f485df13f86d2970c91f0673148d", size = 13018014 }, - { url = "https://files.pythonhosted.org/packages/83/84/adffc7138fb970e7e2a167bd20b33bb78958370179853a4ebe9008139342/mypy-1.11.2-cp39-cp39-win_amd64.whl", hash = "sha256:3f14cd3d386ac4d05c5a39a51b84387403dadbd936e17cb35882134d4f8f0d24", size = 9568056 }, { url = "https://files.pythonhosted.org/packages/42/3a/bdf730640ac523229dd6578e8a581795720a9321399de494374afc437ec5/mypy-1.11.2-py3-none-any.whl", hash = "sha256:b499bc07dbdcd3de92b0a8b29fdf592c111276f6a12fe29c30f6c417dd546d12", size = 2619625 }, ] [[package]] name = "mypy-extensions" version = "1.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433 } +source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 }, + { url = "https://download.pytorch.org/whl/nightly/mypy_extensions-1.0.0-py3-none-any.whl" }, ] [[package]] @@ -559,7 +503,6 @@ version = "1.11.1.1" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/37/2c/d717d13a413d6f7579cdaa1e28e6e2c98de95461549b08d311c8a5bf4c51/ninja-1.11.1.1.tar.gz", hash = "sha256:9d793b08dd857e38d0b6ffe9e6b7145d7c485a42dcfea04905ca0cdb6017cc3c", size = 132392 } wheels = [ - { url = "https://files.pythonhosted.org/packages/3d/6e/04ed11bb244039908f6f212cb5f3e97933e238655248e4ce307c1687ba1f/ninja-1.11.1.1-py2.py3-none-macosx_10_9_universal2.macosx_10_9_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:376889c76d87b95b5719fdd61dd7db193aa7fd4432e5d52d2e44e4c497bdbbee", size = 270611 }, { url = "https://files.pythonhosted.org/packages/2c/52/0e5423311eb9939b6f9354059a6d88a6211eb4fa1c7a4ef303ecee1c1fe0/ninja-1.11.1.1-py2.py3-none-manylinux1_i686.manylinux_2_5_i686.whl", hash = "sha256:ecf80cf5afd09f14dcceff28cb3f11dc90fb97c999c89307aea435889cb66877", size = 324256 }, { url = "https://files.pythonhosted.org/packages/6d/92/8d7aebd4430ab5ff65df2bfee6d5745f95c004284db2d8ca76dcbfd9de47/ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:84502ec98f02a037a169c4b0d5d86075eaf6afc55e1879003d6cab51ced2ea4b", size = 307194 }, { url = "https://files.pythonhosted.org/packages/01/c8/96424839fd127b4492229acf50763ed9940d864ca35d17d151934aef1f6f/ninja-1.11.1.1-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:73b93c14046447c7c5cc892433d4fae65d6364bec6685411cb97a8bcf815f93a", size = 155643 }, @@ -570,9 +513,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/78/34af88d753389a9412438d16142c77e587e0d69152faf0bbf99701063dd8/ninja-1.11.1.1-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:9df724344202b83018abb45cb1efc22efd337a1496514e7e6b3b59655be85205", size = 419782 }, { url = "https://files.pythonhosted.org/packages/3b/74/de0633f8bced3b188942fca64a950e8f2206c60c10c97af465b356ae9b25/ninja-1.11.1.1-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:3e0f9be5bb20d74d58c66cc1c414c3e6aeb45c35b0d0e41e8d739c2c0d57784f", size = 415476 }, { url = "https://files.pythonhosted.org/packages/9a/f3/3e4a56ff77739d1582749b93497bdebf11e003fbc7a66363ef6c772ebd0a/ninja-1.11.1.1-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:76482ba746a2618eecf89d5253c0d1e4f1da1270d41e9f54dfbd91831b0f6885", size = 379229 }, - { url = "https://files.pythonhosted.org/packages/c5/ee/53df34fcc9c0b1db62b2f2e2c848e28d9354e1c7f0dce029ee50b16ca157/ninja-1.11.1.1-py2.py3-none-win32.whl", hash = "sha256:fa2ba9d74acfdfbfbcf06fad1b8282de8a7a8c481d9dee45c859a8c93fcc1082", size = 265049 }, - { url = "https://files.pythonhosted.org/packages/b6/2f/a3bc50fa63fc4fe9348e15b53dc8c87febfd4e0c660fcf250c4b19a3aa3b/ninja-1.11.1.1-py2.py3-none-win_amd64.whl", hash = "sha256:95da904130bfa02ea74ff9c0116b4ad266174fafb1c707aa50212bc7859aebf1", size = 312958 }, - { url = "https://files.pythonhosted.org/packages/73/2a/f5b7b3b7ecd5cf4e31375580bf5c6a01a328ed1ebdfff90fab463e3f4bc7/ninja-1.11.1.1-py2.py3-none-win_arm64.whl", hash = "sha256:185e0641bde601e53841525c4196278e9aaf4463758da6dd1e752c0a0f54136a", size = 272686 }, ] [[package]] @@ -589,26 +529,14 @@ name = "numpy" version = "1.26.4" source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl" }, { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp310-cp310-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl" }, { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp311-cp311-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl" }, { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp312-cp312-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl" }, { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp39-cp39-win_amd64.whl" }, ] [[package]] @@ -618,7 +546,6 @@ source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl" }, ] [[package]] @@ -628,7 +555,6 @@ source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl" }, ] [[package]] @@ -638,7 +564,6 @@ source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl" }, ] [[package]] @@ -648,7 +573,6 @@ source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl" }, ] [[package]] @@ -669,7 +593,6 @@ source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl" }, ] [[package]] @@ -679,7 +602,6 @@ source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl" }, ] [[package]] @@ -694,7 +616,6 @@ dependencies = [ wheels = [ { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl" }, ] [[package]] @@ -707,7 +628,6 @@ dependencies = [ wheels = [ { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl" }, ] [[package]] @@ -728,16 +648,12 @@ dependencies = [ wheels = [ { url = "https://files.pythonhosted.org/packages/32/b2/2a688cc56d875a08e3e732af642d0ae0f4a7253dc1a00fd271e2fe1a79e9/nvidia_modelopt-0.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f646da43a46ddf10eb2c2ddebe49cb1a58e631348808c5640afe217f8ab223ae", size = 4699528 }, { url = "https://files.pythonhosted.org/packages/b2/b3/98bea42c27fb9ca9f6c502ec3624f2000732f9cc642652b75378f6fed1db/nvidia_modelopt-0.17.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:67cd2faad4c0084864330533112cb231fdf404526086f0803e593f65b7868f47", size = 4502680 }, - { url = "https://files.pythonhosted.org/packages/b9/5c/24111cfee820bc96169b470f60e1540fbaac85c17566c552b4bf86aba312/nvidia_modelopt-0.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:1594b2492d038940e696ed7e26014e5f0eee7e0c616b6ce619c18ddc2116c433", size = 1092507 }, { url = "https://files.pythonhosted.org/packages/26/7e/beb7461b6bedf7fff043ded616a520468c992df9d005d3bfd87eab1dab71/nvidia_modelopt-0.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee7480fdb7e1e6d22e17092bfbc6abf6140f070853c21662e4b6162aec228feb", size = 5108601 }, { url = "https://files.pythonhosted.org/packages/43/79/45572053e3e928f32d62fcae8704a2667a646356ac752a15490829398987/nvidia_modelopt-0.17.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:16e6a8df04e6551a9f2059cb5c68608d18b38cb14dec609ac920035ceeccbf98", size = 4966637 }, - { url = "https://files.pythonhosted.org/packages/70/d5/ef43dc8b5d9b026318ec138ac2403f63081449ecafedaa8de1be27c38d54/nvidia_modelopt-0.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:f62ffbb4cdfb86a6b965ec5d20ef0709520074fcc5f6be912f540322c527873b", size = 1094628 }, { url = "https://files.pythonhosted.org/packages/7e/a9/53a82e52fd0d5c44a6a09f48db30b7e133e393ad1b7cef280ac41a4a8bca/nvidia_modelopt-0.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:557b92dbeeb4d9dedb58d9a6251901870a7458f6e773838efa9155693c5d21e1", size = 5145613 }, { url = "https://files.pythonhosted.org/packages/4e/59/d3d2c7d59c5b56c86c0c1c55896d6800539f5b1aae7c3ea43331c117fe95/nvidia_modelopt-0.17.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3a07969bb8c684b3e9f9eecc018bc465d0f683d9e9a4792b5eb435b9ad85f4a5", size = 4949031 }, - { url = "https://files.pythonhosted.org/packages/bb/a3/562bebd9aa8f8c033e2127ce981c21f7221aeca62dd4692169c06283ec92/nvidia_modelopt-0.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:98b0c37ba547a821170ed6ed5ac3b90cb9562cb32b736ee94cee0364d1c92263", size = 1079358 }, { url = "https://files.pythonhosted.org/packages/61/00/39c8ad3969003f7ba87a54197626aa822047bd98a0e09428045e9dd00335/nvidia_modelopt-0.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0da51c4464af883eb722b20679633aabc31fadff6b4bee3c33eb4eda3d920484", size = 4694964 }, { url = "https://files.pythonhosted.org/packages/41/6b/7eb1d60a4706ee1346990bb7b45bf4244be60e24c52894fd817f5ffb649a/nvidia_modelopt-0.17.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:fdc86a37698cad0d780e91ac07166ef4e3b073e87200116351f0b2c7dcada2bf", size = 4502582 }, - { url = "https://files.pythonhosted.org/packages/6b/f1/e10a9a525d8d34001dddadd64905d6ce6c875e6962798c65b31645b889e9/nvidia_modelopt-0.17.0-cp39-cp39-win_amd64.whl", hash = "sha256:5f71a2fb40d5fa6a59d8594fb5bffca0e327712aceda41bf79379caab58b0480", size = 1093442 }, ] [package.optional-dependencies] @@ -771,7 +687,6 @@ source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl" }, ] [[package]] @@ -781,7 +696,6 @@ source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl" }, ] [[package]] @@ -838,75 +752,44 @@ version = "10.4.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/cd/74/ad3d526f3bf7b6d3f408b73fde271ec69dfac8b81341a318ce825f2b3812/pillow-10.4.0.tar.gz", hash = "sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06", size = 46555059 } wheels = [ - { url = "https://files.pythonhosted.org/packages/0e/69/a31cccd538ca0b5272be2a38347f8839b97a14be104ea08b0db92f749c74/pillow-10.4.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:4d9667937cfa347525b319ae34375c37b9ee6b525440f3ef48542fcf66f2731e", size = 3509271 }, - { url = "https://files.pythonhosted.org/packages/9a/9e/4143b907be8ea0bce215f2ae4f7480027473f8b61fcedfda9d851082a5d2/pillow-10.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:543f3dc61c18dafb755773efc89aae60d06b6596a63914107f75459cf984164d", size = 3375658 }, { url = "https://files.pythonhosted.org/packages/8a/25/1fc45761955f9359b1169aa75e241551e74ac01a09f487adaaf4c3472d11/pillow-10.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7928ecbf1ece13956b95d9cbcfc77137652b02763ba384d9ab508099a2eca856", size = 4332075 }, { url = "https://files.pythonhosted.org/packages/5e/dd/425b95d0151e1d6c951f45051112394f130df3da67363b6bc75dc4c27aba/pillow-10.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4d49b85c4348ea0b31ea63bc75a9f3857869174e2bf17e7aba02945cd218e6f", size = 4444808 }, { url = "https://files.pythonhosted.org/packages/b1/84/9a15cc5726cbbfe7f9f90bfb11f5d028586595907cd093815ca6644932e3/pillow-10.4.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6c762a5b0997f5659a5ef2266abc1d8851ad7749ad9a6a5506eb23d314e4f46b", size = 4356290 }, { url = "https://files.pythonhosted.org/packages/b5/5b/6651c288b08df3b8c1e2f8c1152201e0b25d240e22ddade0f1e242fc9fa0/pillow-10.4.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a985e028fc183bf12a77a8bbf36318db4238a3ded7fa9df1b9a133f1cb79f8fc", size = 4525163 }, { url = "https://files.pythonhosted.org/packages/07/8b/34854bf11a83c248505c8cb0fcf8d3d0b459a2246c8809b967963b6b12ae/pillow-10.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:812f7342b0eee081eaec84d91423d1b4650bb9828eb53d8511bcef8ce5aecf1e", size = 4463100 }, { url = "https://files.pythonhosted.org/packages/78/63/0632aee4e82476d9cbe5200c0cdf9ba41ee04ed77887432845264d81116d/pillow-10.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ac1452d2fbe4978c2eec89fb5a23b8387aba707ac72810d9490118817d9c0b46", size = 4592880 }, - { url = "https://files.pythonhosted.org/packages/df/56/b8663d7520671b4398b9d97e1ed9f583d4afcbefbda3c6188325e8c297bd/pillow-10.4.0-cp310-cp310-win32.whl", hash = "sha256:bcd5e41a859bf2e84fdc42f4edb7d9aba0a13d29a2abadccafad99de3feff984", size = 2235218 }, - { url = "https://files.pythonhosted.org/packages/f4/72/0203e94a91ddb4a9d5238434ae6c1ca10e610e8487036132ea9bf806ca2a/pillow-10.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:ecd85a8d3e79cd7158dec1c9e5808e821feea088e2f69a974db5edf84dc53141", size = 2554487 }, - { url = "https://files.pythonhosted.org/packages/bd/52/7e7e93d7a6e4290543f17dc6f7d3af4bd0b3dd9926e2e8a35ac2282bc5f4/pillow-10.4.0-cp310-cp310-win_arm64.whl", hash = "sha256:ff337c552345e95702c5fde3158acb0625111017d0e5f24bf3acdb9cc16b90d1", size = 2243219 }, - { url = "https://files.pythonhosted.org/packages/a7/62/c9449f9c3043c37f73e7487ec4ef0c03eb9c9afc91a92b977a67b3c0bbc5/pillow-10.4.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:0a9ec697746f268507404647e531e92889890a087e03681a3606d9b920fbee3c", size = 3509265 }, - { url = "https://files.pythonhosted.org/packages/f4/5f/491dafc7bbf5a3cc1845dc0430872e8096eb9e2b6f8161509d124594ec2d/pillow-10.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dfe91cb65544a1321e631e696759491ae04a2ea11d36715eca01ce07284738be", size = 3375655 }, { url = "https://files.pythonhosted.org/packages/73/d5/c4011a76f4207a3c151134cd22a1415741e42fa5ddecec7c0182887deb3d/pillow-10.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dc6761a6efc781e6a1544206f22c80c3af4c8cf461206d46a1e6006e4429ff3", size = 4340304 }, { url = "https://files.pythonhosted.org/packages/ac/10/c67e20445a707f7a610699bba4fe050583b688d8cd2d202572b257f46600/pillow-10.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e84b6cc6a4a3d76c153a6b19270b3526a5a8ed6b09501d3af891daa2a9de7d6", size = 4452804 }, { url = "https://files.pythonhosted.org/packages/a9/83/6523837906d1da2b269dee787e31df3b0acb12e3d08f024965a3e7f64665/pillow-10.4.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:bbc527b519bd3aa9d7f429d152fea69f9ad37c95f0b02aebddff592688998abe", size = 4365126 }, { url = "https://files.pythonhosted.org/packages/ba/e5/8c68ff608a4203085158cff5cc2a3c534ec384536d9438c405ed6370d080/pillow-10.4.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:76a911dfe51a36041f2e756b00f96ed84677cdeb75d25c767f296c1c1eda1319", size = 4533541 }, { url = "https://files.pythonhosted.org/packages/f4/7c/01b8dbdca5bc6785573f4cee96e2358b0918b7b2c7b60d8b6f3abf87a070/pillow-10.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:59291fb29317122398786c2d44427bbd1a6d7ff54017075b22be9d21aa59bd8d", size = 4471616 }, { url = "https://files.pythonhosted.org/packages/c8/57/2899b82394a35a0fbfd352e290945440e3b3785655a03365c0ca8279f351/pillow-10.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:416d3a5d0e8cfe4f27f574362435bc9bae57f679a7158e0096ad2beb427b8696", size = 4600802 }, - { url = "https://files.pythonhosted.org/packages/4d/d7/a44f193d4c26e58ee5d2d9db3d4854b2cfb5b5e08d360a5e03fe987c0086/pillow-10.4.0-cp311-cp311-win32.whl", hash = "sha256:7086cc1d5eebb91ad24ded9f58bec6c688e9f0ed7eb3dbbf1e4800280a896496", size = 2235213 }, - { url = "https://files.pythonhosted.org/packages/c1/d0/5866318eec2b801cdb8c82abf190c8343d8a1cd8bf5a0c17444a6f268291/pillow-10.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:cbed61494057c0f83b83eb3a310f0bf774b09513307c434d4366ed64f4128a91", size = 2554498 }, - { url = "https://files.pythonhosted.org/packages/d4/c8/310ac16ac2b97e902d9eb438688de0d961660a87703ad1561fd3dfbd2aa0/pillow-10.4.0-cp311-cp311-win_arm64.whl", hash = "sha256:f5f0c3e969c8f12dd2bb7e0b15d5c468b51e5017e01e2e867335c81903046a22", size = 2243219 }, - { url = "https://files.pythonhosted.org/packages/05/cb/0353013dc30c02a8be34eb91d25e4e4cf594b59e5a55ea1128fde1e5f8ea/pillow-10.4.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:673655af3eadf4df6b5457033f086e90299fdd7a47983a13827acf7459c15d94", size = 3509350 }, - { url = "https://files.pythonhosted.org/packages/e7/cf/5c558a0f247e0bf9cec92bff9b46ae6474dd736f6d906315e60e4075f737/pillow-10.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:866b6942a92f56300012f5fbac71f2d610312ee65e22f1aa2609e491284e5597", size = 3374980 }, { url = "https://files.pythonhosted.org/packages/84/48/6e394b86369a4eb68b8a1382c78dc092245af517385c086c5094e3b34428/pillow-10.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29dbdc4207642ea6aad70fbde1a9338753d33fb23ed6956e706936706f52dd80", size = 4343799 }, { url = "https://files.pythonhosted.org/packages/3b/f3/a8c6c11fa84b59b9df0cd5694492da8c039a24cd159f0f6918690105c3be/pillow-10.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf2342ac639c4cf38799a44950bbc2dfcb685f052b9e262f446482afaf4bffca", size = 4459973 }, { url = "https://files.pythonhosted.org/packages/7d/1b/c14b4197b80150fb64453585247e6fb2e1d93761fa0fa9cf63b102fde822/pillow-10.4.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f5b92f4d70791b4a67157321c4e8225d60b119c5cc9aee8ecf153aace4aad4ef", size = 4370054 }, { url = "https://files.pythonhosted.org/packages/55/77/40daddf677897a923d5d33329acd52a2144d54a9644f2a5422c028c6bf2d/pillow-10.4.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:86dcb5a1eb778d8b25659d5e4341269e8590ad6b4e8b44d9f4b07f8d136c414a", size = 4539484 }, { url = "https://files.pythonhosted.org/packages/40/54/90de3e4256b1207300fb2b1d7168dd912a2fb4b2401e439ba23c2b2cabde/pillow-10.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:780c072c2e11c9b2c7ca37f9a2ee8ba66f44367ac3e5c7832afcfe5104fd6d1b", size = 4477375 }, { url = "https://files.pythonhosted.org/packages/13/24/1bfba52f44193860918ff7c93d03d95e3f8748ca1de3ceaf11157a14cf16/pillow-10.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:37fb69d905be665f68f28a8bba3c6d3223c8efe1edf14cc4cfa06c241f8c81d9", size = 4608773 }, - { url = "https://files.pythonhosted.org/packages/55/04/5e6de6e6120451ec0c24516c41dbaf80cce1b6451f96561235ef2429da2e/pillow-10.4.0-cp312-cp312-win32.whl", hash = "sha256:7dfecdbad5c301d7b5bde160150b4db4c659cee2b69589705b6f8a0c509d9f42", size = 2235690 }, - { url = "https://files.pythonhosted.org/packages/74/0a/d4ce3c44bca8635bd29a2eab5aa181b654a734a29b263ca8efe013beea98/pillow-10.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:1d846aea995ad352d4bdcc847535bd56e0fd88d36829d2c90be880ef1ee4668a", size = 2554951 }, - { url = "https://files.pythonhosted.org/packages/b5/ca/184349ee40f2e92439be9b3502ae6cfc43ac4b50bc4fc6b3de7957563894/pillow-10.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:e553cad5179a66ba15bb18b353a19020e73a7921296a7979c4a2b7f6a5cd57f9", size = 2243427 }, - { url = "https://files.pythonhosted.org/packages/c3/00/706cebe7c2c12a6318aabe5d354836f54adff7156fd9e1bd6c89f4ba0e98/pillow-10.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3", size = 3525685 }, - { url = "https://files.pythonhosted.org/packages/cf/76/f658cbfa49405e5ecbfb9ba42d07074ad9792031267e782d409fd8fe7c69/pillow-10.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb", size = 3374883 }, { url = "https://files.pythonhosted.org/packages/46/2b/99c28c4379a85e65378211971c0b430d9c7234b1ec4d59b2668f6299e011/pillow-10.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70", size = 4339837 }, { url = "https://files.pythonhosted.org/packages/f1/74/b1ec314f624c0c43711fdf0d8076f82d9d802afd58f1d62c2a86878e8615/pillow-10.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be", size = 4455562 }, { url = "https://files.pythonhosted.org/packages/4a/2a/4b04157cb7b9c74372fa867096a1607e6fedad93a44deeff553ccd307868/pillow-10.4.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0", size = 4366761 }, { url = "https://files.pythonhosted.org/packages/ac/7b/8f1d815c1a6a268fe90481232c98dd0e5fa8c75e341a75f060037bd5ceae/pillow-10.4.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc", size = 4536767 }, { url = "https://files.pythonhosted.org/packages/e5/77/05fa64d1f45d12c22c314e7b97398ffb28ef2813a485465017b7978b3ce7/pillow-10.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a", size = 4477989 }, { url = "https://files.pythonhosted.org/packages/12/63/b0397cfc2caae05c3fb2f4ed1b4fc4fc878f0243510a7a6034ca59726494/pillow-10.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309", size = 4610255 }, - { url = "https://files.pythonhosted.org/packages/7b/f9/cfaa5082ca9bc4a6de66ffe1c12c2d90bf09c309a5f52b27759a596900e7/pillow-10.4.0-cp313-cp313-win32.whl", hash = "sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060", size = 2235603 }, - { url = "https://files.pythonhosted.org/packages/01/6a/30ff0eef6e0c0e71e55ded56a38d4859bf9d3634a94a88743897b5f96936/pillow-10.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea", size = 2554972 }, - { url = "https://files.pythonhosted.org/packages/48/2c/2e0a52890f269435eee38b21c8218e102c621fe8d8df8b9dd06fabf879ba/pillow-10.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d", size = 2243375 }, - { url = "https://files.pythonhosted.org/packages/31/85/955fa5400fa8039921f630372cfe5056eed6e1b8e0430ee4507d7de48832/pillow-10.4.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0ae24a547e8b711ccaaf99c9ae3cd975470e1a30caa80a6aaee9a2f19c05701d", size = 3509283 }, - { url = "https://files.pythonhosted.org/packages/23/9c/343827267eb28d41cd82b4180d33b10d868af9077abcec0af9793aa77d2d/pillow-10.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:298478fe4f77a4408895605f3482b6cc6222c018b2ce565c2b6b9c354ac3229b", size = 3375691 }, { url = "https://files.pythonhosted.org/packages/60/a3/7ebbeabcd341eab722896d1a5b59a3df98c4b4d26cf4b0385f8aa94296f7/pillow-10.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:134ace6dc392116566980ee7436477d844520a26a4b1bd4053f6f47d096997fd", size = 4328295 }, { url = "https://files.pythonhosted.org/packages/32/3f/c02268d0c6fb6b3958bdda673c17b315c821d97df29ae6969f20fb49388a/pillow-10.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:930044bb7679ab003b14023138b50181899da3f25de50e9dbee23b61b4de2126", size = 4440810 }, { url = "https://files.pythonhosted.org/packages/67/5d/1c93c8cc35f2fdd3d6cc7e4ad72d203902859a2867de6ad957d9b708eb8d/pillow-10.4.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c76e5786951e72ed3686e122d14c5d7012f16c8303a674d18cdcd6d89557fc5b", size = 4352283 }, { url = "https://files.pythonhosted.org/packages/bc/a8/8655557c9c7202b8abbd001f61ff36711cefaf750debcaa1c24d154ef602/pillow-10.4.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b2724fdb354a868ddf9a880cb84d102da914e99119211ef7ecbdc613b8c96b3c", size = 4521800 }, { url = "https://files.pythonhosted.org/packages/58/78/6f95797af64d137124f68af1bdaa13b5332da282b86031f6fa70cf368261/pillow-10.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dbc6ae66518ab3c5847659e9988c3b60dc94ffb48ef9168656e0019a93dbf8a1", size = 4459177 }, { url = "https://files.pythonhosted.org/packages/8a/6d/2b3ce34f1c4266d79a78c9a51d1289a33c3c02833fe294ef0dcbb9cba4ed/pillow-10.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:06b2f7898047ae93fad74467ec3d28fe84f7831370e3c258afa533f81ef7f3df", size = 4589079 }, - { url = "https://files.pythonhosted.org/packages/e3/e0/456258c74da1ff5bf8ef1eab06a95ca994d8b9ed44c01d45c3f8cbd1db7e/pillow-10.4.0-cp39-cp39-win32.whl", hash = "sha256:7970285ab628a3779aecc35823296a7869f889b8329c16ad5a71e4901a3dc4ef", size = 2235247 }, - { url = "https://files.pythonhosted.org/packages/37/f8/bef952bdb32aa53741f58bf21798642209e994edc3f6598f337f23d5400a/pillow-10.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:961a7293b2457b405967af9c77dcaa43cc1a8cd50d23c532e62d48ab6cdd56f5", size = 2554479 }, - { url = "https://files.pythonhosted.org/packages/bb/8e/805201619cad6651eef5fc1fdef913804baf00053461522fabbc5588ea12/pillow-10.4.0-cp39-cp39-win_arm64.whl", hash = "sha256:32cda9e3d601a52baccb2856b8ea1fc213c90b340c542dcef77140dfa3278a9e", size = 2243226 }, - { url = "https://files.pythonhosted.org/packages/38/30/095d4f55f3a053392f75e2eae45eba3228452783bab3d9a920b951ac495c/pillow-10.4.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5b4815f2e65b30f5fbae9dfffa8636d992d49705723fe86a3661806e069352d4", size = 3493889 }, - { url = "https://files.pythonhosted.org/packages/f3/e8/4ff79788803a5fcd5dc35efdc9386af153569853767bff74540725b45863/pillow-10.4.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8f0aef4ef59694b12cadee839e2ba6afeab89c0f39a3adc02ed51d109117b8da", size = 3346160 }, { url = "https://files.pythonhosted.org/packages/d7/ac/4184edd511b14f760c73f5bb8a5d6fd85c591c8aff7c2229677a355c4179/pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f4727572e2918acaa9077c919cbbeb73bd2b3ebcfe033b72f858fc9fbef0026", size = 3435020 }, { url = "https://files.pythonhosted.org/packages/da/21/1749cd09160149c0a246a81d646e05f35041619ce76f6493d6a96e8d1103/pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e", size = 3490539 }, { url = "https://files.pythonhosted.org/packages/b6/f5/f71fe1888b96083b3f6dfa0709101f61fc9e972c0c8d04e9d93ccef2a045/pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:dc3e2db6ba09ffd7d02ae9141cfa0ae23393ee7687248d46a7507b75d610f4f5", size = 3476125 }, { url = "https://files.pythonhosted.org/packages/96/b9/c0362c54290a31866c3526848583a2f45a535aa9d725fd31e25d318c805f/pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02a2be69f9c9b8c1e97cf2713e789d4e398c751ecfd9967c18d0ce304efbf885", size = 3579373 }, - { url = "https://files.pythonhosted.org/packages/52/3b/ce7a01026a7cf46e5452afa86f97a5e88ca97f562cafa76570178ab56d8d/pillow-10.4.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:0755ffd4a0c6f267cccbae2e9903d95477ca2f77c4fcf3a3a09570001856c8a5", size = 2554661 }, - { url = "https://files.pythonhosted.org/packages/e1/1f/5a9fcd6ced51633c22481417e11b1b47d723f64fb536dfd67c015eb7f0ab/pillow-10.4.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:a02364621fe369e06200d4a16558e056fe2805d3468350df3aef21e00d26214b", size = 3493850 }, - { url = "https://files.pythonhosted.org/packages/cb/e6/3ea4755ed5320cb62aa6be2f6de47b058c6550f752dd050e86f694c59798/pillow-10.4.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1b5dea9831a90e9d0721ec417a80d4cbd7022093ac38a568db2dd78363b00908", size = 3346118 }, { url = "https://files.pythonhosted.org/packages/0a/22/492f9f61e4648422b6ca39268ec8139277a5b34648d28f400faac14e0f48/pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b885f89040bb8c4a1573566bbb2f44f5c505ef6e74cec7ab9068c900047f04b", size = 3434958 }, { url = "https://files.pythonhosted.org/packages/f9/19/559a48ad4045704bb0547965b9a9345f5cd461347d977a56d178db28819e/pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87dd88ded2e6d74d31e1e0a99a726a6765cda32d00ba72dc37f0651f306daaa8", size = 3490340 }, { url = "https://files.pythonhosted.org/packages/d9/de/cebaca6fb79905b3a1aa0281d238769df3fb2ede34fd7c0caa286575915a/pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:2db98790afc70118bd0255c2eeb465e9767ecf1f3c25f9a1abb8ffc8cfd1fe0a", size = 3476048 }, { url = "https://files.pythonhosted.org/packages/71/f0/86d5b2f04693b0116a01d75302b0a307800a90d6c351a8aa4f8ae76cd499/pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:f7baece4ce06bade126fb84b8af1c33439a76d8a6fd818970215e0560ca28c27", size = 3579366 }, - { url = "https://files.pythonhosted.org/packages/37/ae/2dbfc38cc4fd14aceea14bc440d5151b21f64c4c3ba3f6f4191610b7ee5d/pillow-10.4.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cfdd747216947628af7b259d274771d84db2268ca062dd5faf373639d00113a3", size = 2554652 }, ] [[package]] @@ -961,15 +844,9 @@ version = "6.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/18/c7/8c6872f7372eb6a6b2e4708b88419fb46b857f7a2e1892966b851cc79fc9/psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2", size = 508067 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c5/66/78c9c3020f573c58101dc43a44f6855d01bbbd747e24da2f0c4491200ea3/psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35", size = 249766 }, - { url = "https://files.pythonhosted.org/packages/e1/3f/2403aa9558bea4d3854b0e5e567bc3dd8e9fbc1fc4453c0aa9aafeb75467/psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1", size = 253024 }, - { url = "https://files.pythonhosted.org/packages/0b/37/f8da2fbd29690b3557cca414c1949f92162981920699cd62095a984983bf/psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0", size = 250961 }, { url = "https://files.pythonhosted.org/packages/35/56/72f86175e81c656a01c4401cd3b1c923f891b31fbcebe98985894176d7c9/psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0", size = 287478 }, { url = "https://files.pythonhosted.org/packages/19/74/f59e7e0d392bc1070e9a70e2f9190d652487ac115bb16e2eff6b22ad1d24/psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd", size = 290455 }, { url = "https://files.pythonhosted.org/packages/cd/5f/60038e277ff0a9cc8f0c9ea3d0c5eb6ee1d2470ea3f9389d776432888e47/psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132", size = 292046 }, - { url = "https://files.pythonhosted.org/packages/8b/20/2ff69ad9c35c3df1858ac4e094f20bd2374d33c8643cf41da8fd7cdcb78b/psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d", size = 253560 }, - { url = "https://files.pythonhosted.org/packages/73/44/561092313ae925f3acfaace6f9ddc4f6a9c748704317bad9c8c8f8a36a79/psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3", size = 257399 }, - { url = "https://files.pythonhosted.org/packages/7c/06/63872a64c312a24fb9b4af123ee7007a306617da63ff13bcc1432386ead7/psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0", size = 251988 }, ] [[package]] @@ -1022,8 +899,6 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/e2/aa/6b6a9b9f8537b872f552ddd46dd3da230367754b6f707b8e1e963f515ea3/pydantic_core-2.23.4.tar.gz", hash = "sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863", size = 402156 } wheels = [ - { url = "https://files.pythonhosted.org/packages/5c/8b/d3ae387f66277bd8104096d6ec0a145f4baa2966ebb2cad746c0920c9526/pydantic_core-2.23.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b", size = 1867835 }, - { url = "https://files.pythonhosted.org/packages/46/76/f68272e4c3a7df8777798282c5e47d508274917f29992d84e1898f8908c7/pydantic_core-2.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166", size = 1776689 }, { url = "https://files.pythonhosted.org/packages/cc/69/5f945b4416f42ea3f3bc9d2aaec66c76084a6ff4ff27555bf9415ab43189/pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63e46b3169866bd62849936de036f901a9356e36376079b05efa83caeaa02ceb", size = 1800748 }, { url = "https://files.pythonhosted.org/packages/50/ab/891a7b0054bcc297fb02d44d05c50e68154e31788f2d9d41d0b72c89fdf7/pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed1a53de42fbe34853ba90513cea21673481cd81ed1be739f7f2efb931b24916", size = 1806469 }, { url = "https://files.pythonhosted.org/packages/31/7c/6e3fa122075d78f277a8431c4c608f061881b76c2b7faca01d317ee39b5d/pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cfdd16ab5e59fc31b5e906d1a3f666571abc367598e3e02c83403acabc092e07", size = 2002246 }, @@ -1032,10 +907,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/75/984740c17f12c3ce18b5a2fcc4bdceb785cce7df1511a4ce89bca17c7e2d/pydantic_core-2.23.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f09e2ff1f17c2b51f2bc76d1cc33da96298f0a036a137f5440ab3ec5360b624f", size = 1921437 }, { url = "https://files.pythonhosted.org/packages/a0/74/13c5f606b64d93f0721e7768cd3e8b2102164866c207b8cd6f90bb15d24f/pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e38e63e6f3d1cec5a27e0afe90a085af8b6806ee208b33030e65b6516353f1a3", size = 1966129 }, { url = "https://files.pythonhosted.org/packages/18/03/9c4aa5919457c7b57a016c1ab513b1a926ed9b2bb7915bf8e506bf65c34b/pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0dbd8dbed2085ed23b5c04afa29d8fd2771674223135dc9bc937f3c09284d071", size = 2110908 }, - { url = "https://files.pythonhosted.org/packages/92/2c/053d33f029c5dc65e5cf44ff03ceeefb7cce908f8f3cca9265e7f9b540c8/pydantic_core-2.23.4-cp310-none-win32.whl", hash = "sha256:6531b7ca5f951d663c339002e91aaebda765ec7d61b7d1e3991051906ddde119", size = 1735278 }, - { url = "https://files.pythonhosted.org/packages/de/81/7dfe464eca78d76d31dd661b04b5f2036ec72ea8848dd87ab7375e185c23/pydantic_core-2.23.4-cp310-none-win_amd64.whl", hash = "sha256:7c9129eb40958b3d4500fa2467e6a83356b3b61bfff1b414c7361d9220f9ae8f", size = 1917453 }, - { url = "https://files.pythonhosted.org/packages/5d/30/890a583cd3f2be27ecf32b479d5d615710bb926d92da03e3f7838ff3e58b/pydantic_core-2.23.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:77733e3892bb0a7fa797826361ce8a9184d25c8dffaec60b7ffe928153680ba8", size = 1865160 }, - { url = "https://files.pythonhosted.org/packages/1d/9a/b634442e1253bc6889c87afe8bb59447f106ee042140bd57680b3b113ec7/pydantic_core-2.23.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b84d168f6c48fabd1f2027a3d1bdfe62f92cade1fb273a5d68e621da0e44e6d", size = 1776777 }, { url = "https://files.pythonhosted.org/packages/75/9a/7816295124a6b08c24c96f9ce73085032d8bcbaf7e5a781cd41aa910c891/pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df49e7a0861a8c36d089c1ed57d308623d60416dab2647a4a17fe050ba85de0e", size = 1799244 }, { url = "https://files.pythonhosted.org/packages/a9/8f/89c1405176903e567c5f99ec53387449e62f1121894aa9fc2c4fdc51a59b/pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ff02b6d461a6de369f07ec15e465a88895f3223eb75073ffea56b84d9331f607", size = 1805307 }, { url = "https://files.pythonhosted.org/packages/d5/a5/1a194447d0da1ef492e3470680c66048fef56fc1f1a25cafbea4bc1d1c48/pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:996a38a83508c54c78a5f41456b0103c30508fed9abcad0a59b876d7398f25fd", size = 2000663 }, @@ -1044,10 +915,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1b/aa/98e190f8745d5ec831f6d5449344c48c0627ac5fed4e5340a44b74878f8e/pydantic_core-2.23.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6f783e0ec4803c787bcea93e13e9932edab72068f68ecffdf86a99fd5918878b", size = 1919967 }, { url = "https://files.pythonhosted.org/packages/ae/35/b6e00b6abb2acfee3e8f85558c02a0822e9a8b2f2d812ea8b9079b118ba0/pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d0776dea117cf5272382634bd2a5c1b6eb16767c223c6a5317cd3e2a757c61a0", size = 1964291 }, { url = "https://files.pythonhosted.org/packages/13/46/7bee6d32b69191cd649bbbd2361af79c472d72cb29bb2024f0b6e350ba06/pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d5f7a395a8cf1621939692dba2a6b6a830efa6b3cee787d82c7de1ad2930de64", size = 2109666 }, - { url = "https://files.pythonhosted.org/packages/39/ef/7b34f1b122a81b68ed0a7d0e564da9ccdc9a2924c8d6c6b5b11fa3a56970/pydantic_core-2.23.4-cp311-none-win32.whl", hash = "sha256:74b9127ffea03643e998e0c5ad9bd3811d3dac8c676e47db17b0ee7c3c3bf35f", size = 1732940 }, - { url = "https://files.pythonhosted.org/packages/2f/76/37b7e76c645843ff46c1d73e046207311ef298d3f7b2f7d8f6ac60113071/pydantic_core-2.23.4-cp311-none-win_amd64.whl", hash = "sha256:98d134c954828488b153d88ba1f34e14259284f256180ce659e8d83e9c05eaa3", size = 1916804 }, - { url = "https://files.pythonhosted.org/packages/74/7b/8e315f80666194b354966ec84b7d567da77ad927ed6323db4006cf915f3f/pydantic_core-2.23.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231", size = 1856459 }, - { url = "https://files.pythonhosted.org/packages/14/de/866bdce10ed808323d437612aca1ec9971b981e1c52e5e42ad9b8e17a6f6/pydantic_core-2.23.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee", size = 1770007 }, { url = "https://files.pythonhosted.org/packages/dc/69/8edd5c3cd48bb833a3f7ef9b81d7666ccddd3c9a635225214e044b6e8281/pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87", size = 1790245 }, { url = "https://files.pythonhosted.org/packages/80/33/9c24334e3af796ce80d2274940aae38dd4e5676298b4398eff103a79e02d/pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8", size = 1801260 }, { url = "https://files.pythonhosted.org/packages/a5/6f/e9567fd90104b79b101ca9d120219644d3314962caa7948dd8b965e9f83e/pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327", size = 1996872 }, @@ -1056,10 +923,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/4d/3079d00c47f22c9a9a8220db088b309ad6e600a73d7a69473e3a8e5e3ea3/pydantic_core-2.23.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126", size = 1917453 }, { url = "https://files.pythonhosted.org/packages/e9/88/9df5b7ce880a4703fcc2d76c8c2d8eb9f861f79d0c56f4b8f5f2607ccec8/pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e", size = 1968793 }, { url = "https://files.pythonhosted.org/packages/e3/b9/41f7efe80f6ce2ed3ee3c2dcfe10ab7adc1172f778cc9659509a79518c43/pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24", size = 2116872 }, - { url = "https://files.pythonhosted.org/packages/63/08/b59b7a92e03dd25554b0436554bf23e7c29abae7cce4b1c459cd92746811/pydantic_core-2.23.4-cp312-none-win32.whl", hash = "sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84", size = 1738535 }, - { url = "https://files.pythonhosted.org/packages/88/8d/479293e4d39ab409747926eec4329de5b7129beaedc3786eca070605d07f/pydantic_core-2.23.4-cp312-none-win_amd64.whl", hash = "sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9", size = 1917992 }, - { url = "https://files.pythonhosted.org/packages/ad/ef/16ee2df472bf0e419b6bc68c05bf0145c49247a1095e85cee1463c6a44a1/pydantic_core-2.23.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc", size = 1856143 }, - { url = "https://files.pythonhosted.org/packages/da/fa/bc3dbb83605669a34a93308e297ab22be82dfb9dcf88c6cf4b4f264e0a42/pydantic_core-2.23.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd", size = 1770063 }, { url = "https://files.pythonhosted.org/packages/4e/48/e813f3bbd257a712303ebdf55c8dc46f9589ec74b384c9f652597df3288d/pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05", size = 1790013 }, { url = "https://files.pythonhosted.org/packages/b4/e0/56eda3a37929a1d297fcab1966db8c339023bcca0b64c5a84896db3fcc5c/pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d", size = 1801077 }, { url = "https://files.pythonhosted.org/packages/04/be/5e49376769bfbf82486da6c5c1683b891809365c20d7c7e52792ce4c71f3/pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510", size = 1996782 }, @@ -1068,10 +931,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/45/bdce5779b59f468bdf262a5bc9eecbae87f271c51aef628d8c073b4b4b4c/pydantic_core-2.23.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327", size = 1916994 }, { url = "https://files.pythonhosted.org/packages/d8/fa/c648308fe711ee1f88192cad6026ab4f925396d1293e8356de7e55be89b5/pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6", size = 1968877 }, { url = "https://files.pythonhosted.org/packages/16/16/b805c74b35607d24d37103007f899abc4880923b04929547ae68d478b7f4/pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f", size = 2116814 }, - { url = "https://files.pythonhosted.org/packages/d1/58/5305e723d9fcdf1c5a655e6a4cc2a07128bf644ff4b1d98daf7a9dbf57da/pydantic_core-2.23.4-cp313-none-win32.whl", hash = "sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769", size = 1738360 }, - { url = "https://files.pythonhosted.org/packages/a5/ae/e14b0ff8b3f48e02394d8acd911376b7b66e164535687ef7dc24ea03072f/pydantic_core-2.23.4-cp313-none-win_amd64.whl", hash = "sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5", size = 1919411 }, - { url = "https://files.pythonhosted.org/packages/7a/04/2580b2deaae37b3e30fc30c54298be938b973990b23612d6b61c7bdd01c7/pydantic_core-2.23.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a4fa4fc04dff799089689f4fd502ce7d59de529fc2f40a2c8836886c03e0175a", size = 1868200 }, - { url = "https://files.pythonhosted.org/packages/39/6e/e311bd0751505350f0cdcee3077841eb1f9253c5a1ddbad048cd9fbf7c6e/pydantic_core-2.23.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0a7df63886be5e270da67e0966cf4afbae86069501d35c8c1b3b6c168f42cb36", size = 1749316 }, { url = "https://files.pythonhosted.org/packages/d0/b4/95b5eb47c6dc8692508c3ca04a1f8d6f0884c9dacb34cf3357595cbe73be/pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcedcd19a557e182628afa1d553c3895a9f825b936415d0dbd3cd0bbcfd29b4b", size = 1800880 }, { url = "https://files.pythonhosted.org/packages/da/79/41c4f817acd7f42d94cd1e16526c062a7b089f66faed4bd30852314d9a66/pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f54b118ce5de9ac21c363d9b3caa6c800341e8c47a508787e5868c6b79c9323", size = 1807077 }, { url = "https://files.pythonhosted.org/packages/fb/53/d13d1eb0a97d5c06cf7a225935d471e9c241afd389a333f40c703f214973/pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86d2f57d3e1379a9525c5ab067b27dbb8a0642fb5d454e17a9ac434f9ce523e3", size = 2002859 }, @@ -1080,24 +939,16 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0f/36/d4ae869e473c3c7868e1cd1e2a1b9e13bce5cd1a7d287f6ac755a0b1575e/pydantic_core-2.23.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a6b5099eeec78827553827f4c6b8615978bb4b6a88e5d9b93eddf8bb6790f55", size = 1921680 }, { url = "https://files.pythonhosted.org/packages/0d/f8/eed5c65b80c4ac4494117e2101973b45fc655774ef647d17dde40a70f7d2/pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e55541f756f9b3ee346b840103f32779c695a19826a4c442b7954550a0972040", size = 1966093 }, { url = "https://files.pythonhosted.org/packages/e8/c8/1d42ce51d65e571ab53d466cae83434325a126811df7ce4861d9d97bee4b/pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a5c7ba8ffb6d6f8f2ab08743be203654bb1aaa8c9dcb09f82ddd34eadb695605", size = 2111437 }, - { url = "https://files.pythonhosted.org/packages/aa/c9/7fea9d13383c2ec6865919e09cffe44ab77e911eb281b53a4deaafd4c8e8/pydantic_core-2.23.4-cp39-none-win32.whl", hash = "sha256:37b0fe330e4a58d3c58b24d91d1eb102aeec675a3db4c292ec3928ecd892a9a6", size = 1735049 }, - { url = "https://files.pythonhosted.org/packages/98/95/dd7045c4caa2b73d0bf3b989d66b23cfbb7a0ef14ce99db15677a000a953/pydantic_core-2.23.4-cp39-none-win_amd64.whl", hash = "sha256:1498bec4c05c9c787bde9125cfdcc63a41004ff167f495063191b863399b1a29", size = 1920180 }, - { url = "https://files.pythonhosted.org/packages/13/a9/5d582eb3204464284611f636b55c0a7410d748ff338756323cb1ce721b96/pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f455ee30a9d61d3e1a15abd5068827773d6e4dc513e795f380cdd59932c782d5", size = 1857135 }, - { url = "https://files.pythonhosted.org/packages/2c/57/faf36290933fe16717f97829eabfb1868182ac495f99cf0eda9f59687c9d/pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1e90d2e3bd2c3863d48525d297cd143fe541be8bbf6f579504b9712cb6b643ec", size = 1740583 }, { url = "https://files.pythonhosted.org/packages/91/7c/d99e3513dc191c4fec363aef1bf4c8af9125d8fa53af7cb97e8babef4e40/pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e203fdf807ac7e12ab59ca2bfcabb38c7cf0b33c41efeb00f8e5da1d86af480", size = 1793637 }, { url = "https://files.pythonhosted.org/packages/29/18/812222b6d18c2d13eebbb0f7cdc170a408d9ced65794fdb86147c77e1982/pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08277a400de01bc72436a0ccd02bdf596631411f592ad985dcee21445bd0068", size = 1941963 }, { url = "https://files.pythonhosted.org/packages/0f/36/c1f3642ac3f05e6bb4aec3ffc399fa3f84895d259cf5f0ce3054b7735c29/pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f220b0eea5965dec25480b6333c788fb72ce5f9129e8759ef876a1d805d00801", size = 1915332 }, { url = "https://files.pythonhosted.org/packages/f7/ca/9c0854829311fb446020ebb540ee22509731abad886d2859c855dd29b904/pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d06b0c8da4f16d1d1e352134427cb194a0a6e19ad5db9161bf32b2113409e728", size = 1957926 }, { url = "https://files.pythonhosted.org/packages/c0/1c/7836b67c42d0cd4441fcd9fafbf6a027ad4b79b6559f80cf11f89fd83648/pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ba1a0996f6c2773bd83e63f18914c1de3c9dd26d55f4ac302a7efe93fb8e7433", size = 2100342 }, - { url = "https://files.pythonhosted.org/packages/a9/f9/b6bcaf874f410564a78908739c80861a171788ef4d4f76f5009656672dfe/pydantic_core-2.23.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:9a5bce9d23aac8f0cf0836ecfc033896aa8443b501c58d0602dbfd5bd5b37753", size = 1920344 }, - { url = "https://files.pythonhosted.org/packages/32/fd/ac9cdfaaa7cf2d32590b807d900612b39acb25e5527c3c7e482f0553025b/pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:78ddaaa81421a29574a682b3179d4cf9e6d405a09b99d93ddcf7e5239c742e21", size = 1857850 }, - { url = "https://files.pythonhosted.org/packages/08/fe/038f4b2bcae325ea643c8ad353191187a4c92a9c3b913b139289a6f2ef04/pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:883a91b5dd7d26492ff2f04f40fbb652de40fcc0afe07e8129e8ae779c2110eb", size = 1740265 }, { url = "https://files.pythonhosted.org/packages/51/14/b215c9c3cbd1edaaea23014d4b3304260823f712d3fdee52549b19b25d62/pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88ad334a15b32a791ea935af224b9de1bf99bcd62fabf745d5f3442199d86d59", size = 1793912 }, { url = "https://files.pythonhosted.org/packages/62/de/2c3ad79b63ba564878cbce325be725929ba50089cd5156f89ea5155cb9b3/pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:233710f069d251feb12a56da21e14cca67994eab08362207785cf8c598e74577", size = 1942870 }, { url = "https://files.pythonhosted.org/packages/cb/55/c222af19e4644c741b3f3fe4fd8bbb6b4cdca87d8a49258b61cf7826b19e/pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:19442362866a753485ba5e4be408964644dd6a09123d9416c54cd49171f50744", size = 1915610 }, { url = "https://files.pythonhosted.org/packages/c4/7a/9a8760692a6f76bb54bcd43f245ff3d8b603db695899bbc624099c00af80/pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:624e278a7d29b6445e4e813af92af37820fafb6dcc55c012c834f9e26f9aaaef", size = 1958403 }, { url = "https://files.pythonhosted.org/packages/4c/91/9b03166feb914bb5698e2f6499e07c2617e2eebf69f9374d0358d7eb2009/pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f5ef8f42bec47f21d07668a043f077d507e5bf4e668d5c6dfe6aaba89de1a5b8", size = 2101154 }, - { url = "https://files.pythonhosted.org/packages/1d/d9/1d7ecb98318da4cb96986daaf0e20d66f1651d0aeb9e2d4435b916ce031d/pydantic_core-2.23.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:aea443fffa9fbe3af1a9ba721a87f926fe548d32cab71d188a6ede77d0ff244e", size = 1920855 }, ] [[package]] @@ -1158,51 +1009,31 @@ version = "6.0.2" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/54/ed/79a089b6be93607fa5cdaedf301d7dfb23af5f25c398d5ead2525b063e17/pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", size = 130631 } wheels = [ - { url = "https://files.pythonhosted.org/packages/9b/95/a3fac87cb7158e231b5a6012e438c647e1a87f09f8e0d123acec8ab8bf71/PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086", size = 184199 }, - { url = "https://files.pythonhosted.org/packages/c7/7a/68bd47624dab8fd4afbfd3c48e3b79efe09098ae941de5b58abcbadff5cb/PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf", size = 171758 }, { url = "https://files.pythonhosted.org/packages/49/ee/14c54df452143b9ee9f0f29074d7ca5516a36edb0b4cc40c3f280131656f/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237", size = 718463 }, { url = "https://files.pythonhosted.org/packages/4d/61/de363a97476e766574650d742205be468921a7b532aa2499fcd886b62530/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b", size = 719280 }, { url = "https://files.pythonhosted.org/packages/6b/4e/1523cb902fd98355e2e9ea5e5eb237cbc5f3ad5f3075fa65087aa0ecb669/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed", size = 751239 }, { url = "https://files.pythonhosted.org/packages/b7/33/5504b3a9a4464893c32f118a9cc045190a91637b119a9c881da1cf6b7a72/PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180", size = 695802 }, { url = "https://files.pythonhosted.org/packages/5c/20/8347dcabd41ef3a3cdc4f7b7a2aff3d06598c8779faa189cdbf878b626a4/PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68", size = 720527 }, - { url = "https://files.pythonhosted.org/packages/be/aa/5afe99233fb360d0ff37377145a949ae258aaab831bde4792b32650a4378/PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99", size = 144052 }, - { url = "https://files.pythonhosted.org/packages/b5/84/0fa4b06f6d6c958d207620fc60005e241ecedceee58931bb20138e1e5776/PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e", size = 161774 }, - { url = "https://files.pythonhosted.org/packages/f8/aa/7af4e81f7acba21a4c6be026da38fd2b872ca46226673c89a758ebdc4fd2/PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774", size = 184612 }, - { url = "https://files.pythonhosted.org/packages/8b/62/b9faa998fd185f65c1371643678e4d58254add437edb764a08c5a98fb986/PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee", size = 172040 }, { url = "https://files.pythonhosted.org/packages/ad/0c/c804f5f922a9a6563bab712d8dcc70251e8af811fce4524d57c2c0fd49a4/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c", size = 736829 }, { url = "https://files.pythonhosted.org/packages/51/16/6af8d6a6b210c8e54f1406a6b9481febf9c64a3109c541567e35a49aa2e7/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317", size = 764167 }, { url = "https://files.pythonhosted.org/packages/75/e4/2c27590dfc9992f73aabbeb9241ae20220bd9452df27483b6e56d3975cc5/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85", size = 762952 }, { url = "https://files.pythonhosted.org/packages/9b/97/ecc1abf4a823f5ac61941a9c00fe501b02ac3ab0e373c3857f7d4b83e2b6/PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4", size = 735301 }, { url = "https://files.pythonhosted.org/packages/45/73/0f49dacd6e82c9430e46f4a027baa4ca205e8b0a9dce1397f44edc23559d/PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e", size = 756638 }, - { url = "https://files.pythonhosted.org/packages/22/5f/956f0f9fc65223a58fbc14459bf34b4cc48dec52e00535c79b8db361aabd/PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5", size = 143850 }, - { url = "https://files.pythonhosted.org/packages/ed/23/8da0bbe2ab9dcdd11f4f4557ccaf95c10b9811b13ecced089d43ce59c3c8/PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44", size = 161980 }, - { url = "https://files.pythonhosted.org/packages/86/0c/c581167fc46d6d6d7ddcfb8c843a4de25bdd27e4466938109ca68492292c/PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab", size = 183873 }, - { url = "https://files.pythonhosted.org/packages/a8/0c/38374f5bb272c051e2a69281d71cba6fdb983413e6758b84482905e29a5d/PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725", size = 173302 }, { url = "https://files.pythonhosted.org/packages/c3/93/9916574aa8c00aa06bbac729972eb1071d002b8e158bd0e83a3b9a20a1f7/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5", size = 739154 }, { url = "https://files.pythonhosted.org/packages/95/0f/b8938f1cbd09739c6da569d172531567dbcc9789e0029aa070856f123984/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425", size = 766223 }, { url = "https://files.pythonhosted.org/packages/b9/2b/614b4752f2e127db5cc206abc23a8c19678e92b23c3db30fc86ab731d3bd/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476", size = 767542 }, { url = "https://files.pythonhosted.org/packages/d4/00/dd137d5bcc7efea1836d6264f049359861cf548469d18da90cd8216cf05f/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48", size = 731164 }, { url = "https://files.pythonhosted.org/packages/c9/1f/4f998c900485e5c0ef43838363ba4a9723ac0ad73a9dc42068b12aaba4e4/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", size = 756611 }, - { url = "https://files.pythonhosted.org/packages/df/d1/f5a275fdb252768b7a11ec63585bc38d0e87c9e05668a139fea92b80634c/PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4", size = 140591 }, - { url = "https://files.pythonhosted.org/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338 }, - { url = "https://files.pythonhosted.org/packages/ef/e3/3af305b830494fa85d95f6d95ef7fa73f2ee1cc8ef5b495c7c3269fb835f/PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba", size = 181309 }, - { url = "https://files.pythonhosted.org/packages/45/9f/3b1c20a0b7a3200524eb0076cc027a970d320bd3a6592873c85c92a08731/PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1", size = 171679 }, { url = "https://files.pythonhosted.org/packages/7c/9a/337322f27005c33bcb656c655fa78325b730324c78620e8328ae28b64d0c/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133", size = 733428 }, { url = "https://files.pythonhosted.org/packages/a3/69/864fbe19e6c18ea3cc196cbe5d392175b4cf3d5d0ac1403ec3f2d237ebb5/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484", size = 763361 }, { url = "https://files.pythonhosted.org/packages/04/24/b7721e4845c2f162d26f50521b825fb061bc0a5afcf9a386840f23ea19fa/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5", size = 759523 }, { url = "https://files.pythonhosted.org/packages/2b/b2/e3234f59ba06559c6ff63c4e10baea10e5e7df868092bf9ab40e5b9c56b6/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc", size = 726660 }, { url = "https://files.pythonhosted.org/packages/fe/0f/25911a9f080464c59fab9027482f822b86bf0608957a5fcc6eaac85aa515/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652", size = 751597 }, - { url = "https://files.pythonhosted.org/packages/14/0d/e2c3b43bbce3cf6bd97c840b46088a3031085179e596d4929729d8d68270/PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183", size = 140527 }, - { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 }, - { url = "https://files.pythonhosted.org/packages/65/d8/b7a1db13636d7fb7d4ff431593c510c8b8fca920ade06ca8ef20015493c5/PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d", size = 184777 }, - { url = "https://files.pythonhosted.org/packages/0a/02/6ec546cd45143fdf9840b2c6be8d875116a64076218b61d68e12548e5839/PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f", size = 172318 }, { url = "https://files.pythonhosted.org/packages/0e/9a/8cc68be846c972bda34f6c2a93abb644fb2476f4dcc924d52175786932c9/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290", size = 720891 }, { url = "https://files.pythonhosted.org/packages/e9/6c/6e1b7f40181bc4805e2e07f4abc10a88ce4648e7e95ff1abe4ae4014a9b2/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12", size = 722614 }, { url = "https://files.pythonhosted.org/packages/3d/32/e7bd8535d22ea2874cef6a81021ba019474ace0d13a4819c2a4bce79bd6a/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19", size = 737360 }, { url = "https://files.pythonhosted.org/packages/d7/12/7322c1e30b9be969670b672573d45479edef72c9a0deac3bb2868f5d7469/PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e", size = 699006 }, { url = "https://files.pythonhosted.org/packages/82/72/04fcad41ca56491995076630c3ec1e834be241664c0c09a64c9a2589b507/PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725", size = 723577 }, - { url = "https://files.pythonhosted.org/packages/ed/5e/46168b1f2757f1fcd442bc3029cd8767d88a98c9c05770d8b420948743bb/PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631", size = 144593 }, - { url = "https://files.pythonhosted.org/packages/19/87/5124b1c1f2412bb95c59ec481eaf936cd32f0fe2a7b16b97b81c4c017a6a/PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8", size = 162312 }, ] [[package]] @@ -1211,9 +1042,6 @@ version = "2024.9.11" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/f9/38/148df33b4dbca3bd069b963acab5e0fa1a9dbd6820f8c322d0dd6faeff96/regex-2024.9.11.tar.gz", hash = "sha256:6c188c307e8433bcb63dc1915022deb553b4203a70722fc542c363bf120a01fd", size = 399403 } wheels = [ - { url = "https://files.pythonhosted.org/packages/63/12/497bd6599ce8a239ade68678132296aec5ee25ebea45fc8ba91aa60fceec/regex-2024.9.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1494fa8725c285a81d01dc8c06b55287a1ee5e0e382d8413adc0a9197aac6408", size = 482488 }, - { url = "https://files.pythonhosted.org/packages/c1/24/595ddb9bec2a9b151cdaf9565b0c9f3da9f0cb1dca6c158bc5175332ddf8/regex-2024.9.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0e12c481ad92d129c78f13a2a3662317e46ee7ef96c94fd332e1c29131875b7d", size = 287443 }, - { url = "https://files.pythonhosted.org/packages/69/a8/b2fb45d9715b1469383a0da7968f8cacc2f83e9fbbcd6b8713752dd980a6/regex-2024.9.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:16e13a7929791ac1216afde26f712802e3df7bf0360b32e4914dca3ab8baeea5", size = 284561 }, { url = "https://files.pythonhosted.org/packages/88/87/1ce4a5357216b19b7055e7d3b0efc75a6e426133bf1e7d094321df514257/regex-2024.9.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46989629904bad940bbec2106528140a218b4a36bb3042d8406980be1941429c", size = 783177 }, { url = "https://files.pythonhosted.org/packages/3c/65/b9f002ab32f7b68e7d1dcabb67926f3f47325b8dbc22cc50b6a043e1d07c/regex-2024.9.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a906ed5e47a0ce5f04b2c981af1c9acf9e8696066900bf03b9d7879a6f679fc8", size = 823193 }, { url = "https://files.pythonhosted.org/packages/22/91/8339dd3abce101204d246e31bc26cdd7ec07c9f91598472459a3a902aa41/regex-2024.9.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a091b0550b3b0207784a7d6d0f1a00d1d1c8a11699c1a4d93db3fbefc3ad35", size = 809950 }, @@ -1225,11 +1053,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/65/7b/953075723dd5ab00780043ac2f9de667306ff9e2a85332975e9f19279174/regex-2024.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:3ce4f1185db3fbde8ed8aa223fc9620f276c58de8b0d4f8cc86fd1360829edb6", size = 845373 }, { url = "https://files.pythonhosted.org/packages/40/b8/3e9484c6230b8b6e8f816ab7c9a080e631124991a4ae2c27a81631777db0/regex-2024.9.11-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:09d77559e80dcc9d24570da3745ab859a9cf91953062e4ab126ba9d5993688ca", size = 845369 }, { url = "https://files.pythonhosted.org/packages/b7/99/38434984d912edbd2e1969d116257e869578f67461bd7462b894c45ed874/regex-2024.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7a22ccefd4db3f12b526eccb129390942fe874a3a9fdbdd24cf55773a1faab1a", size = 773935 }, - { url = "https://files.pythonhosted.org/packages/ab/67/43174d2b46fa947b7b9dfe56b6c8a8a76d44223f35b1d64645a732fd1d6f/regex-2024.9.11-cp310-cp310-win32.whl", hash = "sha256:f745ec09bc1b0bd15cfc73df6fa4f726dcc26bb16c23a03f9e3367d357eeedd0", size = 261624 }, - { url = "https://files.pythonhosted.org/packages/c4/2a/4f9c47d9395b6aff24874c761d8d620c0232f97c43ef3cf668c8b355e7a7/regex-2024.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:01c2acb51f8a7d6494c8c5eafe3d8e06d76563d8a8a4643b37e9b2dd8a2ff623", size = 274020 }, - { url = "https://files.pythonhosted.org/packages/86/a1/d526b7b6095a0019aa360948c143aacfeb029919c898701ce7763bbe4c15/regex-2024.9.11-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2cce2449e5927a0bf084d346da6cd5eb016b2beca10d0013ab50e3c226ffc0df", size = 482483 }, - { url = "https://files.pythonhosted.org/packages/32/d9/bfdd153179867c275719e381e1e8e84a97bd186740456a0dcb3e7125c205/regex-2024.9.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3b37fa423beefa44919e009745ccbf353d8c981516e807995b2bd11c2c77d268", size = 287442 }, - { url = "https://files.pythonhosted.org/packages/33/c4/60f3370735135e3a8d673ddcdb2507a8560d0e759e1398d366e43d000253/regex-2024.9.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:64ce2799bd75039b480cc0360907c4fb2f50022f030bf9e7a8705b636e408fad", size = 284561 }, { url = "https://files.pythonhosted.org/packages/b1/51/91a5ebdff17f9ec4973cb0aa9d37635efec1c6868654bbc25d1543aca4ec/regex-2024.9.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4cc92bb6db56ab0c1cbd17294e14f5e9224f0cc6521167ef388332604e92679", size = 791779 }, { url = "https://files.pythonhosted.org/packages/07/4a/022c5e6f0891a90cd7eb3d664d6c58ce2aba48bff107b00013f3d6167069/regex-2024.9.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d05ac6fa06959c4172eccd99a222e1fbf17b5670c4d596cb1e5cde99600674c4", size = 832605 }, { url = "https://files.pythonhosted.org/packages/ac/1c/3793990c8c83ca04e018151ddda83b83ecc41d89964f0f17749f027fc44d/regex-2024.9.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:040562757795eeea356394a7fb13076ad4f99d3c62ab0f8bdfb21f99a1f85664", size = 818556 }, @@ -1240,11 +1063,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/93/8d/65b9bea7df120a7be8337c415b6d256ba786cbc9107cebba3bf8ff09da99/regex-2024.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:7fb89ee5d106e4a7a51bce305ac4efb981536301895f7bdcf93ec92ae0d91c7f", size = 853744 }, { url = "https://files.pythonhosted.org/packages/96/a7/fba1eae75eb53a704475baf11bd44b3e6ccb95b316955027eb7748f24ef8/regex-2024.9.11-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a738b937d512b30bf75995c0159c0ddf9eec0775c9d72ac0202076c72f24aa96", size = 855890 }, { url = "https://files.pythonhosted.org/packages/45/14/d864b2db80a1a3358534392373e8a281d95b28c29c87d8548aed58813910/regex-2024.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e28f9faeb14b6f23ac55bfbbfd3643f5c7c18ede093977f1df249f73fd22c7b1", size = 781887 }, - { url = "https://files.pythonhosted.org/packages/4d/a9/bfb29b3de3eb11dc9b412603437023b8e6c02fb4e11311863d9bf62c403a/regex-2024.9.11-cp311-cp311-win32.whl", hash = "sha256:18e707ce6c92d7282dfce370cd205098384b8ee21544e7cb29b8aab955b66fa9", size = 261644 }, - { url = "https://files.pythonhosted.org/packages/c7/ab/1ad2511cf6a208fde57fafe49829cab8ca018128ab0d0b48973d8218634a/regex-2024.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:313ea15e5ff2a8cbbad96ccef6be638393041b0a7863183c2d31e0c6116688cf", size = 274033 }, - { url = "https://files.pythonhosted.org/packages/6e/92/407531450762bed778eedbde04407f68cbd75d13cee96c6f8d6903d9c6c1/regex-2024.9.11-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b0d0a6c64fcc4ef9c69bd5b3b3626cc3776520a1637d8abaa62b9edc147a58f7", size = 483590 }, - { url = "https://files.pythonhosted.org/packages/8e/a2/048acbc5ae1f615adc6cba36cc45734e679b5f1e4e58c3c77f0ed611d4e2/regex-2024.9.11-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:49b0e06786ea663f933f3710a51e9385ce0cba0ea56b67107fd841a55d56a231", size = 288175 }, - { url = "https://files.pythonhosted.org/packages/8a/ea/909d8620329ab710dfaf7b4adee41242ab7c9b95ea8d838e9bfe76244259/regex-2024.9.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5b513b6997a0b2f10e4fd3a1313568e373926e8c252bd76c960f96fd039cd28d", size = 284749 }, { url = "https://files.pythonhosted.org/packages/ca/fa/521eb683b916389b4975337873e66954e0f6d8f91bd5774164a57b503185/regex-2024.9.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee439691d8c23e76f9802c42a95cfeebf9d47cf4ffd06f18489122dbb0a7ad64", size = 795181 }, { url = "https://files.pythonhosted.org/packages/28/db/63047feddc3280cc242f9c74f7aeddc6ee662b1835f00046f57d5630c827/regex-2024.9.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a8f877c89719d759e52783f7fe6e1c67121076b87b40542966c02de5503ace42", size = 835842 }, { url = "https://files.pythonhosted.org/packages/e3/94/86adc259ff8ec26edf35fcca7e334566c1805c7493b192cb09679f9c3dee/regex-2024.9.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:23b30c62d0f16827f2ae9f2bb87619bc4fba2044911e2e6c2eb1af0161cdb766", size = 823533 }, @@ -1255,11 +1073,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/71/eff77d3fe7ba08ab0672920059ec30d63fa7e41aa0fb61c562726e9bd721/regex-2024.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d552c78411f60b1fdaafd117a1fca2f02e562e309223b9d44b7de8be451ec5e0", size = 860214 }, { url = "https://files.pythonhosted.org/packages/81/11/e1bdf84a72372e56f1ea4b833dd583b822a23138a616ace7ab57a0e11556/regex-2024.9.11-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a0b2b80321c2ed3fcf0385ec9e51a12253c50f146fddb2abbb10f033fe3d049a", size = 859420 }, { url = "https://files.pythonhosted.org/packages/ea/75/9753e9dcebfa7c3645563ef5c8a58f3a47e799c872165f37c55737dadd3e/regex-2024.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:18406efb2f5a0e57e3a5881cd9354c1512d3bb4f5c45d96d110a66114d84d23a", size = 787333 }, - { url = "https://files.pythonhosted.org/packages/bc/4e/ba1cbca93141f7416624b3ae63573e785d4bc1834c8be44a8f0747919eca/regex-2024.9.11-cp312-cp312-win32.whl", hash = "sha256:e464b467f1588e2c42d26814231edecbcfe77f5ac414d92cbf4e7b55b2c2a776", size = 262058 }, - { url = "https://files.pythonhosted.org/packages/6e/16/efc5f194778bf43e5888209e5cec4b258005d37c613b67ae137df3b89c53/regex-2024.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:9e8719792ca63c6b8340380352c24dcb8cd7ec49dae36e963742a275dfae6009", size = 273526 }, - { url = "https://files.pythonhosted.org/packages/93/0a/d1c6b9af1ff1e36832fe38d74d5c5bab913f2bdcbbd6bc0e7f3ce8b2f577/regex-2024.9.11-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:c157bb447303070f256e084668b702073db99bbb61d44f85d811025fcf38f784", size = 483376 }, - { url = "https://files.pythonhosted.org/packages/a4/42/5910a050c105d7f750a72dcb49c30220c3ae4e2654e54aaaa0e9bc0584cb/regex-2024.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4db21ece84dfeefc5d8a3863f101995de646c6cb0536952c321a2650aa202c36", size = 288112 }, - { url = "https://files.pythonhosted.org/packages/8d/56/0c262aff0e9224fa7ffce47b5458d373f4d3e3ff84e99b5ff0cb15e0b5b2/regex-2024.9.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:220e92a30b426daf23bb67a7962900ed4613589bab80382be09b48896d211e92", size = 284608 }, { url = "https://files.pythonhosted.org/packages/b9/54/9fe8f9aec5007bbbbce28ba3d2e3eaca425f95387b7d1e84f0d137d25237/regex-2024.9.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb1ae19e64c14c7ec1995f40bd932448713d3c73509e82d8cd7744dc00e29e86", size = 795337 }, { url = "https://files.pythonhosted.org/packages/b2/e7/6b2f642c3cded271c4f16cc4daa7231be544d30fe2b168e0223724b49a61/regex-2024.9.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f47cd43a5bfa48f86925fe26fbdd0a488ff15b62468abb5d2a1e092a4fb10e85", size = 835848 }, { url = "https://files.pythonhosted.org/packages/cd/9e/187363bdf5d8c0e4662117b92aa32bf52f8f09620ae93abc7537d96d3311/regex-2024.9.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9d4a76b96f398697fe01117093613166e6aa8195d63f1b4ec3f21ab637632963", size = 823503 }, @@ -1270,11 +1083,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/1d/43ed03a236313639da5a45e61bc553c8d41e925bcf29b0f8ecff0c2c3f25/regex-2024.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:dd4490a33eb909ef5078ab20f5f000087afa2a4daa27b4c072ccb3cb3050ad84", size = 860435 }, { url = "https://files.pythonhosted.org/packages/34/4f/5d04da61c7c56e785058a46349f7285ae3ebc0726c6ea7c5c70600a52233/regex-2024.9.11-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:eee9130eaad130649fd73e5cd92f60e55708952260ede70da64de420cdcad554", size = 859571 }, { url = "https://files.pythonhosted.org/packages/12/7f/8398c8155a3c70703a8e91c29532558186558e1aea44144b382faa2a6f7a/regex-2024.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6a2644a93da36c784e546de579ec1806bfd2763ef47babc1b03d765fe560c9f8", size = 787398 }, - { url = "https://files.pythonhosted.org/packages/58/3a/f5903977647a9a7e46d5535e9e96c194304aeeca7501240509bde2f9e17f/regex-2024.9.11-cp313-cp313-win32.whl", hash = "sha256:e997fd30430c57138adc06bba4c7c2968fb13d101e57dd5bb9355bf8ce3fa7e8", size = 262035 }, - { url = "https://files.pythonhosted.org/packages/ff/80/51ba3a4b7482f6011095b3a036e07374f64de180b7d870b704ed22509002/regex-2024.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:042c55879cfeb21a8adacc84ea347721d3d83a159da6acdf1116859e2427c43f", size = 273510 }, - { url = "https://files.pythonhosted.org/packages/a1/aa/e31baf8482ad690ccb3cdf20d1963a01e98d137e4d9ee493dbb0fa8ba2c6/regex-2024.9.11-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:07f45f287469039ffc2c53caf6803cd506eb5f5f637f1d4acb37a738f71dd066", size = 482489 }, - { url = "https://files.pythonhosted.org/packages/a1/b5/449c2f14fc20dc42ef9729469fcff42809393470f021ed6c6fcf5f3d3297/regex-2024.9.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4838e24ee015101d9f901988001038f7f0d90dc0c3b115541a1365fb439add62", size = 287440 }, - { url = "https://files.pythonhosted.org/packages/3f/36/4b60a0c2e4cc6ecb2651be828117a31f42fae55a51a484a8071729df56a6/regex-2024.9.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6edd623bae6a737f10ce853ea076f56f507fd7726bee96a41ee3d68d347e4d16", size = 284566 }, { url = "https://files.pythonhosted.org/packages/b4/21/feaa5b0d3e5e3bad659cd7d640e6b76cc0719504dbd9bc8f67cfa21bde82/regex-2024.9.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c69ada171c2d0e97a4b5aa78fbb835e0ffbb6b13fc5da968c09811346564f0d3", size = 782747 }, { url = "https://files.pythonhosted.org/packages/bb/89/93516f0aa3e8a9366df2cf79bb0290abdc7dbe5dd27373d9bea0978b7ba6/regex-2024.9.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:02087ea0a03b4af1ed6ebab2c54d7118127fee8d71b26398e8e4b05b78963199", size = 822700 }, { url = "https://files.pythonhosted.org/packages/d5/e7/79c04ccb81cee2831d9d4499274919b9153c1741ce8b3421d69cb0032f1b/regex-2024.9.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:69dee6a020693d12a3cf892aba4808fe168d2a4cef368eb9bf74f5398bfd4ee8", size = 809327 }, @@ -1286,8 +1094,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/16/71/d964c0c9d447f04bbe6ab5eafd220208e7d52b9608e452e6fcad553b38e0/regex-2024.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:73d6d2f64f4d894c96626a75578b0bf7d9e56dcda8c3d037a2118fdfe9b1c664", size = 845014 }, { url = "https://files.pythonhosted.org/packages/83/cb/a378cdc2468782eefefa50183bbeabc3357fb588d4109d845f0a56e68713/regex-2024.9.11-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:e53b5fbab5d675aec9f0c501274c467c0f9a5d23696cfc94247e1fb56501ed89", size = 844916 }, { url = "https://files.pythonhosted.org/packages/b9/f0/82ea1565a6639270cfe96263002b3d91084a1db5048d9b6084f83bd5972d/regex-2024.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0ffbcf9221e04502fc35e54d1ce9567541979c3fdfb93d2c554f0ca583a19b35", size = 773409 }, - { url = "https://files.pythonhosted.org/packages/97/9e/0400d742b9647b4940609a96d550de89e4e89c85f6a370796dab25b5979c/regex-2024.9.11-cp39-cp39-win32.whl", hash = "sha256:e4c22e1ac1f1ec1e09f72e6c44d8f2244173db7eb9629cc3a346a8d7ccc31142", size = 261680 }, - { url = "https://files.pythonhosted.org/packages/b6/f1/aef1112652ac7b3922d2c129f8325a4fd286b66691127dd99f380f8ede19/regex-2024.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:faa3c142464efec496967359ca99696c896c591c56c53506bac1ad465f66e919", size = 274066 }, ] [[package]] @@ -1330,8 +1136,6 @@ source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/96/3f/29b2d3d90f811f6fb5b90242309f4668cd8c2482aab86ffc23099000545b/ruff-0.6.5.tar.gz", hash = "sha256:4d32d87fab433c0cf285c3683dd4dae63be05fd7a1d65b3f5bf7cdd05a6b96fb", size = 2476127 } wheels = [ { url = "https://files.pythonhosted.org/packages/64/05/cc62df44b5a0271b29f11d687aa89e85943e0d26e5bb773dbc1456d9885d/ruff-0.6.5-py3-none-linux_armv6l.whl", hash = "sha256:7e4e308f16e07c95fc7753fc1aaac690a323b2bb9f4ec5e844a97bb7fbebd748", size = 9770988 }, - { url = "https://files.pythonhosted.org/packages/09/3d/89dac56ab7053d5b7cba723c9cae1a29b7a2978174c67e2441525ee00343/ruff-0.6.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:932cd69eefe4daf8c7d92bd6689f7e8182571cb934ea720af218929da7bd7d69", size = 9423303 }, - { url = "https://files.pythonhosted.org/packages/70/76/dc04654d26beace866a3c9e0c87112304e3d6406e1ee8ca0d9bebbd82d91/ruff-0.6.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3a8d42d11fff8d3143ff4da41742a98f8f233bf8890e9fe23077826818f8d680", size = 9134078 }, { url = "https://files.pythonhosted.org/packages/da/52/6a492cffcd2c6e243043937ab52811b6ebb10cb5b77a68cc98e7676ceaef/ruff-0.6.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a50af6e828ee692fb10ff2dfe53f05caecf077f4210fae9677e06a808275754f", size = 10105094 }, { url = "https://files.pythonhosted.org/packages/59/7c/fd76a583ae59a276537d71921d616a83ec7774027d0812049afb6af8a07f/ruff-0.6.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:794ada3400a0d0b89e3015f1a7e01f4c97320ac665b7bc3ade24b50b54cb2972", size = 9542751 }, { url = "https://files.pythonhosted.org/packages/56/5b/4e8928fa11412b16ecf7d7755fe45db6dfa7abce32841f6aec33bae3a7da/ruff-0.6.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:381413ec47f71ce1d1c614f7779d88886f406f1fd53d289c77e4e533dc6ea200", size = 10358844 }, @@ -1343,9 +1147,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/d0/0bacdffc234e588ec05834186ad11ec8281a6ca598d0106892497bbcfa44/ruff-0.6.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9ad7dfbd138d09d9a7e6931e6a7e797651ce29becd688be8a0d4d5f8177b4b0c", size = 9625374 }, { url = "https://files.pythonhosted.org/packages/1a/ad/721003cde8abd9f50bff74acbcb21852531036451d48a1abddba4dd84025/ruff-0.6.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:005256d977021790cc52aa23d78f06bb5090dc0bfbd42de46d49c201533982ae", size = 9959661 }, { url = "https://files.pythonhosted.org/packages/37/84/8d70a3eacaacb65b4bb1461fc1a59e37ff165152b7e507692109117c877f/ruff-0.6.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:482c1e6bfeb615eafc5899127b805d28e387bd87db38b2c0c41d271f5e58d8cc", size = 10327408 }, - { url = "https://files.pythonhosted.org/packages/54/7e/6b0a9ab30428a9e3d9607f6dd2e4fb743594d42bd1b6ba7b7b239acda921/ruff-0.6.5-py3-none-win32.whl", hash = "sha256:cf4d3fa53644137f6a4a27a2b397381d16454a1566ae5335855c187fbf67e4f5", size = 8012512 }, - { url = "https://files.pythonhosted.org/packages/d8/88/176f50162a219e3039f21e9e4323869fc62bf8d3afb4147a390d6c744bd8/ruff-0.6.5-py3-none-win_amd64.whl", hash = "sha256:3e42a57b58e3612051a636bc1ac4e6b838679530235520e8f095f7c44f706ff9", size = 8804438 }, - { url = "https://files.pythonhosted.org/packages/67/a0/1b488bbe35a7ff8296fdea1ec1a9c2676cecc7e42bda63860f9397d59140/ruff-0.6.5-py3-none-win_arm64.whl", hash = "sha256:51935067740773afdf97493ba9b8231279e9beef0f2a8079188c4776c25688e0", size = 8179780 }, ] [[package]] @@ -1354,8 +1155,6 @@ version = "0.4.5" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/cb/46/a1c56ed856c6ac3b1a8b37abe5be0cac53219367af1331e721b04d122577/safetensors-0.4.5.tar.gz", hash = "sha256:d73de19682deabb02524b3d5d1f8b3aaba94c72f1bbfc7911b9b9d5d391c0310", size = 65702 } wheels = [ - { url = "https://files.pythonhosted.org/packages/38/10/0798ec2c8704c2d172620d8a3725bed92cdd75516357b1a3e64d4229ea4e/safetensors-0.4.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a63eaccd22243c67e4f2b1c3e258b257effc4acd78f3b9d397edc8cf8f1298a7", size = 392312 }, - { url = "https://files.pythonhosted.org/packages/2b/9e/9648d8dbb485c40a4a0212b7537626ae440b48156cc74601ca0b7a7615e0/safetensors-0.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:23fc9b4ec7b602915cbb4ec1a7c1ad96d2743c322f20ab709e2c35d1b66dad27", size = 381858 }, { url = "https://files.pythonhosted.org/packages/8b/67/49556aeacc00df353767ed31d68b492fecf38c3f664c52692e4d92aa0032/safetensors-0.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6885016f34bef80ea1085b7e99b3c1f92cb1be78a49839203060f67b40aee761", size = 441382 }, { url = "https://files.pythonhosted.org/packages/5d/ce/e9f4869a37bb11229e6cdb4e73a6ef23b4f360eee9dca5f7e40982779704/safetensors-0.4.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:133620f443450429322f238fda74d512c4008621227fccf2f8cf4a76206fea7c", size = 439001 }, { url = "https://files.pythonhosted.org/packages/a0/27/aee8cf031b89c34caf83194ec6b7f2eed28d053fff8b6da6d00c85c56035/safetensors-0.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4fb3e0609ec12d2a77e882f07cced530b8262027f64b75d399f1504ffec0ba56", size = 478026 }, @@ -1364,10 +1163,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/90/fa/7bc3f18086201b1e55a42c88b822ae197d0158e12c54cd45c887305f1b7e/safetensors-0.4.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9e347d77e2c77eb7624400ccd09bed69d35c0332f417ce8c048d404a096c593b", size = 456273 }, { url = "https://files.pythonhosted.org/packages/3e/59/2ae50150d37a65c1c5f01aec74dc737707b8bbecdc76307e5a1a12c8a376/safetensors-0.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9f556eea3aec1d3d955403159fe2123ddd68e880f83954ee9b4a3f2e15e716b6", size = 619669 }, { url = "https://files.pythonhosted.org/packages/fe/43/10f0bb597aef62c9c154152e265057089f3c729bdd980e6c32c3ec2407a4/safetensors-0.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9483f42be3b6bc8ff77dd67302de8ae411c4db39f7224dec66b0eb95822e4163", size = 605212 }, - { url = "https://files.pythonhosted.org/packages/7c/75/ede6887ea0ceaba55730988bfc7668dc147a8758f907fa6db26fbb681b8e/safetensors-0.4.5-cp310-none-win32.whl", hash = "sha256:7389129c03fadd1ccc37fd1ebbc773f2b031483b04700923c3511d2a939252cc", size = 272652 }, - { url = "https://files.pythonhosted.org/packages/ba/f0/919c72a9eef843781e652d0650f2819039943e69b69d5af2d0451a23edc3/safetensors-0.4.5-cp310-none-win_amd64.whl", hash = "sha256:e98ef5524f8b6620c8cdef97220c0b6a5c1cef69852fcd2f174bb96c2bb316b1", size = 285879 }, - { url = "https://files.pythonhosted.org/packages/9a/a5/25bcf75e373412daf1fd88045ab3aa8140a0d804ef0e70712c4f2c5b94d8/safetensors-0.4.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:21f848d7aebd5954f92538552d6d75f7c1b4500f51664078b5b49720d180e47c", size = 392256 }, - { url = "https://files.pythonhosted.org/packages/08/8c/ece3bf8756506a890bd980eca02f47f9d98dfbf5ce16eda1368f53560f67/safetensors-0.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bb07000b19d41e35eecef9a454f31a8b4718a185293f0d0b1c4b61d6e4487971", size = 381490 }, { url = "https://files.pythonhosted.org/packages/39/83/c4a7ce01d626e46ea2b45887f2e59b16441408031e2ce2f9fe01860c6946/safetensors-0.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09dedf7c2fda934ee68143202acff6e9e8eb0ddeeb4cfc24182bef999efa9f42", size = 441093 }, { url = "https://files.pythonhosted.org/packages/47/26/cc52de647e71bd9a0b0d78ead0d31d9c462b35550a817aa9e0cab51d6db4/safetensors-0.4.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:59b77e4b7a708988d84f26de3ebead61ef1659c73dcbc9946c18f3b1786d2688", size = 438960 }, { url = "https://files.pythonhosted.org/packages/06/78/332538546775ee97e749867df2d58f2282d9c48a1681e4891eed8b94ec94/safetensors-0.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d3bc83e14d67adc2e9387e511097f254bd1b43c3020440e708858c684cbac68", size = 478031 }, @@ -1376,10 +1171,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/61/f0cfce984515b86d1260f556ba3b782158e2855e6a318446ac2613786fa9/safetensors-0.4.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a659467495de201e2f282063808a41170448c78bada1e62707b07a27b05e6943", size = 455984 }, { url = "https://files.pythonhosted.org/packages/e7/a9/3e3b48fcaade3eb4e347d39ebf0bd44291db21a3e4507854b42a7cb910ac/safetensors-0.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bad5e4b2476949bcd638a89f71b6916fa9a5cae5c1ae7eede337aca2100435c0", size = 619513 }, { url = "https://files.pythonhosted.org/packages/80/23/2a7a1be24258c0e44c1d356896fd63dc0545a98d2d0184925fa09cd3ec76/safetensors-0.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a3a315a6d0054bc6889a17f5668a73f94f7fe55121ff59e0a199e3519c08565f", size = 604841 }, - { url = "https://files.pythonhosted.org/packages/b4/5c/34d082ff1fffffd8545fb22cbae3285ab4236f1f0cfc64b7e58261c2363b/safetensors-0.4.5-cp311-none-win32.whl", hash = "sha256:a01e232e6d3d5cf8b1667bc3b657a77bdab73f0743c26c1d3c5dd7ce86bd3a92", size = 272602 }, - { url = "https://files.pythonhosted.org/packages/6d/41/948c96c8a7e9fef57c2e051f1871c108a6dbbc6d285598bdb1d89b98617c/safetensors-0.4.5-cp311-none-win_amd64.whl", hash = "sha256:cbd39cae1ad3e3ef6f63a6f07296b080c951f24cec60188378e43d3713000c04", size = 285973 }, - { url = "https://files.pythonhosted.org/packages/bf/ac/5a63082f931e99200db95fd46fb6734f050bb6e96bf02521904c6518b7aa/safetensors-0.4.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:473300314e026bd1043cef391bb16a8689453363381561b8a3e443870937cc1e", size = 392015 }, - { url = "https://files.pythonhosted.org/packages/73/95/ab32aa6e9bdc832ff87784cdf9da26192b93de3ef82b8d1ada8f345c5044/safetensors-0.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:801183a0f76dc647f51a2d9141ad341f9665602a7899a693207a82fb102cc53e", size = 381774 }, { url = "https://files.pythonhosted.org/packages/d6/6c/7e04b7626809fc63f3698f4c50e43aff2864b40089aa4506c918a75b8eed/safetensors-0.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1524b54246e422ad6fb6aea1ac71edeeb77666efa67230e1faf6999df9b2e27f", size = 441134 }, { url = "https://files.pythonhosted.org/packages/58/2b/ffe7c86a277e6c1595fbdf415cfe2903f253f574a5405e93fda8baaa582c/safetensors-0.4.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b3139098e3e8b2ad7afbca96d30ad29157b50c90861084e69fcb80dec7430461", size = 438467 }, { url = "https://files.pythonhosted.org/packages/67/9c/f271bd804e08c7fda954d17b70ff281228a88077337a9e70feace4f4cc93/safetensors-0.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65573dc35be9059770808e276b017256fa30058802c29e1038eb1c00028502ea", size = 476566 }, @@ -1388,10 +1179,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/47/d4b49b1231abf3131f7bb0bc60ebb94b27ee33e0a1f9569da05f8ac65dee/safetensors-0.4.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dde2bf390d25f67908278d6f5d59e46211ef98e44108727084d4637ee70ab4f1", size = 457166 }, { url = "https://files.pythonhosted.org/packages/c3/cd/006468b03b0fa42ff82d795d47c4193e99001e96c3f08bd62ef1b5cab586/safetensors-0.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7469d70d3de970b1698d47c11ebbf296a308702cbaae7fcb993944751cf985f4", size = 619280 }, { url = "https://files.pythonhosted.org/packages/22/4d/b6208d918e83daa84b424c0ac3191ae61b44b3191613a3a5a7b38f94b8ad/safetensors-0.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a6ba28118636a130ccbb968bc33d4684c48678695dba2590169d5ab03a45646", size = 605390 }, - { url = "https://files.pythonhosted.org/packages/e8/20/bf0e01825dc01ed75538021a98b9a046e60ead63c6c6700764c821a8c873/safetensors-0.4.5-cp312-none-win32.whl", hash = "sha256:c859c7ed90b0047f58ee27751c8e56951452ed36a67afee1b0a87847d065eec6", size = 273250 }, - { url = "https://files.pythonhosted.org/packages/f1/5f/ab6b6cec85b40789801f35b7d2fb579ae242d8193929974a106d5ff5c835/safetensors-0.4.5-cp312-none-win_amd64.whl", hash = "sha256:b5a8810ad6a6f933fff6c276eae92c1da217b39b4d8b1bc1c0b8af2d270dc532", size = 286307 }, - { url = "https://files.pythonhosted.org/packages/90/61/0e27b1403e311cba0be20026bee4ee822d90eda7dad372179e7f18bb99f3/safetensors-0.4.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:25e5f8e2e92a74f05b4ca55686234c32aac19927903792b30ee6d7bd5653d54e", size = 392062 }, - { url = "https://files.pythonhosted.org/packages/b1/9f/cc31fafc9f5d79da10a83a820ca37f069bab0717895ad8cbcacf629dd1c5/safetensors-0.4.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:81efb124b58af39fcd684254c645e35692fea81c51627259cdf6d67ff4458916", size = 382517 }, { url = "https://files.pythonhosted.org/packages/a4/c7/4fda8a0ebb96662550433378f4a74c677fa5fc4d0a43a7ec287d1df254a9/safetensors-0.4.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:585f1703a518b437f5103aa9cf70e9bd437cb78eea9c51024329e4fb8a3e3679", size = 441378 }, { url = "https://files.pythonhosted.org/packages/14/31/9abb431f6209de9c80dab83e1112ebd769f1e32e7ab7ab228a02424a4693/safetensors-0.4.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4b99fbf72e3faf0b2f5f16e5e3458b93b7d0a83984fe8d5364c60aa169f2da89", size = 438831 }, { url = "https://files.pythonhosted.org/packages/37/37/99bfb195578a808b8d045159ee9264f8da58d017ac0701853dcacda14d4e/safetensors-0.4.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b17b299ca9966ca983ecda1c0791a3f07f9ca6ab5ded8ef3d283fff45f6bcd5f", size = 477112 }, @@ -1400,8 +1187,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/6c/7a3233c08bde558d6c33a41219119866cb596139a4673cc6c24024710ffd/safetensors-0.4.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d94581aab8c6b204def4d7320f07534d6ee34cd4855688004a4354e63b639a35", size = 457382 }, { url = "https://files.pythonhosted.org/packages/a0/58/0b7bcba3788ff503990cf9278d611b56c029400612ba93e772c987b5aa03/safetensors-0.4.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:67e1e7cb8678bb1b37ac48ec0df04faf689e2f4e9e81e566b5c63d9f23748523", size = 619301 }, { url = "https://files.pythonhosted.org/packages/82/cc/9c2cf58611daf1c83ce5d37f9de66353e23fcda36008b13fd3409a760aa3/safetensors-0.4.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:dbd280b07e6054ea68b0cb4b16ad9703e7d63cd6890f577cb98acc5354780142", size = 605580 }, - { url = "https://files.pythonhosted.org/packages/78/a7/47e05af6b39964a98396d593fd164723e442871dcf55fff0202dfff50b3b/safetensors-0.4.5-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:cf727bb1281d66699bef5683b04d98c894a2803442c490a8d45cd365abfbdeb2", size = 393129 }, - { url = "https://files.pythonhosted.org/packages/a4/1e/643a04fa43e070da11e11c6defdf0930fb5216aa5e734fa00e238fd09ebb/safetensors-0.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:96f1d038c827cdc552d97e71f522e1049fef0542be575421f7684756a748e457", size = 383165 }, { url = "https://files.pythonhosted.org/packages/08/94/7760694760f1e5001bd62c93155b8b7ccb652d1f4d0161d1e72b5bf9581a/safetensors-0.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:139fbee92570ecea774e6344fee908907db79646d00b12c535f66bc78bd5ea2c", size = 442391 }, { url = "https://files.pythonhosted.org/packages/03/1c/0db6e6e5cb293907b2242447b48cc09f31478aa02f08773155c2a2db22de/safetensors-0.4.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c36302c1c69eebb383775a89645a32b9d266878fab619819ce660309d6176c9b", size = 440015 }, { url = "https://files.pythonhosted.org/packages/15/58/9658bf7ca3a4e77577fbd2c7afda4701c558db66b01daf7cd4d9dbd9781e/safetensors-0.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d641f5b8149ea98deb5ffcf604d764aad1de38a8285f86771ce1abf8e74c4891", size = 478099 }, @@ -1410,17 +1195,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2e/ad/7880a359b0f93322689804bdbe1e9a3110652963478712933ff04a3d45c3/safetensors-0.4.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:788ee7d04cc0e0e7f944c52ff05f52a4415b312f5efd2ee66389fb7685ee030c", size = 456901 }, { url = "https://files.pythonhosted.org/packages/89/4f/0b61e4add7ea9dfa8141d0bb1b8357e3a08730a020c3a287f0e889c386b5/safetensors-0.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:87bc42bd04fd9ca31396d3ca0433db0be1411b6b53ac5a32b7845a85d01ffc2e", size = 620159 }, { url = "https://files.pythonhosted.org/packages/a9/60/544687daf8ce8dc9a74260992ac058d7e3f20c91eada5ca232898d005149/safetensors-0.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4037676c86365a721a8c9510323a51861d703b399b78a6b4486a54a65a975fca", size = 605993 }, - { url = "https://files.pythonhosted.org/packages/98/9a/2889d9df45ee09a02a17b3349c5649dc5516d1d167515b520e4aa79bdc5b/safetensors-0.4.5-cp39-none-win32.whl", hash = "sha256:1500418454529d0ed5c1564bda376c4ddff43f30fce9517d9bee7bcce5a8ef50", size = 272930 }, - { url = "https://files.pythonhosted.org/packages/ce/00/a4bdf45a5f2e1db08aaf95bb97f8ca30ec9568573eda03ec0db9ce5ed5d2/safetensors-0.4.5-cp39-none-win_amd64.whl", hash = "sha256:9d1a94b9d793ed8fe35ab6d5cea28d540a46559bafc6aae98f30ee0867000cab", size = 286065 }, - { url = "https://files.pythonhosted.org/packages/cf/ff/037ae4c0ee32db496669365e66079b6329906c6814722b159aa700e67208/safetensors-0.4.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fdadf66b5a22ceb645d5435a0be7a0292ce59648ca1d46b352f13cff3ea80410", size = 392951 }, - { url = "https://files.pythonhosted.org/packages/f1/d6/6621e16b35bf83ae099eaab07338f04991a26c9aa43879d05f19f35e149c/safetensors-0.4.5-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d42ffd4c2259f31832cb17ff866c111684c87bd930892a1ba53fed28370c918c", size = 383417 }, { url = "https://files.pythonhosted.org/packages/ae/88/3068e1bb16f5e9f9068901de3cf7b3db270b9bfe6e7d51d4b55c1da0425d/safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd8a1f6d2063a92cd04145c7fd9e31a1c7d85fbec20113a14b487563fdbc0597", size = 442311 }, { url = "https://files.pythonhosted.org/packages/f7/15/a2bb77ebbaa76b61ec2e9f731fe4db7f9473fd855d881957c51b3a168892/safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:951d2fcf1817f4fb0ef0b48f6696688a4e852a95922a042b3f96aaa67eedc920", size = 436678 }, { url = "https://files.pythonhosted.org/packages/ec/79/9608c4546cdbfe3860dd7aa59e3562c9289113398b1a0bd89b68ce0a9d41/safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ac85d9a8c1af0e3132371d9f2d134695a06a96993c2e2f0bbe25debb9e3f67a", size = 457316 }, { url = "https://files.pythonhosted.org/packages/0f/23/b17b483f2857835962ad33e38014efd4911791187e177bc23b057d35bee8/safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e3cec4a29eb7fe8da0b1c7988bc3828183080439dd559f720414450de076fcab", size = 620565 }, { url = "https://files.pythonhosted.org/packages/19/46/5d11dc300feaad285c2f1bd784ff3f689f5e0ab6be49aaf568f3a77019eb/safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:21742b391b859e67b26c0b2ac37f52c9c0944a879a25ad2f9f9f3cd61e7fda8f", size = 606660 }, - { url = "https://files.pythonhosted.org/packages/5b/f9/539335e927cfeca8effc972d47e06155c4a39989905082c02b5c72769c41/safetensors-0.4.5-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f4beb84b6073b1247a773141a6331117e35d07134b3bb0383003f39971d414bb", size = 393986 }, - { url = "https://files.pythonhosted.org/packages/72/c6/988925bae113bb280642329fcbbfb502ba1bc9720b6be47c1f4c1fb7cc87/safetensors-0.4.5-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:68814d599d25ed2fdd045ed54d370d1d03cf35e02dce56de44c651f828fb9b7b", size = 384563 }, { url = "https://files.pythonhosted.org/packages/b3/ff/b26d78b6100a08e57a1986ab71a2f9f093ba9943626f4967cd514cd43de2/safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0b6453c54c57c1781292c46593f8a37254b8b99004c68d6c3ce229688931a22", size = 442275 }, { url = "https://files.pythonhosted.org/packages/71/29/6ac541358a07ec593ec9e88636908010bc9bf56c8018e0d25b4481adb64a/safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:adaa9c6dead67e2dd90d634f89131e43162012479d86e25618e821a03d1eb1dc", size = 437217 }, { url = "https://files.pythonhosted.org/packages/2b/f8/258564b71fe95d0117356e6915b1c0128f1ec3031cf8522a28f9d2108b47/safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:73e7d408e9012cd17511b382b43547850969c7979efc2bc353f317abaf23c84c", size = 458132 }, @@ -1437,30 +1216,18 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/ae/00/48c2f661e2816ccf2ecd77982f6605b2950afe60f60a52b4cbbc2504aa8f/scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c", size = 57210720 } wheels = [ - { url = "https://files.pythonhosted.org/packages/33/59/41b2529908c002ade869623b87eecff3e11e3ce62e996d0bdcb536984187/scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca", size = 39328076 }, - { url = "https://files.pythonhosted.org/packages/d5/33/f1307601f492f764062ce7dd471a14750f3360e33cd0f8c614dae208492c/scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f", size = 30306232 }, { url = "https://files.pythonhosted.org/packages/c0/66/9cd4f501dd5ea03e4a4572ecd874936d0da296bd04d1c45ae1a4a75d9c3a/scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989", size = 33743202 }, { url = "https://files.pythonhosted.org/packages/a3/ba/7255e5dc82a65adbe83771c72f384d99c43063648456796436c9a5585ec3/scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f", size = 38577335 }, { url = "https://files.pythonhosted.org/packages/49/a5/bb9ded8326e9f0cdfdc412eeda1054b914dfea952bda2097d174f8832cc0/scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94", size = 38820728 }, - { url = "https://files.pythonhosted.org/packages/12/30/df7a8fcc08f9b4a83f5f27cfaaa7d43f9a2d2ad0b6562cced433e5b04e31/scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54", size = 46210588 }, - { url = "https://files.pythonhosted.org/packages/b4/15/4a4bb1b15bbd2cd2786c4f46e76b871b28799b67891f23f455323a0cdcfb/scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9", size = 39333805 }, - { url = "https://files.pythonhosted.org/packages/ba/92/42476de1af309c27710004f5cdebc27bec62c204db42e05b23a302cb0c9a/scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326", size = 30317687 }, { url = "https://files.pythonhosted.org/packages/80/ba/8be64fe225360a4beb6840f3cbee494c107c0887f33350d0a47d55400b01/scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299", size = 33694638 }, { url = "https://files.pythonhosted.org/packages/36/07/035d22ff9795129c5a847c64cb43c1fa9188826b59344fee28a3ab02e283/scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa", size = 38569931 }, { url = "https://files.pythonhosted.org/packages/d9/10/f9b43de37e5ed91facc0cfff31d45ed0104f359e4f9a68416cbf4e790241/scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59", size = 38838145 }, - { url = "https://files.pythonhosted.org/packages/4a/48/4513a1a5623a23e95f94abd675ed91cfb19989c58e9f6f7d03990f6caf3d/scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b", size = 46196227 }, - { url = "https://files.pythonhosted.org/packages/f2/7b/fb6b46fbee30fc7051913068758414f2721003a89dd9a707ad49174e3843/scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1", size = 39357301 }, - { url = "https://files.pythonhosted.org/packages/dc/5a/2043a3bde1443d94014aaa41e0b50c39d046dda8360abd3b2a1d3f79907d/scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d", size = 30363348 }, { url = "https://files.pythonhosted.org/packages/e7/cb/26e4a47364bbfdb3b7fb3363be6d8a1c543bcd70a7753ab397350f5f189a/scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627", size = 33406062 }, { url = "https://files.pythonhosted.org/packages/88/ab/6ecdc526d509d33814835447bbbeedbebdec7cca46ef495a61b00a35b4bf/scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884", size = 38218311 }, { url = "https://files.pythonhosted.org/packages/0b/00/9f54554f0f8318100a71515122d8f4f503b1a2c4b4cfab3b4b68c0eb08fa/scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16", size = 38442493 }, - { url = "https://files.pythonhosted.org/packages/3e/df/963384e90733e08eac978cd103c34df181d1fec424de383cdc443f418dd4/scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949", size = 45910955 }, - { url = "https://files.pythonhosted.org/packages/7f/29/c2ea58c9731b9ecb30b6738113a95d147e83922986b34c685b8f6eefde21/scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5", size = 39352927 }, - { url = "https://files.pythonhosted.org/packages/5c/c0/e71b94b20ccf9effb38d7147c0064c08c622309fd487b1b677771a97d18c/scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24", size = 30324538 }, { url = "https://files.pythonhosted.org/packages/6d/0f/aaa55b06d474817cea311e7b10aab2ea1fd5d43bc6a2861ccc9caec9f418/scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004", size = 33732190 }, { url = "https://files.pythonhosted.org/packages/35/f5/d0ad1a96f80962ba65e2ce1de6a1e59edecd1f0a7b55990ed208848012e0/scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d", size = 38612244 }, { url = "https://files.pythonhosted.org/packages/8d/02/1165905f14962174e6569076bcc3315809ae1291ed14de6448cc151eedfd/scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c", size = 38845637 }, - { url = "https://files.pythonhosted.org/packages/3e/77/dab54fe647a08ee4253963bcd8f9cf17509c8ca64d6335141422fe2e2114/scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2", size = 46227440 }, ] [[package]] @@ -1517,13 +1284,9 @@ version = "10.3.0" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/3b/68/eab45c46fdcafe08c6b21de4560fe2d3d845ce072d3e7743de4077c2d8c0/tensorrt_cu12_bindings-10.3.0-cp310-none-manylinux_2_17_x86_64.whl", hash = "sha256:1d6e4cf08ef1f54f6fd44a33cf6b253050af2fc6e9a1d92e40e1436a1d858eb0", size = 1108101 }, - { url = "https://files.pythonhosted.org/packages/a1/86/cb2c9cbd01fa1c5502899cd65df697a9d647a67725cf8c4901f174b1c6e6/tensorrt_cu12_bindings-10.3.0-cp310-none-win_amd64.whl", hash = "sha256:39aa0c2ee3dd20757f9d53e759092868a18155552a32f785844c5f66a2a6d3ba", size = 784666 }, { url = "https://files.pythonhosted.org/packages/4c/ce/47593af3fd15777ff48040da2901d539905c9bed3fc167d4368b0d4fcbf7/tensorrt_cu12_bindings-10.3.0-cp311-none-manylinux_2_17_x86_64.whl", hash = "sha256:59ace22d7f2ca1e9dcde2cb0cb5916912cb3cd5a9d72dd7852be0160d9b3a0ee", size = 1111069 }, - { url = "https://files.pythonhosted.org/packages/08/96/5e9f89e002800f04f0de4b01b4e2415dae3c8d53e84aaec6b9f1f7962fb7/tensorrt_cu12_bindings-10.3.0-cp311-none-win_amd64.whl", hash = "sha256:5582ece5578572a4a7aa3db69ba4cb2e2dcf1127570de1c334bba0182baec604", size = 784738 }, { url = "https://files.pythonhosted.org/packages/cd/1f/8215c8ff476bdc5f8d256413892ad48296df4277af077eefb9f7c0dcfeac/tensorrt_cu12_bindings-10.3.0-cp312-none-manylinux_2_17_x86_64.whl", hash = "sha256:f5c2582aeaa7f5628d2c4d4148a701ebe97be78f7ff3b46a617f0ee0cb5460f2", size = 1098829 }, - { url = "https://files.pythonhosted.org/packages/66/51/96b06a7dcb31418d31c6bb82f5d4b40ae196916c6288db8274bece9a33f9/tensorrt_cu12_bindings-10.3.0-cp312-none-win_amd64.whl", hash = "sha256:c4e8f2f5c7dd23b671fc6a456dbd3a0cdd13bca54d280f1b340dfcb8f73190f7", size = 788277 }, { url = "https://files.pythonhosted.org/packages/71/bf/32b901d844527fdfa5dbc7e57ac3ac10c48ce682254289f790a72faae162/tensorrt_cu12_bindings-10.3.0-cp39-none-manylinux_2_17_x86_64.whl", hash = "sha256:db337018c55043502eff993f165160044b4bebb935f01c8f8f93e4ee71481dc4", size = 1108759 }, - { url = "https://files.pythonhosted.org/packages/91/70/537007a74d4dbc643b9ca0b7fae7ba2dc8cf28fd7399609eccf5ca16490f/tensorrt_cu12_bindings-10.3.0-cp39-none-win_amd64.whl", hash = "sha256:ad629fd7a4c483af500d4ec4863f1a048f0fd0893b5449b48c832a11be7c72f6", size = 722940 }, ] [[package]] @@ -1544,8 +1307,6 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/48/04/2071c150f374aab6d5e92aaec38d0f3c368d227dd9e0469a1f0966ac68d1/tokenizers-0.19.1.tar.gz", hash = "sha256:ee59e6680ed0fdbe6b724cf38bd70400a0c1dd623b07ac729087270caeac88e3", size = 321039 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/60/91cac8d496b304ec5a22f07606893cad35ea8e1a8406dc8909e365f97a80/tokenizers-0.19.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:952078130b3d101e05ecfc7fc3640282d74ed26bcf691400f872563fca15ac97", size = 2533301 }, - { url = "https://files.pythonhosted.org/packages/4c/12/9cb68762ff5fee1efd51aefe2f62cb225f26f060a68a3779e1060bbc7a59/tokenizers-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:82c8b8063de6c0468f08e82c4e198763e7b97aabfe573fd4cf7b33930ca4df77", size = 2440223 }, { url = "https://files.pythonhosted.org/packages/e4/03/b2020e6a78fb994cff1ec962adc157c23109172a46b4fe451d6d0dd33fdb/tokenizers-0.19.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f03727225feaf340ceeb7e00604825addef622d551cbd46b7b775ac834c1e1c4", size = 3683779 }, { url = "https://files.pythonhosted.org/packages/50/4e/2e5549a26dc6f9e434f83bebf16c2d7dc9dc3477cc0ec8b23ede4d465b90/tokenizers-0.19.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:453e4422efdfc9c6b6bf2eae00d5e323f263fff62b29a8c9cd526c5003f3f642", size = 3569431 }, { url = "https://files.pythonhosted.org/packages/75/79/158626bd794e75551e0c6bb93f1cd3c9ba08ba14b181b98f09e95994f609/tokenizers-0.19.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:02e81bf089ebf0e7f4df34fa0207519f07e66d8491d963618252f2e0729e0b46", size = 3424739 }, @@ -1554,10 +1315,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/4f/eb78de4af3b17b589f43a369cbf0c3a7173f25c3d2cd93068852c07689aa/tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b01afb7193d47439f091cd8f070a1ced347ad0f9144952a30a41836902fe09e", size = 3607049 }, { url = "https://files.pythonhosted.org/packages/f5/f8/141dcb0f88e9452af8d20d14dd53aab5937222a2bb4f2c04bfed6829263c/tokenizers-0.19.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7fb297edec6c6841ab2e4e8f357209519188e4a59b557ea4fafcf4691d1b4c98", size = 9634084 }, { url = "https://files.pythonhosted.org/packages/2e/be/debb7caa3f88ed54015170db16e07aa3a5fea2d3983d0dde92f98d888dc8/tokenizers-0.19.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e8a3dd055e515df7054378dc9d6fa8c8c34e1f32777fb9a01fea81496b3f9d3", size = 9949480 }, - { url = "https://files.pythonhosted.org/packages/7a/e7/26bedf5d270d293d572a90bd66b0b030012aedb95d8ee87e8bcd446b76fb/tokenizers-0.19.1-cp310-none-win32.whl", hash = "sha256:7ff898780a155ea053f5d934925f3902be2ed1f4d916461e1a93019cc7250837", size = 2041462 }, - { url = "https://files.pythonhosted.org/packages/f4/85/d999b9a05fd101d48f1a365d68be0b109277bb25c89fb37a389d669f9185/tokenizers-0.19.1-cp310-none-win_amd64.whl", hash = "sha256:bea6f9947e9419c2fda21ae6c32871e3d398cba549b93f4a65a2d369662d9403", size = 2220036 }, - { url = "https://files.pythonhosted.org/packages/c8/d6/6e1d728d765eb4102767f071bf7f6439ab10d7f4a975c9217db65715207a/tokenizers-0.19.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:5c88d1481f1882c2e53e6bb06491e474e420d9ac7bdff172610c4f9ad3898059", size = 2533448 }, - { url = "https://files.pythonhosted.org/packages/90/79/d17a0f491d10817cd30f1121a07aa09c8e97a81114b116e473baf1577f09/tokenizers-0.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddf672ed719b4ed82b51499100f5417d7d9f6fb05a65e232249268f35de5ed14", size = 2440254 }, { url = "https://files.pythonhosted.org/packages/c7/28/2d11c3ff94f9d42eceb2ea549a06e3f166fe391c5a025e5d96fac898a3ac/tokenizers-0.19.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:dadc509cc8a9fe460bd274c0e16ac4184d0958117cf026e0ea8b32b438171594", size = 3684971 }, { url = "https://files.pythonhosted.org/packages/36/c6/537f22b57e6003904d35d07962dbde2f2e9bdd791d0241da976a4c7f8194/tokenizers-0.19.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfedf31824ca4915b511b03441784ff640378191918264268e6923da48104acc", size = 3568894 }, { url = "https://files.pythonhosted.org/packages/af/ef/3c1deed14ec59b2c8e7e2fa27b2a53f7d101181277a43b89ab17d891ef2e/tokenizers-0.19.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac11016d0a04aa6487b1513a3a36e7bee7eec0e5d30057c9c0408067345c48d2", size = 3426873 }, @@ -1566,10 +1323,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/03/fb50fc03f86016b227a967c8d474f90230c885c0d18f78acdfda7a96ce56/tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d16ff18907f4909dca9b076b9c2d899114dd6abceeb074eca0c93e2353f943aa", size = 3608228 }, { url = "https://files.pythonhosted.org/packages/5b/cd/0385e1026e1e03732fd398e964792a3a8433918b166748c82507e014d748/tokenizers-0.19.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:706a37cc5332f85f26efbe2bdc9ef8a9b372b77e4645331a405073e4b3a8c1c6", size = 9633115 }, { url = "https://files.pythonhosted.org/packages/25/50/8f8ad0bbdaf09d04b15e6502d1fa1c653754ed7e016e4ae009726aa1a4e4/tokenizers-0.19.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:16baac68651701364b0289979ecec728546133e8e8fe38f66fe48ad07996b88b", size = 9949062 }, - { url = "https://files.pythonhosted.org/packages/db/11/31be66710f1d14526f3588a441efadeb184e1e68458067007b20ead03c59/tokenizers-0.19.1-cp311-none-win32.whl", hash = "sha256:9ed240c56b4403e22b9584ee37d87b8bfa14865134e3e1c3fb4b2c42fafd3256", size = 2041039 }, - { url = "https://files.pythonhosted.org/packages/65/8e/6d7d72b28f22c422cff8beae10ac3c2e4376b9be721ef8167b7eecd1da62/tokenizers-0.19.1-cp311-none-win_amd64.whl", hash = "sha256:ad57d59341710b94a7d9dbea13f5c1e7d76fd8d9bcd944a7a6ab0b0da6e0cc66", size = 2220386 }, - { url = "https://files.pythonhosted.org/packages/63/90/2890cd096898dcdb596ee172cde40c0f54a9cf43b0736aa260a5501252af/tokenizers-0.19.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:621d670e1b1c281a1c9698ed89451395d318802ff88d1fc1accff0867a06f153", size = 2530580 }, - { url = "https://files.pythonhosted.org/packages/74/d1/f4e1e950adb36675dfd8f9d0f4be644f3f3aaf22a5677a4f5c81282b662e/tokenizers-0.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d924204a3dbe50b75630bd16f821ebda6a5f729928df30f582fb5aade90c818a", size = 2436682 }, { url = "https://files.pythonhosted.org/packages/ed/30/89b321a16c58d233e301ec15072c0d3ed5014825e72da98604cd3ab2fba1/tokenizers-0.19.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4f3fefdc0446b1a1e6d81cd4c07088ac015665d2e812f6dbba4a06267d1a2c95", size = 3693494 }, { url = "https://files.pythonhosted.org/packages/05/40/fa899f32de483500fbc78befd378fd7afba4270f17db707d1a78c0a4ddc3/tokenizers-0.19.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9620b78e0b2d52ef07b0d428323fb34e8ea1219c5eac98c2596311f20f1f9266", size = 3566541 }, { url = "https://files.pythonhosted.org/packages/67/14/e7da32ae5fb4971830f1ef335932fae3fa57e76b537e852f146c850aefdf/tokenizers-0.19.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04ce49e82d100594715ac1b2ce87d1a36e61891a91de774755f743babcd0dd52", size = 3430792 }, @@ -1578,10 +1331,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/80/54/12047a69f5b382d7ee72044dc89151a2dd0d13b2c9bdcc22654883704d31/tokenizers-0.19.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c9d5b6c0e7a1e979bec10ff960fae925e947aab95619a6fdb4c1d8ff3708ce3", size = 3610961 }, { url = "https://files.pythonhosted.org/packages/52/b7/1e8a913d18ac28feeda42d4d2d51781874398fb59cd1c1e2653a4b5742ed/tokenizers-0.19.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a179856d1caee06577220ebcfa332af046d576fb73454b8f4d4b0ba8324423ea", size = 9631367 }, { url = "https://files.pythonhosted.org/packages/ac/3d/2284f6d99f8f21d09352b88b8cfefa24ab88468d962aeb0aa15c20d76b32/tokenizers-0.19.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:952b80dac1a6492170f8c2429bd11fcaa14377e097d12a1dbe0ef2fb2241e16c", size = 9950121 }, - { url = "https://files.pythonhosted.org/packages/2a/94/ec3369dbc9b7200c14c8c7a1a04c78b7a7398d0c001e1b7d1ffe30eb93a0/tokenizers-0.19.1-cp312-none-win32.whl", hash = "sha256:01d62812454c188306755c94755465505836fd616f75067abcae529c35edeb57", size = 2044069 }, - { url = "https://files.pythonhosted.org/packages/0c/97/80bff6937e0c67d30c0facacd4f0bcf4254e581aa4995c73cef8c8640e56/tokenizers-0.19.1-cp312-none-win_amd64.whl", hash = "sha256:b70bfbe3a82d3e3fb2a5e9b22a39f8d1740c96c68b6ace0086b39074f08ab89a", size = 2214527 }, - { url = "https://files.pythonhosted.org/packages/1a/ed/42801618bab16c79d6bd222977c212dba5770e6c935ba53728b731653a3d/tokenizers-0.19.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0b9394bd204842a2a1fd37fe29935353742be4a3460b6ccbaefa93f58a8df43d", size = 2533937 }, - { url = "https://files.pythonhosted.org/packages/0a/2b/4e5718e806ff23e5e758e02bd4b34967b5218f085b0c189335fd27c14dc1/tokenizers-0.19.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4692ab92f91b87769d950ca14dbb61f8a9ef36a62f94bad6c82cc84a51f76f6a", size = 2440312 }, { url = "https://files.pythonhosted.org/packages/c5/28/ac2a277bd23b631e1ff986182c4fcb9028ccc7ff7c07743ef906fa5389e7/tokenizers-0.19.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6258c2ef6f06259f70a682491c78561d492e885adeaf9f64f5389f78aa49a051", size = 3686532 }, { url = "https://files.pythonhosted.org/packages/ba/26/139bd2371228a0e203da7b3e3eddcb02f45b2b7edd91df00e342e4b55e13/tokenizers-0.19.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c85cf76561fbd01e0d9ea2d1cbe711a65400092bc52b5242b16cfd22e51f0c58", size = 3570575 }, { url = "https://files.pythonhosted.org/packages/3b/6b/98383dff29416127c73dc196844ed23e29d790f1ad4b4ecf69d45e03841d/tokenizers-0.19.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:670b802d4d82bbbb832ddb0d41df7015b3e549714c0e77f9bed3e74d42400fbe", size = 3425806 }, @@ -1590,17 +1339,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0f/cb/8fc733c8f251bac1e5c4ae52458c353b3faa98f41d734c226cad3783da03/tokenizers-0.19.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c89aa46c269e4e70c4d4f9d6bc644fcc39bb409cb2a81227923404dd6f5227", size = 3608229 }, { url = "https://files.pythonhosted.org/packages/76/05/badd3a66571ad257270b38c33b9a7470afd2ae12e409c7c74baedf16f2ef/tokenizers-0.19.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:39c1ec76ea1027438fafe16ecb0fb84795e62e9d643444c1090179e63808c69d", size = 9634933 }, { url = "https://files.pythonhosted.org/packages/d9/46/97f8e84ba6a9133e34b148631d2933fda2a6ad8e0767b6e07ad0af9d83c2/tokenizers-0.19.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c2a0d47a89b48d7daa241e004e71fb5a50533718897a4cd6235cb846d511a478", size = 9950957 }, - { url = "https://files.pythonhosted.org/packages/81/b2/bf9a0f9136964df5e94dd9854ba071480c5425ff0db6d1ad9a6a8e683d55/tokenizers-0.19.1-cp39-none-win32.whl", hash = "sha256:61b7fe8886f2e104d4caf9218b157b106207e0f2a4905c9c7ac98890688aabeb", size = 2040628 }, - { url = "https://files.pythonhosted.org/packages/25/aa/c6992cdc0a74bcbb666e7c00ada6826f5b49fc4cbdafc50db0d1369503fe/tokenizers-0.19.1-cp39-none-win_amd64.whl", hash = "sha256:f97660f6c43efd3e0bfd3f2e3e5615bf215680bad6ee3d469df6454b8c6e8256", size = 2220919 }, - { url = "https://files.pythonhosted.org/packages/cf/7b/38fb7207cde3d1dc5272411cd18178e6437cdc1ef08cac5d0e8cfd57f38c/tokenizers-0.19.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3b11853f17b54c2fe47742c56d8a33bf49ce31caf531e87ac0d7d13d327c9334", size = 2532668 }, - { url = "https://files.pythonhosted.org/packages/1d/0d/2c452fe17fc17f0cdb713acb811eebb1f714b8c21d497c4672af4f491229/tokenizers-0.19.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d26194ef6c13302f446d39972aaa36a1dda6450bc8949f5eb4c27f51191375bd", size = 2438321 }, { url = "https://files.pythonhosted.org/packages/19/e0/f9e915d028b45798723eab59c253da28040aa66b9f31dcb7cfc3be88fa37/tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e8d1ed93beda54bbd6131a2cb363a576eac746d5c26ba5b7556bc6f964425594", size = 3682304 }, { url = "https://files.pythonhosted.org/packages/ce/2b/db8a94608c392752681c2ca312487b7cd5bcc4f77e24a90daa4916138271/tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca407133536f19bdec44b3da117ef0d12e43f6d4b56ac4c765f37eca501c7bda", size = 3566208 }, { url = "https://files.pythonhosted.org/packages/d8/58/2e998462677c4c0eb5123ce386bcb488a155664d273d0283122866515f09/tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce05fde79d2bc2e46ac08aacbc142bead21614d937aac950be88dc79f9db9022", size = 3605791 }, { url = "https://files.pythonhosted.org/packages/83/ac/26bc2e2bb2a054dc2e51699628936f5474e093b68da6ccdde04b2fc39ab8/tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:35583cd46d16f07c054efd18b5d46af4a2f070a2dd0a47914e66f3ff5efb2b1e", size = 9632867 }, { url = "https://files.pythonhosted.org/packages/45/b6/36c1bb106bbe96012c9367df89ed01599cada036c0b96d38fbbdbeb75c9f/tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:43350270bfc16b06ad3f6f07eab21f089adb835544417afda0f83256a8bf8b75", size = 9945103 }, - { url = "https://files.pythonhosted.org/packages/aa/9c/deed1e549b767832cc4ee5b386d1660bde3408bbd6d1ab48352fb61c54e2/tokenizers-0.19.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:56ae39d4036b753994476a1b935584071093b55c7a72e3b8288e68c313ca26e7", size = 2533737 }, - { url = "https://files.pythonhosted.org/packages/c8/59/4dbebca9ef6b61d10a94cbf404d3abf509dfedb52cdcf2fe7ed1fb52460d/tokenizers-0.19.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f9939ca7e58c2758c01b40324a59c034ce0cebad18e0d4563a9b1beab3018243", size = 2439981 }, { url = "https://files.pythonhosted.org/packages/72/42/e18b67ab9fd31e433171cf447d85bf5dede8009db04a46f3905bff5ca715/tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6c330c0eb815d212893c67a032e9dc1b38a803eccb32f3e8172c19cc69fbb439", size = 3683158 }, { url = "https://files.pythonhosted.org/packages/08/5c/54419545d61c085d7adcbd54f5711815ffbb1164d6132209172c984320be/tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec11802450a2487cdf0e634b750a04cbdc1c4d066b97d94ce7dd2cb51ebb325b", size = 3568486 }, { url = "https://files.pythonhosted.org/packages/6d/61/f8b59cc2580297ca78a7b5b2cefc8996b8417dc6cb9abb6a1d303973156b/tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2b718f316b596f36e1dae097a7d5b91fc5b85e90bf08b01ff139bd8953b25af", size = 3608836 }, @@ -1619,7 +1362,7 @@ wheels = [ [[package]] name = "torch" -version = "2.5.0.dev20240912+cu124" +version = "2.6.0.dev20240924+cu124" source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } dependencies = [ { name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, @@ -1644,24 +1387,20 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp310-cp310-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp310-cp310-linux_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp310-cp310-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp311-cp311-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp311-cp311-linux_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp311-cp311-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp312-cp312-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp312-cp312-linux_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp312-cp312-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp313-cp313-linux_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp39-cp39-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp39-cp39-linux_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240912%2Bcu124-cp39-cp39-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp310-cp310-linux_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp310-cp310-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp311-cp311-linux_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp311-cp311-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp312-cp312-linux_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp312-cp312-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp313-cp313-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp39-cp39-linux_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp39-cp39-linux_x86_64.whl" }, ] [[package]] name = "torch-tensorrt" -version = "2.5.0.dev0+4f32e93bb" +version = "2.6.0.dev0+0de0b1651" source = { editable = "." } dependencies = [ { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, @@ -1713,7 +1452,7 @@ requires-dist = [ { name = "tensorrt-cu12", specifier = "==10.3.0" }, { name = "tensorrt-cu12-bindings", specifier = "==10.3.0" }, { name = "tensorrt-cu12-libs", specifier = "==10.3.0" }, - { name = "torch", specifier = "<2.6.0,>=2.5.0.dev0" }, + { name = "torch", specifier = ">=2.6.0.dev0,<2.7.0" }, { name = "torchvision", marker = "extra == 'torchvision'" }, { name = "typing-extensions", specifier = ">=4.7.0" }, ] @@ -1759,11 +1498,8 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/48/20/380758a94be49d38798a6cfd25824f72ec1f230b00c0014efb15903777c6/torchvision-0.11.3-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:8bc8a7db80c97ca254be362ba883a202192e361ba2f6dff7ff5bb010d4bfc23a", size = 14675721 }, - { url = "https://files.pythonhosted.org/packages/59/33/eecbba97ef527f40b25f9cbdc54ddb4f057e4150698615a518f6a75dc546/torchvision-0.11.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3380211bf061d114c380f52fb33f55d2fbe483e2fd297f6aa596803f7cbdb408", size = 1187105 }, - { url = "https://files.pythonhosted.org/packages/4d/cb/d3bf0ffa1bdf83ee2fcd360f9794e48687831655cda1247eae4c7309e099/torchvision-0.11.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a3997b63bd8fac985323b6068e689c9617b0b36e1126616f7b380e17c501aefa", size = 585749 }, { url = "https://files.pythonhosted.org/packages/ac/b1/9702d02e233bec7ce231cc8be94489ee31084fb6d350703f0ed22086ebed/torchvision-0.11.3-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:eca0b0f7a0e462bdecf7926d89faae6dcd51da418ca0cf70e725981ed775a11b", size = 23199346 }, { url = "https://files.pythonhosted.org/packages/ac/d3/913e25d7775c74f76d174a82eba45bf68e384dc78373598f6c2b3a727fed/torchvision-0.11.3-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:25e72231be8ce03467a77806d9c3f5fd34b9cd23b9543d3e999bf57622377532", size = 14674764 }, - { url = "https://files.pythonhosted.org/packages/8b/68/5a976d601c11f527cb278dbd510521e3d6e192d7c5fd60471e64d1c84c25/torchvision-0.11.3-cp39-cp39-win_amd64.whl", hash = "sha256:5263770a9a91011206b3566b33bbba040b92932885c63cfe5ac9c720ed1fdaca", size = 947974 }, ] [[package]] @@ -1822,15 +1558,11 @@ version = "1.24.6" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/92/44/0d3f5a153919bc757573fe89200ae77609a440b1b774d04e5f816839ee58/typos-1.24.6.tar.gz", hash = "sha256:0feda2aab59fc1c32cd1f382ea8676b4ef0921086ab172a43e69e5bb19206993", size = 1107518 } wheels = [ - { url = "https://files.pythonhosted.org/packages/ad/6b/75bf4f3de20c5edc17919b55b592dffade154ccee7c65512b2f506514082/typos-1.24.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:be576cd0afcbf72bd0fa4129d457b146627c837db189eae7ee83b9fc311dacef", size = 3523299 }, - { url = "https://files.pythonhosted.org/packages/b8/d0/c4f711a402c938c87ea3ee2bc173bb6ecb1a9869b662245eb97c56571f18/typos-1.24.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:49fc10b7d28a6a016678c92a5b3d091ea46a2a7e09d5d1122045e8509378f785", size = 3447219 }, { url = "https://files.pythonhosted.org/packages/b1/a9/d63c5be9eb7a0105ae6b1257c6bae98595d52198b8c53359f6b93618d9f8/typos-1.24.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfbd7c40af229d680c2b9bc90e846eea70626bde9608f77a57c4e72145a5aa5f", size = 4893439 }, { url = "https://files.pythonhosted.org/packages/38/56/3aebbf2f950a15343396cbfc4773d4de5dc632867210051516cb4faef83f/typos-1.24.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8eb05826d6ff1f8747e1c7d9991a10e13b644b2eb7e2855cc79a37ebb1104f1", size = 3532218 }, { url = "https://files.pythonhosted.org/packages/3b/6f/d8ecddc82501b01a0e956173060c3b216bc9a7ee4c66d5b5c07dd02305c6/typos-1.24.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2abbe9dc208f6da9fddbf9bb281a3944d66188df9b3d43ad6f2f99721713446", size = 4215931 }, { url = "https://files.pythonhosted.org/packages/86/09/e0cf7945287e4d7a61403237a1df2def419734a6a8e41f06829187f96ee2/typos-1.24.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7291555c82e81e305ab3e10cb04d0f7d49ccecc1ced322c60f4619f6a14c7225", size = 3892579 }, { url = "https://files.pythonhosted.org/packages/1f/02/336eb3315e1d53a780d98511b465cd801be9151294bcbb50a603d503b5de/typos-1.24.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:7e27c307c26549a7986f2701f161358df29543e818bf9d6d81f0a81ca5ddeff5", size = 4224637 }, - { url = "https://files.pythonhosted.org/packages/71/b3/fafc840ea619798a7f98765c76eeb13e9227175509172f7a730cbd633396/typos-1.24.6-py3-none-win32.whl", hash = "sha256:cd725db3823c319f7e97b4e8e9fa4af143568b1c7d834f66c584bf86b9691f94", size = 2431615 }, - { url = "https://files.pythonhosted.org/packages/93/ec/2842da6226fe6c3c5a030364f65d99901c06507f6575c5f8250bc3602d9f/typos-1.24.6-py3-none-win_amd64.whl", hash = "sha256:12972e7a8be14fe5e7f0392de0b228a0098748959d1fecc35c4e8eab3efc04c0", size = 2568857 }, ] [[package]]