from pydantic import BaseModel, model_serializer, ConfigDict
from typing import TypeVar, Any, Type, List, Union
import yaml
import numpy as np
from pydantic_core.core_schema import SerializationInfo
# Create a generic variable that can be 'Parent', or any subclass.
T = TypeVar("T", bound="BaseModel")
[docs]
class string_with_quotes(str):
pass
[docs]
class flow_list(list):
pass
[docs]
def flow_list_rep(dumper, data):
return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True)
[docs]
def quoted_presenter(dumper, data):
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style='"')
yaml.add_representer(string_with_quotes, quoted_presenter)
yaml.add_representer(flow_list, flow_list_rep)
[docs]
def convert_numpy_types(v: Any) -> Any:
"""
Recursively convert numpy types in a data structure to native Python types.
Args:
v (Any): The input data structure which may contain numpy types.
Returns:
Any: The data structure with numpy types converted to native Python types.
"""
if isinstance(v, (dict)):
return {k: convert_numpy_types(l) for k, l in v.items()}
if isinstance(v, (np.ndarray, list, tuple)):
return flow_list([convert_numpy_types(arr) for arr in v])
elif isinstance(v, (np.float64, np.float32, np.float16)):
return float(v)
elif isinstance(
v,
(
np.int_,
np.intc,
np.intp,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
),
):
return int(v)
else:
return v
[docs]
class ModelBase(BaseModel):
"""Base Model that ignores extra fields."""
def __eq__(self, other):
"""Equality that gracefully handles numpy arrays in private attributes."""
try:
return super().__eq__(other)
except (ValueError, TypeError):
# Fallback: compare serialised forms when private-attribute
# comparison fails (e.g. numpy arrays).
if not isinstance(other, BaseModel):
return NotImplemented
return self.model_dump() == other.model_dump()
def __hash__(self):
return id(self)
[docs]
def base_model_dump(self, exclude_defaults: bool = False) -> dict:
return convert_numpy_types(
self.model_dump(exclude_none=True, exclude_defaults=exclude_defaults)
)
[docs]
class NumpyModel(ModelBase):
"""Model using numpy arrays."""
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_serializer(mode="wrap")
def ser_model(self, handler, info: SerializationInfo):
if info.mode == "json":
return self.array # vector for JSON
return handler(self) # default dict for python
@property
def array(self) -> np.ndarray:
cls = self.__class__
return np.array([getattr(self, a) for a in cls.model_fields.keys()])
[docs]
@classmethod
def from_list(cls: Type[T], vec: List[Union[float, int]]) -> T:
assert len(vec) == len(cls.model_fields.keys())
return cls(**dict(zip(list(cls.model_fields.keys()), vec)))
[docs]
@classmethod
def from_values(cls: Type[T], *values: Union[float, int]) -> T:
assert len(values) == len(cls.model_fields.keys())
return cls(**dict(zip(list(cls.model_fields.keys()), values)))
[docs]
def update(self, **kwargs):
[
v.annotation.update(v)
for k, v in self.model_fields.items()
if hasattr(v.annotation, "update")
]
self.__dict__.update(kwargs)
[docs]
class NumpyVectorModel(NumpyModel):
"""vector model using numpy arrays."""
def __iter__(self) -> iter:
cls = self.__class__
return iter([getattr(self, k) for k in cls.model_fields.keys()])
def __eq__(self, other: Any) -> bool:
cls = self.__class__
if other == 0 or other == 0.0 or other is None:
if all([getattr(self, k) == 0 for k in cls.model_fields.keys()]):
return True
return False
return list(self) == list(other)
def __neq__(self, other: Any) -> bool:
cls = self.__class__
if other == 0 or other == 0.0 or other is None:
if all([getattr(self, k) == 0 for k in cls.model_fields.keys()]):
return False
return True
return list(self) != list(other)
[docs]
class objectList(IgnoreExtra):
def __iter__(self) -> iter:
cls = self.__class__
return iter(getattr(self, list(cls.model_fields.keys())[0]))
def __str__(self) -> str:
cls = self.__class__
return str(list(getattr(self, list(cls.model_fields.keys())[0])))
def __repr__(self) -> repr:
cls = self.__class__
return repr(list(getattr(self, list(cls.model_fields.keys())[0])))
[docs]
class DeviceList(objectList):
devices: list = []
[docs]
class Aliases(objectList):
aliases: list = []