| | import os |
| | from collections import defaultdict |
| | from datetime import datetime, timedelta |
| | from typing import Any, AnyStr, Dict, List, NamedTuple, Optional, Union |
| |
|
| | import numpy as np |
| | import tensorflow as tf |
| | from fastapi import FastAPI, WebSocket |
| | from postprocess import extract_picks |
| | from pydantic import BaseModel |
| | from scipy.interpolate import interp1d |
| |
|
| | from model import UNet |
| |
|
| | PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) |
| |
|
| | tf.compat.v1.disable_eager_execution() |
| | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
| | JSONObject = Dict[AnyStr, Any] |
| | JSONArray = List[Any] |
| | JSONStructure = Union[JSONArray, JSONObject] |
| |
|
| | app = FastAPI() |
| | X_SHAPE = [3000, 1, 3] |
| | SAMPLING_RATE = 100 |
| |
|
| | |
| | model = UNet(mode="pred") |
| | sess_config = tf.compat.v1.ConfigProto() |
| | sess_config.gpu_options.allow_growth = True |
| |
|
| | sess = tf.compat.v1.Session(config=sess_config) |
| | saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) |
| | init = tf.compat.v1.global_variables_initializer() |
| | sess.run(init) |
| | latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543") |
| | print(f"restoring model {latest_check_point}") |
| | saver.restore(sess, latest_check_point) |
| |
|
| |
|
| | def normalize_batch(data, window=3000): |
| | """ |
| | data: nsta, nt, nch |
| | """ |
| | shift = window // 2 |
| | nsta, nt, nch = data.shape |
| |
|
| | |
| | data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0)), mode="reflect") |
| | t = np.arange(0, nt, shift, dtype="int") |
| | std = np.zeros([nsta, len(t) + 1, nch]) |
| | mean = np.zeros([nsta, len(t) + 1, nch]) |
| | for i in range(1, len(t)): |
| | std[:, i, :] = np.std(data_pad[:, i * shift : i * shift + window, :], axis=1) |
| | mean[:, i, :] = np.mean(data_pad[:, i * shift : i * shift + window, :], axis=1) |
| |
|
| | t = np.append(t, nt) |
| | |
| | |
| | std[:, -1, :], mean[:, -1, :] = std[:, -2, :], mean[:, -2, :] |
| | std[:, 0, :], mean[:, 0, :] = std[:, 1, :], mean[:, 1, :] |
| | std[std == 0] = 1 |
| |
|
| | |
| | t_interp = np.arange(nt, dtype="int") |
| | std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp) |
| | mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp) |
| | data = (data - mean_interp) / std_interp |
| |
|
| | return data |
| |
|
| |
|
| | def preprocess(data): |
| | raw = data.copy() |
| | data = normalize_batch(data) |
| | if len(data.shape) == 3: |
| | data = data[:, :, np.newaxis, :] |
| | raw = raw[:, :, np.newaxis, :] |
| | return data, raw |
| |
|
| |
|
| | def calc_timestamp(timestamp, sec): |
| | timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec) |
| | return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] |
| |
|
| |
|
| | def format_picks(picks, dt, amplitudes): |
| | picks_ = [] |
| | for pick, amplitude in zip(picks, amplitudes): |
| | for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp): |
| | for idx, prob, amp in zip(idxs, probs, amps): |
| | picks_.append( |
| | { |
| | "id": pick.fname, |
| | "timestamp": calc_timestamp(pick.t0, float(idx) * dt), |
| | "prob": prob, |
| | "amp": amp, |
| | "type": "p", |
| | } |
| | ) |
| | for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp): |
| | for idx, prob, amp in zip(idxs, probs, amps): |
| | picks_.append( |
| | { |
| | "id": pick.fname, |
| | "timestamp": calc_timestamp(pick.t0, float(idx) * dt), |
| | "prob": prob, |
| | "amp": amp, |
| | "type": "s", |
| | } |
| | ) |
| | return picks_ |
| |
|
| |
|
| | def format_data(data): |
| | |
| | |
| | |
| | chn2idx = {"E": 0, "N": 1, "Z": 2, "3": 0, "2": 1, "1": 2} |
| | Data = NamedTuple("data", [("id", list), ("timestamp", list), ("vec", list), ("dt", float)]) |
| |
|
| | |
| | chn_ = defaultdict(list) |
| | t0_ = defaultdict(list) |
| | vv_ = defaultdict(list) |
| | for i in range(len(data.id)): |
| | key = data.id[i][:-1] |
| | chn_[key].append(data.id[i][-1]) |
| | t0_[key].append(datetime.strptime(data.timestamp[i], "%Y-%m-%dT%H:%M:%S.%f").timestamp() * SAMPLING_RATE) |
| | vv_[key].append(np.array(data.vec[i])) |
| |
|
| | |
| | id_ = [] |
| | timestamp_ = [] |
| | vec_ = [] |
| | for k in chn_: |
| | id_.append(k) |
| | min_t0 = min(t0_[k]) |
| | timestamp_.append(datetime.fromtimestamp(min_t0 / SAMPLING_RATE).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]) |
| | vec = np.zeros([X_SHAPE[0], X_SHAPE[-1]]) |
| | for i in range(len(chn_[k])): |
| | |
| | shift = int(t0_[k][i] - min_t0) |
| | vec[shift : len(vv_[k][i]) + shift, chn2idx[chn_[k][i]]] = vv_[k][i][: X_SHAPE[0] - shift] - np.mean( |
| | vv_[k][i][: X_SHAPE[0] - shift] |
| | ) |
| | vec_.append(vec.tolist()) |
| |
|
| | return Data(id=id_, timestamp=timestamp_, vec=vec_, dt=1 / SAMPLING_RATE) |
| | |
| |
|
| |
|
| | def get_prediction(data, return_preds=False): |
| | vec = np.array(data.vec) |
| | vec, vec_raw = preprocess(vec) |
| |
|
| | feed = {model.X: vec, model.drop_rate: 0, model.is_training: False} |
| | preds = sess.run(model.preds, feed_dict=feed) |
| |
|
| | picks = extract_picks(preds, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw) |
| |
|
| | picks = [ |
| | {k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} |
| | for pick in picks |
| | ] |
| |
|
| | if return_preds: |
| | return picks, preds |
| |
|
| | return picks |
| |
|
| |
|
| | class Data(BaseModel): |
| | id: List[List[str]] |
| | timestamp: List[Union[str, float, datetime]] |
| | vec: Union[List[List[List[float]]], List[List[float]]] |
| |
|
| | dt: Optional[float] = 0.01 |
| | |
| | stations: Optional[List[Dict[str, Union[float, str]]]] = None |
| | config: Optional[Dict[str, Union[List[float], List[int], List[str], float, int, str]]] = None |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | @app.post("/predict") |
| | def predict(data: Data): |
| | picks = get_prediction(data) |
| |
|
| | return picks |
| |
|
| |
|
| | @app.post("/predict_prob") |
| | def predict(data: Data): |
| | picks, preds = get_prediction(data, True) |
| |
|
| | return picks, preds.tolist() |
| |
|
| |
|
| | @app.websocket("/ws") |
| | async def websocket_endpoint(websocket: WebSocket): |
| | await websocket.accept() |
| | while True: |
| | data = await websocket.receive_json() |
| | |
| | data = Data(**data) |
| | picks = get_prediction(data) |
| | await websocket.send_json(picks) |
| | print("PhaseNet Updating...") |
| |
|
| |
|
| | @app.get("/healthz") |
| | def healthz(): |
| | return {"status": "ok"} |
| |
|
| |
|
| | @app.get("/") |
| | def greet_json(): |
| | return {"Hello": "PhaseNet!"} |
| |
|