Experiment 4

Summary

Add L0 smooth regularization during SmallNet training.

Description

Setup

nnue-pytorch

Base commit

70781454bb9fbd5fbaf3d47281daa22a82565984

Try to catch one source of errors (#367)

Changes
diff --git a/model/config.py b/model/config.py
index c1eb5d8..aefd3ce 100644
--- a/model/config.py
+++ b/model/config.py
@@ -20,3 +20,4 @@ class LossParams:
     end_lambda: float = 1.0
     pow_exp: float = 2.5
     qp_asymmetry: float = 0.0
+    sparsity_l0: float = 0.0
diff --git a/model/lightning_module.py b/model/lightning_module.py
index 0000ba3..c50d656 100644
--- a/model/lightning_module.py
+++ b/model/lightning_module.py
@@ -70,19 +70,17 @@ class NNUE(L.LightningModule):
             layer_stack_indices,
         ) = batch
 
-        scorenet = (
-            self.model(
-                us,
-                them,
-                white_indices,
-                white_values,
-                black_indices,
-                black_values,
-                psqt_indices,
-                layer_stack_indices,
-            )
-            * self.model.quantization.nnue2score
+        scorenet_logits, ft = self.model(
+            us,
+            them,
+            white_indices,
+            white_values,
+            black_indices,
+            black_values,
+            psqt_indices,
+            layer_stack_indices
         )
+        scorenet = scorenet_logits * self.model.quantization.nnue2score
 
         p = self.loss_params
         # convert the network and search scores to an estimate match result
@@ -108,6 +106,11 @@ class NNUE(L.LightningModule):
             loss = loss * ((qf > pt) * p.qp_asymmetry + 1)
         loss = loss.mean()
 
+        if p.sparsity_l0 > 0.0:
+            alpha = 20.0
+            l0_proxy = (1.0 - torch.exp(-alpha * ft.abs())).mean()
+            loss += p.sparsity_l0 * l0_proxy
+
         self.log(loss_type, loss, prog_bar=True)
 
         return loss
diff --git a/model/model.py b/model/model.py
index 359db58..d63bf8b 100644
--- a/model/model.py
+++ b/model/model.py
@@ -344,4 +344,4 @@ class NNUEModel(nn.Module):
         # which does both the averaging and sign flip for black to move)
         x = self.layer_stacks(l0_, layer_stack_indices) + (wpsqt - bpsqt) * (us - 0.5)
 
-        return x
+        return (x, l0_)
diff --git a/train.py b/train.py
index 999b0c3..486a984 100644
--- a/train.py
+++ b/train.py
@@ -207,6 +207,13 @@ def main():
         dest="out_scaling",
         help="scaling for conversion to win on input (default=380.0)",
     )
+    parser.add_argument(
+        "--sparsity-l0",
+        default=0.0,
+        type=float,
+        dest="sparsity_l0",
+        help="L0 regularization factor"
+    )
     parser.add_argument(
         "--gamma",
         default=0.992,
@@ -377,6 +384,7 @@ def main():
         end_lambda=args.end_lambda or args.lambda_,
         pow_exp=args.pow_exp,
         qp_asymmetry=args.qp_asymmetry,
+        sparsity_l0=args.sparsity_l0
     )
     print("Loss parameters:")
     print(loss_params)

Training Script

train.sh
#!/bin/sh

case "$1" in
  *.ckpt)
    echo "Resuming from the checkpoint: $1"
    resume_option="--resume-from-checkpoint=$1"
    ;;
  *.pt)
    echo "Resuming from the model: $1"
    resume_option="--resume-from-model=$1"
    ;;
esac

