This commit is contained in:
Dustella 2025-02-08 13:20:53 +08:00
commit e1b44ed2f9
Signed by: Dustella
GPG Key ID: 35AA0AA3DC402D5C
7 changed files with 268 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
*.tar
*.ipynb
*.png
__pycache__
result
staged
res.md

11
backend.py Normal file
View File

@ -0,0 +1,11 @@
from flask import Flask
import src
app = Flask(__name__)
app.register_blueprint(src.tc_module, url_prefix='/tc')
if __name__ == '__main__':
app.run(debug=True, port=35000)

0
data/.gitkeep Normal file
View File

49
src/__init__.py Normal file
View File

@ -0,0 +1,49 @@
from io import BytesIO
from flask import Blueprint, request, send_file
from src.plot import create_heatmap, find_centroids
from src.read_tar import TarGet
tc_module = Blueprint('tc_module', __name__)
@tc_module.route("/metadata")
def get_files():
return TarGet.get_tars()
@tc_module.route("/render")
def render():
path = request.args.get('path')
day = request.args.get('day')
tar = TarGet(path)
data = tar.read_dtcp()
data_this_day = data["data"][:, :, int(day)]
map = create_heatmap(data_this_day, shift=(
0, 180), colormap='hot', rotate_90=1)
buff = BytesIO()
map.save(buff, format="PNG")
buff.seek(0)
return send_file(buff, mimetype='image/png')
@tc_module.route("/metadata/centroids")
def get_centroids():
path = request.args.get('path')
day = request.args.get('day')
tar = TarGet(path)
data = tar.read_dtcp()
data_this_day = data["data"][:, :, int(day)]
centords = find_centroids(data_this_day)
return centords
@tc_module.route("/metadata/tc")
def get_tcs():
path = request.args.get('path')
tar = TarGet(path)
return [i.to_dict() for i in tar.read_tcs()]

96
src/plot.py Normal file
View File

@ -0,0 +1,96 @@
import numpy as np
from PIL import Image
from scipy import ndimage
def find_centroids(image):
"""
找出图像中所有联通区域的重心
Args:
image: 2D numpy array值为0或非0
Returns:
list of tuples: [(y1,x1), (y2,x2), ...] 表示每个联通区域的重心坐标
"""
# 标记联通区域
labeled_array, num_features = ndimage.label(image > 0)
centroids = []
for i in range(1, num_features + 1):
# 获取当前区域的mask
mask = labeled_array == i
# 获取当前区域的坐标点
y_coords, x_coords = np.where(mask)
# 使用像素值作为权重计算重心
weights = image[mask]
centroid_y = np.average(y_coords, weights=weights)
centroid_x = np.average(x_coords, weights=weights)
centroids.append((centroid_y, centroid_x))
return centroids
def create_heatmap(data,
colormap='hot',
flip_vertical=False,
flip_horizontal=False,
rotate_90=0,
shift=(0, 0)
):
"""
创建热图背景透明
参数:
data: 2D numpy array值应该在0-1之间
colormap: 颜色映射方案目前支持'hot'
flip_vertical: 是否垂直翻转
flip_horizontal: 是否水平翻转
rotate_90: 顺时针旋转的90度次数 (0, 1, 2, 3)
返回:
PIL.Image: RGBA格式的图像
"""
# 确保数据在0-1之间
if data.max() > 1 or data.min() < 0:
data = (data - data.min()) / (data.max() - data.min())
if shift != (0, 0):
data = np.roll(data, shift[1], axis=0) # 纵向平移
data = np.roll(data, shift[0], axis=1) # 横向平移
# 应用翻转
if flip_vertical:
data = np.flipud(data)
if flip_horizontal:
data = np.fliplr(data)
# 应用旋转
rotate_90 = rotate_90 % 4 # 确保在0-3之间
if rotate_90 != 0:
data = np.rot90(data, k=-rotate_90) # numpy的rot90是逆时针所以用负号
# 创建RGBA图像
height, width = data.shape
image = Image.new('RGBA', (width, height), (0, 0, 0, 0))
# 创建颜色映射
if colormap == 'hot':
# 红色到黄色的渐变
def get_color(value):
if np.isnan(value):
return (0, 0, 0, 0)
r = int(min(255, value * 510))
g = int(max(0, min(255, value * 510 - 255)))
return (r, g, 0, int(value * 255))
# 填充像素
pixels = image.load()
for y in range(height):
for x in range(width):
pixels[x, y] = get_color(data[y, x])
return image

0
src/process_data.py Normal file
View File

105
src/read_tar.py Normal file
View File

@ -0,0 +1,105 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
import glob
import os
import tarfile
from typing import List, Tuple
import numpy as np
import scipy.io
@dataclass
class TcData:
genesis: datetime
dates: List[datetime]
positions: List[Tuple[float, float]]
def to_dict(self):
return {
"genesis": self.genesis.isoformat(),
"dates": [d.isoformat() for d in self.dates],
"positions": self.positions
}
def date_converte(num):
matlab_datenum = int(num)
return datetime.fromordinal(
int(matlab_datenum)) + timedelta(days=matlab_datenum % 1) - timedelta(days=366)
class TarGet:
def __init__(self, path):
# normalize the path, and replaces \\ with /
self.path = path.replace('\\', '/')
file_name = self.path.split('/')[-1]
self.file_name = file_name
# date is something like 20200101
date_str = file_name.replace(".tar", "").split('_')[1]
date = datetime.strptime(date_str, '%Y%m%d')
self.date = date
tar = tarfile.open(self.path, 'r')
self.tar = tar
filelist = [i for i in tar.getnames() if ".mat" in i]
self.dtcp_filelist = [f for f in filelist if 'dtcp_' in f]
self.tc_filelist = [f for f in filelist if 'tc_' in f]
@classmethod
def get_tars(self, path="./data"):
files = [i for i in os.listdir(path) if i.endswith(".tar")]
pathes = glob.glob(f"{path}/*.tar")
all_days = [file.replace(".tar", "").replace("tc_", "")
for file in files]
return {
"data": {
key: value.replace("\\", "/") for key, value in zip(all_days, pathes)
}
}
def read_dtcp(self, mode=None):
if mode is not None:
global_files = list(
filter(lambda a: mode in a and "global" in a, self.dtcp_filelist))
else:
global_files = list(
filter(lambda a: "global" in a, self.dtcp_filelist))
if global_files.__len__() != 1:
raise Exception(f"Global file has {global_files.__len__()} files")
global_file = global_files[0]
with self.tar.extractfile(global_file) as f:
mat_file = scipy.io.loadmat(f)
return {
"data": mat_file["tc_global"],
"shape": mat_file['tc_global'].shape
}
def read_tcs(self):
tcs: List[TcData] = []
for tc_file in self.tc_filelist:
with self.tar.extractfile(tc_file) as f:
mat_file = scipy.io.loadmat(f)
tc = mat_file["tc"]
for casts in tc[0]:
if casts[0][0].size == 0:
continue
test_genesis = casts[0][0][0]
if test_genesis == 0:
continue
genesis = date_converte(casts[0][0][0])
dates = [date_converte(d) for d in casts[1][0]]
positions = [(float(p[0])/10, float(p[1])/10)
for p in zip(casts[2][0], casts[3][0])]
data = TcData(genesis, dates, positions)
tcs.append(data)
return tcs
def __exit__(self):
self.tar.close()