diff --git a/plotid/publish.py b/plotid/publish.py index 1c84f999b664bf98424f4717f022a9449fd6b7d6..67e9b94fac9779f43386e06963df3aef0fb6a097 100644 --- a/plotid/publish.py +++ b/plotid/publish.py @@ -8,6 +8,7 @@ the plot is based on. Additionally, the script that produced the plot will be copied to the destination directory. """ +import ast import os import shutil import sys @@ -239,6 +240,53 @@ class PublishOptions: os.path.join(destination, final_file_path), ) + def export_imports(self, file: str) -> None: + """Export all imported modules in of a python script to file.""" + with open(file, "r") as source: + tree = ast.parse(source.read()) + + analyzer = Analyzer() + analyzer.visit(tree) + analyzer.report() + + +class Analyzer(ast.NodeVisitor): + """Visit and analyze nodes of Abstract Syntax Trees (AST).""" + + def __init__(self) -> None: + self.stats = {"import": [], "from_module": [], "from": []} + + def visit_Import(self, node) -> None: + """Get modules that are imported with the 'import' statement.""" + for alias in node.names: + self.stats["import"].append(alias.name) + self.generic_visit(node) + + def visit_ImportFrom(self, node) -> None: + """Get modules that are imported with the 'from X import Y' statement.""" + for alias in node.names: + self.stats["from"].append(node.module) + self.generic_visit(node) + + def report(self) -> None: + """Create summary of imported modules.""" + # Create set of imported modules + imports_as_set = set( + # Save the first part of import statement since it references the installed + # module. + [module.split(".", 1)[0] for module in self.stats["import"]] + ) + imports_as_set.update( + # Add modules imported with "from X import Y". + [module.split(".", 1)[0] for module in self.stats["from"]] + ) + + # Write every item of the set to one line. + with open("required_imports.txt", "w") as output: + for item in imports_as_set: + output.write(f"{item}\n") + output.close() + kwargs_types_publish = TypedDict( "kwargs_types_publish", @@ -288,3 +336,4 @@ def publish( publish_container = PublishOptions(figs_and_ids, src_datapath, dst_path, **kwargs) publish_container.validate_input() publish_container.export() + publish_container.export_imports(sys.argv[0])