Working bedrock generate queries

This commit is contained in:
Mahesh Kommareddi 2024-07-16 20:32:00 -04:00
parent 09f508f5fa
commit 72a7555ac6
4 changed files with 16 additions and 11 deletions

View File

@ -42,7 +42,7 @@ public class IoASystem {
@Bean @Bean
public BedrockLanguageModel bedrockLanguageModel() { public BedrockLanguageModel bedrockLanguageModel() {
return new BedrockLanguageModel("anthropic.claude-v2"); return new BedrockLanguageModel("anthropic.claude-3-5-sonnet-20240620-v1:0");
} }
@Bean @Bean

View File

@ -7,9 +7,8 @@ import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.SdkBytes;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.util.Map; import com.fasterxml.jackson.databind.node.ArrayNode;
import java.util.HashMap;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@ -30,26 +29,32 @@ public class BedrockLanguageModel {
public String generate(String prompt) { public String generate(String prompt) {
try { try {
Map<String, Object> requestBody = new HashMap<>(); ObjectNode requestBody = objectMapper.createObjectNode();
requestBody.put("prompt", prompt); requestBody.put("anthropic_version", "bedrock-2023-05-31");
requestBody.put("max_tokens_to_sample", 500); ArrayNode messages = requestBody.putArray("messages");
ObjectNode message = messages.addObject();
message.put("role", "user");
message.put("content", prompt);
requestBody.put("max_tokens", 500);
requestBody.put("temperature", 0.7); requestBody.put("temperature", 0.7);
requestBody.put("top_p", 0.9); requestBody.put("top_p", 0.9);
String jsonPayload = objectMapper.writeValueAsString(requestBody); String jsonPayload = objectMapper.writeValueAsString(requestBody);
InvokeModelRequest invokeModelRequest = InvokeModelRequest.builder() InvokeModelRequest invokeRequest = InvokeModelRequest.builder()
.modelId(modelId) .modelId(modelId)
.contentType("application/json") .contentType("application/json")
.accept("application/json") .accept("application/json")
.body(SdkBytes.fromUtf8String(jsonPayload)) .body(SdkBytes.fromUtf8String(jsonPayload))
.build(); .build();
InvokeModelResponse response = bedrockClient.invokeModel(invokeModelRequest); InvokeModelResponse response = bedrockClient.invokeModel(invokeRequest);
String responseBody = response.body().asUtf8String(); String responseBody = response.body().asUtf8String();
Map<String, Object> responseMap = objectMapper.readValue(responseBody, Map.class); ObjectNode responseJson = (ObjectNode) objectMapper.readTree(responseBody);
return (String) responseMap.get("completion"); return responseJson.path("content").asText();
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException("Error generating text with Bedrock", e); throw new RuntimeException("Error generating text with Bedrock", e);
} }