init
This commit is contained in:
commit
e1b44ed2f9
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
*.tar
|
||||
*.ipynb
|
||||
*.png
|
||||
__pycache__
|
||||
result
|
||||
staged
|
||||
res.md
|
||||
11
backend.py
Normal file
11
backend.py
Normal file
@ -0,0 +1,11 @@
|
||||
from flask import Flask
|
||||
|
||||
import src
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
app.register_blueprint(src.tc_module, url_prefix='/tc')
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True, port=35000)
|
||||
0
data/.gitkeep
Normal file
0
data/.gitkeep
Normal file
49
src/__init__.py
Normal file
49
src/__init__.py
Normal file
@ -0,0 +1,49 @@
|
||||
from io import BytesIO
|
||||
from flask import Blueprint, request, send_file
|
||||
|
||||
from src.plot import create_heatmap, find_centroids
|
||||
from src.read_tar import TarGet
|
||||
|
||||
|
||||
tc_module = Blueprint('tc_module', __name__)
|
||||
|
||||
|
||||
@tc_module.route("/metadata")
|
||||
def get_files():
|
||||
return TarGet.get_tars()
|
||||
|
||||
|
||||
@tc_module.route("/render")
|
||||
def render():
|
||||
path = request.args.get('path')
|
||||
day = request.args.get('day')
|
||||
tar = TarGet(path)
|
||||
data = tar.read_dtcp()
|
||||
data_this_day = data["data"][:, :, int(day)]
|
||||
|
||||
map = create_heatmap(data_this_day, shift=(
|
||||
0, 180), colormap='hot', rotate_90=1)
|
||||
|
||||
buff = BytesIO()
|
||||
map.save(buff, format="PNG")
|
||||
buff.seek(0)
|
||||
return send_file(buff, mimetype='image/png')
|
||||
|
||||
|
||||
@tc_module.route("/metadata/centroids")
|
||||
def get_centroids():
|
||||
path = request.args.get('path')
|
||||
day = request.args.get('day')
|
||||
tar = TarGet(path)
|
||||
data = tar.read_dtcp()
|
||||
data_this_day = data["data"][:, :, int(day)]
|
||||
|
||||
centords = find_centroids(data_this_day)
|
||||
return centords
|
||||
|
||||
|
||||
@tc_module.route("/metadata/tc")
|
||||
def get_tcs():
|
||||
path = request.args.get('path')
|
||||
tar = TarGet(path)
|
||||
return [i.to_dict() for i in tar.read_tcs()]
|
||||
96
src/plot.py
Normal file
96
src/plot.py
Normal file
@ -0,0 +1,96 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from scipy import ndimage
|
||||
|
||||
|
||||
def find_centroids(image):
|
||||
"""
|
||||
找出图像中所有联通区域的重心
|
||||
|
||||
Args:
|
||||
image: 2D numpy array,值为0或非0
|
||||
|
||||
Returns:
|
||||
list of tuples: [(y1,x1), (y2,x2), ...] 表示每个联通区域的重心坐标
|
||||
"""
|
||||
# 标记联通区域
|
||||
labeled_array, num_features = ndimage.label(image > 0)
|
||||
|
||||
centroids = []
|
||||
for i in range(1, num_features + 1):
|
||||
# 获取当前区域的mask
|
||||
mask = labeled_array == i
|
||||
|
||||
# 获取当前区域的坐标点
|
||||
y_coords, x_coords = np.where(mask)
|
||||
|
||||
# 使用像素值作为权重计算重心
|
||||
weights = image[mask]
|
||||
centroid_y = np.average(y_coords, weights=weights)
|
||||
centroid_x = np.average(x_coords, weights=weights)
|
||||
|
||||
centroids.append((centroid_y, centroid_x))
|
||||
|
||||
return centroids
|
||||
|
||||
|
||||
def create_heatmap(data,
|
||||
colormap='hot',
|
||||
flip_vertical=False,
|
||||
flip_horizontal=False,
|
||||
rotate_90=0,
|
||||
shift=(0, 0)
|
||||
):
|
||||
"""
|
||||
创建热图,背景透明
|
||||
|
||||
参数:
|
||||
data: 2D numpy array,值应该在0-1之间
|
||||
colormap: 颜色映射方案,目前支持'hot'
|
||||
flip_vertical: 是否垂直翻转
|
||||
flip_horizontal: 是否水平翻转
|
||||
rotate_90: 顺时针旋转的90度次数 (0, 1, 2, 或 3)
|
||||
|
||||
返回:
|
||||
PIL.Image: RGBA格式的图像
|
||||
"""
|
||||
# 确保数据在0-1之间
|
||||
if data.max() > 1 or data.min() < 0:
|
||||
data = (data - data.min()) / (data.max() - data.min())
|
||||
if shift != (0, 0):
|
||||
data = np.roll(data, shift[1], axis=0) # 纵向平移
|
||||
data = np.roll(data, shift[0], axis=1) # 横向平移
|
||||
|
||||
# 应用翻转
|
||||
if flip_vertical:
|
||||
data = np.flipud(data)
|
||||
if flip_horizontal:
|
||||
data = np.fliplr(data)
|
||||
|
||||
# 应用旋转
|
||||
rotate_90 = rotate_90 % 4 # 确保在0-3之间
|
||||
if rotate_90 != 0:
|
||||
data = np.rot90(data, k=-rotate_90) # numpy的rot90是逆时针,所以用负号
|
||||
|
||||
# 创建RGBA图像
|
||||
height, width = data.shape
|
||||
image = Image.new('RGBA', (width, height), (0, 0, 0, 0))
|
||||
|
||||
# 创建颜色映射
|
||||
if colormap == 'hot':
|
||||
# 红色到黄色的渐变
|
||||
def get_color(value):
|
||||
if np.isnan(value):
|
||||
return (0, 0, 0, 0)
|
||||
r = int(min(255, value * 510))
|
||||
g = int(max(0, min(255, value * 510 - 255)))
|
||||
return (r, g, 0, int(value * 255))
|
||||
|
||||
# 填充像素
|
||||
pixels = image.load()
|
||||
for y in range(height):
|
||||
for x in range(width):
|
||||
pixels[x, y] = get_color(data[y, x])
|
||||
|
||||
return image
|
||||
0
src/process_data.py
Normal file
0
src/process_data.py
Normal file
105
src/read_tar.py
Normal file
105
src/read_tar.py
Normal file
@ -0,0 +1,105 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
import glob
|
||||
import os
|
||||
import tarfile
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
|
||||
|
||||
@dataclass
|
||||
class TcData:
|
||||
genesis: datetime
|
||||
dates: List[datetime]
|
||||
positions: List[Tuple[float, float]]
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"genesis": self.genesis.isoformat(),
|
||||
"dates": [d.isoformat() for d in self.dates],
|
||||
"positions": self.positions
|
||||
}
|
||||
|
||||
|
||||
def date_converte(num):
|
||||
matlab_datenum = int(num)
|
||||
return datetime.fromordinal(
|
||||
int(matlab_datenum)) + timedelta(days=matlab_datenum % 1) - timedelta(days=366)
|
||||
|
||||
|
||||
class TarGet:
|
||||
def __init__(self, path):
|
||||
# normalize the path, and replaces \\ with /
|
||||
self.path = path.replace('\\', '/')
|
||||
file_name = self.path.split('/')[-1]
|
||||
self.file_name = file_name
|
||||
# date is something like 20200101
|
||||
date_str = file_name.replace(".tar", "").split('_')[1]
|
||||
date = datetime.strptime(date_str, '%Y%m%d')
|
||||
self.date = date
|
||||
|
||||
tar = tarfile.open(self.path, 'r')
|
||||
self.tar = tar
|
||||
filelist = [i for i in tar.getnames() if ".mat" in i]
|
||||
self.dtcp_filelist = [f for f in filelist if 'dtcp_' in f]
|
||||
self.tc_filelist = [f for f in filelist if 'tc_' in f]
|
||||
|
||||
@classmethod
|
||||
def get_tars(self, path="./data"):
|
||||
files = [i for i in os.listdir(path) if i.endswith(".tar")]
|
||||
pathes = glob.glob(f"{path}/*.tar")
|
||||
all_days = [file.replace(".tar", "").replace("tc_", "")
|
||||
for file in files]
|
||||
return {
|
||||
"data": {
|
||||
key: value.replace("\\", "/") for key, value in zip(all_days, pathes)
|
||||
}
|
||||
}
|
||||
|
||||
def read_dtcp(self, mode=None):
|
||||
|
||||
if mode is not None:
|
||||
global_files = list(
|
||||
filter(lambda a: mode in a and "global" in a, self.dtcp_filelist))
|
||||
else:
|
||||
global_files = list(
|
||||
filter(lambda a: "global" in a, self.dtcp_filelist))
|
||||
|
||||
if global_files.__len__() != 1:
|
||||
raise Exception(f"Global file has {global_files.__len__()} files")
|
||||
global_file = global_files[0]
|
||||
|
||||
with self.tar.extractfile(global_file) as f:
|
||||
mat_file = scipy.io.loadmat(f)
|
||||
|
||||
return {
|
||||
"data": mat_file["tc_global"],
|
||||
"shape": mat_file['tc_global'].shape
|
||||
}
|
||||
|
||||
def read_tcs(self):
|
||||
tcs: List[TcData] = []
|
||||
for tc_file in self.tc_filelist:
|
||||
with self.tar.extractfile(tc_file) as f:
|
||||
mat_file = scipy.io.loadmat(f)
|
||||
tc = mat_file["tc"]
|
||||
for casts in tc[0]:
|
||||
if casts[0][0].size == 0:
|
||||
continue
|
||||
test_genesis = casts[0][0][0]
|
||||
if test_genesis == 0:
|
||||
continue
|
||||
genesis = date_converte(casts[0][0][0])
|
||||
|
||||
dates = [date_converte(d) for d in casts[1][0]]
|
||||
positions = [(float(p[0])/10, float(p[1])/10)
|
||||
for p in zip(casts[2][0], casts[3][0])]
|
||||
data = TcData(genesis, dates, positions)
|
||||
tcs.append(data)
|
||||
|
||||
return tcs
|
||||
|
||||
def __exit__(self):
|
||||
self.tar.close()
|
||||
Loading…
x
Reference in New Issue
Block a user