"""
Firetasks for FWs.
"""
import shutil
import json
import os
import numpy as np
from monty.json import jsanitize
from spglib import standardize_cell
from pymatgen.core.structure import Structure
from pymatgen.io.vasp import Incar, Outcar
from pytopomat.irrep_caller import IrrepCaller, IrrepOutput
from pytopomat.irvsp_caller import IRVSPCaller, IRVSPOutput
from pytopomat.vasp2trace_caller import (
Vasp2TraceCaller,
Vasp2Trace2Caller,
Vasp2TraceOutput,
)
from pytopomat.z2pack_caller import Z2PackCaller
from fireworks import explicit_serialize, FiretaskBase, FWAction
from fireworks.utilities.fw_serializers import DATETIME_HANDLER
from atomate.utils.utils import env_chk, get_logger
from atomate.vasp.database import VaspCalcDb
logger = get_logger(__name__)
[docs]@explicit_serialize
class RunIrrep(FiretaskBase):
"""
Execute irrep in current directory.
"""
[docs] def run_task(self, fw_spec):
wd = os.getcwd()
IrrepCaller(wd)
try:
raw_struct = Structure.from_file(wd + "/POSCAR")
formula = raw_struct.composition.formula
structure = raw_struct.as_dict()
outcar = Outcar(wd + "/OUTCAR")
efermi = outcar.efermi
nelect = outcar.nelect
except FileNotFoundError:
formula = None
structure = None
efermi = None
nelect = None
data = IrrepOutput(wd + "/outir.txt", efermi=efermi)
return FWAction(
update_spec={
"irrep_out": data.as_dict(),
"structure": structure,
"formula": formula,
"efermi": efermi,
"nelect": nelect,
}
)
[docs]@explicit_serialize
class RunIRVSP(FiretaskBase):
"""
Execute IRVSP in current directory.
"""
[docs] def run_task(self, fw_spec):
wd = os.getcwd()
IRVSPCaller(wd)
try:
raw_struct = Structure.from_file(wd + "/POSCAR")
formula = raw_struct.composition.formula
structure = raw_struct.as_dict()
outcar = Outcar(wd + "/OUTCAR")
efermi = outcar.efermi
nelect = outcar.nelect
except FileNotFoundError:
formula = None
structure = None
efermi = None
nelect = None
data = IRVSPOutput(wd + "/outir.txt")
return FWAction(
update_spec={
"irvsp_out": data.as_dict(),
"structure": structure,
"formula": formula,
"efermi": efermi,
"nelect": nelect,
}
)
[docs]@explicit_serialize
class StandardizeCell(FiretaskBase):
"""
Standardize primitive cell with spglib and symprec=1e-2.
"""
[docs] def run_task(self, fw_spec):
wd = os.getcwd()
struct = Structure.from_file(wd + "/POSCAR")
numbers = [site.specie.number for site in struct]
lattice = struct.lattice.matrix
positions = struct.frac_coords
if "magmom" in struct.site_properties:
magmoms = struct.site_properties["magmom"]
cell = (lattice, positions, numbers, magmoms)
else:
magmoms = None
cell = (lattice, positions, numbers)
lat, pos, nums = standardize_cell(cell, to_primitive=True, symprec=1e-2)
structure = Structure(lat, nums, pos)
if magmoms is not None:
structure.add_site_property("magmom", magmoms)
structure.to(fmt="poscar", filename="CONTCAR")
return FWAction(update_spec={"structure": structure})
[docs]@explicit_serialize
class IrrepToDb(FiretaskBase):
"""
Stores data from outir.txt that is output by irrep.
required_params:
irrep_out (IrrepOutput): output from irrep calculation.
wf_uuid (str): unique wf id
optional_params:
db_file (str): path to the db file
additional_fields (dict): dict of additional fields to add
"""
required_params = ["irrep_out", "wf_uuid"]
optional_params = ["db_file", "additional_fields"]
[docs] def run_task(self, fw_spec):
irrep = self["irrep_out"] or fw_spec["irrep_out"]
irrep = jsanitize(irrep)
additional_fields = self.get("additional_fields", {})
d = additional_fields.copy()
d["wf_uuid"] = self["wf_uuid"]
d["formula"] = fw_spec["formula"]
d["efermi"] = fw_spec["efermi"]
d["nelect"] = fw_spec["nelect"]
d["structure"] = fw_spec["structure"]
d["irrep"] = irrep
# store the results
db_file = env_chk(self.get("db_file"), fw_spec)
if not db_file:
with open("irrep.json", "w") as f:
f.write(json.dumps(d, default=DATETIME_HANDLER))
else:
db = VaspCalcDb.from_db_file(db_file, admin=True)
db.collection = db.db["irrep"]
db.collection.insert_one(d)
logger.info("Irrep calculation complete.")
return FWAction()
[docs]@explicit_serialize
class IRVSPToDb(FiretaskBase):
"""
Stores data from outir.txt that is output by irvsp.
required_params:
irvsp_out (IRVSPOutput): output from IRVSP calculation.
wf_uuid (str): unique wf id
optional_params:
db_file (str): path to the db file
additional_fields (dict): dict of additional fields to add
"""
required_params = ["irvsp_out", "wf_uuid"]
optional_params = ["db_file", "additional_fields"]
[docs] def run_task(self, fw_spec):
irvsp = self["irvsp_out"] or fw_spec["irvsp_out"]
irvsp = jsanitize(irvsp)
additional_fields = self.get("additional_fields", {})
d = additional_fields.copy()
d["wf_uuid"] = self["wf_uuid"]
d["formula"] = fw_spec["formula"]
d["efermi"] = fw_spec["efermi"]
d["nelect"] = fw_spec["nelect"]
d["structure"] = fw_spec["structure"]
d["irvsp"] = irvsp
# store the results
db_file = env_chk(self.get("db_file"), fw_spec)
if not db_file:
with open("irvsp.json", "w") as f:
f.write(json.dumps(d, default=DATETIME_HANDLER))
else:
db = VaspCalcDb.from_db_file(db_file, admin=True)
db.collection = db.db["irvsp"]
db.collection.insert_one(d)
logger.info("IRVSP calculation complete.")
return FWAction()
[docs]@explicit_serialize
class Vasp2TraceToDb(FiretaskBase):
"""
Stores data from traces.txt that is output by vasp2trace.
optional_params:
db_file (str): path to the db file
"""
required_params = ["vasp2trace_out"]
optional_params = ["db_file"]
[docs] def run_task(self, fw_spec):
v2t = self["vasp2trace_out"] or fw_spec["vasp2trace_out"]
v2t = jsanitize(v2t)
d = {
"formula": fw_spec["formula"],
"structure": fw_spec["structure"],
"vasp2trace": v2t,
}
# store the results
db_file = env_chk(self.get("db_file"), fw_spec)
if not db_file:
with open("vasp2trace.json", "w") as f:
f.write(json.dumps(d, default=DATETIME_HANDLER))
else:
db = VaspCalcDb.from_db_file(db_file, admin=True)
db.collection = db.db["vasp2trace"]
db.collection.insert_one(d)
logger.info("Vasp2trace calculation complete.")
return FWAction()
[docs]@explicit_serialize
class RunVasp2Trace(FiretaskBase):
"""
Execute vasp2trace in current directory.
"""
[docs] def run_task(self, fw_spec):
wd = os.getcwd()
Vasp2TraceCaller(wd)
try:
raw_struct = Structure.from_file(wd + "/POSCAR")
formula = raw_struct.composition.formula
structure = raw_struct.as_dict()
except FileNotFoundError:
formula = None
structure = None
data = Vasp2TraceOutput(wd + "/trace.txt")
return FWAction(
update_spec={
"vasp2trace_out": data.as_dict(),
"structure": structure,
"formula": formula,
}
)
[docs]@explicit_serialize
class RunVasp2TraceMagnetic(FiretaskBase):
"""
Execute vasp2trace in current directory with spin-polarized calculation.
"""
[docs] def run_task(self, fw_spec):
wd = os.getcwd()
Vasp2Trace2Caller(wd) # version2 of vasp2trace for spin-polarized calcs
try:
raw_struct = Structure.from_file(wd + "/POSCAR")
formula = raw_struct.composition.formula
structure = raw_struct.as_dict()
except FileNotFoundError:
formula = None
structure = None
up_data = Vasp2TraceOutput(wd + "/trace_up.txt")
down_data = Vasp2TraceOutput(wd + "/trace_dn.txt")
return FWAction(
update_spec={
"vasp2trace_out": {
"up": up_data.as_dict(),
"down": down_data.as_dict(),
},
"structure": structure,
"formula": formula,
}
)
[docs]@explicit_serialize
class SetUpZ2Pack(FiretaskBase):
"""
Set up input files for a z2pack run.
required_params:
ncl_magmoms (str): 3*natoms long array of x,y,z magmoms for each ion.
"""
required_params = ["ncl_magmoms", "wf_uuid", "db_file"]
[docs] def run_task(self, fw_spec):
ncl_magmoms = self["ncl_magmoms"]
# Get num of electrons and bands from static calc
uuid = self["wf_uuid"]
db_file = env_chk(self.get("db_file"), fw_spec)
db = VaspCalcDb.from_db_file(db_file, admin=True)
db.collection = db.db["tasks"]
task_doc = db.collection.find_one(
{"wf_meta.wf_uuid": uuid, "task_label": "static"}, ["input.parameters"]
)
nbands = int(task_doc["input"]["parameters"]["NBANDS"])
incar = Incar.from_file("INCAR")
# Modify INCAR for Z2Pack
incar_update = {
"PREC": "Accurate",
"LSORBIT": ".TRUE.",
"GGA_COMPAT": ".FALSE.",
"LASPH": ".TRUE.",
"ISMEAR": 0,
"SIGMA": 0.05,
"ISYM": -1,
"LPEAD": ".FALSE.",
"LWANNIER90": ".TRUE.",
"LWRITE_MMN_AMN": ".TRUE.",
"LWAVE": ".FALSE.",
"ICHARG": 11,
"MAGMOM": "%s" % ncl_magmoms,
"NBANDS": "%d" % (2 * nbands),
}
incar.update(incar_update)
incar.write_file("INCAR")
try:
struct = Structure.from_file("POSCAR")
formula = struct.composition.formula
reduced_formula = struct.composition.reduced_formula
structure = struct.as_dict()
except FileNotFoundError:
formula = None
structure = None
reduced_formula = None
files_to_copy = ["CHGCAR", "INCAR", "POSCAR", "POTCAR", "wannier90.win"]
os.mkdir("input")
for file in files_to_copy:
shutil.move(file, "input")
return FWAction(
update_spec={
"structure": structure,
"formula": formula,
"reduced_formula": reduced_formula,
}
)
[docs]@explicit_serialize
class RunZ2Pack(FiretaskBase):
"""
Call Z2Pack.
required_params:
surface (str): TRIM surface, e.g. k_x = 0 or k_x = 1/2.
"""
required_params = ["surface"]
[docs] def run_task(self, fw_spec):
z2pc = Z2PackCaller(input_dir="input", surface=self["surface"])
z2pc.run(z2_settings=None)
data = z2pc.output
return FWAction(update_spec={self["surface"]: data.as_dict()})
[docs]@explicit_serialize
class Z2PackToDb(FiretaskBase):
"""
Stores data from running Z2Pack.
optional_params:
db_file (str): path to the db file
"""
optional_params = ["db_file", "wf_uuid"]
[docs] def run_task(self, fw_spec):
wf_uuid = self["wf_uuid"]
surfaces = ["kx_0", "kx_1", "ky_0", "ky_1", "kz_0", "kz_1"]
d = {
"wf_uuid": wf_uuid,
"formula": fw_spec["formula"],
"reduced_formula": fw_spec["reduced_formula"],
"structure": fw_spec["structure"],
}
for surface in surfaces:
if surface in fw_spec.keys():
d[surface] = fw_spec[surface]
d = jsanitize(d)
# store the results
db_file = env_chk(self.get("db_file"), fw_spec)
if not db_file:
with open("z2pack.json", "w") as f:
f.write(json.dumps(d, default=DATETIME_HANDLER))
else:
db = VaspCalcDb.from_db_file(db_file, admin=True)
db.collection = db.db["z2pack"]
db.collection.insert_one(d)
logger.info("Z2Pack surface calculation complete.")
return FWAction()
[docs]@explicit_serialize
class WriteWannier90Win(FiretaskBase):
"""
Write the wannier90.win input file for Z2Pack.
required_params:
wf_uuid (str): Unique identifier
db_file (str): path to the db file
"""
required_params = ["wf_uuid", "db_file"]
[docs] def run_task(self, fw_spec):
# Get num of electrons and bands from static calc
uuid = self["wf_uuid"]
db_file = env_chk(self.get("db_file"), fw_spec)
db = VaspCalcDb.from_db_file(db_file, admin=True)
db.collection = db.db["tasks"]
task_doc = db.collection.find_one(
{"wf_meta.wf_uuid": uuid, "task_label": "static"}, ["input.parameters"]
)
nelec = int(task_doc["input"]["parameters"]["NELECT"])
nbands = int(task_doc["input"]["parameters"]["NBANDS"])
w90_file = [
"num_wann = %d" % (nelec),
"num_bands = %d" % (nelec), # 1 band / elec with SOC
"spinors=.true.",
"num_iter 0",
"shell_list 1",
"exclude_bands %d-%d" % (nelec + 1, 2 * nbands),
]
w90_file = "\n".join(w90_file)
with open("wannier90.win", "w") as f:
f.write(w90_file)
return FWAction()
[docs]@explicit_serialize
class InvariantsToDB(FiretaskBase):
"""
Store Z2 and Chern nums on TRIM surfaces from Z2P output.
required_params:
wf_uuid (str): Unique wf identifier.
symmetry_reduction (bool): Set to False to disable symmetry reduction
and include all 6 BZ surfaces (for magnetic systems).
equiv_planes (dict): of the form {kx_0': ['ky_0', 'kz_0']}.
"""
required_params = [
"wf_uuid",
"db_file",
"structure",
"symmetry_reduction",
"equiv_planes",
]
[docs] def run_task(self, fw_spec):
surfaces = ["kx_0", "kx_1", "ky_0", "ky_1", "kz_0", "kz_1"]
structure = self["structure"]
symmetry_reduction = self["symmetry_reduction"]
equiv_planes = self["equiv_planes"]
# Get invariants for each surface
uuid = self["wf_uuid"]
db_file = env_chk(self.get("db_file"), fw_spec)
db = VaspCalcDb.from_db_file(db_file, admin=True)
db.collection = db.db["z2pack"]
task_docs = db.collection.find({"wf_uuid": uuid})
z2_dict = {}
chern_dict = {}
for doc in task_docs:
for s in surfaces:
if s in doc.keys():
z2_dict[s] = doc[s]["z2_invariant"]
chern_dict[s] = doc[s]["chern_number"]
# Write invariants for equivalent planes
if symmetry_reduction and len(z2_dict) < 6: # some equivalent planes
for surface in equiv_planes.keys():
# Z2
if surface in z2_dict.keys() and len(equiv_planes[surface]) > 0:
for ep in equiv_planes[surface]:
if ep not in z2_dict.keys():
z2_dict[ep] = z2_dict[surface]
# Chern
if surface in chern_dict.keys() and len(equiv_planes[surface]) > 0:
for ep in equiv_planes[surface]:
if ep not in chern_dict.keys():
chern_dict[ep] = chern_dict[surface]
# Compute Z2 invariant
if all(
surface in z2_dict.keys() for surface in ["kx_1", "ky_1", "kz_0", "kz_1"]
):
v0 = (z2_dict["kz_0"] + z2_dict["kz_1"]) % 2
v1 = z2_dict["kx_1"]
v2 = z2_dict["ky_1"]
v3 = z2_dict["kz_1"]
z2 = (v0, v1, v2, v3)
else:
z2 = (np.nan, np.nan, np.nan, np.nan)
# store the results
d = {
"wf_uuid": uuid,
"task_label": "topological invariants",
"formula": structure.composition.formula,
"reduced_formula": structure.composition.reduced_formula,
"structure": structure.as_dict(),
"z2_dict": z2_dict,
"chern_dict": chern_dict,
"z2": z2,
"equiv_planes": equiv_planes,
"symmetry_reduction": symmetry_reduction,
}
d = jsanitize(d)
db.collection.insert_one(d)
return FWAction()