Organisationsahrefsocannle7fe9c ()(lint-fmt)

(lint-fmt)

Link Copied
Code Copied

Logs

2025-12-15 21:43.22: New job: test ahrefs/ocannl https://github.com/ahrefs/ocannl.git#refs/heads/master (e7fe9c6cd69d74f0883f22c232c81a0dcbc378ab) (linux-x86_64:(lint-fmt))
Base: ocaml/opam:debian-13-ocaml-4.08@sha256:6fadef23b5069dc945f3a454c49421fd09e8c17aa57d3f9ad27d3879fce6aa44
ocamlformat version: version 0.28.1 (from opam)


To reproduce locally:


git clone --recursive "https://github.com/ahrefs/ocannl.git" -b "master" && cd "ocannl" && git reset --hard e7fe9c6c
cat > Dockerfile <<'END-OF-DOCKERFILE'
FROM ocaml/opam:debian-13-ocaml-4.08@sha256:6fadef23b5069dc945f3a454c49421fd09e8c17aa57d3f9ad27d3879fce6aa44
USER 1000:1000
RUN cd ~/opam-repository && (git cat-file -e 6c1b38620288b5bf349067f089a7b1fc91185d94 || git fetch origin master) && git reset -q --hard 6c1b38620288b5bf349067f089a7b1fc91185d94 && git log --no-decorate -n1 --oneline && opam update -u
RUN opam depext -i dune
WORKDIR /src
RUN opam depext -i ocamlformat=0.28.1
COPY --chown=1000:1000 . /src/
RUN opam exec -- dune build @fmt --ignore-promoted-rules || (echo "dune build @fmt failed"; exit 2)


END-OF-DOCKERFILE
docker build .
END-REPRO-BLOCK


2025-12-15 21:43.22: Using cache hint "ahrefs/ocannl-ocaml/opam:debian-13-ocaml-4.08@sha256:6fadef23b5069dc945f3a454c49421fd09e8c17aa57d3f9ad27d3879fce6aa44-debian-13-4.08_opam-2.4-ocamlformat-6c1b38620288b5bf349067f089a7b1fc91185d94"
2025-12-15 21:43.22: Using OBuilder spec:
((from ocaml/opam:debian-13-ocaml-4.08@sha256:6fadef23b5069dc945f3a454c49421fd09e8c17aa57d3f9ad27d3879fce6aa44)
(user (uid 1000) (gid 1000))
(run (cache (opam-archives (target /home/opam/.opam/download-cache)))
(network host)
(shell "cd ~/opam-repository && (git cat-file -e 6c1b38620288b5bf349067f089a7b1fc91185d94 || git fetch origin master) && git reset -q --hard 6c1b38620288b5bf349067f089a7b1fc91185d94 && git log --no-decorate -n1 --oneline && opam update -u"))
(run (cache (opam-archives (target /home/opam/.opam/download-cache)))
(network host)
(shell "opam depext -i dune"))
(workdir /src)
(run (cache (opam-archives (target /home/opam/.opam/download-cache)))
(network host)
(shell "opam depext -i ocamlformat=0.28.1"))
(copy (src .) (dst /src/))
(run (shell "opam exec -- dune build @fmt --ignore-promoted-rules || (echo \"dune build @fmt failed\"; exit 2)"))
)


2025-12-15 21:43.22: Waiting for resource in pool OCluster
2025-12-15 21:43.22: Waiting for worker…
2025-12-15 21:44.02: Got resource from pool OCluster
Building on laodoke.caelum.ci.dev
HEAD is now at 6b0d7d46 Document roots, embedded nodes, and params concepts
HEAD is now at e7fe9c6c Add test demonstrating padding initialization bug for max-pool


(from ocaml/opam:debian-13-ocaml-4.08@sha256:6fadef23b5069dc945f3a454c49421fd09e8c17aa57d3f9ad27d3879fce6aa44)
Unable to find image 'ocaml/opam:debian-13-ocaml-4.08@sha256:6fadef23b5069dc945f3a454c49421fd09e8c17aa57d3f9ad27d3879fce6aa44' locally
docker.io/ocaml/opam@sha256:6fadef23b5069dc945f3a454c49421fd09e8c17aa57d3f9ad27d3879fce6aa44: Pulling from ocaml/opam
2981f7e8980b: Already exists
9c63e1c4ba84: Already exists
02578b9c9f1b: Already exists
3801cb7ba5e6: Already exists
1c4cdedd39f1: Already exists
40f6006c5f5a: Already exists
c451a17216ec: Already exists
e4104b8f72ee: Already exists
7c7dbc5e7919: Already exists
aa82be714e7c: Already exists
922520f12384: Already exists
9bfea042cef8: Already exists
1244979f7c21: Already exists
7583a0e34f94: Already exists
14bef3f2665a: Already exists
3a4b42ce6cb5: Already exists
b0a08a900877: Already exists
b0a08a900877: Already exists
16ac87e68d60: Already exists
74ac6e8c9b82: Already exists
b41290a57dc5: Already exists
592e5bcb7159: Already exists
4f4fb700ef54: Already exists
1db0705661a3: Already exists
76bb8e35b9cc: Already exists
3bc9d98c3b49: Already exists
7096ef42e6c0: Already exists
cd0e70de8125: Already exists
23ffab57f98e: Already exists
b974353d8023: Already exists
c74fba566723: Already exists
5ad399846f45: Already exists
3e7203fa3980: Already exists
0b8915d2a92b: Already exists
e68c7a56c438: Already exists
9024d680a167: Already exists
3c7c73421b6e: Already exists
7134fa9b4278: Already exists
9d94fb523099: Already exists
20d06dbdae7e: Already exists
b9a45e537661: Already exists
d013a1d2f205: Pulling fs layer
b9e11c34eba5: Pulling fs layer
e0af6a95398b: Pulling fs layer
b559598791bf: Pulling fs layer
b559598791bf: Waiting
b9e11c34eba5: Verifying Checksum
b9e11c34eba5: Download complete
e0af6a95398b: Verifying Checksum
e0af6a95398b: Download complete
b559598791bf: Verifying Checksum
b559598791bf: Download complete
d013a1d2f205: Verifying Checksum
d013a1d2f205: Download complete
d013a1d2f205: Pull complete
b9e11c34eba5: Pull complete
e0af6a95398b: Pull complete
b559598791bf: Pull complete
Digest: sha256:6fadef23b5069dc945f3a454c49421fd09e8c17aa57d3f9ad27d3879fce6aa44
Status: Downloaded newer image for ocaml/opam@sha256:6fadef23b5069dc945f3a454c49421fd09e8c17aa57d3f9ad27d3879fce6aa44
2025-12-15 21:47.07 ---> using "c557567380599a9d74e9cba757661af503505628fe0a86c2e1639761275a17b1" from cache


/: (user (uid 1000) (gid 1000))


/: (run (cache (opam-archives (target /home/opam/.opam/download-cache)))
(network host)
(shell "cd ~/opam-repository && (git cat-file -e 6c1b38620288b5bf349067f089a7b1fc91185d94 || git fetch origin master) && git reset -q --hard 6c1b38620288b5bf349067f089a7b1fc91185d94 && git log --no-decorate -n1 --oneline && opam update -u"))
6c1b386202 Merge pull request #28774 from Julow/release-ocamlformat-0.28.1