python train.py \
  /data/linrock/dual-nnue/hse-v1/dfrc99-16tb7p-eval-filt-v2.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/leela96-filt-v2.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test60-novdec2021-12tb7p-filter-v6-dd.min-mar2023.unmin.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test77-nov2021-2tb7p.no-db.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test77-dec2021-16tb7p.no-db.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test77-jan2022-2tb7p.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test78-jantomay2022-16tb7p-filter-v6-dd.min-mar2023.unmin.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test78-juntosep2022-16tb7p-filter-v6-dd.min-mar2023.unmin.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test79-apr2022-16tb7p.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test79-may2022-16tb7p-filter-v6-dd.min-mar2023.unmin.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test80-apr2022-16tb7p.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test80-may2022-16tb7p.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test80-jun2022-16tb7p-filter-v6-dd.min-mar2023.unmin.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test80-jul2022-16tb7p.v6-dd.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test80-sep2022-16tb7p-filter-v6-dd.min-mar2023.unmin.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test80-nov2022-16tb7p-v6-dd.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/test80-2022/test80-2022-08-aug-16tb7p.v6-dd.min.binpack \
  /data/linrock/test80-2022/test80-2022-10-oct-16tb7p.v6-dd.binpack \
  /data/linrock/test80-2022/test80-2022-12-dec-16tb7p.min.binpack \
  /data/linrock/test80-2023/test80-2023-01-jan-16tb7p.v6-sk20.min.binpack \
  /data/linrock/test80-2023/test80-2023-02-feb-16tb7p.v6-dd.min.binpack \
  /data/linrock/test80-2023/test80-2023-03-mar-2tb7p.v6-sk16.min.binpack \
  /data/linrock/test80-2023/test80-2023-04-apr-2tb7p.v6-sk16.min.binpack \
  /data/linrock/test80-2023/test80-2023-05-may-2tb7p.v6.min.binpack \
  /data/linrock/test80-2023/test80-2023-06-jun-2tb7p.min-v2.binpack \
  /data/linrock/test80-2023/test80-2023-07-jul-2tb7p.min-v2.binpack \
  /data/linrock/test80-2023/test80-2023-08-aug-2tb7p.v6.min.binpack \
  /data/linrock/test80-2023/test80-2023-09-sep-2tb7p.binpack \
  /data/linrock/test80-2023/test80-2023-10-oct-2tb7p.binpack \
  /data/linrock/test80-2023/test80-2023-11-nov-2tb7p.min-v2.binpack \
  /data/linrock/test80-2023/test80-2023-12-dec-2tb7p.min-v2.binpack \
  /data/linrock/test80-2024/test80-2024-01-jan-2tb7p.min-v2.v6.binpack \
  /data/linrock/test80-2024/test80-2024-02-feb-2tb7p.min-v2.v6.binpack \
  /data/linrock/test80-2024/test80-2024-03-mar-2tb7p.min-v2.v6.binpack \
  /data/linrock/test80-2024/test80-2024-04-apr-2tb7p.min-v2.v6.binpack \
  --threads=4                           \
  --gpus=0,                             \
  --max_epochs=800                      \
  --num-workers=64                      \
  --batch-size=16384                    \
  --features=HalfKAv2_hm^               \
  --l1=128                              \
  --no-wld-fen-skipping                 \
  --start-lambda=1.0                    \
  --end-lambda=0.6992605384421289       \
  --gamma=0.9942746303116422            \
  --lr=0.0012181558724738395            \
  --in-scaling=317.54982869522763       \
  --out-scaling=379.8040378799748       \
  --in-offset=269.50654301142134        \
  --out-offset=253.51483568025657       \
  --pow-exp=2.4829519732155125          \
  --sparsity-l0=5e-5                    \
  --random-fen-skipping=3               \
  --simple-eval-skipping=931            \
  --network-save-period=10              \
  --compile-backend=cudagraphs          \
  $resume_option

Local Testing

  • 10,000 fixed games

  • Openings: UHO_Lichess_4852_v1

  • Time control: 3+0.03

  • Engine options: Threads=1, Hash=16

Epoch 400

Note

  • Features are more sparse (master 30.29%, test 39.21%) but ~1% slowdown is measured. Seemingly, the overhead is greater than the benefit of filtering and accumulating non-zero blocks. With the combination of a bit of accuracy being also sacrificed, this approach doesn't sound promising at least regarding SmallNet.

Last updated