# Author: Cameron F. Abrams, <cfa22@drexel.edu>
"""
Defines the :class:`ByteCollector` and :class:`FileCollector` classes, and defines the :func:`banner` and :func:`my_logger` functions.
"""
import ast
import logging
import operator
import os
import re
import subprocess
import shutil
import sys
import pandas as pd
from argparse import Namespace
from collections import UserList
from io import StringIO
from pathlib import Path
from typing import Callable
logger = logging.getLogger(__name__)
import importlib.metadata
__pestifer_version__ = importlib.metadata.version("pestifer")
_banner_message="""
Pestifer v. {pestifer_version}
https://pestifer.readthedocs.io/en/latest/
Cameron F. Abrams <cfa22@drexel.edu>
Supported in part by Grants GM100472, AI154071,
and AI178833 from the NIH
CHARMM force field files ({charmmff_version}) from the
MacKerell Lab
"""
_enhanced_banner_message="""
░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒███████████▓░░░░░░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░░░░█████████████████▒░░░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░░▓████████████████████░░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░▓▓▓████████▒████████▓▓█░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░▓▒▒▒░▒████▒▒█▓░████▓░░▒▒█░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░▒█▒░▒██▓▓▓██▒░░███▓▓███░▒█▓░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░▓▒░░█▓░░░░░░█▓▓▓░░░░░░▓▓░▒█░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░▓▒░█▒░░░░░░░███▒░░░░░░░█░▒▓░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░▒░░█▒░░░░░▒█▒░░██░░░░░░█░░▓░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▓▒░░░░█▒██████░░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░█▓▒▓██░▒██░░▒░░███░▓█▓░▓█░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░▒░░░░░░▒███▓████░░░░░░▒░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒░▒░▒░▒░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒░▒░▒░▒░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░█████▒░▒█████▒░▓█████░▓██████░██▒▒█████░██████░██████░░░░░░░░
░░░░░░░░░██░░██▒░██░░█▒▒█▓░░▒█░█░░██░▓░██░░██░░▓▒▓█▒░░█░██▒░▒██░░░░░░░
░░░░░░░░░██░░▒██░██░░░▒▒██░░░▒▒░░░██░░░██░░██░░░▒▓█▒░░░░██▒░░██▒░░░░░░
░░░░░░░░░██░░▒██░██░▒▓░░▓███░░░░░░██░░░██░░██░▒█░▓█▒░▒▓░██▒░░██░░░░░░░
░░░░░░░░░██░░██▒░████▓░░░░▓███░░░░██░░░██░░██▓██░▓████▓░██▒▒██▒░░░░░░░
░░░░░░░░░██▓██░░▒██░░▓░▓░░░░▒██░░░██░░░██░░██░░▓░▓█▒░░█░█████░░░░░░░░░
░░░░░░░░░██▒░░░░▒██░░░░▒█░░░░▓█▒░░██░░░██░░██░░░░▓█▒░░░░██▒▓█▒░░░░░░░░
░░░░░░░░░██░░░░░▒█▓░░▓▓░██░░▓██░░░██░░░██░░██░░░░▓█▒░░▒▒██▒▒██░░░░░░░░
░░░░░░░░░██░░░░░▒██▒▓█▒░▒████▒░░░▓▓▓▒░▒▓▓▒▓██▒░░░██▓▒▓█░▓█▒░██▒░░░░░░░
░░░░░░░░░██░░░░░▓███▓▒░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▓███░██▒░▒█▓░░░░░░░
░░░░░░░░░▓█▒░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▓█░░██▒░░░░░░
░░░░░░░░▓▓▒░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▓█░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░ https://pestifer.readthedocs.io/en/latest/ ░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░ Cameron F. Abrams, cfa22@drexel.edu ░░░░░░░░░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
"""
def _discover_charmmff_version() -> str:
"""Return a human-readable string describing the available CHARMMFF version(s)."""
try:
import datetime
from .. import resources as _resources
charmmff_dir = Path(os.path.dirname(_resources.__file__)) / 'charmmff'
_month_num = {
'jan': 1, 'feb': 2, 'mar': 3, 'apr': 4, 'may': 5, 'jun': 6,
'jul': 7, 'aug': 8, 'sep': 9, 'oct': 10, 'nov': 11, 'dec': 12,
}
def _key_to_date(d: Path) -> datetime.date:
mon = _month_num.get(d.name[:3], 0)
yr = 2000 + int(d.name[3:])
return datetime.date(yr, mon, 1)
version_dirs = sorted(
[d for d in charmmff_dir.iterdir() if d.is_dir() and re.match(r'^[a-z]{3}\d{2}$', d.name)],
key=_key_to_date
)
if not version_dirs:
return 'unknown'
# Convert keys like 'jul24' -> 'July 2024', 'feb26' -> 'Feb 2026'
month_map = {
'jan': 'January', 'feb': 'February', 'mar': 'March', 'apr': 'April',
'may': 'May', 'jun': 'June', 'jul': 'July', 'aug': 'August',
'sep': 'September', 'oct': 'October', 'nov': 'November', 'dec': 'December',
}
def key_to_label(key: str) -> str:
mon = month_map.get(key[:3], key[:3].capitalize())
yr = '20' + key[3:]
return f'{mon} {yr}'
labels = [key_to_label(d.name) for d in version_dirs]
return ', '.join(labels)
except Exception:
return 'unknown'
[docs]
def banner(logf: Callable, args: Namespace):
"""
Writes a banner message to the log file
Parameters
----------
logf: file-like object
The log file to which the banner message will be written.
"""
if args.banner:
if args.kick_ass:
my_logger(_enhanced_banner_message, logf, fill=' ', just='<')
else:
charmmff_version = _discover_charmmff_version()
msg = _banner_message.format(
pestifer_version=__pestifer_version__,
charmmff_version=charmmff_version,
)
my_logger(msg, logf, fill=' ', just='<')
[docs]
class ByteCollector:
"""
A simple string manager.
Parameters
----------
comment_char: str, optional
The character used to denote comments in the string.
Defaults to '#'.
line_length: int, optional
The maximum length of a line in the string.
Defaults to 80.
"""
def __init__(self, comment_char='#', line_length=80):
self.line_length = line_length
self.comment_char = comment_char
self.byte_collector = ''
[docs]
def reset(self):
"""
Resets the string
"""
self.byte_collector = ''
[docs]
def write(self, msg):
"""
Appends msg to the string
Parameters
----------
msg: str
the message
"""
self.byte_collector += msg
[docs]
def addline(self, msg, end='\n'):
"""Appends msg to the string as a line
Parameters
----------
msg: str
the message
end: str, optional
the end-of-line byte
"""
self.byte_collector += f'{msg}{end}'
[docs]
def lastline(self, end='\n', exclude='#'):
"""
Returns last line in the string
Parameters
----------
end: str, optional
end-of-line byte
exclude: str, optional
comment byte
"""
lines = [x for x in self.byte_collector.split(end) if (len(x) > 0 and not x.startswith(exclude))]
if len(lines) > 0:
return lines[-1]
else:
return None
[docs]
def has_statement(self, statement: str, end='\n', exclude='#'):
"""
Determines if a particular statement is on at least one non-comment line
Parameters
----------
statement: str
the statement; e.g., 'exit'
end: str, optional
end-of-line byte
exclude: str, optional
comment byte
"""
lines = [x for x in self.byte_collector.split(end) if (len(x) > 0 and not x.startswith(exclude))]
if len(lines) > 0:
for l in lines:
if statement in l:
return True
return False
[docs]
def reassign(self, varname: str, varvalue: str, style: str = 'bash', end: str = '\n'):
"""
Reassigns variable varname to value varvalue if an assignment to varname already
exists in the byte_collector
Parameters
----------
varname: str
the name of the variable to reassign
varvalue: str
the value to assign to the variable
style: str, optional
the style of the assignment (e.g., ``bash``, 'SLURM')
end: str, optional
end-of-line byte
"""
lines = self.byte_collector.split(end)
logger.debug(lines)
if len(lines) > 0:
for i, l in enumerate(lines):
if len(l) > 0:
if style == 'bash':
if l.startswith('export'):
logger.debug(f'{varname} {varvalue} : export assignment {l}')
tokens = l.split()
assignment = tokens[1]
evarname, evarvalue = assignment.split('=')
if evarname == varname:
lines[i] = f'export {varname}={varvalue}'
elif l.startswith(varname) and '=' in l:
lines[i] = f'{varname}={varvalue}'
elif style == 'SLURM': # will only replace SLURM variables!
if l.startswith('#SBATCH'):
tokens = l.split()
logger.debug(f'SLURM statement {l}')
if tokens[1].startswith('--'):
assignment = tokens[1][2:]
if not '=' in assignment:
logger.debug(f'cannot parse SLURM directive {l}')
else:
svarname, svarvalue = assignment.split('=')
if svarname == varname:
lines[i] = f'#SBATCH --{svarname}={varvalue}'
elif tokens[1].startswith('-'):
svarname = tokens[1][1:]
svarvalue = tokens[2]
if svarname == varname:
lines[i] = f'#SBATCH -{svarname} {varvalue}'
self.byte_collector = end.join(lines)
logger.debug(self.byte_collector)
[docs]
def ingest_file(self, filename: Path | str):
"""
Appends contents of file ``filename`` to the string
Parameters
----------
filename: str
the name of the file
"""
with open(filename, 'r') as f:
self.byte_collector += f.read()
[docs]
def write_file(self, filename: Path | str):
"""
Writes string to ``filename``
Parameters
----------
filename: str
the name of the file
"""
with open(filename, 'w') as f:
f.write(self.byte_collector)
[docs]
def log(self, msg: str):
"""
Logs a message to the string
Parameters
----------
msg: str
the message to log"""
my_logger(msg, self.addline)
[docs]
def banner(self, msg: str):
"""
Logs a banner message to the string
Parameters
----------
msg: str
the message to log
"""
my_logger(msg, self.addline, fill='#', width=80, just='^')
def __str__(self):
"""
Returns the string representation of the byte collector
"""
return self.byte_collector
[docs]
def my_logger(msg: str | list | dict | pd.DataFrame, logf: Callable, width=None, fill='', just='<', frame='', depth=0, **kwargs):
"""
A fancy recursive logger
Parameters
----------
msg: str, list
the message to be logged, either as a single string, list, or dict
logf: function
writer; e.g., print, f.write, etc.
width: int, optional
linelength in bytes
fill: str, optional
single character used to fill blank spaces
sep: str, optional
single character used in join calls
just: str, optional
format character
"""
if width is None:
if logf is print:
ts = shutil.get_terminal_size((80, 20))
width = ts.columns
else:
width = 67
fmt = r'{'+r':'+fill+just+f'{width}'+r'}'
ll = ' ' if just in '^>' else ''
rr = ' ' if just in '^<' else ''
if frame:
ffmt = r'{'+r':'+frame+just+f'{width}'+r'}'
logf(ffmt.format(frame))
# logger.debug(f'Logging message of type {type(msg)} at depth {depth}; is list? {isinstance(msg, list)}; is dict? {isinstance(msg, dict)}; is DataFrame? {isinstance(msg, pd.DataFrame)}')
if isinstance(msg, list):
for tok in msg:
my_logger(tok, logf, width=width, fill=fill, just=just, frame=False, depth=depth, kwargs=kwargs)
elif isinstance(msg, dict):
# logger.debug(f'logging dict of length {len(msg)}')
for key, value in msg.items():
if isinstance(value, str) or not hasattr(value, "__len__"):
my_logger(f'{key}: {value}', logf, width=width, fill=fill, just=just, frame=False, depth=depth, kwargs=kwargs)
elif isinstance(value, list):
# logger.debug(f'key {key} is a list of length {len(value)}')
for tok in value:
my_logger(tok, logf, width=width, fill=fill, just=just, frame=False, depth=depth+1, kwargs=kwargs)
else:
my_logger(f'{key}:', logf, width=width, fill=fill, just=just, frame=False, depth=depth, kwargs=kwargs)
my_logger(value, logf, width=width, fill=fill, just=just, frame=False, depth=depth+1, kwargs=kwargs)
elif isinstance(msg, pd.DataFrame):
dfoutmode = kwargs.get('dfoutmode', 'value')
if dfoutmode == 'value':
my_logger([ll+x+rr for x in msg.to_string().split('\n')], logf, width=width, fill=fill, just=just, frame=False, depth=depth, kwargs=kwargs)
elif dfoutmode == 'info':
buf = StringIO()
msg.info(buf=buf)
my_logger([ll+x+rr for x in buf.getvalue().split('\n')], logf, width=width, fill=fill, just=just, frame=False, depth=depth, kwargs=kwargs)
else:
return
else:
indent = f'{" "*depth*2}' if just == '<' and not kwargs.get('no_indent', False) else ''
if isinstance(msg, str):
lns = msg.split('\n')
if len(lns) > 1:
my_logger(lns, logf, width=width, fill=fill, just=just, frame=False, depth=depth+1, kwargs=kwargs)
else:
outstr = indent + ll + f'{msg}' + rr
logf(fmt.format(outstr))
else:
outstr = indent + ll + f'{msg}' + rr
logf(fmt.format(outstr))
if frame:
logf(ffmt.format(frame))
[docs]
class FileCollector(UserList):
"""
A class for handling collections of files; inherits from :class:`~collections.UserList`.
"""
[docs]
def flush(self):
"""
Deletes all files in the collection and clears the collection
"""
logger.debug(f'Flushing file collector: {len(self)} files.')
for f in self:
if os.path.exists(f):
os.remove(f)
else:
logger.debug(f'{f}: not found.')
self.clear()
[docs]
def tarball(self, basename, preferred_extension='tgz'):
"""
Makes a tarball of the files in the collection
Parameters
----------
basename: str
basename of the resulting tarball
"""
filelist = ' '.join([x for x in self])
subprocess.run(f'tar zvcf {basename}.{preferred_extension} {filelist}',
shell=True,
executable='/bin/bash',
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
logger.debug(f'generated tarball {basename}.{preferred_extension}')
[docs]
def oxford(a_list: list[str], conjunction: str = 'or') -> str | None:
"""
Obey the Oxford comma when ending a list with a conjunction
Parameters
----------
a_list: list of str
The list of items to join.
conjunction: str, optional
The conjunction to use when joining the list. Defaults to 'or'.
Returns
-------
str: The Oxford-compliant string representation of the list.
"""
if not a_list: return None
if len(a_list) == 1:
return a_list[0]
elif len(a_list) == 2:
return f'{a_list[0]} {conjunction} {a_list[1]}'
else:
return ", ".join(a_list[:-1]) + f', {conjunction} {a_list[-1]}'
[docs]
def linesplit(line: str, cchar: str = '!') -> tuple[str, str]:
"""
Splits a line at the first occurrence of the character ``cchar``.
"""
if not cchar in line:
return line, ''
idx = line.index(cchar)
if idx == 0:
return '', line[1:]
return line[:idx], line[idx+1:]
[docs]
def striplist(L: list[str]) -> list[str]:
"""
Strips whitespace from each item in a list and removes empty strings.
Parameters
----------
L: list of str
The list of strings to be stripped.
Returns
-------
list of str: A new list with whitespace stripped from each item and empty strings removed.
"""
l = [x.strip() for x in L]
while '' in l:
l.remove('')
return l
[docs]
def to_latex_math(s: str) -> str:
"""
Converts a string into LaTeX math format. Written by ChatGPT 4o.
"""
def convert_token(token):
# Match base, subscript, and superscript using regex
match = re.fullmatch(r'\s*([a-zA-Z0-9]+)(?:_([a-zA-Z0-9^]+))?(?:\^([a-zA-Z0-9_]+))?\s*', token)
if not match:
return token # return as-is if it doesn't match expected pattern
base, sub, sup = match.groups()
result = base
if sub:
if '^' in sub:
# Handle cases like a_x^2 packed into sub
sub, sup = sub.split('^', 1)
result += f'$_{{{sub}}}$'
if sup:
result += f'$^{{{sup}}}$'
return result
parts = s.split(',')
converted = [convert_token(part) for part in parts]
return ', '.join(converted)
[docs]
def raise_clean(ErrorInstance: Exception):
"""
Raises an error with a clean message showing no traceback.
Parameters
----------
ErrorInstance: Exception instance
The exception instance to raise.
"""
try:
raise ErrorInstance
except ErrorInstance.__class__ as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
# def special_update(dict, key, val, mode='strict'):
# """
# Special update function for dictionaries with different modes.
# Parameters
# ----------
# dict: dict
# The dictionary to update.
# key: str
# The key to update.
# val: any
# The value to set.
# mode: str, optional
# The mode of the update. Can be 'strict' or 'permissive'. Defaults to 'strict'.
# Raises
# ------
# KeyError
# If the mode is 'strict' and the key already exists.
# """
# if mode == 'strict' and key in dict:
# raise KeyError(f'Key {key} already exists in dictionary.')
# dict[key] = val
[docs]
def plu(n: int, singform: str = '', plurform: str = 's') -> str:
"""
Returns the singular or plural form of a word based on the count.
Parameters
----------
n: int
The count to determine singular or plural.
singform: str, optional
The singular form of the word/word ending. Defaults to an empty string.
plurform: str, optional
The plural form of the word/word ending. Defaults to 's'.
Returns
-------
str: The appropriate form of the word based on the count.
Examples
--------
>>> for ncats in [1, 2]:
... print(f'There {plu(ncats, "is", "are")} {ncats} cat{plu(ncats)}.')
There is 1 cat.
There are 2 cats.
>>> for nberries in [1, 2]:
... print(f'There {plu(nberries, "is", "are")} {nberries} berr{plu(nberries, "y", "ies")}.')
There is 1 berry.
There are 2 berries.
>>> for noctopi in [1, 2]:
... print(f'There {plu(noctopi, "is", "are")} {noctopi} octop{plu(noctopi, "us", "i")}.')
There is 1 octopus.
There are 2 octopi.
"""
return singform if n == 1 else plurform
# Supported operators for comparisons
comp_ops = {
ast.Eq: operator.eq,
ast.NotEq: operator.ne,
ast.Lt: operator.lt,
ast.LtE: operator.le,
ast.Gt: operator.gt,
ast.GtE: operator.ge,
ast.In: lambda a, b: a in b,
ast.NotIn: lambda a, b: a not in b,
}
# Supported boolean ops
bool_ops = {
ast.And: all,
ast.Or: any,
}
[docs]
def parse_filter_expression(expr: str):
"""
Parses a logical expression string (e.g., "segtype == 'protein' and resid > 100")
and returns a function f(obj) -> bool that evaluates the expression on obj's attributes.
"""
expr_ast = ast.parse(expr, mode='eval')
def _eval(node, obj):
if isinstance(node, ast.Expression):
return _eval(node.body, obj)
elif isinstance(node, ast.BoolOp):
op_func = bool_ops[type(node.op)]
return op_func(_eval(value, obj) for value in node.values)
elif isinstance(node, ast.Compare):
# Left side can be Name or Attribute, only support simple Names for now
if isinstance(node.left, ast.Name):
left_val = getattr(obj, node.left.id, None)
else:
raise ValueError(f"Unsupported left operand type: {type(node.left)}")
# Only support single comparator for simplicity
op = node.ops[0]
comparator = node.comparators[0]
# Evaluate right side literals
if isinstance(comparator, (ast.Constant)):
right_val = comparator.value if hasattr(comparator, 'value') else comparator.s
elif isinstance(comparator, ast.Name):
right_val = getattr(obj, comparator.id)
else:
# For lists or tuples, recursively evaluate literals
right_val = ast.literal_eval(comparator)
op_func = comp_ops.get(type(op))
if op_func is None:
raise ValueError(f"Unsupported comparison operator: {type(op)}")
return op_func(left_val, right_val)
elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
return not _eval(node.operand, obj)
elif isinstance(node, ast.Constant):
# True / False / None literals
return node.value
else:
raise ValueError(f"Unsupported AST node: {type(node)}")
def filter_func(obj):
return _eval(expr_ast, obj)
return filter_func