Source code for agi_node.polars_worker.polars_worker

# BSD 3-Clause License
#
# [License Text Remains Unchanged]
#
# (Include the full BSD 3-Clause License text here)

"""

data_worker Framework Callback Functions Module
===============================================

This module provides the `PolarsWorker` class, which extends the foundational
functionalities of `BaseWorker` for processing data using a thread pool or
single-threaded approaches. The pool path deliberately uses threads rather
than processes: polars releases the GIL in its native kernels, so threads
parallelise IO/native work without process spawn and pickling costs.

Classes:
    PolarsWorker: Worker class for data processing tasks.

External Libraries:
    concurrent.futures.ThreadPoolExecutor
    pathlib.Path
    polars as pl
    BaseWorker from node import BaseWorker


"""

# External Libraries:
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

from agi_node.agi_dispatcher import BaseWorker
from agi_node.agi_dispatcher import worker_pool_support

import polars as pl
import logging
logger = logging.getLogger(__name__)


def _thread_pool_factory(*, max_workers, initializer, initargs):
    """Build the thread pool for the pool execution path.

    ``ThreadPoolExecutor`` is resolved through this module's global namespace
    so tests can monkeypatch ``polars_worker.ThreadPoolExecutor``.
    """
    return ThreadPoolExecutor(
        max_workers=max_workers,
        initializer=initializer,
        initargs=initargs,
    )


def _concat_labeled(frames, labels):
    """Label each frame with its worker_id and concatenate vertically."""
    labeled = [
        frame.with_columns(pl.lit(label).alias("worker_id"))
        for frame, label in zip(frames, labels)
    ]
    return pl.concat(labeled, how="vertical")


_POLARS_POOL_HOOKS = worker_pool_support.PoolFrameHooks(
    family="PolarsWorker",
    executor_kind="thread",
    executor_factory=_thread_pool_factory,
    is_frame=lambda result: isinstance(result, pl.DataFrame),
    is_empty=lambda df: df.is_empty(),
    concat_labeled=_concat_labeled,
    empty_frame=pl.DataFrame,
)

[docs] class PolarsWorker(BaseWorker): """ PolarsWorker Class -------------------- Inherits from :class:`BaseWorker` to provide extended data processing functionalities. Attributes: verbose (int): Verbosity level for logging. data_out (str): Path to the output directory. worker_id (int): Identifier for the worker instance. args (dict): Configuration arguments for the worker. """
[docs] def work_pool(self, x: any = None) -> pl.DataFrame: # ty: ignore[invalid-type-form] """ Processes a single task. Args: x (any, optional): The task to process. Defaults to None. Returns: pl.DataFrame: A Polars DataFrame with the processed results. """ logging.info("work_pool") # Call the actual work_pool method, which should return a Polars DataFrame. # Ensure that the original _actual_work_pool method is refactored accordingly. return self._actual_work_pool(x) # ty: ignore[unresolved-attribute]
[docs] def work_done(self, df: pl.DataFrame = None) -> None: # ty: ignore[invalid-parameter-default] """ Handles the post-processing of the DataFrame after `work_pool` execution. Args: df (pl.DataFrame, optional): The Polars DataFrame to process. Defaults to None. Raises: ValueError: If an unsupported output format is specified. """ logging.info("work_done") if df is None or df.is_empty(): return # Example post-processing logic using Polars. # For instance, saving the DataFrame to disk. output_format = self.args.get("output_format") # work_done is called once per work chunk; suffix subsequent chunks so # they do not overwrite the first chunk's output file. chunk_index = getattr(self, "_work_done_chunk", 0) self._work_done_chunk = chunk_index + 1 output_filename = ( f"{self._worker_id}_output" if chunk_index == 0 else f"{self._worker_id}_output_{chunk_index}" ) if output_format == "parquet": output_path = Path(self.data_out) / f"{output_filename}.parquet" df.write_parquet(output_path) elif output_format == "csv": output_path = Path(self.data_out) / f"{output_filename}.csv" df.write_csv(output_path) else: raise ValueError("Unsupported output format")
[docs] def works(self, workers_plan: any, workers_plan_metadata: any) -> float: # ty: ignore[invalid-type-form] """ Executes worker tasks based on the distribution tree. Args: workers_plan (any): Distribution tree structure. workers_plan_metadata (any): Additional information about the workers. Returns: float: Execution time of this works() call in seconds. """ return worker_pool_support.run_works(self, workers_plan, workers_plan_metadata)
def _exec_multi_process(self, workers_plan: any, workers_plan_metadata: any) -> None: # ty: ignore[invalid-type-form] """ Executes tasks with a thread pool shared across all plan chunks. Args: workers_plan (any): Distribution tree structure. workers_plan_metadata (any): Additional information about the workers. """ worker_pool_support.exec_multi_process( self, workers_plan, workers_plan_metadata, _POLARS_POOL_HOOKS ) def _exec_mono_process(self, workers_plan: any, workers_plan_metadata: any) -> None: # ty: ignore[invalid-type-form] """ Executes tasks in single-threaded mode. Args: workers_plan (any): Distribution tree structure. workers_plan_metadata (any): Additional information about the workers. """ worker_pool_support.exec_mono_process( self, workers_plan, workers_plan_metadata, _POLARS_POOL_HOOKS )