import numpy as np
import pandas as pd # type: ignore
import matplotlib.pyplot as plt # type: ignore
from matplotlib.axes import Axes # type: ignore
from matplotlib.figure import Figure # type: ignore
import matplotlib._color_data as mcd # type: ignore
from typing import Optional, List, Union, Dict, Tuple
from collections.abc import Iterable
TABLEAU_GREY: str = "#bab0ac"
[docs]def pdos_column_names(lmax: int, ispin: int) -> List[str]:
if lmax == 2:
names = ["s", "p_y", "p_z", "p_x", "d_xy", "d_yz", "d_z2-r2", "d_xz", "d_x2-y2"]
elif lmax == 3:
names = [
"s",
"p_y",
"p_z",
"p_x",
"d_xy",
"d_yz",
"d_z2-r2",
"d_xz",
"d_x2-y2",
"f_y(3x2-y2)",
"f_xyz",
"f_yz2",
"f_z3",
"f_xz2",
"f_z(x2-y2)",
"f_x(x2-3y2)",
]
else:
raise ValueError("lmax value not supported")
if ispin == 2:
all_names = []
for n in names:
all_names.extend(["{}_up".format(n), "{}_down".format(n)])
else:
all_names = names
all_names.insert(0, "energy")
return all_names
[docs]class Doscar:
"""Contains all the data in a VASP DOSCAR file,
and methods for manipulating this.
"""
number_of_header_lines: int = 6
def __init__(
self,
filename: str,
ispin: int = 2,
lmax: int = 2,
lorbit: int = 11,
spin_orbit_coupling: bool = False,
read_pdos: bool = True,
species: Optional[List[str]] = None,
) -> None:
"""
Create a Doscar object from a VASP DOSCAR file.
Args:
filename (str): Filename of the VASP DOSCAR file to read.
ispin (optional:int): ISPIN flag.
Set to 1 for non-spin-polarised
or to 2 for spin-polarised calculations.
Default = 2.
lmax (optional:int): Maximum l angular momentum. (d=2, f=3). Default is 2.
lorbit (optional:int): The VASP LORBIT flag. (Default=11).
spin_orbit_coupling (optional:bool): Spin-orbit coupling (Default=False).
read_pdos (optional:bool): Set to True to read the atom-projected density of states (Default=True).
species (optional:list(str)): List of atomic species strings, e.g. `['Fe', 'Fe', 'O', 'O', 'O']`.
Default=None.
"""
self.filename = filename
self.ispin = ispin
self.lmax = lmax
self.spin_orbit_coupling = spin_orbit_coupling
if self.spin_orbit_coupling:
raise NotImplementedError("Spin-orbit coupling is not yet implemented")
self.lorbit = lorbit
self.pdos: Optional[np.ndarray] = None
self.species = species
self.read_header()
self.read_total_dos()
if read_pdos:
try:
self.read_projected_dos()
except:
raise
# if species is set, should check that this is consistent with the number of entries in the
# projected_dos dataset
@property
def number_of_channels(self) -> int:
if self.lorbit == 11:
return {2: 9, 3: 16}[self.lmax]
raise NotImplementedError
[docs] def read_total_dos(self) -> pd.DataFrame: # assumes spin_polarised
start_to_read: int = Doscar.number_of_header_lines
df: pd.DataFrame = pd.read_csv(
self.filename,
skiprows=start_to_read,
nrows=self.number_of_data_points,
delim_whitespace=True,
names=["energy", "up", "down", "int_up", "int_down"],
index_col=False,
)
self.energy: np.ndarray = df.energy.values
df.drop("energy", axis=1)
self.tdos = df
# currently assume spin-polarised, no-SO-coupling, no f-states
[docs] def read_atomic_dos_as_df(self, atom_number: int) -> pd.DataFrame:
assert atom_number > 0 & atom_number <= self.number_of_atoms
start_to_read = Doscar.number_of_header_lines + atom_number * (
self.number_of_data_points + 1
)
df = pd.read_csv(
self.filename,
skiprows=start_to_read,
nrows=self.number_of_data_points,
delim_whitespace=True,
names=pdos_column_names(lmax=self.lmax, ispin=self.ispin),
index_col=False,
)
return df.drop("energy", axis=1)
[docs] def read_projected_dos(self) -> None:
"""Read the projected density of states data"""
pdos_list = []
for i in range(self.number_of_atoms):
df = self.read_atomic_dos_as_df(i + 1)
pdos_list.append(df)
self.pdos = np.vstack([np.array(df) for df in pdos_list]).reshape(
self.number_of_atoms,
self.number_of_data_points,
self.number_of_channels,
self.ispin,
)
[docs] def pdos_select(
self,
atoms: Optional[Union[int, List[int]]] = None,
spin: Optional[str] = None,
l: Optional[str] = None,
m: Optional[List[str]] = None,
) -> np.ndarray:
"""
Returns a subset of the projected density of states array.
Args:
atoms (int or list(int)): Atom numbers to include in the selection. Atom numbers count from 1.
Default is to select all atoms.
spin (str): Select up or down, or both spin channels to include in the selection.
Accepted options are 'up', 'down', and 'both'. Default is to select both spins.
l (str): Select one angular momentum to include in the selectrion.
Accepted options are 's', 'p', 'd', and 'f'. Default is to include all l-values.
Setting `l` and not setting `m` will return all projections for that angular momentum value.
m (list(str)): Select one or more m-values. Requires `l` to be set.
The accepted values depend on the value of `l`:
`l='s'`: Only one projection. Not set.
`l='p'`: One or more of ['x', 'y', 'z']
`l='d'`: One or more of ['xy', 'yz', 'z2-r2', 'xz', 'x2-y2']
`l='f'`: One or more of ['y(3x2-y2)', 'xyz', 'yz2', 'z3', 'xz2', 'z(x2-y2)', 'x(x2-3y2)']
Returns:
np.array: A 4-dimensional numpy array containing the selected pdos values.
The array dimensions are [ atom_no, energy_value, lm-projection, spin ]
"""
assert isinstance(self.pdos, np.ndarray)
valid_m_values: Dict[str, List[str]] = {
"s": [],
"p": ["x", "y", "z"],
"d": ["xy", "yz", "z2-r2", "xz", "x2-y2"],
"f": ["y(3x2-y2)", "xyz", "yz2", "z3", "xz2", "z(x2-y2)", "x(x2-3y2)"],
}
if not atoms:
atom_idx = list(range(self.number_of_atoms))
else:
assert isinstance(atoms, list)
atom_idx = atoms
to_return = self.pdos[atom_idx, :, :, :]
if not spin:
spin_idx = list(range(self.ispin))
elif spin == "up":
spin_idx = [0]
elif spin == "down":
spin_idx = [1]
elif spin == "both":
spin_idx = [0, 1]
else:
raise ValueError(
"valid spin values are 'up', 'down', and 'both'. The default is 'both'"
)
to_return = to_return[:, :, :, spin_idx]
if not l:
channel_idx = list(range(self.number_of_channels))
elif l == "s":
channel_idx = [0]
elif l == "p":
if not m:
channel_idx = [1, 2, 3]
else: # TODO this looks like it should be i+1
channel_idx = [
i + 1 for i, v in enumerate(valid_m_values["p"]) if v in m
]
elif l == "d":
if not m:
channel_idx = [4, 5, 6, 7, 8]
else: # TODO this looks like it should be i+4
channel_idx = [
i + 4 for i, v in enumerate(valid_m_values["d"]) if v in m
]
elif l == "f":
if not m:
channel_idx = [9, 10, 11, 12, 13, 14, 15]
else: # TODO this looks like it should be i+9
channel_idx = [
i + 9 for i, v in enumerate(valid_m_values["f"]) if v in m
]
else:
raise ValueError
return to_return[:, :, channel_idx, :]
[docs] def pdos_sum(
self,
atoms: Optional[Union[int, List[int]]] = None,
spin: Optional[str] = None,
l: Optional[str] = None,
m: Optional[List[str]] = None,
) -> np.ndarray:
return np.array(
np.sum(self.pdos_select(atoms=atoms, spin=spin, l=l, m=m), axis=(0, 2, 3))
)
[docs] def plot_pdos(
self,
ax: Optional[Axes] = None,
to_plot: Optional[Dict[str, List[str]]] = None,
colors: Optional[Iterable] = None,
plot_total_dos: Optional[bool] = True,
xrange: Optional[Tuple[float, float]] = None,
ymax: Optional[float] = None,
scaling: Optional[Dict[str, Dict[str, float]]] = None,
split: bool = False,
title: Optional[str] = None,
title_loc: str = "center",
labels: bool = True,
title_fontsize: int = 16,
legend_pos: str = "outside",
) -> Figure:
if not ax:
fig, ax = plt.subplots(1, 1, figsize=(8.0, 3.0))
else:
fig = None
if not colors:
colors = mcd.TABLEAU_COLORS
assert isinstance(colors, Iterable)
color_iterator = (c for c in colors)
if not scaling:
scaling = {}
if xrange:
e_range = (self.energy >= xrange[0]) & (self.energy <= xrange[1])
else:
e_range = np.ma.make_mask(self.energy)
auto_ymax = 0.0
if not to_plot:
to_plot = {}
assert isinstance(self.species, Iterable)
for s in set(self.species):
to_plot[s] = ["s", "p", "d"]
if self.lmax == 3:
to_plot[s].append("f")
for species in to_plot.keys():
assert isinstance(self.species, Iterable)
index = [i for i, s in enumerate(self.species) if s == species]
for state in to_plot[species]:
assert state in ["s", "p", "d", "f"]
color = next(color_iterator)
label = "{} {}".format(species, state)
up_dos = self.pdos_sum(atoms=index, l=state, spin="up")[e_range]
down_dos = self.pdos_sum(atoms=index, l=state, spin="down")[e_range]
if species in scaling:
if state in scaling[species]:
up_dos *= scaling[species][state]
down_dos *= scaling[species][state]
label = r"{} {} $\times${}".format(
species, state, scaling[species][state]
)
auto_ymax = max([auto_ymax, up_dos.max(), down_dos.max()])
ax.plot(self.energy[e_range], up_dos, label=label, c=color)
ax.plot(self.energy[e_range], down_dos * -1.0, c=color)
if plot_total_dos:
ax.fill_between(
self.energy[e_range],
self.tdos.up.values[e_range],
self.tdos.down.values[e_range] * -1.0,
facecolor=TABLEAU_GREY,
alpha=0.2,
)
auto_ymax = max(
[
auto_ymax,
self.tdos.up.values[e_range].max(),
self.tdos.down.values[e_range].max(),
]
)
if xrange:
ax.set_xlim(xrange[0], xrange[1])
if not ymax:
ymax = 1.1 * auto_ymax
ymax = float(ymax)
ax.set_ylim(-ymax * 1.1, ymax * 1.1)
if legend_pos == "outside":
ax.legend(bbox_to_anchor=(1.01, 1.04), loc="upper left")
else:
ax.legend(loc=legend_pos)
if labels:
ax.set_xlabel("Energy [eV]")
ax.axhline(y=0, c="lightgrey")
ax.axes.grid(False, axis="y")
ax.tick_params(
axis="y", # changes apply to the y-axis
which="both", # both major and minor ticks are affected
left=False, # ticks along the left edge are off
right=False, # ticks along the right edge are off
labelleft=False,
) # labels along the left edge are off
if title:
ax.set_title(title, loc=title_loc, fontdict={"fontsize": title_fontsize})
return fig