Source code for das.block_stratify

A bit on semantics:
- blocks are parts of the data (individual files in a file list, parts of a data array)
- individual blocks are assigned specific groups (train/test/val)
from typing import Optional, Dict, List, Any, Tuple, TypeVar
import numpy as np
import pandas as pd

Block = TypeVar("Block")
Group = TypeVar("Group")

[docs]def groupstats(stats: List, groups: List[Group]): """_summary_ Args: stats (_type_): stats [x] for each group [n, x] groups (_type_): ids for each group [n,] Returns: Dict[group_id, x]: stats average over all members of groups """ unique_groups, group_index = np.unique(groups, return_inverse=True) group_stats = {} for index, group in enumerate(unique_groups): group_stats[group] = np.mean(stats[group_index == index, :], axis=0) return group_stats
def group_splits(data: np.ndarray, group_sizes: List[str]) -> List[np.ndarray]: group_cdf = np.insert(np.cumsum(group_sizes), 0, 0)[:-1] split_points = (group_cdf * data.shape[0]).astype(int) blocks = np.array_split(data, split_points[1:]) split_points = list(split_points) split_points.append(data.shape[0]) split_points = [(start, end) for start, end in zip(split_points[:-1], split_points[1:])] return blocks, split_points def group_blocks(nb_blocks: int, group_sizes: List[str], group_names: List[Group]): group_cdf = np.insert(np.cumsum(group_sizes), 0, 0)[:-1] group_probs = np.linspace(0, 1, nb_blocks) groups = np.zeros((nb_blocks,), dtype=int) for cnt, s in enumerate(group_cdf): groups[group_probs >= s] = cnt groups = [group_names[s] for s in groups] return groups def score_grouping(block_stats, groups: List[Group]): # compute global probs # we want the individual groups to have stats as close as possible to the global stats global_stats = np.mean(block_stats, axis=0) # exclude groups with zero occurrence from score non_zero_stats = np.where(global_stats > 0)[0] # score each group group_stats = groupstats(block_stats, groups) group_scores = {} for group_name, group_stat in group_stats.items(): tmp = np.log2((group_stat[non_zero_stats] + 0.0000001) / global_stats[non_zero_stats]) group_scores[group_name] = np.sum(np.abs(tmp)) # compute total score over all groups total_score = np.sum(list(group_scores.values())) return total_score def opt_grouping( block_stats: np.ndarray, group_sizes: List[Group], group_names: List[Group], nb_perms: int = 100 ) -> List[Group]: nb_blocks = len(block_stats) grouping = group_blocks(nb_blocks, group_sizes, group_names) grouping = np.array(grouping) best_perm = [] best_score = np.inf for _ in range(nb_perms): perm = np.random.permutation(grouping) score = score_grouping(block_stats, grouping) if score < best_score: best_score = score best_perm = perm return best_perm def block_data(data: np.ndarray, block_size: int): split_points = np.array(range(0, data.shape[0], block_size)) split_points[-1] = data.shape[0] # Returns a view so free in terms of memory. blocks = np.array_split(data, split_points[1:])[:-1] # ALTERNATIVE: block_size will be chosen so that all blocks have roughly equal size # blocks = np.array_split(y, np.floor(data.shape[0] / block_size)) # DEBUG CODE # print([block.shape for block in blocks]) # print(len(blocks), blocks[1].shape, blocks[-1].shape, np.sum([len(b) for b in blocks])) return blocks, split_points def blocks_from_split_points(data: np.ndarray, split_points: List[Tuple[int, int]]): blocks = [] for split_point in split_points: blocks.append(data[split_point[0] : split_point[1]]) return blocks
[docs]def blockstats_from_data(data: np.ndarray, block_size: int, gap: int = 0) -> Dict[Tuple[int, int], np.ndarray]: """_summary_ Last block will contain overhang so is typically longer Args: data (np.ndarray): _description_ block_size (int): _description_ gap (int, optional) Returns: Dict[Tuple[int, int], np.array]: _description_ """ blocks, split_points = block_data(data, block_size) blockstats = {} for split_start, split_end, block in zip(split_points[:-1], split_points[1:], blocks): blockstats[(split_start, split_end)] = np.mean(block[gap : block.shape[0] - gap], axis=0) return blockstats
[docs]def blockstats_from_files(file_bases: List[str], class_names: Optional[List] = None) -> Dict[str, np.ndarray]: """_summary_ Args: file_bases (List[str]): List of annotation files. class_names (Optional[List]): List of class_names. Defaults to None (will infer from all annotations). Returns: Dict[str, np.ndarray]: Block stats for each file name. """ if class_names is None: class_names = [] for file_base in file_bases: df = pd.read_csv(file_base) df = df.dropna() class_names.extend(list(set(df["name"]))) class_names = list(set(class_names)) stratify_by_number = dict() for file_base in file_bases: df = pd.read_csv(file_base) df = df.dropna() # stratify_by_number[file_base] = {n: df["name"].tolist().count(n) for n in class_names} stratify_by_number[file_base] = [df["name"].tolist().count(n) for n in class_names] return stratify_by_number
def format_by_group(groups, block_names, group_names): blocks = {group_name: [] for group_name in group_names} for block_name, group in zip(block_names, groups): blocks[group].append(block_name) return blocks def format_by_block(blocks, block_names): block_dict = dict() for block_name, block in zip(block_names, blocks): block_dict[block_name] = block return block_dict def block( block_names: List[np.ndarray], group_sizes: List[float], group_names: List[Group], block_stats: Optional[List[Block]] = None, shuffle: bool = True, seed: Optional[float] = None, ) -> Dict[Group, List[Block]]: if seed is not None: np.random.seed(seed) if block_stats is None: blocks = group_blocks(len(block_names), group_sizes, group_names) if shuffle: blocks = np.random.permutation(blocks) else: block_stats = np.array(list(block_stats)) blocks = opt_grouping(block_stats, group_sizes, group_names) blocks = format_by_block(blocks, block_names) return blocks