Non-Linear NCEM Example#
%load_ext autoreload
%autoreload 2
import pytorch_lightning as pl
import torch
import anndata as ad
from geome import transforms
from geome.adata2data import AnnData2DataByCategory
from utils.datasets import DatasetHartmann
from utils.models.non_linear_ncem import NonLinearNCEM
from geome.datamodule import GraphAnnDataModule
fields = {
'x':['obs/Cluster_preprocessed','obs/donor'],
'edge_index': ['uns/edge_index'],
'y':['X']
}
from geome.transforms import Categorize, AddDesignMatrix, Compose, AddAdjMatrix, AddEdgeIndex
adj_matrix_loc = 'obsp/adjacency_matrix_connectivities'
preprocess = Categorize(['donor', 'Cluster_preprocessed', 'point'],axis='obs')
transform = Compose([
AddAdjMatrix(location=adj_matrix_loc),
AddEdgeIndex(adj_matrix_loc=adj_matrix_loc,edge_index_key='edge_index'),
])
category_to_iterate = 'point'
a2d = AnnData2DataByCategory(
fields=fields,
category=category_to_iterate,
preprocess=preprocess,
transform=transform,
)
#Mibitof
dataset = DatasetHartmann(data_path='./example_data/hartmann/')
adatas = list(dataset.img_celldata.values())
# Merge the list of adatas and convert some string to categories as they should be
adata = ad.concat(adatas)
datas = a2d(adata)
datas
Loading data from raw files
registering celldata
collecting image-wise celldata
adding graph-level covariates
Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.
[Data(x=[1338, 12], edge_index=[2, 8028], y=[1338, 36]),
Data(x=[311, 12], edge_index=[2, 1866], y=[311, 36]),
Data(x=[768, 12], edge_index=[2, 4608], y=[768, 36]),
Data(x=[1020, 12], edge_index=[2, 6120], y=[1020, 36]),
Data(x=[2100, 12], edge_index=[2, 12600], y=[2100, 36]),
Data(x=[1325, 12], edge_index=[2, 7950], y=[1325, 36]),
Data(x=[1091, 12], edge_index=[2, 6546], y=[1091, 36]),
Data(x=[1046, 12], edge_index=[2, 6276], y=[1046, 36]),
Data(x=[618, 12], edge_index=[2, 3708], y=[618, 36]),
Data(x=[61, 12], edge_index=[2, 366], y=[61, 36]),
Data(x=[1316, 12], edge_index=[2, 7896], y=[1316, 36]),
Data(x=[1540, 12], edge_index=[2, 9240], y=[1540, 36]),
Data(x=[1822, 12], edge_index=[2, 10932], y=[1822, 36]),
Data(x=[863, 12], edge_index=[2, 5178], y=[863, 36]),
Data(x=[564, 12], edge_index=[2, 3384], y=[564, 36]),
Data(x=[1023, 12], edge_index=[2, 6138], y=[1023, 36]),
Data(x=[324, 12], edge_index=[2, 1944], y=[324, 36]),
Data(x=[287, 12], edge_index=[2, 1722], y=[287, 36]),
Data(x=[636, 12], edge_index=[2, 3816], y=[636, 36]),
Data(x=[890, 12], edge_index=[2, 5340], y=[890, 36]),
Data(x=[1235, 12], edge_index=[2, 7410], y=[1235, 36]),
Data(x=[1020, 12], edge_index=[2, 6120], y=[1020, 36]),
Data(x=[1241, 12], edge_index=[2, 7446], y=[1241, 36]),
Data(x=[1438, 12], edge_index=[2, 8628], y=[1438, 36]),
Data(x=[1021, 12], edge_index=[2, 6126], y=[1021, 36]),
Data(x=[1632, 12], edge_index=[2, 9792], y=[1632, 36]),
Data(x=[780, 12], edge_index=[2, 4680], y=[780, 36]),
Data(x=[524, 12], edge_index=[2, 3144], y=[524, 36]),
Data(x=[669, 12], edge_index=[2, 4014], y=[669, 36]),
Data(x=[241, 12], edge_index=[2, 1446], y=[241, 36]),
Data(x=[935, 12], edge_index=[2, 5610], y=[935, 36]),
Data(x=[347, 12], edge_index=[2, 2082], y=[347, 36]),
Data(x=[1499, 12], edge_index=[2, 8994], y=[1499, 36]),
Data(x=[601, 12], edge_index=[2, 3606], y=[601, 36]),
Data(x=[2268, 12], edge_index=[2, 13608], y=[2268, 36]),
Data(x=[1912, 12], edge_index=[2, 11472], y=[1912, 36]),
Data(x=[1678, 12], edge_index=[2, 10068], y=[1678, 36]),
Data(x=[1025, 12], edge_index=[2, 6150], y=[1025, 36]),
Data(x=[1306, 12], edge_index=[2, 7836], y=[1306, 36]),
Data(x=[852, 12], edge_index=[2, 5112], y=[852, 36]),
Data(x=[1664, 12], edge_index=[2, 9984], y=[1664, 36]),
Data(x=[1698, 12], edge_index=[2, 10188], y=[1698, 36]),
Data(x=[1672, 12], edge_index=[2, 10032], y=[1672, 36]),
Data(x=[777, 12], edge_index=[2, 4662], y=[777, 36]),
Data(x=[556, 12], edge_index=[2, 3336], y=[556, 36]),
Data(x=[554, 12], edge_index=[2, 3324], y=[554, 36]),
Data(x=[937, 12], edge_index=[2, 5622], y=[937, 36]),
Data(x=[1524, 12], edge_index=[2, 9144], y=[1524, 36]),
Data(x=[1528, 12], edge_index=[2, 9168], y=[1528, 36]),
Data(x=[721, 12], edge_index=[2, 4326], y=[721, 36]),
Data(x=[1395, 12], edge_index=[2, 8370], y=[1395, 36]),
Data(x=[611, 12], edge_index=[2, 3666], y=[611, 36]),
Data(x=[1872, 12], edge_index=[2, 11232], y=[1872, 36]),
Data(x=[1217, 12], edge_index=[2, 7302], y=[1217, 36]),
Data(x=[1531, 12], edge_index=[2, 9186], y=[1531, 36]),
Data(x=[1927, 12], edge_index=[2, 11562], y=[1927, 36]),
Data(x=[690, 12], edge_index=[2, 4140], y=[690, 36]),
Data(x=[1706, 12], edge_index=[2, 10236], y=[1706, 36])]
num_features = datas[0].x.shape[1]
out_channels = datas[0].y.shape[1]
num_features, out_channels
(12, 36)
dm = GraphAnnDataModule(datas=datas, num_workers = 12, batch_size=100,learning_type='node')
model = NonLinearNCEM(
in_channels=num_features,
out_channels=out_channels,
encoder_hidden_dims=[16],
decoder_hidden_dims=[16],
latent_dim=14,
lr=0.001,weight_decay=0.00001)
trainer:pl.Trainer = pl.Trainer(accelerator='gpu' if torch.torch.cuda.is_available() else 'cpu',
max_epochs=100,log_every_n_steps=10)
trainer.fit(model,datamodule=dm)
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[7], line 1
----> 1 trainer.fit(model,datamodule=dm)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:531, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
529 model = _maybe_unwrap_optimized(model)
530 self.strategy._lightning_module = model
--> 531 call._call_and_handle_interrupt(
532 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
533 )
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:42, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
40 if trainer.strategy.launcher is not None:
41 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 42 return trainer_fn(*args, **kwargs)
44 except _TunerExitException:
45 _call_teardown_hook(trainer)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:570, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
560 self._data_connector.attach_data(
561 model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
562 )
564 ckpt_path = self._checkpoint_connector._select_ckpt_path(
565 self.state.fn,
566 ckpt_path,
567 model_provided=True,
568 model_connected=self.lightning_module is not None,
569 )
--> 570 self._run(model, ckpt_path=ckpt_path)
572 assert self.state.stopped
573 self.training = False
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:975, in Trainer._run(self, model, ckpt_path)
970 self._signal_connector.register_signal_handlers()
972 # ----------------------------
973 # RUN THE TRAINER
974 # ----------------------------
--> 975 results = self._run_stage()
977 # ----------------------------
978 # POST-Training CLEAN UP
979 # ----------------------------
980 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1016, in Trainer._run_stage(self)
1014 if self.training:
1015 with isolate_rng():
-> 1016 self._run_sanity_check()
1017 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
1018 self.fit_loop.run()
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1045, in Trainer._run_sanity_check(self)
1042 call._call_callback_hooks(self, "on_sanity_check_start")
1044 # run eval step
-> 1045 val_loop.run()
1047 call._call_callback_hooks(self, "on_sanity_check_end")
1049 # reset logger connector
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:177, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
175 context_manager = torch.no_grad
176 with context_manager():
--> 177 return loop_run(self, *args, **kwargs)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:108, in _EvaluationLoop.run(self)
106 while True:
107 try:
--> 108 batch, batch_idx, dataloader_idx = next(data_fetcher)
109 self.batch_progress.is_last_batch = data_fetcher.done
110 if previous_dataloader_idx != dataloader_idx:
111 # the dataloader has changed, notify the logger connector
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:136, in _PrefetchDataFetcher.__next__(self)
133 elif not self.done:
134 # this will run only when no pre-fetching was done.
135 try:
--> 136 self._fetch_next_batch(self.dataloader_iter)
137 # consume the batch we just fetched
138 batch = self.batches.pop(0)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:150, in _PrefetchDataFetcher._fetch_next_batch(self, iterator)
148 self._start_profiler()
149 try:
--> 150 batch = next(iterator)
151 finally:
152 self._stop_profiler()
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:284, in CombinedLoader.__next__(self)
282 def __next__(self) -> Any:
283 assert self._iterator is not None
--> 284 out = next(self._iterator)
285 if isinstance(self._iterator, _Sequential):
286 return out
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:123, in _Sequential.__next__(self)
120 raise StopIteration
122 try:
--> 123 out = next(self.iterators[0])
124 index = self._idx
125 self._idx += 1
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/loader/base.py:36, in DataLoaderIterator.__next__(self)
35 def __next__(self) -> Any:
---> 36 return self.transform_fn(next(self.iterator))
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
630 if self._sampler_iter is None:
631 # TODO(https://github.com/pytorch/pytorch/issues/76750)
632 self._reset() # type: ignore[call-arg]
--> 633 data = self._next_data()
634 self._num_yielded += 1
635 if self._dataset_kind == _DatasetKind.Iterable and \
636 self._IterableDataset_len_called is not None and \
637 self._num_yielded > self._IterableDataset_len_called:
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1345, in _MultiProcessingDataLoaderIter._next_data(self)
1343 else:
1344 del self._task_info[idx]
-> 1345 return self._process_data(data)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1371, in _MultiProcessingDataLoaderIter._process_data(self, data)
1369 self._try_put_index()
1370 if isinstance(data, ExceptionWrapper):
-> 1371 data.reraise()
1372 return data
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/_utils.py:644, in ExceptionWrapper.reraise(self)
640 except TypeError:
641 # If the exception takes multiple arguments, don't try to
642 # instantiate since we don't know how to
643 raise RuntimeError(msg) from None
--> 644 raise exception
ImportError: Caught ImportError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
return self.collate_fn(data)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/loader/node_loader.py", line 117, in collate_fn
out = self.node_sampler.sample_from_nodes(input_data)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 174, in sample_from_nodes
return node_sample(inputs, self._sample)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 358, in node_sample
out = sample_fn(seed, seed_time)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 325, in _sample
raise ImportError(f"'{self.__class__.__name__}' requires "
ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'
trainer.test(model, datamodule=dm)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_loss │ 43.77671432495117 │ │ test_r2_score │ -0.9155278940025104 │ └───────────────────────────┴───────────────────────────┘
[{'test_r2_score': -0.9155278940025104, 'test_loss': 43.77671432495117}]