From 664baefdae2f504816530c77097535caba7c9f32 Mon Sep 17 00:00:00 2001
From: "Mayr, Hannes" <hannes.mayr@stud.tu-darmstadt.de>
Date: Fri, 13 Jan 2023 13:06:03 +0100
Subject: [PATCH] Resolve: comment out plotID function calls in the copied
 script file

---
 plotid/publish.py     | 122 ++++++++++++++++++++++++++++++++++++------
 tests/test_publish.py |  68 ++++++++++++++++++++++-
 2 files changed, 172 insertions(+), 18 deletions(-)

diff --git a/plotid/publish.py b/plotid/publish.py
index fab060c..e0cc5de 100644
--- a/plotid/publish.py
+++ b/plotid/publish.py
@@ -15,6 +15,7 @@ import sys
 import warnings
 from importlib.metadata import version, PackageNotFoundError
 from typing import TypedDict, Any
+import re
 from plotid.save_plot import save_plot
 from plotid.plotoptions import PlotIDTransfer, validate_list
 
@@ -48,6 +49,7 @@ class PublishOptions:
         self.dst_path = os.path.abspath(dst_path)
         self.data_storage = kwargs.get("data_storage", "individual")
         self.dst_path_head, self.dst_dirname = os.path.split(self.dst_path)
+        self.dst_path_invisible = kwargs.get("dst_path_invisible", "")
         self.plot_names = kwargs.get("plot_names", self.figure_ids)
 
     def __str__(self) -> str:
@@ -123,7 +125,7 @@ class PublishOptions:
                     try:
                         # Create folder with ID as name
                         dst_path = os.path.join(self.dst_path, self.figure_ids[i])
-                        dst_path_invisible = os.path.join(
+                        self.dst_path_invisible = os.path.join(
                             self.dst_path, "." + self.figure_ids[i]
                         )
 
@@ -149,18 +151,19 @@ class PublishOptions:
                                     "overwriting."
                                 )
 
-                        self.individual_data_storage(dst_path_invisible, plot)
-                        self.export_imports(sys.argv[0], dst_path_invisible)
+                        self.individual_data_storage(self.dst_path_invisible, plot)
+                        self.export_imports(sys.argv[0], self.dst_path_invisible)
+                        find_plotid_calls(sys.argv[0], self.dst_path_invisible)
                         # If export was successful, make the directory visible
-                        os.rename(dst_path_invisible, dst_path)
+                        os.rename(self.dst_path_invisible, dst_path)
                     except FileExistsError as exc:
                         delete_dir = input(
                             "There was an error while publishing the data. Should the "
-                            f"partially copied data at {dst_path_invisible} be"
+                            f"partially copied data at {self.dst_path_invisible} be"
                             " removed? (yes/no[default])\n"
                         )
                         if delete_dir in ("yes", "y", "Yes", "YES"):
-                            shutil.rmtree(dst_path_invisible)
+                            shutil.rmtree(self.dst_path_invisible)
                         raise RuntimeError(
                             "Publishing was unsuccessful. Try re-running publish."
                         ) from exc
@@ -222,9 +225,6 @@ class PublishOptions:
             except NotADirectoryError:
                 shutil.copy2(path, destination)
 
-        # Copy script that calls this function to folder
-        shutil.copy2(sys.argv[0], destination)
-
         if os.path.isfile(pic_path):
             # Copy plot file to folder
             shutil.copy2(pic_path, destination)
@@ -249,38 +249,128 @@ class PublishOptions:
         analyzer.report(destination)
 
 
