@@ -43,6 +43,7 @@ def apply_rotary_pos_emb_single(x, cos, sin, position_ids=None, unsqueeze_dim=1)
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
+ print(cos.shape, sin.shape, x.shape)
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed