Source code for agi_env.agi_env

from IPython.core.ultratb import FormattedTB
from IPython.core.ultratb import FormattedTB
import ast
import asyncio
import getpass
import os
import re
import shutil
import subprocess
import sys
import asyncssh
from asyncssh.process import ProcessError
from contextlib import asynccontextmanager
import traceback
from pathlib import Path, PureWindowsPath, PurePosixPath
from dotenv import dotenv_values, set_key
import tomlkit
import logging
import inspect
import errno
import astor
from pathspec import PathSpec
from pathspec.patterns import GitWildMatchPattern
from requests import packages

# Compile regex once globally
LOG_LEVEL_RE = re.compile(r'\b(INFO|ERROR|WARNING|DEBUG|CRITICAL)\b')

# Patch for IPython ≥8.37 (theme_name) vs ≤8.36 (color_scheme)
_sig = inspect.signature(FormattedTB.__init__).parameters
_tb_kwargs = dict(mode='Verbose', call_pdb=True)
if 'color_scheme' in _sig:
    _tb_kwargs['color_scheme'] = 'Linux'
else:
    _tb_kwargs['theme_name'] = 'Linux'

sys.excepthook = FormattedTB(**_tb_kwargs)

logger = logging.getLogger(__name__)

[docs] def normalize_path(path): return ( str(PureWindowsPath(Path(path))) if os.name == "nt" else str(PurePosixPath(Path(path))) )
[docs] class AgiEnv: install_type = None apps_dir = None app = None module = None GUI_NROW = None GUI_SAMPLING = None init_done = False has_rapids_hw = None _debug = False uv = None benchmark = None verbose = None import inspect import logging import sys
[docs] def init_logging(self, verbosity: int = None): """ Initialize logging with a level based on verbosity: 0 = WARNING, 1 = INFO, 2 or more = DEBUG INFO and DEBUG levels go to stdout; WARNING and above go to stderr. """ self.uv = "uv" if verbosity is None: verbosity = 0 elif verbosity > 1: self.uv = "uv -q" # Root logger level based on verbosity root_level = logging.DEBUG if verbosity >= 2 else logging.INFO if verbosity == 1 else logging.WARNING # Cap distributed logs at CRITICAL (silent) sys_level = logging.ERROR if verbosity < 2 else logging.INFO if verbosity > 3 else logging.DEBUG # Use root_level for your app-specific loggers as well app_level = root_level root = logging.getLogger() root.setLevel(root_level) # Set distributed logger levels explicitly to suppress debug/info noise logging.getLogger("distributed").setLevel(sys_level) logging.getLogger("distributed.worker").setLevel(sys_level) logging.getLogger("distributed.scheduler").setLevel(sys_level) logging.getLogger("distributed.comm").setLevel(sys_level) logging.getLogger("distributed.comm.tcp").setLevel(sys_level) logging.getLogger("distributed.active_memory_manager").setLevel(sys_level) # Set asyncssh and other custom loggers to app_level (verbosity controlled) logging.getLogger('asyncssh').setLevel(sys_level) # agilab fwk logging.getLogger("agi_runner").setLevel(app_level) logging.getLogger("agi_worker").setLevel(app_level) logging.getLogger("agi_manager").setLevel(app_level) logging.getLogger("agi_env").setLevel(app_level) logging.getLogger("dag_worker").setLevel(app_level) logging.getLogger("pandas_worker").setLevel(app_level) logging.getLogger("polars_worker").setLevel(app_level) logging.getLogger("agent_worker").setLevel(app_level) # Remove existing handlers to avoid duplicate logs for handler in root.handlers[:]: root.removeHandler(handler) class ClassNameFilter(logging.Filter): def filter(self, record): # Try to find the class name from the frame where the log call was made try: # Walk up frames starting from current to find frame matching record frame = sys._getframe(0) while frame: code = frame.f_code if code.co_filename == record.pathname and code.co_name == record.funcName: # Found the frame of the caller # Check if 'self' is in locals to get class name if 'self' in frame.f_locals: record.classname = frame.f_locals['self'].__class__.__name__ else: record.classname = record.module or record.pathname break frame = frame.f_back else: record.classname = '<no-class>' except Exception: record.classname = '<no-class>' return True fmt_std = logging.Formatter( "%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S" ) fmt_err = logging.Formatter( "%(asctime)s %(levelname)s %(classname)s %(funcName)s %(message)s", datefmt="%H:%M:%S" ) if verbosity > 1: fmt_std = fmt_err stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setLevel(logging.DEBUG) stdout_handler.setFormatter(fmt_std) stdout_handler.addFilter(ClassNameFilter()) stderr_handler = logging.StreamHandler(sys.stderr) stderr_handler.setLevel(logging.WARNING) stderr_handler.setFormatter(fmt_err) stderr_handler.addFilter(ClassNameFilter()) root.addHandler(stdout_handler) root.addHandler(stderr_handler) root.setLevel(logging.DEBUG if verbosity and verbosity >= 2 else logging.INFO if verbosity == 1 else logging.WARNING) logging.debug(f"Logging initialized at level {logging.getLevelName(root.level)}")
[docs] def __init__(self, install_type: int = None, apps_dir: Path = None, active_app: Path | str = None, verbose: int = None, debug=False): AgiEnv.verbose = verbose self.verbose = verbose self.init_logging(verbose) AgiEnv._debug = debug self.is_managed_pc = getpass.getuser().startswith("T0") self.agi_resources = Path("resources/.agilab") home_abs = Path.home() / "MyApp" if self.is_managed_pc else Path.home() self.home_abs = home_abs self.resource_path = home_abs / self.agi_resources.name env_path = self.resource_path / ".env" self.benchmark = self.resource_path / "benchmark.json" self.envars = dotenv_values(dotenv_path=env_path, verbose=verbose) envars = self.envars if install_type is None: install_type = 1 if ("site-packages" not in __file__ or sys.prefix.endswith("gui/.venv")) else 0 elif isinstance(install_type, str): install_type = int(install_type) self.install_type = install_type if install_type != 2: self.agi_root = AgiEnv.locate_agi_installation(verbose) else: self.agi_root = home_abs / "wenv" / active_app if install_type == 1: if "site-packages" in self.agi_root.parts: self.agi_env_root = self.agi_root.parent / 'agi_env' self.agi_core_root = self.agi_root.parent / 'agi_root' resource_path = self.agi_env_root / self.agi_resources else: self.agi_env_root = self.agi_root / "fwk/env" self.agi_core_root = self.agi_root / "fwk/core" resource_path = self.agi_env_root / "src/agi_env" / self.agi_resources if not self.agi_env_root.exists(): raise RuntimeError("Your Agilab installation is not valid") self._init_resources(resource_path) elif install_type == 2: if AgiEnv._debug: self.agi_env_root = self.agi_root / "fwk/env" self.agi_core_root = self.agi_root / "fwk/core" else: self.agi_env_root = list(Path(sys.prefix).rglob('agi_env'))[0] self.agi_core_root = list(Path(sys.prefix).rglob('agi_core'))[0] elif install_type == 0: head, sep, _ = __file__.partition("site-packages") if not sep: raise ValueError("site-packages not in", __file__) self.agi_env_root = Path(head + sep) self.agi_core_root = Path(head + sep) if not apps_dir: apps_dir = 'apps' if install_type != 2: apps_dir = envars.get("APPS_DIR", apps_dir) else: set_key(dotenv_path=env_path, key_to_set="APPS_DIR", value_to_set=str(apps_dir)) apps_dir = Path(apps_dir) try: if apps_dir.exists(): self.apps_dir = apps_dir elif install_type < 2: self.apps_dir = self.agi_root / apps_dir else: os.makedirs(str(apps_dir), exist_ok=True) except FileNotFoundError: logging.error("apps_dir not found: %s", apps_dir) sys.exit(1) self.GUI_NROW = int(envars.get("GUI_NROW", 1000)) self.GUI_SAMPLING = int(envars.get("GUI_SAMPLING", 20)) if not active_app: active_app = envars.get("APP_DEFAULT", 'flight_project') src_apps = None if isinstance(active_app, str): if not active_app.endswith('_project'): active_app = active_app + '_project' self.app = active_app if not install_type: src_apps = self.agi_root / "apps" apps_dir = Path(self.agi_root).parents[4] / "apps" if not apps_dir.exists(): shutil.copytree(src_apps, apps_dir) else: self.copy_missing(src_apps, apps_dir) src_apps = apps_dir module = active_app.replace("_project", "").replace("-", "_") else: apps_dir = self._determine_apps_dir(active_app) module = apps_dir.name.replace("_project", "").replace("-", "_") self.module = module wenv_root = Path("wenv") target_worker = f"{module}_worker" self.target_worker = target_worker wenv_rel = wenv_root / target_worker target_class = "".join(x.title() for x in module.split("_")) self.target_class = target_class worker_class = target_class + "Worker" self.target_worker_class = worker_class self.wenv_rel = wenv_rel self.dist_rel = wenv_rel / 'dist' wenv_abs = home_abs / wenv_rel self.wenv_abs = wenv_abs if not self.wenv_abs.exists(): os.makedirs(self.wenv_abs, exist_ok=True) dist_abs = wenv_abs / 'dist' dist = normalize_path(dist_abs) if not dist in sys.path: sys.path.append(dist) self.dist_abs = dist_abs self.wenv_target_worker = self.wenv_abs if install_type == 0: app_abs = src_apps / active_app app_src = app_abs / "src" self.app_pyproject = app_abs / "pyproject.toml" self.worker_path = app_src / target_worker / f"{target_worker}.py" self.worker_pyproject = self.worker_path.parent / "pyproject.toml" self.module_path = app_src / module / f"{self.module}.py" worker_module_path = self.worker_path.parent elif install_type == 1: app_abs = self.agi_root / apps_dir / active_app app_src = app_abs / "src" self.app_pyproject = app_abs / "pyproject.toml" self.worker_path = app_src / target_worker / f"{target_worker}.py" self.worker_pyproject = self.worker_path.parent / "pyproject.toml" self.module_path = app_src / module / f"{self.module}.py" worker_module_path = self.worker_path.parent elif install_type == 2: app_abs = self.agi_root app_src = self.agi_root / "src" self.worker_path = self.wenv_rel / 'src' / target_worker / f"{target_worker}.py" self.module_path = self.wenv_rel / 'src' / module / f"{self.module}.py" worker_module_path = self.worker_path.parent self.app_abs = app_abs self.uvproject = app_abs / "uv_config.toml" self.post_install = worker_module_path / "post_install.py" self.pre_install = worker_module_path / "pre_install.py" self.post_install_rel = self.wenv_rel / 'src' / target_worker / "post_install.py" src_path = normalize_path(app_src) if not src_path in sys.path: sys.path.append(src_path) AgiEnv.apps_dir = apps_dir distribution_tree = self.wenv_abs / "distribution_tree.json" if distribution_tree.exists(): distribution_tree.unlink() self.distribution_tree = distribution_tree if install_type == 2: return self.base_worker_cls, self.base_worker_module = self.get_base_worker_cls( self.worker_path, worker_class ) self.workers_packages_prefix = "agi_core.workers." if not self.worker_path.exists(): logging.info(f"Missing {self.target_worker_class} definition; should be in {self.worker_path} but it does not exist") sys.exit(1) envars = self.envars self.credantials = envars.get("CLUSTER_CREDENTIALS", getpass.getuser()) credantials = self.credantials.split(":") self.user = credantials[0] self.password = credantials[1] if len(credantials) > 1 else None self.python_version = envars.get("AGI_PYTHON_VERSION", "3.13") os.makedirs(AgiEnv.apps_dir, exist_ok=True) if "site-packages" in self.agi_root.parts: self.agi_core_loc = self.agi_root.parent else: self.agi_core_loc = self.agi_root / "fwk/core/src" if install_type != 2: self.resolve_packages_path_in_toml() agi_core = self.agi_core_loc / "agi_core" self.agi_core = agi_core self.projects = self.get_projects(self.apps_dir) if not self.projects: logging.info(f"Could not find any target project app in {self.agi_root / 'apps'}.") self.workers_root = agi_core / "workers" self.manager_root = agi_core / "managers" self.setup_app = app_abs / "build.py" self.setup_core_rel = "agi_worker/build.py" self.setup_core = self.workers_root / self.setup_core_rel if isinstance(module, Path): module_path = module.expanduser().resolve() else: module_path = self._determine_module_path(module) self.target = module_path.stem self.module_path = module_path self.AGILAB_SHARE = Path(envars.get("AGI_SHARE_DIR", "data")) data_rel = self.AGILAB_SHARE / self.target self.dataframes_path = data_rel / "dataframes" self.data_rel = data_rel self._init_projects() self.scheduler_ip = envars.get("AGI_SCHEDULER_IP", "127.0.0.1") if not self.is_valid_ip(self.scheduler_ip): raise ValueError(f"Invalid scheduler IP address: {self.scheduler_ip}") if self.install_type: self.help_path = str(self.agi_root / "../docs/html") else: self.help_path = "https://thalesgroup.github.io/agilab" self.AGILAB_SHARE = Path(envars.get("AGI_SHARE_DIR", home_abs / "data")) app_src.mkdir(parents=True, exist_ok=True) app_src_str = str(app_src) if app_src_str not in sys.path: sys.path.append(app_src_str) self.app_src = app_src self.app_abs = app_abs # type 3: only core install if AgiEnv.install_type != 3: self.init_envars_app(self.envars) self._init_apps() if os.name == "nt": self.export_local_bin = None else: self.export_local_bin = 'export PATH="$HOME/.local/bin:$PATH";' self._ssh_connections = {}
[docs] def active(self, target, install_type): if self.module != target: self.change_active_app(target + '_project', install_type)
[docs] def check_args(self, target_args_class, target_args): try: validated_args = target_args_class.parse_obj(target_args) validation_errors = None except Exception as e: import humanize validation_errors = self.humanize_validation_errors(e) return validation_errors
[docs] def humanize_validation_errors(self, error): formatted_errors = [] for err in error.errors(): field = ".".join(str(loc) for loc in err["loc"]) message = err["msg"] error_type = err.get("type", "unknown_error") input_value = err.get("ctx", {}).get("input_value", None) user_message = f"❌ **{field}**: {message}" if input_value is not None: user_message += f" (Received: `{input_value}`)" user_message += f"*Error Type:* `{error_type}`" formatted_errors.append(user_message) return formatted_errors
[docs] def set_env_var(self, key: str, value: str): self.envars[key] = value os.environ[key] = str(value) self._update_env_file({key: value})
[docs] @staticmethod def locate_agi_installation(verbose=False): if os.name == "nt": where_is_agi = Path(os.getenv("LOCALAPPDATA", "")) / "agilab/.agi-path" else: where_is_agi = Path.home() / ".local/share/agilab/.agi-path" if where_is_agi.exists(): try: with where_is_agi.open("r", encoding="utf-8-sig") as f: install_path = f.read().strip() agilab_path = Path(install_path) if install_path and agilab_path.exists(): return agilab_path else: raise ValueError("Installation path file is empty or invalid.") except FileNotFoundError: logging.error(f"File {where_is_agi} does not exist.") except PermissionError: logging.error(f"Permission denied when accessing {where_is_agi}.") except Exception as e: logging.error(f"An error occurred: {e}") for p in sys.path_importer_cache: if p.endswith("agi_env"): base_dir = os.path.dirname(p).replace('_env', 'lab') if verbose: logging.info(f"Fallback agilab path found: {base_dir}") if AgiEnv.install_type == 0: return Path(base_dir) else: return Path(p.split('/agilab',1)[0]) logging.info("Falling back to current working directory") return Path(os.getcwd())
def copy_missing(self, src: Path, dst: Path): dst.mkdir(parents=True, exist_ok=True) for item in src.iterdir(): src_item = item dst_item = dst / item.name if src_item.is_dir(): self.copy_missing(src_item, dst_item) else: if not dst_item.exists(): shutil.copy2(src_item, dst_item) def _update_env_file(self, updates: dict): env_file = self.resource_path / ".env" os.makedirs(env_file.parent, exist_ok=True) env_file.touch(exist_ok=True) for k, v in updates.items(): set_key(str(env_file), k, str(v), quote_mode="never") def _init_resources(self, resources_path): src_env_path = resources_path / ".env" dest_env_file = self.resource_path / ".env" if not src_env_path.exists(): msg = f"Installation issue: {src_env_path} is missing!" logging.info(msg) raise RuntimeError(msg) if not dest_env_file.exists(): os.makedirs(dest_env_file.parent, exist_ok=True) shutil.copy(src_env_path, dest_env_file) for root, dirs, files in os.walk(resources_path): for file in files: src_file = Path(root) / file relative_path = src_file.relative_to(resources_path) dest_file = self.resource_path / relative_path dest_file.parent.mkdir(parents=True, exist_ok=True) if not dest_file.exists(): shutil.copy(src_file, dest_file) def _init_projects(self): self.projects = self.get_projects(self.apps_dir) for idx, project in enumerate(self.projects): if self.target == project[:-8].replace("-", "_"): self.app_abs = AgiEnv.apps_dir / project self.project_index = idx self.app = project break def _determine_apps_dir(self, module_path): path_str = str(module_path) index = path_str.index("_project") return Path(path_str[:index]).parent def _determine_module_path(self, project_or_module_name): parts = project_or_module_name.rsplit("-", 1) suffix = parts[-1] name = parts[0].split(os.sep)[-1] module_name = name.replace("-", "_") if suffix.startswith("project"): name = name.replace("-" + suffix, "") project_name = name + "_project" else: project_name = name.replace("_", "-") + "_project" module_path = self.apps_dir / project_name / "src" / module_name / (module_name + ".py") return module_path.resolve()
[docs] def get_projects(self, path: Path): return [p.name for p in path.glob("*project")]
[docs] def get_modules(self, target=None): pattern = "_project" modules = [ re.sub(f"^{pattern}|{pattern}$", "", project).replace("-", "_") for project in self.get_projects(AgiEnv.apps_dir) ] return modules
[docs] def get_base_worker_cls(self, module_path, class_name): base_info_list = self.get_base_classes(module_path, class_name) try: base_class, module_name = next((base, mod) for base, mod in base_info_list if base.endswith("Worker")) return base_class, module_name except StopIteration: return None, None
[docs] def get_base_classes(self, module_path, class_name): try: with open(module_path, "r", encoding="utf-8") as file: source = file.read() except (IOError, FileNotFoundError) as e: logging.error(f"Error reading module file {module_path}: {e}") return [] try: tree = ast.parse(source) except SyntaxError as e: logging.error(f"Syntax error parsing {module_path}: {e}") raise RuntimeError(f"Syntax error parsing {module_path}: {e}") import_mapping = self.get_import_mapping(source) base_classes = [] for node in ast.walk(tree): if isinstance(node, ast.ClassDef) and node.name == class_name: for base in node.bases: base_info = self.extract_base_info(base, import_mapping) if base_info: base_classes.append(base_info) break return base_classes
[docs] def get_import_mapping(self, source): mapping = {} try: tree = ast.parse(source) except SyntaxError as e: logging.error(f"Syntax error during import mapping: {e}") raise for node in ast.walk(tree): if isinstance(node, ast.Import): for alias in node.names: mapping[alias.asname or alias.name] = alias.name elif isinstance(node, ast.ImportFrom): module = node.module for alias in node.names: mapping[alias.asname or alias.name] = module return mapping
[docs] def extract_base_info(self, base, import_mapping): if isinstance(base, ast.Name): module_name = import_mapping.get(base.id) return base.id, module_name elif isinstance(base, ast.Attribute): full_name = self.get_full_attribute_name(base) parts = full_name.split(".") if len(parts) > 1: alias = parts[0] module_name = import_mapping.get(alias, alias) return parts[-1], module_name return base.attr, None return None
[docs] def get_full_attribute_name(self, node): if isinstance(node, ast.Name): return node.id elif isinstance(node, ast.Attribute): return self.get_full_attribute_name(node.value) + "." + node.attr return ""
[docs] def mode2str(self, mode): chars = ["p", "c", "d", "r"] reversed_chars = reversed(list(enumerate(chars))) if self.has_rapids_hw: mode += 8 mode_str = "".join( "_" if (mode & (1 << i)) == 0 else v for i, v in reversed_chars ) return mode_str
[docs] @staticmethod def mode2int(mode): mode_int = 0 set_rm = set(mode) for i, v in enumerate(["p", "c", "d"]): if v in set_rm: mode_int += 2 ** (len(["p", "c", "d"]) - 1 - i) return mode_int
def is_valid_ip(self, ip: str) -> bool: pattern = re.compile(r"^(?:[0-9]{1,3}\.){3}[0-9]{1,3}$") if pattern.match(ip): parts = ip.split(".") return all(0 <= int(part) <= 255 for part in parts) return False
[docs] def init_envars_app(self, envars): self.CLUSTER_CREDENTIALS = envars.get("CLUSTER_CREDENTIALS", None) self.OPENAI_API_KEY = envars.get("OPENAI_API_KEY", None) AGILAB_LOG_ABS = Path(envars.get("AGI_LOG_DIR", self.home_abs / "log")) if not AGILAB_LOG_ABS.exists(): AGILAB_LOG_ABS.mkdir(parents=True) self.AGILAB_LOG_ABS = AGILAB_LOG_ABS self.runenv = self.AGILAB_LOG_ABS AGILAB_EXPORT_ABS = Path(envars.get("AGI_EXPORT_DIR", self.home_abs / "export")) if not AGILAB_EXPORT_ABS.exists(): AGILAB_EXPORT_ABS.mkdir(parents=True) self.AGILAB_EXPORT_ABS = AGILAB_EXPORT_ABS self.export_apps = AGILAB_EXPORT_ABS / "apps" if not self.export_apps.exists(): os.makedirs(str(self.export_apps), exist_ok=True) self.MLFLOW_TRACKING_DIR = Path(envars.get("MLFLOW_TRACKING_DIR", self.home_abs / ".mlflow")) self.AGILAB_VIEWS_ABS = Path(envars.get("AGI_VIEWS_DIR", self.agi_root / "views")) self.AGILAB_VIEWS_REL = Path(envars.get("AGI_VIEWS_DIR", "agi/_")) if self.install_type == 0: self.copilot_file = self.agi_core_loc / "agi_gui/agi_copilot.py" # WTF ? else: self.copilot_file = self.agi_root / "fwk/gui/src/agi_gui/agi_copilot.py"
[docs] def resolve_packages_path_in_toml(self): agi_root = self.agi_root for file in [self.worker_pyproject, self.app_pyproject]: if not file.exists(): raise FileNotFoundError(f"{file} not found in {self.app_abs}") text = file.read_text(encoding="utf-8") doc = tomlkit.parse(text) try: uv = doc["tool"]["uv"] except KeyError: raise RuntimeError("Could not find [tool.uv] section in the TOML") if "sources" not in uv or not isinstance(uv["sources"], tomlkit.items.Table): raise RuntimeError("Could not find [tool.uv.sources] in the TOML") sources = uv["sources"] if "site-packages" in agi_root.parts: if "agi-core" in sources: del sources["agi-core"] if not sources: del uv["sources"] if not uv: del doc["tool"]["uv"] if not doc["tool"]: del doc["tool"] deps = doc["project"].get("dependencies", []) if not any(dep.split()[0] == "agi-core" for dep in deps): deps.append("agi-core") doc["project"]["dependencies"] = deps else: agi_core_path = str((agi_root / "fwk" / "core").resolve()) tbl = tomlkit.inline_table() tbl["path"] = agi_core_path tbl["editable"] = True sources["agi-core"] = tbl file.write_text(tomlkit.dumps(doc), encoding="utf-8")
[docs] def copy_missing(self, src: Path, dst: Path): dst.mkdir(parents=True, exist_ok=True) for item in src.iterdir(): src_item = item dst_item = dst / item.name if src_item.is_dir(): self.copy_missing(src_item, dst_item) else: if not dst_item.exists(): shutil.copy2(src_item, dst_item)
def _init_apps(self): app_settings_file = self.app_src / "app_settings.toml" app_settings_file.touch(exist_ok=True) self.app_settings_file = app_settings_file args_ui_snippet = self.app_src / "args_ui_snippet.py" args_ui_snippet.touch(exist_ok=True) self.args_ui_snippet = args_ui_snippet self.gitignore_file = self.app_abs / ".gitignore" dest = self.resource_path if self.install_type == 1 and not "site-packages" in self.agi_root.parts: shutil.copytree(self.agi_root / "fwk/gui/src/agi_gui" / self.agi_resources, dest, dirs_exist_ok=True) else: shutil.copytree(self.agi_root.parent / "agi_gui" / self.agi_resources, dest, dirs_exist_ok=True) @staticmethod def _build_env(venv=None): """Build environment dict for subprocesses, with activated virtualenv paths.""" proc_env = os.environ.copy() if venv is not None: venv_path = Path(venv) / ".venv" proc_env["VIRTUAL_ENV"] = str(venv_path) bin_path = "Scripts" if os.name == "nt" else "bin" venv_bin = venv_path / bin_path proc_env["PATH"] = str(venv_bin) + os.pathsep + proc_env.get("PATH", "") return proc_env
[docs] @staticmethod def log_info(line): GREEN = "\033[32m" RESET = "\033[0m" if not isinstance(line, str): line = str(line) msg = f"{GREEN}{line}{RESET}" if sys.stdout.isatty() else line logging.info(msg)
[docs] @staticmethod def log_error(line): RED = "\033[31m" RESET = "\033[0m" if not isinstance(line, str): line = str(line) msg = f"{RED}{line}{RESET}" if sys.stdout.isatty() else line logging.info(msg)
[docs] @staticmethod async def run(cmd, venv, cwd=None, timeout=None, wait=True, log_callback=None): """ Run a shell command synchronously inside a virtual environment. Log stdout lines as info, stderr lines as error. Returns full stdout string. """ if AgiEnv.verbose > 1: logging.info(f"Executing in {venv}: {cmd}") if not cwd: cwd = venv process_env = os.environ.copy() venv_path = Path(venv) if not (venv_path / "bin").exists() and venv_path.name != ".venv": venv_path = venv_path / ".venv" process_env["VIRTUAL_ENV"] = str(venv_path) bin_dir = "Scripts" if sys.platform == "win32" else "bin" venv_bin = venv_path / bin_dir process_env["PATH"] = str(venv_bin) + os.pathsep + process_env.get("PATH", "") shell_executable = None if sys.platform == "win32" else "/bin/bash" if wait: try: process = subprocess.Popen( cmd, shell=True, cwd=str(cwd), env=process_env, text=True, executable=shell_executable, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True, ) result = "" while True: out_line = process.stdout.readline() err_line = process.stderr.readline() result += out_line if out_line: line = out_line.rstrip("\n") if log_callback: log_callback(line) else: logging.info(line) if err_line: line = err_line.rstrip("\n") if log_callback: log_callback(line) else: logging.info(line) if out_line == '' and err_line == '' and process.poll() is not None: break process.wait(timeout=timeout) if AgiEnv.verbose > 1 or AgiEnv._debug: logging.info(f"Command completed with exit code {process.returncode}") return result except subprocess.TimeoutExpired: process.kill() raise RuntimeError(f"Command timed out after {timeout} seconds: {cmd}") except Exception as e: logging.error(traceback.format_exc()) raise RuntimeError(f"Command execution error: {e}") from e else: subprocess.Popen( cmd, shell=True, cwd=str(cwd), env=process_env, executable=shell_executable, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) return 0
@staticmethod async def _run_bg(cmd, cwd=".", venv=None, timeout=None, log_callback=None): """ Run the given command asynchronously, reading stdout and stderr line by line and passing them to the log_callback. """ proc_env = AgiEnv._build_env(venv) proc_env["PYTHONUNBUFFERED"] = "1" proc = await asyncio.create_subprocess_shell( cmd, cwd=os.path.abspath(cwd), env=proc_env, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) async def read_stream(stream, callback): while True: line = await stream.readline() if not line: break decoded_line = line.decode('utf-8', errors='replace').rstrip() if decoded_line: callback(decoded_line) tasks = [] if proc.stdout: tasks.append(asyncio.create_task( read_stream(proc.stdout, log_callback if log_callback else logging.info) )) if proc.stderr: tasks.append(asyncio.create_task( read_stream(proc.stderr, log_callback if log_callback else logging.error) )) try: await asyncio.wait_for(proc.wait(), timeout=timeout) except asyncio.TimeoutError as err: proc.kill() raise RuntimeError(f"Timeout expired for command: {cmd}") from err await asyncio.gather(*tasks) stdout, stderr = await proc.communicate() return stdout.decode(), stderr.decode()
[docs] async def run_agi(self, code, log_callback=None, venv: Path = None, type=None): """ Asynchronous version of run_agi for use within an async context. """ pattern = r"await\s+(?:Agi\.)?([^\(]+)\(" matches = re.findall(pattern, code) if not matches: message = "Could not determine snippet name from code." if log_callback: log_callback(message) else: logging.info(message) return "", "" snippet_file = os.path.join(self.runenv, f"{matches[0]}-{self.target}.py") with open(snippet_file, "w") as file: file.write(code) cmd = f"uv -q run --project {str(venv)} python {snippet_file}" result = await AgiEnv._run_bg(cmd, cwd=venv, log_callback=log_callback) if log_callback: log_callback(f"Process finished with output: {result}") else: logging.info("Process finished") return result
[docs] @staticmethod async def run_async(cmd, venv=None, cwd=None, timeout=None, log_callback=None): """ Run a shell command asynchronously inside a virtual environment. Returns the last line of combined stdout and stderr outputs. """ if not cwd: cwd = venv process_env = os.environ.copy() venv_path = Path(venv) / ".venv" process_env["VIRTUAL_ENV"] = str(venv_path) bin_dir = "Scripts" if os.name == "nt" else "bin" venv_bin = venv_path / bin_dir process_env["PATH"] = str(venv_bin) + os.pathsep + process_env.get("PATH", "") shell_executable = "/bin/bash" if os.name != "nt" else None if isinstance(cmd, list): cmd = " ".join(cmd) process = await asyncio.create_subprocess_shell( cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=str(cwd), env=process_env, executable=shell_executable ) stdout_lines = [] stderr_lines = [] async def read_stream(stream, lines, callback): while True: line = await stream.readline() if not line: break decoded_line = line.decode().rstrip() lines.append(decoded_line) if callback: callback(decoded_line) stdout_task = asyncio.create_task( read_stream(process.stdout, stdout_lines, log_callback if log_callback else logging.info) ) stderr_task = asyncio.create_task( read_stream(process.stderr, stderr_lines, log_callback if log_callback else logging.error) ) try: await asyncio.wait_for(process.wait(), timeout=timeout) except asyncio.TimeoutError as err: process.kill() raise RuntimeError(f"Timeout expired for command: {cmd}") from err await asyncio.gather(stdout_task, stderr_task) # Find last non-empty line from stderr first (usually errors), else stdout last_line = None for line in reversed(stderr_lines): if line.strip(): last_line = line break if not last_line: for line in reversed(stdout_lines): if line.strip(): last_line = line break return last_line
[docs] def change_active_app(self, app, install_type=1): if isinstance(app, str): app_name = app elif isinstance(app, Path): app_name = app.name else: raise TypeError(f"Invalid app type (<str>|<Path>): {type(app)}") if app_name != self.app: self.__init__(active_app=app_name, install_type=install_type, verbose=AgiEnv.verbose)
[docs] @asynccontextmanager async def get_ssh_connection(self, ip: str, timeout_sec: int = 5): if not self.user: raise ValueError("SSH username is not configured. Please set 'user' in your .env file.") conn = self._ssh_connections.get(ip) if conn and not conn.is_closed(): yield conn return try: ssh_dir = Path("~/.ssh").expanduser() keys = [] for file in ssh_dir.iterdir(): if not file.is_file(): continue name = file.name if name.startswith('authorized_keys'): continue if name.startswith('known_hosts'): continue if name.startswith('id_') and name.endswith('.pub'): continue keys.append(str(file)) client_keys = keys if keys else None conn = await asyncio.wait_for( asyncssh.connect( ip, username=self.user, password=self.password, known_hosts=None, client_keys=client_keys, ), timeout=timeout_sec ) self._ssh_connections[ip] = conn yield conn except asyncio.TimeoutError: err_msg = f"Connection to {ip} timed out after {timeout_sec} seconds." logging.error(err_msg) raise except asyncssh.PermissionDenied: err_msg = f"Authentication failed for SSH user '{self.user}' on host {ip}." logging.error(err_msg) raise except OSError as e: if e.errno == errno.EHOSTUNREACH: err_msg = ( f"Unable to connect to {ip} on SSH port 22. " "Please check that the device is powered on, network cable connected, and SSH service running." ) raise ConnectionError(err_msg) elif e.errno in (errno.EACCES, errno.ECONNREFUSED): logging.error(str(e)) else: logging.error(str(e)) raise except asyncssh.Error as e: logging.error(e.command if hasattr(e, 'command') else "No command attribute") logging.error(e) raise except Exception as e: logging.error(f"Unexpected error while connecting to {ip}: {e}") raise
[docs] async def exec_ssh(self, ip: str, cmd: str) -> str: try: async with self.get_ssh_connection(ip) as conn: msg = f"[{ip}] {cmd}" if AgiEnv.verbose > 1 or AgiEnv._debug: logging.info(msg) result = await conn.run(cmd, check=True) stdout = result.stdout if isinstance(stdout, bytes): stdout = stdout.decode('utf-8', errors='replace') if AgiEnv.verbose > 1 or AgiEnv._debug: logging.info(f"[{ip}] {stdout.strip()}") return stdout.strip() except ConnectionError: raise except ProcessError as e: stdout = getattr(e, 'stdout', '') stderr = getattr(e, 'stderr', '') if isinstance(stdout, bytes): stdout = stdout.decode('utf-8', errors='replace') if isinstance(stderr, bytes): stderr = stderr.decode('utf-8', errors='replace') logging.error(f"Remote command stderr: {stderr.strip()}") raise except (asyncssh.Error, OSError) as e: logging.error(e) raise
[docs] async def exec_ssh_async(self, ip: str, cmd: str) -> str: """ Execute a remote command via SSH and return the last line of its stdout output. """ async with self.get_ssh_connection(ip) as conn: process = await conn.create_process(cmd) # Read entire stdout output as bytes stdout = await process.stdout.read() await process.wait() # Decode output safely #stdout_str = stdout.decode('utf-8', errors='replace') # Split output into lines and get the last non-empty line lines = [line.strip() for line in stdout.splitlines() if line.strip()] if lines: return lines[-1] else: return "" # or None if no output
[docs] async def send_file( self, ip: str, local_path: Path, remote_path: Path, user: str = None, password: str = None ): if not user: user = self.user if not password: password = self.password user_at_ip = f"{user}@{ip}" if user else ip remote = f"{user_at_ip}:{remote_path}" cmd = [] if password and os.name != "nt": cmd_base = ["sshpass"] cmd += cmd_base + ["-p", password] cmd_end = ["scp", str(local_path), remote] cmd = cmd + cmd_end try: process = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await process.communicate() if process.returncode != 0: logging.error(f"SCP failed sending {local_path} to {remote}: {stderr.decode().strip()}") raise ConnectionError(f"SCP error: {stderr.decode().strip()}") logging.info(f"Sent file {local_path} to {remote}") except Exception as e: try: cmd = cmd_base + cmd_end process = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await process.communicate() if process.returncode: logging.error(f"SCP failed sending {local_path} to {remote}: {stderr.decode().strip()}") raise ConnectionError(f"SCP error: {stderr.decode().strip()}") logging.info(f"Sent file {local_path} to {remote}") except Exception as e: raise
[docs] async def send_files(self, ip: str, files: list[Path], remote_dir: Path, user: str = None): tasks = [] for f in files: remote_path = f"{remote_dir / f.name}" tasks.append(self.send_file(ip, f, remote_path, user=user)) await asyncio.gather(*tasks)
# logging.info(f"Sent {len(files)} files to {user if user else self.user}@{ip}:{remote_dir}")
[docs] def remove_dir_forcefully(self, path): import shutil import os import time def onerror(func, path, exc_info): import stat if not os.access(path, os.W_OK): os.chmod(path, stat.S_IWUSR) func(path) else: logging.info(f"{path} not removed due to {exc_info[1]}") try: shutil.rmtree(path, onerror=onerror) except Exception as e: logging.error(f"Exception while deleting {path}: {e}") time.sleep(1) try: shutil.rmtree(path, onerror=onerror) except Exception as e2: logging.error(f"Second failure deleting {path}: {e2}") raise
[docs] async def close_all_connections(self): """ Ferme proprement toutes les connexions SSH ouvertes. À appeler à la fin de ton programme ou avant arrêt. """ for conn in self._ssh_connections.values(): conn.close() await conn.wait_closed() self._ssh_connections.clear()
@staticmethod def log_remote_line(ip: str, line: str): """ Log a line from remote SSH output with the proper log level. Args: ip (str): IP address of remote host. line (str): One line of output from remote process. """ match = LOG_LEVEL_RE.search(line) if match: level_name = match.group(1) level = getattr(logging, level_name, logging.INFO) else: # Default to INFO if no level found level = logging.INFO logging.info(level, f"[{ip}] {line}")
[docs] def set_cluster_credentials(self, credentials: str): """Set the AGI_CREDENTIALS environment variable.""" self.CLUSTER_CREDENTIALS = credentials # maintain internal state self.set_env_var("CLUSTER_CREDENTIALS", credentials)
[docs] def set_openai_api_key(self, api_key: str): """Set the OPENAI_API_KEY environment variable.""" self.OPENAI_API_KEY = api_key self.set_env_var("OPENAI_API_KEY", api_key)
[docs] def set_install_type(self, install_type: int): self.install_type = install_type self.set_env_var("INSTALL_TYPE", str(install_type))
[docs] def set_apps_dir(self, apps_dir: Path): self.apps_dir =apps_dir self.set_env_var("APPS_DIR", apps_dir)
[docs] def has_admin_rights(): """ Check if the current process has administrative rights on Windows. Returns: bool: True if admin, False otherwise. """ try: return ctypes.windll.shell32.IsUserAnAdmin() except: return False
[docs] def create_junction_windows(source: Path, dest: Path): """ Create a directory junction on Windows. Args: source (Path): The target directory path. dest (Path): The destination junction path. """ try: # Using the mklink command to create a junction (/J) which doesn't require admin rights. subprocess.check_call(['cmd', '/c', 'mklink', '/J', str(dest), str(source)]) print(f"Created junction: {dest} -> {source}") except subprocess.CalledProcessError as e: print(f"Failed to create junction. Error: {e}")
[docs] def handle_venv_directory(self, source_venv: Path, dest_venv: Path): """ Create a symbolic link for the .venv directory instead of copying it. Args: source_venv (Path): Source .venv directory path. dest_venv (Path): Destination .venv symbolic link path. """ try: if os.name == "nt": create_symlink_windows(source_venv, dest_venv) else: # For Unix-like systems os.symlink(source_venv, dest_venv, target_is_directory=True) print(f"Created symbolic link for .venv: {dest_venv} -> {source_venv}") except OSError as e: print(f"Failed to create symbolic link for .venv: {e}")
[docs] def create_rename_map(self, target_project: Path, dest_project: Path) -> dict: """ Create a mapping of old → new names for cloning. Includes project names, top-level src folders, worker folders, in-file identifiers and class names. """ def cap(s: str) -> str: return "".join(p.capitalize() for p in s.split("_")) name_tp = target_project.name # e.g. "flight_project" name_dp = dest_project.name # e.g. "tata_project" tp = name_tp[:-8] # strip "_project" → "flight" dp = name_dp[:-8] # → "tata" tm = tp.replace("-", "_") dm = dp.replace("-", "_") tc = cap(tm) # "Flight" dc = cap(dm) # "Tata" return { # project-level name_tp: name_dp, # folder-level (longest keys first) f"src/{tm}_worker": f"src/{dm}_worker", f"src/{tm}": f"src/{dm}", # sibling-level f"{tm}_worker": f"{dm}_worker", tm: dm, # class-level f"{tc}Worker": f"{dc}Worker", f"{tc}Args": f"{dc}Args", tc: dc, }
[docs] def clone_project(self, target_project: Path, dest_project: Path): """ Clone a project by copying files and directories, applying renaming, then cleaning up any leftovers. Args: target_project: Path under self.apps_dir (e.g. Path("flight_project")) dest_project: Path under self.apps_dir (e.g. Path("tata_project")) """ # normalize names if not target_project.name.endswith("_project"): target_project = target_project.with_name(target_project.name + "_project") if not dest_project.name.endswith("_project"): dest_project = dest_project.with_name(dest_project.name + "_project") rename_map = self.create_rename_map(target_project, dest_project) source_root = self.apps_dir / target_project dest_root = self.apps_dir / dest_project if not source_root.exists(): print(f"Source project '{target_project}' does not exist.") return if dest_root.exists(): print(f"Destination project '{dest_project}' already exists.") return gitignore = source_root / ".gitignore" if not gitignore.exists(): print(f"No .gitignore at '{gitignore}'.") return spec = PathSpec.from_lines(GitWildMatchPattern, gitignore.read_text().splitlines()) try: dest_root.mkdir(parents=True, exist_ok=False) except Exception as e: print(f"Could not create '{dest_root}': {e}") return # 1) Recursive clone self.clone_directory(source_root, dest_root, rename_map, spec, source_root) # 2) Final cleanup self._cleanup_rename(dest_root, rename_map) self.projects.insert(0, dest_project)
[docs] def clone_directory(self, source_dir: Path, dest_dir: Path, rename_map: dict, spec: PathSpec, source_root: Path): """ Recursively copy + rename directories, files, and contents, applying renaming only on exact path segments. """ for item in source_dir.iterdir(): rel = item.relative_to(source_root).as_posix() # Skip files/directories matched by .gitignore spec if spec.match_file(rel + ("/" if item.is_dir() else "")): continue # Rename only full segments of the relative path parts = rel.split("/") for i, seg in enumerate(parts): # Sort rename_map by key length descending to avoid partial conflicts for old, new in sorted(rename_map.items(), key=lambda kv: -len(kv[0])): if seg == old: parts[i] = new break new_rel = "/".join(parts) dst = dest_dir / new_rel dst.parent.mkdir(parents=True, exist_ok=True) if item.is_dir(): if item.name == ".venv": # Keep virtual env directory as a symlink os.symlink(item, dst, target_is_directory=True) else: self.clone_directory(item, dest_dir, rename_map, spec, source_root) elif item.is_file(): suf = item.suffix.lower() base = item.stem # Rename file if its basename is in rename_map if base in rename_map: dst = dst.with_name(rename_map[base] + item.suffix) if suf in (".7z", ".zip"): shutil.copy2(item, dst) elif suf == ".py": src = item.read_text(encoding="utf-8") try: tree = ast.parse(src) renamer = ContentRenamer(rename_map) new_tree = renamer.visit(tree) ast.fix_missing_locations(new_tree) out = astor.to_source(new_tree) except SyntaxError: out = src # Whole word replacements in Python source text for old, new in rename_map.items(): out = re.sub(rf"\b{re.escape(old)}\b", new, out) dst.write_text(out, encoding="utf-8") elif suf in (".toml", ".md", ".txt", ".json", ".yaml", ".yml"): txt = item.read_text(encoding="utf-8") for old, new in rename_map.items(): txt = re.sub(rf"\b{re.escape(old)}\b", new, txt) dst.write_text(txt, encoding="utf-8") else: shutil.copy2(item, dst) elif item.is_symlink(): target = os.readlink(item) os.symlink(target, dst, target_is_directory=item.is_dir())
def _cleanup_rename(self, root: Path, rename_map: dict): """ 1) Rename any leftover file/dir basenames (including .py) that exactly match a key. 2) Rewrite text files for any straggler content references. """ # build simple name→new map (no slashes) simple_map = {old: new for old, new in rename_map.items() if "/" not in old} # sort longest first sorted_simple = sorted(simple_map.items(), key=lambda kv: len(kv[0]), reverse=True) # -- step 1: rename basenames (dirs & files) bottom‑up -- for path in sorted(root.rglob("*"), key=lambda p: len(p.parts), reverse=True): old = path.name for o, n in sorted_simple: # directory exactly "flight" → "truc", or "flight_worker" → "truc_worker" if old == o or old == f"{o}_worker" or old == f"{o}_project": new_name = old.replace(o, n, 1) path.rename(path.with_name(new_name)) break # file like "flight.py" → "truc.py" if path.is_file() and old.startswith(o + "."): new_name = n + old[len(o):] path.rename(path.with_name(new_name)) break # -- step 2: rewrite any lingering text references -- exts = {".py", ".toml", ".md", ".txt", ".json", ".yaml", ".yml"} for file in root.rglob("*"): if not file.is_file() or file.suffix.lower() not in exts: continue txt = file.read_text(encoding="utf-8") new_txt = txt for old, new in rename_map.items(): new_txt = re.sub(rf"\b{re.escape(old)}\b", new, new_txt) if new_txt != txt: file.write_text(new_txt, encoding="utf-8")
[docs] def replace_content(self, txt: str, rename_map: dict) -> str: for old, new in sorted(rename_map.items(), key=lambda kv: len(kv[0]), reverse=True): # only match whole‐word occurrences of `old` pattern = re.compile(rf"\b{re.escape(old)}\b") txt = pattern.sub(new, txt) return txt
[docs] def read_gitignore(self, gitignore_path: Path) -> 'PathSpec': from pathspec import PathSpec from pathspec.patterns import GitWildMatchPattern lines = gitignore_path.read_text(encoding="utf-8").splitlines() return PathSpec.from_lines(GitWildMatchPattern, lines)
[docs] def is_valid_ip(self, ip: str) -> bool: pattern = re.compile(r"^(?:[0-9]{1,3}\.){3}[0-9]{1,3}$") if pattern.match(ip): parts = ip.split(".") return all(0 <= int(part) <= 255 for part in parts) return False
@property def scheduler_ip_address(self): return self.scheduler_ip
[docs] def log_remote_line(self, ip, line): print(f"[{ip}] {line}") # Replace with your real remote line logger
[docs] class ContentRenamer(ast.NodeTransformer): """ A class that renames identifiers in an abstract syntax tree (AST). Attributes: rename_map (dict): A mapping of old identifiers to new identifiers. """
[docs] def __init__(self, rename_map): """ Initialize the ContentRenamer with the rename_map. Args: rename_map (dict): Mapping of old names to new names. """ self.rename_map = rename_map
[docs] def visit_Name(self, node): # Rename variable and function names """ Visit and potentially rename a Name node in the abstract syntax tree. Args: self: The current object instance. node: The Name node in the abstract syntax tree. Returns: ast.Node: The modified Name node after potential renaming. Note: This function modifies the Name node in place. Raises: None """ if node.id in self.rename_map: print(f"Renaming Name: {node.id}{self.rename_map[node.id]}") node.id = self.rename_map[node.id] self.generic_visit(node) # Ensure child nodes are visited return node
[docs] def visit_Attribute(self, node): # Rename attributes """ Visit and potentially rename an attribute in a node. Args: node: A node representing an attribute. Returns: node: The visited node with potential attribute renamed. Raises: None. """ if node.attr in self.rename_map: print(f"Renaming Attribute: {node.attr}{self.rename_map[node.attr]}") node.attr = self.rename_map[node.attr] self.generic_visit(node) return node
[docs] def visit_FunctionDef(self, node): # Rename function names """ Rename a function node based on a provided mapping. Args: node (ast.FunctionDef): The function node to be processed. Returns: ast.FunctionDef: The function node with potential name change. """ if node.name in self.rename_map: print(f"Renaming Function: {node.name}{self.rename_map[node.name]}") node.name = self.rename_map[node.name] self.generic_visit(node) return node
[docs] def visit_ClassDef(self, node): # Rename class names """ Visit and potentially rename a ClassDef node. Args: node (ast.ClassDef): The ClassDef node to visit. Returns: ast.ClassDef: The potentially modified ClassDef node. """ if node.name in self.rename_map: print(f"Renaming Class: {node.name}{self.rename_map[node.name]}") node.name = self.rename_map[node.name] self.generic_visit(node) return node
[docs] def visit_arg(self, node): # Rename function argument names """ Visit and potentially rename an argument node. Args: self: The instance of the class. node: The argument node to visit and possibly rename. Returns: ast.AST: The modified argument node. Notes: Modifies the argument node in place if its name is found in the rename map. Raises: None. """ if node.arg in self.rename_map: print(f"Renaming Argument: {node.arg}{self.rename_map[node.arg]}") node.arg = self.rename_map[node.arg] self.generic_visit(node) return node
[docs] def visit_Global(self, node): # Rename global variable names """ Visit and potentially rename global variables in the AST node. Args: self: The instance of the class that contains the renaming logic. node: The AST node to visit and potentially rename global variables. Returns: AST node: The modified AST node with global variable names potentially renamed. """ new_names = [] for name in node.names: if name in self.rename_map: print(f"Renaming Global Variable: {name}{self.rename_map[name]}") new_names.append(self.rename_map[name]) else: new_names.append(name) node.names = new_names self.generic_visit(node) return node
[docs] def visit_nonlocal(self, node): # Rename nonlocal variable names """ Visit and potentially rename nonlocal variables in the AST node. Args: self: An instance of the class containing the visit_nonlocal method. node: The AST node to visit and potentially modify. Returns: ast.AST: The modified AST node after visiting and potentially renaming nonlocal variables. """ new_names = [] for name in node.names: if name in self.rename_map: print( f"Renaming Nonlocal Variable: {name}{self.rename_map[name]}" ) new_names.append(self.rename_map[name]) else: new_names.append(name) node.names = new_names self.generic_visit(node) return node
[docs] def visit_Assign(self, node): # Rename assigned variable names """ Visit and process an assignment node. Args: self: The instance of the visitor class. node: The assignment node to be visited. Returns: ast.Node: The visited assignment node. """ self.generic_visit(node) return node
[docs] def visit_AnnAssign(self, node): # Rename annotated assignments """ Visit and process an AnnAssign node in an abstract syntax tree. Args: self: The AST visitor object. node: The AnnAssign node to be visited. Returns: AnnAssign: The visited AnnAssign node. """ self.generic_visit(node) return node
[docs] def visit_For(self, node): # Rename loop variable names """ Visit and potentially rename the target variable in a For loop node. Args: node (ast.For): The For loop node to visit. Returns: ast.For: The modified For loop node. Note: This function may modify the target variable in the For loop node if it exists in the rename map. """ if isinstance(node.target, ast.Name) and node.target.id in self.rename_map: print( f"Renaming For Loop Variable: {node.target.id}{self.rename_map[node.target.id]}" ) node.target.id = self.rename_map[node.target.id] self.generic_visit(node) return node
[docs] def visit_Import(self, node): """ Rename imported modules in 'import module' statements. Args: node (ast.Import): The import node. """ for alias in node.names: original_name = alias.name if original_name in self.rename_map: print( f"Renaming Import Module: {original_name}{self.rename_map[original_name]}" ) alias.name = self.rename_map[original_name] else: # Handle compound module names if necessary for old, new in self.rename_map.items(): if original_name.startswith(old): print( f"Renaming Import Module: {original_name}{original_name.replace(old, new, 1)}" ) alias.name = original_name.replace(old, new, 1) break self.generic_visit(node) return node
[docs] def visit_ImportFrom(self, node): """ Rename modules and imported names in 'from module import name' statements. Args: node (ast.ImportFrom): The import from node. """ # Rename the module being imported from if node.module in self.rename_map: print( f"Renaming ImportFrom Module: {node.module}{self.rename_map[node.module]}" ) node.module = self.rename_map[node.module] else: for old, new in self.rename_map.items(): if node.module and node.module.startswith(old): new_module = node.module.replace(old, new, 1) print( f"Renaming ImportFrom Module: {node.module}{new_module}" ) node.module = new_module break # Rename the imported names for alias in node.names: if alias.name in self.rename_map: print( f"Renaming Imported Name: {alias.name}{self.rename_map[alias.name]}" ) alias.name = self.rename_map[alias.name] else: for old, new in self.rename_map.items(): if alias.name.startswith(old): print( f"Renaming Imported Name: {alias.name}{alias.name.replace(old, new, 1)}" ) alias.name = alias.name.replace(old, new, 1) break self.generic_visit(node) return node import getpass, os, sys, subprocess, signal me = getpass.getuser() my_pid = os.getpid()