Non-Linear NCEM Example#

%load_ext autoreload
%autoreload 2
import pytorch_lightning as pl
import torch
import anndata as ad
from geome import transforms
from geome.adata2data import AnnData2DataByCategory
from utils.datasets import DatasetHartmann
from utils.models.non_linear_ncem import NonLinearNCEM
from geome.datamodule import GraphAnnDataModule
fields = {
    'x':['obs/Cluster_preprocessed','obs/donor'],
    'edge_index': ['uns/edge_index'],
    'y':['X']
}

from geome.transforms import Categorize, AddDesignMatrix, Compose, AddAdjMatrix, AddEdgeIndex

adj_matrix_loc = 'obsp/adjacency_matrix_connectivities'


preprocess = Categorize(['donor', 'Cluster_preprocessed', 'point'],axis='obs')
transform = Compose([
    AddAdjMatrix(location=adj_matrix_loc),
    AddEdgeIndex(adj_matrix_loc=adj_matrix_loc,edge_index_key='edge_index'),
])


category_to_iterate = 'point'

a2d = AnnData2DataByCategory(
    fields=fields,
    category=category_to_iterate,
    preprocess=preprocess,
    transform=transform,
)


#Mibitof
dataset = DatasetHartmann(data_path='./example_data/hartmann/')
adatas = list(dataset.img_celldata.values())

# Merge the list of adatas and convert some string to categories as they should be
adata = ad.concat(adatas)

datas = a2d(adata)
datas
Loading data from raw files
registering celldata
collecting image-wise celldata
adding graph-level covariates
Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.
[Data(x=[1338, 12], edge_index=[2, 8028], y=[1338, 36]),
 Data(x=[311, 12], edge_index=[2, 1866], y=[311, 36]),
 Data(x=[768, 12], edge_index=[2, 4608], y=[768, 36]),
 Data(x=[1020, 12], edge_index=[2, 6120], y=[1020, 36]),
 Data(x=[2100, 12], edge_index=[2, 12600], y=[2100, 36]),
 Data(x=[1325, 12], edge_index=[2, 7950], y=[1325, 36]),
 Data(x=[1091, 12], edge_index=[2, 6546], y=[1091, 36]),
 Data(x=[1046, 12], edge_index=[2, 6276], y=[1046, 36]),
 Data(x=[618, 12], edge_index=[2, 3708], y=[618, 36]),
 Data(x=[61, 12], edge_index=[2, 366], y=[61, 36]),
 Data(x=[1316, 12], edge_index=[2, 7896], y=[1316, 36]),
 Data(x=[1540, 12], edge_index=[2, 9240], y=[1540, 36]),
 Data(x=[1822, 12], edge_index=[2, 10932], y=[1822, 36]),
 Data(x=[863, 12], edge_index=[2, 5178], y=[863, 36]),
 Data(x=[564, 12], edge_index=[2, 3384], y=[564, 36]),
 Data(x=[1023, 12], edge_index=[2, 6138], y=[1023, 36]),
 Data(x=[324, 12], edge_index=[2, 1944], y=[324, 36]),
 Data(x=[287, 12], edge_index=[2, 1722], y=[287, 36]),
 Data(x=[636, 12], edge_index=[2, 3816], y=[636, 36]),
 Data(x=[890, 12], edge_index=[2, 5340], y=[890, 36]),
 Data(x=[1235, 12], edge_index=[2, 7410], y=[1235, 36]),
 Data(x=[1020, 12], edge_index=[2, 6120], y=[1020, 36]),
 Data(x=[1241, 12], edge_index=[2, 7446], y=[1241, 36]),
 Data(x=[1438, 12], edge_index=[2, 8628], y=[1438, 36]),
 Data(x=[1021, 12], edge_index=[2, 6126], y=[1021, 36]),
 Data(x=[1632, 12], edge_index=[2, 9792], y=[1632, 36]),
 Data(x=[780, 12], edge_index=[2, 4680], y=[780, 36]),
 Data(x=[524, 12], edge_index=[2, 3144], y=[524, 36]),
 Data(x=[669, 12], edge_index=[2, 4014], y=[669, 36]),
 Data(x=[241, 12], edge_index=[2, 1446], y=[241, 36]),
 Data(x=[935, 12], edge_index=[2, 5610], y=[935, 36]),
 Data(x=[347, 12], edge_index=[2, 2082], y=[347, 36]),
 Data(x=[1499, 12], edge_index=[2, 8994], y=[1499, 36]),
 Data(x=[601, 12], edge_index=[2, 3606], y=[601, 36]),
 Data(x=[2268, 12], edge_index=[2, 13608], y=[2268, 36]),
 Data(x=[1912, 12], edge_index=[2, 11472], y=[1912, 36]),
 Data(x=[1678, 12], edge_index=[2, 10068], y=[1678, 36]),
 Data(x=[1025, 12], edge_index=[2, 6150], y=[1025, 36]),
 Data(x=[1306, 12], edge_index=[2, 7836], y=[1306, 36]),
 Data(x=[852, 12], edge_index=[2, 5112], y=[852, 36]),
 Data(x=[1664, 12], edge_index=[2, 9984], y=[1664, 36]),
 Data(x=[1698, 12], edge_index=[2, 10188], y=[1698, 36]),
 Data(x=[1672, 12], edge_index=[2, 10032], y=[1672, 36]),
 Data(x=[777, 12], edge_index=[2, 4662], y=[777, 36]),
 Data(x=[556, 12], edge_index=[2, 3336], y=[556, 36]),
 Data(x=[554, 12], edge_index=[2, 3324], y=[554, 36]),
 Data(x=[937, 12], edge_index=[2, 5622], y=[937, 36]),
 Data(x=[1524, 12], edge_index=[2, 9144], y=[1524, 36]),
 Data(x=[1528, 12], edge_index=[2, 9168], y=[1528, 36]),
 Data(x=[721, 12], edge_index=[2, 4326], y=[721, 36]),
 Data(x=[1395, 12], edge_index=[2, 8370], y=[1395, 36]),
 Data(x=[611, 12], edge_index=[2, 3666], y=[611, 36]),
 Data(x=[1872, 12], edge_index=[2, 11232], y=[1872, 36]),
 Data(x=[1217, 12], edge_index=[2, 7302], y=[1217, 36]),
 Data(x=[1531, 12], edge_index=[2, 9186], y=[1531, 36]),
 Data(x=[1927, 12], edge_index=[2, 11562], y=[1927, 36]),
 Data(x=[690, 12], edge_index=[2, 4140], y=[690, 36]),
 Data(x=[1706, 12], edge_index=[2, 10236], y=[1706, 36])]
