import operator
from abc import ABC
from functools import reduce
from typing import AbstractSet, FrozenSet, Optional, Sequence, Union
import dagster._check as check
from dagster._annotations import public
from dagster._core.errors import DagsterInvalidSubsetError
from dagster._core.selector.subset_selector import (
    fetch_connected,
    generate_asset_dep_graph,
    generate_asset_name_to_definition_map,
)
from .assets import AssetsDefinition
from .events import AssetKey, CoercibleToAssetKey
from .source_asset import SourceAsset
[docs]class AssetSelection(ABC):
    """
    An AssetSelection defines a query over a set of assets, normally all the assets in a repository.
    You can use the "|" and "&" operators to create unions and intersections of asset selections,
    respectively.
    AssetSelections are typically used with :py:func:`define_asset_job`.
    Examples:
        .. code-block:: python
            # Select all assets in group "marketing":
            AssetSelection.groups("marketing")
            # Select all assets in group "marketing", as well as the asset with key "promotion":
            AssetSelection.groups("marketing") | AssetSelection.keys("promotion")
            # Select all assets in group "marketing" that are downstream of asset "leads":
            AssetSelection.groups("marketing") & AssetSelection.keys("leads").downstream()
    """
    @public  # type: ignore
    @staticmethod
    def all() -> "AllAssetSelection":
        """Returns a selection that includes all assets."""
        return AllAssetSelection()
    @public  # type: ignore
    @staticmethod
    def assets(*assets_defs: AssetsDefinition) -> "KeysAssetSelection":
        """Returns a selection that includes all of the provided assets."""
        return KeysAssetSelection(*(key for assets_def in assets_defs for key in assets_def.keys))
    @public  # type: ignore
    @staticmethod
    def keys(*asset_keys: CoercibleToAssetKey) -> "KeysAssetSelection":
        """Returns a selection that includes assets with any of the provided keys."""
        _asset_keys = [AssetKey.from_coerceable(key) for key in asset_keys]
        return KeysAssetSelection(*_asset_keys)
    @public  # type: ignore
    @staticmethod
    def groups(*group_strs) -> "GroupsAssetSelection":
        """Returns a selection that includes assets that belong to any of the provided groups"""
        check.tuple_param(group_strs, "group_strs", of_type=str)
        return GroupsAssetSelection(*group_strs)
    @public  # type: ignore
    def downstream(self, depth: Optional[int] = None) -> "DownstreamAssetSelection":
        """
        Returns a selection that includes all assets that are downstream of any of the assets in
        this selection, as well as all the assets in this selection.
        depth (Optional[int]): If provided, then only include assets to the given depth. A depth
            of 2 means all assets that are children or grandchildren of the assets in this
            selection.
        """
        check.opt_int_param(depth, "depth")
        return DownstreamAssetSelection(self, depth=depth)
    @public  # type: ignore
    def upstream(self, depth: Optional[int] = None) -> "UpstreamAssetSelection":
        """
        Returns a selection that includes all assets that are upstream of any of the assets in
        this selection, as well as all the assets in this selection.
        Args:
            depth (Optional[int]): If provided, then only include assets to the given depth. A depth
                of 2 means all assets that are parents or grandparents of the assets in this
                selection.
        """
        check.opt_int_param(depth, "depth")
        return UpstreamAssetSelection(self, depth=depth)
    def __or__(self, other: "AssetSelection") -> "OrAssetSelection":
        check.inst_param(other, "other", AssetSelection)
        return OrAssetSelection(self, other)
    def __and__(self, other: "AssetSelection") -> "AndAssetSelection":
        check.inst_param(other, "other", AssetSelection)
        return AndAssetSelection(self, other)
    def resolve(
        self, all_assets: Sequence[Union[AssetsDefinition, SourceAsset]]
    ) -> FrozenSet[AssetKey]:
        check.sequence_param(all_assets, "all_assets", (AssetsDefinition, SourceAsset))
        return Resolver(all_assets).resolve(self) 
class AllAssetSelection(AssetSelection):
    pass
class AndAssetSelection(AssetSelection):
    def __init__(self, child_1: AssetSelection, child_2: AssetSelection):
        self.children = (child_1, child_2)
class DownstreamAssetSelection(AssetSelection):
    def __init__(self, child: AssetSelection, *, depth: Optional[int] = None):
        self.children = (child,)
        self.depth = depth
