Building LSTMs with PyTorch and Lightning AI Part 3: Finishing the LSTM Cell

In the previous article, we started with the creation of LSTM cell.

In this article we will continue building the LSTM Unit as well as create the forward pass and the optimizer.

Creating the Short-Term Memory

In this stage, we create the updated short-term memory and determine what percentage of it should be sent to the output.

First, we calculate the output percentage:

output_percent = torch.sigmoid(
    (short_memory * self.wo1) +
    (input_value * self.wo2) +
    self.bo1
)

Here:

  • wo1 is the weight associated with the current short-term memory.
  • wo2 is the weight associated with the current input value.
  • bo1 is the bias term.

The sigmoid function produces a value between 0 and 1, representing the percentage of information that should be passed to the output.

Next, we use this percentage to scale the new short-term memory.

We first apply the tanh activation function to the updated long-term memory, and then multiply the result by output_percent.

updated_short_memory = torch.tanh(updated_long_memory) * output_percent

Finally, we return the updated long-term and short-term memory values:

return [updated_long_memory, updated_short_memory]

At this point, our lstm_unit() function is complete.

def lstm_unit(self, input_value, long_memory, short_memory):

    long_remember_percent = torch.sigmoid(
        (short_memory * self.wlr1) +
        (input_value * self.wlr2) +
        self.blr1
    )

    potential_remember_percent = torch.sigmoid(
        (short_memory * self.wpr1) +
        (input_value * self.wpr2) +
        self.bpr1
    )

    potential_memory = torch.tanh(
        (short_memory * self.wp1) +
        (input_value * self.wp2) +
        self.bp1
    )

    updated_long_memory = (
        (long_memory * long_remember_percent) +
        (potential_remember_percent * potential_memory)
    )

    output_percent = torch.sigmoid(
        (short_memory * self.wo1) +
        (input_value * self.wo2) +
        self.bo1
    )

    updated_short_memory = (
        torch.tanh(updated_long_memory) * output_percent
    )

    return [updated_long_memory, updated_short_memory]

Now that we have implemented the LSTM unit, the next step is to create the forward() method that performs a forward pass through the unrolled LSTM.

For this example, the input will be the stock prices from the previous four days.

First, we initialize the long-term and short-term memory values:

def forward(self, input):
    long_memory = 0
    short_memory = 0

Next, we process each day’s stock price through the LSTM unit:

def forward(self, input):

    long_memory = 0
    short_memory = 0

    day1 = input[0]
    day2 = input[1]
    day3 = input[2]
    day4 = input[3]

    long_memory, short_memory = self.lstm_unit(
        day1, long_memory, short_memory
    )

    long_memory, short_memory = self.lstm_unit(
        day2, long_memory, short_memory
    )

    long_memory, short_memory = self.lstm_unit(
        day3, long_memory, short_memory
    )

    long_memory, short_memory = self.lstm_unit(
        day4, long_memory, short_memory
    )

    return short_memory

Here, the same LSTM unit is reused for each day’s input. As each value is processed, the long-term and short-term memory are updated and carried forward to the next step.

After the fourth day, we return the final short-term memory, which serves as the output of the LSTM.

Now that we have a forward() method capable of performing a forward pass through the unrolled LSTM, we are ready to configure the optimizer.

This is straightforward:

def configure_optimizers(self):
    return Adam(self.parameters())

This tells Lightning to use the Adam optimizer to train all trainable parameters in the model.

In the next article, we will explore the training_step() method, which is responsible for calculating the loss during training.

AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs — without telling you. You often find out in production.

git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.

Any feedback or contributors are welcome! It’s online, source-available, and ready for anyone to use.

Give it a ⭐ star on Github

Total
0
Shares
Leave a Reply

Your email address will not be published. Required fields are marked *

Previous Post

Facebook rolls out an AI companion app for creators

Next Post

How Accurate Is Your 3D Laser Scanner? The 4 Precision Metrics That Actually Matter

Related Posts