
[NLP Project] Bert 모델에 NER 학습시키기 (텐서플로우) - keras.saving() 해결일지

sillon 2022. 11. 11. 23:00

기존의 인터넷 상에 있던 코드들은 모두 옛날 코드인지 잘 적용이 안됐었다.

긴 시간 구현해보고 파이토치로도 구현해보았는데 제대로 코드조차 실행이 안됐다.

사실은 다른 오픈소스를 다운받아 해결할 수 있었겠지만, 그렇게 하면 공부가 절대 되지 않을 것 같아 직접 발로 뛰며 구현했다. (사실 앉아만 있었음)


버트 인풋 만들기에서 조금 더 추가한 부분이 있다.



참고 :


[CLS] 부분과 [SEP] 부분까지의 문장을 구분한다.

사실 내가 전처리한 데이터 셋에는 데이터셋 처음 부분은  [CLS], 마지막 부분은 [SEP]라서 문장 자체에서 이렇게 구분하진 않아도 된다.

그래도 일단... 인풋에 필요하다고 하니 넣어주기로 한다.


    # 'token_type_ids' ( [CLS]와 [SEP]를 구분해줌 )
    token_type_ids = np.array([np.zeros(len(i)) for i in input_ids], dtype=int)


아, 처음 실행했을 때 넘파이 0인 배열을 생성하는데  float 형태로 배열을 생성하여서 버트 입력에 오류가 있었다. 따라서 꼭 int로 타입을 설정해주어야한다.


  tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
    tokened_data = file_open('data/token_data.pkl') 
    tokenized_texts = [token_label_pair[0] for token_label_pair in tokened_data]
    labels = [token_label_pair[1] for token_label_pair in tokened_data]

    ## padding 
    # print(np.quantile(np.array([len(x) for x in tokenized_texts]), 0.975)) # 문장의 길이가 상위 2.5%(88) 인 지점
    max_len = 88
    bs = 32

    input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts],
                          maxlen=max_len, dtype = "int", value=tokenizer.convert_tokens_to_ids("[PAD]"), truncating="post", padding="post")
    label_dict = {'PER_B': 0, 'DAT_B': 1, '-': 2, 'ORG_B': 3, 'CVL_B': 4, 'NUM_B': 5, 'LOC_B': 6, 'EVT_B': 7, 'TRM_B': 8, 'TRM_I': 9, 'EVT_I': 10, 'PER_I': 11, 'CVL_I': 12, 'NUM_I': 13, 'TIM_B': 14, 'TIM_I': 15, 'ORG_I': 16, 'DAT_I': 17, 'ANM_B': 18, 'MAT_B': 19, 'MAT_I': 20, 'AFW_B': 21, 'FLD_B': 22, 'LOC_I': 23, 'AFW_I': 24, 'PLT_B': 25, 'FLD_I': 26, 'ANM_I': 27, 'PLT_I': 28, '[PAD]': 29}
    tags = pad_sequences([lab for lab in labels], maxlen=max_len, value=label_dict["[PAD]"], padding='post',\
                     dtype='int', truncating='post')
    # Attention mask
    attention_masks = np.array([[int(i != tokenizer.convert_tokens_to_ids("[PAD]")) for i in ii] for ii in input_ids])
    # 'token_type_ids' ( [CLS]와 [SEP]를 구분해줌 ) - 별로 필요 없을 거 같은데?
    token_type_ids = np.array([np.zeros(len(i)) for i in input_ids], dtype=int)
    # train 데이터에서 10% 만큼을 validation 데이터로 분리
    tr_inputs, val_inputs, tr_tags, val_tags = train_test_split(input_ids, tags,
                                                            random_state=222, test_size=0.1)

    # Atteion mask train-test data split
    tr_masks, val_masks, _, _ = train_test_split(attention_masks, input_ids,
                                             random_state=222, test_size=0.1)

    # 'token_type_ids' split
    tr_token_type, val_token_type,_,_ = train_test_split(token_type_ids, input_ids,
                                             random_state=222, test_size=0.1)

    X_train, y_train = (tr_inputs,tr_masks,tr_token_type), tr_tags
    X_test, y_test =  (val_inputs,val_masks,val_token_type), val_tags

이렇게 main.py에서 전처리한 데이터를 바탕으로 학습데이터와 훈련데이터를 다시 묶어주었다.

import tensorflow as tf
from transformers import *

class TFBertForTokenClassification(tf.keras.Model):
    def __init__(self, model_name, num_labels):
        super(TFBertForTokenClassification, self).__init__()
        self.bert = TFBertModel.from_pretrained(model_name, from_pt=True)
        self.classifier = tf.keras.layers.Dense(num_labels,

    def call(self, inputs):
        input_ids, attention_mask, token_type_ids = inputs
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        all_output = outputs[0]
        prediction = self.classifier(all_output)

        return prediction

def compute_loss(labels, logits):

  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
  active_loss = tf.reshape(labels, (-1,)) != -100
  reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
  labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)

  return loss_fn(labels, reduced_logits)

