Source code for pestifer.tasks.terminate

# Author: Cameron F. Abrams, <cfa22@drexel.edu>
"""
Definition of the :class:`TerminateTask` class for terminating a pestifer build.
This task is a descendant of the :class:`MDTask <pestifer.tasks.md.MDTask>` class and is used to prepare the system for termination.
It handles the copying of state files, writing chain maps, and packaging the system for NAMD runs.
The task also manages the state of the simulation, including the base molecule and various file extensions such as PSF, PDB, COOR, XSC, and VEL.
The task is designed to be used in a workflow where the simulation needs to be gracefully terminated and packaged for further analysis or continuation.
It ensures that all necessary files are collected and organized, making it easy to resume or analyze the simulation later.
"""
import logging
import shutil
import yaml
import os
import numpy as np

from .mdtask import MDTask
from ..charmmff.charmmffprm import CharmmParamFile
from ..core.artifacts import *
from ..molecule.molecule import Molecule
from ..psfutil.psfcontents import PSFContents
from ..util.colors import PestiferColors
from ..util.stringthings import my_logger
from ..util.util import cell_from_xsc

logger = logging.getLogger(__name__)

[docs] class TerminateTask(MDTask): """ TerminateTask class for terminating a pestifer build. This class inherits from the :class:`MDTask <pestifer.tasks.md.MDTask>` class and is used to prepare the system for termination. It handles the copying of state files, writing chain maps, and packaging the system for NAMD runs. """ _yaml_header = 'terminate' """ YAML header for the TerminateTask, used to identify the task in configuration files as part of a ``tasks`` list. This header is used to declare TerminateTask objects in YAML task lists. It is typically used at the end of a build workflow to finalize the state of the simulation and prepare for termination. """
[docs] def do(self) -> int: self.next_basename() if 'chainmapfile' in self.specs: self.write_chainmaps() self.result = self.test_standard() if self.specs.get('basename'): self.copy_state_to_basename() minimal_prm = self.generate_minimal_params() if minimal_prm: state: StateArtifacts = self.get_current_artifact('state') state.minimal_prm = CharmmffParFileArtifact(data=minimal_prm, keep=True) state.data['minimal_prm'] = state.minimal_prm self.print_system_report() self.result += self.make_package() self.result += self.cleanup() return self.result
[docs] def write_chainmaps(self): """ Write the chain maps to a YAML file. This method retrieves the base molecule from the state variables, gets the chain maps, and writes them to a specified YAML file. The chain maps are used to map chains in the molecular structure, which is useful for understanding the topology of the system. """ bm: Molecule = self.get_current_artifact_data('base_molecule') if bm: maps = bm.get_chainmaps() with open(self.specs['chainmapfile'], 'w') as f: yaml.dump(maps, f) self.register(self.specs['chainmapfile'].replace('.yaml', ''), key='chainmapfile', artifact_type=YAMLFileArtifact)
[docs] def copy_state_to_basename(self): """Copy current state files (psf, pdb, coor, xsc, vel, minimal_prm) to the user-specified basename and re-register state.""" basename = self.specs.get('basename') state: StateArtifacts = self.get_current_artifact('state') if not state: return new_state = {} for ext in ['psf', 'pdb', 'coor', 'xsc', 'vel']: fa: FileArtifact = getattr(state, ext, None) if fa and fa.exists(): dest = f'{basename}.{ext}' if fa.name != dest: shutil.copy(fa.name, dest) new_fa = fa.copy(data=dest) new_fa.keep = True new_state[ext] = new_fa fa: FileArtifact = getattr(state, 'minimal_prm', None) if fa and fa.exists(): dest = f'{basename}_minimal.prm' if fa.name != dest: shutil.copy(fa.name, dest) new_fa = fa.copy(data=dest) new_fa.keep = True new_state['minimal_prm'] = new_fa if new_state: self.register(new_state, key='state', artifact_type=StateArtifacts)
[docs] def make_package(self): """ Create a package for a production NAMD run starting from the end of the build. State files are included in the tarball under their existing names (the terminate basename); only the tarball itself and the NAMD config script use the package basename. """ package_specs = self.specs.get('package', {}) if not package_specs or not package_specs.get('basename'): logger.debug('No package basename provided; packaging will not be performed.') return 0 md_specs = package_specs.get('namd', {}) pkg_basename = package_specs.get('basename') TarballContents = FileArtifactList() state: StateArtifacts = self.get_current_artifact('state') for ext in ['psf', 'pdb', 'coor', 'xsc', 'vel']: fa: FileArtifact = getattr(state, ext, None) if fa and fa.exists(): TarballContents.append(fa) # Minimal parameter file was already stored in state by do(); retrieve it here. # Replace pipeline par/stream file artifacts with just the minimal prm so that # namdrun() writes only one "parameters" line in the NAMD config. min_artifact = state.minimal_prm if state else None if min_artifact: self.register([min_artifact], key='charmmff_parfiles', artifact_type=CharmmffParFileArtifacts) self.register([], key='charmmff_streamfiles', artifact_type=CharmmffStreamFileArtifacts) result = 0 if md_specs: logger.debug(f'Packaging for namd using basename {pkg_basename}') save_specs = self.specs self.specs = md_specs self.specs['basename'] = pkg_basename result = self.namdrun(script_only=True, skip_standard_params=(min_artifact is not None)) self.specs = save_specs TarballContents.append(self.get_current_artifact('namd')) constraints = self.specs.get('constraints', {}) if constraints: self.make_constraint_pdb(constraints, statekey='consref') TarballContents.append(self.get_current_artifact('consref')) else: logger.debug(f'No NAMD configuration is included in the package.') if min_artifact: TarballContents.append(min_artifact) TarballContents.make_tarball(pkg_basename, arcname_prefix=pkg_basename, unique=True, remove=True) return result
[docs] def print_system_report(self): """Log a summary of the final built system: file sizes, topology counts, and box dimensions.""" state: StateArtifacts = self.get_current_artifact('state') if not state: return W = 60 sep = '=' * W logger.info(sep) logger.info('System Report') logger.info(sep) # File sizes logger.info('Output files:') for ext in ['psf', 'pdb', 'coor', 'xsc', 'vel']: fa: FileArtifact = getattr(state, ext, None) if fa and fa.exists(): size_kb = os.path.getsize(fa.name) / 1024 logger.info(f' {fa.name:<40s} {size_kb:>10.1f} kB') # Topology from PSF if state.psf and state.psf.exists(): psf = PSFContents(state.psf.name) logger.info('Topology (from PSF):') sections = [ ('ATOM', 'Atoms'), ('BOND', 'Bonds'), ('THETA', 'Angles'), ('PHI', 'Dihedrals'), ('IMPHI', 'Impropers'), ('CRTERM', 'Cross-terms'), ] for key, label in sections: count = psf.token_count.get(key) if count is not None: logger.info(f' {label:<16s} {count:>10,}') # Box dimensions from XSC if state.xsc and state.xsc.exists(): box, _ = cell_from_xsc(state.xsc.name) if box is not None: logger.info('Periodic box (from XSC):') labels = ['a', 'b', 'c'] for i, label in enumerate(labels): length = float(np.linalg.norm(box[i])) vec = box[i] logger.info(f' {label} = ({vec[0]:8.3f}, {vec[1]:8.3f}, {vec[2]:8.3f}) Å |{label}| = {length:.3f} Å') logger.info(sep)
[docs] def generate_minimal_params(self) -> str | None: """Generate a minimal consolidated CHARMM parameter file for the current PSF. Collects all parameter files (.prm and .str) registered in the pipeline, merges them, and extracts only the records needed for the atom types present in the current PSF. Writes a single ``.prm`` file and returns its name. Returns ``None`` if no parameter files are available. """ state: StateArtifacts = self.get_current_artifact('state') if not state or not state.psf or not state.psf.exists(): logger.debug('generate_minimal_params: no PSF available, skipping') return None charmmff_parfiles: CharmmffParFileArtifacts = self.get_current_artifact('charmmff_parfiles') charmmff_streamfiles: CharmmffStreamFileArtifacts = self.get_current_artifact('charmmff_streamfiles') param_files = [] if charmmff_parfiles: param_files.extend(fa.name for fa in charmmff_parfiles if os.path.exists(fa.name)) if charmmff_streamfiles: param_files.extend(fa.name for fa in charmmff_streamfiles if os.path.exists(fa.name)) if not param_files: logger.debug('generate_minimal_params: no parameter files available, skipping') return None psf = PSFContents(state.psf.name) atomtypes = set(a.atomtype for a in psf.atoms) logger.debug(f'generate_minimal_params: {len(atomtypes)} unique atom types in PSF') combined = CharmmParamFile() for fname in param_files: try: combined.merge(CharmmParamFile.from_file(fname)) logger.debug(f'generate_minimal_params: parsed {fname}') except Exception as exc: logger.warning(f'generate_minimal_params: failed to parse {fname}: {exc}') minimal = combined.extract_for_atomtypes(atomtypes) logger.debug(f'generate_minimal_params: {minimal.summary()}') outname = f'{self.basename}_minimal.prm' minimal.write(outname, title=f'Minimal CHARMM parameter file for {self.basename}') logger.debug(f'Wrote minimal parameter file: {outname}') return outname
[docs] def cleanup(self): if not self.specs.get('cleanup', True): logger.debug('Cleanup disabled; skipping cleanup step.') return 0 archive_name = self.specs.get('artifacts', 'artifacts') all_file_artifacts: FileArtifactList = self.pipeline.get_all_file_artifacts() file_artifacts = FileArtifactList([fa for fa in all_file_artifacts if not fa.keep]) file_artifacts.sort(key=lambda x: x.name) logger.debug(f'{len(file_artifacts)} file artifacts to be included in archive:') my_logger([fa.name for fa in file_artifacts.data], logger.debug, depth=1) file_artifact_names = [fa.name for fa in file_artifacts.data] non_artifact_files = [] cwd_files = os.listdir('.') for f in cwd_files: if f not in file_artifact_names: non_artifact_files.append(f) if len(non_artifact_files) > 0: logger.debug(f'Non-artifact files in current working directory:') my_logger(non_artifact_files, logger.debug, depth=1) file_artifacts.make_tarball(archive_name, remove=True, arcname_prefix=archive_name, unique=True) return 0
[docs] def test_standard(self): standard_specs = self.specs.get('test_standards', {}) if not standard_specs: logger.debug('No test_standards specifications provided; skipping test_standards step.') return 0 artifact_file_collection = self.pipeline.get_all_file_artifacts() testable_file_artifacts = FileArtifactList(list(filter(lambda a: a.pytestable, artifact_file_collection))) if len(testable_file_artifacts) == 0: logger.debug('No pytestable artifacts found; skipping test_standard step.') return 0 standard_path = standard_specs.get('standard_path', None) if standard_path is None: logger.debug('No standard_path provided; skipping test_standard step.') return 0 standard_path = Path(standard_path).resolve() if not standard_path.exists(): # create it standard_path.mkdir(parents=True, exist_ok=True) elif not standard_path.is_dir(): logger.debug(f'standard_path {standard_path} is not a directory; skipping test_standard step.') return 0 # standards path cannot be CWD if standard_path.absolute() == Path.cwd().absolute(): logger.debug('standard_path cannot be the current working directory; skipping test_standard step.') return 0 logger.debug(f'Using standard_path: {standard_path.name}') all_files = [x for x in standard_path.iterdir() if x.is_file()] if len(all_files) == 0 and standard_path.is_dir(): logger.debug(f'standard_path {standard_path} is empty; populating it with current testable artifacts.') for f in testable_file_artifacts: logger.debug(f' Populating standard with {f.name}') shutil.copy(f.name, standard_path / f.name) return 0 elif not standard_path.is_dir(): logger.debug(f'standard_path {standard_path} does not exist; creating it and populating with current testable artifacts.') standard_path.mkdir(parents=True, exist_ok=True) for f in testable_file_artifacts.data: logger.debug(f' Populating standard with {f.name}') shutil.copy(f.name, standard_path / f.name) return 0 else: logger.debug('Testable artifacts:') results = {} for f in testable_file_artifacts.data: logger.debug(f' {f.name}') results[f.name] = "pass" if f.compare(standard_path / f.name) else "fail" self.register(results, key='test_results') logger.debug(f'Registered all test results at "test_results"') my_logger('Standard Test Results:', logger.info) maxnamelen = max([len(n) for n in results.keys()]) if len(results) > 0 else 0 name_format = f'- {{:>{maxnamelen}s}}' num_items = len(results) colors = dict(_pass=PestiferColors['emerald'], _fail=PestiferColors['alizarin']) for item, (name, result) in enumerate(results.items()): color = colors.get(f'_{result}', PestiferColors['alizarin']) logger.info(f' {name_format.format(name)}: {color}{result}{color.OFF} ({item+1}/{num_items})') logger.info('*'*70) return 0