# -*- coding: utf-8 -*-
# https://github.com/cython/cython/wiki/enhancements-compilerdirectives
# cython:infer_types True
# cython:boundscheck False
# cython:cdivision True
import glob
import logging
import math
import os
import re
import traceback
import warnings
from datetime import datetime as dt
from pathlib import Path
from agi_env import normalize_path
from agi_node import MutableNamespace
from agi_node.polars_worker import PolarsWorker
from agi_env.agi_logger import AgiLogger
from flight.flight_args import UNSUPPORTED_DATA_SOURCE_MESSAGE
from flight.reduction import write_reduce_artifact
logger = AgiLogger.get_logger(__name__)
warnings.filterwarnings("ignore")
import polars as pl
_EARTH_RADIUS_M = 6_371_000.0
def _haversine_distance_m(row) -> float:
if row["prev_lat"] is None or row["prev_long"] is None:
return 0.0
try:
lat1 = math.radians(float(row["prev_lat"]))
lon1 = math.radians(float(row["prev_long"]))
lat2 = math.radians(float(row["lat"]))
lon2 = math.radians(float(row["long"]))
except (TypeError, ValueError):
return 0.0
dlat = lat2 - lat1
dlon = lon2 - lon1
a = (
math.sin(dlat / 2.0) ** 2
+ math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2.0) ** 2
)
return 2.0 * _EARTH_RADIUS_M * math.asin(math.sqrt(min(1.0, a)))
[docs]
class FlightWorker(PolarsWorker):
"""Class derived from AgiDataWorker"""
pool_vars = {}
[docs]
def preprocess_df(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Preprocess the DataFrame by parsing the date column and creating
previous coordinate columns. This operation is done once per file.
"""
df = df.with_columns(
[
# Convert date column from string to datetime only once
pl.col("date")
.str.strptime(pl.Datetime, format="%Y-%m-%d %H:%M:%S")
.alias("date"),
# Create shifted columns for previous latitude and longitude
pl.col("lat").shift(1).alias("prev_lat"),
pl.col("long").shift(1).alias("prev_long"),
]
)
return df
[docs]
def calculate_speed(self, new_column_name: str, df: pl.DataFrame) -> pl.DataFrame:
"""
Compute the segment distance in meters between consecutive coordinate pairs
and add it under the legacy ``speed`` column name used by this demo.
Assumes that the previous coordinate columns are already present.
"""
df = df.with_columns(
[
pl.struct(["prev_lat", "prev_long", "lat", "long"])
.map_elements(
_haversine_distance_m,
return_dtype=pl.Float64,
)
.alias(new_column_name),
]
)
return df
[docs]
def start(self):
"""Initialize global variables and setup paths."""
global global_vars
if not isinstance(self.args, MutableNamespace):
if isinstance(self.args, dict):
payload = self.args
else:
payload = vars(self.args)
self.args = MutableNamespace(**payload)
logging.info(f"from: {__file__}")
if getattr(self, "pool_vars", None) is None:
self.pool_vars = {}
reset_target = getattr(self.args, "reset_target", False)
data_paths = self.setup_data_directories(
source_path=self.args.data_in,
target_path=self.args.data_out,
target_subdir="dataframe",
reset_target=reset_target,
)
data_in = data_paths.normalized_input
self.args.data_in = data_in
if self.verbose > 1:
logging.info(
f"Worker #{self._worker_id} dataframe root path = {self.data_out}"
)
if self.verbose > 0:
logging.info(f"start worker_id {self._worker_id}\n")
args = self.args
if not getattr(args, "data_source", None):
args.data_source = "file"
if not getattr(args, "output_format", None):
args.output_format = "parquet"
if args.data_source != "file":
raise NotImplementedError(UNSUPPORTED_DATA_SOURCE_MESSAGE)
self.pool_vars["args"] = self.args
self.pool_vars["verbose"] = self.verbose
global_vars = self.pool_vars
[docs]
def work_init(self):
"""Initialize work by reading from shared space."""
global global_vars
pass
[docs]
def pool_init(self, worker_vars):
"""Initialize the pool with worker variables.
Args:
worker_vars (dict): Variables specific to the worker.
"""
global global_vars
global_vars = worker_vars
[docs]
def work_pool(self, file):
"""Parse IVQ log files.
Args:
file (str): The log file to parse.
Returns:
pl.DataFrame: Parsed data.
"""
global global_vars
data_source = getattr(global_vars["args"], "data_source", None)
if not data_source:
data_source = "file"
if data_source == "file":
file_path = Path(os.path.expanduser(str(file)))
if not file_path.is_absolute():
file_path = Path(os.path.expanduser(f"~/{file}"))
if os.name != "nt":
file = os.path.normpath(str(file_path)).replace("\\", "/")
else:
file = normalize_path(str(file_path))
if not Path(file).is_file():
raise FileNotFoundError(file)
else:
raise NotImplementedError(UNSUPPORTED_DATA_SOURCE_MESSAGE)
# Read the CSV file using Polars.
df = pl.read_csv(file)
# If the first column is redundant (e.g. named "Unnamed: 0"), drop it.
if df.columns and (df.columns[0].startswith("Unnamed") or df.columns[0] == ""):
df = df.drop(df.columns[0])
# Preprocess the DataFrame (date parsing, cleaning, etc.)
df = self.preprocess_df(df)
# Preserve the historical output column name expected by downstream demos.
df = self.calculate_speed("speed", df)
return df.with_columns(pl.lit(Path(file).name).alias("source_file"))
[docs]
def work_done(self, worker_df):
"""Concatenate dataframe if any and save the results.
Args:
worker_df (pl.DataFrame): Output dataframe for one plane.
"""
if worker_df.is_empty():
return
logger.info(f"mkdir {self.data_out}")
os.makedirs(self.data_out, exist_ok=True)
output_files = []
# Process each plane separately
for plane in worker_df.select(pl.col("aircraft")).unique().to_series():
plane_df = worker_df.filter(pl.col("aircraft") == plane).sort("date")
# Create (or replace) "worker_id" from "aircraft":
plane_df = plane_df.with_columns(pl.col("aircraft").alias("worker_id"))
try:
if self.args.output_format == "parquet":
filename = (Path(self.data_out) / str(plane)).with_suffix(
".parquet"
)
plane_df.write_parquet(str(filename))
output_files.append(filename)
elif self.args.output_format == "csv":
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = f"{self.data_out}/{str(plane)+'_'+timestamp}.csv"
plane_df.write_csv(str(filename))
output_files.append(filename)
logging.info(
f"Saved dataframe for plane {plane} with shape {plane_df.shape} in {filename}"
)
except Exception as e:
logging.info(traceback.format_exc())
raise RuntimeError(f"Error saving dataframe for plane {plane}") from e
write_reduce_artifact(
worker_df,
self.data_out,
worker_id=getattr(self, "_worker_id", 0),
output_files=output_files,
output_format=getattr(self.args, "output_format", ""),
)
[docs]
def stop(self):
try:
"""Finalize the worker by listing saved dataframes."""
files = glob.glob(os.path.join(self.data_out, "**"), recursive=True)
df_files = [f for f in files if re.search(r"\.(csv|parquet|json)$", f)]
n_df = len(df_files)
if self.verbose > 0:
logging.info(f"{n_df} dataframes")
for f in df_files:
logging.info("\t" + str(Path(f)))
if not n_df:
logging.info("No dataframe created")
except Exception as err:
logging.info(f"Error while trying to find dataframes: {err}")
# call the base class stop()
super().stop()