# BSD 3-Clause License
#
# Copyright (c) 2025, Jean-Pierre Morard, THALES SIX GTS France SAS
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
# 3. Neither the name of Jean-Pierre Morard nor the names of its contributors, or THALES SIX GTS France SAS, may be used to endorse or promote products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
node module
Auteur: Jean-Pierre Morard
"""
######################################################
# Agi Framework call back functions
######################################################
# Internal Libraries:
import abc
import asyncio
from contextlib import suppress
import getpass
import inspect
import json
import os
import pickle
import shutil
import stat
import subprocess
import sys
import tempfile
import threading
import time
import uuid
import traceback
import warnings
from pathlib import Path, PureWindowsPath
from types import SimpleNamespace
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Union
# External Libraries:
import numpy as np
from distutils.sysconfig import get_python_lib
import psutil
import humanize
import datetime
import logging
from copy import deepcopy
from agi_env import AgiEnv, normalize_path
from agi_env.agi_logger import AgiLogger
from . import base_worker_execution_support as execution_support
from . import base_worker_path_support as path_support
from . import base_worker_runtime_support as runtime_support
from . import base_worker_service_support as service_support
logger = AgiLogger.get_logger(__name__)
warnings.filterwarnings("ignore")
[docs]
class BaseWorker(abc.ABC):
"""
class BaseWorker v1.0
"""
_insts = {}
_built = None
_pool_init = None
_work_pool = None
_share_path = None
verbose = 1
_mode = None
env = None
_worker_id = None
_worker = None
_home_dir = None
_logs = None
_dask_home = None
_worker = None
_t0 = None
_is_managed_pc = getpass.getuser().startswith("T0")
_cython_decorators = ["njit"]
env: Optional[AgiEnv] = None
default_settings_path: ClassVar[str] = "app_settings.toml"
default_settings_section: ClassVar[str] = "args"
args_loader: ClassVar[Callable[..., Any] | None] = None
args_merger: ClassVar[Callable[[Any, Optional[Any]], Any] | None] = None
args_ensure_defaults: ClassVar[Callable[..., Any] | None] = None
args_dumper: ClassVar[Callable[..., None] | None] = None
args_dump_mode: ClassVar[str] = "json"
managed_pc_home_suffix: ClassVar[str] = "MyApp"
managed_pc_path_fields: ClassVar[tuple[str, ...]] = ()
_service_stop_events: ClassVar[Dict[int, threading.Event]] = {}
_service_active: ClassVar[Dict[int, bool]] = {}
_service_lock: ClassVar[threading.Lock] = threading.Lock()
_service_poll_default: ClassVar[float] = 1.0
@classmethod
def _require_args_helper(cls, attr_name: str) -> Callable[..., Any]:
helper = getattr(cls, attr_name, None)
if helper is None:
raise AttributeError(
f"{cls.__name__} must define `{attr_name}` to use argument helpers"
)
return helper
@classmethod
def _remap_managed_pc_path(
cls,
value: Path | str,
*,
env: AgiEnv | None = None,
) -> Path:
return path_support.remap_managed_pc_path(
value,
env=env or cls.env,
managed_pc_home_suffix=cls.managed_pc_home_suffix,
path_cls=Path,
home_factory=Path.home,
)
@classmethod
def _apply_managed_pc_path_overrides(
cls,
args: Any,
*,
env: AgiEnv | None = None,
) -> Any:
cls._ensure_managed_pc_share_dir(env)
fields = cls.managed_pc_path_fields
if not fields:
return args
for field in fields:
if not hasattr(args, field):
continue
value = getattr(args, field)
try:
remapped = cls._remap_managed_pc_path(value, env=env)
except (TypeError, ValueError):
continue
setattr(args, field, remapped)
return args
def _apply_managed_pc_paths(self, args: Any) -> Any:
return type(self)._apply_managed_pc_path_overrides(args, env=self.env)
@classmethod
def _ensure_managed_pc_share_dir(cls, env: AgiEnv | None) -> None:
path_support.ensure_managed_pc_share_dir(
env,
managed_pc_home_suffix=cls.managed_pc_home_suffix,
path_cls=Path,
home_factory=Path.home,
)
@classmethod
def _normalized_path(cls, value: Path | str) -> Path:
return path_support.normalized_path(
value,
normalize_path_fn=normalize_path,
path_cls=Path,
)
@staticmethod
def _share_root_path(env: AgiEnv | None) -> Path | None:
return path_support.share_root_path(env, path_cls=Path)
@classmethod
def _resolve_data_dir(
cls,
env: AgiEnv | None,
data_path: Path | str | None,
) -> Path:
"""Resolve ``data_in`` style values relative to the current environment."""
return path_support.resolve_data_dir(
env,
data_path,
share_root_path_fn=cls._share_root_path,
remap_managed_pc_path_fn=lambda value: cls._remap_managed_pc_path(value, env=env),
normalized_path_fn=cls._normalized_path,
path_cls=Path,
home_factory=Path.home,
)
@staticmethod
def _relative_to_user_home(path: Path) -> Path | None:
return path_support.relative_to_user_home(path, path_cls=Path)
@staticmethod
def _remap_user_home(path: Path, *, username: str) -> Path | None:
return path_support.remap_user_home(path, username=username, path_cls=Path)
@staticmethod
def _strip_share_prefix(path: Path, aliases: set[str]) -> Path:
return path_support.strip_share_prefix(path, aliases, path_cls=Path)
@staticmethod
def _can_create_path(path: Path) -> bool:
return path_support.can_create_path(path, path_cls=Path)
@staticmethod
def _collect_share_aliases(
env: AgiEnv | None, share_base: Path
) -> set[str]:
return path_support.collect_share_aliases(env, share_base, path_cls=Path)
@staticmethod
def _iter_input_files(
folder: Path,
*,
patterns: Iterable[str] | None = None,
) -> list[Path]:
return path_support.iter_input_files(folder, patterns=patterns)
@classmethod
def _has_min_input_files(
cls,
folder: Path,
*,
min_files: int = 1,
patterns: Iterable[str] | None = None,
) -> bool:
return path_support.has_min_input_files(
folder,
min_files=min_files,
patterns=patterns,
iter_input_files_fn=cls._iter_input_files,
)
@classmethod
def _candidate_named_dataset_roots(
cls,
env: AgiEnv | None,
dataset_root: Path | str,
*,
namespace: str | None = None,
parent_levels: int = 4,
) -> list[Path]:
return path_support.candidate_named_dataset_roots(
env,
dataset_root,
namespace=namespace,
parent_levels=parent_levels,
normalized_path_fn=cls._normalized_path,
share_root_path_fn=cls._share_root_path,
path_cls=Path,
)
[docs]
def prepare_output_dir(
self,
root: Path | str,
*,
subdir: str = "dataframe",
attribute: str = "data_out",
clean: bool = True,
) -> Path:
"""Create (and optionally reset) a deterministic output directory."""
target = Path(normalize_path(Path(root) / subdir))
if clean and target.exists():
try:
shutil.rmtree(target, ignore_errors=True, onerror=self._onerror)
except (OSError, RuntimeError) as exc: # pragma: no cover - defensive guard
logger.warning(
"Issue while cleaning output directory %s: %s", target, exc
)
try:
logger.info(f"mkdir {target}")
target.mkdir(parents=True, exist_ok=True)
except OSError as exc: # pragma: no cover - defensive guard
logger.warning(
"Issue while ensuring output directory %s exists: %s", target, exc
)
setattr(self, attribute, target)
return target
[docs]
def setup_args(
self,
args: Any,
*,
env: AgiEnv | None = None,
error: str | None = None,
output_field: str | None = None,
output_subdir: str = "dataframe",
output_attr: str = "data_out",
output_clean: bool = True,
output_parents_up: int = 0,
) -> Any:
env = env or getattr(self, "env", None)
if args is None:
raise ValueError(
error or f"{type(self).__name__} requires an initialized arguments object"
)
ensure_fn = getattr(type(self), "args_ensure_defaults", None)
if ensure_fn is not None:
args = ensure_fn(args, env=env)
processed = type(self)._apply_managed_pc_path_overrides(args, env=env)
self.args = processed
if output_field:
root = Path(getattr(processed, output_field))
for _ in range(max(output_parents_up, 0)):
root = root.parent
self.prepare_output_dir(
root,
subdir=output_subdir,
attribute=output_attr,
clean=output_clean,
)
return processed
[docs]
@classmethod
def from_toml(
cls,
env: AgiEnv,
settings_path: str | Path | None = None,
section: str | None = None,
**overrides: Any,
) -> "BaseWorker":
settings_path = settings_path or cls.default_settings_path
section = section or cls.default_settings_section
loader = cls._require_args_helper("args_loader")
merger = cls._require_args_helper("args_merger")
base_args = loader(settings_path, section=section)
merged_args = merger(base_args, overrides or None)
ensure_fn = getattr(cls, "args_ensure_defaults", None)
if ensure_fn is not None:
merged_args = ensure_fn(merged_args, env=env)
merged_args = cls._apply_managed_pc_path_overrides(merged_args, env=env)
return cls(env, args=merged_args)
[docs]
def to_toml(
self,
settings_path: str | Path | None = None,
section: str | None = None,
create_missing: bool = True,
) -> None:
_cls = type(self)
settings_path = settings_path or _cls.default_settings_path
section = section or _cls.default_settings_section
dumper = _cls._require_args_helper("args_dumper")
dumper(self.args, settings_path, section=section, create_missing=create_missing)
[docs]
def as_dict(self, mode: str | None = None) -> dict[str, Any]:
payload: dict[str, Any]
if hasattr(self, "args"):
dump_mode = mode or type(self).args_dump_mode
payload = self.args.model_dump(mode=dump_mode)
else:
payload = {}
return self._extend_payload(payload)
def _extend_payload(self, payload: dict[str, Any]) -> dict[str, Any]:
return payload
[docs]
@staticmethod
def start(worker_inst):
"""Invoke the concrete worker's ``start`` hook once initialised."""
try:
logger.info(
"worker #%s: %s - mode: %s",
BaseWorker._worker_id,
BaseWorker._worker,
getattr(worker_inst, "_mode", None),
)
method = getattr(worker_inst, "start", None)
base_method = BaseWorker.start
if method and method is not base_method:
method()
except Exception: # pragma: no cover - intentional hook boundary
logger.error("Worker start hook failed:\n%s", traceback.format_exc())
raise
[docs]
def stop(self):
"""
Returns:
"""
logger.info(f"worker #{self._worker_id}: {self._worker} - mode: {self._mode}"
)
with BaseWorker._service_lock:
is_active = BaseWorker._service_active.get(self._worker_id)
if is_active:
try:
BaseWorker.break_loop()
except RuntimeError:
logger.debug("break_loop raised", exc_info=True)
[docs]
@staticmethod
def loop(*, poll_interval: Optional[float] = None) -> Dict[str, Any]:
"""Run a long-lived service loop on this worker until signalled to stop.
The derived worker can implement a ``loop`` method accepting either zero
arguments or a single ``stop_event`` argument. When the method signature
accepts ``stop_event`` (keyword ``stop_event`` or ``should_stop``), the
worker implementation is responsible for honouring the event. Otherwise
the base implementation repeatedly invokes the method and sleeps for the
configured poll interval between calls. Returning ``False`` from the
worker method requests termination of the loop.
"""
worker_id = BaseWorker._worker_id
worker_inst = BaseWorker._insts.get(worker_id)
if worker_id is None or worker_inst is None:
raise RuntimeError("BaseWorker.loop called before worker initialisation")
with BaseWorker._service_lock:
stop_event = threading.Event()
BaseWorker._service_stop_events[worker_id] = stop_event
BaseWorker._service_active[worker_id] = True
poll = BaseWorker._service_poll_default if poll_interval is None else max(
poll_interval, 0.0
)
# Only invoke a worker-defined loop implementation. If the worker
# relies on BaseWorker.loop (default), block on stop_event instead of
# recursively calling this method again.
worker_loop = getattr(type(worker_inst), "loop", None)
loop_fn = None
if callable(worker_loop) and worker_loop is not BaseWorker.loop:
loop_fn = getattr(worker_inst, "loop", None)
accepts_event = False
if callable(loop_fn):
try:
signature = inspect.signature(loop_fn)
accepts_event = any(
param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY)
and param.name in {"stop_event", "should_stop"}
for param in signature.parameters.values()
)
except (TypeError, ValueError):
# Some builtins don't expose signatures; fall back to simple mode
accepts_event = False
worker_args = getattr(worker_inst, "args", None)
queue_root = service_support.resolve_service_queue_root(
worker_args,
path_cls=Path,
)
def _write_heartbeat(_state: str) -> None:
return
if queue_root is not None:
_write_heartbeat = service_support.make_heartbeat_writer(
queue_root,
worker_id=worker_id,
worker_name=BaseWorker._worker,
logger_obj=logger,
path_cls=Path,
open_fn=open,
json_module=json,
os_module=os,
time_module=time,
)
start_time = time.time()
logger.info(
"worker #%s: %s entering service loop (poll %.3fs)",
worker_id,
BaseWorker._worker,
poll,
)
primary_exc: Optional[BaseException] = None
try:
if not callable(loop_fn):
if queue_root is not None:
payload = service_support.run_service_queue(
stop_event=stop_event,
queue_root=queue_root,
worker_id=worker_id,
worker_name=BaseWorker._worker,
poll=poll,
do_works_fn=BaseWorker._do_works,
write_heartbeat=_write_heartbeat,
logger_obj=logger,
path_cls=Path,
open_fn=open,
pickle_module=pickle,
os_module=os,
time_module=time,
traceback_module=traceback,
)
payload["runtime"] = time.time() - start_time
return payload
# No custom loop provided; block until break is requested.
stop_event.wait()
return {"status": "stopped", "runtime": time.time() - start_time}
def _run_once() -> Any:
if accepts_event:
return loop_fn(stop_event)
return loop_fn()
while not stop_event.is_set():
_write_heartbeat("running")
result = _run_once()
if inspect.isawaitable(result):
try:
result = asyncio.run(result)
except RuntimeError:
loop = asyncio.new_event_loop()
try:
result = loop.run_until_complete(result)
finally:
loop.close()
if result is False:
break
if accepts_event:
# Worker manages its own waiting when it handles the stop event.
continue
if poll > 0:
stop_event.wait(poll)
_write_heartbeat("stopped")
return {"status": "stopped", "runtime": time.time() - start_time}
except Exception as exc: # pragma: no cover - intentional hook boundary
primary_exc = exc
logger.exception("Service loop failed: %s", exc)
raise
finally:
_write_heartbeat("stopped")
with BaseWorker._service_lock:
BaseWorker._service_active.pop(worker_id, None)
BaseWorker._service_stop_events.pop(worker_id, None)
stop_hook = getattr(worker_inst, "stop", None)
if callable(stop_hook):
try:
stop_hook()
except Exception: # pragma: no cover - intentional hook boundary
logger.exception("Worker stop hook raised inside service loop", exc_info=True)
if primary_exc is None:
raise
logger.info(
"worker #%s: %s leaving service loop (elapsed %.3fs)",
worker_id,
BaseWorker._worker,
time.time() - start_time,
)
[docs]
@staticmethod
def break_loop() -> bool:
"""Signal the service loop to exit on this worker."""
worker_id = BaseWorker._worker_id
if worker_id is None:
logger.warning("break_loop called without worker context")
return False
with BaseWorker._service_lock:
stop_event = BaseWorker._service_stop_events.get(worker_id)
if stop_event is None:
logger.info("worker #%s: no active service loop to break", worker_id)
return False
stop_event.set()
logger.info("worker #%s: service loop break requested", worker_id)
return True
[docs]
@staticmethod
def expand_and_join(path1, path2):
"""
Join two paths after expanding the first path.
Args:
path1 (str): The first path to expand and join.
path2 (str): The second path to join with the expanded first path.
Returns:
str: The joined path.
"""
if os.name == "nt" and not BaseWorker._is_managed_pc:
path = Path(path1)
parts = path.parts
if "Users" in parts:
index = parts.index("Users") + 2
path = Path(*parts[index:])
net_path = normalize_path("\\\\127.0.0.1\\" + str(path))
try:
# your nfs account in order to mount it as net drive on windows
cmd = f'net use Z: "{net_path}" /user:your-name your-password'
logger.info(cmd)
subprocess.run(cmd, shell=True, check=True)
except (OSError, subprocess.CalledProcessError) as e:
logger.error(f"Mount failed: {e}")
return BaseWorker._join(BaseWorker.expand(path1), path2)
[docs]
@staticmethod
def expand(path, base_directory=None):
# Normalize Windows-style backslashes to POSIX forward slashes
"""
Expand a given path to an absolute path.
Args:
path (str): The path to expand.
base_directory (str, optional): The base directory to use for expanding the path. Defaults to None.
Returns:
str: The expanded absolute path.
Raises:
None
Note:
This method handles both Unix and Windows paths and expands '~' notation to the user's home directory.
"""
normalized_path = path.replace("\\", "/")
# Check if the path starts with `~`, expand to home directory only in that case
if normalized_path.startswith("~"):
expanded_path = Path(normalized_path).expanduser()
else:
# Use base_directory if provided; otherwise, assume current working directory
base_directory = (
Path(base_directory).expanduser()
if base_directory
else Path("~/").expanduser()
)
expanded_path = (base_directory / normalized_path).resolve()
if os.name != "nt":
return str(expanded_path)
else:
return normalize_path(expanded_path)
[docs]
@staticmethod
def normalize_dataset_path(data_path: Union[str, Path]) -> str:
"""Normalise any dataset directory input so workers rely on consistent paths."""
data_in_str = str(data_path)
if os.name == "nt" and data_in_str.startswith("\\\\"):
candidate = Path(PureWindowsPath(data_in_str))
else:
candidate = Path(data_in_str).expanduser()
if not candidate.is_absolute():
candidate = (Path.home() / candidate).expanduser()
try:
candidate = candidate.resolve(strict=False)
except OSError:
candidate = Path(os.path.normpath(str(candidate)))
if os.name == "nt":
resolved_str = os.path.normpath(str(candidate))
if not BaseWorker._is_managed_pc:
parts = Path(resolved_str).parts
if "Users" in parts:
mapped = Path(*parts[parts.index("Users") + 2 :])
else:
mapped = Path(resolved_str)
net_path = normalize_path(f"\\\\127.0.0.1\\{mapped}")
try:
cmd = f'net use Z: "{net_path}" /user:your-credentials'
logger.info(cmd)
subprocess.run(cmd, shell=True, check=True)
except (OSError, subprocess.CalledProcessError) as exc:
logger.info("Failed to map network drive: %s", exc)
return resolved_str
return candidate.as_posix()
[docs]
def setup_data_directories(
self,
*,
source_path: str | Path,
target_path: str | Path | None = None,
target_subdir: str = "dataframe",
reset_target: bool = False,
) -> SimpleNamespace:
"""Prepare normalised input/output dataset paths without relying on worker args.
Returns a namespace with the resolved input path (`input_path`), the normalised
string used by downstream readers (`normalized_input`), the output directory
as a ``Path`` (`output_path`), and its normalised string representation
(`normalized_output`). Optionally clears and recreates the output directory.
"""
if source_path is None:
raise ValueError("setup_data_directories requires a source_path value")
env = self.env
input_path = type(self)._resolve_data_dir(env, source_path)
normalized_input = self.normalize_dataset_path(input_path)
base_parent = input_path.parent
if target_path is None:
output_path = base_parent / target_subdir
else:
candidate = Path(str(target_path)).expanduser()
if not candidate.is_absolute():
share_root = type(self)._share_root_path(env)
has_nested_segments = len(candidate.parts) > 1
if has_nested_segments:
anchor = share_root or base_parent.parent or base_parent
else:
anchor = base_parent
candidate = (Path(anchor) / candidate).expanduser()
try:
output_path = candidate.resolve(strict=False)
except (OSError, RuntimeError):
output_path = Path(os.path.normpath(str(candidate)))
normalized_output = normalize_path(output_path)
if os.name != "nt":
normalized_output = normalized_output.replace("\\", "/")
def _ensure_output_dir(path: str | Path) -> Path:
path_obj = Path(path).expanduser()
try:
logger.info(f"mkdir {path_obj}")
path_obj.mkdir(parents=True, exist_ok=True)
return path_obj
except (OSError, TypeError, ValueError) as exc:
raise OSError(f"Failed to create output directory {path_obj}: {exc}") from exc
try:
if reset_target:
try:
shutil.rmtree(normalized_output, ignore_errors=True, onerror=self._onerror)
except (OSError, RuntimeError) as exc:
logger.info("Error removing directory: %s", exc)
output_path = _ensure_output_dir(normalized_output)
normalized_output = normalize_path(output_path)
if os.name != "nt":
normalized_output = normalized_output.replace("\\", "/")
except OSError:
fallback_base = None
if env:
if env.AGI_LOCAL_SHARE:
fallback_base = Path(env.AGI_LOCAL_SHARE).expanduser()
else:
fallback_base = Path(env.home_abs)
if fallback_base is None:
fallback_base = Path.home()
fallback_target = env.target if env else Path(normalized_output).name
fallback = fallback_base / fallback_target
try:
fallback = _ensure_output_dir(fallback / target_subdir)
normalized_output = normalize_path(fallback)
if os.name != "nt":
normalized_output = normalized_output.replace("\\", "/")
logger.warning(
"Output path %s unavailable; using fallback %s",
output_path if 'output_path' in locals() else normalized_output,
normalized_output,
)
except OSError as exc:
logger.error("Fallback output directory failed: %s", exc)
raise
# Preserve compatibility with workers that rely on these attributes.
self.home_rel = input_path
self.data_out = normalized_output
return SimpleNamespace(
input_path=input_path,
normalized_input=normalized_input,
output_path=output_path,
normalized_output=normalized_output,
)
@staticmethod
def _join(path1, path2):
# path to data base on symlink Path.home()/data(symlink)
"""
Join two file paths.
Args:
path1 (str): The first file path.
path2 (str): The second file path.
Returns:
str: The combined file path.
Raises:
None
"""
path = os.path.join(BaseWorker.expand(path1), path2)
if os.name != "nt":
path = path.replace("\\", "/")
return path
@staticmethod
def _get_logs_and_result(func, *args, verbosity=logging.CRITICAL, **kwargs):
return runtime_support.capture_logs_and_result(
func,
*args,
verbosity=verbosity,
**kwargs,
)
@staticmethod
def _exec(cmd, path, worker):
"""execute a command within a subprocess
Args:
cmd: the str of the command
path: the path where to lunch the command
worker:
Returns:
"""
return runtime_support.exec_command(
cmd,
path,
worker,
normalize_path_fn=normalize_path,
logger_obj=logger,
)
@staticmethod
def _log_import_error(module, target_class, target_module):
runtime_support.log_import_error(
module,
target_class,
target_module,
logger_obj=logger,
file_path=__file__,
sys_path=sys.path,
)
@staticmethod
def _load_module(module_name, module_class):
return runtime_support.load_module(module_name, module_class)
@staticmethod
def _load_manager():
return runtime_support.load_manager(
BaseWorker.env,
load_module_fn=BaseWorker._load_module,
sys_modules=sys.modules,
)
@staticmethod
def _load_worker(mode):
return runtime_support.load_worker(
BaseWorker.env,
mode,
load_module_fn=BaseWorker._load_module,
sys_modules=sys.modules,
)
@staticmethod
def _is_cython_installed(env):
return runtime_support.is_cython_installed(env)
@staticmethod
async def _run(env=None, workers={"127.0.0.1": 1}, mode=0, verbose=None, args=None):
"""
:param app:
:param workers:
:param mode:
:param verbose:
:param args:
:return:
"""
if not env:
env = BaseWorker.env
else:
BaseWorker.env
def _load_dispatcher():
from .agi_dispatcher import WorkDispatcher # Local import to avoid circular dependency
return WorkDispatcher
return await execution_support.run_worker(
env=env,
workers=workers,
mode=mode,
args=args,
do_works_fn=BaseWorker._do_works,
dispatcher_loader=_load_dispatcher,
sys_path=sys.path,
logger_obj=logger,
traceback_module=traceback,
time_module=time,
humanize_module=humanize,
datetime_module=datetime,
path_cls=Path,
)
@staticmethod
def _onerror(func, path, exc_info):
"""
Error handler for `shutil.rmtree`.
If it’s a permission error, make it writable and retry.
Otherwise re-raise.
"""
exc_type, exc_value, _ = exc_info
# handle permission errors or any non-writable path
if exc_type is PermissionError or not os.access(path, os.W_OK):
try:
os.chmod(path, stat.S_IWUSR | stat.S_IREAD)
func(path)
except OSError as e:
logger.error(f"warning failed to grant write access to {path}: {e}")
else:
# not a permission problem—re-raise so you see real errors
raise exc_value
@staticmethod
def _new(
env: AgiEnv=None,
app: str=None,
mode: int=0,
verbose: int=0,
worker_id: int=0,
worker: str="localhost",
args: dict=None,
):
"""new worker instance
Args:
module: instanciate and load target mycode_worker module
target_worker:
target_worker_class:
target_package:
mode: (Default value = mode)
verbose: (Default value = 0)
worker_id: (Default value = 0)
worker: (Default value = 'localhost')
args: (Default value = None)
Returns:
"""
execution_support.initialize_worker(
env=env,
app=app,
mode=mode,
verbose=verbose,
worker_id=worker_id,
worker=worker,
args=args,
base_worker_cls=BaseWorker,
agi_env_factory=AgiEnv,
ensure_managed_pc_share_dir_fn=BaseWorker._ensure_managed_pc_share_dir,
load_worker_fn=BaseWorker._load_worker,
start_fn=BaseWorker.start,
args_namespace_cls=ArgsNamespace,
logger_obj=logger,
time_module=time,
traceback_module=traceback,
sys_module=sys,
file_path=__file__,
path_cls=Path,
)
@staticmethod
def _get_worker_info(worker_id):
"""def get_worker_info():
Args:
worker_id:
Returns:
"""
return execution_support.collect_worker_info(
share_path=BaseWorker._share_path,
worker=BaseWorker._worker,
normalize_path_fn=normalize_path,
logger_obj=logger,
psutil_module=psutil,
tempfile_module=tempfile,
os_module=os,
time_module=time,
)
@staticmethod
def _build(target_worker, dask_home, worker, mode=0, verbose=0):
"""
Function to build target code on a target Worker.
Args:
target_worker (str): module to build
dask_home (str): path to dask home
worker: current worker
mode: (Default value = 0)
verbose: (Default value = 0)
"""
execution_support.build_worker_artifacts(
target_worker=target_worker,
dask_home=dask_home,
worker=worker,
mode=mode,
verbose=verbose,
base_worker_cls=BaseWorker,
logger_obj=logger,
getuser_fn=getpass.getuser,
file_path=__file__,
sys_path=sys.path,
path_cls=Path,
os_module=os,
shutil_module=shutil,
)
@staticmethod
def _expand_chunk(payload, worker_id):
"""Unwrap per-worker payload chunk back into legacy list form."""
if not isinstance(payload, dict) or not payload.get("__agi_worker_chunk__"):
return payload, None, None
chunk = payload.get("chunk", [])
total_workers = payload.get("total_workers")
worker_idx = payload.get("worker_idx", worker_id if worker_id is not None else 0)
if isinstance(total_workers, int) and total_workers > 0:
reconstructed_len = max(total_workers, worker_idx + 1)
else:
reconstructed_len = worker_idx + 1
def _placeholder():
if isinstance(chunk, list):
return []
if isinstance(chunk, dict):
return {}
return None
reconstructed = [_placeholder() for _ in range(reconstructed_len)]
reconstructed[worker_idx] = chunk
chunk_len = len(chunk) if hasattr(chunk, "__len__") else (1 if chunk else 0)
return reconstructed, chunk_len, reconstructed_len
@staticmethod
def _do_works(workers_plan, workers_plan_metadata):
"""run of workers
Args:
workers_plan: distribution tree
workers_plan_metadata:
Returns:
logs: str, the log output from this worker
"""
return execution_support.execute_worker_plan(
workers_plan=workers_plan,
workers_plan_metadata=workers_plan_metadata,
worker_id=BaseWorker._worker_id,
worker_name=BaseWorker._worker,
insts=BaseWorker._insts,
expand_chunk_fn=BaseWorker._expand_chunk,
logger_obj=logger,
traceback_module=traceback,
file_path=__file__,
path_cls=Path,
)
# enable dotted access ``BaseWorker.break()`` even though ``break`` is a keyword
setattr(BaseWorker, "break", BaseWorker.break_loop)
[docs]
class ArgsNamespace(SimpleNamespace):
"""Namespace that supports both attribute and key-style access."""
def __getitem__(self, key):
try:
return getattr(self, key)
except AttributeError as exc:
raise KeyError(key) from exc
[docs]
def get(self, key, default=None):
return getattr(self, key, default)
def __contains__(self, key):
return hasattr(self, key)
[docs]
def to_dict(self):
return dict(self.__dict__)