Linear NCEM Example#

%load_ext autoreload
%autoreload 2
import pytorch_lightning as pl
import torch
import anndata as ad
from geome import transforms 
from geome.adata2data import AnnData2DataByCategory
from utils.datasets import DatasetHartmann
from utils.models.linear_ncem import LinearNCEM
from geome.datamodule import GraphAnnDataModule
fields = {
    'x':['obs/Cluster_preprocessed','obs/donor','obsm/design_matrix'],
    'y':['X']
}


# preprocess = [
#     lambda x,_: transforms.categorize_obs(x,['donor', 'Cluster_preprocessed', 'point']),
#     lambda x,_: transforms.add_design_matrix(x,'obs/Cluster_preprocessed','obs/donor','design_matrix'),
# ]

from geome.transforms import Categorize, AddDesignMatrix, Compose, AddAdjMatrix

adj_matrix_loc = 'obsp/adj_matrix'

preprocess = Compose(
    [
        Categorize(keys=['donor', 'Cluster_preprocessed', 'point'], axis='obs'),
        AddAdjMatrix(adj_matrix_loc, overwrite=True),
    ]
)



transform = Compose(
    [
        AddDesignMatrix('obs/Cluster_preprocessed','obs/donor', adj_matrix_loc, 'design_matrix', overwrite=True),

    ]
)


# a2d = AnnData2DataByCategory(
#     fields=fields,
#     category=category_to_iterate,
#     preprocess=preprocess,
#     yields_edge_index=True,
# )

category_to_iterate = 'point'


a2d = AnnData2DataByCategory(
    fields=fields,
    category=category_to_iterate,
    preprocess=preprocess,
    transform=transform,
)

#Mibitof
dataset = DatasetHartmann(data_path='./example_data/hartmann/')
adatas = list(dataset.img_celldata.values())

# Merge the list of adatas and convert some string to categories as they should be
adata = ad.concat(adatas)
Loading data from raw files
registering celldata
collecting image-wise celldata
adding graph-level covariates
Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.
datas = a2d(adata)
datas
[Data(x=[1338, 88], y=[1338, 36]),
 Data(x=[311, 88], y=[311, 36]),
 Data(x=[768, 88], y=[768, 36]),
 Data(x=[1020, 88], y=[1020, 36]),
 Data(x=[2100, 88], y=[2100, 36]),
 Data(x=[1325, 88], y=[1325, 36]),
 Data(x=[1091, 88], y=[1091, 36]),
 Data(x=[1046, 88], y=[1046, 36]),
 Data(x=[618, 88], y=[618, 36]),
 Data(x=[61, 88], y=[61, 36]),
 Data(x=[1316, 88], y=[1316, 36]),
 Data(x=[1540, 88], y=[1540, 36]),
 Data(x=[1822, 88], y=[1822, 36]),
 Data(x=[863, 88], y=[863, 36]),
 Data(x=[564, 88], y=[564, 36]),
 Data(x=[1023, 88], y=[1023, 36]),
 Data(x=[324, 88], y=[324, 36]),
 Data(x=[287, 88], y=[287, 36]),
 Data(x=[636, 88], y=[636, 36]),
 Data(x=[890, 88], y=[890, 36]),
 Data(x=[1235, 88], y=[1235, 36]),
 Data(x=[1020, 88], y=[1020, 36]),
 Data(x=[1241, 88], y=[1241, 36]),
 Data(x=[1438, 88], y=[1438, 36]),
 Data(x=[1021, 88], y=[1021, 36]),
 Data(x=[1632, 88], y=[1632, 36]),
 Data(x=[780, 88], y=[780, 36]),
 Data(x=[524, 88], y=[524, 36]),
 Data(x=[669, 88], y=[669, 36]),
 Data(x=[241, 88], y=[241, 36]),
 Data(x=[935, 88], y=[935, 36]),
 Data(x=[347, 88], y=[347, 36]),
 Data(x=[1499, 88], y=[1499, 36]),
 Data(x=[601, 88], y=[601, 36]),
 Data(x=[2268, 88], y=[2268, 36]),
 Data(x=[1912, 88], y=[1912, 36]),
 Data(x=[1678, 88], y=[1678, 36]),
 Data(x=[1025, 88], y=[1025, 36]),
 Data(x=[1306, 88], y=[1306, 36]),
 Data(x=[852, 88], y=[852, 36]),
 Data(x=[1664, 88], y=[1664, 36]),
 Data(x=[1698, 88], y=[1698, 36]),
 Data(x=[1672, 88], y=[1672, 36]),
 Data(x=[777, 88], y=[777, 36]),
 Data(x=[556, 88], y=[556, 36]),
 Data(x=[554, 88], y=[554, 36]),
 Data(x=[937, 88], y=[937, 36]),
 Data(x=[1524, 88], y=[1524, 36]),
 Data(x=[1528, 88], y=[1528, 36]),
 Data(x=[721, 88], y=[721, 36]),
 Data(x=[1395, 88], y=[1395, 36]),
 Data(x=[611, 88], y=[611, 36]),
 Data(x=[1872, 88], y=[1872, 36]),
 Data(x=[1217, 88], y=[1217, 36]),
 Data(x=[1531, 88], y=[1531, 36]),
 Data(x=[1927, 88], y=[1927, 36]),
 Data(x=[690, 88], y=[690, 36]),
 Data(x=[1706, 88], y=[1706, 36])]
