Architecture & Typing Status
Photon Weave Architecture & Typing Status
High-Level Architecture
Config & RNG:
photon_weave/photon_weave.pyexposesConfig(JIT, contraction, seeds). RNG keys come fromcore/rng.py;__init__.pypinsJAX_ENABLE_LEGACY_RNG. Tests run withJAX_PLATFORMS=cpu.States:
state/base_state.pydefines the abstract base; concrete states live instate/fock.py,state/polarization.py,state/custom_state.py. Each holds dimensions, index within a product, expansion level (Label/Vector/Matrix), and data (statearray or label).Containers:
state/envelope.pypairsFock+Polarization, manages combination/reordering, apply/measure/trace, and bridges to ops and kernels.state/composite_envelope.pycombines multiple envelopes into product states.
Operations:
operation/*_operation.pydefine enums (FockOperationType,PolarizationOperationType,CustomStateOperationType) plusoperation/operation.pywrappers. They compute operators, call intostate/utils/operations.py, which routes tocore/adapters.py/core/kernels.py.Kernels & JIT:
core/kernels.pycontains pure JAX kernels for apply/measure/trace;core/jitted.pywraps them.state/utils/shape_planning.pybuildsShapePlan(dims, target indices) to feed the jitted paths.Measurements:
state/utils/measurements.pyprovides POVM and projective measurement helpers for vectors/matrices and feeds into state methods.Examples & Benchmarks:
examples/*.py(e.g.,mach_zehnder_interferometer.py,super_dense_coding.py,time_bin_encoding.py) andbenchmarks/*exercise the public surface—useful for validating behavior when refactoring.
Current mypy Failure Themes (snapshot)
The code compiles and tests pass, but mypy reports many errors due to interface drift. Main categories:
Protocol vs. concrete mismatch:
BaseStateLike/EnvelopeLikesignatures differ fromBaseState/Envelope(measure/apply_kraus args/returns; property types likeexpansion_level,_num_quanta,uid).Return types from kernels/adapters: Measurement helpers now return JAX scalars, but adapters are still annotated to return
int(e.g.,core/adapters.pyandcore/kernels.pyannotations).Enums annotated: Enum members in
operation/polarization_operation.pyandoperation/fock_operation.pywere annotated; mypy requires unannotated members.Shape planning shims: Local helpers in
state/utils/shape_planning.pylack return annotations and acceptobject, causing arg-type errors when calling apply/measure kernels.Measurements utilities: Inconsistent tuple shapes (Optionals in returns), list mutation typing (
list[Array]vs.Sequence), and missing annotations lead to errors instate/utils/measurements.py.Composite/Envelope typing: Caches (
_plan_cache) and plan builders expectBaseStatebut receiveBaseStateLike;trace_out/measure_POVMarg/return types disagree.BaseState collections: Lists/dicts typed as concrete
BaseStateconflict with protocol expectations (BaseStateLike) across envelope/composite paths.
Refactor Plan (to reduce circular imports & simplify operations)
Phases are ordered to de-risk changes; run JAX_PLATFORMS=cpu poetry run pytest after each chunk and re-run a narrow mypy target to track progress.
Stabilize Protocols
Create minimal
interfaces.pyprotocols per layer:BaseStateProto(dimensions/index/state,_num_quanta,expand,measure,apply_kraus,trace_out,resize,uid,measured),EnvelopeProto(fock, polarization, measure/trace/apply_kraus),CompositeEnvelopeProto.Keep protocol methods permissive (
*args,**kwargs) but set consistent return aliases (e.g.,OutcomeMap = dict[BaseStateProto, int]).Ensure
expansion_level/uidare typed to accept current concrete types (e.g.,ExpansionLevel | None,object).
Align Concrete Classes
Update
BaseState.measure/apply_kraussignatures to match the protocol (allow*states: BaseStateLikeand returnOutcomeMap).Make
_num_quantaa read-only property with concrete implementations inFock/Polarization/CustomState.Normalize
uidtype toobject(orUUID | str) in concrete classes and inOutcomeMapkeys.
Re-layer Imports
Enforce dependency direction:
interfaces→core(kernels/adapters/meta) →state→operation→envelope/composite→examples/tests.Move any interface-only needs out of
state/utilsso they import protocols, not concretes. For example,shape_planningshould work onSequence[BaseStateProto]and returnShapePlanwith only dims/indices, not concrete types.Where circular imports remain, introduce lightweight “view” modules (e.g.,
state/typing.py) that export aliases without importing heavy implementations.
Harmonize Measurement API
Define a single
MeasurementResult = tuple[OutcomeMap, jnp.ndarray | None](or similar) and use it acrossmeasure_vector,measure_matrix,measure_POVM.In
measurements.py, remove Optional tails in return tuples where callers expect concrete arrays; cast JAX scalar outcomes tointonly at API boundaries that require Python ints (tests already handle JAX scalars).Ensure RNG keys are propagated and returned uniformly.
Adapters/Kernels Typing Cleanup
Update return annotations in
core/kernels.pyandcore/adapters.pyto usejnp.ndarray(scalar) for outcomes, notint.Add explicit return types to small helpers (e.g., local lambdas in
shape_planning) to quiet mypy without altering behavior.
Envelope/Composite Simplification
Type caches (
_plan_cache) againstShapePlankeys usingTuple[Tuple[int, int], Tuple[int, int]]or a dedicated key dataclass to avoid tuple-of-tuple indexing errors.Convert collections in envelope/composite paths to
Sequence[BaseStateLike]and cast only at boundaries that require concrete states (e.g., resizing).Reduce double-routing: ensure
apply_operationuses a single path to compute meta and callapply_operation_vector/matrixwith already-ordered state lists.
Operation Enums & Factory
Keep enum members unannotated; expose typed factories that return
Operationwith fully-typedoperatorfields.Co-locate parameter validation with operator construction to reduce cross-module imports.
Verification Using Examples
Use existing examples as smoke tests for surface APIs:
examples/mach_zehnder_interferometer.pyandexamples/time_bin_encoding.pycover envelope combine/measure flows.examples/super_dense_coding.pyexercises composite envelopes and multiple operations.Benchmarks (
benchmarks/lossy_circuit/*.py,benchmarks/contraction_vs_jitted.py) ensure JIT/contraction paths stay performant.
After each phase, run a small set of examples with
JAX_PLATFORMS=cputo confirm behavior matches current outputs (add assertions in a temporary harness if needed; remove before commit).
Migration Steps
Phase 1: Fix protocols and enum annotations; adjust kernel/adapter return types; rerun
mypy core/.Phase 2: Align
BaseState/Envelopesignatures and measurement utilities; rerunmypy state/utils/measurements.py state/envelope.py.Phase 3: Clean up
composite_envelope/shape_planningcaches and casts; rerunmypy state/composite_envelope.py state/utils/shape_planning.py.Final: Full
mypyacrossphoton_weave/andpyteston CPU JAX.
Quick Insights for Debugging Current Execution
RNG determinism: Seeds are set via
Config.set_seed; JAX legacy RNG is enabled. Measurement outcomes in tests accept JAX scalar results.JIT vs. dynamic dims:
Config.use_jitwithdynamic_dimensionsraises inEnvelope.apply_operation; resizing is only allowed when JIT is off.Platform: GitHub Actions and tox set
JAX_PLATFORMS=cpu; avoid GPU-specific behavior in refactors.
Use this document as the living map for the typing cleanup and architectural hardening. Update sections as refactors land and mypy noise drops.
Proposed Extensions & API Simplification
The following proposals aim to make the system more modular, interoperable, and pleasant to use, while keeping compatibility for two development rounds before fully adopting the new scheme.
Intermediate Representation (IR) for Interop
Introduce a small, framework-agnostic IR in
core/ir.py: gates/ops as typed dataclasses (name, params, targets), state specs (dims, basis), and circuits as ordered lists. Keep it JAX-friendly (pytree-compatible) so it can be jitted/vmap’d.Add import/export adapters:
Strawberry Fields: map Gaussian ops and Fock-level ops to IR; export IR back to SF circuits.
Piquasso: translate photonic ops and states via IR.
Qiskit: map polarization/Fock subsets to qubit gates where possible; use IR as the bridge.
Keep IR pure-data; no side effects. Concrete simulators (our kernels or external backends) consume IR.
Unique Base Class per Module
Define a single base class per major module folder:
state/base_state.py→BaseState(already present, align to protocol).operation/base_operation.py→ shared operation base, encapsulating operator caching and parameter validation.core/base_kernel.py(optional) → typed protocol for kernels/adapters to reduce import cycles.
Within each folder, type against the local base/protocol to avoid cross-imports; export light
interfacesthat re-export these bases to other layers.
Math Decoupling into Core
Move or wrap mathematical primitives into
core/math/:Operator builders (creation/annihilation, Pauli, rotations, displacements).
Contraction utilities and einsum patterns.
Measurement probability helpers.
State and operation classes call these helpers; avoid inline math in higher layers. This reduces duplication and keeps JAX/pytree constraints localized.
JIT/vmap/Differentiability First
Enforce that new paths are side-effect free and functional: inputs → outputs without in-place Python mutations. Return updated states instead of mutating when in JIT mode; keep legacy mutating paths for compatibility.
Ensure all new math lives in JAX-compatible functions; use
jaxtyping-style annotations (or doc hints) to flag differentiable paths. Provide a small test matrix to validatejit,vmap, andgradon representative ops.
New API Layer with Compatibility Shim
Add a thin fluent API in
api/(e.g.,api/envelope.py,api/circuit.py) that:Builds envelopes/composites declaratively:
env = EnvelopeBuilder().fock(dim=2, state=1).polarization("R").combine().Applies ops via method chaining:
env.apply(FockOp.displace(alpha)).measure(key=...).Emits IR under the hood and executes via core kernels.
Provide adapters that map the new API calls to legacy
Envelope/CompositeEnvelopemethods for two development cycles. Mark legacy interfaces as “compat mode” and plan a removal milestone.
Roadmap Checkpoints
Phase A (compat): Introduce IR, core math wrappers, and new API facades that delegate to existing classes. Keep mutating behavior but add pure variants for JIT/vmap.
Phase B (migration): Flip default paths to pure/IR-driven execution; keep legacy shims emitting deprecation warnings.
Phase C (cleanup): Remove legacy shims; enforce protocol alignment and mypy cleanliness; examples updated to new API.
Refactoring Direction (Envelope-First Simplicity)
Preserve the public
EnvelopeandCompositeEnvelopeinterfaces exactly as described in the paper (combine, apply, measure, trace_out, metadata such as wavelength/temporal profile). They remain the primary user abstraction and signature surface.Behind those interfaces, map calls onto the lightweight circuit/IR path (
CircuitSpec+ planner + kernels). The envelopes become facades that assemble IR from their state containers and forward to the runtime executor, keeping behavior while reducing internal coupling.Keep typing minimal and concrete: prefer small data records (
StateSpec,OpSpec) and runtime assertions over Protocol-heavy checks; avoidAny/objectexcept at external boundaries. This keeps serialization easy and JAX-friendly.Shape planning should consume only dimension/target tuples, not concrete envelope classes, so the IR layer stays pure-data and easy to serialize/compile.
Compatibility: during Phase A/B, continue to route the existing envelope/composite methods to the new runtime while tests/examples validate parity; deprecation only applies to internal helpers, not the envelope APIs.
Naming & Circuit Mapping
To keep a single, memorable user-facing name while mapping cleanly to the circuit-oriented internals:
Expose the public API under the existing package name (
photon_weave) but brand the circuit DSL as Weave Circuits. Users import fromphoton_weave.weave(e.g.,WeaveCircuit,WeaveOp,WeaveState).Internally,
WeaveCircuitis a light facade over the IR (CircuitSpec,OpSpec,StateSpec); constructors only assemble these specs and never own kernels.Execution routes through the runtime executor (
execute_circuit), which translates specs to planner/kernels. This preserves the unique naming for users while keeping internals pure-data and circuit-behavior driven.Interop layers (
interop/*) accept or emit Weave Circuits, so external frameworks only learn one name while mapping to IR remains an implementation detail.
These changes, combined with the earlier typing refactor plan, should yield a cleaner, interoperable, and JAX-friendly architecture with a gentle migration path for users and downstream frameworks.
Execution Profiling Snapshot (Mach–Zehnder Example)
Timing recorded on CPU with JAX_PLATFORMS=cpu for the core steps inside mach_zender_probabilities:
Cold start (phase=0.0): env build 0.06 ms; op build 0.015 ms; composite build 0.033 ms; first beam splitter apply ~718 ms (JAX/XLA compile cold-start + contraction setup); phase shift apply ~68 ms; second beam splitter apply ~2.7 ms; trace-outs ~0.18 ms.
Warm cache (subsequent phases): env/op/composite each <0.05 ms; first beam splitter apply ~3–4 ms; phase shift ~2 ms; second beam splitter ~2–3 ms; trace-outs ~0.17 ms.
Findings:
The dominant cost is the first
apply_operation(JAX compilation and contraction path). Subsequent calls reuse compiled kernels and drop to low single-digit milliseconds.Python-layer overhead (Operation → ShapePlan → adapters → kernels) is negligible after warm-up, but batching ops into a single IR execution could reduce Python hops further.
Practical guidance: perform a warm-up call before benchmarking; cache ShapePlans/compiled kernels aggressively; consider a pure functional “apply IR sequence” path to amortize setup.
Progress Checkpoint (Runtime & Envelope Integration)
IR + runtime are live:
StateSpec/OpSpec/MeasureSpec/TraceOutSpecplusexecute_circuitroute through planner/kernels without touching envelope classes.Envelope and CompositeEnvelope now emit IR internally for
apply_operationwhile preserving their public API and paper-described semantics.Planner now supports dims-only planning (
plan_from_dims) so core layers can operate without importing state objects.Next up: extend IR delegation to envelope/composite measurement and trace-out flows to remove residual state↔operation coupling.
Concrete Delivery Plan (preserve Envelope interface)
Scaffold IR/runtime (Phase A): add
core/ir.py(StateSpec,OpSpec,CircuitSpec) andruntime/execute_circuitthat call existing kernels via the planner. Keep IR pure-data and JSON-serializable.Planner extraction: move
shape_planninghelpers to accept only dims/targets; cache plans by(dims, targets). Ensure JIT paths are shape-stable.Envelope/Composite facades: keep current public methods/metadata; internally translate to IR and dispatch to runtime. Maintain combine/apply/measure/trace_out signatures and behaviors from the paper.
Operation builders: refactor
*_operation.pyto produceOpSpec+ operator matrices; enums stay unannotated. Add small validation helpers; remove protocol dependencies.Measurements/RNG unification: normalize measurement returns across vector/matrix/POVM; pass/return RNG keys explicitly; avoid Optional tails and widen scalar handling to JAX arrays.
Interop adapters: add
interop/shims that map Weave Circuits (IR) to StrawberryFields/Piquasso/Qiskit and back; keep Envelope-facing API untouched.Testing & parity: wire examples/benchmarks through the new runtime via envelopes; add snapshot tests for IR JSON round-trip and a few deterministic measurement cases on CPU. Use examples as smoke tests after each milestone.