diff --git a/README.md b/README.md index 3d297bc9..05c37205 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ This is the official codebase for **scGPT: Towards Building a Foundation Model f **[2023.12.31]** New tutorials about zero-shot applications are now available! Please see find them in the [tutorials/zero-shot](tutorials/zero-shot) directory. We also provide a new continual pretrained model checkpoint for cell embedding related tasks. Please see the [notebook](tutorials/zero-shot/Tutorial_ZeroShot_Integration_Continual_Pretraining.ipynb) for more details. -**[2023.11.07]** As requested by many, now we have made flash-attention an optional dependency. The pretrained weights can be loaded on pytorch CPU, GPU, and flash-attn backends using the same [load_pretrained](https://github.com/bowang-lab/scGPT/blob/f6097112fe5175cd4e221890ed2e2b1815f54010/scgpt/utils/util.py#L304) function, `load_pretrained(target_model, torch.load("path_to_ckpt.pt"))`. An example usage is also [here](https://github.com/bowang-lab/scGPT/blob/f6097112fe5175cd4e221890ed2e2b1815f54010/scgpt/tasks/cell_emb.py#L258). +**[2023.11.07]** As requested by many, now we have made flash-attention an optional dependency. The pretrained weights can be loaded on pytorch CPU, GPU, and flash-attn backends using the same [load_pretrained](https://github.com/bowang-lab/scGPT/blob/f6097112fe5175cd4e221890ed2e2b1815f54010/scgpt/utils/util.py#L304) function, `load_pretrained(target_model, torch.load("path_to_ckpt.pt", map_location=device))`. An example usage is also [here](https://github.com/bowang-lab/scGPT/blob/f6097112fe5175cd4e221890ed2e2b1815f54010/scgpt/tasks/cell_emb.py#L258). **[2023.09.05]** We have release a new feature for reference mapping samples to a custom reference dataset or to all the millions of cells collected from CellXGene! With the help of the [faiss](https://github.com/facebookresearch/faiss) library, we achieved a great time and memory efficiency. The index of over 33 millions cells only takes less than 1GB of memory and the similarity search takes less than **1 second for 10,000 query cells on GPU**. Please see the [Reference mapping tutorial](https://github.com/bowang-lab/scGPT/blob/main/tutorials/Tutorial_Reference_Mapping.ipynb) for more details. diff --git a/examples/finetune_integration.py b/examples/finetune_integration.py index badf125a..6b7b5dae 100644 --- a/examples/finetune_integration.py +++ b/examples/finetune_integration.py @@ -405,12 +405,12 @@ def prepare_dataloader( ) if config.load_model is not None: try: - model.load_state_dict(torch.load(model_file)) + model.load_state_dict(torch.load(model_file, map_location=device)) logger.info(f"Loading all model params from {model_file}") except: # only load params that are in the model and match the size model_dict = model.state_dict() - pretrained_dict = torch.load(model_file) + pretrained_dict = torch.load(model_file, map_location=device) pretrained_dict = { k: v for k, v in pretrained_dict.items() diff --git a/tutorials/Tutorial_Annotation.ipynb b/tutorials/Tutorial_Annotation.ipynb index 13d039ed..742098ad 100644 --- a/tutorials/Tutorial_Annotation.ipynb +++ b/tutorials/Tutorial_Annotation.ipynb @@ -1271,12 +1271,12 @@ ")\n", "if config.load_model is not None:\n", " try:\n", - " model.load_state_dict(torch.load(model_file))\n", + " model.load_state_dict(torch.load(model_file, map_location=device))\n", " logger.info(f\"Loading all model params from {model_file}\")\n", " except:\n", " # only load params that are in the model and match the size\n", " model_dict = model.state_dict()\n", - " pretrained_dict = torch.load(model_file)\n", + " pretrained_dict = torch.load(model_file, map_location=device)\n", " pretrained_dict = {\n", " k: v\n", " for k, v in pretrained_dict.items()\n", diff --git a/tutorials/Tutorial_Attention_GRN.ipynb b/tutorials/Tutorial_Attention_GRN.ipynb index 76c9cfc0..888511ad 100644 --- a/tutorials/Tutorial_Attention_GRN.ipynb +++ b/tutorials/Tutorial_Attention_GRN.ipynb @@ -598,12 +598,12 @@ ")\n", "\n", "try:\n", - " model.load_state_dict(torch.load(model_file))\n", + " model.load_state_dict(torch.load(model_file, map_location=device))\n", " print(f\"Loading all model params from {model_file}\")\n", "except:\n", " # only load params that are in the model and match the size\n", " model_dict = model.state_dict()\n", - " pretrained_dict = torch.load(model_file)\n", + " pretrained_dict = torch.load(model_file, map_location=device)\n", " pretrained_dict = {\n", " k: v\n", " for k, v in pretrained_dict.items()\n", diff --git a/tutorials/Tutorial_GRN.ipynb b/tutorials/Tutorial_GRN.ipynb index 61b4c69e..98b5250a 100644 --- a/tutorials/Tutorial_GRN.ipynb +++ b/tutorials/Tutorial_GRN.ipynb @@ -521,12 +521,12 @@ ")\n", "\n", "try:\n", - " model.load_state_dict(torch.load(model_file))\n", + " model.load_state_dict(torch.load(model_file, map_location=device))\n", " print(f\"Loading all model params from {model_file}\")\n", "except:\n", " # only load params that are in the model and match the size\n", " model_dict = model.state_dict()\n", - " pretrained_dict = torch.load(model_file)\n", + " pretrained_dict = torch.load(model_file, map_location=device)\n", " pretrained_dict = {\n", " k: v\n", " for k, v in pretrained_dict.items()\n", diff --git a/tutorials/Tutorial_Integration.ipynb b/tutorials/Tutorial_Integration.ipynb index 3fd9949e..aeb8980b 100644 --- a/tutorials/Tutorial_Integration.ipynb +++ b/tutorials/Tutorial_Integration.ipynb @@ -689,7 +689,7 @@ " pre_norm=config.pre_norm,\n", ")\n", "if config.load_model is not None:\n", - " load_pretrained(model, torch.load(model_file), verbose=False)\n", + " load_pretrained(model, torch.load(model_file, map_location=device), verbose=False)\n", "\n", "model.to(device)\n", "wandb.watch(model)" diff --git a/tutorials/Tutorial_Multiomics.ipynb b/tutorials/Tutorial_Multiomics.ipynb index d4d09bb1..9897e7ba 100644 --- a/tutorials/Tutorial_Multiomics.ipynb +++ b/tutorials/Tutorial_Multiomics.ipynb @@ -869,7 +869,7 @@ ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "model_dict = torch.load(model_file)\n", + "model_dict = torch.load(model_file, map_location=device)\n", "ntokens = len(vocab) # size of vocabulary\n", "model = MultiOmicTransformerModel(\n", " ntokens,\n", diff --git a/tutorials/Tutorial_Perturbation.ipynb b/tutorials/Tutorial_Perturbation.ipynb index 038af382..e1cb30a3 100644 --- a/tutorials/Tutorial_Perturbation.ipynb +++ b/tutorials/Tutorial_Perturbation.ipynb @@ -659,7 +659,7 @@ "if load_param_prefixs is not None and load_model is not None:\n", " # only load params that start with the prefix\n", " model_dict = model.state_dict()\n", - " pretrained_dict = torch.load(model_file)\n", + " pretrained_dict = torch.load(model_file, map_location=device)\n", " pretrained_dict = {\n", " k: v\n", " for k, v in pretrained_dict.items()\n", @@ -671,12 +671,12 @@ " model.load_state_dict(model_dict)\n", "elif load_model is not None:\n", " try:\n", - " model.load_state_dict(torch.load(model_file))\n", + " model.load_state_dict(torch.load(model_file, map_location=device))\n", " logger.info(f\"Loading all model params from {model_file}\")\n", " except:\n", " # only load params that are in the model and match the size\n", " model_dict = model.state_dict()\n", - " pretrained_dict = torch.load(model_file)\n", + " pretrained_dict = torch.load(model_file, map_location=device)\n", " pretrained_dict = {\n", " k: v\n", " for k, v in pretrained_dict.items()\n",