zephyr-backend/modules/saber/gravityw_render.py
2025-03-05 11:40:19 +08:00

297 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass
from typing import Optional, Tuple, List
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from matplotlib.axes import Axes
import matplotlib.dates as mdates
from modules.saber.gravityw_process import SaberGravitywProcessor, WaveData, YearlyData
@dataclass
class PlotConfig:
"""绘图配置类"""
figsize: Tuple[int, int] = (16, 10)
dpi: int = 100
cmap: str = 'viridis'
title_fontsize: int = 12
label_fontsize: int = 10
legend_fontsize: int = 8
tick_fontsize: int = 8
class SaberGravitywRenderer:
def __init__(self, config: Optional[PlotConfig] = None):
self.config = config or PlotConfig()
plt.rcParams['font.family'] = 'SimHei' # 设置为黑体(需要你的环境中有该字体)
plt.rcParams['axes.unicode_minus'] = False # 解决负号'-'显示为方块的问题
def _create_figure(self, rows: int, cols: int) -> Tuple[Figure, List[Axes]]:
"""创建图形和轴对象"""
fig, axes = plt.subplots(
rows, cols, figsize=self.config.figsize, dpi=self.config.dpi)
return fig, axes.flatten() if isinstance(axes, np.ndarray) else [axes]
def _setup_subplot(self, ax: Axes, x: np.ndarray, y1: np.ndarray,
y2: Optional[np.ndarray], title: str, y_limits: Tuple[float, float]):
"""设置单个子图的样式和数据"""
ax.plot(x, y1, label='原始信号', linewidth=1.5)
if y2 is not None:
ax.plot(x, y2, label='拟合信号', linestyle='--', linewidth=1.5)
ax.set_title(title, fontsize=self.config.title_fontsize)
ax.set_xlabel('Cycles', labelpad=10,
fontsize=self.config.label_fontsize)
ax.set_ylabel('温度 (K)', labelpad=10,
fontsize=self.config.label_fontsize)
ax.legend(fontsize=self.config.legend_fontsize)
ax.tick_params(axis='both', labelsize=self.config.tick_fontsize)
ax.set_ylim(y_limits)
ax.grid(True, linestyle='--', alpha=0.7)
def plot_wave_fitting(self, wave_data: WaveData, height_no: int):
"""绘制波数拟合结果"""
fig, axes = self._create_figure(2, 3)
# 准备数据
N = len(wave_data.wn0[:, height_no])
x = np.arange(N)
data_pairs = [
(wave_data.wn0[:, height_no], wave_data.fit_wn[0][:, height_no]),
(wave_data.wn[0][:, height_no], wave_data.fit_wn[1][:, height_no]),
(wave_data.wn[1][:, height_no], wave_data.fit_wn[2][:, height_no]),
(wave_data.wn[2][:, height_no], wave_data.fit_wn[3][:, height_no]),
(wave_data.wn[3][:, height_no], wave_data.fit_wn[4][:, height_no]),
(wave_data.wn[4][:, height_no], None)
]
# 计算统一的y轴范围
all_values = [
val for pair in data_pairs for val in pair if val is not None]
y_limits = (np.min(all_values), np.max(all_values))
# 绘制子图
for i, (y1, y2) in enumerate(data_pairs):
title = f'{"abcdef"[i]}波数k={i + 1 if i < 5 else "滤波1-5后信号"}'
self._setup_subplot(axes[i], x, y1, y2, title, y_limits)
# 调整布局
plt.tight_layout()
plt.subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.9,
hspace=0.3, wspace=0.3)
def day_fft_ifft_plot(self, cycle_no, wave_data: WaveData):
ktemp_wn5 = wave_data.wn[4]
ktemp_fft = wave_data.fft
ktemp_fft_lvbo = wave_data.fft_filtered
ktemp_ifft = wave_data.ifft
altitude_min, altitude_max = wave_data.alt_range
lamda_low, lamda_high = wave_data.lambda_range
N = len(ktemp_wn5[cycle_no, :])
# 采样时间间隔其倒数等于采用频率以1km为标准尺度等同于1s假设波的速度为1km/s
dt = (altitude_max-altitude_min)/(N-1)
# 时间序列索引
n = np.arange(N)
f = n / (N * dt)
t = np.round(np.linspace(altitude_min, altitude_max, N), 2)
# 原始扰动温度
x = ktemp_wn5[cycle_no, :]
# 傅里叶变换频谱分析
y = ktemp_fft[cycle_no, :]
# 滤波后的傅里叶变换频谱分析
yy = ktemp_fft_lvbo[cycle_no, :]
# 傅里叶逆变换后的扰动温度
yyy = ktemp_ifft[cycle_no, :]
plt.figure(figsize=(15, 10)) # 调整图形大小
# 原始信号的时间序列
plt.subplot(2, 2, 1)
plt.plot(t, x)
plt.title('a原始信号')
plt.xlabel('高度 (km)', labelpad=10) # 增加标签间距
plt.ylabel('温度 (K)', labelpad=10) # 增加标签间距
# 原始振幅谱
plt.subplot(2, 2, 2)
plt.plot(f, np.abs(y) * 2 / N)
plt.title('b原始振幅谱')
plt.xlabel('频率/Hz', labelpad=10) # 增加标签间距
plt.ylabel('振幅', labelpad=10) # 增加标签间距
# 通过IFFT回到时间域
plt.subplot(2, 2, 3)
plt.plot(t, yyy)
plt.title('c傅里叶逆变换')
plt.xlabel('高度 (km)', labelpad=10) # 增加标签间距
plt.ylabel('温度 (K)', labelpad=10) # 增加标签间距
# 滤波后的振幅谱
plt.subplot(2, 2, 4)
plt.plot(f, np.abs(yy) * 2 / N)
plt.title(f'd滤除波长 < {lamda_low} km, > {lamda_high} km的波')
plt.xlabel('频率/Hz', labelpad=10) # 增加标签间距
plt.ylabel('振幅', labelpad=10) # 增加标签间距
# 调整子图之间的边距
plt.subplots_adjust(top=0.8, bottom=0.2, left=0.2,
right=0.8, hspace=0.3, wspace=0.2)
def day_cycle_power_wave_plot(self, cycle_no, wave_data: WaveData):
ktemp_Nz = wave_data.Nz
ktemp_Ptz = wave_data.Ptz
altitude_min, altitude_max = wave_data.alt_range
N = len(ktemp_Nz[cycle_no, :])
y = np.round(np.linspace(altitude_min, altitude_max, N), 2)
x1 = ktemp_Nz[cycle_no, :]
x2 = ktemp_Ptz[cycle_no, :]
plt.figure(figsize=(12, 10)) # 调整图形大小
# 原始信号的时间序列
plt.subplot(1, 2, 1)
plt.plot(x1[::-1], y, label='原始信号')
plt.title('aNz')
plt.xlabel('势能', labelpad=10) # 增加标签间距
plt.ylabel('高度 (km)', labelpad=10) # 增加标签间距
# 原始信号的时间序列
plt.subplot(1, 2, 2)
plt.plot(x2[::-1], y, label='原始信号')
plt.title('bPtz')
plt.xlabel('浮力频率', labelpad=10) # 增加标签间距
plt.ylabel('高度 (km)', labelpad=10) # 增加标签间距
# 调整子图之间的边距
plt.subplots_adjust(top=0.8, bottom=0.2, left=0.1,
right=0.8, hspace=0.3, wspace=0.2)
plt.tight_layout() # 调整子图参数以适应图形区域
def month_power_wave_plot(self, wave_data: List[WaveData], date_time):
ktemp_Nz_mon = [data.Nz for data in wave_data]
ktemp_Ptz_mon = [data.Ptz for data in wave_data]
altitude_min, altitude_max = wave_data[0].alt_range
if ktemp_Nz_mon and ktemp_Nz_mon[0] is not None:
nz_shape = np.array(ktemp_Nz_mon[0]).shape
else:
nz_shape = (15, 157)
if ktemp_Ptz_mon and ktemp_Ptz_mon[0] is not None:
ptz_shape = np.array(ktemp_Ptz_mon[0]).shape
else:
ptz_shape = (15, 157)
y = np.round(np.linspace(altitude_min, altitude_max, nz_shape[1]), 2)
x = np.arange(len(date_time.data))
# 处理 ktemp_Nz_mon
ktemp_Nz_plot = np.array([np.mean(day_data if day_data is not None else np.zeros(
nz_shape), axis=0) for day_data in ktemp_Nz_mon])
ktemp_Ptz_plot = np.array(
[np.mean(day_data if day_data is not None else np.zeros(nz_shape), axis=0) for day_data in ktemp_Ptz_mon])
# 处理 ktemp_Ptz_mon以100为界剔除异常值
# ktemp_Ptz_plot = np.array([np.mean(day_data if day_data is not None and np.all(day_data <= 100) else np.zeros(ptz_shape), axis=0) for day_data in ktemp_Ptz_mon])
# 创建一个图形,并指定两个子图
fig, axs = plt.subplots(1, 2, figsize=(15, 10))
# 第一幅图 (a) NZ
cax1 = axs[0].imshow(ktemp_Nz_plot.T[::-1], aspect='auto', cmap='rainbow', origin='lower',
extent=[x[0], x[-1], y[0], y[-1]])
fig.colorbar(cax1, ax=axs[0]) # 为第一幅图添加颜色条
axs[0].set_title('(a) NZ')
axs[0].set_xlabel('Time')
axs[0].set_ylabel('Height')
axs[0].set_yticks(np.linspace(30, 100, 8))
axs[0].set_yticklabels(np.round(np.linspace(30, 100, 8), 1))
axs[0].set_xticks(x)
axs[0].set_xticklabels(x)
# 第二幅图 (b) PTZ
cax2 = axs[1].imshow(np.log(ktemp_Ptz_plot.T[::-1]), aspect='auto',
cmap='rainbow', origin='lower', extent=[x[0], x[-1], y[0], y[-1]])
fig.colorbar(cax2, ax=axs[1]) # 为第二幅图添加颜色条
axs[1].set_title('(b) PTZ')
axs[1].set_xlabel('Time')
axs[1].set_ylabel('Height')
axs[1].set_yticks(np.linspace(30, 100, 8))
axs[1].set_yticklabels(np.round(np.linspace(30, 100, 8), 1))
axs[1].set_xticks(x)
axs[1].set_xticklabels(x)
# 调整子图之间的边距
plt.subplots_adjust(top=0.9, bottom=0.1, left=0.05,
right=0.95, hspace=0.3, wspace=0.3)
plt.tight_layout() # 调整布局以避免重叠
def year_power_wave_plot(self, year_wave: YearlyData):
# 假设我们已经从process_yearly_data函数中获取了一年的Nz和Ptz数据
results = year_wave.to_dict()
altitude_min, altitude_max = results["alt_range"][0]
ktemp_Nz_mon_list = results["ktemp_Nz_mon_list"]
ktemp_Ptz_mon_list = results["ktemp_Ptz_mon_list"]
ktemp_Ptz_mon_list.pop(0)
ktemp_Nz_mon_list.pop(0)
# 准备日期数据这里假设date_time_list是一年中的所有日期
date_time_list = results["date_time_list"]
date_time_list.pop(0)
# 将日期转换为matplotlib可以理解的数字格式
date_nums = mdates.date2num(date_time_list)
# 获取date_time_list长度作为横坐标新的依据
x_ticks_length = len(date_time_list)
x_ticks = np.arange(0, x_ticks_length, 30)
x_labels = [date_time_list[i] if i < len(
date_time_list) else "" for i in x_ticks]
# 准备高度数据
# 假设高度数据有157个点
y = np.round(np.linspace(altitude_min, altitude_max, 157), 2)
# 创建一个图形,并指定两个子图
fig, axs = plt.subplots(1, 2, figsize=(20, 10))
# 处理 ktemp_Nz_mon
ktemp_Nz_plot = np.array(
[np.mean(day_data if day_data is not None else np.zeros((15, 157)), axis=0) for day_data in ktemp_Nz_mon_list])
# 处理 ktemp_Ptz_mon
ktemp_Ptz_plot = np.array(
[np.mean(day_data if day_data is not None else np.zeros((15, 157)), axis=0) for day_data in ktemp_Ptz_mon_list])
# ktemp_Ptz_plot = np.array(
# [np.mean(day_data if day_data is not None and np.all(day_data <= 100) else np.zeros((15, 157)), axis=0) for
# day_data in ktemp_Ptz_mon_list])
# 第一幅图 (a) NZ
cax1 = axs[0].imshow(ktemp_Nz_plot.T[::-1], aspect='auto', cmap='rainbow', origin='lower',
extent=[0, x_ticks_length - 1, y[0], y[-1]])
fig.colorbar(cax1, ax=axs[0]) # 为第一幅图添加颜色条
axs[0].set_title('(a) NZ')
axs[0].set_xlabel('Time')
axs[0].set_ylabel('Height')
axs[0].set_yticks(np.linspace(30, 100, 8))
axs[0].set_yticklabels(np.round(np.linspace(30, 100, 8), 1))
axs[0].set_xticks(x_ticks) # 设置新的横坐标刻度
axs[0].set_xticklabels(x_labels, rotation=45)
# 第二幅图 (b) PTZ
cax2 = axs[1].imshow(np.log(ktemp_Ptz_plot.T[::-1]), aspect='auto', cmap='rainbow', origin='lower',
extent=[0, x_ticks_length - 1, y[0], y[-1]])
fig.colorbar(cax2, ax=axs[1]) # 为第二幅图添加颜色条
axs[1].set_title('(b) PTZ')
axs[1].set_xlabel('Time')
axs[1].set_ylabel('Height')
axs[1].set_yticks(np.linspace(30, 100, 8))
axs[1].set_yticklabels(np.round(np.linspace(30, 100, 8), 1))
axs[1].set_xticks(x_ticks) # 设置新的横坐标刻度
axs[1].set_xticklabels(x_labels, rotation=45)
# 调整子图之间的边距
plt.subplots_adjust(top=0.9, bottom=0.1, left=0.05,
right=0.95, hspace=0.3, wspace=0.3)
plt.tight_layout() # 调整布局以避免重叠