Skip to content
Snippets Groups Projects
Commit 664baefd authored by Mayr, Hannes's avatar Mayr, Hannes
Browse files

Resolve: comment out plotID function calls in the copied script file

parent 1f26e3db
No related branches found
No related tags found
2 merge requests!66Resolve: comment out plotID function calls in the copied script file,!65Release v0.3.0
......@@ -15,6 +15,7 @@ import sys
import warnings
from importlib.metadata import version, PackageNotFoundError
from typing import TypedDict, Any
import re
from plotid.save_plot import save_plot
from plotid.plotoptions import PlotIDTransfer, validate_list
......@@ -48,6 +49,7 @@ class PublishOptions:
self.dst_path = os.path.abspath(dst_path)
self.data_storage = kwargs.get("data_storage", "individual")
self.dst_path_head, self.dst_dirname = os.path.split(self.dst_path)
self.dst_path_invisible = kwargs.get("dst_path_invisible", "")
self.plot_names = kwargs.get("plot_names", self.figure_ids)
def __str__(self) -> str:
......@@ -123,7 +125,7 @@ class PublishOptions:
try:
# Create folder with ID as name
dst_path = os.path.join(self.dst_path, self.figure_ids[i])
dst_path_invisible = os.path.join(
self.dst_path_invisible = os.path.join(
self.dst_path, "." + self.figure_ids[i]
)
......@@ -149,18 +151,19 @@ class PublishOptions:
"overwriting."
)
self.individual_data_storage(dst_path_invisible, plot)
self.export_imports(sys.argv[0], dst_path_invisible)
self.individual_data_storage(self.dst_path_invisible, plot)
self.export_imports(sys.argv[0], self.dst_path_invisible)
find_plotid_calls(sys.argv[0], self.dst_path_invisible)
# If export was successful, make the directory visible
os.rename(dst_path_invisible, dst_path)
os.rename(self.dst_path_invisible, dst_path)
except FileExistsError as exc:
delete_dir = input(
"There was an error while publishing the data. Should the "
f"partially copied data at {dst_path_invisible} be"
f"partially copied data at {self.dst_path_invisible} be"
" removed? (yes/no[default])\n"
)
if delete_dir in ("yes", "y", "Yes", "YES"):
shutil.rmtree(dst_path_invisible)
shutil.rmtree(self.dst_path_invisible)
raise RuntimeError(
"Publishing was unsuccessful. Try re-running publish."
) from exc
......@@ -222,9 +225,6 @@ class PublishOptions:
except NotADirectoryError:
shutil.copy2(path, destination)
# Copy script that calls this function to folder
shutil.copy2(sys.argv[0], destination)
if os.path.isfile(pic_path):
# Copy plot file to folder
shutil.copy2(pic_path, destination)
......@@ -249,38 +249,128 @@ class PublishOptions:
analyzer.report(destination)
def find_plotid_calls(text_file: str, dst_path: str) -> None:
"""Find all calls to plotID in a script and copy the script to dst_path."""
with open(text_file, "r", encoding="utf-8") as source:
tree = ast.parse(source.read())
analyzer = Analyzer()
analyzer.visit(tree)
matched_imports = analyzer.get_plotid_lines()
pattern_publish = re.compile(r"publish\(")
pattern_tagplot = re.compile(r"tagplot\(")
matched_lines_publish = find_pattern(text_file, pattern_publish)
matched_lines_tagplot = find_pattern(text_file, pattern_tagplot)
matched_lines = set(matched_imports + matched_lines_publish + matched_lines_tagplot)
if matched_lines:
comment_lines(text_file, matched_lines, dst_path)
def find_pattern(text_file: str, pattern: re.Pattern[str]) -> list[int]:
"""
Find a RegEx pattern in a file and return the lines where it was found.
Find a Python function call via RegEx pattern and return all lines that it
spans, i.e. the beginning of the function call until the parenthesis when the
call is closed.
"""
lines = []
matched_position = {}
with open(text_file, "r", encoding="utf-8") as file:
content = file.readlines()
for line_number, line in enumerate(content):
found_pattern = pattern.search(line)
if found_pattern:
start_position = (
found_pattern.start()
) # Start position of the matched pattern
matched_position[line_number] = start_position
for line_num_matched, start_position in matched_position.items():
# Count opened and closed brackets in lines,
# where the regex pattern occurred
open_brackets = content[line_num_matched].count("(", start_position)
close_brackets = content[line_num_matched].count(")", start_position)
# temp_line_number tracks where the bracket of the pattern is closed
temp_line_number = line_num_matched
while open_brackets > close_brackets:
# Search the following lines until there are as many closed as
# opened brackets.
temp_line_number += 1
open_brackets += content[temp_line_number].count("(")
close_brackets += content[temp_line_number].count(")")
lines += list(range(line_num_matched, temp_line_number + 1))
return lines
def comment_lines(text_file: str, lines: set[int], dst_path: str) -> None:
"""Copy a script to dst_path and comment all given lines with a '#'."""
with open(text_file, "r", encoding="utf-8") as file:
content = file.readlines()
for line in lines:
content[line] = "# " + content[line]
content = [
"# This script was automatically processed by plotID to comment all\n"
"# function calls to plotID.\n"
] + content
_, script_name = os.path.split(text_file)
copied_script = os.path.join(dst_path, script_name)
with open(copied_script, "w", encoding="utf-8") as file:
file.writelines(content)
class Analyzer(ast.NodeVisitor):
"""Visit and analyze nodes of Abstract Syntax Trees (AST)."""
"""
Visit and analyze nodes of Abstract Syntax Trees (AST).
Create dict that contains a list of tuples that contain the module name and
the line where the module is imported.
"""
def __init__(self) -> None:
self.stats: dict[str, list[str]] = {
self.stats: dict[str, list[tuple[str, int]]] = {
"import": [],
"from_module": [],
"from": [],
}
def visit_Import(self, node: ast.Import) -> None:
"""Get modules that are imported with the 'import' statement."""
for alias in node.names:
self.stats["import"].append(alias.name)
self.stats["import"].append((alias.name, alias.lineno))
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Get modules that are imported with the 'from X import Y' statement."""
self.stats["from"].append(str(node.module))
self.stats["from"].append(((str(node.module)), node.lineno))
self.generic_visit(node)
def get_plotid_lines(self) -> list[int]:
"""Get the lines where plotID is imported."""
lines = []
for module in self.stats["import"]:
if module[0].split(".", 1)[0] == "plotid":
lines.append(module[1] - 1)
for module in self.stats["from"]:
if module[0].split(".", 1)[0] == "plotid":
lines.append(module[1] - 1)
return lines
def report(self, dst_dir: str) -> None:
"""Create summary of imported modules."""
# Save the first part of import statement since it references the installed
# module.
imports_as_set = {module.split(".", 1)[0] for module in self.stats["import"]}
imports_as_set = {module[0].split(".", 1)[0] for module in self.stats["import"]}
imports_as_set.update(
# Add modules imported with "from X import Y".
{module.split(".", 1)[0] for module in self.stats["from"]}
{module[0].split(".", 1)[0] for module in self.stats["from"]}
)
# Remove plotid from the list
imports_as_set.discard("plotid")
output_file = os.path.join(dst_dir, "required_imports.txt")
# Write every item of the set to one line.
with open(output_file, "w", encoding="utf-8") as output:
......
......@@ -13,8 +13,15 @@ from collections import Counter
from importlib.metadata import version
from subprocess import run, CalledProcessError
from unittest.mock import patch
import re
import matplotlib.pyplot as plt
from plotid.publish import publish, PublishOptions
from plotid.publish import (
publish,
PublishOptions,
find_plotid_calls,
find_pattern,
comment_lines,
)
from plotid.plotoptions import PlotIDTransfer
......@@ -30,6 +37,9 @@ FIGS_AS_LIST = [FIG, FIG2]
IDS_AS_LIST = ["MR05_0x63203c6f", "MR05_0x63203c70"]
FIGS_AND_IDS = PlotIDTransfer(FIGS_AS_LIST, IDS_AS_LIST)
PIC_NAME_LIST = [PIC_NAME, "second_picture"]
DST_DIR = os.path.abspath(os.path.join("tests", "test_comment"))
TEXT_FILE = "tmp_test_find_lines.txt"
PYTHON_FILE = "tmp_test_find_calls.py"
class TestPublish(unittest.TestCase):
......@@ -44,6 +54,21 @@ class TestPublish(unittest.TestCase):
for file in SRC_FILES:
open(file, "w", encoding="utf-8").close()
os.mkdir(DST_DIR)
with open(TEXT_FILE, "x", encoding="utf-8") as file:
content = (
"Lorem ipsum(\ndolor\nsit() amet(,\nconsectetur) adipisici )elit, sed "
"eiusmod tempor\nincidunt ut\nlab\nore et dolore\nmagna aliqua."
)
file.write(content)
with open(PYTHON_FILE, "x", encoding="utf-8") as file:
content = (
"import sys\nfrom plotid.tagplot import tagplot\nimport plotid\n"
"x=123\ns='abc'\npublish(\n'Lorem ipsum',\nx+1 (\n) )"
" \ntagplot()\n \n\n\ntagplot(()) \nsys.exit()"
)
file.write(content)
# Skip test if tests are run from command line.
@unittest.skipIf(
not os.path.isfile(sys.argv[0]),
......@@ -357,8 +382,8 @@ class TestPublish(unittest.TestCase):
"platform\n",
f"matplotlib=={mpl_version}\n",
"os\n",
"re\n",
"sys\n",
"plotid\n",
"collections\n",
"importlib\n",
]
......@@ -379,10 +404,49 @@ class TestPublish(unittest.TestCase):
msg=f"{modules}, {expected_modules}",
)
def test_find_plotid_calls(self) -> None:
"""Test if all calls to plotID in a file are found."""
find_plotid_calls(PYTHON_FILE, DST_DIR)
copied_file = os.path.join(DST_DIR, PYTHON_FILE)
expected = (
"# This script was automatically processed by plotID to comment "
"all\n# function calls to plotID.\nimport sys\n# from plotid.tagplot import"
" tagplot\n# import plotid\nx=123\ns='abc'\n# publish(\n# 'Lorem ipsum',"
"\n# x+1 (\n# ) ) \n# tagplot()\n \n\n\n# tagplot(()) \nsys.exit()"
)
with open(copied_file, "r", encoding="utf-8") as file:
new_content = file.read()
self.assertEqual(new_content, expected)
def test_find_pattern(self) -> None:
"""Test if RegEx pattern is found and correct list of lines is returned."""
pattern_simple = re.compile(r"incidunt")
pattern_bracket = re.compile(r"ipsum\(")
lines_simple = find_pattern(TEXT_FILE, pattern_simple)
self.assertEqual(lines_simple, [4])
lines_bracket = find_pattern(TEXT_FILE, pattern_bracket)
self.assertEqual(lines_bracket, [0, 1, 2, 3])
def test_comment_lines(self) -> None:
"""Test if correct lines get commented."""
expected = (
"# This script was automatically processed by plotID to comment "
"all\n# function calls to plotID.\n"
"Lorem ipsum(\n# dolor\nsit() amet(,\n# consectetur) adipisici )elit, sed"
" eiusmod tempor\nincidunt ut\nlab\n# ore et dolore\nmagna aliqua."
)
comment_lines(TEXT_FILE, {1, 3, 6}, DST_DIR)
with open(os.path.join(DST_DIR, TEXT_FILE), "r", encoding="utf-8") as file:
new_content = file.read()
self.assertEqual(new_content, expected)
def tearDown(self) -> None:
"""Delete all files created in setUp."""
shutil.rmtree(SRC_DIR)
shutil.rmtree(DST_PARENT_DIR)
shutil.rmtree(DST_DIR)
os.remove(TEXT_FILE)
os.remove(PYTHON_FILE)
for file in SRC_FILES:
os.remove(file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment