InPeerReview's picture
Upload 9 files
032c113 verified
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
import torchmetrics
from tqdm import tqdm
import prettytable
import numpy as np
import argparse
from rscd.models.build_model import build_model
from rscd.datasets import build_dataloader
from rscd.optimizers import build_optimizer
from rscd.losses import build_loss
from utils.config import Config
from torch.autograd import Variable
import sys
sys.path.append('rscd')
seed_everything(1234, workers=True)
import numpy as np
import os
import time # 用于计时
def resize_label(label, size):
label = np.expand_dims(label,axis=0)
label_resized = np.zeros((1,label.shape[1],size[0],size[1]))
interp = nn.Upsample(size=(size[0], size[1]),mode='bilinear')
labelVar = Variable(torch.from_numpy(label).float())
label_resized[:, :,:,:] = interp(labelVar).data.numpy()
label_resized = np.array(label_resized, dtype=np.int32)
return torch.from_numpy(np.squeeze(label_resized,axis=0)).float()
def get_args():
parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
parser.add_argument("-c", "--config", type=str, default="configs/cdlamba.py")
return parser.parse_args()
class myTrain(LightningModule):
def __init__(self, cfg, log_dir = None):
super(myTrain, self).__init__()
self.cfg = cfg
self.log_dir = log_dir
self.net = build_model(cfg.model_config)
self.loss = build_loss(cfg.loss_config)
self.loss.to('cuda:{}'.format(cfg.gpus[0]))
metric_cfg1 = cfg.metric_cfg1
metric_cfg2 = cfg.metric_cfg2
self.tr_oa=torchmetrics.Accuracy(**metric_cfg1)
self.tr_prec = torchmetrics.Precision(**metric_cfg2)
self.tr_recall = torchmetrics.Recall(**metric_cfg2)
self.tr_f1 = torchmetrics.F1Score(**metric_cfg2)
self.tr_iou=torchmetrics.JaccardIndex(**metric_cfg2)
self.val_oa=torchmetrics.Accuracy(**metric_cfg1)
self.val_prec = torchmetrics.Precision(**metric_cfg2)
self.val_recall = torchmetrics.Recall(**metric_cfg2)
self.val_f1 = torchmetrics.F1Score(**metric_cfg2)
self.val_iou=torchmetrics.JaccardIndex(**metric_cfg2)
self.test_oa=torchmetrics.Accuracy(**metric_cfg1)
self.test_prec = torchmetrics.Precision(**metric_cfg2)
self.test_recall = torchmetrics.Recall(**metric_cfg2)
self.test_f1 = torchmetrics.F1Score(**metric_cfg2)
self.test_iou=torchmetrics.JaccardIndex(**metric_cfg2)
self.test_max_f1 = [0 for _ in range(10)]
self.test_loader = build_dataloader(cfg.dataset_config, mode='test')
def forward(self, x1, x2) :
pred = self.net(x1, x2)
return pred
def configure_optimizers(self):
optimizer, scheduler = build_optimizer(self.cfg.optimizer_config, self.net)
return {'optimizer':optimizer,'lr_scheduler':scheduler, 'monitor': self.cfg.monitor_val}
def train_dataloader(self):
loader = build_dataloader(self.cfg.dataset_config, mode='train')
return loader
def val_dataloader(self):
loader = build_dataloader(self.cfg.dataset_config, mode='val')
return loader
def output(self, metrics, total_metrics, mode, test_idx=0, test_value=None):
result_table = prettytable.PrettyTable()
result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']
for i in range(len(metrics[0])):
item = [i, '--']
for j in range(len(metrics)):
item.append(np.round(metrics[j][i].cpu().numpy(), 4))
result_table.add_row(item)
total = list(total_metrics.values())
total = [np.round(v, 4) for v in total]
total.insert(0, 'total')
result_table.add_row(total)
if mode == 'val' or mode == 'test':
print(mode)
print(result_table)
if self.log_dir:
base_dir = self.log_dir
else:
base_dir = os.path.join('work_dirs', cfg.exp_name)
if mode == 'test':
if self.cfg.argmax:
file_name = os.path.join(base_dir, "test_metrics_{}.txt".format(test_idx))
if metrics[2][1] > self.test_max_f1[test_idx]:
self.test_max_f1[test_idx] = metrics[2][1]
file_name = os.path.join(base_dir, "test_max_metrics_{}.txt".format(test_idx))
else:
file_name = os.path.join(base_dir, "test_metrics_{}_{}.txt".format(test_idx, str(test_value)))
if metrics[2][1] > self.test_max_f1[test_idx]:
self.test_max_f1[test_idx] = metrics[2][1]
file_name = os.path.join(base_dir, "test_max_metrics_{}_{}.txt".format(test_idx, '%.1f' % test_value))
else:
file_name = os.path.join(base_dir, "train_metrics.txt")
f = open(file_name,"a")
f.write('epoch:{}/{} {}\n'.format(self.current_epoch, self.cfg.epoch, mode))
f.write(str(result_table)+'\n')
f.close()
def training_step(self, batch, batch_idx):
imgA, imgB, mask = batch[0], batch[1], batch[2]
preds = self(imgA, imgB)
if self.cfg.net == 'SARASNet':
mask = Variable(resize_label(mask.data.cpu().numpy(), \
size=preds.data.cpu().numpy().shape[2:]).to('cuda')).long()
param = 1 # This parameter is balance precision and recall to get higher F1-score
preds[:,1,:,:] = preds[:,1,:,:] + param
if self.cfg.argmax:
loss = self.loss(preds, mask)
pred = preds.argmax(dim=1)
else:
if self.cfg.net == 'maskcd':
loss = self.loss(preds[1], mask)
pred = preds[0]
pred = pred > 0.5
pred.squeeze_(1)
else:
pred = preds.squeeze(1)
loss = self.loss(pred, mask)
pred = pred > 0.5
self.tr_oa(pred, mask)
self.tr_prec(pred, mask)
self.tr_recall(pred, mask)
self.tr_f1(pred, mask)
self.tr_iou(pred, mask)
self.log('tr_loss', loss, on_step=True,on_epoch=True,prog_bar=True)
return loss
def on_train_epoch_end(self):
metrics = [self.tr_prec.compute(),
self.tr_recall.compute(),
self.tr_f1.compute(),
self.tr_iou.compute()]
log = {'tr_oa': float(self.tr_oa.compute().cpu()),
'tr_prec': np.mean([item.cpu() for item in metrics[0]]),
'tr_recall': np.mean([item.cpu() for item in metrics[1]]),
'tr_f1': np.mean([item.cpu() for item in metrics[2]]),
'tr_miou': np.mean([item.cpu() for item in metrics[3]])}
self.output(metrics, log, 'train')
for key, value in zip(log.keys(), log.values()):
self.log(key, value, on_step=False,on_epoch=True,prog_bar=True)
self.log('tr_change_f1', metrics[2][1], on_step=False,on_epoch=True,prog_bar=True)
self.tr_oa.reset()
self.tr_prec.reset()
self.tr_recall.reset()
self.tr_f1.reset()
self.tr_iou.reset()
def validation_step(self, batch, batch_idx):
imgA, imgB, mask = batch[0], batch[1], batch[2]
preds = self(imgA, imgB)
if self.cfg.net == 'SARASNet':
mask = Variable(resize_label(mask.data.cpu().numpy(), \
size=preds.data.cpu().numpy().shape[2:]).to('cuda')).long()
param = 1 # This parameter is balance precision and recall to get higher F1-score
preds[:,1,:,:] = preds[:,1,:,:] + param
if self.cfg.argmax:
loss = self.loss(preds, mask)
pred = preds.argmax(dim=1)
else:
if self.cfg.net == 'maskcd':
loss = self.loss(preds[1], mask)
pred = preds[0]
pred = pred > 0.5
pred.squeeze_(1)
else:
pred = preds.squeeze(1)
loss = self.loss(pred, mask)
pred = pred > 0.5
self.val_oa(pred, mask)
self.val_prec(pred, mask)
self.val_recall(pred, mask)
self.val_f1(pred, mask)
self.val_iou(pred, mask)
self.log('val_loss', loss, on_step=True,on_epoch=True,prog_bar=True)
return loss
def on_validation_epoch_end(self):
metrics = [self.val_prec.compute(),
self.val_recall.compute(),
self.val_f1.compute(),
self.val_iou.compute()]
log = {'val_oa': float(self.val_oa.compute().cpu()),
'val_prec': np.mean([item.cpu() for item in metrics[0]]),
'val_recall': np.mean([item.cpu() for item in metrics[1]]),
'val_f1': np.mean([item.cpu() for item in metrics[2]]),
'val_miou': np.mean([item.cpu() for item in metrics[3]])}
self.output(metrics, log, 'val')
for key, value in zip(log.keys(), log.values()):
self.log(key, value, on_step=False,on_epoch=True,prog_bar=True)
self.log('val_change_f1', metrics[2][1], on_step=False,on_epoch=True,prog_bar=True)
self.val_oa.reset()
self.val_prec.reset()
self.val_recall.reset()
self.val_f1.reset()
self.val_iou.reset()
for idx in range(0, len(self.cfg.monitor_test), 1):
if self.cfg.argmax:
self.log(self.cfg.monitor_test[idx], self.test(idx), on_step=False,on_epoch=True,prog_bar=True)
else:
t = 0.2 + 0.1 * idx
self.log(self.cfg.monitor_test[idx], self.test(idx, t), on_step=False,on_epoch=True,prog_bar=True)
def test(self, idx, value = None):
for input in tqdm(self.test_loader):
raw_predictions, mask_test = self(input[0].cuda(cfg.gpus[0]), input[1].cuda(cfg.gpus[0])), input[2].cuda(cfg.gpus[0])
if self.cfg.net == 'SARASNet':
mask_test = Variable(resize_label(mask_test.data.cpu().numpy(), \
size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long()
param = 1 # This parameter is balance precision and recall to get higher F1-score
raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param
if self.cfg.argmax:
pred_test = raw_predictions.argmax(dim=1)
else:
if self.cfg.net == 'maskcd':
raw_prediction = raw_predictions[0]
pred_test = raw_prediction > value
pred_test.squeeze_(1)
else:
pred_test = raw_predictions.squeeze(1)
pred_test = pred_test > 0.5
self.test_oa(pred_test, mask_test)
self.test_iou(pred_test, mask_test)
self.test_prec(pred_test, mask_test)
self.test_f1(pred_test, mask_test)
self.test_recall(pred_test, mask_test)
metrics_test = [self.test_prec.compute(),
self.test_recall.compute(),
self.test_f1.compute(),
self.test_iou.compute()]
log = {'test_oa': float(self.test_oa.compute().cpu()),
'test_prec': np.mean([item.cpu() for item in metrics_test[0]]),
'test_recall': np.mean([item.cpu() for item in metrics_test[1]]),
'test_f1': np.mean([item.cpu() for item in metrics_test[2]]),
'test_miou': np.mean([item.cpu() for item in metrics_test[3]])}
self.output(metrics_test, log, 'test', idx, value)
self.test_oa.reset()
self.test_prec.reset()
self.test_recall.reset()
self.test_f1.reset()
self.test_iou.reset()
return metrics_test[2][1]
if __name__ == "__main__":
args = get_args()
cfg = Config.fromfile(args.config)
logger = TensorBoardLogger(save_dir = "work_dirs",
sub_dir = 'log',
name = cfg.exp_name,
default_hp_metric = False)
log_dir = os.path.dirname(logger.log_dir)
model = myTrain(cfg, log_dir)
# —— 在这里插入“推理 FPS 测试”功能 —— #
device = torch.device(f'cuda:{cfg.gpus[0]}' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
# 从验证集 dataloader 里取一个 batch
val_loader = model.val_dataloader()
batch_iter = iter(val_loader)
try:
batch = next(batch_iter)
imgA_batch = batch[0]
imgB_batch = batch[1]
except StopIteration:
raise RuntimeError("验证集 dataloader 为空,请检查数据集配置。")
# 将输入搬到同一个设备
imgA_batch = imgA_batch.to(device)
imgB_batch = imgB_batch.to(device)
# 热身推理 10 次
with torch.no_grad():
for _ in range(10):
_ = model(imgA_batch, imgB_batch)
# 正式计时 N 次推理
N = 100
torch.cuda.synchronize(device)
start_time = time.time()
with torch.no_grad():
for _ in range(N):
_ = model(imgA_batch, imgB_batch)
torch.cuda.synchronize(device)
elapsed = time.time() - start_time
fps = N / elapsed
print(f"[推理 FPS 测试] 输入分辨率 = {imgA_batch.shape[2]}×{imgA_batch.shape[3]},"
f"Batch Size = {imgA_batch.shape[0]},推理 {N} 次总耗时:{elapsed:.4f} 秒,FPS = {fps:.2f}")
# —— 插入结束 —— #
pbar = TQDMProgressBar(refresh_rate=1)
lr_monitor=LearningRateMonitor(logging_interval = cfg.logging_interval)
callbacks = [pbar, lr_monitor]
ckpt_cb = ModelCheckpoint(dirpath = f'{log_dir}/ckpts/val',
filename = '{' + cfg.monitor_val + ':.4f}' + '-{epoch:d}',
monitor = cfg.monitor_val,
mode = 'max',
save_top_k = cfg.save_top_k,
save_last=True)
callbacks.append(ckpt_cb)
for m_test in cfg.monitor_test:
ckpt_cb = ModelCheckpoint(dirpath = f'{log_dir}/ckpts/test/{m_test}',
filename = '{' + m_test + ':.4f}' + '-{epoch:d}',
monitor = m_test,
mode = 'max',
save_top_k = cfg.save_top_k,
save_last=True)
callbacks.append(ckpt_cb)
trainer = Trainer(max_epochs = cfg.epoch,
# precision='16-mixed',
callbacks = callbacks,
logger = logger,
enable_model_summary = True,
accelerator = 'auto',
devices = cfg.gpus,
num_sanity_val_steps = 2,
benchmark = True)
trainer.fit(model, ckpt_path=cfg.resume_ckpt_path)