from __future__ import annotations

from copy import copy
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterator
from typing import Sequence

from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import remove_prefix

if TYPE_CHECKING:
    import pyarrow as pa
    import pyarrow.compute as pc
    from typing_extensions import Self

    from narwhals._arrow.dataframe import ArrowDataFrame
    from narwhals._arrow.series import ArrowSeries
    from narwhals._arrow.typing import IntoArrowExpr
    from narwhals.typing import CompliantExpr


def polars_to_arrow_aggregations() -> (
    dict[str, tuple[str, pc.VarianceOptions | pc.CountOptions | None]]
):
    """Map polars compute functions to their pyarrow counterparts and options that help match polars behaviour."""
    import pyarrow.compute as pc

    return {
        "sum": ("sum", None),
        "mean": ("mean", None),
        "median": ("approximate_median", None),
        "max": ("max", None),
        "min": ("min", None),
        "std": ("stddev", pc.VarianceOptions(ddof=1)),
        "var": (
            "variance",
            pc.VarianceOptions(ddof=1),
        ),  # currently unused, we don't have `var` yet
        "len": ("count", pc.CountOptions(mode="all")),
        "n_unique": ("count_distinct", pc.CountOptions(mode="all")),
        "count": ("count", pc.CountOptions(mode="only_valid")),
    }


class ArrowGroupBy:
    def __init__(
        self: Self, df: ArrowDataFrame, keys: list[str], *, drop_null_keys: bool
    ) -> None:
        import pyarrow as pa

        if drop_null_keys:
            self._df = df.drop_nulls(keys)
        else:
            self._df = df
        self._keys = list(keys)
        self._grouped = pa.TableGroupBy(self._df._native_frame, list(self._keys))

    def agg(
        self: Self,
        *aggs: IntoArrowExpr,
        **named_aggs: IntoArrowExpr,
    ) -> ArrowDataFrame:
        exprs = parse_into_exprs(
            *aggs,
            namespace=self._df.__narwhals_namespace__(),
            **named_aggs,
        )
        output_names: list[str] = copy(self._keys)
        for expr in exprs:
            if expr._output_names is None:
                msg = (
                    "Anonymous expressions are not supported in group_by.agg.\n"
                    "Instead of `nw.all()`, try using a named expression, such as "
                    "`nw.col('a', 'b')`\n"
                )
                raise ValueError(msg)
            output_names.extend(expr._output_names)

        return agg_arrow(
            self._grouped,
            exprs,
            self._keys,
            output_names,
            self._df._from_native_frame,
        )

    def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:
        import pyarrow as pa
        import pyarrow.compute as pc

        col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns)
        null_token = "__null_token_value__"  # noqa: S105

        table = self._df._native_frame
        key_values = pc.binary_join_element_wise(
            *[pc.cast(table[key], pa.string()) for key in self._keys],
            "",
            null_handling="replace",
            null_replacement=null_token,
        )
        table = table.add_column(i=0, field_=col_token, column=key_values)

        yield from (
            (
                next(
                    (
                        t := self._df._from_native_frame(
                            table.filter(pc.equal(table[col_token], v)).drop([col_token])
                        )
                    )
                    .select(*self._keys)
                    .head(1)
                    .iter_rows(named=False, buffer_size=512)
                ),
                t,
            )
            for v in pc.unique(key_values)
        )


def agg_arrow(
    grouped: pa.TableGroupBy,
    exprs: Sequence[CompliantExpr[ArrowSeries]],
    keys: list[str],
    output_names: list[str],
    from_dataframe: Callable[[Any], ArrowDataFrame],
) -> ArrowDataFrame:
    import pyarrow.compute as pc

    all_simple_aggs = True
    for expr in exprs:
        if not (
            is_simple_aggregation(expr)
            and remove_prefix(expr._function_name, "col->")
            in polars_to_arrow_aggregations()
        ):
            all_simple_aggs = False
            break

    if all_simple_aggs:
        # Mapping from output name to
        # (aggregation_args, pyarrow_output_name)  # noqa: ERA001
        simple_aggregations: dict[str, tuple[tuple[Any, ...], str]] = {}
        for expr in exprs:
            if expr._depth == 0:
                # e.g. agg(nw.len()) # noqa: ERA001
                if (
                    expr._output_names is None or expr._function_name != "len"
                ):  # pragma: no cover
                    msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
                    raise AssertionError(msg)
                simple_aggregations[expr._output_names[0]] = (
                    (keys[0], "count", pc.CountOptions(mode="all")),
                    f"{keys[0]}_count",
                )
                continue

            # e.g. agg(nw.mean('a')) # noqa: ERA001
            if (
                expr._depth != 1 or expr._root_names is None or expr._output_names is None
            ):  # pragma: no cover
                msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
                raise AssertionError(msg)

            function_name = remove_prefix(expr._function_name, "col->")
            function_name, option = polars_to_arrow_aggregations().get(
                function_name, (function_name, None)
            )

            for root_name, output_name in zip(expr._root_names, expr._output_names):
                simple_aggregations[output_name] = (
                    (root_name, function_name, option),
                    f"{root_name}_{function_name}",
                )

        aggs: list[Any] = []
        name_mapping = {}
        for output_name, (
            aggregation_args,
            pyarrow_output_name,
        ) in simple_aggregations.items():
            aggs.append(aggregation_args)
            name_mapping[pyarrow_output_name] = output_name
        result_simple = grouped.aggregate(aggs)
        result_simple = result_simple.rename_columns(
            [name_mapping.get(col, col) for col in result_simple.column_names]
        ).select(output_names)
        return from_dataframe(result_simple)

    msg = (
        "Non-trivial complex aggregation found.\n\n"
        "Hint: you were probably trying to apply a non-elementary aggregation with a "
        "pyarrow table.\n"
        "Please rewrite your query such that group-by aggregations "
        "are elementary. For example, instead of:\n\n"
        "    df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
        "use:\n\n"
        "    df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
    )
    raise ValueError(msg)