num_features = datas[0].x.shape[1]
out_channels = datas[0].y.shape[1]
num_features, out_channels
(88, 36)
dm = GraphAnnDataModule(datas=datas, num_workers = 12, batch_size=12,learning_type='node')
model = LinearNCEM(in_channels=num_features,out_channels=out_channels, lr=0.0001,weight_decay=0.000001)
trainer:pl.Trainer = pl.Trainer(accelerator='gpu' if torch.torch.cuda.is_available() else 'cpu',
                                max_epochs=100)
trainer.fit(model,datamodule=dm)
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[8], line 1
----> 1 trainer.fit(model,datamodule=dm)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:531, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    529 model = _maybe_unwrap_optimized(model)
    530 self.strategy._lightning_module = model
--> 531 call._call_and_handle_interrupt(
    532     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    533 )

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:42, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     40     if trainer.strategy.launcher is not None:
     41         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 42     return trainer_fn(*args, **kwargs)
     44 except _TunerExitException:
     45     _call_teardown_hook(trainer)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:570, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    560 self._data_connector.attach_data(
    561     model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
    562 )
    564 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    565     self.state.fn,
    566     ckpt_path,
    567     model_provided=True,
    568     model_connected=self.lightning_module is not None,
    569 )
--> 570 self._run(model, ckpt_path=ckpt_path)
    572 assert self.state.stopped
    573 self.training = False

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:975, in Trainer._run(self, model, ckpt_path)
    970 self._signal_connector.register_signal_handlers()
    972 # ----------------------------
    973 # RUN THE TRAINER
    974 # ----------------------------
--> 975 results = self._run_stage()
    977 # ----------------------------
    978 # POST-Training CLEAN UP
    979 # ----------------------------
    980 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1016, in Trainer._run_stage(self)
   1014 if self.training:
   1015     with isolate_rng():
-> 1016         self._run_sanity_check()
   1017     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
   1018         self.fit_loop.run()

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1045, in Trainer._run_sanity_check(self)
   1042 call._call_callback_hooks(self, "on_sanity_check_start")
   1044 # run eval step
-> 1045 val_loop.run()
   1047 call._call_callback_hooks(self, "on_sanity_check_end")
   1049 # reset logger connector

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:177, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    175     context_manager = torch.no_grad
    176 with context_manager():
--> 177     return loop_run(self, *args, **kwargs)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:108, in _EvaluationLoop.run(self)
    106 while True:
    107     try:
--> 108         batch, batch_idx, dataloader_idx = next(data_fetcher)
    109         self.batch_progress.is_last_batch = data_fetcher.done
    110         if previous_dataloader_idx != dataloader_idx:
    111             # the dataloader has changed, notify the logger connector

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:136, in _PrefetchDataFetcher.__next__(self)
    133 elif not self.done:
    134     # this will run only when no pre-fetching was done.
    135     try:
--> 136         self._fetch_next_batch(self.dataloader_iter)
    137         # consume the batch we just fetched
    138         batch = self.batches.pop(0)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:150, in _PrefetchDataFetcher._fetch_next_batch(self, iterator)
    148 self._start_profiler()
    149 try:
--> 150     batch = next(iterator)
    151 finally:
    152     self._stop_profiler()

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:284, in CombinedLoader.__next__(self)
    282 def __next__(self) -> Any:
    283     assert self._iterator is not None
--> 284     out = next(self._iterator)
    285     if isinstance(self._iterator, _Sequential):
    286         return out

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:123, in _Sequential.__next__(self)
    120             raise StopIteration
    122 try:
--> 123     out = next(self.iterators[0])
    124     index = self._idx
    125     self._idx += 1

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/loader/base.py:36, in DataLoaderIterator.__next__(self)
     35 def __next__(self) -> Any:
---> 36     return self.transform_fn(next(self.iterator))

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
    630 if self._sampler_iter is None:
    631     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    632     self._reset()  # type: ignore[call-arg]
--> 633 data = self._next_data()
    634 self._num_yielded += 1
    635 if self._dataset_kind == _DatasetKind.Iterable and \
    636         self._IterableDataset_len_called is not None and \
    637         self._num_yielded > self._IterableDataset_len_called:

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1345, in _MultiProcessingDataLoaderIter._next_data(self)
   1343 else:
   1344     del self._task_info[idx]
-> 1345     return self._process_data(data)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1371, in _MultiProcessingDataLoaderIter._process_data(self, data)
   1369 self._try_put_index()
   1370 if isinstance(data, ExceptionWrapper):
-> 1371     data.reraise()
   1372 return data

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/_utils.py:644, in ExceptionWrapper.reraise(self)
    640 except TypeError:
    641     # If the exception takes multiple arguments, don't try to
    642     # instantiate since we don't know how to
    643     raise RuntimeError(msg) from None
