Linear NCEM Example#
%load_ext autoreload
%autoreload 2
import pytorch_lightning as pl
import torch
import anndata as ad
from geome import transforms
from geome.adata2data import AnnData2DataByCategory
from utils.datasets import DatasetHartmann
from utils.models.linear_ncem import LinearNCEM
from geome.datamodule import GraphAnnDataModule
fields = {
'x':['obs/Cluster_preprocessed','obs/donor','obsm/design_matrix'],
'y':['X']
}
# preprocess = [
# lambda x,_: transforms.categorize_obs(x,['donor', 'Cluster_preprocessed', 'point']),
# lambda x,_: transforms.add_design_matrix(x,'obs/Cluster_preprocessed','obs/donor','design_matrix'),
# ]
from geome.transforms import Categorize, AddDesignMatrix, Compose, AddAdjMatrix
adj_matrix_loc = 'obsp/adj_matrix'
preprocess = Compose(
[
Categorize(keys=['donor', 'Cluster_preprocessed', 'point'], axis='obs'),
AddAdjMatrix(adj_matrix_loc, overwrite=True),
]
)
transform = Compose(
[
AddDesignMatrix('obs/Cluster_preprocessed','obs/donor', adj_matrix_loc, 'design_matrix', overwrite=True),
]
)
# a2d = AnnData2DataByCategory(
# fields=fields,
# category=category_to_iterate,
# preprocess=preprocess,
# yields_edge_index=True,
# )
category_to_iterate = 'point'
a2d = AnnData2DataByCategory(
fields=fields,
category=category_to_iterate,
preprocess=preprocess,
transform=transform,
)
#Mibitof
dataset = DatasetHartmann(data_path='./example_data/hartmann/')
adatas = list(dataset.img_celldata.values())
# Merge the list of adatas and convert some string to categories as they should be
adata = ad.concat(adatas)
Loading data from raw files
registering celldata
collecting image-wise celldata
adding graph-level covariates
Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.
datas = a2d(adata)
datas
[Data(x=[1338, 88], y=[1338, 36]),
Data(x=[311, 88], y=[311, 36]),
Data(x=[768, 88], y=[768, 36]),
Data(x=[1020, 88], y=[1020, 36]),
Data(x=[2100, 88], y=[2100, 36]),
Data(x=[1325, 88], y=[1325, 36]),
Data(x=[1091, 88], y=[1091, 36]),
Data(x=[1046, 88], y=[1046, 36]),
Data(x=[618, 88], y=[618, 36]),
Data(x=[61, 88], y=[61, 36]),
Data(x=[1316, 88], y=[1316, 36]),
Data(x=[1540, 88], y=[1540, 36]),
Data(x=[1822, 88], y=[1822, 36]),
Data(x=[863, 88], y=[863, 36]),
Data(x=[564, 88], y=[564, 36]),
Data(x=[1023, 88], y=[1023, 36]),
Data(x=[324, 88], y=[324, 36]),
Data(x=[287, 88], y=[287, 36]),
Data(x=[636, 88], y=[636, 36]),
Data(x=[890, 88], y=[890, 36]),
Data(x=[1235, 88], y=[1235, 36]),
Data(x=[1020, 88], y=[1020, 36]),
Data(x=[1241, 88], y=[1241, 36]),
Data(x=[1438, 88], y=[1438, 36]),
Data(x=[1021, 88], y=[1021, 36]),
Data(x=[1632, 88], y=[1632, 36]),
Data(x=[780, 88], y=[780, 36]),
Data(x=[524, 88], y=[524, 36]),
Data(x=[669, 88], y=[669, 36]),
Data(x=[241, 88], y=[241, 36]),
Data(x=[935, 88], y=[935, 36]),
Data(x=[347, 88], y=[347, 36]),
Data(x=[1499, 88], y=[1499, 36]),
Data(x=[601, 88], y=[601, 36]),
Data(x=[2268, 88], y=[2268, 36]),
Data(x=[1912, 88], y=[1912, 36]),
Data(x=[1678, 88], y=[1678, 36]),
Data(x=[1025, 88], y=[1025, 36]),
Data(x=[1306, 88], y=[1306, 36]),
Data(x=[852, 88], y=[852, 36]),
Data(x=[1664, 88], y=[1664, 36]),
Data(x=[1698, 88], y=[1698, 36]),
Data(x=[1672, 88], y=[1672, 36]),
Data(x=[777, 88], y=[777, 36]),
Data(x=[556, 88], y=[556, 36]),
Data(x=[554, 88], y=[554, 36]),
Data(x=[937, 88], y=[937, 36]),
Data(x=[1524, 88], y=[1524, 36]),
Data(x=[1528, 88], y=[1528, 36]),
Data(x=[721, 88], y=[721, 36]),
Data(x=[1395, 88], y=[1395, 36]),
Data(x=[611, 88], y=[611, 36]),
Data(x=[1872, 88], y=[1872, 36]),
Data(x=[1217, 88], y=[1217, 36]),
Data(x=[1531, 88], y=[1531, 36]),
Data(x=[1927, 88], y=[1927, 36]),
Data(x=[690, 88], y=[690, 36]),
Data(x=[1706, 88], y=[1706, 36])]
num_features = datas[0].x.shape[1]
out_channels = datas[0].y.shape[1]
num_features, out_channels
(88, 36)
dm = GraphAnnDataModule(datas=datas, num_workers = 12, batch_size=12,learning_type='node')
model = LinearNCEM(in_channels=num_features,out_channels=out_channels, lr=0.0001,weight_decay=0.000001)
trainer:pl.Trainer = pl.Trainer(accelerator='gpu' if torch.torch.cuda.is_available() else 'cpu',
max_epochs=100)
trainer.fit(model,datamodule=dm)
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[8], line 1
----> 1 trainer.fit(model,datamodule=dm)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:531, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
529 model = _maybe_unwrap_optimized(model)
530 self.strategy._lightning_module = model
--> 531 call._call_and_handle_interrupt(
532 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
533 )
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:42, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
40 if trainer.strategy.launcher is not None:
41 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 42 return trainer_fn(*args, **kwargs)
44 except _TunerExitException:
45 _call_teardown_hook(trainer)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:570, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
560 self._data_connector.attach_data(
561 model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
562 )
564 ckpt_path = self._checkpoint_connector._select_ckpt_path(
565 self.state.fn,
566 ckpt_path,
567 model_provided=True,
568 model_connected=self.lightning_module is not None,
569 )
--> 570 self._run(model, ckpt_path=ckpt_path)
572 assert self.state.stopped
573 self.training = False
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:975, in Trainer._run(self, model, ckpt_path)
970 self._signal_connector.register_signal_handlers()
972 # ----------------------------
973 # RUN THE TRAINER
974 # ----------------------------
--> 975 results = self._run_stage()
977 # ----------------------------
978 # POST-Training CLEAN UP
979 # ----------------------------
980 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1016, in Trainer._run_stage(self)
1014 if self.training:
1015 with isolate_rng():
-> 1016 self._run_sanity_check()
1017 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
1018 self.fit_loop.run()
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1045, in Trainer._run_sanity_check(self)
1042 call._call_callback_hooks(self, "on_sanity_check_start")
1044 # run eval step
-> 1045 val_loop.run()
1047 call._call_callback_hooks(self, "on_sanity_check_end")
1049 # reset logger connector
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:177, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
175 context_manager = torch.no_grad
176 with context_manager():
--> 177 return loop_run(self, *args, **kwargs)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:108, in _EvaluationLoop.run(self)
106 while True:
107 try:
--> 108 batch, batch_idx, dataloader_idx = next(data_fetcher)
109 self.batch_progress.is_last_batch = data_fetcher.done
110 if previous_dataloader_idx != dataloader_idx:
111 # the dataloader has changed, notify the logger connector
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:136, in _PrefetchDataFetcher.__next__(self)
133 elif not self.done:
134 # this will run only when no pre-fetching was done.
135 try:
--> 136 self._fetch_next_batch(self.dataloader_iter)
137 # consume the batch we just fetched
138 batch = self.batches.pop(0)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:150, in _PrefetchDataFetcher._fetch_next_batch(self, iterator)
148 self._start_profiler()
149 try:
--> 150 batch = next(iterator)
151 finally:
152 self._stop_profiler()
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:284, in CombinedLoader.__next__(self)
282 def __next__(self) -> Any:
283 assert self._iterator is not None
--> 284 out = next(self._iterator)
285 if isinstance(self._iterator, _Sequential):
286 return out
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:123, in _Sequential.__next__(self)
120 raise StopIteration
122 try:
--> 123 out = next(self.iterators[0])
124 index = self._idx
125 self._idx += 1
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/loader/base.py:36, in DataLoaderIterator.__next__(self)
35 def __next__(self) -> Any:
---> 36 return self.transform_fn(next(self.iterator))
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
630 if self._sampler_iter is None:
631 # TODO(https://github.com/pytorch/pytorch/issues/76750)
632 self._reset() # type: ignore[call-arg]
--> 633 data = self._next_data()
634 self._num_yielded += 1
635 if self._dataset_kind == _DatasetKind.Iterable and \
636 self._IterableDataset_len_called is not None and \
637 self._num_yielded > self._IterableDataset_len_called:
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1345, in _MultiProcessingDataLoaderIter._next_data(self)
1343 else:
1344 del self._task_info[idx]
-> 1345 return self._process_data(data)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1371, in _MultiProcessingDataLoaderIter._process_data(self, data)
1369 self._try_put_index()
1370 if isinstance(data, ExceptionWrapper):
-> 1371 data.reraise()
1372 return data
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/_utils.py:644, in ExceptionWrapper.reraise(self)
640 except TypeError:
641 # If the exception takes multiple arguments, don't try to
642 # instantiate since we don't know how to
643 raise RuntimeError(msg) from None
--> 644 raise exception
ImportError: Caught ImportError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
return self.collate_fn(data)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/loader/node_loader.py", line 117, in collate_fn
out = self.node_sampler.sample_from_nodes(input_data)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 174, in sample_from_nodes
return node_sample(inputs, self._sample)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 358, in node_sample
out = sample_fn(seed, seed_time)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 325, in _sample
raise ImportError(f"'{self.__class__.__name__}' requires "
ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'
trainer.test(model, datamodule=dm)
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[9], line 1
----> 1 trainer.test(model, datamodule=dm)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:737, in Trainer.test(self, model, dataloaders, ckpt_path, verbose, datamodule)
735 model = _maybe_unwrap_optimized(model)
736 self.strategy._lightning_module = model
--> 737 return call._call_and_handle_interrupt(
738 self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
739 )
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:42, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
40 if trainer.strategy.launcher is not None:
41 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 42 return trainer_fn(*args, **kwargs)
44 except _TunerExitException:
45 _call_teardown_hook(trainer)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:780, in Trainer._test_impl(self, model, dataloaders, ckpt_path, verbose, datamodule)
775 self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
777 ckpt_path = self._checkpoint_connector._select_ckpt_path(
778 self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
779 )
--> 780 results = self._run(model, ckpt_path=ckpt_path)
781 # remove the tensors from the test results
782 results = convert_tensors_to_scalars(results)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:975, in Trainer._run(self, model, ckpt_path)
970 self._signal_connector.register_signal_handlers()
972 # ----------------------------
973 # RUN THE TRAINER
974 # ----------------------------
--> 975 results = self._run_stage()
977 # ----------------------------
978 # POST-Training CLEAN UP
979 # ----------------------------
980 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1011, in Trainer._run_stage(self)
1008 self.strategy.barrier("run-stage")
1010 if self.evaluating:
-> 1011 return self._evaluation_loop.run()
1012 if self.predicting:
1013 return self.predict_loop.run()
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:177, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
175 context_manager = torch.no_grad
176 with context_manager():
--> 177 return loop_run(self, *args, **kwargs)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:108, in _EvaluationLoop.run(self)
106 while True:
107 try:
--> 108 batch, batch_idx, dataloader_idx = next(data_fetcher)
109 self.batch_progress.is_last_batch = data_fetcher.done
110 if previous_dataloader_idx != dataloader_idx:
111 # the dataloader has changed, notify the logger connector
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:136, in _PrefetchDataFetcher.__next__(self)
133 elif not self.done:
134 # this will run only when no pre-fetching was done.
135 try:
--> 136 self._fetch_next_batch(self.dataloader_iter)
137 # consume the batch we just fetched
138 batch = self.batches.pop(0)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:150, in _PrefetchDataFetcher._fetch_next_batch(self, iterator)
148 self._start_profiler()
149 try:
--> 150 batch = next(iterator)
151 finally:
152 self._stop_profiler()
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:284, in CombinedLoader.__next__(self)
282 def __next__(self) -> Any:
283 assert self._iterator is not None
--> 284 out = next(self._iterator)
285 if isinstance(self._iterator, _Sequential):
286 return out
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:123, in _Sequential.__next__(self)
120 raise StopIteration
122 try:
--> 123 out = next(self.iterators[0])
124 index = self._idx
125 self._idx += 1
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/loader/base.py:36, in DataLoaderIterator.__next__(self)
35 def __next__(self) -> Any:
---> 36 return self.transform_fn(next(self.iterator))
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
630 if self._sampler_iter is None:
631 # TODO(https://github.com/pytorch/pytorch/issues/76750)
632 self._reset() # type: ignore[call-arg]
--> 633 data = self._next_data()
634 self._num_yielded += 1
635 if self._dataset_kind == _DatasetKind.Iterable and \
636 self._IterableDataset_len_called is not None and \
637 self._num_yielded > self._IterableDataset_len_called:
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1345, in _MultiProcessingDataLoaderIter._next_data(self)
1343 else:
1344 del self._task_info[idx]
-> 1345 return self._process_data(data)
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1371, in _MultiProcessingDataLoaderIter._process_data(self, data)
1369 self._try_put_index()
1370 if isinstance(data, ExceptionWrapper):
-> 1371 data.reraise()
1372 return data
File ~/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/_utils.py:644, in ExceptionWrapper.reraise(self)
640 except TypeError:
641 # If the exception takes multiple arguments, don't try to
642 # instantiate since we don't know how to
643 raise RuntimeError(msg) from None
--> 644 raise exception
ImportError: Caught ImportError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
return self.collate_fn(data)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/loader/node_loader.py", line 117, in collate_fn
out = self.node_sampler.sample_from_nodes(input_data)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 174, in sample_from_nodes
return node_sample(inputs, self._sample)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 358, in node_sample
out = sample_fn(seed, seed_time)
File "/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py", line 325, in _sample
raise ImportError(f"'{self.__class__.__name__}' requires "
ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'