利用 BERT 模型解析電子病歷

項目原始地址

項目地址
本項目改編自此 Github 項目,鳴謝做者。vue

問題描述

咱們但願能從患者住院期間的臨牀記錄來預測該患者將來30天內是否會再次入院,該預測能夠輔助醫生更好的選擇治療方案並對手術風險進行評估。在臨牀中治療手段常見而預後狀況難以控制管理的狀況家常便飯。好比關節置換手術做爲治療老年骨性關節炎等疾病的最終方法在臨牀中取得了極大成功,可是與手術相關的併發症以及由此致使的再入院狀況也並很多見。患者的自身因素如心臟病、糖尿病、肥胖等狀況也會增長關節置換術後的再入院風險。當接受關節置換手術的人羣的年齡愈來愈大,健康情況愈來愈差的狀況下,會出現更多的併發症而且增長再次入院風險。
經過電子病歷的相關記錄,觀察到對於某些疾病或者手術來講,30天內再次入院的患者各方面的風險都明顯增長。所以對與前次住院緣由相同,且前次出院與下次入院間隔未超過30天的再一次住院視爲同一次住院的狀況進行了篩選標註,訓練模型來嘗試解決這個問題。node

數據選取與數據清洗

選取於 Medical Information Mart for Intensive Care III 數據集,也稱 MIMIC-III,是在NIH資助下,由MIT、哈佛醫學院BID醫學中心、飛利浦醫療聯合開發維護的多參數重症監護數據庫。該數據集免費向研究人員開放,可是須要進行申請。咱們在進行實驗的時候將數據部署在 Postgre SQL 中。首先從admission表中取出全部數據,針對每一條記錄計算同個subject_id下一次出現時的時間間隔,若小於30天則給該條記錄添加標籤Label=1,不然Label=0。而後再計算該次住院的時長(出院日期-入院日期),並抽取其中住院時長>2的樣本。將上述抽出的全部樣本的HADM_ID按照0.8:0.1:0.1的比例隨機分配造成訓練集、驗證集和測試集。以後再從noteevents表中按照以前分配好的HADM_ID獲取各個數據集的文本內容(即表noteevents中的TEXT列)。整理好的訓練集、驗證集和測試集均含有三列,分別爲TEXT(文本內容),ID(即HADM_ID),Label(0或1)。python

預訓練模型

原項目使用的預訓練模型。基於 BERT 訓練。在NLP(天然語言處理)領域BERT模型有着里程碑式的意義。2018年的10月11日,Google發佈的論文《Pre-training of Deep Bidirectional Transformers for Language Understanding》,成功在 11 項 NLP 任務中取得 state of the art 的結果,贏得天然語言處理學界的一片讚譽之聲。BERT模型在文本分類、文本預測等多個領域都取得了很好的效果。
更多關於BERT模型的內容可參考連接git

BERT算法的原理主要由兩部分組成:github

  • 第一步,經過對大量未標註的語料進行非監督的預訓練,來學習其中的表達法。
  • 其次,使用少許標記的訓練數據以監督方式微調(fine tuning)預訓練模型以進行各類監督任務。

ClinicalBERT 模型根據含有標記的臨牀記錄對BERT模型進行微調,從而獲得一個能夠用於醫療領域文本分析的模型。細節請參考原項目連接算法

環境安裝

!pip install -U pytorch-pretrained-bert -i https://pypi.tuna.tsinghua.edu.cn/simple
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'

數據查看

讓咱們來看看被預測的數據是什麼格式shell

import pandas as pd
sample = pd.read_csv('/home/input/MIMIC_note3519/BERT/sample.csv')
sample
TEXT ID Label
0 Nursing Progress Note 1900-0700 hours:\n** Ful... 176088 1
1 Nursing Progress Note 1900-0700 hours:\n** Ful... 135568 1
2 NPN:\n\nNeuro: Alert and oriented X2-3, Sleepi... 188180 0
3 RESPIRATORY CARE:\n\n35 yo m adm from osh for ... 110655 0
4 NEURO: A+OX3 pleasant, mae, following commands... 139362 0
5 Nursing Note\nSee Flowsheet\n\nNeuro: Propofol... 176981 0

能夠看到在 TEXT 字段下存放了幾條非結構的文本數據,讓咱們來取出一條看看在說什麼。數據庫

text = sample['TEXT'][0]
print(text)
Nursing Progress Note 1900-0700 hours:
** Full code

** allergy: nkda

** access: #18 piv to right FA, #18 piv to right FA.

** diagnosis: angioedema

