Source code for pestifer.tasks.validate

# Author: Cameron F. Abrams, <cfa22@drexel.edu>

"""
This module contains functions for validating PSF/PDB files generated by Pestifer.  The user specifies VMD atomselections, variables to get from those selections and values those variables should have.
"""

import logging

from pathlib import Path
from typing import ClassVar
from pestifer.core.artifacts import StateArtifacts, VMDLogFileArtifact, VMDScriptArtifact, DataArtifact
from pestifer.logparsers.logparser import VMDLogParser
from pestifer.scripters.vmd import VMDScripter
from pestifer.tasks.basetask import VMDTask

logger = logging.getLogger(__name__)

[docs] class AttributeTest: """ This class represents a test for expected values of specific atom attributes in an atom selection. Attributes ---------- name : str The name of the test. selection : str The atom selection string. attribute : str The atom attribute to test. value : str The expected value of the attribute. value_count : int The number of expected occurrences of the value. """ def __init__(self, name: str, selection: str, attribute: str, value: str, value_count: int = 1): self.name = name self.selection = selection self.attribute = attribute self.value = value self.value_count = value_count
[docs] def write(self, vt: VMDScripter): """ Write the commands to the VMD script to execute this test. """ vt.addline(f'set test_selection [atomselect top "{self.selection}"]') vt.addline(f'set result [$test_selection get {self.attribute}]') target = self.value target_count = self.value_count vt.addline(f'set count [llength [lsearch -nocase -exact -all $result {target}]]') vt.addline(f'if {{$count != {target_count}}} {{') vt.addline(f' vmdcon "FAIL {self.attribute} has unexpected count $count of value {target} (expected {target_count}) in selection {self.selection}"') vt.addline(f'}} else {{') vt.addline(f' vmdcon "PASS {self.attribute} has expected count $count of value {target} (expected {target_count}) in selection {self.selection}"') vt.addline(f'}}')
[docs] class ConnectionTest: """ This class represents a test for expected number of bonds of a specific type among an atomselection. Attributes ---------- name : str The name of the test. selection : str The atom selection string. connection_type : str The type of connection to test (e.g., "interresidue"). connection_count : int The expected number of connections. """ connection_type_supported = {'interresidue', 'disulfide', 'glycosylation'} def __init__(self, name: str, selection: str, connection_type: str, connection_count: int = 1): self.name = name self.selection = selection self.connection_type = connection_type if not self.connection_type in self.connection_type_supported: raise NotImplementedError(f'Unsupported connection type: {self.connection_type}') self.connection_count = connection_count
[docs] def write(self, vt: VMDScripter): vt.addline(f'set test_selection [atomselect top "{self.selection}"]') match self.connection_type: case 'interresidue': vt.addline(f'set indexes [$test_selection get index]') vt.addline(f'set atomnames [$test_selection get name]') vt.addline(f'set resids [$test_selection get residue]') vt.addline(f'set bondlists [$test_selection getbonds]') vt.addline(f'set count 0') vt.addline(f'foreach index $indexes atomname $atomnames resid $resids bondlist $bondlists {{') vt.addline(f' foreach partner $bondlist {{') vt.addline(f' if {{$partner < $index}} {{ continue }}') # no double-counting! vt.addline(f' set idx [lsearch $indexes $partner]') vt.addline(f' set partner_resid [lindex $resids $idx]') vt.addline(f' if {{$partner_resid != "" && $partner_resid != $resid}} {{') vt.addline(f' set partnersel [atomselect top "index $partner"]') vt.addline(f' set partner_name [$partnersel get name]') vt.addline(f' vmdcon "BOND $index $atomname $resid <--> $partner $partner_name $partner_resid"') vt.addline(f' incr count') vt.addline(f' }}') vt.addline(f' }}') vt.addline(f'}}') vt.addline(f'if {{$count != {self.connection_count}}} {{') vt.addline(f' vmdcon "FAIL {self.name} has unexpected count $count (expected {self.connection_count}) in selection {self.selection} for interresidue bonds"') vt.addline(f'}} else {{') vt.addline(f' vmdcon "PASS {self.name} has expected count $count (expected {self.connection_count}) in selection {self.selection}"') vt.addline(f'}}') case 'disulfide': vt.addline(f'set indexes [$test_selection get index]') vt.addline(f'set atomnames [$test_selection get name]') vt.addline(f'set resids [$test_selection get residue]') vt.addline(f'set bondlists [$test_selection getbonds]') vt.addline(f'set count 0') vt.addline(f'foreach index $indexes atomname $atomnames resid $resids bondlist $bondlists {{') vt.addline(f' foreach partner $bondlist {{') vt.addline(f' if {{$partner < $index}} {{ continue }}') # no double-counting! vt.addline(f' set idx [lsearch $indexes $partner]') vt.addline(f' set partner_resid [lindex $resids $idx]') vt.addline(f' if {{$partner_resid != "" && $partner_resid != $resid}} {{') vt.addline(f' set partnersel [atomselect top "index $partner"]') vt.addline(f' set partner_name [$partnersel get name]') vt.addline(f' vmdcon "BOND $index $atomname $resid <--> $partner $partner_name $partner_resid"') vt.addline(f' if {{$atomname == "SG" && $partner_name == "SG"}} {{') vt.addline(f' incr count') vt.addline(f' }}') vt.addline(f' }}') vt.addline(f' }}') vt.addline(f'}}') vt.addline(f'if {{$count != {self.connection_count}}} {{') vt.addline(f' vmdcon "FAIL {self.name} has unexpected count $count (expected {self.connection_count}) in selection {self.selection} for interresidue bonds"') vt.addline(f'}} else {{') vt.addline(f' vmdcon "PASS {self.name} has expected count $count (expected {self.connection_count}) in selection {self.selection}"') vt.addline(f'}}') case 'glycosylation': vt.addline(f'set indexes [$test_selection get index]') vt.addline(f'set resids [$test_selection get residue]') vt.addline(f'set bondlists [$test_selection getbonds]') vt.addline(f'set external_partners [list]') vt.addline(f'foreach index $indexes resid $resids bondlist $bondlists {{') vt.addline(f' foreach partner $bondlist {{') vt.addline(f' set idx [lsearch $indexes $partner]') vt.addline(f' if {{$idx == -1}} {{') vt.addline(f' lappend external_partners $partner') vt.addline(f' }}') vt.addline(f' }}') vt.addline(f'}}') vt.addline(f'set extsel [atomselect top "index $external_partners"]') vt.addline(f'set names [$extsel get name]') vt.addline(f'set count [llength [lsearch -nocase -exact -all $names C1]]') vt.addline(f'if {{$count != {self.connection_count}}} {{') vt.addline(f' vmdcon "FAIL Selection {self.selection} has no atoms bound to an external C1 atom"') vt.addline(f'}} else {{') vt.addline(f' vmdcon "PASS Selection {self.selection} has $count atoms bound to an external C1 atom"') vt.addline(f'}}') case '_': logger.debug(f'Unsupported connection type: {self.connection_type}')
[docs] class ResidueTest: """ This class represents a test for expected properties of a specific residue in an atom selection. """ def __init__(self, name: str, selection: str, measure: str, value: int, relation: str = '=='): self.name = name self.selection = selection self.measure = measure self.relation = relation self.value = value
[docs] def write(self, vt: VMDScripter): vt.addline(f'set test_selection [atomselect top "{self.selection}"]') pass_msg = f'PASS {self.name} has expected relation {self.relation} to count $count (expected {self.value}) in selection {self.selection}' fail_msg = f'FAIL {self.name} does not have expected relation {self.relation} to count $count (expected {self.value}) in selection {self.selection}' match self.measure: case 'atom_count': vt.addline(f'set count [$test_selection num]') vt.addline(f'if {{$count {self.relation} {self.value}}} {{') vt.addline(f' vmdcon "{pass_msg}"') vt.addline(f'}} else {{') vt.addline(f' vmdcon "{fail_msg}"') vt.addline(f'}}') case 'residue_count': vt.addline(f'set resids [$test_selection get residue]') vt.addline(f'set unique_resids [lsort -unique $resids]') vt.addline(f'set count [llength $unique_resids]') vt.addline(f'if {{$count {self.relation} {self.value}}} {{') vt.addline(f' vmdcon "{pass_msg}"') vt.addline(f'}} else {{') vt.addline(f' vmdcon "{fail_msg}"') vt.addline(f'}}') case '_': logger.debug(f'Unsupported measure type: {self.measure}')
[docs] class ValidateTask(VMDTask): """ This class represents a validation task for PSF/PDB files. A validate task can be inserted anywhere in the workflow to perform validation checks. """ _yaml_header: ClassVar[str] = 'validate'
[docs] def provision(self, packet: dict = {}): super().provision(packet) self.test_specs = self.specs.get('tests', []) self.attribute_tests: list[AttributeTest] = [] self.connection_tests: list[ConnectionTest] = [] self.residue_tests: list[ResidueTest] = [] for test in self.test_specs: logger.debug(f'Processing test specification: {test}') assert isinstance(test, dict), f"Test specification {test} must be a dictionary" assert len(test) == 1, f"Test specification {test} must have a single key-value pair" test_type = list(test.keys())[0] match test_type: case 'attribute_test': specs = test['attribute_test'] assert isinstance(specs, dict), f"Attribute test specification {specs} must be a dictionary" self.attribute_tests.append(AttributeTest(**specs)) case 'connection_test': specs = test['connection_test'] assert isinstance(specs, dict), f"Connection test specification {specs} must be a dictionary" self.connection_tests.append(ConnectionTest(**specs)) case 'residue_test': specs = test['residue_test'] assert isinstance(specs, dict), f"Residue test specification {specs} must be a dictionary" self.residue_tests.append(ResidueTest(**specs)) logger.debug(f'Provisioned {len(self.attribute_tests)} Attribute tests: {self.attribute_tests}') logger.debug(f'Provisioned {len(self.connection_tests)} Connection tests: {self.connection_tests}') logger.debug(f'Provisioned {len(self.residue_tests)} Residue tests: {self.residue_tests}')
[docs] def do(self): """ Execute the validation task. """ state: StateArtifacts = self.get_current_artifact('state') psf: Path = state.psf pdb: Path = state.pdb vt: VMDScripter = self.get_scripter('vmd') self.next_basename() vt.newscript(basename=self.basename) vt.load_psf_pdb(psf.name, pdb.name) for attribute_test in self.attribute_tests: attribute_test.write(vt) for connection_test in self.connection_tests: connection_test.write(vt) for residue_test in self.residue_tests: residue_test.write(vt) # here we insert the commands to extract values of desired variables vt.writescript() self.register(self.basename, key = 'tcl', artifact_type=VMDScriptArtifact) vt.runscript() log_artifact = self.register(self.basename, key='log', artifact_type=VMDLogFileArtifact) self.log = VMDLogParser.from_file(log_artifact.name) results = self.log.collect_validation_results() npass = 0 nfail = 0 if not results: logger.debug(f'Empty validation results') self.extra_message = "No validation results found." else: npass = sum(1 for r in results if 'PASS' in r) nfail = sum(1 for r in results if 'FAIL' in r) self.register(dict(npass=npass, nfail=nfail), key='validation_results') logger.debug(f'Validation results: \x1b[32m\x1b[1m{npass} passing\x1b[0m, \x1b[31m\x1b[1m{nfail} failing\x1b[0m>') self.extra_message = f"\x1b[32m\x1b[1mpass: {npass}\x1b[0m, \x1b[31m\x1b[1mfail: {nfail}\x1b[0m" # here we would parse the resulting log file return nfail