init
This commit is contained in:
174
experiments/data_loader.py
Normal file
174
experiments/data_loader.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import random
|
||||
import os
|
||||
import pandas as pd
|
||||
|
||||
from utils import *
|
||||
|
||||
|
||||
class DataLoader:
|
||||
def __init__(self, data_path):
|
||||
self.data_path = data_path
|
||||
self.files = os.listdir(data_path)
|
||||
|
||||
def load_webtable_columns_randomized(self, reference_set_amount: int, source_set_amount: int) -> tuple[list, list]:
|
||||
"""
|
||||
Get randomized reference sets and source sets of webtable columns.
|
||||
Reference sets are subsets of the source sets.
|
||||
Only columns with 4 or more different elements are considered.
|
||||
Only considering columns with non-numeric values.
|
||||
|
||||
Args:
|
||||
reference_set_amount (int): Number of reference sets to return.
|
||||
source_set_amount (int): Number of source sets to return.
|
||||
Returns:
|
||||
tuple: A tuple containing a list of reference sets and a list of source sets.
|
||||
"""
|
||||
# Basic validation of input parameters
|
||||
if reference_set_amount < 1 or source_set_amount < 2:
|
||||
raise ValueError("reference_set_amount must be at least 1 and source_set_amount must be at least 2")
|
||||
if reference_set_amount >= source_set_amount:
|
||||
raise ValueError("reference_set_amount must be smaller than source_set_amount")
|
||||
if reference_set_amount > len(self.files):
|
||||
raise ValueError("reference_set_amount must be smaller than the number of files in data_path")
|
||||
if source_set_amount > len(self.files):
|
||||
raise ValueError("source_set_amount must be smaller than the number of files in data_path")
|
||||
if len(self.files) == 0:
|
||||
raise ValueError("data_path does not contain any files")
|
||||
|
||||
|
||||
# Randomly select a reference set and source sets
|
||||
source_set_nums = random.sample(range(len(self.files)), source_set_amount)
|
||||
|
||||
# Pick source_set_amount of columns which have at least 4 different elements
|
||||
source_sets = []
|
||||
while len(source_sets) < source_set_amount:
|
||||
# Pick a random number from the source_set_nums
|
||||
source_set_num = random.choice(source_set_nums)
|
||||
file_path = os.path.join(self.data_path, self.files[source_set_num])
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
json_data = json.load(file)
|
||||
if "relation" in json_data and isinstance(json_data["relation"], list):
|
||||
# pick random column
|
||||
col = random.randint(0, len(json_data["relation"]) - 1)
|
||||
col = json_data["relation"][col]
|
||||
|
||||
# Check if the column has at least 4 different elements and contains no numeric values
|
||||
if len(set(col)) >= 4:
|
||||
if all(not is_convertible_to_number(value) and len(value) > 0 for value in col):
|
||||
# Add the column to the source sets
|
||||
source_sets.append(col)
|
||||
print(f"Source set number {len(source_sets)} loaded")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading JSON file: {e}")
|
||||
|
||||
# Randomly select reference sets from the source sets
|
||||
reference_sets = random.sample(source_sets, reference_set_amount)
|
||||
return reference_sets, source_sets
|
||||
|
||||
def load_webtable_reference_sets_element_restriction(self, source_set: list, element_restriction: int) -> list:
|
||||
"""
|
||||
Get a reference set of webtable columns with a specific element restriction.
|
||||
Restriction is the minimal number of elements allowed in the reference set.
|
||||
|
||||
Args:
|
||||
source_set (list): The source set to use for generating the reference set.
|
||||
element_restriction (int): The number of elements in the reference set.
|
||||
Returns:
|
||||
list: A list of reference sets.
|
||||
"""
|
||||
if element_restriction < 1:
|
||||
raise ValueError("element_restriction must be at least 1")
|
||||
|
||||
reference_sets = []
|
||||
|
||||
while len(reference_sets) < 1000:
|
||||
# Randomly select a column from the source set
|
||||
col = random.choice(source_set)
|
||||
|
||||
# Check if the column has at least element_restriction different elements
|
||||
if len(col) >= element_restriction:
|
||||
reference_sets.append(col)
|
||||
print(f"Reference set number {len(reference_sets)} loaded")
|
||||
|
||||
return reference_sets
|
||||
|
||||
def load_webtable_schemas_randomized(self, set_amount: int) -> list:
|
||||
if set_amount < 2:
|
||||
raise ValueError("source_set_amount must be at least 2")
|
||||
# Random sequence of table numbers
|
||||
table_nums = random.sample(range(len(self.files)), len(self.files))
|
||||
|
||||
schema_sets = []
|
||||
|
||||
i = 0
|
||||
while len(schema_sets) < set_amount and i < len(table_nums):
|
||||
try:
|
||||
# Load the schema for the current table number
|
||||
schema = self.load_single_webtable_schema(table_nums[i])
|
||||
schema_sets.append(schema)
|
||||
print(f"Schema set number {len(schema_sets)} loaded")
|
||||
i += 1
|
||||
except ValueError as e:
|
||||
print(f"Skipping table number {table_nums[i]} due to error: {e}")
|
||||
i += 1
|
||||
|
||||
return schema_sets
|
||||
|
||||
def load_single_webtable_schema(self, reference_set_num: int) -> list:
|
||||
# Load the webtable schema for the given reference set number
|
||||
if reference_set_num < 0 or reference_set_num >= len(self.files):
|
||||
raise IndexError("reference_set_num is out of range")
|
||||
|
||||
# Get the file at the specified position
|
||||
file_path = os.path.join(self.data_path, self.files[reference_set_num])
|
||||
|
||||
# Load and return the JSON content
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
json_data = json.load(file)
|
||||
if "relation" in json_data and isinstance(json_data["relation"], list):
|
||||
schema = [relation[0] for relation in json_data["relation"]]
|
||||
if len(schema) == 0:
|
||||
raise ValueError("Schema is empty")
|
||||
|
||||
if all(not is_convertible_to_number(col) for col in schema):
|
||||
# remove "" empty strings from the schema
|
||||
schema = [col for col in schema if len(col) > 0]
|
||||
if len(schema) == 0:
|
||||
raise ValueError("Schema contains only empty strings")
|
||||
return schema
|
||||
else:
|
||||
raise ValueError("Schema contains numeric values or is empty")
|
||||
else:
|
||||
raise ValueError("JSON does not contain a valid 'relation' key or it is not a list")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading JSON file: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
def load_dblp_titles(self, data_path: str) -> list:
|
||||
"""
|
||||
Load DBLP paper titles from a CSV file.
|
||||
|
||||
Args:
|
||||
data_path (str): Path to CSV file containing a column 'title'.
|
||||
|
||||
Returns:
|
||||
list: A list of title strings.
|
||||
"""
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"DBLP CSV file not found: {data_path}")
|
||||
|
||||
df = pd.read_csv(data_path)
|
||||
if "title" not in df.columns:
|
||||
raise ValueError("CSV must contain a 'title' column")
|
||||
|
||||
titles = df["title"].dropna().tolist()
|
||||
return titles
|
||||
|
||||
|
||||
Reference in New Issue
Block a user