diff --git a/src/main/java/com/ioa/IoASystem.java b/src/main/java/com/ioa/IoASystem.java index f4abdd0..746501a 100644 --- a/src/main/java/com/ioa/IoASystem.java +++ b/src/main/java/com/ioa/IoASystem.java @@ -42,7 +42,7 @@ public class IoASystem { @Bean public BedrockLanguageModel bedrockLanguageModel() { - return new BedrockLanguageModel("anthropic.claude-v2"); + return new BedrockLanguageModel("anthropic.claude-3-5-sonnet-20240620-v1:0"); } @Bean diff --git a/src/main/java/com/ioa/model/BedrockLanguageModel.java b/src/main/java/com/ioa/model/BedrockLanguageModel.java index bb530cd..3b42c71 100644 --- a/src/main/java/com/ioa/model/BedrockLanguageModel.java +++ b/src/main/java/com/ioa/model/BedrockLanguageModel.java @@ -7,9 +7,8 @@ import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import software.amazon.awssdk.core.SdkBytes; import com.fasterxml.jackson.databind.ObjectMapper; - -import java.util.Map; -import java.util.HashMap; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.databind.node.ArrayNode; import org.springframework.stereotype.Component; @@ -30,26 +29,32 @@ public class BedrockLanguageModel { public String generate(String prompt) { try { - Map requestBody = new HashMap<>(); - requestBody.put("prompt", prompt); - requestBody.put("max_tokens_to_sample", 500); + ObjectNode requestBody = objectMapper.createObjectNode(); + requestBody.put("anthropic_version", "bedrock-2023-05-31"); + ArrayNode messages = requestBody.putArray("messages"); + ObjectNode message = messages.addObject(); + message.put("role", "user"); + message.put("content", prompt); + + requestBody.put("max_tokens", 500); requestBody.put("temperature", 0.7); requestBody.put("top_p", 0.9); String jsonPayload = objectMapper.writeValueAsString(requestBody); - InvokeModelRequest invokeModelRequest = InvokeModelRequest.builder() + InvokeModelRequest invokeRequest = InvokeModelRequest.builder() .modelId(modelId) .contentType("application/json") .accept("application/json") .body(SdkBytes.fromUtf8String(jsonPayload)) .build(); - InvokeModelResponse response = bedrockClient.invokeModel(invokeModelRequest); + InvokeModelResponse response = bedrockClient.invokeModel(invokeRequest); String responseBody = response.body().asUtf8String(); - Map responseMap = objectMapper.readValue(responseBody, Map.class); - return (String) responseMap.get("completion"); + ObjectNode responseJson = (ObjectNode) objectMapper.readTree(responseBody); + return responseJson.path("content").asText(); + } catch (Exception e) { throw new RuntimeException("Error generating text with Bedrock", e); } diff --git a/target/classes/com/ioa/IoASystem.class b/target/classes/com/ioa/IoASystem.class index a9a54ae..31d5438 100644 Binary files a/target/classes/com/ioa/IoASystem.class and b/target/classes/com/ioa/IoASystem.class differ diff --git a/target/classes/com/ioa/model/BedrockLanguageModel.class b/target/classes/com/ioa/model/BedrockLanguageModel.class index e1f9452..6077460 100644 Binary files a/target/classes/com/ioa/model/BedrockLanguageModel.class and b/target/classes/com/ioa/model/BedrockLanguageModel.class differ