<><> Updating package repositories ><><><><><><><><><><><><><><><><><><><><><><>
[default] Initialised
default (at git+file:///home/opam/opam-repository):
[INFO] opam 2.1 and 2.2 include many performance and security improvements over 2.0; please consider upgrading (https://opam.ocaml.org/doc/Install.html)


Everything as up-to-date as possible (run with --verbose to show unavailable upgrades).
However, you may "opam upgrade" these packages explicitly, which will ask permission to downgrade or uninstall the conflicting packages.
Nothing to do.
# Run eval $(opam env) to update the current shell environment
2025-12-15 21:49.33 ---> saved as "1258adf89d25757b915a8f4f5f0ba0eb7df8555636ae304d68c749c221ea21e4"


/: (run (cache (opam-archives (target /home/opam/.opam/download-cache)))
(network host)
(shell "opam depext -i dune"))
# Detecting depexts using vars: arch=x86_64, os=linux, os-distribution=debian, os-family=debian
# No extra OS packages requirements found.
# All required OS packages found.
# Now letting opam install the packages
The following actions will be performed:
- install dune 3.20.2


<><> Gathering sources ><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
[dune.3.20.2] found in cache


<><> Processing actions <><><><><><><><><><><><><><><><><><><><><><><><><><><><>
-> installed dune.3.20.2
Done.
# Run eval $(opam env) to update the current shell environment
2025-12-15 21:51.25 ---> saved as "cf5c80eb98554726a678ce73db8f3dbc8a3b1c7af6d5a2f454d5ed3097401b40"


/: (workdir /src)


/src: (run (cache (opam-archives (target /home/opam/.opam/download-cache)))
(network host)
(shell "opam depext -i ocamlformat=0.28.1"))
# Detecting depexts using vars: arch=x86_64, os=linux, os-distribution=debian, os-family=debian
# No extra OS packages requirements found.
# All required OS packages found.
# Now letting opam install the packages
The following actions will be performed:
- install sexplib0          v0.14.0  [required by base]
- install ocamlbuild        0.16.1   [required by fpath, astring, uuseg]
- install either            1.0.0    [required by ocamlformat-lib]
- install menhirLib         20250912 [required by ocamlformat-lib]
- install cmdliner          2.1.0    [required by ocamlformat]
- install csexp             1.5.2    [required by ocamlformat]
- install camlp-streams     5.0.1    [required by ocamlformat-lib]
- install seq               base     [required by re]
- install menhirSdk         20250912 [required by ocamlformat-lib]
- install fix               20250919 [required by ocamlformat-lib]
- install menhirCST         20250912 [required by menhir]
- install ocamlfind         1.9.8    [required by ocp-indent, astring, fpath, uuseg]
- install dune-build-info   3.20.2   [required by ocamlformat-lib]
- install ocaml-version     4.0.3    [required by ocamlformat-lib]
- install dune-configurator 3.20.2   [required by base]
- install re                1.11.0   [required by ocamlformat]
- install menhir            20250912 [required by ocamlformat-lib]
- install topkg             1.1.1    [required by fpath, astring, uuseg]
- install ocp-indent        1.9.0    [required by ocamlformat-lib]
- install base              v0.14.3  [required by ocamlformat-lib]
- install uutf              1.0.4    [required by ocamlformat-lib]
- install astring           0.8.5    [required by ocamlformat-lib]
- install stdio             v0.14.0  [required by ocamlformat-lib]
- install uucp              15.0.0   [required by uuseg]
- install fpath             0.7.3    [required by ocamlformat-lib]
- install uuseg             15.0.0   [required by ocamlformat-lib]
- install ocamlformat-lib   0.28.1   [required by ocamlformat]
- install ocamlformat       0.28.1
===== 28 to install =====


<><> Gathering sources ><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
[astring.0.8.5] found in cache
[base.v0.14.3] found in cache
[camlp-streams.5.0.1] found in cache
[cmdliner.2.1.0] found in cache
[csexp.1.5.2] found in cache
[dune-build-info.3.20.2] found in cache
[dune-configurator.3.20.2] found in cache
[either.1.0.0] found in cache
[fix.20250919] found in cache
[fpath.0.7.3] found in cache
[menhir.20250912] found in cache
[menhirCST.20250912] found in cache
[menhirLib.20250912] found in cache
[menhirSdk.20250912] found in cache
[ocaml-version.4.0.3] found in cache
[ocamlbuild.0.16.1] found in cache
[ocamlfind.1.9.8] found in cache
[ocamlformat.0.28.1] found in cache
[ocamlformat-lib.0.28.1] found in cache
[ocp-indent.1.9.0] found in cache
[re.1.11.0] found in cache
[sexplib0.v0.14.0] found in cache
[stdio.v0.14.0] found in cache
[topkg.1.1.1] found in cache
[uucp.15.0.0] found in cache
[uuseg.15.0.0] found in cache
[uutf.1.0.4] found in cache


<><> Processing actions <><><><><><><><><><><><><><><><><><><><><><><><><><><><>
-> installed seq.base
-> installed camlp-streams.5.0.1
-> installed csexp.1.5.2
-> installed either.1.0.0
-> installed fix.20250919
-> installed menhirCST.20250912
-> installed menhirLib.20250912
-> installed menhirSdk.20250912
-> installed ocaml-version.4.0.3
-> installed sexplib0.v0.14.0
-> installed re.1.11.0
-> installed cmdliner.2.1.0
-> installed dune-build-info.3.20.2
-> installed ocamlfind.1.9.8
-> installed dune-configurator.3.20.2
-> installed ocp-indent.1.9.0
-> installed ocamlbuild.0.16.1
-> installed base.v0.14.3
-> installed topkg.1.1.1
-> installed stdio.v0.14.0
-> installed uutf.1.0.4
-> installed astring.0.8.5
-> installed fpath.0.7.3
-> installed menhir.20250912
-> installed uucp.15.0.0
-> installed uuseg.15.0.0
-> installed ocamlformat-lib.0.28.1
-> installed ocamlformat.0.28.1
Done.


<><> ocp-indent.1.9.0 installed successfully ><><><><><><><><><><><><><><><><><>
=> This package requires additional configuration for use in editors. Install package 'user-setup', or manually:


* for Emacs, add these lines to ~/.emacs:
(add-to-list 'load-path "/home/opam/.opam/4.08/share/emacs/site-lisp")
(require 'ocp-indent)


* for Vim, add this line to ~/.vimrc:
set rtp^="/home/opam/.opam/4.08/share/ocp-indent/vim"
# Run eval $(opam env) to update the current shell environment
2025-12-15 21:53.37 ---> saved as "f13c932b5025e0eacc8744367d69770e9cfb3357c0e6c9295b596339590e56ac"


/src: (copy (src .) (dst /src/))
2025-12-15 21:53.38 ---> saved as "1a7f8d9956976f16f69371be759ea811ae6045fa0e1b33715300d3c637eaa16b"


/src: (run (shell "opam exec -- dune build @fmt --ignore-promoted-rules || (echo \"dune build @fmt failed\"; exit 2)"))
Warning: Invalid documentation comment:
File "tensor/einsum_types.ml", line 38, characters 0-0:
End of text is not allowed in '[...]' (code).
File "datasets/rand.ml", line 1, characters 0-0:
diff --git a/_build/default/datasets/rand.ml b/_build/default/datasets/.formatted/rand.ml
index 22f8f7f..84ab9a6 100644
--- a/_build/default/datasets/rand.ml
+++ b/_build/default/datasets/.formatted/rand.ml
@@ -24,6 +24,7 @@ module Random_for_tests : Random = struct
(raw /. 10000. *. (high -. low)) +. low


let char () = Char.chr @@ Int32.(to_int @@ rem (rand_int32 ()) 256l)
+
let int high =
(* Use abs to handle negative random values from xor-shift RNG *)
Int32.(to_int @@ rem (abs (rand_int32 ())) @@ of_int high)
File "tensor/einsum_types.ml", line 1, characters 0-0:
diff --git a/_build/default/tensor/einsum_types.ml b/_build/default/tensor/.formatted/einsum_types.ml
index 084d9ac..e357e66 100644
--- a/_build/default/tensor/einsum_types.ml
+++ b/_build/default/tensor/.formatted/einsum_types.ml
@@ -4,18 +4,16 @@


open Base


-(** Use_padding specification for convolutions. *)
type use_padding_spec = [ `True | `False | `Unspecified ] [@@deriving compare, sexp]
+(** Use_padding specification for convolutions. *)


-(** Convolution component for affine axis specifications.
-    Note: [dilation] is a string because it can be an identifier at parse time,
-    and is resolved to an int at runtime. *)
type conv_spec = { dilation : string; kernel_label : string; use_padding : use_padding_spec }
[@@deriving compare, sexp]
+(** Convolution component for affine axis specifications. Note: [dilation] is a string because it
+    can be an identifier at parse time, and is resolved to an int at runtime. *)


-(** Specification for individual axes in the einsum notation.
-    Note: [stride] is a string because it can be an identifier at parse time,
-    and is resolved to an int at runtime. *)
+(** Specification for individual axes in the einsum notation. Note: [stride] is a string because it
+    can be an identifier at parse time, and is resolved to an int at runtime. *)
type axis_spec =
| Label of string  (** A variable axis label. *)
| Fixed_index of int  (** A fixed index, used for projection. *)
@@ -25,8 +23,8 @@ type axis_spec =
conv : conv_spec option;  (** Optional convolution: dilation*kernel. *)
stride_offset : int;  (** Constant offset added after stride*over. *)
}
-      (** Affine axis specification: stride*over + stride_offset [+ dilation*kernel].
-          Corresponds to [Row.Affine] in shape inference. *)
+      (** Affine axis specification: stride*over + stride_offset [+ dilation*kernel]. Corresponds to
+          [Row.Affine] in shape inference. *)
[@@deriving compare, sexp]


(** An index pointing to any of a shape's axes, including the kind of the axis ([Batch, Input,
@@ -75,8 +73,8 @@ type parsed_axis_labels = {
(** The labels are strings assigned to [AxisKey] axes. Moreover the [bcast_] fields represent
whether additional leading/middle axes are allowed (corresponding to the dot-ellipsis syntax for
broadcasting). The string can be used to identify a row variable, and defaults to ["batch"],
-    ["input"], ["output"] respectively when parsing ["..."]. The [given_] fields are lists of
-    axis specs of the corresponding kind in [labels] where [from_end=true], [given_beg_] where
+    ["input"], ["output"] respectively when parsing ["..."]. The [given_] fields are lists of axis
+    specs of the corresponding kind in [labels] where [from_end=true], [given_beg_] where
[from_end=false]. *)


let axis_labels parsed = parsed.labels
File "datasets/circles.ml", line 1, characters 0-0:
diff --git a/_build/default/datasets/circles.ml b/_build/default/datasets/.formatted/circles.ml
index 1c640a3..4fae3df 100644
--- a/_build/default/datasets/circles.ml
+++ b/_build/default/datasets/.formatted/circles.ml
@@ -11,21 +11,19 @@ module Config = struct
seed : int option;  (** Optional random seed for reproducibility *)
}


-  let default =
-    { image_size = 32; max_radius = 8; min_radius = 2; max_circles = 5; seed = None }
+  let default = { image_size = 32; max_radius = 8; min_radius = 2; max_circles = 5; seed = None }
end


module Random = Rand.Random_for_tests


-(** Draw a filled circle on the image at (cx, cy) with radius r.
-    Values are clamped to [0, 1] range. *)
+(** Draw a filled circle on the image at (cx, cy) with radius r. Values are clamped to [0, 1] range.
+*)
let draw_circle ~image_size image cx cy r =
for y = 0 to image_size - 1 do
for x = 0 to image_size - 1 do
let dx = x - cx in
let dy = y - cy in
-      if (dx * dx) + (dy * dy) <= r * r then
-        Genarray.set image [| y; x; 0 |] 1.0
+      if (dx * dx) + (dy * dy) <= r * r then Genarray.set image [| y; x; 0 |] 1.0
done
done


@@ -36,7 +34,8 @@ let draw_circle ~image_size image cx cy r =
@param len Number of images to generate
@return
A tuple of (images, labels) where:
-      - images is a bigarray of shape [len; image_size; image_size; 1] (batch, height, width, channels)
+      - images is a bigarray of shape [len; image_size; image_size; 1] (batch, height, width,
+        channels)
- labels is a bigarray of shape [len; 1] (batch, output) containing the circle count *)
let generate_with_kind kind ?(config = Config.default) ~len () =
(match config.seed with Some seed -> Random.init seed | None -> ());
File "tensor/shape.mli", line 1, characters 0-0:
diff --git a/_build/default/tensor/shape.mli b/_build/default/tensor/.formatted/shape.mli
index d7533f0..b3e6eba 100644
--- a/_build/default/tensor/shape.mli
+++ b/_build/default/tensor/.formatted/shape.mli
@@ -49,8 +49,9 @@


Adding [<] after the output label (e.g., [stride*output<+kernel]) indicates no-padding mode,
where indices must stay within the input bounds. In this mode, the input dimension must satisfy:
-    [(input - effective_kernel_span) mod stride = 0], where [effective_kernel_span = 1 + (kernel - 1) * dilation].
-    Without [<], padding is applied and there is no such divisibility constraint.
+    [(input - effective_kernel_span) mod stride = 0], where
+    [effective_kernel_span = 1 + (kernel - 1) * dilation]. Without [<], padding is applied and there
+    is no such divisibility constraint.


Note: currently, OCANNL shapes always allow broadcasting. Row variables track the broadcasted
axes -- if there is no row variable, broadcasted axes are not tracked. In the notation case
@@ -242,9 +243,9 @@ val to_padding : t -> (Ir.Ops.axis_padding array * float) option
val propagate_shapes : update_step -> unit


val get_projections : update_step -> Ir.Indexing.projections
-(** Returns the projections for this update step, computing them if not already done.
-    This triggers [finish_inference] and then retrieves the projections from
-    [unsafe_projections]. Use this instead of [derive_projections] directly. *)
+(** Returns the projections for this update step, computing them if not already done. This triggers
+    [finish_inference] and then retrieves the projections from [unsafe_projections]. Use this
+    instead of [derive_projections] directly. *)


val of_spec : ?deduced:deduce_within_shape -> debug_name:string -> id:int -> string -> t
val default_display_indices : t -> int array
@@ -253,5 +254,5 @@ val to_labels : t -> string array
(** Uses the matrix convention of putting the input axes last. *)


val parse_n5_layout : string -> int array
-(** Parse a N5_layout priority string (e.g., "0,1,2") into display indices.
-    Only supports integer labels (Fixed_index). *)
+(** Parse a N5_layout priority string (e.g., "0,1,2") into display indices. Only supports integer
+    labels (Fixed_index). *)
File "test/einsum/test_padding_reset.ml", line 1, characters 0-0:
diff --git a/_build/default/test/einsum/test_padding_reset.ml b/_build/default/test/einsum/.formatted/test_padding_reset.ml
index d67aa9d..4573e51 100644
--- a/_build/default/test/einsum/test_padding_reset.ml
+++ b/_build/default/test/einsum/.formatted/test_padding_reset.ml
@@ -5,10 +5,9 @@ open Stdio


(** Test that padding margins are properly initialized and reset between operations.


-    This test demonstrates the padding behavior with use_padding=true (= marker).
-    We apply TWO DIFFERENT operations to the SAME input tensor, each requiring
-    different padding margins. The input's padding must be properly reset between
-    the two operations.
+    This test demonstrates the padding behavior with use_padding=true (= marker). We apply TWO
+    DIFFERENT operations to the SAME input tensor, each requiring different padding margins. The
+    input's padding must be properly reset between the two operations.


- Max-pool-like operation: padding should be -infinity for correct max behavior
- Conv-like operation: padding should be 0 for correct sum behavior
@@ -18,29 +17,24 @@ let test_padding_reset () =
printf "Testing padding margin initialization and reset...\n%!";
Tensor.unsafe_reinitialize ();


-  (* Create a 4x4 input with negative values: -16..-1
-     This way, if padding margins are 0 (incorrect for max-pool with negative values),
-     the max will incorrectly be 0 instead of the actual maximum negative value.
-     Proper padding for max should be -infinity or at least very negative. *)
+  (* Create a 4x4 input with negative values: -16..-1 This way, if padding margins are 0 (incorrect
+     for max-pool with negative values), the max will incorrectly be 0 instead of the actual maximum
+     negative value. Proper padding for max should be -infinity or at least very negative. *)
let%op input = TDSL.range_of_shape ~output_dims:[ 4; 4 ] () - 16. in


-  (* Max-pool-like operation on input with stride=1, window=3, use_padding=true.
-     For max-pool, padding value should be -infinity so max ignores padding positions. *)
-  let%op pooled =
-    input @^+ "oh=+wh, ow=+ww; wh, ww => oh, ow" [ "wh"; "ww" ] (0.0 + 0.0)
-  in
+  (* Max-pool-like operation on input with stride=1, window=3, use_padding=true. For max-pool,
+     padding value should be -infinity so max ignores padding positions. *)
+  let%op pooled = input @^+ "oh=+wh, ow=+ww; wh, ww => oh, ow" [ "wh"; "ww" ] (0.0 + 0.0) in
Shape.set_dim wh 3;
Shape.set_dim ww 3;


(* Conv-like operation ALSO on input (not pooled!) with stride=1, kernel=3, use_padding=true.
-     Kernel is all 1.0, so this sums 3x3 windows of input.
-     For conv/sum, padding value should be 0 so sum ignores padding positions.
-
-     KEY: Both operations use the SAME input tensor, but require DIFFERENT padding values.
-     The input's padding margins must be reset between the two operations. *)
-  let%op conv_out =
-    input +* "oh=+kh, ow=+kw; kh, kw => oh, ow" [ "kh"; "kw" ] (1.0 + 0.0)
-  in
+     Kernel is all 1.0, so this sums 3x3 windows of input. For conv/sum, padding value should be 0
+     so sum ignores padding positions.
+
+     KEY: Both operations use the SAME input tensor, but require DIFFERENT padding values. The
+     input's padding margins must be reset between the two operations. *)
+  let%op conv_out = input +* "oh=+kh, ow=+kw; kh, kw => oh, ow" [ "kh"; "kw" ] (1.0 + 0.0) in
Shape.set_dim kh 3;
Shape.set_dim kw 3;


@@ -49,8 +43,8 @@ let test_padding_reset () =
Train.set_hosted pooled.value;
Train.set_hosted conv_out.value;


-  (* Compile BOTH forward passes into a single routine using sequence.
-     This tests that input's padding is properly reset between the two operations. *)
+  (* Compile BOTH forward passes into a single routine using sequence. This tests that input's
+     padding is properly reset between the two operations. *)
let ctx = Train.init_params ctx Train.IDX.empty conv_out in
let ctx = Train.init_params ctx Train.IDX.empty pooled in
(* Get forward codes - order matters since consume_forward_code modifies state *)
@@ -81,30 +75,19 @@ let test_padding_reset () =
printf "\nConv after second pass:\n%!";
Tensor.print ~here:[%here] ~force:true ~with_code:false ~with_grad:false `Inline conv_out;


-  (* Analysis of expected values:
-     Input is 4x4 with values -16 to -1:
-       -16  -15  -14  -13
-       -12  -11  -10  -9
-       -8   -7   -6   -5
-       -4   -3   -2   -1
-
-     For MAX-POOL at (0,0) with 3x3 window:
-       - Window covers input[0..2, 0..2] = top-left 3x3
-       - Values: -16,-15,-14,-12,-11,-10,-8,-7,-6
-       - Correct max = -6
-       - But with padding=0 at corners, window includes padding positions
-       - If pad=0: max = 0 (BUG!)
-
-     For CONV at (0,0) with 3x3 window:
-       - Same window, but summing with kernel of 1s
-       - With padding=0, sum = -16-15-14-12-11-10-8-7-6 = -99
-       - Corner positions include fewer real values due to padding
-       - sum at (0,0) = 0+0+0+0+(-16)+(-15)+0+(-12)+(-11) = -54
-
-     The key test: if input's padding is not reset between max-pool and conv,
-     the results will be wrong for one or both operations.
-  *)
+  (* Analysis of expected values: Input is 4x4 with values -16 to -1: -16 -15 -14 -13 -12 -11 -10 -9
+     -8 -7 -6 -5 -4 -3 -2 -1
+
+     For MAX-POOL at (0,0) with 3x3 window: - Window covers input[0..2, 0..2] = top-left 3x3 -
+     Values: -16,-15,-14,-12,-11,-10,-8,-7,-6 - Correct max = -6 - But with padding=0 at corners,
+     window includes padding positions - If pad=0: max = 0 (BUG!)
+
+     For CONV at (0,0) with 3x3 window: - Same window, but summing with kernel of 1s - With
+     padding=0, sum = -16-15-14-12-11-10-8-7-6 = -99 - Corner positions include fewer real values
+     due to padding - sum at (0,0) = 0+0+0+0+(-16)+(-15)+0+(-12)+(-11) = -54


+     The key test: if input's padding is not reset between max-pool and conv, the results will be
+     wrong for one or both operations. *)
printf "\n=== Expected Behavior Analysis ===\n%!";
printf "Input values: -16 to -1 (all negative)\n%!";
printf "\nFor MAX-POOL (padding should be -infinity):\n%!";
File "test/einsum/test_einsum_parser.ml", line 1, characters 0-0:
diff --git a/_build/default/test/einsum/test_einsum_parser.ml b/_build/default/test/einsum/.formatted/test_einsum_parser.ml
index c305169..ac88c8a 100644
--- a/_build/default/test/einsum/test_einsum_parser.ml
+++ b/_build/default/test/einsum/.formatted/test_einsum_parser.ml
@@ -12,8 +12,7 @@ let test_single_char () =
(* Test 2: With batch and input *)
let spec2 = "b|i->o" in
let labels2 = Einsum_parser.axis_labels_of_spec spec2 in
-  printf "  'b|i->o' -> batch:%d input:%d output:%d\n"
-    (List.length labels2.given_batch)
+  printf "  'b|i->o' -> batch:%d input:%d output:%d\n" (List.length labels2.given_batch)
(List.length labels2.given_input)
(List.length labels2.given_output);


@@ -21,13 +20,9 @@ let test_single_char () =
let spec3 = "ij;jk=>ik" in
let l1, l2_opt, l3 = Einsum_parser.einsum_of_spec spec3 in
let l2 = Option.value_exn l2_opt in
-  printf "  'ij;jk=>ik' -> (%d,%d);(%d,%d)=>(%d,%d)\n"
-    (List.length l1.given_input)
-    (List.length l1.given_output)
-    (List.length l2.given_input)
-    (List.length l2.given_output)
-    (List.length l3.given_input)
-    (List.length l3.given_output);
+  printf "  'ij;jk=>ik' -> (%d,%d);(%d,%d)=>(%d,%d)\n" (List.length l1.given_input)
+    (List.length l1.given_output) (List.length l2.given_input) (List.length l2.given_output)
+    (List.length l3.given_input) (List.length l3.given_output);


printf "\n"


File "test/einsum/test_conv_syntax.ml", line 1, characters 0-0:
diff --git a/_build/default/test/einsum/test_conv_syntax.ml b/_build/default/test/einsum/.formatted/test_conv_syntax.ml
index bb97681..028dbb4 100644
--- a/_build/default/test/einsum/test_conv_syntax.ml
+++ b/_build/default/test/einsum/.formatted/test_conv_syntax.ml
@@ -8,43 +8,50 @@ let test_conv_parsing () =
let spec1 = "2*o+3*k" in
let labels1 = Einsum_parser.axis_labels_of_spec spec1 in
printf "Test 1: Parsed '%s' successfully\n%!" spec1;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels1));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels1));


(* Test 2: Simple conv expression without coefficients (multichar - requires commas) *)
let spec2 = "o+k" in
let labels2 = Einsum_parser.axis_labels_of_spec spec2 in
printf "Test 2: Parsed '%s' successfully\n%!" spec2;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels2));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels2));


