diff --git a/functions/classes.py b/functions/classes.py index 7563deb4c7038d69925b316047c6da58d06d5246..c13ad5bc9760ccf15b876563f6a04e9c5b8e1966 100644 --- a/functions/classes.py +++ b/functions/classes.py @@ -428,41 +428,36 @@ class LegoAssembly: return clone -def print_assembly_tree(root, level=0, is_last=False): +def print_assembly_tree(root, levels=None): """ Prints the assembly tree starting from root with a visualization implemented with text characters. Args: root (LegoAssembly): The root of the assembly tree to print. - level (int): The indentation level. Defaults to 0. - is_last (bool): Determines whether the current node is the last in level. - Defaults to False. + levels (List[bool]): Internally used by recursion to know where to print vertical connection. + Defaults to an empty list. """ if not isinstance(root, LegoAssembly): raise TypeError( f"Argument should be of type {LegoAssembly.__name__}, " f"got {type(root).__name__} instead." ) - """ Print the items. """ + if levels is None: + levels = [] + connection_padding = "".join(map(lambda draw: "│ " if draw else " ", levels)) assembly_padding = "" - if level > 0: - assembly_padding += "│ " * (level - 1) - if is_last: - assembly_padding += "└── " - else: - assembly_padding += "├── " - print(f"{assembly_padding}{root}") + if len(levels) > 0: + assembly_padding = "├── " if levels[-1] else "└── " + print(f"{connection_padding[:-4]}{assembly_padding}{root}") """ Recursively print child components. """ for i, assembly in enumerate(root.assemblies): - is_last_ = i == len(root.assemblies) - 1 and len(root.components) == 0 - print_assembly_tree(assembly, level + 1, is_last_) + is_last = i == len(root.assemblies) - 1 and len(root.components) == 0 + print_assembly_tree(assembly, [*levels, not is_last]) """ Print the components. """ - for i, item in enumerate(root.components): - component_padding = "│ " * level if not is_last else " " - component_padding += "├── " if i < len(root.components) - 1 else "└── " - print(f"{component_padding}{item}") - + for i, component in enumerate(root.components): + component_padding = "├── " if i < len(root.components) - 1 else "└── " + print(f"{connection_padding}{component_padding}{component}") def correct_aggregation_hierarchy(root: LegoAssembly, strict: bool = False): """