In Brief: Pt is a 51yo F with pmh significant for: COPD, HTN, diabetes insipidus, hypothyroidism, OSA (on bipap at home), restrictive lung disease, pulm artery hypertension attributed to COPD/OSA, ASD with shunt, down syndrome, CHF with LVEF >60%. Also, 45pk-yr smoker (quit in [**2112**]).

Pt brought to [**Hospital1 2**] by EMS after family found with decreased LOC.  Pt presented with facial swelling and mental status changes. In [**Name (NI) **], pt with enlarged lips and with sats 99% on 2-4l.  Her pupils were pinpoint so given narcan.  She c/o LLQ abd pain and also developed a severe HA.  ABG with profound resp acidosis 7.18/108/71.  Given benadryl, nebs, solumedrol. Difficult intubation-req'd being taken to OR to have fiberoptic used.  Also found to have ARF.  On admit to ICU-denied pain in abdomen, denied HA.  Denied any pain. Pt understands basic english but also used [**Name (NI) **] interpretor to determine these findings. Head CT on [**Name6 (MD) **] [**Name8 (MD) 20**] md as pt was able to nod yes and no and follow commands.

NEURO: pt is sedate on fent at 50mcg/hr and versed at 0.5mg/hr-able to arouse on this level of sedation.  PEARL 2mm/brisk. Able to move all ext's, nod yes and no to questions.  Occasional cough.

CARDIAC: sb-nsr with hr high 50's to 70's.  Ace inhibitors  (pt takes at home) on hold right now as unclear as to what meds or other cause of angioedema.  no ectopy.  SBP >100 with MAPs > 60.

RESP: nasally intubated. #6.0 tube which is sutured in place.  Confirmed by xray for proper placement (5cm above carina). ** some resp events overnight: on 3 occasions thus far, pt noted to have vent alarm 'apnea' though on AC mode and then alarms 'pressure limited/not constant'.  At that time-pt appears comfortably sedate (not bucking vent) but dropping TV's into 100's (from 400's), MV to 3.0 and then desats to 60's and 70's with no chest rise and fall noted. Given 100% 02 first two times with immediate elevation of o2 sat to >92%.  The third time RT ambubagged to see if it was difficult-also climbed right back up to sat >93%.   Suctioned for scant sputum only.  ? as to whether tube was kinking off in trachea or occluding somehow.  RT also swapped out the vent for a new one in case [**Last Name **] problem.  Issue did occur again with new vent (so ruled out a [**Last Name **] problem). Several ABGs overnight (see carevue) which last abg stable. Current settings: 50%/ tv 400/ac 22/p5. Lungs with some rhonchi-received MDI's/nebs overnight. IVF infusing (some risk for chf) Sats have been >93% except for above events. cont to assess.

GI/GU: abd soft, distended, obese. two small bm's this shift-brown, soft, loose. Pt without FT and unlikely to have one placed [**3-3**] edema.  IVF started for ARF and [**3-3**] without nutrition. Foley in place draining clear, yellow 25-80cc/hr.

ID: initial wbc of 12. Pt spiked temp overnight to 102.1-given tylenol supp (last temp 101.3) and pan cx'd.  no abx at this time.

[**Month/Day (2) **]: fs wnl

文本內容

能夠看到是一段 ICU 的護理日記,是一個 51 歲的女性,有慢性阻塞性肺疾病,高血壓,甲減,唐氏綜合徵,先心房缺,慢性心衰,肺動脈高壓,睡眠呼吸暫停綜合症等多種疾病。被家人發現昏迷後送醫,是嚴重的過敏反應,急性血管水腫。處於鎮靜狀態有輕微意識。她在治療過的過程當中發生過好凝,作過溶拴還發生過急性腎衰竭。併發

模型推理

修改當前工做路徑

import os
os.chdir('/home/work/clinicalBERT')

基礎類定義

每一個類的說明見註釋app

import csv
import pandas as pd


class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.

        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""
    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
            return lines

    @classmethod
    def _read_csv(cls, input_file):
        """Reads a comma separated value file."""
        file = pd.read_csv(input_file)
        lines = zip(file.ID, file.TEXT, file.Label)
        return lines

定義數據讀取與處理類

繼承自基類

