Source code for pestifer.tasks.validate
# Author: Cameron F. Abrams, <cfa22@drexel.edu>
"""
This module contains functions for validating PSF/PDB files generated by Pestifer. The user specifies VMD atomselections, variables to get from those selections and values those variables should have.
"""
import logging
from pathlib import Path
from typing import ClassVar
from pestifer.core.artifacts import StateArtifacts, VMDLogFileArtifact, VMDScriptArtifact, DataArtifact
from pestifer.logparsers.logparser import VMDLogParser
from pestifer.scripters.vmd import VMDScripter
from pestifer.tasks.basetask import VMDTask
logger = logging.getLogger(__name__)
[docs]
class AttributeTest:
"""
This class represents a test for expected values of specific
atom attributes in an atom selection.
Attributes
----------
name : str
The name of the test.
selection : str
The atom selection string.
attribute : str
The atom attribute to test.
value : str
The expected value of the attribute.
value_count : int
The number of expected occurrences of the value.
"""
def __init__(self, name: str, selection: str, attribute: str, value: str, value_count: int = 1):
self.name = name
self.selection = selection
self.attribute = attribute
self.value = value
self.value_count = value_count
[docs]
def write(self, vt: VMDScripter):
"""
Write the commands to the VMD script to execute this test.
"""
vt.addline(f'set test_selection [atomselect top "{self.selection}"]')
vt.addline(f'set result [$test_selection get {self.attribute}]')
target = self.value
target_count = self.value_count
vt.addline(f'set count [llength [lsearch -nocase -exact -all $result {target}]]')
vt.addline(f'if {{$count != {target_count}}} {{')
vt.addline(f' vmdcon "FAIL {self.attribute} has unexpected count $count of value {target} (expected {target_count}) in selection {self.selection}"')
vt.addline(f'}} else {{')
vt.addline(f' vmdcon "PASS {self.attribute} has expected count $count of value {target} (expected {target_count}) in selection {self.selection}"')
vt.addline(f'}}')
[docs]
class ConnectionTest:
"""
This class represents a test for expected number of bonds of a specific type among an atomselection.
Attributes
----------
name : str
The name of the test.
selection : str
The atom selection string.
connection_type : str
The type of connection to test (e.g., "interresidue").
connection_count : int
The expected number of connections.
"""
connection_type_supported = {'interresidue', 'disulfide', 'glycosylation'}
def __init__(self, name: str, selection: str, connection_type: str, connection_count: int = 1):
self.name = name
self.selection = selection
self.connection_type = connection_type
if not self.connection_type in self.connection_type_supported:
raise NotImplementedError(f'Unsupported connection type: {self.connection_type}')
self.connection_count = connection_count
[docs]
def write(self, vt: VMDScripter):
vt.addline(f'set test_selection [atomselect top "{self.selection}"]')
match self.connection_type:
case 'interresidue':
vt.addline(f'set indexes [$test_selection get index]')
vt.addline(f'set atomnames [$test_selection get name]')
vt.addline(f'set resids [$test_selection get residue]')
vt.addline(f'set bondlists [$test_selection getbonds]')
vt.addline(f'set count 0')
vt.addline(f'foreach index $indexes atomname $atomnames resid $resids bondlist $bondlists {{')
vt.addline(f' foreach partner $bondlist {{')
vt.addline(f' if {{$partner < $index}} {{ continue }}') # no double-counting!
vt.addline(f' set idx [lsearch $indexes $partner]')
vt.addline(f' set partner_resid [lindex $resids $idx]')
vt.addline(f' if {{$partner_resid != "" && $partner_resid != $resid}} {{')
vt.addline(f' set partnersel [atomselect top "index $partner"]')
vt.addline(f' set partner_name [$partnersel get name]')
vt.addline(f' vmdcon "BOND $index $atomname $resid <--> $partner $partner_name $partner_resid"')
vt.addline(f' incr count')
vt.addline(f' }}')
vt.addline(f' }}')
vt.addline(f'}}')
vt.addline(f'if {{$count != {self.connection_count}}} {{')
vt.addline(f' vmdcon "FAIL {self.name} has unexpected count $count (expected {self.connection_count}) in selection {self.selection} for interresidue bonds"')
vt.addline(f'}} else {{')
vt.addline(f' vmdcon "PASS {self.name} has expected count $count (expected {self.connection_count}) in selection {self.selection}"')
vt.addline(f'}}')
case 'disulfide':
vt.addline(f'set indexes [$test_selection get index]')
vt.addline(f'set atomnames [$test_selection get name]')
vt.addline(f'set resids [$test_selection get residue]')
vt.addline(f'set bondlists [$test_selection getbonds]')
vt.addline(f'set count 0')
vt.addline(f'foreach index $indexes atomname $atomnames resid $resids bondlist $bondlists {{')
vt.addline(f' foreach partner $bondlist {{')
vt.addline(f' if {{$partner < $index}} {{ continue }}') # no double-counting!
vt.addline(f' set idx [lsearch $indexes $partner]')
vt.addline(f' set partner_resid [lindex $resids $idx]')
vt.addline(f' if {{$partner_resid != "" && $partner_resid != $resid}} {{')
vt.addline(f' set partnersel [atomselect top "index $partner"]')
vt.addline(f' set partner_name [$partnersel get name]')
vt.addline(f' vmdcon "BOND $index $atomname $resid <--> $partner $partner_name $partner_resid"')
vt.addline(f' if {{$atomname == "SG" && $partner_name == "SG"}} {{')
vt.addline(f' incr count')
vt.addline(f' }}')
vt.addline(f' }}')
vt.addline(f' }}')
vt.addline(f'}}')
vt.addline(f'if {{$count != {self.connection_count}}} {{')
vt.addline(f' vmdcon "FAIL {self.name} has unexpected count $count (expected {self.connection_count}) in selection {self.selection} for interresidue bonds"')
vt.addline(f'}} else {{')
vt.addline(f' vmdcon "PASS {self.name} has expected count $count (expected {self.connection_count}) in selection {self.selection}"')
vt.addline(f'}}')
case 'glycosylation':
vt.addline(f'set indexes [$test_selection get index]')
vt.addline(f'set resids [$test_selection get residue]')
vt.addline(f'set bondlists [$test_selection getbonds]')
vt.addline(f'set external_partners [list]')
vt.addline(f'foreach index $indexes resid $resids bondlist $bondlists {{')
vt.addline(f' foreach partner $bondlist {{')
vt.addline(f' set idx [lsearch $indexes $partner]')
vt.addline(f' if {{$idx == -1}} {{')
vt.addline(f' lappend external_partners $partner')
vt.addline(f' }}')
vt.addline(f' }}')
vt.addline(f'}}')
vt.addline(f'set extsel [atomselect top "index $external_partners"]')
vt.addline(f'set names [$extsel get name]')
vt.addline(f'set count [llength [lsearch -nocase -exact -all $names C1]]')
vt.addline(f'if {{$count != {self.connection_count}}} {{')
vt.addline(f' vmdcon "FAIL Selection {self.selection} has no atoms bound to an external C1 atom"')
vt.addline(f'}} else {{')
vt.addline(f' vmdcon "PASS Selection {self.selection} has $count atoms bound to an external C1 atom"')
vt.addline(f'}}')
case '_':
logger.debug(f'Unsupported connection type: {self.connection_type}')
[docs]
class ResidueTest:
"""
This class represents a test for expected properties of a specific residue in an atom selection.
"""
def __init__(self, name: str, selection: str, measure: str, value: int, relation: str = '=='):
self.name = name
self.selection = selection
self.measure = measure
self.relation = relation
self.value = value
[docs]
def write(self, vt: VMDScripter):
vt.addline(f'set test_selection [atomselect top "{self.selection}"]')
pass_msg = f'PASS {self.name} has expected relation {self.relation} to count $count (expected {self.value}) in selection {self.selection}'
fail_msg = f'FAIL {self.name} does not have expected relation {self.relation} to count $count (expected {self.value}) in selection {self.selection}'
match self.measure:
case 'atom_count':
vt.addline(f'set count [$test_selection num]')
vt.addline(f'if {{$count {self.relation} {self.value}}} {{')
vt.addline(f' vmdcon "{pass_msg}"')
vt.addline(f'}} else {{')
vt.addline(f' vmdcon "{fail_msg}"')
vt.addline(f'}}')
case 'residue_count':
vt.addline(f'set resids [$test_selection get residue]')
vt.addline(f'set unique_resids [lsort -unique $resids]')
vt.addline(f'set count [llength $unique_resids]')
vt.addline(f'if {{$count {self.relation} {self.value}}} {{')
vt.addline(f' vmdcon "{pass_msg}"')
vt.addline(f'}} else {{')
vt.addline(f' vmdcon "{fail_msg}"')
vt.addline(f'}}')
case '_':
logger.debug(f'Unsupported measure type: {self.measure}')
[docs]
class ValidateTask(VMDTask):
"""
This class represents a validation task for PSF/PDB files. A validate
task can be inserted anywhere in the workflow to perform validation checks.
"""
_yaml_header: ClassVar[str] = 'validate'
[docs]
def provision(self, packet: dict = {}):
super().provision(packet)
self.test_specs = self.specs.get('tests', [])
self.attribute_tests: list[AttributeTest] = []
self.connection_tests: list[ConnectionTest] = []
self.residue_tests: list[ResidueTest] = []
for test in self.test_specs:
logger.debug(f'Processing test specification: {test}')
assert isinstance(test, dict), f"Test specification {test} must be a dictionary"
assert len(test) == 1, f"Test specification {test} must have a single key-value pair"
test_type = list(test.keys())[0]
match test_type:
case 'attribute_test':
specs = test['attribute_test']
assert isinstance(specs, dict), f"Attribute test specification {specs} must be a dictionary"
self.attribute_tests.append(AttributeTest(**specs))
case 'connection_test':
specs = test['connection_test']
assert isinstance(specs, dict), f"Connection test specification {specs} must be a dictionary"
self.connection_tests.append(ConnectionTest(**specs))
case 'residue_test':
specs = test['residue_test']
assert isinstance(specs, dict), f"Residue test specification {specs} must be a dictionary"
self.residue_tests.append(ResidueTest(**specs))
logger.debug(f'Provisioned {len(self.attribute_tests)} Attribute tests: {self.attribute_tests}')
logger.debug(f'Provisioned {len(self.connection_tests)} Connection tests: {self.connection_tests}')
logger.debug(f'Provisioned {len(self.residue_tests)} Residue tests: {self.residue_tests}')
[docs]
def do(self):
"""
Execute the validation task.
"""
state: StateArtifacts = self.get_current_artifact('state')
psf: Path = state.psf
pdb: Path = state.pdb
vt: VMDScripter = self.get_scripter('vmd')
self.next_basename()
vt.newscript(basename=self.basename)
vt.load_psf_pdb(psf.name, pdb.name)
for attribute_test in self.attribute_tests:
attribute_test.write(vt)
for connection_test in self.connection_tests:
connection_test.write(vt)
for residue_test in self.residue_tests:
residue_test.write(vt)
# here we insert the commands to extract values of desired variables
vt.writescript()
self.register(self.basename, key = 'tcl', artifact_type=VMDScriptArtifact)
vt.runscript()
log_artifact = self.register(self.basename, key='log', artifact_type=VMDLogFileArtifact)
self.log = VMDLogParser.from_file(log_artifact.name)
results = self.log.collect_validation_results()
npass = 0
nfail = 0
if not results:
logger.debug(f'Empty validation results')
self.extra_message = "No validation results found."
else:
npass = sum(1 for r in results if 'PASS' in r)
nfail = sum(1 for r in results if 'FAIL' in r)
self.register(dict(npass=npass, nfail=nfail), key='validation_results')
logger.debug(f'Validation results: \x1b[32m\x1b[1m{npass} passing\x1b[0m, \x1b[31m\x1b[1m{nfail} failing\x1b[0m>')
self.extra_message = f"\x1b[32m\x1b[1mpass: {npass}\x1b[0m, \x1b[31m\x1b[1mfail: {nfail}\x1b[0m"
# here we would parse the resulting log file
return nfail