Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 23, 2024
1 parent 1575ecf commit 6d548fb
Showing 1 changed file with 72 additions and 23 deletions.
95 changes: 72 additions & 23 deletions graphein/protein/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,18 @@ def read_pdb_to_dataframe(
:rtype: pd.DataFrame
"""
if pdb_code is None and path is None and uniprot_id is None:
raise NameError("One of pdb_code, path or uniprot_id must be specified!")
raise NameError(
"One of pdb_code, path or uniprot_id must be specified!"
)

if path is not None:
if isinstance(path, Path):
path = os.fsdecode(path)
if path.endswith(".pdb") or path.endswith(".pdb.gz") or path.endswith(".ent"):
if (
path.endswith(".pdb")
or path.endswith(".pdb.gz")
or path.endswith(".ent")
):
atomic_df = PandasPdb().read_pdb(path)
elif path.endswith(".mmtf") or path.endswith(".mmtf.gz"):
atomic_df = PandasMmtf().read_mmtf(path)
Expand All @@ -110,7 +116,9 @@ def read_pdb_to_dataframe(
f"File {path} must be either .pdb(.gz), .mmtf(.gz) or .ent, not {path.split('.')[-1]}"
)
elif uniprot_id is not None:
atomic_df = PandasPdb().fetch_pdb(uniprot_id=uniprot_id, source="alphafold2-v3")
atomic_df = PandasPdb().fetch_pdb(
uniprot_id=uniprot_id, source="alphafold2-v3"
)
else:
atomic_df = PandasPdb().fetch_pdb(pdb_code)

Expand Down Expand Up @@ -164,7 +172,11 @@ def label_node_id(
df["node_id"] = df["node_id"] + ":" + df["atom_name"]
elif granularity in {"rna_atom", "rna_centroid"}:
df["node_id"] = (
df["node_id"] + ":" + df["atom_number"].apply(str) + ":" + df["atom_name"]
df["node_id"]
+ ":"
+ df["atom_number"].apply(str)
+ ":"
+ df["atom_name"]
)
return df

Expand All @@ -177,7 +189,9 @@ def deprotonate_structure(df: pd.DataFrame) -> pd.DataFrame:
:returns: Atomic dataframe with all ``element_symbol == "H" or "D" or "T"`` removed.
:rtype: pd.DataFrame
"""
log.debug("Deprotonating protein. This removes H atoms from the pdb_df dataframe")
log.debug(
"Deprotonating protein. This removes H atoms from the pdb_df dataframe"
)
return filter_dataframe(
df,
by_column="element_symbol",
Expand Down Expand Up @@ -211,7 +225,9 @@ def convert_structure_to_centroids(df: pd.DataFrame) -> pd.DataFrame:
return df


def subset_structure_to_atom_type(df: pd.DataFrame, granularity: str) -> pd.DataFrame:
def subset_structure_to_atom_type(
df: pd.DataFrame, granularity: str
) -> pd.DataFrame:
"""
Return a subset of atomic dataframe that contains only certain atom names.
Expand All @@ -225,7 +241,9 @@ def subset_structure_to_atom_type(df: pd.DataFrame, granularity: str) -> pd.Data
)


def remove_alt_locs(df: pd.DataFrame, keep: str = "max_occupancy") -> pd.DataFrame:
def remove_alt_locs(
df: pd.DataFrame, keep: str = "max_occupancy"
) -> pd.DataFrame:
"""
This function removes alternatively located atoms from PDB DataFrames
(see https://proteopedia.org/wiki/index.php/Alternate_locations). Among the
Expand Down Expand Up @@ -289,7 +307,9 @@ def remove_insertions(
)


def filter_hetatms(df: pd.DataFrame, keep_hets: List[str]) -> List[pd.DataFrame]:
def filter_hetatms(
df: pd.DataFrame, keep_hets: List[str]
) -> List[pd.DataFrame]:
"""Return hetatms of interest.
:param df: Protein Structure dataframe to filter hetatoms from.
Expand Down Expand Up @@ -434,7 +454,9 @@ def sort_dataframe(df: pd.DataFrame) -> pd.DataFrame:
:return: Sorted protein dataframe.
:rtype: pd.DataFrame
"""
return df.sort_values(by=["chain_id", "residue_number", "atom_number", "insertion"])
return df.sort_values(
by=["chain_id", "residue_number", "atom_number", "insertion"]
)


def select_chains(
Expand Down Expand Up @@ -536,7 +558,8 @@ def initialise_graph_with_metadata(
elif granularity == "atom":
sequence = (
protein_df.loc[
(protein_df["chain_id"] == c) & (protein_df["atom_name"] == "CA")
(protein_df["chain_id"] == c)
& (protein_df["atom_name"] == "CA")
]["residue_name"]
.apply(three_to_one_with_mods)
.str.cat()
Expand Down Expand Up @@ -587,9 +610,13 @@ def add_nodes_to_graph(
# Set intrinsic node attributes
nx.set_node_attributes(G, dict(zip(nodes, chain_id)), "chain_id")
nx.set_node_attributes(G, dict(zip(nodes, residue_name)), "residue_name")
nx.set_node_attributes(G, dict(zip(nodes, residue_number)), "residue_number")
nx.set_node_attributes(
G, dict(zip(nodes, residue_number)), "residue_number"
)
nx.set_node_attributes(G, dict(zip(nodes, atom_type)), "atom_type")
nx.set_node_attributes(G, dict(zip(nodes, element_symbol)), "element_symbol")
nx.set_node_attributes(
G, dict(zip(nodes, element_symbol)), "element_symbol"
)
nx.set_node_attributes(G, dict(zip(nodes, coords)), "coords")
nx.set_node_attributes(G, dict(zip(nodes, b_factor)), "b_factor")

Expand All @@ -615,7 +642,9 @@ def calculate_centroid_positions(
:rtype: pd.DataFrame
"""
centroids = (
atoms.groupby(["residue_number", "chain_id", "residue_name", "insertion"])
atoms.groupby(
["residue_number", "chain_id", "residue_name", "insertion"]
)
.mean(numeric_only=True)[["x_coord", "y_coord", "z_coord"]]
.reset_index()
)
Expand Down Expand Up @@ -873,9 +902,13 @@ def _mp_graph_constructor(
func = partial(construct_graph, config=config)
try:
if source == "pdb_code":
return func(pdb_code=args[0], chain_selection=args[1], model_index=args[2])
return func(
pdb_code=args[0], chain_selection=args[1], model_index=args[2]
)
elif source == "path":
return func(path=args[0], chain_selection=args[1], model_index=args[2])
return func(
path=args[0], chain_selection=args[1], model_index=args[2]
)
elif source == "uniprot_id":
return func(
uniprot_id=args[0],
Expand Down Expand Up @@ -971,7 +1004,9 @@ def construct_graphs_mp(
)
if out_path is not None:
[
nx.write_gpickle(g, str(f"{out_path}/" + f"{g.graph['name']}.pickle"))
nx.write_gpickle(
g, str(f"{out_path}/" + f"{g.graph['name']}.pickle")
)
for g in graphs
]

Expand Down Expand Up @@ -1035,11 +1070,15 @@ def compute_chain_graph(

# Add edges
for u, v, d in g.edges(data=True):
h.add_edge(g.nodes[u]["chain_id"], g.nodes[v]["chain_id"], kind=d["kind"])
h.add_edge(
g.nodes[u]["chain_id"], g.nodes[v]["chain_id"], kind=d["kind"]
)
# Remove self-loops if necessary. Checks for equality between nodes in a
# given edge.
if remove_self_loops:
edges_to_remove: List[Tuple[str]] = [(u, v) for u, v in h.edges() if u == v]
edges_to_remove: List[Tuple[str]] = [
(u, v) for u, v in h.edges() if u == v
]
h.remove_edges_from(edges_to_remove)

# Compute a weighted graph if required.
Expand Down Expand Up @@ -1142,10 +1181,16 @@ def compute_secondary_structure_graph(
ss_list = ss_list[~ss_list.str.contains("-")]
# Subset to only allowable SS elements if necessary
if allowable_ss_elements:
ss_list = ss_list[ss_list.str.contains("|".join(allowable_ss_elements))]
ss_list = ss_list[
ss_list.str.contains("|".join(allowable_ss_elements))
]

constituent_residues: Dict[str, List[str]] = ss_list.index.groupby(ss_list.values)
constituent_residues = {k: list(v) for k, v in constituent_residues.items()}
constituent_residues: Dict[str, List[str]] = ss_list.index.groupby(
ss_list.values
)
constituent_residues = {
k: list(v) for k, v in constituent_residues.items()
}
residue_counts: Dict[str, int] = ss_list.groupby(ss_list).count().to_dict()

# Add Nodes from secondary structure list
Expand All @@ -1164,7 +1209,9 @@ def compute_secondary_structure_graph(
# Iterate over edges in source graph and add SS-SS edges to new graph.
for u, v, d in g.edges(data=True):
try:
h.add_edge(ss_list[u], ss_list[v], kind=d["kind"], source=f"{u}_{v}")
h.add_edge(
ss_list[u], ss_list[v], kind=d["kind"], source=f"{u}_{v}"
)
except KeyError as e:
log.debug(
f"Edge {u}-{v} not added to secondary structure graph. \
Expand All @@ -1174,7 +1221,9 @@ def compute_secondary_structure_graph(
# Remove self-loops if necessary.
# Checks for equality between nodes in a given edge.
if remove_self_loops:
edges_to_remove: List[Tuple[str]] = [(u, v) for u, v in h.edges() if u == v]
edges_to_remove: List[Tuple[str]] = [
(u, v) for u, v in h.edges() if u == v
]
h.remove_edges_from(edges_to_remove)

# Create weighted graph from h
Expand Down

0 comments on commit 6d548fb

Please sign in to comment.