Skip to content
Snippets Groups Projects

Resolve: comment out plotID function calls in the copied script file

Files
2
+ 106
16
@@ -15,6 +15,7 @@ import sys
import warnings
from importlib.metadata import version, PackageNotFoundError
from typing import TypedDict, Any
import re
from plotid.save_plot import save_plot
from plotid.plotoptions import PlotIDTransfer, validate_list
@@ -48,6 +49,7 @@ class PublishOptions:
self.dst_path = os.path.abspath(dst_path)
self.data_storage = kwargs.get("data_storage", "individual")
self.dst_path_head, self.dst_dirname = os.path.split(self.dst_path)
self.dst_path_invisible = kwargs.get("dst_path_invisible", "")
self.plot_names = kwargs.get("plot_names", self.figure_ids)
def __str__(self) -> str:
@@ -123,7 +125,7 @@ class PublishOptions:
try:
# Create folder with ID as name
dst_path = os.path.join(self.dst_path, self.figure_ids[i])
dst_path_invisible = os.path.join(
self.dst_path_invisible = os.path.join(
self.dst_path, "." + self.figure_ids[i]
)
@@ -149,18 +151,19 @@ class PublishOptions:
"overwriting."
)
self.individual_data_storage(dst_path_invisible, plot)
self.export_imports(sys.argv[0], dst_path_invisible)
self.individual_data_storage(self.dst_path_invisible, plot)
self.export_imports(sys.argv[0], self.dst_path_invisible)
find_plotid_calls(sys.argv[0], self.dst_path_invisible)
# If export was successful, make the directory visible
os.rename(dst_path_invisible, dst_path)
os.rename(self.dst_path_invisible, dst_path)
except FileExistsError as exc:
delete_dir = input(
"There was an error while publishing the data. Should the "
f"partially copied data at {dst_path_invisible} be"
f"partially copied data at {self.dst_path_invisible} be"
" removed? (yes/no[default])\n"
)
if delete_dir in ("yes", "y", "Yes", "YES"):
shutil.rmtree(dst_path_invisible)
shutil.rmtree(self.dst_path_invisible)
raise RuntimeError(
"Publishing was unsuccessful. Try re-running publish."
) from exc
@@ -222,9 +225,6 @@ class PublishOptions:
except NotADirectoryError:
shutil.copy2(path, destination)
# Copy script that calls this function to folder
shutil.copy2(sys.argv[0], destination)
if os.path.isfile(pic_path):
# Copy plot file to folder
shutil.copy2(pic_path, destination)
@@ -249,38 +249,128 @@ class PublishOptions:
analyzer.report(destination)
def find_plotid_calls(text_file: str, dst_path: str) -> None:
"""Find all calls to plotID in a script and copy the script to dst_path."""
with open(text_file, "r", encoding="utf-8") as source:
tree = ast.parse(source.read())
analyzer = Analyzer()
analyzer.visit(tree)
matched_imports = analyzer.get_plotid_lines()
pattern_publish = re.compile(r"publish\(")
pattern_tagplot = re.compile(r"tagplot\(")
matched_lines_publish = find_pattern(text_file, pattern_publish)
matched_lines_tagplot = find_pattern(text_file, pattern_tagplot)
matched_lines = set(matched_imports + matched_lines_publish + matched_lines_tagplot)
if matched_lines:
comment_lines(text_file, matched_lines, dst_path)
def find_pattern(text_file: str, pattern: re.Pattern[str]) -> list[int]:
"""
Find a RegEx pattern in a file and return the lines where it was found.
Find a Python function call via RegEx pattern and return all lines that it
spans, i.e. the beginning of the function call until the parenthesis when the
call is closed.
"""
lines = []
matched_position = {}
with open(text_file, "r", encoding="utf-8") as file:
content = file.readlines()
for line_number, line in enumerate(content):
found_pattern = pattern.search(line)
if found_pattern:
start_position = (
found_pattern.start()
) # Start position of the matched pattern
matched_position[line_number] = start_position
for line_num_matched, start_position in matched_position.items():
# Count opened and closed brackets in lines,
# where the regex pattern occurred
open_brackets = content[line_num_matched].count("(", start_position)
close_brackets = content[line_num_matched].count(")", start_position)
# temp_line_number tracks where the bracket of the pattern is closed
temp_line_number = line_num_matched
while open_brackets > close_brackets:
# Search the following lines until there are as many closed as
# opened brackets.
temp_line_number += 1
open_brackets += content[temp_line_number].count("(")
close_brackets += content[temp_line_number].count(")")
lines += list(range(line_num_matched, temp_line_number + 1))
return lines
def comment_lines(text_file: str, lines: set[int], dst_path: str) -> None:
"""Copy a script to dst_path and comment all given lines with a '#'."""
with open(text_file, "r", encoding="utf-8") as file:
content = file.readlines()
for line in lines:
content[line] = "# " + content[line]
content = [
"# This script was automatically processed by plotID to comment all\n"
"# function calls to plotID.\n"
] + content
_, script_name = os.path.split(text_file)
copied_script = os.path.join(dst_path, script_name)
with open(copied_script, "w", encoding="utf-8") as file:
file.writelines(content)
class Analyzer(ast.NodeVisitor):
"""Visit and analyze nodes of Abstract Syntax Trees (AST)."""
"""
Visit and analyze nodes of Abstract Syntax Trees (AST).
Create dict that contains a list of tuples that contain the module name and
the line where the module is imported.
"""
def __init__(self) -> None:
self.stats: dict[str, list[str]] = {
self.stats: dict[str, list[tuple[str, int]]] = {
"import": [],
"from_module": [],
"from": [],
}
def visit_Import(self, node: ast.Import) -> None:
"""Get modules that are imported with the 'import' statement."""
for alias in node.names:
self.stats["import"].append(alias.name)
self.stats["import"].append((alias.name, alias.lineno))
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Get modules that are imported with the 'from X import Y' statement."""
self.stats["from"].append(str(node.module))
self.stats["from"].append(((str(node.module)), node.lineno))
self.generic_visit(node)
def get_plotid_lines(self) -> list[int]:
"""Get the lines where plotID is imported."""
lines = []
for module in self.stats["import"]:
if module[0].split(".", 1)[0] == "plotid":
lines.append(module[1] - 1)
for module in self.stats["from"]:
if module[0].split(".", 1)[0] == "plotid":
lines.append(module[1] - 1)
return lines
def report(self, dst_dir: str) -> None:
"""Create summary of imported modules."""
# Save the first part of import statement since it references the installed
# module.
imports_as_set = {module.split(".", 1)[0] for module in self.stats["import"]}
imports_as_set = {module[0].split(".", 1)[0] for module in self.stats["import"]}
imports_as_set.update(
# Add modules imported with "from X import Y".
{module.split(".", 1)[0] for module in self.stats["from"]}
{module[0].split(".", 1)[0] for module in self.stats["from"]}
)
# Remove plotid from the list
imports_as_set.discard("plotid")
output_file = os.path.join(dst_dir, "required_imports.txt")
# Write every item of the set to one line.
with open(output_file, "w", encoding="utf-8") as output:
Loading