diff --git a/.github/workflows/examples-ci.yaml b/.github/workflows/examples-ci.yaml index 05a7bc1..bacb651 100644 --- a/.github/workflows/examples-ci.yaml +++ b/.github/workflows/examples-ci.yaml @@ -31,7 +31,7 @@ jobs: fail-fast: false matrix: os: [macOS-12, ubuntu-latest] - python-version: ["3.10", "3.11"] + python-version: ["3.11", "3.12"] pydantic-version: ["2"] include-rdkit: [true] include-openeye: [false] diff --git a/devtools/conda-envs/examples_env.yaml b/devtools/conda-envs/examples_env.yaml index cce6a95..e0d83e0 100644 --- a/devtools/conda-envs/examples_env.yaml +++ b/devtools/conda-envs/examples_env.yaml @@ -22,6 +22,8 @@ dependencies: - openff-toolkit >=0.11.1 - openff-units - openff-recharge + - openff-qcsubmit + - psi4 - pydantic <3 - rdkit @@ -37,6 +39,9 @@ dependencies: # parallelism - dask-jobqueue + # compatibility + - apsw >=3.42 + # CI - nbval diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 381a01a..45b6363 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -19,6 +19,9 @@ The rules for this file: ### Authors - [@lilyminium] +### Added +- General linear fit target and example (PR #131) + ### Changed - Removed unused, undocumented code paths, and updated docs (PR #132) diff --git a/examples/README.md b/examples/README.md index e2a19a1..0f56f1c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -6,3 +6,4 @@ The following examples are available in [the OpenFF NAGL repository](https://git * [prepare-dataset](https://github.com/openforcefield/openff-nagl/tree/main/examples/prepare-dataset) - Prepare a dataset for training, validating or testing NAGL models from a list of SMILES and the OpenFF Toolkit * [train-gnn-notebook](https://github.com/openforcefield/openff-nagl/tree/main/examples/train-gnn-notebook) - Architect, train, and use a simple GCN partial charge model on an alkane test dataset +* [train-electric-field](https://github.com/openforcefield/openff-nagl/tree/main/examples/train-gnn-to-electric-field) - Prepare a dataset from QM data, set up training and validation sets, create a GNN and train to electric field data \ No newline at end of file diff --git a/examples/train-gnn-to-electric-field/.gitignore b/examples/train-gnn-to-electric-field/.gitignore new file mode 100644 index 0000000..b8f2783 --- /dev/null +++ b/examples/train-gnn-to-electric-field/.gitignore @@ -0,0 +1,3 @@ +lightning_logs +tmp.pkl +api.qcarchive.molssi.org_443 \ No newline at end of file diff --git a/examples/train-gnn-to-electric-field/input-env.yaml b/examples/train-gnn-to-electric-field/input-env.yaml new file mode 100644 index 0000000..550cbef --- /dev/null +++ b/examples/train-gnn-to-electric-field/input-env.yaml @@ -0,0 +1,16 @@ +name: train-gnn-to-electric-field + +channels: + - conda-forge + +dependencies: + - openff-qcsubmit + - openff-recharge ==0.5.2 + - psi4 + + - jupyter + - tqdm + - pip + + - pip: + - git+https://github.com/openforcefield/openff-nagl.git@main \ No newline at end of file diff --git a/examples/train-gnn-to-electric-field/train-gnn-to-electric-field.ipynb b/examples/train-gnn-to-electric-field/train-gnn-to-electric-field.ipynb new file mode 100644 index 0000000..0defa7f --- /dev/null +++ b/examples/train-gnn-to-electric-field/train-gnn-to-electric-field.ipynb @@ -0,0 +1,2145 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "94ff1817-f641-4a11-ae4a-91c7a6c73a7e", + "metadata": {}, + "source": [ + "# Train a GNN directly to an electric field\n", + "\n", + "To execute this example fully, the following packages are required.\n", + "\n", + "* openff-nagl\n", + "* openff-recharge\n", + "* openff-qcsubmit\n", + "* psi4\n", + "\n", + "However, if you wish to just follow along the training part without first creating the training datasets yourself, you can get away with just `openff-nagl` installed and simply load the training/validation data from the provided `.parquet` files. The commands are provided at the end of the \"Generate and format training data\" section, but commented out." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "522bc606-c1aa-4cc7-89da-87440c61f8ba", + "metadata": {}, + "outputs": [], + "source": [ + "import tqdm\n", + "\n", + "from qcportal import PortalClient\n", + "from openff.units import unit\n", + "\n", + "from openff.toolkit import Molecule\n", + "from openff.qcsubmit.results import BasicResultCollection\n", + "from openff.recharge.esp.storage import MoleculeESPRecord\n", + "from openff.recharge.esp.qcresults import from_qcportal_results\n", + "from openff.recharge.grids import MSKGridSettings\n", + "from openff.recharge.utilities.geometry import compute_vector_field\n", + "\n", + "import pyarrow as pa\n", + "import pyarrow.parquet as pq\n", + "import numpy as np\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "id": "097895ad-685b-4825-8c22-1a0f79467e9d", + "metadata": {}, + "source": [ + "## Generate and format training data\n", + "\n", + "\n", + "### Downloading from QCArchive\n", + "First, we will create training data. We'll download a smaller training set for the purposes of this example." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "964318ce-9e08-4551-9af4-dc49ea89d027", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: This client version is newer than the server version. This may work if the versions are close, but expect exceptions and errors if attempting things the server does not support. client version: 0.56, server version: 0.55\n", + "WARNING: This client version is newer than the server version. This may work if the versions are close, but expect exceptions and errors if attempting things the server does not support. client version: 0.56, server version: 0.55\n" + ] + } + ], + "source": [ + "qc_client = PortalClient(\"https://api.qcarchive.molssi.org:443\", cache_dir=\".\")\n", + "\n", + "# download dataset from QCPortal\n", + "br_esps_collection = BasicResultCollection.from_server(\n", + " client=qc_client,\n", + " datasets=\"OpenFF multi-Br ESP Fragment Conformers v1.1\",\n", + " spec_name=\"HF/6-31G*\",\n", + ")\n", + "\n", + "records_and_molecules = br_esps_collection.to_records()" + ] + }, + { + "cell_type": "markdown", + "id": "54235869-1a60-4945-967d-602ef1f49185", + "metadata": {}, + "source": [ + "### Converting to MoleculeESPRecords\n", + "\n", + "Now we convert to OpenFF Recharge records and compute the ESPs and electric fields of each molecule." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "25f9db71-4738-4431-9e65-7c0b015b40b4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [04:36<00:00, 5.52s/it]\n" + ] + } + ], + "source": [ + "# Create OpenFF Recharge MoleculeESPRecords\n", + "grid_settings = MSKGridSettings()\n", + "\n", + "# this can take a while; set records_and_molecules[:10]\n", + "# to only use the first 10\n", + "molecule_esp_records = [\n", + " from_qcportal_results(\n", + " qc_result=qcrecord,\n", + " qc_molecule=qcrecord.molecule,\n", + " qc_keyword_set=qcrecord.specification.keywords,\n", + " grid_settings=grid_settings,\n", + " compute_field=True\n", + " )\n", + " for qcrecord, _ in tqdm.tqdm(records_and_molecules[:50])\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "9dc86d0a-c59e-4f8f-93c7-3b74adb6636e", + "metadata": {}, + "source": [ + "### Convert to PyArrow dataset\n", + "\n", + "NAGL reads in and trains to data from [PyArrow tables](https://arrow.apache.org/docs/python/getstarted.html#creating-arrays-and-tables). Below we do some conversion of each electric field to fit a basic GeneralLinearFit target, which fits the equation Ax = b. To avoid carrying too many data points around, we can do some postprocessing by first flattening the electric field matrix from 3 dimensions to 2, then multiplying by the transpose of A.\n", + "\n", + "$$\\mathbf{A}\\vec{x} = \\vec{b}$$\n", + "$$\\mathbf{A^{T}}\\mathbf{A}\\vec{x} = \\mathbf{A^{T}}\\vec{b}$$" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "70634680-8249-49f6-ad24-bc5178851360", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 447.17it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "pyarrow.Table\n", + "mapped_smiles: string\n", + "precursor_matrix: list\n", + " child 0, item: double\n", + "prediction_vector: list\n", + " child 0, item: double\n", + "----\n", + "mapped_smiles: [[\"[H:1][C:2](=[O:3])[C:4]1=[C:9]([C:7](=[N:6][S:5]1)[Br:8])[Br:10]\",\"[H:1][C:2](=[O:3])[C:4]1=[C:5]([N:7]=[C:8]([N:10]1[H:11])[Br:9])[Br:6]\",\"[H:1][C:2](=[O:3])[C:4]1=[C:10]([N:8]([C:6](=[N:5]1)[Br:7])[H:9])[Br:11]\",\"[H:1][C:2](=[O:3])[C:4]1=[C:9]([S:8][N:7]=[C:5]1[Br:6])[Br:10]\",\"[H:1][C:2](=[O:3])[C:4]1=[N:5][S:6][C:7](=[C:9]1[Br:10])[Br:8]\",...,\"[H:4][c:3]1[c:2]([c:20]([c:18]([c:16]([c:5]1[C:6]([H:7])([C:8]([H:9])([H:10])[H:11])[N+:12]([H:13])([H:14])[H:15])[Br:17])[H:19])[Br:21])[H:1]\",\"[H:1][c:2]1[c:19]([c:9]([c:7]([c:5]([c:3]1[H:4])[Br:6])[Br:8])[C:10]([H:11])([C:15]([H:16])([H:17])[H:18])[N:12]([H:13])[H:14])[H:20]\",\"[H:1][c:2]1[c:20]([c:9]([c:7]([c:5]([c:3]1[H:4])[Br:6])[Br:8])[C:10]([H:11])([C:12]([H:13])([H:14])[H:15])[N+:16]([H:17])([H:18])[H:19])[H:21]\",\"[H:4][c:3]1[c:2]([c:22]([c:20]([c:18]([c:5]1[C:6]([H:7])([H:8])[N:9]([H:10])[C:11]([H:12])([H:13])[C:14]([H:15])([H:16])[H:17])[H:19])[Br:21])[Br:23])[H:1]\",\"[H:4][c:3]1[c:2]([c:22]([c:20]([c:18]([c:5]1[C:6]([H:7])([H:8])[N:9]([H:10])[C:11]([H:12])([H:13])[C:14]([H:15])([H:16])[H:17])[H:19])[Br:21])[Br:23])[H:1]\"]]\n", + "precursor_matrix: [[[0.5268278695916089,0.3663918376587872,0.30209382642805216,0.24952373167995737,0.20284281127461076,...,0.12682386926954425,0.17299639848680223,0.14835260748248824,0.22417394022165082,0.31729613836761544],[0.5459961499174945,0.3811772132544467,0.310495392830858,0.26421186900748445,0.16815883207694643,...,0.1959073125648948,0.2712204910932398,0.2018013854288873,0.3843880654806236,0.5473435022093819],...,[0.5188600228738232,0.36372757615732326,0.27049010438839094,0.2625940436429743,0.1744397734479557,...,0.10263580193044784,0.18779157315897135,0.16465862132957662,0.24504157435453747,0.34126749604041895],[0.523471022948499,0.3711542616380526,0.2768487697745459,0.26502768742612426,0.17897523760699538,...,0.10483875445263033,0.18971891584707673,0.16387162999092086,0.24867226425516173,0.3443168978222578]]]\n", + "prediction_vector: [[[0.06806112259081701,0.030716434587978025,-0.016729472006731878,0.029951025414986638,0.029496272551505647,-0.0024857716639207815,0.005087538225819736,-0.004348276082205866,0.013868185380437326,-0.0003752919418263855],[0.06562099527752072,0.022540831241458854,-0.03627159213493826,0.030697369204467136,-0.003824835345860035,-0.022952256996058556,-0.013973378103473662,0.024411442114600782,0.020922857209616338,0.06615881080418917,0.12160304322089216],...,[0.022271980401278424,0.0056816366203780245,0.016862184483462814,0.036808512643687834,0.013646974537450432,...,0.0024587377823627682,-0.015963611907693187,-0.028345562608464187,-0.014267070720150517,-0.0268968083285119],[0.020751052712825666,0.0033002036503013057,0.009596814898405207,0.01807019115329054,0.01444660212469153,...,0.02400538665376119,-0.013272983985162003,-0.026300857709583143,-0.01412185316045618,-0.026633251145107114]]]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pyarrow_entries = []\n", + "for molecule_esp_record in tqdm.tqdm(molecule_esp_records):\n", + " electric_field = molecule_esp_record.electric_field # in atomic units\n", + " grid = molecule_esp_record.grid_coordinates * unit.angstrom\n", + " conformer = molecule_esp_record.conformer * unit.angstrom\n", + "\n", + " vector_field = compute_vector_field(\n", + " conformer.m_as(unit.bohr), # shape: M x 3\n", + " grid.m_as(unit.bohr), # shape: N x 3\n", + " ) # N x 3 x M\n", + "\n", + " # postprocess so we're not carrying around millions of floats\n", + " # firstly flatten out N x 3 -> 3N x M\n", + " vector_field_2d = np.concatenate(vector_field, axis=0)\n", + " electric_field_1d = np.concatenate(electric_field, axis=0)\n", + "\n", + " # now multiply by vector_field_2d's transpose\n", + " new_precursor_matrix = vector_field_2d.T @ vector_field_2d\n", + " new_field_vector = vector_field_2d.T @ electric_field_1d\n", + "\n", + " n_atoms = conformer.shape[0]\n", + " assert new_precursor_matrix.shape == (n_atoms, n_atoms)\n", + " assert new_field_vector.shape == (n_atoms,)\n", + "\n", + " # create entry. These columns are essential\n", + " # mapped_smiles is essential for every target\n", + " entry = {\n", + " \"mapped_smiles\": molecule_esp_record.tagged_smiles,\n", + " \"precursor_matrix\": new_precursor_matrix.flatten().tolist(),\n", + " \"prediction_vector\": new_field_vector.tolist()\n", + " }\n", + " pyarrow_entries.append(entry)\n", + "\n", + "\n", + "# arbitrarily split into training and validation datasets\n", + "training_pyarrow_entries = pyarrow_entries[:-10]\n", + "validation_pyarrow_entries = pyarrow_entries[-10:]\n", + "\n", + "training_table = pa.Table.from_pylist(training_pyarrow_entries)\n", + "validation_table = pa.Table.from_pylist(validation_pyarrow_entries)\n", + "training_table" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "72a9cdce-5411-4fe3-b2ea-46fb63e09d2c", + "metadata": {}, + "outputs": [], + "source": [ + "pq.write_table(training_table, \"training_dataset_table.parquet\")\n", + "pq.write_table(validation_table, \"validation_dataset_table.parquet\")\n", + "\n", + "# # to read back in -- note, the files saved here give the full dataset, not the 50 record subset\n", + "# training_table = pq.read_table(\"training_dataset_table.parquet\")\n", + "# validation_table = pq.read_table(\"validation_dataset_table.parquet\")" + ] + }, + { + "cell_type": "markdown", + "id": "4682e8cb-fba8-4cec-97b0-4d996e24ecf9", + "metadata": {}, + "source": [ + "## Set up for training a GNN" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3d3c909f-c1b9-44e0-86e0-d4ccf6ceb996", + "metadata": {}, + "outputs": [], + "source": [ + "from openff.nagl.config import (\n", + " TrainingConfig,\n", + " OptimizerConfig,\n", + " ModelConfig,\n", + " DataConfig\n", + ")\n", + "from openff.nagl.config.model import (\n", + " ConvolutionModule, ReadoutModule,\n", + " ConvolutionLayer, ForwardLayer,\n", + ")\n", + "from openff.nagl.config.data import DatasetConfig\n", + "from openff.nagl.training.training import TrainingGNNModel\n", + "from openff.nagl.features.atoms import (\n", + " AtomicElement,\n", + " AtomConnectivity,\n", + " AtomInRingOfSize,\n", + " AtomAverageFormalCharge,\n", + ")\n", + "\n", + "from openff.nagl.training.loss import GeneralLinearFitTarget" + ] + }, + { + "cell_type": "markdown", + "id": "88be032d-0101-479e-a468-2c5db34600fc", + "metadata": {}, + "source": [ + "### Defining the training config\n", + "\n", + "#### Defining a ModelConfig\n", + "\n", + "First we define a ModelConfig.\n", + "This can be done in Python, but in practice it is probably easier to define the model in a YAML file and load it with `ModelConfig.from_yaml`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f7ddfbde-1349-42ec-a2aa-0f9ead097e99", + "metadata": {}, + "outputs": [], + "source": [ + "atom_features = [\n", + " AtomicElement(categories=[\"H\", \"C\", \"N\", \"O\", \"F\", \"Br\", \"S\", \"P\", \"I\"]),\n", + " AtomConnectivity(categories=[1, 2, 3, 4, 5, 6]),\n", + " AtomInRingOfSize(ring_size=3),\n", + " AtomInRingOfSize(ring_size=4),\n", + " AtomInRingOfSize(ring_size=5),\n", + " AtomInRingOfSize(ring_size=6),\n", + " AtomAverageFormalCharge(),\n", + "]\n", + "\n", + "# define our convolution module\n", + "convolution_module = ConvolutionModule(\n", + " architecture=\"SAGEConv\",\n", + " # construct 6 layers with dropout 0 (default),\n", + " # hidden feature size 512, and ReLU activation function\n", + " # these layers can also be individually specified,\n", + " # but we just duplicate the layer 6 times for identical layers\n", + " layers=[\n", + " ConvolutionLayer(\n", + " hidden_feature_size=512,\n", + " activation_function=\"ReLU\",\n", + " aggregator_type=\"mean\"\n", + " )\n", + " ] * 6,\n", + ")\n", + "\n", + "# define our readout module/s\n", + "# multiple are allowed but let's focus on charges\n", + "readout_modules = {\n", + " # key is the name of output property, any naming is allowed\n", + " \"charges\": ReadoutModule(\n", + " pooling=\"atoms\",\n", + " postprocess=\"compute_partial_charges\",\n", + " # 2 layers\n", + " layers=[\n", + " ForwardLayer(\n", + " hidden_feature_size=512,\n", + " activation_function=\"ReLU\",\n", + " )\n", + " ] * 2,\n", + " )\n", + "}\n", + "\n", + "# bring it all together\n", + "model_config = ModelConfig(\n", + " version=\"0.1\",\n", + " atom_features=atom_features,\n", + " convolution=convolution_module,\n", + " readouts=readout_modules,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a68f7b57-a1f7-43f4-ba2d-9579dd9562db", + "metadata": {}, + "source": [ + "#### Defining a DataConfig\n", + "\n", + "We can then define our dataset configs. Here we also have to specify our training targets." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "bf524fad-42e5-4020-a9d9-90a56343b7a1", + "metadata": {}, + "outputs": [], + "source": [ + "target = GeneralLinearFitTarget(\n", + " # what we're using to evaluate loss\n", + " target_label=\"prediction_vector\",\n", + " # the output of the GNN we use to evaluate loss\n", + " prediction_label=\"charges\",\n", + " # the column in the table that contains the precursor matrix\n", + " design_matrix_column=\"precursor_matrix\",\n", + " # how we want to evaluate loss, e.g. RMSE, MSE, ...\n", + " metric=\"rmse\",\n", + " # how much to weight this target\n", + " # helps with scaling in multi-target optimizations\n", + " weight=1,\n", + " denominator=1,\n", + ")\n", + "\n", + "training_to_electric_field = DatasetConfig(\n", + " sources=[\"training_dataset_table.parquet\"],\n", + " targets=[target],\n", + " batch_size=100,\n", + ")\n", + "validating_to_electric_field = DatasetConfig(\n", + " sources=[\"validation_dataset_table.parquet\"],\n", + " targets=[target],\n", + " batch_size=100,\n", + ")\n", + "\n", + "# bringing it together\n", + "data_config = DataConfig(\n", + " training=training_to_electric_field,\n", + " validation=validating_to_electric_field\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "cfd8bba9-241c-4c67-8f88-a89ae608ec89", + "metadata": {}, + "source": [ + "#### Defining an OptimizerConfig\n", + "\n", + "The optimizer config is relatively simple; the only moving part here currently is the learning rate." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1759ed0f-2e89-4d22-80ae-2c0ba04f37a0", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer_config = OptimizerConfig(optimizer=\"Adam\", learning_rate=0.001)" + ] + }, + { + "cell_type": "markdown", + "id": "7c784871-e834-470e-a346-bde893326fee", + "metadata": {}, + "source": [ + "#### Creating a TrainingConfig" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f5958818-03ef-4580-9cbf-c588cbb3141f", + "metadata": {}, + "outputs": [], + "source": [ + "training_config = TrainingConfig(\n", + " model=model_config,\n", + " data=data_config,\n", + " optimizer=optimizer_config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "17d61258-fe16-4993-9c19-57118c9d0a1a", + "metadata": {}, + "source": [ + "### Creating a TrainingGNNModel\n", + "\n", + "Now we can create a `TrainingGNNModel`, which allows easy training of a `GNNModel`. The `GNNModel` can be accessed through `TrainingGNNModel.model`." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "3eefcea4-2e9a-4e4a-b3fc-939f3f733fb8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TrainingGNNModel(\n", + " (model): GNNModel(\n", + " (convolution_module): ConvolutionModule(\n", + " (gcn_layers): SAGEConvStack(\n", + " (0): SAGEConv(\n", + " (feat_drop): Dropout(p=0.0, inplace=False)\n", + " (activation): ReLU()\n", + " (fc_neigh): Linear(in_features=20, out_features=512, bias=False)\n", + " (fc_self): Linear(in_features=20, out_features=512, bias=True)\n", + " )\n", + " (1-5): 5 x SAGEConv(\n", + " (feat_drop): Dropout(p=0.0, inplace=False)\n", + " (activation): ReLU()\n", + " (fc_neigh): Linear(in_features=512, out_features=512, bias=False)\n", + " (fc_self): Linear(in_features=512, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (readout_modules): ModuleDict(\n", + " (charges): ReadoutModule(\n", + " (pooling_layer): PoolAtomFeatures()\n", + " (readout_layers): SequentialLayers(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): ReLU()\n", + " (2): Dropout(p=0.0, inplace=False)\n", + " (3): Linear(in_features=512, out_features=512, bias=True)\n", + " (4): ReLU()\n", + " (5): Dropout(p=0.0, inplace=False)\n", + " (6): Linear(in_features=512, out_features=2, bias=True)\n", + " (7): Identity()\n", + " (8): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (postprocess_layer): ComputePartialCharges()\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "training_model = TrainingGNNModel(training_config)\n", + "training_model" + ] + }, + { + "cell_type": "markdown", + "id": "d696dae6-93ce-4f0a-ab0f-ebb7e72fa176", + "metadata": {}, + "source": [ + "We can look at the initial capabilities of the model by comparing its charges to AM1-BCC charges. They're pretty bad." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "adb99158-fa97-412d-972b-89f6f563b84a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-0.0422502 , -0.05428261, 0.15142802, -0.42682151, 0.09391647,\n", + " 0.09391647, 0.09391647, 0.2050635 , 0.2050635 , -0.15997499,\n", + " -0.15997499])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_molecule = Molecule.from_smiles(\"CCCBr\")\n", + "test_molecule.assign_partial_charges(\"am1bcc\")\n", + "reference_charges = test_molecule.partial_charges.m\n", + "\n", + "# switch to eval mode\n", + "training_model.model.eval()\n", + "\n", + "with torch.no_grad():\n", + " nagl_charges_1 = training_model.model.compute_properties(\n", + " test_molecule,\n", + " as_numpy=True\n", + " )[\"charges\"]\n", + "\n", + "# switch back to training mode\n", + "training_model.model.train()\n", + "\n", + "# compare charges\n", + "differences = reference_charges - nagl_charges_1\n", + "differences" + ] + }, + { + "cell_type": "markdown", + "id": "79cb1d56-3d70-4307-93f0-0d5956fdb32e", + "metadata": {}, + "source": [ + "### Training the model\n", + "\n", + "We use Pytorch Lightning to train." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "5bfcbce1-8323-4403-a607-89292681375b", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "from pytorch_lightning.callbacks import TQDMProgressBar" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "4e0174d7-a57b-49d9-89a3-5c756adc0d8c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "/Users/lily/micromamba/envs/train-gnn-to-electric-field/lib/python3.12/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n", + "Featurizing dataset: 0it [00:00, ?it/s]\n", + "Featurizing batch: 0%| | 0/40 [00:00 typing.List[str]: return [self.target_label] @@ -231,8 +239,6 @@ def evaluate_target( readout_modules: typing.Dict[str, ReadoutModule], ) -> "torch.Tensor": return predictions[self.prediction_label] - - @@ -243,7 +249,11 @@ class HeavyAtomReadoutTarget(_BaseTarget): # name: typing.ClassVar[str] = "heavy_atom_readout" name: typing.Literal["heavy_atom_readout"] = "heavy_atom_readout" - prediction_label: str + prediction_label: str = Field( + description=( + "The predicted property to evaluate the target on." + ) + ) def get_required_columns(self) -> typing.List[str]: return [self.target_label] @@ -258,7 +268,50 @@ def evaluate_target( atomic_numbers = molecules.graph.ndata["atomic_number"] heavy_atom_mask = atomic_numbers != 1 return predictions[self.prediction_label].squeeze()[heavy_atom_mask] + + +class GeneralLinearFitTarget(_BaseTarget): + """A target that is evaluated to solve the general Ax=b equation.""" + + name: typing.Literal["general_linear_fit"] = "general_linear_fit" + prediction_label: str = Field( + description=( + "The predicted property to evaluate the target on." + ) + ) + design_matrix_column: str = Field( + description=( + "The column in the labels that contains the design matrix." + ) + ) + + def get_required_columns(self) -> typing.List[str]: + return [self.target_label, self.design_matrix_column] + def evaluate_target( + self, + molecules: "DGLMoleculeOrBatch", + labels: typing.Dict[str, "torch.Tensor"], + predictions: typing.Dict[str, "torch.Tensor"], + readout_modules: typing.Dict[str, ReadoutModule], + ) -> "torch.Tensor": + x_vectors = predictions[self.prediction_label].squeeze().float() + A_matrices = labels[self.design_matrix_column].float() + all_n_atoms = tuple(map(int, molecules.n_atoms_per_molecule)) + + result_vectors = [] + # split up by molecule + for n_atoms in all_n_atoms: + x_vector = x_vectors[:n_atoms] + A_matrix = A_matrices[:n_atoms * n_atoms].reshape(n_atoms, n_atoms) + result = torch.matmul(A_matrix, x_vector) + result_vectors.append(result) + + x_vectors = x_vectors[n_atoms:] + A_matrices = A_matrices[n_atoms * n_atoms:] + + return torch.cat(result_vectors) + class SingleDipoleTarget(_BaseTarget): """A target that is evaluated on the dipole of a molecule.""" @@ -423,4 +476,5 @@ def evaluate_target( HeavyAtomReadoutTarget, SingleDipoleTarget, MultipleESPTarget, + GeneralLinearFitTarget ] \ No newline at end of file diff --git a/openff/nagl/training/training.py b/openff/nagl/training/training.py index 8806658..001dcea 100644 --- a/openff/nagl/training/training.py +++ b/openff/nagl/training/training.py @@ -113,6 +113,9 @@ def _torch_optimizer(self): optimizer = self.optimizers() return optimizer.optimizer + def create_data_module(self, n_processes: int = 0, verbose: bool = True): + return DGLMoleculeDataModule(self.config, n_processes=n_processes, verbose=verbose) + class DGLMoleculeDataModule(pl.LightningDataModule): def __init__(