Source code for finetuning_scheduler.dynamic_versioning.utils
#!/usr/bin/env python
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from pathlib import Path
from typing import ValuesView
# -----------------------------------------------------------------------------
# Lightning Configuration
# -----------------------------------------------------------------------------
#
# These version constraints are the single source of truth for minimum versions.
# They are used by setup.py to generate dynamic dependencies at build time.
#
# For visibility, these values are also documented in pyproject.toml under
# [tool.fts.min-versions] (informational only - not used during installation).
# Shared version constraint for all Lightning packages
LIGHTNING_VERSION = ">=2.6.0,<2.6.1"
LIGHTNING_PACKAGE_MAPPING = {
"lightning.pytorch": "pytorch_lightning",
"lightning.fabric": "lightning_fabric",
}
# Package and repository mapping
LIGHTNING_PACKAGES = {
"unified": {
"package": "lightning",
"repo": "Lightning-AI/lightning",
"version": LIGHTNING_VERSION # Use shared version constraint
},
"standalone": {
"package": "pytorch-lightning",
"repo": "Lightning-AI/pytorch-lightning",
"version": LIGHTNING_VERSION # Use shared version constraint
}
}
# Base dependencies (torch + Lightning are handled dynamically)
# These are the core dependencies that are always installed
#
# Note: For visibility, minimum versions are also documented in pyproject.toml
# under [tool.fts.min-versions] (informational only).
BASE_DEPENDENCIES = [
"torch>=2.6.0",
]
# Files to exclude from modification to prevent self-modification
EXCLUDE_FILES_FROM_CONVERSION = [
"setup_tools.py",
"dynamic_versioning/utils.py",
"dynamic_versioning/toggle_lightning_mode.py",
"test_toggle_lightning_mode.py",
"test_setup_tools.py",
"test_dynamic_versioning_utils.py"
]
def get_base_dependencies() -> list[str]:
"""Get the base dependencies list.
Returns:
List of base dependency strings (excluding Lightning, which is added dynamically)
"""
return BASE_DEPENDENCIES.copy()
def get_requirement_files(standalone: bool = False) -> list[str]:
"""Get installation requirements with dynamic Lightning configuration.
Note: Lightning commit pinning is now handled at install time via UV_OVERRIDE
environment variable rather than at package metadata time. This simplifies the
install_requires list to just include version constraints.
Args:
standalone: Whether to use standalone pytorch-lightning package
Returns:
List of requirement strings
"""
# Start with base dependencies
reqs = get_base_dependencies()
# Add Lightning dependency based on package type
package_type = "standalone" if standalone else "unified"
lightning_req = get_lightning_requirement(package_type)
reqs.append(lightning_req)
return reqs
def get_lightning_requirement(package_type: str = "unified") -> str:
"""Get the Lightning requirement string based on configuration.
Note: Commit pinning is handled at install time via UV_OVERRIDE, not here.
Args:
package_type: Either "unified" or "standalone"
Returns:
The requirement string for the Lightning package with version constraint
"""
pkg_info = LIGHTNING_PACKAGES[package_type]
package_name = pkg_info["package"]
return f"{package_name}{pkg_info['version']}"
def _retrieve_files(directory: str, *ext: str, exclude_files: list[str] | None = None) -> list[str]:
"""Find all files in a directory with optional extension filtering and exclusion."""
exclude_files = exclude_files or []
all_files = []
for root, _, files in os.walk(directory):
for fname in files:
file_path = os.path.join(root, fname)
relative_path = os.path.relpath(file_path, directory)
# Normalize path separators to handle both Windows and Unix paths
norm_relative_path = relative_path.replace('\\', '/')
# Check if any excluded path is contained in the normalized file path
if any(exclude_path in norm_relative_path for exclude_path in exclude_files):
print(f"Skipping {file_path} to prevent self-modification")
continue
if not ext or any(os.path.split(fname)[1].lower().endswith(e) for e in ext):
all_files.append(file_path)
return all_files
def _replace_imports(lines: list[str], mapping: list[tuple[str, str]], lightning_by: str = "") -> list[str]:
"""Replace imports of unified packages to standalone."""
out = lines[:]
for source_import, target_import in mapping:
for i, ln in enumerate(out):
if "from" in ln and "import" in ln:
out[i] = re.sub(rf"(^|\s)from\s+{re.escape(source_import)}(\.|\s)", rf"\1from {target_import}\2", ln)
if ln.strip().startswith("import "):
out[i] = re.sub(rf"import\s+{re.escape(source_import)}(\.|\s|$|,)", rf"import {target_import}\1", ln)
if lightning_by:
out[i] = out[i].replace("from lightning import ", f"from {lightning_by} import ")
out[i] = out[i].replace("import lightning ", f"import {lightning_by} ")
return out
def _check_import_format(file_content: str, source_imports: list[str]) -> bool:
"""Check if imports in a file already match the expected format."""
for import_name in source_imports:
if re.search(rf"(^|\s)from\s+{re.escape(import_name)}(\.|\s)", file_content, re.MULTILINE) or \
re.search(rf"import\s+{re.escape(import_name)}(\.|\s|$|,)", file_content, re.MULTILINE):
return False
return True
def _process_lightning_imports(src_dirs: ValuesView, source_imports: list[str],
mapping_pairs: list[tuple[str, str]], target_format: str, debug: bool = False) -> None:
"""Process Lightning imports in python files across directories.
Args:
src_dirs: Directories to process
source_imports: List of import patterns to check for in files
mapping_pairs: List of (source, target) import mapping tuples
target_format: Format name for debug messages ("standalone" or "unified")
debug: Whether to output debug information
"""
for in_place_path in src_dirs:
files_to_process = _retrieve_files(str(in_place_path), '.py', exclude_files=EXCLUDE_FILES_FROM_CONVERSION)
if not files_to_process:
continue
for file_path in files_to_process:
try:
with open(file_path, encoding='utf-8') as f:
content = f.read()
if _check_import_format(content, source_imports=source_imports):
if debug:
print(f"No imports needed conversion to {target_format} format in {file_path}.")
continue
with open(file_path, 'w', encoding='utf-8') as f:
lines = _replace_imports(
content.splitlines(True),
mapping_pairs
)
f.writelines(lines)
print(f"Updated imports in {file_path}")
except UnicodeDecodeError:
if debug:
print(f"Skipping binary file: {file_path}")
continue
def use_standalone_pl(
src_dirs: ValuesView,
mapping: dict[str, str] = LIGHTNING_PACKAGE_MAPPING,
debug: bool = False
) -> None:
"""Replace unified Lightning imports with standalone imports."""
_process_lightning_imports(src_dirs, list(mapping.keys()), list(zip(mapping.keys(), mapping.values())),
"standalone", debug)
def use_unified_pl(src_dirs: ValuesView,
mapping: dict[str, str] = LIGHTNING_PACKAGE_MAPPING, debug: bool = False) -> None:
"""Replace standalone Lightning imports with unified imports."""
_process_lightning_imports(src_dirs, list(mapping.values()), list(zip(mapping.values(), mapping.keys())),
"unified", debug)
def get_project_paths() -> tuple[Path, dict[str, Path]]:
"""Get project paths for imports conversion and package setup."""
current_file_dir = Path(os.path.dirname(os.path.abspath(__file__)))
if "site-packages" in str(current_file_dir) or "dist-packages" in str(current_file_dir):
project_root = current_file_dir.parent.parent
install_paths = {
"source": current_file_dir.parent,
"examples": current_file_dir.parent.parent / "fts_examples",
}
else:
project_root = Path(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
install_paths = {}
for p, d in zip(["source", "tests", "require"], ["src", "tests", "requirements"]):
install_paths[p] = project_root / d
install_paths = {k: v for k, v in install_paths.items() if v is not None and v.exists()}
return project_root, install_paths
def _is_package_installed(package_name: str) -> bool:
"""Check if a package is installed."""
try:
__import__(package_name)
return True
except ImportError:
return False
[docs]
def toggle_lightning_imports(mode: str = "unified", debug: bool = False) -> None:
"""Toggle between standalone and unified Lightning imports."""
try:
if mode == "unified" and not _is_package_installed("lightning"):
print("Warning: Cannot toggle to unified imports because the 'lightning' package is not installed.")
print("Please install the unified Lightning package with: uv pip install lightning")
return
elif mode == "standalone" and not _is_package_installed("pytorch_lightning"):
print("Warning: Cannot toggle to standalone imports because the 'pytorch-lightning' package is not "
"installed.")
print("Please install the standalone Lightning package with: uv pip install pytorch-lightning")
return
_, install_paths = get_project_paths()
if mode == "standalone":
print("Converting to standalone imports (e.g. lightning.pytorch -> pytorch_lightning)...")
use_standalone_pl(install_paths.values(), debug=debug)
else:
print("Converting to unified imports (e.g. pytorch_lightning -> lightning.pytorch)...")
use_unified_pl(install_paths.values(), debug=debug)
print(f"Successfully toggled to {mode} Lightning imports.")
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError(f"Failed to toggle Lightning imports: {e}")