English
Llama-slideQA / split.py
weiheng-1009's picture
added code for running
cbff41a
import pandas as pd
from datasets import load_dataset, DatasetDict
k = 10
dataset = load_dataset('CNX-PathLLM/TCGA-WSI-Text', split='train', cache_dir='/bask/projects/p/phwq4930-gbm/Zeyu/PathVLM/.cache')
# Create empty DatasetDict
dataset_dict = DatasetDict()
df_indices = pd.read_csv('./dataset_csv/indices_and_slide_ids_with_folds.csv')
# split data and add to DatasetDict
for i in range(k):
fold_indices = df_indices[df_indices['fold'] == i]['index'].tolist()
fold_dataset = dataset.select(fold_indices)
# added to DatasetDict
dataset_dict[f'fold_{i}'] = fold_dataset
print(dataset_dict)
dataset_dict.save_to_disk('/bask/projects/p/phwq4930-gbm/Zeyu/WSI_Dataset/TCGA-WSI-Text-Folds')
# dataset_dict.push_to_hub('CNX-PathLLM/TCGA-WSI-Text-Folds')