Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions crates/core/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ fn array_cat(exprs: Vec<PyExpr>) -> PyExpr {
array_concat(exprs)
}

#[pyfunction]
fn make_map(keys: Vec<PyExpr>, values: Vec<PyExpr>) -> PyExpr {
let keys = keys.into_iter().map(|x| x.into()).collect();
let values = values.into_iter().map(|x| x.into()).collect();
datafusion::functions_nested::map::map(keys, values).into()
}

#[pyfunction]
#[pyo3(signature = (array, element, index=None))]
fn array_position(array: PyExpr, element: PyExpr, index: Option<i64>) -> PyExpr {
Expand Down Expand Up @@ -666,6 +673,12 @@ array_fn!(cardinality, array);
array_fn!(flatten, array);
array_fn!(range, start stop step);

// Map Functions
array_fn!(map_keys, map);
array_fn!(map_values, map);
array_fn!(map_extract, map key);
array_fn!(map_entries, map);

aggregate_function!(array_agg);
aggregate_function!(max);
aggregate_function!(min);
Expand Down Expand Up @@ -1126,6 +1139,13 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(flatten))?;
m.add_wrapped(wrap_pyfunction!(cardinality))?;

// Map Functions
m.add_wrapped(wrap_pyfunction!(make_map))?;
m.add_wrapped(wrap_pyfunction!(map_keys))?;
m.add_wrapped(wrap_pyfunction!(map_values))?;
m.add_wrapped(wrap_pyfunction!(map_extract))?;
m.add_wrapped(wrap_pyfunction!(map_entries))?;

// Window Functions
m.add_wrapped(wrap_pyfunction!(lead))?;
m.add_wrapped(wrap_pyfunction!(lag))?;
Expand Down
158 changes: 158 additions & 0 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
"degrees",
"dense_rank",
"digest",
"element_at",
"empty",
"encode",
"ends_with",
Expand Down Expand Up @@ -202,7 +203,12 @@
"make_array",
"make_date",
"make_list",
"make_map",
"make_time",
"map_entries",
"map_extract",
"map_keys",
"map_values",
"max",
"md5",
"mean",
Expand Down Expand Up @@ -3374,6 +3380,158 @@ def empty(array: Expr) -> Expr:
return array_empty(array)


# map functions


def make_map(*args: Any) -> Expr:
"""Returns a map expression.

Supports three calling conventions:

- ``make_map({"a": 1, "b": 2})`` — from a Python dictionary.
- ``make_map([keys], [values])`` — from a list of keys and a list of
their associated values. Both lists must be the same length.
- ``make_map(k1, v1, k2, v2, ...)`` — from alternating keys and their
associated values.

Keys and values that are not already :py:class:`~datafusion.expr.Expr`
are automatically converted to literal expressions.

Examples:
From a dictionary:

>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1]})
>>> result = df.select(
... dfn.functions.make_map({"a": 1, "b": 2}).alias("m"))
>>> result.collect_column("m")[0].as_py()
[('a', 1), ('b', 2)]

From two lists:

>>> df = ctx.from_pydict({"key": ["x", "y"], "val": [10, 20]})
>>> df = df.select(
... dfn.functions.make_map(
... [dfn.col("key")], [dfn.col("val")]
... ).alias("m"))
>>> df.collect_column("m")[0].as_py()
[('x', 10)]

From alternating keys and values:

>>> df = ctx.from_pydict({"a": [1]})
>>> result = df.select(
... dfn.functions.make_map("x", 1, "y", 2).alias("m"))
>>> result.collect_column("m")[0].as_py()
[('x', 1), ('y', 2)]
"""
if len(args) == 1 and isinstance(args[0], dict):
key_list = list(args[0].keys())
value_list = list(args[0].values())
elif (
len(args) == 2 # noqa: PLR2004
and isinstance(args[0], list)
and isinstance(args[1], list)
):
if len(args[0]) != len(args[1]):
msg = "make_map requires key and value lists to be the same length"
raise ValueError(msg)
key_list = args[0]
value_list = args[1]
elif len(args) >= 2 and len(args) % 2 == 0: # noqa: PLR2004
key_list = list(args[0::2])
value_list = list(args[1::2])
else:
msg = (
"make_map expects a dict, two lists, or an even number of "
"key-value arguments"
)
raise ValueError(msg)

key_exprs = [k if isinstance(k, Expr) else Expr.literal(k) for k in key_list]
val_exprs = [v if isinstance(v, Expr) else Expr.literal(v) for v in value_list]
return Expr(f.make_map([k.expr for k in key_exprs], [v.expr for v in val_exprs]))


