| import sys |
| from pathlib import Path |
| sys.path.append(str(Path(__file__).resolve().parent.parent)) |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import json |
| from models.model import Microformer |
| from config import * |
|
|
| |
| |
| |
| with open("data/vocab.json", "r") as f: |
| vocab = json.load(f) |
| stoi = vocab["stoi"] |
| itos = {int(k): v for k, v in vocab["itos"].items()} |
| VOCAB_SIZE = len(stoi) |
|
|
| data = torch.load("data/train.pt") |
| SEQ_LEN = MAX_SEQ_LEN |
| BATCH_SIZE = 32 |
|
|
| |
| num_batches = len(data) // (SEQ_LEN * BATCH_SIZE) |
| trimmed_len = num_batches * SEQ_LEN * BATCH_SIZE |
| data = data[:trimmed_len] |
| data = data.view(BATCH_SIZE, -1) |
|
|
| def get_batch(start_idx): |
| x = data[:, start_idx:start_idx+SEQ_LEN] |
| y = data[:, start_idx+1:start_idx+SEQ_LEN+1] |
| return x, y |
|
|
| |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| |
| |
| model = Microformer( |
| VOCAB_SIZE, |
| EMBED_DIM, |
| NUM_HEADS, |
| FF_DIM, |
| NUM_LAYERS, |
| MAX_SEQ_LEN, |
| long_term_adapter_dim=ADAPTER_DIM, |
| session_adapter_dim=ADAPTER_DIM |
| ) |
| model.to(device) |
|
|
| |
| |
| |
| model.freeze_except_adapters(session_only=False, include_output=True) |
| |
| for layer in model.layers: |
| if getattr(layer, 'session_adapter', None) is not None: |
| for param in layer.session_adapter.parameters(): |
| param.requires_grad = False |
|
|
| criterion = nn.CrossEntropyLoss() |
| optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3) |
|
|
| |
| |
| |
| for epoch in range(6): |
| for i in range(0, data.shape[1] - SEQ_LEN, SEQ_LEN): |
| inputs, targets = get_batch(i) |
| inputs, targets = inputs.to(device), targets.to(device) |
| optimizer.zero_grad() |
| out = model(inputs) |
| loss = criterion(out.reshape(-1, VOCAB_SIZE), targets.reshape(-1)) |
| loss.backward() |
| optimizer.step() |
|
|
| print(f"Epoch {epoch}, Loss: {loss.item():.4f}") |
|
|
| torch.save(model.state_dict(), "microformer.pt") |
|
|
| |
| |
| |
| def online_unsupervised_update(model, tokenizer, text, optimizer, loss_fn, device, max_len=64): |
| |
| ids = tokenizer.encode(text).ids + [tokenizer.token_to_id("<EOS>")] |
| if len(ids) < 2: |
| return None |
|
|
| ids = ids[:max_len + 1] |
| input_ids = ids[:-1] |
| target_ids = ids[1:] |
| input_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(input_ids)) |
| target_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(target_ids)) |
| input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device) |
| target_tensor = torch.tensor([target_ids], dtype=torch.long, device=device) |
|
|
| model.train() |
| logits = model(input_tensor) |
| logits = logits.view(-1, logits.size(-1)) |
| targets = target_tensor.view(-1) |
| loss = loss_fn(logits, targets) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| model.eval() |
| return loss.item() |
|
|
| |
| |
| |
| def reset_session_adapters(model): |
| for layer in model.layers: |
| if getattr(layer, 'session_adapter', None) is not None: |
| for param in layer.session_adapter.parameters(): |
| if param.data is not None: |
| nn.init.zeros_(param.data) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|