diff --git a/manim/camera/camera.py b/manim/camera/camera.py index af5899c5c5..9da835a01c 100644 --- a/manim/camera/camera.py +++ b/manim/camera/camera.py @@ -10,17 +10,19 @@ import pathlib from collections.abc import Iterable from functools import reduce -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable import cairo import numpy as np from PIL import Image from scipy.spatial.distance import pdist +from typing_extensions import Self + +from manim.typing import PixelArray from .. import config, logger from ..constants import * from ..mobject.mobject import Mobject -from ..mobject.types.image_mobject import AbstractImageMobject from ..mobject.types.point_cloud_mobject import PMobject from ..mobject.types.vectorized_mobject import VMobject from ..utils.color import ManimColor, ParsableManimColor, color_to_int_rgba @@ -29,6 +31,10 @@ from ..utils.iterables import list_difference_update from ..utils.space_ops import angle_of_vector +if TYPE_CHECKING: + from ..mobject.types.image_mobject import AbstractImageMobject + + LINE_JOIN_MAP = { LineJointType.AUTO: None, # TODO: this could be improved LineJointType.ROUND: cairo.LineJoin.ROUND, @@ -84,8 +90,8 @@ def __init__( frame_rate: float | None = None, background_color: ParsableManimColor | None = None, background_opacity: float | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: self.background_image = background_image self.frame_center = frame_center self.image_mode = image_mode @@ -116,11 +122,13 @@ def __init__( self.frame_rate = frame_rate if background_color is None: - self._background_color = ManimColor.parse(config["background_color"]) + self._background_color: ManimColor = ManimColor.parse( + config["background_color"] + ) else: self._background_color = ManimColor.parse(background_color) if background_opacity is None: - self._background_opacity = config["background_opacity"] + self._background_opacity: float = config["background_opacity"] else: self._background_opacity = background_opacity @@ -129,7 +137,7 @@ def __init__( self.max_allowable_norm = config["frame_width"] self.rgb_max_val = np.iinfo(self.pixel_array_dtype).max - self.pixel_array_to_cairo_context = {} + self.pixel_array_to_cairo_context: dict[int, cairo.Context] = {} # Contains the correct method to process a list of Mobjects of the # corresponding class. If a Mobject is not an instance of a class in @@ -140,7 +148,7 @@ def __init__( self.resize_frame_shape() self.reset() - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Any) -> Camera: # This is to address a strange bug where deepcopying # will result in a segfault, which is somehow related # to the aggdraw library @@ -148,24 +156,26 @@ def __deepcopy__(self, memo): return copy.copy(self) @property - def background_color(self): + def background_color(self) -> ManimColor: return self._background_color @background_color.setter - def background_color(self, color): + def background_color(self, color: ManimColor) -> None: self._background_color = color self.init_background() @property - def background_opacity(self): + def background_opacity(self) -> float: return self._background_opacity @background_opacity.setter - def background_opacity(self, alpha): + def background_opacity(self, alpha: float) -> None: self._background_opacity = alpha self.init_background() - def type_or_raise(self, mobject: Mobject): + def type_or_raise( + self, mobject: Mobject + ) -> type[VMobject] | type[PMobject] | type[AbstractImageMobject] | type[Mobject]: """Return the type of mobject, if it is a type that can be rendered. If `mobject` is an instance of a class that inherits from a class that @@ -192,6 +202,8 @@ def type_or_raise(self, mobject: Mobject): :exc:`TypeError` When mobject is not an instance of a class that can be rendered. """ + from ..mobject.types.image_mobject import AbstractImageMobject + self.display_funcs = { VMobject: self.display_multiple_vectorized_mobjects, PMobject: self.display_multiple_point_cloud_mobjects, @@ -206,7 +218,7 @@ def type_or_raise(self, mobject: Mobject): return _type raise TypeError(f"Displaying an object of class {_type} is not supported") - def reset_pixel_shape(self, new_height: float, new_width: float): + def reset_pixel_shape(self, new_height: float, new_width: float) -> None: """This method resets the height and width of a single pixel to the passed new_height and new_width. @@ -223,7 +235,7 @@ def reset_pixel_shape(self, new_height: float, new_width: float): self.resize_frame_shape() self.reset() - def resize_frame_shape(self, fixed_dimension: int = 0): + def resize_frame_shape(self, fixed_dimension: int = 0) -> None: """ Changes frame_shape to match the aspect ratio of the pixels, where fixed_dimension determines @@ -248,7 +260,7 @@ def resize_frame_shape(self, fixed_dimension: int = 0): self.frame_height = frame_height self.frame_width = frame_width - def init_background(self): + def init_background(self) -> None: """Initialize the background. If self.background_image is the path of an image the image is set as background; else, the default @@ -274,7 +286,9 @@ def init_background(self): ) self.background[:, :] = background_rgba - def get_image(self, pixel_array: np.ndarray | list | tuple | None = None): + def get_image( + self, pixel_array: PixelArray | list | tuple | None = None + ) -> PixelArray: """Returns an image from the passed pixel array, or from the current frame if the passed pixel array is none. @@ -294,8 +308,8 @@ def get_image(self, pixel_array: np.ndarray | list | tuple | None = None): return Image.fromarray(pixel_array, mode=self.image_mode) def convert_pixel_array( - self, pixel_array: np.ndarray | list | tuple, convert_from_floats: bool = False - ): + self, pixel_array: PixelArray | list | tuple, convert_from_floats: bool = False + ) -> PixelArray: """Converts a pixel array from values that have floats in then to proper RGB values. @@ -321,8 +335,8 @@ def convert_pixel_array( return retval def set_pixel_array( - self, pixel_array: np.ndarray | list | tuple, convert_from_floats: bool = False - ): + self, pixel_array: PixelArray | list | tuple, convert_from_floats: bool = False + ) -> None: """Sets the pixel array of the camera to the passed pixel array. Parameters @@ -332,19 +346,21 @@ def set_pixel_array( convert_from_floats Whether or not to convert float values to proper RGB values, by default False """ - converted_array = self.convert_pixel_array(pixel_array, convert_from_floats) + converted_array: PixelArray = self.convert_pixel_array( + pixel_array, convert_from_floats + ) if not ( hasattr(self, "pixel_array") and self.pixel_array.shape == converted_array.shape ): - self.pixel_array = converted_array + self.pixel_array: PixelArray = converted_array else: # Set in place self.pixel_array[:, :, :] = converted_array[:, :, :] def set_background( - self, pixel_array: np.ndarray | list | tuple, convert_from_floats: bool = False - ): + self, pixel_array: PixelArray | list | tuple, convert_from_floats: bool = False + ) -> None: """Sets the background to the passed pixel_array after converting to valid RGB values. @@ -360,7 +376,7 @@ def set_background( # TODO, this should live in utils, not as a method of Camera def make_background_from_func( self, coords_to_colors_func: Callable[[np.ndarray], np.ndarray] - ): + ) -> PixelArray: """ Makes a pixel array for the background by using coords_to_colors_func to determine each pixel's color. Each input pixel's color. Each input to coords_to_colors_func is an (x, y) pair in space (in ordinary space coordinates; not @@ -386,7 +402,7 @@ def make_background_from_func( def set_background_from_func( self, coords_to_colors_func: Callable[[np.ndarray], np.ndarray] - ): + ) -> None: """ Sets the background to a pixel array using coords_to_colors_func to determine each pixel's color. Each input pixel's color. Each input to coords_to_colors_func is an (x, y) pair in space (in ordinary space coordinates; not @@ -400,7 +416,7 @@ def set_background_from_func( """ self.set_background(self.make_background_from_func(coords_to_colors_func)) - def reset(self): + def reset(self) -> Self: """Resets the camera's pixel array to that of the background @@ -412,7 +428,7 @@ def reset(self): self.set_pixel_array(self.background) return self - def set_frame_to_background(self, background): + def set_frame_to_background(self, background: PixelArray) -> None: self.set_pixel_array(background) #### @@ -422,7 +438,7 @@ def get_mobjects_to_display( mobjects: Iterable[Mobject], include_submobjects: bool = True, excluded_mobjects: list | None = None, - ): + ) -> list[Mobject]: """Used to get the list of mobjects to display with the camera. @@ -454,7 +470,7 @@ def get_mobjects_to_display( mobjects = list_difference_update(mobjects, all_excluded) return list(mobjects) - def is_in_frame(self, mobject: Mobject): + def is_in_frame(self, mobject: Mobject) -> bool: """Checks whether the passed mobject is in frame or not. @@ -481,7 +497,7 @@ def is_in_frame(self, mobject: Mobject): ], ) - def capture_mobject(self, mobject: Mobject, **kwargs: Any): + def capture_mobject(self, mobject: Mobject, **kwargs: Any) -> None: """Capture mobjects by storing it in :attr:`pixel_array`. This is a single-mobject version of :meth:`capture_mobjects`. @@ -497,7 +513,7 @@ def capture_mobject(self, mobject: Mobject, **kwargs: Any): """ return self.capture_mobjects([mobject], **kwargs) - def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs): + def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs: Any) -> None: """Capture mobjects by printing them on :attr:`pixel_array`. This is the essential function that converts the contents of a Scene @@ -525,14 +541,16 @@ def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs): # partition while at the same time preserving order. mobjects = self.get_mobjects_to_display(mobjects, **kwargs) for group_type, group in it.groupby(mobjects, self.type_or_raise): - self.display_funcs[group_type](list(group), self.pixel_array) + # TODO + # error: Call to untyped function (unknown) in typed context [no-untyped-call] + self.display_funcs[group_type](list(group), self.pixel_array) # type: ignore[no-untyped-call] # Methods associated with svg rendering # NOTE: None of the methods below have been mentioned outside of their definitions. Their DocStrings are not as # detailed as possible. - def get_cached_cairo_context(self, pixel_array: np.ndarray): + def get_cached_cairo_context(self, pixel_array: PixelArray) -> cairo.Context: """Returns the cached cairo context of the passed pixel array if it exists, and None if it doesn't. @@ -548,7 +566,7 @@ def get_cached_cairo_context(self, pixel_array: np.ndarray): """ return self.pixel_array_to_cairo_context.get(id(pixel_array), None) - def cache_cairo_context(self, pixel_array: np.ndarray, ctx: cairo.Context): + def cache_cairo_context(self, pixel_array: PixelArray, ctx: cairo.Context) -> None: """Caches the passed Pixel array into a Cairo Context Parameters @@ -560,7 +578,7 @@ def cache_cairo_context(self, pixel_array: np.ndarray, ctx: cairo.Context): """ self.pixel_array_to_cairo_context[id(pixel_array)] = ctx - def get_cairo_context(self, pixel_array: np.ndarray): + def get_cairo_context(self, pixel_array: PixelArray) -> cairo.Context: """Returns the cairo context for a pixel array after caching it to self.pixel_array_to_cairo_context If that array has already been cached, it returns the @@ -606,8 +624,8 @@ def get_cairo_context(self, pixel_array: np.ndarray): return ctx def display_multiple_vectorized_mobjects( - self, vmobjects: list, pixel_array: np.ndarray - ): + self, vmobjects: list[VMobject], pixel_array: PixelArray + ) -> None: """Displays multiple VMobjects in the pixel_array Parameters @@ -630,8 +648,8 @@ def display_multiple_vectorized_mobjects( ) def display_multiple_non_background_colored_vmobjects( - self, vmobjects: list, pixel_array: np.ndarray - ): + self, vmobjects: Iterable[VMobject], pixel_array: PixelArray + ) -> None: """Displays multiple VMobjects in the cairo context, as long as they don't have background colors. @@ -646,7 +664,7 @@ def display_multiple_non_background_colored_vmobjects( for vmobject in vmobjects: self.display_vectorized(vmobject, ctx) - def display_vectorized(self, vmobject: VMobject, ctx: cairo.Context): + def display_vectorized(self, vmobject: VMobject, ctx: cairo.Context) -> Self: """Displays a VMobject in the cairo context Parameters @@ -667,7 +685,7 @@ def display_vectorized(self, vmobject: VMobject, ctx: cairo.Context): self.apply_stroke(ctx, vmobject) return self - def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject): + def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject) -> Self: """Sets a path for the cairo context with the vmobject passed Parameters @@ -686,7 +704,9 @@ def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject): # TODO, shouldn't this be handled in transform_points_pre_display? # points = points - self.get_frame_center() if len(points) == 0: - return + # TODO: + # Here the return value is modified. Is that ok? + return self ctx.new_path() subpaths = vmobject.gen_subpaths_from_points_2d(points) @@ -702,8 +722,8 @@ def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject): return self def set_cairo_context_color( - self, ctx: cairo.Context, rgbas: np.ndarray, vmobject: VMobject - ): + self, ctx: cairo.Context, rgbas: PixelArray, vmobject: VMobject + ) -> Self: """Sets the color of the cairo context Parameters @@ -735,7 +755,7 @@ def set_cairo_context_color( ctx.set_source(pat) return self - def apply_fill(self, ctx: cairo.Context, vmobject: VMobject): + def apply_fill(self, ctx: cairo.Context, vmobject: VMobject) -> Self: """Fills the cairo context Parameters @@ -756,7 +776,7 @@ def apply_fill(self, ctx: cairo.Context, vmobject: VMobject): def apply_stroke( self, ctx: cairo.Context, vmobject: VMobject, background: bool = False - ): + ) -> Self: """Applies a stroke to the VMobject in the cairo context. Parameters @@ -795,7 +815,9 @@ def apply_stroke( ctx.stroke_preserve() return self - def get_stroke_rgbas(self, vmobject: VMobject, background: bool = False): + def get_stroke_rgbas( + self, vmobject: VMobject, background: bool = False + ) -> PixelArray: """Gets the RGBA array for the stroke of the passed VMobject. @@ -814,7 +836,7 @@ def get_stroke_rgbas(self, vmobject: VMobject, background: bool = False): """ return vmobject.get_stroke_rgbas(background) - def get_fill_rgbas(self, vmobject: VMobject): + def get_fill_rgbas(self, vmobject: VMobject) -> PixelArray: """Returns the RGBA array of the fill of the passed VMobject Parameters @@ -829,13 +851,15 @@ def get_fill_rgbas(self, vmobject: VMobject): """ return vmobject.get_fill_rgbas() - def get_background_colored_vmobject_displayer(self): + def get_background_colored_vmobject_displayer( + self, + ) -> BackgroundColoredVMobjectDisplayer: """Returns the background_colored_vmobject_displayer if it exists or makes one and returns it if not. Returns ------- - BackGroundColoredVMobjectDisplayer + BackgroundColoredVMobjectDisplayer Object that displays VMobjects that have the same color as the background. """ @@ -843,11 +867,11 @@ def get_background_colored_vmobject_displayer(self): bcvd = "background_colored_vmobject_displayer" if not hasattr(self, bcvd): setattr(self, bcvd, BackgroundColoredVMobjectDisplayer(self)) - return getattr(self, bcvd) + return getattr(self, bcvd) # type: ignore[no-any-return] def display_multiple_background_colored_vmobjects( - self, cvmobjects: list, pixel_array: np.ndarray - ): + self, cvmobjects: Iterable[VMobject], pixel_array: PixelArray + ) -> Self: """Displays multiple vmobjects that have the same color as the background. Parameters @@ -873,8 +897,8 @@ def display_multiple_background_colored_vmobjects( # As a result, the other methods do not have as detailed docstrings as would be preferred. def display_multiple_point_cloud_mobjects( - self, pmobjects: list, pixel_array: np.ndarray - ): + self, pmobjects: list, pixel_array: PixelArray + ) -> None: """Displays multiple PMobjects by modifying the passed pixel array. Parameters @@ -899,8 +923,8 @@ def display_point_cloud( points: list, rgbas: np.ndarray, thickness: float, - pixel_array: np.ndarray, - ): + pixel_array: PixelArray, + ) -> None: """Displays a PMobject by modifying the pixel array suitably. TODO: Write a description for the rgbas argument. @@ -948,7 +972,7 @@ def display_point_cloud( def display_multiple_image_mobjects( self, image_mobjects: list, pixel_array: np.ndarray - ): + ) -> None: """Displays multiple image mobjects by modifying the passed pixel_array. Parameters @@ -963,7 +987,7 @@ def display_multiple_image_mobjects( def display_image_mobject( self, image_mobject: AbstractImageMobject, pixel_array: np.ndarray - ): + ) -> None: """Displays an ImageMobject by changing the pixel_array suitably. Parameters @@ -1020,7 +1044,9 @@ def display_image_mobject( # Paint on top of existing pixel array self.overlay_PIL_image(pixel_array, full_image) - def overlay_rgba_array(self, pixel_array: np.ndarray, new_array: np.ndarray): + def overlay_rgba_array( + self, pixel_array: np.ndarray, new_array: np.ndarray + ) -> None: """Overlays an RGBA array on top of the given Pixel array. Parameters @@ -1032,7 +1058,7 @@ def overlay_rgba_array(self, pixel_array: np.ndarray, new_array: np.ndarray): """ self.overlay_PIL_image(pixel_array, self.get_image(new_array)) - def overlay_PIL_image(self, pixel_array: np.ndarray, image: Image): + def overlay_PIL_image(self, pixel_array: np.ndarray, image: Image) -> None: """Overlays a PIL image on the passed pixel array. Parameters @@ -1047,7 +1073,7 @@ def overlay_PIL_image(self, pixel_array: np.ndarray, image: Image): dtype="uint8", ) - def adjust_out_of_range_points(self, points: np.ndarray): + def adjust_out_of_range_points(self, points: np.ndarray) -> np.ndarray: """If any of the points in the passed array are out of the viable range, they are adjusted suitably. @@ -1078,9 +1104,10 @@ def adjust_out_of_range_points(self, points: np.ndarray): def transform_points_pre_display( self, - mobject, - points, - ): # TODO: Write more detailed docstrings for this method. + mobject: Mobject, + points: np.ndarray, + ) -> np.ndarray: + # TODO: Write more detailed docstrings for this method. # NOTE: There seems to be an unused argument `mobject`. # Subclasses (like ThreeDCamera) may want to @@ -1093,9 +1120,9 @@ def transform_points_pre_display( def points_to_pixel_coords( self, - mobject, - points, - ): # TODO: Write more detailed docstrings for this method. + mobject: Mobject, + points: np.ndarray, + ) -> np.ndarray: # TODO: Write more detailed docstrings for this method. points = self.transform_points_pre_display(mobject, points) shifted_points = points - self.frame_center @@ -1115,7 +1142,7 @@ def points_to_pixel_coords( result[:, 1] = shifted_points[:, 1] * height_mult + height_add return result.astype("int") - def on_screen_pixels(self, pixel_coords: np.ndarray): + def on_screen_pixels(self, pixel_coords: np.ndarray) -> PixelArray: """Returns array of pixels that are on the screen from a given array of pixel_coordinates @@ -1154,12 +1181,12 @@ def adjusted_thickness(self, thickness: float) -> float: the camera. """ # TODO: This seems...unsystematic - big_sum = op.add(config["pixel_height"], config["pixel_width"]) - this_sum = op.add(self.pixel_height, self.pixel_width) + big_sum: float = op.add(config["pixel_height"], config["pixel_width"]) + this_sum: float = op.add(self.pixel_height, self.pixel_width) factor = big_sum / this_sum return 1 + (thickness - 1) * factor - def get_thickening_nudges(self, thickness: float): + def get_thickening_nudges(self, thickness: float) -> PixelArray: """Determine a list of vectors used to nudge two-dimensional pixel coordinates. @@ -1176,7 +1203,9 @@ def get_thickening_nudges(self, thickness: float): _range = list(range(-thickness // 2 + 1, thickness // 2 + 1)) return np.array(list(it.product(_range, _range))) - def thickened_coordinates(self, pixel_coords: np.ndarray, thickness: float): + def thickened_coordinates( + self, pixel_coords: np.ndarray, thickness: float + ) -> PixelArray: """Returns thickened coordinates for a passed array of pixel coords and a thickness to thicken by. @@ -1198,7 +1227,7 @@ def thickened_coordinates(self, pixel_coords: np.ndarray, thickness: float): return pixel_coords.reshape((size // 2, 2)) # TODO, reimplement using cairo matrix - def get_coords_of_all_pixels(self): + def get_coords_of_all_pixels(self) -> PixelArray: """Returns the cartesian coordinates of each pixel. Returns @@ -1246,20 +1275,20 @@ class BackgroundColoredVMobjectDisplayer: def __init__(self, camera: Camera): self.camera = camera - self.file_name_to_pixel_array_map = {} + self.file_name_to_pixel_array_map: dict[str, PixelArray] = {} self.pixel_array = np.array(camera.pixel_array) self.reset_pixel_array() - def reset_pixel_array(self): + def reset_pixel_array(self) -> None: self.pixel_array[:, :] = 0 def resize_background_array( self, - background_array: np.ndarray, + background_array: PixelArray, new_width: float, new_height: float, mode: str = "RGBA", - ): + ) -> PixelArray: """Resizes the pixel array representing the background. Parameters @@ -1284,8 +1313,8 @@ def resize_background_array( return np.array(resized_image) def resize_background_array_to_match( - self, background_array: np.ndarray, pixel_array: np.ndarray - ): + self, background_array: PixelArray, pixel_array: PixelArray + ) -> PixelArray: """Resizes the background array to match the passed pixel array. Parameters @@ -1304,7 +1333,9 @@ def resize_background_array_to_match( mode = "RGBA" if pixel_array.shape[2] == 4 else "RGB" return self.resize_background_array(background_array, width, height, mode) - def get_background_array(self, image: Image.Image | pathlib.Path | str): + def get_background_array( + self, image: Image.Image | pathlib.Path | str + ) -> PixelArray: """Gets the background array that has the passed file_name. Parameters @@ -1333,7 +1364,7 @@ def get_background_array(self, image: Image.Image | pathlib.Path | str): self.file_name_to_pixel_array_map[image_key] = back_array return back_array - def display(self, *cvmobjects: VMobject): + def display(self, *cvmobjects: VMobject) -> PixelArray | None: """Displays the colored VMobjects. Parameters diff --git a/manim/camera/moving_camera.py b/manim/camera/moving_camera.py index 1d01d01e22..ba89ee2c22 100644 --- a/manim/camera/moving_camera.py +++ b/manim/camera/moving_camera.py @@ -10,6 +10,9 @@ __all__ = ["MovingCamera"] +from collections.abc import Iterable +from typing import Any + import numpy as np from .. import config @@ -17,7 +20,7 @@ from ..constants import DOWN, LEFT, RIGHT, UP from ..mobject.frame import ScreenRectangle from ..mobject.mobject import Mobject -from ..utils.color import WHITE +from ..utils.color import WHITE, ManimColor class MovingCamera(Camera): @@ -33,11 +36,11 @@ class MovingCamera(Camera): def __init__( self, frame=None, - fixed_dimension=0, # width - default_frame_stroke_color=WHITE, - default_frame_stroke_width=0, - **kwargs, - ): + fixed_dimension: int = 0, # width + default_frame_stroke_color: ManimColor = WHITE, + default_frame_stroke_width: int = 0, + **kwargs: Any, + ) -> None: """ Frame is a Mobject, (should almost certainly be a rectangle) determining which region of space the camera displays @@ -123,7 +126,7 @@ def frame_center(self, frame_center: np.ndarray | list | tuple | Mobject): """ self.frame.move_to(frame_center) - def capture_mobjects(self, mobjects, **kwargs): + def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs: Any) -> None: # self.reset_frame_center() # self.realign_frame_shape() super().capture_mobjects(mobjects, **kwargs) diff --git a/manim/camera/multi_camera.py b/manim/camera/multi_camera.py index a5202135e9..fe032e4668 100644 --- a/manim/camera/multi_camera.py +++ b/manim/camera/multi_camera.py @@ -5,7 +5,13 @@ __all__ = ["MultiCamera"] -from manim.mobject.types.image_mobject import ImageMobject +from collections.abc import Iterable +from typing import Any + +from typing_extensions import Self + +from manim.mobject.mobject import Mobject +from manim.mobject.types.image_mobject import ImageMobjectFromCamera from ..camera.moving_camera import MovingCamera from ..utils.iterables import list_difference_update @@ -16,10 +22,10 @@ class MultiCamera(MovingCamera): def __init__( self, - image_mobjects_from_cameras: ImageMobject | None = None, - allow_cameras_to_capture_their_own_display=False, - **kwargs, - ): + image_mobjects_from_cameras: Iterable[ImageMobjectFromCamera] | None = None, + allow_cameras_to_capture_their_own_display: bool = False, + **kwargs: Any, + ) -> None: """Initialises the MultiCamera Parameters @@ -29,7 +35,7 @@ def __init__( kwargs Any valid keyword arguments of MovingCamera. """ - self.image_mobjects_from_cameras = [] + self.image_mobjects_from_cameras: list[ImageMobjectFromCamera] = [] if image_mobjects_from_cameras is not None: for imfc in image_mobjects_from_cameras: self.add_image_mobject_from_camera(imfc) @@ -38,7 +44,9 @@ def __init__( ) super().__init__(**kwargs) - def add_image_mobject_from_camera(self, image_mobject_from_camera: ImageMobject): + def add_image_mobject_from_camera( + self, image_mobject_from_camera: ImageMobjectFromCamera + ) -> None: """Adds an ImageMobject that's been obtained from the camera into the list ``self.image_mobject_from_cameras`` @@ -53,11 +61,13 @@ def add_image_mobject_from_camera(self, image_mobject_from_camera: ImageMobject) assert isinstance(imfc.camera, MovingCamera) self.image_mobjects_from_cameras.append(imfc) - def update_sub_cameras(self): + def update_sub_cameras(self) -> None: """Reshape sub_camera pixel_arrays""" for imfc in self.image_mobjects_from_cameras: pixel_height, pixel_width = self.pixel_array.shape[:2] - imfc.camera.frame_shape = ( + # TODO: + # error: "MovingCamera" has no attribute "frame_shape" [attr-defined] + imfc.camera.frame_shape = ( # type: ignore[attr-defined] imfc.camera.frame.height, imfc.camera.frame.width, ) @@ -66,7 +76,7 @@ def update_sub_cameras(self): int(pixel_width * imfc.width / self.frame_width), ) - def reset(self): + def reset(self) -> Self: """Resets the MultiCamera. Returns @@ -79,7 +89,7 @@ def reset(self): super().reset() return self - def capture_mobjects(self, mobjects, **kwargs): + def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs: Any) -> None: self.update_sub_cameras() for imfc in self.image_mobjects_from_cameras: to_add = list(mobjects) @@ -88,7 +98,7 @@ def capture_mobjects(self, mobjects, **kwargs): imfc.camera.capture_mobjects(to_add, **kwargs) super().capture_mobjects(mobjects, **kwargs) - def get_mobjects_indicating_movement(self): + def get_mobjects_indicating_movement(self) -> list[Mobject]: """Returns all mobjects whose movement implies that the camera should think of all other mobjects on the screen as moving diff --git a/manim/mobject/graphing/coordinate_systems.py b/manim/mobject/graphing/coordinate_systems.py index b21879b90b..24bb42e8b8 100644 --- a/manim/mobject/graphing/coordinate_systems.py +++ b/manim/mobject/graphing/coordinate_systems.py @@ -153,11 +153,12 @@ def __init__( self.x_length = x_length self.y_length = y_length self.num_sampled_graph_points_per_tick = 10 + self.x_axis: NumberLine - def coords_to_point(self, *coords: ManimFloat): + def coords_to_point(self, *coords: ManimFloat) -> Point3DLike: raise NotImplementedError() - def point_to_coords(self, point: Point3DLike): + def point_to_coords(self, point: Point3DLike) -> list[ManimFloat]: raise NotImplementedError() def polar_to_point(self, radius: float, azimuth: float) -> Point2D: @@ -213,7 +214,7 @@ def c2p( """Abbreviation for :meth:`coords_to_point`""" return self.coords_to_point(*coords) - def p2c(self, point: Point3DLike): + def p2c(self, point: Point3DLike) -> list[ManimFloat]: """Abbreviation for :meth:`point_to_coords`""" return self.point_to_coords(point) @@ -221,17 +222,18 @@ def pr2pt(self, radius: float, azimuth: float) -> np.ndarray: """Abbreviation for :meth:`polar_to_point`""" return self.polar_to_point(radius, azimuth) - def pt2pr(self, point: np.ndarray) -> tuple[float, float]: + def pt2pr(self, point: np.ndarray) -> Point2D: """Abbreviation for :meth:`point_to_polar`""" return self.point_to_polar(point) - def get_axes(self): + def get_axes(self) -> VGroup: raise NotImplementedError() - def get_axis(self, index: int) -> Mobject: - return self.get_axes()[index] + def get_axis(self, index: int) -> NumberLine: + val: NumberLine = self.get_axes()[index] + return val - def get_origin(self) -> np.ndarray: + def get_origin(self) -> Point3DLike: """Gets the origin of :class:`~.Axes`. Returns @@ -241,13 +243,13 @@ def get_origin(self) -> np.ndarray: """ return self.coords_to_point(0, 0) - def get_x_axis(self) -> Mobject: + def get_x_axis(self) -> NumberLine: return self.get_axis(0) - def get_y_axis(self) -> Mobject: + def get_y_axis(self) -> NumberLine: return self.get_axis(1) - def get_z_axis(self) -> Mobject: + def get_z_axis(self) -> NumberLine: return self.get_axis(2) def get_x_unit_size(self) -> float: @@ -258,11 +260,11 @@ def get_y_unit_size(self) -> float: def get_x_axis_label( self, - label: float | str | Mobject, + label: float | str | VMobject, edge: Sequence[float] = UR, direction: Sequence[float] = UR, buff: float = SMALL_BUFF, - **kwargs, + **kwargs: Any, ) -> Mobject: """Generate an x-axis label. @@ -301,11 +303,11 @@ def construct(self): def get_y_axis_label( self, - label: float | str | Mobject, + label: float | str | VMobject, edge: Sequence[float] = UR, direction: Sequence[float] = UP * 0.5 + RIGHT, buff: float = SMALL_BUFF, - **kwargs, + **kwargs: Any, ) -> Mobject: """Generate a y-axis label. @@ -347,7 +349,7 @@ def construct(self): def _get_axis_label( self, - label: float | str | Mobject, + label: float | str | VMobject, axis: Mobject, edge: Sequence[float], direction: Sequence[float], @@ -373,12 +375,14 @@ def _get_axis_label( :class:`~.Mobject` The positioned label along the given axis. """ - label = self.x_axis._create_label_tex(label) - label.next_to(axis.get_edge_center(edge), direction=direction, buff=buff) - label.shift_onto_screen(buff=MED_SMALL_BUFF) - return label + label_mobject: Mobject = self.x_axis._create_label_tex(label) + label_mobject.next_to( + axis.get_edge_center(edge), direction=direction, buff=buff + ) + label_mobject.shift_onto_screen(buff=MED_SMALL_BUFF) + return label_mobject - def get_axis_labels(self): + def get_axis_labels(self) -> VGroup: raise NotImplementedError() def add_coordinates( @@ -552,7 +556,7 @@ def construct(self): """ return self.get_line_from_axis_to_point(0, point, **kwargs) - def get_horizontal_line(self, point: Sequence[float], **kwargs) -> Line: + def get_horizontal_line(self, point: Sequence[float], **kwargs: Any) -> Line: """A horizontal line from the y-axis to a given point in the scene. Parameters @@ -584,7 +588,7 @@ def construct(self): """ return self.get_line_from_axis_to_point(1, point, **kwargs) - def get_lines_to_point(self, point: Sequence[float], **kwargs) -> VGroup: + def get_lines_to_point(self, point: Sequence[float], **kwargs: Any) -> VGroup: """Generate both horizontal and vertical lines from the axis to a point. Parameters @@ -1093,7 +1097,7 @@ def i2gp(self, x: float, graph: ParametricFunction) -> np.ndarray: def get_graph_label( self, graph: ParametricFunction, - label: float | str | Mobject = "f(x)", + label: float | str | VMobject = "f(x)", x_val: float | None = None, direction: Sequence[float] = RIGHT, buff: float = MED_SMALL_BUFF, @@ -1150,7 +1154,7 @@ def construct(self): dot_config = {} if color is None: color = graph.get_color() - label = self.x_axis._create_label_tex(label).set_color(color) + label_object: Mobject = self.x_axis._create_label_tex(label).set_color(color) if x_val is None: # Search from right to left @@ -1161,14 +1165,14 @@ def construct(self): else: point = self.input_to_graph_point(x_val, graph) - label.next_to(point, direction, buff=buff) - label.shift_onto_screen() + label_object.next_to(point, direction, buff=buff) + label_object.shift_onto_screen() if dot: dot = Dot(point=point, **dot_config) - label.add(dot) - label.dot = dot - return label + label_object.add(dot) + label_object.dot = dot + return label_object # calculus @@ -1176,14 +1180,14 @@ def get_riemann_rectangles( self, graph: ParametricFunction, x_range: Sequence[float] | None = None, - dx: float | None = 0.1, + dx: float = 0.1, input_sample_type: str = "left", stroke_width: float = 1, stroke_color: ParsableManimColor = BLACK, fill_opacity: float = 1, color: Iterable[ParsableManimColor] | ParsableManimColor = (BLUE, GREEN), show_signed_area: bool = True, - bounded_graph: ParametricFunction = None, + bounded_graph: ParametricFunction | None = None, blend: bool = False, width_scale_factor: float = 1.001, ) -> VGroup: @@ -1277,16 +1281,16 @@ def construct(self): x_range = [*x_range[:2], dx] rectangles = VGroup() - x_range = np.arange(*x_range) + x_range_array = np.arange(*x_range) if isinstance(color, (list, tuple)): color = [ManimColor(c) for c in color] else: color = [ManimColor(color)] - colors = color_gradient(color, len(x_range)) + colors = color_gradient(color, len(x_range_array)) - for x, color in zip(x_range, colors): + for x, color in zip(x_range_array, colors): if input_sample_type == "left": sample_input = x elif input_sample_type == "right": @@ -1341,7 +1345,7 @@ def get_area( x_range: tuple[float, float] | None = None, color: ParsableManimColor | Iterable[ParsableManimColor] = (BLUE, GREEN), opacity: float = 0.3, - bounded_graph: ParametricFunction = None, + bounded_graph: ParametricFunction | None = None, **kwargs: Any, ) -> Polygon: """Returns a :class:`~.Polygon` representing the area under the graph passed. @@ -1485,10 +1489,14 @@ def slope_of_tangent( ax.slope_of_tangent(x=-2, graph=curve) # -3.5000000259052038 """ - return np.tan(self.angle_of_tangent(x, graph, **kwargs)) + val: float = np.tan(self.angle_of_tangent(x, graph, **kwargs)) + return val def plot_derivative_graph( - self, graph: ParametricFunction, color: ParsableManimColor = GREEN, **kwargs + self, + graph: ParametricFunction, + color: ParsableManimColor = GREEN, + **kwargs: Any, ) -> ParametricFunction: """Returns the curve of the derivative of the passed graph. @@ -1526,7 +1534,7 @@ def construct(self): self.add(ax, curves, labels) """ - def deriv(x): + def deriv(x: float) -> float: return self.slope_of_tangent(x, graph) return self.plot(deriv, color=color, **kwargs) @@ -1843,14 +1851,17 @@ def construct(self): return T_label_group - def __matmul__(self, coord: Point3DLike | Mobject): + def __matmul__(self, coord: Point3DLike | Mobject) -> Point3DLike: if isinstance(coord, Mobject): coord = coord.get_center() return self.coords_to_point(*coord) - def __rmatmul__(self, point: Point3DLike): + def __rmatmul__(self, point: Point3DLike) -> Point3DLike: return self.point_to_coords(point) + @staticmethod + def _origin_shift(axis_range: Sequence[float]) -> float: ... + class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL): """Creates a set of axes. @@ -1926,8 +1937,11 @@ def __init__( "include_tip": tips, "numbers_to_exclude": [0], } - self.x_axis_config = {} - self.y_axis_config = {"rotation": 90 * DEGREES, "label_direction": LEFT} + self.x_axis_config: dict[str, Any] = {} + self.y_axis_config: dict[str, Any] = { + "rotation": 90 * DEGREES, + "label_direction": LEFT, + } self._update_default_configs( (self.axis_config, self.x_axis_config, self.y_axis_config), @@ -2418,8 +2432,8 @@ def __init__( num_axis_pieces: int = 20, light_source: Sequence[float] = 9 * DOWN + 7 * LEFT + 10 * OUT, # opengl stuff (?) - depth=None, - gloss=0.5, + depth: Any = None, + gloss: float = 0.5, **kwargs: dict[str, Any], ) -> None: super().__init__( @@ -2433,7 +2447,7 @@ def __init__( self.z_range = z_range self.z_length = z_length - self.z_axis_config = {} + self.z_axis_config: dict[str, Any] = {} self._update_default_configs((self.z_axis_config,), (z_axis_config,)) self.z_axis_config = merge_dicts_recursively( self.axis_config, @@ -2500,13 +2514,13 @@ def make_func(axis): def get_y_axis_label( self, - label: float | str | Mobject, + label: float | str | VMobject, edge: Sequence[float] = UR, direction: Sequence[float] = UR, buff: float = SMALL_BUFF, rotation: float = PI / 2, rotation_axis: Vector3D = OUT, - **kwargs, + **kwargs: dict[str, Any], ) -> Mobject: """Generate a y-axis label. @@ -2550,7 +2564,7 @@ def construct(self): def get_z_axis_label( self, - label: float | str | Mobject, + label: float | str | VMobject, edge: Vector3D = OUT, direction: Vector3D = RIGHT, buff: float = SMALL_BUFF, @@ -2600,9 +2614,9 @@ def construct(self): def get_axis_labels( self, - x_label: float | str | Mobject = "x", - y_label: float | str | Mobject = "y", - z_label: float | str | Mobject = "z", + x_label: float | str | VMobject = "x", + y_label: float | str | VMobject = "y", + z_label: float | str | VMobject = "z", ) -> VGroup: """Defines labels for the x_axis and y_axis of the graph. @@ -2741,7 +2755,7 @@ def __init__( **kwargs: dict[str, Any], ): # configs - self.axis_config = { + self.axis_config: dict[str, Any] = { "stroke_width": 2, "include_ticks": False, "include_tip": False, @@ -2749,8 +2763,8 @@ def __init__( "label_direction": DR, "font_size": 24, } - self.y_axis_config = {"label_direction": DR} - self.background_line_style = { + self.y_axis_config: dict[str, Any] = {"label_direction": DR} + self.background_line_style: dict[str, Any] = { "stroke_color": BLUE_D, "stroke_width": 2, "stroke_opacity": 1, @@ -2997,7 +3011,7 @@ def __init__( size: float | None = None, radius_step: float = 1, azimuth_step: float | None = None, - azimuth_units: str | None = "PI radians", + azimuth_units: str = "PI radians", azimuth_compact_fraction: bool = True, azimuth_offset: float = 0, azimuth_direction: str = "CCW", @@ -3130,11 +3144,11 @@ def _get_lines(self) -> tuple[VGroup, VGroup]: unit_vector = self.x_axis.get_unit_vector()[0] for k, x in enumerate(rinput): - new_line = Circle(radius=x * unit_vector) + new_circle = Circle(radius=x * unit_vector) if k % ratio_faded_lines == 0: - alines1.add(new_line) + alines1.add(new_circle) else: - alines2.add(new_line) + alines2.add(new_circle) line = Line(center, self.get_x_axis().get_end()) @@ -3292,7 +3306,9 @@ def add_coordinates( self.add(self.get_coordinate_labels(r_values, a_values)) return self - def get_radian_label(self, number, font_size: float = 24, **kwargs: Any) -> MathTex: + def get_radian_label( + self, number: float, font_size: float = 24, **kwargs: Any + ) -> MathTex: constant_label = {"PI radians": r"\pi", "TAU radians": r"\tau"}[ self.azimuth_units ] diff --git a/manim/mobject/graphing/functions.py b/manim/mobject/graphing/functions.py index 83c48b1092..5cf406dd22 100644 --- a/manim/mobject/graphing/functions.py +++ b/manim/mobject/graphing/functions.py @@ -17,9 +17,12 @@ from manim.mobject.types.vectorized_mobject import VMobject if TYPE_CHECKING: + from typing import Any + from typing_extensions import Self from manim.typing import Point3D, Point3DLike + from manim.utils.color import ParsableManimColor from manim.utils.color import YELLOW @@ -111,7 +114,7 @@ def __init__( discontinuities: Iterable[float] | None = None, use_smoothing: bool = True, use_vectorized: bool = False, - **kwargs, + **kwargs: Any, ): def internal_parametric_function(t: float) -> Point3D: """Wrap ``function``'s output inside a NumPy array.""" @@ -143,13 +146,13 @@ def generate_points(self) -> Self: lambda t: self.t_min <= t <= self.t_max, self.discontinuities, ) - discontinuities = np.array(list(discontinuities)) + discontinuities_array = np.array(list(discontinuities)) boundary_times = np.array( [ self.t_min, self.t_max, - *(discontinuities - self.dt), - *(discontinuities + self.dt), + *(discontinuities_array - self.dt), + *(discontinuities_array + self.dt), ], ) boundary_times.sort() @@ -211,19 +214,29 @@ def construct(self): self.add(cos_func, sin_func_1, sin_func_2) """ - def __init__(self, function, x_range=None, color=YELLOW, **kwargs): + def __init__( + self, + function: Callable[[float], Any], + x_range: np.array | None = None, + color: ParsableManimColor = YELLOW, + **kwargs: Any, + ) -> None: if x_range is None: x_range = np.array([-config["frame_x_radius"], config["frame_x_radius"]]) self.x_range = x_range - self.parametric_function = lambda t: np.array([t, function(t), 0]) - self.function = function + self.parametric_function: Callable[[float], np.array] = lambda t: np.array( + [t, function(t), 0] + ) + # TODO: + # error: Incompatible types in assignment (expression has type "Callable[[float], Any]", variable has type "Callable[[Arg(float, 't')], Any]") [assignment] + self.function = function # type: ignore[assignment] super().__init__(self.parametric_function, self.x_range, color=color, **kwargs) - def get_function(self): + def get_function(self) -> Callable[[float], Any]: return self.function - def get_point_from_function(self, x): + def get_point_from_function(self, x: float) -> np.array: return self.parametric_function(x) @@ -236,8 +249,8 @@ def __init__( min_depth: int = 5, max_quads: int = 1500, use_smoothing: bool = True, - **kwargs, - ): + **kwargs: Any, + ) -> None: """An implicit function. Parameters @@ -295,7 +308,7 @@ def construct(self): super().__init__(**kwargs) - def generate_points(self): + def generate_points(self) -> Self: p_min, p_max = ( np.array([self.x_range[0], self.y_range[0]]), np.array([self.x_range[1], self.y_range[1]]), diff --git a/manim/mobject/graphing/number_line.py b/manim/mobject/graphing/number_line.py index 017fac5bcb..2f96a89712 100644 --- a/manim/mobject/graphing/number_line.py +++ b/manim/mobject/graphing/number_line.py @@ -12,6 +12,10 @@ from typing import TYPE_CHECKING, Callable if TYPE_CHECKING: + from typing import Any + + from typing_extensions import Self + from manim.mobject.geometry.tips import ArrowTip from manim.typing import Point3DLike @@ -158,14 +162,14 @@ def __init__( include_numbers: bool = False, font_size: float = 36, label_direction: Sequence[float] = DOWN, - label_constructor: VMobject = MathTex, + label_constructor: type[MathTex] = MathTex, scaling: _ScaleBase = LinearBase(), line_to_number_buff: float = MED_SMALL_BUFF, decimal_number_config: dict | None = None, numbers_to_exclude: Iterable[float] | None = None, numbers_to_include: Iterable[float] | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: # avoid mutable arguments in defaults if numbers_to_exclude is None: numbers_to_exclude = [] @@ -189,6 +193,9 @@ def __init__( # turn into a NumPy array to scale by just applying the function self.x_range = np.array(x_range, dtype=float) + self.x_min: float + self.x_max: float + self.x_step: float self.x_min, self.x_max, self.x_step = scaling.function(self.x_range) self.length = length self.unit_size = unit_size @@ -250,7 +257,9 @@ def __init__( dict( zip( tick_range, - self.scaling.get_custom_labels( + # TODO: + # Argument 2 to "zip" has incompatible type "Iterable[Mobject]"; expected "Iterable[str | float | VMobject]" [arg-type] + self.scaling.get_custom_labels( # type: ignore[arg-type] tick_range, unit_decimal_places=decimal_number_config[ "num_decimal_places" @@ -267,21 +276,25 @@ def __init__( font_size=self.font_size, ) - def rotate_about_zero(self, angle: float, axis: Sequence[float] = OUT, **kwargs): + def rotate_about_zero( + self, angle: float, axis: Sequence[float] = OUT, **kwargs: Any + ) -> Self: return self.rotate_about_number(0, angle, axis, **kwargs) def rotate_about_number( - self, number: float, angle: float, axis: Sequence[float] = OUT, **kwargs - ): + self, number: float, angle: float, axis: Sequence[float] = OUT, **kwargs: Any + ) -> Self: return self.rotate(angle, axis, about_point=self.n2p(number), **kwargs) - def add_ticks(self): + def add_ticks(self) -> None: """Adds ticks to the number line. Ticks can be accessed after creation via ``self.ticks``. """ ticks = VGroup() elongated_tick_size = self.tick_size * self.longer_tick_multiple - elongated_tick_offsets = self.numbers_with_elongated_ticks - self.x_min + elongated_tick_offsets = ( + np.array(self.numbers_with_elongated_ticks) - self.x_min + ) for x in self.get_tick_range(): size = self.tick_size if np.any(np.isclose(x - self.x_min, elongated_tick_offsets)): @@ -413,19 +426,22 @@ def point_to_number(self, point: Sequence[float]) -> float: point = np.asarray(point) start, end = self.get_start_and_end() unit_vect = normalize(end - start) - proportion = np.dot(point - start, unit_vect) / np.dot(end - start, unit_vect) + proportion: float = np.dot(point - start, unit_vect) / np.dot( + end - start, unit_vect + ) return interpolate(self.x_min, self.x_max, proportion) def n2p(self, number: float | np.ndarray) -> np.ndarray: """Abbreviation for :meth:`~.NumberLine.number_to_point`.""" return self.number_to_point(number) - def p2n(self, point: Sequence[float]) -> float: + def p2n(self, point: Point3DLike) -> float: """Abbreviation for :meth:`~.NumberLine.point_to_number`.""" return self.point_to_number(point) def get_unit_size(self) -> float: - return self.get_length() / (self.x_range[1] - self.x_range[0]) + val: float = self.get_length() / (self.x_range[1] - self.x_range[0]) + return val def get_unit_vector(self) -> np.ndarray: return super().get_unit_vector() * self.unit_size @@ -436,8 +452,8 @@ def get_number_mobject( direction: Sequence[float] | None = None, buff: float | None = None, font_size: float | None = None, - label_constructor: VMobject | None = None, - **number_config, + label_constructor: type[MathTex] | None = None, + **number_config: dict[str, Any], ) -> VMobject: """Generates a positioned :class:`~.DecimalNumber` mobject generated according to ``label_constructor``. @@ -476,7 +492,14 @@ def get_number_mobject( label_constructor = self.label_constructor num_mob = DecimalNumber( - x, font_size=font_size, mob_class=label_constructor, **number_config + # TODO: + # error: Argument 4 to "DecimalNumber" has incompatible type "**dict[str, dict[str, Any]]"; expected "int" [arg-type] + x, + font_size=font_size, + # TODO + # error: Argument "mob_class" to "DecimalNumber" has incompatible type "type[MathTex]"; expected "VMobject" [arg-type] + mob_class=label_constructor, # type: ignore[arg-type] + **number_config, # type: ignore[arg-type] ) num_mob.next_to(self.number_to_point(x), direction=direction, buff=buff) @@ -485,7 +508,7 @@ def get_number_mobject( num_mob.shift(num_mob[0].width * LEFT / 2) return num_mob - def get_number_mobjects(self, *numbers, **kwargs) -> VGroup: + def get_number_mobjects(self, *numbers: float, **kwargs: Any) -> VGroup: if len(numbers) == 0: numbers = self.default_numbers_to_display() return VGroup([self.get_number_mobject(number, **kwargs) for number in numbers]) @@ -498,9 +521,9 @@ def add_numbers( x_values: Iterable[float] | None = None, excluding: Iterable[float] | None = None, font_size: float | None = None, - label_constructor: VMobject | None = None, - **kwargs, - ): + label_constructor: type[MathTex] | None = None, + **kwargs: Any, + ) -> Self: """Adds :class:`~.DecimalNumber` mobjects representing their position at each tick of the number line. The numbers can be accessed after creation via ``self.numbers``. @@ -551,11 +574,11 @@ def add_numbers( def add_labels( self, dict_values: dict[float, str | float | VMobject], - direction: Sequence[float] = None, + direction: Sequence[float] | None = None, buff: float | None = None, font_size: float | None = None, - label_constructor: VMobject | None = None, - ): + label_constructor: type[MathTex] | None = None, + ) -> Self: """Adds specifically positioned labels to the :class:`~.NumberLine` using a ``dict``. The labels can be accessed after creation via ``self.labels``. @@ -593,12 +616,17 @@ def add_labels( # this method via CoordinateSystem.add_coordinates() # must be explicitly called if isinstance(label, str) and label_constructor is MathTex: - label = Tex(label) + # TODO + # error: Call to untyped function "Tex" in typed context [no-untyped-call] + label = Tex(label) # type: ignore[no-untyped-call] else: label = self._create_label_tex(label, label_constructor) if hasattr(label, "font_size"): - label.font_size = font_size + # assert isinstance(label, MathTex) + # TODO + # error: "VMobject" has no attribute "font_size" [attr-defined] + label.font_size = font_size # type: ignore[attr-defined] else: raise AttributeError(f"{label} is not compatible with add_labels.") label.next_to(self.number_to_point(x), direction=direction, buff=buff) @@ -612,7 +640,7 @@ def _create_label_tex( self, label_tex: str | float | VMobject, label_constructor: Callable | None = None, - **kwargs, + **kwargs: Any, ) -> VMobject: """Checks if the label is a :class:`~.VMobject`, otherwise, creates a label by passing ``label_tex`` to ``label_constructor``. @@ -633,24 +661,25 @@ def _create_label_tex( :class:`~.VMobject` The label. """ - if label_constructor is None: - label_constructor = self.label_constructor if isinstance(label_tex, (VMobject, OpenGLVMobject)): return label_tex - else: + if label_constructor is None: + label_constructor = self.label_constructor + if isinstance(label_tex, str): return label_constructor(label_tex, **kwargs) + return label_constructor(str(label_tex), **kwargs) @staticmethod - def _decimal_places_from_step(step) -> int: - step = str(step) - if "." not in step: + def _decimal_places_from_step(step: float) -> int: + step_str = str(step) + if "." not in step_str: return 0 - return len(step.split(".")[-1]) + return len(step_str.split(".")[-1]) - def __matmul__(self, other: float): + def __matmul__(self, other: float) -> np.ndarray: return self.n2p(other) - def __rmatmul__(self, other: Point3DLike | Mobject): + def __rmatmul__(self, other: Point3DLike | Mobject) -> float: if isinstance(other, Mobject): other = other.get_center() return self.p2n(other) @@ -659,11 +688,11 @@ def __rmatmul__(self, other: Point3DLike | Mobject): class UnitInterval(NumberLine): def __init__( self, - unit_size=10, - numbers_with_elongated_ticks=None, - decimal_number_config=None, - **kwargs, - ): + unit_size: float = 10, + numbers_with_elongated_ticks: list[float] | None = None, + decimal_number_config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: numbers_with_elongated_ticks = ( [0, 1] if numbers_with_elongated_ticks is None diff --git a/manim/mobject/graphing/probability.py b/manim/mobject/graphing/probability.py index 24134c0a7a..b8440596e5 100644 --- a/manim/mobject/graphing/probability.py +++ b/manim/mobject/graphing/probability.py @@ -6,6 +6,7 @@ from collections.abc import Iterable, MutableSequence, Sequence +from typing import Any import numpy as np @@ -17,7 +18,8 @@ from manim.mobject.opengl.opengl_mobject import OpenGLMobject from manim.mobject.svg.brace import Brace from manim.mobject.text.tex_mobject import MathTex, Tex -from manim.mobject.types.vectorized_mobject import VGroup, VMobject +from manim.mobject.types.vectorized_mobject import VGroup +from manim.typing import Vector3D from manim.utils.color import ( BLUE_E, DARK_GREY, @@ -54,14 +56,14 @@ def construct(self): def __init__( self, - height=3, - width=3, - fill_color=DARK_GREY, - fill_opacity=1, - stroke_width=0.5, - stroke_color=LIGHT_GREY, - default_label_scale_val=1, - ): + height: float = 3, + width: float = 3, + fill_color: ParsableManimColor = DARK_GREY, + fill_opacity: float = 1, + stroke_width: float = 0.5, + stroke_color: ParsableManimColor = LIGHT_GREY, + default_label_scale_val: float = 1, + ) -> None: super().__init__( height=height, width=width, @@ -72,32 +74,43 @@ def __init__( ) self.default_label_scale_val = default_label_scale_val - def add_title(self, title="Sample space", buff=MED_SMALL_BUFF): + def add_title( + self, title: str = "Sample space", buff: float = MED_SMALL_BUFF + ) -> None: # TODO, should this really exist in SampleSpaceScene - title_mob = Tex(title) + # TODO + # error: Call to untyped function "Tex" in typed context [no-untyped-call] + title_mob = Tex(title) # type: ignore[no-untyped-call] if title_mob.width > self.width: title_mob.width = self.width title_mob.next_to(self, UP, buff=buff) self.title = title_mob self.add(title_mob) - def add_label(self, label): + def add_label(self, label: str) -> None: self.label = label - def complete_p_list(self, p_list): + def complete_p_list(self, p_list: list[tuple]) -> list[float]: new_p_list = list(tuplify(p_list)) remainder = 1.0 - sum(new_p_list) if abs(remainder) > EPSILON: new_p_list.append(remainder) return new_p_list - def get_division_along_dimension(self, p_list, dim, colors, vect): - p_list = self.complete_p_list(p_list) - colors = color_gradient(colors, len(p_list)) + def get_division_along_dimension( + self, + p_list: list[tuple], + dim: int, + colors: Sequence[ParsableManimColor], + vect: Vector3D, + ) -> VGroup: + p_list_complete = self.complete_p_list(p_list) + colors_in_gradient = color_gradient(colors, len(p_list)) + assert isinstance(colors_in_gradient, list) last_point = self.get_edge_center(-vect) parts = VGroup() - for factor, color in zip(p_list, colors): + for factor, color in zip(p_list_complete, colors_in_gradient): part = SampleSpace() part.set_fill(color, 1) part.replace(self, stretch=True) @@ -107,28 +120,42 @@ def get_division_along_dimension(self, p_list, dim, colors, vect): parts.add(part) return parts - def get_horizontal_division(self, p_list, colors=[GREEN_E, BLUE_E], vect=DOWN): + def get_horizontal_division( + self, + p_list: list[tuple], + colors: Sequence[ParsableManimColor] = [GREEN_E, BLUE_E], + vect: Vector3D = DOWN, + ) -> VGroup: return self.get_division_along_dimension(p_list, 1, colors, vect) - def get_vertical_division(self, p_list, colors=[MAROON_B, YELLOW], vect=RIGHT): + def get_vertical_division( + self, + p_list: list[tuple], + colors: Sequence[ParsableManimColor] = [MAROON_B, YELLOW], + vect: Vector3D = RIGHT, + ) -> VGroup: return self.get_division_along_dimension(p_list, 0, colors, vect) - def divide_horizontally(self, *args, **kwargs): + # TODO: + # error: Function is missing a type annotation for one or more arguments [no-untyped-def] + def divide_horizontally(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] self.horizontal_parts = self.get_horizontal_division(*args, **kwargs) self.add(self.horizontal_parts) - def divide_vertically(self, *args, **kwargs): + # TODO: + # error: Function is missing a type annotation for one or more arguments [no-untyped-def] + def divide_vertically(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] self.vertical_parts = self.get_vertical_division(*args, **kwargs) self.add(self.vertical_parts) def get_subdivision_braces_and_labels( self, - parts, - labels, - direction, - buff=SMALL_BUFF, - min_num_quads=1, - ): + parts: VGroup, + labels: list[str | Mobject | OpenGLMobject], + direction: Vector3D, + buff: float = SMALL_BUFF, + min_num_quads: int = 1, + ) -> VGroup: label_mobs = VGroup() braces = VGroup() for label, part in zip(labels, parts): @@ -141,34 +168,47 @@ def get_subdivision_braces_and_labels( label_mob.next_to(brace, direction, buff) braces.add(brace) - label_mobs.add(label_mob) - parts.braces = braces - parts.labels = label_mobs - parts.label_kwargs = { + # TODO: + # error: Argument 1 to "add" of "VGroup" has incompatible type "Mobject | OpenGLMobject"; expected "VMobject | Iterable[VMobject]" [arg-type] + label_mobs.add(label_mob) # type: ignore[arg-type] + # TODO: + # error: "VGroup" has no attribute "braces" [attr-defined] + parts.braces = braces # type: ignore[attr-defined] + parts.labels = label_mobs # type: ignore[attr-defined] + parts.label_kwargs = { # type: ignore[attr-defined] "labels": label_mobs.copy(), "direction": direction, "buff": buff, } return VGroup(parts.braces, parts.labels) - def get_side_braces_and_labels(self, labels, direction=LEFT, **kwargs): + def get_side_braces_and_labels( + self, + labels: list[str | Mobject | OpenGLMobject], + direction: Vector3D = LEFT, + **kwargs: Any, + ) -> VGroup: assert hasattr(self, "horizontal_parts") parts = self.horizontal_parts return self.get_subdivision_braces_and_labels( parts, labels, direction, **kwargs ) - def get_top_braces_and_labels(self, labels, **kwargs): + def get_top_braces_and_labels( + self, labels: list[str | Mobject | OpenGLMobject], **kwargs: Any + ) -> VGroup: assert hasattr(self, "vertical_parts") parts = self.vertical_parts return self.get_subdivision_braces_and_labels(parts, labels, UP, **kwargs) - def get_bottom_braces_and_labels(self, labels, **kwargs): + def get_bottom_braces_and_labels( + self, labels: list[str | Mobject | OpenGLMobject], **kwargs: Any + ) -> VGroup: assert hasattr(self, "vertical_parts") parts = self.vertical_parts return self.get_subdivision_braces_and_labels(parts, labels, DOWN, **kwargs) - def add_braces_and_labels(self): + def add_braces_and_labels(self) -> None: for attr in "horizontal_parts", "vertical_parts": if not hasattr(self, attr): continue @@ -177,12 +217,16 @@ def add_braces_and_labels(self): if hasattr(parts, subattr): self.add(getattr(parts, subattr)) - def __getitem__(self, index): + def __getitem__(self, index: int) -> VGroup: if hasattr(self, "horizontal_parts"): - return self.horizontal_parts[index] + val: VGroup = self.horizontal_parts[index] + return val elif hasattr(self, "vertical_parts"): - return self.vertical_parts[index] - return self.split()[index] + val = self.vertical_parts[index] + return val + # TODO: + # error: Incompatible return value type (got "SampleSpace", expected "VGroup") [return-value] + return self.split()[index] # type: ignore[return-value] class BarChart(Axes): @@ -253,8 +297,8 @@ def __init__( bar_width: float = 0.6, bar_fill_opacity: float = 0.7, bar_stroke_width: float = 3, - **kwargs, - ): + **kwargs: Any, + ) -> None: if isinstance(bar_colors, str): logger.warning( "Passing a string to `bar_colors` has been deprecated since v0.15.2 and will be removed after v0.17.0, the parameter must be a list. " @@ -311,7 +355,7 @@ def __init__( self.y_axis.add_numbers() - def _update_colors(self): + def _update_colors(self) -> None: """Initialize the colors of the bars of the chart. Sets the color of ``self.bars`` via ``self.bar_colors``. @@ -321,13 +365,14 @@ def _update_colors(self): """ self.bars.set_color_by_gradient(*self.bar_colors) - def _add_x_axis_labels(self): + def _add_x_axis_labels(self) -> None: """Essentially :meth`:~.NumberLine.add_labels`, but differs in that the direction of the label with respect to the x_axis changes to UP or DOWN depending on the value. UP for negative values and DOWN for positive values. """ + assert isinstance(self.bar_names, list) val_range = np.arange( 0.5, len(self.bar_names), 1 ) # 0.5 shifted so that labels are centered, not on ticks @@ -338,7 +383,7 @@ def _add_x_axis_labels(self): # to accommodate negative bars, the label may need to be # below or above the x_axis depending on the value of the bar direction = UP if self.values[i] < 0 else DOWN - bar_name_label = self.x_axis.label_constructor(bar_name) + bar_name_label: MathTex = self.x_axis.label_constructor(bar_name) bar_name_label.font_size = self.x_axis.font_size bar_name_label.next_to( @@ -398,8 +443,8 @@ def get_bar_labels( color: ParsableManimColor | None = None, font_size: float = 24, buff: float = MED_SMALL_BUFF, - label_constructor: type[VMobject] = Tex, - ): + label_constructor: type[MathTex] = Tex, + ) -> VGroup: """Annotates each bar with its corresponding value. Use ``self.bar_labels`` to access the labels after creation. @@ -431,7 +476,7 @@ def construct(self): """ bar_labels = VGroup() for bar, value in zip(self.bars, self.values): - bar_lbl = label_constructor(str(value)) + bar_lbl: MathTex = label_constructor(str(value)) if color is None: bar_lbl.set_color(bar.get_fill_color()) @@ -446,7 +491,9 @@ def construct(self): return bar_labels - def change_bar_values(self, values: Iterable[float], update_colors: bool = True): + def change_bar_values( + self, values: Iterable[float], update_colors: bool = True + ) -> None: """Updates the height of the bars of the chart. Parameters @@ -512,4 +559,4 @@ def construct(self): if update_colors: self._update_colors() - self.values[: len(values)] = values + self.values[: len(list(values))] = values diff --git a/manim/mobject/graphing/scale.py b/manim/mobject/graphing/scale.py index 78ffa2308b..56a02d65b8 100644 --- a/manim/mobject/graphing/scale.py +++ b/manim/mobject/graphing/scale.py @@ -2,7 +2,7 @@ import math from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload import numpy as np @@ -11,6 +11,8 @@ from manim.mobject.text.numbers import Integer if TYPE_CHECKING: + from typing import Callable + from manim.mobject.mobject import Mobject @@ -26,6 +28,12 @@ class _ScaleBase: def __init__(self, custom_labels: bool = False): self.custom_labels = custom_labels + @overload + def function(self, value: float) -> float: ... + + @overload + def function(self, value: np.array) -> np.array: ... + def function(self, value: float) -> float: """The function that will be used to scale the values. @@ -59,6 +67,7 @@ def inverse_function(self, value: float) -> float: def get_custom_labels( self, val_range: Iterable[float], + **kw_args: Any, ) -> Iterable[Mobject]: """Custom instructions for generating labels along an axis. @@ -139,15 +148,19 @@ def __init__(self, base: float = 10, custom_labels: bool = True): def function(self, value: float) -> float: """Scales the value to fit it to a logarithmic scale.``self.function(5)==10**5``""" - return self.base**value + val: float = self.base**value + return val def inverse_function(self, value: float) -> float: """Inverse of ``function``. The value must be greater than 0""" if isinstance(value, np.ndarray): condition = value.any() <= 0 - def func(value, base): - return np.log(value) / np.log(base) + func: Callable[[float, float], float] + + def func(value: float, base: float) -> float: + val: float = np.log(value) / np.log(base) + return val else: condition = value <= 0 func = math.log @@ -177,11 +190,13 @@ def get_custom_labels( Additional arguments to be passed to :class:`~.Integer`. """ # uses `format` syntax to control the number of decimal places. - tex_labels = [ + tex_labels: list[Mobject] = [ Integer( self.base, unit="^{%s}" % (f"{self.inverse_function(i):.{unit_decimal_places}f}"), # noqa: UP031 - **base_config, + # TODO: + # error: Argument 3 to "Integer" has incompatible type "**dict[str, dict[str, Any]]"; expected "int" [arg-type] + **base_config, # type: ignore[arg-type] ) for i in val_range ] diff --git a/manim/mobject/text/tex_mobject.py b/manim/mobject/text/tex_mobject.py index 26334a60d9..c3a03e5746 100644 --- a/manim/mobject/text/tex_mobject.py +++ b/manim/mobject/text/tex_mobject.py @@ -1,3 +1,6 @@ +# The following line is needed to avoid some strange +# mypy errors related to the code around line 366. +# mypy: disable_error_code = has-type r"""Mobjects representing text rendered using LaTeX. .. important:: diff --git a/manim/mobject/types/image_mobject.py b/manim/mobject/types/image_mobject.py index 56029f941e..f2fb9d649f 100644 --- a/manim/mobject/types/image_mobject.py +++ b/manim/mobject/types/image_mobject.py @@ -14,6 +14,7 @@ from manim.mobject.geometry.shape_matchers import SurroundingRectangle from ... import config +from ...camera.moving_camera import MovingCamera from ...constants import * from ...mobject.mobject import Mobject from ...utils.bezier import interpolate @@ -28,7 +29,9 @@ import numpy.typing as npt from typing_extensions import Self - from manim.typing import StrPath + from manim.typing import PixelArray, StrPath + + from ...camera.moving_camera import MovingCamera class AbstractImageMobject(Mobject): @@ -57,7 +60,7 @@ def __init__( self.set_resampling_algorithm(resampling_algorithm) super().__init__(**kwargs) - def get_pixel_array(self) -> None: + def get_pixel_array(self) -> PixelArray: raise NotImplementedError() def set_color(self, color, alpha=None, family=True): @@ -303,7 +306,7 @@ def get_style(self) -> dict[str, Any]: class ImageMobjectFromCamera(AbstractImageMobject): def __init__( self, - camera, + camera: MovingCamera, default_display_frame_config: dict[str, Any] | None = None, **kwargs: Any, ) -> None: diff --git a/mypy.ini b/mypy.ini index 80571869be..4d81f2c7b9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -58,12 +58,30 @@ ignore_errors = True [mypy-manim.camera.*] ignore_errors = True +[mypy-manim.camera.camera.*] +ignore_errors = False + +[mypy-manim.camera.multi_camera.*] +ignore_errors = False + [mypy-manim.cli.*] ignore_errors = False [mypy-manim.cli.cfg.*] ignore_errors = False +[mypy-manim.mobject.graphing.scale.*] +ignore_errors = False + +[mypy-manim.mobject.graphing.functions.*] +ignore_errors = False + +[mypy-manim.mobject.graphing.number_line.*] +ignore_errors = False + +[mypy-manim.mobject.graphing.probability.*] +ignore_errors = False + [mypy-manim.gui.*] ignore_errors = True