#!/usr/bin/env python
# coding: utf-8

# (wf_array_v2)= 
# # `wf_array` to `WFArray` in v2.0
# 
# This tutorial demonstrates the new `WFArray` class in PythTB v2.0, which is an extension to the original `wf_array` class in previous versions.
# For a more comprehensive overview of all changes, please refer to the [release notes](release/2.0.0-notes).
# 
# The first thing to be aware of is that the class names have changed
# -  < v1.8: `tb_model`, `wf_array`, `w90`
# - 2.0: `TBModel`, `WFArray`, `W90`

# In[ ]:


from pythtb import Lattice, Mesh, TBModel, WFArray
import numpy as np


# If you haven't already, we recommend you read through the [TBModel v2.0 tutorial](tb_model_v2.ipynb) and [Mesh tutorial](mesh.ipynb) first, as this tutorial will build on that knowledge.
# 
# The `WFArray` class in v2.0 is aimed at addressing a few limitations of the previous `wf_array` class:
# 
# - __Mesh awareness__: The new `WFArray` class is designed to be aware of the k-point mesh structure, allowing for more efficient storage and manipulation of wavefunction data on combined $(k, \lambda)$-meshes.
# - __Faster Calculations__: The new class includes optimized methods for common operations such as overlaps, projections, etc., which are faster than the previous implementation.
# - __Flexibility__: The `WFArray` class is decoupled from the `TBModel` class, allowing for more flexible usage.
# 
# Let's look at how to use the new `WFArray` class in practice.

# ## Initializing and populating a `WFArray`
# 
# In previous versions of PythTB, the `wf_array` class was tightly coupled to the `tb_model` class. Upon initializing a `wf_array`, `wf_array` would read off the lattice information from the associated `tb_model`. In v2.0, like in `TBModel`, we will instead explicitly pass a `Lattice` object to the `WFArray` constructor. Aside from increased flexibility, this also allows us to create `WFArray` objects that are not tied to any specific `TBModel`. For instance, one may want to store states that are derived from `TBModel`s on an adiabatic cycle, which there is no single `TBModel` that describes all the states.
# 
# Compare to versions 1.x:
# 
# ```python
# # v1.x
# from pythtb import tb_model, wf_array
# model = tb_model(1, 1, lat=[[1.0, 0.0], [0.0, 1.0]], orb=[[0, 1/3, 2/3]])
# wfa = wf_array(model, [20])
# wfa.solve_on_grid(0.0)
# ```
# 
# In `solve_on_grid`, the `wf_array` will generate the k-point mesh and populate itself with the wavefunctions from `model`. Afterwords the mesh is discarded and the `wf_array` only retains the wavefunction data. This loses the information about the k-point structure, and requires the user to manually keep track of the k-point mesh if needed.
# 
# 
# In v2.0, one itializes a `WFArray` as follows:
# 
# 1. Create a `Lattice` object that describes the lattice structure of the system.

# In[ ]:


lattice = Lattice(
    lat_vecs=[[1.0]], orb_vecs=[[0.0], [1 / 3], [2 / 3]], periodic_dirs=[0]
)


# 2. Create a `Mesh` object that describes the k-point and parameter mesh.

# In[ ]:


mesh = Mesh(["k"])


# 3. Populate the `Mesh` with the desired k-points and parameter values.

# In[ ]:


mesh.build_grid([20])


# 4. Pass these objects to the `WFArray` constructor.

# In[ ]:


wfa = WFArray(lattice=lattice, mesh=mesh)


# Now if we want to populate the `WFArray` with energy eigenstates of a given `TBModel`, we can do so by calling the `WFArray.solve_model()` function.
# 

# In[ ]:


t = -1.3
delta = 2.0

model = TBModel(lattice=lattice)

# nearest-neighbour hoppings (last hop wraps to the next cell)
model.set_hop(t, 0, 1, [0])
model.set_hop(t, 1, 2, [0])
model.set_hop(t, 2, 0, [1])

# Per-orbital onsite as lambdas of lmbda
onsite = [
    lambda lmbda: delta * -np.cos(2 * np.pi * (lmbda - 0 / 3)),
    lambda lmbda: delta * -np.cos(2 * np.pi * (lmbda - 1 / 3)),
    lambda lmbda: delta * -np.cos(2 * np.pi * (lmbda - 2 / 3)),
]

