Add optional device backend support (Accelerate, torch_xla) and flexible device selection for model loading
This commit is contained in:
@@ -7,6 +7,23 @@ import os
|
|||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
"""
|
||||||
|
Optional device backends:
|
||||||
|
- Accelerate: simplifies multi-GPU/TPU/CPU device placement
|
||||||
|
- torch_xla: TPU support (only available in TPU environments)
|
||||||
|
Both imports are optional to keep backward compatibility.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from accelerate import Accelerator # type: ignore
|
||||||
|
except Exception:
|
||||||
|
Accelerator = None # accelerate is optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_xla.core.xla_model as xm # type: ignore
|
||||||
|
XLA_AVAILABLE = True
|
||||||
|
except Exception:
|
||||||
|
xm = None
|
||||||
|
XLA_AVAILABLE = False
|
||||||
import lm_decoder
|
import lm_decoder
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
@@ -92,7 +109,8 @@ def update_ngram_params(
|
|||||||
def build_opt(
|
def build_opt(
|
||||||
model_name='facebook/opt-6.7b',
|
model_name='facebook/opt-6.7b',
|
||||||
cache_dir=None,
|
cache_dir=None,
|
||||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
device=None,
|
||||||
|
accelerator=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
'''
|
'''
|
||||||
@@ -101,17 +119,50 @@ def build_opt(
|
|||||||
Put the model onto the GPU (if available).
|
Put the model onto the GPU (if available).
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
# Resolve device automatically if not provided
|
||||||
|
if device is None:
|
||||||
|
if accelerator is not None:
|
||||||
|
device = accelerator.device
|
||||||
|
elif XLA_AVAILABLE and xm is not None:
|
||||||
|
# Use all available TPU cores for multi-core processing
|
||||||
|
device = xm.xla_device()
|
||||||
|
logging.info(f"TPU cores available: {xm.xrt_world_size()}")
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = torch.device('cuda')
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
|
# Choose appropriate dtype per device
|
||||||
|
try:
|
||||||
|
device_type = device.type # torch.device or XLA device
|
||||||
|
except AttributeError:
|
||||||
|
# Fallback for XLA device objects
|
||||||
|
device_type = str(device)
|
||||||
|
|
||||||
|
if XLA_AVAILABLE and (str(device).startswith('xla') or device_type == 'xla'):
|
||||||
|
load_dtype = torch.bfloat16 # TPU prefers bfloat16
|
||||||
|
elif device_type == 'cuda':
|
||||||
|
load_dtype = torch.float16
|
||||||
|
else:
|
||||||
|
load_dtype = torch.float32
|
||||||
|
|
||||||
# load tokenizer and model
|
# load tokenizer and model
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=load_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if device != 'cpu':
|
# Device placement
|
||||||
# Move the model to the GPU
|
if accelerator is not None:
|
||||||
|
model = accelerator.prepare(model)
|
||||||
|
else:
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
# For TPU multi-core, ensure model is replicated across all cores
|
||||||
|
if XLA_AVAILABLE and str(device).startswith('xla') and xm is not None:
|
||||||
|
# This will be handled by torch_xla internally when using xla_device()
|
||||||
|
pass
|
||||||
|
|
||||||
# Set the model to evaluation mode
|
# Set the model to evaluation mode
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -466,8 +517,24 @@ def main(args):
|
|||||||
'score_penalty_percent': float(score_penalty_percent),
|
'score_penalty_percent': float(score_penalty_percent),
|
||||||
}
|
}
|
||||||
|
|
||||||
# pick GPU
|
# pick device (Accelerate -> TPU -> CUDA -> CPU)
|
||||||
device = torch.device(f"cuda:{gpu_number}" if torch.cuda.is_available() else "cpu")
|
accelerator = None
|
||||||
|
if Accelerator is not None:
|
||||||
|
try:
|
||||||
|
accelerator = Accelerator()
|
||||||
|
except Exception:
|
||||||
|
accelerator = None
|
||||||
|
|
||||||
|
if accelerator is not None:
|
||||||
|
device = accelerator.device
|
||||||
|
elif XLA_AVAILABLE and xm is not None:
|
||||||
|
# Use all available TPU cores for multi-core processing
|
||||||
|
device = xm.xla_device()
|
||||||
|
logging.info(f"TPU cores available: {xm.xrt_world_size()}")
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{gpu_number}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
logging.info(f'Using device: {device}')
|
logging.info(f'Using device: {device}')
|
||||||
|
|
||||||
# initialize opt model
|
# initialize opt model
|
||||||
@@ -477,6 +544,7 @@ def main(args):
|
|||||||
lm, lm_tokenizer = build_opt(
|
lm, lm_tokenizer = build_opt(
|
||||||
cache_dir=opt_cache_dir,
|
cache_dir=opt_cache_dir,
|
||||||
device=device,
|
device=device,
|
||||||
|
accelerator=accelerator,
|
||||||
)
|
)
|
||||||
logging.info(f'OPT model successfully built in {(time.time()-start_time):0.4f} seconds.')
|
logging.info(f'OPT model successfully built in {(time.time()-start_time):0.4f} seconds.')
|
||||||
|
|
||||||
|
@@ -6,4 +6,5 @@ tensorboard
|
|||||||
tensorboardX
|
tensorboardX
|
||||||
typeguard
|
typeguard
|
||||||
textgrid
|
textgrid
|
||||||
redis
|
redis
|
||||||
|
accelerate>=0.33.0
|
Reference in New Issue
Block a user