(* Test 3: Mixed spec with comma (multichar mode) *)
let spec3 = "a, 2*b+c" in
let labels3 = Einsum_parser.axis_labels_of_spec spec3 in
printf "Test 3: Parsed '%s' successfully\n%!" spec3;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels3));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels3));


(* Test 4: Conv expression with multiple identifiers (multichar - requires commas) *)
let spec4 = "i, o+k, j" in
let labels4 = Einsum_parser.axis_labels_of_spec spec4 in
printf "Test 4: Parsed '%s' successfully (multichar mode)\n%!" spec4;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels4));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels4));


(* Test 5: Conv expression with multi-char identifiers (multichar) *)
let spec5 = "a+bc" in
let labels5 = Einsum_parser.axis_labels_of_spec spec5 in
printf "Test 5: Parsed '%s' successfully (multichar mode)\n%!" spec5;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels5));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels5));


(* Test 6: Test in einsum notation with multichar conv *)
let spec6 = "i, j -> 2*i+j" in
let labels6 = Einsum_parser.axis_labels_of_spec spec6 in
printf "Test 6: Parsed '%s' successfully\n%!" spec6;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels6));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels6));


(* Test 7: Complex batch-input-output spec with conv *)
let spec7 = "batch|input->3*output+1*kernel," in
let labels7 = Einsum_parser.axis_labels_of_spec spec7 in
printf "Test 7: Parsed '%s' successfully\n%!" spec7;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels7));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels7));