class GroupsAssetSelection(AssetSelection):
    def __init__(self, *children: str):
        self.children = children
class KeysAssetSelection(AssetSelection):
    def __init__(self, *children: AssetKey):
        self.children = children
class OrAssetSelection(AssetSelection):
    def __init__(self, child_1: AssetSelection, child_2: AssetSelection):
        self.children = (child_1, child_2)
class UpstreamAssetSelection(AssetSelection):
    def __init__(self, child: AssetSelection, *, depth: Optional[int] = None):
        self.children = (child,)
        self.depth = depth
# ########################
# ##### RESOLUTION
# ########################
class Resolver:
    def __init__(self, all_assets: Sequence[Union[AssetsDefinition, SourceAsset]]):
        assets_defs = []
        source_assets = []
        for asset in all_assets:
            if isinstance(asset, SourceAsset):
                source_assets.append(asset)
            elif isinstance(asset, AssetsDefinition):
                assets_defs.append(asset)
            else:
                check.failed(f"Expected SourceAsset or AssetsDefinition, got {type(asset)}")
        self.assets_defs = assets_defs
        self.asset_dep_graph = generate_asset_dep_graph(assets_defs, source_assets)
        self.all_assets_by_key_str = generate_asset_name_to_definition_map(assets_defs)
        self.source_asset_key_strs = {
            source_asset.key.to_user_string() for source_asset in source_assets
        }
    def resolve(self, root_node: AssetSelection) -> FrozenSet[AssetKey]:
        return frozenset(
            {AssetKey.from_user_string(asset_name) for asset_name in self._resolve(root_node)}
        )
    def _resolve(self, node: AssetSelection) -> AbstractSet[str]:
        if isinstance(node, AllAssetSelection):
            return set(self.all_assets_by_key_str.keys())
        elif isinstance(node, AndAssetSelection):
            child_1, child_2 = [self._resolve(child) for child in node.children]
            return child_1 & child_2
        elif isinstance(node, DownstreamAssetSelection):
            child = self._resolve(node.children[0])
            return reduce(
                operator.or_,
                [
                    {asset_name}
                    | fetch_connected(
                        item=asset_name,
                        graph=self.asset_dep_graph,
                        direction="downstream",
                        depth=node.depth,
                    )
                    for asset_name in child
                ],
            )
        elif isinstance(node, GroupsAssetSelection):
            return reduce(
                operator.or_,
                [_match_groups(assets_def, set(node.children)) for assets_def in self.assets_defs],
            )
        elif isinstance(node, KeysAssetSelection):
            specified_key_strs = set([child.to_user_string() for child in node.children])
            invalid_key_strs = specified_key_strs - set(self.all_assets_by_key_str.keys())
            selected_source_asset_key_strs = specified_key_strs & self.source_asset_key_strs
            if selected_source_asset_key_strs:
                raise DagsterInvalidSubsetError(
                    f"AssetKey(s) {selected_source_asset_key_strs} were selected, but these keys are "
                    "supplied by SourceAsset objects, not AssetsDefinition objects. You don't need "
                    "to include source assets in a selection for downstream assets to be able to "
                    "read them."
                )
            if invalid_key_strs:
                raise DagsterInvalidSubsetError(
                    f"AssetKey(s) {invalid_key_strs} were selected, but no AssetsDefinition objects supply "
                    "these keys. Make sure all keys are spelled correctly, and all AssetsDefinitions "
                    "are correctly added to the repository."
                )
            return specified_key_strs
        elif isinstance(node, OrAssetSelection):
            child_1, child_2 = [self._resolve(child) for child in node.children]
            return child_1 | child_2
        elif isinstance(node, UpstreamAssetSelection):
            child = self._resolve(node.children[0])
            return reduce(
                operator.or_,
                [
                    {asset_name}
                    | fetch_connected(
                        item=asset_name,
                        graph=self.asset_dep_graph,
                        direction="upstream",
                        depth=node.depth,
                    )
                    for asset_name in child
                ],
            )
        else:
            check.failed(f"Unknown node type: {type(node)}")
def _match_groups(assets_def: AssetsDefinition, groups: AbstractSet[str]) -> AbstractSet[str]:
    return {
        asset_key.to_user_string()
        for asset_key, group in assets_def.group_names_by_key.items()
        if group in groups
    }