Priors

A hidten.prior.Prior can by attachd to any hidten.generic.Module. A prior rates the matrix() of its module and returns prior_scores that describe how likely the parameters are.

In the following example we’ll train a model that contains an HMM with a categorical emitter. We’ll attach a hidten.tf.prior.dirichlet.TFDirichletPrior to the emitter. Lets assume with know a priori that the first symbol is much more likely to be observed than the others and that we are in a situation where this is not necessarily reflected in the training data. We’ll use the concentration parameters of the Dirichlet prior to encode this knowledge.

First, we define a model.

import numpy as np
import tensorflow as tf

from hidten import HMMMode
from hidten.tf import TFHMM
from hidten.tf.prior.dirichlet import TFDirichletPrior


class HMMModel(tf.keras.Model):

    def __init__(self, use_prior: bool=False) -> None:
        super().__init__()
        self.hmm = TFHMM(states=4)

        self.hmm.emitter[0].initializer = tf.keras.initializers.GlorotNormal()

        self._use_prior = use_prior
        if use_prior:
            prior = TFDirichletPrior()
            # concentration parameters are shared between all states in the head
            prior.share = list(range(3)) * 4
            # a priori we expect high concentration on the first symbol
            prior.initializer = [100]+[0.1]*2
            self.hmm.emitter[0].prior = prior

        self.hmm.transitioner.allow = [
            (0, 0, 0), (0, 1, 1),
            (0, 0, 1), (0, 1, 2),
            (0, 2, 2),
            (0, 2, 3),
            (0, 3, 3),
            (0, 3, 1),
        ]
        self.hmm.transitioner.share = [(0, 2), (2, 4)]
        self.hmm.transitioner.values = [0.6, 0.4, 0.8, 0.2, 0.1, 0.9]

        self.out = self.add_weight(
            shape=(1, 4, 5),
            initializer=tf.keras.initializers.GlorotNormal(),
        )

    def build(self, input_shape: tuple[int | None, ...]) -> None:
        self.hmm.build(input_shape)

    def call(self, x: tf.Tensor) -> tf.Tensor:
        x = tf.nn.softmax(x) # to get hmm outputs that are not nan
        hmm_out = self.hmm(x, mode=HMMMode.POSTERIOR, parallel=25)
        if self._use_prior:
            prior_log_pdf = self.hmm.prior_scores()
            prior_loss = -prior_log_pdf
            self.add_loss(prior_loss)
        return tf.einsum("bthd,hdo->bto", hmm_out, self.out)

We train the model on random data - clearly this data does not reflect the true distribution we want to learn. However, the Dirichlet prior should help guide the learning process.

ds = tf.data.Dataset.from_tensor_slices((
        np.random.normal(size=(64, 1000, 3)),
        np.random.randint(0, 5, size=(64, 1000)),
))
ds = ds.batch(32)

model = HMMModel(use_prior=True)
model.build((None, None, 3))

model.compile(
    # larger learning rate to see some effect of the prior in a very short
    # training
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.1),
    loss= tf.losses.SparseCategoricalCrossentropy(from_logits=True),
    jit_compile=True,
)

before_training = model.hmm.emitter[0].matrix()

model.fit(ds, epochs=3)

after_training = model.hmm.emitter[0].matrix()

Before training, the emission matrix was:

<tf.Tensor: shape=(1, 4, 3), dtype=float32, numpy=
array([[[0.35459444, 0.31414294, 0.33126262],
        [0.37932956, 0.4167584 , 0.20391206],
        [0.2959837 , 0.31581268, 0.38820365],
        [0.2735348 , 0.53488946, 0.19157575]]], dtype=float32)>

After training we can observe a shift towards the first symbol:

<tf.Tensor: shape=(1, 4, 3), dtype=float32, numpy=
 array([[[0.64187217, 0.17431244, 0.18381536],
         [0.66569066, 0.22450018, 0.10980911],
         [0.5790523 , 0.18882595, 0.2321217 ],
         [0.5522496 , 0.32969874, 0.11805164]]], dtype=float32)>