| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Training DALL·E Mini. |
| Script adapted from run_summarization_flax.py |
| """ |
|
|
| import io |
| import logging |
| import os |
| import sys |
| import tempfile |
| import time |
| from dataclasses import asdict, dataclass, field |
| from functools import partial |
| from pathlib import Path |
| from typing import Any, Callable, NamedTuple, Optional |
|
|
| import datasets |
| import flax |
| import jax |
| import jax.numpy as jnp |
| import jaxlib |
| import numpy as np |
| import optax |
| import transformers |
| import wandb |
| from datasets import Dataset |
| from flax import core, struct, traverse_util |
| from flax.core.frozen_dict import FrozenDict, freeze, unfreeze |
| from flax.serialization import from_bytes, to_bytes |
| from flax.training.common_utils import onehot |
| from jax.experimental import PartitionSpec, maps |
| from jax.experimental.compilation_cache import compilation_cache as cc |
| from jax.experimental.pjit import pjit, with_sharding_constraint |
| from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shampoo |
| from tqdm import tqdm |
| from transformers import HfArgumentParser |
|
|
| import dalle_mini |
| from dalle_mini.data import Dataset |
| from dalle_mini.model import ( |
| DalleBart, |
| DalleBartConfig, |
| DalleBartTokenizer, |
| set_partitions, |
| ) |
|
|
| try: |
| from google.cloud import storage |
| except: |
| storage = None |
|
|
| logger = logging.getLogger(__name__) |
|
|
| cc.initialize_cache("jax_cache") |
|
|
|
|
| @dataclass |
| class ModelArguments: |
| """ |
| Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. |
| """ |
|
|
| model_name_or_path: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "The model checkpoint for weights initialization. " |
| "Don't set if you want to train a model from scratch. " |
| "W&B artifact references are supported in addition to the sources supported by `PreTrainedModel`." |
| }, |
| ) |
| config_name: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "Pretrained config name or path if not the same as model_name_or_path" |
| }, |
| ) |
| tokenizer_name: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "Pretrained tokenizer name or path if not the same as model_name_or_path" |
| }, |
| ) |
| dtype: Optional[str] = field( |
| default="float32", |
| metadata={ |
| "help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`." |
| }, |
| ) |
| restore_state: Optional[bool] = field( |
| default=False, |
| metadata={ |
| "help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path." |
| }, |
| ) |
| dropout: Optional[float] = field( |
| default=None, |
| metadata={"help": "Dropout rate. Overwrites config."}, |
| ) |
| activation_dropout: Optional[float] = field( |
| default=None, |
| metadata={"help": "Activation dropout rate. Overwrites config."}, |
| ) |
| attention_dropout: Optional[float] = field( |
| default=None, |
| metadata={"help": "Attention dropout rate. Overwrites config."}, |
| ) |
|
|
| def __post_init__(self): |
| if self.tokenizer_name is None: |
| self.tokenizer_name = self.model_name_or_path |
| assert ( |
| self.tokenizer_name is not None |
| ), "Tokenizer name or model name/path needs to be specified" |
| if self.restore_state: |
| assert self.model_name_or_path is not None and ( |
| "/model-" in self.model_name_or_path |
| ), "Restoring state only available with W&B artifact reference" |
|
|
| def get_metadata(self): |
| if self.model_name_or_path is not None and ":" in self.model_name_or_path: |
| if jax.process_index() == 0: |
| artifact = wandb.run.use_artifact(self.model_name_or_path) |
| else: |
| artifact = wandb.Api().artifact(self.model_name_or_path) |
| return artifact.metadata |
| else: |
| return dict() |
|
|
| def get_opt_state(self): |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| if self.restore_state is True: |
| |
| state_artifact = self.model_name_or_path.replace( |
| "/model-", "/state-", 1 |
| ) |
| if jax.process_index() == 0: |
| artifact = wandb.run.use_artifact(state_artifact) |
| else: |
| artifact = wandb.Api().artifact(state_artifact) |
| if artifact.metadata.get("bucket_path"): |
| |
| self.restore_state = artifact.metadata["bucket_path"] |
| else: |
| artifact_dir = artifact.download(tmp_dir) |
| self.restore_state = str(Path(artifact_dir) / "opt_state.msgpack") |
|
|
| if self.restore_state.startswith("gs://"): |
| bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack" |
| bucket, blob_name = str(bucket_path).split("/", 1) |
| assert ( |
| storage is not None |
| ), 'Could not find google.storage. Install with "pip install google-cloud-storage"' |
| client = storage.Client() |
| bucket = client.bucket(bucket) |
| blob = bucket.blob(blob_name) |
| return blob.download_as_bytes() |
|
|
| with Path(self.restore_state).open("rb") as f: |
| return f.read() |
|
|
|
|
| @dataclass |
| class DataTrainingArguments: |
| """ |
| Arguments pertaining to what data we are going to input our model for training and eval. |
| """ |
|
|
| text_column: Optional[str] = field( |
| default="caption", |
| metadata={ |
| "help": "The name of the column in the datasets containing the full texts (for summarization)." |
| }, |
| ) |
| encoding_column: Optional[str] = field( |
| default="encoding", |
| metadata={ |
| "help": "The name of the column in the datasets containing the image encodings." |
| }, |
| ) |
| dataset_repo_or_path: str = field( |
| default=None, |
| metadata={"help": "The dataset repository containing encoded files."}, |
| ) |
| train_file: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "The input training data file (glob & braceexpand acceptable)." |
| }, |
| ) |
| validation_file: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "An optional input evaluation data file (glob & braceexpand acceptable)." |
| }, |
| ) |
| |
| streaming: Optional[bool] = field( |
| default=True, |
| metadata={"help": "Whether to stream the dataset."}, |
| ) |
| use_auth_token: Optional[bool] = field( |
| default=False, |
| metadata={ |
| "help": "Whether to use the authentication token for private datasets." |
| }, |
| ) |
| shard_by_host: Optional[bool] = field( |
| default=False, |
| metadata={ |
| "help": "Whether to shard data files by host in multi-host environments." |
| }, |
| ) |
| blank_caption_prob: Optional[float] = field( |
| default=0.0, |
| metadata={ |
| "help": "Probability of removing some captions for classifier-free guidance." |
| }, |
| ) |
| clip_score_column: Optional[str] = field( |
| default="clip_score", |
| metadata={"help": "Column that containts clip score for filtering."}, |
| ) |
| min_clip_score: Optional[float] = field( |
| default=None, |
| metadata={"help": "Minimum clip score required."}, |
| ) |
| max_clip_score: Optional[float] = field( |
| default=None, |
| metadata={"help": "Maximum clip score required."}, |
| ) |
| filter_column: Optional[str] = field( |
| default=None, |
| metadata={"help": "Column that containts classes to be filtered."}, |
| ) |
| filter_value: Optional[str] = field( |
| default=None, |
| metadata={"help": "Class value to be kept during filtering."}, |
| ) |
| multi_eval_ds: Optional[bool] = field( |
| default=False, |
| metadata={ |
| "help": "Whether to look for multiple validation datasets (local support only)." |
| }, |
| ) |
| max_train_samples: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": "For debugging purposes or quicker training, truncate the number of training examples." |
| }, |
| ) |
| max_eval_samples: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": "For debugging purposes or quicker training, truncate the number of evaluation examples." |
| }, |
| ) |
| preprocessing_num_workers: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": "The number of processes to use for the preprocessing. Not used in streaming mode." |
| }, |
| ) |
| overwrite_cache: bool = field( |
| default=False, |
| metadata={ |
| "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode." |
| }, |
| ) |
| |
| seed_dataset: int = field( |
| default=None, |
| metadata={ |
| "help": "Random seed for the dataset that will be set at the beginning of training." |
| }, |
| ) |
|
|
| def __post_init__(self): |
| if self.dataset_repo_or_path is None: |
| raise ValueError("Need a dataset repository or path.") |
|
|
|
|
| @dataclass |
| class TrainingArguments: |
| """ |
| Arguments pertaining to training parameters. |
| """ |
|
|
| output_dir: str = field( |
| metadata={ |
| "help": "The output directory where the model predictions and checkpoints will be written." |
| }, |
| ) |
| overwrite_output_dir: bool = field( |
| default=False, |
| metadata={ |
| "help": ( |
| "Overwrite the content of the output directory. " |
| "Use this to continue training if output_dir points to a checkpoint directory." |
| ) |
| }, |
| ) |
|
|
| do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) |
| do_eval: bool = field( |
| default=False, metadata={"help": "Whether to run eval on the validation set."} |
| ) |
|
|
| per_device_train_batch_size: int = field( |
| default=8, |
| metadata={"help": "Batch size per data parallel device for training."}, |
| ) |
| per_device_eval_batch_size: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": "Batch size per data parallel device for evaluation. Same as training batch size if not set." |
| }, |
| ) |
|
|
| gradient_accumulation_steps: int = field( |
| default=1, |
| metadata={ |
| "help": "Number of updates steps to accumulate before performing an update pass." |
| }, |
| ) |
| gradient_checkpointing: bool = field( |
| default=False, metadata={"help": "Use gradient checkpointing."} |
| ) |
|
|
| learning_rate: float = field( |
| default=5e-5, metadata={"help": "The initial learning rate."} |
| ) |
| optim: str = field( |
| default="distributed_shampoo", |
| metadata={ |
| "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"' |
| }, |
| ) |
| weight_decay: float = field( |
| default=0.0, metadata={"help": "Weight decay applied to parameters."} |
| ) |
| beta1: float = field( |
| default=0.9, |
| metadata={"help": "Beta1 for Adam & Distributed Shampoo."}, |
| ) |
| beta2: float = field( |
| default=0.999, |
| metadata={"help": "Beta2 for for Adam & Distributed Shampoo."}, |
| ) |
| adam_epsilon: float = field( |
| default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."} |
| ) |
| max_grad_norm: float = field( |
| default=1.0, metadata={"help": "Max gradient norm for Adafactor."} |
| ) |
| block_size: int = field( |
| default=1024, |
| metadata={"help": "Chunked size for large layers with Distributed Shampoo."}, |
| ) |
| preconditioning_compute_steps: int = field( |
| default=10, metadata={"help": "Number of steps to update preconditioner."} |
| ) |
| skip_preconditioning_dim_size_gt: int = field( |
| default=4096, |
| metadata={"help": "Max size for preconditioning with Distributed Shampoo."}, |
| ) |
| graft_type: str = field( |
| default="rmsprop_normalized", |
| metadata={ |
| "help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'" |
| }, |
| ) |
| nesterov: bool = field( |
| default=False, |
| metadata={"help": "Use Nesterov momentum for Distributed Shampoo."}, |
| ) |
| optim_quantized: bool = field( |
| default=False, |
| metadata={ |
| "help": "Whether to quantize optimizer (only supported with Distributed Shampoo)." |
| }, |
| ) |
| shard_shampoo_across: str = field( |
| default="dp", |
| metadata={ |
| "help": "Whether to shard the optimizer across data devices (dp), model devices (mp) or both (2d)." |
| }, |
| ) |
|
|
| num_train_epochs: int = field( |
| default=3, metadata={"help": "Total number of training epochs to perform."} |
| ) |
|
|
| warmup_steps: int = field( |
| default=0, metadata={"help": "Linear warmup over warmup_steps."} |
| ) |
| lr_decay: str = field( |
| default=None, |
| metadata={ |
| "help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential." |
| }, |
| ) |
| lr_transition_steps: int = field( |
| default=None, |
| metadata={ |
| "help": "Number of transition steps associated with learning rate decay when using exponential decay." |
| }, |
| ) |
| lr_decay_rate: float = field( |
| default=None, |
| metadata={ |
| "help": "Decay rate associated with learning rate when using exponential decay." |
| }, |
| ) |
| lr_staircase: bool = field( |
| default=False, |
| metadata={ |
| "help": "Whether to use staircase or continuous learning rate when using exponential decay." |
| }, |
| ) |
| lr_offset: int = field( |
| default=0, |
| metadata={"help": "Number of steps to offset learning rate and keep it at 0."}, |
| ) |
| logging_steps: int = field( |
| default=40, metadata={"help": "Log every X updates steps."} |
| ) |
| eval_steps: int = field( |
| default=400, metadata={"help": "Run an evaluation every X steps."} |
| ) |
| save_steps: int = field( |
| default=4000, metadata={"help": "Save checkpoint every X updates steps."} |
| ) |
| log_model: bool = field( |
| default=False, |
| metadata={"help": "Log model to wandb at `save_steps` frequency."}, |
| ) |
| log_norm_steps: int = field( |
| default=True, |
| metadata={"help": "Log parameters and gradients norm at this frequency."}, |
| ) |
| log_histogram_steps: int = field( |
| default=False, |
| metadata={ |
| "help": "Log parameters and gradients histograms at this frequency. Slows down training." |
| }, |
| ) |
|
|
| seed_model: int = field( |
| default=42, |
| metadata={ |
| "help": "Random seed for the model that will be set at the beginning of training." |
| }, |
| ) |
|
|
| embeddings_only: bool = field( |
| default=False, metadata={"help": "Train only embedding layers."} |
| ) |
| init_embeddings: bool = field( |
| default=False, |
| metadata={"help": "When training embedding layers, initialize them."}, |
| ) |
|
|
| wandb_entity: Optional[str] = field( |
| default=None, |
| metadata={"help": "The wandb entity to use (for teams)."}, |
| ) |
| wandb_project: str = field( |
| default="dalle-mini", |
| metadata={"help": "The name of the wandb project."}, |
| ) |
| wandb_job_type: str = field( |
| default="Seq2Seq", |
| metadata={"help": "The name of the wandb job type."}, |
| ) |
|
|
| assert_TPU_available: bool = field( |
| default=False, |
| metadata={"help": "Verify that TPU is not in use."}, |
| ) |
|
|
| use_vmap_trick: bool = field( |
| default=True, |
| metadata={"help": "Verify that TPU is not in use."}, |
| ) |
|
|
| mp_devices: Optional[int] = field( |
| default=1, |
| metadata={ |
| "help": "Number of devices required for model parallelism. The other dimension of available devices is used for data parallelism." |
| }, |
| ) |
|
|
| dp_devices: int = field(init=False) |
|
|
| def __post_init__(self): |
| if self.assert_TPU_available: |
| assert ( |
| jax.local_device_count() == 8 |
| ), "TPUs in use, please check running processes" |
| if self.output_dir.startswith("gs://"): |
| assert ( |
| storage is not None |
| ), 'Could not find google.storage. Install with "pip install google-cloud-storage"' |
| assert self.optim in [ |
| "distributed_shampoo", |
| "adam", |
| "adafactor", |
| ], f"Selected optimizer not supported: {self.optim}" |
| if self.optim == "adafactor" and self.weight_decay == 0: |
| self.weight_decay = None |
| assert self.graft_type in [ |
| "rmsprop_normalized", |
| "rmsprop", |
| "adagrad", |
| "adagrad_normalized", |
| "sgd", |
| "sqrt_n", |
| ], f"Selected graft type not supported: {self.graft_type}" |
| assert self.lr_decay in [ |
| None, |
| "linear", |
| "exponential", |
| ], f"Selected learning rate decay not supported: {self.lr_decay}" |
| if self.per_device_eval_batch_size is None: |
| self.per_device_eval_batch_size = self.per_device_train_batch_size |
| if self.log_norm_steps is True: |
| self.log_norm_steps = self.logging_steps |
| if not self.do_train: |
| self.num_train_epochs = 1 |
| if ( |
| os.path.exists(self.output_dir) |
| and os.listdir(self.output_dir) |
| and self.do_train |
| and not self.overwrite_output_dir |
| ): |
| raise ValueError( |
| f"Output directory ({self.output_dir}) already exists and is not empty." |
| "Use --overwrite_output_dir to overcome." |
| ) |
| assert self.shard_shampoo_across in [ |
| "dp", |
| "mp", |
| "2d", |
| ], f"Shard shampoo across {self.shard_shampoo_across} not supported." |
| assert ( |
| self.mp_devices > 0 |
| ), f"Number of devices for model parallelism must be > 0" |
| assert ( |
| jax.device_count() % self.mp_devices == 0 |
| ), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})." |
| self.dp_devices = jax.device_count() // self.mp_devices |
|
|
|
|
| def split_params(data): |
| """Split params between scanned and non-scanned""" |
| flat = traverse_util.flatten_dict(unfreeze(data)) |
| split = {"standard": {}, "scanned_encoder": {}, "scanned_decoder": {}} |
| for k, v in flat.items(): |
| if "FlaxBartEncoderLayers" in k: |
| split["scanned_encoder"][k] = v |
| elif "FlaxBartDecoderLayers" in k: |
| split["scanned_decoder"][k] = v |
| else: |
| split["standard"][k] = v |
| |
| split = {k: v for k, v in split.items() if v} |
| for k, v in split.items(): |
| split[k] = freeze(traverse_util.unflatten_dict(v)) |
| return split |
|
|
|
|
| def unsplit_params(data): |
| flat = {} |
| for k in ["standard", "scanned_encoder", "scanned_decoder"]: |
| if k in data: |
| flat.update(traverse_util.flatten_dict(unfreeze(data[k]))) |
| return freeze(traverse_util.unflatten_dict(flat)) |
|
|
|
|
| def trainable_params(data, embeddings_only): |
| """Keep only trainable parameters""" |
|
|
| if not embeddings_only: |
| return data |
|
|
| data = unfreeze(data) |
| trainable = { |
| "lm_head": data["lm_head"], |
| "model": { |
| "decoder": { |
| layer: data["model"]["decoder"][layer] |
| for layer in [ |
| "embed_positions", |
| "embed_tokens", |
| "final_ln", |
| "layernorm_embedding", |
| ] |
| } |
| }, |
| } |
| return freeze(trainable) |
|
|
|
|
| def init_embeddings(model, params): |
| """Reinitialize trainable embeddings""" |
| |
| trainable_keypaths = [ |
| "lm_head.kernel", |
| "model.decoder.embed_positions.embedding", |
| "model.decoder.embed_tokens.embedding", |
| "model.decoder.final_ln.bias", |
| "model.decoder.layernorm_embedding.bias", |
| "model.decoder.layernorm_embedding.scale", |
| ] |
|
|
| |
| init_keys = {tuple(k.split(".")) for k in trainable_keypaths} |
| model._missing_keys = init_keys |
| return model.init_weights(model.key, model.input_shape, params=params) |
|
|
|
|
| def main(): |
| |
| parser = HfArgumentParser( |
| (ModelArguments, DataTrainingArguments, TrainingArguments) |
| ) |
| if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
| |
| |
| model_args, data_args, training_args = parser.parse_json_file( |
| json_file=os.path.abspath(sys.argv[1]) |
| ) |
| else: |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
| |
| if training_args.mp_devices > jax.local_device_count(): |
| assert ( |
| data_args.seed_dataset is not None |
| ), "Seed dataset must be provided when model is split over multiple hosts" |
|
|
| |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
| |
| logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) |
| if jax.process_index() == 0: |
| datasets.utils.logging.set_verbosity_warning() |
| transformers.utils.logging.set_verbosity_info() |
| else: |
| datasets.utils.logging.set_verbosity_error() |
| transformers.utils.logging.set_verbosity_error() |
|
|
| |
| logger.info(f"Training/evaluation parameters {training_args}") |
|
|
| |
| dataset = Dataset( |
| **asdict(data_args), |
| do_train=training_args.do_train, |
| do_eval=training_args.do_eval, |
| ) |
|
|
| logger.info(f"Local TPUs: {jax.local_device_count()}") |
| logger.info(f"Global TPUs: {jax.device_count()}") |
|
|
| |
| if jax.process_index() == 0: |
| wandb.init( |
| entity=training_args.wandb_entity, |
| project=training_args.wandb_project, |
| job_type=training_args.wandb_job_type, |
| config=parser.parse_args(), |
| ) |
|
|
| |
| config_args = { |
| k: getattr(model_args, k) |
| for k in ["dropout", "activation_dropout", "attention_dropout"] |
| if getattr(model_args, k) is not None |
| } |
| config_args["gradient_checkpointing"] = training_args.gradient_checkpointing |
| if model_args.config_name: |
| config = DalleBartConfig.from_pretrained(model_args.config_name) |
| else: |
| config = None |
|
|
| |
| if model_args.model_name_or_path: |
| model, params = DalleBart.from_pretrained( |
| model_args.model_name_or_path, |
| config=config, |
| seed=training_args.seed_model, |
| dtype=getattr(jnp, model_args.dtype), |
| _do_init=False, |
| ) |
| if training_args.embeddings_only and training_args.init_embeddings: |
| params = init_embeddings(model, params) |
| else: |
| model = DalleBart( |
| config, |
| seed=training_args.seed_model, |
| dtype=getattr(jnp, model_args.dtype), |
| _do_init=False, |
| ) |
| params = None |
| for k, v in config_args.items(): |
| setattr(model.config, k, v) |
| params_shape = model.params_shape_tree |
|
|
| |
| model_metadata = model_args.get_metadata() |
|
|
| |
| param_spec = set_partitions(params_shape, model.config.use_scan) |
| params_shape = freeze(params_shape) |
| if params is not None: |
| params = freeze(params) |
|
|
| |
| tokenizer = DalleBartTokenizer.from_pretrained( |
| model_args.tokenizer_name, use_fast=True |
| ) |
|
|
| |
| |
| dataset.preprocess(tokenizer=tokenizer, config=model.config) |
|
|
| |
| dropout_rng = jax.random.PRNGKey(training_args.seed_model) |
|
|
| |
| num_epochs = training_args.num_train_epochs |
| |
| batch_size_per_node_per_grad_step = ( |
| training_args.per_device_train_batch_size |
| * jax.local_device_count() |
| // training_args.mp_devices |
| ) |
| batch_size_per_node = ( |
| batch_size_per_node_per_grad_step * training_args.gradient_accumulation_steps |
| ) |
| batch_size_per_step = batch_size_per_node * jax.process_count() |
| eval_batch_size_per_node = ( |
| training_args.per_device_eval_batch_size |
| * jax.local_device_count() |
| // training_args.mp_devices |
| ) |
| eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count() |
| len_train_dataset, len_eval_dataset = dataset.length |
| steps_per_epoch = ( |
| len_train_dataset // batch_size_per_node |
| if len_train_dataset is not None |
| else None |
| ) |
| num_train_steps = ( |
| steps_per_epoch * num_epochs if steps_per_epoch is not None else None |
| ) |
| num_params = model.num_params(params_shape) |
|
|
| logger.info("***** Running training *****") |
| logger.info(f" Num examples = {len_train_dataset}") |
| logger.info(f" Num Epochs = {num_epochs}") |
| logger.info( |
| f" Batch size per dp device = {training_args.per_device_train_batch_size}" |
| ) |
| logger.info(f" Number of devices = {jax.device_count()}") |
| logger.info( |
| f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}" |
| ) |
| logger.info(f" Batch size per update = {batch_size_per_step}") |
| logger.info(f" Model parameters = {num_params:,}") |
|
|
| |
| if jax.process_index() == 0: |
| |
| wandb.define_metric("*", step_metric="train/step") |
|
|
| |
| wandb.config.update( |
| { |
| "len_train_dataset": len_train_dataset, |
| "len_eval_dataset": len_eval_dataset, |
| "batch_size_per_step": batch_size_per_step, |
| "num_params": num_params, |
| "model_config": model.config.to_dict(), |
| "num_devices": jax.device_count(), |
| "versions": { |
| "jax": jax.__version__, |
| "jaxlib": jaxlib.__version__, |
| "flax": flax.__version__, |
| "transformers": transformers.__version__, |
| "datasets": datasets.__version__, |
| "wandb": wandb.__version__, |
| "dalle_mini": dalle_mini.__version__, |
| }, |
| } |
| ) |
|
|
| |
| def create_learning_rate_fn() -> Callable[[int], jnp.array]: |
| """Create the learning rate function.""" |
| warmup_fn = optax.linear_schedule( |
| init_value=0.0, |
| end_value=training_args.learning_rate, |
| transition_steps=training_args.warmup_steps + 1, |
| ) |
| last_boundary = training_args.warmup_steps |
| |
| if training_args.lr_offset: |
| warmup_fn = optax.join_schedules( |
| schedules=[optax.constant_schedule(0.0), warmup_fn], |
| boundaries=[training_args.lr_offset], |
| ) |
| last_boundary += training_args.lr_offset |
| if training_args.lr_decay is None: |
| return warmup_fn |
| elif training_args.lr_decay == "linear": |
| assert ( |
| num_train_steps is not None |
| ), "linear decay requires knowing the dataset length" |
| decay_fn = optax.linear_schedule( |
| init_value=training_args.learning_rate, |
| end_value=0, |
| transition_steps=num_train_steps - training_args.warmup_steps, |
| ) |
| elif training_args.lr_decay == "exponential": |
| decay_fn = optax.exponential_decay( |
| init_value=training_args.learning_rate, |
| transition_steps=training_args.lr_transition_steps, |
| decay_rate=training_args.lr_decay_rate, |
| staircase=training_args.lr_staircase, |
| ) |
| schedule_fn = optax.join_schedules( |
| schedules=[warmup_fn, decay_fn], |
| boundaries=[last_boundary], |
| ) |
| return schedule_fn |
|
|
| learning_rate_fn = create_learning_rate_fn() |
|
|
| |
| trainable_params_shape = trainable_params( |
| params_shape, training_args.embeddings_only |
| ) |
| if training_args.optim == "distributed_shampoo": |
| |
| graft_type = { |
| "sgd": GraftingType.SGD, |
| "adagrad": GraftingType.ADAGRAD, |
| "rmsprop": GraftingType.RMSPROP, |
| "rmsprop_normalized": GraftingType.RMSPROP_NORMALIZED, |
| "sqrt_n": GraftingType.SQRT_N, |
| "adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED, |
| }[training_args.graft_type] |
| statistics_partition_spec = ( |
| PartitionSpec(None, training_args.shard_shampoo_across, None) |
| if training_args.shard_shampoo_across != "2d" |
| else PartitionSpec(None, "dp", "mp") |
| ) |
| opt = distributed_shampoo( |
| learning_rate_fn, |
| block_size=training_args.block_size, |
| beta1=training_args.beta1, |
| beta2=training_args.beta2, |
| diagonal_epsilon=1e-10, |
| matrix_epsilon=1e-6, |
| weight_decay=training_args.weight_decay, |
| start_preconditioning_step=max( |
| training_args.preconditioning_compute_steps + 1, 101 |
| ), |
| preconditioning_compute_steps=training_args.preconditioning_compute_steps, |
| statistics_compute_steps=1, |
| best_effort_shape_interpretation=True, |
| graft_type=graft_type, |
| nesterov=training_args.nesterov, |
| exponent_override=0, |
| statistics_partition_spec=statistics_partition_spec, |
| preconditioner_partition_spec=PartitionSpec( |
| training_args.shard_shampoo_across, None, None |
| ) |
| if training_args.shard_shampoo_across != "2d" |
| else PartitionSpec( |
| "mp" if training_args.mp_devices > training_args.dp_devices else "dp", |
| None, |
| None, |
| ), |
| num_devices_for_pjit=training_args.dp_devices, |
| shard_optimizer_states=True, |
| inverse_failure_threshold=0.1, |
| moving_average_for_momentum=True, |
| skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt, |
| clip_by_scaled_gradient_norm=None, |
| precision=jax.lax.Precision.HIGHEST, |
| best_effort_memory_usage_reduction=training_args.optim_quantized, |
| ) |
| |
| update_fn = opt.update |
|
|
| optimizer = {} |
| opt_fn = {} |
| for k, p in split_params(trainable_params_shape).items(): |
| if "scanned" in k: |
| p = jax.eval_shape( |
| lambda x: jax.tree_util.tree_map(lambda y: y[0], x), p |
| ) |
| optimizer[k] = opt.init(p) |
| opt_fn[k] = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)( |
| optimizer[k].pspec_fn, optimizer[k].shape_and_dtype_fn |
| ) |
| optimizer[k] = optax.GradientTransformation(optimizer[k].init_fn, update_fn) |
|
|
| elif training_args.optim == "adam": |
| optimizer = optax.adamw( |
| learning_rate=learning_rate_fn, |
| b1=training_args.beta1, |
| b2=training_args.beta2, |
| eps=training_args.adam_epsilon, |
| weight_decay=training_args.weight_decay, |
| ) |
| optimizer = {k: optimizer for k in split_params(trainable_params_shape)} |
|
|
| elif training_args.optim == "adafactor": |
| |
| |
| optimizer = optax.adafactor( |
| learning_rate=learning_rate_fn, |
| clipping_threshold=training_args.max_grad_norm, |
| weight_decay_rate=training_args.weight_decay, |
| ) |
| optimizer = {k: optimizer for k in split_params(trainable_params_shape)} |
|
|
| |
| def get_opt_state_spec_and_shape(): |
| |
| opt_state_shape = {} |
| for k, p in split_params(trainable_params_shape).items(): |
| if "scanned" not in k: |
| opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p) |
| else: |
| opt_state_shape[k] = jax.eval_shape(jax.vmap(optimizer[k].init), p) |
|
|
| if training_args.optim == "adafactor": |
| |
| opt_state_spec = {k: None for k in split_params(trainable_params_shape)} |
|
|
| elif training_args.optim in ["adam", "distributed_shampoo"]: |
|
|
| def _opt_state_spec_per_leaf(x, spec): |
| if isinstance(x, FrozenDict): |
| |
| return spec |
| else: |
| |
| return None |
|
|
| split_spec = split_params(set_partitions(trainable_params_shape, False)) |
| opt_state_spec = {} |
| for k, p in split_params(trainable_params_shape).items(): |
| if "scanned" in k: |
| p = jax.eval_shape( |
| lambda x: jax.tree_util.tree_map(lambda y: y[0], x), p |
| ) |
| if training_args.optim == "adam": |
| opt_state_spec[k] = jax.tree_util.tree_map( |
| partial(_opt_state_spec_per_leaf, spec=split_spec[k]), |
| opt_state_shape[k], |
| |
| is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)), |
| ) |
| elif training_args.optim == "distributed_shampoo": |
| opt_state_spec[k] = opt_fn[k].pspec_fn( |
| p, |
| split_spec[k], |
| statistics_partition_spec, |
| ) |
| |
| if "scanned" in k: |
| opt_state_spec[k] = jax.tree_util.tree_map( |
| lambda x: PartitionSpec(*(None,) + x) |
| if x is not None |
| else None, |
| opt_state_spec[k], |
| is_leaf=lambda x: isinstance(x, PartitionSpec), |
| ) |
|
|
| else: |
| raise NotImplementedError |
| return freeze(opt_state_spec), freeze(opt_state_shape) |
|
|
| opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape() |
|
|
| |
| mesh_shape = (training_args.dp_devices, training_args.mp_devices) |
| devices = np.asarray(jax.devices()).reshape(*mesh_shape) |
| mesh = maps.Mesh(devices, ("dp", "mp")) |
| logger.info(f" Mesh shape: {mesh_shape}") |
|
|
| |
| class TrainState(struct.PyTreeNode): |
| step: int |
| params: core.FrozenDict[str, Any] |
| opt_state: optax.OptState |
| apply_fn: Callable = struct.field(pytree_node=False) |
| tx: optax.GradientTransformation = struct.field(pytree_node=False) |
| dropout_rng: jnp.ndarray = None |
| epoch: int = 0 |
| train_time: float = 0.0 |
| train_samples: int = 0 |
|
|
| def apply_gradients(self, *, grads, **kwargs): |
| grads = split_params(trainable_params(grads, training_args.embeddings_only)) |
| params = split_params( |
| trainable_params(self.params, training_args.embeddings_only) |
| ) |
| opt_state = {} |
| |
| for k, param in params.items(): |
| update_fn = self.tx[k].update |
| if "scanned" in k: |
| update_fn = jax.vmap(update_fn, in_axes=(0, 0, 0), out_axes=(0, 0)) |
| updates, new_opt_state = update_fn(grads[k], self.opt_state[k], param) |
| params[k] = optax.apply_updates(param, updates) |
| opt_state[k] = new_opt_state |
| params = unsplit_params(params) |
| |
| params, new_params = traverse_util.flatten_dict( |
| unfreeze(self.params) |
| ), traverse_util.flatten_dict(unfreeze(params)) |
| params.update(new_params) |
| params = freeze(traverse_util.unflatten_dict(params)) |
|
|
| return self.replace( |
| step=self.step + 1, |
| params=params, |
| opt_state=freeze(opt_state), |
| **kwargs, |
| ) |
|
|
| @classmethod |
| def create(cls, *, apply_fn, params, tx, **kwargs): |
| opt_state = {} |
| for k, p in split_params( |
| trainable_params(params, training_args.embeddings_only) |
| ).items(): |
| init_fn = tx[k].init |
| if "scanned" in k: |
| init_fn = jax.vmap(init_fn) |
| opt_state[k] = init_fn(p) |
| return cls( |
| step=0, |
| apply_fn=apply_fn, |
| params=params, |
| tx=tx, |
| opt_state=freeze(opt_state), |
| **kwargs, |
| ) |
|
|
| |
| state_spec = TrainState( |
| params=param_spec, |
| opt_state=opt_state_spec, |
| dropout_rng=None, |
| step=None, |
| epoch=None, |
| train_time=None, |
| train_samples=None, |
| apply_fn=model.__call__, |
| tx=optimizer, |
| ) |
|
|
| |
| def maybe_init_params(params): |
| if params is not None: |
| |
| return params |
| else: |
| |
| return model.init_weights(model.key, model.input_shape) |
|
|
| with mesh: |
| logger.info(" Creating state") |
|
|
| |
| attr_state = {} |
| keys = ["train_time", "train_samples"] |
| if model_args.restore_state: |
| keys += ["step", "epoch"] |
| attr_state = {k: v for k, v in model_metadata.items() if k in keys} |
|
|
| if not model_args.restore_state: |
|
|
| def init_state(params): |
| return TrainState.create( |
| apply_fn=model.__call__, |
| tx=optimizer, |
| params=maybe_init_params(params), |
| dropout_rng=dropout_rng, |
| **attr_state, |
| ) |
|
|
| state = pjit( |
| init_state, |
| in_axis_resources=(param_spec,) |
| if model_args.model_name_or_path |
| else None, |
| out_axis_resources=state_spec, |
| donate_argnums=(0,), |
| )(params) |
|
|
| else: |
| |
| opt_state = from_bytes(opt_state_shape, model_args.get_opt_state()) |
|
|
| def restore_state(params, opt_state): |
| return TrainState( |
| apply_fn=model.__call__, |
| tx=optimizer, |
| params=params, |
| opt_state=opt_state, |
| dropout_rng=dropout_rng, |
| **attr_state, |
| ) |
|
|
| state = pjit( |
| restore_state, |
| in_axis_resources=( |
| param_spec, |
| opt_state_spec, |
| ), |
| out_axis_resources=state_spec, |
| donate_argnums=(0, 1), |
| )(params, opt_state) |
|
|
| |
| del opt_state |
|
|
| |
| del params, opt_state_spec, opt_state_shape |
|
|
| |
| batch_spec = PartitionSpec("dp") |
| grad_batch_spec = PartitionSpec(None, "dp") |
|
|
| |
| def loss_fn(logits, labels): |
| loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) |
| loss = loss.mean() |
| return loss |
|
|
| |
| |
| use_vmap_trick = training_args.use_vmap_trick |
|
|
| |
| if use_vmap_trick: |
| grad_param_spec = jax.tree_util.tree_map( |
| lambda x: PartitionSpec(*("dp",) + (x if x is not None else (None,))), |
| param_spec, |
| ) |
|
|
| |
| def train_step(state, batch, train_time): |
| |
| def get_minibatch(batch, grad_idx): |
| return jax.tree_util.tree_map( |
| lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False), |
| batch, |
| ) |
|
|
| def compute_loss(params, minibatch, dropout_rng): |
| |
| minibatch, labels = minibatch.pop("labels") |
| logits = state.apply_fn( |
| **minibatch, params=params, dropout_rng=dropout_rng, train=True |
| )[0] |
| return loss_fn(logits, labels) |
|
|
| grad_fn = jax.value_and_grad(compute_loss) |
|
|
| def loss_and_grad(grad_idx, dropout_rng): |
| |
| minibatch = ( |
| get_minibatch(batch, grad_idx) if grad_idx is not None else batch |
| ) |
| |
| minibatch = with_sharding_constraint(minibatch, batch_spec) |
| |
| dropout_rng, _ = jax.random.split(dropout_rng) |
|
|
| if use_vmap_trick: |
| |
| loss, grads = jax.vmap( |
| grad_fn, in_axes=(None, 0, None), out_axes=(0, 0) |
| )(state.params, minibatch, dropout_rng) |
| |
| loss = with_sharding_constraint(loss, batch_spec) |
| grads = with_sharding_constraint(grads, grad_param_spec) |
| |
| |
| loss, grads = jax.tree_util.tree_map( |
| lambda x: jnp.mean(x, axis=0), (loss, grads) |
| ) |
| else: |
| |
| loss, grads = grad_fn(state.params, minibatch, dropout_rng) |
| |
| grads = with_sharding_constraint(grads, param_spec) |
| |
| return loss, grads, dropout_rng |
|
|
| if training_args.gradient_accumulation_steps == 1: |
| loss, grads, dropout_rng = loss_and_grad(None, state.dropout_rng) |
| else: |
| |
| init_minibatch_step = ( |
| 0.0, |
| with_sharding_constraint( |
| jax.tree_util.tree_map(jnp.zeros_like, state.params), param_spec |
| ), |
| state.dropout_rng, |
| ) |
|
|
| |
| def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout): |
| cumul_loss, cumul_grads, dropout_rng = cumul_loss_grad_dropout |
| loss, grads, dropout_rng = loss_and_grad(grad_idx, dropout_rng) |
| cumul_loss, cumul_grads = jax.tree_util.tree_map( |
| jnp.add, (cumul_loss, cumul_grads), (loss, grads) |
| ) |
| cumul_grads = with_sharding_constraint(cumul_grads, param_spec) |
| return cumul_loss, cumul_grads, dropout_rng |
|
|
| |
| loss, grads, dropout_rng = jax.lax.fori_loop( |
| 0, |
| training_args.gradient_accumulation_steps, |
| cumul_minibatch_step, |
| init_minibatch_step, |
| ) |
| grads = with_sharding_constraint(grads, param_spec) |
| |
| loss, grads = jax.tree_util.tree_map( |
| lambda x: x / training_args.gradient_accumulation_steps, (loss, grads) |
| ) |
|
|
| grads = with_sharding_constraint(grads, param_spec) |
|
|
| |
| state = state.apply_gradients( |
| grads=grads, |
| dropout_rng=dropout_rng, |
| train_time=train_time, |
| train_samples=state.train_samples + batch_size_per_step, |
| ) |
|
|
| metrics = { |
| "loss": loss, |
| "learning_rate": learning_rate_fn(state.step), |
| } |
|
|
| def maybe_fn(fn, val, zeros, freq): |
| """Call fn only if it is a logging step""" |
| return jax.lax.cond( |
| state.step % freq == 0, |
| fn, |
| lambda _: zeros, |
| val, |
| ) |
|
|
| |
| params = trainable_params(state.params, training_args.embeddings_only) |
| grads = trainable_params(grads, training_args.embeddings_only) |
| if training_args.log_norm_steps: |
| zeros_norm = jax.tree_util.tree_map(lambda _: jnp.float32(0), params) |
|
|
| def norm(val): |
| return jax.tree_util.tree_map(lambda x: jnp.linalg.norm(x), val) |
|
|
| gradients_norm = maybe_fn( |
| norm, grads, zeros_norm, training_args.log_norm_steps |
| ) |
| params_norm = maybe_fn( |
| norm, params, zeros_norm, training_args.log_norm_steps |
| ) |
|
|
| metrics.update( |
| { |
| "gradients_norm": gradients_norm, |
| "params_norm": params_norm, |
| } |
| ) |
|
|
| if training_args.log_histogram_steps: |
| zeros_hist = jax.tree_util.tree_map( |
| lambda _: jnp.histogram(jnp.zeros(1), density=True), params |
| ) |
|
|
| def histogram(val): |
| return jax.tree_util.tree_map( |
| lambda x: jnp.histogram(x, density=True), val |
| ) |
|
|
| gradients_hist = maybe_fn( |
| histogram, grads, zeros_hist, training_args.log_histogram_steps |
| ) |
| params_hist = maybe_fn( |
| histogram, params, zeros_hist, training_args.log_histogram_steps |
| ) |
|
|
| metrics.update( |
| { |
| "params_hist": params_hist, |
| "gradients_hist": gradients_hist, |
| } |
| ) |
|
|
| return state, metrics |
|
|
| |
| eval_model = ( |
| model |
| if model_args.dtype == "float32" |
| else DalleBart( |
| model.config, |
| seed=training_args.seed_model, |
| dtype=jnp.float32, |
| _do_init=False, |
| ) |
| ) |
|
|
| def eval_step(state, batch): |
| def compute_eval_loss(batch): |
| batch, labels = batch.pop("labels") |
| logits = eval_model(**batch, params=state.params, train=False)[0] |
| return loss_fn(logits, labels) |
|
|
| if use_vmap_trick: |
| loss = jax.vmap(compute_eval_loss)(batch) |
| |
| loss = with_sharding_constraint(loss, batch_spec) |
| |
| loss = jnp.mean(loss) |
| else: |
| loss = compute_eval_loss(batch) |
|
|
| return loss |
|
|
| |
| p_train_step = pjit( |
| train_step, |
| in_axis_resources=( |
| state_spec, |
| grad_batch_spec |
| if training_args.gradient_accumulation_steps > 1 |
| else batch_spec, |
| None, |
| ), |
| out_axis_resources=(state_spec, None), |
| donate_argnums=(0,), |
| ) |
| p_eval_step = pjit( |
| eval_step, |
| in_axis_resources=(state_spec, batch_spec), |
| out_axis_resources=None, |
| ) |
|
|
| |
| class MetricsLogger: |
| def __init__(self, step): |
| |
| self.state_dict = {} |
| |
| self.step = step |
| self.time = time.perf_counter() |
| self.offset_time = 0.0 |
|
|
| def update_state_metrics(self, state): |
| """Update internal state metrics (logged at each call to be used as x-axis)""" |
| self.state_dict = { |
| f'train/{k.split("_")[-1]}': state[k] |
| for k in ["step", "epoch", "train_time", "train_samples"] |
| } |
| |
| new_step = int(state["step"]) |
| new_time = time.perf_counter() |
| if new_step > self.step: |
| |
| delta_time = new_time - self.time - self.offset_time |
| self.offset_time = 0 |
| time_per_step = delta_time / (new_step - self.step) |
| self.step = new_step |
| self.time = new_time |
| self.log_time("train_per_step", time_per_step, offset=False) |
| self.log_time("train_per_log", delta_time, offset=False) |
|
|
| def log_time(self, key, duration, offset=True): |
| if jax.process_index() == 0: |
| wandb.log({f"time/{key}": duration, **self.state_dict}) |
| if offset: |
| self.offset_time += duration |
|
|
| def log(self, metrics, prefix=None): |
| if jax.process_index() == 0: |
| log_metrics = {} |
| for k, v in metrics.items(): |
| if "_norm" in k: |
| if self.step % training_args.log_norm_steps == 0: |
| log_metrics[f"{k}/"] = unfreeze(v) |
| elif "_hist" in k: |
| if self.step % training_args.log_histogram_steps == 0: |
| v = jax.tree_util.tree_map( |
| lambda x: jax.device_get(x), unfreeze(v) |
| ) |
| v = jax.tree_util.tree_map( |
| lambda x: wandb.Histogram(np_histogram=x), |
| v, |
| is_leaf=lambda x: isinstance(x, tuple), |
| ) |
| log_metrics[f"{k}/"] = v |
| else: |
| if prefix is not None: |
| k = f"{prefix}/{k}" |
| log_metrics[k] = v |
| wandb.log({**log_metrics, **self.state_dict}) |
|
|
| |
| local_state = { |
| k: jax.device_get(getattr(state, k)).item() |
| for k in ["step", "epoch", "train_time", "train_samples"] |
| } |
| |
| start_time = time.perf_counter() - local_state["train_time"] |
| train_metrics = None |
| evaluation_ran = False |
| save_model_ran = False |
| metrics_logger = MetricsLogger(local_state["step"]) |
| epochs = tqdm( |
| range(local_state["epoch"], num_epochs), |
| desc=f"Epoch ... (1/{num_epochs})", |
| position=0, |
| disable=jax.process_index() > 0, |
| ) |
|
|
| def run_evaluation(): |
| |
| if training_args.do_eval: |
| start_eval_time = time.perf_counter() |
| |
| val_datasets = list( |
| dataset.other_eval_datasets.keys() |
| if hasattr(dataset, "other_eval_datasets") |
| else [] |
| ) |
| val_datasets += ["eval"] |
| for val_dataset in val_datasets: |
| eval_loader = dataset.dataloader( |
| val_dataset, |
| eval_batch_size_per_step |
| * max(1, training_args.mp_devices // jax.local_device_count()), |
| ) |
| eval_steps = ( |
| len_eval_dataset // eval_batch_size_per_step |
| if len_eval_dataset is not None |
| else None |
| ) |
| eval_loss = [] |
| for batch in tqdm( |
| eval_loader, |
| desc="Evaluating...", |
| position=2, |
| leave=False, |
| total=eval_steps, |
| disable=jax.process_index() > 0, |
| ): |
| |
| batch = jax.tree_util.tree_map( |
| lambda x: x.reshape( |
| (jax.process_count(), eval_batch_size_per_node) |
| + x.shape[1:] |
| ), |
| batch, |
| ) |
| batch = jax.tree_util.tree_map( |
| lambda x: x[jax.process_index()], batch |
| ) |
|
|
| |
| if use_vmap_trick: |
| bs_shape = ( |
| jax.local_device_count() // training_args.mp_devices, |
| training_args.per_device_eval_batch_size, |
| ) |
| batch = jax.tree_util.tree_map( |
| lambda x: x.reshape(bs_shape + x.shape[1:]), batch |
| ) |
|
|
| |
| batch = freeze(batch) |
| |
| eval_loss.append(p_eval_step(state, batch)) |
|
|
| |
| eval_loss = jnp.stack(eval_loss) |
| eval_loss = jnp.mean(eval_loss) |
| eval_metrics = {"loss": eval_loss} |
|
|
| |
| metrics_logger.log(eval_metrics, prefix=val_dataset) |
|
|
| |
| desc = f"Epoch... ({epoch + 1}/{num_epochs} | {val_dataset} Loss: {eval_metrics['loss']})" |
| epochs.write(desc) |
| epochs.desc = desc |
|
|
| |
| metrics_logger.log_time("eval", time.perf_counter() - start_eval_time) |
|
|
| return eval_metrics |
|
|
| def run_save_model(state, eval_metrics=None): |
| if jax.process_index() == 0: |
| start_save_time = time.perf_counter() |
| output_dir = training_args.output_dir |
| use_bucket = output_dir.startswith("gs://") |
| if use_bucket: |
| bucket_path = Path(output_dir[5:]) / wandb.run.id / f"step_{state.step}" |
| bucket, dir_path = str(bucket_path).split("/", 1) |
| tmp_dir = tempfile.TemporaryDirectory() |
| output_dir = tmp_dir.name |
|
|
| |
| params = jax.device_get(state.params) |
| model.save_pretrained( |
| output_dir, |
| params=params, |
| ) |
|
|
| |
| tokenizer.save_pretrained(output_dir) |
|
|
| |
| if use_bucket: |
| client = storage.Client() |
| bucket = client.bucket(bucket) |
| for filename in Path(output_dir).glob("*"): |
| blob_name = str(Path(dir_path) / "model" / filename.name) |
| blob = bucket.blob(blob_name) |
| blob.upload_from_filename(str(filename)) |
| tmp_dir.cleanup() |
|
|
| |
| opt_state = jax.device_get(state.opt_state) |
| if use_bucket: |
| blob_name = str(Path(dir_path) / "state" / "opt_state.msgpack") |
| blob = bucket.blob(blob_name) |
| blob.upload_from_file(io.BytesIO(to_bytes(opt_state))) |
| else: |
| with (Path(output_dir) / "opt_state.msgpack").open("wb") as f: |
| f.write(to_bytes(opt_state)) |
|
|
| |
| if training_args.log_model: |
| |
| c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache() |
| c.cleanup(wandb.util.from_human_size("20GB")) |
|
|
| metadata = { |
| k: jax.device_get(getattr(state, k)).item() |
| for k in ["step", "epoch", "train_time", "train_samples"] |
| } |
| metadata["num_params"] = num_params |
| if eval_metrics is not None: |
| metadata["eval"] = eval_metrics |
|
|
| |
| if use_bucket: |
| metadata["bucket_path"] = f"gs://{bucket_path}/model" |
| artifact = wandb.Artifact( |
| name=f"model-{wandb.run.id}", |
| type="DalleBart_model", |
| metadata=metadata, |
| ) |
| if use_bucket: |
| artifact.add_reference(metadata["bucket_path"]) |
| else: |
| for filename in [ |
| "config.json", |
| "flax_model.msgpack", |
| "merges.txt", |
| "special_tokens_map.json", |
| "tokenizer.json", |
| "tokenizer_config.json", |
| "vocab.json", |
| ]: |
| artifact.add_file( |
| f"{Path(training_args.output_dir) / filename}" |
| ) |
| wandb.run.log_artifact(artifact) |
|
|
| |
| if use_bucket: |
| metadata["bucket_path"] = f"gs://{bucket_path}/state" |
| artifact_state = wandb.Artifact( |
| name=f"state-{wandb.run.id}", |
| type="DalleBart_state", |
| metadata=metadata, |
| ) |
| if use_bucket: |
| artifact_state.add_reference(metadata["bucket_path"]) |
| else: |
| artifact_state.add_file( |
| f"{Path(training_args.output_dir) / 'opt_state.msgpack'}" |
| ) |
| wandb.run.log_artifact(artifact_state) |
| metrics_logger.log_time("save_model", time.perf_counter() - start_save_time) |
|
|
| logger.info(" Ready to start training") |
| with mesh: |
| for epoch in epochs: |
| state = state.replace(epoch=epoch) |
| local_state["epoch"] = epoch |
| |
| metrics_logger.update_state_metrics(local_state) |
| metrics_logger.log({}) |
|
|
| if training_args.do_train: |
| |
| node_groups = max( |
| 1, training_args.mp_devices // jax.local_device_count() |
| ) |
| loader_bs = batch_size_per_node * node_groups |
| train_loader = dataset.dataloader( |
| "train", |
| loader_bs, |
| epoch, |
| ) |
| |
| for batch in tqdm( |
| train_loader, |
| desc="Training...", |
| position=1, |
| leave=False, |
| total=steps_per_epoch, |
| disable=jax.process_index() > 0, |
| ): |
| |
| train_time = time.perf_counter() - start_time |
|
|
| |
| evaluation_ran = False |
| save_model_ran = False |
|
|
| |
| |
| bs_shape = ( |
| (batch_size_per_node_per_grad_step * node_groups,) |
| if not use_vmap_trick |
| else ( |
| jax.local_device_count() |
| * node_groups |
| // training_args.mp_devices, |
| training_args.per_device_train_batch_size, |
| ) |
| ) |
| if training_args.gradient_accumulation_steps > 1: |
| |
| |
| bs_shape = ( |
| training_args.gradient_accumulation_steps, |
| ) + bs_shape |
|
|
| |
| batch = jax.tree_util.tree_map( |
| lambda x: x.reshape(bs_shape + x.shape[1:]), |
| batch, |
| ) |
| |
| batch = freeze(batch) |
|
|
| |
| state, train_metrics = p_train_step(state, batch, train_time) |
| local_state["step"] += 1 |
| local_state["train_time"] = train_time |
| local_state["train_samples"] += batch_size_per_step |
|
|
| if ( |
| local_state["step"] % training_args.logging_steps == 0 |
| and jax.process_index() == 0 |
| ): |
| metrics_logger.update_state_metrics(local_state) |
| metrics_logger.log(train_metrics, prefix="train") |
|
|
| eval_metrics = None |
| if local_state["step"] % training_args.eval_steps == 0: |
| eval_metrics = run_evaluation() |
| evaluation_ran = True |
|
|
| if local_state["step"] % training_args.save_steps == 0: |
| run_save_model(state, eval_metrics) |
| save_model_ran = True |
|
|
| |
| if train_metrics is not None: |
| metrics_logger.update_state_metrics(local_state) |
| metrics_logger.log(train_metrics, prefix="train") |
|
|
| epochs.write( |
| f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})" |
| ) |
|
|
| |
| if not evaluation_ran: |
| eval_metrics = run_evaluation() |
|
|
| |
| if not save_model_ran: |
| run_save_model(state, eval_metrics) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|