From e1b44ed2f916da57e35fd5d9cf72a58029435ecb Mon Sep 17 00:00:00 2001 From: Dustella Date: Sat, 8 Feb 2025 13:20:53 +0800 Subject: [PATCH] init --- .gitignore | 7 +++ backend.py | 11 +++++ data/.gitkeep | 0 src/__init__.py | 49 +++++++++++++++++++++ src/plot.py | 96 ++++++++++++++++++++++++++++++++++++++++ src/process_data.py | 0 src/read_tar.py | 105 ++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 268 insertions(+) create mode 100644 .gitignore create mode 100644 backend.py create mode 100644 data/.gitkeep create mode 100644 src/__init__.py create mode 100644 src/plot.py create mode 100644 src/process_data.py create mode 100644 src/read_tar.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c0d0230 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +*.tar +*.ipynb +*.png +__pycache__ +result +staged +res.md \ No newline at end of file diff --git a/backend.py b/backend.py new file mode 100644 index 0000000..dfa4b1d --- /dev/null +++ b/backend.py @@ -0,0 +1,11 @@ +from flask import Flask + +import src + + +app = Flask(__name__) + +app.register_blueprint(src.tc_module, url_prefix='/tc') + +if __name__ == '__main__': + app.run(debug=True, port=35000) diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..022f973 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,49 @@ +from io import BytesIO +from flask import Blueprint, request, send_file + +from src.plot import create_heatmap, find_centroids +from src.read_tar import TarGet + + +tc_module = Blueprint('tc_module', __name__) + + +@tc_module.route("/metadata") +def get_files(): + return TarGet.get_tars() + + +@tc_module.route("/render") +def render(): + path = request.args.get('path') + day = request.args.get('day') + tar = TarGet(path) + data = tar.read_dtcp() + data_this_day = data["data"][:, :, int(day)] + + map = create_heatmap(data_this_day, shift=( + 0, 180), colormap='hot', rotate_90=1) + + buff = BytesIO() + map.save(buff, format="PNG") + buff.seek(0) + return send_file(buff, mimetype='image/png') + + +@tc_module.route("/metadata/centroids") +def get_centroids(): + path = request.args.get('path') + day = request.args.get('day') + tar = TarGet(path) + data = tar.read_dtcp() + data_this_day = data["data"][:, :, int(day)] + + centords = find_centroids(data_this_day) + return centords + + +@tc_module.route("/metadata/tc") +def get_tcs(): + path = request.args.get('path') + tar = TarGet(path) + return [i.to_dict() for i in tar.read_tcs()] diff --git a/src/plot.py b/src/plot.py new file mode 100644 index 0000000..0d3106c --- /dev/null +++ b/src/plot.py @@ -0,0 +1,96 @@ +import numpy as np +from PIL import Image + +from scipy import ndimage + + +def find_centroids(image): + """ + 找出图像中所有联通区域的重心 + + Args: + image: 2D numpy array,值为0或非0 + + Returns: + list of tuples: [(y1,x1), (y2,x2), ...] 表示每个联通区域的重心坐标 + """ + # 标记联通区域 + labeled_array, num_features = ndimage.label(image > 0) + + centroids = [] + for i in range(1, num_features + 1): + # 获取当前区域的mask + mask = labeled_array == i + + # 获取当前区域的坐标点 + y_coords, x_coords = np.where(mask) + + # 使用像素值作为权重计算重心 + weights = image[mask] + centroid_y = np.average(y_coords, weights=weights) + centroid_x = np.average(x_coords, weights=weights) + + centroids.append((centroid_y, centroid_x)) + + return centroids + + +def create_heatmap(data, + colormap='hot', + flip_vertical=False, + flip_horizontal=False, + rotate_90=0, + shift=(0, 0) + ): + """ + 创建热图,背景透明 + + 参数: + data: 2D numpy array,值应该在0-1之间 + colormap: 颜色映射方案,目前支持'hot' + flip_vertical: 是否垂直翻转 + flip_horizontal: 是否水平翻转 + rotate_90: 顺时针旋转的90度次数 (0, 1, 2, 或 3) + + 返回: + PIL.Image: RGBA格式的图像 + """ + # 确保数据在0-1之间 + if data.max() > 1 or data.min() < 0: + data = (data - data.min()) / (data.max() - data.min()) + if shift != (0, 0): + data = np.roll(data, shift[1], axis=0) # 纵向平移 + data = np.roll(data, shift[0], axis=1) # 横向平移 + + # 应用翻转 + if flip_vertical: + data = np.flipud(data) + if flip_horizontal: + data = np.fliplr(data) + + # 应用旋转 + rotate_90 = rotate_90 % 4 # 确保在0-3之间 + if rotate_90 != 0: + data = np.rot90(data, k=-rotate_90) # numpy的rot90是逆时针,所以用负号 + + # 创建RGBA图像 + height, width = data.shape + image = Image.new('RGBA', (width, height), (0, 0, 0, 0)) + + # 创建颜色映射 + if colormap == 'hot': + # 红色到黄色的渐变 + def get_color(value): + if np.isnan(value): + return (0, 0, 0, 0) + r = int(min(255, value * 510)) + g = int(max(0, min(255, value * 510 - 255))) + return (r, g, 0, int(value * 255)) + + # 填充像素 + pixels = image.load() + for y in range(height): + for x in range(width): + pixels[x, y] = get_color(data[y, x]) + + return image diff --git a/src/process_data.py b/src/process_data.py new file mode 100644 index 0000000..e69de29 diff --git a/src/read_tar.py b/src/read_tar.py new file mode 100644 index 0000000..ea7be4a --- /dev/null +++ b/src/read_tar.py @@ -0,0 +1,105 @@ +from dataclasses import dataclass +from datetime import datetime, timedelta +import glob +import os +import tarfile +from typing import List, Tuple + +import numpy as np +import scipy.io + + +@dataclass +class TcData: + genesis: datetime + dates: List[datetime] + positions: List[Tuple[float, float]] + + def to_dict(self): + return { + "genesis": self.genesis.isoformat(), + "dates": [d.isoformat() for d in self.dates], + "positions": self.positions + } + + +def date_converte(num): + matlab_datenum = int(num) + return datetime.fromordinal( + int(matlab_datenum)) + timedelta(days=matlab_datenum % 1) - timedelta(days=366) + + +class TarGet: + def __init__(self, path): + # normalize the path, and replaces \\ with / + self.path = path.replace('\\', '/') + file_name = self.path.split('/')[-1] + self.file_name = file_name + # date is something like 20200101 + date_str = file_name.replace(".tar", "").split('_')[1] + date = datetime.strptime(date_str, '%Y%m%d') + self.date = date + + tar = tarfile.open(self.path, 'r') + self.tar = tar + filelist = [i for i in tar.getnames() if ".mat" in i] + self.dtcp_filelist = [f for f in filelist if 'dtcp_' in f] + self.tc_filelist = [f for f in filelist if 'tc_' in f] + + @classmethod + def get_tars(self, path="./data"): + files = [i for i in os.listdir(path) if i.endswith(".tar")] + pathes = glob.glob(f"{path}/*.tar") + all_days = [file.replace(".tar", "").replace("tc_", "") + for file in files] + return { + "data": { + key: value.replace("\\", "/") for key, value in zip(all_days, pathes) + } + } + + def read_dtcp(self, mode=None): + + if mode is not None: + global_files = list( + filter(lambda a: mode in a and "global" in a, self.dtcp_filelist)) + else: + global_files = list( + filter(lambda a: "global" in a, self.dtcp_filelist)) + + if global_files.__len__() != 1: + raise Exception(f"Global file has {global_files.__len__()} files") + global_file = global_files[0] + + with self.tar.extractfile(global_file) as f: + mat_file = scipy.io.loadmat(f) + + return { + "data": mat_file["tc_global"], + "shape": mat_file['tc_global'].shape + } + + def read_tcs(self): + tcs: List[TcData] = [] + for tc_file in self.tc_filelist: + with self.tar.extractfile(tc_file) as f: + mat_file = scipy.io.loadmat(f) + tc = mat_file["tc"] + for casts in tc[0]: + if casts[0][0].size == 0: + continue + test_genesis = casts[0][0][0] + if test_genesis == 0: + continue + genesis = date_converte(casts[0][0][0]) + + dates = [date_converte(d) for d in casts[1][0]] + positions = [(float(p[0])/10, float(p[1])/10) + for p in zip(casts[2][0], casts[3][0])] + data = TcData(genesis, dates, positions) + tcs.append(data) + + return tcs + + def __exit__(self): + self.tar.close()