+def find_plotid_calls(text_file: str, dst_path: str) -> None:
+    """Find all calls to plotID in a script and copy the script to dst_path."""
+    with open(text_file, "r", encoding="utf-8") as source:
+        tree = ast.parse(source.read())
+    analyzer = Analyzer()
+    analyzer.visit(tree)
+    matched_imports = analyzer.get_plotid_lines()
+
+    pattern_publish = re.compile(r"publish\(")
+    pattern_tagplot = re.compile(r"tagplot\(")
+    matched_lines_publish = find_pattern(text_file, pattern_publish)
+    matched_lines_tagplot = find_pattern(text_file, pattern_tagplot)
+    matched_lines = set(matched_imports + matched_lines_publish + matched_lines_tagplot)
+    if matched_lines:
+        comment_lines(text_file, matched_lines, dst_path)
+
+
+def find_pattern(text_file: str, pattern: re.Pattern[str]) -> list[int]:
+    """
+    Find a RegEx pattern in a file and return the lines where it was found.
+
+    Find a Python function call via RegEx pattern and return all lines that it
+    spans, i.e. the beginning of the function call until the parenthesis when the
+    call is closed.
+    """
+    lines = []
+    matched_position = {}
+    with open(text_file, "r", encoding="utf-8") as file:
+        content = file.readlines()
+        for line_number, line in enumerate(content):
+            found_pattern = pattern.search(line)
+            if found_pattern:
+                start_position = (
+                    found_pattern.start()
+                )  # Start position of the matched pattern
+                matched_position[line_number] = start_position
+
+        for line_num_matched, start_position in matched_position.items():
+            # Count opened and closed brackets in lines,
+            # where the regex pattern occurred
+            open_brackets = content[line_num_matched].count("(", start_position)
+            close_brackets = content[line_num_matched].count(")", start_position)
+            # temp_line_number tracks where the bracket of the pattern is closed
+            temp_line_number = line_num_matched
+            while open_brackets > close_brackets:
+                # Search the following lines until there are as many closed as
+                # opened brackets.
+                temp_line_number += 1
+                open_brackets += content[temp_line_number].count("(")
+                close_brackets += content[temp_line_number].count(")")
+
+            lines += list(range(line_num_matched, temp_line_number + 1))
+    return lines
+
+
+def comment_lines(text_file: str, lines: set[int], dst_path: str) -> None:
+    """Copy a script to dst_path and comment all given lines with a '#'."""
+    with open(text_file, "r", encoding="utf-8") as file:
+        content = file.readlines()
+        for line in lines:
+            content[line] = "# " + content[line]
+        content = [
+            "# This script was automatically processed by plotID to comment all\n"
+            "# function calls to plotID.\n"
+        ] + content
+
+    _, script_name = os.path.split(text_file)
+    copied_script = os.path.join(dst_path, script_name)
+    with open(copied_script, "w", encoding="utf-8") as file:
+        file.writelines(content)
+
+
 class Analyzer(ast.NodeVisitor):
-    """Visit and analyze nodes of Abstract Syntax Trees (AST)."""
+    """
+    Visit and analyze nodes of Abstract Syntax Trees (AST).
+
+    Create dict that contains a list of tuples that contain the module name and
+    the line where the module is imported.
+    """
 
     def __init__(self) -> None:
-        self.stats: dict[str, list[str]] = {
+        self.stats: dict[str, list[tuple[str, int]]] = {
             "import": [],
-            "from_module": [],
             "from": [],
         }
 
     def visit_Import(self, node: ast.Import) -> None:
         """Get modules that are imported with the 'import' statement."""
         for alias in node.names:
-            self.stats["import"].append(alias.name)
+            self.stats["import"].append((alias.name, alias.lineno))
         self.generic_visit(node)
 
     def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
         """Get modules that are imported with the 'from X import Y' statement."""
-        self.stats["from"].append(str(node.module))
+        self.stats["from"].append(((str(node.module)), node.lineno))
         self.generic_visit(node)
 
+    def get_plotid_lines(self) -> list[int]:
+        """Get the lines where plotID is imported."""
+        lines = []
+        for module in self.stats["import"]:
+            if module[0].split(".", 1)[0] == "plotid":
+                lines.append(module[1] - 1)
+        for module in self.stats["from"]:
+            if module[0].split(".", 1)[0] == "plotid":
+                lines.append(module[1] - 1)
+        return lines
+
     def report(self, dst_dir: str) -> None:
         """Create summary of imported modules."""
         # Save the first part of import statement since it references the installed
         # module.
-        imports_as_set = {module.split(".", 1)[0] for module in self.stats["import"]}
+        imports_as_set = {module[0].split(".", 1)[0] for module in self.stats["import"]}
 
         imports_as_set.update(
             # Add modules imported with "from X import Y".
-            {module.split(".", 1)[0] for module in self.stats["from"]}
+            {module[0].split(".", 1)[0] for module in self.stats["from"]}
         )
 
+        # Remove plotid from the list
+        imports_as_set.discard("plotid")
+
         output_file = os.path.join(dst_dir, "required_imports.txt")
         # Write every item of the set to one line.
         with open(output_file, "w", encoding="utf-8") as output:
diff --git a/tests/test_publish.py b/tests/test_publish.py
index eb985fe..cff7099 100644
--- a/tests/test_publish.py
+++ b/tests/test_publish.py
@@ -13,8 +13,15 @@ from collections import Counter
 from importlib.metadata import version
 from subprocess import run, CalledProcessError
 from unittest.mock import patch
+import re
 import matplotlib.pyplot as plt
-from plotid.publish import publish, PublishOptions
+from plotid.publish import (
+    publish,
+    PublishOptions,
+    find_plotid_calls,
+    find_pattern,
+    comment_lines,
+)
 from plotid.plotoptions import PlotIDTransfer
 
 
@@ -30,6 +37,9 @@ FIGS_AS_LIST = [FIG, FIG2]
 IDS_AS_LIST = ["MR05_0x63203c6f", "MR05_0x63203c70"]
 FIGS_AND_IDS = PlotIDTransfer(FIGS_AS_LIST, IDS_AS_LIST)
 PIC_NAME_LIST = [PIC_NAME, "second_picture"]
+DST_DIR = os.path.abspath(os.path.join("tests", "test_comment"))
+TEXT_FILE = "tmp_test_find_lines.txt"
+PYTHON_FILE = "tmp_test_find_calls.py"
 
 
 class TestPublish(unittest.TestCase):
@@ -44,6 +54,21 @@ class TestPublish(unittest.TestCase):
         for file in SRC_FILES:
             open(file, "w", encoding="utf-8").close()
 
+        os.mkdir(DST_DIR)
+        with open(TEXT_FILE, "x", encoding="utf-8") as file:
+            content = (
+                "Lorem ipsum(\ndolor\nsit() amet(,\nconsectetur) adipisici )elit, sed "
+                "eiusmod tempor\nincidunt ut\nlab\nore et dolore\nmagna aliqua."
+            )
+            file.write(content)
+        with open(PYTHON_FILE, "x", encoding="utf-8") as file:
+            content = (
+                "import sys\nfrom plotid.tagplot import tagplot\nimport plotid\n"
+                "x=123\ns='abc'\npublish(\n'Lorem ipsum',\nx+1 (\n)  )"
+                "  \ntagplot()\n \n\n\ntagplot(()) \nsys.exit()"
+            )
+            file.write(content)
+
     # Skip test if tests are run from command line.
     @unittest.skipIf(
         not os.path.isfile(sys.argv[0]),
@@ -357,8 +382,8 @@ class TestPublish(unittest.TestCase):
             "platform\n",
             f"matplotlib=={mpl_version}\n",
             "os\n",
+            "re\n",
             "sys\n",
-            "plotid\n",
             "collections\n",
             "importlib\n",
         ]
@@ -379,10 +404,49 @@ class TestPublish(unittest.TestCase):
             msg=f"{modules}, {expected_modules}",
         )
 