def create_examples(lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
        guid = "%s-%s" % (set_type, i)
        text_a = line[1]
        label = str(int(line[2]))
        examples.append(
            InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
    return examples


class ReadmissionProcessor(DataProcessor):
    def get_test_examples(self, data_dir):
        return create_examples(
            self._read_csv(os.path.join(data_dir, "sample.csv")), "test")

    def get_labels(self):
        return ["0", "1"]

定義腳手架函數

  • truncate_seq_pair
  • convert_examples_to_features
  • vote_score
  • pr_curve_plot
  • vote_pr_curve
def truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()
# 將文件載入,而且轉換爲張量
import logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
    """Loads a data file into a list of `InputBatch`s."""

    label_map = {}
    for (i, label) in enumerate(l        label_id = label_map[example.label]
        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("tokens: %s" % " ".join(
                    [str(x) for x in tokens]))
…                              label_id=label_id))
    return featuresabel_list):
        label_map[label] = i

    features = []
    for (ex_index, example) in enumerate(examples):
        tokens_a = tokenizer.tokenize(example.text_a)

        tokens_b = None
        if example.text_b:
            tokens_b = tokenizer.tokenize(example.text_b)

        if tokens_b:
            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
        else:
            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[0:(max_seq_length - 2)]

        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids: 0   0   0   0  0     0 0
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambigiously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the first vector (corresponding to [CLS]) is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.
        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append("[SEP]")
        segment_ids.append(0)

        if tokens_b:
            for token in tokens_b:
                tokens.append(token)
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        #print (example.label)
        label_id = label_map[example.label]
        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("tokens: %s" % " ".join(
                    [str(x) for x in tokens]))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            logger.info(
                    "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
            logger.info("label: %s (id = %d)" % (example.label, label_id))

        features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              label_id=label_id))
    return features
# 準確率曲線與繪圖
import numpy as np
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt


def vote_score(df, score, ax):
    df['pred_score'] = score
    df_sort = df.sort_values(by=['ID'])
    # score
    temp = (df_sort.groupby(['ID'])['pred_score'].agg(max) + df_sort.groupby(['ID'])['pred_score'].agg(sum) / 2) / (
                1 + df_sort.groupby(['ID'])['pred_score'].agg(len) / 2)
    x = df_sort.groupby(['ID'])['Label'].agg(np.min).values
    df_out = pd.DataFrame({'logits': temp.values, 'ID': x})

    fpr, tpr, thresholds = roc_curve(x, temp.values)
    auc_score = auc(fpr, tpr)

    ax.plot([0, 1], [0, 1], 'k--')
    ax.plot(fpr, tpr, label='Val (area = {:.3f})'.format(auc_score))
    ax.set_xlabel('False positive rate')
    ax.set_ylabel('True positive rate')
    ax.set_title('ROC curve')
    ax.legend(loc='best')
    return fpr, tpr, df_out
from sklearn.metrics import precision_recall_curve
from funcsigs import signature


def pr_curve_plot(y, y_score, ax):
    precision, recall, _ = precision_recall_curve(y, y_score)
    area = auc(recall, precision)
    step_kwargs = ({'step': 'post'}
                   if 'step' in signature(plt.fill_between).parameters
                   else {})

    ax.step(recall, precision, color='b', alpha=0.2,
             where='post')
    ax.fill_between(recall, precision, alpha=0.2, color='b', **step_kwargs)
    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    ax.set_ylim([0.0, 1.05])
    ax.set_xlim([0.0, 1.0])
    ax.set_title('Precision-Recall curve: AUC={0:0.2f}'.format(
        area))
def vote_pr_curve(df, score, ax):
    df['pred_score'] = score
    df_sort = df.sort_values(by=['ID'])
    # score
    temp = (df_sort.groupby(['ID'])['pred_score'].agg(max) + df_sort.groupby(['ID'])['pred_score'].agg(sum) / 2) / (
                1 + df_sort.groupby(['ID'])['pred_score'].agg(len) / 2)
    y = df_sort.groupby(['ID'])['Label'].agg(np.min).values

    precision, recall, thres = precision_recall_curve(y, temp)
    pr_thres = pd.DataFrame(data=list(zip(precision, recall, thres)), columns=['prec', 'recall', 'thres'])

    pr_curve_plot(y, temp, ax)

    temp = pr_thres[pr_thres.prec > 0.799999].reset_index()

    rp80 = 0
    if temp.size == 0:
        print('Test Sample too small or RP80=0')
    else:
        rp80 = temp.iloc[0].recall
        print(f'Recall at Precision of 80 is {rp80}')

    return rp80

配置推理參數

  • output_dir: 輸出文件的目錄
  • task_name: 任務名稱
  • bert_model: 模型目錄
  • data_dir: 數據目錄,默認文件名稱爲 sample.csv
  • max_seq_length: 最大字符串序列長度
  • eval_batch_size: 推理批的大小,越大佔內存越大
config = {
    "local_rank": -1,
    "no_cuda": False,
    "seed": 42,
    "output_dir": './result',
    "task_name": 'readmission',
    "bert_model": '/home/input/MIMIC_note3519/BERT/early_readmission',
    "fp16": False,
    "data_dir": '/home/input/MIMIC_note3519/BERT',
    "max_seq_length": 512,
    "eval_batch_size": 2,
}