def modeling(model_name, tag_size):
    model = TFBertForTokenClassification(model_name, tag_size)
    optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
    model.compile(optimizer=optimizer, loss=compute_loss)
    return model


Bert 모델 구현이다.

클래스 자체를 가져와서 파이썬에 넣어주었다.

기존에 있던 모듈을 이용해서 가져오는데에는 계속 keras.saving() 오류가 생겼었다.


하이퍼파라미터 설정:

Optimizer  = Adam

Leraning rate = 5e-5

epoch = 3

batch_size = 32


하이퍼파라미터는 적절히 바꿔가며 모델의 성능을 최대화할 수 있는 방향으로 바꿔나가면 된다.


NER 의 성능 평가는 F1 Score로 하는 것이 더 정확하다.

다른 모델에 적용하는 accuracy 를 적용하면 정확한 지표가 되지 못한다.


F1 Score 클래스는 다음과 같이 구현하였다.

from seqeval.metrics import precision_score, recall_score, f1_score, classification_report
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import numpy as np
import tensorflow as tf

label_dict = {'PER_B': 0, 'DAT_B': 1, '-': 2, 'ORG_B': 3, 'CVL_B': 4, 'NUM_B': 5, 'LOC_B': 6, 'EVT_B': 7, 'TRM_B': 8, 'TRM_I': 9, 'EVT_I': 10, 'PER_I': 11, 'CVL_I': 12, 'NUM_I': 13, 'TIM_B': 14, 'TIM_I': 15, 'ORG_I': 16, 'DAT_I': 17, 'ANM_B': 18, 'MAT_B': 19, 'MAT_I': 20, 'AFW_B': 21, 'FLD_B': 22, 'LOC_I': 23, 'AFW_I': 24, 'PLT_B': 25, 'FLD_I': 26, 'ANM_I': 27, 'PLT_I': 28, '[PAD]': 29}
index_to_tag = {v:k for k,v in label_dict.items()} # 키-값 쌍 변경, 인덱스(키)로 태그(값) 찾기

class F1score(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        self.X_test = X_test
        self.y_test = y_test

    def sequences_to_tags(self, label_ids, pred_ids):
      label_list = []
      pred_list = []

      for i in range(0, len(label_ids)):
        label_tag = []
        pred_tag = []

        for label_index, pred_index in zip(label_ids[i], pred_ids[i]):
          if label_index != -100:

      return label_list, pred_list

    def on_epoch_end(self, epoch, logs={}):

      y_predicted = self.model.predict(self.X_test)
      y_predicted = np.argmax(y_predicted, axis = 2)

      label_list, pred_list = self.sequences_to_tags(self.y_test, y_predicted)

      score = f1_score(label_list, pred_list)
      print(' - f1: {:04.2f}'.format(score * 100))
      print(classification_report(label_list, pred_list))


이제 작성한 코드들을 바탕으로 학습을 진행할 것이다.


다른 파일에서 클래스, 함수 호출을 하려면 파일 이름을 임포트하면 된다.

import model
import metrics_f1
    model = model.modeling(model_name='bert-base-multilingual-cased',

    f1_score_report = metrics_f1.F1score(X_test, y_test)
    X_train, y_train, epochs=3, batch_size=32,
    callbacks = [f1_score_report]

모델을 불러와 학습시키고, 학습한 에포크마다 f1 score을 구할 것이다.

그리고, 학습한 모델은 텐서플로우에 내장된 save() 함수로 저장한다.


순수 출력 로그만 보자면...

/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: - seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: PER_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: CVL_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: DAT_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: ORG_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: FLD_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: [PAD] seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: LOC_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: EVT_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: DAT_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: AFW_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: EVT_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: NUM_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: NUM_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: TRM_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: TIM_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: PER_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: ANM_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: CVL_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: ORG_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: AFW_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: TRM_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: LOC_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: TIM_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: FLD_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: MAT_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: MAT_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: PLT_B seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: ANM_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UserWarning: PLT_I seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))
 - f1: 76.68
