Skip to content
Snippets Groups Projects
Commit e2c0d770 authored by Mayr, Hannes's avatar Mayr, Hannes
Browse files

Add export to file and unittest.

parent 26177e15
No related branches found
No related tags found
1 merge request!62Export imported modules
Pipeline #878716 passed
......@@ -242,47 +242,49 @@ class PublishOptions:
def export_imports(self, file: str) -> None:
"""Export all imported modules in of a python script to file."""
with open(file, "r") as source:
with open(file, "r", encoding="utf-8") as source:
tree = ast.parse(source.read())
analyzer = Analyzer()
analyzer.visit(tree)
analyzer.report()
analyzer.report(self.dst_path)
class Analyzer(ast.NodeVisitor):
"""Visit and analyze nodes of Abstract Syntax Trees (AST)."""
def __init__(self) -> None:
self.stats = {"import": [], "from_module": [], "from": []}
self.stats: dict[str, list[str]] = {
"import": [],
"from_module": [],
"from": [],
}
def visit_Import(self, node) -> None:
def visit_Import(self, node: ast.Import) -> None:
"""Get modules that are imported with the 'import' statement."""
for alias in node.names:
self.stats["import"].append(alias.name)
self.generic_visit(node)
def visit_ImportFrom(self, node) -> None:
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Get modules that are imported with the 'from X import Y' statement."""
for alias in node.names:
self.stats["from"].append(node.module)
self.stats["from"].append(str(node.module))
self.generic_visit(node)
def report(self) -> None:
def report(self, dst_dir: str) -> None:
"""Create summary of imported modules."""
# Create set of imported modules
imports_as_set = set(
# Save the first part of import statement since it references the installed
# module.
[module.split(".", 1)[0] for module in self.stats["import"]]
)
imports_as_set = {module.split(".", 1)[0] for module in self.stats["import"]}
imports_as_set.update(
# Add modules imported with "from X import Y".
[module.split(".", 1)[0] for module in self.stats["from"]]
{module.split(".", 1)[0] for module in self.stats["from"]}
)
output_file = os.path.join(dst_dir, "required_imports.txt")
# Write every item of the set to one line.
with open("required_imports.txt", "w") as output:
with open(output_file, "w", encoding="utf-8") as output:
for item in imports_as_set:
output.write(f"{item}\n")
output.close()
......
......@@ -9,6 +9,7 @@ import os
import sys
import platform
import shutil
from collections import Counter
from subprocess import run, CalledProcessError
from unittest.mock import patch
import matplotlib.pyplot as plt
......@@ -341,6 +342,33 @@ class TestPublish(unittest.TestCase):
str(publish_obj),
)
def test_export_imports(self) -> None:
"""
Test if imports of the calling script are correctly written to file.
This test only works if called from the parent directory, since the path to the
file to test the behaviour has to be specified correctly.
"""
expected_modules = [
"shutil\n",
"unittest\n",
"subprocess\n",
"platform\n",
"matplotlib\n",
"os\n",
"sys\n",
"plotid\n",
"collections\n",
]
folder = os.path.join("test_parent", "test_dst_folder")
os.mkdir(folder)
file_path = os.path.join(folder, "required_imports.txt")
publish_obj = PublishOptions(FIGS_AND_IDS, SRC_DIR, DST_PATH)
publish_obj.export_imports(os.path.join("tests", "test_publish.py"))
with open(file_path, "r", encoding="utf-8") as file:
modules = file.readlines()
assert Counter(modules) == Counter(expected_modules)
assert os.path.isfile(file_path)
def tearDown(self) -> None:
"""Delete all files created in setUp."""
shutil.rmtree(SRC_DIR)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment