#!/usr/bin/env python
# coding: utf-8

# (w90-nb)=
# # Wannier90 Silicon example

# In[ ]:


from pythtb import W90
import matplotlib.pyplot as plt
import numpy as np


# In[ ]:


silicon = W90("silicon_w90", "si")


# In[ ]:


# hard coded fermi level in eV
fermi_ev = 6.2285135


# In[ ]:


# all pair distances between the orbitals
print("Shells:\n", silicon.shells())


# We can look at how the Hamiltonian matrix elements generated by Wannier90 decay with distance. This will help us determine the cutoff radius to use in a tight-binding model constructed with `pythtb`. Most hoppings are above $1$ meV, so we will use this as our curtoff energy in this example.

# In[ ]:


# plot hopping terms as a function of distance on a log scale
(dist, ham) = silicon.dist_hop()
fig, ax = plt.subplots()
ax.scatter(dist, np.abs(ham))
ax.hlines(
    1e-3, xmin=0, xmax=max(dist), colors="r", linestyles="dashed", label="Cutoff = 1meV"
)
ax.legend()
ax.set_xlabel("Distance (A)")
ax.set_ylabel(r"$H$ (eV)")
ax.set_yscale("log")


# Now, we will generate the tight-binding model for silicon using the Wannier90 output files. This will use the Hamiltonian matrix elements generated by Wannier90 to create a `TBModel` object in `pythtb`.
# 
# :::{tip}
# It is advised to save the tight-binding model to disk with the cPickle module:
# ```python
# import cPickle
# cPickle.dump(my_model, open("store.pkl", "wb"))
# ```
# Later one can load in the model from disk in a separate script with
# ```python
# my_model = cPickle.load(open("store.pkl", "rb"))
# ```
# :::

# In[ ]:


# get tb model in which some small terms are ignored
my_model = silicon.model(
    zero_energy=fermi_ev,
    min_hopping_norm=1e-3,
)


# # Band comparison

# First, we will obtain the band structure from the Wannier90 calculation. To do this, we call the `W90.w90_bands` function, which reads the `si_band.dat`, `si_band.kpt` and `si.win` files to extract the band energies and k-point path used in the Wannier90 calculation. Setting `return_k_dist = True` returns the cumulative distance along the k-point path, which we will use for plotting. Setting `return_k_nodes = True` returns the fractional coordinates of the high-symmetry k-points along the path, as well as their labels.
# 
# :::{hint}
# Small discrepancies in the plot may arise due to the terms that were ignored in the `silicon.model` function call above.
# :::

# In[ ]:


(w90_kpt, w90_evals, w90_k_dist, w90_k_nodes, w90_k_labels) = silicon.bands_w90(
    return_k_dist=True, return_k_nodes=True
)

print("k-point labels:", w90_k_labels)
print("k-point nodes (fractional):\n", w90_k_nodes)


# For plotting purposes, we also need to know the cumulative distance along the k-point path at each high-symmetry k-point node, which we compute below.

# In[ ]:


k_vec, k_dist, k_node_dist = my_model.k_path(w90_k_nodes, nk=500, report=False)


# Next, we will solve and plot the `TBModel` on the same path as used in Wannier90. This allows us to directly compare the two band structures.

# In[ ]:


int_evals = my_model.solve_ham(w90_kpt)


# In[ ]:


fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(w90_k_dist, w90_evals[:, 0] - fermi_ev, "k-", zorder=0, label="Wannier90")
ax.plot(w90_k_dist, w90_evals[:, 1:] - fermi_ev, "k-", zorder=0)
ax.plot(w90_k_dist, int_evals[:, 0], "r--", zorder=1, label="TBModel")
ax.plot(w90_k_dist, int_evals[:, 1:], "r--", zorder=1)

# set x-ticks at k-point nodes
ax.set_xticks(k_node_dist)
for n in range(len(w90_k_nodes)):
    ax.axvline(x=k_node_dist[n], linewidth=0.5, color="k", zorder=1)
ax.set_xticklabels(w90_k_labels, size=12)
ax.set_xlim(k_node_dist[0], k_node_dist[-1])
ax.set_ylabel("Band energy (eV)")
plt.show()

