[NLP] Hugging Face API, 허깅페이스 API / Trainer, Training Arguments 구현

2023. 2. 1. 16:55·공부정리/NLP
728x90
반응형

 

 

Trainer class는 모델학습부터 평가까지 한 번에 해결할 수 있는 API를 제공한다. 다음의 사용예시를 보면 직관적으로 이해할 수 있다.

 

from transformers import Trainer

#initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset
    eval_dataset=eval_dataset
    compute_metrics,
    tokenizer=tokenizer
)

#train
trainer.train()

#save
trainer.save_model()

#eval
metrics = trainer.evaluate(eval_dataset=eval_dataset)

 

 

Initialize Trainer

from transformers import Trainer

#initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset
    eval_dataset=eval_dataset
    compute_metrics=compute_metrics,
)

 

  • 기본적으로 위와 같이 Trainer을 선언할 수 있다. (이외에도 더 많은 argument들이 존재한다) model은 HuggingFace 라이브러리에서 제공되는 PretrainedModel을 사용해도 되지만, torch.nn.Module을 사용할 수도 있다. 모델을 지정하는 방법은 위와 같이 model argument로 줄 수도 있고, 혹은, 아래와 같이 callable한 모델 초기화 함수를 model_init argument로 줄 수도 있다. 만약 model_init을 이용해 모델을 지정해주면, 매 train() method가 호출될 때마다 모델이 새롭게 초기화(생성)된다. 
from transformers import Trainer, AutoModelForSequenceClassification

#initialize Trainer
trainer = Trainer(
    model_init=AutoModelForSequenceClassification.from_pretrained(model_name),
    args=training_args,
    train_dataset=train_dataset
    eval_dataset=eval_dataset
    compute_metrics=compute_metrics,
)

 

  • args는 train에 필요한 파라미터들의 모음으로, TrainingArgs를 이용해 줄 수 있다. Optimizer의 종류, learning rate, epoch 수, scheduler, half precision 사용여부 등을 지정할 수 있으며, 모든 파라미터의 목록은 여기에서 확인할 수 있다. 예시는 다음과 같다.
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=1,              # total number of training epochs
    per_device_train_batch_size=1,   # batch size per device during training
    per_device_eval_batch_size=10,   # batch size for evaluation
    warmup_steps=1000,               # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=200,               # How often to print logs
    do_train=True,                   # Perform training
    do_eval=True,                    # Perform evaluation
    evaluation_strategy="epoch",     # evalute after eachh epoch
    gradient_accumulation_steps=64,  # total number of steps before back propagation
    fp16=True,                       # Use mixed precision
    fp16_opt_level="02",             # mixed precision mode
    run_name="ProBert-BFD-MS",       # experiment name
    seed=3                           # Seed for experiment reproducibility 3x3
)

 

  • train_dataset과 eval_dataset은 각각 train과 validation/test에 사용되는 torch.utils.data.Dataset이다. 꼭 초기화 할 때 지정하지 않아도 된다.

 

  • compute_metrics는 evaluation에 사용할 metric을 계산하는 함수이다. 모델의 output인 EvalPrediction을 input으로 받아 metric을 dictionary 형태로 return하는 함수가 되야 한다. 예시는 다음과 같다.
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    auc = roc_auc_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'auroc': auc
    }

Official Docs

https://huggingface.co/docs/transformers/v4.19.2/en/main_classes/trainer

 

Trainer

When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, evaluation, save will be conducted every gradient_accumulation_steps * xxx_step training examples.

huggingface.co

reference

https://bo-10000.tistory.com/154

 

[HuggingFace] Trainer 사용법

Official Docs: https://huggingface.co/docs/transformers/v4.19.2/en/main_classes/trainer Trainer When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, evaluation, save will be conducted every gradient_accu

bo-10000.tistory.com

 

728x90
반응형

'공부정리 > NLP' 카테고리의 다른 글