--> 644 raise exception

ImportError: Caught ImportError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/loader/node_loader.py", line 117, in collate_fn
    out = self.node_sampler.sample_from_nodes(input_data)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 174, in sample_from_nodes
    return node_sample(inputs, self._sample)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 358, in node_sample
    out = sample_fn(seed, seed_time)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 325, in _sample
    raise ImportError(f"'{self.__class__.__name__}' requires "
ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'
trainer.test(model, datamodule=dm)
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[9], line 1
----> 1 trainer.test(model, datamodule=dm)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:737, in Trainer.test(self, model, dataloaders, ckpt_path, verbose, datamodule)
    735     model = _maybe_unwrap_optimized(model)
    736     self.strategy._lightning_module = model
--> 737 return call._call_and_handle_interrupt(
    738     self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
    739 )

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:42, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     40     if trainer.strategy.launcher is not None:
     41         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 42     return trainer_fn(*args, **kwargs)
     44 except _TunerExitException:
     45     _call_teardown_hook(trainer)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:780, in Trainer._test_impl(self, model, dataloaders, ckpt_path, verbose, datamodule)
    775 self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
    777 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    778     self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
    779 )
--> 780 results = self._run(model, ckpt_path=ckpt_path)
    781 # remove the tensors from the test results
    782 results = convert_tensors_to_scalars(results)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:975, in Trainer._run(self, model, ckpt_path)
    970 self._signal_connector.register_signal_handlers()
    972 # ----------------------------
    973 # RUN THE TRAINER
    974 # ----------------------------
--> 975 results = self._run_stage()
    977 # ----------------------------
    978 # POST-Training CLEAN UP
    979 # ----------------------------
    980 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1011, in Trainer._run_stage(self)
   1008 self.strategy.barrier("run-stage")
   1010 if self.evaluating:
-> 1011     return self._evaluation_loop.run()
   1012 if self.predicting:
   1013     return self.predict_loop.run()

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:177, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    175     context_manager = torch.no_grad
    176 with context_manager():
--> 177     return loop_run(self, *args, **kwargs)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:108, in _EvaluationLoop.run(self)
    106 while True:
    107     try:
--> 108         batch, batch_idx, dataloader_idx = next(data_fetcher)
    109         self.batch_progress.is_last_batch = data_fetcher.done
    110         if previous_dataloader_idx != dataloader_idx:
    111             # the dataloader has changed, notify the logger connector

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:136, in _PrefetchDataFetcher.__next__(self)
    133 elif not self.done:
    134     # this will run only when no pre-fetching was done.
    135     try:
--> 136         self._fetch_next_batch(self.dataloader_iter)
    137         # consume the batch we just fetched
    138         batch = self.batches.pop(0)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:150, in _PrefetchDataFetcher._fetch_next_batch(self, iterator)
    148 self._start_profiler()
    149 try:
--> 150     batch = next(iterator)
    151 finally:
    152     self._stop_profiler()

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:284, in CombinedLoader.__next__(self)
    282 def __next__(self) -> Any:
    283     assert self._iterator is not None
--> 284     out = next(self._iterator)
    285     if isinstance(self._iterator, _Sequential):
    286         return out

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:123, in _Sequential.__next__(self)
    120             raise StopIteration
    122 try:
--> 123     out = next(self.iterators[0])
    124     index = self._idx
    125     self._idx += 1

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/loader/base.py:36, in DataLoaderIterator.__next__(self)
     35 def __next__(self) -> Any:
---> 36     return self.transform_fn(next(self.iterator))

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
    630 if self._sampler_iter is None:
    631     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    632     self._reset()  # type: ignore[call-arg]
--> 633 data = self._next_data()
    634 self._num_yielded += 1
    635 if self._dataset_kind == _DatasetKind.Iterable and \
    636         self._IterableDataset_len_called is not None and \
    637         self._num_yielded > self._IterableDataset_len_called:

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1345, in _MultiProcessingDataLoaderIter._next_data(self)
   1343 else:
   1344     del self._task_info[idx]
-> 1345     return self._process_data(data)

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1371, in _MultiProcessingDataLoaderIter._process_data(self, data)
   1369 self._try_put_index()
   1370 if isinstance(data, ExceptionWrapper):
-> 1371     data.reraise()
   1372 return data

File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/_utils.py:644, in ExceptionWrapper.reraise(self)
    640 except TypeError:
    641     # If the exception takes multiple arguments, don't try to
    642     # instantiate since we don't know how to
    643     raise RuntimeError(msg) from None
--> 644 raise exception

ImportError: Caught ImportError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/loader/node_loader.py", line 117, in collate_fn
    out = self.node_sampler.sample_from_nodes(input_data)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 174, in sample_from_nodes
    return node_sample(inputs, self._sample)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 358, in node_sample
    out = sample_fn(seed, seed_time)
  File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 325, in _sample
    raise ImportError(f"'{self.__class__.__name__}' requires "
ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'