執行推理

推理過程會產生大量日誌,能夠經過選擇當前 cell (選擇後cell左側會變爲藍色),按下鍵盤上的 「O」 鍵來隱藏日誌輸出

import random
from tqdm import tqdm
from pytorch_pretrained_bert.tokenization import BertTokenizer
from modeling_readmission import BertForSequenceClassification
from torch.utils.data import TensorDataset, SequentialSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch


processors = {
    "readmission": ReadmissionProcessor
}

if config['local_rank'] == -1 or config['no_cuda']:
    device = torch.device("cuda" if torch.cuda.is_available() and not config['no_cuda'] else "cpu")
    n_gpu = torch.cuda.device_count()
else:
    device = torch.device("cuda", config['local_rank'])
    n_gpu = 1
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend='nccl')
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(config['local_rank'] != -1))


random.seed(config['seed'])
np.random.seed(config['seed'])
torch.manual_seed(config['seed'])
if n_gpu > 0:
    torch.cuda.manual_seed_all(config['seed'])


if os.path.exists(config['output_dir']):
    pass
else:
    os.makedirs(config['output_dir'], exist_ok=True)

task_name = config['task_name'].lower()

if task_name not in processors:
    raise ValueError(f"Task not found: {task_name}")

processor = processors[task_name]()
label_list = processor.get_labels()

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Prepare model
model = BertForSequenceClassification.from_pretrained(config['bert_model'], 1)
if config['fp16']:
    model.half()
model.to(device)
if config['local_rank'] != -1:
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config['local_rank']],
                                                      output_device=config['local_rank'])
elif n_gpu > 1:
    model = torch.nn.DataParallel(model)

eval_examples = processor.get_test_examples(config['data_dir'])
eval_features = convert_examples_to_features(
    eval_examples, label_list, config['max_seq_length'], tokenizer)
logger.info("***** Running evaluation *****")
logger.info("  Num examples = %d", len(eval_examples))
logger.info("  Batch size = %d", config['eval_batch_size'])
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if config['local_rank'] == -1:
    eval_sampler = SequentialSampler(eval_data)
else:
    eval_sampler = DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=config['eval_batch_size'])
model.eval()
eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
true_labels = []
pred_labels = []
logits_history = []
m = torch.nn.Sigmoid()
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader):
    input_ids = input_ids.to(device)
    input_mask = input_mask.to(device)
    segment_ids = segment_ids.to(device)
    label_ids = label_ids.to(device)
    with torch.no_grad():
        tmp_eval_loss, temp_logits = model(input_ids, segment_ids, input_mask, label_ids)
        logits = model(input_ids, segment_ids, input_mask)

    logits = torch.squeeze(m(logits)).detach().cpu().numpy()
    label_ids = label_ids.to('cpu').numpy()

    outputs = np.asarray([1 if i else 0 for i in (logits.flatten() >= 0.5)])
    tmp_eval_accuracy = np.sum(outputs == label_ids)

    true_labels = true_labels + label_ids.flatten().tolist()
    pred_labels = pred_labels + outputs.flatten().tolist()
    logits_history = logits_history + logits.flatten().tolist()

    eval_loss += tmp_eval_loss.mean().item()
    eval_accuracy += tmp_eval_accuracy

    nb_eval_examples += input_ids.size(0)
    nb_eval_steps += 1

### 繪製精度評價曲線

df = pd.DataFrame({'logits': logits_history, 'pred_label': pred_labels, 'label': true_labels})
df_test = pd.read_csv(os.path.join(config['data_dir'], "sample.csv"))

fig = plt.figure(1)
ax1 = fig.add_subplot(1,2,1)
ax2 = fig.add_subplot(1,2,2)
fpr, tpr, df_out = vote_score(df_test, logits_history, ax1)
rp80 = vote_pr_curve(df_test, logits_history, ax2)

output_eval_file = os.path.join(config['output_dir'], "eval_results.txt")
plt.tight_layout()
plt.show()

image

將推理信息保存至輸出目錄

eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples
result = {'eval_loss': eval_loss,
          'eval_accuracy': eval_accuracy,
          'RP80': rp80}
with open(output_eval_file, "w") as writer:
    logger.info("***** Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(result[key]))
      writer.write("%s = %s\n" % (key, str(result[key])))

image

小結

經過 ICU 的醫療日記,能夠知道患者的豐富的體徵、病史等信息。經過這個模型能夠有效預測該患者是否還會住院.

代碼已提交至Github
更多內容請關注個人我的博客

相關文章
相關標籤/搜索