Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beta #149

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open

Beta #149

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
cf97a17
update prep.py to accept mulitple lengths
sokrypton May 7, 2023
081f0ab
Update proteinmpnn_in_jax.ipynb
sokrypton May 8, 2023
196d64b
Update config.py
sokrypton May 10, 2023
a1e57c7
Update utils.py
sokrypton May 11, 2023
6bad4a1
clear outputs
sokrypton May 11, 2023
ed79531
cleanup
sokrypton May 11, 2023
37b3ed3
cleanup
sokrypton May 11, 2023
62d28db
Update prep.py
sokrypton May 11, 2023
abf4f9f
fixing MSA inputs and cleaning up prep
sokrypton May 15, 2023
7cead3b
Update prep.py
sokrypton May 15, 2023
886340d
use asym_id inside module.py to correct offset
sokrypton May 15, 2023
651b1aa
fix shuffle_first
sokrypton May 15, 2023
cb79f54
Update model.py
sokrypton May 15, 2023
a92f3de
adding i_pae to fixbb protocol (if more than 1 chain used)
sokrypton May 15, 2023
ff7e5f1
Update utils.py
sokrypton May 15, 2023
f9e8c2e
Update prep.py
sokrypton May 17, 2023
2fe6863
Update design.py
sokrypton May 17, 2023
127aa64
minor edit
sokrypton May 17, 2023
ff0b0b5
Update plot.py
sokrypton May 18, 2023
4f0c3ce
refactoring dgram recycling (#146)
sokrypton Jun 6, 2023
aa08462
adding ability to finetune alphafold params (#147)
sokrypton Jun 13, 2023
2aeb993
bugfix MSA input (#148)
sokrypton Jun 13, 2023
4c49940
bugfix - update to work with latest refactoring
sokrypton Jun 13, 2023
5494d7d
add option to control model_params within pre_callback
sokrypton Jun 14, 2023
b582249
Update design.py
sokrypton Jun 17, 2023
598c77c
rm nomem mpnn weights
sokrypton Jun 27, 2023
df4a24c
major bugfix: within multimer model outer_product_mean not set correctly
sokrypton Jul 25, 2023
e548f9f
update to fix error message
sokrypton Oct 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 288 additions & 0 deletions af/examples/af_single.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "AlphaFold_single.ipynb",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sokrypton/ColabDesign/blob/beta/af/examples/af_single.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"#AlphaFold - single sequence input\n",
"- WARNING - For DEMO and educational purposes only.\n",
"- For natural proteins you often need more than a single sequence to accurately predict the structure. See [ColabFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) notebook if you want to predict the protein structure from a multiple-sequence-alignment. That being said, this notebook could be useful for evaluating *de novo* designed proteins and learning the idealized principles of proteins.\n",
"\n",
"### Tips and Instructions\n",
"- Patience... The first time you run the cell below it will take 1 minitue to setup, after that it should run in seconds (after each change).\n",
"- click the little ▶ play icon to the left of each cell below.\n",
"- For 3D display, hold mouseover aminoacid to get name and position number\n",
"- use \"/\" to specify chainbreaks, (eg. sequence=\"AAA/AAA\")\n"
],
"metadata": {
"id": "VpfCw7IzVHXv"
}
},
{
"cell_type": "code",
"source": [
"#@title Enter the amino acid sequence to fold ⬇️\n",
"\n",
"###############################################################################\n",
"###############################################################################\n",
"#@title Setup\n",
"# import libraries\n",
"import os,sys,re,time\n",
"\n",
"if \"SETUP_DONE\" not in dir():\n",
" from IPython.utils import io\n",
" from IPython.display import HTML\n",
" import numpy as np\n",
" import matplotlib\n",
" from matplotlib import animation\n",
" import matplotlib.pyplot as plt\n",
" import tqdm.notebook\n",
" TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
"\n",
" if not os.path.isdir(\"params\"):\n",
" os.system(\"wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py\")\n",
" # get code\n",
" print(\"installing ColabDesign...\")\n",
" os.system(\"(mkdir params; apt-get install aria2 -qq; \\\n",
" aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar; \\\n",
" tar -xf alphafold_params_2021-07-14.tar -C params; \\\n",
" touch params/done.txt )&\")\n",
"\n",
" os.system(\"pip -q install git+https:/sokrypton/ColabDesign.git@beta\")\n",
" os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign\")\n",
"\n",
" # download params\n",
" if not os.path.isfile(\"params/done.txt\"):\n",
" print(\"downloading AlphaFold params...\")\n",
" while not os.path.isfile(\"params/done.txt\"):\n",
" time.sleep(5)\n",
"\n",
" # configure which device to use\n",
" import jax\n",
" # disable triton_gemm for jax versions > 0.3\n",
" if int(jax.__version__.split(\".\")[1]) > 3:\n",
" os.environ[\"XLA_FLAGS\"] = \"--xla_gpu_enable_triton_gemm=false\"\n",
" import jax.numpy as jnp\n",
" try:\n",
" # check if TPU is available\n",
" import jax.tools.colab_tpu\n",
" jax.tools.colab_tpu.setup_tpu()\n",
" print('Running on TPU')\n",
" DEVICE = \"tpu\"\n",
" except:\n",
" if jax.local_devices()[0].platform == 'cpu':\n",
" print(\"WARNING: no GPU detected, will be using CPU\")\n",
" DEVICE = \"cpu\"\n",
" else:\n",
" print('Running on GPU')\n",
" DEVICE = \"gpu\"\n",
"\n",
" # import libraries\n",
" sys.path.append('af_backprop')\n",
"\n",
" SETUP_DONE = True\n",
"\n",
"if \"LIBRARY_IMPORTED\" not in dir():\n",
" from colabdesign.af.loss import get_plddt, get_pae\n",
" from colabdesign.af.prep import prep_input_features\n",
" from colabdesign.af.inputs import update_seq, update_aatype\n",
" from colabdesign.af.alphafold.common import protein\n",
" from colabdesign.af.alphafold.model import data, config, model\n",
" from colabdesign.af.alphafold.common import residue_constants\n",
" from colabdesign.rf.utils import make_animation\n",
" import py3Dmol\n",
" import colabfold as cf\n",
"\n",
" # setup model\n",
" cfg = config.model_config(\"model_5_ptm\")\n",
" cfg.model.num_recycle = 0\n",
" cfg.model.global_config.subbatch_size = None\n",
" model_name=\"model_2_ptm\"\n",
" model_params = data.get_model_haiku_params(model_name=model_name,\n",
" data_dir=\".\",\n",
" fuse=True)\n",
" model_runner = model.RunModel(cfg, model_params)\n",
"\n",
" def setup_model(max_len):\n",
"\n",
" seq = \"A\" * max_len\n",
" length = len(seq)\n",
" inputs = prep_input_features(length)\n",
"\n",
" def runner(I):\n",
" # update sequence\n",
" inputs = I[\"inputs\"]\n",
" inputs[\"prev\"] = I[\"prev\"]\n",
"\n",
" seq_oh = jax.nn.one_hot(I[\"seq\"],20)[None]\n",
" update_seq(seq_oh, inputs)\n",
" update_aatype(seq_oh, inputs)\n",
"\n",
" # mask prediction\n",
" mask = jnp.arange(inputs[\"residue_index\"].shape[0]) < I[\"length\"]\n",
" inputs[\"seq_mask\"] = inputs[\"seq_mask\"].at[:].set(mask)\n",
" inputs[\"msa_mask\"] = inputs[\"msa_mask\"].at[:].set(mask)\n",
" inputs[\"residue_index\"] = jnp.where(mask, inputs[\"residue_index\"], 0)\n",
"\n",
" # get prediction\n",
" key = jax.random.PRNGKey(0)\n",
" outputs = model_runner.apply(I[\"params\"], key, inputs)\n",
"\n",
" aux = {\"final_atom_positions\":outputs[\"structure_module\"][\"final_atom_positions\"],\n",
" \"final_atom_mask\":outputs[\"structure_module\"][\"final_atom_mask\"],\n",
" \"plddt\":get_plddt(outputs),\"pae\":get_pae(outputs),\n",
" \"length\":I[\"length\"], \"seq\":I[\"seq\"],\n",
" \"prev\":outputs[\"prev\"],\n",
" \"residue_idx\":inputs[\"residue_index\"]}\n",
" return aux\n",
"\n",
" return jax.jit(runner), {\"inputs\":inputs, \"params\":model_params, \"length\":max_length}\n",
"\n",
" def save_pdb(outs, filename):\n",
" '''save pdb coordinates'''\n",
" p = {\"residue_index\":outs[\"residue_idx\"] + 1,\n",
" \"aatype\":outs[\"seq\"],\n",
" \"atom_positions\":outs[\"final_atom_positions\"],\n",
" \"atom_mask\":outs[\"final_atom_mask\"],\n",
" \"plddt\":outs[\"plddt\"]}\n",
" p = jax.tree_util.tree_map(lambda x:x[:outs[\"length\"]], p)\n",
" b_factors = 100 * p.pop(\"plddt\")[:,None] * p[\"atom_mask\"]\n",
" p = protein.Protein(**p,b_factors=b_factors)\n",
" pdb_lines = protein.to_pdb(p)\n",
" with open(filename, 'w') as f: f.write(pdb_lines)\n",
"\n",
" LIBRARY_IMPORTED = True\n",
"\n",
"###############################################################################\n",
"###############################################################################\n",
"\n",
"# initialize\n",
"if \"current_seq\" not in dir():\n",
" current_seq = \"\"\n",
" r = -1\n",
" max_length = -1\n",
"\n",
"# collect user inputs\n",
"sequence = 'GGGGGGGGGG' #@param {type:\"string\"}\n",
"recycles = 0 #@param [\"0\", \"1\", \"2\", \"3\", \"6\", \"12\", \"24\", \"48\"] {type:\"raw\"}\n",
"ori_sequence = re.sub(\"[^A-Z\\/\\:]\", \"\", sequence.upper())\n",
"Ls = [len(s) for s in ori_sequence.replace(\":\",\"/\").split(\"/\")]\n",
"sequence = re.sub(\"[^A-Z]\",\"\",ori_sequence)\n",
"length = len(sequence)\n",
"\n",
"# avoid recompiling if length within 10\n",
"if length > max_length or (max_length - length) > 20:\n",
" max_length = length + 10\n",
" runner, I = setup_model(max_length)\n",
"\n",
"if ori_sequence != current_seq:\n",
" outs = []\n",
" positions = []\n",
" plddts = []\n",
" paes = []\n",
" r = -1\n",
"\n",
" # pad sequence to max length\n",
" seq = np.array([residue_constants.restype_order.get(aa,0) for aa in sequence])\n",
" seq = np.pad(seq,[0,max_length-length],constant_values=-1)\n",
"\n",
" # update inputs, restart recycle\n",
" I.update({\"seq\":seq, \"length\":length,\n",
" \"prev\":{'prev_msa_first_row': np.zeros([max_length, 256]),\n",
" 'prev_pair': np.zeros([max_length, max_length, 128]),\n",
" 'prev_pos': np.zeros([max_length, 37, 3])}})\n",
"\n",
" I[\"inputs\"][\"use_dropout\"] = False\n",
" I[\"inputs\"]['residue_index'][:] = cf.chain_break(np.arange(max_length), Ls, length=32)\n",
" current_seq = ori_sequence\n",
"\n",
"# run for defined number of recycles\n",
"with tqdm.notebook.tqdm(total=(recycles+1), bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" p = 0\n",
" while p < min(r+1,recycles+1):\n",
" pbar.update(1)\n",
" p += 1\n",
" while r < recycles:\n",
" O = runner(I)\n",
" O = jax.tree_util.tree_map(lambda x:np.asarray(x), O)\n",
" positions.append(O[\"final_atom_positions\"][:length])\n",
" plddts.append(O[\"plddt\"][:length])\n",
" paes.append(O[\"pae\"][:length,:length])\n",
" I[\"prev\"] = O[\"prev\"]\n",
" outs.append(O)\n",
" r += 1\n",
" pbar.update(1)\n",
"\n",
"#@markdown #### Display options\n",
"color = \"confidence\" #@param [\"chain\", \"confidence\", \"rainbow\"]\n",
"if color == \"confidence\": color = \"lDDT\"\n",
"show_sidechains = True #@param {type:\"boolean\"}\n",
"show_mainchains = False #@param {type:\"boolean\"}\n",
"\n",
"print(f\"plotting prediction at recycle={recycles}\")\n",
"save_pdb(outs[recycles], \"out.pdb\")\n",
"v = cf.show_pdb(\"out.pdb\", show_sidechains, show_mainchains, color,\n",
" color_HP=True, size=(800,480), Ls=Ls)\n",
"v.setHoverable({}, True,\n",
" '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel(\" \"+atom.resn+\":\"+atom.resi,{position:atom,backgroundColor:'mintcream',fontColor:'black'});}}''',\n",
" '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')\n",
"v.show()\n",
"if color == \"lDDT\":\n",
" cf.plot_plddt_legend().show()\n",
"\n",
"# add confidence plots\n",
"cf.plot_confidence(plddts[recycles]*100, paes[recycles], Ls=Ls).show()"
],
"metadata": {
"cellView": "form",
"id": "cAoC4ar8G7ZH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title Animate\n",
"#@markdown - Animate trajectory if more than 0 recycle(s)\n",
"HTML(make_animation(np.asarray(positions)[...,1,:],\n",
" np.asarray(plddts) * 100.0,\n",
" Ls=Ls,\n",
" ref=-1, align_to_ref=True,\n",
" verbose=True))"
],
"metadata": {
"cellView": "form",
"id": "tdjdC0KFPjWw"
},
"execution_count": null,
"outputs": []
}
]
}
5 changes: 3 additions & 2 deletions colabdesign/af/alphafold/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'multimer_mode': False,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
'zero_init': True,
'use_dgram_pred': False,
},
'heads': {
'distogram': {
Expand Down Expand Up @@ -536,7 +537,7 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'subbatch_size': 4,
'use_remat': False,
'zero_init': True,
'use_dgram': False
'use_dgram_pred': False,
},
'heads': {
'distogram': {
Expand Down
4 changes: 2 additions & 2 deletions colabdesign/af/alphafold/model/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def casp_model_names(data_dir: str) -> List[str]:
return [os.path.splitext(filename)[0] for filename in params]


def get_model_haiku_params(model_name: str, data_dir: str, fuse: bool = None) -> hk.Params:
def get_model_haiku_params(model_name: str, data_dir: str, fuse: bool = None, rm_templates: bool = False) -> hk.Params:
"""Get the Haiku parameters from a model name."""

path = os.path.join(data_dir, 'params', f'params_{model_name}.npz')
Expand All @@ -38,4 +38,4 @@ def get_model_haiku_params(model_name: str, data_dir: str, fuse: bool = None) ->
if os.path.isfile(path):
with open(path, 'rb') as f:
params = np.load(io.BytesIO(f.read()), allow_pickle=False)
return utils.flat_params_to_haiku(params, fuse=fuse)
return utils.flat_params_to_haiku(params, fuse=fuse, rm_templates=rm_templates)
Loading