[NLP] KoNlPy Okt 형태소 분석기 사전에 추가하기  (0) 2023.04.04
[NLP] Hugging Face 허깅페이스에서 불러온 모델 미세조정 후 모델업로드하기 (SQuAD v1.1 Dataset)  (0) 2023.02.07
[NLP] Hugging Face 허깅페이스 오류 / 깃 설정 오류 / OSError: Tried to clone a repository in a non-empty folder that isn't a git repository. If you really want to do this, do it manually:git init && git remote add origin && git pull origin ma..  (0) 2023.02.01
[NLP] Hugging Face 오류 / 토큰 권한 변경 / HfHubHTTPError: <class 'requests.exceptions.HTTPError'> (Request ID: Root=1-63d9c8e7-7270e6f27fc51f431f1a5df3)You don't have the rights to create a model under this namespace - You don't have the rights..  (0) 2023.02.01
[NLP] 허깅페이스(Huggingface)에 로그인하여 내 모델 포팅(porting)하기 / 토큰 발급, 허깅페이스 로그인, 모델 포팅  (0) 2023.01.31
'공부정리/NLP' 카테고리의 다른 글
  • [NLP] KoNlPy Okt 형태소 분석기 사전에 추가하기
  • [NLP] Hugging Face 허깅페이스에서 불러온 모델 미세조정 후 모델업로드하기 (SQuAD v1.1 Dataset)
  • [NLP] Hugging Face 허깅페이스 오류 / 깃 설정 오류 / OSError: Tried to clone a repository in a non-empty folder that isn't a git repository. If you really want to do this, do it manually:git init && git remote add origin && git pull origin ma..
  • [NLP] Hugging Face 오류 / 토큰 권한 변경 / HfHubHTTPError: <class 'requests.exceptions.HTTPError'> (Request ID: Root=1-63d9c8e7-7270e6f27fc51f431f1a5df3)You don't have the rights to create a model under this namespace - You don't have the rights..
sillon
sillon
꾸준해지려고 합니다..
    반응형
  • sillon
    sillon coding
    sillon
  • 전체
    오늘
    어제
    • menu (614) N
      • notice (2)
      • python (68)
        • 자료구조 & 알고리즘 (23)
        • 라이브러리 (19)
        • 기초 (8)
        • 자동화 (14)
        • 보안 (1)
      • coding test - python (301)
        • Programmers (166)
        • 백준 (76)
        • Code Tree (22)
        • 기본기 문제 (37)
      • coding test - C++ (5)
        • Programmers (4)
        • 백준 (1)
        • 기본기문제 (0)
      • 공부정리 (5)
        • 신호처리 시스템 (0)
        • Deep learnig & Machine lear.. (41)
        • Data Science (18)
        • Computer Vision (17)
        • NLP (40)
        • Dacon (2)
        • 모두를 위한 딥러닝 (강의 정리) (4)
        • 모두의 딥러닝 (교재 정리) (9)
        • 통계 (2)
      • HCI (23)
        • Haptics (7)
        • Graphics (11)
        • Arduino (4)
      • Project (21)
        • Web Project (1)
        • App Project (1)
        • Paper Project (1)
        • 캡스톤디자인2 (17)
        • etc (1)
      • OS (10)
        • Ubuntu (9)
        • Rasberry pi (1)
      • App & Web (9)
        • Android (7)
        • javascript (2)
      • C++ (5)
        • 기초 (5)
      • Cloud & SERVER (8) N
        • Git (2)
        • Docker (1) N
        • DB (4)
      • Paper (7)
        • NLP Paper review (6)
      • 데이터 분석 (0)
        • GIS (0)
      • daily (2)
        • 대학원 준비 (0)
      • 영어공부 (6)
        • job interview (2)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    Python
    소수
    백준
    programmers
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
sillon
[NLP] Hugging Face API, 허깅페이스 API / Trainer, Training Arguments 구현
상단으로

티스토리툴바