Skip to content

Add overloads for the mod function to support various input types#15577

Draft
hunterhogan wants to merge 1 commit intopython:mainfrom
hunterhogan:operator.mod
Draft

Add overloads for the mod function to support various input types#15577
hunterhogan wants to merge 1 commit intopython:mainfrom
hunterhogan:operator.mod

Conversation

@hunterhogan
Copy link
Copy Markdown
Contributor

I made a stupid script to prove to myself I had the right overloads

import contextlib
import random
from itertools import product as CartesianProduct  # noqa: N812
from operator import mod, neg
from types import EllipsisType, NotImplementedType
from typing import Any

listNonExceptions: list[tuple[Any, Any]] = []  # noqa: N816
typeMODtypeIStype: set[tuple[str, str, str]] = set()  # noqa: N816

numbers: set[int | float | complex] = {
    0,
    1,
    2,
    5,
    10,
    17,
    1024,
    0.5,
    2.5,
    3.5,
    23.301,
    1e-6,
    1e100,
    float("inf"),
    float("-inf"),
    float("nan"),
    47 - 29j,
    3 + 4j,
    0 + 0j,
    0 + 1j,
    1 + 0j,
    1 - 1j,
}

atoms: set[int | float | complex | bool | str | EllipsisType | NotImplementedType | None] = {
    *numbers,
    *map(neg, numbers),
    True,
    False,
    "inf",
    "-inf",
    "nan",
    "",
    "a",
    "mn",
    "xyz",
    None,
    ...,
    NotImplemented,
}

tupleLengths: list[int] = random.sample(range(2, len(atoms)), 5)  # noqa: N816

constructors = {
    bool,
    int,
    float,
    complex,
    str,
    set,
    frozenset,
    bytes,
    bytearray,
    list,
    tuple,
    range,
    dict,
    type,
    memoryview,
    object,
    slice,
}

test_values: list[Any] = [*atoms]

for constructor, atom in CartesianProduct(constructors, atoms):
    with contextlib.suppress(Exception):
        test_values.append(constructor(atom))  # pyright: ignore[reportCallIssue]

for constructor, tupleLength in CartesianProduct(constructors, tupleLengths):  # noqa: N816
    with contextlib.suppress(Exception):
        test_values.append(
            constructor(random.sample(list(atoms), tupleLength))  # pyright: ignore[reportArgumentType, reportCallIssue]
        )

for a, b in CartesianProduct(test_values, repeat=2):
    try:
        r = mod(a, b)
        listNonExceptions.append((a, b))
        if (
            not (isinstance(b, dict) and (len(b) == 0))  # pyright: ignore[reportUnknownArgumentType]
            and not (isinstance(b, tuple) and (len(b) == 0))  # pyright: ignore[reportUnknownArgumentType]
            and not (isinstance(b, list) and (len(b) == 0))  # pyright: ignore[reportUnknownArgumentType]
            and not (isinstance(a, bytes) and isinstance(b, (list, range)) and (a == r))
            and not (isinstance(a, bytearray) and isinstance(b, (list, range)) and (a == r))
            and not (isinstance(a, str) and isinstance(b, (list, range, bytes, bytearray)) and (a == r))
        ):
            typeMODtypeIStype.add(
                (type(a).__name__, type(b).__name__, type(r).__name__)  # pyright: ignore[reportUnknownArgumentType]
            )
    except Exception:
        continue

print(len(listNonExceptions), len(typeMODtypeIStype))

print("\n".join(f"@overload\ndef mod(a: {entry[0]}, b: {entry[1]}, /) -> {entry[2]}: ..." for entry in sorted(typeMODtypeIStype)))

@github-actions
Copy link
Copy Markdown
Contributor

Diff from mypy_primer, showing the effect of this PR on open source code:

static-frame (https://github.com/static-frame/static-frame)
+ static_frame/core/util.py:601: error: Dict entry 10 has incompatible type "str": overloaded function; expected "str": "Callable[..., ndarray[Any, Any]]"  [dict-item]

jax (https://github.com/google/jax)
+ jax/experimental/mosaic/gpu/dialect_lowering.py:1368: error: Argument "impl" to "_binary_op_lowering_rule" has incompatible type "function"; expected "Callable[[FragmentedArray, FragmentedArray], FragmentedArray]"  [arg-type]

Copy link
Copy Markdown
Collaborator

@srittau srittau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think using concrete overloads is the correct solution here. mod(x, y) basically does x % y, so it will also work with any type that implements a custom __mod__() method. We should use a protocol here, similar to how it's done with the comparison operators in the same file.

@hunterhogan hunterhogan marked this pull request as draft April 3, 2026 03:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants