Merge pull request #574 from ufownl:bugfix/vit_weights

PiperOrigin-RevId: 761948356
This commit is contained in:
Copybara-Service 2025-05-22 07:04:53 -07:00
commit eb8a463038
1 changed files with 1 additions and 1 deletions

View File

@ -224,7 +224,7 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(vit.attn_out_w, kMustRead)); func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
func(TENSOR_ARGS(vit.attn_out_b, kMustRead)); func(TENSOR_ARGS(vit.attn_out_b, kMustRead));
func(TENSOR_ARGS(vit.qkv_einsum_w, kMustRead)); func(TENSOR_ARGS(vit.qkv_einsum_w, kMustRead));
func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead)); func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead | TensorArgs::kNoPad));
// MlpBlock. // MlpBlock.
func(TENSOR_ARGS(vit.linear_0_w, kMustRead)); func(TENSOR_ARGS(vit.linear_0_w, kMustRead));
func(TENSOR_ARGS(vit.linear_0_b, kMustRead)); func(TENSOR_ARGS(vit.linear_0_b, kMustRead));