num_features = datas[0].x.shape[1]
out_channels = datas[0].y.shape[1]
num_features, out_channels
(12, 36)
dm = GraphAnnDataModule(datas=datas, num_workers = 12, batch_size=100,learning_type='node')
model = NonLinearNCEM(
    in_channels=num_features,
    out_channels=out_channels,
    encoder_hidden_dims=[16],
    decoder_hidden_dims=[16],
    latent_dim=14,
    lr=0.001,weight_decay=0.00001)
trainer:pl.Trainer = pl.Trainer(accelerator='gpu' if torch.torch.cuda.is_available() else 'cpu',
                                max_epochs=100,log_every_n_steps=10)
trainer.fit(model,datamodule=dm)
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[7], line 1
----> 1 trainer.fit(model,datamodule=dm)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:531, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    529 model = _maybe_unwrap_optimized(model)
    530 self.strategy._lightning_module = model
--> 531 call._call_and_handle_interrupt(
    532     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    533 )

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:42, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     40     if trainer.strategy.launcher is not None:
     41         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 42     return trainer_fn(*args, **kwargs)
     44 except _TunerExitException:
     45     _call_teardown_hook(trainer)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:570, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    560 self._data_connector.attach_data(
    561     model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
    562 )
    564 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    565     self.state.fn,
    566     ckpt_path,
    567     model_provided=True,
    568     model_connected=self.lightning_module is not None,
    569 )
