Describe the bug
After successfully installing the vision-mamba package in my environment, attempting to import it using from vision_mamba.model import Vim
results in an ImportError. The error message indicates a problem with shape mismatch during a matrix multiplication operation in one of the package's dependencies.
To Reproduce
Steps to reproduce the behavior:
from vision_mamba.model import Vim
.Additional context
Traceback (most recent call last):
File "", line 1, in
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/vision_mamba/model.py", line 4, in
from zeta.nn.modules.ssm import SSM
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/init.py", line 28, in
from zeta.nn import *
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/init.py", line 3, in
from zeta.nn.modules import *
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/init.py", line 47, in
from zeta.nn.modules.mlp_mixer import MLPMixer
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py", line 145, in
output = mlp_mixer(example_input)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py", line 125, in forward
x = mixer_block(x)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py", line 63, in forward
y = self.tokens_mlp(y)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py", line 30, in forward
y = self.dense1(x)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)
Pay now to fund the work behind this issue.
Get updates on progress being made.
Maintainer is rewarded once the issue is completed.
You're funding impactful open source efforts
You want to contribute to this effort
You want to get funding like this too