# Experiment 4

## Summary

Add L0 smooth regularization during SmallNet training.

## Description

### Setup

#### nnue-pytorch

<details>

<summary>Base commit</summary>

[70781454bb9fbd5fbaf3d47281daa22a82565984](https://github.com/official-stockfish/nnue-pytorch/tree/70781454bb9fbd5fbaf3d47281daa22a82565984)

Try to catch one source of errors (#367)

</details>

<details>

<summary>Changes</summary>

```diff
diff --git a/model/config.py b/model/config.py
index c1eb5d8..aefd3ce 100644
--- a/model/config.py
+++ b/model/config.py
@@ -20,3 +20,4 @@ class LossParams:
     end_lambda: float = 1.0
     pow_exp: float = 2.5
     qp_asymmetry: float = 0.0
+    sparsity_l0: float = 0.0
diff --git a/model/lightning_module.py b/model/lightning_module.py
index 0000ba3..c50d656 100644
--- a/model/lightning_module.py
+++ b/model/lightning_module.py
@@ -70,19 +70,17 @@ class NNUE(L.LightningModule):
             layer_stack_indices,
         ) = batch
 
-        scorenet = (
-            self.model(
-                us,
-                them,
-                white_indices,
-                white_values,
-                black_indices,
-                black_values,
-                psqt_indices,
-                layer_stack_indices,
-            )
-            * self.model.quantization.nnue2score
+        scorenet_logits, ft = self.model(
+            us,
+            them,
+            white_indices,
+            white_values,
+            black_indices,
+            black_values,
+            psqt_indices,
+            layer_stack_indices
         )
+        scorenet = scorenet_logits * self.model.quantization.nnue2score
 
         p = self.loss_params
         # convert the network and search scores to an estimate match result
@@ -108,6 +106,11 @@ class NNUE(L.LightningModule):
             loss = loss * ((qf > pt) * p.qp_asymmetry + 1)
         loss = loss.mean()
 
+        if p.sparsity_l0 > 0.0:
+            alpha = 20.0
+            l0_proxy = (1.0 - torch.exp(-alpha * ft.abs())).mean()
+            loss += p.sparsity_l0 * l0_proxy
+
         self.log(loss_type, loss, prog_bar=True)
 
         return loss
diff --git a/model/model.py b/model/model.py
index 359db58..d63bf8b 100644
--- a/model/model.py
+++ b/model/model.py
@@ -344,4 +344,4 @@ class NNUEModel(nn.Module):
         # which does both the averaging and sign flip for black to move)
         x = self.layer_stacks(l0_, layer_stack_indices) + (wpsqt - bpsqt) * (us - 0.5)
 
-        return x
+        return (x, l0_)
diff --git a/train.py b/train.py
index 999b0c3..486a984 100644
--- a/train.py
+++ b/train.py
@@ -207,6 +207,13 @@ def main():
         dest="out_scaling",
         help="scaling for conversion to win on input (default=380.0)",
     )
+    parser.add_argument(
+        "--sparsity-l0",
+        default=0.0,
+        type=float,
+        dest="sparsity_l0",
+        help="L0 regularization factor"
+    )
     parser.add_argument(
         "--gamma",
         default=0.992,
@@ -377,6 +384,7 @@ def main():
         end_lambda=args.end_lambda or args.lambda_,
         pow_exp=args.pow_exp,
         qp_asymmetry=args.qp_asymmetry,
+        sparsity_l0=args.sparsity_l0
     )
     print("Loss parameters:")
     print(loss_params)
```

</details>

#### Training Script

{% code title="train.sh" %}

```sh
#!/bin/sh

case "$1" in
  *.ckpt)
    echo "Resuming from the checkpoint: $1"
    resume_option="--resume-from-checkpoint=$1"
    ;;
  *.pt)
    echo "Resuming from the model: $1"
    resume_option="--resume-from-model=$1"
    ;;
esac

python train.py \
  /data/linrock/dual-nnue/hse-v1/dfrc99-16tb7p-eval-filt-v2.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/leela96-filt-v2.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test60-novdec2021-12tb7p-filter-v6-dd.min-mar2023.unmin.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test77-nov2021-2tb7p.no-db.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test77-dec2021-16tb7p.no-db.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test77-jan2022-2tb7p.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test78-jantomay2022-16tb7p-filter-v6-dd.min-mar2023.unmin.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test78-juntosep2022-16tb7p-filter-v6-dd.min-mar2023.unmin.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test79-apr2022-16tb7p.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test79-may2022-16tb7p-filter-v6-dd.min-mar2023.unmin.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test80-apr2022-16tb7p.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test80-may2022-16tb7p.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test80-jun2022-16tb7p-filter-v6-dd.min-mar2023.unmin.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test80-jul2022-16tb7p.v6-dd.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test80-sep2022-16tb7p-filter-v6-dd.min-mar2023.unmin.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/dual-nnue/hse-v1/test80-nov2022-16tb7p-v6-dd.min.high-simple-eval-1k.min-v2.binpack \
  /data/linrock/test80-2022/test80-2022-08-aug-16tb7p.v6-dd.min.binpack \
  /data/linrock/test80-2022/test80-2022-10-oct-16tb7p.v6-dd.binpack \
  /data/linrock/test80-2022/test80-2022-12-dec-16tb7p.min.binpack \
  /data/linrock/test80-2023/test80-2023-01-jan-16tb7p.v6-sk20.min.binpack \
  /data/linrock/test80-2023/test80-2023-02-feb-16tb7p.v6-dd.min.binpack \
  /data/linrock/test80-2023/test80-2023-03-mar-2tb7p.v6-sk16.min.binpack \
  /data/linrock/test80-2023/test80-2023-04-apr-2tb7p.v6-sk16.min.binpack \
  /data/linrock/test80-2023/test80-2023-05-may-2tb7p.v6.min.binpack \
  /data/linrock/test80-2023/test80-2023-06-jun-2tb7p.min-v2.binpack \
  /data/linrock/test80-2023/test80-2023-07-jul-2tb7p.min-v2.binpack \
  /data/linrock/test80-2023/test80-2023-08-aug-2tb7p.v6.min.binpack \
  /data/linrock/test80-2023/test80-2023-09-sep-2tb7p.binpack \
  /data/linrock/test80-2023/test80-2023-10-oct-2tb7p.binpack \
  /data/linrock/test80-2023/test80-2023-11-nov-2tb7p.min-v2.binpack \
  /data/linrock/test80-2023/test80-2023-12-dec-2tb7p.min-v2.binpack \
  /data/linrock/test80-2024/test80-2024-01-jan-2tb7p.min-v2.v6.binpack \
  /data/linrock/test80-2024/test80-2024-02-feb-2tb7p.min-v2.v6.binpack \
  /data/linrock/test80-2024/test80-2024-03-mar-2tb7p.min-v2.v6.binpack \
  /data/linrock/test80-2024/test80-2024-04-apr-2tb7p.min-v2.v6.binpack \
  --threads=4                           \
  --gpus=0,                             \
  --max_epochs=800                      \
  --num-workers=64                      \
  --batch-size=16384                    \
  --features=HalfKAv2_hm^               \
  --l1=128                              \
  --no-wld-fen-skipping                 \
  --start-lambda=1.0                    \
  --end-lambda=0.6992605384421289       \
  --gamma=0.9942746303116422            \
  --lr=0.0012181558724738395            \
  --in-scaling=317.54982869522763       \
  --out-scaling=379.8040378799748       \
  --in-offset=269.50654301142134        \
  --out-offset=253.51483568025657       \
  --pow-exp=2.4829519732155125          \
  --sparsity-l0=5e-5                    \
  --random-fen-skipping=3               \
  --simple-eval-skipping=931            \
  --network-save-period=10              \
  --compile-backend=cudagraphs          \
  $resume_option
```

{% endcode %}

{% code title="train-finalize.sh" %}

```sh
#!/bin/sh

python3 serialize.py \
  --out-sha \
  --ft_optimize \
  --ft_optimize_data=/data/official-stockfish/master-binpacks/fishpack32.binpack \
  --ft_optimize_count=1000000 \
  --l1=128 \
  "$1" -
```

{% endcode %}

### Local Testing

* 10,000 fixed games
* Openings: UHO\_Lichess\_4852\_v1
* Time control: 3+0.03
* Engine options: Threads=1, Hash=16

#### Epoch 400

```
Results of Small-FTSparse vs master (3+0.03, 1t, 16MB, UHO_Lichess_4852_v1.epd):
Elo: -1.77 +/- 3.60, nElo: -3.35 +/- 6.81
LOS: 16.75 %, DrawRatio: 50.14 %, PairsRatio: 0.95
Games: 10000, Wins: 2604, Losses: 2655, Draws: 4741, Points: 4974.5 (49.74 %)
Ptnml(0-2): [46, 1231, 2507, 1160, 56], WL/DD Ratio: 1.13
```

## Note

* Features are more sparse (master 30.29%, test 39.21%) but \~1% slowdown is measured. Seemingly, the overhead is greater than the benefit of filtering and accumulating non-zero blocks. With the combination of a bit of accuracy being also sacrificed, this approach doesn't sound promising at least regarding SmallNet.


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://stockfish.mineta.dev/nnue/readme/experiment-4.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