/home/suyeon/anaconda3/envs/py39/lib/python3.9/site-packages/seqeval/metrics/ UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
              precision    recall  f1-score   support

        AT_B       0.84      0.91      0.87      1743
        AT_I       0.84      0.83      0.83       500
        ER_B       0.78      0.77      0.78      2853
        ER_I       0.51      0.58      0.54       337
        FW_B       0.53      0.31      0.39       284
        FW_I       0.30      0.20      0.24        80
        IM_B       0.80      0.80      0.80       214
        IM_I       0.92      0.87      0.89        53
        LD_B       0.46      0.44      0.45       157
        LD_I       0.00      0.00      0.00         4
        LT_B       0.57      0.22      0.32        18
        LT_I       0.00      0.00      0.00         1
        NM_B       0.56      0.57      0.56       480
        NM_I       0.00      0.00      0.00         1
        OC_B       0.64      0.66      0.65      1290
        OC_I       0.00      0.00      0.00        10
        PAD]       1.00      1.00      1.00      6302
        RM_B       0.54      0.57      0.56      1273
        RM_I       0.22      0.11      0.15       178
        UM_B       0.87      0.91      0.89      3526
        UM_I       0.61      0.85      0.71       614
        VL_B       0.72      0.56      0.63      3804
        VL_I       0.29      0.05      0.09       184
        VT_B       0.83      0.71      0.77      3451
        VT_I       0.61      0.83      0.71      1653
           _       0.73      0.73      0.73     18351

   micro avg       0.77      0.76      0.77     47361
   macro avg       0.55      0.52      0.52     47361
weighted avg       0.77      0.76      0.76     47361

Epoch 2/3
1843/1843 [==============================] - 419s 227ms/step - loss: 0.1319
 - f1: 79.44
              precision    recall  f1-score   support

        AT_B       0.90      0.91      0.90      1743
        AT_I       0.84      0.87      0.85       500
        ER_B       0.78      0.82      0.80      2853
        ER_I       0.60      0.61      0.61       337
        FW_B       0.49      0.47      0.48       284
        FW_I       0.36      0.29      0.32        80
        IM_B       0.83      0.85      0.84       214
        IM_I       0.88      0.79      0.83        53
        LD_B       0.49      0.56      0.52       157
        LD_I       0.00      0.00      0.00         4
        LT_B       0.17      0.06      0.08        18
        LT_I       0.00      0.00      0.00         1
        NM_B       0.54      0.69      0.61       480
        NM_I       0.00      0.00      0.00         1
        OC_B       0.59      0.75      0.66      1290
        OC_I       0.00      0.00      0.00        10
        PAD]       1.00      1.00      1.00      6302
        RM_B       0.65      0.57      0.60      1273
        RM_I       0.44      0.18      0.25       178
        UM_B       0.91      0.94      0.92      3526
        UM_I       0.79      0.71      0.75       614
        VL_B       0.71      0.70      0.70      3804
        VL_I       0.43      0.24      0.31       184
        VT_B       0.75      0.76      0.76      3451
        VT_I       0.67      0.86      0.76      1653
           _       0.76      0.77      0.77     18351

   micro avg       0.79      0.80      0.79     47361
   macro avg       0.56      0.55      0.55     47361
weighted avg       0.79      0.80      0.79     47361

Epoch 3/3
1843/1843 [==============================] - 421s 229ms/step - loss: 0.1045
 - f1: 80.73
              precision    recall  f1-score   support

        AT_B       0.88      0.93      0.91      1743
        AT_I       0.83      0.90      0.87       500
        ER_B       0.78      0.82      0.80      2853
        ER_I       0.62      0.65      0.63       337
        FW_B       0.43      0.46      0.44       284
        FW_I       0.27      0.36      0.31        80
        IM_B       0.82      0.88      0.85       214
        IM_I       0.81      0.87      0.84        53
        LD_B       0.49      0.54      0.51       157
        LD_I       0.50      0.25      0.33         4
        LT_B       0.17      0.28      0.21        18
        LT_I       0.00      0.00      0.00         1
        NM_B       0.58      0.73      0.65       480
        NM_I       0.00      0.00      0.00         1
        OC_B       0.70      0.73      0.72      1290
        OC_I       0.00      0.00      0.00        10
        PAD]       1.00      1.00      1.00      6302
        RM_B       0.59      0.66      0.62      1273
        RM_I       0.30      0.31      0.31       178
        UM_B       0.92      0.93      0.92      3526
        UM_I       0.73      0.85      0.79       614
        VL_B       0.74      0.72      0.73      3804
        VL_I       0.37      0.18      0.24       184
        VT_B       0.81      0.76      0.79      3451
        VT_I       0.75      0.78      0.76      1653
           _       0.77      0.79      0.78     18351

   micro avg       0.80      0.82      0.81     47361
   macro avg       0.57      0.59      0.58     47361
weighted avg       0.80      0.82      0.81     47361

최종 학습 결과는 이렇게 나왔다!

micro avg 라고 적혀있는게 f1 score이다.

f1 score에 나타나는 태그들이 조금 잘려있다.. 


이렇게 학습하고 학습한 데이터를 바탕으로 예측하려고 했는데 모델 저장이 안되었다(띠용)

