English
Llama-slideQA / filter_dataset.py
weiheng-1009's picture
added code for running
cbff41a
from datasets import load_dataset, DatasetDict
from PIL import Image, ImageFile, UnidentifiedImageError
import io
from tqdm import tqdm
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Assuming your dataset is loaded using load_dataset
cache_dir = "/bask/projects/p/phwq4930-gbm/Zeyu/PathVLM/.cache"
dataset_name = "CNX-PathLLM/Pathcap"
dataset = load_dataset(dataset_name, split="train", cache_dir=cache_dir)
print(f"original dataset size: {len(dataset)}")
# keep valid indices
valid_indices = []
# go through and check every element
for idx in tqdm(range(len(dataset))):
try:
example = dataset[idx]
text = example["txt"]
if not isinstance(text, str):
raise ValueError(f"not a string: {text}")
valid_indices.append(idx)
except Exception as e:
print(f"Cannot recognize file {idx}: {e}")
# Select valid samples according to the indices of valid samples.
filtered_dataset = dataset.select(valid_indices)
# Filter out images that cannot be loaded.
# filtered_dataset = dataset.filter(lambda example: example["is_valid"])
# Print the size of the filtered dataset
print(f"filtered dataset size: {len(filtered_dataset)}")
if len(dataset) != len(filtered_dataset):
# convert to DatasetDict
filtered_dataset_dict = DatasetDict({"train": filtered_dataset})
# push to hub
filtered_dataset_dict.push_to_hub(dataset_name)