Source code for dagster._core.definitions.asset_selection

import operator
from abc import ABC
from functools import reduce
from typing import AbstractSet, FrozenSet, Optional, Sequence, Union

import dagster._check as check
from dagster._annotations import public
from dagster._core.errors import DagsterInvalidSubsetError
from dagster._core.selector.subset_selector import (
    fetch_connected,
    generate_asset_dep_graph,
    generate_asset_name_to_definition_map,
)

from .assets import AssetsDefinition
from .events import AssetKey, CoercibleToAssetKey
from .source_asset import SourceAsset


[docs]class AssetSelection(ABC): """ An AssetSelection defines a query over a set of assets, normally all the assets in a repository. You can use the "|" and "&" operators to create unions and intersections of asset selections, respectively. AssetSelections are typically used with :py:func:`define_asset_job`. Examples: .. code-block:: python # Select all assets in group "marketing": AssetSelection.groups("marketing") # Select all assets in group "marketing", as well as the asset with key "promotion": AssetSelection.groups("marketing") | AssetSelection.keys("promotion") # Select all assets in group "marketing" that are downstream of asset "leads": AssetSelection.groups("marketing") & AssetSelection.keys("leads").downstream() """ @public # type: ignore @staticmethod def all() -> "AllAssetSelection": """Returns a selection that includes all assets.""" return AllAssetSelection() @public # type: ignore @staticmethod def assets(*assets_defs: AssetsDefinition) -> "KeysAssetSelection": """Returns a selection that includes all of the provided assets.""" return KeysAssetSelection(*(key for assets_def in assets_defs for key in assets_def.keys)) @public # type: ignore @staticmethod def keys(*asset_keys: CoercibleToAssetKey) -> "KeysAssetSelection": """Returns a selection that includes assets with any of the provided keys.""" _asset_keys = [AssetKey.from_coerceable(key) for key in asset_keys] return KeysAssetSelection(*_asset_keys) @public # type: ignore @staticmethod def groups(*group_strs) -> "GroupsAssetSelection": """Returns a selection that includes assets that belong to any of the provided groups""" check.tuple_param(group_strs, "group_strs", of_type=str) return GroupsAssetSelection(*group_strs) @public # type: ignore def downstream(self, depth: Optional[int] = None) -> "DownstreamAssetSelection": """ Returns a selection that includes all assets that are downstream of any of the assets in this selection, as well as all the assets in this selection. depth (Optional[int]): If provided, then only include assets to the given depth. A depth of 2 means all assets that are children or grandchildren of the assets in this selection. """ check.opt_int_param(depth, "depth") return DownstreamAssetSelection(self, depth=depth) @public # type: ignore def upstream(self, depth: Optional[int] = None) -> "UpstreamAssetSelection": """ Returns a selection that includes all assets that are upstream of any of the assets in this selection, as well as all the assets in this selection. Args: depth (Optional[int]): If provided, then only include assets to the given depth. A depth of 2 means all assets that are parents or grandparents of the assets in this selection. """ check.opt_int_param(depth, "depth") return UpstreamAssetSelection(self, depth=depth) def __or__(self, other: "AssetSelection") -> "OrAssetSelection": check.inst_param(other, "other", AssetSelection) return OrAssetSelection(self, other) def __and__(self, other: "AssetSelection") -> "AndAssetSelection": check.inst_param(other, "other", AssetSelection) return AndAssetSelection(self, other) def resolve( self, all_assets: Sequence[Union[AssetsDefinition, SourceAsset]] ) -> FrozenSet[AssetKey]: check.sequence_param(all_assets, "all_assets", (AssetsDefinition, SourceAsset)) return Resolver(all_assets).resolve(self)
class AllAssetSelection(AssetSelection): pass class AndAssetSelection(AssetSelection): def __init__(self, child_1: AssetSelection, child_2: AssetSelection): self.children = (child_1, child_2) class DownstreamAssetSelection(AssetSelection): def __init__(self, child: AssetSelection, *, depth: Optional[int] = None): self.children = (child,) self.depth = depth class GroupsAssetSelection(AssetSelection): def __init__(self, *children: str): self.children = children class KeysAssetSelection(AssetSelection): def __init__(self, *children: AssetKey): self.children = children class OrAssetSelection(AssetSelection): def __init__(self, child_1: AssetSelection, child_2: AssetSelection): self.children = (child_1, child_2) class UpstreamAssetSelection(AssetSelection): def __init__(self, child: AssetSelection, *, depth: Optional[int] = None): self.children = (child,) self.depth = depth # ######################## # ##### RESOLUTION # ######################## class Resolver: def __init__(self, all_assets: Sequence[Union[AssetsDefinition, SourceAsset]]): assets_defs = [] source_assets = [] for asset in all_assets: if isinstance(asset, SourceAsset): source_assets.append(asset) elif isinstance(asset, AssetsDefinition): assets_defs.append(asset) else: check.failed(f"Expected SourceAsset or AssetsDefinition, got {type(asset)}") self.assets_defs = assets_defs self.asset_dep_graph = generate_asset_dep_graph(assets_defs, source_assets) self.all_assets_by_key_str = generate_asset_name_to_definition_map(assets_defs) self.source_asset_key_strs = { source_asset.key.to_user_string() for source_asset in source_assets } def resolve(self, root_node: AssetSelection) -> FrozenSet[AssetKey]: return frozenset( {AssetKey.from_user_string(asset_name) for asset_name in self._resolve(root_node)} ) def _resolve(self, node: AssetSelection) -> AbstractSet[str]: if isinstance(node, AllAssetSelection): return set(self.all_assets_by_key_str.keys()) elif isinstance(node, AndAssetSelection): child_1, child_2 = [self._resolve(child) for child in node.children] return child_1 & child_2 elif isinstance(node, DownstreamAssetSelection): child = self._resolve(node.children[0]) return reduce( operator.or_, [ {asset_name} | fetch_connected( item=asset_name, graph=self.asset_dep_graph, direction="downstream", depth=node.depth, ) for asset_name in child ], ) elif isinstance(node, GroupsAssetSelection): return reduce( operator.or_, [_match_groups(assets_def, set(node.children)) for assets_def in self.assets_defs], ) elif isinstance(node, KeysAssetSelection): specified_key_strs = set([child.to_user_string() for child in node.children]) invalid_key_strs = specified_key_strs - set(self.all_assets_by_key_str.keys()) selected_source_asset_key_strs = specified_key_strs & self.source_asset_key_strs if selected_source_asset_key_strs: raise DagsterInvalidSubsetError( f"AssetKey(s) {selected_source_asset_key_strs} were selected, but these keys are " "supplied by SourceAsset objects, not AssetsDefinition objects. You don't need " "to include source assets in a selection for downstream assets to be able to " "read them." ) if invalid_key_strs: raise DagsterInvalidSubsetError( f"AssetKey(s) {invalid_key_strs} were selected, but no AssetsDefinition objects supply " "these keys. Make sure all keys are spelled correctly, and all AssetsDefinitions " "are correctly added to the repository." ) return specified_key_strs elif isinstance(node, OrAssetSelection): child_1, child_2 = [self._resolve(child) for child in node.children] return child_1 | child_2 elif isinstance(node, UpstreamAssetSelection): child = self._resolve(node.children[0]) return reduce( operator.or_, [ {asset_name} | fetch_connected( item=asset_name, graph=self.asset_dep_graph, direction="upstream", depth=node.depth, ) for asset_name in child ], ) else: check.failed(f"Unknown node type: {type(node)}") def _match_groups(assets_def: AssetsDefinition, groups: AbstractSet[str]) -> AbstractSet[str]: return { asset_key.to_user_string() for asset_key, group in assets_def.group_names_by_key.items() if group in groups }