Source code for flight_worker.flight_worker

# -*- coding: utf-8 -*-
# https://github.com/cython/cython/wiki/enhancements-compilerdirectives
# cython:infer_types True
# cython:boundscheck False
# cython:cdivision True

import glob
import logging
import math
import os
import re
import traceback
import warnings
from datetime import datetime as dt
from pathlib import Path

from agi_env import normalize_path
from agi_node import MutableNamespace
from agi_node.polars_worker import PolarsWorker
from agi_env.agi_logger import AgiLogger
from flight.flight_args import UNSUPPORTED_DATA_SOURCE_MESSAGE
from flight.reduction import write_reduce_artifact

logger = AgiLogger.get_logger(__name__)
warnings.filterwarnings("ignore")

import polars as pl

_EARTH_RADIUS_M = 6_371_000.0


def _haversine_distance_m(row) -> float:
    if row["prev_lat"] is None or row["prev_long"] is None:
        return 0.0
    try:
        lat1 = math.radians(float(row["prev_lat"]))
        lon1 = math.radians(float(row["prev_long"]))
        lat2 = math.radians(float(row["lat"]))
        lon2 = math.radians(float(row["long"]))
    except (TypeError, ValueError):
        return 0.0

    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = (
        math.sin(dlat / 2.0) ** 2
        + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2.0) ** 2
    )
    return 2.0 * _EARTH_RADIUS_M * math.asin(math.sqrt(min(1.0, a)))


[docs] class FlightWorker(PolarsWorker): """Class derived from AgiDataWorker""" pool_vars = {}
[docs] def preprocess_df(self, df: pl.DataFrame) -> pl.DataFrame: """ Preprocess the DataFrame by parsing the date column and creating previous coordinate columns. This operation is done once per file. """ df = df.with_columns( [ # Convert date column from string to datetime only once pl.col("date") .str.strptime(pl.Datetime, format="%Y-%m-%d %H:%M:%S") .alias("date"), # Create shifted columns for previous latitude and longitude pl.col("lat").shift(1).alias("prev_lat"), pl.col("long").shift(1).alias("prev_long"), ] ) return df
[docs] def calculate_speed(self, new_column_name: str, df: pl.DataFrame) -> pl.DataFrame: """ Compute the segment distance in meters between consecutive coordinate pairs and add it under the legacy ``speed`` column name used by this demo. Assumes that the previous coordinate columns are already present. """ df = df.with_columns( [ pl.struct(["prev_lat", "prev_long", "lat", "long"]) .map_elements( _haversine_distance_m, return_dtype=pl.Float64, ) .alias(new_column_name), ] ) return df
[docs] def start(self): """Initialize global variables and setup paths.""" global global_vars if not isinstance(self.args, MutableNamespace): if isinstance(self.args, dict): payload = self.args else: payload = vars(self.args) self.args = MutableNamespace(**payload) logging.info(f"from: {__file__}") if getattr(self, "pool_vars", None) is None: self.pool_vars = {} reset_target = getattr(self.args, "reset_target", False) data_paths = self.setup_data_directories( source_path=self.args.data_in, target_path=self.args.data_out, target_subdir="dataframe", reset_target=reset_target, ) data_in = data_paths.normalized_input self.args.data_in = data_in if self.verbose > 1: logging.info( f"Worker #{self._worker_id} dataframe root path = {self.data_out}" ) if self.verbose > 0: logging.info(f"start worker_id {self._worker_id}\n") args = self.args if not getattr(args, "data_source", None): args.data_source = "file" if not getattr(args, "output_format", None): args.output_format = "parquet" if args.data_source != "file": raise NotImplementedError(UNSUPPORTED_DATA_SOURCE_MESSAGE) self.pool_vars["args"] = self.args self.pool_vars["verbose"] = self.verbose global_vars = self.pool_vars
[docs] def work_init(self): """Initialize work by reading from shared space.""" global global_vars pass
[docs] def pool_init(self, worker_vars): """Initialize the pool with worker variables. Args: worker_vars (dict): Variables specific to the worker. """ global global_vars global_vars = worker_vars
[docs] def work_pool(self, file): """Parse IVQ log files. Args: file (str): The log file to parse. Returns: pl.DataFrame: Parsed data. """ global global_vars data_source = getattr(global_vars["args"], "data_source", None) if not data_source: data_source = "file" if data_source == "file": file_path = Path(os.path.expanduser(str(file))) if not file_path.is_absolute(): file_path = Path(os.path.expanduser(f"~/{file}")) if os.name != "nt": file = os.path.normpath(str(file_path)).replace("\\", "/") else: file = normalize_path(str(file_path)) if not Path(file).is_file(): raise FileNotFoundError(file) else: raise NotImplementedError(UNSUPPORTED_DATA_SOURCE_MESSAGE) # Read the CSV file using Polars. df = pl.read_csv(file) # If the first column is redundant (e.g. named "Unnamed: 0"), drop it. if df.columns and (df.columns[0].startswith("Unnamed") or df.columns[0] == ""): df = df.drop(df.columns[0]) # Preprocess the DataFrame (date parsing, cleaning, etc.) df = self.preprocess_df(df) # Preserve the historical output column name expected by downstream demos. df = self.calculate_speed("speed", df) return df.with_columns(pl.lit(Path(file).name).alias("source_file"))
[docs] def work_done(self, worker_df): """Concatenate dataframe if any and save the results. Args: worker_df (pl.DataFrame): Output dataframe for one plane. """ if worker_df.is_empty(): return logger.info(f"mkdir {self.data_out}") os.makedirs(self.data_out, exist_ok=True) output_files = [] # Process each plane separately for plane in worker_df.select(pl.col("aircraft")).unique().to_series(): plane_df = worker_df.filter(pl.col("aircraft") == plane).sort("date") # Create (or replace) "worker_id" from "aircraft": plane_df = plane_df.with_columns(pl.col("aircraft").alias("worker_id")) try: if self.args.output_format == "parquet": filename = (Path(self.data_out) / str(plane)).with_suffix( ".parquet" ) plane_df.write_parquet(str(filename)) output_files.append(filename) elif self.args.output_format == "csv": timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") filename = f"{self.data_out}/{str(plane)+'_'+timestamp}.csv" plane_df.write_csv(str(filename)) output_files.append(filename) logging.info( f"Saved dataframe for plane {plane} with shape {plane_df.shape} in {filename}" ) except Exception as e: logging.info(traceback.format_exc()) raise RuntimeError(f"Error saving dataframe for plane {plane}") from e write_reduce_artifact( worker_df, self.data_out, worker_id=getattr(self, "_worker_id", 0), output_files=output_files, output_format=getattr(self.args, "output_format", ""), )
[docs] def stop(self): try: """Finalize the worker by listing saved dataframes.""" files = glob.glob(os.path.join(self.data_out, "**"), recursive=True) df_files = [f for f in files if re.search(r"\.(csv|parquet|json)$", f)] n_df = len(df_files) if self.verbose > 0: logging.info(f"{n_df} dataframes") for f in df_files: logging.info("\t" + str(Path(f))) if not n_df: logging.info("No dataframe created") except Exception as err: logging.info(f"Error while trying to find dataframes: {err}") # call the base class stop() super().stop()