printf "All conv syntax parsing tests passed!\n%!"


@@ -55,25 +62,29 @@ let test_strided_iteration_parsing () =
let spec1 = "2*output" in
let labels1 = Einsum_parser.axis_labels_of_spec spec1 in
printf "Test 1: Parsed strided iteration '%s' successfully\n%!" spec1;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels1));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels1));


(* Test 2: Strided iteration with single-char identifier (multichar mode) *)
let spec2 = "3*i" in
let labels2 = Einsum_parser.axis_labels_of_spec spec2 in
printf "Test 2: Parsed strided iteration '%s' successfully\n%!" spec2;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels2));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels2));


(* Test 3: Strided iteration in einsum context (multichar due to multiplication) *)
let spec3 = "input -> 2*output" in
let labels3 = Einsum_parser.axis_labels_of_spec spec3 in
printf "Test 3: Parsed einsum with strided iteration '%s' successfully\n%!" spec3;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels3));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels3));


(* Test 4: Mixed regular labels and strided iteration (multichar due to comma) *)
let spec4 = "regular, 3*strided" in
let labels4 = Einsum_parser.axis_labels_of_spec spec4 in
printf "Test 4: Parsed mixed labels with strided iteration '%s' successfully\n%!" spec4;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels4));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels4));


printf "\nAll strided iteration parsing tests completed!\n%!"


@@ -138,37 +149,43 @@ let test_use_padding_syntax () =
let spec1 = "o=+k" in
let labels1 = Einsum_parser.axis_labels_of_spec spec1 in
printf "Test 1: Parsed '%s' (use_padding=true)\n%!" spec1;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels1));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels1));


(* Test 2: use_padding=false with < syntax *)
let spec2 = "o<+k" in
let labels2 = Einsum_parser.axis_labels_of_spec spec2 in
printf "Test 2: Parsed '%s' (use_padding=false)\n%!" spec2;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels2));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels2));


(* Test 3: use_padding with stride *)
let spec3 = "2*o=+k" in
let labels3 = Einsum_parser.axis_labels_of_spec spec3 in
printf "Test 3: Parsed '%s' (stride with use_padding=true)\n%!" spec3;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels3));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels3));


(* Test 4: use_padding with dilation *)
let spec4 = "o<+3*k" in
let labels4 = Einsum_parser.axis_labels_of_spec spec4 in
printf "Test 4: Parsed '%s' (dilation with use_padding=false)\n%!" spec4;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels4));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels4));


(* Test 5: use_padding with stride and dilation *)
let spec5 = "2*o=+3*k" in
let labels5 = Einsum_parser.axis_labels_of_spec spec5 in
printf "Test 5: Parsed '%s' (stride, dilation, use_padding=true)\n%!" spec5;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels5));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels5));


(* Test 6: unspecified use_padding (legacy syntax) *)
let spec6 = "o+k" in
let labels6 = Einsum_parser.axis_labels_of_spec spec6 in
printf "Test 6: Parsed '%s' (unspecified use_padding)\n%!" spec6;
-  printf "  Structure: %s\n\n%!" (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels6));
+  printf "  Structure: %s\n\n%!"
+    (Sexp.to_string_hum (Einsum_parser.sexp_of_parsed_axis_labels labels6));


printf "All use_padding syntax tests completed!\n%!"


File "test/einsum/test_max_pool2d.ml", line 1, characters 0-0:
diff --git a/_build/default/test/einsum/test_max_pool2d.ml b/_build/default/test/einsum/.formatted/test_max_pool2d.ml
index 2e44dd4..b4415ec 100644
--- a/_build/default/test/einsum/test_max_pool2d.ml
+++ b/_build/default/test/einsum/.formatted/test_max_pool2d.ml
@@ -133,13 +133,11 @@ let test_max_pool2d_backprop () =
printf "\nTesting backprop for max_pool2d...\n%!";
Tensor.unsafe_reinitialize ();