model.set_onsite(onsite)


# In[ ]:


model_fixed = model.with_parameters(lmbda=0.25)
wfa.solve_model(model_fixed)


# We can access the individual wavefunctions stored in the `WFArray` using indexing, similar to how it was done in previous versions. For example, `wfa[k_index]` will return the wavefunctions at the specified mesh index.

# In[ ]:


wfa[2]


# As before, internally periodic boundary conditions are automatically applied across the BZ boundary in Berry phase calculations and other operations.

# ## `WFArray` on an adiabatic cycle
# 
# We will reuse the same lattice as above, except now we will create a `Mesh` that includes an adiabatic parameter axis `lambda`, which varies from $0$ to $1$ in 11 steps. This parameter could represent, for example, a distortion of the lattice or a change in an external field.
# 
# Recall from the [Mesh tutorial](mesh.ipynb) that we need to name the $\lambda$ axis according to the parameter name we are varying in the `TBModel` function (see function above if you missed it).

# In[ ]:


mesh = Mesh(["k", "l"], axis_names=["kx", "lmbda"])


# Now we will build the grid, specifying the range of the adiabatic parameter from $0$ to $1$ in 11 steps.

# In[ ]:


mesh.build_grid(
    shape=(31, 21),
    gamma_centered=True,
    lambda_start=0.0,
    lambda_stop=1.0,
    lambda_endpoints=True,
)


# If we want to impose the $\lambda$ axis as an adiabatic loop, we need to set `loop` for the axis index corresponding to $\lambda$. The component index is `1` (Python counting) since the combined $(k, \lambda)$-space is two dimensional, and the $\lambda$ dimensions come after the k-dimensions (of which there are 1). This indicates that traversing axis 1 winds the second (1 in Python counting) component of the mesh vector in a loop, such that the Hamiltonian at the end of the loop connects back to the Hamiltonian at the beginning of the loop.
# 
# We also mark the axis as closed, since the endpoint at $\lambda=1$ is equivalent to the starting point at $\lambda=0$. We can do this by setting the `closed` parameter to `True`.

# In[ ]:


mesh.loop(
    axis_idx=1, component_idx=1, closed=True
)  # form the lambda axis into a closed loop
print(mesh)


# We initialize the `WFArray`

# In[ ]:


wfa = WFArray(lattice, mesh)


# Finally we call the `solve_model` function to populate the `WFArray` with the eigenstates of the `TBModel` at each point in the combined $(k, \lambda)$-mesh.

# In[ ]:


wfa.solve_model(model=model)


# We can verify that the $\lambda$ axis was correctly imposed as a loop by comparing the wavefunctions at the start and end of the adiabatic cycle.

# In[ ]:


np.allclose(wfa[0, 0], wfa[0, -1])  # first k-point, first and last lambda points


# Compare to versions 1.x:
# 
# ```python
# # v1.x
# 
# path_steps = 11
# all_lambda = np.linspace(0, 1, path_steps, endpoint=True)
# num_kpt = 31
# 
# wf_kpt_lambda = wf_array(my_model,[num_kpt, path_steps])
# for i_lambda in range(path_steps):
#     lmbd = all_lambda[i_lambda]
#     model = set_model(t, delta, lmbd)
# 
#     (k_vec, k_dist, k_node) = my_model.k_path([[0],[1]],num_kpt, report=False)
#     (eval, evec) = my_model.solve_all(k_vec,eig_vectors=True)
#     for i_kpt in range(num_kpt):
#         # Manually populate the wf_array
#         wf_kpt_lambda[i_kpt, i_lambda]=evec[:,i_kpt,:]
# 
# # Need to manually impose PBCs since the wf_array does not know about the mesh structure
# wf_kpt_lambda.impose_pbc(0, 0)
# # Need to manually impose the loop in lambda
# wf_kpt_lambda.impose_loop(1, 1)
# ```
# 

# ## New methods in `WFArray` v2.0
# 
# Explore the variety of new methods available in the `WFArray` class in v2.0 by checking out the [API documentation](https://pythtb.readthedocs.io/en/dev/generated/pythtb.WFArray.html)
