# BSD 3-Clause License
#
# [License Text Remains Unchanged]
#
# (Include the full BSD 3-Clause License text here)
"""
data_worker Framework Callback Functions Module
===============================================
This module provides the `PolarsWorker` class, which extends the foundational
functionalities of `BaseWorker` for processing data using multiprocessing or
single-threaded approaches.
Classes:
PolarsWorker: Worker class for data processing tasks.
Internal Libraries:
os, warnings
External Libraries:
concurrent.futures.ProcessPoolExecutor
pathlib.Path
time
polars as pl
BaseWorker from node import BaseWorker
"""
# Internal Libraries:
import os
import warnings
# External Libraries:
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import time
from agi_env import AgiEnv, normalize_path
from agi_node.agi_dispatcher import BaseWorker
import polars as pl
import logging
warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)
[docs]
class PolarsWorker(BaseWorker):
"""
PolarsWorker Class
--------------------
Inherits from :class:`BaseWorker` to provide extended data processing functionalities.
Attributes:
verbose (int): Verbosity level for logging.
data_out (str): Path to the output directory.
worker_id (int): Identifier for the worker instance.
args (dict): Configuration arguments for the worker.
"""
[docs]
def work_pool(self, x: any = None) -> pl.DataFrame:
"""
Processes a single task.
Args:
x (any, optional): The task to process. Defaults to None.
Returns:
pl.DataFrame: A Polars DataFrame with the processed results.
"""
logging.info("work_pool")
# Call the actual work_pool method, which should return a Polars DataFrame.
# Ensure that the original _actual_work_pool method is refactored accordingly.
return self._actual_work_pool(x)
[docs]
def work_done(self, df: pl.DataFrame = None) -> None:
"""
Handles the post-processing of the DataFrame after `work_pool` execution.
Args:
df (pl.DataFrame, optional): The Polars DataFrame to process. Defaults to None.
Raises:
ValueError: If an unsupported output format is specified.
"""
logging.info("work_done")
if df is None or df.is_empty():
return
# Example post-processing logic using Polars.
# For instance, saving the DataFrame to disk.
output_format = self.args.get("output_format")
output_filename = f"{self._worker_id}_output"
if output_format == "parquet":
output_path = Path(self.data_out) / f"{output_filename}.parquet"
df.write_parquet(output_path)
elif output_format == "csv":
output_path = Path(self.data_out) / f"{output_filename}.csv"
df.write_csv(output_path)
else:
raise ValueError("Unsupported output format")
[docs]
def works(self, workers_plan: any, workers_plan_metadata: any) -> float:
"""
Executes worker tasks based on the distribution tree.
Args:
workers_plan (any): Distribution tree structure.
workers_plan_metadata (any): Additional information about the workers.
Returns:
float: Execution time in seconds.
"""
if workers_plan:
if self._mode & 4:
self._exec_multi_process(workers_plan, workers_plan_metadata)
else:
self._exec_mono_process(workers_plan, workers_plan_metadata)
self.stop()
if BaseWorker._t0 is None:
BaseWorker._t0 = time.time()
return time.time() - BaseWorker._t0
def _exec_multi_process(self, workers_plan: any, workers_plan_metadata: any) -> None:
"""
Executes tasks in multiprocessing mode.
Args:
workers_plan (any): Distribution tree structure.
workers_plan_metadata (any): Additional information about the workers.
"""
ncore = 1
works = []
if isinstance(workers_plan, list):
for i in workers_plan[self._worker_id]:
works += i
ncore = max(min(len(works), int(os.cpu_count())), 1)
logging.info(
f"PolarsWorker.work - ncore {ncore} - worker_id #{self._worker_id}"
f" - work_pool x {len(works)}",
)
self.work_init()
for work_id, work in enumerate(workers_plan[self._worker_id]):
list_df = []
df = pl.DataFrame()
ncore = max(min(len(work), int(os.cpu_count())), 1)
if os.name == "nt":
process_factory_type = "spawn"
else:
process_factory_type ="spawn"
# mp_ctx = multiprocessing.get_context(process_factory_type)
with ThreadPoolExecutor(
#mp_context=mp_ctx,
max_workers=ncore,
initializer=self.pool_init,
initargs=(self.pool_vars,),
) as exec:
# Map each work item to work_pool.
dfs = exec.map(self.work_pool, work)
# Post pool processinflight: collect non-empty DataFrames.
for df_result in dfs:
# Check if the result DataFrame has columns.
if df_result.shape[1] > 0:
list_df.append(df_result)
# Additional processing per result can be done here.
if list_df:
# Add a 'worker_id' to each DataFrame.
for idx, df_result in enumerate(list_df):
list_df[idx] = df_result.with_columns(
pl.lit(str((self._worker_id, idx))).alias("worker_id")
)
# Concatenate all DataFrames into a single DataFrame.
df = pl.concat(list_df, how="vertical")
# Handle the concatenated DataFrame.
self.work_done(df if not df.is_empty() else pl.DataFrame())
def _exec_mono_process(self, workers_plan: any, workers_plan_metadata: any) -> None:
"""
Executes tasks in single-threaded mode.
Args:
workers_plan (any): Distribution tree structure.
workers_plan_metadata (any): Additional information about the workers.
"""
self.work_init()
for work_id, work in enumerate(workers_plan[self._worker_id]):
list_df = []
df = pl.DataFrame()
logging.info(
f"PolarsWorker.work - monoprocess work #{work_id} - work_pool x {len(work)}"
)
if workers_plan:
# Apply the work_pool function to each item in work.
dfs = [self.work_pool(file) for file in work]
# Ensure that the returned objects are Polars DataFrames.
if dfs and isinstance(dfs[0], pl.DataFrame):
for df_result in dfs:
if not df_result.is_empty():
list_df.append(df_result)
if list_df:
# Add a 'worker_id' to each DataFrame.
for idx, df_result in enumerate(list_df):
list_df[idx] = df_result.with_columns(
pl.lit(str((self._worker_id, 0))).alias("worker_id")
)
# Concatenate all DataFrames into a single DataFrame.
df = pl.concat(list_df, how="vertical")
# Handle the concatenated DataFrame.
self.work_done(df)