diff --git a/plotid/publish.py b/plotid/publish.py index 67e9b94fac9779f43386e06963df3aef0fb6a097..85252ea087e31d44b3996f1e17cc638fd6e458a4 100644 --- a/plotid/publish.py +++ b/plotid/publish.py @@ -242,47 +242,49 @@ class PublishOptions: def export_imports(self, file: str) -> None: """Export all imported modules in of a python script to file.""" - with open(file, "r") as source: + with open(file, "r", encoding="utf-8") as source: tree = ast.parse(source.read()) analyzer = Analyzer() analyzer.visit(tree) - analyzer.report() + analyzer.report(self.dst_path) class Analyzer(ast.NodeVisitor): """Visit and analyze nodes of Abstract Syntax Trees (AST).""" def __init__(self) -> None: - self.stats = {"import": [], "from_module": [], "from": []} + self.stats: dict[str, list[str]] = { + "import": [], + "from_module": [], + "from": [], + } - def visit_Import(self, node) -> None: + def visit_Import(self, node: ast.Import) -> None: """Get modules that are imported with the 'import' statement.""" for alias in node.names: self.stats["import"].append(alias.name) self.generic_visit(node) - def visit_ImportFrom(self, node) -> None: + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """Get modules that are imported with the 'from X import Y' statement.""" - for alias in node.names: - self.stats["from"].append(node.module) + self.stats["from"].append(str(node.module)) self.generic_visit(node) - def report(self) -> None: + def report(self, dst_dir: str) -> None: """Create summary of imported modules.""" - # Create set of imported modules - imports_as_set = set( - # Save the first part of import statement since it references the installed - # module. - [module.split(".", 1)[0] for module in self.stats["import"]] - ) + # Save the first part of import statement since it references the installed + # module. + imports_as_set = {module.split(".", 1)[0] for module in self.stats["import"]} + imports_as_set.update( # Add modules imported with "from X import Y". - [module.split(".", 1)[0] for module in self.stats["from"]] + {module.split(".", 1)[0] for module in self.stats["from"]} ) + output_file = os.path.join(dst_dir, "required_imports.txt") # Write every item of the set to one line. - with open("required_imports.txt", "w") as output: + with open(output_file, "w", encoding="utf-8") as output: for item in imports_as_set: output.write(f"{item}\n") output.close() diff --git a/tests/test_publish.py b/tests/test_publish.py index 9106236f489c55e0ddc5f8adbe27765a8db7826e..3e2b887fe42ece2d3be3b8561f3790072df28a37 100644 --- a/tests/test_publish.py +++ b/tests/test_publish.py @@ -9,6 +9,7 @@ import os import sys import platform import shutil +from collections import Counter from subprocess import run, CalledProcessError from unittest.mock import patch import matplotlib.pyplot as plt @@ -341,6 +342,33 @@ class TestPublish(unittest.TestCase): str(publish_obj), ) + def test_export_imports(self) -> None: + """ + Test if imports of the calling script are correctly written to file. + This test only works if called from the parent directory, since the path to the + file to test the behaviour has to be specified correctly. + """ + expected_modules = [ + "shutil\n", + "unittest\n", + "subprocess\n", + "platform\n", + "matplotlib\n", + "os\n", + "sys\n", + "plotid\n", + "collections\n", + ] + folder = os.path.join("test_parent", "test_dst_folder") + os.mkdir(folder) + file_path = os.path.join(folder, "required_imports.txt") + publish_obj = PublishOptions(FIGS_AND_IDS, SRC_DIR, DST_PATH) + publish_obj.export_imports(os.path.join("tests", "test_publish.py")) + with open(file_path, "r", encoding="utf-8") as file: + modules = file.readlines() + assert Counter(modules) == Counter(expected_modules) + assert os.path.isfile(file_path) + def tearDown(self) -> None: """Delete all files created in setUp.""" shutil.rmtree(SRC_DIR)