Hello, I presume according to BitNet paper the weight should be -1 or 1. But
import torch
from bitnet import BitLinearNew
# Create a random tensor of shape (16, 10)
x = torch.randn(2, 10, 10)
# Create an instance of the BitLinearNew class with input size 10, output size 20, and 2 groups
layer = BitLinearNew(
10,
20,
)
# Perform a forward pass through the BitLinearNew layer with input x
output = layer(x)
print(layer.weight.dtype)
print(layer.weight)
torch.float32
Parameter containing:
tensor([[ 0.1634, 0.2419, -0.0605, 0.1592, 0.2348, -0.1431, -0.1634, 0.0171,
-0.1672, -0.1526],
[-0.0848, 0.0079, -0.2014, -0.0492, 0.2833, 0.1290, -0.2156, -0.1515,
-0.0473, -0.0839],
[ 0.2230, 0.1434, -0.1410, -0.0626, 0.1189, -0.1652, -0.2978, -0.0287,
0.1025, 0.2458],
[-0.1670, -0.0222, -0.0272, -0.2312, 0.1880, -0.2040, -0.0305, 0.1009,
-0.2247, 0.0124],
[ 0.1351, -0.2926, 0.1891, -0.1614, 0.2894, -0.2931, 0.0802, 0.2884,
0.0454, -0.1398],
[-0.2954, 0.2651, -0.0062, -0.1592, 0.2138, -0.2038, 0.2965, -0.2545,
0.0505, -0.0811],
[-0.3062, -0.1191, -0.1521, 0.1021, -0.1865, -0.1102, 0.2120, -0.2865,
0.1754, 0.1763],
[ 0.1375, -0.2975, 0.0399, -0.1723, -0.0526, -0.2694, 0.1838, -0.1826,
0.2806, -0.1438],
[-0.3150, 0.2163, 0.1946, -0.0244, 0.0657, -0.1531, -0.0310, 0.0071,
0.2590, 0.0985],
[ 0.0402, 0.0704, -0.1441, -0.1929, -0.2450, 0.2408, -0.0750, 0.0238,
0.3030, 0.0516],
[ 0.1537, -0.2231, -0.0092, -0.1068, 0.3131, 0.0859, -0.1692, -0.2364,
0.2257, 0.2601],
[-0.0478, -0.2978, -0.2025, -0.2411, -0.3061, -0.2566, 0.0564, -0.0906,
0.2113, 0.3118],
[-0.1048, 0.2073, -0.2126, -0.1883, 0.0463, -0.1716, -0.3052, 0.0548,
0.2079, 0.2587],
[-0.1387, 0.1778, -0.1886, 0.1239, 0.0265, -0.0421, -0.1020, 0.2481,
-0.0840, 0.1879],
[ 0.0707, -0.0534, 0.0623, 0.0803, 0.3135, 0.2192, -0.1202, 0.3139,
0.0781, -0.0512],
[ 0.2812, 0.2515, -0.0371, 0.0248, 0.0231, -0.0437, 0.0875, 0.3085,
-0.0482, -0.0092],
[ 0.1735, 0.2584, -0.0900, -0.1616, 0.1253, 0.1352, 0.1841, 0.1416,
-0.0686, -0.0269],
[-0.3121, -0.1050, 0.0265, 0.0242, 0.1973, 0.1816, -0.0084, 0.2866,
0.2559, -0.2523],
[ 0.1272, -0.2361, 0.0847, -0.0724, 0.2531, 0.0948, -0.0765, -0.1252,
-0.0459, -0.0133],
[-0.0660, 0.0650, 0.2529, -0.1763, -0.1248, -0.1073, -0.2926, 0.1837,
0.1265, -0.0090]], requires_grad=True)
Am I missing something?
Pay now to fund the work behind this issue.
Get updates on progress being made.
Maintainer is rewarded once the issue is completed.
You're funding impactful open source efforts
You want to contribute to this effort
You want to get funding like this too