이전에 yolo 관련 글을 엄청나게 썼었는데 yolo를 사용하면서 접하게 되었던 트래킹용 도구이다.
모델 실험 시 metric들을 logging하거나 hyperparameter를 변경하며 반복 실험이 필요할 경우 유용하다.
pytorch 기준으로 작성한다.
1. 설치 및 initialization
pip install wandb
터미널에서 pip을 통해 간단하게 설치할 수 있다. wandb 계정을 생성한 후
wandb login
커맨드를 입력한 다음 API 키를 입력하면 해당 계정과 연결이 된다.
2. 학습 코드에 wandb 추가 예시
wandb run 생성
import wandb
wandb.init()
wandb.init() 커맨드를 사용하면 데이터를 logging하기 위한 background process가 생성되고, 이렇게 돌아가는 process들을 wandb에서는 'runs'라고 칭하는 것 같다. 이 아래부터 wandb.log()를 통해 logging되는 데이터들은 이 process에 저장된다. wandb에서는 여러 개의 project를 만들어서 관리할 수 있는데 인자에 project name을 명시하면 해당 project 안에 run이 생성된다.
import wandb
import torch
cfg = {
'batch_size':8,
'gamma':0.5
}
wandb.init(project='testrun', config=cfg)
이외에 위 코드처럼 실험 setting configuration 등을 넘겨줄 수도 있는데 이러면 metric과 함께 기록된다. 이 config 값들은 wandb.config.batch_size 등으로 접근할 수 있다.
wandb.log()를 통해 logging
학습/평가 과정 중 기록하고 싶은 metric을 wandb.log() 함수에 dictionary 형태로 넘겨주면 된다.
def train(model, input, target):
criterion = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters())
for epoch in range(10):
optimizer.zero_grad()
pred = model(input)
loss = criterion(pred, target)
# log our loss
wandb.log({"train_loss":loss})
loss.backward()
optimizer.step()
def test(model, input, target):
metric = dice_score()
model.eval()
with torch.no_grad():
pred = model(input)
acc = metric(pred, target)
# log our loss
wandb.log({"test_acc":acc})
예시로 간단한 학습 및 평가과정에서의 logging 방법을 구현해보았다.
3. wandb.sweep()을 통한 hyperparameter 최적화
모델에 대해 반복 실험이 필요한 경우 매우 유용하다. 각 hyperparameter와 최적화하고 싶은 값 (ex: validation accuracy) 간 상관관계와 model의 gradient 변화 등을 기록하고 시각화할 수 있으나 나는 사실 자동사냥 툴처럼 쓰고 있다.
sweep에서 조금 귀찮은 점은 모든 학습관련 코드를 함수 하나로 묶어야 한다는 것... 그래서 코드가 모듈별로 조각조각나있는 형태라면 수정해서 사용하기 좀 귀찮다. 우선 전체 코드 동작은 다음과 같다.
import wandb
sweep_configuration = {
'method': 'grid',
'name': 'sweep',
'metric': {'goal': 'maximize', 'name': 'valid_acc'},
'parameters':
{
'weight_decay': {'values': [0.05]},
'channel_rate': {'values': [1]},
'attn_drop': {'values': [0.1, 0.2, 0.5]},
'mlp_drop': {'values': [0.1, 0.2, 0.5]}
}
}
sweep_id = wandb.sweep(
sweep=sweep_configuration,
project='your_pj_name'
)
# 이 아래부터 train용 함수 작성
def main():
wandb.init() # 여기에 intitialization
# 현재 sweep run에서의 hyperparameter
attn_drop = wandb.config.attn_drop
mlp_drop = wandb.config.mlp_drop
# hyperparameter 필요한 곳에 넘겨주기
model = SwinTransformer(attn_drop=attn_drop, mlp_drop=mlp_drop)
# .... 필요한 코드 정리
# 실제 실행되는 부분
if __name__=="__main__":
wandb.agent(sweep_id, function=main, count=10)
sweep configuration 작성
sweep_configuration = {
'method': 'grid',
'name': 'sweep',
'metric': {'goal': 'maximize', 'name': 'valid_acc'},
'parameters':
{
'weight_decay': {'values': [0.05]},
'channel_rate': {'values': [1]},
'attn_drop': {'values': [0.1, 0.2, 0.5]},
'mlp_drop': {'values': [0.1, 0.2, 0.5]}
}
}
실험 setting에 맞게 적으면 된다.
'method' : hyperparameter 최적화를 위해 각 파라미터를 선택하는 전략으로 grid로 설정할 경우 모든 parameter 간 조합에 대해 최적화를 진행하고 random/bayes로 설정할 경우 확률분포에 기반해 parameter를 선택한다.
'name' : wandb 대시보드에 표시될 이름이다. (random 형용사-sweep-1 식으로 저장됨)
'metric' : hyperparameter tuning을 통해 최적화할 metric이다. (즉 wandb.log()를 통해 넘어가는 값이어야 함) goal은 loss를 최적화하는 경우 minimize, accuracy 등을 최적화할 경우 maximize로 상황에 맞게 선택하면 된다.
'parameters' : 가장 중요한 부분으로 실험할 파라미터 값이다. 마찬가지로 dictionary 형태로 정의되는데, 위 옵션에서 grid를 선택한 경우 key를 'values'로 설정하고 value를 list로 설정하면 list 안의 원소들을 하나씩 뽑아서 파라미터 조합을 만들어 준다. random/bayes로 설정했을 경우 그에 맞게 확률변수 등을 파라미터로 지정해줄 수 있으나 grid로만 진행해도 웬만한 실험은 다 되는 것 같다.
sweep id 선언
sweep_id = wandb.sweep(
sweep=sweep_configuration,
project='your_pj_name'
)
인자로 방금 정의한 configuration을 넘겨주면 sweep project가 생성된다. (위에서 wandb.init()으로 생성되는 run 여러개를 묶어서 취급하는 느낌)
함수 정의 및 실행
def main():
wandb.init() # 여기에 intitialization
# 현재 sweep run에서의 hyperparameter
attn_drop = wandb.config.attn_drop
mlp_drop = wandb.config.mlp_drop
# hyperparameter 필요한 곳에 넘겨주기
model = SwinTransformer(attn_drop=attn_drop, mlp_drop=mlp_drop)
# .... 필요한 코드 정리
# 실제 실행되는 부분
if __name__=="__main__":
wandb.agent(sweep_id, function=main, count=10)
sweep_id를 먼저 생성하고 wandb.init()을 실행하는 것을 유의해야함.
각 학습/평가에 맞는 코드를 함수 하나로 묶은 후 wandb.agent()를 통해 정의한 함수를 돌려준다. 이 때 count는 sweep을 진행할 최대 횟수이다. (method를 random으로 설정했거나 grid를 통해 찾을 파라미터가 무한할 경우 영원히 sweep이 돌아갈 수 있으므로 이걸 방지하기 위함)
여러 번 사용해본 결과 학습이 끝날 때 gpu 메모리가 초기화되지 않는 경우가 발생해서 학습 코드 말미에 꼭 torch.cuda.empty_cache() 및 gc.collect()를 추가해주는 편이다. (하지만 멀티 GPU로 여러 세션을 함께 돌릴 경우 조심하라...)
4. wandb.watch()를 통한 모델 가중치 (weight, bias) 및 그래디언트 확인
모델 학습 전 정의한 모델을 wandb.watch()에 넘겨주면 학습 동안 가중치 및 그래디언트 변화를 자동으로 기록해준다.
wandb.init()
model = SwinTransformer(attn_drop=attn_drop, mlp_drop=mlp_drop)
wandb.watch(model, log='all')
# ... training code
'Deep Learning' 카테고리의 다른 글
[Python] tqdm 이용해 모델 train 시 progress bar 표시하기 (1) | 2022.07.15 |
---|