Shortcuts

Source code for fts_examples.profiling.config

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
MemProfiler Configuration Dataclasses
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This module defines the configuration dataclasses for the MemProfiler.
"""
from typing import Callable
from dataclasses import dataclass, field, fields
from pathlib import Path

import torch
from lightning.fabric.utilities import rank_zero_warn


[docs] @dataclass class MemProfilerHooks: pre_forward_hooks: list[str | Callable] = \ field(default_factory=lambda: ['fts_examples.profiling.npp_hooks._hook_npp_pre_forward']) post_forward_hooks: list[str | Callable] = \ field(default_factory=lambda: ['fts_examples.profiling.npp_hooks._hook_npp_post_forward']) # the provided reset_state_hooks will be called with the model and the `save_hook_attrs` list reset_state_hooks: list[str | Callable] = \ field(default_factory=lambda: ['fts_examples.profiling.npp_hooks._reset_memory_hooks_state'])
[docs] @dataclass class MemProfilerFuncs: # can specify arbitrary list of `memprofilable` decorated function names # funcs that will be added to all memory collection types default: set[str] = field(default_factory=lambda: {'training_step'}) cpu: set[str] = field(default_factory=set) cuda: set[str] = field(default_factory=set) cuda_allocator_history: set[str] = field(default_factory=set) fsdp: set[str] = field(default_factory=set)
[docs] @dataclass class MemProfilerSchedule: # keeping schedule simple as possibile for now, may expand to accommodate more flexible schedules in the future warmup_iters: int = 1 max_iter: int | None = None
[docs] @dataclass class MemProfilerCfg: """Configuration dataclass for the MemProfiler. :param enabled: Whether to enable memory profiling. :param collect_funcs: A MemProfilerFuncs instance specifying the functions to collect per memory collection type. :param cuda_allocator_history: Whether to collect CUDA memory allocator history. :param track_fsdp_mem: Whether to collect FSDP memory statistics. :param fsdp_mem_track_module_depth: The depth of FSDP modules to track. :param fsdp_mem_tracker_tabulate: Whether to print FSDP memory statistics in a tabular format. :param fsdp_mem_tracker_units: The units to use for FSDP memory statistics. :param fsdp_mem_tracker_root_module: The root module to use for FSDP memory statistics. :param dump_memorystats_pickle: Whether to dump memory statistics to a pickle file. :param dump_memorystats_yaml: Whether to dump memory statistics to a yaml file. :param schedule: A MemProfilerSchedule instance specifying the schedule for memory collection. :param save_dir: The directory to save the memory statistics. :param enable_memory_hooks: Whether to enable memory hooks. :param enable_saved_tensors_hooks: Whether to enable saved tensors hooks. :param memory_hooks: A MemProfilerHooks instance specifying the memory hooks. :param saved_tensors_funcs: A list of saved tensors functions. :param save_hook_attrs: A list of module state attributes to save. :param retain_hooks_for_funcs: A set of functions to retain memory hooks for. """ enabled: bool = False # specify funcs to collect per memory collection type, a default list to apply to all types or both composed collect_funcs: MemProfilerFuncs = field(default_factory=MemProfilerFuncs) cuda_allocator_history: bool = False track_fsdp_mem: bool = False fsdp_mem_track_module_depth: int = 2 fsdp_mem_tracker_tabulate: bool = False fsdp_mem_tracker_units: str = "MiB" fsdp_mem_tracker_root_module: str = "" dump_memorystats_pickle: bool = False dump_memorystats_yaml: bool = True schedule: MemProfilerSchedule = field(default_factory=MemProfilerSchedule) save_dir: str | Path | None = None enable_memory_hooks: bool = True enable_saved_tensors_hooks: bool = True memory_hooks: MemProfilerHooks = field(default_factory=MemProfilerHooks) # because it's frequently used for unpacking and to ensure this dataclass remains serializable, we allow # specification of 'identity_lambda' which will resolve to `lambda x: x` saved_tensors_funcs: list = field(default_factory=lambda: list(('fts_examples.profiling.npp_hooks._npp_hook', 'identity_lambda'))) # if you add custom hooks, make sure to add the desired module state attributes to save to `save_hook_attrs` save_hook_attrs: list = field(default_factory=lambda: ["rss_pre_forward", "rss_post_forward", "rss_diff", "npp_pre_forward", "npp_post_forward", "npp_diff"]) # since we cannot reliably ascertain when all MemProfilerFuncs will be executed, memory hooks will # only be removed once the funcs in this set have reached `max_iter` retain_hooks_for_funcs: set[str] = field(default_factory=lambda: {'training_step'}) def __post_init__(self) -> None: if not self.enabled: return if not torch.cuda.is_available() and any((self.collect_funcs.cuda_allocator_history, self.collect_funcs.cuda, self.cuda_allocator_history)): rank_zero_warn("Disabling CUDA memory profiling functionality since no CUDA device detected.") self.collect_funcs.cuda, self.collect_funcs.cuda_allocator_history = set(), set() self.cuda_allocator_history = False has_hooks = any(getattr(self.memory_hooks, ht.name) for ht in fields(self.memory_hooks)) if not has_hooks: rank_zero_warn("MemProfilerCfg is configured to enable memory hooks but MemProfilerHooks does not have" " any specified.") if self.schedule.max_iter is None: self.schedule.max_iter = self.schedule.warmup_iters + 1 # compose all non-default func sets with the default set default_funcs = self.collect_funcs.default for k in self.collect_funcs.__dataclass_fields__.keys(): if k != 'default': getattr(self.collect_funcs, k).update(default_funcs)