Ask AI

You are viewing an unreleased or outdated version of the documentation

Source code for dagster_polars.io_managers.delta

from collections import defaultdict
from collections.abc import Mapping, Sequence
from enum import Enum
from pprint import pformat
from typing import TYPE_CHECKING, Any, Optional, Union

import packaging.version
import polars as pl
from dagster import InputContext, MetadataValue, MultiPartitionKey, OutputContext
from dagster._annotations import experimental
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.storage.upath_io_manager import is_dict_type

from dagster_polars.io_managers.base import BasePolarsUPathIOManager

try:
    from deltalake import DeltaTable
    from deltalake.exceptions import TableNotFoundError
except ImportError as e:
    if "deltalake" in str(e):
        raise ImportError(
            "Install 'dagster-polars[deltalake]' to use DeltaLake functionality"
        ) from e
    else:
        raise e

if TYPE_CHECKING:
    from upath import UPath


SINGLE_LOADING_TYPES = (pl.DataFrame, pl.LazyFrame)


class DeltaWriteMode(str, Enum):
    error = "error"
    append = "append"
    overwrite = "overwrite"
    ignore = "ignore"


[docs] @experimental class PolarsDeltaIOManager(BasePolarsUPathIOManager): """Implements writing and reading DeltaLake tables. Features: - All features provided by :py:class:`~dagster_polars.BasePolarsUPathIOManager`. - All read/write options can be set via corresponding metadata or config parameters (metadata takes precedence). - Supports native DeltaLake partitioning by storing different asset partitions in the same DeltaLake table. To enable this behavior, set the `partition_by` metadata value or config parameter **and** use a non-dict type annotation when loading the asset. The `partition_by` value will be used in `delta_write_options` of `pl.DataFrame.write_delta` and `pyarrow_options` of `pl.scan_delta`). When using a one-dimensional `PartitionsDefinition`, it should be a single string like `"column"`. When using a `MultiPartitionsDefinition`, it should be a dict with dimension to column names mapping, like `{"dimension": "column"}`. Install `dagster-polars[delta]` to use this IOManager. Examples: .. code-block:: python from dagster import asset from dagster_polars import PolarsDeltaIOManager import polars as pl @asset( io_manager_key="polars_delta_io_manager", key_prefix=["my_dataset"] ) def my_asset() -> pl.DataFrame: # data will be stored at <base_dir>/my_dataset/my_asset.delta ... defs = Definitions( assets=[my_table], resources={ "polars_parquet_io_manager": PolarsDeltaIOManager(base_dir="s3://my-bucket/my-dir") } ) Appending to a DeltaLake table: .. code-block:: python @asset( io_manager_key="polars_delta_io_manager", metadata={ "mode": "append" }, ) def my_table() -> pl.DataFrame: ... Using native DeltaLake partitioning by storing different asset partitions in the same DeltaLake table: .. code-block:: python from dagster import AssetExecutionContext, DailyPartitionedDefinition from dagster_polars import LazyFramePartitions @asset( io_manager_key="polars_delta_io_manager", metadata={ "partition_by": "partition_col" }, partitions_def=StaticPartitionsDefinition(["a, "b", "c"]) ) def upstream(context: AssetExecutionContext) -> pl.DataFrame: df = ... # column with the partition_key must match `partition_by` metadata value return df.with_columns(pl.lit(context.partition_key).alias("partition_col")) @asset def downstream(upstream: pl.LazyFrame) -> pl.DataFrame: ... When using `MuiltiPartitionsDefinition`, `partition_by` metadata value should be a dictionary mapping dimensions to column names. .. code-block:: python from dagster import AssetExecutionContext, DailyPartitionedDefinition, MultiPartitionsDefinition, StaticPartitionsDefinition from dagster_polars import LazyFramePartitions @asset( io_manager_key="polars_delta_io_manager", metadata={ "partition_by": {"time": "date", "clients": "client"} # dimension -> column mapping }, partitions_def=MultiPartitionsDefinition( { "date": DailyPartitionedDefinition(...), "clients": StaticPartitionsDefinition(...) } ) ) def upstream(context: AssetExecutionContext) -> pl.DataFrame: df = ... partition_keys_by_dimension = context.partition_key.keys_by_dimension return df.with_columns( pl.lit(partition_keys_by_dimension["time"]).alias("date"), # time dimension matches date column pl.lit(partition_keys_by_dimension["clients"]).alias("client") # clients dimension matches client column ) @asset def downstream(upstream: pl.LazyFrame) -> pl.DataFrame: ... """ extension: str = ".delta" mode: DeltaWriteMode = DeltaWriteMode.overwrite.value # type: ignore overwrite_schema: bool = False version: Optional[int] = None def sink_df_to_path( self, context: OutputContext, df: pl.LazyFrame, path: "UPath", ): context_metadata = context.definition_metadata or {} streaming = context_metadata.get("streaming", False) return self.write_df_to_path(context, df.collect(streaming=streaming), path) def write_df_to_path( self, context: OutputContext, df: pl.DataFrame, path: "UPath", ): context_metadata = context.definition_metadata or {} delta_write_options = context_metadata.get( "delta_write_options" ) # This needs to be gone and just only key value on the metadata if context.has_asset_partitions: delta_write_options = delta_write_options or {} # this could be wrong, you could have partition_by in # delta_write_options and in the metadata partition_by = context_metadata.get("partition_by") if partition_by is not None: assert ( context.partition_key is not None ), "Cannot set `partition_by` for an asset without partitions." if isinstance(partition_by, dict) and isinstance( context.partition_key, MultiPartitionKey ): delta_write_options["partition_by"] = list(partition_by.values()) elif isinstance(partition_by, str) and isinstance(context.partition_key, str): delta_write_options["partition_by"] = partition_by else: raise ValueError( "Metadata value for `partition_by` must be a string " "for single-partitioned assets or a dictionary for " "multi-partitioned assets.\n" f"Found: `{partition_by}`" ) if (engine := delta_write_options.get("engine", "pyarrow")) == "rust": delta_write_options["predicate"] = self.get_predicate(context) elif engine == "pyarrow": delta_write_options["partition_filters"] = self.get_partition_filters(context) else: raise NotImplementedError(f"Invalid engine: {engine}") if delta_write_options is not None: context.log.debug(f"Writing with delta_write_options: {pformat(delta_write_options)}") storage_options = self.storage_options try: dt = DeltaTable(str(path), storage_options=storage_options) except TableNotFoundError: dt = str(path) df.write_delta( dt, mode=context_metadata.get("mode") or self.mode.value, overwrite_schema=context_metadata.get("overwrite_schema") or self.overwrite_schema, storage_options=storage_options, delta_write_options=delta_write_options, ) if isinstance(dt, DeltaTable): current_version = dt.version() else: current_version = DeltaTable( str(path), storage_options=storage_options, without_files=True ).version() context.add_output_metadata({"version": current_version}) def scan_df_from_path( self, path: "UPath", context: InputContext, ) -> pl.LazyFrame: """This method scans a DeltaLake table into a `polars.LazyFrame`. It can be called in 3 different situations: 1. with an unpartitioned asset 2. with a partitioned asset without native partitioning enabled - multiple times on nested .delta tables 3. with a partitioned asset and with native partitioning enabled - a single time on the .delta table. In the (3) option we apply partition filters to only load mapped partitions """ assert context.upstream_output is not None assert context.upstream_output.definition_metadata is not None context_metadata = context.definition_metadata or {} version = self.get_delta_version_to_load(path, context) context.log.debug(f"Reading Delta table with version: {version}") pyarrow_options = context_metadata.get("pyarrow_options", {}) partition_by = context.upstream_output.definition_metadata.get("partition_by") # we want to apply partition filters when loading some partitions, but # not all partitions if ( partition_by and len(context.asset_partition_keys) > 0 and context.has_asset_key is not None and context.has_asset_partitions is not None and context.asset_partition_keys != set( context.upstream_output.asset_partitions_def.get_partition_keys( dynamic_partitions_store=context.instance ) ) ): pyarrow_options["partitions"] = self.get_partition_filters(context) if pyarrow_options: context.log.debug(f"Reading with pyarrow_options: {pyarrow_options}") delta_table_options = context_metadata.get("delta_table_options") if delta_table_options: context.log.debug(f"Reading with delta_table_options: {delta_table_options}") return pl.scan_delta( str(path), version=version, delta_table_options=delta_table_options, storage_options=self.storage_options, **_get_pyarrow_options_kwargs(pyarrow_options), ) def load_partitions(self, context: InputContext): assert context.upstream_output is not None # any partition would work as they all are stored in the same DeltaLake # table path = self._get_path_without_extension(context) context.log.debug( f"Loading {len(context.asset_partition_keys)} partitions from {path} using {self.__class__.__name__}..." ) if context.upstream_output.definition_metadata.get("partition_by") and not is_dict_type( context.dagster_type.typing_type ): # user enabled native partitioning and wants a `pl.DataFrame` or # `pl.LazyFrame` return self.load_from_path(context, self._with_extension(path)) else: # default behaviour return super().load_partitions(context) def get_path_for_partition( self, context: Union[InputContext, OutputContext], path: "UPath", partition: str ) -> "UPath": if isinstance(context, InputContext): if ( context.upstream_output is not None and context.upstream_output.definition_metadata is not None and context.upstream_output.definition_metadata.get("partition_by") is not None ): # upstream asset has "partition_by" metadata set, so # partitioning for it is handled by DeltaLake itself return path if isinstance(context, OutputContext): if ( context.definition_metadata is not None and context.definition_metadata.get("partition_by") is not None ): # this asset has "partition_by" metadata set, so partitioning # for it is handled by DeltaLake itself return path return path / partition # partitioning is handled by the IOManager @staticmethod def get_partition_filters( context: Union[InputContext, OutputContext], ) -> Sequence[tuple[str, str, Any]]: """Create filters for `deltalake` to know which partitions are overwritten. Only used with the PyArrow, which will be removed in `deltalake>=1`. """ if isinstance(context, OutputContext): partition_by = context.definition_metadata.get("partition_by") elif isinstance(context, InputContext) and context.upstream_output is not None: partition_by = context.upstream_output.definition_metadata.get("partition_by") else: raise DagsterInvariantViolationError(f"Invalid context type: {type(context)}") if partition_by is None or not context.has_asset_partitions: filters = [] elif isinstance(partition_by, dict): all_keys_by_dim = defaultdict(list) for partition_key in context.asset_partition_keys: assert isinstance(partition_key, MultiPartitionKey), ( f"Received dict `partition_by` metadata value {partition_by}, " f"but the partition_key is not a `MultiPartitionKey`: {partition_key}" ) for dim, key in partition_key.keys_by_dimension.items(): all_keys_by_dim[dim].append(key) filters = [(partition_by[dim], "in", keys) for dim, keys in all_keys_by_dim.items()] elif isinstance(partition_by, str): assert not isinstance(context.asset_partition_keys[0], MultiPartitionKey), ( f"Received string `partition_by` metadata value `{partition_by}`, " f"but the partition_key is not a `MultiPartitionKey`: {context.asset_partition_keys[0]}" ) filters = [(partition_by, "in", context.asset_partition_keys)] else: raise NotImplementedError(f"Unsupported `partition_by` metadata value: {partition_by}") return filters @staticmethod def get_predicate( context: Union[InputContext, OutputContext], ) -> Optional[str]: """Create a predicate for `deltalake` to select which partitions are overwritten. Returns `None` if the entire table is overwritten. See documentation here: https://delta-io.github.io/delta-rs/usage/writing/#overwriting-part-of-the-table-data-using-a-predicate """ if isinstance(context, OutputContext): partition_by = context.definition_metadata.get("partition_by") elif isinstance(context, InputContext) and context.upstream_output is not None: partition_by = context.upstream_output.definition_metadata.get("partition_by") else: raise DagsterInvariantViolationError(f"Invalid context type: {type(context)}") def key_to_predicate(key): return f"'{key}'" if partition_by is None or not context.has_asset_partitions: return elif isinstance(partition_by, dict): all_keys_by_dim = defaultdict(list) for partition_key in context.asset_partition_keys: assert isinstance(partition_key, MultiPartitionKey), ( f"Received dict `partition_by` metadata value `{partition_by}`, " f"but the partition_key is not a `MultiPartitionKey`: `{partition_key}`." ) for dim, key in partition_key.keys_by_dimension.items(): all_keys_by_dim[dim].append(key) predicate = " AND ".join( [ f"{partition_by[dim]} in ({', '.join(map(key_to_predicate, keys))})" for dim, keys in all_keys_by_dim.items() ] ) elif isinstance(partition_by, str): assert not isinstance(context.asset_partition_keys[0], MultiPartitionKey), ( f"Received string `partition_by` metadata value `{partition_by}`, " f"but the partition_key is not a `MultiPartitionKey`: {context.asset_partition_keys[0]}" ) if len(context.asset_partition_keys) == 1: predicate = f"{partition_by} = '{context.asset_partition_keys[0]}'" else: predicate = f"{partition_by} in ({', '.join(map(key_to_predicate, context.asset_partition_keys))})" else: raise NotImplementedError(f"Unsupported `partition_by` metadata value: {partition_by}") return predicate def get_metadata( self, context: OutputContext, obj: Union[pl.DataFrame, pl.LazyFrame, None] ) -> dict[str, MetadataValue]: context_metadata = context.definition_metadata or {} metadata = super().get_metadata(context, obj) if context.has_asset_partitions: partition_by = context_metadata.get("partition_by") if partition_by is not None: metadata["partition_by"] = partition_by if context_metadata.get("mode") == "append": # modify the metadata to reflect the fact that we are appending to # the table if context.has_asset_partitions: # paths = self._get_paths_for_partitions(context) # assert len(paths) == 1 # path = list(paths.values())[0] # FIXME: what to about row_count metadata do if we are # appending to a partitioned table? # we should not be using the full table length, # but it's unclear how to get the length of the partition we # are appending to pass else: metadata["append_row_count"] = metadata["dagster/row_count"] path = self._get_path(context) # we need to get row_count from the full table metadata["dagster/row_count"] = MetadataValue.int( DeltaTable(str(path), storage_options=self.storage_options) .to_pyarrow_dataset() .count_rows() ) return metadata def get_delta_version_to_load(self, path: "UPath", context: InputContext) -> int: context_metadata = context.definition_metadata or {} version_from_metadata = context_metadata.get("version") version_from_config = self.version version: Optional[int] = None if version_from_metadata is not None and version_from_config is not None: context.log.warning( f"Both version from metadata ({version_from_metadata}) " f"and config ({version_from_config}) are set. Using version from metadata." ) version = int(version_from_metadata) elif version_from_metadata is None and version_from_config is not None: version = int(version_from_config) elif version_from_metadata is not None and version_from_config is None: version = int(version_from_metadata) if version is None: return DeltaTable( str(path), storage_options=self.storage_options, without_files=True ).version() else: return version
# polars>=1.14.0 requires use_pyarrow=True when setting pyarrow_options def _get_pyarrow_options_kwargs(pyarrow_options: Mapping[str, object]) -> Mapping[str, Any]: kwargs: dict[str, object] = {"pyarrow_options": pyarrow_options} if packaging.version.parse(pl.__version__) >= packaging.version.parse("1.14.0"): kwargs["use_pyarrow"] = True return kwargs