Source code for catmap.analyze.analysis_base

import catmap
from catmap import ReactionModelWrapper
from catmap.model import ReactionModel as RM
from catmap import griddata
from copy import copy
try:
    from scipy.stats import norm
except:
    norm = None
from matplotlib.ticker import MaxNLocator
import os
import math
plt = catmap.plt
pickle= catmap.pickle
np = catmap.np
spline = catmap.spline
mtransforms = catmap.mtransforms

basic_colors = [[0,0,0],[0,0,1],[0.1,1,0.1],[1,0,0],[0,1,1],[1,0.5,0],[1,0.9,0],
                [1,0,1],[0,0.5,0.5],[0.5,0.25,0.15],[0.5,0.5,0.5]]
               #black,blue,green,red,cyan,orange,yellow,magenta,turquoise,brown,gray

[docs]def get_colors(n_colors): """ Get n colors from basic_colors. :param n_colors: Number of colors :type n_colors: int """ if n_colors <len(basic_colors): return basic_colors[0:n_colors] else: longlist= basic_colors*n_colors return longlist[0:n_colors]
[docs]def boltzmann_vector(energy_list,vector_list,temperature): """ Create a vector which is a Boltzmann average of the vector_list weighted with energies in the energy_list. :param energy_list: List of energies :type energy_list: list :param vector_list: List of vectors :type energy_list: list :param temperature: Temperature :type energy_list: float """ def boltzmann_avg(es,ns,T): """ Calculate the Boltzmann average :param es: energies :type es: iterable :param ns: :type ns: iterable :param T: temperature :type T: float ..todo: description for ns """ kB = 8.613e-5 #assuming energies are in eV and T is in K es = [e-min(es) for e in es] #normalize to minimum energy exp_sum = sum([np.exp(-e/(kB*T)) for e in es]) exp_weighted = [n*np.exp(-e/(kB*T))/exp_sum for n,e in zip(ns,es)] Z = sum(exp_weighted) return Z vars = zip(*vector_list) boltz_vec = [boltzmann_avg(energy_list,v,temperature) for v in vars] return boltz_vec
[docs]class MapPlot: """ Class for generating plots using a dictionary of default plotting attributes. The following attributes can be modified: :param resolution_enhancement: Resolution enhancement for interpolated maps :type resolution_enhancement: int :param min: Minimum :type min: :param max: Maximum :type max: :param n_ticks: Number of ticks :type n_ticks: int :param descriptor_labels: Label of descriptors :type descriptor_labels: list :param default_descriptor_pt_args: Dictionary of descriptor point arguments :type default_descriptor_pt_args: dict :param default_descriptor_label_args: Dictionary of descriptor labels :type default_descriptor_label_args: dict :param descriptor_pt_args: :type descriptor_pt_args: dict :param include_descriptors: Include the descriptors :type include_descriptors: bool :param plot_size: Size of the plot :type plot_size: int :param aspect: :type aspect: :param subplots_adjust_kwargs: Dictionary of keyword arguments for adjusting matplotlib subplots :type subplots_adjust_kwargs: dict .. todo:: Some missing descriptions """
[docs] def __init__(self): defaults = dict(resolution_enhancement=1, min=None, max=None, n_ticks=8, plot_function=None, colorbar=True, colormap=plt.cm.YlGnBu_r, axis_label_decimals=2, log_scale=False, descriptor_labels=['X_descriptor', 'Y_descriptor'], default_descriptor_pt_args={'marker': 'o'}, default_descriptor_label_args={}, descriptor_pt_args={}, descriptor_label_args={}, include_descriptors=False, plot_size=4, aspect=None, subplots_adjust_kwargs={'hspace': 0.35, 'wspace': 0.35, 'bottom': 0.15}) for key in defaults: val = defaults[key] if not hasattr(self, key): setattr(self, key, val) elif getattr(self,key) is None: setattr(self,key,val)
[docs] def update_descriptor_args(self): """ Update descriptor arguments .. todo:: __doc__ """ if getattr(self,'descriptor_dict',None): if self.descriptor_pt_args == {}: for pt in self.descriptor_dict: self.descriptor_pt_args[pt] = copy( self.default_descriptor_pt_args) if self.descriptor_label_args == {}: for pt in self.descriptor_dict: self.descriptor_label_args[pt] = copy( self.default_descriptor_label_args)
[docs] def plot_descriptor_pts(self, mapp, idx, ax, plot_in=None): """ Plot descriptor points :param mapp: :type mapp: :param idx: :type idx: :param ax: axes object :param plot_in: :type plot_in: .. todo:: __doc__ """ if getattr(self,'descriptor_dict',None): self.update_descriptor_args() xy,rates = zip(*list(mapp)) dim = len(xy[0]) for key in self.descriptor_dict: pt_kwargs = self.descriptor_pt_args.get(key, self.default_descriptor_pt_args) lab_kwargs = self.descriptor_label_args.get(key, self.default_descriptor_label_args) if dim == 1: # x will be descriptor values. y will be rate/coverage/etc. x,y = self.descriptor_dict[key] y_sp = catmap.spline(plot_in[0], plot_in[1], k=1) y = y_sp(x) elif dim == 2: x,y = self.descriptor_dict[key] if None not in [x,y]: if pt_kwargs is not None: ax.errorbar(x,y,**pt_kwargs) if lab_kwargs is not None: ax.annotate(key,[x,y],**lab_kwargs) if dim == 1: ax.set_xlim(self.descriptor_ranges[0]) elif dim == 2: ax.set_xlim(self.descriptor_ranges[0]) ax.set_ylim(self.descriptor_ranges[1])
[docs] def plot_single(self, mapp, rxn_index, ax=None, overlay_map = None, alpha_range=None, **plot_args): """ :param mapp: :param rxn_index: Index for the reaction :type rxn_index: int :param ax: axes object :param overlay_map: :type overlay_map: :type alpha_range: :type alpha_range: .. todo:: __doc__ """ if not ax: fig = plt.figure() ax = fig.add_subplot(111) xy,rates = zip(*list(mapp)) dim = len(xy[0]) if dim == 1: x = list(zip(*xy))[0] descriptor_ranges = [[min(x),max(x)]] if not self.plot_function: if self.log_scale == True: self.plot_function = 'semilogy' else: self.plot_function = 'plot' elif dim == 2: x,y = zip(*xy) descriptor_ranges = [[min(x),max(x)],[min(y),max(y)]] if not self.plot_function: self.plot_function = 'contourf' if 'cmap' not in plot_args: plot_args['cmap'] = self.colormap eff_res =self.resolution*self.resolution_enhancement if self.min: minval = self.min else: minval = None maparray = RM.map_to_array(mapp,descriptor_ranges,eff_res, log_interpolate=self.log_scale,minval=minval) if self.max is None: self.max = maparray.T[rxn_index].max() if self.min is None: self.min = maparray.T[rxn_index].min() if dim == 2: if maparray.min() <= self.min: plot_args['extend'] = 'min' if maparray.max() >= self.max: plot_args['extend'] = 'max' if maparray.max() >= self.max and maparray.min() <= self.min: plot_args['extend'] = 'both' if 'extend' not in plot_args: plot_args['extend'] = 'neither' if self.log_scale and dim == 2: maparray = np.log10(maparray) min_val = np.log10(float(self.min)) max_val = np.log10(float(self.max)) if min_val < -200: min_val = max(maparray.min(),-200) elif max_val == np.inf: max_val = min(maparray.max(),200) else: min_val = self.min max_val = self.max maparray = np.clip(maparray,min_val,max_val) log_scale = self.log_scale if overlay_map: overlay_array = RM.map_to_array(overlay_map, descriptor_ranges,eff_res) if alpha_range: alpha_min,alpha_max = alpha_range else: alpha_min = overlay_array.min() alpha_max = overlay_array.max() overlay_array = (overlay_array - overlay_array.min()) overlay_array = overlay_array/(alpha_max - alpha_min) overlay_array = np.clip(overlay_array,0,1) maparray = np.clip(maparray,min_val,max_val) norm_array = (maparray - maparray.min()) norm_array = norm_array/(maparray.max()-maparray.min()) maparray = norm_array*overlay_array maparray = (maparray - maparray.min()) maparray = maparray/(maparray.max()-maparray.min()) maparray = maparray*(max_val-min_val) + min_val maparray=norm_array*overlay_array norm_array = (maparray - maparray.min()) norm_array = norm_array/(maparray.max()-maparray.min()) maparray = norm_array*(max_val-min_val)+min_val if dim == 1: x_range = descriptor_ranges[0] plot_in = [np.linspace(*x_range+eff_res),maparray[:,rxn_index]] plot = getattr(ax,self.plot_function)(*plot_in) elif dim == 2: x_range,y_range = descriptor_ranges z = maparray[:,:,rxn_index] if self.log_scale: levels = range(int(min_val),int(max_val)+1) if len(levels) < 3*self.n_ticks: levels = np.linspace( int(min_val),int(max_val),3*self.n_ticks) else: # python 3 cannot do int < list, thus # we look at the first element if it is # a list. levels = np.linspace(min_val,max_val,min(eff_res if type(eff_res) is int else eff_res[0],25)) plot_in = [np.linspace(*x_range+[eff_res[0]]), np.linspace(*y_range+[eff_res[1]]),z,levels] plot = getattr(ax,self.plot_function)(*plot_in,**plot_args) pos = ax.get_position() if self.aspect: ax.set_aspect(self.aspect) ax.apply_aspect() if dim == 1: ax.set_xlim(descriptor_ranges[0]) ax.set_xlabel(self.descriptor_labels[0]) ax.set_ylim([float(self.min), float(self.max)]) elif dim == 2: if self.colorbar: if log_scale: #take only integer tick labels cbar_nums = range(int(min_val),int(max_val)+1) mod = max(int(len(cbar_nums)/self.n_ticks), 1) cbar_nums = [n for i,n in enumerate(cbar_nums) if not i%mod] cbar_nums = np.array(cbar_nums) else: cbar_nums = np.linspace(min_val,max_val,self.n_ticks) formatstring = '%.'+str(self.axis_label_decimals)+'g' cbar_labels = [formatstring % (s,) for s in cbar_nums] cbar_labels = [lab.replace('e-0','e-').replace('e+0','e') for lab in cbar_labels] plot.set_clim(min_val,max_val) fig = ax.get_figure() axpos = list(ax.get_position().bounds) xsize = axpos[2]*0.04 ysize = axpos[3] xp = axpos[0]+axpos[2]+0.04*axpos[2] yp = axpos[1] cbar_box = [xp,yp,xsize,ysize] cbar_ax = fig.add_axes(cbar_box) cbar = fig.colorbar(mappable=plot,ticks=cbar_nums, cax=cbar_ax,extend=plot_args['extend']) cbar.ax.set_yticklabels(cbar_labels) if getattr(self,'colorbar_label',None): cbar_kwargs = getattr(self,'colorbar_label_kwargs',{'rotation':-90}) cbar_ax.set_ylabel(self.colorbar_label,**cbar_kwargs) if self.descriptor_labels: ax.set_xlabel(self.descriptor_labels[0]) ax.set_ylabel(self.descriptor_labels[1]) ax.set_xlim(descriptor_ranges[0]) ax.set_ylim(descriptor_ranges[1]) if 'title' in plot_args and plot_args['title']: if 'title_size' not in plot_args: n_pts = self.plot_size*72 font_size = min([n_pts/len(plot_args['title']),14]) else: font_size = plot_args['title_size'] ax.set_title(plot_args['title'],size=font_size) if getattr(self,'n_xticks',None): ax.xaxis.set_major_locator(MaxNLocator(self.n_xticks)) if getattr(self,'n_yticks',None): ax.yaxis.set_major_locator(MaxNLocator(self.n_yticks)) self.plot_descriptor_pts(mapp,rxn_index,ax=ax,plot_in=plot_in) return ax
[docs] def plot_separate(self,mapp,ax_list=None,indices=None, overlay_map = None,**plot_single_kwargs): """ Generate separate plots .. todo:: __doc__ """ list_mapp = list(mapp) pts,rates = list(zip(*list(mapp))) if indices is None: indices = range(0,len(rates[0])) n_plots = len(indices) if not ax_list: x = int(np.sqrt(n_plots)) if x*x < n_plots: y = x+1 else: y = x if x*y < n_plots: x = x+1 if self.colorbar: fig = plt.figure( figsize=(y*self.plot_size*1.25,x*self.plot_size)) else: fig = plt.figure(figsize=(y*self.plot_size,x*self.plot_size)) ax_list = [] for i in range(0,n_plots): ax_list.append(fig.add_subplot(x,y,i+1)) else: fig = ax_list[0].get_figure() if fig: fig.subplots_adjust(**self.subplots_adjust_kwargs) else: fig = plt.gcf() fig.subplots_adjust(**self.subplots_adjust_kwargs) plotnum = 0 old_dict = copy(self.__dict__) if not self.min or not self.max: for id,i in enumerate(indices): pts, datas = zip(*list(mapp)) dat_min = 1e99 dat_max = -1e99 for col in zip(*datas): if min(col) < dat_min: dat_min = min(col) if max(col) > dat_max: dat_max = max(col) if self.min is None: self.min = dat_min if self.max is None: self.max = dat_max for id,i in enumerate(indices): kwargs = plot_single_kwargs if self.map_plot_labels: try: kwargs['title'] = self.map_plot_labels[i] except IndexError: kwargs['title'] = '' kwargs['overlay_map'] = overlay_map self.__dict__.update(old_dict) self.plot_single(mapp,i,ax=ax_list[plotnum],**kwargs) plotnum+=1 return fig
[docs] def plot_weighted(self,mapp,ax=None,weighting='linear', second_map=None,indices=None,**plot_args): """ Generate weighted plot :param mapp: :type mapp: :param ax: axes object :param weighting: weighting function, 'linear' or 'dual'. :type weighting: str :param second_map: :param indices: .. todo:: __doc__ """ if ax is None: fig = plt.figure() ax = fig.add_subplot(111) else: fig = ax.get_figure() if self.color_list is None: color_list = get_colors(len(mapp[0][-1])+1) color_list.pop(0) #remove black else: color_list = self.color_list pts,datas = zip(*list(mapp)) if indices is None: indices = range(0,len(datas[0])) rgbs = [] datas = zip(*datas) datas = [d for id,d in enumerate(datas) if id in indices] datas = zip(*datas) if second_map: pts2,datas2 = zip(*second_map) datas2 = zip(*datas2) datas2 = [d for id,d in enumerate(datas2) if id in indices] datas2 = zip(*datas2) else: datas2 = datas for data,data2 in zip(datas,datas2): if weighting=='linear': rs,gs,bs = zip(*color_list) r = 1 - sum(float((1-ri)*di) for ri,di in zip(rs,data)) g = 1 - sum(float((1-gi)*di) for gi,di in zip(gs,data)) b = 1 - sum(float((1-bi)*di) for bi,di in zip(bs,data)) eff_res = self.resolution*self.resolution_enhancement rgbs.append([r,g,b]) elif weighting =='dual': rs,gs,bs = zip(*color_list) r = 1 - sum(float((1-ri)*di*d2i) for ri,di,d2i in zip(rs,data,data2)) g = 1 - sum(float((1-gi)*di*d2i) for gi,di,d2i in zip(gs,data,data2)) b = 1 - sum(float((1-bi)*di*d2i) for bi,di,d2i in zip(bs,data,data2)) eff_res = 300 rgbs.append([r,g,b]) r,g,b = zip(*rgbs) x,y = zip(*pts) xi = np.linspace(min(x),max(x),eff_res) yi = np.linspace(min(y),max(y),eff_res) ri = griddata(x,y,r,xi,yi) gi = griddata(x,y,g,xi,yi) bi = griddata(x,y,b,xi,yi) rgb_array = np.zeros((eff_res,eff_res,3)) for i in range(0,eff_res): for j in range(0,eff_res): rgb_array[i,j,0] = ri[i,j] rgb_array[i,j,1] = gi[i,j] rgb_array[i,j,2] = bi[i,j] xminmax,yminmax = self.descriptor_ranges xmin,xmax = xminmax ymin,ymax = yminmax ax.imshow(rgb_array,extent=[xmin,xmax,ymin,ymax],origin='lower') self.plot_descriptor_pts(mapp, i, ax) if getattr(self,'n_xticks',None): ax.xaxis.set_major_locator(MaxNLocator(self.n_xticks)) if getattr(self,'n_yticks',None): ax.yaxis.set_major_locator(MaxNLocator(self.n_yticks)) ax.set_xlabel(self.descriptor_labels[0]) ax.set_ylabel(self.descriptor_labels[1]) if self.aspect: ax.set_aspect(self.aspect) ax.apply_aspect() return fig
[docs] def save(self, fig, save=True, default_name='map_plot.pdf'): """ :param fig: figure object :param save: save the figure :type save: bool :param default_name: default name for the saved figure. :type default: str """ if save == True: if not hasattr(self,'output_file'): save = default_name else: save = self.output_file if save: fig.savefig(save)
[docs]class MechanismPlot: """ Class for generating potential energy diagrams :param energies: list of energies :type energies: list :param barriers: list of barriers :type barriers: list :param labels: list of labels :type labels: list """
[docs] def __init__(self,energies,barriers=[],labels=[]): self.energies = energies self.barriers = barriers self.labels = labels self.energy_line_args = {'color':'k','lw':2} self.barrier_line_args = {'color':'k','lw':2} self.label_args = {'color':'k','size':16,'rotation':45} self.label_positions= None self.initial_energy = 0 self.initial_stepnumber = 0 self.energy_mode ='relative' #absolute self.energy_line_widths = 0.5
[docs] def draw(self, ax=None): """ Draw the potential energy diagram .. todo:: __doc__ """ def attr_to_list(attrname,required_length=len(self.energies)): """ Return list of attributes :param attrname: Name of attributes :type attrname: list :param required_length: Required length for the list of attributes :type required_length: int .. todo:: __doc__ """ try: getattr(self,attrname)[0] #Ensure that it is a list iter(getattr(self,attrname)) #Ensure that it is a list... if len(getattr(self,attrname)) == required_length: pass else: raise ValueError(attrname + ' list is of length '+ \ str(len(getattr(self,attrname)))+ \ ', but needs to be of length ' + \ str(required_length)) return getattr(self,attrname) except: return [getattr(self,attrname)]*required_length barrier_line_args = attr_to_list('barrier_line_args', len(self.energies)-1) energy_line_widths = attr_to_list('energy_line_widths') energy_line_args = attr_to_list('energy_line_args') label_args =attr_to_list('label_args') label_positions=attr_to_list('label_positions') #plot energy lines energy_list = np.array(self.energies) energy_list = (energy_list - energy_list[0]) energy_list = list(energy_list) if self.energy_mode == 'relative': cum_energy = [energy_list[0]] for i,e in enumerate(energy_list[1:]): last = cum_energy[i]+e cum_energy.append(last) energy_list = cum_energy energy_list = np.array(energy_list) + self.initial_energy energy_list = list(energy_list) energy_lines = [ [[i+self.initial_stepnumber,i+width+self.initial_stepnumber], [energy_list[i]]*2] for i,width in enumerate(energy_line_widths)] self.energy_lines = energy_lines for i,line in enumerate(energy_lines): ax.plot(*line,**energy_line_args[i]) #create barrier lines barrier_lines = [] if not self.barriers: self.barriers = [0]*(len(self.energies)-1) for i,barrier in enumerate(self.barriers): xi = energy_lines[i][0][1] xf = energy_lines[i+1][0][0] yi = energy_lines[i][1][0] yf = energy_lines[i+1][1][0] if self.energy_mode == 'relative' and (barrier == 0 or barrier <= yf-yi): line = [[xi,xf],[yi,yf]] xts = (xi+xf)/2. yts = max([yi,yf]) elif self.energy_mode == 'absolute' and (barrier <= yf or barrier <= yi): line = [[xi,xf],[yi,yf]] xts = (xi+xf)/2. yts = max([yi,yf]) else: if self.energy_mode == 'relative': yts = yi+barrier elif self.energy_mode == 'absolute': yts = barrier barrier = yts - yi barrier_rev = barrier + (yi-yf) if barrier > 0 and barrier_rev > 0: ratio = np.sqrt(barrier)/(np.sqrt(barrier)+np.sqrt(barrier_rev)) else: print('Warning: Encountered barrier less than 0') ratio = 0.0001 yts = max(yi,yf) xts = xi + ratio*(xf-xi) xs = [xi,xts,xf] ys = [yi,yts,yf] f = spline(xs,ys,k=2) newxs = np.linspace(xi,xf,20) newys = f(newxs) line = [newxs,newys] barrier_lines.append(line) self.barrier_lines = barrier_lines #plot barrier lines for i,line in enumerate(barrier_lines): ax.plot(*line,**barrier_line_args[i]) #add labels trans = ax.get_xaxis_transform() for i,label in enumerate(self.labels): xpos = sum(energy_lines[i][0])/len(energy_lines[i][0]) label_position = label_positions[i] args = label_args[i] if label_position in ['top','ymax']: if 'ha' not in args: args['ha'] = 'left' if 'va' not in args: args['va'] = 'bottom' ypos = 1 args['transform'] = trans ax.text(xpos,ypos,label,**args) elif label_position in ['bot','bottom','ymin']: ypos = -0.1 ax.xaxis.set_ticks([float(sum(line[0])/len(line[0])) for line in energy_lines]) ax.set_xticklabels(self.labels) for attr in args.keys(): try: [getattr(t,'set_'+attr)(args[attr]) for t in ax.xaxis.get_ticklabels()] except: pass elif label_position in ['omit']: pass else: ypos = energy_lines[i][1][0] if 'ha' not in args:# and 'textcoords' not in args: args['ha'] = 'left' if 'va' not in args:# and 'textcoords' not in args: args['va'] = 'bottom' ax.annotate(label,[xpos,ypos],**args)
[docs]class ScalingPlot: """ :param descriptor_names: list of descriptor names :type descriptor_names: list :param descriptor_dict: dictionary of descriptors :type descriptor_dict: dict :param surface_names: list of the surface names :type surface_names: list :param parameter_dict: dictionary of parameters :type parameter_dict: dict :param scaling_function: function to project descriptors into energies. Should take descriptors as an argument and return a dictionary of {adsorbate:energy} pairs. :type scaling_function: function :param x_axis_function: function to project descriptors onto the x-axis. Should take descriptors as an argument and return a dictionary of {adsorbate:x_value} pairs. :type x_axis_function: function :param scaling_function_kwargs: keyword arguments for scaling_function. :type scaling_function_kwargs: dict :param x_axis_function_kwargs: keyword arguments for x_axis_function. :type x_axis_function_kwargs: dict """
[docs] def __init__(self,descriptor_names,descriptor_dict,surface_names, parameter_dict,scaling_function,x_axis_function, scaling_function_kwargs={},x_axis_function_kwargs={}, ): self.descriptor_names = descriptor_names self.surface_names = surface_names self.descriptor_dict = descriptor_dict self.parameter_dict = parameter_dict self.scaling_function = scaling_function self.scaling_function_kwargs = scaling_function_kwargs self.x_axis_function = x_axis_function self.x_axis_function_kwargs = x_axis_function_kwargs self.axis_label_size = 16 self.surface_label_size = 16 self.title_size = 18 self.same_scale = True self.show_titles = True self.show_surface_labels = True self.subplots_adjust_kwargs = {'wspace':0.4,'hspace':0.4} self.x_label_dict = {} self.y_label_dict = {} self.surface_colors = [] self.scaling_line_args = {} self.label_args = {} self.line_args = {} self.include_empty = True self.include_error_histogram = True
[docs] def plot(self, ax_list=None, plot_size=4.0, save=None): """ :param ax_list: list of axes objects :type ax_list: [ax] :param plot_size: size of the plot :type plot_size: float :param save: whether or not to save the plot :type save: bool .. todo:: __doc__ """ all_ads = self.adsorbate_names + self.transition_state_names all_ads = [a for a in all_ads if a in self.parameter_dict.keys() and a not in self.echem_transition_state_names] if self.include_empty: ads_names = all_ads else: ads_names = [n for n in all_ads if (None in self.parameter_dict[n] or sum(self.parameter_dict[n])>0.0)] if not self.surface_colors: self.surface_colors = get_colors(len(self.surface_names)) if not self.scaling_line_args: self.scaling_line_args = [{'color':'k'}]*len(ads_names) elif hasattr(self.scaling_line_args,'update'): #its a dictionary if so. self.scaling_line_args = [self.scaling_line_args]*len( self.adsorbate_names) for d in self.descriptor_names: if not self.include_descriptors: if d in ads_names: ads_names.remove(d) if self.include_error_histogram: extra = 1 else: extra = 0 if not ax_list: spx = round(np.sqrt(len(ads_names)+extra)) spy = round(np.sqrt(len(ads_names)+extra)) if spy*spx < len(ads_names)+extra: spy+= 1 fig = plt.figure(figsize=(spy*plot_size,spx*plot_size)) ax_list = [fig.add_subplot(spx,spy,i+1) for i in range(len(ads_names))] else: fig = None all_xs, all_ys = zip(*[self.descriptor_dict[s] for s in self.surface_names]) fig.subplots_adjust(**self.subplots_adjust_kwargs) all_ys = [] maxyrange = 0 ymins = [] all_err = [] for i,ads in enumerate(ads_names): actual_y_vals = self.parameter_dict[ads] desc_vals = [self.descriptor_dict[s] for s in self.surface_names] scaled_x_vals = [self.x_axis_function( d,**self.x_axis_function_kwargs)[0][ads] for d in desc_vals] label = self.x_axis_function( desc_vals[0],**self.x_axis_function_kwargs)[-1][ads] scaled_y_vals = [self.scaling_function( d,**self.scaling_function_kwargs)[ads] for d in desc_vals] diffs = [scaled-actual for scaled,actual in zip(scaled_y_vals,actual_y_vals) if actual != None] ax = ax_list[i] m,b = plt.polyfit(scaled_x_vals,scaled_y_vals,1) x_vals = np.array([round(min(scaled_x_vals),1)-0.1, round(max(scaled_x_vals),1)+0.1]) ax.plot(x_vals,m*x_vals+b,**self.scaling_line_args[i]) err = [yi - (m*xi+b) for xi,yi in zip(scaled_x_vals,actual_y_vals) if yi != None] all_err += err ax.set_xlabel(label) ax.set_ylabel('$E_{'+ads+'}$ [eV]') num_y_vals = [] # for s,c in zip(self.surface_names,self.surface_colors): # print s, c for sf,col,x,y in zip(self.surface_names, self.surface_colors,scaled_x_vals,actual_y_vals): if y and y != None: ax.plot(x,y,'o',color=col,markersize=10,mec=None) if self.show_surface_labels: ax.annotate(sf,[x,y],color=col,**self.label_args) num_y_vals.append(y) if self.show_titles: ax.set_title('$'+ads+'$',size=self.title_size) all_ys += num_y_vals if not num_y_vals: num_y_vals = scaled_y_vals dy = max(num_y_vals) - min(num_y_vals) ymins.append([min(num_y_vals),max(num_y_vals)]) if dy > maxyrange: maxyrange = dy ax.set_xlim(x_vals) y_range = [round(min(num_y_vals),1)-0.1, round(max(num_y_vals),1)+0.1] self.scaling_error = all_err if self.same_scale == True: for i,ax in enumerate(ax_list): pad = maxyrange - (ymins[i][1]-ymins[i][0]) y_range = [round(ymins[i][0]-pad,1)-0.1, round(ymins[i][1]+pad,1)+0.1] ax.set_ylim(y_range) if self.include_error_histogram: err_ax = fig.add_subplot(spx,spy,len(ads_names)+1) err_ax.hist(all_err,bins=15) err_ax.set_xlabel('$E_{actual} - E_{scaled}$ [eV]') err_ax.set_ylabel('Counts') ax_list.append(err_ax) for ax in ax_list: if getattr(self,'n_xticks',None): ax.xaxis.set_major_locator(MaxNLocator(self.n_xticks)) if getattr(self,'n_yticks',None): ax.yaxis.set_major_locator(MaxNLocator(self.n_yticks)) if save is None: save = self.model_name+'_scaling.pdf' if save: fig.savefig(save) return fig