-  (* Create a 4x4 input with 1 channel using a parameter (requires grad).
-     Design: each 2x2 window has its max in a different position:
-     Window positions: (row within window, col within window)
-     - Top-left window [0-1, 0-1]: max 9 at (0,0)
-     - Top-right window [0-1, 2-3]: max 8 at (1,1)
-     - Bottom-left window [2-3, 0-1]: max 7 at (0,1)
-     - Bottom-right window [2-3, 2-3]: max 6 at (1,0) *)
+  (* Create a 4x4 input with 1 channel using a parameter (requires grad). Design: each 2x2 window
+     has its max in a different position: Window positions: (row within window, col within window) -
+     Top-left window [0-1, 0-1]: max 9 at (0,0) - Top-right window [0-1, 2-3]: max 8 at (1,1) -
+     Bottom-left window [2-3, 0-1]: max 7 at (0,1) - Bottom-right window [2-3, 2-3]: max 6 at
+     (1,0) *)
let%op input =
{
x =
File "arrayjit/lib/indexing.ml", line 1, characters 0-0:
diff --git a/_build/default/arrayjit/lib/indexing.ml b/_build/default/arrayjit/lib/.formatted/indexing.ml
index 018c9e9..0d03161 100644
--- a/_build/default/arrayjit/lib/indexing.ml
+++ b/_build/default/arrayjit/lib/.formatted/indexing.ml
@@ -139,8 +139,8 @@ type projections = {
*)
product_iterators : symbol array;
(** The product space iterators (concatentation of the relevant batch, output, input axes) for
-          iterating over the [product_space] axes, where same axes are at same array indices.
-          These may be shared; lowering creates fresh symbols for loop indices. *)
+          iterating over the [product_space] axes, where same axes are at same array indices. These
+          may be shared; lowering creates fresh symbols for loop indices. *)
project_lhs : axis_index array;
(** A projection that takes an [product_space]-bound index and produces an index into the
result of an operation. *)
File "test/einsum/test_conv_padding.ml", line 1, characters 0-0:
diff --git a/_build/default/test/einsum/test_conv_padding.ml b/_build/default/test/einsum/.formatted/test_conv_padding.ml
index 2415a61..bbdc511 100644
--- a/_build/default/test/einsum/test_conv_padding.ml
+++ b/_build/default/test/einsum/.formatted/test_conv_padding.ml
@@ -128,16 +128,15 @@ let test_conv2d_stride_with_padding_backprop () =


(** Test conv2d with stride=2 and use_padding=false.


-    With stride=2 and use_padding=false, output dims are (input - kernel) / stride + 1.
-    IMPORTANT: For no-padding convolutions, (input - kernel) must be divisible by stride.
-    For 9x9 input, kernel_size=3, stride=2: (9-3)/2 + 1 = 4, so output should be 4x4. *)
+    With stride=2 and use_padding=false, output dims are (input - kernel) / stride + 1. IMPORTANT:
+    For no-padding convolutions, (input - kernel) must be divisible by stride. For 9x9 input,
+    kernel_size=3, stride=2: (9-3)/2 + 1 = 4, so output should be 4x4. *)
let test_conv2d_stride_without_padding () =
printf "Testing conv2d with stride=2 and use_padding=false...\n%!";
Tensor.unsafe_reinitialize ();


-  (* Create a 9x9 input with 1 channel - sized for stride=2, kernel=3 without padding.
-     For no-padding conv: (input - kernel) must be divisible by stride.
-     (9 - 3) = 6, 6 % 2 = 0 ✓ *)
+  (* Create a 9x9 input with 1 channel - sized for stride=2, kernel=3 without padding. For
+     no-padding conv: (input - kernel) must be divisible by stride. (9 - 3) = 6, 6 % 2 = 0 ✓ *)
let input = TDSL.range_of_shape ~output_dims:[ 9; 9; 1 ] () in


(* Apply conv2d with kernel_size=3, stride=2, use_padding=false, out_channels=4 *)
@@ -164,16 +163,15 @@ let test_conv2d_stride_without_padding () =
This tests that shape inference works correctly during backpropagation for strided convolutions
without padding.


-    IMPORTANT: For no-padding convolutions, (input - kernel) must be divisible by stride,
-    otherwise shape inference will fail with "incompatible stride" error. *)
+    IMPORTANT: For no-padding convolutions, (input - kernel) must be divisible by stride, otherwise
+    shape inference will fail with "incompatible stride" error. *)
let test_conv2d_stride_without_padding_backprop () =
printf "\nTesting backprop for conv2d with stride=2 and use_padding=false...\n%!";
Tensor.unsafe_reinitialize ();


-  (* Create a 9x9 input with 1 channel - sized for stride=2, kernel=3 without padding.
-     For no-padding conv: (input - kernel) must be divisible by stride.
-     (9 - 3) = 6, 6 % 2 = 0 ✓
-     Output size: (9 - 3) / 2 + 1 = 4, so 4x4 output. *)
+  (* Create a 9x9 input with 1 channel - sized for stride=2, kernel=3 without padding. For
+     no-padding conv: (input - kernel) must be divisible by stride. (9 - 3) = 6, 6 % 2 = 0 ✓ Output
+     size: (9 - 3) / 2 + 1 = 4, so 4x4 output. *)
let input = TDSL.range_of_shape ~output_dims:[ 9; 9; 1 ] () in


(* Apply conv2d with kernel_size=3, stride=2, use_padding=false, out_channels=4 *)
File "test/einsum/test_tropical_kernel.ml", line 1, characters 0-0:
diff --git a/_build/default/test/einsum/test_tropical_kernel.ml b/_build/default/test/einsum/.formatted/test_tropical_kernel.ml
index 08991bf..32bad19 100644
--- a/_build/default/test/einsum/test_tropical_kernel.ml
+++ b/_build/default/test/einsum/.formatted/test_tropical_kernel.ml
@@ -5,20 +5,21 @@ open Stdio


(** Test tropical semiring (max-plus) operations with a learnable kernel.


-    This tests backpropagation for tropical operations with both input (t1/rhs1)
-    and kernel (t2/rhs2) gradients.
+    This tests backpropagation for tropical operations with both input (t1/rhs1) and kernel
+    (t2/rhs2) gradients.


-    The implementation uses `_rhs1` suffix for both input and kernel gradient paths.
-    This gives condition tensors input shape (ih,iw) which is effectively the "outer
-    product" of output (oh,ow) and kernel (wh,ww) dimensions. This correctly tracks
-    which (input position, kernel position) pair achieved the argmax for each output. *)
+    The implementation uses `_rhs1` suffix for both input and kernel gradient paths. This gives
+    condition tensors input shape (ih,iw) which is effectively the "outer product" of output (oh,ow)
+    and kernel (wh,ww) dimensions. This correctly tracks which (input position, kernel position)
+    pair achieved the argmax for each output. *)


(** Create a tropical convolution-like operation with a learnable kernel.


-    This is similar to max_pool2d but with a non-zero learnable kernel, allowing us to
-    verify that g2 (kernel) gradients are computed correctly.
+    This is similar to max_pool2d but with a non-zero learnable kernel, allowing us to verify that
+    g2 (kernel) gradients are computed correctly.


-    The tropical operation computes: output[oh,ow] = max over (wh,ww) of (input[2*oh+wh, 2*ow+ww] + kernel[wh,ww])
+    The tropical operation computes: output[oh,ow] = max over (wh,ww) of (input[2*oh+wh, 2*ow+ww] +
+    kernel[wh,ww])


For backprop:
- g1 (input grad): gradient flows to input positions that achieved the argmax
@@ -27,8 +28,9 @@ let tropical_conv2d ?(stride = 2) ?(window_size = 2) () =
let%op op x kernel =
Shape.set_dim wh window_size;
Shape.set_dim ww window_size;
-    x @^+ "... | stride*oh< + wh, stride*ow< + ww, ..c..; wh, ww => ... | oh, ow, ..c.." [ "wh"; "ww" ]
-          kernel
+    x
+    @^+ "... | stride*oh< + wh, stride*ow< + ww, ..c..; wh, ww => ... | oh, ow, ..c.."
+          [ "wh"; "ww" ] kernel
in
op


@@ -73,19 +75,10 @@ let test_tropical_kernel_forward () =


This is the key test: verifies that gradients flow correctly to both input and kernel.


-    Input pattern (4x4, values designed so argmax varies):
-    ```
-      [[9, 0, 0, 0],
-       [0, 0, 0, 8],
-       [0, 7, 0, 0],
-       [0, 0, 6, 0]]
-    ```
+    Input pattern (4x4, values designed so argmax varies): ```
+    [[9, 0, 0, 0], [0, 0, 0, 8], [0, 7, 0, 0], [0, 0, 6, 0]] ```


-    Kernel (2x2, small values so input determines argmax):
-    ```
-      [[0, 0],
-       [0, 0]]
-    ```
+    Kernel (2x2, small values so input determines argmax): ``` [[0, 0], [0, 0]] ```


With zero kernel, this is like max_pool2d - argmax is at input max positions.
- Window [0,0]: max at (0,0)=9, argmax kernel position (0,0)
@@ -93,8 +86,8 @@ let test_tropical_kernel_forward () =
- Window [1,0]: max at (2,1)=7, argmax kernel position (0,1)
- Window [1,1]: max at (3,2)=6, argmax kernel position (1,0)


-    Expected input gradients: 1 at positions (0,0), (1,3), (2,1), (3,2); 0 elsewhere.
-    Expected kernel gradients: 1 at each position (each is argmax for exactly one output). *)
+    Expected input gradients: 1 at positions (0,0), (1,3), (2,1), (3,2); 0 elsewhere. Expected
+    kernel gradients: 1 at each position (each is argmax for exactly one output). *)
let test_tropical_kernel_backprop_zero_kernel () =
printf "Testing tropical conv backprop with zero kernel...\n%!";
Tensor.unsafe_reinitialize ();
@@ -140,27 +133,18 @@ let test_tropical_kernel_backprop_zero_kernel () =