+    def test_find_plotid_calls(self) -> None:
+        """Test if all calls to plotID in a file are found."""
+        find_plotid_calls(PYTHON_FILE, DST_DIR)
+        copied_file = os.path.join(DST_DIR, PYTHON_FILE)
+        expected = (
+            "# This script was automatically processed by plotID to comment "
+            "all\n# function calls to plotID.\nimport sys\n# from plotid.tagplot import"
+            " tagplot\n# import plotid\nx=123\ns='abc'\n# publish(\n# 'Lorem ipsum',"
+            "\n# x+1 (\n# )  )  \n# tagplot()\n \n\n\n# tagplot(()) \nsys.exit()"
+        )
+        with open(copied_file, "r", encoding="utf-8") as file:
+            new_content = file.read()
+        self.assertEqual(new_content, expected)
+
+    def test_find_pattern(self) -> None:
+        """Test if RegEx pattern is found and correct list of lines is returned."""
+        pattern_simple = re.compile(r"incidunt")
+        pattern_bracket = re.compile(r"ipsum\(")
+        lines_simple = find_pattern(TEXT_FILE, pattern_simple)
+        self.assertEqual(lines_simple, [4])
+        lines_bracket = find_pattern(TEXT_FILE, pattern_bracket)
+        self.assertEqual(lines_bracket, [0, 1, 2, 3])
+
+    def test_comment_lines(self) -> None:
+        """Test if correct lines get commented."""
+        expected = (
+            "# This script was automatically processed by plotID to comment "
+            "all\n# function calls to plotID.\n"
+            "Lorem ipsum(\n# dolor\nsit() amet(,\n# consectetur) adipisici )elit, sed"
+            " eiusmod tempor\nincidunt ut\nlab\n# ore et dolore\nmagna aliqua."
+        )
+        comment_lines(TEXT_FILE, {1, 3, 6}, DST_DIR)
+        with open(os.path.join(DST_DIR, TEXT_FILE), "r", encoding="utf-8") as file:
+            new_content = file.read()
+        self.assertEqual(new_content, expected)
+
     def tearDown(self) -> None:
         """Delete all files created in setUp."""
         shutil.rmtree(SRC_DIR)
         shutil.rmtree(DST_PARENT_DIR)
+        shutil.rmtree(DST_DIR)
+        os.remove(TEXT_FILE)
+        os.remove(PYTHON_FILE)
         for file in SRC_FILES:
             os.remove(file)
 
-- 
GitLab