import numpy as np
import scipy
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.collections as mcoll
import matplotlib.patheffects as path_effects
class FixedKeysDict(dict):
def __init__(self, *args, **kwargs):
super(FixedKeysDict, self).__init__(*args, **kwargs)
self._frozen_keys = set(self.keys()) # Capture initial keys
def __setitem__(self, key, value):
if key not in self._frozen_keys:
raise KeyError(f"Key '{key}' cannot be added.")
super(FixedKeysDict, self).__setitem__(key, value)
def __delitem__(self, key):
raise KeyError(f"Key '{key}' cannot be deleted.")
def pop(self, key, default=None):
raise KeyError(f"Key '{key}' cannot be popped.")
def popitem(self):
raise KeyError("Cannot pop item from FixedKeysDict.")
def clear(self):
raise KeyError("Cannot clear FixedKeysDict.")
def update(self, *args, **kwargs):
for key in dict(*args, **kwargs):
if key not in self._frozen_keys:
raise KeyError(f"Key '{key}' cannot be added.")
super(FixedKeysDict, self).update(*args, **kwargs)
class SetDict(object):
# set_dict(self, dTarget, dSource, bPreserve)
# Takes a target dictionary, and enters values from the source dictionary, overwriting or not, as asked.
# For example,
# dT={'a':1, 'b':2}
# dS={'a':0, 'c':0}
# Set(dT, dS, True)
# dT is {'a':1, 'b':2, 'c':0}
#
# dT={'a':1, 'b':2}
# dS={'a':0, 'c':0}
# Set(dT, dS, False)
# dT is {'a':0, 'b':2, 'c':0}
#
def set_dict(self, dTarget, dSource, bPreserve):
for k, v in dSource.items():
bKeyExists = (k in dTarget)
if (not bKeyExists) and type(v) == type({}):
dTarget[k] = {}
if ((not bKeyExists) or not bPreserve) and (type(v) != type({})):
dTarget[k] = v
if type(v) == type({}):
self.set_dict(dTarget[k], v, bPreserve)
def set_dict_with_preserve(self, dTarget, dSource):
self.set_dict(dTarget, dSource, True)
def set_dict_with_overwrite(self, dTarget, dSource):
self.set_dict(dTarget, dSource, False)
class LatexStates:
""" Holds LaTex format corresponding to set symbolic variables.
"""
def __init__(self, dict=None):
self.dict = {'v_para': r'$v_{\parallel}$',
'v_perp': r'$v_{\perp}$',
'phi': r'$\phi$',
'phidot': r'$\dot{\phi}$',
'phi_dot': r'$\dot{\phi}$',
'phiddot': r'$\ddot{\phi}$',
'w': r'$w$',
'zeta': r'$\zeta$',
'I': r'$I$',
'm': r'$m$',
'C_para': r'$C_{\parallel}$',
'C_perp': r'$C_{\perp}$',
'C_phi': r'$C_{\phi}$',
'km1': r'$k_{m_1}$',
'km2': r'$k_{m_2}$',
'km3': r'$k_{m_3}$',
'km4': r'$k_{m_4}$',
'd': r'$d$',
'psi': r'$\psi$',
'gamma': r'$\gamma$',
'alpha': r'$\alpha$',
'of': r'$\frac{g}{d}$',
'gdot': r'$\dot{g}$',
'v_para_dot': r'$\dot{v_{\parallel}}$',
'v_perp_dot': r'$\dot{v_{\perp}}$',
'v_para_dot_ratio': r'$\frac{\Delta v_{\parallel}}{v_{\parallel}}$',
'x': r'$x$',
'y': r'$y$',
'v_x': r'$v_{x}$',
'v_y': r'$v_{y}$',
'v_z': r'$v_{z}$',
'w_x': r'$w_{x}$',
'w_y': r'$w_{y}$',
'w_z': r'$w_{z}$',
'a_x': r'$a_{x}$',
'a_y': r'$a_{y}$',
'vx': r'$v_x$',
'vy': r'$v_y$',
'vz': r'$v_z$',
'wx': r'$w_x$',
'wy': r'$w_y$',
'wz': r'$w_z$',
'ax': r'$ax$',
'ay': r'$ay$',
'beta': r'$\beta',
'thetadot': r'$\dot{\theta}$',
'theta_dot': r'$\dot{\theta}$',
'psidot': r'$\dot{\psi}$',
'psi_dot': r'$\dot{\psi}$',
'theta': r'$\theta$',
'Yaw': r'$\psi$',
'R': r'$\phi$',
'P': r'$\theta$',
'dYaw': r'$\dot{\psi}$',
'dP': r'$\dot{\theta}$',
'dR': r'$\dot{\phi}$',
'acc_x': r'$\dot{v}x$',
'acc_y': r'$\dot{v}y$',
'acc_z': r'$\dot{v}z$',
'Psi': r'$\Psi$',
'Ix': r'$I_x$',
'Iy': r'$I_y$',
'Iz': r'$I_z$',
'Jr': r'$J_r$',
'Dl': r'$D_l$',
'Dr': r'$D_r$',
}
if dict is not None:
SetDict().set_dict_with_overwrite(self.dict, dict)
def convert_to_latex(self, list_of_strings, remove_dollar_signs=False):
""" Loop through list of strings and if any match the dict, then swap in LaTex symbol.
"""
if isinstance(list_of_strings, str): # if single string is given instead of list
list_of_strings = [list_of_strings]
string_flag = True
else:
string_flag = False
list_of_strings = list_of_strings.copy()
for n, s in enumerate(list_of_strings): # each string in list
for k in self.dict.keys(): # check each key in Latex dict
if s == k: # string contains key
# print(s, ',', self.dict[k])
list_of_strings[n] = self.dict[k] # replace string with LaTex
if remove_dollar_signs:
list_of_strings[n] = list_of_strings[n].replace('$', '')
if string_flag:
list_of_strings = list_of_strings[0]
return list_of_strings
def make_segments(x, y):
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
return segments
[docs]
def colorline(x, y, z, ax=None, cmap=plt.get_cmap('copper'), norm=None, linewidth=1.5, alpha=1.0):
# Special case if a single number:
if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack
z = np.array([z])
z = np.asarray(z)
# Set normalization
if norm is None:
norm = plt.Normalize(np.min(z), np.max(z))
# Make segments
segments = make_segments(x, y)
lc = mcoll.LineCollection(segments, array=z, cmap=cmap, norm=norm,
linewidth=linewidth, alpha=alpha,
path_effects=[path_effects.Stroke(capstyle="round")])
# Plot
if ax is None:
ax = plt.gca()
ax.add_collection(lc)
return lc
[docs]
def plot_heatmap_log_timeseries(data, ax=None, log_ticks=None, data_labels=None,
cmap='inferno_r', y_label=None,
aspect=0.25, interpolation=False):
""" Plot log-scale time-series as heatmap.
"""
n_label = data.shape[1]
# Set ticks
if log_ticks is None:
log_tick_low = int(np.floor(np.log10(np.min(data))))
log_tick_high = int(np.ceil(np.log10(np.max(data))))
else:
log_tick_low = log_ticks[0]
log_tick_high = log_ticks[1]
log_ticks = np.logspace(log_tick_low, log_tick_high, log_tick_high - log_tick_low + 1)
# Set color normalization
cnorm = mpl.colors.LogNorm(10 ** log_tick_low, 10 ** log_tick_high)
# Set labels
if data_labels is None:
data_labels = np.arange(0, n_label).tolist()
data_labels = [str(x) for x in data_labels]
# Make figure/axis
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=(5 * 1, 4 * 1), dpi=150)
else:
# ax = plt.gca()
fig = plt.gcf()
# Plot heatmap
if interpolation:
data = 10**scipy.ndimage.zoom(np.log10(data), (interpolation, 1), order=1)
aspect = aspect / interpolation
ax.imshow(data, norm=cnorm, aspect=aspect, cmap=cmap, interpolation='none')
# Set axis properties
ax.grid(True, axis='x')
ax.tick_params(axis='both', which='both', labelsize=6, top=False, labeltop=True, bottom=False, labelbottom=False,
color='gray')
# Set x-ticks
LatexConverter = LatexStates()
data_labels_latex = LatexConverter.convert_to_latex(data_labels)
ax.set_xticks(np.arange(0, len(data_labels)) - 0.5)
ax.set_xticklabels(data_labels_latex)
# Set labels
ax.set_ylabel('time steps', fontsize=7, fontweight='bold')
ax.set_xlabel('states', fontsize=7, fontweight='bold')
ax.xaxis.set_label_position('top')
# Set x-ticks
xticks = ax.get_xticklabels()
for tick in xticks:
tick.set_ha('left')
tick.set_va('center')
# tick.set_rotation(0)
# tick.set_transform(tick.get_transform() + transforms.ScaledTranslation(6 / 72, 0, ax.figure.dpi_scale_trans))
# Colorbar
if y_label is None:
y_label = 'values'
cax = ax.inset_axes((1.03, 0.0, 0.04, 1.0))
cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=cnorm, cmap=cmap), cax=cax, ticks=log_ticks)
cbar.set_label(y_label, rotation=270, fontsize=7, labelpad=8)
cbar.ax.tick_params(labelsize=6)
ax.spines[['bottom', 'top', 'left', 'right']].set_color('gray')
return cnorm, cmap, log_ticks