297 lines
13 KiB
Python
297 lines
13 KiB
Python
from dataclasses import dataclass
|
||
from typing import Optional, Tuple, List
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
from matplotlib.figure import Figure
|
||
from matplotlib.axes import Axes
|
||
import matplotlib.dates as mdates
|
||
|
||
from modules.saber.gravityw_process import SaberGravitywProcessor, WaveData, YearlyData
|
||
|
||
|
||
@dataclass
|
||
class PlotConfig:
|
||
"""绘图配置类"""
|
||
figsize: Tuple[int, int] = (16, 10)
|
||
dpi: int = 100
|
||
cmap: str = 'viridis'
|
||
title_fontsize: int = 12
|
||
label_fontsize: int = 10
|
||
legend_fontsize: int = 8
|
||
tick_fontsize: int = 8
|
||
|
||
|
||
class SaberGravitywRenderer:
|
||
def __init__(self, config: Optional[PlotConfig] = None):
|
||
self.config = config or PlotConfig()
|
||
plt.rcParams['font.family'] = 'SimHei' # 设置为黑体(需要你的环境中有该字体)
|
||
plt.rcParams['axes.unicode_minus'] = False # 解决负号'-'显示为方块的问题
|
||
|
||
def _create_figure(self, rows: int, cols: int) -> Tuple[Figure, List[Axes]]:
|
||
"""创建图形和轴对象"""
|
||
fig, axes = plt.subplots(
|
||
rows, cols, figsize=self.config.figsize, dpi=self.config.dpi)
|
||
return fig, axes.flatten() if isinstance(axes, np.ndarray) else [axes]
|
||
|
||
def _setup_subplot(self, ax: Axes, x: np.ndarray, y1: np.ndarray,
|
||
y2: Optional[np.ndarray], title: str, y_limits: Tuple[float, float]):
|
||
"""设置单个子图的样式和数据"""
|
||
ax.plot(x, y1, label='原始信号', linewidth=1.5)
|
||
if y2 is not None:
|
||
ax.plot(x, y2, label='拟合信号', linestyle='--', linewidth=1.5)
|
||
|
||
ax.set_title(title, fontsize=self.config.title_fontsize)
|
||
ax.set_xlabel('Cycles', labelpad=10,
|
||
fontsize=self.config.label_fontsize)
|
||
ax.set_ylabel('温度 (K)', labelpad=10,
|
||
fontsize=self.config.label_fontsize)
|
||
ax.legend(fontsize=self.config.legend_fontsize)
|
||
ax.tick_params(axis='both', labelsize=self.config.tick_fontsize)
|
||
ax.set_ylim(y_limits)
|
||
ax.grid(True, linestyle='--', alpha=0.7)
|
||
|
||
def plot_wave_fitting(self, wave_data: WaveData, height_no: int):
|
||
"""绘制波数拟合结果"""
|
||
fig, axes = self._create_figure(2, 3)
|
||
|
||
# 准备数据
|
||
N = len(wave_data.wn0[:, height_no])
|
||
x = np.arange(N)
|
||
|
||
data_pairs = [
|
||
(wave_data.wn0[:, height_no], wave_data.fit_wn[0][:, height_no]),
|
||
(wave_data.wn[0][:, height_no], wave_data.fit_wn[1][:, height_no]),
|
||
(wave_data.wn[1][:, height_no], wave_data.fit_wn[2][:, height_no]),
|
||
(wave_data.wn[2][:, height_no], wave_data.fit_wn[3][:, height_no]),
|
||
(wave_data.wn[3][:, height_no], wave_data.fit_wn[4][:, height_no]),
|
||
(wave_data.wn[4][:, height_no], None)
|
||
]
|
||
|
||
# 计算统一的y轴范围
|
||
all_values = [
|
||
val for pair in data_pairs for val in pair if val is not None]
|
||
y_limits = (np.min(all_values), np.max(all_values))
|
||
|
||
# 绘制子图
|
||
for i, (y1, y2) in enumerate(data_pairs):
|
||
title = f'({"abcdef"[i]})波数k={i + 1 if i < 5 else "滤波1-5后信号"}'
|
||
self._setup_subplot(axes[i], x, y1, y2, title, y_limits)
|
||
|
||
# 调整布局
|
||
plt.tight_layout()
|
||
plt.subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.9,
|
||
hspace=0.3, wspace=0.3)
|
||
|
||
def day_fft_ifft_plot(self, cycle_no, wave_data: WaveData):
|
||
|
||
ktemp_wn5 = wave_data.wn[4]
|
||
ktemp_fft = wave_data.fft
|
||
ktemp_fft_lvbo = wave_data.fft_filtered
|
||
ktemp_ifft = wave_data.ifft
|
||
|
||
altitude_min, altitude_max = wave_data.alt_range
|
||
lamda_low, lamda_high = wave_data.lambda_range
|
||
|
||
N = len(ktemp_wn5[cycle_no, :])
|
||
# 采样时间间隔,其倒数等于采用频率,以1km为标准尺度等同于1s,假设波的速度为1km/s
|
||
dt = (altitude_max-altitude_min)/(N-1)
|
||
# 时间序列索引
|
||
n = np.arange(N)
|
||
f = n / (N * dt)
|
||
t = np.round(np.linspace(altitude_min, altitude_max, N), 2)
|
||
|
||
# 原始扰动温度
|
||
x = ktemp_wn5[cycle_no, :]
|
||
# 傅里叶变换频谱分析
|
||
y = ktemp_fft[cycle_no, :]
|
||
# 滤波后的傅里叶变换频谱分析
|
||
yy = ktemp_fft_lvbo[cycle_no, :]
|
||
# 傅里叶逆变换后的扰动温度
|
||
yyy = ktemp_ifft[cycle_no, :]
|
||
|
||
plt.figure(figsize=(15, 10)) # 调整图形大小
|
||
# 原始信号的时间序列
|
||
plt.subplot(2, 2, 1)
|
||
plt.plot(t, x)
|
||
plt.title('(a)原始信号')
|
||
plt.xlabel('高度 (km)', labelpad=10) # 增加标签间距
|
||
plt.ylabel('温度 (K)', labelpad=10) # 增加标签间距
|
||
# 原始振幅谱
|
||
plt.subplot(2, 2, 2)
|
||
plt.plot(f, np.abs(y) * 2 / N)
|
||
plt.title('(b))原始振幅谱')
|
||
plt.xlabel('频率/Hz', labelpad=10) # 增加标签间距
|
||
plt.ylabel('振幅', labelpad=10) # 增加标签间距
|
||
|
||
# 通过IFFT回到时间域
|
||
plt.subplot(2, 2, 3)
|
||
plt.plot(t, yyy)
|
||
plt.title('(c))傅里叶逆变换')
|
||
plt.xlabel('高度 (km)', labelpad=10) # 增加标签间距
|
||
plt.ylabel('温度 (K)', labelpad=10) # 增加标签间距
|
||
|
||
# 滤波后的振幅谱
|
||
plt.subplot(2, 2, 4)
|
||
plt.plot(f, np.abs(yy) * 2 / N)
|
||
plt.title(f'(d)滤除波长 < {lamda_low} km, > {lamda_high} km的波')
|
||
plt.xlabel('频率/Hz', labelpad=10) # 增加标签间距
|
||
plt.ylabel('振幅', labelpad=10) # 增加标签间距
|
||
|
||
# 调整子图之间的边距
|
||
plt.subplots_adjust(top=0.8, bottom=0.2, left=0.2,
|
||
right=0.8, hspace=0.3, wspace=0.2)
|
||
|
||
def day_cycle_power_wave_plot(self, cycle_no, wave_data: WaveData):
|
||
ktemp_Nz = wave_data.Nz
|
||
ktemp_Ptz = wave_data.Ptz
|
||
|
||
altitude_min, altitude_max = wave_data.alt_range
|
||
|
||
N = len(ktemp_Nz[cycle_no, :])
|
||
y = np.round(np.linspace(altitude_min, altitude_max, N), 2)
|
||
x1 = ktemp_Nz[cycle_no, :]
|
||
x2 = ktemp_Ptz[cycle_no, :]
|
||
|
||
plt.figure(figsize=(12, 10)) # 调整图形大小
|
||
# 原始信号的时间序列
|
||
plt.subplot(1, 2, 1)
|
||
plt.plot(x1[::-1], y, label='原始信号')
|
||
plt.title('(a)Nz')
|
||
plt.xlabel('势能', labelpad=10) # 增加标签间距
|
||
plt.ylabel('高度 (km)', labelpad=10) # 增加标签间距
|
||
|
||
# 原始信号的时间序列
|
||
plt.subplot(1, 2, 2)
|
||
plt.plot(x2[::-1], y, label='原始信号')
|
||
plt.title('(b)Ptz')
|
||
plt.xlabel('浮力频率', labelpad=10) # 增加标签间距
|
||
plt.ylabel('高度 (km)', labelpad=10) # 增加标签间距
|
||
|
||
# 调整子图之间的边距
|
||
plt.subplots_adjust(top=0.8, bottom=0.2, left=0.1,
|
||
right=0.8, hspace=0.3, wspace=0.2)
|
||
plt.tight_layout() # 调整子图参数以适应图形区域
|
||
|
||
def month_power_wave_plot(self, wave_data: List[WaveData], date_time):
|
||
ktemp_Nz_mon = [data.Nz for data in wave_data]
|
||
ktemp_Ptz_mon = [data.Ptz for data in wave_data]
|
||
|
||
altitude_min, altitude_max = wave_data[0].alt_range
|
||
|
||
if ktemp_Nz_mon and ktemp_Nz_mon[0] is not None:
|
||
nz_shape = np.array(ktemp_Nz_mon[0]).shape
|
||
else:
|
||
nz_shape = (15, 157)
|
||
if ktemp_Ptz_mon and ktemp_Ptz_mon[0] is not None:
|
||
ptz_shape = np.array(ktemp_Ptz_mon[0]).shape
|
||
else:
|
||
ptz_shape = (15, 157)
|
||
y = np.round(np.linspace(altitude_min, altitude_max, nz_shape[1]), 2)
|
||
x = np.arange(len(date_time.data))
|
||
# 处理 ktemp_Nz_mon
|
||
ktemp_Nz_plot = np.array([np.mean(day_data if day_data is not None else np.zeros(
|
||
nz_shape), axis=0) for day_data in ktemp_Nz_mon])
|
||
ktemp_Ptz_plot = np.array(
|
||
[np.mean(day_data if day_data is not None else np.zeros(nz_shape), axis=0) for day_data in ktemp_Ptz_mon])
|
||
# 处理 ktemp_Ptz_mon(以100为界剔除异常值)
|
||
# ktemp_Ptz_plot = np.array([np.mean(day_data if day_data is not None and np.all(day_data <= 100) else np.zeros(ptz_shape), axis=0) for day_data in ktemp_Ptz_mon])
|
||
# 创建一个图形,并指定两个子图
|
||
fig, axs = plt.subplots(1, 2, figsize=(15, 10))
|
||
|
||
# 第一幅图 (a) NZ
|
||
cax1 = axs[0].imshow(ktemp_Nz_plot.T[::-1], aspect='auto', cmap='rainbow', origin='lower',
|
||
extent=[x[0], x[-1], y[0], y[-1]])
|
||
fig.colorbar(cax1, ax=axs[0]) # 为第一幅图添加颜色条
|
||
axs[0].set_title('(a) NZ')
|
||
axs[0].set_xlabel('Time')
|
||
axs[0].set_ylabel('Height')
|
||
axs[0].set_yticks(np.linspace(30, 100, 8))
|
||
axs[0].set_yticklabels(np.round(np.linspace(30, 100, 8), 1))
|
||
axs[0].set_xticks(x)
|
||
axs[0].set_xticklabels(x)
|
||
|
||
# 第二幅图 (b) PTZ
|
||
cax2 = axs[1].imshow(np.log(ktemp_Ptz_plot.T[::-1]), aspect='auto',
|
||
cmap='rainbow', origin='lower', extent=[x[0], x[-1], y[0], y[-1]])
|
||
fig.colorbar(cax2, ax=axs[1]) # 为第二幅图添加颜色条
|
||
axs[1].set_title('(b) PTZ')
|
||
axs[1].set_xlabel('Time')
|
||
axs[1].set_ylabel('Height')
|
||
axs[1].set_yticks(np.linspace(30, 100, 8))
|
||
axs[1].set_yticklabels(np.round(np.linspace(30, 100, 8), 1))
|
||
axs[1].set_xticks(x)
|
||
axs[1].set_xticklabels(x)
|
||
|
||
# 调整子图之间的边距
|
||
plt.subplots_adjust(top=0.9, bottom=0.1, left=0.05,
|
||
right=0.95, hspace=0.3, wspace=0.3)
|
||
plt.tight_layout() # 调整布局以避免重叠
|
||
|
||
def year_power_wave_plot(self, year_wave: YearlyData):
|
||
# 假设我们已经从process_yearly_data函数中获取了一年的Nz和Ptz数据
|
||
results = year_wave.to_dict()
|
||
altitude_min, altitude_max = results["alt_range"][0]
|
||
|
||
ktemp_Nz_mon_list = results["ktemp_Nz_mon_list"]
|
||
ktemp_Ptz_mon_list = results["ktemp_Ptz_mon_list"]
|
||
ktemp_Ptz_mon_list.pop(0)
|
||
ktemp_Nz_mon_list.pop(0)
|
||
|
||
# 准备日期数据,这里假设date_time_list是一年中的所有日期
|
||
date_time_list = results["date_time_list"]
|
||
date_time_list.pop(0)
|
||
# 将日期转换为matplotlib可以理解的数字格式
|
||
date_nums = mdates.date2num(date_time_list)
|
||
|
||
# 获取date_time_list长度作为横坐标新的依据
|
||
x_ticks_length = len(date_time_list)
|
||
x_ticks = np.arange(0, x_ticks_length, 30)
|
||
x_labels = [date_time_list[i] if i < len(
|
||
date_time_list) else "" for i in x_ticks]
|
||
|
||
# 准备高度数据
|
||
# 假设高度数据有157个点
|
||
y = np.round(np.linspace(altitude_min, altitude_max, 157), 2)
|
||
|
||
# 创建一个图形,并指定两个子图
|
||
fig, axs = plt.subplots(1, 2, figsize=(20, 10))
|
||
|
||
# 处理 ktemp_Nz_mon
|
||
ktemp_Nz_plot = np.array(
|
||
[np.mean(day_data if day_data is not None else np.zeros((15, 157)), axis=0) for day_data in ktemp_Nz_mon_list])
|
||
# 处理 ktemp_Ptz_mon
|
||
ktemp_Ptz_plot = np.array(
|
||
[np.mean(day_data if day_data is not None else np.zeros((15, 157)), axis=0) for day_data in ktemp_Ptz_mon_list])
|
||
# ktemp_Ptz_plot = np.array(
|
||
# [np.mean(day_data if day_data is not None and np.all(day_data <= 100) else np.zeros((15, 157)), axis=0) for
|
||
# day_data in ktemp_Ptz_mon_list])
|
||
|
||
# 第一幅图 (a) NZ
|
||
cax1 = axs[0].imshow(ktemp_Nz_plot.T[::-1], aspect='auto', cmap='rainbow', origin='lower',
|
||
extent=[0, x_ticks_length - 1, y[0], y[-1]])
|
||
fig.colorbar(cax1, ax=axs[0]) # 为第一幅图添加颜色条
|
||
axs[0].set_title('(a) NZ')
|
||
axs[0].set_xlabel('Time')
|
||
axs[0].set_ylabel('Height')
|
||
axs[0].set_yticks(np.linspace(30, 100, 8))
|
||
axs[0].set_yticklabels(np.round(np.linspace(30, 100, 8), 1))
|
||
axs[0].set_xticks(x_ticks) # 设置新的横坐标刻度
|
||
axs[0].set_xticklabels(x_labels, rotation=45)
|
||
|
||
# 第二幅图 (b) PTZ
|
||
cax2 = axs[1].imshow(np.log(ktemp_Ptz_plot.T[::-1]), aspect='auto', cmap='rainbow', origin='lower',
|
||
extent=[0, x_ticks_length - 1, y[0], y[-1]])
|
||
fig.colorbar(cax2, ax=axs[1]) # 为第二幅图添加颜色条
|
||
axs[1].set_title('(b) PTZ')
|
||
axs[1].set_xlabel('Time')
|
||
axs[1].set_ylabel('Height')
|
||
axs[1].set_yticks(np.linspace(30, 100, 8))
|
||
axs[1].set_yticklabels(np.round(np.linspace(30, 100, 8), 1))
|
||
axs[1].set_xticks(x_ticks) # 设置新的横坐标刻度
|
||
axs[1].set_xticklabels(x_labels, rotation=45)
|
||
|
||
# 调整子图之间的边距
|
||
plt.subplots_adjust(top=0.9, bottom=0.1, left=0.05,
|
||
right=0.95, hspace=0.3, wspace=0.3)
|
||
plt.tight_layout() # 调整布局以避免重叠
|