def map_keys(map: Expr) -> Expr:
"""Returns a list of all keys in the map.

Examples:
>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1]})
>>> df = df.select(
... dfn.functions.make_map({"x": 1, "y": 2}).alias("m"))
>>> result = df.select(
... dfn.functions.map_keys(dfn.col("m")).alias("keys"))
>>> result.collect_column("keys")[0].as_py()
['x', 'y']
"""
return Expr(f.map_keys(map.expr))


def map_values(map: Expr) -> Expr:
"""Returns a list of all values in the map.

Examples:
>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1]})
>>> df = df.select(
... dfn.functions.make_map({"x": 1, "y": 2}).alias("m"))
>>> result = df.select(
... dfn.functions.map_values(dfn.col("m")).alias("vals"))
>>> result.collect_column("vals")[0].as_py()
[1, 2]
"""
return Expr(f.map_values(map.expr))


def map_extract(map: Expr, key: Expr) -> Expr:
"""Returns the value for a given key in the map.

Returns ``[None]`` if the key is absent.

Examples:
>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1]})
>>> df = df.select(
... dfn.functions.make_map({"x": 1, "y": 2}).alias("m"))
>>> result = df.select(
... dfn.functions.map_extract(
... dfn.col("m"), dfn.lit("x")
... ).alias("val"))
>>> result.collect_column("val")[0].as_py()
[1]
"""
return Expr(f.map_extract(map.expr, key.expr))


def map_entries(map: Expr) -> Expr:
"""Returns a list of all entries (key-value struct pairs) in the map.

Examples:
>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1]})
>>> df = df.select(
... dfn.functions.make_map({"x": 1, "y": 2}).alias("m"))
>>> result = df.select(
... dfn.functions.map_entries(dfn.col("m")).alias("entries"))
>>> result.collect_column("entries")[0].as_py()
[{'key': 'x', 'value': 1}, {'key': 'y', 'value': 2}]
"""
return Expr(f.map_entries(map.expr))


def element_at(map: Expr, key: Expr) -> Expr:
"""Returns the value for a given key in the map.

Returns ``[None]`` if the key is absent.

See Also:
This is an alias for :py:func:`map_extract`.
"""
return map_extract(map, key)


# aggregate functions
def approx_distinct(
expression: Expr,
Expand Down
100 changes: 100 additions & 0 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,106 @@ def test_array_function_obj_tests(stmt, py_expr):
assert a == b


@pytest.mark.parametrize(
("args", "expected"),
[
pytest.param(
({"x": 1, "y": 2},),
[("x", 1), ("y", 2)],
id="dict",
),
pytest.param(
({"x": literal(1), "y": literal(2)},),
[("x", 1), ("y", 2)],
id="dict_with_exprs",
),
pytest.param(
("x", 1, "y", 2),
[("x", 1), ("y", 2)],
id="variadic_pairs",
),
pytest.param(
(literal("x"), literal(1), literal("y"), literal(2)),
[("x", 1), ("y", 2)],
id="variadic_with_exprs",
),
],
)
def test_make_map(args, expected):
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
df = ctx.create_dataframe([[batch]])

result = df.select(f.make_map(*args).alias("m")).collect()[0].column(0)
assert result[0].as_py() == expected


def test_make_map_from_two_lists():
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays(
[
pa.array(["k1", "k2", "k3"]),
pa.array([10, 20, 30]),
],
names=["keys", "vals"],
)
df = ctx.create_dataframe([[batch]])

m = f.make_map([column("keys")], [column("vals")])
result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0)
assert result.to_pylist() == [["k1"], ["k2"], ["k3"]]

result = df.select(f.map_values(m).alias("v")).collect()[0].column(0)
assert result.to_pylist() == [[10], [20], [30]]


def test_make_map_odd_args_raises():
with pytest.raises(ValueError, match="make_map expects"):
f.make_map("x", 1, "y")


def test_make_map_mismatched_lengths():
with pytest.raises(ValueError, match="same length"):
f.make_map(["a", "b"], [1])


@pytest.mark.parametrize(
("func", "expected"),
[
pytest.param(f.map_keys, ["x", "y"], id="map_keys"),
pytest.param(f.map_values, [1, 2], id="map_values"),
pytest.param(
lambda m: f.map_extract(m, literal("x")),
[1],
id="map_extract",
),
pytest.param(
lambda m: f.map_extract(m, literal("z")),
[None],
id="map_extract_missing_key",
),
pytest.param(
f.map_entries,
[{"key": "x", "value": 1}, {"key": "y", "value": 2}],
id="map_entries",
),
pytest.param(
lambda m: f.element_at(m, literal("y")),
[2],
id="element_at",
),
],
)
def test_map_functions(func, expected):
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
df = ctx.create_dataframe([[batch]])

m = f.make_map({"x": 1, "y": 2})
result = df.select(func(m).alias("out")).collect()[0].column(0)
assert result[0].as_py() == expected


@pytest.mark.parametrize(
("function", "expected_result"),
[
Expand Down
Loading