diff --git a/docs/docstr_coverage_badge.svg b/docs/docstr_coverage_badge.svg index a457fdc8906fbe80d75cec963a3544560b29ca67..56ebe18bd7caccd4eb1f7eb72f42abd353bad76c 100644 --- a/docs/docstr_coverage_badge.svg +++ b/docs/docstr_coverage_badge.svg @@ -8,13 +8,13 @@ </clipPath> <g clip-path="url(#r)"> <rect width="99" height="20" fill="#555"/> - <rect x="99" width="43" height="20" fill="#a4a61d"/> + <rect x="99" width="43" height="20" fill="#97CA00"/> <rect width="142" height="20" fill="url(#s)"/> </g> <g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" font-size="110"> <text x="505" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="890">docstr-coverage</text> <text x="505" y="140" transform="scale(.1)" textLength="890">docstr-coverage</text> - <text x="1195" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)">83%</text> - <text x="1195" y="140" transform="scale(.1)">83%</text> + <text x="1195" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)">93%</text> + <text x="1195" y="140" transform="scale(.1)">93%</text> </g> </svg> \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a31696c166a9c44e2b8683a80aa3886676a9704d..a7fb958ee73e2656de778a6ab37541dbe724248b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ classifiers = [ dependencies = [ "betterproto>=2.0.0b6", "nest-asyncio", + "scipy", ] [project.urls] @@ -132,7 +133,7 @@ detached = true [tool.hatch.envs.docs.scripts] build = "mkdocs build" serve = "mkdocs serve" -doc-cov = "docstr-coverage {args:src/vapython} --exclude=\".*_grpc.*\" -b docs --skip-file-doc" +doc-cov = "docstr-coverage {args:src/vapython} --exclude=\".*(?:_grpc|NatNetClient).*\" -b docs --skip-file-doc" [tool.black] target-version = ["py37"] @@ -188,9 +189,13 @@ unfixable = [ ] exclude = [ "*/vanet/_vanet_grpc.py", + "*/tracking/NatNetClient.py", "examples/*" ] +[tool.ruff.lint.pyupgrade] +keep-runtime-typing = true + [tool.ruff.isort] known-first-party = ["vapython"] diff --git a/scripts/build_hook.py b/scripts/build_hook.py index 1168a663f12d53497072618e69197c35df322063..994900300f5de14f500b22ea4e3e6135aa1e1cc2 100644 --- a/scripts/build_hook.py +++ b/scripts/build_hook.py @@ -33,7 +33,7 @@ class CustomBuildHook(BuildHookInterface): if branch not in ["master", "develop"]: branch = "develop" - branch = "feature/grpc-improvement" # TODO: hardcoded branch name, remove once the branch is merged in VANet + branch = "feature/interface-clean-up" # TODO: hardcoded branch name, remove once the branch is merged in VANet build_vapython(build_dir, build_dir / "va_python.py", branch) vanet_dir = src_dir / "vanet" diff --git a/scripts/build_vapython.py b/scripts/build_vapython.py index 63f101f694d13573e1f0a0b9b7d61d9a420f8397..f5995ab128c05c0942318705b7aa4a828c1e5757 100644 --- a/scripts/build_vapython.py +++ b/scripts/build_vapython.py @@ -117,13 +117,13 @@ def get_documentation(): for key, value in data["VA"].items(): docstring = "" - if "brief" in value and value["brief"]: + if value.get("brief"): docstring += value["brief"] + "\n\n" - if "detail" in value and value["detail"]: + if value.get("detail"): docstring += value["detail"] + "\n\n" - if "args" in value and value["args"]: + if value.get("args"): docstring += "Args:\n" if isinstance(value["args"], str): docstring += f" {value['args']}\n" @@ -132,7 +132,7 @@ def get_documentation(): docstring += f" {arg}\n" docstring += "\n" - if "returns" in value and value["returns"]: + if value.get("returns"): docstring += "Returns:\n" if isinstance(value["returns"], str): docstring += f" {value['returns']}\n" @@ -141,7 +141,7 @@ def get_documentation(): docstring += f" {ret}\n" docstring += "\n" - if "raises" in value and value["raises"]: + if value.get("raises"): docstring += "Raises:\n" if isinstance(value["raises"], str): docstring += f" {value['raises']}\n" @@ -150,7 +150,7 @@ def get_documentation(): docstring += f" {exc}\n" docstring += "\n" - if "post_doc" in value and value["post_doc"]: + if value.get("post_doc"): docstring += value["post_doc"] if docstring[-2:] == "\n\n": @@ -175,6 +175,11 @@ def parse_python_file(file_path: Path): "get_sound_source_i_ds": "get_sound_source_ids", } + private_methods = [ + "get_state", + "attach_event_handler", + ] + output_data = [] for node in ast.walk(all_code): @@ -268,20 +273,8 @@ def parse_python_file(file_path: Path): fixed_name = name_fix[n.name] if n.name in name_fix else n.name - placeholder_method_target = ["geometry_mesh", "acoustic_material", "sound_portal"] - is_placeholder = any(mth_target in fixed_name for mth_target in placeholder_method_target) - - if is_placeholder: - documentation[fixed_name] = "".join( - ( - documentation[fixed_name] if fixed_name in documentation else "", - """ - -Warning: - This is a placeholders for potential future functions of VA, currently it has no effect. -""", - ) - ) + if fixed_name in private_methods: + fixed_name = f"_{fixed_name}" output_data.append( { @@ -293,7 +286,6 @@ Warning: "returns": return_type, "wrapped_return_type": wrapped_return_type, "docstring": documentation[fixed_name] if fixed_name in documentation else "", - "placeholder": is_placeholder, } ) diff --git a/scripts/templates/wrapper.py.j2 b/scripts/templates/wrapper.py.j2 index be8c887a557923dda2c071e1c8f63254c0fdd8a3..c7475d3634cb9b444f8bf5b1e844572585cdc446 100644 --- a/scripts/templates/wrapper.py.j2 +++ b/scripts/templates/wrapper.py.j2 @@ -90,12 +90,6 @@ class VAInterface: """{{ method.docstring -}} """ - {% endif %} - {% if method.placeholder %} - warnings.warn( - "This is a placeholders for potential future functions of VA, currently it has no effect.", - stacklevel=2, - ) {% endif %} return_value = self.loop.run_until_complete( self.service.{{ method.org_name }}( diff --git a/src/vapython/__init__.py b/src/vapython/__init__.py index b03c5cd6f1ac1f4cc6bc9e94e9c8087e24752e43..275f4341df7109d1dee801f993127e24318034f0 100644 --- a/src/vapython/__init__.py +++ b/src/vapython/__init__.py @@ -5,12 +5,8 @@ import nest_asyncio # type: ignore from vapython.va import VA -from vapython.vanet._vanet_grpc import PlaybackActionAction, PlaybackStateState nest_asyncio.apply() -PlaybackStateState.__str__ = lambda self: self.name # type: ignore - - -__all__ = ["VA", "PlaybackActionAction"] +__all__ = ["VA"] diff --git a/src/vapython/tracking/NatNetClient.py b/src/vapython/tracking/NatNetClient.py new file mode 100644 index 0000000000000000000000000000000000000000..e41b97a6c273baeb61b9eab577d26f18e44ff5fc --- /dev/null +++ b/src/vapython/tracking/NatNetClient.py @@ -0,0 +1,490 @@ +import select +import socket +import struct +from threading import Thread + + +def trace(*args): + pass # print( "".join(map(str,args)) ) + + +# Create structs for reading various object types to speed up parsing. +Vector3 = struct.Struct("<fff") +Quaternion = struct.Struct("<ffff") +FloatValue = struct.Struct("<f") +DoubleValue = struct.Struct("<d") + + +class NatNetClient: + def __init__(self): + # Change this value to the IP address of the NatNet server. + self.serverIPAddress = "169.254.201.120" + + # This should match the multicast address listed in Motive's streaming settings. + self.multicastAddress = "239.255.42.99" + + # NatNet Command channel + self.commandPort = 1510 + + # NatNet Data channel + self.dataPort = 1511 + + # Set this to a callback method of your choice to receive per-rigid-body data at each frame. + self.rigidBodyListener = None + + # NatNet stream version. This will be updated to the actual version the server is using during initialization. + self.__natNetStreamVersion = (3, 0, 0, 0) + + self.__stop = False + self.__dataThread: Thread | None = None + self.__commandThread: Thread | None = None + + # Client/server message ids + NAT_PING = 0 + NAT_PINGRESPONSE = 1 + NAT_REQUEST = 2 + NAT_RESPONSE = 3 + NAT_REQUEST_MODELDEF = 4 + NAT_MODELDEF = 5 + NAT_REQUEST_FRAMEOFDATA = 6 + NAT_FRAMEOFDATA = 7 + NAT_MESSAGESTRING = 8 + NAT_DISCONNECT = 9 + NAT_UNRECOGNIZED_REQUEST = 100 + + # Create a data socket to attach to the NatNet stream + def __createDataSocket(self, port): + result = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) # Internet # UDP + result.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + result.bind(("", port)) + + mreq = struct.pack("4sl", socket.inet_aton(self.multicastAddress), socket.INADDR_ANY) + result.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) + + result.setblocking(False) + return result + + # Create a command socket to attach to the NatNet stream + def __createCommandSocket(self): + result = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + result.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + result.bind(("", 0)) + result.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + result.setblocking(False) + + return result + + # Unpack a rigid body object from a data packet + def __unpackRigidBody(self, data): + offset = 0 + + # ID (4 bytes) + id = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("ID:", id) + + # Position and orientation + pos = Vector3.unpack(data[offset : offset + 12]) + offset += 12 + trace("\tPosition:", pos[0], ",", pos[1], ",", pos[2]) + rot = Quaternion.unpack(data[offset : offset + 16]) + offset += 16 + trace("\tOrientation:", rot[0], ",", rot[1], ",", rot[2], ",", rot[3]) + + # Marker count (4 bytes) + markerCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + markerCountRange = range(0, markerCount) + trace("\tMarker Count:", markerCount) + + # Send information to any listener. + if self.rigidBodyListener is not None: + self.rigidBodyListener(id, pos, rot) + + # Marker positions + for i in markerCountRange: + pos = Vector3.unpack(data[offset : offset + 12]) + offset += 12 + trace("\tMarker", i, ":", pos[0], ",", pos[1], ",", pos[2]) + + if self.__natNetStreamVersion[0] >= 2: + # Marker ID's + for i in markerCountRange: + id = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("\tMarker ID", i, ":", id) + + # Marker sizes + for i in markerCountRange: + size = FloatValue.unpack(data[offset : offset + 4]) + offset += 4 + trace("\tMarker Size", i, ":", size[0]) + + (markerError,) = FloatValue.unpack(data[offset : offset + 4]) + offset += 4 + trace("\tMarker Error:", markerError) + + # Version 2.6 and later + if ( + ((self.__natNetStreamVersion[0] == 2) and (self.__natNetStreamVersion[1] >= 6)) + or self.__natNetStreamVersion[0] > 2 + or self.__natNetStreamVersion[0] == 0 + ): + (param,) = struct.unpack("h", data[offset : offset + 2]) + trackingValid = (param & 0x01) != 0 + offset += 2 + trace("\tTracking Valid:", "True" if trackingValid else "False") + + return offset + + # Unpack a skeleton object from a data packet + def __unpackSkeleton(self, data): + offset = 0 + + id = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("ID:", id) + + rigidBodyCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("Rigid Body Count:", rigidBodyCount) + for j in range(0, rigidBodyCount): + offset += self.__unpackRigidBody(data[offset:]) + + return offset + + # Unpack data from a motion capture frame message + def __unpackMocapData(self, data): + trace("Begin MoCap Frame\n-----------------\n") + + data = memoryview(data) + offset = 0 + + # Frame number (4 bytes) + frameNumber = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("Frame #:", frameNumber) + + # Marker set count (4 bytes) + markerSetCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("Marker Set Count:", markerSetCount) + + for i in range(0, markerSetCount): + # Model name + modelName, separator, remainder = bytes(data[offset:]).partition(b"\0") + offset += len(modelName) + 1 + trace("Model Name:", modelName.decode("utf-8")) + + # Marker count (4 bytes) + markerCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("Marker Count:", markerCount) + + for j in range(0, markerCount): + pos = Vector3.unpack(data[offset : offset + 12]) + offset += 12 + # trace( "\tMarker", j, ":", pos[0],",", pos[1],",", pos[2] ) + + # Unlabeled markers count (4 bytes) + unlabeledMarkersCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("Unlabeled Markers Count:", unlabeledMarkersCount) + + for i in range(0, unlabeledMarkersCount): + pos = Vector3.unpack(data[offset : offset + 12]) + offset += 12 + trace("\tMarker", i, ":", pos[0], ",", pos[1], ",", pos[2]) + + # Rigid body count (4 bytes) + rigidBodyCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("Rigid Body Count:", rigidBodyCount) + + for i in range(0, rigidBodyCount): + offset += self.__unpackRigidBody(data[offset:]) + + # Version 2.1 and later + skeletonCount = 0 + if (self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] > 0) or self.__natNetStreamVersion[ + 0 + ] > 2: + skeletonCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("Skeleton Count:", skeletonCount) + for i in range(0, skeletonCount): + offset += self.__unpackSkeleton(data[offset:]) + + # Labeled markers (Version 2.3 and later) + labeledMarkerCount = 0 + if (self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] > 3) or self.__natNetStreamVersion[ + 0 + ] > 2: + labeledMarkerCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("Labeled Marker Count:", labeledMarkerCount) + for i in range(0, labeledMarkerCount): + id = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + pos = Vector3.unpack(data[offset : offset + 12]) + offset += 12 + size = FloatValue.unpack(data[offset : offset + 4]) + offset += 4 + + # Version 2.6 and later + if ( + (self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] >= 6) + or self.__natNetStreamVersion[0] > 2 + or major == 0 + ): + (param,) = struct.unpack("h", data[offset : offset + 2]) + offset += 2 + occluded = (param & 0x01) != 0 + pointCloudSolved = (param & 0x02) != 0 + modelSolved = (param & 0x04) != 0 + + # Force Plate data (version 2.9 and later) + if (self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] >= 9) or self.__natNetStreamVersion[ + 0 + ] > 2: + forcePlateCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("Force Plate Count:", forcePlateCount) + for i in range(0, forcePlateCount): + # ID + forcePlateID = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("Force Plate", i, ":", forcePlateID) + + # Channel Count + forcePlateChannelCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + + # Channel Data + for j in range(0, forcePlateChannelCount): + trace("\tChannel", j, ":", forcePlateID) + forcePlateChannelFrameCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + for k in range(0, forcePlateChannelFrameCount): + forcePlateChannelVal = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + trace("\t\t", forcePlateChannelVal) + + # Latency + (latency,) = FloatValue.unpack(data[offset : offset + 4]) + offset += 4 + + # Timecode + timecode = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + timecodeSub = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + + # Timestamp (increased to double precision in 2.7 and later) + if (self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] >= 7) or self.__natNetStreamVersion[ + 0 + ] > 2: + (timestamp,) = DoubleValue.unpack(data[offset : offset + 8]) + offset += 8 + else: + (timestamp,) = FloatValue.unpack(data[offset : offset + 4]) + offset += 4 + + # Frame parameters + (param,) = struct.unpack("h", data[offset : offset + 2]) + isRecording = (param & 0x01) != 0 + trackedModelsChanged = (param & 0x02) != 0 + offset += 2 + + # Send information to any listener. + if self.newFrameListener is not None: + self.newFrameListener( + frameNumber, + markerSetCount, + unlabeledMarkersCount, + rigidBodyCount, + skeletonCount, + labeledMarkerCount, + latency, + timecode, + timecodeSub, + timestamp, + isRecording, + trackedModelsChanged, + ) + + # Unpack a marker set description packet + def __unpackMarkerSetDescription(self, data): + offset = 0 + + name, separator, remainder = bytes(data[offset:]).partition(b"\0") + offset += len(name) + 1 + trace("Markerset Name:", name.decode("utf-8")) + + markerCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + + for i in range(0, markerCount): + name, separator, remainder = bytes(data[offset:]).partition(b"\0") + offset += len(name) + 1 + trace("\tMarker Name:", name.decode("utf-8")) + + return offset + + # Unpack a rigid body description packet + def __unpackRigidBodyDescription(self, data): + offset = 0 + + # Version 2.0 or higher + if self.__natNetStreamVersion[0] >= 2: + name, separator, remainder = bytes(data[offset:]).partition(b"\0") + offset += len(name) + 1 + trace("\tMarker Name:", name.decode("utf-8")) + + id = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + + parentID = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + + timestamp = Vector3.unpack(data[offset : offset + 12]) + offset += 12 + + return offset + + # Unpack a skeleton description packet + def __unpackSkeletonDescription(self, data): + offset = 0 + + name, separator, remainder = bytes(data[offset:]).partition(b"\0") + offset += len(name) + 1 + trace("\tMarker Name:", name.decode("utf-8")) + + id = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + + rigidBodyCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + + for i in range(0, rigidBodyCount): + offset += self.__unpackRigidBodyDescription(data[offset:]) + + return offset + + # Unpack a data description packet + def __unpackDataDescriptions(self, data): + offset = 0 + datasetCount = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + + for i in range(0, datasetCount): + type = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + if type == 0: + offset += self.__unpackMarkerSetDescription(data[offset:]) + elif type == 1: + offset += self.__unpackRigidBodyDescription(data[offset:]) + elif type == 2: + offset += self.__unpackSkeletonDescription(data[offset:]) + + def __dataThreadFunction(self, socket, stop): + while not stop(): + # Block for input + ready = select.select([socket], [], [], 1.0) + if ready[0]: + data, addr = socket.recvfrom(32768) # 32k byte buffer size + if len(data) > 0: + self.__processMessage(data) + + def __processMessage(self, data): + trace("Begin Packet\n------------\n") + + messageID = int.from_bytes(data[0:2], byteorder="little") + trace("Message ID:", messageID) + + packetSize = int.from_bytes(data[2:4], byteorder="little") + trace("Packet Size:", packetSize) + + offset = 4 + if messageID == self.NAT_FRAMEOFDATA: + self.__unpackMocapData(data[offset:]) + elif messageID == self.NAT_MODELDEF: + self.__unpackDataDescriptions(data[offset:]) + elif messageID == self.NAT_PINGRESPONSE: + offset += 256 # Skip the sending app's Name field + offset += 4 # Skip the sending app's Version info + self.__natNetStreamVersion = struct.unpack("BBBB", data[offset : offset + 4]) + offset += 4 + elif messageID == self.NAT_RESPONSE: + if packetSize == 4: + commandResponse = int.from_bytes(data[offset : offset + 4], byteorder="little") + offset += 4 + else: + message, separator, remainder = bytes(data[offset:]).partition(b"\0") + offset += len(message) + 1 + trace("Command response:", message.decode("utf-8")) + elif messageID == self.NAT_UNRECOGNIZED_REQUEST: + trace("Received 'Unrecognized request' from server") + elif messageID == self.NAT_MESSAGESTRING: + message, separator, remainder = bytes(data[offset:]).partition(b"\0") + offset += len(message) + 1 + trace("Received message from server:", message.decode("utf-8")) + else: + trace("ERROR: Unrecognized packet type") + + trace("End Packet\n----------\n") + + def sendCommand(self, command, commandStr, socket, address): + # Compose the message in our known message format + if command == self.NAT_REQUEST_MODELDEF or command == self.NAT_REQUEST_FRAMEOFDATA: + packetSize = 0 + commandStr = "" + elif command == self.NAT_REQUEST: + packetSize = len(commandStr) + 1 + elif command == self.NAT_PING: + commandStr = "Ping" + packetSize = len(commandStr) + 1 + + data = command.to_bytes(2, byteorder="little") + data += packetSize.to_bytes(2, byteorder="little") + + data += commandStr.encode("utf-8") + data += b"\0" + + socket.sendto(data, address) + + def run(self): + # Create the data socket + self.dataSocket = self.__createDataSocket(self.dataPort) + if self.dataSocket is None: + print("Could not open data channel") + exit + + # Create the command socket + self.commandSocket = self.__createCommandSocket() + if self.commandSocket is None: + print("Could not open command channel") + exit + + # Create a separate thread for receiving data packets + self.__dataThread = Thread(target=self.__dataThreadFunction, args=(self.dataSocket, lambda: self.__stop)) + self.__dataThread.daemon = True + self.__dataThread.start() + + # Create a separate thread for receiving command packets + self.__commandThread = Thread(target=self.__dataThreadFunction, args=(self.commandSocket, lambda: self.__stop)) + self.__commandThread.daemon = True + self.__commandThread.start() + + self.sendCommand(self.NAT_REQUEST_MODELDEF, "", self.commandSocket, (self.serverIPAddress, self.commandPort)) + + def stop(self): + self.__stop = True + if self.__dataThread is not None: + self.__dataThread.join() + + if self.__commandThread is not None: + self.__commandThread.join() + + self.dataSocket.close() + self.commandSocket.close() diff --git a/src/vapython/tracking/__init__.py b/src/vapython/tracking/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f81a9ae1d3c3cb24911ed77d2d346e30cd96fa --- /dev/null +++ b/src/vapython/tracking/__init__.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: 2024-present Pascal Palenda <pascal.palenda@akustik.rwth-aachen.de> +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import warnings +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +ihta_tracking_available = True +try: + import IHTATrackingPython as IHTATracking # type: ignore +except ImportError: + ihta_tracking_available = False + warnings.warn( + "IHTATrackingPython not found. Falling back to NatNet only.", + ImportWarning, + stacklevel=2, + ) + +if not ihta_tracking_available: + import vapython.tracking.NatNetClient as NatNetTracking + + +if TYPE_CHECKING: + import vapython._types as va_types + from vapython import VA + + +@dataclass +class BaseTrackingData: + """Base class for tracking data.""" + + position_offset: Optional[Union[va_types.VAVector, List[float], Tuple[float, float, float]]] = None + rotation_offset: Optional[Union[va_types.VAQuaternion, List[float], Tuple[float, float, float, float]]] = None + + +@dataclass +class BaseReceiverTrackingData(BaseTrackingData): + """Base class for receiver tracking data.""" + + receiver_id: int = -1 + + +@dataclass +class ReceiverTrackingData(BaseReceiverTrackingData): + """Class for receiver tracking data.""" + + pass + + +@dataclass +class ReceiverTorsoTrackingData(BaseReceiverTrackingData): + """Class for receiver torso tracking data.""" + + pass + + +@dataclass +class ReceiverRealWorldTrackingData(BaseReceiverTrackingData): + """Class for receiver real-world tracking data.""" + + pass + + +@dataclass +class ReceiverRealWorldTorsoTrackingData(BaseReceiverTrackingData): + """Class for receiver real-world torso tracking data.""" + + pass + + +@dataclass +class SourceTrackingData(BaseTrackingData): + """Class for source tracking data.""" + + source_id: int = -1 + + +class TrackingType(Enum): + """Enumeration for the tracking types.""" + + NatNet = 1 + ART = 2 + + +class Tracker: + """Wrapper to interface with different tracking systems. + + Without the IHTATrackingPython library, only NatNet tracking is available. + """ + + def __init__(self, va_instance: VA, server_ip: str, tracker: TrackingType): + """Initialize the tracker. + + Args: + va_instance: The VA instance. + server_ip: The IP address of the tracking server. + tracker: The tracking type. + """ + if tracker == TrackingType.ART and not ihta_tracking_available: + msg = "IHTATrackingPython not found. Cannot use ART tracking." + raise RuntimeError(msg) + + if ihta_tracking_available: + ihta_tracking_type = ( + IHTATracking.Tracker.Type.NATNET if tracker == TrackingType.NatNet else IHTATracking.Tracker.Type.ART + ) + self._tracker = IHTATracking.Tracker(ihta_tracking_type, server_ip) + + self._tracker.registerCallback(self.ihta_callback) + else: + self._tracker = NatNetTracking.NatNetClient() + + self._tracker.serverIPAddress = server_ip + self._tracker.newFrameListener = self.natnet_frame_callback # type: ignore + self._tracker.rigidBodyListener = self.natnet_callback # type: ignore + + self._tracker.run() + + self._va_instance = va_instance + + def __del__(self): + """Destructor.""" + if ihta_tracking_available: + pass + else: + self._tracker.stop() + del self._tracker + + def disconnect(self): + """Disconnect from the tracking.""" + if ihta_tracking_available: + pass + else: + self._tracker.stop() + + def ihta_callback(self, data: List[IHTATracking.TrackingDataPoint]): + """Callback for the IHTA tracking. + + TODO: Implement this method. + """ + pass + + def natnet_callback(self, new_id, position, rotation): + """Callback for the NatNet tracking. + + The callback is called for each rigid body, each frame. + The position is given as a tuple of 3 floats (x, y, z). + The rotation is given as a tuple of 4 floats (x, y, z, w) aka quaternion. + Args: + new_id: The new ID. + position: The position. + rotation: The rotation. + """ + self._va_instance._apply_tracking(new_id, position, rotation) + + def natnet_frame_callback(self, *_): + """Callback for the NatNet tracking for each frame. + + This is not used in the current implementation. + """ + pass diff --git a/src/vapython/va.py b/src/vapython/va.py index 1807fad8f35bae4b210733df6182c1c50b6d0ccc..20b99e051b1c7b6af6731abcab4ba3427bc98570 100644 --- a/src/vapython/va.py +++ b/src/vapython/va.py @@ -2,14 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import time import warnings from pathlib import Path -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union + +from scipy.spatial.transform import Rotation # type: ignore import vapython._helper as helper import vapython._types as va_types import vapython.vanet._vanet_grpc as va_grpc +from vapython import tracking from vapython.vanet import VAInterface @@ -101,6 +106,18 @@ class VA(VAInterface): super().__init__() self._timer_interval = None self._timer_last_call = None + self._tracker: Optional[tracking.Tracker] = None + + self._tracker_data: Dict[ + int, + Union[ + tracking.ReceiverTrackingData, + tracking.ReceiverTorsoTrackingData, + tracking.ReceiverRealWorldTrackingData, + tracking.ReceiverRealWorldTorsoTrackingData, + tracking.SourceTrackingData, + ], + ] = {} def get_global_auralization_mode(self, *, short_form: bool = True) -> str: # type: ignore[override] """Get the global auralization mode. @@ -329,30 +346,6 @@ class VA(VAInterface): return super().create_directivity_from_parameters(name, parameters) - def create_acoustic_material_from_file(self, file_path: Union[str, Path], name: str = "") -> int: - """Create an acoustic material from a file - - Args: - file_path: The path to the file - name: The name of the acoustic material - """ - - parameters: va_types.VAStruct = {"filepath": str(file_path)} - - return super().create_acoustic_material_from_parameters(name, parameters) - - def create_geometry_mesh_from_file(self, file_path: Union[str, Path], name: str = "") -> int: - """Create a geometry mesh from a file - - Args: - file_path: The path to the file - name: The name of the geometry mesh - """ - - parameters: va_types.VAStruct = {"filepath": str(file_path)} - - return super().create_geometry_mesh_from_parameters(name, parameters) - def get_server_state(self) -> va_grpc.CoreState: """Get the state of the server @@ -365,19 +358,7 @@ class VA(VAInterface): The state of the server """ - return super().get_state() - - def get_state(self) -> None: # type: ignore[override] - """Get the state of the server - - Raises: - NotImplementedError: This method is not implemented - ie not to be used in the public interface of VAPython use - [`get_server_state`][vapython.va.VA.get_server_state] instead. - """ - - msg = "Method not implemented in VAPython use `get_server_state`" - raise NotImplementedError(msg) + return super()._get_state() def remove_sound_source_signal_source(self, sound_source_id: int) -> None: """Remove the signal source from the sound source @@ -592,3 +573,325 @@ class VA(VAInterface): pass self._timer_last_call = time.perf_counter_ns() + + def connect_tracker(self, server_ip: str, tracker: tracking.TrackingType = tracking.TrackingType.NatNet) -> None: + """Connect to a tracking system. + + This function connects to a tracking system to track sound sources or receivers. + The tracking system can be used to track the position and orientation of sound sources and receivers. + + Via the `tracker` argument, the tracking system can be selected. Without the IHTATrackingPython package, + only the NatNet tracking system is available. + + Args: + server_ip: The IP address of the tracking server. + local_ip: The local IP address to bind to. + tracker: The tracking system to use. Defaults to `NatNet`. + """ + from vapython.tracking import Tracker + + self._tracker = Tracker(self, server_ip, tracker) + + def disconnect_tracker(self) -> None: + """Disconnect from the tracking system.""" + if self._tracker is None: + return + self._tracker.disconnect() + del self._tracker + self._tracker = None + + def get_tracker_connected(self) -> bool: + """Check if the tracker is connected.""" + return True if self._tracker else False + + def get_tracker_info(self) -> va_types.VAStruct: + """Get information about the tracked objects. + + This includes what objects are tracked and their offsets. + As well as if the tracking is connected. + + Returns: + A dictionary with information about the tracked objects. + """ + tracker_info: va_types.VAStruct = {"IsConnected": True if self._tracker else False} + + for tracker_id, tracker_data in self._tracker_data.items(): + concrete_tracker_data: va_types.VAStruct = {"TrackerID": tracker_id} + + if tracker_data.rotation_offset: + concrete_tracker_data["RotationOffset"] = str(tracker_data.rotation_offset) + + if isinstance(tracker_data, tracking.SourceTrackingData): + concrete_tracker_data["SourceID"] = tracker_data.source_id + + if tracker_data.position_offset: + concrete_tracker_data["PositionOffset"] = str(tracker_data.position_offset) + + tracker_info[f"TrackedSource{tracker_data.source_id}"] = concrete_tracker_data + + elif isinstance(tracker_data, tracking.ReceiverTrackingData): + concrete_tracker_data["ReceiverID"] = tracker_data.receiver_id + + if tracker_data.position_offset: + concrete_tracker_data["PositionOffset"] = str(tracker_data.position_offset) + + tracker_info["TrackedReceiver"] = concrete_tracker_data + + elif isinstance(tracker_data, tracking.ReceiverRealWorldTrackingData): + concrete_tracker_data["ReceiverID"] = tracker_data.receiver_id + + if tracker_data.position_offset: + concrete_tracker_data["PositionOffset"] = str(tracker_data.position_offset) + + tracker_info["TrackedRealWorldReceiver"] = concrete_tracker_data + + elif isinstance(tracker_data, tracking.ReceiverTorsoTrackingData): + concrete_tracker_data["ReceiverID"] = tracker_data.receiver_id + + tracker_info["TrackedReceiverTorso"] = concrete_tracker_data + + elif isinstance(tracker_data, tracking.ReceiverRealWorldTorsoTrackingData): + concrete_tracker_data["ReceiverID"] = tracker_data.receiver_id + + tracker_info["TrackedRealWorldReceiverTorso"] = concrete_tracker_data + + return tracker_info + + def set_tracked_sound_source(self, sound_source_id: int, tracker_id: int) -> None: + """Set a sound source to be tracked by the tracking system. + + This function sets a sound source to be tracked by the tracking system. + The sound source will be tracked with the given `tracker_id`. + + Args: + sound_source_id: The ID of the sound source to track. + tracker_id: The ID of the tracker to use. + """ + self._tracker_data[tracker_id] = tracking.SourceTrackingData(source_id=sound_source_id) + + def set_tracked_sound_receiver(self, sound_receiver_id: int, tracker_id: int) -> None: + """Set a sound receiver to be tracked by the tracking system. + + This function sets a sound receiver to be tracked by the tracking system. + The sound receiver will be tracked with the given `tracker_id`. + + Args: + sound_receiver_id: The ID of the sound receiver to track. + tracker_id: The ID of the tracker to use. + """ + self._tracker_data[tracker_id] = tracking.ReceiverTrackingData(receiver_id=sound_receiver_id) + + def set_tracked_sound_receiver_torso(self, sound_receiver_id: int, tracker_id: int) -> None: + """Set a sound receiver torso to be tracked by the tracking system. + + This function sets a sound receiver torso to be tracked by the tracking system. + The sound receiver torso will be tracked with the given `tracker_id`. + + The rotation of the torso will influence the HRIR selection. + + Args: + sound_receiver_id: The ID of the sound receiver torso to track. + tracker_id: The ID of the tracker to use. + """ + self._tracker_data[tracker_id] = tracking.ReceiverTorsoTrackingData(receiver_id=sound_receiver_id) + + def set_tracked_real_world_sound_receiver(self, sound_receiver_id: int, tracker_id: int) -> None: + """Set a real-world sound receiver to be tracked by the tracking system. + + This function sets a real-world sound receiver to be tracked by the tracking system. + The real-world sound receiver will be tracked with the given `tracker_id`. + + Args: + sound_receiver_id: The ID of the real-world sound receiver to track. + tracker_id: The ID of the tracker to use. + """ + self._tracker_data[tracker_id] = tracking.ReceiverRealWorldTrackingData(receiver_id=sound_receiver_id) + + def set_tracked_real_world_sound_receiver_torso(self, sound_receiver_id: int, tracker_id: int) -> None: + """Set a real-world sound receiver torso to be tracked by the tracking system. + + This function sets a real-world sound receiver torso to be tracked by the tracking system. + The real-world sound receiver torso will be tracked with the given `tracker_id`. + + The rotation of the torso will influence the HRIR selection. + + Args: + sound_receiver_id: The ID of the real-world sound receiver torso to track + tracker_id: The ID of the tracker to use + """ + self._tracker_data[tracker_id] = tracking.ReceiverRealWorldTorsoTrackingData(receiver_id=sound_receiver_id) + + def set_tracked_sound_source_offset( + self, + sound_source_id: int, + *, + position_offset: Optional[Union[va_types.VAVector, List[float], Tuple[float, float, float]]] = None, + orientation_offset: ( + Optional[Union[va_types.VAQuaternion, List[float], Tuple[float, float, float, float]]] | None + ) = None, + ) -> None: + """Set the offset for a tracked sound source. + + The orientation offset is applied directly to the orientation of the sound source. + The position offset is first rotated by the orientation of the sound source and then applied. + + Args: + sound_source_id: The ID of the sound source to set the offset for. + position_offset: The position offset to set. + orientation_offset: The orientation offset to set. + """ + for _, data in self._tracker_data.items(): + if isinstance(data, tracking.SourceTrackingData) and data.source_id == sound_source_id: + if position_offset is not None: + data.position_offset = position_offset + if orientation_offset is not None: + data.rotation_offset = orientation_offset + break + + def set_tracked_sound_receiver_offset( + self, + sound_receiver_id: int, + *, + position_offset: Optional[Union[va_types.VAVector, List[float], Tuple[float, float, float]]] = None, + orientation_offset: ( + Optional[Union[va_types.VAQuaternion, List[float], Tuple[float, float, float, float]]] | None + ) = None, + ) -> None: + """Set the offset for a tracked sound receiver. + + The orientation offset is applied directly to the orientation of the sound receiver. + The position offset is first rotated by the orientation of the sound receiver and then applied. + + Args: + sound_receiver_id: The ID of the sound receiver to set the offset for. + position_offset: The position offset to set. + orientation_offset: The orientation offset to set. + """ + for _, data in self._tracker_data.items(): + if isinstance(data, tracking.ReceiverTrackingData) and data.receiver_id == sound_receiver_id: + if position_offset is not None: + data.position_offset = position_offset + if orientation_offset is not None: + data.rotation_offset = orientation_offset + break + + def set_tracked_sound_receiver_torso_offset( + self, + sound_receiver_id: int, + *, + orientation_offset: ( + Optional[Union[va_types.VAQuaternion, List[float], Tuple[float, float, float, float]]] | None + ) = None, + ) -> None: + """Set the offset for a tracked sound receiver torso. + + Args: + sound_receiver_id: The ID of the sound receiver torso to set the offset for. + + orientation_offset: The orientation offset to set. + """ + for _, data in self._tracker_data.items(): + if isinstance(data, tracking.ReceiverTorsoTrackingData) and data.receiver_id == sound_receiver_id: + if orientation_offset is not None: + data.rotation_offset = orientation_offset + break + + def set_tracked_real_world_sound_receiver_offset( + self, + sound_receiver_id: int, + *, + position_offset: Optional[Union[va_types.VAVector, List[float], Tuple[float, float, float]]] = None, + orientation_offset: ( + Optional[Union[va_types.VAQuaternion, List[float], Tuple[float, float, float, float]]] | None + ) = None, + ) -> None: + """Set the offset for a tracked real-world sound receiver. + + The orientation offset is applied directly to the orientation of the sound receiver. + The position offset is first rotated by the orientation of the sound receiver and then applied. + + Args: + sound_receiver_id: The ID of the real-world sound receiver to set the offset for. + position_offset: The position offset to set. + orientation_offset: The orientation offset to set. + """ + for _, data in self._tracker_data.items(): + if isinstance(data, tracking.ReceiverRealWorldTrackingData) and data.receiver_id == sound_receiver_id: + if position_offset is not None: + data.position_offset = position_offset + if orientation_offset is not None: + data.rotation_offset = orientation_offset + break + + def set_tracked_real_world_sound_receiver_torso_offset( + self, + sound_receiver_id: int, + *, + orientation_offset: ( + Optional[Union[va_types.VAQuaternion, List[float], Tuple[float, float, float, float]]] | None + ) = None, + ) -> None: + """Set the offset for a tracked real-world sound receiver torso. + + Args: + sound_receiver_id: The ID of the real-world sound receiver torso to set the offset for. + orientation_offset: The orientation offset to set. + """ + for _, data in self._tracker_data.items(): + if isinstance(data, tracking.ReceiverRealWorldTorsoTrackingData) and data.receiver_id == sound_receiver_id: + if orientation_offset is not None: + data.rotation_offset = orientation_offset + break + + def _apply_tracking( + self, tracker_id: int, position: Tuple[float, float, float], orientation: Tuple[float, float, float, float] + ): + """Internal function to apply the tracking data to the VA server. + + Args: + tracker_id: The ID of the tracker to apply the data for. + position: The position to apply. + orientation: The orientation to apply. + """ + if tracker_id not in self._tracker_data: + return + + data = self._tracker_data[tracker_id] + + if data.rotation_offset: + orientation_offset = Rotation.from_quat(data.rotation_offset, scalar_first=False) + orientation_quat = Rotation.from_quat(orientation, scalar_first=False) + + orientation = tuple((orientation_offset * orientation_quat).as_quat()) + + if data.position_offset: + orientation_quat = Rotation.from_quat(orientation, scalar_first=False) + + position_offset = orientation_quat.apply(data.position_offset) + + position = tuple([position[i] + position_offset[i] for i in range(3)]) + + def head_above_torso_orientation( + orientation: Tuple[float, float, float, float], receiver_id + ) -> Tuple[float, float, float, float]: + """Calculate the orientation of the head above the torso.""" + orientation_quat = Rotation.from_quat(orientation, scalar_first=False) + + head_quat = Rotation.from_quat(self.get_sound_receiver_orientation(receiver_id), scalar_first=False) + orientation_quat = orientation_quat.inv() * head_quat + return tuple(orientation_quat.as_quat()) + + if isinstance(data, tracking.SourceTrackingData): + self.set_sound_source_pose(data.source_id, position, orientation) + elif isinstance(data, tracking.ReceiverTrackingData): + self.set_sound_receiver_pose(data.receiver_id, position, orientation) + elif isinstance(data, tracking.ReceiverRealWorldTrackingData): + self.set_sound_receiver_real_world_pose(data.receiver_id, position, orientation) + elif isinstance(data, tracking.ReceiverTorsoTrackingData): + # TODO: check if this is correct + orientation = head_above_torso_orientation(orientation, data.receiver_id) + self.set_sound_receiver_head_above_torso_orientation(data.receiver_id, orientation) + elif isinstance(data, tracking.ReceiverRealWorldTorsoTrackingData): + # TODO: check if this is correct + orientation = head_above_torso_orientation(orientation, data.receiver_id) + self.set_sound_receiver_real_world_head_above_torso_orientation(data.receiver_id, orientation) diff --git a/tests/geometry_test.py b/tests/geometry_test.py deleted file mode 100644 index 1e33fa8d26e181961cc060e723559d2b4d25ee6c..0000000000000000000000000000000000000000 --- a/tests/geometry_test.py +++ /dev/null @@ -1,314 +0,0 @@ -# SPDX-FileCopyrightText: 2024-present Pascal Palenda <pascal.palenda@akustik.rwth-aachen.de> -# -# SPDX-License-Identifier: Apache-2.0 - -import random -from pathlib import Path - -import pytest -from betterproto.lib.google import protobuf - -import vapython.vanet._helper as va_grpc_helper -import vapython.vanet._vanet_grpc as va_grpc - -from .utils import random_grpc_struct, random_string, random_struct - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [(random_string(5), random.randint(0, 100)) for _ in range(5)], -) -async def test_create_geometry_mesh(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "create_geometry_mesh" - message_name = "CreateGeometryMeshRequest" - - # todo: add a "mesh" - mesh = va_grpc.GeometryMesh(id=test_input[1], parameters=random_grpc_struct()) - - mocker.patch.object(service, method_name, return_value=protobuf.Int32Value(test_input[1]), autospec=True) - - function = getattr(va, method_name) - - ret_val = function(mesh, test_input[0]) - - getattr(service, method_name).assert_called_once_with(getattr(va_grpc, message_name)(mesh, test_input[0])) - assert ret_val == test_input[1] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [(random_string(5), random_struct(), random.randint(0, 100)) for _ in range(5)], -) -async def test_create_geometry_mesh_from_parameters(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "create_geometry_mesh_from_parameters" - message_name = "CreateGeometryMeshFromParametersRequest" - - mocker.patch.object(service, method_name, return_value=protobuf.Int32Value(test_input[2]), autospec=True) - - function = getattr(va, method_name) - - ret_val = function(test_input[0], test_input[1]) - - getattr(service, method_name).assert_called_once_with( - getattr(va_grpc, message_name)(test_input[0], va_grpc_helper.convert_struct_to_vanet(test_input[1])) - ) - assert ret_val == test_input[2] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - range(5), -) -@pytest.mark.parametrize( - "path_type", - [str, Path], -) -async def test_create_geometry_mesh_from_file(mocked_connection, mocker, test_input, path_type): # noqa: ARG001 - va, service = mocked_connection - - public_method_name = "create_geometry_mesh_from_file" - method_name = "create_geometry_mesh_from_parameters" - message_name = "CreateGeometryMeshFromParametersRequest" - - identifier = random.randint(0, 100) - - mocker.patch.object(service, method_name, return_value=protobuf.Int32Value(identifier), autospec=True) - - test_path = path_type(random_string(5)) - - function = getattr(va, public_method_name) - - ret_val = function(test_path) - - getattr(service, method_name).assert_called_once_with( - getattr(va_grpc, message_name)("", va_grpc_helper.convert_struct_to_vanet({"filepath": str(test_path)})) - ) - assert ret_val == identifier - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [(random.randint(0, 100), random.choice([True, False])) for _ in range(5)], -) -async def test_delete_geometry_mesh(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "delete_geometry_mesh" - message_name = "DeleteGeometryMeshRequest" - - mocker.patch.object(service, method_name, return_value=protobuf.BoolValue(test_input[1]), autospec=True) - - function = getattr(va, method_name) - - ret_val = function(test_input[0]) - - getattr(service, method_name).assert_called_once_with(getattr(va_grpc, message_name)(test_input[0])) - assert ret_val == test_input[1] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - range(5), -) -async def test_get_geometry_mesh(mocked_connection, mocker, test_input): # noqa: ARG001 - va, service = mocked_connection - - method_name = "get_geometry_mesh" - message_name = "GetGeometryMeshRequest" - reply_name = "GeometryMesh" - - # todo: add a "mesh" - reply = getattr(va_grpc, reply_name)( - id=random.randint(0, 100), - enabled=random.choice([True, False]), - parameters=random_grpc_struct(), - ) - - mocker.patch.object(service, method_name, return_value=reply, autospec=True) - - function = getattr(va, method_name) - - ret_val = function(reply.id) - - getattr(service, method_name).assert_called_once_with(getattr(va_grpc, message_name)(reply.id)) - assert ret_val == reply - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - range(5), -) -async def test_get_geometry_mesh_ids(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "get_geometry_mesh_i_ds" - - reply = va_grpc.IntIdVector( - [random.randint(0, 100) for _ in range(test_input)], - ) - - mocker.patch.object(service, method_name, return_value=reply, autospec=True) - - ret_val = va.get_geometry_mesh_ids() - - assert getattr(service, method_name).called - assert ret_val == reply - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [(random.randint(0, 100), random_string(5)) for _ in range(5)], -) -async def test_set_geometry_mesh_name(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "set_geometry_mesh_name" - message_name = "SetGeometryMeshNameRequest" - - mocker.patch.object(service, method_name, return_value=protobuf.Empty(), autospec=True) - - function = getattr(va, method_name) - - function(test_input[0], test_input[1]) - - getattr(service, method_name).assert_called_once_with(getattr(va_grpc, message_name)(test_input[0], test_input[1])) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [(random.randint(0, 100), random_string(5)) for _ in range(5)], -) -async def test_get_geometry_mesh_name(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "get_geometry_mesh_name" - message_name = "GetGeometryMeshNameRequest" - - mocker.patch.object(service, method_name, return_value=protobuf.StringValue(test_input[1]), autospec=True) - - function = getattr(va, method_name) - - ret_val = function(test_input[0]) - - getattr(service, method_name).assert_called_once_with(getattr(va_grpc, message_name)(test_input[0])) - assert ret_val == test_input[1] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [ - ( - random.randint(0, 100), - random_struct(), - ) - for _ in range(5) - ], -) -async def test_set_geometry_mesh_parameters(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "set_geometry_mesh_parameters" - message_name = "SetGeometryMeshParametersRequest" - - mocker.patch.object(service, method_name, return_value=protobuf.Empty(), autospec=True) - - function = getattr(va, method_name) - - function(test_input[0], test_input[1]) - - getattr(service, method_name).assert_called_once_with( - getattr(va_grpc, message_name)( - test_input[0], - va_grpc_helper.convert_struct_to_vanet(test_input[1]), - ) - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [ - ( - random.randint(0, 100), - random_struct(), - random_grpc_struct(), - ) - for _ in range(5) - ], -) -async def test_get_geometry_mesh_parameters(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "get_geometry_mesh_parameters" - message_name = "GetGeometryMeshParametersRequest" - - mocker.patch.object(service, method_name, return_value=test_input[2], autospec=True) - - function = getattr(va, method_name) - - ret_val = function(test_input[0], test_input[1]) - - getattr(service, method_name).assert_called_once_with( - getattr(va_grpc, message_name)(test_input[0], va_grpc_helper.convert_struct_to_vanet(test_input[1])) - ) - assert ret_val == va_grpc_helper.convert_struct_from_vanet(test_input[2]) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [ - ( - random.randint(0, 100), - random.choice([True, False]), - ) - for _ in range(5) - ], -) -async def test_set_geometry_mesh_enabled(mocked_connection, mocker, test_input): - va, service = mocked_connection - - mocker.patch.object(service, "set_geometry_mesh_enabled", return_value=protobuf.Empty(), autospec=True) - - va.set_geometry_mesh_enabled(geometry_mesh_id=test_input[0], enabled=test_input[1]) - - service.set_geometry_mesh_enabled.assert_called_once_with( - va_grpc.SetGeometryMeshEnabledRequest(geometry_mesh_id=test_input[0], enabled=test_input[1]) - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [ - ( - random.randint(0, 100), - random.choice([True, False]), - ) - for _ in range(5) - ], -) -async def test_get_geometry_mesh_enabled(mocked_connection, mocker, test_input): - va, service = mocked_connection - - mocker.patch.object( - service, "get_geometry_mesh_enabled", return_value=protobuf.BoolValue(value=test_input[1]), autospec=True - ) - - ret_val = va.get_geometry_mesh_enabled(test_input[0]) - - service.get_geometry_mesh_enabled.assert_called_once_with(va_grpc.GetGeometryMeshEnabledRequest(test_input[0])) - assert ret_val == test_input[1] diff --git a/tests/global_methods_test.py b/tests/global_methods_test.py index 079d998c1190cb5d7c6195d1014bcb79405e0956..a0581aaeddacdeb5762c7ba7a6f07353e3b1f82a 100644 --- a/tests/global_methods_test.py +++ b/tests/global_methods_test.py @@ -131,7 +131,7 @@ async def test_call_module(mocked_connection, mocker, test_input): va_grpc.VaModuleInfos( module_infos=[ va_grpc.VaModuleInfosModuleInfo( - name=random_string(5), description=random_string(10), parameter=random_grpc_struct() + name=random_string(5), description=random_string(10), id=random.randint(0, 5) ) for _ in range(random.randint(0, 5)) ] @@ -242,14 +242,6 @@ async def test_get_global_auralization_mode(mocked_connection, mocker, aura_mode assert ret_val == va_helper.convert_aura_mode_to_str(aura_mode) -@pytest.mark.asyncio -async def test_get_state(mocked_connection): - va, _ = mocked_connection - - with pytest.raises(NotImplementedError, match="Method not implemented in VAPython use `get_server_state`"): - va.get_state() - - @pytest.mark.asyncio @pytest.mark.parametrize( "core_state", [va_grpc.CoreStateState.CREATED, va_grpc.CoreStateState.READY, va_grpc.CoreStateState.FAIL] diff --git a/tests/material_test.py b/tests/material_test.py deleted file mode 100644 index b1241135368877a830d62e37237ec7130e48255b..0000000000000000000000000000000000000000 --- a/tests/material_test.py +++ /dev/null @@ -1,272 +0,0 @@ -# SPDX-FileCopyrightText: 2024-present Pascal Palenda <pascal.palenda@akustik.rwth-aachen.de> -# -# SPDX-License-Identifier: Apache-2.0 - -import random -from pathlib import Path - -import pytest -from betterproto.lib.google import protobuf - -import vapython.vanet._helper as va_grpc_helper -import vapython.vanet._vanet_grpc as va_grpc - -from .utils import random_grpc_struct, random_string, random_struct - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [(random_string(5), random.randint(0, 100)) for _ in range(5)], -) -async def test_create_acoustic_material(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "create_acoustic_material" - message_name = "CreateAcousticMaterialRequest" - - material = va_grpc.AcousticMaterial(id=test_input[1], parameters=random_grpc_struct()) - - mocker.patch.object(service, method_name, return_value=protobuf.Int32Value(test_input[1]), autospec=True) - - function = getattr(va, method_name) - - ret_val = function(material, test_input[0]) - - getattr(service, method_name).assert_called_once_with(getattr(va_grpc, message_name)(material, test_input[0])) - assert ret_val == test_input[1] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [(random_string(5), random_struct(), random.randint(0, 100)) for _ in range(5)], -) -async def test_create_acoustic_material_from_parameters(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "create_acoustic_material_from_parameters" - message_name = "CreateAcousticMaterialFromParametersRequest" - - mocker.patch.object(service, method_name, return_value=protobuf.Int32Value(test_input[2]), autospec=True) - - function = getattr(va, method_name) - - ret_val = function(test_input[0], test_input[1]) - - getattr(service, method_name).assert_called_once_with( - getattr(va_grpc, message_name)(test_input[0], va_grpc_helper.convert_struct_to_vanet(test_input[1])) - ) - assert ret_val == test_input[2] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - range(5), -) -@pytest.mark.parametrize( - "path_type", - [str, Path], -) -async def test_create_acoustic_material_from_file(mocked_connection, mocker, test_input, path_type): # noqa: ARG001 - va, service = mocked_connection - - public_method_name = "create_acoustic_material_from_file" - method_name = "create_acoustic_material_from_parameters" - message_name = "CreateAcousticMaterialFromParametersRequest" - - identifier = random.randint(0, 100) - - mocker.patch.object(service, method_name, return_value=protobuf.Int32Value(identifier), autospec=True) - - test_path = path_type(random_string(5)) - - function = getattr(va, public_method_name) - - ret_val = function(test_path) - - getattr(service, method_name).assert_called_once_with( - getattr(va_grpc, message_name)("", va_grpc_helper.convert_struct_to_vanet({"filepath": str(test_path)})) - ) - assert ret_val == identifier - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [(random.randint(0, 100), random.choice([True, False])) for _ in range(5)], -) -async def test_delete_acoustic_material(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "delete_acoustic_material" - message_name = "DeleteAcousticMaterialRequest" - - mocker.patch.object(service, method_name, return_value=protobuf.BoolValue(test_input[1]), autospec=True) - - function = getattr(va, method_name) - - ret_val = function(test_input[0]) - - getattr(service, method_name).assert_called_once_with(getattr(va_grpc, message_name)(test_input[0])) - assert ret_val == test_input[1] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - range(5), -) -async def test_get_acoustic_material_info(mocked_connection, mocker, test_input): # noqa: ARG001 - va, service = mocked_connection - - method_name = "get_acoustic_material_info" - message_name = "GetAcousticMaterialInfoRequest" - reply_name = "AcousticMaterial" - - reply = getattr(va_grpc, reply_name)( - id=random.randint(0, 100), - name=random_string(5), - parameters=random_grpc_struct(), - ) - - mocker.patch.object(service, method_name, return_value=reply, autospec=True) - - function = getattr(va, method_name) - - ret_val = function(reply.id) - - getattr(service, method_name).assert_called_once_with(getattr(va_grpc, message_name)(reply.id)) - assert ret_val == reply - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - range(5), -) -async def test_get_acoustic_material_infos(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "get_acoustic_material_infos" - - info = va_grpc.AcousticMaterial( - id=random.randint(0, 100), - name=random_string(5), - parameters=random_grpc_struct(), - ) - reply = va_grpc.AcousticMaterialInfosReply( - acoustic_material_infos=[info for _ in range(test_input)], - ) - - mocker.patch.object(service, method_name, return_value=reply, autospec=True) - - function = getattr(va, method_name) - - ret_val = function() - - assert getattr(service, method_name).called - assert ret_val == reply - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [(random.randint(0, 100), random_string(5)) for _ in range(5)], -) -async def test_set_acoustic_material_name(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "set_acoustic_material_name" - message_name = "SetAcousticMaterialNameRequest" - - mocker.patch.object(service, method_name, return_value=protobuf.Empty(), autospec=True) - - function = getattr(va, method_name) - - function(test_input[0], test_input[1]) - - getattr(service, method_name).assert_called_once_with(getattr(va_grpc, message_name)(test_input[0], test_input[1])) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [(random.randint(0, 100), random_string(5)) for _ in range(5)], -) -async def test_get_acoustic_material_name(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "get_acoustic_material_name" - message_name = "GetAcousticMaterialNameRequest" - - mocker.patch.object(service, method_name, return_value=protobuf.StringValue(test_input[1]), autospec=True) - - function = getattr(va, method_name) - - ret_val = function(test_input[0]) - - getattr(service, method_name).assert_called_once_with(getattr(va_grpc, message_name)(test_input[0])) - assert ret_val == test_input[1] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [ - ( - random.randint(0, 100), - random_struct(), - ) - for _ in range(5) - ], -) -async def test_set_acoustic_material_parameters(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "set_acoustic_material_parameters" - message_name = "SetAcousticMaterialParametersRequest" - - mocker.patch.object(service, method_name, return_value=protobuf.Empty(), autospec=True) - - function = getattr(va, method_name) - - function(test_input[0], test_input[1]) - - getattr(service, method_name).assert_called_once_with( - getattr(va_grpc, message_name)( - test_input[0], - va_grpc_helper.convert_struct_to_vanet(test_input[1]), - ) - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "test_input", - [ - ( - random.randint(0, 100), - random_struct(), - random_grpc_struct(), - ) - for _ in range(5) - ], -) -async def test_get_acoustic_material_parameters(mocked_connection, mocker, test_input): - va, service = mocked_connection - - method_name = "get_acoustic_material_parameters" - message_name = "GetAcousticMaterialParametersRequest" - - mocker.patch.object(service, method_name, return_value=test_input[2], autospec=True) - - function = getattr(va, method_name) - - ret_val = function(test_input[0], test_input[1]) - - getattr(service, method_name).assert_called_once_with( - getattr(va_grpc, message_name)(test_input[0], va_grpc_helper.convert_struct_to_vanet(test_input[1])) - ) - assert ret_val == va_grpc_helper.convert_struct_from_vanet(test_input[2]) diff --git a/tests/source_receiver_test.py b/tests/source_receiver_test.py index 18fe647d128160d6c72b5765d42fd314aad52fcd..cbda6588b7f59951d9dd0019100773c06caa6da9 100644 --- a/tests/source_receiver_test.py +++ b/tests/source_receiver_test.py @@ -835,7 +835,7 @@ async def test_get_sound_info_(mocked_connection, mocker, entity, test_input): reply = getattr(va_grpc, reply_name)( id=test_input[0], name=random_string(5), - parameters=random_grpc_struct(), + explicit_renderer_id=random_string(5), ) mocker.patch.object(service, method_name, return_value=reply, autospec=True) @@ -846,58 +846,3 @@ async def test_get_sound_info_(mocked_connection, mocker, entity, test_input): assert getattr(service, method_name).called assert ret_val == reply - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "entity", - [ - "source", - "receiver", - ], -) -@pytest.mark.parametrize( - "test_input", - [(random.randint(0, 100), random.randint(0, 100)) for _ in range(5)], -) -async def test_set_sound_geometry_mesh_(mocked_connection, mocker, entity, test_input): - va, service = mocked_connection - - method_name = f"set_sound_{entity}_geometry_mesh" - message_name = f"SetSound{entity.capitalize()}GeometryMeshRequest" - - mocker.patch.object(service, method_name, return_value=protobuf.Empty(), autospec=True) - - function = getattr(va, method_name) - - function(test_input[0], test_input[1]) - - getattr(service, method_name).assert_called_once_with(getattr(va_grpc, message_name)(test_input[0], test_input[1])) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "entity", - [ - "source", - "receiver", - ], -) -@pytest.mark.parametrize( - "test_input", - [(random.randint(0, 100), random.randint(0, 100)) for _ in range(5)], -) -async def test_get_sound_geometry_mesh_(mocked_connection, mocker, entity, test_input): - va, service = mocked_connection - - method_name = f"get_sound_{entity}_geometry_mesh" - message_name = f"GetSound{entity.capitalize()}GeometryMeshRequest" - - mocker.patch.object(service, method_name, return_value=protobuf.Int32Value(test_input[1]), autospec=True) - - function = getattr(va, method_name) - - ret_val = function(test_input[0]) - - getattr(service, method_name).assert_called_once_with(getattr(va_grpc, message_name)(test_input[0])) - assert ret_val == test_input[1]