When I used the mamba network, I defined a loss to test backpropagation and found that the calculation was very slow. Setting the len length to 1024 requires a long waiting time. code show as below:
`import torch
import torch.nn as nn
from zeta.nn import MambaBlock
block = MambaBlock(dim=512, depth=1)
x = torch.randn(1, 1024, 512)
target = torch.randn(1, 1024, 512)
loss_fn = nn.MSELoss()
y = block(x)
loss = loss_fn(y, target)
loss.backward()
print("Output shape:", y.shape)
print("Loss value:", loss.item())
`
Pay now to fund the work behind this issue.
Get updates on progress being made.
Maintainer is rewarded once the issue is completed.
You're funding impactful open source efforts
You want to contribute to this effort
You want to get funding like this too