Source code for adam.model.tree
import dataclasses
from typing import Iterable, Iterator, Union
from adam.model.abc_factories import Joint, Link
@dataclasses.dataclass
[docs]
class Node:
"""The node class"""
[docs]
parent: Union[Link, None] = None
[docs]
parent_arc: Union[Joint, None] = None
[docs]
def __hash__(self) -> int:
return hash(self.name)
[docs]
def get_elements(self) -> tuple[Link, Joint, Link]:
"""returns the node with its parent arc and parent link
Returns:
tuple[Link, Joint, Link]: the node, the parent_arc, the parent_link
"""
return self.link, self.parent_arc, self.parent
@dataclasses.dataclass
[docs]
class Tree(Iterable):
"""The directed tree class"""
[docs]
def __post_init__(self):
self.ordered_nodes_list = self.get_ordered_nodes_list(self.root)
@staticmethod
[docs]
def build_tree(links: list[Link], joints: list[Joint]) -> "Tree":
"""builds the tree from the connectivity of the elements
Args:
links (list[Link])
joints (list[Joint])
Returns:
Tree: the directed tree
"""
nodes: dict[str, Node] = {
l.name: Node(
name=l.name, link=l, arcs=[], children=[], parent=None, parent_arc=None
)
for l in links
}
for joint in joints:
# don't add the frames
if joint.parent not in nodes.keys() or joint.child not in nodes.keys():
continue
if joint.parent not in {l.name for l in nodes[joint.parent].children}:
nodes[joint.parent].children.append(nodes[joint.child])
nodes[joint.parent].arcs.append(joint)
nodes[joint.child].parent = nodes[joint.parent].link
nodes[joint.child].parent_arc = joint
root_link = [l for l in nodes if nodes[l].parent is None]
if len(root_link) != 1:
raise ValueError(
f"Expected only one root, found {len(root_link)}: {root_link}"
)
return Tree(nodes, root_link[0])
[docs]
def print(self, root):
"""prints the tree
Args:
root (str): the root of the tree
"""
import pptree
pptree.print_tree(self.graph[root])
[docs]
def get_ordered_nodes_list(self, start: str) -> list[str]:
"""get the ordered list of the nodes, given the connectivity
Args:
start (str): the start node
Returns:
list[str]: the ordered list
"""
ordered_list = [start]
self.get_children(self.graph[start], ordered_list)
return ordered_list
@classmethod
[docs]
def get_children(cls, node: Node, list: list):
"""Recursive method that finds children of child of child
Args:
node (Node): the analized node
list (list): the list of the children that needs to be filled
"""
if node.children is not []:
for child in node.children:
list.append(child.name)
cls.get_children(child, list)
[docs]
def get_idx_from_name(self, name: str) -> int:
"""
Args:
name (str): node name
Returns:
int: the index of the node in the ordered list
"""
return self.ordered_nodes_list.index(name)
[docs]
def get_name_from_idx(self, idx: int) -> str:
"""
Args:
idx (int): the index in the ordered list
Returns:
str: the corresponding node name
"""
return self.ordered_nodes_list[idx]
[docs]
def get_node_from_name(self, name: str) -> Node:
"""
Args:
name (str): the node name
Returns:
Node: the node istance
"""
return self.graph[name]
[docs]
def __iter__(self) -> Iterator[Node]:
"""This method allows to iterate on the model
Returns:
Node: the node istance
Yields:
Iterator[Node]: the list of the nodes
"""
yield from [self.graph[name] for name in self.ordered_nodes_list]
[docs]
def __reversed__(self) -> Iterator[Node]:
"""
Returns:
Node
Yields:
Iterator[Node]: the reversed nodes list
"""
yield from reversed(self)
[docs]
def __getitem__(self, key) -> Node:
"""get the item at key in the model
Args:
key (Union[int, Slice]): _description_
Returns:
Node: _description_
"""
return self.graph[self.ordered_nodes_list[key]]
[docs]
def __len__(self) -> int:
"""
Returns:
int: the length of the model
"""
return len(self.ordered_nodes_list)