Skip to content
Snippets Groups Projects
Commit 77140c51 authored by Mayr, Hannes's avatar Mayr, Hannes
Browse files

Merge branch '18-optional-feature-to-install-missing-dependencies-via-pip-or'...

Merge branch '18-optional-feature-to-install-missing-dependencies-via-pip-or' into 'new-year-developments'

Export imported modules

See merge request !62
parents bba01f22 fa35dddf
Branches
Tags
3 merge requests!65Release v0.3.0,!64New year developments,!62Export imported modules
Pipeline #888214 passed
......@@ -79,7 +79,7 @@ FIGS_AND_IDS = tagplot(FIGS_AS_LIST, "matplotlib", prefix="XY23_", id_method="ra
### publish()
Save plot, data and measuring script. It is possible to export multiple figures at once.
Save plot, data and measuring script. Modules that are imported in the script which calls plotID are exported to the file "required_imports.txt". These can later be installed via pip. It is possible to export multiple figures at once.
`publish(figs_and_ids, src_datapath, dst_path)`
- *figs_and_ids* must be a PlotIDTransfer object. Therefore, it can be directly passed from tagplot() to publish().
......
......@@ -8,6 +8,7 @@ the plot is based on. Additionally, the script that produced the plot will be
copied to the destination directory.
"""
import ast
import os
import shutil
import sys
......@@ -239,6 +240,55 @@ class PublishOptions:
os.path.join(destination, final_file_path),
)
def export_imports(self, file: str) -> None:
"""Export all imported modules of a python script to file."""
with open(file, "r", encoding="utf-8") as source:
tree = ast.parse(source.read())
analyzer = Analyzer()
analyzer.visit(tree)
analyzer.report(self.dst_path)
class Analyzer(ast.NodeVisitor):
"""Visit and analyze nodes of Abstract Syntax Trees (AST)."""
def __init__(self) -> None:
self.stats: dict[str, list[str]] = {
"import": [],
"from_module": [],
"from": [],
}
def visit_Import(self, node: ast.Import) -> None:
"""Get modules that are imported with the 'import' statement."""
for alias in node.names:
self.stats["import"].append(alias.name)
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Get modules that are imported with the 'from X import Y' statement."""
self.stats["from"].append(str(node.module))
self.generic_visit(node)
def report(self, dst_dir: str) -> None:
"""Create summary of imported modules."""
# Save the first part of import statement since it references the installed
# module.
imports_as_set = {module.split(".", 1)[0] for module in self.stats["import"]}
imports_as_set.update(
# Add modules imported with "from X import Y".
{module.split(".", 1)[0] for module in self.stats["from"]}
)
output_file = os.path.join(dst_dir, "required_imports.txt")
# Write every item of the set to one line.
with open(output_file, "w", encoding="utf-8") as output:
for item in imports_as_set:
output.write(f"{item}\n")
output.close()
kwargs_types_publish = TypedDict(
"kwargs_types_publish",
......@@ -288,3 +338,4 @@ def publish(
publish_container = PublishOptions(figs_and_ids, src_datapath, dst_path, **kwargs)
publish_container.validate_input()
publish_container.export()
publish_container.export_imports(sys.argv[0])
......@@ -21,16 +21,16 @@ def save_plot(
Parameters
----------
figure : list of/single figure object
figure :
Figure that was tagged and now should be saved as picture.
plot_name : str or list of str
plot_name :
Names of the files where the plots will be saved to.
extension : str
extension :
File extension for the plot export.
Returns
-------
plot_path : list of str
plot_path :
Names of the created pictures.
"""
# Check if figs is a valid figure or a list of valid figures
......
......@@ -9,6 +9,7 @@ import os
import sys
import platform
import shutil
from collections import Counter
from subprocess import run, CalledProcessError
from unittest.mock import patch
import matplotlib.pyplot as plt
......@@ -341,6 +342,33 @@ class TestPublish(unittest.TestCase):
str(publish_obj),
)
def test_export_imports(self) -> None:
"""
Test if imports of the calling script are correctly written to file.
This test only works if called from the parent directory, since the path to the
file to test the behaviour has to be specified correctly.
"""
expected_modules = [
"shutil\n",
"unittest\n",
"subprocess\n",
"platform\n",
"matplotlib\n",
"os\n",
"sys\n",
"plotid\n",
"collections\n",
]
folder = os.path.join("test_parent", "test_dst_folder")
os.mkdir(folder)
file_path = os.path.join(folder, "required_imports.txt")
publish_obj = PublishOptions(FIGS_AND_IDS, SRC_DIR, DST_PATH)
publish_obj.export_imports(os.path.join("tests", "test_publish.py"))
with open(file_path, "r", encoding="utf-8") as file:
modules = file.readlines()
assert Counter(modules) == Counter(expected_modules)
assert os.path.isfile(file_path)
def tearDown(self) -> None:
"""Delete all files created in setUp."""
shutil.rmtree(SRC_DIR)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment