保姆级教程:用Python处理METR-la交通数据集,搞定12步预测的输入输出格式
从零解析METR-la交通数据集12步预测的完整数据工程指南当你第一次打开METR-la交通数据集时那个看似简单的N×T矩阵背后隐藏着多少未解之谜作为时空预测领域的经典基准数据集METR-la的原始.h5文件就像一座未经雕琢的矿山而本文将带你用Python这把精密工具完成从原始数据到模型可消化格式的完整蜕变过程。1. 理解METR-la数据集的本质METR-la数据集记录了洛杉矶高速公路207个传感器在4个月内的交通速度数据原始格式为34272个时间点×207个传感器的二维矩阵。但现代时空图神经网络(ST-GNN)需要的是序列到序列的输入输出范式——用过去12个时间步预测未来12个时间步。这就产生了三个核心问题如何从N×T矩阵中提取出N×12的序列片段如何处理时间边界处的数据截断如何合理分割训练集、验证集和测试集关键洞察交通数据具有明显的时空双重特性——空间上传感器构成图结构时间上速度变化呈现周期性。预处理必须保留这两种特性。让我们先建立基础环境import numpy as np import pandas as pd import os from datetime import datetime # 确保可复现性 np.random.seed(42)2. 数据加载与初步探索原始数据存储在HDF5格式中Pandas提供了便捷的读取方式def load_metr_la_data(file_pathmetr-la.h5): df pd.read_hdf(file_path) print(f数据集形状{df.shape}) # 应输出(34272, 207) print(f时间范围{df.index.min()} 至 {df.index.max()}) return df执行后会看到类似输出数据集形状(34272, 207) 时间范围2012-03-01 00:00:00 至 2012-06-30 23:55:00数据特性分析时间分辨率为5分钟288个时间点/天包含3%的缺失值用NaN表示速度单位是mph(英里/小时)3. 构建时间序列偏移量体系时空预测的核心在于构建时间滑动窗口。我们需要定义两组关键偏移量# 输入序列过去12个时间步含当前时刻 x_offsets np.arange(-11, 1, 1) # [-11, -10,..., 0] # 预测目标未来12个时间步不含当前时刻 y_offsets np.arange(1, 13, 1) # [1, 2,..., 12]这种设计意味着每个样本x[i]包含t-11到t时刻的数据对应标签y[i]包含t1到t12时刻的数据技术细节偏移量选择直接影响模型学习的时间依赖范围。12步对应1小时跨度适合捕捉交通流的短时演变规律。4. 时间特征工程实战除了交通速度时间本身也是重要特征。我们需要提取两种时间信息def create_time_features(df): # 一天中的相对位置0-1之间 time_ind (df.index.values - df.index.values.astype(datetime64[D])) time_ind time_ind / np.timedelta64(1, D) # 星期几0-6 day_of_week df.index.dayofweek.values / 7 return np.stack([time_ind, day_of_week], axis-1)特征拼接后的数据维度变化原始数据(34272, 207, 1) → 添加时间后(34272, 207, 3)5. 滑动窗口生成算法这是最关键的步骤将长序列切割为训练样本def generate_sequence_samples(data, x_offsets, y_offsets): num_samples, num_nodes, _ data.shape min_t abs(min(x_offsets)) # 11 max_t num_samples - max(y_offsets) # 34272-1234260 x, y [], [] for t in range(min_t, max_t): x.append(data[t x_offsets]) # 12×207×3 y.append(data[t y_offsets]) # 12×207×3 return np.stack(x), np.stack(y)内存优化技巧 对于大型数据集建议使用生成器而非一次性加载所有样本def sequence_generator(data, batch_size64): # 实现批处理生成逻辑 while True: for i in range(0, len(data), batch_size): yield data[i:ibatch_size]6. 数据集分割策略交通数据具有强时间相关性必须按时间顺序分割数据集比例样本数时间范围训练集70%23,974前3个月验证集10%3,425第4个月前三周测试集20%6,850最后一周实现代码def split_dataset(x, y, train_ratio0.7, val_ratio0.1): num_samples x.shape[0] num_train int(num_samples * train_ratio) num_val int(num_samples * val_ratio) return { train: (x[:num_train], y[:num_train]), val: (x[num_train:num_trainnum_val], y[num_train:num_trainnum_val]), test: (x[num_trainnum_val:], y[num_trainnum_val:]) }7. 数据标准化与缺失值处理交通数据需要两种标准化方法全局标准化适合静态模型scaler StandardScaler() data_scaled scaler.fit_transform(df.values.reshape(-1, 1)).reshape(df.shape)局部标准化适合自适应模型def rolling_normalize(data, window_size288): # 1天窗口 rolling_mean data.rolling(windowwindow_size, min_periods1).mean() rolling_std data.rolling(windowwindow_size, min_periods1).std() return (data - rolling_mean) / (rolling_std 1e-8)对于缺失值推荐采用时空混合插补from sklearn.impute import KNNImputer # 空间维度插补利用相邻传感器 spatial_imputer KNNImputer(n_neighbors5) data_spatial_filled spatial_imputer.fit_transform(df.values) # 时间维度插补利用历史同期数据 temporal_imputer KNNImputer(n_neighbors3) data_final temporal_imputer.fit_transform(data_spatial_filled.T).T8. 高效存储方案设计NPZ格式比HDF5更轻量且兼容性好def save_as_npz(data_dict, output_dir): os.makedirs(output_dir, exist_okTrue) for name, (x, y) in data_dict.items(): np.savez_compressed( os.path.join(output_dir, f{name}.npz), xx.astype(np.float32), # 节省存储空间 yy.astype(np.float32), x_offsetsx_offsets, y_offsetsy_offsets )文件结构示例output_dir/ ├── train.npz ├── val.npz └── test.npz9. 数据可视化验证在预处理完成后必须进行可视化校验import matplotlib.pyplot as plt def plot_sample(x, y, node_idx0, feature_idx0): plt.figure(figsize(12, 6)) plt.plot(x[:, node_idx, feature_idx], labelInput) plt.plot(range(12, 24), y[:, node_idx, feature_idx], labelTarget) plt.legend() plt.show()典型检查点输入输出序列的时间对齐是否正确标准化后的数据是否保持原有模式缺失值插补是否自然10. 与主流框架的集成处理后的数据可直接用于流行库PyTorch数据加载from torch.utils.data import Dataset class TrafficDataset(Dataset): def __init__(self, npz_file): data np.load(npz_file) self.x data[x] self.y data[y] def __len__(self): return len(self.x) def __getitem__(self, idx): return torch.FloatTensor(self.x[idx]), torch.FloatTensor(self.y[idx])DGL图神经网络输入import dgl def build_spatial_graph(num_nodes207): # 基于距离构建图结构 src [...] # 源节点列表 dst [...] # 目标节点列表 return dgl.graph((src, dst), num_nodesnum_nodes)11. 性能优化技巧处理大规模交通数据时这些技巧可提升效率内存映射技术x np.load(train.npz, mmap_moder)[x]并行化处理from joblib import Parallel, delayed def process_chunk(chunk): return generate_sequence_samples(chunk, x_offsets, y_offsets) results Parallel(n_jobs4)(delayed(process_chunk)(chunk) for chunk in np.array_split(data, 4))数据压缩策略对比格式读取速度存储大小兼容性NPZ★★★★★★★★★★★★HDF5★★★★★★★★★★★Feather★★★★★★★★★★12. 常见问题解决方案问题1序列长度不匹配检查max_t计算是否准确max_t num_samples - max(y_offsets)问题2内存不足使用分块处理for chunk in pd.read_hdf(metr-la.h5, chunksize10000)问题3预测结果滞后验证偏移量方向y_offsets应为正数序列问题4模型无法收敛检查标准化是否反向应用了测试集必须使用训练集的均值和方差在真实项目中我发现最易出错的是时间偏移量的方向设置。曾经因为将y_offsets误设为np.arange(0, 12)导致模型学习到恒等映射而非预测能力。另一个经验是对于长期预测如预测未来1小时以上建议在时间特征中加入星期几和节假日标志这对捕捉交通模式的宏观周期非常有效。