Source code for adam.model.tree

import dataclasses
from typing import Iterable, Iterator, Union

from adam.model.abc_factories import Joint, Link


@dataclasses.dataclass
[docs] class Node: """The node class"""
[docs] name: str
[docs] arcs: list[Joint]
[docs] children: list["Node"]
[docs] parent: Union[Link, None] = None
[docs] parent_arc: Union[Joint, None] = None
[docs] def __hash__(self) -> int: return hash(self.name)
[docs] def get_elements(self) -> tuple[Link, Joint, Link]: """returns the node with its parent arc and parent link Returns: tuple[Link, Joint, Link]: the node, the parent_arc, the parent_link """ return self.link, self.parent_arc, self.parent
@dataclasses.dataclass
[docs] class Tree(Iterable): """The directed tree class"""
[docs] graph: dict[str, Node]
[docs] root: str
[docs] def __post_init__(self): self.ordered_nodes_list = self.get_ordered_nodes_list(self.root)
@staticmethod
[docs] def build_tree(links: list[Link], joints: list[Joint]) -> "Tree": """builds the tree from the connectivity of the elements Args: links (list[Link]) joints (list[Joint]) Returns: Tree: the directed tree """ nodes: dict[str, Node] = { l.name: Node( name=l.name, link=l, arcs=[], children=[], parent=None, parent_arc=None ) for l in links } for joint in joints: # don't add the frames if joint.parent not in nodes.keys() or joint.child not in nodes.keys(): continue if joint.parent not in {l.name for l in nodes[joint.parent].children}: nodes[joint.parent].children.append(nodes[joint.child]) nodes[joint.parent].arcs.append(joint) nodes[joint.child].parent = nodes[joint.parent].link nodes[joint.child].parent_arc = joint root_link = [l for l in nodes if nodes[l].parent is None] if len(root_link) != 1: raise ValueError( f"Expected only one root, found {len(root_link)}: {root_link}" ) return Tree(nodes, root_link[0])
[docs] def print(self, root): """prints the tree Args: root (str): the root of the tree """ import pptree pptree.print_tree(self.graph[root])
[docs] def get_ordered_nodes_list(self, start: str) -> list[str]: """get the ordered list of the nodes, given the connectivity Args: start (str): the start node Returns: list[str]: the ordered list """ ordered_list = [start] self.get_children(self.graph[start], ordered_list) return ordered_list
@classmethod
[docs] def get_children(cls, node: Node, list: list): """Recursive method that finds children of child of child Args: node (Node): the analized node list (list): the list of the children that needs to be filled """ if node.children is not []: for child in node.children: list.append(child.name) cls.get_children(child, list)
[docs] def get_idx_from_name(self, name: str) -> int: """ Args: name (str): node name Returns: int: the index of the node in the ordered list """ return self.ordered_nodes_list.index(name)
[docs] def get_name_from_idx(self, idx: int) -> str: """ Args: idx (int): the index in the ordered list Returns: str: the corresponding node name """ return self.ordered_nodes_list[idx]
[docs] def get_node_from_name(self, name: str) -> Node: """ Args: name (str): the node name Returns: Node: the node istance """ return self.graph[name]
[docs] def __iter__(self) -> Iterator[Node]: """This method allows to iterate on the model Returns: Node: the node istance Yields: Iterator[Node]: the list of the nodes """ yield from [self.graph[name] for name in self.ordered_nodes_list]
[docs] def __reversed__(self) -> Iterator[Node]: """ Returns: Node Yields: Iterator[Node]: the reversed nodes list """ yield from reversed(self)
[docs] def __getitem__(self, key) -> Node: """get the item at key in the model Args: key (Union[int, Slice]): _description_ Returns: Node: _description_ """ return self.graph[self.ordered_nodes_list[key]]
[docs] def __len__(self) -> int: """ Returns: int: the length of the model """ return len(self.ordered_nodes_list)