299 lines
14 KiB
Python
299 lines
14 KiB
Python
import os
|
||
import numpy as np
|
||
import pandas as pd
|
||
from scipy.interpolate import interp1d
|
||
from scipy.optimize import curve_fit
|
||
import netCDF4 as nc
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
import matplotlib.font_manager as fm
|
||
|
||
from CONSTANT import DATA_BASEPATH
|
||
from modules.cosmic.gravityw_multiday_process import process_single_file
|
||
|
||
DAY_RANGE = (0, 204)
|
||
|
||
# 设置支持中文的字体
|
||
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
|
||
fm.fontManager.addfont("./SimHei.ttf")
|
||
plt.rcParams['font.sans-serif'] = ['Simhei'] # 设置字体为微软雅黑
|
||
|
||
# 主循环,处理1到365个文件
|
||
|
||
|
||
def get_multiday_data(year=2008, day_range=DAY_RANGE, lat_range=(30, 40)):
|
||
base_folder_path = f"{DATA_BASEPATH.cosmic}/{year}"
|
||
all_mean_ktemp_Nz = []
|
||
all_mean_ktemp_Ptz = []
|
||
day_begin, day_end = day_range
|
||
for file_index in range(day_begin, day_end):
|
||
try:
|
||
mean_ktemp_Nz, mean_ktemp_Ptz = process_single_file(
|
||
base_folder_path, file_index, lat_range)
|
||
if mean_ktemp_Nz is not None and mean_ktemp_Ptz is not None:
|
||
all_mean_ktemp_Nz.append(mean_ktemp_Nz)
|
||
all_mean_ktemp_Ptz.append(mean_ktemp_Ptz)
|
||
except ValueError as e:
|
||
print(
|
||
f"Error processing file index {file_index}: {e}, skipping this file.")
|
||
continue
|
||
|
||
# 转换每个数组为二维形状
|
||
final_mean_ktemp_Nz = np.vstack([arr.reshape(1, -1)
|
||
for arr in all_mean_ktemp_Nz])
|
||
final_mean_ktemp_Ptz = np.vstack(
|
||
[arr.reshape(1, -1) for arr in all_mean_ktemp_Ptz])
|
||
# 使用条件索引替换大于50的值为NaN
|
||
final_mean_ktemp_Ptz[final_mean_ktemp_Ptz > 50] = np.nan
|
||
# heights 为每个高度的值
|
||
heights = np.linspace(0, 60, 3000)
|
||
df_final_mean_ktemp_Ptz = pd.DataFrame(final_mean_ktemp_Ptz)
|
||
df_final_mean_ktemp_Nz = pd.DataFrame(final_mean_ktemp_Nz)
|
||
# -------------------------------------------------绘制年统计图------------------------------------
|
||
# -----------绘制浮力频率年统计图-----------------------
|
||
data = df_final_mean_ktemp_Nz.T
|
||
# 对每个元素进行平方(计算N2)
|
||
data = data ** 2
|
||
data = data*10000 # (绘图好看)
|
||
# 将大于 10 的值替换为 NaN(个别异常值)
|
||
data[data > 10] = np.nan
|
||
return df_final_mean_ktemp_Nz, df_final_mean_ktemp_Ptz, data, heights
|
||
# 绘制热力图的函数
|
||
|
||
|
||
def plot_heatmap(data, heights, title):
|
||
plt.figure(figsize=(10, 8))
|
||
# 绘制热力图,数据中的行代表高度,列代表天数
|
||
sns.heatmap(data, cmap='coolwarm', xticklabels=1,
|
||
yticklabels=50, cbar_kws={'label': 'Value'})
|
||
plt.xlabel('天')
|
||
plt.ylabel('高度 (km)')
|
||
plt.title(title)
|
||
num_days = data.shape[1]
|
||
x_tick_positions = np.arange(0, num_days, 30)
|
||
x_tick_labels = np.arange(0, num_days, 30)
|
||
plt.xticks(x_tick_positions, x_tick_labels)
|
||
# 设置y轴的刻度,使其显示对应的高度
|
||
plt.yticks(np.linspace(
|
||
0, data.shape[0] - 1, 6), np.round(np.linspace(heights[0], heights[-1], 6), 2))
|
||
# 反转 y 轴,使 0 在底部
|
||
plt.gca().invert_yaxis()
|
||
plt.show()
|
||
|
||
|
||
class GravityMultidayPlot:
|
||
def __init__(self, year, day_range, lat_range=None):
|
||
|
||
self.year = year
|
||
df_final_mean_ktemp_Nz, df_final_mean_ktemp_Ptz, data, heights = get_multiday_data(
|
||
year, day_range, lat_range)
|
||
self.df_final_mean_ktemp_Nz = df_final_mean_ktemp_Nz
|
||
self.df_final_mean_ktemp_Ptz = df_final_mean_ktemp_Ptz
|
||
self.data = data
|
||
self.data1 = df_final_mean_ktemp_Ptz.T
|
||
self.heights = heights
|
||
|
||
def plot_heatmap_tempNz(self):
|
||
# 调用函数绘制热力图
|
||
data = self.data
|
||
plot_heatmap(data, self.heights,
|
||
'"最终平均温度场的热图 (单位:10^(-4)),随浮力频率Nz变化')
|
||
|
||
def plot_heatmap_tempPtz(self):
|
||
# -------------绘制重力势能年统计图------------------------------------------------
|
||
data = self.df_final_mean_ktemp_Ptz.T
|
||
plot_heatmap(data, self.heights,
|
||
'最终平均动能温度-位势温度的热图 (单位:焦耳/千克)')
|
||
|
||
def plot_monthly_tempNz(self):
|
||
# ------------------------绘制月统计图---------------------------------------------------------------------------------
|
||
# ----------绘制浮力频率月统计图-------------------------------------------------
|
||
# 获取总列数
|
||
data = self.data
|
||
num_columns = data.shape[1]
|
||
# 按30列分组计算均值
|
||
averaged_df = []
|
||
# 逐步处理每30列
|
||
for i in range(0, num_columns, 30):
|
||
# 获取当前范围内的列,并计算均值
|
||
subset = data.iloc[:, i:i+30] # 获取第i到i+29列
|
||
mean_values = subset.mean(axis=1) # 对每行计算均值
|
||
averaged_df.append(mean_values) # 将均值添加到列表
|
||
# 将结果转化为一个新的 DataFrame
|
||
averaged_df = pd.DataFrame(averaged_df).T
|
||
# 1. 每3000行取一个均值
|
||
# 获取总行数
|
||
num_rows = averaged_df.shape[0]
|
||
# 创建一个新的列表来存储每3000行的均值
|
||
averaged_by_rows_df = []
|
||
# 逐步处理每3000行
|
||
for i in range(0, num_rows, 3000):
|
||
# 获取当前范围内的行
|
||
subset = averaged_df.iloc[i:i+3000, :] # 获取第i到i+99行
|
||
mean_values = subset.mean(axis=0) # 对每列计算均值
|
||
averaged_by_rows_df.append(mean_values) # 将均值添加到列表
|
||
# 将结果转化为一个新的 DataFrame
|
||
averaged_by_rows_df = pd.DataFrame(averaged_by_rows_df)
|
||
# 绘制折线图
|
||
plt.figure(figsize=(10, 6)) # 设置图形的大小
|
||
plt.plot(averaged_by_rows_df.columns, averaged_by_rows_df.mean(
|
||
axis=0), marker='o', color='b', label='平均值')
|
||
# 添加标题和标签
|
||
plt.title('每月平均N^2的折线图')
|
||
plt.xlabel('月份')
|
||
plt.ylabel('N^2(10^-4)')
|
||
plt.legend()
|
||
# 显示图形
|
||
plt.grid(True)
|
||
plt.xticks(rotation=45) # 让x轴标签(月份)倾斜,以便更清晰显示
|
||
plt.tight_layout()
|
||
|
||
def plot_monthly_energy(self):
|
||
data1 = self.data1
|
||
|
||
# ------------重力势能的月统计-----------------------------------
|
||
# 获取总列数
|
||
num_columns = data1.shape[1]
|
||
# 按30列分组计算均值
|
||
averaged_df = []
|
||
# 逐步处理每30列
|
||
for i in range(0, num_columns, 30):
|
||
# 获取当前范围内的列,并计算均值
|
||
subset = data1.iloc[:, i:i+30] # 获取第i到i+29列
|
||
mean_values = subset.mean(axis=1) # 对每行计算均值
|
||
averaged_df.append(mean_values) # 将均值添加到列表
|
||
# 将结果转化为一个新的 DataFrame
|
||
averaged_df = pd.DataFrame(averaged_df).T
|
||
# 1. 每3000行取一个均值
|
||
# 获取总行数
|
||
num_rows = averaged_df.shape[0]
|
||
# 创建一个新的列表来存储每3000行的均值
|
||
averaged_by_rows_df = []
|
||
# 逐步处理每3000行
|
||
for i in range(0, num_rows, 3000):
|
||
# 获取当前范围内的行
|
||
subset = averaged_df.iloc[i:i+3000, :] # 获取第i到i+99行
|
||
mean_values = subset.mean(axis=0) # 对每列计算均值
|
||
averaged_by_rows_df.append(mean_values) # 将均值添加到列表
|
||
# 将结果转化为一个新的 DataFrame
|
||
averaged_by_rows_df = pd.DataFrame(averaged_by_rows_df)
|
||
# 绘制折线图
|
||
plt.figure(figsize=(10, 6)) # 设置图形的大小
|
||
plt.plot(averaged_by_rows_df.columns, averaged_by_rows_df.mean(
|
||
axis=0), marker='o', color='b', label='平均值')
|
||
# 添加标题和标签
|
||
plt.title('每月平均重力势能的折线图')
|
||
plt.xlabel('月份')
|
||
plt.ylabel('重力势能(J/Kg)')
|
||
plt.legend()
|
||
# 显示图形
|
||
plt.grid(True)
|
||
plt.xticks(rotation=45) # 让x轴标签(月份)倾斜,以便更清晰显示
|
||
plt.tight_layout()
|
||
|
||
def plot_floatage_trend(self):
|
||
data = self.data
|
||
# 获取总列数
|
||
total_columns = data.shape[1]
|
||
# 用于存储每一组30列计算得到的均值列数据(最终会构成新的DataFrame)
|
||
mean_columns = []
|
||
# 分组序号,用于生成列名时区分不同的均值列,从1开始
|
||
group_index = 1
|
||
# 按照每30列一组进行划分(不滑动)
|
||
for start_col in range(0, total_columns, 30):
|
||
end_col = start_col + 30
|
||
if end_col > total_columns:
|
||
end_col = total_columns
|
||
# 选取当前组的30列(如果不足30列,按实际剩余列数选取)
|
||
group_data = data.iloc[:, start_col:end_col]
|
||
# 按行对当前组的列数据求和
|
||
sum_per_row = group_data.sum(axis=1)
|
||
# 计算平均(每一组的平均,每行都有一个平均结果)
|
||
mean_per_row = sum_per_row / (end_col - start_col)
|
||
# 生成新的列名,格式为'Mean_分组序号',例如'Mean_1'、'Mean_2'等
|
||
new_column_name = f'Mean_{group_index}'
|
||
group_index += 1
|
||
# 将当前组计算得到的均值列添加到列表中
|
||
mean_columns.append(mean_per_row)
|
||
# 将所有的均值列合并为一个新的DataFrame(列方向合并)
|
||
new_mean_df = pd.concat(mean_columns, axis=1)
|
||
# 按行对new_mean_df所有列的数据进行求和,得到一个Series,索引与new_mean_df的索引一致,每个元素是每行的总和
|
||
row_sums = new_mean_df.sum(axis=1)
|
||
# 计算所有行总和的均值
|
||
mean_value = row_sums.mean()
|
||
# 设置中文字体为黑体,解决中文显示问题(Windows系统下),如果是其他系统或者有其他字体需求可适当调整
|
||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||
# 解决负号显示问题
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
# 提取月份作为x轴标签(假设mean_value的索引就是月份信息)
|
||
print(mean_value)
|
||
months = mean_value.index.tolist()
|
||
# 提取均值数据作为y轴数据
|
||
energy_values = mean_value.tolist()
|
||
# 创建图形和坐标轴对象
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
# 绘制折线图
|
||
ax.plot(months, energy_values, marker='o', linestyle='-', color='b')
|
||
# 设置坐标轴标签和标题
|
||
ax.set_xlabel('月份')
|
||
ax.set_ylabel('平均浮力频率')
|
||
ax.set_title('每月浮力频率变化趋势')
|
||
# 设置x轴刻度,让其旋转一定角度以便更好地显示所有月份标签,避免重叠
|
||
plt.xticks(rotation=45)
|
||
# 显示网格线,增强图表可读性
|
||
ax.grid(True)
|
||
# 显示图形
|
||
|
||
def plot_floatage_trend(self):
|
||
data1 = self.data1
|
||
# --------------------------------绘制重力势能月统计图------------------------------
|
||
# 获取总列数
|
||
total_columns = data1.shape[1]
|
||
# 用于存储每一组30列计算得到的均值列数据(最终会构成新的DataFrame)
|
||
mean_columns = []
|
||
# 分组序号,用于生成列名时区分不同的均值列,从1开始
|
||
group_index = 1
|
||
# 按照每30列一组进行划分(不滑动)
|
||
for start_col in range(0, total_columns, 30):
|
||
end_col = start_col + 30
|
||
if end_col > total_columns:
|
||
end_col = total_columns
|
||
# 选取当前组的30列(如果不足30列,按实际剩余列数选取)
|
||
group_data = data1.iloc[:, start_col:end_col]
|
||
# 按行对当前组的列数据求和
|
||
sum_per_row = group_data.sum(axis=1)
|
||
# 计算平均(每一组的平均,每行都有一个平均结果)
|
||
mean_per_row = sum_per_row / (end_col - start_col)
|
||
# 生成新的列名,格式为'Mean_分组序号',例如'Mean_1'、'Mean_2'等
|
||
new_column_name = f'Mean_{group_index}'
|
||
group_index += 1
|
||
# 将当前组计算得到的均值列添加到列表中
|
||
mean_columns.append(mean_per_row)
|
||
# 将所有的均值列合并为一个新的DataFrame(列方向合并)
|
||
new_mean_df = pd.concat(mean_columns, axis=1)
|
||
# 按行对new_mean_df所有列的数据进行求和,得到一个Series,索引与new_mean_df的索引一致,每个元素是每行的总和
|
||
row_sums = new_mean_df.sum(axis=1)
|
||
# 计算所有行总和的均值
|
||
mean_value = row_sums.mean()
|
||
# 设置中文字体为黑体,解决中文显示问题(Windows系统下),如果是其他系统或者有其他字体需求可适当调整
|
||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||
# 解决负号显示问题
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
# 提取月份作为x轴标签(假设mean_value的索引就是月份信息)
|
||
months = mean_value.index.tolist()
|
||
# 提取均值数据作为y轴数据
|
||
energy_values = mean_value.tolist()
|
||
# 创建图形和坐标轴对象
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
# 绘制折线图
|
||
ax.plot(months, energy_values, marker='o', linestyle='-', color='b')
|
||
# 设置坐标轴标签和标题
|
||
ax.set_xlabel('月份')
|
||
ax.set_ylabel('平均浮力频率')
|
||
ax.set_title('每月浮力频率变化趋势')
|
||
# 设置x轴刻度,让其旋转一定角度以便更好地显示所有月份标签,避免重叠
|
||
plt.xticks(rotation=45)
|
||
# 显示网格线,增强图表可读性
|
||
ax.grid(True)
|
||
# 显示图形
|