Source code for htpolynet.utils.checkpoint

"""Implements a simple checkpointing scheme using a wrapper.

Author: Cameron F. Abrams <cfa22@drexel.edu>
"""
import functools
import logging
import os

import yaml

logger=logging.getLogger(__name__)

[docs] class Checkpoint: default_filename='checkpoint_state.yaml' def __init__(self,input_dict={}): self.my_abspath=os.getcwd() self.calls:list=input_dict.get('calls',[]) self.results:dict=input_dict.get('results',{}) self.narrative:list=input_dict.get('narrative',[])
[docs] def to_yaml(self): with open(self.default_filename,'w') as f: f.write(yaml.dump(self))
[docs] @classmethod def from_yaml(cls): try: with open(cls.default_filename,'r') as f: yaml_string=f.read() inst=yaml.load(yaml_string,Loader=yaml.Loader) # overwrite the absolute path in the file upon reading in inst.my_abspath=os.getcwd() return inst except FileNotFoundError: return cls()
_CP_=Checkpoint()
[docs] def enableCheckpoint(method): """Wraps any method so that every call is registered in a history of calls in a written checkpoint file. Args: method: name of method to be wrapped Returns: method: wrapped method """ @functools.wraps(method) def wrapper_method(self,*args,**kwargs): ''' define working directory as current directory relative to the checkpoint ''' mywd=os.path.relpath(os.getcwd(),_CP_.my_abspath) ''' if the method calling this from the current directory is already in the checkpoint history, to not call the method, just return ''' if len(_CP_.calls)>0 and (method.__name__,mywd) in _CP_.calls: logger.info(f'Skipping {method.__name__} in {mywd}') return ''' call the method, save result ''' result=method(self,*args,**kwargs) ''' register this method call in this directory and its results ''' _CP_.calls.append((method.__name__,mywd)) _CP_.results.update(result) # must be a dict _CP_.narrative.append(f'Method {method.__name__} called in {os.path.join(_CP_.my_abspath,mywd)} gave result {result}') ''' update the written checkpoint file ''' _write_checkpoint() return result return wrapper_method
def _write_checkpoint(): """Writes the checkpoint file in its globally resolved location.""" sv=os.getcwd() os.chdir(_CP_.my_abspath) _CP_.to_yaml() os.chdir(sv)
[docs] def read_checkpoint(): """Creates a new global Checkpoint object by reading from the default file. Returns: dict: current results dictionary with any pathnames resolved as absolute """ global _CP_ _CP_=Checkpoint.from_yaml() if len(_CP_.calls)>0: lwd=os.path.join(_CP_.my_abspath,_CP_.calls[-1][1]) return {c:os.path.join(lwd,x) for c,x in _CP_.results.items()}