|
|
import pandas as pd |
|
|
from datasets import load_dataset, DatasetDict |
|
|
|
|
|
k = 10 |
|
|
|
|
|
|
|
|
dataset = load_dataset('CNX-PathLLM/TCGA-WSI-Text', split='train', cache_dir='/bask/projects/p/phwq4930-gbm/Zeyu/PathVLM/.cache') |
|
|
|
|
|
|
|
|
dataset_dict = DatasetDict() |
|
|
|
|
|
df_indices = pd.read_csv('./dataset_csv/indices_and_slide_ids_with_folds.csv') |
|
|
|
|
|
|
|
|
for i in range(k): |
|
|
fold_indices = df_indices[df_indices['fold'] == i]['index'].tolist() |
|
|
fold_dataset = dataset.select(fold_indices) |
|
|
|
|
|
dataset_dict[f'fold_{i}'] = fold_dataset |
|
|
|
|
|
print(dataset_dict) |
|
|
|
|
|
dataset_dict.save_to_disk('/bask/projects/p/phwq4930-gbm/Zeyu/WSI_Dataset/TCGA-WSI-Text-Folds') |
|
|
|
|
|
|