diff --git a/plotid/publish.py b/plotid/publish.py index fab060c3f29f398333e114ee43cc68f86659ebda..e0cc5de81479ffe1c25ba5f90ef64803da0dddb8 100644 --- a/plotid/publish.py +++ b/plotid/publish.py @@ -15,6 +15,7 @@ import sys import warnings from importlib.metadata import version, PackageNotFoundError from typing import TypedDict, Any +import re from plotid.save_plot import save_plot from plotid.plotoptions import PlotIDTransfer, validate_list @@ -48,6 +49,7 @@ class PublishOptions: self.dst_path = os.path.abspath(dst_path) self.data_storage = kwargs.get("data_storage", "individual") self.dst_path_head, self.dst_dirname = os.path.split(self.dst_path) + self.dst_path_invisible = kwargs.get("dst_path_invisible", "") self.plot_names = kwargs.get("plot_names", self.figure_ids) def __str__(self) -> str: @@ -123,7 +125,7 @@ class PublishOptions: try: # Create folder with ID as name dst_path = os.path.join(self.dst_path, self.figure_ids[i]) - dst_path_invisible = os.path.join( + self.dst_path_invisible = os.path.join( self.dst_path, "." + self.figure_ids[i] ) @@ -149,18 +151,19 @@ class PublishOptions: "overwriting." ) - self.individual_data_storage(dst_path_invisible, plot) - self.export_imports(sys.argv[0], dst_path_invisible) + self.individual_data_storage(self.dst_path_invisible, plot) + self.export_imports(sys.argv[0], self.dst_path_invisible) + find_plotid_calls(sys.argv[0], self.dst_path_invisible) # If export was successful, make the directory visible - os.rename(dst_path_invisible, dst_path) + os.rename(self.dst_path_invisible, dst_path) except FileExistsError as exc: delete_dir = input( "There was an error while publishing the data. Should the " - f"partially copied data at {dst_path_invisible} be" + f"partially copied data at {self.dst_path_invisible} be" " removed? (yes/no[default])\n" ) if delete_dir in ("yes", "y", "Yes", "YES"): - shutil.rmtree(dst_path_invisible) + shutil.rmtree(self.dst_path_invisible) raise RuntimeError( "Publishing was unsuccessful. Try re-running publish." ) from exc @@ -222,9 +225,6 @@ class PublishOptions: except NotADirectoryError: shutil.copy2(path, destination) - # Copy script that calls this function to folder - shutil.copy2(sys.argv[0], destination) - if os.path.isfile(pic_path): # Copy plot file to folder shutil.copy2(pic_path, destination) @@ -249,38 +249,128 @@ class PublishOptions: analyzer.report(destination) +def find_plotid_calls(text_file: str, dst_path: str) -> None: + """Find all calls to plotID in a script and copy the script to dst_path.""" + with open(text_file, "r", encoding="utf-8") as source: + tree = ast.parse(source.read()) + analyzer = Analyzer() + analyzer.visit(tree) + matched_imports = analyzer.get_plotid_lines() + + pattern_publish = re.compile(r"publish\(") + pattern_tagplot = re.compile(r"tagplot\(") + matched_lines_publish = find_pattern(text_file, pattern_publish) + matched_lines_tagplot = find_pattern(text_file, pattern_tagplot) + matched_lines = set(matched_imports + matched_lines_publish + matched_lines_tagplot) + if matched_lines: + comment_lines(text_file, matched_lines, dst_path) + + +def find_pattern(text_file: str, pattern: re.Pattern[str]) -> list[int]: + """ + Find a RegEx pattern in a file and return the lines where it was found. + + Find a Python function call via RegEx pattern and return all lines that it + spans, i.e. the beginning of the function call until the parenthesis when the + call is closed. + """ + lines = [] + matched_position = {} + with open(text_file, "r", encoding="utf-8") as file: + content = file.readlines() + for line_number, line in enumerate(content): + found_pattern = pattern.search(line) + if found_pattern: + start_position = ( + found_pattern.start() + ) # Start position of the matched pattern + matched_position[line_number] = start_position + + for line_num_matched, start_position in matched_position.items(): + # Count opened and closed brackets in lines, + # where the regex pattern occurred + open_brackets = content[line_num_matched].count("(", start_position) + close_brackets = content[line_num_matched].count(")", start_position) + # temp_line_number tracks where the bracket of the pattern is closed + temp_line_number = line_num_matched + while open_brackets > close_brackets: + # Search the following lines until there are as many closed as + # opened brackets. + temp_line_number += 1 + open_brackets += content[temp_line_number].count("(") + close_brackets += content[temp_line_number].count(")") + + lines += list(range(line_num_matched, temp_line_number + 1)) + return lines + + +def comment_lines(text_file: str, lines: set[int], dst_path: str) -> None: + """Copy a script to dst_path and comment all given lines with a '#'.""" + with open(text_file, "r", encoding="utf-8") as file: + content = file.readlines() + for line in lines: + content[line] = "# " + content[line] + content = [ + "# This script was automatically processed by plotID to comment all\n" + "# function calls to plotID.\n" + ] + content + + _, script_name = os.path.split(text_file) + copied_script = os.path.join(dst_path, script_name) + with open(copied_script, "w", encoding="utf-8") as file: + file.writelines(content) + + class Analyzer(ast.NodeVisitor): - """Visit and analyze nodes of Abstract Syntax Trees (AST).""" + """ + Visit and analyze nodes of Abstract Syntax Trees (AST). + + Create dict that contains a list of tuples that contain the module name and + the line where the module is imported. + """ def __init__(self) -> None: - self.stats: dict[str, list[str]] = { + self.stats: dict[str, list[tuple[str, int]]] = { "import": [], - "from_module": [], "from": [], } def visit_Import(self, node: ast.Import) -> None: """Get modules that are imported with the 'import' statement.""" for alias in node.names: - self.stats["import"].append(alias.name) + self.stats["import"].append((alias.name, alias.lineno)) self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """Get modules that are imported with the 'from X import Y' statement.""" - self.stats["from"].append(str(node.module)) + self.stats["from"].append(((str(node.module)), node.lineno)) self.generic_visit(node) + def get_plotid_lines(self) -> list[int]: + """Get the lines where plotID is imported.""" + lines = [] + for module in self.stats["import"]: + if module[0].split(".", 1)[0] == "plotid": + lines.append(module[1] - 1) + for module in self.stats["from"]: + if module[0].split(".", 1)[0] == "plotid": + lines.append(module[1] - 1) + return lines + def report(self, dst_dir: str) -> None: """Create summary of imported modules.""" # Save the first part of import statement since it references the installed # module. - imports_as_set = {module.split(".", 1)[0] for module in self.stats["import"]} + imports_as_set = {module[0].split(".", 1)[0] for module in self.stats["import"]} imports_as_set.update( # Add modules imported with "from X import Y". - {module.split(".", 1)[0] for module in self.stats["from"]} + {module[0].split(".", 1)[0] for module in self.stats["from"]} ) + # Remove plotid from the list + imports_as_set.discard("plotid") + output_file = os.path.join(dst_dir, "required_imports.txt") # Write every item of the set to one line. with open(output_file, "w", encoding="utf-8") as output: diff --git a/tests/test_publish.py b/tests/test_publish.py index eb985fec860795ecb66cdece3ce4bcde1a2bc0de..cff7099c74fe485529258dc230f2d93983ff75f6 100644 --- a/tests/test_publish.py +++ b/tests/test_publish.py @@ -13,8 +13,15 @@ from collections import Counter from importlib.metadata import version from subprocess import run, CalledProcessError from unittest.mock import patch +import re import matplotlib.pyplot as plt -from plotid.publish import publish, PublishOptions +from plotid.publish import ( + publish, + PublishOptions, + find_plotid_calls, + find_pattern, + comment_lines, +) from plotid.plotoptions import PlotIDTransfer @@ -30,6 +37,9 @@ FIGS_AS_LIST = [FIG, FIG2] IDS_AS_LIST = ["MR05_0x63203c6f", "MR05_0x63203c70"] FIGS_AND_IDS = PlotIDTransfer(FIGS_AS_LIST, IDS_AS_LIST) PIC_NAME_LIST = [PIC_NAME, "second_picture"] +DST_DIR = os.path.abspath(os.path.join("tests", "test_comment")) +TEXT_FILE = "tmp_test_find_lines.txt" +PYTHON_FILE = "tmp_test_find_calls.py" class TestPublish(unittest.TestCase): @@ -44,6 +54,21 @@ class TestPublish(unittest.TestCase): for file in SRC_FILES: open(file, "w", encoding="utf-8").close() + os.mkdir(DST_DIR) + with open(TEXT_FILE, "x", encoding="utf-8") as file: + content = ( + "Lorem ipsum(\ndolor\nsit() amet(,\nconsectetur) adipisici )elit, sed " + "eiusmod tempor\nincidunt ut\nlab\nore et dolore\nmagna aliqua." + ) + file.write(content) + with open(PYTHON_FILE, "x", encoding="utf-8") as file: + content = ( + "import sys\nfrom plotid.tagplot import tagplot\nimport plotid\n" + "x=123\ns='abc'\npublish(\n'Lorem ipsum',\nx+1 (\n) )" + " \ntagplot()\n \n\n\ntagplot(()) \nsys.exit()" + ) + file.write(content) + # Skip test if tests are run from command line. @unittest.skipIf( not os.path.isfile(sys.argv[0]), @@ -357,8 +382,8 @@ class TestPublish(unittest.TestCase): "platform\n", f"matplotlib=={mpl_version}\n", "os\n", + "re\n", "sys\n", - "plotid\n", "collections\n", "importlib\n", ] @@ -379,10 +404,49 @@ class TestPublish(unittest.TestCase): msg=f"{modules}, {expected_modules}", ) + def test_find_plotid_calls(self) -> None: + """Test if all calls to plotID in a file are found.""" + find_plotid_calls(PYTHON_FILE, DST_DIR) + copied_file = os.path.join(DST_DIR, PYTHON_FILE) + expected = ( + "# This script was automatically processed by plotID to comment " + "all\n# function calls to plotID.\nimport sys\n# from plotid.tagplot import" + " tagplot\n# import plotid\nx=123\ns='abc'\n# publish(\n# 'Lorem ipsum'," + "\n# x+1 (\n# ) ) \n# tagplot()\n \n\n\n# tagplot(()) \nsys.exit()" + ) + with open(copied_file, "r", encoding="utf-8") as file: + new_content = file.read() + self.assertEqual(new_content, expected) + + def test_find_pattern(self) -> None: + """Test if RegEx pattern is found and correct list of lines is returned.""" + pattern_simple = re.compile(r"incidunt") + pattern_bracket = re.compile(r"ipsum\(") + lines_simple = find_pattern(TEXT_FILE, pattern_simple) + self.assertEqual(lines_simple, [4]) + lines_bracket = find_pattern(TEXT_FILE, pattern_bracket) + self.assertEqual(lines_bracket, [0, 1, 2, 3]) + + def test_comment_lines(self) -> None: + """Test if correct lines get commented.""" + expected = ( + "# This script was automatically processed by plotID to comment " + "all\n# function calls to plotID.\n" + "Lorem ipsum(\n# dolor\nsit() amet(,\n# consectetur) adipisici )elit, sed" + " eiusmod tempor\nincidunt ut\nlab\n# ore et dolore\nmagna aliqua." + ) + comment_lines(TEXT_FILE, {1, 3, 6}, DST_DIR) + with open(os.path.join(DST_DIR, TEXT_FILE), "r", encoding="utf-8") as file: + new_content = file.read() + self.assertEqual(new_content, expected) + def tearDown(self) -> None: """Delete all files created in setUp.""" shutil.rmtree(SRC_DIR) shutil.rmtree(DST_PARENT_DIR) + shutil.rmtree(DST_DIR) + os.remove(TEXT_FILE) + os.remove(PYTHON_FILE) for file in SRC_FILES: os.remove(file)