47 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			47 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2019 Mobvoi Inc. All Rights Reserved.
 | |
| # Author: binbinzhang@mobvoi.com (Binbin Zhang)
 | |
| 
 | |
| import logging
 | |
| import os
 | |
| import re
 | |
| 
 | |
| import yaml
 | |
| import torch
 | |
| 
 | |
| 
 | |
| def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
 | |
|     if torch.cuda.is_available():
 | |
|         logging.info('Checkpoint: loading from checkpoint %s for GPU' % path)
 | |
|         checkpoint = torch.load(path)
 | |
|     else:
 | |
|         logging.info('Checkpoint: loading from checkpoint %s for CPU' % path)
 | |
|         checkpoint = torch.load(path, map_location='cpu')
 | |
|     model.load_state_dict(checkpoint)
 | |
|     info_path = re.sub('.pt$', '.yaml', path)
 | |
|     configs = {}
 | |
|     if os.path.exists(info_path):
 | |
|         with open(info_path, 'r') as fin:
 | |
|             configs = yaml.load(fin, Loader=yaml.FullLoader)
 | |
|     return configs
 | |
| 
 | |
| 
 | |
| def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
 | |
|     '''
 | |
|     Args:
 | |
|         infos (dict or None): any info you want to save.
 | |
|     '''
 | |
|     logging.info('Checkpoint: save to checkpoint %s' % path)
 | |
|     if isinstance(model, torch.nn.DataParallel):
 | |
|         state_dict = model.module.state_dict()
 | |
|     elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
 | |
|         state_dict = model.module.state_dict()
 | |
|     else:
 | |
|         state_dict = model.state_dict()
 | |
|     torch.save(state_dict, path)
 | |
|     info_path = re.sub('.pt$', '.yaml', path)
 | |
|     if infos is None:
 | |
|         infos = {}
 | |
|     with open(info_path, 'w') as fout:
 | |
|         data = yaml.dump(infos)
 | |
|         fout.write(data)
 | 