--> 570 self._run(model, ckpt_path=ckpt_path)
    572 assert self.state.stopped
    573 self.training = False

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:975, in Trainer._run(self, model, ckpt_path)
    970 self._signal_connector.register_signal_handlers()
    972 # ----------------------------
    973 # RUN THE TRAINER
    974 # ----------------------------
--> 975 results = self._run_stage()
    977 # ----------------------------
    978 # POST-Training CLEAN UP
    979 # ----------------------------
    980 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1016, in Trainer._run_stage(self)
   1014 if self.training:
   1015     with isolate_rng():
-> 1016         self._run_sanity_check()
   1017     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
   1018         self.fit_loop.run()

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1045, in Trainer._run_sanity_check(self)
   1042 call._call_callback_hooks(self, "on_sanity_check_start")
   1044 # run eval step
-> 1045 val_loop.run()
   1047 call._call_callback_hooks(self, "on_sanity_check_end")
   1049 # reset logger connector

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:177, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    175     context_manager = torch.no_grad
    176 with context_manager():
--> 177     return loop_run(self, *args, **kwargs)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:108, in _EvaluationLoop.run(self)
    106 while True:
    107     try:
--> 108         batch, batch_idx, dataloader_idx = next(data_fetcher)
    109         self.batch_progress.is_last_batch = data_fetcher.done
    110         if previous_dataloader_idx != dataloader_idx:
    111             # the dataloader has changed, notify the logger connector

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:136, in _PrefetchDataFetcher.__next__(self)
    133 elif not self.done:
    134     # this will run only when no pre-fetching was done.
    135     try:
--> 136         self._fetch_next_batch(self.dataloader_iter)
    137         # consume the batch we just fetched
    138         batch = self.batches.pop(0)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:150, in _PrefetchDataFetcher._fetch_next_batch(self, iterator)
    148 self._start_profiler()
    149 try:
--> 150     batch = next(iterator)
    151 finally:
    152     self._stop_profiler()

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:284, in CombinedLoader.__next__(self)
    282 def __next__(self) -> Any:
    283     assert self._iterator is not None
--> 284     out = next(self._iterator)
    285     if isinstance(self._iterator, _Sequential):
    286         return out

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:123, in _Sequential.__next__(self)
    120             raise StopIteration
    122 try:
--> 123     out = next(self.iterators[0])
    124     index = self._idx
    125     self._idx += 1

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/loader/base.py:36, in DataLoaderIterator.__next__(self)
     35 def __next__(self) -> Any:
---> 36     return self.transform_fn(next(self.iterator))

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
    630 if self._sampler_iter is None:
    631     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    632     self._reset()  # type: ignore[call-arg]
--> 633 data = self._next_data()
    634 self._num_yielded += 1
    635 if self._dataset_kind == _DatasetKind.Iterable and \
    636         self._IterableDataset_len_called is not None and \
    637         self._num_yielded > self._IterableDataset_len_called:

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1345, in _MultiProcessingDataLoaderIter._next_data(self)
   1343 else:
   1344     del self._task_info[idx]
-> 1345     return self._process_data(data)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1371, in _MultiProcessingDataLoaderIter._process_data(self, data)
   1369 self._try_put_index()
   1370 if isinstance(data, ExceptionWrapper):
-> 1371     data.reraise()
   1372 return data

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/_utils.py:644, in ExceptionWrapper.reraise(self)
    640 except TypeError:
    641     # If the exception takes multiple arguments, don't try to
    642     # instantiate since we don't know how to
    643     raise RuntimeError(msg) from None
--> 644 raise exception

ImportError: Caught ImportError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/loader/node_loader.py", line 117, in collate_fn
    out = self.node_sampler.sample_from_nodes(input_data)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 174, in sample_from_nodes
    return node_sample(inputs, self._sample)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 358, in node_sample
    out = sample_fn(seed, seed_time)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 325, in _sample
    raise ImportError(f"'{self.__class__.__name__}' requires "
ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'
trainer.test(model, datamodule=dm)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_loss              43.77671432495117     │
│       test_r2_score           -0.9155278940025104    │
└───────────────────────────┴───────────────────────────┘
[{'test_r2_score': -0.9155278940025104, 'test_loss': 43.77671432495117}]