(** Test tropical conv backprop with non-zero kernel that affects argmax.


-    Input (4x4, uniform low values):
-    ```
-      [[1, 1, 1, 1],
-       [1, 1, 1, 1],
-       [1, 1, 1, 1],
-       [1, 1, 1, 1]]
+    Input (4x4, uniform low values): ``` [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
```


-    Kernel (2x2, large value at position (1,1)):
-    ```
-      [[0, 0],
-       [0, 10]]
-    ```
+    Kernel (2x2, large value at position (1,1)): ``` [[0, 0], [0, 10]] ```


-    With this kernel, the argmax for every window is at kernel position (1,1)
-    because 1+10=11 > 1+0=1 for all other positions.
+    With this kernel, the argmax for every window is at kernel position (1,1) because 1+10=11 >
+    1+0=1 for all other positions.


-    Expected output: all 11 (value 1 + kernel 10 at position (1,1) of each window)
-    Expected input gradients: 1 at positions (1,1), (1,3), (3,1), (3,3); 0 elsewhere
-      (these are the input positions corresponding to kernel (1,1) in each window)
-    Expected kernel gradients: [[0,0],[0,4]] - only (1,1) was argmax, used 4 times *)
+    Expected output: all 11 (value 1 + kernel 10 at position (1,1) of each window) Expected input
+    gradients: 1 at positions (1,1), (1,3), (3,1), (3,3); 0 elsewhere (these are the input positions
+    corresponding to kernel (1,1) in each window) Expected kernel gradients: [[0,0],[0,4]] - only
+    (1,1) was argmax, used 4 times *)
let test_tropical_kernel_backprop_nonzero_kernel () =
printf "Testing tropical conv backprop with non-zero kernel...\n%!";
Tensor.unsafe_reinitialize ();
File "tensor/ppx_shared.ml", line 1, characters 0-0:
diff --git a/_build/default/tensor/ppx_shared.ml b/_build/default/tensor/.formatted/ppx_shared.ml
index bee7b1b..7225cf0 100644
--- a/_build/default/tensor/ppx_shared.ml
+++ b/_build/default/tensor/.formatted/ppx_shared.ml
@@ -115,12 +115,12 @@ let ndarray_constant expr =


(** Convert an einsum spec string to an OCaml expression that constructs the runtime string.


-    This function parses the einsum spec using the Einsum_parser, then reconstructs a runtime
-    string expression, handling:
-    - stride and dilation values: if they look like integer literals, emit them directly;
-      otherwise emit [Int.to_string identifier] to convert at runtime
-    - use_padding: if unspecified (legacy syntax), emit [if use_padding then "=" else "<"]
-      to read the value from [Row.use_padding] at runtime
+    This function parses the einsum spec using the Einsum_parser, then reconstructs a runtime string
+    expression, handling:
+    - stride and dilation values: if they look like integer literals, emit them directly; otherwise
+      emit [Int.to_string identifier] to convert at runtime
+    - use_padding: if unspecified (legacy syntax), emit [if use_padding then "=" else "<"] to read
+      the value from [Row.use_padding] at runtime


Example: ["stride*x=+k; y => z"] where [stride] is a variable, generates an expression that
evaluates to e.g. ["2*x=+k; y => z"] if [stride = 2]. *)
@@ -216,11 +216,12 @@ let substitute_identifiers_in_einsum_spec ~loc str_input =
let output_segments =
row_to_segments ~kind:"output" parsed.bcast_output parsed.given_beg_output parsed.given_output
in
-    let has_batch = not (List.is_empty batch_segments) || Option.is_some parsed.bcast_batch in
-    let has_input = not (List.is_empty input_segments) || Option.is_some parsed.bcast_input in
+    let has_batch = (not (List.is_empty batch_segments)) || Option.is_some parsed.bcast_batch in
+    let has_input = (not (List.is_empty input_segments)) || Option.is_some parsed.bcast_input in
let segments =
if has_batch then
-        batch_segments @ [ estring ~loc "|" ]
+        batch_segments
+        @ [ estring ~loc "|" ]
@ (if has_input then input_segments @ [ estring ~loc "->" ] else [])
@ output_segments
else if has_input then input_segments @ [ estring ~loc "->" ] @ output_segments
@@ -248,33 +249,34 @@ let substitute_identifiers_in_einsum_spec ~loc str_input =
let combined =
String.concat
(List.filter_map all_segments ~f:(fun e ->
-               match e.pexp_desc with Pexp_constant (Pconst_string (s, _, _)) -> Some s | _ -> None))
+               match e.pexp_desc with
+               | Pexp_constant (Pconst_string (s, _, _)) -> Some s
+               | _ -> None))
in
estring ~loc combined
else [%expr String.concat ~sep:"" [%e elist ~loc all_segments]]
-  with Parse_error _ ->
+  with Parse_error _ -> (
(* If parsing fails, try as axis_labels_spec *)
-    (try
-       let parsed = axis_labels_of_spec str_input in
-       let segments = parsed_to_segments parsed in
-       let all_literals =
-         List.for_all segments ~f:(fun e ->
-             match e.pexp_desc with Pexp_constant (Pconst_string _) -> true | _ -> false)
-       in
-       if all_literals then
-         let combined =
-           String.concat
-             (List.filter_map segments ~f:(fun e ->
-                  match e.pexp_desc with
-                  | Pexp_constant (Pconst_string (s, _, _)) -> Some s
-                  | _ -> None))
-         in
-         estring ~loc combined
-       else [%expr String.concat ~sep:"" [%e elist ~loc segments]]
-     with Parse_error msg ->
-       (* Fall back to returning the original string with an error note *)
-       pexp_extension ~loc
-       @@ Location.error_extensionf ~loc "Failed to parse einsum spec: %s" msg)
+    try
+      let parsed = axis_labels_of_spec str_input in
+      let segments = parsed_to_segments parsed in
+      let all_literals =
+        List.for_all segments ~f:(fun e ->
+            match e.pexp_desc with Pexp_constant (Pconst_string _) -> true | _ -> false)
+      in
+      if all_literals then
+        let combined =
+          String.concat
+            (List.filter_map segments ~f:(fun e ->
+                 match e.pexp_desc with
+                 | Pexp_constant (Pconst_string (s, _, _)) -> Some s
+                 | _ -> None))
+        in
+        estring ~loc combined
+      else [%expr String.concat ~sep:"" [%e elist ~loc segments]]
+    with Parse_error msg ->
+      (* Fall back to returning the original string with an error note *)
+      pexp_extension ~loc @@ Location.error_extensionf ~loc "Failed to parse einsum spec: %s" msg)


let string_expr ~loc s = Ast_helper.Exp.constant @@ Pconst_string (s, loc, None)


@@ -546,11 +548,7 @@ let let_opt ~loc vbs expr =
(* Check for duplicates and create nested let bindings preserving definition order *)
let seen = Hashtbl.create (module String) in
List.fold_right vbs ~init:expr ~f:(fun vb acc ->
-      let name =
-        match vb.pvb_pat.ppat_desc with
-        | Ppat_var { txt; _ } -> txt
-        | _ -> "_"
-      in
+      let name = match vb.pvb_pat.ppat_desc with Ppat_var { txt; _ } -> txt | _ -> "_" in
match Hashtbl.add seen ~key:name ~data:() with
| `Ok -> Ast_helper.Exp.let_ ~loc Nonrecursive [ vb ] acc
| `Duplicate ->
@@ -565,7 +563,6 @@ let let_opt ~loc vbs expr =
Ast_helper.Exp.let_ ~loc Nonrecursive [ { vb with pvb_expr = error_expr } ] acc)


let no_vbs = []
-
let reduce_vbss vbss = List.concat vbss


let expr_expander_with_punning translate ~loc ~path:_ payload =
File "tensor/operation.ml", line 1, characters 0-0:
diff --git a/_build/default/tensor/operation.ml b/_build/default/tensor/.formatted/operation.ml
index aa8ddae..cd89772 100644
--- a/_build/default/tensor/operation.ml
+++ b/_build/default/tensor/.formatted/operation.ml
@@ -430,10 +430,10 @@ let einmax1 ?(capture_dims = []) spec =


(** This generalizes the tropical matrix multiplication to arbitrary indices combinations.


-    LIMITATION: Backpropagation is only correct when the RHS1 (t1) index space includes
-    the RHS2 (t2) index space. This is the case for convolution-like operations where
-    the kernel indices are contracted with strided input indices. For general tropical
-    operations where RHS2 has independent indices, the g2 gradient will be incorrect. *)
+    LIMITATION: Backpropagation is only correct when the RHS1 (t1) index space includes the RHS2
+    (t2) index space. This is the case for convolution-like operations where the kernel indices are
+    contracted with strided input indices. For general tropical operations where RHS2 has
+    independent indices, the g2 gradient will be incorrect. *)
let tropical ?(capture_dims = []) spec =
let module NTDSL = struct
include Initial_NTDSL
@@ -441,8 +441,8 @@ let tropical ?(capture_dims = []) spec =
end in
let%cd op_asn ~t ~t1 ~t2 ~projections = v =:@^ v1 + v2 in
let%cd grad_asn ~t ~g ~t1 ~t2 ~projections =
-    (* Use _rhs1 suffix for both: gives input shape (ih,iw) = (oh,ow) x (wh,ww) outer product.
-       This correctly tracks which (input position, kernel position) pair achieved argmax. *)
+    (* Use _rhs1 suffix for both: gives input shape (ih,iw) = (oh,ow) x (wh,ww) outer product. This
+       correctly tracks which (input position, kernel position) pair achieved argmax. *)
{ sum_rhs1 } =:@^ add (t1, t2);
{ cond_rhs1 } =: eq (t, sum_rhs1);
g1 =+ where cond_rhs1 g 0;
File "test/operations/test_random_histograms.ml", line 1, characters 0-0:
diff --git a/_build/default/test/operations/test_random_histograms.ml b/_build/default/test/operations/.formatted/test_random_histograms.ml
index e527f4b..a91ffd4 100644
--- a/_build/default/test/operations/test_random_histograms.ml
+++ b/_build/default/test/operations/.formatted/test_random_histograms.ml
@@ -6,20 +6,18 @@ open Stdio


IMPORTANT: Understanding OCANNL's counter-based PRNG architecture:


-    The [uniform_at], [normal_at], [kaiming_at], [xavier_at] functions use a counter-based
-    PRNG (Threefry). The [counter] argument is NOT meant to determine the output shape!
-    It is a "mix-in" to bifurcate randomness across different counter values.
+    The [uniform_at], [normal_at], [kaiming_at], [xavier_at] functions use a counter-based PRNG
+    (Threefry). The [counter] argument is NOT meant to determine the output shape! It is a "mix-in"
+    to bifurcate randomness across different counter values.


-    The architecture:
-    1. [counter] should be scalar or small (dimension-1) so it broadcasts to any result shape
-    2. [Range_over_offsets] generates indices over the result shape for mixing
-    3. [uint4x32_to_prec_uniform] reshapes from the uint4x32 backbone to the target shape
-    4. The output shape is determined by shape inference from how the result is used
+    The architecture: 1. [counter] should be scalar or small (dimension-1) so it broadcasts to any
+    result shape 2. [Range_over_offsets] generates indices over the result shape for mixing 3.
+    [uint4x32_to_prec_uniform] reshapes from the uint4x32 backbone to the target shape 4. The output
+    shape is determined by shape inference from how the result is used


For [kaiming] and [xavier] operations:
- The result tensor's shape determines fan_in/fan_out through einsum dimension capture
-    - The counter is just for randomness bifurcation (e.g., different steps in training)
-*)
+    - The counter is just for randomness bifurcation (e.g., different steps in training) *)


let create_histogram values ~num_bins ~min_val ~max_val =
let bins = Array.create ~len:num_bins 0 in
@@ -42,14 +40,13 @@ let print_histogram bins ~title ~max_width =
let percentage = Float.of_int count /. Float.of_int total *. 100.0 in
printf "Bin %2d: %s %4d (%.1f%%)\n" i bar count percentage)


-(** Test uniform_at with a SCALAR counter, letting shape be inferred from usage.
-    This is the correct way to use uniform_at - counter is for randomness bifurcation,
-    not for determining the output shape. *)
+(** Test uniform_at with a SCALAR counter, letting shape be inferred from usage. This is the correct
+    way to use uniform_at - counter is for randomness bifurcation, not for determining the output
+    shape. *)
let test_uniform_at_with_shape () =
Tensor.unsafe_reinitialize ();
let ctx = Context.auto () in
let module O = TDSL.O in
-
(* Scalar counter - just for randomness bifurcation *)
let counter = NTDSL.number 44.0 in


@@ -105,14 +102,13 @@ let test_uniform_at_with_shape () =


(** Test normal_at1 which works pointwise (one output per uint4x32 input).


-    NOTE: normal_at internally uses box_muller which creates TWO uniform random tensors.
-    The non-1 variants have shape constraints from uint4x32. Use normal_at1 which works
-    pointwise, combined with a target tensor to drive shape inference. *)
+    NOTE: normal_at internally uses box_muller which creates TWO uniform random tensors. The non-1
+    variants have shape constraints from uint4x32. Use normal_at1 which works pointwise, combined
+    with a target tensor to drive shape inference. *)
let test_normal_at_with_shape () =
Tensor.unsafe_reinitialize ();
let ctx = Context.auto () in
let module O = TDSL.O in
-
(* Scalar counter for randomness bifurcation *)
let counter = NTDSL.number 123.0 in


@@ -200,8 +196,8 @@ let test_normal_at_with_shape () =


printf "\nOverall: %s\n" (if all_passed then "ALL TESTS PASSED" else "SOME TESTS FAILED")


-(** Test that different counter values produce different random sequences.
-    This demonstrates the counter's purpose: bifurcating randomness. *)
+(** Test that different counter values produce different random sequences. This demonstrates the
+    counter's purpose: bifurcating randomness. *)
let test_counter_bifurcation () =
printf "\nCounter Bifurcation Test\n";
printf "========================\n";
@@ -245,12 +241,12 @@ let test_counter_bifurcation () =
if !diff_count > 90 && !same_count = num_values then printf "\nBifurcation test: PASS\n"
else printf "\nBifurcation test: FAIL\n"


-(** Test kaiming_at with proper shape structure.
-    The result tensor needs input dimensions for kaiming to extract fan_in.
+(** Test kaiming_at with proper shape structure. The result tensor needs input dimensions for
+    kaiming to extract fan_in.


-    This test demonstrates specifying dimensions explicitly via TDSL (not TDSL.O).
-    The counter is scalar (for randomness bifurcation), and output shape is given
-    directly to uniform_at via ~input_dims and ~output_dims. *)
+    This test demonstrates specifying dimensions explicitly via TDSL (not TDSL.O). The counter is
+    scalar (for randomness bifurcation), and output shape is given directly to uniform_at via
+    ~input_dims and ~output_dims. *)
let test_kaiming_at_with_proper_shape () =
Tensor.unsafe_reinitialize ();
let ctx = Context.auto () in
@@ -262,12 +258,10 @@ let test_kaiming_at_with_proper_shape () =
(* Scalar counter for randomness bifurcation *)
let counter = NTDSL.number 45.0 in


-  (* Use TDSL.uniform_at (not TDSL.O.uniform_at) to specify dimensions explicitly.
-     This is an alternative to shape inference from a target tensor. *)
+  (* Use TDSL.uniform_at (not TDSL.O.uniform_at) to specify dimensions explicitly. This is an
+     alternative to shape inference from a target tensor. *)
let kaiming_values =
-    TDSL.kaiming_at ~input_dims:[ fan_in ] ~output_dims:[ fan_out ]
-      TDSL.O.uniform_at
-      counter ()
+    TDSL.kaiming_at ~input_dims:[ fan_in ] ~output_dims:[ fan_out ] TDSL.O.uniform_at counter ()
in
Ir.Tnode.update_prec kaiming_values.value Ir.Ops.single;


@@ -276,8 +270,8 @@ let test_kaiming_at_with_proper_shape () =
ignore (Ocannl.Train.forward_once ctx kaiming_values);
let result = Ir.Tnode.get_values kaiming_values.value in


-  (* Expected: uniform [0,1) scaled by sqrt(6/fan_in) = sqrt(6/100) ≈ 0.245
-     So values should be in [0, 0.245) with mean ≈ 0.122 *)
+  (* Expected: uniform [0,1) scaled by sqrt(6/fan_in) = sqrt(6/100) ≈ 0.245 So values should be in
+     [0, 0.245) with mean ≈ 0.122 *)
let expected_scale = Float.sqrt (6.0 /. Float.of_int fan_in) in


printf "Kaiming Initialization Test (fan_in=%d, fan_out=%d)\n" fan_in fan_out;
@@ -303,11 +297,13 @@ let test_kaiming_at_with_proper_shape () =


(* Create and print histogram *)
let num_bins = 20 in
-  let bins = create_histogram result ~num_bins ~min_val:(min_val -. 0.01) ~max_val:(max_val +. 0.01) in
+  let bins =
+    create_histogram result ~num_bins ~min_val:(min_val -. 0.01) ~max_val:(max_val +. 0.01)
+  in
print_histogram bins ~title:"Kaiming Distribution Histogram" ~max_width:40


-(** Test xavier_at with proper shape structure.
-    Xavier needs both input and output dimensions for scaling.
+(** Test xavier_at with proper shape structure. Xavier needs both input and output dimensions for
+    scaling.


Similar to kaiming test, uses TDSL with explicit dimensions. *)
let test_xavier_at_with_proper_shape () =
@@ -323,9 +319,7 @@ let test_xavier_at_with_proper_shape () =


(* Use TDSL.uniform_at with explicit dimensions *)
let xavier_values =
-    TDSL.xavier_at ~input_dims:[ fan_in ] ~output_dims:[ fan_out ]
-      TDSL.O.uniform_at
-      counter ()
+    TDSL.xavier_at ~input_dims:[ fan_in ] ~output_dims:[ fan_out ] TDSL.O.uniform_at counter ()
in
Ir.Tnode.update_prec xavier_values.value Ir.Ops.single;


@@ -360,7 +354,9 @@ let test_xavier_at_with_proper_shape () =


(* Create and print histogram *)
let num_bins = 20 in
-  let bins = create_histogram result ~num_bins ~min_val:(min_val -. 0.01) ~max_val:(max_val +. 0.01) in
+  let bins =
+    create_histogram result ~num_bins ~min_val:(min_val -. 0.01) ~max_val:(max_val +. 0.01)
+  in
print_histogram bins ~title:"Xavier Distribution Histogram" ~max_width:40


let () =
File "arrayjit/lib/tnode.ml", line 1, characters 0-0:
diff --git a/_build/default/arrayjit/lib/tnode.ml b/_build/default/arrayjit/lib/.formatted/tnode.ml
index cb92c2a..81a4eab 100644
--- a/_build/default/arrayjit/lib/tnode.ml
+++ b/_build/default/arrayjit/lib/.formatted/tnode.ml
@@ -712,8 +712,8 @@ let create_with_reshape ~id ~label ~base_ndarray ~unpadded_dims ~padding ~from_p
in
Some (Nd.apply_with_prec { f = f_reshape_with_prec } base_ndarray)
| Some _, false ->
-           (* Create new bigarray with padding and copy source into non-padding parts.
-              semantic_dims are the data area dimensions (without padding). *)
+           (* Create new bigarray with padding and copy source into non-padding parts. semantic_dims
+              are the data area dimensions (without padding). *)
let target = Nd.create_array ~debug prec_val ~dims:padded_dims ~padding:target_padding in
let source_dims = Nd.dims base_ndarray in
(* Check total elements match, allowing shape differences *)
File "tensor/tensor.ml", line 1, characters 0-0:
diff --git a/_build/default/tensor/tensor.ml b/_build/default/tensor/.formatted/tensor.ml
index 763936c..0620ae4 100644
--- a/_build/default/tensor/tensor.ml
+++ b/_build/default/tensor/.formatted/tensor.ml
@@ -115,7 +115,7 @@ let iter_embedded ~f t =
Option.iter t.diff ~f:(fun diff -> Set.iter ~f diff.backprop.embedded_nodes)


(* Global singleton for random seed, used in init_params and random number generation *)
-let random_seed : (t option) ref = ref None
+let random_seed : t option ref = ref None


let%debug7_sexp rec init_params ?skip (t : t) : Asgns.comp =
let more_embedded = ref @@ Set.empty (module Tn) in
@@ -142,9 +142,9 @@ let%debug7_sexp rec init_params ?skip (t : t) : Asgns.comp =
Set.add (Set.union acc p.forward.embedded_nodes) p.value)
in
(* Handle random_seed specially: it's a global singleton whose forward code might have been
-     "stolen" by a tensor that isn't part of params (e.g., from an untaken conditional branch).
-     If random_seed exists and was used (no longer a fwd_root) but not in embedded_nodes,
-     we need to include random_seed's own embedded_nodes. *)
+     "stolen" by a tensor that isn't part of params (e.g., from an untaken conditional branch). If
+     random_seed exists and was used (no longer a fwd_root) but not in embedded_nodes, we need to
+     include random_seed's own embedded_nodes. *)
let embedded_nodes =
match !random_seed with
| None -> embedded_nodes
@@ -316,7 +316,8 @@ let%track7_sexp op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
let v =
match terminal_op with
| Some (Shape.Data (Asgns.Reshape data)) ->
-        Tn.create_with_reshape ~id ~label ~unpadded_dims ~padding ~from_padded:false ~base_ndarray:data ()
+        Tn.create_with_reshape ~id ~label ~unpadded_dims ~padding ~from_padded:false
+          ~base_ndarray:data ()
| Some (Shape.Data (Asgns.Keep_shape_no_padding data)) ->
Tn.create_from_padded ~id ~label ~ndarray:data ~padding:None ()
| Some (Shape.Data (Asgns.Padded { data; padding = padding_spec; padded_value })) ->
@@ -346,8 +347,8 @@ let%track7_sexp op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
assert false
in
let local_shape_updates =
-    List.map
-      ~f:(fun logic -> Shape.{ shape; logic; id = get_update_id (); unsafe_projections = None })
+    List.map ~f:(fun logic ->
+        Shape.{ shape; logic; id = get_update_id (); unsafe_projections = None })
@@ shape_logics orig_ts
in
List.iter ~f:Shape.propagate_shapes local_shape_updates;
File "arrayjit/lib/low_level.ml", line 1, characters 0-0:
diff --git a/_build/default/arrayjit/lib/low_level.ml b/_build/default/arrayjit/lib/.formatted/low_level.ml
index 3e03331..64884c8 100644
--- a/_build/default/arrayjit/lib/low_level.ml
+++ b/_build/default/arrayjit/lib/.formatted/low_level.ml
@@ -469,7 +469,8 @@ let%diagn2_sexp check_and_store_virtual computations_table traced static_indices
| [] -> None
| [ s ] -> Some s
| _ ->
-                       (* TODO(#133): multiple non-static symbols in affine index not yet supported *)
+                       (* TODO(#133): multiple non-static symbols in affine index not yet
+                          supported *)
raise @@ Non_virtual 51))
in
let num_syms =
File "tensor/row.ml", line 1, characters 0-0:
diff --git a/_build/default/tensor/row.ml b/_build/default/tensor/.formatted/row.ml
index ba77de1..4207c51 100644
--- a/_build/default/tensor/row.ml
+++ b/_build/default/tensor/.formatted/row.ml
@@ -2977,21 +2977,21 @@ and eliminate_row_constraint ~depth stage origin ~terminal ~(lub : row option) (
| { bcast = Broadcastable; _ } -> keep_constr ()
| { bcast = Row_var { v; beg_dims }; dims; prov } -> (
let r1 = row_of_var v prov in
-      (* If lub is not provided from context, try to get it from the row environment.
-         This is critical for non-terminal shapes where LUBs are populated through
-         inequalities but wouldn't otherwise be available until Stage 6.
-         However, we only use the environment LUB if it has fully resolved dimensions
-         (no dimension variables), as partially resolved LUBs can prevent proper
-         constraint resolution. *)
+      (* If lub is not provided from context, try to get it from the row environment. This is
+         critical for non-terminal shapes where LUBs are populated through inequalities but wouldn't
+         otherwise be available until Stage 6. However, we only use the environment LUB if it has
+         fully resolved dimensions (no dimension variables), as partially resolved LUBs can prevent
+         proper constraint resolution. *)
let lub =
match lub with
| Some _ -> lub
| None -> (
match find_row env.row_env v with
-            | Some (Bounds_row { lub = Some env_lub; _ }) ->
-                (* We need to substitute environment dimensions into the LUB to see if it's resolved *)
+            | Some (Bounds_row { lub = Some env_lub; _ }) -> (
+                (* We need to substitute environment dimensions into the LUB to see if it's
+                   resolved *)
let env_lub = subst_row env env_lub in
-                (match collect_factors env_lub.dims with
+                match collect_factors env_lub.dims with
| Some (_, []) -> Some env_lub (* All dims are known constants after substitution *)
| _ -> None (* LUB has unresolved dimension variables or collect_factors failed *))
| _ -> None)
@@ -3075,17 +3075,17 @@ and eliminate_row_constraint ~depth stage origin ~terminal ~(lub : row option) (
[],
Some ({ dims = lub_dims; bcast = _; prov = lub_prov } as lub) )
when is_stage5_up stage && Utils.safe_force coeff > denom -> (
-              (* Check if coeff > denom * product of known dimensions of the LUB.
-                 The constraint is: coeff * var / denom = total_elements(row).
-                 So: var = total_elements * denom / coeff. *)
+              (* Check if coeff > denom * product of known dimensions of the LUB. The constraint is:
+                 coeff * var / denom = total_elements(row). So: var = total_elements * denom /
+                 coeff. *)
match collect_factors lub_dims with
| Some (known_product, []) ->
let coeff_val = Utils.safe_force coeff in
if coeff_val > denom * known_product then
([ Row_eq { r1; r2 = lub; origin } ], env)
else
-                    (* Equate the row variable to the dimensions of the LUB,
-                       and compute var from the total elements *)
+                    (* Equate the row variable to the dimensions of the LUB, and compute var from
+                       the total elements *)
let var_value = known_product * denom / coeff_val in
( [
Row_eq
dune build @fmt failed
"/usr/bin/env" "bash" "-c" "opam exec -- dune build @fmt --ignore-promoted-rules || (echo "dune build @fmt failed"; exit 2)" failed with exit status 2
2025-12-15 21:53.42: Job failed: Failed: Build failed