diff --git a/tools/server/tests/unit/test_embedding.py b/tools/server/tests/unit/test_embedding.py index 50601b8396..f244774903 100644 --- a/tools/server/tests/unit/test_embedding.py +++ b/tools/server/tests/unit/test_embedding.py @@ -255,3 +255,57 @@ def test_embedding_openai_library_base64(): # make sure the decoded data is the same as the original for x, y in zip(floats, vec0): assert abs(x - y) < EPSILON + + +def test_embedding_pooling_max(): + global server + server.pooling = 'max' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": "I believe the meaning of life is", + }) + assert res.status_code == 200 + assert len(res.body['data']) == 1 + assert 'embedding' in res.body['data'][0] + assert len(res.body['data'][0]['embedding']) > 1 + + # make sure embedding vector is normalized + assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON + + +def test_embedding_pooling_max_multiple(): + global server + server.pooling = 'max' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "Write a joke about AI", + "This is a test", + "This is another test", + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 4 + for d in res.body['data']: + assert 'embedding' in d + assert len(d['embedding']) > 1 + + +def test_embedding_pooling_max_same_prompt(): + global server + server.pooling = 'max' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "I believe the meaning of life is", + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 2 + # same input should give same output + v0 = res.body['data'][0]['embedding'] + v1 = res.body['data'][1]['embedding'] + for x, y in zip(v0, v1): + assert abs(x - y) < EPSILON