Source code for posydon.visualization.plot2D

"""Plotting class for 2D (MESA) psygrids.

The 2D visualization plotting class allows to plot 2D slices of PsyGrid
objects. The PsyGrid object is composed of 2D/3D/4D MESA grid run with POSYDON
and post processed with the psygrid object into an h5 file.

__authors__ = [
    "Simone Bavera <>",
    "Emmanouil Zapartas <>",
    "Konstantinos Kovlakas <>",

import numpy as np
import matplotlib.pyplot as plt
from posydon.utils.gridutils import add_field
from posydon.visualization.plot_defaults import DEFAULT_MARKERS_COLORS_LEGENDS
from posydon.visualization.plot_defaults import PLOT_PROPERTIES
from posydon.visualization.plot_defaults import DEFAULT_LABELS
from posydon.visualization.combine_TF import combine_TF12
import copy

[docs]class plot2D(object): """Plotting class for 2D (MESA) grids.""" def __init__( self, psygrid, x_var_str, y_var_str, z_var_str=None, selected_star_history_for_z_var=1, termination_flag="termination_flag_1", grid_3D=False, slice_3D_var_str=None, slice_3D_var_range=None, grid_4D=False, slice_4D_var_str=None, slice_4D_var_range=None, extra_grid=None, slice_at_RLO=False, MARKERS_COLORS_LEGENDS=None, verbose=False, **kwargs ): """Read a PsyGrid object and plot a 2D slice of x vs y. Parameters ---------- psygrid : object PsyGrid object containing a 2D/3D/4D MESA grid. x_var_str : str String of the initial value to plot on the x axis. Allowed strings are `psygrid.initial_values.dtype.names`. y_var_str : str String of the initial value to plot on the y axis. Allowed strings are `psygrid.initial_values.dtype.names`. z_var_str : str String of the initial value to plot on the z axis (displayed as a color). Allowed strings are `psygrid.final_values.dtype.names`, `psygrid.history1.dtype.names` or `psygrid.history2.dtype.names` depending on "selected_star_history_for_z_var" value, and `psygrid.binary_history.dtype.names`. selected_star_history_for_z_var: int Accepted valuess: 1 or 2. In case z_var_str is an attribute of history1 or history2, then selected_star_history_for_z_var determines which of the two to select. termination_flag : str Termination flag to display, allowed values are: "termination_flag_1", "termination_flag_2", "termination_flag_3", "termination_flag_4", "all". grid_3D : bool If `True`, the psygrid object is a 3D grid and needs to be sliced. slice_3D_var_str : str Variable along which the 3D space will be sliced. Allowed values are `psygrid.initial_values.dtype.names`. slice_3D_var_range : tuple Range between which you want to slice the variable slice_3D_var_str e.g., `(2.5,3.)`. grid_4D : bool If `True`, the psygrid object is a 4D grid and needs to be sliced. slice_4D_var_str : str Variable along which the 4D space will be sliced. Allowed values are `psygrid.initial_values.dtype.names`. slice_4D_var_range : tuple Range between which you want to slice the variable slice_4D_var_str e.g., `(2.5,3.)`. extra_grid : object or array of objects If subset of the grid was rerun a or an extention was added, one can overlay the new psygrid by passing it here. slice_at_RLO : bool If `True`, the object plots the tracks until onset of Roche Lobe overflow. MARKERS_COLORS_LEGENDS : dict Each termination flag is associated with a marker shape, size, color and label (cf. `MARKERS_COLORS_LEGENDS` in ``). DEFAULT_LABELS : dict Each varaible is associated to an axis label. (cf. `DEFAULT_LABELS` in ``). verbose : bool If `True`, the object reports by printing to standard output. **kwargs : dict Dictionary containing extra visualisation options (cf. `PLOT_PROPERTIES` in ``. """ self.psygrid = psygrid # info 4D/3D parameter space self.grid_3D = grid_3D self.slice_3D_var_str = slice_3D_var_str self.slice_3D_var_range = slice_3D_var_range self.grid_4D = grid_4D self.slice_4D_var_str = slice_4D_var_str self.slice_4D_var_range = slice_4D_var_range self.slice_at_RLO = slice_at_RLO self.verbose = verbose # store the extra psygrid if extra_grid is None: self.extra_grid = extra_grid if isinstance(extra_grid, list): raise ValueError( "We support only one extra psygrid at the moment!") else: self.extra_grid = extra_grid # read kwargs for key in kwargs: if key not in PLOT_PROPERTIES: raise ValueError(key + " is not a valid parameter name!") for varname in PLOT_PROPERTIES: default_value = PLOT_PROPERTIES[varname] if ( varname not in ["colorbar", "legend1D", "legend2D"] or varname not in kwargs.keys() ): setattr(self, varname, kwargs.get(varname, default_value)) else: temp_var = {} for sub_varname in PLOT_PROPERTIES[varname]: default_value = PLOT_PROPERTIES[varname][sub_varname] temp_var[sub_varname] = kwargs[varname].get( sub_varname, default_value ) setattr(self, varname, temp_var) # plotting fonts plt.rcParams.update(self.rcParams) # store the initial/final values self.initial_values = self.psygrid.initial_values # add extra properties to initial_values self.add_properties_to_initial_values() self.initial_values_str = self.initial_values.dtype.names self.final_values = self.psygrid.final_values # add extra properties to final_values if termination_flag in ["combined_TF12", "debug"]: self.add_properties_to_final_values(termination_flag) if termination_flag == 'termination_flag_2': # remmove ? from strings TF2 = self.final_values['termination_flag_2'] TF2_clean = [TF.replace('?', '') for TF in TF2] self.final_values['termination_flag_2'] = TF2_clean self.final_values_str = self.final_values.dtype.names # x, y and z variables must exist if x_var_str not in self.initial_values_str and not self.slice_at_RLO: raise ValueError( "x_var_str = {} is not available in psygrid.initial_values". format(x_var_str)) elif ( x_var_str not in self.psygrid[0].binary_history.dtype.names and self.slice_at_RLO ): raise ValueError("x_var_str = {} is not available in " "psygrid.binary_history".format(x_var_str)) else: self.x_var_str = x_var_str if y_var_str not in self.initial_values_str and not self.slice_at_RLO: raise ValueError("y_var_str = {} is not available in " "psygrid.initial_values".format(y_var_str)) elif ( y_var_str not in self.psygrid[0].binary_history.dtype.names and self.slice_at_RLO ): raise ValueError("y_var_str = {} is not available in " "psygrid.binary_history".format(y_var_str)) else: self.y_var_str = y_var_str if selected_star_history_for_z_var in [1, 2]: self.selected_star_history_for_z_var = ( selected_star_history_for_z_var) else: raise ValueError( "selected_star_history_for_z_var should be either 1 or 2" ) if z_var_str is not None: if not self.slice_at_RLO: if isinstance(z_var_str, np.ndarray): self.z_var_str = None self.z_var = z_var_str self.history = False self.binary_history = False elif z_var_str in self.final_values_str: self.z_var_str = z_var_str self.history = False self.binary_history = False elif 'relative_change' in z_var_str: self.z_var_str = z_var_str self.history = False self.binary_history = False self.add_properties_to_final_values(None) elif (self.selected_star_history_for_z_var == 1 and z_var_str in self.psygrid[0].history1.dtype.names): self.z_var_str = z_var_str self.history = True self.binary_history = False elif (self.selected_star_history_for_z_var == 2 and z_var_str in self.psygrid[0].history2.dtype.names): self.z_var_str = z_var_str self.history = True self.binary_history = False elif z_var_str in self.psygrid[0].binary_history.dtype.names: self.z_var_str = z_var_str self.history = False self.binary_history = True else: raise ValueError( "z_var_str = {} is not available in " "psygrid.final_values or psygrid.history1/2 or " "psygrid.binary_history".format(z_var_str) ) else: if self.selected_star_history_for_z_var == 1 and \ z_var_str in self.psygrid[0].history1.dtype.names: self.z_var_str = z_var_str self.history = True self.binary_history = False elif (self.selected_star_history_for_z_var == 2 and z_var_str in self.psygrid[0].history2.dtype.names): self.z_var_str = z_var_str self.history = True self.binary_history = False elif z_var_str in self.psygrid[0].binary_history.dtype.names: self.z_var_str = z_var_str self.history = False self.binary_history = True else: raise ValueError( "z_var_str = {} is not available in psygrid.history1/2" " or psygrid.binary_history".format(z_var_str) ) else: self.z_var_str = None self.history = None self.binary_history = None # get values to plot if termination_flag in [ "termination_flag_1", "termination_flag_2", "termination_flag_3", "termination_flag_4", "combined_TF12", "debug", "interpolation_class" ] or 'SN_type' in termination_flag or 'state' in termination_flag: self.all_termination_flags = False if 'SN_type' in termination_flag: self.update_markers_colors_legends('SN_type', MARKERS_COLORS_LEGENDS) elif 'state' in termination_flag: self.update_markers_colors_legends('state', MARKERS_COLORS_LEGENDS) else: self.update_markers_colors_legends(termination_flag, MARKERS_COLORS_LEGENDS) self.update_values_to_plot(termination_flag) self.extra_grid_termination_flag = termination_flag elif termination_flag == "all": self.all_termination_flags = True else: raise ValueError('termination_flag can only be 1,2,3,4 or "all"!') def __call__(self): """Generate the plot when the class is called.""" fig = plt.figure(figsize=self.figsize) if self.all_termination_flags: ax1 = plt.subplot(2, 2, 1) self.update_markers_colors_legends("termination_flag_1") self.update_values_to_plot("termination_flag_1") self.plot_panel(ax1) ax2 = plt.subplot(2, 2, 2) self.update_markers_colors_legends("termination_flag_2") self.update_values_to_plot("termination_flag_2") self.plot_panel(ax2) ax3 = plt.subplot(2, 2, 3) self.update_markers_colors_legends("termination_flag_3") self.update_values_to_plot("termination_flag_3") self.plot_panel(ax3) ax4 = plt.subplot(2, 2, 4) self.update_markers_colors_legends("termination_flag_4") self.update_values_to_plot("termination_flag_4") self.plot_panel(ax4) # adjust spacing plt.subplots_adjust(wspace=self.wspace, hspace=self.hspace) # save figure if self.fname is not None: fig.savefig(self.path_to_file + self.fname, dpi=self.dpi, bbox_inches=self.bbox_inches) # show figure if self.show_fig: # close figure if self.close_fig: plt.close(fig) else: return fig else: ax = plt.subplot(111) self.plot_panel(ax) # add extra layer of grid on top of the plot if self.extra_grid is not None: # switch values to extra psygrid and update values to plot self.initial_values = self.extra_grid.initial_values self.final_values = self.extra_grid.final_values self.add_properties_to_initial_values() self.update_values_to_plot(self.extra_grid_termination_flag) self.plot_panel(ax, extra_grid_call=True) # add title self.set_title(fig) # save figure if self.fname is not None: fig.savefig(self.path_to_file + self.fname, dpi=self.dpi, bbox_inches=self.bbox_inches) # show figure if self.show_fig: # close figure if self.close_fig: plt.close(fig) else: return fig
[docs] def plot_panel(self, ax, extra_grid_call=False): """Plot the 2D pannel. Parameters ---------- ax : object matplotlib figure axes. extra_grid_call : bool If `True`, one ore more extra grids are passed. """ scatters = [] scatters_legend = [] # plot figure by looping over termination_flag sc_last = None for flag in self.termination_flag_str: selection = self.termination_flag == flag if self.MARKERS_COLORS_LEGENDS[flag][2] is not None: if self.slice_at_RLO: for i in range(len(self.x_var[selection])): if not isinstance(self.x_var_oRLO[selection][i], float): if (not any(np.isnan( self.x_var_oRLO[selection][i])) and not any(np.isnan( self.y_var_oRLO[selection][i]))): plt.plot( self.x_var[selection][i], self.y_var[selection][i], marker=".", color="black", ) plt.plot( self.x_var_oRLO[selection][i], self.y_var_oRLO[selection][i], color="black", ) sc = ax.scatter( self.x_var_oRLO[selection][i][-1], self.y_var_oRLO[selection][i][-1], marker=self.MARKERS_COLORS_LEGENDS[flag][0], linewidths=self.MARKERS_COLORS_LEGENDS[flag][ 1], c=self.MARKERS_COLORS_LEGENDS[flag][2], s=self.marker_size, ) else: plt.plot( self.x_var[selection][i], self.y_var[selection][i], marker=".", color="black", ) sc = ax.scatter( self.x_var[selection][i], self.y_var[selection][i], marker=self.MARKERS_COLORS_LEGENDS[flag][0], linewidths=self.MARKERS_COLORS_LEGENDS[ flag][1], c=self.MARKERS_COLORS_LEGENDS[flag][2], s=self.marker_size, ) else: sc = ax.scatter( self.x_var[selection], self.y_var[selection], marker=self.MARKERS_COLORS_LEGENDS[flag][0], linewidths=self.MARKERS_COLORS_LEGENDS[flag][1], c=self.MARKERS_COLORS_LEGENDS[flag][2], s=self.marker_size, ) else: if self.z_var is not None: if self.slice_at_RLO: for i in range(len(self.x_var[selection])): if not isinstance(self.x_var_oRLO[selection][i], float): if not any( np.isnan(self.x_var_oRLO[selection][i]) ) and not any(np.isnan( self.y_var_oRLO[selection][i])): plt.plot( self.x_var[selection][i], self.y_var[selection][i], marker=".", color="black", ) plt.plot( self.x_var_oRLO[selection][i], self.y_var_oRLO[selection][i], color="black", ) sc = ax.scatter( self.x_var_oRLO[selection][i][-1], self.y_var_oRLO[selection][i][-1], marker=self.MARKERS_COLORS_LEGENDS[ flag][0], linewidths=self.MARKERS_COLORS_LEGENDS[ flag][1], c=self.z_var[selection][i], s=self.marker_size, alpha=0.5, vmin=self.zmin, vmax=self.zmax, ) else: plt.plot( self.x_var[selection][i], self.y_var[selection][i], marker=".", color="black", ) sc = ax.scatter( self.x_var[selection][i], self.y_var[selection][i], marker=self.MARKERS_COLORS_LEGENDS[ flag][0], linewidths=self.MARKERS_COLORS_LEGENDS[ flag][1], c=self.z_var[selection][i], s=self.marker_size, alpha=0.5, vmin=self.zmin, vmax=self.zmax, ) else: sc = ax.scatter( self.x_var[selection], self.y_var[selection], marker=self.MARKERS_COLORS_LEGENDS[flag][0], linewidths=self.MARKERS_COLORS_LEGENDS[flag][1], c=self.z_var[selection], s=self.marker_size, vmin=self.zmin, vmax=self.zmax, ) sc_last = sc # collect scatters for legend if self.MARKERS_COLORS_LEGENDS[flag][3] not in scatters_legend: scatters.append(sc) scatters_legend.append(self.MARKERS_COLORS_LEGENDS[flag][3]) if sc_last is not None and not extra_grid_call: self.set_color_bar(sc_last) # add labels and legend self.set_xlabel() self.set_ylabel() self.set_xlim() self.set_ylim() self.set_legend(ax, scatters, scatters_legend)
[docs] def add_properties_to_initial_values(self): """Add extra initial values.""" # add the column mass_ratio old_initial_values = copy.copy(self.initial_values) mass_ratio = (old_initial_values["star_2_mass"] / old_initial_values["star_1_mass"]) new_initial_values = add_field(old_initial_values, [("mass_ratio", "<f8")]) new_initial_values["mass_ratio"] = mass_ratio self.initial_values = new_initial_values
[docs] def add_properties_to_final_values(self, termination_flag=None): """Add extra initial values.""" old_initial_values = copy.copy(self.initial_values) old_final_values = copy.copy(self.final_values) if termination_flag == "combined_TF12": combined_TF12 = combine_TF12( old_final_values['interpolation_class'], old_final_values['termination_flag_2'], self.verbose) new_final_values = add_field(old_final_values, [("combined_TF12", "<U70")]) new_final_values["combined_TF12"] = combined_TF12 elif termination_flag == "debug": new_final_values = add_field(old_final_values, [("debug", "<U70")]) new_final_values["debug"] = old_final_values['termination_flag_1'] elif 'relative_change' in self.z_var_str: key = self.z_var_str.split('relative_change_')[1] relative_change_key = ( (old_final_values[key] - old_initial_values[key]) / old_initial_values[key]) new_final_values = add_field(old_final_values, [(self.z_var_str, "<f8")]) new_final_values[self.z_var_str] = relative_change_key self.final_values = new_final_values
[docs] def update_markers_colors_legends( self, termination_flag, MARKERS_COLORS_LEGENDS=None ): """Udpdate markers, colors and legend. Parameters ---------- termination_flag : string Termination flag to display, allowed values are: "termination_flag_1", "termination_flag_2", "termination_flag_3", "termination_flag_4". MARKERS_COLORS_LEGENDS : dict Each termination flag is associated with a marker shape, size, color and label. """ if MARKERS_COLORS_LEGENDS is None: self.MARKERS_COLORS_LEGENDS = DEFAULT_MARKERS_COLORS_LEGENDS[ termination_flag] else: self.MARKERS_COLORS_LEGENDS = MARKERS_COLORS_LEGENDS[ termination_flag]
[docs] def update_values_to_plot(self, termination_flag): """Update all values to plot. Parameters ---------- termination_flag : string Termination flag to display, allowed values are: "termination_flag_1", "termination_flag_2", "termination_flag_3", "termination_flag_4". """ # get termination flags self.termination_flag = self.final_values[termination_flag] # save values to plot self.get_x_var() self.get_y_var() if self.z_var_str is not None: self.get_z_var() elif not hasattr(self, 'z_var'): self.z_var = None # if 4D space: slice it to 2D if self.grid_4D and self.grid_3D: slice = np.logical_and( np.logical_and( self.initial_values[self.slice_4D_var_str] >= self.slice_4D_var_range[0], self.initial_values[self.slice_4D_var_str] <= self.slice_4D_var_range[1], ), np.logical_and( self.initial_values[self.slice_3D_var_str] >= self.slice_3D_var_range[0], self.initial_values[self.slice_3D_var_str] <= self.slice_3D_var_range[1], ), ) self.x_var = self.x_var[slice] self.y_var = self.y_var[slice] if self.z_var is not None: self.z_var = self.z_var[slice] if self.slice_at_RLO: self.x_var_oRLO = self.x_var_oRLO[slice] self.y_var_oRLO = self.y_var_oRLO[slice] self.termination_flag = self.termination_flag[slice] if self.verbose: print("The 4D space was sliced along {} = {} and {} = {}.". format(self.slice_4D_var_str, self.slice_4D_var_range, self.slice_3D_var_str, self.slice_3D_var_range)) print("") print("Total values to plot {}". format(len(self.termination_flag))) # if 3D space: slice it to 2D if (not self.grid_4D) and self.grid_3D: slice = np.logical_and( self.initial_values[self.slice_3D_var_str] >= self.slice_3D_var_range[0], self.initial_values[self.slice_3D_var_str] <= self.slice_3D_var_range[1], ) self.x_var = self.x_var[slice] self.y_var = self.y_var[slice] if self.z_var is not None: self.z_var = self.z_var[slice] if self.slice_at_RLO: self.x_var_oRLO = self.x_var_oRLO[slice] self.y_var_oRLO = self.y_var_oRLO[slice] self.termination_flag = self.termination_flag[slice] if self.verbose: print( "The 3D space was sliced along {} = {}.".format( self.slice_3D_var_str, self.slice_3D_var_range ) ) print("") print("Total values to plot {}.". format(len(self.termination_flag))) # find all different termination flags self.termination_flag_str = np.unique(self.termination_flag) # fix max min color bar if self.z_var is not None: if self.zmin is None: not_nan = np.invert(np.isnan(self.z_var)) self.zmin = min(self.z_var[not_nan]) if self.zmax is None: not_nan = np.invert(np.isnan(self.z_var)) self.zmax = max(self.z_var[not_nan])
[docs] def get_x_var(self): """Get x value to plot.""" if self.log10_x: self.x_var = np.log10(self.initial_values[self.x_var_str]) else: self.x_var = self.initial_values[self.x_var_str] if self.slice_at_RLO: values = [] for run in self.psygrid: # failed runs are stored as signle values or None, not arrays if run.binary_history is None: values.append(np.nan) elif isinstance(run.binary_history[self.x_var_str], np.ndarray): # index of onset of RLO indicies_RLO = np.where( (run.binary_history["rl_relative_overflow_1"] >= -0.05) & (run.binary_history["lg_mtransfer_rate"] >= -12) )[0] if len(indicies_RLO) > 0: index_oRLO = indicies_RLO[0] if self.log10_x: values.append(np.log10(run.binary_history[ self.x_var_str][: index_oRLO + 1])) else: values.append(run.binary_history[ self.x_var_str][: index_oRLO + 1]) else: # the array is empty, no RLO values.append(np.nan) else: values.append(np.nan) self.x_var_oRLO = np.array(values, dtype=object)
[docs] def get_y_var(self): """Get y value to plot.""" if self.log10_y: self.y_var = np.log10(self.initial_values[self.y_var_str]) else: self.y_var = self.initial_values[self.y_var_str] if self.slice_at_RLO: values = [] for run in self.psygrid: # failed runs are stored as signle values or None, not arrays if run.binary_history is None: values.append(np.nan) elif isinstance(run.binary_history[self.y_var_str], np.ndarray): # index of onset of RLO indicies_RLO = np.where( (run.binary_history["rl_relative_overflow_1"] >= -0.05) & (run.binary_history["lg_mtransfer_rate"] >= -12) )[0] if len(indicies_RLO) > 0: index_oRLO = indicies_RLO[0] if self.log10_y: values.append(np.log10(run.binary_history[ self.y_var_str][: index_oRLO + 1])) else: values.append(run.binary_history[ self.y_var_str][: index_oRLO + 1]) else: # the array is empty, no RLO values.append(np.nan) else: values.append(np.nan) self.y_var_oRLO = np.array(values, dtype=object)
[docs] def get_z_var(self): """Get z value to plot.""" if self.history is None: raise ValueError("Something went wrong!") # read the final values from history1 if not self.slice_at_RLO: if self.history: final_values = [] for run in self.psygrid: if self.selected_star_history_for_z_var == 1: history = run.history1 elif self.selected_star_history_for_z_var == 2: history = run.history2 else: raise ValueError( "wrong selected_star_history_for_z_var") # failed runs are stored as signle values or None # and not arrays if history is None: final_values.append(np.nan) elif isinstance(history[self.z_var_str], np.ndarray): final_values.append(history[self.z_var_str][-1]) else: final_values.append(np.nan) final_values = np.array(final_values) if self.log10_z: self.z_var = np.log10(final_values) else: self.z_var = final_values elif self.binary_history: final_values = [] for run in self.psygrid: # failed runs are stored as signle values or None # and not arrays if run.binary_history is None: final_values.append(np.nan) elif isinstance(run.binary_history[self.z_var_str], np.ndarray): final_values.append( run.binary_history[self.z_var_str][-1]) else: final_values.append(np.nan) final_values = np.array(final_values) if self.log10_z: self.z_var = np.log10(final_values) else: self.z_var = final_values # read final values from final_values else: if self.log10_z: self.z_var = np.log10(self.final_values[self.z_var_str]) else: self.z_var = self.final_values[self.z_var_str] # take the z_var at oRLO else: if self.history is None and self.binary_history is None: raise ValueError("Something went wrong!") if self.history: values = [] for run in self.psygrid: if self.selected_star_history_for_z_var == 1: history = run.history1 elif self.selected_star_history_for_z_var == 2: history = run.history2 else: raise ValueError( "wrong selected_star_history_for_z_var") # failed runs are stored as signle values or None # and not arrays if history is None: values.append(np.nan) elif isinstance(history[self.z_var_str], np.ndarray): # index of onset of RLO indicies_RLO = np.where( (run.binary_history["rl_relative_overflow_1"] >= -0.05) & (run.binary_history["lg_mtransfer_rate"] >= -12) )[0] if len(indicies_RLO) > 0: index_oRLO = indicies_RLO[0] values.append(history[self.z_var_str][index_oRLO]) else: # the array is empty, no RLO values.append(np.nan) else: values.append(np.nan) values = np.array(values) if self.log10_z: self.z_var = np.log10(values) else: self.z_var = values elif self.binary_history: values = [] for run in self.psygrid: # failed runs are stored as signle values or None # and not arrays if run.binary_history is None: values.append(np.nan) elif isinstance(run.binary_history[self.z_var_str], np.ndarray): # index of onset of RLO indicies_RLO = np.where( (run.binary_history["rl_relative_overflow_1"] >= -0.05) & (run.binary_history["lg_mtransfer_rate"] >= -12) )[0] if len(indicies_RLO) > 0: index_oRLO = indicies_RLO[0] values.append( run.binary_history[self.z_var_str][index_oRLO] ) else: # the array is empty, no RLO values.append(np.nan) else: values.append(np.nan) values = np.array(values) if self.log10_z: self.z_var = np.log10(values) else: self.z_var = values
[docs] def set_title(self, fig): """Add title. Parameters ---------- fig : object matplotlib figure object. """ if self.title is not None and not self.all_termination_flags: plt.title(self.title, fontdict=self.title_font_dict, loc=self.title_loc) elif self.title is not None and self.all_termination_flags: fig.suptitle(self.title, fontdict=self.title_font_dict)
[docs] def set_xlabel(self): """Add x label.""" if self.xlabel is not None: plt.xlabel(self.xlabel, **self.xlabel_kwargs) else: if self.log10_x: plt.xlabel(DEFAULT_LABELS[self.x_var_str][1], **self.xlabel_kwargs) else: plt.xlabel(DEFAULT_LABELS[self.x_var_str][0], **self.xlabel_kwargs)
[docs] def set_ylabel(self): """Add y label.""" if self.ylabel is not None: plt.ylabel(self.ylabel, **self.ylabel_kwargs) else: if self.log10_y: plt.ylabel(DEFAULT_LABELS[self.y_var_str][1], **self.ylabel_kwargs) else: plt.ylabel(DEFAULT_LABELS[self.y_var_str][0], **self.ylabel_kwargs)
[docs] def set_xlim(self): """Set x axes limits.""" if self.xmin is not None and self.xmax is not None: plt.xlim(self.xmin, self.xmax)
[docs] def set_ylim(self): """Set y axes limits.""" if self.ymin is not None and self.ymax is not None: plt.ylim(self.ymin, self.ymax)
[docs] def set_legend(self, ax, scatters, scatters_legend): """Add legend. Parameters ---------- ax : object matplotlib figure axes. scatters : object matplotlib scatter object. scatters_legend : list of str List of strings which will be used as labels. """ if self.legend2D["title"] is not None: # defailt: shrink current axis by 20% and put tje legend to # the right of the current axis box = ax.get_position() ax.set_position( [box.x0, box.y0, box.width * self.legend2D["shrink_box"], box.height]) ax.legend( scatters, scatters_legend, borderaxespad=self.legend2D["borderaxespad"], handletextpad=self.legend2D["handletextpad"], columnspacing=self.legend2D["columnspacing"], title=self.legend2D["title"], title_fontsize=self.legend2D["title_font_size"], prop=self.legend2D["prop"], loc=self.legend2D["loc"], ncol=self.legend2D["ncol"], bbox_to_anchor=self.legend2D["bbox_to_anchor"], )
[docs] def set_color_bar(self, scatter): """Add colorbar. Parameters ---------- scatters : object matplotlib scatter object. """ if self.colorbar["label"] is not None: label = self.colorbar["label"] elif isinstance(self.z_var_str, str): z_var_str = self.z_var_str.replace('S1_', '').replace('S2_', '') if z_var_str in DEFAULT_LABELS.keys(): if self.log10_z: label = DEFAULT_LABELS[z_var_str][1] else: label = DEFAULT_LABELS[z_var_str][0] else: label = None else: label = None plt.colorbar( mappable=scatter, orientation=self.colorbar["orientation"], fraction=self.colorbar["fraction"], pad=self.colorbar["pad"], shrink=self.colorbar["shrink"], aspect=self.colorbar["aspect"], anchor=self.colorbar["anchor"], panchor=self.colorbar["panchor"], extend=self.colorbar["extend"], ).set_label(label=label, size=self.colorbar["label_size"])